Source code for openstef_core.mixins.stateful
# SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
#
# SPDX-License-Identifier: MPL-2.0
"""State management for serializable objects.
Enables objects to save their state and restore it later. Supports model
deployment, caching, and distributed processing by providing versioned
serialization with automatic state migration.
"""
import warnings
from typing import ClassVar, TypedDict, cast
from openstef_core.types import Any
[docs]
class VersionedState(TypedDict):
"""Versioned state structure for object serialization.
Contains version metadata and the actual object state, enabling
backward compatibility through state migration.
"""
__version__: int
__class_name__: str
state: dict[str, Any]
[docs]
class Stateful:
"""Mixin for objects that can save and restore their internal state.
Provides versioned serialization with automatic state migration. Objects can be
pickled, saved to disk, transmitted over networks, or stored in databases, then
restored to their previous state. Version tracking ensures backward compatibility
when object structure changes.
Subclasses can override `_migrate_state` to handle state migrations between versions.
Increment `_VERSION` when making incompatible changes to object structure.
"""
_VERSION: ClassVar[int] = 1
[docs]
def __getstate__(self) -> VersionedState:
"""Serialize object state with version metadata.
Returns:
Versioned state dictionary containing version number, class name,
and the object's internal state.
"""
if hasattr(super(), "__getstate__"):
# In case of pydantic or other base classes implementing __getstate__
base_state = super().__getstate__() # type: ignore[misc]
# Pydantic returns None for models with no fields
if base_state is None:
base_state = {}
else:
base_state = self.__dict__.copy()
return VersionedState(
__version__=self._VERSION,
__class_name__=self.__class__.__name__,
state=cast(dict[str, Any], base_state),
)
[docs]
def __setstate__(self, state: Any) -> None:
"""Restore object from serialized state.
Handles both versioned and legacy state formats. Automatically migrates
state from older versions using `_migrate_state`. Warns when loading
legacy objects or when current version is older than saved version.
Args:
state: Serialized state, either VersionedState dict or legacy format.
"""
# Handle legacy objects without versioning
if not isinstance(state, dict) or "__version__" not in state: # pyright: ignore[reportUnnecessaryIsInstance]
warnings.warn(
f"Loading legacy {self.__class__.__name__} without version metadata.", UserWarning, stacklevel=2
)
self._restore_state(state)
return
state = cast(VersionedState, state)
saved_version: int = state["__version__"]
actual_state: dict[str, Any] = state["state"]
if saved_version < self._VERSION:
actual_state = self._migrate_state(state=actual_state, from_version=saved_version, to_version=self._VERSION)
elif saved_version > self._VERSION:
warnings.warn(
f"{self.__class__.__name__} saved with v{saved_version}, "
f"current is v{self._VERSION}. Forward compatibility not guaranteed.",
UserWarning,
stacklevel=2,
)
self._restore_state(actual_state)
def _restore_state(self, state: Any) -> None:
"""Restore object's internal state from a dictionary.
Delegates to parent class `__setstate__` if available, otherwise
updates `__dict__` directly.
Args:
state: State dictionary to restore.
"""
# Check if any parent class has __setstate__
if hasattr(super(), "__setstate__"):
super().__setstate__(state) # type: ignore[misc]
elif state: # Only update if state is not empty
self.__dict__.update(state)
@classmethod
def _migrate_state(cls, state: dict[str, Any], from_version: int, to_version: int) -> dict[str, Any]:
"""Migrate state from an older version to the current version.
Override this method in subclasses to handle state transformations when
the object structure changes. Called automatically during deserialization
when saved_version < current_version.
Args:
state: State dictionary from the older version.
from_version: Version of the saved state.
to_version: Target version (current `_VERSION`).
Returns:
Migrated state dictionary compatible with current version.
"""
_ = from_version, to_version # Important arguments, but unused in base implementation
return state