Skip to content

Commit a739a9d

Browse files
authored
Merge pull request #11 from GSTT-CSC/upgrade-pipeline
Upgrade pipeline
2 parents ecbc29e + cf74608 commit a739a9d

10 files changed

Lines changed: 655 additions & 84 deletions

File tree

Dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM pytorch/pytorch:latest
1+
FROM python:3.10-slim
22

33
WORKDIR /project
44

@@ -10,6 +10,6 @@ ENV PYTHONPATH="/mlflow/projects/code/:$PYTHONPATH"
1010

1111
COPY . .
1212

13-
# install requirements
13+
# install requirements with compatible versions
1414
RUN python -m pip install --upgrade pip && \
1515
python -m pip install -r requirements.txt

config/local_config.cfg

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
[server]
2-
MLFLOW_S3_ENDPOINT_URL = http://0.0.0.0:8002
3-
MLFLOW_TRACKING_URI = http://0.0.0.0:85
2+
MLFLOW_S3_ENDPOINT_URL = http://localhost:8002
3+
MLFLOW_TRACKING_URI = http://localhost:85
44
ARTIFACT_PATH = s3://mlflow
55

66
[xnat]
77
USER = admin
88
PASSWORD = admin
99
PROJECT = hipposeg
1010
VERIFY = false
11-
SERVER = http://localhost
11+
SERVER = http://localhost/
1212

1313
[project]
1414
NAME = hipposeg

project/DataModule.py

Lines changed: 110 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from typing import List, Optional
22

33
import pytorch_lightning
4-
from mlops.data.tools.tools import xnat_build_dataset
5-
from mlops.data.transforms.LoadImageXNATd import LoadImageXNATd
4+
from project.utils.tools import xnat_build_dataset
5+
from project.transforms.LoadImageXNATd import LoadImageXNATd
66
from monai.data import CacheDataset, pad_list_data_collate
77
from monai.transforms import (
88
EnsureChannelFirstd,
@@ -20,8 +20,16 @@
2020

2121
class DataModule(pytorch_lightning.LightningDataModule):
2222

23-
def __init__(self, data_dir: str = './', xnat_configuration: dict = None, batch_size: int = 1, num_workers: int = 4,
24-
test_fraction: float = 0.1, train_val_ratio: float = 0.2, test_batch: int = -1):
23+
def __init__(
24+
self,
25+
data_dir: str = "./",
26+
xnat_configuration: dict = None,
27+
batch_size: int = 1,
28+
num_workers: int = 4,
29+
test_fraction: float = 0.1,
30+
train_val_ratio: float = 0.2,
31+
test_batch: int = -1,
32+
):
2533

2634
super().__init__()
2735
self.data_dir = data_dir
@@ -38,18 +46,24 @@ def setup(self, stage: Optional[str] = None):
3846
:param stage:
3947
:return:
4048
"""
41-
# list of tuples defining action functions and their data keys
42-
actions = [(self.fetch_image, 'image'),
43-
(self.fetch_label, 'label')]
49+
actions = [(self.fetch_image, "image"), (self.fetch_label, "label")]
4450

45-
self.xnat_data_list = xnat_build_dataset(self.xnat_configuration, actions=actions, test_batch=self.test_batch)
51+
self.xnat_data_list = xnat_build_dataset(
52+
self.xnat_configuration, actions=actions, test_batch=self.test_batch
53+
)
4654

47-
self.train_samples, self.valid_samples = random_split(self.xnat_data_list, [1-self.train_val_ratio, self.train_val_ratio])
55+
self.train_samples, self.valid_samples = random_split(
56+
self.xnat_data_list, [1 - self.train_val_ratio, self.train_val_ratio]
57+
)
4858

4959
self.train_transforms = Compose(
5060
[
51-
LoadImageXNATd(keys=['data'], xnat_configuration=self.xnat_configuration,
52-
image_loader=LoadImage(image_only=True), expected_filetype_ext='.nii.gz'),
61+
LoadImageXNATd(
62+
keys=["data"],
63+
xnat_configuration=self.xnat_configuration,
64+
image_loader=LoadImage(image_only=True),
65+
expected_filetype_ext=".nii.gz",
66+
),
5367
EnsureChannelFirstd(keys=["image", "label"]),
5468
Spacingd(
5569
keys=["image", "label"],
@@ -62,8 +76,12 @@ def setup(self, stage: Optional[str] = None):
6276

6377
self.val_transforms = Compose(
6478
[
65-
LoadImageXNATd(keys=['data'], xnat_configuration=self.xnat_configuration,
66-
image_loader=LoadImage(image_only=True), expected_filetype_ext='.nii.gz'),
79+
LoadImageXNATd(
80+
keys=["data"],
81+
xnat_configuration=self.xnat_configuration,
82+
image_loader=LoadImage(image_only=True),
83+
expected_filetype_ext=".nii.gz",
84+
),
6785
EnsureChannelFirstd(keys=["image", "label"]),
6886
Spacingd(
6987
keys=["image", "label"],
@@ -74,8 +92,12 @@ def setup(self, stage: Optional[str] = None):
7492
]
7593
)
7694

77-
self.train_dataset = CacheDataset(data=self.train_samples, transform=self.train_transforms)
78-
self.val_dataset = CacheDataset(data=self.valid_samples, transform=self.val_transforms)
95+
self.train_dataset = CacheDataset(
96+
data=self.train_samples, transform=self.train_transforms
97+
)
98+
self.val_dataset = CacheDataset(
99+
data=self.valid_samples, transform=self.val_transforms
100+
)
79101

80102
def prepare_data(self, *args, **kwargs):
81103
pass
@@ -85,18 +107,27 @@ def train_dataloader(self):
85107
Define train dataloader
86108
:return:
87109
"""
88-
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True,
89-
num_workers=self.num_workers, collate_fn=pad_list_data_collate,
90-
pin_memory=is_available())
110+
return DataLoader(
111+
self.train_dataset,
112+
batch_size=self.batch_size,
113+
shuffle=True,
114+
num_workers=self.num_workers,
115+
collate_fn=pad_list_data_collate,
116+
pin_memory=is_available(),
117+
)
91118

