Source code for openstef_models.utils.data_split

# SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
#
# SPDX-License-Identifier: MPL-2.0

"""Time series dataset splitting utilities for training and evaluation.

Provides various strategies for splitting time series datasets into training,
validation, and test sets. Supports chronological splits, stratified splits
based on extreme values, and custom date-based splits.

Key functions handle the temporal nature of forecasting data, ensuring that
training data always precedes test data to prevent information leakage.
"""

from collections.abc import Callable
from datetime import datetime
from typing import cast

import numpy as np
import pandas as pd
from pydantic import Field

from openstef_core.base_model import BaseConfig
from openstef_core.datasets import TimeSeriesDataset
from openstef_core.exceptions import InsufficientlyCompleteError


[docs] def split_by_dates[T: TimeSeriesDataset]( dataset: T, dates_test: pd.DatetimeIndex, ) -> tuple[T, T]: """Split a dataset into train and test sets based on specific dates. Args: dataset: The dataset to split. dates_test: Dates to include in the test set. All other dates go to training. Returns: Tuple of (train_dataset, test_dataset). """ mask = cast("pd.Series[bool]", dataset.index.normalize().isin(dates_test)) # type: ignore train_data, test_data = dataset.data[~mask], dataset.data[mask] return dataset._copy_with_data(train_data), dataset._copy_with_data(test_data) # noqa: SLF001 - allow protected access, invariants are maintained
[docs] def split_by_date[T: TimeSeriesDataset]( dataset: T, split_date: datetime | pd.Timestamp, ) -> tuple[T, T]: """Split a dataset into train and test sets based on a specific date. Args: dataset: The dataset to split. split_date: The date to split on. Data before this date goes to train, data at/after goes to test. Returns: Tuple of (train_dataset, test_dataset). """ split_idx = cast(pd.Series, dataset.index).searchsorted(split_date, side="left") train_data = dataset.data.iloc[:split_idx] test_data = dataset.data.iloc[split_idx:] return dataset._copy_with_data(train_data), dataset._copy_with_data(test_data) # noqa: SLF001 - allow protected access, invariants are maintained
[docs] def chronological_train_test_split[T: TimeSeriesDataset]( dataset: T, test_fraction: float, ) -> tuple[T, T]: """Split a dataset into train and test sets chronologically. Divides the dataset into training and testing sets based on temporal order, ensuring that all training data comes before all testing data. This is the standard approach for time series forecasting evaluation. The split point is determined by the test_fraction parameter, placing the most recent portion of data in the test set. Args: dataset: The dataset to split. test_fraction: Fraction of data to include in the test split. Returns: Tuple of (train_dataset, test_dataset). Raises: ValueError: If test_fraction is not between 0 and 1. InsufficientlyCompleteError: If dataset has fewer than 2 unique timestamps. """ if not 0.0 <= test_fraction <= 1.0: raise ValueError("test_fraction must be between 0 and 1.") if test_fraction == 0.0: # No test set return dataset, dataset._copy_with_data(dataset.data.iloc[0:0]) # noqa: SLF001 - allow protected access, invariants are maintained index_unique = dataset.index.unique() n_total = len(index_unique) min_timestamps = 2 if n_total < min_timestamps: msg = f"Dataset has {n_total} unique timestamps, need at least {min_timestamps} to split into train/test." raise InsufficientlyCompleteError(msg) n_test = int(n_total * test_fraction) n_test = min(n_test, n_total - 1) # Ensure at least one for train if possible if n_total > 1 and n_test == 0: n_test = 1 # Ensure at least one for test if possible n_train = n_total - n_test split_date = index_unique[n_train] return split_by_date(dataset=dataset, split_date=split_date)
[docs] def stratified_train_test_split[T: TimeSeriesDataset]( dataset: T, test_fraction: float, stratification_fraction: float = 0.15, target_column: str = "load", random_state: int = 42, min_days_for_stratification: int = 4, ) -> tuple[T, T]: """Split a dataset into train and test sets with stratification on extreme values. Splits data while ensuring that extreme high and low values are proportionally represented in both training and testing sets. This helps maintain representative distributions for model evaluation, especially important for forecasting tasks where extreme events are critical. Args: dataset: The dataset to split. test_fraction: Fraction of data to include in the test split. stratification_fraction: Fraction of extreme days to consider for stratification. target_column: Column name containing the values to stratify on. random_state: Random seed for reproducible splits. min_days_for_stratification: Minimum days required for stratification. Returns: Tuple of (train_dataset, test_dataset). Raises: ValueError: If test_fraction is not between 0 and 1. Note: Falls back to chronological splitting if there are too few days for stratification. """ if not 0.0 <= test_fraction <= 1.0: raise ValueError("test_fraction must be between 0 and 1.") index_dates = dataset.index.normalize() n_unique_days = index_dates.nunique() # If not enough days, fall back to simple chronological split if n_unique_days < min_days_for_stratification: return chronological_train_test_split(dataset=dataset, test_fraction=test_fraction) rng = np.random.default_rng(random_state) # Get extreme day groups target_series = dataset.select_features([target_column]).select_version().data[target_column] max_days, min_days, other_days = _get_extreme_days(target_series=target_series, fraction=stratification_fraction) # Split each group proportionally between train and test _, test_max_days = _sample_dates_for_split(dates=max_days, test_fraction=test_fraction, rng=rng) _, test_min_days = _sample_dates_for_split(dates=min_days, test_fraction=test_fraction, rng=rng) _, test_other_days = _sample_dates_for_split(dates=other_days, test_fraction=test_fraction, rng=rng) # Combine all train and test dates test_dates = cast(pd.DatetimeIndex, test_max_days.union(test_min_days).union(test_other_days)) return split_by_dates(dataset=dataset, dates_test=test_dates)
def _sample_dates_for_split( dates: pd.DatetimeIndex, test_fraction: float, rng: np.random.Generator, ) -> tuple[pd.DatetimeIndex, pd.DatetimeIndex]: if dates.empty: return pd.DatetimeIndex([]), pd.DatetimeIndex([]) min_test_days = 1 if test_fraction > 0.0 else 0 n_test = max(min_test_days, int(test_fraction * len(dates))) n_test = min(n_test, len(dates) - 1) # Ensure at least one for train if possible if len(dates) == 1: # Only one date, put in train return pd.DatetimeIndex([]), dates test_dates = pd.DatetimeIndex(np.sort(rng.choice(dates, size=n_test, replace=False))) train_dates = dates.difference(test_dates, sort=True) # type: ignore return train_dates, test_dates def _get_extreme_days( target_series: pd.Series, fraction: float = 0.1, ) -> tuple[pd.DatetimeIndex, pd.DatetimeIndex, pd.DatetimeIndex]: if not isinstance(target_series.index, pd.DatetimeIndex): raise TypeError("target_series must have a DatetimeIndex.") # Compute daily min and max once daily_agg: pd.DataFrame = target_series.resample("1D").agg(["min", "max"]) # type: ignore n_days = len(daily_agg) n_extremes = max(int(fraction * n_days), 2) # Sort once max_days = cast(pd.DatetimeIndex, daily_agg["max"].nlargest(n_extremes).index) min_days = cast(pd.DatetimeIndex, daily_agg["min"].nsmallest(n_extremes).index) all_days = cast(pd.DatetimeIndex, daily_agg.index) other_days = all_days.difference(other=max_days.union(other=min_days)) # type: ignore return max_days, min_days, other_days
[docs] def train_val_test_split[T]( dataset: T, split_func: Callable[[T, float], tuple[T, T]], val_fraction: float, test_fraction: float, ) -> tuple[T, T, T]: """Split a dataset into train, validation, and test sets chronologically. Divides the dataset into training, validation, and testing sets based on temporal order, ensuring that all training data comes before all validation data, which comes before all testing data. The split points are determined by the val_fraction and test_fraction parameters. Args: dataset: The dataset to split. split_func: Function to use for splitting the dataset into two parts. val_fraction: Fraction of data to include in the validation split. test_fraction: Fraction of data to include in the test split. Returns: Tuple of (train_dataset, val_dataset, test_dataset). Raises: ValueError: If test_fraction + val_fraction is not less than 1.0. """ if not 0.0 <= val_fraction <= 1.0: raise ValueError("val_fraction must be between 0 and 1.") if not 0.0 <= test_fraction <= 1.0: raise ValueError("test_fraction must be between 0 and 1.") if test_fraction + val_fraction >= 1.0: msg = f"test_fraction ({test_fraction}) + val_fraction ({val_fraction}) must be less than 1.0" raise ValueError(msg) # First split: separate test set from train+val train_val, test = split_func(dataset, test_fraction) # Calculate adjusted validation fraction for the remaining data # We want val_fraction of the *original* dataset size # From the remaining (1 - test_fraction), we need val_fraction # So: adjusted = val_fraction / (1 - test_fraction). adjusted_val_fraction = val_fraction / (1 - test_fraction) # Second split: separate validation from training train, val = split_func(train_val, adjusted_val_fraction) return train, val, test
[docs] class DataSplitter(BaseConfig): """Handles splitting of time series data into train, validation, and test sets. Supports stratified splitting to ensure representative data distribution across splits, particularly for extreme values in forecasting scenarios. """ val_fraction: float = Field( default=0.15, description="Fraction of data to reserve for the validation set when automatic splitting is used.", ) test_fraction: float = Field( default=0.1, description="Fraction of data to reserve for the test set when automatic splitting is used.", ) stratification_fraction: float = Field( default=0.15, description="Fraction of extreme values to use for stratified splitting into train/test sets.", ) min_days_for_stratification: int = Field( default=4, description="Minimum number of unique days required to perform stratified splitting.", ) random_state: int = Field( default=42, description="Random seed for reproducible splits when stratification is used.", )
[docs] def split_dataset[T: TimeSeriesDataset]( self, data: T, data_val: T | None = None, data_test: T | None = None, target_column: str = "load", ) -> tuple[T, T | None, T | None]: """Prepare and split input data into train, validation, and test sets. Args: data: Full dataset to split. data_val: Optional pre-split validation data. data_test: Optional pre-split test data. target_column: Column name containing the target variable for stratification. Returns: Tuple of (train_data, val_data, test_data) where val_data and test_data may be None. """ # Apply splitting strategy input_data_train, input_data_val, input_data_test = train_val_test_split( dataset=data, split_func=lambda dataset, fraction: stratified_train_test_split( dataset=dataset, test_fraction=fraction, stratification_fraction=self.stratification_fraction, target_column=target_column, random_state=self.random_state, min_days_for_stratification=self.min_days_for_stratification, ), val_fraction=self.val_fraction if data_val is None else 0.0, test_fraction=self.test_fraction if data_test is None else 0.0, ) input_data_val = data_val or input_data_val input_data_test = data_test or input_data_test if input_data_val.index.empty: input_data_val = None if input_data_test.index.empty: input_data_test = None return (input_data_train, input_data_val, input_data_test)
__all__ = [ "DataSplitter", "chronological_train_test_split", "split_by_date", "split_by_dates", "stratified_train_test_split", "train_val_test_split", ]