Source code for openstef_core.mixins.param_ranges
# SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
#
# SPDX-License-Identifier: MPL-2.0
"""Tuning range types and metadata for hyperparameter search spaces.
Range types (``FloatRange``, ``IntRange``, ``CategoricalRange``) are frozen
dataclasses used as ``Annotated`` metadata on ``HyperParams`` fields, where
Pydantic treats plain dataclasses as opaque field metadata.
``ModelTuningInfo`` is a Pydantic model used outside ``Annotated`` context.
"""
from dataclasses import dataclass, replace
from typing import Any, Self
from pydantic import ConfigDict, Field, model_validator
from pydantic.fields import FieldInfo
from openstef_core.base_model import BaseConfig, BaseModel
@dataclass(frozen=True)
class _BoundedRange:
"""Shared validation and resolution logic for numeric range types.
Not part of the public API — use ``FloatRange`` or ``IntRange`` instead.
"""
low: Any = None
high: Any = None
log: bool = False
tune: bool = False
def __post_init__(self) -> None:
if self.low is not None and self.high is not None and self.low > self.high:
msg = f"low ({self.low}) must be <= high ({self.high})"
raise ValueError(msg)
def resolve(self, class_default: Self | None) -> Self:
"""Fill ``None`` bounds from *class_default*.
Returns:
Resolved range.
"""
if class_default is None:
return self
return replace(
self,
low=self.low if self.low is not None else class_default.low,
high=self.high if self.high is not None else class_default.high,
)
[docs]
@dataclass(frozen=True)
class FloatRange(_BoundedRange):
"""Annotate a ``HyperParams`` float field as tunable within ``[low, high]``."""
low: float | None = None
high: float | None = None
[docs]
@dataclass(frozen=True)
class IntRange(_BoundedRange):
"""Annotate a ``HyperParams`` int field as tunable within ``[low, high]``."""
low: int | None = None
high: int | None = None
[docs]
@dataclass(frozen=True)
class CategoricalRange:
"""Annotate a ``HyperParams`` field as tunable over discrete ``choices``."""
choices: tuple[Any, ...] | None = None
tune: bool = False
def __post_init__(self) -> None: # noqa: D105
if self.choices is not None and len(self.choices) == 0:
msg = "choices must not be empty"
raise ValueError(msg)
[docs]
def resolve(self, class_default: Self | None) -> Self:
"""Fill ``None`` choices from *class_default*.
Returns:
Resolved range.
"""
if class_default is None:
return self
return replace(
self,
choices=self.choices if self.choices is not None else class_default.choices,
)
type TuningRange = FloatRange | IntRange | CategoricalRange
[docs]
class ModelTuningInfo(BaseModel):
"""Groups a hyperparameter config with its resolved search space."""
model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True)
field_name: str = Field(description="Name of the HyperParams field on the parent config.")
hyperparams: BaseConfig = Field(description="The HyperParams instance that owns the search space.")
search_space: dict[str, TuningRange] = Field(description="Resolved tuning ranges keyed by parameter name.")
@model_validator(mode="after")
def _validate_search_space(self) -> Self:
if not self.search_space:
msg = f"search_space for '{self.field_name}' must not be empty"
raise ValueError(msg)
return self
[docs]
def get_tuning_range(field_info: FieldInfo) -> TuningRange | None:
"""Return the first TuningRange found in a Pydantic FieldInfo's metadata."""
for meta in field_info.metadata:
if isinstance(meta, (FloatRange, IntRange, CategoricalRange)):
return meta
return None
__all__ = [
"CategoricalRange",
"FloatRange",
"IntRange",
"ModelTuningInfo",
"TuningRange",
"get_tuning_range",
]