Source code for openstef_beam.analysis.visualizations.precision_recall_curve_visualization

# SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
#
# SPDX-License-Identifier: MPL-2.0

"""Precision-recall curve visualization provider.

This module provides visualization for precision-recall curves, useful for
evaluating binary classification performance at different threshold levels.
"""

from typing import override

from openstef_beam.analysis.models import AnalysisAggregation, RunName, TargetMetadata, VisualizationOutput
from openstef_beam.analysis.plots import PrecisionRecallCurvePlotter
from openstef_beam.analysis.visualizations.base import ReportTuple, VisualizationProvider
from openstef_beam.evaluation import EvaluationSubsetReport
from openstef_core.types import Quantile


[docs] class PrecisionRecallCurveVisualization(VisualizationProvider): """Creates precision-recall curves for evaluating binary classification performance. Displays the classic precision-recall trade-off as a curve where each point represents performance at a different probability threshold. The closer the curve to the top-right corner, the better the model performs across all thresholds. Two evaluation modes: - Standard: Traditional precision/recall based on binary classification accuracy - Effective: Specialized for congestion management, evaluating whether forecasts provide actionable insights for grid operators (correct direction + sufficient magnitude) What you'll see: - Curve plotting precision (Y-axis) vs recall (X-axis) - Each point represents a different probability threshold - Area under curve (AUC-PR) as overall performance metric - Multiple curves when comparing models or targets - Reference lines showing random classifier performance Interpretation guide: - High precision: Few false alarms when predicting events - High recall: Catches most actual events that occur - Effective mode: Focuses on operationally useful predictions for grid management Example: >>> from openstef_beam.analysis import AnalysisConfig >>> from openstef_beam.analysis.visualizations import PrecisionRecallCurveVisualization >>> >>> analysis_config = AnalysisConfig( ... visualization_providers=[ ... PrecisionRecallCurveVisualization( ... name="precision_recall", ... effective_precision_recall=True, # For congestion management ... ), ... ] ... ) """ effective_precision_recall: bool = False @property @override def supported_aggregations(self) -> set[AnalysisAggregation]: return { AnalysisAggregation.NONE, AnalysisAggregation.RUN_AND_NONE, AnalysisAggregation.TARGET, } @property def _precision_metric_name(self) -> str: return "effective_precision" if self.effective_precision_recall else "precision" @property def _recall_metric_name(self) -> str: return "effective_recall" if self.effective_precision_recall else "recall" def _extract_precision_recall_values( self, report: EvaluationSubsetReport ) -> tuple[list[float], list[float], list[Quantile]]: """Extract precision and recall values from evaluation report. Args: report: The evaluation subset report Returns: A tuple of (precision_values, recall_values, quantiles) Raises: ValueError: If no global metrics are found """ global_metrics = report.get_global_metric() if global_metrics is None: raise ValueError("No global metrics found in the report.") quantiles = global_metrics.get_quantiles() precision_values: list[float] = [] recall_values: list[float] = [] for quantile in quantiles: quantile_metrics = global_metrics.metrics.get(quantile, {}) precision = quantile_metrics.get(self._precision_metric_name) recall = quantile_metrics.get(self._recall_metric_name) if precision is not None and recall is not None: precision_values.append(precision) recall_values.append(recall) return precision_values, recall_values, quantiles def _create_plot_title(self, context: str) -> str: """Create a formatted title for the plot. Args: context: The context description (e.g., target name, run name) Returns: Formatted plot title """ curve_type = "Effective Precision-Recall" if self.effective_precision_recall else "Precision-Recall" return f"{curve_type} Curve for {context}"
[docs] @override def create_by_none( self, report: EvaluationSubsetReport, metadata: TargetMetadata, ) -> VisualizationOutput: plotter = PrecisionRecallCurvePlotter() precision_values, recall_values, quantiles = self._extract_precision_recall_values(report) plotter.add_model( model_name=metadata.run_name, precision_values=precision_values, recall_values=recall_values, quantiles=quantiles, ) title = self._create_plot_title(metadata.name) figure = plotter.plot(title=title) return VisualizationOutput(name=self.name, figure=figure)
[docs] @override def create_by_run_and_none(self, reports: dict[RunName, list[ReportTuple]]) -> VisualizationOutput: plotter = PrecisionRecallCurvePlotter() # Get the first run name for the title (since we're aggregating by run) first_run_name = next(iter(reports.keys())) if reports else "" for run_name, run_reports in reports.items(): for _, report in run_reports: precision_values, recall_values, quantiles = self._extract_precision_recall_values(report) plotter.add_model( model_name=run_name, precision_values=precision_values, recall_values=recall_values, quantiles=quantiles, ) title = self._create_plot_title(f"Run {first_run_name}") figure = plotter.plot(title=title) return VisualizationOutput(name=self.name, figure=figure)
[docs] @override def create_by_target( self, reports: list[ReportTuple], ) -> VisualizationOutput: plotter = PrecisionRecallCurvePlotter() # Get the run name from the first target metadata for the title run_name = reports[0][0].run_name if reports else "" for metadata, report in reports: precision_values, recall_values, quantiles = self._extract_precision_recall_values(report) plotter.add_model( model_name=metadata.name, precision_values=precision_values, recall_values=recall_values, quantiles=quantiles, ) title = self._create_plot_title(run_name) figure = plotter.plot(title=title) return VisualizationOutput(name=self.name, figure=figure)
__all__ = ["PrecisionRecallCurveVisualization"]