Skip to content

Commit 6571cfb

Browse files
committed
changes moved from #135
1 parent ed89b16 commit 6571cfb

File tree

9 files changed

+371
-93
lines changed

9 files changed

+371
-93
lines changed

README.md

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,19 @@ python -m chebai fit --config=[path-to-your-esol-config] --trainer.callbacks=con
7878

7979
### Predicting classes given SMILES strings
8080
```
81-
python3 -m chebai predict_from_file --model=[path-to-model-config] --checkpoint_path=[path-to-model] --input_path={path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]]
81+
python3 chebai/result/prediction.py predict_from_file --checkpoint_path=[path-to-model] ----smiles_file_path=[path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]]
8282
```
83-
The input files should contain a list of line-separated SMILES strings. This generates a CSV file that contains the
84-
one row for each SMILES string and one column for each class.
85-
The `classes_path` is the path to the dataset's `raw/classes.txt` file that contains the relationship between model output and ChEBI-IDs.
83+
84+
* **`--checkpoint_path`**: Path to the Lightning checkpoint file (must end with `.ckpt`).
85+
86+
* **`--smiles_file_path`**: Path to a text file containing one SMILES string per line.
87+
88+
* **`--save_to`** *(optional)*: Predictions will be saved to the path as CSV file. The CSV will contain one row per SMILES string and one column per predicted class. Default path will be the current working directory with file name as `predictions.csv`.
89+
90+
* **`--classes_path`** *(optional)*: Path to the dataset’s `classes.txt` file, which maps model output indices to ChEBI IDs.
91+
* Checkpoints created after PR #135 will have the classification labels stored in them and hence this parameter is not required.
92+
* If provided, the CSV columns will be named using the ChEBI IDs.
93+
* If omitted, then script will located the file automatically. If unable to locate then the columns will be numbered sequentially.
8694

8795
## Evaluation
8896

chebai/cli.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ def call_data_methods(data: Type[XYBaseDataModule]):
5959
apply_on="instantiate",
6060
)
6161

62+
parser.link_arguments(
63+
"data.classes_txt_file_path",
64+
"model.init_args.classes_txt_file_path",
65+
apply_on="instantiate",
66+
)
67+
6268
for kind in ("train", "val", "test"):
6369
for average in (
6470
"micro-f1",
@@ -112,7 +118,6 @@ def subcommands() -> Dict[str, Set[str]]:
112118
"validate": {"model", "dataloaders", "datamodule"},
113119
"test": {"model", "dataloaders", "datamodule"},
114120
"predict": {"model", "dataloaders", "datamodule"},
115-
"predict_from_file": {"model"},
116121
}
117122

118123

