@@ -16,38 +16,27 @@ def __init__(
1616 self ,
1717 model_name : str ,
1818 ckpt_path : str ,
19- target_labels_path : str ,
2019 ** kwargs ,
2120 ):
2221 super ().__init__ (model_name , ** kwargs )
2322 self .batch_size = kwargs .get ("batch_size" , None )
2423 # If batch_size is not provided, it will be set to default batch size used during training in Predictor
25- self ._predictor : Predictor = Predictor (ckpt_path , self .batch_size )
26- self .target_labels = [
27- line .strip () for line in open (target_labels_path , encoding = "utf-8" )
28- ]
29-
30- # Sanity check - ensure that the number of classes predicted by the model matches the number of target labels
31- # TODO: In future, we can include the target labels in the model metadata and avoid this.
32- raw_preds = self ._predictor .predict_smiles (["CO" ])
33- assert len (raw_preds [0 ]) == len (
34- self .target_labels
35- ), "Number of predicted classes does not match number of target labels."
24+ self .predictor : Predictor = Predictor (ckpt_path , self .batch_size )
3625
3726 @modelwise_smiles_lru_cache .batch_decorator
3827 def predict_smiles_list (self , smiles_list : list [str ]) -> list :
3928 """
4029 Returns a list with the length of smiles_list, each element is
4130 either None (=failure) or a dictionary of classes and predicted values.
4231 """
43- raw_preds : Tensor = self ._predictor .predict_smiles (smiles_list )
32+ raw_preds : Tensor = self .predictor .predict_smiles (smiles_list )
4433 if raw_preds is not None :
4534 preds = [
4635 (
4736 {
4837 label : pred
4938 for label , pred in zip (
50- self .target_labels , raw_preds [i ].tolist ()
39+ self .predictor . _classification_labels , raw_preds [i ].tolist ()
5140 )
5241 }
5342 )
@@ -56,3 +45,10 @@ def predict_smiles_list(self, smiles_list: list[str]) -> list:
5645 return preds
5746 else :
5847 return [None for _ in smiles_list ]
48+
49+ def calculate_results (self , batch ):
50+ collator = self .predictor ._dm .reader .COLLATOR ()
51+ dat = self .predictor ._model ._process_batch (
52+ collator (batch ).to (self .predictor .device ), 0
53+ )
54+ return self .predictor ._model (dat , ** dat ["model_kwargs" ])
0 commit comments