Source code for openstef_beam.analysis.visualizations.base

# SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
#
# SPDX-License-Identifier: MPL-2.0

"""Base classes for analysis visualization providers.

This module provides the foundation for creating visualization providers that
transform evaluation reports into interactive plots and charts. The base classes
define the interface and common functionality for all visualization providers.
"""

from abc import abstractmethod

from pydantic import Field

from openstef_beam.analysis.models import AnalysisAggregation, GroupName, RunName, TargetMetadata, VisualizationOutput
from openstef_beam.evaluation import EvaluationSubsetReport
from openstef_core.base_model import BaseConfig
from openstef_core.types import Quantile
from openstef_core.utils.itertools import groupby, is_all_same

ReportTuple = tuple[TargetMetadata, EvaluationSubsetReport]


[docs] class VisualizationProvider(BaseConfig): """Abstract base class for creating visualizations from evaluation reports. Provides a unified interface for generating different types of visualizations at various aggregation levels. Subclasses must implement specific visualization logic for each supported aggregation type. """ name: str = Field(description="Name of the visualization provider, used for identification in reports.")
[docs] def create( self, reports: list[ReportTuple], aggregation: AnalysisAggregation, ) -> VisualizationOutput: """Creates a visualization based on evaluation reports and aggregation level. Validates the aggregation type, groups reports appropriately, and delegates to the specific creation method. Ensures data consistency constraints are met for each aggregation type. Args: reports: List of (metadata, evaluation_report) tuples containing the data to visualize. aggregation: The aggregation level determining how reports are grouped and visualized. Returns: A visualization output containing the generated plot or HTML content. Raises: ValueError: If aggregation is not supported by this provider, if the number of reports doesn't match the aggregation requirements, if reports have inconsistent run_name for GROUP aggregation, or if reports have inconsistent group_name for RUN aggregation. """ # Validate aggregation support upfront self._validate_aggregation_support(aggregation) # Not NONE aggregations require non-empty reports if aggregation != AnalysisAggregation.NONE and len(reports) == 0: msg = f"No reports provided for {aggregation.value} aggregation." raise ValueError(msg) # Use match/case to dispatch aggregation handling match aggregation: case AnalysisAggregation.NONE: # Early return for unaggregated (single report) case if len(reports) != 1: raise ValueError("Cannot create unaggregated visualization for multiple reports.") metadata, report = reports[0] return self.create_by_none(report=report, metadata=metadata) case AnalysisAggregation.TARGET: _validate_same_run_names(reports) # Create visualization for each target in the same run return self.create_by_target(reports=reports) case AnalysisAggregation.GROUP: _validate_same_run_names(reports) # Group by group_name to compare performance across target categories grouped_reports = groupby(((m.group_name, (m, r)) for m, r in reports)) return self.create_by_group(reports=grouped_reports) case AnalysisAggregation.RUN_AND_NONE: _validate_same_group_names(reports) # Group by run_name to compare different models on the same target grouped_reports = groupby(((m.run_name, (m, r)) for m, r in reports)) return self.create_by_run_and_none(reports=grouped_reports) case AnalysisAggregation.RUN_AND_TARGET: _validate_same_group_names(reports) # Group by run_name to compare different models on the same target grouped_reports = groupby(((m.run_name, (m, r)) for m, r in reports)) return self.create_by_run_and_target(reports=grouped_reports) case AnalysisAggregation.RUN_AND_GROUP: # Group by both run_name and group_name for comparison matrix grouped_reports = groupby((((m.run_name, m.group_name), (m, r)) for m, r in reports)) return self.create_by_run_and_group(reports=grouped_reports)
[docs] def create_by_none( self, report: EvaluationSubsetReport, metadata: TargetMetadata, ) -> VisualizationOutput: """Creates visualization for a single target from a single run. Generates detailed analysis for individual target performance, typically showing time series, detailed metrics, or target-specific insights. Returns: Visualization focused on the specific target's performance. """ raise NotImplementedError
[docs] def create_by_target( self, reports: list[ReportTuple], ) -> VisualizationOutput: """Creates visualization comparing multiple targets from the same run. Groups reports by target metadata and creates visualizations showing performance differences across individual targets within the same model run. Args: reports: List of (metadata, report) tuples for each target in the run. Returns: Visualization comparing performance across different targets. """ raise NotImplementedError
[docs] def create_by_group(self, reports: dict[GroupName, list[ReportTuple]]) -> VisualizationOutput: """Creates visualization comparing multiple targets from the same run. Groups targets by their group_name and creates comparative visualizations showing performance differences across target categories or types. Args: reports: Dictionary mapping group names to lists of (metadata, report) tuples for that group. Returns: Visualization comparing performance across different target groups. """ raise NotImplementedError
[docs] def create_by_run_and_none(self, reports: dict[RunName, list[ReportTuple]]) -> VisualizationOutput: """Creates visualization comparing multiple runs on the same target group. Groups reports by run_name and creates comparative visualizations showing how different models or configurations perform on the same targets. Args: reports: Dictionary mapping run names to lists of (metadata, report) tuples for that run. Returns: Visualization comparing different model runs on the same targets. """ raise NotImplementedError
[docs] def create_by_run_and_target(self, reports: dict[RunName, list[ReportTuple]]) -> VisualizationOutput: """Creates visualization comparing multiple runs on the same target group. Groups reports by run_name and creates comparative visualizations showing how different models or configurations perform on the same targets. Args: reports: Dictionary mapping run names to lists of (metadata, report) tuples for that run. Returns: Visualization comparing different model runs on the same targets. """ raise NotImplementedError
[docs] def create_by_run_and_group( self, reports: dict[tuple[RunName, GroupName], list[ReportTuple]] ) -> VisualizationOutput: """Creates visualization across multiple runs and target groups. Creates matrix-style comparisons showing how different models perform across different target categories, enabling full comparative analysis. Args: reports: Dictionary mapping (run_name, group_name) tuples to lists of (metadata, report) tuples for that combination. Returns: Visualization matrix comparing runs across target groups. """ raise NotImplementedError
@property @abstractmethod def supported_aggregations(self) -> set[AnalysisAggregation]: """Returns the set of aggregation types supported by this provider. Returns: Set of supported VisualizationAggregation values. """ raise NotImplementedError def _validate_aggregation_support(self, aggregation: AnalysisAggregation) -> None: """Validate that the aggregation type is supported by this provider. Raises: ValueError: If aggregation is not supported by this provider. """ if aggregation not in self.supported_aggregations: msg = f"Aggregation {aggregation} is not supported by this provider." raise ValueError(msg)
def _validate_same_run_names( reports: list[ReportTuple], ) -> None: """Validate that all reports have the same run name. Raises: ValueError: If reports have different run names. """ run_names = [metadata.run_name for metadata, _ in reports] if not is_all_same(run_names): raise ValueError("All reports must have the same run name.") def _validate_same_group_names( reports: list[ReportTuple], ) -> None: """Validate that all reports have the same group name. Raises: ValueError: If reports have different group names. """ group_names = [metadata.group_name for metadata, _ in reports] if not is_all_same(group_names): raise ValueError("All reports must have the same group.") type MetricIdentifier = str | tuple[str, Quantile] __all__ = ["MetricIdentifier", "ReportTuple", "VisualizationProvider"]