Skip to content

Commit bcf3839

Browse files
authored
Merge pull request #148 from ChEB-AI/feature/general_pred_pipeline
[Feature]: Generalize Prediction pipeline for Lightning CLI models
2 parents dc82d38 + 6148667 commit bcf3839

File tree

13 files changed

+505
-107
lines changed

13 files changed

+505
-107
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ chebai.egg-info
175175
lightning_logs
176176
logs
177177
.isort.cfg
178-
/.vscode
178+
/.vscode/launch.json
179179

180180
*.out
181181
*.err

.vscode/extensions.json

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{
2+
"recommendations": [
3+
"ms-python.python",
4+
"ms-python.vscode-pylance",
5+
"charliermarsh.ruff",
6+
"usernamehw.errorlens"
7+
],
8+
"unwantedRecommendations": [
9+
"ms-python.vscode-python2"
10+
]
11+
}

.vscode/settings.json

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"python.testing.unittestArgs": [
3+
"-v",
4+
"-s",
5+
"./tests",
6+
"-p",
7+
"test*.py"
8+
],
9+
"python.testing.pytestEnabled": false,
10+
"python.testing.unittestEnabled": true,
11+
"python.analysis.typeCheckingMode": "basic",
12+
"editor.formatOnSave": true,
13+
"[python]": {
14+
"editor.defaultFormatter": "charliermarsh.ruff"
15+
}
16+
}

