Skip to content

Commit 7547057

Browse files
committed
Squashed commit of the following:
commit efac8aa Merge: 0a66ef4 2ead405 Author: Simon Flügel <43573433+sfluegel05@users.noreply.github.com> Date: Thu Jan 22 16:40:22 2026 +0100 Merge pull request #138 from ChEB-AI/fix/read_data Raise error for invalid smiles and return None commit 2ead405 Author: aditya0by0 <aditya0by0@gmail.com> Date: Thu Jan 22 16:34:51 2026 +0100 avoid repeatition of smiles-mol conv commit 0a66ef4 Merge: b32e6c5 203b2b3 Author: Simon Flügel <43573433+sfluegel05@users.noreply.github.com> Date: Thu Jan 22 09:43:01 2026 +0100 Merge pull request #143 from schnamo/dev tidy up config files for loss, fix missing labels issue, etc commit 203b2b3 Merge: f034269 b32e6c5 Author: Charlotte Tumescheit <18518966+schnamo@users.noreply.github.com> Date: Tue Jan 20 13:18:09 2026 +0100 Merge branch 'ChEB-AI:dev' into dev commit f034269 Author: schnamo <ch.tumescheit@gmail.com> Date: Tue Jan 20 13:05:33 2026 +0100 tidy up config files for loss, fix missing labels issue, fix a number of other small issues commit b32e6c5 Merge: c9c08dc a5ea56a Author: Simon Flügel <43573433+sfluegel05@users.noreply.github.com> Date: Mon Jan 19 10:35:01 2026 +0100 Merge pull request #141 from ChEB-AI/fix/file_not_found_for_loss BCE Loss unable to locate processed files commit a5ea56a Author: aditya0by0 <aditya0by0@gmail.com> Date: Thu Jan 15 15:54:15 2026 +0100 docstring commit 89cb005 Author: aditya0by0 <aditya0by0@gmail.com> Date: Fri Jan 9 16:21:26 2026 +0100 pre-commit format commit 0094e6c Author: aditya0by0 <aditya0by0@gmail.com> Date: Fri Jan 9 16:06:18 2026 +0100 File not found error for loss commit 9052aca Author: aditya0by0 <aditya0by0@gmail.com> Date: Thu Dec 18 15:14:24 2025 +0100 Update error msg
1 parent 096ab3d commit 7547057

11 files changed

Lines changed: 56 additions & 30 deletions

File tree

