Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,21 @@ In other words, we require a rotation and translation, or more formally
$$ \mathbf{R}\mathbf{x} + \mathbf{o} = \mathbf{g} .$$

With a rotation matrix $\mathbf{R} \in \mathbb{R}^{3,3}$, the local coordinate vector $\mathbf{x \in \mathbb{R}^{3}}$, the offset $\mathbf{o} \in \mathbb{R}^{3}$, and the global coordinate line $\mathbf{g}$.
Evaluate this transform for every coordinate box line. Use the `box_lines` function from the
Evaluate this transform for every coordinate box line.

1. Use the `box_lines` function from the
`util.py` module to generate a bounding box at the origin. All points in every line must be transformed using the above relationship.

The region of interest is the overlap of all boxes in the global coordinate system. Use [np.amin](https://numpy.org/doc/stable/reference/generated/numpy.amin.html) and [np.amax](https://numpy.org/doc/stable/reference/generated/numpy.amax.html) to find roi-box points $\mathbf{r} \in \mathbb{R}^{3}$.
2. The region of interest is the overlap of all boxes in the global coordinate system. Use [np.amin](https://numpy.org/doc/stable/reference/generated/numpy.amin.html) and [np.amax](https://numpy.org/doc/stable/reference/generated/numpy.amax.html) to find roi-box points $\mathbf{r} \in \mathbb{R}^{3}$.

To obtain array indices, transform all box points back into the local system. Or, more formally:
3. To obtain array indices, transform all box points back into the local system. Or, more formally:

$$ \mathbf{R}^{-1} \mathbf{r} - \mathbf{o} = \mathbf{x}_{\text{roi}} $$
```math
\mathbf{R}^{-1} \mathbf{r} - \mathbf{o} = \mathbf{x}_{\text{roi}}
```

With the inverse of the rotation matrix $\mathbf{R}^{-1}$ use [np.linalg.inv](https://numpy.org/doc/stable/reference/generated/numpy.linalg.inv.html) to compute it. $\mathbf{x}_{\text{roi}} \in \mathbb{R}^{3}$ is a point on the boundary of the local roi-box we seek.
Transform all boundary points.
With the inverse of the rotation matrix $\mathbf{R}^{-1}$ use [np.linalg.inv](https://numpy.org/doc/stable/reference/generated/numpy.linalg.inv.html) to compute it. $\mathbf{x}_{\text{roi}} \in \mathbb{R}^{3}$ is a point on the boundary of the local roi-box we seek.
Transform all boundary points.

Using the smallest and largest coordinate values of the roi box in
local coordinates now allows array indexing. Following Meyer et al. we discard all but the axial `t2w` scans.
Expand All @@ -76,19 +80,25 @@ Test your implementation by setting the if-condition wrapping the plotting utili
### Task 3: Implement the UNet.
Navigate to the `train.py` file in the `src` folder.
Finish the `UNet3D` class, as discussed in the lecture.
Use [torch.nn.Conv3d](https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html), [torch.nn.ReLU](https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html), [torch.nn.MaxPool3d](https://pytorch.org/docs/stable/generated/torch.nn.MaxPool3d.html) and [th.nn.UpSample](https://pytorch.org/docs/stable/generated/torch.nn.Upsample.html) to build the model. For upsampling, we suggest to use `mode='nearest'` algorithm for reproducibility purpose.

1. In the `__init__` function, you need to define the building blocks of the UNet architecture. To do this you can use [torch.nn.Sequential](https://docs.pytorch.org/docs/stable/generated/torch.nn.Sequential.html) to stack multiple layers together, so that you can call them with a single forward pass. For defining the blocks look at the slide from the lecture to see how many layers and which types of layers and dimensions you need to use.
Use [torch.nn.Conv3d](https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html), [torch.nn.ReLU](https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html) and [torch.nn.MaxPool3d](https://pytorch.org/docs/stable/generated/torch.nn.MaxPool3d.html) to build the blocks.

2. Next, go to the `__upsize` function and implement the upsampling using [th.nn.Upsample](https://pytorch.org/docs/stable/generated/torch.nn.Upsample.html) which is needed for the second half of the UNet. For upsampling, we suggest to use `mode='nearest'` algorithm for reproducibility purpose.

3. Finally, implement the `forward` function to define the forward pass of the UNet. You can use the building blocks you defined in the `__init__` function and the upsampling function you implemented in the previous step to build the forward pass. Remember to use skip connections as shown in the lecture slides.

### Task 4: Implement the focal-loss.

Open the `util.py` module in `src` and implement the `softmax_focal_loss` function as discussed in the lecture:

$$\mathcal{L}(\mathbf{o},\mathbf{I})=-\mathbf{I}\cdot(1-\sigma_s(\mathbf{o}))^\gamma\cdot\alpha\cdot\ln(\sigma_s(\mathbf{o})) $$

with output logits $\mathbf{o}$, the corresponding labels $\mathbf{I}$ and the softmax function $\sigma_s$.
with output logits $\mathbf{o}$, the corresponding labels $\mathbf{I}$ and the softmax function $\sigma_s$ (`torch.nn.functional.softmax`).

### Task 5: Run and test the training script.

Execute the training script with by running `scripts/train.slurm` (locally or using `sbatch`).
Execute the training script by running `scripts/train.slurm` (locally or using `sbatch`).

After training you can test your model by changing the `checkpoint_name` variable in `src/sample.py` to the desired model checkpoint and running `scripts/test.slurm`.

Expand Down
10 changes: 5 additions & 5 deletions src/meanIoU.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ def compute_iou(preds: th.Tensor, target: th.Tensor) -> th.Tensor:
"""Calculate meanIoU for a given batch.

Args:
preds (jnp.ndarray): Predictions from network
target (jnp.ndarray): Labels
preds (th.Tensor): Predictions from network
target (th.Tensor): Labels

Returns:
jnp.ndarray: Mean Intersection over Union values
th.Tensor: Mean Intersection over Union values
"""
assert preds.shape == target.shape
# TODO: Implement meanIoU
# 6. TODO: Implement meanIoU
return th.tensor(0.0)


Expand Down Expand Up @@ -48,7 +48,7 @@ def compute_iou(preds: th.Tensor, target: th.Tensor) -> th.Tensor:
batched_labels[batch_index],
)
preds = model(imgs)
preds = preds.permute((0, 2, 3, 4, 1))
preds = preds.permute((0, 3, 4, 2, 1))
preds = th.argmax(preds, dim=-1)
ious.append(compute_iou(preds, lbls))

Expand Down
1 change: 1 addition & 0 deletions src/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from util import softmax_focal_loss

if __name__ == "__main__":
# Change the checkpoint name to match the desired model.
checkpoint_name = "./weights/unet_softmaxfl_124.pkl"
device = th.device("cuda") if th.cuda.is_available() else th.device("cpu")
mean = th.Tensor([206.12558]).to(device)
Expand Down
15 changes: 9 additions & 6 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(self):
input_feat = 1
init_feat = 16
out_neurons = 5
# 3.1 TODO: Define the building blocks of the UNet architecture.
# TODO: Initialize downscaling blocks
# TODO: Initialize upscaling blocks

Expand All @@ -86,7 +87,12 @@ def forward(self, x: th.Tensor) -> th.Tensor:
Returns:
th.Tensor: Segmented output.
"""
# TODO: Implement 3D UNet as discussed in the lecture
x = x.permute(
(0, 1, 4, 2, 3)
) # Permute such that x has shape (batch, channels, depth, height, width)

# 3.3 TODO: Implement the forward pass of the UNet using the building blocks
# defined in the __init__ function and the upsampling function.
return th.tensor(0.0)

def __upsize(self, input_: th.Tensor) -> th.Tensor:
Expand All @@ -98,7 +104,7 @@ def __upsize(self, input_: th.Tensor) -> th.Tensor:
Returns:
th.Tensor: Upsampled image.
"""
# TODO: Upsample the height and width using th.nn.Upsample with nearest mode.
# 3.2 TODO: Upsample the height and width using th.nn.Upsample with nearest mode.
return th.tensor(0.0)


Expand Down Expand Up @@ -170,7 +176,6 @@ def train():
preds, labels_y, th.ones((preds.shape[-1])).to(device)
)
)
# loss = loss_fn(preds, labels_y.type(th.LongTensor).to(device))
loss.backward()
opt.step()

Expand All @@ -192,13 +197,11 @@ def train():
val_out, label_val, th.ones((val_out.shape[-1])).to(device)
)
)
# label_val = val_data["annotation"].to(device)
# val_loss = loss_fn(val_out, label_val.type(th.LongTensor).to(device))

val_loss_list.append((e, val_loss.item()))
writer.write_scalars(e, {"validation_loss": val_loss.item()})
val_out = val_out.cpu()
print(f"Validation loss: {val_loss.item()}")
# val_out = val_out.permute((0, 2, 3, 4, 1))
for i in range(len(val_keys)):
writer.write_images(
e,
Expand Down
18 changes: 6 additions & 12 deletions src/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,14 @@ def compute_roi(images: Tuple[Image, Image, Image]):
rects = []
for pos, size in enumerate(sizes):
lines = box_lines(size)
# TODO: Rotate and shift the lines.
# 2.1 TODO: Rotate and shift the lines.
rotated = []
shifted = []
rects.append(shifted)

# find the intersection.
rects_stacked = np.stack(rects) # Had to rename because of mypy
# TODO: Find the axis maxima and minima
rects_stacked = np.stack(rects)
# 2.2 TODO: Find the axis maxima and minima
bbs = [
(
np.zeros_like(rect[0, 0]),
Expand All @@ -164,7 +164,7 @@ def compute_roi(images: Tuple[Image, Image, Image]):
rects_stacked = np.concatenate([rects_stacked, np.expand_dims(roi_bb_lines, 0)])

spacings = [image.GetSpacing() for image in images]
# compute roi coordinates in image space.
# 2.3 TODO: compute roi coordinates in image space.
img_coord_rois = [
(
np.zeros_like(roi_bb[0]), # TODO: Implement me
Expand Down Expand Up @@ -254,14 +254,8 @@ def softmax_focal_loss(
gamma: float = 2,
) -> th.Tensor:
"""Compute a softmax focal loss."""
# chex.assert_type([logits], float)
# # see also the original sigmoid implementation at:
# # https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/focal_loss.py
# chex.assert_type([logits], float)
# focus = jnp.power(1.0 - jax.nn.softmax(logits, axis=-1), gamma)
# loss = -labels * focus * alpha * jax.nn.log_softmax(logits, axis=-1)
# return jnp.sum(loss, axis=-1)

logits = logits.float()
labels = labels.float()
# TODO: Implement softmax focal loss.
# 4. TODO: Implement softmax focal loss.
return th.tensor(0.0)
Loading