# SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
#
# SPDX-License-Identifier: MPL-2.0
"""Forecasting workflow presets and configurations.
Provides predefined configurations and factory functions for common forecasting workflows,
including XGBoost, GBLinear, and Flatliner models with appropriate preprocessing pipelines.
"""
from datetime import timedelta
from decimal import Decimal
from typing import Literal
from pydantic import Field
from pydantic_extra_types.coordinate import Coordinate, Latitude, Longitude
from pydantic_extra_types.country import CountryAlpha2
from openstef_beam.evaluation.metric_providers import (
MetricDirection,
MetricProvider,
ObservedProbabilityProvider,
R2Provider,
)
from openstef_core.base_model import BaseConfig
from openstef_core.datasets.timeseries_dataset import TimeSeriesDataset
from openstef_core.mixins import TransformPipeline
from openstef_core.mixins.transform import Transform
from openstef_core.types import LeadTime, Q, Quantile, QuantileOrGlobal
from openstef_models.integrations.mlflow import MLFlowStorage, MLFlowStorageCallback
from openstef_models.mixins import ModelIdentifier
from openstef_models.models import ForecastingModel
from openstef_models.models.forecasting.constant_quantile_forecaster import ConstantQuantileForecaster
from openstef_models.models.forecasting.flatliner_forecaster import FlatlinerForecaster
from openstef_models.models.forecasting.gblinear_forecaster import GBLinearForecaster, GBLinearHyperParams
from openstef_models.models.forecasting.lgbm_forecaster import LGBMForecaster, LGBMHyperParams
from openstef_models.models.forecasting.lgbmlinear_forecaster import LGBMLinearForecaster, LGBMLinearHyperParams
from openstef_models.models.forecasting.median_forecaster import MedianForecaster
from openstef_models.models.forecasting.xgboost_forecaster import XGBoostForecaster, XGBoostHyperParams
from openstef_models.transforms.energy_domain import WindPowerFeatureAdder
from openstef_models.transforms.general import (
EmptyFeatureRemover,
Imputer,
NaNDropper,
OutlierHandler,
SampleWeightConfig,
SampleWeighter,
Scaler,
Selector,
Shifter,
)
from openstef_models.transforms.postprocessing import ConfidenceIntervalApplicator, QuantileSorter
from openstef_models.transforms.time_domain import (
CyclicFeaturesAdder,
DatetimeFeaturesAdder,
HolidayFeatureAdder,
RollingAggregatesAdder,
)
from openstef_models.transforms.time_domain.lags_adder import LagsAdder
from openstef_models.transforms.time_domain.rolling_aggregates_adder import AggregationFunction
from openstef_models.transforms.validation import CompletenessChecker, FlatlineChecker, InputConsistencyChecker
from openstef_models.transforms.weather_domain import (
AtmosphereDerivedFeaturesAdder,
DaylightFeatureAdder,
RadiationDerivedFeaturesAdder,
)
from openstef_models.utils.data_split import DataSplitter
from openstef_models.utils.feature_selection import Exclude, FeatureSelection, Include
from openstef_models.workflows.callbacks import ModelPerformanceCallback
from openstef_models.workflows.custom_forecasting_workflow import (
CustomForecastingWorkflow,
ForecastingCallback,
)
[docs]
class LocationConfig(BaseConfig):
"""Configuration for location information in forecasting workflows."""
name: str = Field(
default="test_location",
description="Name of the forecasting location or workflow.",
)
description: str = Field(default="", description="Description of the forecasting workflow.")
coordinate: Coordinate = Field(
default=Coordinate(
latitude=Latitude(Decimal("52.132633")),
longitude=Longitude(Decimal("5.291266")),
),
description="Geographic coordinate of the location.",
)
country_code: CountryAlpha2 = Field(
default=CountryAlpha2("NL"),
description="Country code for holiday feature generation.",
)
@property
def tags(self) -> dict[str, str]:
"""Generate tags dictionary from location information."""
return {
"location_name": self.name,
"location_description": self.description,
"location_coordinate": str(self.coordinate),
"location_country_code": str(self.country_code),
}
[docs]
class ForecastingWorkflowConfig(BaseConfig): # PredictionJob
"""Configuration for forecasting workflows.
Defines all parameters needed to set up a forecasting model, including model type,
hyperparameters, location information, data columns, and feature engineering settings.
"""
kind: Literal["single"] = Field(default="single", description="Discriminator tag for config type.")
model_id: ModelIdentifier = Field(description="Unique identifier for the forecasting model.")
run_name: str | None = Field(
default=None, description="Optional name for this workflow run, can be used for versioning."
)
# Model configuration
model: Literal["xgboost", "gblinear", "flatliner", "median", "constant_quantile", "lgbm", "lgbmlinear"] = Field(
description="Type of forecasting model to use."
)
quantiles: list[Quantile] = Field(
default=[Q(0.5)],
description="List of quantiles to predict for probabilistic forecasting.",
)
sample_interval: timedelta = Field(
default=timedelta(minutes=15),
description="Time interval between consecutive data samples.",
)
horizons: list[LeadTime] = Field(
default=[LeadTime.from_string("PT48H")],
description="List of forecast horizons to predict.",
)
xgboost_hyperparams: XGBoostHyperParams = Field(
default=XGBoostHyperParams(),
description="Hyperparameters for XGBoost forecaster.",
)
gblinear_hyperparams: GBLinearHyperParams = Field(
default=GBLinearHyperParams(),
description="Hyperparameters for GBLinear forecaster.",
)
lgbm_hyperparams: LGBMHyperParams = Field(
default=LGBMHyperParams(),
description="Hyperparameters for LightGBM forecaster.",
)
lgbmlinear_hyperparams: LGBMLinearHyperParams = Field(
default=LGBMLinearHyperParams(),
description="Hyperparameters for LightGBM forecaster.",
)
location: LocationConfig = Field(
default=LocationConfig(),
description="Location information for the forecasting workflow.",
)
# Data properties
target_column: str = Field(default="load", description="Name of the target variable column in datasets.")
energy_price_column: str = Field(
default="day_ahead_electricity_price",
description="Name of the energy price column in datasets.",
)
radiation_column: str = Field(default="radiation", description="Name of the radiation column in datasets.")
wind_speed_column: str = Field(default="windspeed", description="Name of the wind speed column in datasets.")
pressure_column: str = Field(default="pressure", description="Name of the pressure column in datasets.")
temperature_column: str = Field(default="temperature", description="Name of the temperature column in datasets.")
relative_humidity_column: str = Field(
default="relative_humidity",
description="Name of the relative humidity column in datasets.",
)
selected_features: FeatureSelection = Field(
default=FeatureSelection.ALL,
description="Feature selection for which features to include/exclude.",
)
predict_history: timedelta = Field(
default=timedelta(days=14),
description="Amount of historical data available at prediction time.",
)
cutoff_history: timedelta = Field(
default=timedelta(days=0),
description="Amount of historical data to exclude from training and prediction due to incomplete features "
"from lag-based preprocessing. When using lag transforms (e.g., lag-14), the first N days contain NaN values. "
"Set this to match your maximum lag duration (e.g., timedelta(days=14)). "
"Default of 0 assumes no invalid rows are created by preprocessing. "
"Note: should be same as predict_history if you are using lags. We default to disabled to keep the same "
"behaviour as openstef 3.0.",
)
# Feature engineering and validation
completeness_threshold: float = Field(
default=0.5,
description="Minimum fraction of data that should be available for making a regular forecast.",
)
completeness_threshold_target_constant_quantile: float = Field(
default=0.03,
description="Minimum fraction of target data that should be available for making a constant quantile forecast.",
)
flatliner_threshold: timedelta = Field(
default=timedelta(hours=24),
description="Number of minutes that the load has to be constant to detect a flatliner.",
)
detect_non_zero_flatliner: bool = Field(
default=False,
description="If True, flatliners are also detected on non-zero values (median of the load).",
)
predict_nonzero_flatliner: bool = Field(
default=False,
description="If True, predict the median of load measurements instead of zero (only for flatliner model).",
)
# Feature engineering
shifters: list[Shifter] = Field(
default=[],
description="List of feature shifts to align aggregation intervals. "
"Each Shifter can target different features with different aggregation periods.",
)
rolling_aggregate_features: list[AggregationFunction] = Field(
default=[],
description="If not None, rolling aggregate(s) of load will be used as features in the model.",
)
clip_features: FeatureSelection = Field(
default=FeatureSelection.ALL,
description="Feature selection for which features to clip to their learned range.",
)
nan_on_outlier_features: FeatureSelection = Field(
default=FeatureSelection.NONE,
description="Feature selection for which features to replace out-of-range values with NaN. "
"Defaults to no features (disabled).",
)
max_day_lags: int = Field(
default=14,
description="Maximum number of days to look back for day-based lags. "
"Default is 14 days (two weekly cycles). Set to 7 for a single weekly cycle.",
ge=1,
)
sample_weight_config: SampleWeightConfig = Field(
default_factory=lambda data: SampleWeightConfig(weight_exponent=1.0)
if data.get("model") == "gblinear"
else SampleWeightConfig(weight_exponent=0.0),
description="Sample weighting configuration. Controls how training samples are weighted. "
"Defaults to weight_exponent=1.0 for gblinear, 0.0 (uniform) for other models.",
)
# Data splitting strategy
data_splitter: DataSplitter = Field(
default=DataSplitter(
# Copied from OpenSTEF3 pipeline defaults
val_fraction=0.15,
test_fraction=0.0,
stratification_fraction=0.15,
min_days_for_stratification=4,
),
description="Configuration for splitting data into training, validation, and test sets.",
)
# Evaluation
evaluation_metrics: list[MetricProvider] = Field(
default_factory=lambda: [R2Provider(), ObservedProbabilityProvider()],
description="List of metric providers for evaluating model score.",
)
# Callbacks
mlflow_storage: MLFlowStorage | None = Field(
default_factory=MLFlowStorage,
description="Configuration for MLflow experiment tracking and model storage.",
)
model_reuse_enable: bool = Field(
default=True,
description="Whether to enable reuse of previously trained models.",
)
model_reuse_max_age: timedelta = Field(
default=timedelta(days=7),
description="Maximum age of a model to be considered for reuse.",
)
model_selection_enable: bool = Field(
default=True,
description="Whether to enable automatic model selection based on performance.",
)
model_selection_metric: tuple[QuantileOrGlobal, str, MetricDirection] = Field(
default=(Q(0.5), "R2", "higher_is_better"),
description="Metric to monitor for model performance when retraining.",
)
model_selection_old_model_penalty: float = Field(
default=1.2,
description="Penalty to apply to the old model's metric to bias selection towards newer models.",
)
model_performance_callback_enabled: bool = Field(
default=False,
description=(
"Whether to enable the ModelPerformanceCallback that evaluates model performance at the end of fitting."
),
)
model_performance_callback_metric_threshold: tuple[QuantileOrGlobal, str, MetricDirection, float] = Field(
default=(Q(0.5), "R2", "higher_is_better", 0.0),
description=("Metric to monitor for model performance threshold at the end of fitting. "),
)
verbosity: Literal[0, 1, 2, 3, True] = Field(
default=0, description="Verbosity level. 0=silent, 1=warning, 2=info, 3=debug"
)
# Metadata
tags: dict[str, str] = Field(
default_factory=dict,
description="Optional metadata tags for the model run.",
)
experiment_tags: dict[str, str] = Field(
default_factory=dict,
description="Optional metadata tags for experiment tracking.",
)
def _checks(config: ForecastingWorkflowConfig) -> list[Transform[TimeSeriesDataset, TimeSeriesDataset]]:
return [
Selector(selection=config.selected_features),
InputConsistencyChecker(),
FlatlineChecker(
load_column=config.target_column,
flatliner_threshold=config.flatliner_threshold,
detect_non_zero_flatliner=config.detect_non_zero_flatliner,
error_on_flatliner=True,
),
CompletenessChecker(completeness_threshold=config.completeness_threshold),
]
[docs]
def create_forecasting_workflow(
config: ForecastingWorkflowConfig,
) -> CustomForecastingWorkflow:
"""Create a forecasting workflow from configuration.
Builds a complete forecasting pipeline including preprocessing, forecaster, and postprocessing
transforms based on the provided configuration.
Args:
config: Configuration object containing all workflow parameters.
Returns:
Configured forecasting workflow ready for training and prediction.
Raises:
ValueError: If an unsupported model type is specified.
"""
checks = _checks(config)
feature_aligners = config.shifters
feature_adders = [
LagsAdder(
history_available=config.predict_history,
horizons=config.horizons,
add_trivial_lags=config.model
not in {"gblinear", "stacking", "learned_weights"}, # GBLinear uses only 7day lag.
target_column=config.target_column,
max_day_lags=config.max_day_lags,
custom_lags=[timedelta(days=7)] if config.model in {"gblinear", "stacking", "learned_weights"} else [],
lag_fallback_offset=timedelta(days=7)
if config.model in {"gblinear", "stacking", "learned_weights"}
else None,
),
WindPowerFeatureAdder(
windspeed_reference_column=config.wind_speed_column,
),
AtmosphereDerivedFeaturesAdder(
pressure_column=config.pressure_column,
relative_humidity_column=config.relative_humidity_column,
temperature_column=config.temperature_column,
),
RadiationDerivedFeaturesAdder(
coordinate=config.location.coordinate,
radiation_column=config.radiation_column,
),
CyclicFeaturesAdder(),
DaylightFeatureAdder(
coordinate=config.location.coordinate,
),
*(
[
RollingAggregatesAdder(
feature=config.target_column,
aggregation_functions=config.rolling_aggregate_features,
horizons=config.horizons,
),
]
if config.rolling_aggregate_features
else []
),
]
feature_standardizers = [
OutlierHandler(
selection=Include(config.energy_price_column).combine(config.clip_features),
mode="standard",
outlier_action="clip",
),
Scaler(selection=Exclude(config.target_column), method="standard"),
SampleWeighter(
target_column=config.target_column,
config=config.sample_weight_config,
),
EmptyFeatureRemover(),
]
if config.model == "xgboost":
nan_outlier_handlers = [
*(
[OutlierHandler(mode="standard", selection=config.nan_on_outlier_features, outlier_action="nan")]
if config.nan_on_outlier_features != FeatureSelection.NONE
else []
),
]
preprocessing = [
*checks,
*feature_aligners,
*feature_adders,
HolidayFeatureAdder(country_code=config.location.country_code),
DatetimeFeaturesAdder(onehot_encode=False),
*nan_outlier_handlers,
*feature_standardizers,
]
forecaster = XGBoostForecaster(
quantiles=config.quantiles,
horizons=config.horizons,
hyperparams=config.xgboost_hyperparams,
verbosity=config.verbosity,
)
postprocessing = [
QuantileSorter(),
ConfidenceIntervalApplicator(
quantiles=config.quantiles,
add_quantiles_from_std=False,
),
]
elif config.model == "lgbmlinear":
preprocessing = [
*checks,
*feature_aligners,
*feature_adders,
HolidayFeatureAdder(country_code=config.location.country_code),
DatetimeFeaturesAdder(onehot_encode=False),
*feature_standardizers,
]
forecaster = LGBMLinearForecaster(
quantiles=config.quantiles,
horizons=config.horizons,
hyperparams=config.lgbmlinear_hyperparams,
)
postprocessing = [QuantileSorter()]
elif config.model == "lgbm":
preprocessing = [
*checks,
*feature_aligners,
*feature_adders,
HolidayFeatureAdder(country_code=config.location.country_code),
DatetimeFeaturesAdder(onehot_encode=False),
*feature_standardizers,
]
forecaster = LGBMForecaster(
quantiles=config.quantiles,
horizons=config.horizons,
hyperparams=config.lgbm_hyperparams,
)
postprocessing = [QuantileSorter()]
elif config.model == "gblinear":
preprocessing = [
*checks,
*feature_aligners,
*feature_adders,
*feature_standardizers,
Imputer(
selection=Exclude(config.target_column),
imputation_strategy="mean",
fill_future_values=Include(config.energy_price_column),
),
NaNDropper(
selection=Exclude(config.target_column),
),
]
forecaster = GBLinearForecaster(
quantiles=config.quantiles,
horizons=config.horizons,
hyperparams=config.gblinear_hyperparams,
verbosity=config.verbosity,
)
postprocessing = [
QuantileSorter(),
ConfidenceIntervalApplicator(
quantiles=config.quantiles,
add_quantiles_from_std=False,
),
]
elif config.model == "median":
preprocessing = [
LagsAdder(
history_available=config.predict_history,
horizons=config.horizons,
add_trivial_lags=True,
target_column=config.target_column,
)
]
forecaster = MedianForecaster(
quantiles=config.quantiles,
horizons=config.horizons,
)
postprocessing = []
elif config.model == "constant_quantile":
preprocessing = [
CompletenessChecker(
columns=[config.target_column],
completeness_threshold=config.completeness_threshold_target_constant_quantile,
),
]
forecaster = ConstantQuantileForecaster(
quantiles=config.quantiles,
horizons=config.horizons,
)
postprocessing = []
elif config.model == "flatliner":
preprocessing = []
forecaster = FlatlinerForecaster(
quantiles=config.quantiles,
horizons=config.horizons,
predict_median=config.predict_nonzero_flatliner,
)
postprocessing = [
QuantileSorter(),
ConfidenceIntervalApplicator(
quantiles=config.quantiles,
add_quantiles_from_std=False,
),
]
else:
msg = f"Unsupported model type: {config.model}"
raise ValueError(msg)
tags = {
**config.location.tags,
"model_type": config.model,
**config.tags,
}
callbacks: list[ForecastingCallback] = []
if config.mlflow_storage is not None:
callbacks.append(
MLFlowStorageCallback(
storage=config.mlflow_storage,
model_reuse_enable=config.model_reuse_enable,
model_reuse_max_age=config.model_reuse_max_age,
model_selection_enable=config.model_selection_enable,
model_selection_metric=config.model_selection_metric,
model_selection_old_model_penalty=config.model_selection_old_model_penalty,
)
)
if config.model_performance_callback_enabled:
quantile, metric_name, metric_direction, threshold = config.model_performance_callback_metric_threshold
callbacks.append(
ModelPerformanceCallback(
metric_name=metric_name,
threshold=threshold,
metric_direction=metric_direction,
quantile=quantile,
)
)
return CustomForecastingWorkflow(
model=ForecastingModel(
preprocessing=TransformPipeline(transforms=preprocessing),
forecaster=forecaster,
postprocessing=TransformPipeline(transforms=postprocessing),
target_column=config.target_column,
data_splitter=config.data_splitter,
cutoff_history=config.cutoff_history,
# Evaluation
evaluation_metrics=config.evaluation_metrics,
# Other
tags=tags,
),
model_id=config.model_id,
run_name=config.run_name,
callbacks=callbacks,
experiment_tags=config.experiment_tags,
)