Source code for openstef_core.base_model

# SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
#
# SPDX-License-Identifier: MPL-2.0

"""Configuration utilities for OpenSTEF Beam.

This module provides a `BaseConfig` class extending Pydantic's `BaseModel`
with convenience helpers for reading from and writing to YAML files. It also
exposes two helper functions `write_yaml_config` and `read_yaml_config` that
operate on arbitrary config instances or Pydantic models / adapters.
"""

from pathlib import Path
from typing import Annotated, Any, Self

import yaml
from pydantic import BaseModel as PydanticBaseModel
from pydantic import BeforeValidator, ConfigDict, GetCoreSchemaHandler, TypeAdapter
from pydantic_core import core_schema


[docs] class BaseModel(PydanticBaseModel): """Base model class for OpenSTEF components.""" model_config = ConfigDict(protected_namespaces=(), arbitrary_types_allowed=True, ser_json_inf_nan="null")
[docs] class BaseConfig(PydanticBaseModel): """Base configuration model. It configures Pydantic model for safe YAML serialization / deserialization. """ model_config = ConfigDict( protected_namespaces=(), extra="ignore", arbitrary_types_allowed=False, )
[docs] @classmethod def read_yaml(cls, path: Path) -> Self: """Create an instance from a YAML file. Args: path: Path to the YAML file to read. Returns: An instance of the config class populated with the file contents. """ return read_yaml_config(path, class_type=cls)
[docs] def write_yaml(self, path: Path) -> None: """Write this configuration to a YAML file. Args: path: Destination path for the YAML file (will be overwritten). """ write_yaml_config(self, path)
[docs] def write_yaml_config(config: BaseConfig, path: Path) -> None: """Write the config to a YAML file. Args: config: The configuration object to serialize. path: Destination path for the YAML file (will be overwritten). Example: >>> from pathlib import Path >>> from pydantic import BaseModel >>> class MyConfig(BaseModel): ... foo: int >>> cfg = MyConfig(foo=123) >>> write_yaml_config(cfg, Path("/tmp/test.yaml")) """ with path.open("w", encoding="utf-8") as f: yaml.dump(config.model_dump(mode="json"), f, allow_unicode=True)
[docs] def read_yaml_config[T: BaseConfig, U](path: Path, class_type: type[T] | TypeAdapter[U]) -> T | U: """Read a configuration object from a YAML file. This function supports two kinds of targets: * A subclass of `BaseConfig`, in which case Pydantic's `model_validate` is used. * A `TypeAdapter` instance for more advanced / non-`BaseModel` schema validation. Args: path: Path to the YAML file to read. class_type: The target type (a `BaseConfig` subclass) or a `TypeAdapter`. Returns: A validated configuration instance (either ``T`` or ``U`` depending on the provided ``class_type``). """ with path.open("r", encoding="utf-8") as f: data = yaml.safe_load(f) if isinstance(class_type, TypeAdapter): return class_type.validate_python(data) return class_type.model_validate(data)
[docs] class PydanticStringPrimitive: """Base class for Pydantic-compatible types with string serialization."""
[docs] def __str__(self) -> str: """Convert to string representation.""" raise NotImplementedError("Subclasses must implement __str__")
[docs] @classmethod def from_string(cls, s: str) -> Self: """Create an instance from string representation.""" raise NotImplementedError("Subclasses must implement from_string")
[docs] @classmethod def validate(cls, v: Any, _info: Any = None) -> Self: # noqa: ANN401 """Validate and convert input to this type. Args: v: Input value to validate. _info: Additional validation info (unused). Returns: Validated instance of this type. Raises: ValueError: If input cannot be converted to this type. """ if isinstance(v, cls): return v if isinstance(v, str): return cls.from_string(v) # Subclasses should handle their specific types error_message = f"Cannot convert {v} to {cls.__name__}" raise ValueError(error_message)
[docs] @classmethod def __get_pydantic_core_schema__( cls, _source_type: type[Any], _handler: GetCoreSchemaHandler ) -> core_schema.CoreSchema: """Define Pydantic validation and serialization behavior. Returns: Core schema for Pydantic validation and serialization. """ return core_schema.with_info_plain_validator_function( function=cls.validate, serialization=core_schema.plain_serializer_function_ser_schema(cls.__str__) )
[docs] def __eq__(self, other: object) -> bool: """Check equality based on string representation. Returns: True if both objects have the same string representation, False otherwise. """ if not isinstance(other, self.__class__): return NotImplemented return str(self) == str(other)
[docs] def __hash__(self) -> int: """Return hash based on string representation.""" return hash(str(self))
def _convert_none_to_nan(v: float | None) -> float: if v is None: return float("nan") return v FloatOrNan = Annotated[float, BeforeValidator(_convert_none_to_nan)] __all__ = [ "BaseConfig", "BaseModel", "FloatOrNan", "PydanticStringPrimitive", "read_yaml_config", "write_yaml_config", ]