Skip to content

Commit a55f0f3

Browse files
authored
Use chebi-utils library (#158)
* remove outdated JCI files * get molecule data from SDF file * add new tokens * add chembl dependency * update tests for SDF files * fix 3-STAR preprocessing * Revert "fix 3-STAR preprocessing" This reverts commit 9166d9e. * add new tokens from SDF * fix 3-star processing * add sanitize function * use chebi utils library for dataset preparation * fix subset filtering * disable rdkit logging * update unit tests * update chebi utils version
1 parent 8ee3378 commit a55f0f3

File tree

9 files changed

+275
-1330
lines changed

9 files changed

+275
-1330
lines changed

chebai/preprocessing/bin/smiles_token/tokens.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4520,3 +4520,4 @@ b
45204520
[224RaH2]
45214521
[226RaH2]
45224522
[228RaH2]
4523+
[*-:0]

chebai/preprocessing/datasets/base.py

Lines changed: 4 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -907,7 +907,9 @@ def _perform_data_preparation(self, *args: Any, **kwargs: Any) -> None:
907907
print(f"Missing processed data file (`{processed_name}` file)")
908908
os.makedirs(self.processed_dir_main, exist_ok=True)
909909
data_path = self._download_required_data()
910-
g = self._extract_class_hierarchy(data_path)
910+
from chebi_utils import build_chebi_graph
911+
912+
g = build_chebi_graph(data_path)
911913
data_df = self._graph_to_raw_dataset(g)
912914
self.save_processed(data_df, processed_name)
913915

@@ -921,26 +923,11 @@ def _download_required_data(self) -> str:
921923
"""
922924
pass
923925

924-
@abstractmethod
925-
def _extract_class_hierarchy(self, data_path: str) -> "nx.DiGraph":
926-
"""
927-
Extracts the class hierarchy from the data.
928-
Constructs a directed graph (DiGraph) using NetworkX, where nodes are annotated with fields/terms from
929-
the term documents.
930-
931-
Args:
932-
data_path (str): Path to the data.
933-
934-
Returns:
935-
nx.DiGraph: The class hierarchy graph.
936-
"""
937-
pass
938-
939926
@abstractmethod
940927
def _graph_to_raw_dataset(self, graph: "nx.DiGraph") -> pd.DataFrame:
941928
"""
942929
Converts the graph to a raw dataset.
943-
Uses the graph created by `_extract_class_hierarchy` method to extract the
930+
Uses the graph created by chebi_utils to extract the
944931
raw data in Dataframe format with additional columns corresponding to each multi-label class.
945932
946933
Args:
@@ -951,21 +938,6 @@ def _graph_to_raw_dataset(self, graph: "nx.DiGraph") -> pd.DataFrame:
951938
"""
952939
pass
953940

954-
@abstractmethod
955-
def select_classes(self, g: "nx.DiGraph", *args, **kwargs) -> List:
956-
"""
957-
Selects classes from the dataset based on a specified criteria.
958-
959-
Args:
960-
g (nx.Graph): The graph representing the dataset.
961-
*args: Additional positional arguments.
962-
**kwargs: Additional keyword arguments.
963-
964-
Returns:
965-
List: A sorted list of node IDs that meet the specified criteria.
966-
"""
967-
pass
968-
969941
def save_processed(self, data: pd.DataFrame, filename: str) -> None:
970942
"""
971943
Save the processed dataset to a pickle file.
@@ -1123,120 +1095,6 @@ def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
11231095
"""
11241096
pass
11251097

1126-
def get_test_split(
1127-
self, df: pd.DataFrame, seed: Optional[int] = None
1128-
) -> Tuple[pd.DataFrame, pd.DataFrame]:
1129-
"""
1130-
Split the input DataFrame into training and testing sets based on multilabel stratified sampling.
1131-
1132-
This method uses MultilabelStratifiedShuffleSplit to split the data such that the distribution of labels
1133-
in the training and testing sets is approximately the same. The split is based on the "labels" column
1134-
in the DataFrame.
1135-
1136-
Args:
1137-
df (pd.DataFrame): The input DataFrame containing the data to be split. It must contain a column
1138-
named "labels" with the multilabel data.
1139-
seed (int, optional): The random seed to be used for reproducibility. Default is None.
1140-
1141-
Returns:
1142-
Tuple[pd.DataFrame, pd.DataFrame]: A tuple containing the training set and testing set DataFrames.
1143-
1144-
Raises:
1145-
ValueError: If the DataFrame does not contain a column named "labels".
1146-
"""
1147-
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
1148-
from sklearn.model_selection import StratifiedShuffleSplit
1149-
1150-
print("Get test data split")
1151-
1152-
labels_list = df["labels"].tolist()
1153-
1154-
if len(labels_list[0]) > 1:
1155-
splitter = MultilabelStratifiedShuffleSplit(
1156-
n_splits=1, test_size=self.test_split, random_state=seed
1157-
)
1158-
else:
1159-
splitter = StratifiedShuffleSplit(
1160-
n_splits=1, test_size=self.test_split, random_state=seed
1161-
)
1162-
1163-
train_indices, test_indices = next(splitter.split(labels_list, labels_list))
1164-
1165-
df_train = df.iloc[train_indices]
1166-
df_test = df.iloc[test_indices]
1167-
return df_train, df_test
1168-
1169-
def get_train_val_splits_given_test(
1170-
self, df: pd.DataFrame, test_df: pd.DataFrame, seed: int = None
1171-
) -> Union[Dict[str, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]:
1172-
"""
1173-
Split the dataset into train and validation sets, given a test set.
1174-
Use test set (e.g., loaded from another source or generated in get_test_split), to avoid overlap
1175-
1176-
Args:
1177-
df (pd.DataFrame): The original dataset.
1178-
test_df (pd.DataFrame): The test dataset.
1179-
seed (int, optional): The random seed to be used for reproducibility. Default is None.
1180-
1181-
Returns:
1182-
Union[Dict[str, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]: A dictionary containing train and
1183-
validation sets if self.use_inner_cross_validation is True, otherwise a tuple containing the train
1184-
and validation DataFrames. The keys are the names of the train and validation sets, and the values
1185-
are the corresponding DataFrames.
1186-
"""
1187-
from iterstrat.ml_stratifiers import (
1188-
MultilabelStratifiedKFold,
1189-
MultilabelStratifiedShuffleSplit,
1190-
)
1191-
from sklearn.model_selection import StratifiedShuffleSplit
1192-
1193-
print("Split dataset into train / val with given test set")
1194-
1195-
test_ids = test_df["ident"].tolist()
1196-
df_trainval = df[~df["ident"].isin(test_ids)]
1197-
labels_list_trainval = df_trainval["labels"].tolist()
1198-
1199-
if self.use_inner_cross_validation:
1200-
folds = {}
1201-
kfold = MultilabelStratifiedKFold(
1202-
n_splits=self.inner_k_folds, random_state=seed
1203-
)
1204-
for fold, (train_ids, val_ids) in enumerate(
1205-
kfold.split(
1206-
labels_list_trainval,
1207-
labels_list_trainval,
1208-
)
1209-
):
1210-
df_validation = df_trainval.iloc[val_ids]
1211-
df_train = df_trainval.iloc[train_ids]
1212-
folds[self.raw_file_names_dict[f"fold_{fold}_train"]] = df_train
1213-
folds[self.raw_file_names_dict[f"fold_{fold}_validation"]] = (
1214-
df_validation
1215-
)
1216-
1217-
return folds
1218-
1219-
if len(labels_list_trainval[0]) > 1:
1220-
splitter = MultilabelStratifiedShuffleSplit(
1221-
n_splits=1,
1222-
test_size=self.validation_split / (1 - self.test_split),
1223-
random_state=seed,
1224-
)
1225-
else:
1226-
splitter = StratifiedShuffleSplit(
1227-
n_splits=1,
1228-
test_size=self.validation_split / (1 - self.test_split),
1229-
random_state=seed,
1230-
)
1231-
1232-
train_indices, validation_indices = next(
1233-
splitter.split(labels_list_trainval, labels_list_trainval)
1234-
)
1235-
1236-
df_validation = df_trainval.iloc[validation_indices]
1237-
df_train = df_trainval.iloc[train_indices]
1238-
return df_train, df_validation
1239-
12401098
def _retrieve_splits_from_csv(self) -> None:
12411099
"""
12421100
Retrieve previously saved data splits from splits.csv file or from provided file path.

0 commit comments

Comments
 (0)