22
33from __future__ import annotations
44
5- import numpy as np
65import pandas as pd
76
87
9- def create_splits (
8+ def create_multilabel_splits (
109 df : pd .DataFrame ,
10+ label_start_col : int = 2 ,
1111 train_ratio : float = 0.8 ,
1212 val_ratio : float = 0.1 ,
1313 test_ratio : float = 0.1 ,
14- stratify_col : str | None = None ,
15- seed : int = 42 ,
14+ seed : int | None = 42 ,
1615) -> dict [str , pd .DataFrame ]:
17- """Create stratified train/validation/test splits of a DataFrame.
16+ """Create stratified train/validation/test splits for multilabel DataFrames.
17+
18+ Columns from index *label_start_col* onwards are treated as binary label
19+ columns (one boolean column per label). The stratification strategy is
20+ chosen automatically based on the number of label columns:
21+
22+ - More than one label column: ``MultilabelStratifiedShuffleSplit`` from
23+ the ``iterative-stratification`` package.
24+ - Single label column: ``StratifiedShuffleSplit`` from ``scikit-learn``.
1825
1926 Parameters
2027 ----------
2128 df : pd.DataFrame
22- Input data to split.
29+ Input data. Columns ``0`` to ``label_start_col - 1`` are treated as
30+ feature/metadata columns; all remaining columns are boolean label
31+ columns. A typical ChEBI DataFrame has columns
32+ ``["chebi_id", "mol", "label1", "label2", ...]``.
33+ label_start_col : int
34+ Index of the first label column (default 2).
2335 train_ratio : float
2436 Fraction of data for training (default 0.8).
2537 val_ratio : float
2638 Fraction of data for validation (default 0.1).
2739 test_ratio : float
2840 Fraction of data for testing (default 0.1).
29- stratify_col : str or None
30- Column name to use for stratification. If None, splits are random.
31- seed : int
41+ seed : int or None
3242 Random seed for reproducibility.
3343
3444 Returns
@@ -40,44 +50,60 @@ def create_splits(
4050 Raises
4151 ------
4252 ValueError
43- If the ratios do not sum to 1 or any ratio is outside ``[0, 1]``.
53+ If the ratios do not sum to 1, any ratio is outside ``[0, 1]``, or
54+ *label_start_col* is out of range.
4455 """
4556 if abs (train_ratio + val_ratio + test_ratio - 1.0 ) > 1e-6 :
4657 raise ValueError ("train_ratio + val_ratio + test_ratio must equal 1.0" )
4758 if any (r < 0 or r > 1 for r in [train_ratio , val_ratio , test_ratio ]):
4859 raise ValueError ("All ratios must be between 0 and 1" )
60+ if label_start_col >= len (df .columns ):
61+ raise ValueError (
62+ f"label_start_col={ label_start_col } is out of range for a DataFrame "
63+ f"with { len (df .columns )} columns"
64+ )
4965
50- rng = np .random .default_rng (seed )
66+ from iterstrat .ml_stratifiers import MultilabelStratifiedShuffleSplit
67+ from sklearn .model_selection import StratifiedShuffleSplit
5168
52- if stratify_col is not None :
53- return _stratified_split (df , train_ratio , val_ratio , test_ratio , stratify_col , rng )
54- return _random_split (df , train_ratio , val_ratio , test_ratio , rng )
69+ labels_matrix = df .iloc [:, label_start_col :].values
70+ is_multilabel = labels_matrix .shape [1 ] > 1
71+ # StratifiedShuffleSplit requires a 1-D label array
72+ y = labels_matrix if is_multilabel else labels_matrix [:, 0 ]
5573
74+ df_reset = df .reset_index (drop = True )
5675
57- def _stratified_split (
58- df : pd .DataFrame ,
59- train_ratio : float ,
60- val_ratio : float ,
61- test_ratio : float , # noqa: ARG001
62- stratify_col : str ,
63- rng : np .random .Generator ,
64- ) -> dict [str , pd .DataFrame ]:
65- train_indices : list [int ] = []
66- val_indices : list [int ] = []
67- test_indices : list [int ] = []
76+ # ── Step 1: carve out the test set ──────────────────────────────────────
77+ if is_multilabel :
78+ test_splitter = MultilabelStratifiedShuffleSplit (
79+ n_splits = 1 , test_size = test_ratio , random_state = seed
80+ )
81+ else :
82+ test_splitter = StratifiedShuffleSplit (n_splits = 1 , test_size = test_ratio , random_state = seed )
83+ train_val_idx , test_idx = next (test_splitter .split (y , y ))
84+
85+ df_test = df_reset .iloc [test_idx ]
86+ df_trainval = df_reset .iloc [train_val_idx ]
87+
88+ # ── Step 2: split train/val from the remaining data ─────────────────────
89+ y_trainval = y [train_val_idx ]
90+ val_ratio_adjusted = val_ratio / (1.0 - test_ratio )
6891
69- for _ , group in df .groupby (stratify_col , sort = False ):
70- group_indices = rng .permutation (np .array (group .index .tolist ()))
71- n = len (group_indices )
72- n_train = max (1 , int (n * train_ratio ))
73- n_val = max (0 , int (n * val_ratio ))
92+ if is_multilabel :
93+ val_splitter = MultilabelStratifiedShuffleSplit (
94+ n_splits = 1 , test_size = val_ratio_adjusted , random_state = seed
95+ )
96+ else :
97+ val_splitter = StratifiedShuffleSplit (
98+ n_splits = 1 , test_size = val_ratio_adjusted , random_state = seed
99+ )
100+ train_idx_inner , val_idx_inner = next (val_splitter .split (y_trainval , y_trainval ))
74101
75- train_indices .extend (group_indices [:n_train ].tolist ())
76- val_indices .extend (group_indices [n_train : n_train + n_val ].tolist ())
77- test_indices .extend (group_indices [n_train + n_val :].tolist ())
102+ df_train = df_trainval .iloc [train_idx_inner ]
103+ df_val = df_trainval .iloc [val_idx_inner ]
78104
79105 return {
80- "train" : df . loc [ train_indices ] .reset_index (drop = True ),
81- "val" : df . loc [ val_indices ] .reset_index (drop = True ),
82- "test" : df . loc [ test_indices ] .reset_index (drop = True ),
106+ "train" : df_train .reset_index (drop = True ),
107+ "val" : df_val .reset_index (drop = True ),
108+ "test" : df_test .reset_index (drop = True ),
83109 }
0 commit comments