"""Base classes for scoring functions"""fromdataclassesimportdataclassimportnumpyasnpfromexamol.score.utils.multifiimportcollect_outputsfromexamol.store.modelsimportMoleculeRecordfromexamol.store.recipesimportPropertyRecipe# TODO (wardlt): Make this a generic class once we move to Py3.12. https://peps.python.org/pep-0695/
[docs]@dataclassclassScorer:"""Base class for algorithms which quickly assign a score to a molecule, typically using a machine learning model **Using a Scorer** Scoring a molecule requires transforming the molecule into a form compatible with a machine learning algorithm, then executing inference using the machine learning algorithm. We separate these two steps so that the former can run on local resources and the latter can run on larger remote resource. Running the scorer will then look something like .. code-block:: python scorer = Scorer() recipe = PropertyRecipe() # Recipe that we are trying to predict model = ... # The model that we'll be sending to workers inputs = model.transform_inputs(records) # Readies records to run inference model_msg = model.prepare_message(model) # Readies model to be sent to a remote worker scorer.score(model_msg, inputs) # Can be run remotely Note how the ``Scorer`` class does not hold on to the model as state. The Scorer is just the tool which holds code needed train and run the model. Training operations are broken into separate operations for similar reasons. We separate the training operation from pre-processing inputs and outputs, and updating a local copy of the model given the results of training. .. code-block: python outputs = scorer.transform_outputs(records, recipe) # Prepares label for a specific recipe update_msg = scorer.retrain(model_msg, inputs, outputs) # Run remotely model = scorer.update(model, update_msg) """
[docs]deftransform_inputs(self,record_batch:list[MoleculeRecord])->list:"""Form inputs for the model based on the data in a molecule record Args: record_batch: List of records to pre-process Returns: List of inputs ready for :meth:`score` or :meth:`retrain` """raiseNotImplementedError()
[docs]deftransform_outputs(self,records:list[MoleculeRecord],recipe:PropertyRecipe)->np.ndarray:"""Gather the target outputs of the model Args: records: List of records from which to extract outputs recipe: Target recipe for the scorer for single-fidelity learning Returns: Outputs ready for model training """returncollect_outputs(records,[recipe])[:,-1]
[docs]defprepare_message(self,model:object,training:bool=False)->object:"""Get the model state as a serializable object Args: model: Model to be sent to `score` or `retrain` function training: Whether to prepare the message for training or inference Returns: Get the model state as an object which can be serialized then transmitted to a remote worker """raiseNotImplementedError()
[docs]defscore(self,model_msg:object,input_data:list,**kwargs)->np.ndarray:"""Assign a score to molecules Args: model_msg: Model in a transmittable format, may need to be deserialized input_data: Batch of inputs ready for the model, as generated by :meth:`transform_inputs` Returns: The scores to a set of records """raiseNotImplementedError()
[docs]defretrain(self,model_msg:object,input_data:list,output_data:list,**kwargs)->object:"""Retrain the scorer based on new training records Args: model_msg: Model to be retrained input_data: Training set inputs, as generated by :meth:`transform_inputs` output_data: Training Set outputs, as generated by :meth:`transform_outputs` Returns: Message defining how to update the model """raiseNotImplementedError()
[docs]defupdate(self,model:object,update_msg:object)->object:"""Update this local copy of a model Args: model: Model to be updated update_msg: Update for the model Returns: Updated model """raiseNotImplementedError()
[docs]classMultiFidelityScorer(Scorer):"""Base class for scorers which support multi-fidelity learning All subclasses support a "lower_fidelities" keyword argument to the :meth:`score` and :meth:`retrain` functions that takes any lower-fidelity information available. Subclasses should train a multi-fidelity model if provided lower-fidelity data during training and use the lower-fidelity data to enhance prediction accuracy during scoring. """