Source code for openstef_models.explainability.plotters.contributions_plotter

# SPDX-FileCopyrightText: 2026 Contributors to the OpenSTEF project <openstef@lfenergy.org>
#
# SPDX-License-Identifier: MPL-2.0

"""Visualizations for per-sample feature contributions (SHAP values)."""

from __future__ import annotations

from typing import TYPE_CHECKING

import plotly.graph_objects as go
from plotly.subplots import make_subplots  # pyright: ignore[reportUnknownVariableType]

from openstef_core.datasets import TimeSeriesDataset  # noqa: TC001  # runtime needed for pyright

if TYPE_CHECKING:
    import pandas as pd


[docs] class ContributionsPlotter: """Visualizations for per-timestep feature contributions."""
[docs] @staticmethod def plot_heatmap( contributions: TimeSeriesDataset, top_n: int = 10, target_column: str = "load", bias_column: str = "bias", *, show_prediction: bool = True, ) -> go.Figure: """Create an interactive heatmap of feature contributions over time. X-axis is the prediction datetime, Y-axis shows feature names ranked by mean absolute contribution (most important at top). Color ranges from blue (negative) through white (zero) to red (positive). When ``show_prediction`` is True a line plot of the model prediction (sum of contributions + bias) is shown above the heatmap. Args: contributions: Output of ``predict_contributions()``. top_n: Number of top features to show (ranked by mean absolute contribution). target_column: Name of the target column to exclude. Default "load". bias_column: Name of the bias column. Default "bias". show_prediction: If True, add a prediction line subplot above the heatmap. Default True. Returns: Plotly Figure with a diverging heatmap centered at zero (and optional prediction line). """ bias = contributions.data[bias_column] if bias_column in contributions.data.columns else None cols_to_drop = [c for c in [target_column, bias_column] if c in contributions.data.columns] df = contributions.data.drop(columns=cols_to_drop) ranked: list[str] = df.abs().mean().sort_values(ascending=False).head(top_n).index.tolist() # Most-important feature at top of Y-axis y_labels = list(reversed(ranked)) heatmap = go.Heatmap( z=df[y_labels].T.values, x=df.index, y=y_labels, colorscale="RdBu_r", zmid=0, colorbar={"title": "Contribution"}, showlegend=False, ) if show_prediction: prediction = df.sum(axis=1) if bias is not None: prediction += bias fig = make_subplots( rows=2, cols=1, shared_xaxes=True, row_heights=[0.2, 0.8], vertical_spacing=0.03, ) fig.add_trace( # pyright: ignore[reportUnknownMemberType] go.Scatter( x=df.index, y=prediction, mode="lines", name="Prediction", line={"color": "black", "width": 1.5}, showlegend=False, ), row=1, col=1, ) fig.add_trace(heatmap, row=2, col=1) # pyright: ignore[reportUnknownMemberType] fig.update_layout( # pyright: ignore[reportUnknownMemberType] yaxis_title="Prediction", yaxis2_title="Feature", xaxis2_title="Time", margin={"t": 30, "r": 10, "b": 40, "l": 120}, ) else: fig = go.Figure( data=heatmap, layout={ "xaxis_title": "Time", "yaxis_title": "Feature", "margin": {"t": 30, "r": 10, "b": 40, "l": 120}, }, ) return fig
[docs] @staticmethod def plot_waterfall( contributions: TimeSeriesDataset, timestep: int = 0, top_n: int = 10, target_column: str = "load", bias_column: str = "bias", ) -> go.Figure: """Create a waterfall chart decomposing a single timestep's prediction. Shows how the bias (base value) is pushed up or down by each feature's contribution to arrive at the final prediction. Args: contributions: Output of ``predict_contributions()``. timestep: Row index (0-based) of the timestep to explain. top_n: Number of top features to show. Remaining features are aggregated into an "other" bar. target_column: Name of the target column to exclude. Default "load". bias_column: Name of the bias column used as base value. Default "bias". Returns: Plotly Figure with waterfall chart. """ bias = contributions.data[bias_column] if bias_column in contributions.data.columns else None cols_to_drop = [c for c in [target_column, bias_column] if c in contributions.data.columns] df = contributions.data.drop(columns=cols_to_drop) row = df.iloc[timestep] base_value = float(bias.iloc[timestep]) if bias is not None else 0.0 # Rank by |contribution| for this specific timestep abs_sorted = row.abs().sort_values(ascending=False) top = abs_sorted.head(top_n).index.tolist() remaining = [c for c in abs_sorted.index if c not in top] names: list[str] = [bias_column] values: list[float] = [base_value] measures: list[str] = ["absolute"] for feat in top: names.append(feat) values.append(float(row[feat])) # pyright: ignore[reportArgumentType] measures.append("relative") if len(remaining) > 0: other_sum = float(row[remaining].sum()) names.append(f"other ({len(remaining)})") values.append(other_sum) measures.append("relative") names.append("Prediction") values.append(base_value + float(row.sum())) measures.append("total") timestamp = contributions.data.index[timestep] return go.Figure( go.Waterfall( x=names, y=values, measure=measures, connector={"line": {"color": "grey", "width": 0.5}}, increasing={"marker": {"color": "#ff4136"}}, decreasing={"marker": {"color": "#0074d9"}}, totals={"marker": {"color": "#2ecc40"}}, textposition="outside", text=[f"{v:+.4f}" if m == "relative" else f"{v:.4f}" for v, m in zip(values, measures, strict=True)], ), layout={ "title": f"Contributions at {timestamp}", "yaxis_title": "Contribution", "margin": {"t": 50, "r": 10, "b": 40, "l": 60}, "showlegend": False, }, )
[docs] @staticmethod def plot_bar( contributions: TimeSeriesDataset, top_n: int = 10, target_column: str = "load", bias_column: str = "bias", ) -> go.Figure: """Create a horizontal bar chart of mean absolute contributions per feature. Features are ranked from most to least important (top to bottom). Args: contributions: Output of ``predict_contributions()``. top_n: Number of top features to show. target_column: Name of the target column to exclude. Default "load". bias_column: Name of the bias column to exclude. Default "bias". Returns: Plotly Figure with horizontal bar chart. """ cols_to_drop = [c for c in [target_column, bias_column] if c in contributions.data.columns] df = contributions.data.drop(columns=cols_to_drop) mean_abs: pd.Series = df.abs().mean().sort_values(ascending=False).head(top_n) # Reverse for plotly (bottom-to-top rendering) mean_abs = mean_abs.iloc[::-1] return go.Figure( go.Bar( x=mean_abs.values, # pyright: ignore[reportArgumentType] y=mean_abs.index.tolist(), orientation="h", marker_color="#1f77b4", hovertemplate="<b>%{y}</b><br>mean |SHAP|: %{x:.4f}<extra></extra>", ), layout={ "xaxis_title": "mean |SHAP value|", "yaxis_title": "Feature", "margin": {"t": 30, "r": 10, "b": 40, "l": 120}, "showlegend": False, }, )