Source code for openstef_beam.analysis.visualizations.windowed_metric_visualization

# SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
#
# SPDX-License-Identifier: MPL-2.0

"""Windowed metric visualization provider.

This module provides visualization for windowed metrics over time, showing
how performance metrics evolve across different time windows.
"""

import logging
import operator
from collections import defaultdict
from datetime import datetime
from typing import Literal, override

import numpy as np

from openstef_beam.analysis.models import AnalysisAggregation, GroupName, RunName, TargetMetadata, VisualizationOutput
from openstef_beam.analysis.plots import (
    WindowedMetricPlotter,
)
from openstef_beam.analysis.visualizations.base import MetricIdentifier, ReportTuple, VisualizationProvider
from openstef_beam.evaluation import EvaluationSubsetReport, Window
from openstef_core.types import Quantile

_logger = logging.getLogger(__name__)


[docs] class WindowedMetricVisualization(VisualizationProvider): """Creates time series plots showing metric evolution across evaluation windows. Displays how evaluation metrics change over time by plotting metric values on a timeline where each point represents performance over a specific time window. The visualization reveals performance trends, seasonal patterns, and helps identify periods where model accuracy degrades or improves. What you'll see: - Time series line plot with metric values on Y-axis and time on X-axis - Each point shows metric computed over a sliding evaluation window - Multiple lines when comparing across targets or model runs - Clear trends showing model performance stability over time Useful for identifying: - Performance degradation patterns over time - Seasonal effects in forecasting accuracy - Model stability across different periods - Optimal retraining intervals based on performance drops Example: >>> from openstef_beam.analysis import AnalysisConfig >>> from openstef_beam.analysis.visualizations import WindowedMetricVisualization >>> from openstef_beam.evaluation import Window >>> from datetime import timedelta >>> >>> analysis_config = AnalysisConfig( ... visualization_providers=[ ... WindowedMetricVisualization( ... name="mae_evolution", ... metric="MAE", ... window=Window(lag=timedelta(hours=0), size=timedelta(days=7)), ... ), ... ] ... ) """ metric: MetricIdentifier window: Window @property @override def supported_aggregations(self) -> set[AnalysisAggregation]: return { AnalysisAggregation.NONE, AnalysisAggregation.RUN_AND_NONE, AnalysisAggregation.TARGET, AnalysisAggregation.RUN_AND_TARGET, AnalysisAggregation.GROUP, AnalysisAggregation.RUN_AND_GROUP, } def _get_metric_info(self) -> tuple[str, Quantile | Literal["global"]]: """Extract metric name and quantile/global type from the metric config. Returns: A tuple containing: - metric_name: The name of the metric - quantile_or_global: Either a Quantile object or the literal "global" """ if isinstance(self.metric, str): return self.metric, "global" metric_name, quantile = self.metric return metric_name, quantile def _extract_windowed_metric_values( self, report: EvaluationSubsetReport, metric_name: str, quantile_or_global: Quantile | Literal["global"] ) -> list[tuple[datetime, float]]: """Extract time-value pairs for the specified metric from windowed metrics. Args: report: The evaluation subset report metric_name: Name of the metric to extract quantile_or_global: Either a Quantile object or "global" Returns: List of (timestamp, metric_value) tuples where timestamp is a datetime object """ windowed_metrics = report.get_windowed_metrics() if not windowed_metrics: return [] time_value_pairs: list[tuple[datetime, float]] = [] for window_metrics in windowed_metrics: # Only process metrics for the specified window if self.window and window_metrics.window == self.window: timestamp = window_metrics.timestamp metric_value = window_metrics.metrics.get(quantile_or_global, {}).get(metric_name) if metric_value is not None: time_value_pairs.append((timestamp, metric_value)) # Sort by timestamp for proper time series visualization time_value_pairs.sort(key=operator.itemgetter(0)) return time_value_pairs def _create_plot_title( self, metric_name: str, quantile_or_global: Quantile | Literal["global"], suffix: str ) -> str: """Create a formatted title for the plot. Args: metric_name: Name of the metric quantile_or_global: Either a Quantile object or "global" suffix: Additional suffix for the title Returns: Formatted plot title """ metric_display = f"{metric_name} (q={quantile_or_global})" if quantile_or_global != "global" else metric_name return f"Windowed {metric_display} {self.window} over Time {suffix}" def _average_time_series_across_targets( self, reports: list[ReportTuple], metric_name: str, quantile_or_global: Quantile | Literal["global"] ) -> list[tuple[datetime, float]]: """Average windowed metric values across multiple targets at each timestamp. Args: reports: List of (metadata, report) tuples from different targets metric_name: Name of the metric to extract quantile_or_global: Either a Quantile object or "global" Returns: List of (timestamp, averaged_metric_value) tuples """ # Collect all time-value pairs from all targets timestamp_values: dict[datetime, list[float]] = defaultdict(list) for _metadata, report in reports: time_value_pairs = self._extract_windowed_metric_values(report, metric_name, quantile_or_global) for timestamp, value in time_value_pairs: timestamp_values[timestamp].append(value) # Calculate average for each timestamp averaged_pairs: list[tuple[datetime, float]] = [] for timestamp in sorted(timestamp_values.keys()): values = timestamp_values[timestamp] if values: # Only include timestamps that have data avg_value = float(np.nanmean(values)) averaged_pairs.append((timestamp, avg_value)) return averaged_pairs
[docs] @override def create_by_none( self, report: EvaluationSubsetReport, metadata: TargetMetadata, ) -> VisualizationOutput: metric_name, quantile_or_global = self._get_metric_info() time_value_pairs = self._extract_windowed_metric_values(report, metric_name, quantile_or_global) if not time_value_pairs: _logger.warning("No windowed metrics for %s (%s) — skipping visualization.", metadata.name, self.name) return self._empty_output(f"No windowed metrics available for {metadata.name}") # Unpack the sorted pairs timestamps = [pair[0] for pair in time_value_pairs] metric_values = [pair[1] for pair in time_value_pairs] plotter = WindowedMetricPlotter() plotter.add_model( model_name=metadata.run_name, timestamps=timestamps, metric_values=metric_values, ) title = self._create_plot_title(metric_name, quantile_or_global, f"for {metadata.name}") figure = plotter.plot(title=title) return VisualizationOutput(name=self.name, figure=figure)
def _empty_output(self, message: str) -> VisualizationOutput: return VisualizationOutput(name=self.name, html=f"<p>{message}</p>")
[docs] @override def create_by_run_and_none(self, reports: dict[RunName, list[ReportTuple]]) -> VisualizationOutput: metric_name, quantile_or_global = self._get_metric_info() plotter = WindowedMetricPlotter() has_data = False # Collect data for each run for run_name, report_pairs in reports.items(): for _metadata, report in report_pairs: time_value_pairs = self._extract_windowed_metric_values(report, metric_name, quantile_or_global) if not time_value_pairs: _logger.warning("No windowed metrics for run '%s' (%s) — skipping.", run_name, self.name) continue # Unpack the sorted pairs timestamps = [pair[0] for pair in time_value_pairs] metric_values = [pair[1] for pair in time_value_pairs] plotter.add_model( model_name=run_name, timestamps=timestamps, metric_values=metric_values, ) has_data = True if not has_data: return self._empty_output("No windowed metrics available for any run") title = self._create_plot_title(metric_name, quantile_or_global, "by Run") figure = plotter.plot(title=title) return VisualizationOutput(name=self.name, figure=figure)
[docs] @override def create_by_target( self, reports: list[ReportTuple], ) -> VisualizationOutput: metric_name, quantile_or_global = self._get_metric_info() plotter = WindowedMetricPlotter() # Get the run name from the first target metadata for the title run_name = reports[0][0].run_name if reports else "" has_data = False # Process each target's report for metadata, report in reports: time_value_pairs = self._extract_windowed_metric_values(report, metric_name, quantile_or_global) if not time_value_pairs: _logger.warning("No windowed metrics for target '%s' (%s) — skipping.", metadata.name, self.name) continue # Unpack the sorted pairs timestamps = [pair[0] for pair in time_value_pairs] metric_values = [pair[1] for pair in time_value_pairs] # Add this target to the plotter plotter.add_model( model_name=metadata.name, # Use target name as the model name timestamps=timestamps, metric_values=metric_values, ) has_data = True if not has_data: return self._empty_output("No windowed metrics available for any target") title_suffix = "by Target" if run_name: title_suffix += f" for {run_name}" title = self._create_plot_title(metric_name, quantile_or_global, title_suffix) figure = plotter.plot(title=title) return VisualizationOutput(name=self.name, figure=figure)
# averaging over all targets in a single group
[docs] @override def create_by_run_and_target( self, reports: dict[RunName, list[ReportTuple]], ) -> VisualizationOutput: metric_name, quantile_or_global = self._get_metric_info() plotter = WindowedMetricPlotter() has_data = False # Process each run and calculate averaged metrics across its targets for run_name, target_reports in reports.items(): if not target_reports: _logger.warning("No reports for run '%s' (%s) — skipping.", run_name, self.name) continue # Average windowed metrics across all targets for this run averaged_pairs = self._average_time_series_across_targets( reports=target_reports, metric_name=metric_name, quantile_or_global=quantile_or_global, ) if not averaged_pairs: _logger.warning("No windowed averaged metrics for run '%s' (%s) — skipping.", run_name, self.name) continue # Unpack the averaged pairs timestamps = [pair[0] for pair in averaged_pairs] metric_values = [pair[1] for pair in averaged_pairs] # Add this run to the plotter with averaged values plotter.add_model( model_name=run_name, timestamps=timestamps, metric_values=metric_values, ) has_data = True if not has_data: return self._empty_output("No windowed metrics available for any run") title = self._create_plot_title(metric_name, quantile_or_global, "by run (averaged over targets in group)") figure = plotter.plot(title=title, metric_name=metric_name) return VisualizationOutput(name=self.name, figure=figure)
# averaging over all targets (also when in different groups)
[docs] @override def create_by_run_and_group( self, reports: dict[tuple[RunName, GroupName], list[ReportTuple]], ) -> VisualizationOutput: metric_name, quantile_or_global = self._get_metric_info() plotter = WindowedMetricPlotter() # Collect all targets for each run run_to_targets: dict[str, list[ReportTuple]] = {} for (run_name, _group_name), target_reports in reports.items(): run_to_targets.setdefault(run_name, []).extend(target_reports) has_data = False # Average metrics over all targets for each run for run_name, all_target_reports in run_to_targets.items(): if not all_target_reports: _logger.warning("No reports for run '%s' (%s) — skipping.", run_name, self.name) continue # Average windowed metrics across all targets for this run averaged_pairs = self._average_time_series_across_targets( reports=all_target_reports, metric_name=metric_name, quantile_or_global=quantile_or_global, ) if not averaged_pairs: _logger.warning("No windowed averaged metrics for run '%s' (%s) — skipping.", run_name, self.name) continue timestamps = [pair[0] for pair in averaged_pairs] metric_values = [pair[1] for pair in averaged_pairs] # Add this (run, group) to the plotter with averaged values plotter.add_model( model_name=run_name, timestamps=timestamps, metric_values=metric_values, ) has_data = True if not has_data: return self._empty_output("No windowed metrics available for any run") title = self._create_plot_title(metric_name, quantile_or_global, "by run (averaged over all targets)") figure = plotter.plot(title=title, metric_name=metric_name) return VisualizationOutput(name=self.name, figure=figure)
# averaging over all targets (also when in different groups) for a single run
[docs] @override def create_by_group( self, reports: dict[GroupName, list[ReportTuple]], ) -> VisualizationOutput: metric_name, quantile_or_global = self._get_metric_info() plotter = WindowedMetricPlotter() # Collect all targets from all groups all_target_reports: list[ReportTuple] = [] for report_list in reports.values(): all_target_reports.extend(report_list) # Average metrics over all targets averaged_pairs = self._average_time_series_across_targets( reports=all_target_reports, metric_name=metric_name, quantile_or_global=quantile_or_global, ) if not averaged_pairs: return self._empty_output("No windowed metrics available across all groups") timestamps = [pair[0] for pair in averaged_pairs] metric_values = [pair[1] for pair in averaged_pairs] # Use the run name from the first target if available run_name = all_target_reports[0][0].run_name if all_target_reports else "" plotter.add_model( model_name=run_name, timestamps=timestamps, metric_values=metric_values, ) title = self._create_plot_title(metric_name, quantile_or_global, "averaged over all targets") figure = plotter.plot(title=title, metric_name=metric_name) return VisualizationOutput(name=self.name, figure=figure)
__all__ = ["WindowedMetricVisualization"]