chebai/models/base.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,16 @@ def __init__(
4040
pass_loss_kwargs: bool = True,
4141
optimizer_kwargs: Optional[Dict[str, Any]] = None,
4242
exclude_hyperparameter_logging: Optional[Iterable[str]] = None,
43+
classes_txt_file_path: Optional[str] = None,
4344
**kwargs,
4445
):
4546
super().__init__(**kwargs)
4647
# super().__init__()
4748
if exclude_hyperparameter_logging is None:
4849
exclude_hyperparameter_logging = tuple()
4950
self.criterion = criterion
50-
assert out_dim is not None, "out_dim must be specified"
51-
assert input_dim is not None, "input_dim must be specified"
51+
assert out_dim is not None and out_dim > 0, "out_dim must be specified"
52+
assert input_dim is not None and input_dim > 0, "input_dim must be specified"
5253
self.out_dim = out_dim
5354
self.input_dim = input_dim
5455
print(
@@ -77,6 +78,17 @@ def __init__(
7778
self.validation_metrics = val_metrics
7879
self.test_metrics = test_metrics
7980
self.pass_loss_kwargs = pass_loss_kwargs
81+
with open(classes_txt_file_path, "r") as f:
82+
self.labels_list = [cls.strip() for cls in f.readlines()]
83+
assert len(self.labels_list) > 0, "Class labels list is empty."
84+
assert len(self.labels_list) == out_dim, (
85+
f"Number of class labels ({len(self.labels_list)}) does not match "
86+
f"the model output dimension ({out_dim})."
87+
)
88+
89+
def on_save_checkpoint(self, checkpoint):
90+
# https://lightning.ai/docs/pytorch/stable/common/checkpointing_intermediate.html#modify-a-checkpoint-anywhere
91+
checkpoint["classification_labels"] = self.labels_list
8092

8193
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
8294
# avoid errors due to unexpected keys (e.g., if loading checkpoint from a bce model and using it with a
@@ -100,7 +112,7 @@ def __init_subclass__(cls, **kwargs):
100112

101113
def _get_prediction_and_labels(
102114
self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor
103-
) -> (torch.Tensor, torch.Tensor):
115+
) -> tuple[torch.Tensor, torch.Tensor]:
104116
"""
105117
Gets the predictions and labels from the model output.
106118
@@ -151,7 +163,7 @@ def _process_for_loss(
151163
model_output: torch.Tensor,
152164
labels: torch.Tensor,
153165
loss_kwargs: Dict[str, Any],
154-
) -> (torch.Tensor, torch.Tensor, Dict[str, Any]):
166+
) -> tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
155167
"""
156168
Processes the data for loss computation.
157169
@@ -237,7 +249,15 @@ def predict_step(
237249
Returns:
238250
Dict[str, Union[torch.Tensor, Any]]: The result of the prediction step.
239251
"""
240-
return self._execute(batch, batch_idx, self.test_metrics, prefix="", log=False)
252+
assert isinstance(batch, XYData)
253+
batch = batch.to(self.device)
254+
data = self._process_batch(batch, batch_idx)
255+
model_output = self(data, **data.get("model_kwargs", dict()))
256+
257+
# Dummy labels to avoid errors in _get_prediction_and_labels
258+
labels = torch.zeros((len(batch), self.out_dim)).to(self.device)
259+
pr, _ = self._get_prediction_and_labels(data, labels, model_output)
260+
return pr
241261

242262
def _execute(
243263
self,

chebai/preprocessing/datasets/base.py

Lines changed: 85 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -340,18 +340,19 @@ def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]:
340340
for d in tqdm.tqdm(self._load_dict(path), total=lines)
341341
if d["features"] is not None
342342
]
343-
# filter for missing features in resulting data, keep features length below token limit
344-
data = [
345-
val
346-
for val in data
347-
if val["features"] is not None
348-
and (
349-
self.n_token_limit is None or len(val["features"]) <= self.n_token_limit
350-
)
351-
]
352343

344+
data = [val for val in data if self._filter_to_token_limit(val)]
353345
return data
354346

347+
def _filter_to_token_limit(self, data_instance: dict) -> bool:
348+
# filter for missing features in resulting data, keep features length below token limit
349+
if data_instance["features"] is not None and (
350+
self.n_token_limit is None
351+
or len(data_instance["features"]) <= self.n_token_limit
352+
):
353+
return True
354+
return False
355+
355356
def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
356357
"""
357358
Returns the train DataLoader.
@@ -401,22 +402,77 @@ def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]
401402
Returns:
402403
Union[DataLoader, List[DataLoader]]: A DataLoader object for test data.
403404
"""
405+
404406
return self.dataloader("test", shuffle=False, **kwargs)
405407

406408
def predict_dataloader(
407-
self, *args, **kwargs
408-
) -> Union[DataLoader, List[DataLoader]]:
409+
self,
410+
smiles_list: List[str],
411+
model_hparams: Optional[dict] = None,
412+
**kwargs,
413+
) -> tuple[DataLoader, list[int]]:
409414
"""
410415
Returns the predict DataLoader.
411416
412417
Args:
413-
*args: Additional positional arguments (unused).
418+
smiles_list (List[str]): List of SMILES strings to predict.
419+
model_hparams (Optional[dict]): Model hyperparameters.
420+
Some prediction pre-processing pipelines may require these.
414421
**kwargs: Additional keyword arguments, passed to dataloader().
415422
416423
Returns:
417-
Union[DataLoader, List[DataLoader]]: A DataLoader object for prediction data.
424+
tuple[DataLoader, list[int]]: A DataLoader object for prediction data and a list of valid indices.
418425
"""
419-
return self.dataloader(self.prediction_kind, shuffle=False, **kwargs)
426+
427+
data, valid_indices = self._process_input_for_prediction(
428+
smiles_list, model_hparams
429+
)
430+
return (
431+
DataLoader(
432+
data,
433+
collate_fn=self.reader.collator,
434+
batch_size=self.batch_size,
435+
**kwargs,
436+
),
437+
valid_indices,
438+
)
439+
440+
def _process_input_for_prediction(
441+
self, smiles_list: list[str], model_hparams: Optional[dict] = None
442+
) -> tuple[list, list]:
443+
"""
444+
Process input data for prediction.
445+
446+
Args:
447+
smiles_list (List[str]): List of SMILES strings.
448+
model_hparams (Optional[dict]): Model hyperparameters.
449+
Some prediction pre-processing pipelines may require these.
450+
451+
Returns:
452+
tuple[list, list]: Processed input data and valid indices.
453+
"""
454+
data, valid_indices = [], []
455+
for idx, smiles in enumerate(smiles_list):
456+
result = self._preprocess_smiles_for_pred(idx, smiles, model_hparams)
457+
if result is None or result["features"] is None:
458+
continue
459+
if not self._filter_to_token_limit(result):
460+
continue
461+
data.append(result)
462+
valid_indices.append(idx)
463+
464+
return data, valid_indices
465+
466+
def _preprocess_smiles_for_pred(
467+
self, idx, smiles: str, model_hparams: Optional[dict] = None
468+
) -> dict:
469+
"""Preprocess prediction data."""
470+
# Add dummy labels because the collate function requires them.
471+
# Note: If labels are set to `None`, the collator will insert a `non_null_labels` entry into `loss_kwargs`,
472+
# which later causes `_get_prediction_and_labels` method in the prediction pipeline to treat the data as empty.
473+
return self.reader.to_data(
474+
{"id": f"smiles_{idx}", "features": smiles, "labels": [1, 2]}
475+
)
420476

421477
def prepare_data(self, *args, **kwargs) -> None:
422478
if self._prepare_data_flag != 1:
@@ -563,6 +619,19 @@ def raw_file_names_dict(self) -> dict:
563619
"""
564620
raise NotImplementedError
565621

622+
@property
623+
def classes_txt_file_path(self) -> str:
624+
"""
625+
Returns the filename for the classes text file.
626+
627+
Returns:
628+
str: The filename for the classes text file.
629+
"""
630+
# This property also used in following places:
631+
# - results/prediction.py: to load class names for csv columns names
632+
# - chebai/cli.py: to link this property to `model.init_args.classes_txt_file_path`
633+
return os.path.join(self.processed_dir_main, "classes.txt")
634+
566635

567636
class MergedDataset(XYBaseDataModule):
568637
MERGED = []
@@ -1189,7 +1258,8 @@ def _retrieve_splits_from_csv(self) -> None:
11891258
print(f"Applying label filter from {self.apply_label_filter}...")
11901259
with open(self.apply_label_filter, "r") as f:
11911260
label_filter = [line.strip() for line in f]
1192-
with open(os.path.join(self.processed_dir_main, "classes.txt"), "r") as cf:
1261+
1262+
with open(self.classes_txt_file_path, "r") as cf:
11931263
classes = [line.strip() for line in cf]
11941264
# reorder labels
11951265
old_labels = np.stack(df_data["labels"])

0 commit comments

Comments
 (0)