Source code for openstef_beam.analysis.visualizations.summary_table_visualization

# SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
#
# SPDX-License-Identifier: MPL-2.0

"""Summary table visualization module for displaying metrics in tabular format."""

from typing import Any, NamedTuple, override

import pandas as pd

from openstef_beam.analysis.models import AnalysisAggregation, GroupName, RunName, TargetMetadata, VisualizationOutput
from openstef_beam.analysis.plots import (
    SummaryTablePlotter,
)
from openstef_beam.analysis.visualizations.base import ReportTuple, VisualizationProvider
from openstef_beam.evaluation import EvaluationSubsetReport
from openstef_core.types import QuantileOrGlobal


class MetricAggregation(NamedTuple):
    """Container for aggregated metric statistics."""

    mean: float
    min: float
    max: float
    median: float


[docs] class SummaryTableVisualization(VisualizationProvider): """Creates HTML tables summarizing evaluation metrics. Generates sortable HTML tables presenting evaluation metrics organized by quantiles, targets, and model runs. Tables automatically aggregate statistics and provide formatted overviews for reports and documentation. What you'll see: - Sortable columns showing metrics with mean, min, max, and median aggregations - Color-coded formatting for easy comparison - Automatic organization by quantiles (when applicable) and targets - Export-ready HTML format Aggregation behavior: - Single target: Simple metric table with quantile breakdown - Multiple targets: Comparative statistics across targets - Multiple runs: Model comparison with aggregated performance - Target groups: Hierarchical organization by categories Example: >>> from openstef_beam.analysis import AnalysisConfig >>> from openstef_beam.analysis.visualizations import SummaryTableVisualization >>> analysis_config = AnalysisConfig( ... visualization_providers=[ ... SummaryTableVisualization(name="performance_summary"), ... ] ... ) >>> # Tables will show all available metrics organized by: >>> # - Quantile levels (0.1, 0.5, 0.9, global) >>> # - Performance metrics (MAE, RMSE, rCRPS, etc.) >>> # - Target groupings and model runs """ @property @override def supported_aggregations(self) -> set[AnalysisAggregation]: return { AnalysisAggregation.NONE, AnalysisAggregation.RUN_AND_NONE, AnalysisAggregation.TARGET, AnalysisAggregation.GROUP, AnalysisAggregation.RUN_AND_GROUP, AnalysisAggregation.RUN_AND_TARGET, } @staticmethod def _aggregate_metric_values(values: list[float]) -> MetricAggregation: """Compute statistical aggregations for a collection of metric values. Args: values: List of numeric metric values to aggregate Returns: MetricAggregation containing mean, min, max, and median statistics """ if not values: return MetricAggregation(mean=0.0, min=0.0, max=0.0, median=0.0) series = pd.Series(values) return MetricAggregation( mean=float(series.mean()), min=float(series.min()), max=float(series.max()), median=float(series.median()), ) @staticmethod def _format_quantile(quantile: QuantileOrGlobal) -> str: """Format quantile values for display by removing trailing zeros. Args: quantile: Either a float quantile value or "global" string Returns: Formatted string representation suitable for display """ if isinstance(quantile, float): return f"{quantile:f}".rstrip("0").rstrip(".") return str(quantile) def _extract_metrics_from_report(self, report: EvaluationSubsetReport) -> list[dict[str, Any]]: """Extract all metrics from a report into a list of dictionaries. Args: report: The evaluation subset report to extract metrics from Returns: List of dictionaries containing metric data, or empty list if no metrics found """ rows: list[dict[str, Any]] = [] global_metric = report.get_global_metric() if global_metric is not None and hasattr(global_metric, "metrics"): for quantile, metrics in global_metric.metrics.items(): formatted_quantile = self._format_quantile(quantile) for metric_name, metric_value in metrics.items(): rows.append({ "Quantile|Global": formatted_quantile, "Metric": metric_name, "Value": metric_value, }) return rows @staticmethod def _create_sorted_dataframe(rows: list[dict[str, Any]], columns: list[str], sort_by: list[str]) -> pd.DataFrame: """Create and sort a DataFrame from metric rows. Args: rows: List of dictionaries containing metric data columns: Column names for the DataFrame sort_by: Column names to sort by Returns: Sorted pandas DataFrame """ # Create DataFrame with or without columns specified dataframe = pd.DataFrame(columns=columns) if not rows else pd.DataFrame(rows) # type: ignore[arg-type] return dataframe.sort_values(by=sort_by) if not dataframe.empty else dataframe
[docs] @override def create_by_none( self, report: EvaluationSubsetReport, metadata: TargetMetadata, ) -> VisualizationOutput: rows = self._extract_metrics_from_report(report) dataframe = SummaryTableVisualization._create_sorted_dataframe( rows=rows, columns=["Quantile|Global", "Metric", "Value"], sort_by=["Metric", "Quantile|Global"] ) plotter = SummaryTablePlotter(dataframe) return VisualizationOutput(name=self.name, html=plotter.plot())
[docs] @override def create_by_target( self, reports: list[ReportTuple], ) -> VisualizationOutput: rows: list[dict[str, Any]] = [] for metadata, report in reports: metric_rows = self._extract_metrics_from_report(report) for row in metric_rows: row["Target"] = metadata.name rows.append(row) dataframe = SummaryTableVisualization._create_sorted_dataframe( rows=rows, columns=["Target", "Quantile|Global", "Metric", "Value"], sort_by=["Metric", "Quantile|Global", "Target"], ) plotter = SummaryTablePlotter(dataframe) return VisualizationOutput(name=self.name, html=plotter.plot())
[docs] def create_by_group(self, reports: dict[GroupName, list[ReportTuple]]) -> VisualizationOutput: """Create summary table with aggregated metrics by group. Args: reports: Dictionary mapping group names to lists of (metadata, report) tuples Returns: VisualizationOutput containing the generated aggregated table """ metric_dict: dict[tuple[str, str, GroupName], list[float]] = {} # Collect all metric values per group, quantile/global, and metric for group, target_reports in reports.items(): for _metadata, report in target_reports: metric_rows = self._extract_metrics_from_report(report) for row in metric_rows: key = (row["Quantile|Global"], row["Metric"], group) metric_dict.setdefault(key, []).append(float(row["Value"])) # Aggregate values and build rows rows: list[dict[str, Any]] = [] for (quantile, metric_name, group), values in metric_dict.items(): agg = self._aggregate_metric_values(values) rows.append({ "Group": group, "Quantile|Global": quantile, "Metric": metric_name, "Mean": agg.mean, "Min": agg.min, "Max": agg.max, "Median": agg.median, }) dataframe = SummaryTableVisualization._create_sorted_dataframe( rows=rows, columns=["Group", "Quantile|Global", "Metric", "Mean", "Min", "Max", "Median"], sort_by=["Metric", "Quantile|Global", "Group"], ) plotter = SummaryTablePlotter(dataframe) return VisualizationOutput(name=self.name, html=plotter.plot())
[docs] @override def create_by_run_and_none(self, reports: dict[RunName, list[ReportTuple]]) -> VisualizationOutput: rows: list[dict[str, Any]] = [] for run_name, target_reports in reports.items(): for metadata, report in target_reports: metric_rows = self._extract_metrics_from_report(report) for row in metric_rows: row["Target"] = metadata.name row["Run"] = run_name rows.append(row) dataframe = SummaryTableVisualization._create_sorted_dataframe( rows=rows, columns=["Target", "Quantile|Global", "Metric", "Run", "Value"], sort_by=["Metric", "Quantile|Global", "Target", "Run"], ) plotter = SummaryTablePlotter(dataframe) return VisualizationOutput(name=self.name, html=plotter.plot())
[docs] def create_by_run_and_target( self, reports: dict[RunName, list[ReportTuple]], ) -> VisualizationOutput: """Create summary table comparing different runs on the same targets. Args: reports: Dictionary mapping run names to their report lists. Returns: Visualization output with summary table comparing runs. """ return self.create_by_run_and_none( reports=reports, )
[docs] def create_by_run_and_group( self, reports: dict[tuple[RunName, GroupName], list[ReportTuple]] ) -> VisualizationOutput: """Create summary table with aggregated metrics by run and group combinations. Args: reports: Dictionary mapping (run_name, group_name) tuples to lists of (metadata, report) tuples Returns: VisualizationOutput containing the generated aggregated comparison table """ metric_dict: dict[tuple[str, str, RunName, GroupName], list[float]] = {} # Collect all metric values per (run, group), quantile/global, and metric for (run_name, group), target_reports in reports.items(): for _metadata, report in target_reports: metric_rows = self._extract_metrics_from_report(report) for row in metric_rows: key = (row["Quantile|Global"], row["Metric"], run_name, group) metric_dict.setdefault(key, []).append(float(row["Value"])) # Aggregate values and build rows rows: list[dict[str, Any]] = [] for (quantile, metric_name, run_name, group), values in metric_dict.items(): agg = self._aggregate_metric_values(values) rows.append({ "Group": group, "Quantile|Global": quantile, "Metric": metric_name, "Run": run_name, "Mean": agg.mean, "Min": agg.min, "Max": agg.max, "Median": agg.median, }) dataframe = self._create_sorted_dataframe( rows=rows, columns=["Group", "Quantile|Global", "Metric", "Run", "Mean", "Min", "Max", "Median"], sort_by=["Metric", "Quantile|Global", "Group", "Run"], ) plotter = SummaryTablePlotter(dataframe) return VisualizationOutput(name=self.name, html=plotter.plot())
__all__ = ["SummaryTableVisualization"]