Source code for openstef_models.integrations.mlflow.mlflow_storage_callback

# SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
#
# SPDX-License-Identifier: MPL-2.0

"""MLflow integration for tracking and storing forecasting workflows.

Provides a single callback for logging model training runs, artifacts,
and metrics to MLflow.  The callback is model-agnostic — it delegates to
polymorphic methods on ``BaseForecastingModel`` and ``ModelFitResult`` so
it works unchanged for both single-model and ensemble workflows.

Key behaviours:

- Logs model hyperparameters, plus per-component hyperparameters via
  ``model.component_hyperparams`` (e.g. per-forecaster in an ensemble).
- Stores training data, plus per-component datasets via
  ``result.component_fit_results``.
- Collects evaluation metrics via ``result.metrics_to_flat_dict()``;
  subclasses embed child metrics automatically.
- Stores feature-importance plots for every explainable component
  exposed by ``model.get_explainable_components()``.
- Supports model reuse (skip re-fit if a recent run exists) and
  model selection (keep the better model based on a configurable metric
  with a bias-towards-newer penalty).
"""

import logging
from datetime import UTC, datetime, timedelta
from pathlib import Path
from typing import Any, cast, override

from mlflow.entities import Run
from pydantic import Field, PrivateAttr

from openstef_beam.evaluation import SubsetMetric
from openstef_beam.evaluation.metric_providers import MetricDirection
from openstef_core.base_model import BaseConfig
from openstef_core.datasets.timeseries_dataset import TimeSeriesDataset
from openstef_core.datasets.versioned_timeseries_dataset import (
    VersionedTimeSeriesDataset,
)
from openstef_core.exceptions import (
    MissingColumnsError,
    ModelNotFoundError,
    SkipFitting,
)
from openstef_core.types import Q, QuantileOrGlobal
from openstef_models.integrations.mlflow.mlflow_storage import MLFlowStorage
from openstef_models.mixins.callbacks import WorkflowContext
from openstef_models.models.forecasting_model import BaseForecastingModel, ModelFitResult
from openstef_models.workflows.custom_forecasting_workflow import (
    CustomForecastingWorkflow,
    ForecastingCallback,
)


