Source code for openstef_models.workflows.callbacks.model_performance_callback
# SPDX-FileCopyrightText: 2026 Contributors to the OpenSTEF project <openstef@lfenergy.org>
#
# SPDX-License-Identifier: MPL-2.0
"""Model performance callback for forecasting workflows.
Evaluates model performance against a specified metric and threshold at the end of fitting.
If the model's performance does not meet the defined criteria, a ModelUnderperformingError is raised.
This allows for early stopping of workflows or using a fallback model when performance is insufficient.
"""
import logging
from typing import override
from pydantic import Field, PrivateAttr
from openstef_beam.evaluation.metric_providers import MetricDirection
from openstef_core.base_model import BaseConfig
from openstef_core.exceptions import ModelUnderperformingError
from openstef_core.types import QuantileOrGlobal
from openstef_models.mixins.callbacks import WorkflowContext
from openstef_models.models.forecasting_model import ModelFitResult
from openstef_models.workflows.custom_forecasting_workflow import CustomForecastingWorkflow, ForecastingCallback
[docs]
class ModelPerformanceCallback(BaseConfig, ForecastingCallback):
"""Callback for comparing model performance against a treshold during the fit process.
This callback evaluates the model's performance using a specified metric after fitting.
If the performance metric does not meet the defined threshold, a ModelUnderperformingError is raised.
"""
metric_name: str = Field(description="The name of the performance metric to evaluate.")
threshold: float = Field(
description="The minimum acceptable value for the performance metric. If the model's performance is "
"below or above this threshold (depending on `metric_direction`), it will be considered underperforming."
)
metric_direction: MetricDirection = Field(
description=("Direction of the performance metric. Either 'higher_is_better' or 'lower_is_better'.")
)
quantile: QuantileOrGlobal = Field(
default="global",
description=(
"The quantile level to evaluate the metric on. Use 'global' for overall performance metrics, or specify a "
"quantile (e.g., 0.5 for median) for quantile-specific metrics."
),
)
_logger: logging.Logger = PrivateAttr(default_factory=lambda: logging.getLogger(__name__))
[docs]
@override
def on_fit_end(
self,
context: WorkflowContext[CustomForecastingWorkflow],
result: ModelFitResult,
) -> None:
"""Evaluate model performance at the end of fitting and raise an error if underperforming.
Args:
context: The workflow context that completed fitting.
result: Result of the fitting process containing performance metrics.
Raises:
ModelUnderperformingError: If the model's performance metric is below the defined threshold.
"""
if result.metrics_val is None:
self._logger.warning("No validation metrics found in fit results. Skipping performance evaluation.")
return
metric_value = result.metrics_val.get_metric(self.quantile, self.metric_name)
if metric_value is None:
self._logger.warning(
"Performance metric '%s' not found in fit results. Skipping performance evaluation.",
self.metric_name,
)
return
match self.metric_direction:
case "higher_is_better" if metric_value < self.threshold:
raise ModelUnderperformingError(
metric_name=self.metric_name,
metric_value=metric_value,
threshold=self.threshold,
)
case "lower_is_better" if metric_value > self.threshold:
raise ModelUnderperformingError(
metric_name=self.metric_name,
metric_value=metric_value,
threshold=self.threshold,
)
case _:
return