# SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
#
# SPDX-License-Identifier: MPL-2.0
"""Transform that adds holiday features to time series data.
This module provides functionality to compute holiday features based on the
datetime index of time series datasets. It uses the `holidays` library to identify
public holidays for a specified country and adds binary features indicating
whether a date is a holiday or a specific holiday.
"""
import re
import unicodedata
from datetime import UTC, datetime
from functools import lru_cache
from typing import override
import holidays
import pandas as pd
from pydantic import Field
from pydantic_extra_types.country import CountryAlpha2
from openstef_core.base_model import BaseConfig
from openstef_core.datasets import TimeSeriesDataset
from openstef_core.transforms import TimeSeriesTransform
_DEFAULT_RESERVED_DATETIME_FEATURES: frozenset[str] = frozenset({"sunday", "week_day", "weekend_day"})
[docs]
class HolidayFeatureAdder(BaseConfig, TimeSeriesTransform):
"""Transform that adds holiday features to time series data.
Computes features that indicate whether a date is a holiday
based on the specified country and categories.
The features added are:
- is_holiday: 1 if the day is a public holiday, 0 otherwise.
- is_{holiday_name}: 1 if the day is a specific holiday, 0 otherwise.
Example:
>>> import pandas as pd
>>> from datetime import timedelta
>>> from openstef_core.datasets import TimeSeriesDataset
>>> from openstef_models.transforms.time_domain import (
... HolidayFeatureAdder
... )
>>>
>>> # Create sample dataset
>>> data = pd.DataFrame({
... 'load': [100] * 5
... }, index=pd.date_range('2025-12-24', periods=5, freq='D'))
>>> dataset = TimeSeriesDataset(data, sample_interval=timedelta(days=1))
>>> transform = HolidayFeatureAdder(country_code='NL')
>>> transformed_data = transform.transform(dataset).data
>>> transformed_data["is_holiday"].tolist()
[0, 1, 1, 0, 0]
>>> transformed_data["is_christmas_day"].tolist()
[0, 1, 0, 0, 0]
>>> transformed_data["is_second_day_of_christmas"].tolist()
[0, 0, 1, 0, 0]
"""
country_code: CountryAlpha2 = Field(description="Country code for holiday calculation.")
reserved_datetime_features: frozenset[str] = Field(
default=_DEFAULT_RESERVED_DATETIME_FEATURES,
description="Holiday names (after sanitization) that collide with other features and must be excluded.",
)
[docs]
@override
def features_added(self) -> list[str]:
return [
"is_holiday",
*[f"is_{name}" for name in get_holiday_names(self.country_code, self.reserved_datetime_features)],
]
[docs]
def sanitize_holiday_name(holiday_name: str) -> str:
"""Convert holiday name to a valid Python identifier for feature names.
Args:
holiday_name: Original holiday name from holidays library.
Returns:
Sanitized feature name (lowercase, underscores, no special chars).
"""
# Normalize Unicode and convert to ASCII
normalized = unicodedata.normalize("NFKD", holiday_name)
ascii_text = normalized.encode("ascii", "ignore").decode("ascii")
# Replace non-alphanumeric with underscores and clean up
return re.sub(r"_+", "_", re.sub(r"[^a-z0-9]+", "_", ascii_text.lower())).strip("_")
[docs]
@lru_cache(maxsize=5)
def get_holiday_names(
country_code: CountryAlpha2, reserved_features: frozenset[str] = _DEFAULT_RESERVED_DATETIME_FEATURES
) -> list[str]:
"""Get list of original holiday names for the country.
Returns:
List of unique holiday names in their original form.
"""
# Years used to fetch holiday names. If new holidays are added later, this needs to be updated.
current_year = datetime.now(tz=UTC).year
reference_years = range(current_year - 5, current_year + 1)
country_holidays = holidays.country_holidays(
country=str(country_code), categories=["public"], years=reference_years, language="en_US"
)
return sorted([
sanitize_holiday_name(holiday_name)
for holiday_name in set(country_holidays.values())
if sanitize_holiday_name(holiday_name) and sanitize_holiday_name(holiday_name) not in reserved_features
])
[docs]
def get_holidays(
index: pd.DatetimeIndex,
country: CountryAlpha2,
reserved_features: frozenset[str] = _DEFAULT_RESERVED_DATETIME_FEATURES,
) -> pd.DataFrame:
"""Get holidays for the dataset's date range as a cleaned DataFrame.
Args:
index: DatetimeIndex of the dataset.
country: Country code for holiday calculation.
reserved_features: Holiday names to exclude from results.
Returns:
DataFrame with columns: date, holiday_name, sanitized_name.
Empty DataFrame if no holidays found.
Sanitized names are pre-cleaned (no empty strings, duplicates handled).
"""
years = range(int(index.year.min()), int(index.year.max()) + 1)
country_holidays = holidays.country_holidays(str(country), categories=["public"], years=years, language="en_US")
all_country_holiday_names = get_holiday_names(country, reserved_features)
if not country_holidays:
return pd.DataFrame(columns=["date", "holiday_name", "sanitized_name"])
return pd.DataFrame(
data=[
{
"date": holiday_date,
"holiday_name": holiday_name,
"sanitized_name": sanitize_holiday_name(holiday_name),
}
for holiday_date, holiday_name in country_holidays.items()
if sanitize_holiday_name(holiday_name) in all_country_holiday_names
]
)