Skip to content

Commit bdb7be7

Browse files
committed
add docstring
1 parent 89cb005 commit bdb7be7

File tree

1 file changed

+26
-13
lines changed
  • chebai/preprocessing/datasets

1 file changed

+26
-13
lines changed

chebai/preprocessing/datasets/base.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ def __init__(
9696
self.prediction_kind = prediction_kind
9797
self.data_limit = data_limit
9898
self.label_filter = label_filter
99-
assert (balance_after_filter is not None) or (
100-
self.label_filter is None
101-
), "Filter balancing requires a filter"
99+
assert (balance_after_filter is not None) or (self.label_filter is None), (
100+
"Filter balancing requires a filter"
101+
)
102102
self.balance_after_filter = balance_after_filter
103103
self.num_workers = num_workers
104104
self.persistent_workers: bool = bool(persistent_workers)
@@ -108,13 +108,13 @@ def __init__(
108108
self.use_inner_cross_validation = (
109109
inner_k_folds > 1
110110
) # only use cv if there are at least 2 folds
111-
assert (
112-
fold_index is None or self.use_inner_cross_validation is not None
113-
), "fold_index can only be set if cross validation is used"
111+
assert fold_index is None or self.use_inner_cross_validation is not None, (
112+
"fold_index can only be set if cross validation is used"
113+
)
114114
if fold_index is not None and self.inner_k_folds is not None:
115-
assert (
116-
fold_index < self.inner_k_folds
117-
), "fold_index can't be larger than the total number of folds"
115+
assert fold_index < self.inner_k_folds, (
116+
"fold_index can't be larger than the total number of folds"
117+
)
118118
self.fold_index = fold_index
119119
self._base_dir = base_dir
120120
self.n_token_limit = n_token_limit
@@ -137,9 +137,9 @@ def num_of_labels(self):
137137

138138
@property
139139
def feature_vector_size(self):
140-
assert (
141-
self._feature_vector_size is not None
142-
), "size of feature vector must be set"
140+
assert self._feature_vector_size is not None, (
141+
"size of feature vector must be set"
142+
)
143143
return self._feature_vector_size
144144

145145
@property
@@ -1252,7 +1252,20 @@ def load_processed_data(
12521252
# If filename is provided
12531253
return self.load_processed_data_from_file(filename)
12541254

1255-
def load_processed_data_from_file(self, filename):
1255+
def load_processed_data_from_file(self, filename: str) -> list[dict[str, Any]]:
1256+
"""Load processed data from a file.
1257+
1258+
The full path is not required; only the filename is needed, as it will be joined with the processed directory.
1259+
1260+
Args:
1261+
filename (str): The name of the file to load the processed data from.
1262+
1263+
Returns:
1264+
List[Dict[str, Any]]: The loaded processed data.
1265+
1266+
Example:
1267+
data = self.load_processed_data_from_file('data.pt')
1268+
"""
12561269
return torch.load(
12571270
os.path.join(self.processed_dir, filename), weights_only=False
12581271
)

0 commit comments

Comments
 (0)