92119
def val_dataloader(self):
93120
"""
94121
Define validation dataloader
95122
:return:
96123
"""
97-
return DataLoader(self.val_dataset, batch_size=1, num_workers=self.num_workers, collate_fn=pad_list_data_collate,
98-
pin_memory=is_available())
99-
124+
return DataLoader(
125+
self.val_dataset,
126+
batch_size=1,
127+
num_workers=self.num_workers,
128+
collate_fn=pad_list_data_collate,
129+
pin_memory=is_available(),
130+
)
100131

101132
@staticmethod
102133
def fetch_image(subject_data: SubjectData = None) -> List[ImageScanData]:
@@ -105,10 +136,35 @@ def fetch_image(subject_data: SubjectData = None) -> List[ImageScanData]:
105136
along with the 'key' that it will be used to access it.
106137
"""
107138
output = []
108-
for exp in subject_data.experiments:
109-
for scan in subject_data.experiments[exp].scans:
110-
if 'image' in subject_data.experiments[exp].scans[scan].id.lower():
111-
output.append(subject_data.experiments[exp].scans[scan])
139+
140+
if hasattr(subject_data.experiments, "values"):
141+
experiments = subject_data.experiments.values()
142+
else:
143+
experiments = [
144+
subject_data.experiments[exp_id]
145+
for exp_id in subject_data.experiments.keys()
146+
]
147+
148+
for experiment in experiments:
149+
try:
150+
if hasattr(experiment.scans, "values"):
151+
scans = experiment.scans.values()
152+
else:
153+
scans = [
154+
experiment.scans[scan_id] for scan_id in experiment.scans.keys()
155+
]
156+
157+
for scan_obj in scans:
158+
try:
159+
scan_name = scan_obj.id.lower()
160+
if "image" in scan_name:
161+
output.append(scan_obj)
162+
except Exception:
163+
continue
164+
165+
except Exception:
166+
continue
167+
112168
if len(output) > 1:
113169
raise TypeError
114170
return output
@@ -120,10 +176,35 @@ def fetch_label(subject_data: SubjectData = None) -> List[ImageScanData]:
120176
along with the 'key' that it will be used to access it.
121177
"""
122178
output = []
123-
for exp in subject_data.experiments:
124-
for scan in subject_data.experiments[exp].scans:
125-
if 'label' in subject_data.experiments[exp].scans[scan].id.lower():
126-
output.append(subject_data.experiments[exp].scans[scan])
179+
180+
if hasattr(subject_data.experiments, "values"):
181+
experiments = subject_data.experiments.values()
182+
else:
183+
experiments = [
184+
subject_data.experiments[exp_id]
185+
for exp_id in subject_data.experiments.keys()
186+
]
187+
188+
for experiment in experiments:
189+
try:
190+
if hasattr(experiment.scans, "values"):
191+
scans = experiment.scans.values()
192+
else:
193+
scans = [
194+
experiment.scans[scan_id] for scan_id in experiment.scans.keys()
195+
]
196+
197+
for scan_obj in scans:
198+
try:
199+
scan_name = scan_obj.id.lower()
200+
if "label" in scan_name:
201+
output.append(scan_obj)
202+
except Exception:
203+
continue
204+
205+
except Exception:
206+
continue
207+
127208
if len(output) > 1:
128209
raise TypeError
129210
return output