README.md

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ python -m chebai fit --trainer=configs/training/default_trainer.yml --model=conf
6363
```
6464
A command with additional options may look like this:
6565
```
66-
python3 -m chebai fit --trainer=configs/training/default_trainer.yml --model=configs/model/electra.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --model.pretrained_checkpoint=electra_pretrained.ckpt --model.load_prefix=generator. --data=configs/data/chebi/chebi50.yml --model.criterion=configs/loss/bce.yml --data.init_args.batch_size=10 --trainer.logger.init_args.name=chebi50_bce_unweighted --data.init_args.num_workers=9 --model.pass_loss_kwargs=false --data.init_args.chebi_version=231 --data.init_args.data_limit=1000
66+
python3 -m chebai fit --trainer=configs/training/default_trainer.yml --model=configs/model/electra.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --model.pretrained_checkpoint=electra_pretrained.ckpt --model.load_prefix=generator. --data=configs/data/chebi/chebi50.yml --model.criterion=configs/loss/bce_weighted.yml --data.init_args.batch_size=10 --trainer.logger.init_args.name=chebi50_bce_weighted --data.init_args.num_workers=9 --model.pass_loss_kwargs=false --data.init_args.chebi_version=231 --data.init_args.data_limit=1000
6767
```
6868

6969
### Fine-tuning for classification tasks, e.g. Toxicity prediction
@@ -78,11 +78,16 @@ python -m chebai fit --config=[path-to-your-esol-config] --trainer.callbacks=con
7878

7979
### Predicting classes given SMILES strings
8080
```
81-
python3 -m chebai predict_from_file --model=[path-to-model-config] --checkpoint_path=[path-to-model] --input_path={path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]]
81+
python3 chebai/result/prediction.py predict_from_file --checkpoint_path=[path-to-model] --smiles_file_path=[path-to-file-containing-smiles] [--save_to=[path-to-output]]
8282
```
83-
The input files should contain a list of line-separated SMILES strings. This generates a CSV file that contains the
84-
one row for each SMILES string and one column for each class.
85-
The `classes_path` is the path to the dataset's `raw/classes.txt` file that contains the relationship between model output and ChEBI-IDs.
83+
84+
* **`--checkpoint_path`**: Path to the Lightning checkpoint file (must end with `.ckpt`).
85+
86+
* **`--smiles_file_path`**: Path to a text file containing one SMILES string per line.
87+
88+
* **`--save_to`** *(optional)*: Predictions will be saved to the path as CSV file. The CSV will contain one row per SMILES string and one column per predicted class. Default path will be the current working directory with file name as `predictions.csv`.
89+
90+
> **Note**: Newly created checkpoints after PR #148 must be used for this prediction pipeline. The list of ChEBI classes (classification labels) used during training is stored in new checkpoints, which are required.
8691
8792
## Evaluation
8893

@@ -96,7 +101,7 @@ An example notebook is provided at `tutorials/eval_model_basic.ipynb`.
96101
Alternatively, you can evaluate the model via the CLI:
97102

98103
```bash
99-
python -m chebai test --trainer=configs/training/default_trainer.yml --trainer.devices=1 --trainer.num_nodes=1 --ckpt_path=[path-to-finetuned-model] --model=configs/model/electra.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --data=configs/data/chebi/chebi50.yml --data.init_args.batch_size=32 --data.init_args.num_workers=10 --data.init_args.chebi_version=[chebi-version] --model.pass_loss_kwargs=false --model.criterion=configs/loss/bce.yml --model.criterion.init_args.beta=0.99 --data.init_args.splits_file_path=[path-to-splits-file]
104+
python -m chebai test --trainer=configs/training/default_trainer.yml --trainer.devices=1 --trainer.num_nodes=1 --ckpt_path=[path-to-finetuned-model] --model=configs/model/electra.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --data=configs/data/chebi/chebi50.yml --data.init_args.batch_size=32 --data.init_args.num_workers=10 --data.init_args.chebi_version=[chebi-version] --model.pass_loss_kwargs=false --model.criterion=configs/loss/bce_weighted.yml --model.criterion.init_args.beta=0.99 --data.init_args.splits_file_path=[path-to-splits-file]
100105
```
101106

102107
> **Note**: It is recommended to use `devices=1` and `num_nodes=1` during testing; multi-device settings use a `DistributedSampler`, which may replicate some samples to maintain equal batch sizes, so using a single device ensures that each sample or batch is evaluated exactly once.

chebai/cli.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ def call_data_methods(data: Type[XYBaseDataModule]):
5959
apply_on="instantiate",
6060
)
6161

62+
parser.link_arguments(
63+
"data.classes_txt_file_path",
64+
"model.init_args.classes_txt_file_path",
65+
apply_on="instantiate",
66+
)
67+
6268
for kind in ("train", "val", "test"):
6369
for average in (
6470
"micro-f1",
@@ -111,8 +117,6 @@ def subcommands() -> Dict[str, Set[str]]:
111117
"fit": {"model", "train_dataloaders", "val_dataloaders", "datamodule"},
112118
"validate": {"model", "dataloaders", "datamodule"},
113119
"test": {"model", "dataloaders", "datamodule"},
114-
"predict": {"model", "dataloaders", "datamodule"},
115-
"predict_from_file": {"model"},
116120
}
117121

118122

chebai/models/base.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,16 @@ def __init__(
4040
pass_loss_kwargs: bool = True,
4141
optimizer_kwargs: Optional[Dict[str, Any]] = None,
4242
exclude_hyperparameter_logging: Optional[Iterable[str]] = None,
43+
classes_txt_file_path: Optional[str] = None,
4344
**kwargs,
4445
):
4546
super().__init__(**kwargs)
4647
# super().__init__()
4748
if exclude_hyperparameter_logging is None:
4849
exclude_hyperparameter_logging = tuple()
4950
self.criterion = criterion
50-
assert out_dim is not None, "out_dim must be specified"
51-
assert input_dim is not None, "input_dim must be specified"
51+
assert out_dim is not None and out_dim > 0, "out_dim must be specified"
52+
assert input_dim is not None and input_dim > 0, "input_dim must be specified"
5253
self.out_dim = out_dim
5354
self.input_dim = input_dim
5455
print(
@@ -62,6 +63,7 @@ def __init__(
6263
"train_metrics",
6364
"val_metrics",
6465
"test_metrics",
66+
"classes_txt_file_path",
6567
*exclude_hyperparameter_logging,
6668
]
6769
)
@@ -78,6 +80,23 @@ def __init__(
7880
self.test_metrics = test_metrics
7981
self.pass_loss_kwargs = pass_loss_kwargs
8082

83+
self.classes_txt_file_path = classes_txt_file_path
84+
85+
# During prediction `classes_txt_file_path` is set to None
86+
if classes_txt_file_path is not None:
87+
with open(classes_txt_file_path, "r") as f:
88+
self.labels_list = [cls.strip() for cls in f.readlines()]
89+
assert len(self.labels_list) > 0, "Class labels list is empty."
90+
assert len(self.labels_list) == out_dim, (
91+
f"Number of class labels ({len(self.labels_list)}) does not match "
92+
f"the model output dimension ({out_dim})."
93+
)
94+
95+
def on_save_checkpoint(self, checkpoint):
96+
if self.classes_txt_file_path is not None:
97+
# https://lightning.ai/docs/pytorch/stable/common/checkpointing_intermediate.html#modify-a-checkpoint-anywhere
98+
checkpoint["classification_labels"] = self.labels_list
99+
81100
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
82101
# avoid errors due to unexpected keys (e.g., if loading checkpoint from a bce model and using it with a
83102
# different loss)
@@ -100,7 +119,7 @@ def __init_subclass__(cls, **kwargs):
100119

101120
def _get_prediction_and_labels(
102121
self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor
103-
) -> (torch.Tensor, torch.Tensor):
122+
) -> tuple[torch.Tensor, torch.Tensor]:
104123
"""
105124
Gets the predictions and labels from the model output.
106125
@@ -151,7 +170,7 @@ def _process_for_loss(
151170
model_output: torch.Tensor,
152171
labels: torch.Tensor,
153172
loss_kwargs: Dict[str, Any],
154-
) -> (torch.Tensor, torch.Tensor, Dict[str, Any]):
173+
) -> tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
155174
"""
156175
Processes the data for loss computation.
157176
@@ -237,7 +256,7 @@ def predict_step(
237256
Returns:
238257
Dict[str, Union[torch.Tensor, Any]]: The result of the prediction step.
239258
"""
240-
return self._execute(batch, batch_idx, self.test_metrics, prefix="", log=False)
259+
return self._execute(batch, batch_idx, log=False)
241260

242261
def _execute(
243262
self,

chebai/preprocessing/datasets/base.py

Lines changed: 92 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -340,18 +340,19 @@ def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]:
340340
for d in tqdm.tqdm(self._load_dict(path), total=lines)
341341
if d["features"] is not None
342342
]
343-
# filter for missing features in resulting data, keep features length below token limit
344-
data = [
345-
val
346-
for val in data
347-
if val["features"] is not None
348-
and (
349-
self.n_token_limit is None or len(val["features"]) <= self.n_token_limit
350-
)
351-
]
352343

344+
data = [val for val in data if self._filter_to_token_limit(val)]
353345
return data
354346

347+
def _filter_to_token_limit(self, data_instance: dict) -> bool:
348+
# filter for missing features in resulting data, keep features length below token limit
349+
if data_instance["features"] is not None and (
350+
self.n_token_limit is None
351+
or len(data_instance["features"]) <= self.n_token_limit
352+
):
353+
return True
354+
return False
355+
355356
def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
356357
"""
357358
Returns the train DataLoader.
@@ -401,22 +402,84 @@ def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]
401402
Returns:
402403
Union[DataLoader, List[DataLoader]]: A DataLoader object for test data.
403404
"""
405+
404406
return self.dataloader("test", shuffle=False, **kwargs)
405407

406408
def predict_dataloader(
407-
self, *args, **kwargs
408-
) -> Union[DataLoader, List[DataLoader]]:
409+
self,
410+
smiles_list: List[str],
411+
model_hparams: dict,
412+
**kwargs,
413+
) -> tuple[DataLoader, list[int]]:
409414
"""
410415
Returns the predict DataLoader.
411416
412417
Args:
413-
*args: Additional positional arguments (unused).
418+
smiles_list (List[str]): List of SMILES strings to predict.
419+
model_hparams (Optional[dict]): Model hyperparameters.
420+
Some prediction pre-processing pipelines may require these.
414421
**kwargs: Additional keyword arguments, passed to dataloader().
415422
416423
Returns:
417-
Union[DataLoader, List[DataLoader]]: A DataLoader object for prediction data.
424+
tuple[DataLoader, list[int]]: A DataLoader object for prediction data and a list of valid indices.
418425
"""
419-
return self.dataloader(self.prediction_kind, shuffle=False, **kwargs)
426+
427+
data, valid_indices = self._process_input_for_prediction(
428+
smiles_list, model_hparams
429+
)
430+
return (
431+
DataLoader(
432+
data,
433+
collate_fn=self.reader.collator,
434+
batch_size=self.batch_size,
435+
**kwargs,
436+
),
437+
valid_indices,
438+
)
439+
440+
def _process_input_for_prediction(
441+
self, smiles_list: list[str], model_hparams: dict
442+
) -> tuple[list, list]:
443+
"""
444+
Process input data for prediction.
445+
446+
Args:
447+
smiles_list (List[str]): List of SMILES strings.
448+
model_hparams (dict): Model hyperparameters.
449+
Some prediction pre-processing pipelines may require these.
450+
451+
Returns:
452+
tuple[list, list]: Processed input data and valid indices.
453+
"""
454+
data, valid_indices = [], []
455+
num_of_labels = int(model_hparams["out_dim"])
456+
self._dummy_labels: list = list(range(1, num_of_labels + 1))
457+
458+
for idx, smiles in enumerate(smiles_list):
459+
result = self._preprocess_smiles_for_pred(idx, smiles, model_hparams)
460+
if result is None or result["features"] is None:
461+
continue
462+
if not self._filter_to_token_limit(result):
463+
continue
464+
data.append(result)
465+
valid_indices.append(idx)
466+
467+
return data, valid_indices
468+
469+
def _preprocess_smiles_for_pred(
470+
self, idx: int, smiles: str, model_hparams: Optional[dict] = None
471+
) -> dict:
472+
"""Preprocess prediction data."""
473+
# Add dummy labels because the collate function requires them.
474+
# Note: If labels are set to `None`, the collator will insert a `non_null_labels` entry into `loss_kwargs`,
475+
# which later causes `_get_prediction_and_labels` method in the prediction pipeline to treat the data as empty.
476+
return self.reader.to_data(
477+
{
478+
"id": f"smiles_{idx}",
479+
"features": smiles,
480+
"labels": self._dummy_labels,
481+
}
482+
)
420483

421484
def prepare_data(self, *args, **kwargs) -> None:
422485
if self._prepare_data_flag != 1:
@@ -563,6 +626,19 @@ def raw_file_names_dict(self) -> dict:
563626
"""
564627
raise NotImplementedError
565628

629+
@property
630+
def classes_txt_file_path(self) -> str:
631+
"""
632+
Returns the filename for the classes text file.
633+
634+
Returns:
635+
str: The filename for the classes text file.
636+
"""
637+
# This property also used in following places:
638+
# - chebai/result/prediction.py: to load class names for csv columns names
639+
# - chebai/cli.py: to link this property to `model.init_args.classes_txt_file_path`
640+
return os.path.join(self.processed_dir_main, "classes.txt")
641+
566642

567643
class MergedDataset(XYBaseDataModule):
568644
MERGED = []
@@ -1189,7 +1265,8 @@ def _retrieve_splits_from_csv(self) -> None:
11891265
print(f"Applying label filter from {self.apply_label_filter}...")
11901266
with open(self.apply_label_filter, "r") as f:
11911267
label_filter = [line.strip() for line in f]
1192-
with open(os.path.join(self.processed_dir_main, "classes.txt"), "r") as cf:
1268+
1269+
with open(self.classes_txt_file_path, "r") as cf:
11931270
classes = [line.strip() for line in cf]
11941271
# reorder labels
11951272
old_labels = np.stack(df_data["labels"])
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""
2+
Docstring for chebai.preprocessing.migration.migrate_checkpoints
3+
4+
This script migrates lightning checkpoints created before python-chebai
5+
version 1.2.1 to be compatible with the new version.
6+
7+
The main change is the addition of a new key "classification_labels" in the checkpoint,
8+
which is required for the new version of python-chebai from version 1.2.1 onwards.
9+
10+
For more details, see the pull request: https://github.com/ChEB-AI/python-chebai/pulls
11+
"""
12+
13+
import sys
14+
15+
import torch
16+
17+
18+
def add_class_labels_to_checkpoint(input_path, classes_file_path):
19+
print(f"Loading checkpoint from {input_path}...")
20+
print(f"Loading class labels from {classes_file_path}...")
21+
22+
with open(classes_file_path, "r") as f:
23+
class_labels = [line.strip() for line in f.readlines()]
24+
25+
assert len(class_labels) > 0, "The classes file is empty."
26+
27+
# 1. Load the checkpoint
28+
checkpoint = torch.load(
29+
input_path, map_location=torch.device("cpu"), weights_only=False
30+
)
31+
32+
if "classification_labels" in checkpoint:
33+
print(
34+
"Warning: 'classification_labels' key already exists in the checkpoint and will be overwritten."
35+
)
36+
37+
# 2. Add your custom key/value pair
38+
checkpoint["classification_labels"] = class_labels
39+
40+
# 3. Save the modified checkpoint
41+
output_path = input_path.replace(".ckpt", "_modified.ckpt")
42+
torch.save(checkpoint, output_path)
43+
print(f"Successfully added classification_labels and saved to {output_path}")
44+
45+
46+
if __name__ == "__main__":
47+
if len(sys.argv) < 3:
48+
print("Usage: python migrate_checkpoints.py <input_checkpoint> <classes_file>")
49+
sys.exit(1)
50+
51+
input_ckpt = sys.argv[1]
52+
classes_file = sys.argv[2]
53+
54+
add_class_labels_to_checkpoint(
55+
input_path=input_ckpt, classes_file_path=classes_file
56+
)

0 commit comments

Comments
 (0)