Skip to content

Commit 0ba4f11

Browse files
committed
only calculate entended molecule graph if needed, sanitize molecule with custom method in fg rules
1 parent d2bbad5 commit 0ba4f11

2 files changed

Lines changed: 68 additions & 48 deletions

File tree

chebai_graph/preprocessing/datasets/chebi.py

Lines changed: 61 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -127,45 +127,53 @@ def enc_if_not_none(encode, value):
127127
if value is not None and len(value) > 0
128128
else None
129129
)
130-
131-
# augment molecule graph if possible (this would also happen for the properties if needed, but this avoids redundancy)
132-
if isinstance(self.reader, _AugmentorReader):
133-
returned_results = []
134-
for mol in features:
135-
try:
136-
r = self.reader._create_augmented_graph(mol)
137-
except Exception as e:
138-
r = None
139-
returned_results.append(r)
140-
mols = [augmented_mol[1] for augmented_mol in returned_results if augmented_mol is not None]
141-
else:
142-
mols = features
143130

144-
for property in self.properties:
145-
if not os.path.isfile(self.get_property_path(property)):
146-
rank_zero_info(f"Processing property {property.name}")
147-
# read all property values first, then encode
148-
rank_zero_info(f"\tReading property values of {property.name}...")
149-
property_values = [
150-
self.reader.read_property(mol, property)
151-
for mol in tqdm.tqdm(mols)
152-
]
153-
rank_zero_info(f"\tEncoding property values of {property.name}...")
154-
property.encoder.on_start(property_values=property_values)
155-
encoded_values = [
156-
enc_if_not_none(property.encoder.encode, value)
157-
for value in tqdm.tqdm(property_values)
131+
if any(
132+
not os.path.isfile(self.get_property_path(property))
133+
for property in self.properties
134+
):
135+
# augment molecule graph if possible (this would also happen for the properties if needed, but this avoids redundancy)
136+
if isinstance(self.reader, _AugmentorReader):
137+
returned_results = []
138+
for mol in features:
139+
try:
140+
r = self.reader._create_augmented_graph(mol)
141+
except Exception as e:
142+
r = None
143+
returned_results.append(r)
144+
mols = [
145+
augmented_mol[1]
146+
for augmented_mol in returned_results
147+
if augmented_mol is not None
158148
]
159-
160-
torch.save(
161-
[
162-
{property.name: torch.cat(feat), "ident": id}
163-
for feat, id in zip(encoded_values, idents)
164-
if feat is not None
165-
],
166-
self.get_property_path(property),
167-
)
168-
property.on_finish()
149+
else:
150+
mols = features
151+
152+
for property in self.properties:
153+
if not os.path.isfile(self.get_property_path(property)):
154+
rank_zero_info(f"Processing property {property.name}")
155+
# read all property values first, then encode
156+
rank_zero_info(f"\tReading property values of {property.name}...")
157+
property_values = [
158+
self.reader.read_property(mol, property)
159+
for mol in tqdm.tqdm(mols)
160+
]
161+
rank_zero_info(f"\tEncoding property values of {property.name}...")
162+
property.encoder.on_start(property_values=property_values)
163+
encoded_values = [
164+
enc_if_not_none(property.encoder.encode, value)
165+
for value in tqdm.tqdm(property_values)
166+
]
167+
168+
torch.save(
169+
[
170+
{property.name: torch.cat(feat), "ident": id}
171+
for feat, id in zip(encoded_values, idents)
172+
if feat is not None
173+
],
174+
self.get_property_path(property),
175+
)
176+
property.on_finish()
169177

170178
@property
171179
def processed_properties_dir(self) -> str:
@@ -268,7 +276,9 @@ def __init__(
268276
assert (
269277
distribution is not None
270278
and distribution in RandomFeatureInitializationReader.DISTRIBUTIONS
271-
), "When using padding for features, a valid distribution must be specified."
279+
), (
280+
"When using padding for features, a valid distribution must be specified."
281+
)
272282
self.distribution = distribution
273283
if self.pad_node_features:
274284
print(
@@ -297,7 +307,9 @@ def _merge_props_into_base(self, row: pd.Series | dict) -> GeomData:
297307
A GeomData object with merged features.
298308
"""
299309
if isinstance(row["features"], tuple):
300-
geom_data, _ = row["features"] # ignore additional returned data from _read_data (e.g. augmented molecule dict)
310+
geom_data, _ = row[
311+
"features"
312+
] # ignore additional returned data from _read_data (e.g. augmented molecule dict)
301313
else:
302314
geom_data = row["features"]
303315
assert isinstance(geom_data, GeomData)
@@ -560,7 +572,9 @@ def _merge_props_into_base(
560572
if geom_data is None:
561573
return None
562574
if isinstance(geom_data, tuple):
563-
geom_data = geom_data[0] # ignore additional returned data from _read_data (e.g. augmented molecule dict)
575+
geom_data = geom_data[
576+
0
577+
] # ignore additional returned data from _read_data (e.g. augmented molecule dict)
564578
assert isinstance(geom_data, GeomData)
565579

566580
is_atom_node = geom_data.is_atom_node
@@ -573,9 +587,9 @@ def _merge_props_into_base(
573587
edge_attr = geom_data.edge_attr
574588

575589
# Initialize node feature matrix
576-
assert (
577-
max_len_node_properties is not None
578-
), "Maximum len of node properties should not be None"
590+
assert max_len_node_properties is not None, (
591+
"Maximum len of node properties should not be None"
592+
)
579593
x = torch.zeros((num_nodes, max_len_node_properties))
580594

581595
# Track column offsets for each node type
@@ -630,9 +644,9 @@ def _merge_props_into_base(
630644
raise TypeError(f"Unsupported property type: {type(property).__name__}")
631645

632646
total_used_columns = max(atom_offset, fg_offset, graph_offset)
633-
assert (
634-
total_used_columns <= max_len_node_properties
635-
), f"Used {total_used_columns} columns, but max allowed is {max_len_node_properties}"
647+
assert total_used_columns <= max_len_node_properties, (
648+
f"Used {total_used_columns} columns, but max allowed is {max_len_node_properties}"
649+
)
636650

637651
return GeomData(
638652
x=x,
@@ -833,4 +847,4 @@ class ChEBI100_WFGE_WGN_AsPerNodeType(GraphPropAsPerNodeType, ChEBIOver100):
833847
class ChEBI25_WFGE_WGN_AsPerNodeType(GraphPropAsPerNodeType, ChEBIOverX):
834848
READER = AtomFGReader_WithFGEdges_WithGraphNode
835849

836-
THRESHOLD = 25
850+
THRESHOLD = 25

chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from rdkit.Chem import AllChem
99
from rdkit.Chem import MolToSmiles as m2s
1010

11+
from chebi_utils.sdf_extractor import _sanitize_molecule
12+
1113
from .fg_constants import ELEMENTS, FLAG_NO_FG
1214

1315

@@ -1911,7 +1913,11 @@ def get_structure(mol):
19111913
structure[frag] = {"atom": atom_idx, "is_ring_fg": False}
19121914

19131915
# Convert fragment SMILES back to mol to match with fused ring atom indices
1914-
frag_mol = Chem.MolFromSmiles(frag)
1916+
frag_mol = Chem.MolFromSmiles(frag, sanitize=False)
1917+
try:
1918+
frag_mol = _sanitize_molecule(frag_mol)
1919+
except:
1920+
pass
19151921
frag_rings = frag_mol.GetRingInfo().AtomRings()
19161922
if len(frag_rings) >= 1:
19171923
structure[frag]["is_ring_fg"] = True

0 commit comments

Comments
 (0)