# SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
#
# SPDX-License-Identifier: MPL-2.0
"""Testing utilities for comparing pandas objects.
Provides matcher classes for use in test assertions when comparing pandas
DataFrames and Series with equality semantics.
"""
import logging
from collections.abc import Sequence
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, override
import numpy as np
import pandas as pd
from openstef_core.constants import LIANDER_DATASET_REPO_ID
from openstef_core.datasets import TimeSeriesDataset, VersionedTimeSeriesDataset
[docs]
class IsSamePandas:
"""Utility class to allow comparison of pandas DataFrames in assertion / calls."""
[docs]
def __init__(self, pandas_obj: pd.DataFrame | pd.Series):
"""Matcher to check if two DataFrames are equal."""
self.pandas_obj = pandas_obj
@override
def __eq__(self, other: object) -> bool:
return isinstance(other, type(self.pandas_obj)) and self.pandas_obj.equals(other) # type: ignore
@override
def __hash__(self) -> int:
return hash(self.pandas_obj)
[docs]
def assert_timeseries_equal(actual: TimeSeriesDataset, expected: TimeSeriesDataset):
"""Assert that two TimeSeriesDataset objects are equal."""
pd.testing.assert_frame_equal(actual.data, expected.data)
assert actual.sample_interval == expected.sample_interval, ( # noqa: S101 - exception - testing utility
f"Sample intervals differ: {actual.sample_interval} != {expected.sample_interval}"
)
[docs]
def create_timeseries_dataset(
index: pd.DatetimeIndex,
available_ats: pd.Series | list[datetime] | pd.DatetimeIndex | None = None,
horizons: pd.Series | list[timedelta] | None = None,
sample_interval: timedelta = timedelta(hours=1),
*,
check_frequency: bool = False,
**kwargs: pd.Series | list[Any] | pd.DatetimeIndex,
) -> TimeSeriesDataset:
"""Create a TimeSeriesDataset for testing purposes.
Args:
index: Datetime index for the dataset.
available_ats: Optional available_at timestamps for each data point.
horizons: Optional forecast horizons for each data point.
sample_interval: Time interval between consecutive samples.
**kwargs: Additional columns to include in the dataset.
check_frequency: Whether to check the frequency of the datetime index.
Returns:
TimeSeriesDataset with the specified structure.
"""
data = kwargs
if available_ats is not None:
data["available_at"] = available_ats
elif horizons is not None:
data["horizon"] = horizons
return TimeSeriesDataset(
data=pd.DataFrame(data=data, index=index), sample_interval=sample_interval, check_frequency=check_frequency
)
[docs]
def create_synthetic_forecasting_dataset( # noqa: PLR0913, PLR0917 - complex function - testing utility
start: datetime = datetime.fromisoformat("2025-01-01T00:00:00+00:00"), # noqa: B008
length: timedelta = timedelta(days=30 * 9),
sample_interval: timedelta = timedelta(hours=1),
random_seed: int = 42,
wind_influence: float | None = -0.2,
temp_influence: float | None = 0.3,
radiation_influence: float | None = -0.2,
stochastic_influence: float | None = 0.1,
other_components: dict[str, float] | None = None,
*,
include_atmosphere: bool = False,
include_price: bool = False,
include_available_at: bool = False,
) -> TimeSeriesDataset:
"""Create synthetic forecasting dataset for testing.
Generates time series data with configurable components influencing load.
Args:
start: Start datetime for the dataset.
length: Total duration of the dataset.
sample_interval: Time interval between consecutive samples.
random_seed: Random seed for reproducible random components.
wind_influence: Coefficient for wind speed component on load.
temp_influence: Coefficient for temperature component on load.
radiation_influence: Coefficient for radiation component on load.
stochastic_influence: Coefficient for random noise component.
other_components: Additional components with their influence coefficients.
include_atmosphere: Add ``pressure`` (~1013) and ``relative_humidity`` (~70%) columns.
include_price: Add ``day_ahead_electricity_price`` (~50) column.
include_available_at: Add ``available_at`` column (index + sample_interval).
Returns:
TimeSeriesDataset containing synthetic load and component data.
"""
timestamps = pd.date_range(start=start, periods=length // sample_interval, freq=sample_interval, tz="UTC")
# Build load as a combination of various components
component_influence = other_components or {}
if wind_influence is not None:
component_influence["windspeed"] = wind_influence
if temp_influence is not None:
component_influence["temperature"] = temp_influence
if radiation_influence is not None:
component_influence["radiation"] = radiation_influence
if stochastic_influence is not None:
component_influence["stochastic"] = stochastic_influence
rng = np.random.default_rng(random_seed)
load = pd.Series(np.zeros(len(timestamps)), index=timestamps, name="load")
components: dict[str, pd.Series] = {}
for component_name, influence in component_influence.items():
component = pd.Series(rng.standard_normal(size=len(timestamps)), index=timestamps, name=component_name)
load += component * influence
components[component_name] = component
extras: dict[str, Any] = {}
if include_atmosphere:
extras["pressure"] = 1013.0 + rng.normal(0, 5, len(timestamps))
extras["relative_humidity"] = 70.0 + rng.normal(0, 10, len(timestamps))
if include_price:
extras["day_ahead_electricity_price"] = 50.0 + rng.normal(0, 10, len(timestamps))
if include_available_at:
extras["available_at"] = timestamps + sample_interval
return TimeSeriesDataset(
data=pd.DataFrame(
data={
"load": load,
**components,
**extras,
},
index=timestamps,
),
sample_interval=sample_interval,
)
[docs]
def load_liander_dataset(
*,
target: str = "mv_feeder/OS Gorredijk",
repo_id: str = LIANDER_DATASET_REPO_ID,
local_dir: Path = Path("./liander_dataset"),
extra_files: list[str] | None = None,
) -> TimeSeriesDataset:
"""Download and combine the Liander benchmark dataset into a single TimeSeriesDataset.
Downloads load measurements, weather forecasts, electricity prices, and standard load
profiles from HuggingFace Hub, then combines them via left join.
Raises:
ImportError: When ``huggingface-hub`` is not installed.
Args:
target: Sub-path within the repo identifying the installation (e.g. ``"mv_feeder/OS Gorredijk"``).
repo_id: HuggingFace dataset repository ID.
local_dir: Local directory for caching downloaded files.
extra_files: Additional parquet files to download and include (paths relative to repo root).
Returns:
Combined dataset with all features aligned by timestamp.
"""
try:
from huggingface_hub import hf_hub_download # pyright: ignore[reportUnknownVariableType] # noqa: PLC0415
from huggingface_hub.utils import logging as hf_logging # noqa: PLC0415
except ImportError:
msg = "huggingface-hub is required for benchmark datasets: pip install openstef-core[benchmark]"
raise ImportError(msg) from None
files_to_download = [
f"load_measurements/{target}.parquet",
f"weather_forecasts_versioned/{target}.parquet",
"EPEX.parquet",
"profiles.parquet",
*(extra_files or []),
]
# Suppress HF Hub noise (unauthenticated requests warning, progress bars)
hf_logging.set_verbosity_error()
for filename in files_to_download:
hf_hub_download( # pyright: ignore[reportCallIssue]
repo_id=repo_id,
filename=filename,
repo_type="dataset",
local_dir=local_dir,
)
datasets = [VersionedTimeSeriesDataset.read_parquet(local_dir / f) for f in files_to_download]
return VersionedTimeSeriesDataset.concat(datasets, mode="left").select_version()
__all__ = [
"LIANDER_DATASET_REPO_ID",
"IsSamePandas",
"assert_timeseries_equal",
"configure_notebook_display",
"create_synthetic_forecasting_dataset",
"create_timeseries_dataset",
"load_liander_dataset",
"prepare_tutorial_datasets",
"setup_notebook_logging",
]
_DEFAULT_NOISY_LOGGERS: tuple[str, ...] = (
"choreographer",
"kaleido",
"huggingface_hub",
"huggingface_hub.utils._http",
"openstef_core.datasets.timeseries_dataset",
)
[docs]
def setup_notebook_logging(
name: str | None = None,
suppress: Sequence[str] | None = None,
) -> logging.Logger:
"""Configure logging for tutorial notebooks and return a named logger.
Sets the root logger to INFO level and silences the loggers in *suppress*
by raising their level to ERROR and disabling propagation. Child loggers
sharing a prefix are also silenced.
Args:
name: Logger name, typically ``__name__`` of the calling module.
suppress: Sequence of logger names to silence. Defaults to
``_DEFAULT_NOISY_LOGGERS``.
Returns:
Configured Logger instance.
"""
noisy = suppress if suppress is not None else _DEFAULT_NOISY_LOGGERS
logging.basicConfig(level=logging.INFO, format="[%(asctime)s][%(levelname)s] %(message)s")
for logger_name in noisy:
lgr = logging.getLogger(logger_name)
lgr.setLevel(logging.ERROR)
lgr.propagate = False
# Also silence any existing child loggers
prefix = logger_name + "."
for key in logging.Logger.manager.loggerDict:
if key.startswith(prefix):
child = logging.getLogger(key)
child.setLevel(logging.ERROR)
child.propagate = False
return logging.getLogger(name)
[docs]
def prepare_tutorial_datasets(
*,
train_start_iso: str = "2024-03-01T00:00:00Z",
train_days: int = 90,
forecast_days: int = 14,
) -> tuple[TimeSeriesDataset, TimeSeriesDataset]:
"""Load the Liander benchmark dataset and split into training and forecast periods.
Args:
train_start_iso: ISO-format start date for the training period.
train_days: Number of days in the training window.
forecast_days: Number of days in the forecast window (starts immediately after training).
Returns:
Tuple of ``(train_dataset, forecast_dataset)``.
"""
train_start = datetime.fromisoformat(train_start_iso)
train_end = train_start + timedelta(days=train_days)
dataset = load_liander_dataset()
return (
dataset.filter_by_range(start=train_start, end=train_end),
dataset.filter_by_range(start=train_end, end=train_end + timedelta(days=forecast_days)),
)