@@ -96,9 +96,9 @@ def __init__(
9696 self .prediction_kind = prediction_kind
9797 self .data_limit = data_limit
9898 self .label_filter = label_filter
99- assert (balance_after_filter is not None ) or (
100- self . label_filter is None
101- ), "Filter balancing requires a filter"
99+ assert (balance_after_filter is not None ) or (self . label_filter is None ), (
100+ "Filter balancing requires a filter"
101+ )
102102 self .balance_after_filter = balance_after_filter
103103 self .num_workers = num_workers
104104 self .persistent_workers : bool = bool (persistent_workers )
@@ -108,13 +108,13 @@ def __init__(
108108 self .use_inner_cross_validation = (
109109 inner_k_folds > 1
110110 ) # only use cv if there are at least 2 folds
111- assert (
112- fold_index is None or self . use_inner_cross_validation is not None
113- ), "fold_index can only be set if cross validation is used"
111+ assert fold_index is None or self . use_inner_cross_validation is not None , (
112+ " fold_index can only be set if cross validation is used"
113+ )
114114 if fold_index is not None and self .inner_k_folds is not None :
115- assert (
116- fold_index < self . inner_k_folds
117- ), "fold_index can't be larger than the total number of folds"
115+ assert fold_index < self . inner_k_folds , (
116+ " fold_index can't be larger than the total number of folds"
117+ )
118118 self .fold_index = fold_index
119119 self ._base_dir = base_dir
120120 self .n_token_limit = n_token_limit
@@ -137,9 +137,9 @@ def num_of_labels(self):
137137
138138 @property
139139 def feature_vector_size (self ):
140- assert (
141- self . _feature_vector_size is not None
142- ), "size of feature vector must be set"
140+ assert self . _feature_vector_size is not None , (
141+ "size of feature vector must be set"
142+ )
143143 return self ._feature_vector_size
144144
145145 @property
@@ -1252,7 +1252,20 @@ def load_processed_data(
12521252 # If filename is provided
12531253 return self .load_processed_data_from_file (filename )
12541254
1255- def load_processed_data_from_file (self , filename ):
1255+ def load_processed_data_from_file (self , filename : str ) -> list [dict [str , Any ]]:
1256+ """Load processed data from a file.
1257+
1258+ The full path is not required; only the filename is needed, as it will be joined with the processed directory.
1259+
1260+ Args:
1261+ filename (str): The name of the file to load the processed data from.
1262+
1263+ Returns:
1264+ List[Dict[str, Any]]: The loaded processed data.
1265+
1266+ Example:
1267+ data = self.load_processed_data_from_file('data.pt')
1268+ """
12561269 return torch .load (
12571270 os .path .join (self .processed_dir , filename ), weights_only = False
12581271 )
0 commit comments