Skip to content

Commit 28ed332

Browse files
committed
predictors dont need targets fp
1 parent 2c7b499 commit 28ed332

File tree

2 files changed

+12
-16
lines changed

2 files changed

+12
-16
lines changed

chebifier/prediction_models/electra_predictor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,13 @@ class ElectraPredictor(NNPredictor):
3737
def __init__(self, model_name: str, ckpt_path: str, **kwargs):
3838
super().__init__(model_name, ckpt_path, **kwargs)
3939
print(
40-
f"Initialised Electra model {self.model_name} (device: {self._predictor.device})"
40+
f"Initialised Electra model {self.model_name} (device: {self.predictor.device})"
4141
)
4242

4343
def explain_smiles(self, smiles) -> dict:
4444
from chebai.preprocessing.reader import EMBEDDING_OFFSET
4545

46-
token_dict = self._predictor._dm.reader.to_data(
46+
token_dict = self.predictor._dm.reader.to_data(
4747
dict(features=smiles, labels=None)
4848
)
4949
tokens = np.array(token_dict["features"]).astype(int).tolist()

chebifier/prediction_models/nn_predictor.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,38 +16,27 @@ def __init__(
1616
self,
1717
model_name: str,
1818
ckpt_path: str,
19-
target_labels_path: str,
2019
**kwargs,
2120
):
2221
super().__init__(model_name, **kwargs)
2322
self.batch_size = kwargs.get("batch_size", None)
2423
# If batch_size is not provided, it will be set to default batch size used during training in Predictor
25-
self._predictor: Predictor = Predictor(ckpt_path, self.batch_size)
26-
self.target_labels = [
27-
line.strip() for line in open(target_labels_path, encoding="utf-8")
28-
]
29-
30-
# Sanity check - ensure that the number of classes predicted by the model matches the number of target labels
31-
# TODO: In future, we can include the target labels in the model metadata and avoid this.
32-
raw_preds = self._predictor.predict_smiles(["CO"])
33-
assert len(raw_preds[0]) == len(
34-
self.target_labels
35-
), "Number of predicted classes does not match number of target labels."
24+
self.predictor: Predictor = Predictor(ckpt_path, self.batch_size)
3625

3726
@modelwise_smiles_lru_cache.batch_decorator
3827
def predict_smiles_list(self, smiles_list: list[str]) -> list:
3928
"""
4029
Returns a list with the length of smiles_list, each element is
4130
either None (=failure) or a dictionary of classes and predicted values.
4231
"""
43-
raw_preds: Tensor = self._predictor.predict_smiles(smiles_list)
32+
raw_preds: Tensor = self.predictor.predict_smiles(smiles_list)
4433
if raw_preds is not None:
4534
preds = [
4635
(
4736
{
4837
label: pred
4938
for label, pred in zip(
50-
self.target_labels, raw_preds[i].tolist()
39+
self.predictor._classification_labels, raw_preds[i].tolist()
5140
)
5241
}
5342
)
@@ -56,3 +45,10 @@ def predict_smiles_list(self, smiles_list: list[str]) -> list:
5645
return preds
5746
else:
5847
return [None for _ in smiles_list]
48+
49+
def calculate_results(self, batch):
50+
collator = self.predictor._dm.reader.COLLATOR()
51+
dat = self.predictor._model._process_batch(
52+
collator(batch).to(self.predictor.device), 0
53+
)
54+
return self.predictor._model(dat, **dat["model_kwargs"])

0 commit comments

Comments
 (0)