Source code for openstef_core.mixins.param_ranges

# SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
#
# SPDX-License-Identifier: MPL-2.0
"""Tuning range types and metadata for hyperparameter search spaces.

Range types (``FloatRange``, ``IntRange``, ``CategoricalRange``) are frozen
dataclasses used as ``Annotated`` metadata on ``HyperParams`` fields, where
Pydantic treats plain dataclasses as opaque field metadata.

``ModelTuningInfo`` is a Pydantic model used outside ``Annotated`` context.
"""

from dataclasses import dataclass, replace
from typing import Any, Self

from pydantic import ConfigDict, Field, model_validator
from pydantic.fields import FieldInfo

from openstef_core.base_model import BaseConfig, BaseModel


@dataclass(frozen=True)
class _BoundedRange:
    """Shared validation and resolution logic for numeric range types.

    Not part of the public API — use ``FloatRange`` or ``IntRange`` instead.
    """

    low: Any = None
    high: Any = None
    log: bool = False
    tune: bool = False

    def __post_init__(self) -> None:
        if self.low is not None and self.high is not None and self.low > self.high:
            msg = f"low ({self.low}) must be <= high ({self.high})"
            raise ValueError(msg)

    def resolve(self, class_default: Self | None) -> Self:
        """Fill ``None`` bounds from *class_default*.

        Returns:
            Resolved range.
        """
        if class_default is None:
            return self
        return replace(
            self,
            low=self.low if self.low is not None else class_default.low,
            high=self.high if self.high is not None else class_default.high,
        )


[docs] @dataclass(frozen=True) class FloatRange(_BoundedRange): """Annotate a ``HyperParams`` float field as tunable within ``[low, high]``.""" low: float | None = None high: float | None = None
[docs] @dataclass(frozen=True) class IntRange(_BoundedRange): """Annotate a ``HyperParams`` int field as tunable within ``[low, high]``.""" low: int | None = None high: int | None = None
[docs] @dataclass(frozen=True) class CategoricalRange: """Annotate a ``HyperParams`` field as tunable over discrete ``choices``.""" choices: tuple[Any, ...] | None = None tune: bool = False def __post_init__(self) -> None: # noqa: D105 if self.choices is not None and len(self.choices) == 0: msg = "choices must not be empty" raise ValueError(msg)
[docs] def resolve(self, class_default: Self | None) -> Self: """Fill ``None`` choices from *class_default*. Returns: Resolved range. """ if class_default is None: return self return replace( self, choices=self.choices if self.choices is not None else class_default.choices, )
type TuningRange = FloatRange | IntRange | CategoricalRange
[docs] class ModelTuningInfo(BaseModel): """Groups a hyperparameter config with its resolved search space.""" model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) field_name: str = Field(description="Name of the HyperParams field on the parent config.") hyperparams: BaseConfig = Field(description="The HyperParams instance that owns the search space.") search_space: dict[str, TuningRange] = Field(description="Resolved tuning ranges keyed by parameter name.") @model_validator(mode="after") def _validate_search_space(self) -> Self: if not self.search_space: msg = f"search_space for '{self.field_name}' must not be empty" raise ValueError(msg) return self
[docs] def get_tuning_range(field_info: FieldInfo) -> TuningRange | None: """Return the first TuningRange found in a Pydantic FieldInfo's metadata.""" for meta in field_info.metadata: if isinstance(meta, (FloatRange, IntRange, CategoricalRange)): return meta return None
__all__ = [ "CategoricalRange", "FloatRange", "IntRange", "ModelTuningInfo", "TuningRange", "get_tuning_range", ]