-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathexercise_3_unet.py
More file actions
116 lines (97 loc) · 2.84 KB
/
exercise_3_unet.py
File metadata and controls
116 lines (97 loc) · 2.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# import numpy as np
import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import torchvision.models as models
# !pip install monai
import monai
from PIL import Image
import torchvision.transforms.functional as F
torch.manual_seed(42)
device = torch.device("cpu")
if torch.backends.mps.is_available():
device = torch.device("mps")
if torch.cuda.is_available():
device = torch.device("cuda")
print(f"Using device: {device}")
model = monai.networks.nets.UNet(
spatial_dims=2,
in_channels=3,
out_channels=1,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
)
# Model pretrained on imagenet
# Larger transformation pipeline for imagenet
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomRotation(20),
transforms.Normalize((0.5,), (0.5,)),
transforms.Resize((128, 128)),
]
)
try:
datasets.SBDataset(
root="./data/sbd",
image_set="train",
download=True,
mode="segmentation",
)
datasets.SBDataset(
root="./data/sbd",
image_set="val",
download=True,
mode="segmentation",
)
except Exception as e:
print(e)
train_data = datasets.SBDataset(
root="./data/sbd",
image_set="train",
transforms=lambda x, y: [transform(x), transform(y)],
# download=True,
mode="segmentation",
)
test_data = datasets.SBDataset(
root="./data/sbd",
image_set="val",
transforms=lambda x, y: [transform(x), transform(y)],
# download=True,
mode="segmentation",
)
train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
X_test = F.to_pil_image(train_data[0][0]).save("X_test.png")
y_test = F.to_pil_image(train_data[0][1]).save("y_test.png")
learning_rate = 1e-4
epochs = 100
# Dice is a log loss function so negative values are expected
loss_function = monai.losses.DiceLoss(softmax=True)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
loss_fn = monai.losses.DiceLoss(sigmoid=True)
assert model(train_data[0][0].unsqueeze(0)).shape == (
1,
1,
128,
128,
), "Model output shape is correct"
model = model.to(device)
for epoch in tqdm(range(epochs)):
size = len(train_dataloader.dataset)
for batch, (X, y) in enumerate(train_dataloader):
X = X.to(device)
y = y.to(device)
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 10 == 0:
loss, current = loss.item(), batch * len(X)
print(f"Epoch: {epoch+1}, Loss: {loss:.6f}, Progress: [{current}/{size}]")