project/Network.py

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,18 @@ def __init__(self, **kwargs):
3333
norm=Norm.BATCH,
3434
)
3535
self.loss_function = DiceLoss(to_onehot_y=True, softmax=True)
36-
self.post_pred = Compose([EnsureType("tensor", device=torch.device("cpu")), AsDiscrete(argmax=True, to_onehot=3)])
37-
self.post_label = Compose([EnsureType("tensor", device=torch.device("cpu")), AsDiscrete(to_onehot=3)])
38-
self.dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
36+
self.post_pred = Compose(
37+
[
38+
EnsureType("tensor", device=torch.device("cpu")),
39+
AsDiscrete(argmax=True, to_onehot=3),
40+
]
41+
)
42+
self.post_label = Compose(
43+
[EnsureType("tensor", device=torch.device("cpu")), AsDiscrete(to_onehot=3)]
44+
)
45+
self.dice_metric = DiceMetric(
46+
include_background=False, reduction="mean", get_not_nans=False
47+
)
3948
self.best_val_dice = 0
4049
self.best_val_epoch = 0
4150

@@ -52,31 +61,50 @@ def training_step(self, batch, batch_idx):
5261
images, labels = batch["image"], batch["label"]
5362
output = self.forward(images)
5463
loss = self.loss_function(output, labels)
55-
self.log('train_loss', loss.item(), on_step=True, on_epoch=True, logger=True, sync_dist=True)
64+
self.log(
65+
"train_loss",
66+
loss.item(),
67+
on_step=True,
68+
on_epoch=True,
69+
logger=True,
70+
sync_dist=True,
71+
)
5672
return {"loss": loss}
5773

5874
def validation_step(self, batch, batch_idx):
5975
images, labels = batch["image"], batch["label"]
6076
roi_size = (-1, -1, -1)
6177
sw_batch_size = 1
6278
outputs = sliding_window_inference(
63-
images, roi_size, sw_batch_size, self.forward)
79+
images, roi_size, sw_batch_size, self.forward
80+
)
6481
loss = self.loss_function(outputs, labels)
6582
outputs = [self.post_pred(i) for i in decollate_batch(outputs)]
6683
labels = [self.post_label(i) for i in decollate_batch(labels)]
6784
self.dice_metric(y_pred=outputs, y=labels)
68-
return {"val_loss": loss, "val_number": len(outputs)}
85+
output_dict = {"val_loss": loss, "val_number": len(outputs)}
86+
87+
if not hasattr(self, "validation_step_outputs"):
88+
self.validation_step_outputs = []
89+
self.validation_step_outputs.append(output_dict)
6990

70-
def validation_epoch_end(self, outputs):
91+
return output_dict
92+
93+
def on_validation_epoch_end(self):
7194
val_loss, num_items = 0, 0
95+
outputs = getattr(self, "validation_step_outputs", [])
7296
for output in outputs:
7397
val_loss += output["val_loss"].sum().item()
7498
num_items += output["val_number"]
7599
mean_val_dice = self.dice_metric.aggregate().item()
76100
self.dice_metric.reset()
77101
mean_val_loss = torch.tensor(val_loss / num_items)
78-
self.log_dict({
79-
"mean_val_dice": mean_val_dice,
80-
"mean_val_loss": mean_val_loss,
81-
})
102+
self.log_dict(
103+
{
104+
"mean_val_dice": mean_val_dice,
105+
"mean_val_loss": mean_val_loss,
106+
}
107+
)
108+
# Clear the stored outputs
109+
self.validation_step_outputs = []
82110
return

0 commit comments

Comments
 (0)