Source code for openstef_models.transforms.validation.input_consistency_checker

# SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
#
# SPDX-License-Identifier: MPL-2.0

"""Input consistency validation for time series transforms."""

import logging
from typing import override

from pydantic import PrivateAttr

from openstef_core.base_model import BaseConfig
from openstef_core.datasets import TimeSeriesDataset
from openstef_core.datasets.validation import validate_required_columns
from openstef_core.exceptions import NotFittedError
from openstef_core.transforms import TimeSeriesTransform


[docs] class InputConsistencyChecker(BaseConfig, TimeSeriesTransform): """Validates input data consistency during transform operations. Ensures that input features match those seen during fitting and maintains consistent column ordering. Logs warnings and removes extra columns. Invariants: - Must be fitted before transform() can be called - Validates presence of all features seen during fitting - Logs warnings for extra columns not seen during fitting - Removes extra columns from output - Maintains consistent column ordering in output """ _feature_names: list[str] = PrivateAttr(default_factory=list[str]) _is_fitted: bool = PrivateAttr(default=False) _logger: logging.Logger = PrivateAttr(default=logging.getLogger(__name__)) @property @override def is_fitted(self) -> bool: return self._is_fitted
[docs] @override def fit(self, data: TimeSeriesDataset) -> None: self._feature_names = list(data.feature_names) self._is_fitted = True
[docs] @override def transform(self, data: TimeSeriesDataset) -> TimeSeriesDataset: if not self.is_fitted: raise NotFittedError(self.__class__.__name__) validate_required_columns(data.data, self._feature_names) extra_columns = set(data.feature_names) - set(self._feature_names) if extra_columns: self._logger.warning("Input data contains extra columns not seen during fitting: %s", extra_columns) # Ensure features are ordered. internal_columns = [col for col in data.data.columns if col not in data.feature_names] columns_ordered = [*internal_columns, *self._feature_names] df = data.data[columns_ordered] return data.copy_with(data=df)
[docs] @override def features_added(self) -> list[str]: return []