.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/10_plot_compare_ande_interpretability.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_10_plot_compare_ande_interpretability.py: ==================================================================== Interpreting ALR: Automatic Feature Selection via Weights ==================================================================== This example demonstrates the "Embedded Feature Selection" capability of Hybrid AnDE models (ALR). **The Result:** By ignoring noise, ALR achieves significantly lower Log Loss and higher Accuracy. .. GENERATED FROM PYTHON SOURCE LINES 12-146 .. image-sg:: /auto_examples/images/sphx_glr_10_plot_compare_ande_interpretability_001.png :alt: Impact of Noise on Model Weights and Performance, AnDE (Generative) Strategy: Uniform Attention, ALR (Hybrid) Strategy: Learned Attention :srcset: /auto_examples/images/sphx_glr_10_plot_compare_ande_interpretability_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Training models... | .. code-block:: Python # Author: The scikit-bayes Developers # SPDX-License-Identifier: BSD-3-Clause import matplotlib.pyplot as plt import numpy as np from sklearn.metrics import accuracy_score, log_loss from sklearn.model_selection import train_test_split from skbn.ande import ALR, AnDE # --- 1. Generate Dataset with Heavy Noise --- np.random.seed(42) n_samples = 20000 # Large sample size to allow convergence of weights n_features = 10 # A. Informative Features (Mixed Type) X_cat = np.random.randint(0, 3, size=(n_samples, 1)) X_cont = np.random.randn(n_samples, 1) # Logic: Cat=0 & X>0.5 OR Cat=1 & X<-0.5 -> Class 1 y = np.zeros(n_samples, dtype=int) mask_0 = (X_cat.flatten() == 0) & (X_cont.flatten() > 0.5) mask_1 = (X_cat.flatten() == 1) & (X_cont.flatten() < -0.5) y[mask_0 | mask_1] = 1 # B. Noise Features X_noise = np.random.randn(n_samples, 8) X = np.hstack([X_cat, X_cont, X_noise]) # Split X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.3, random_state=42 ) # --- 2. Fit and Evaluate --- print("Training models...") # Model 1: AnDE ande = AnDE(n_dependence=1) ande.fit(X_train, y_train) acc_ande = accuracy_score(y_test, ande.predict(X_test)) ll_ande = log_loss(y_test, ande.predict_proba(X_test)) # Model 2: ALR (Higher regularization for visualization sparsity) alr = ALR(n_dependence=1, l2_reg=0.05, max_iter=200) alr.fit(X_train, y_train) acc_alr = accuracy_score(y_test, alr.predict(X_test)) ll_alr = log_loss(y_test, alr.predict_proba(X_test)) # --- 3. Visualization --- features = [f"F{i}\n(Signal)" if i < 2 else f"F{i}\n(Noise)" for i in range(n_features)] indices = np.arange(n_features) ande_weights = np.ones(n_features) # Detect if we have weights per class (Level 3) or simple weights (Level 1) n_models = len(alr.ensemble_) if alr.learned_weights_.size == n_models: # Level 1: One weight per feature. Use directly. alr_weights = alr.learned_weights_ else: # Level 3: Weights per class. Reshape and average. n_classes = len(alr.classes_) weights_matrix = alr.learned_weights_.reshape(n_models, n_classes) alr_weights = np.mean(weights_matrix, axis=1) fig, axes = plt.subplots(1, 2, figsize=(16, 7), sharey=True) # Helper for info box (MOVED TO RIGHT) def add_score_box(ax, acc, ll): textstr = "\n".join( (r"$\bf{Performance (Test):}$", f"Accuracy: {acc:.3f}", f"Log Loss: {ll:.3f}") ) props = dict(boxstyle="round", facecolor="white", alpha=0.9, edgecolor="gray") ax.text( 0.95, 0.95, textstr, transform=ax.transAxes, fontsize=12, verticalalignment="top", horizontalalignment="right", bbox=props, ) # Plot AnDE (Generative) axes[0].bar(indices, ande_weights, color="lightgray", edgecolor="gray") axes[0].set_title("AnDE (Generative)\nStrategy: Uniform Attention", fontsize=14) axes[0].set_xlabel("Feature (as Super-Parent)", fontsize=12) axes[0].set_ylabel("Weight Magnitude", fontsize=12) axes[0].set_xticks(indices) axes[0].set_xticklabels(features, rotation=45, ha="right") # Highlight Signal Features with GOLD (to show ground truth importance) axes[0].patches[0].set_facecolor("gold") axes[0].patches[0].set_edgecolor("orange") axes[0].patches[1].set_facecolor("gold") axes[0].patches[1].set_edgecolor("orange") add_score_box(axes[0], acc_ande, ll_ande) # Plot ALR (Hybrid) # Use INDIGO for learned weights bars = axes[1].bar(indices, alr_weights, color="indigo", edgecolor="black", alpha=0.8) axes[1].set_title("ALR (Hybrid)\nStrategy: Learned Attention", fontsize=14) axes[1].set_xlabel("Feature (as Super-Parent)", fontsize=12) axes[1].set_xticks(indices) axes[1].set_xticklabels(features, rotation=45, ha="right") add_score_box(axes[1], acc_alr, ll_alr) # Annotate values on ALR for bar in bars: height = bar.get_height() if height > 0.1: axes[1].annotate( f"{height:.2f}", xy=(bar.get_x() + bar.get_width() / 2, height), xytext=(0, 3), textcoords="offset points", ha="center", va="bottom", fontweight="bold", color="indigo", ) fig.suptitle("Impact of Noise on Model Weights and Performance", fontsize=18) plt.tight_layout() plt.subplots_adjust(top=0.85) plt.show() .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 1.956 seconds) .. _sphx_glr_download_auto_examples_10_plot_compare_ande_interpretability.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 10_plot_compare_ande_interpretability.ipynb <10_plot_compare_ande_interpretability.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 10_plot_compare_ande_interpretability.py <10_plot_compare_ande_interpretability.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: 10_plot_compare_ande_interpretability.zip <10_plot_compare_ande_interpretability.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_