chebai/models/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ def _execute(
298298
loss_kwargs = dict()
299299
if self.pass_loss_kwargs:
300300
loss_kwargs = loss_kwargs_candidates
301+
loss_kwargs["current_epoch"] = self.trainer.current_epoch
301302
loss = self.criterion(loss_data, loss_labels, **loss_kwargs)
302303
if isinstance(loss, tuple):
303304
unnamed_loss_index = 1

chebai/models/electra.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,6 @@ def __init__(
241241
self.config = ElectraConfig(**config, output_attentions=True)
242242
self.word_dropout = nn.Dropout(config.get("word_dropout", 0))
243243
self.model_type = model_type
244-
self.pass_loss_kwargs = True
245244

246245
in_d = self.config.hidden_size
247246
self.output = nn.Sequential(

chebai/preprocessing/datasets/base.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ def __init__(
9696
self.prediction_kind = prediction_kind
9797
self.data_limit = data_limit
9898
self.label_filter = label_filter
99-
assert (balance_after_filter is not None) or (self.label_filter is None), (
100-
"Filter balancing requires a filter"
101-
)
99+
assert (balance_after_filter is not None) or (
100+
self.label_filter is None
101+
), "Filter balancing requires a filter"
102102
self.balance_after_filter = balance_after_filter
103103
self.num_workers = num_workers
104104
self.persistent_workers: bool = bool(persistent_workers)
@@ -108,13 +108,13 @@ def __init__(
108108
self.use_inner_cross_validation = (
109109
inner_k_folds > 1
110110
) # only use cv if there are at least 2 folds
111-
assert fold_index is None or self.use_inner_cross_validation is not None, (
112-
"fold_index can only be set if cross validation is used"
113-
)
111+
assert (
112+
fold_index is None or self.use_inner_cross_validation is not None
113+
), "fold_index can only be set if cross validation is used"
114114
if fold_index is not None and self.inner_k_folds is not None:
115-
assert fold_index < self.inner_k_folds, (
116-
"fold_index can't be larger than the total number of folds"
117-
)
115+
assert (
116+
fold_index < self.inner_k_folds
117+
), "fold_index can't be larger than the total number of folds"
118118
self.fold_index = fold_index
119119
self._base_dir = base_dir
120120
self.n_token_limit = n_token_limit
@@ -137,9 +137,9 @@ def num_of_labels(self):
137137

138138
@property
139139
def feature_vector_size(self):
140-
assert self._feature_vector_size is not None, (
141-
"size of feature vector must be set"
142-
)
140+
assert (
141+
self._feature_vector_size is not None
142+
), "size of feature vector must be set"
143143
return self._feature_vector_size
144144

145145
@property
@@ -1322,7 +1322,20 @@ def load_processed_data(
13221322
# If filename is provided
13231323
return self.load_processed_data_from_file(filename)
13241324

1325-
def load_processed_data_from_file(self, filename):
1325+
def load_processed_data_from_file(self, filename: str) -> list[dict[str, Any]]:
1326+
"""Load processed data from a file.
1327+
1328+
The full path is not required; only the filename is needed, as it will be joined with the processed directory.
1329+
1330+
Args:
1331+
filename (str): The name of the file to load the processed data from.
1332+
1333+
Returns:
1334+
List[Dict[str, Any]]: The loaded processed data.
1335+
1336+
Example:
1337+
data = self.load_processed_data_from_file('data.pt')
1338+
"""
13261339
return torch.load(
13271340
os.path.join(self.processed_dir, filename), weights_only=False
13281341
)

chebai/preprocessing/reader.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -199,26 +199,26 @@ def _read_data(self, raw_data: str) -> List[int]:
199199
Returns:
200200
List[int]: A list of integers representing the indices of the SMILES tokens.
201201
"""
202-
if self.canonicalize_smiles:
203-
try:
204-
mol = Chem.MolFromSmiles(raw_data.strip())
205-
if mol is not None:
206-
raw_data = Chem.MolToSmiles(mol, canonical=True)
207-
except Exception as e:
208-
print(f"RDKit failed to process {raw_data}")
209-
print(f"\t{e}")
210202
try:
211203
mol = Chem.MolFromSmiles(raw_data.strip())
212204
if mol is None:
213205
raise ValueError(f"Invalid SMILES: {raw_data}")
214-
return [self._get_token_index(v[1]) for v in _tokenize(raw_data)]
215206
except ValueError as e:
216207
print(f"could not process {raw_data}")
217208
print(f"\tError: {e}")
218209
return None
219210

220-
def _back_to_smiles(self, smiles_encoded):
211+
if self.canonicalize_smiles:
212+
try:
213+
raw_data = Chem.MolToSmiles(mol, canonical=True)
214+
except Exception as e:
215+
print(f"RDKit failed to canonicalize the SMILES: {raw_data}")
216+
print(f"\t{e}")
217+
return None
221218

219+
return [self._get_token_index(v[1]) for v in _tokenize(raw_data)]
220+
221+
def _back_to_smiles(self, smiles_encoded):
222222
token_file = self.reader.token_path
223223
token_coding = {}
224224
counter = 0

configs/loss/bce_new.yml

Lines changed: 0 additions & 1 deletion
This file was deleted.

configs/loss/bce_try.yml

Lines changed: 0 additions & 1 deletion
This file was deleted.

configs/loss/bce_unweighted.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
class_path: torch.nn.BCEWithLogitsLoss
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
class_path: chebai.loss.bce_weighted.BCEWeighted
22
init_args:
3-
beta: 1000
3+
beta: 0.99
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
class_path: chebai.loss.focal_loss.FocalLoss
22
init_args:
33
task_type: multi-label
4-
num_classes: 12

configs/model/electra.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
class_path: chebai.models.Electra
22
init_args:
3-
model_type: regression
3+
model_type: classification
44
optimizer_kwargs:
55
lr: 1e-4
66
config:
@@ -9,4 +9,4 @@ init_args:
99
num_attention_heads: 8
1010
num_hidden_layers: 6
1111
type_vocab_size: 1
12-
hidden_size: 256
12+
hidden_size: 256

0 commit comments

Comments
 (0)