11from typing import List , Optional
22
33import pytorch_lightning
4- from mlops . data . tools .tools import xnat_build_dataset
5- from mlops . data .transforms .LoadImageXNATd import LoadImageXNATd
4+ from project . utils .tools import xnat_build_dataset
5+ from project .transforms .LoadImageXNATd import LoadImageXNATd
66from monai .data import CacheDataset , pad_list_data_collate
77from monai .transforms import (
88 EnsureChannelFirstd ,
2020
2121class DataModule (pytorch_lightning .LightningDataModule ):
2222
23- def __init__ (self , data_dir : str = './' , xnat_configuration : dict = None , batch_size : int = 1 , num_workers : int = 4 ,
24- test_fraction : float = 0.1 , train_val_ratio : float = 0.2 , test_batch : int = - 1 ):
23+ def __init__ (
24+ self ,
25+ data_dir : str = "./" ,
26+ xnat_configuration : dict = None ,
27+ batch_size : int = 1 ,
28+ num_workers : int = 4 ,
29+ test_fraction : float = 0.1 ,
30+ train_val_ratio : float = 0.2 ,
31+ test_batch : int = - 1 ,
32+ ):
2533
2634 super ().__init__ ()
2735 self .data_dir = data_dir
@@ -38,18 +46,24 @@ def setup(self, stage: Optional[str] = None):
3846 :param stage:
3947 :return:
4048 """
41- # list of tuples defining action functions and their data keys
42- actions = [(self .fetch_image , 'image' ),
43- (self .fetch_label , 'label' )]
49+ actions = [(self .fetch_image , "image" ), (self .fetch_label , "label" )]
4450
45- self .xnat_data_list = xnat_build_dataset (self .xnat_configuration , actions = actions , test_batch = self .test_batch )
51+ self .xnat_data_list = xnat_build_dataset (
52+ self .xnat_configuration , actions = actions , test_batch = self .test_batch
53+ )
4654
47- self .train_samples , self .valid_samples = random_split (self .xnat_data_list , [1 - self .train_val_ratio , self .train_val_ratio ])
55+ self .train_samples , self .valid_samples = random_split (
56+ self .xnat_data_list , [1 - self .train_val_ratio , self .train_val_ratio ]
57+ )
4858
4959 self .train_transforms = Compose (
5060 [
51- LoadImageXNATd (keys = ['data' ], xnat_configuration = self .xnat_configuration ,
52- image_loader = LoadImage (image_only = True ), expected_filetype_ext = '.nii.gz' ),
61+ LoadImageXNATd (
62+ keys = ["data" ],
63+ xnat_configuration = self .xnat_configuration ,
64+ image_loader = LoadImage (image_only = True ),
65+ expected_filetype_ext = ".nii.gz" ,
66+ ),
5367 EnsureChannelFirstd (keys = ["image" , "label" ]),
5468 Spacingd (
5569 keys = ["image" , "label" ],
@@ -62,8 +76,12 @@ def setup(self, stage: Optional[str] = None):
6276
6377 self .val_transforms = Compose (
6478 [
65- LoadImageXNATd (keys = ['data' ], xnat_configuration = self .xnat_configuration ,
66- image_loader = LoadImage (image_only = True ), expected_filetype_ext = '.nii.gz' ),
79+ LoadImageXNATd (
80+ keys = ["data" ],
81+ xnat_configuration = self .xnat_configuration ,
82+ image_loader = LoadImage (image_only = True ),
83+ expected_filetype_ext = ".nii.gz" ,
84+ ),
6785 EnsureChannelFirstd (keys = ["image" , "label" ]),
6886 Spacingd (
6987 keys = ["image" , "label" ],
@@ -74,8 +92,12 @@ def setup(self, stage: Optional[str] = None):
7492 ]
7593 )
7694
77- self .train_dataset = CacheDataset (data = self .train_samples , transform = self .train_transforms )
78- self .val_dataset = CacheDataset (data = self .valid_samples , transform = self .val_transforms )
95+ self .train_dataset = CacheDataset (
96+ data = self .train_samples , transform = self .train_transforms
97+ )
98+ self .val_dataset = CacheDataset (
99+ data = self .valid_samples , transform = self .val_transforms
100+ )
79101
80102 def prepare_data (self , * args , ** kwargs ):
81103 pass
@@ -85,18 +107,27 @@ def train_dataloader(self):
85107 Define train dataloader
86108 :return:
87109 """
88- return DataLoader (self .train_dataset , batch_size = self .batch_size , shuffle = True ,
89- num_workers = self .num_workers , collate_fn = pad_list_data_collate ,
90- pin_memory = is_available ())
110+ return DataLoader (
111+ self .train_dataset ,
112+ batch_size = self .batch_size ,
113+ shuffle = True ,
114+ num_workers = self .num_workers ,
115+ collate_fn = pad_list_data_collate ,
116+ pin_memory = is_available (),
117+ )
91118
92119 def val_dataloader (self ):
93120 """
94121 Define validation dataloader
95122 :return:
96123 """
97- return DataLoader (self .val_dataset , batch_size = 1 , num_workers = self .num_workers , collate_fn = pad_list_data_collate ,
98- pin_memory = is_available ())
99-
124+ return DataLoader (
125+ self .val_dataset ,
126+ batch_size = 1 ,
127+ num_workers = self .num_workers ,
128+ collate_fn = pad_list_data_collate ,
129+ pin_memory = is_available (),
130+ )
100131
101132 @staticmethod
102133 def fetch_image (subject_data : SubjectData = None ) -> List [ImageScanData ]:
@@ -105,10 +136,35 @@ def fetch_image(subject_data: SubjectData = None) -> List[ImageScanData]:
105136 along with the 'key' that it will be used to access it.
106137 """
107138 output = []
108- for exp in subject_data .experiments :
109- for scan in subject_data .experiments [exp ].scans :
110- if 'image' in subject_data .experiments [exp ].scans [scan ].id .lower ():
111- output .append (subject_data .experiments [exp ].scans [scan ])
139+
140+ if hasattr (subject_data .experiments , "values" ):
141+ experiments = subject_data .experiments .values ()
142+ else :
143+ experiments = [
144+ subject_data .experiments [exp_id ]
145+ for exp_id in subject_data .experiments .keys ()
146+ ]
147+
148+ for experiment in experiments :
149+ try :
150+ if hasattr (experiment .scans , "values" ):
151+ scans = experiment .scans .values ()
152+ else :
153+ scans = [
154+ experiment .scans [scan_id ] for scan_id in experiment .scans .keys ()
155+ ]
156+
157+ for scan_obj in scans :
158+ try :
159+ scan_name = scan_obj .id .lower ()
160+ if "image" in scan_name :
161+ output .append (scan_obj )
162+ except Exception :
163+ continue
164+
165+ except Exception :
166+ continue
167+
112168 if len (output ) > 1 :
113169 raise TypeError
114170 return output
@@ -120,10 +176,35 @@ def fetch_label(subject_data: SubjectData = None) -> List[ImageScanData]:
120176 along with the 'key' that it will be used to access it.
121177 """
122178 output = []
123- for exp in subject_data .experiments :
124- for scan in subject_data .experiments [exp ].scans :
125- if 'label' in subject_data .experiments [exp ].scans [scan ].id .lower ():
126- output .append (subject_data .experiments [exp ].scans [scan ])
179+
180+ if hasattr (subject_data .experiments , "values" ):
181+ experiments = subject_data .experiments .values ()
182+ else :
183+ experiments = [
184+ subject_data .experiments [exp_id ]
185+ for exp_id in subject_data .experiments .keys ()
186+ ]
187+
188+ for experiment in experiments :
189+ try :
190+ if hasattr (experiment .scans , "values" ):
191+ scans = experiment .scans .values ()
192+ else :
193+ scans = [
194+ experiment .scans [scan_id ] for scan_id in experiment .scans .keys ()
195+ ]
196+
197+ for scan_obj in scans :
198+ try :
199+ scan_name = scan_obj .id .lower ()
200+ if "label" in scan_name :
201+ output .append (scan_obj )
202+ except Exception :
203+ continue
204+
205+ except Exception :
206+ continue
207+
127208 if len (output ) > 1 :
128209 raise TypeError
129210 return output
0 commit comments