BatchPredictor#

class openstef_core.mixins.BatchPredictor[source]#

Bases: Predictor, Generic

Abstract base class for batch prediction models.

This class extends Predictor to provide batch prediction capabilities, allowing multiple predictions to be made efficiently while handling individual prediction errors gracefully.

Batch predictions allow multiple prediction requests to be processed simultaneously, which is more efficient for models that support it, especially on GPUs.

Type parameters:

I: The input data type for fitting and prediction. O: The output prediction type.

Subclasses inherit all requirements from Predictor and must implement the predict_batch method appropriately.

Example

Implementing a batch linear predictor

>>> class BatchLinearPredictor(BatchPredictor[list[float], float]):
...     def __init__(self):
...         self.slope = None
...         self.intercept = None
...
...     @property
...     def is_fitted(self) -> bool:
...         return self.slope is not None and self.intercept is not None
...
...     def fit(self, data: list[float]) -> None:
...         # Simple linear fit
...         self.slope = 1.0
...         self.intercept = 0.0
...
...     def predict(self, data: list[float]) -> float:
...         return self.slope * sum(data) + self.intercept
...
...     def predict_batch(self, data: list[list[float]]) -> BatchResult[float]:
...         result = []
...         for item in data:
...             try:
...                 result.append(self.predict(item))
...             except PredictError as e:
...                 result.append(e)
...         return result
predict_batch(data: list[I]) BatchResult[source]#

Generate predictions for multiple input data items.

This method processes a batch of input data, generating predictions for each item. If any individual prediction fails, the error is captured and included in the results instead of failing the entire batch.

Parameters:

data (list[I]) – List of input data items to generate predictions for.

Returns:

List of predictions or PredictError instances for each input item. Successful predictions are of type O, failed predictions are PredictError.

Return type:

BatchResult

Note

This method does not raise exceptions; errors are captured in the result to allow partial batch processing to continue.

Parameters:

data (list[TypeVar(I)])

Return type:

GenericAlias[TypeVar(O)]