Source code for examol.score.base

"""Base classes for scoring functions"""
from dataclasses import dataclass

import numpy as np

from examol.score.utils.multifi import collect_outputs
from examol.store.models import MoleculeRecord
from examol.store.recipes import PropertyRecipe


# TODO (wardlt): Make this a generic class once we move to Py3.12. https://peps.python.org/pep-0695/
[docs] @dataclass class Scorer: """Base class for algorithms which quickly assign a score to a molecule, typically using a machine learning model **Using a Scorer** Scoring a molecule requires transforming the molecule into a form compatible with a machine learning algorithm, then executing inference using the machine learning algorithm. We separate these two steps so that the former can run on local resources and the latter can run on larger remote resource. Running the scorer will then look something like .. code-block:: python scorer = Scorer() recipe = PropertyRecipe() # Recipe that we are trying to predict model = ... # The model that we'll be sending to workers inputs = model.transform_inputs(records) # Readies records to run inference model_msg = model.prepare_message(model) # Readies model to be sent to a remote worker scorer.score(model_msg, inputs) # Can be run remotely Note how the ``Scorer`` class does not hold on to the model as state. The Scorer is just the tool which holds code needed train and run the model. Training operations are broken into separate operations for similar reasons. We separate the training operation from pre-processing inputs and outputs, and updating a local copy of the model given the results of training. .. code-block: python outputs = scorer.transform_outputs(records, recipe) # Prepares label for a specific recipe update_msg = scorer.retrain(model_msg, inputs, outputs) # Run remotely model = scorer.update(model, update_msg) """
[docs] def transform_inputs(self, record_batch: list[MoleculeRecord]) -> list: """Form inputs for the model based on the data in a molecule record Args: record_batch: List of records to pre-process Returns: List of inputs ready for :meth:`score` or :meth:`retrain` """ raise NotImplementedError()
[docs] def transform_outputs(self, records: list[MoleculeRecord], recipe: PropertyRecipe) -> np.ndarray: """Gather the target outputs of the model Args: records: List of records from which to extract outputs recipe: Target recipe for the scorer for single-fidelity learning Returns: Outputs ready for model training """ return collect_outputs(records, [recipe])[:, -1]
[docs] def prepare_message(self, model: object, training: bool = False) -> object: """Get the model state as a serializable object Args: model: Model to be sent to `score` or `retrain` function training: Whether to prepare the message for training or inference Returns: Get the model state as an object which can be serialized then transmitted to a remote worker """ raise NotImplementedError()
[docs] def score(self, model_msg: object, input_data: list, **kwargs) -> np.ndarray: """Assign a score to molecules Args: model_msg: Model in a transmittable format, may need to be deserialized input_data: Batch of inputs ready for the model, as generated by :meth:`transform_inputs` Returns: The scores to a set of records """ raise NotImplementedError()
[docs] def retrain(self, model_msg: object, input_data: list, output_data: list, **kwargs) -> object: """Retrain the scorer based on new training records Args: model_msg: Model to be retrained input_data: Training set inputs, as generated by :meth:`transform_inputs` output_data: Training Set outputs, as generated by :meth:`transform_outputs` Returns: Message defining how to update the model """ raise NotImplementedError()
[docs] def update(self, model: object, update_msg: object) -> object: """Update this local copy of a model Args: model: Model to be updated update_msg: Update for the model Returns: Updated model """ raise NotImplementedError()
[docs] class MultiFidelityScorer(Scorer): """Base class for scorers which support multi-fidelity learning All subclasses support a "lower_fidelities" keyword argument to the :meth:`score` and :meth:`retrain` functions that takes any lower-fidelity information available. Subclasses should train a multi-fidelity model if provided lower-fidelity data during training and use the lower-fidelity data to enhance prediction accuracy during scoring. """
[docs] def score(self, model_msg: object, input_data: list, lower_fidelities: np.ndarray | None = None, **kwargs) -> np.ndarray: raise NotImplementedError()
[docs] def retrain(self, model_msg: object, input_data: list, output_data: list, lower_fidelities: np.ndarray | None = None, **kwargs) -> object: raise NotImplementedError()