[docs] class MLFlowStorageCallback(BaseConfig, ForecastingCallback): """MLFlow callback for logging forecasting workflow events. Model-agnostic: delegates to polymorphic methods on the model and fit result for child hyperparams, child data, metrics, and feature importances. """ storage: MLFlowStorage = Field(default_factory=MLFlowStorage) model_reuse_enable: bool = Field(default=True) model_reuse_max_age: timedelta = Field(default=timedelta(days=7)) model_selection_enable: bool = Field(default=True) model_selection_metric: tuple[QuantileOrGlobal, str, MetricDirection] = Field( default=(Q(0.5), "R2", "higher_is_better"), description="Metric to monitor for model performance when retraining.", ) model_selection_old_model_penalty: float = Field( default=1.2, description="Penalty to apply to the old model's metric to bias selection towards newer models.", ) store_feature_importance_plot: bool = Field( default=True, description="Whether to store feature importance plots in MLflow artifacts if available.", ) _logger: logging.Logger = PrivateAttr(default=logging.getLogger(__name__))
[docs] @override def model_post_init(self, context: Any) -> None: pass
[docs] @override def on_fit_start( self, context: WorkflowContext[CustomForecastingWorkflow], data: VersionedTimeSeriesDataset | TimeSeriesDataset, ) -> None: if not self.model_reuse_enable: return run = self._find_run(model_id=context.workflow.model_id, run_name=context.workflow.run_name) if run is not None: now = datetime.now(tz=UTC) end_time_millis = cast(float | None, run.info.end_time) run_end_datetime = ( datetime.fromtimestamp(end_time_millis / 1000, tz=UTC) if end_time_millis is not None else None ) self._logger.info( "Found previous MLflow run %s for model %s ended at %s", cast(str, run.info.run_id), context.workflow.model_id, run_end_datetime, ) if run_end_datetime is not None and (now - run_end_datetime) <= self.model_reuse_max_age: raise SkipFitting("Model is recent enough, skipping re-fit.")
[docs] @override def on_fit_end( self, context: WorkflowContext[CustomForecastingWorkflow], result: ModelFitResult, ) -> None: if self.model_selection_enable: self._run_model_selection(workflow=context.workflow, result=result) # Create a new run run = self.storage.create_run( model_id=context.workflow.model_id, tags=context.workflow.model.tags, hyperparams=context.workflow.model.hyperparams, run_name=context.workflow.run_name, experiment_tags=context.workflow.experiment_tags, ) run_id: str = run.info.run_id self._logger.info("Created MLflow run %s for model %s", run_id, context.workflow.model_id) # Log per-component hyperparams for name, hparams in context.workflow.model.component_hyperparams.items(): prefixed = {f"{name}.{k}": str(v) for k, v in hparams.model_dump().items()} self.storage.log_hyperparams(run_id=run_id, params=prefixed) # Store the model input data run_path = self.storage.get_artifacts_path(model_id=context.workflow.model_id, run_id=run_id) data_path = run_path / self.storage.data_path data_path.mkdir(parents=True, exist_ok=True) result.input_dataset.to_parquet(path=data_path / "data.parquet") self._logger.info("Stored training data at %s for run %s", data_path, run_id) # Store per-component training data for name, component_result in result.component_fit_results.items(): component_path = data_path / name component_path.mkdir(parents=True, exist_ok=True) component_result.input_dataset.to_parquet(path=component_path / "data.parquet") # Store feature importance plots if self.store_feature_importance_plot: self._store_feature_importances(context.workflow.model, data_path) # Store the trained model self.storage.save_run_model( model_id=context.workflow.model_id, run_id=run_id, model=context.workflow.model, ) self._logger.info("Stored trained model for run %s", run_id) # Format the metrics for MLflow metrics = result.metrics_to_flat_dict() # Mark the run as finished self.storage.finalize_run(model_id=context.workflow.model_id, run_id=run_id, metrics=metrics) self._logger.info("Stored MLflow run %s for model %s", run_id, context.workflow.model_id)
[docs] @override def on_predict_start( self, context: WorkflowContext[CustomForecastingWorkflow], data: VersionedTimeSeriesDataset | TimeSeriesDataset, ): if context.workflow.model.is_fitted: return run = self._find_run(model_id=context.workflow.model_id, run_name=context.workflow.run_name) if run is None: raise ModelNotFoundError(model_id=context.workflow.model_id) run_id: str = run.info.run_id old_model = self.storage.load_run_model(run_id=run_id, model_id=context.workflow.model_id) if not isinstance(old_model, BaseForecastingModel): self._logger.warning( "Loaded model from run %s is not a BaseForecastingModel, cannot use for prediction", cast(str, run.info.run_id), ) return context.workflow.model = old_model # pyright: ignore[reportAttributeAccessIssue] self._logger.info( "Loaded model from MLflow run %s for model %s", run_id, context.workflow.model_id, )
def _run_model_selection(self, workflow: CustomForecastingWorkflow, result: ModelFitResult) -> None: run = self._find_run(model_id=workflow.model_id, run_name=None) if run is None: return run_id = cast(str, run.info.run_id) if not self._check_tags_compatible( run_tags=run.data.tags, # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType] new_tags=workflow.model.tags, run_id=run_id, ): return new_model = workflow.model new_metrics = result.metrics_full old_model = self._try_load_model(run_id=run_id, model_id=workflow.model_id) if old_model is None: return old_metrics = self._try_evaluate_model( run_id=run_id, old_model=old_model, input_data=result.input_dataset, ) if old_metrics is None: return if self._check_is_new_model_better(old_metrics=old_metrics, new_metrics=new_metrics): workflow.model = new_model # pyright: ignore[reportAttributeAccessIssue] else: workflow.model = old_model # pyright: ignore[reportAttributeAccessIssue] self._logger.info( "New model did not improve %s metric from previous run %s, reusing old model", self.model_selection_metric, run_id, ) raise SkipFitting("New model did not improve monitored metric, skipping re-fit.") @staticmethod def _store_feature_importances(model: BaseForecastingModel, data_path: Path) -> None: for name, component in model.get_explainable_components().items(): if component.feature_importances.empty: continue suffix = f"_{name}" if name else "" fig = component.plot_feature_importances() fig.write_html(data_path / f"feature_importances{suffix}.html") # pyright: ignore[reportUnknownMemberType] def _find_run(self, model_id: str, run_name: str | None) -> Run | None: """Find an MLflow run by model_id and optional run_name. Returns: The matching Run, or None if no run was found. """ if run_name is not None: return self.storage.search_run(model_id=model_id, run_name=run_name) runs = self.storage.search_latest_runs(model_id=model_id) return next(iter(runs), None) def _try_load_model(self, run_id: str, model_id: str) -> BaseForecastingModel | None: """Try to load a model from MLflow, returning None on failure. Returns: The loaded model, or None if loading failed. """ try: old_model = self.storage.load_run_model(run_id=run_id, model_id=model_id) except ModelNotFoundError: self._logger.warning( "Could not load model from previous run %s for model %s, skipping model selection", run_id, model_id, ) return None if not isinstance(old_model, BaseForecastingModel): self._logger.warning( "Loaded old model from run %s is not a BaseForecastingModel, skipping model selection", run_id, ) return None return old_model def _try_evaluate_model( self, run_id: str, old_model: BaseForecastingModel, input_data: TimeSeriesDataset, ) -> SubsetMetric | None: """Try to evaluate a model, returning None on failure. Returns: The evaluation metrics, or None if evaluation failed. """ try: return old_model.score(input_data) except (MissingColumnsError, ValueError) as e: self._logger.warning( "Could not evaluate old model from run %s, skipping model selection: %s", run_id, e, ) return None def _check_tags_compatible(self, run_tags: dict[str, str], new_tags: dict[str, str], run_id: str) -> bool: """Check if model tags are compatible, excluding mlflow.runName. Returns: True if tags are compatible, False otherwise. """ old_tags = {k: v for k, v in run_tags.items() if k != "mlflow.runName"} if old_tags == new_tags: return True differences = { k: (old_tags.get(k), new_tags.get(k)) for k in old_tags.keys() | new_tags.keys() if old_tags.get(k) != new_tags.get(k) } self._logger.info( "Model tags changed since run %s, skipping model selection. Changes: %s", run_id, differences, ) return False def _check_is_new_model_better( self, old_metrics: SubsetMetric, new_metrics: SubsetMetric, ) -> bool: """Compare old and new model metrics to determine if the new model is better. Returns: True if the new model is better, False otherwise. """ quantile, metric_name, direction = self.model_selection_metric old_metric = old_metrics.get_metric(quantile=quantile, metric_name=metric_name) new_metric = new_metrics.get_metric(quantile=quantile, metric_name=metric_name) if old_metric is None or new_metric is None: self._logger.warning( "Could not find %s metric for quantile %s in old or new model metrics, assuming improvement", metric_name, quantile, ) return True self._logger.info( "Comparing old model %s metric %.5f to new model %s metric %.5f for quantile %s", metric_name, old_metric, metric_name, new_metric, quantile, ) # Penalty biases selection towards newer models: # higher_is_better: lower the bar by dividing old metric by penalty # lower_is_better: raise the bar by multiplying old metric by penalty match direction: case "higher_is_better" if new_metric >= old_metric / self.model_selection_old_model_penalty: return True case "lower_is_better" if new_metric <= old_metric * self.model_selection_old_model_penalty: return True case _: return False
__all__ = [ "MLFlowStorageCallback", ]