Source code for openstef_models.explainability.mixins

# SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
#
# SPDX-License-Identifier: MPL-2.0

"""Mixins for adding explainability features to forecasting models.

Provides base classes that enable models to expose feature importance scores
and generate visualization plots.
"""

from abc import ABC, abstractmethod

import pandas as pd
import plotly.graph_objects as go

from openstef_core.datasets import ForecastInputDataset, TimeSeriesDataset
from openstef_core.types import Q, Quantile
from openstef_models.explainability.plotters.feature_importance_plotter import FeatureImportancePlotter


[docs] class ExplainableForecaster(ABC): """Mixin for forecasters that can explain feature importance. Provides a standardized interface for accessing and visualizing feature importance scores across different forecasting models. """ @property @abstractmethod def feature_importances(self) -> pd.DataFrame: """Get feature importance scores for this model. Returns DataFrame with feature names as index and quantiles as columns. Each quantile represents the importance distribution across multiple model training runs or folds. Returns: DataFrame with feature names as index and quantile columns. Values represent normalized importance scores summing to 1.0. Note: The returned DataFrame must have feature names as index and quantile columns in format 'quantile_PXX' (e.g., 'quantile_P50', 'quantile_P95'). All quantile values must be between 0 and 1. """ raise NotImplementedError
[docs] def plot_feature_importances(self, quantile: Quantile = Q(0.5)) -> go.Figure: """Create interactive treemap visualization of feature importances. Args: quantile: Which quantile of importance scores to display. Defaults to median (0.5). Returns: Plotly Figure containing treemap with feature importance scores. Color intensity indicates relative importance of each feature. """ return FeatureImportancePlotter().plot(scores=self.feature_importances, quantile=quantile)
[docs] class ContributionsMixin(ABC): """Mixin for forecasters that can explain per-sample feature contributions. Unlike ``ExplainableForecaster`` which provides aggregate feature importance, this mixin provides per-sample decomposition of predictions — i.e., how much each feature contributed to the prediction for each individual sample. For tree-based models (XGBoost), this corresponds to SHAP TreeExplainer values. For linear models (GBLinear), this is the coefficient x feature value decomposition. For ensembles, this shows each base model's contribution weight. """
[docs] @abstractmethod def predict_contributions(self, data: ForecastInputDataset) -> TimeSeriesDataset: """Compute per-sample feature contributions for the given input data. Returns a TimeSeriesDataset where columns are feature names (or model names for ensemble contributions) and rows correspond to the same time index as the input. Values represent the additive contribution of each feature to the prediction at that timestep. Args: data: Preprocessed input data (same format as ``predict()`` takes). Returns: TimeSeriesDataset with feature contributions. Columns are features, rows are timesteps. A ``bias`` column may be included for the model intercept/base value. """