# SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
#
# SPDX-License-Identifier: MPL-2.0
"""Validated dataset classes for time series forecasting.
Specialized dataset classes with domain-specific validation for different stages
of the forecasting pipeline. These datasets inherit from TimeSeriesDataset and add
validation to catch data quality issues early.
"""
from datetime import datetime, timedelta
from typing import Self, override
import pandas as pd
from openstef_core.datasets.timeseries_dataset import TimeSeriesDataset
from openstef_core.datasets.validation import validate_required_columns
from openstef_core.exceptions import MissingColumnsError
from openstef_core.types import EnergyComponentType, LeadTime, Quantile
ENSEMBLE_COLUMN_SEP: str = "__"
[docs]
class ForecastDataset(TimeSeriesDataset):
"""Time series dataset containing probabilistic forecasts with quantile estimates.
Contains forecast results with column names following quantile naming convention
(e.g., 'quantile_P50' for median). Enables consistent handling of probabilistic
forecasts with uncertainty quantification.
Invariants:
- All columns must be valid quantile strings (e.g., 'quantile_P10')
- Inherits all TimeSeriesDataset guarantees (sorted timestamps, consistent intervals)
Attrs:
forecast_start: Timestamp indicating when the forecast period starts.
quantiles: List of Quantile values represented in the dataset.
Example:
>>> import pandas as pd
>>> import numpy as np
>>> from datetime import timedelta
>>> forecast_data = pd.DataFrame({
... 'load': [100, np.nan],
... 'quantile_P10': [90, 95],
... 'quantile_P50': [100, 110],
... 'quantile_P90': [115, 125]
... }, index=pd.date_range('2025-01-01', periods=2, freq='h'))
>>> dataset = ForecastDataset(forecast_data, timedelta(hours=1))
>>> len(dataset.quantiles)
3
>>> dataset.quantiles[1]
0.5
See Also:
TimeSeriesDataset: Base class for time series datasets.
ForecastInputDataset: For preparing forecasting input data.
Quantile: Type for handling quantile values and naming conventions.
"""
forecast_start: datetime
quantiles: list[Quantile]
target_column: str
[docs]
@override
def __init__(
self,
data: pd.DataFrame,
sample_interval: timedelta = timedelta(minutes=15),
forecast_start: datetime | None = None,
target_column: str = "load",
*,
horizon_column: str = "horizon",
available_at_column: str = "available_at",
standard_deviation_column: str = "stdev",
) -> None:
if "forecast_start" in data.attrs:
self.forecast_start = datetime.fromisoformat(data.attrs["forecast_start"])
else:
self.forecast_start = forecast_start if forecast_start is not None else data.index.min().to_pydatetime()
self.target_column = data.attrs.get("target_column", target_column)
self.standard_deviation_column = data.attrs.get("standard_deviation_column", standard_deviation_column)
super().__init__(
data=data,
sample_interval=sample_interval,
horizon_column=horizon_column,
available_at_column=available_at_column,
)
exclude_columns = {target_column, standard_deviation_column}
quantile_feature_names = [col for col in self.feature_names if col not in exclude_columns]
if not all(Quantile.is_valid_quantile_string(col) for col in quantile_feature_names):
raise ValueError("All feature names must be valid quantile strings.")
self.quantiles = [Quantile.parse(col) for col in quantile_feature_names]
@property
def target_series(self) -> pd.Series | None:
"""Extract the target time series from the dataset.
Returns:
Time series containing target values with original datetime index.
"""
if self.target_column not in self.data.columns:
return None
return self.data[self.target_column]
@property
def median_series(self) -> pd.Series:
"""Extract the median (50th percentile) forecast series.
Returns:
Time series containing median forecast values with original datetime index.
Raises:
MissingColumnsError: If the median quantile column is not found.
"""
median_col = Quantile(0.5).format()
if median_col not in self.feature_names:
raise MissingColumnsError(missing_columns=[median_col])
return self.data[median_col]
@property
def standard_deviation_series(self) -> pd.Series:
"""Extract the standard deviation series if it exists.
Returns:
Time series containing standard deviation values with original datetime index.
Raises:
MissingColumnsError: If the standard deviation column is not found.
"""
if self.standard_deviation_column not in self.data.columns:
raise MissingColumnsError(missing_columns=[self.standard_deviation_column])
return self.data[self.standard_deviation_column] # pyright: ignore[reportUnknownVariableType]
@property
def quantiles_data(self) -> pd.DataFrame:
"""Extract DataFrame containing only the quantile forecast columns.
Returns:
DataFrame with quantile columns and original datetime index.
"""
quantile_columns = [q.format() for q in self.quantiles]
return self.data[quantile_columns]
[docs]
def filter_quantiles(self, quantiles: list[Quantile]) -> Self:
"""Select a subset of quantiles from the forecast dataset.
Args:
quantiles: List of Quantile values to select.
Returns:
New ForecastDataset containing only the specified quantile columns.
"""
selected_quantiles = [q.format() for q in quantiles]
validate_required_columns(self.data, required_columns=selected_quantiles)
all_quantiles = [q.format() for q in self.quantiles]
drop_columns = list(set(all_quantiles) - set(selected_quantiles))
data_filtered = self.data.drop(columns=drop_columns)
result = self._copy_with_data(data=data_filtered)
result.quantiles = quantiles
return result
[docs]
@override
def to_pandas(self) -> pd.DataFrame:
df = super().to_pandas()
df.attrs["target_column"] = self.target_column
df.attrs["forecast_start"] = self.forecast_start.isoformat()
df.attrs["standard_deviation_column"] = self.standard_deviation_column
return df
[docs]
@classmethod
def from_timeseries(
cls,
dataset: TimeSeriesDataset,
target_column: str = "load",
forecast_start: datetime | None = None,
) -> Self:
"""Create ForecastInputDataset from a generic TimeSeriesDataset.
Args:
dataset: Input TimeSeriesDataset to convert.
target_column: Name of the target column to forecast.
forecast_start: Optional timestamp indicating forecast start.
Returns:
Instance of ForecastInputDataset with specified target column.
"""
return cls(
data=dataset.data,
sample_interval=dataset.sample_interval,
target_column=target_column,
forecast_start=forecast_start,
)
[docs]
class EnergyComponentDataset(TimeSeriesDataset):
"""Time series dataset for energy generation by component type.
Validates that all required energy component columns (wind, solar, other)
are present. Used for energy sector analysis and component-specific forecasting.
Invariants:
- Must contain columns for all energy component types
- Inherits all TimeSeriesDataset guarantees (sorted timestamps, consistent intervals)
Example:
>>> import pandas as pd
>>> from datetime import timedelta
>>> energy_data = pd.DataFrame({
... 'wind': [50, 60],
... 'solar': [30, 40],
... 'other': [20, 25]
... }, index=pd.date_range('2025-01-01', periods=2, freq='h'))
>>> dataset = EnergyComponentDataset(energy_data, timedelta(hours=1))
>>> 'wind' in dataset.feature_names
True
>>> len(dataset.feature_names)
3
See Also:
TimeSeriesDataset: Base class for time series datasets.
ForecastInputDataset: For general forecasting input data.
EnergyComponentType: Enum defining required energy component types.
"""
[docs]
@override
def __init__(
self,
data: pd.DataFrame,
sample_interval: timedelta = timedelta(minutes=15),
*,
horizon_column: str = "horizon",
available_at_column: str = "available_at",
) -> None:
validate_required_columns(
data,
required_columns=[item.value for item in EnergyComponentType],
)
super().__init__(
data=data,
sample_interval=sample_interval,
horizon_column=horizon_column,
available_at_column=available_at_column,
)
[docs]
class EnsembleForecastDataset(TimeSeriesDataset):
"""First stage output format for ensemble forecasters."""
forecast_start: datetime
quantiles: list[Quantile]
forecaster_names: list[str]
target_column: str
[docs]
@override
def __init__(
self,
data: pd.DataFrame,
sample_interval: timedelta = timedelta(minutes=15),
forecast_start: datetime | None = None,
target_column: str = "load",
*,
horizon_column: str = "horizon",
available_at_column: str = "available_at",
) -> None:
if "forecast_start" in data.attrs:
self.forecast_start = datetime.fromisoformat(data.attrs["forecast_start"])
else:
self.forecast_start = forecast_start if forecast_start is not None else data.index.min().to_pydatetime()
self.target_column = data.attrs.get("target_column", target_column)
super().__init__(
data=data,
sample_interval=sample_interval,
horizon_column=horizon_column,
available_at_column=available_at_column,
)
quantile_feature_names = [col for col in self.feature_names if col != target_column]
self.forecaster_names, self.quantiles = self.get_learner_and_quantile(pd.Index(quantile_feature_names))
for name in self.forecaster_names:
if ENSEMBLE_COLUMN_SEP in name:
msg = f"Forecaster name '{name}' must not contain separator '{ENSEMBLE_COLUMN_SEP}'."
raise ValueError(msg)
n_cols = len(self.forecaster_names) * len(self.quantiles)
if len(data.columns) not in {n_cols + 1, n_cols}:
raise ValueError("Data columns do not match the expected number based on base forecasters and quantiles.")
@property
def target_series(self) -> pd.Series | None:
"""Return the target series if available."""
if self.target_column in self.data.columns:
return self.data[self.target_column]
return None
[docs]
@staticmethod
def get_learner_and_quantile(feature_names: pd.Index) -> tuple[list[str], list[Quantile]]:
"""Extract base forecaster names and quantiles from feature names.
Column format is ``{learner}{ENSEMBLE_COLUMN_SEP}{quantile.format()}``,
e.g. ``lgbm__quantile_P50``.
Args:
feature_names: Index of feature names in the dataset.
Returns:
Tuple containing a list of base forecaster names and a list of quantiles.
Raises:
ValueError: If a column cannot be parsed or has an invalid quantile string.
"""
forecasters: set[str] = set()
quantiles: set[Quantile] = set()
for feature_name in feature_names:
parts = feature_name.split(ENSEMBLE_COLUMN_SEP, maxsplit=1)
if len(parts) != 2: # noqa: PLR2004
msg = f"Column missing separator '{ENSEMBLE_COLUMN_SEP}': {feature_name}"
raise ValueError(msg)
learner_part, quantile_part = parts
if not Quantile.is_valid_quantile_string(quantile_part):
msg = f"Column has no valid quantile string: {feature_name}"
raise ValueError(msg)
forecasters.add(learner_part)
quantiles.add(Quantile.parse(quantile_part))
return list(forecasters), list(quantiles)
[docs]
@classmethod
def from_forecast_datasets(
cls,
datasets: dict[str, ForecastDataset],
target_series: pd.Series | None = None,
sample_weights: pd.Series | None = None,
) -> Self:
"""Create an EnsembleForecastDataset from multiple ForecastDatasets.
Args:
datasets: Dict of ForecastDatasets to combine.
target_series: Optional target series to include in the dataset.
sample_weights: Optional sample weights series to include in the dataset.
Returns:
EnsembleForecastDataset combining all input datasets.
"""
ds1 = next(iter(datasets.values()))
additional_columns: dict[str, pd.Series] = {}
if isinstance(ds1.target_series, pd.Series):
additional_columns[ds1.target_column] = ds1.target_series
elif target_series is not None:
additional_columns[ds1.target_column] = target_series
sample_weight_column = "sample_weight"
if sample_weights is not None:
additional_columns[sample_weight_column] = sample_weights
combined_data = pd.DataFrame({
f"{learner}{ENSEMBLE_COLUMN_SEP}{q.format()}": ds.data[q.format()]
for learner, ds in datasets.items()
for q in ds.quantiles
}).assign(**additional_columns)
return cls(
data=combined_data,
sample_interval=ds1.sample_interval,
forecast_start=ds1.forecast_start,
target_column=ds1.target_column,
)
[docs]
def get_base_predictions_for_quantile(self, quantile: Quantile) -> ForecastInputDataset:
"""Get base forecaster predictions for a specific quantile.
Args:
quantile: Quantile to select.
Returns:
ForecastInputDataset containing predictions from all base forecasters at the specified quantile.
"""
selected_columns = [f"{learner}{ENSEMBLE_COLUMN_SEP}{quantile.format()}" for learner in self.forecaster_names]
selected_columns.append(self.target_column)
prediction_data = self.data[selected_columns].copy()
prediction_data.columns = [*self.forecaster_names, self.target_column]
return ForecastInputDataset(
data=prediction_data,
sample_interval=self.sample_interval,
target_column=self.target_column,
forecast_start=self.forecast_start,
)
__all__ = [
"ENSEMBLE_COLUMN_SEP",
"EnergyComponentDataset",
"EnsembleForecastDataset",
"ForecastDataset",
"ForecastInputDataset",
]