diff --git a/README.md b/README.md index 920e323..7da7731 100644 --- a/README.md +++ b/README.md @@ -56,17 +56,21 @@ In other words, we require a rotation and translation, or more formally $$ \mathbf{R}\mathbf{x} + \mathbf{o} = \mathbf{g} .$$ With a rotation matrix $\mathbf{R} \in \mathbb{R}^{3,3}$, the local coordinate vector $\mathbf{x \in \mathbb{R}^{3}}$, the offset $\mathbf{o} \in \mathbb{R}^{3}$, and the global coordinate line $\mathbf{g}$. -Evaluate this transform for every coordinate box line. Use the `box_lines` function from the +Evaluate this transform for every coordinate box line. + +1. Use the `box_lines` function from the `util.py` module to generate a bounding box at the origin. All points in every line must be transformed using the above relationship. -The region of interest is the overlap of all boxes in the global coordinate system. Use [np.amin](https://numpy.org/doc/stable/reference/generated/numpy.amin.html) and [np.amax](https://numpy.org/doc/stable/reference/generated/numpy.amax.html) to find roi-box points $\mathbf{r} \in \mathbb{R}^{3}$. +2. The region of interest is the overlap of all boxes in the global coordinate system. Use [np.amin](https://numpy.org/doc/stable/reference/generated/numpy.amin.html) and [np.amax](https://numpy.org/doc/stable/reference/generated/numpy.amax.html) to find roi-box points $\mathbf{r} \in \mathbb{R}^{3}$. -To obtain array indices, transform all box points back into the local system. Or, more formally: +3. To obtain array indices, transform all box points back into the local system. Or, more formally: -$$ \mathbf{R}^{-1} \mathbf{r} - \mathbf{o} = \mathbf{x}_{\text{roi}} $$ + ```math + \mathbf{R}^{-1} \mathbf{r} - \mathbf{o} = \mathbf{x}_{\text{roi}} + ``` -With the inverse of the rotation matrix $\mathbf{R}^{-1}$ use [np.linalg.inv](https://numpy.org/doc/stable/reference/generated/numpy.linalg.inv.html) to compute it. $\mathbf{x}_{\text{roi}} \in \mathbb{R}^{3}$ is a point on the boundary of the local roi-box we seek. -Transform all boundary points. + With the inverse of the rotation matrix $\mathbf{R}^{-1}$ use [np.linalg.inv](https://numpy.org/doc/stable/reference/generated/numpy.linalg.inv.html) to compute it. $\mathbf{x}_{\text{roi}} \in \mathbb{R}^{3}$ is a point on the boundary of the local roi-box we seek. + Transform all boundary points. Using the smallest and largest coordinate values of the roi box in local coordinates now allows array indexing. Following Meyer et al. we discard all but the axial `t2w` scans. @@ -76,7 +80,13 @@ Test your implementation by setting the if-condition wrapping the plotting utili ### Task 3: Implement the UNet. Navigate to the `train.py` file in the `src` folder. Finish the `UNet3D` class, as discussed in the lecture. -Use [torch.nn.Conv3d](https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html), [torch.nn.ReLU](https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html), [torch.nn.MaxPool3d](https://pytorch.org/docs/stable/generated/torch.nn.MaxPool3d.html) and [th.nn.UpSample](https://pytorch.org/docs/stable/generated/torch.nn.Upsample.html) to build the model. For upsampling, we suggest to use `mode='nearest'` algorithm for reproducibility purpose. + +1. In the `__init__` function, you need to define the building blocks of the UNet architecture. To do this you can use [torch.nn.Sequential](https://docs.pytorch.org/docs/stable/generated/torch.nn.Sequential.html) to stack multiple layers together, so that you can call them with a single forward pass. For defining the blocks look at the slide from the lecture to see how many layers and which types of layers and dimensions you need to use. +Use [torch.nn.Conv3d](https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html), [torch.nn.ReLU](https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html) and [torch.nn.MaxPool3d](https://pytorch.org/docs/stable/generated/torch.nn.MaxPool3d.html) to build the blocks. + +2. Next, go to the `__upsize` function and implement the upsampling using [th.nn.Upsample](https://pytorch.org/docs/stable/generated/torch.nn.Upsample.html) which is needed for the second half of the UNet. For upsampling, we suggest to use `mode='nearest'` algorithm for reproducibility purpose. + +3. Finally, implement the `forward` function to define the forward pass of the UNet. You can use the building blocks you defined in the `__init__` function and the upsampling function you implemented in the previous step to build the forward pass. Remember to use skip connections as shown in the lecture slides. ### Task 4: Implement the focal-loss. @@ -84,11 +94,11 @@ Open the `util.py` module in `src` and implement the `softmax_focal_loss` functi $$\mathcal{L}(\mathbf{o},\mathbf{I})=-\mathbf{I}\cdot(1-\sigma_s(\mathbf{o}))^\gamma\cdot\alpha\cdot\ln(\sigma_s(\mathbf{o})) $$ -with output logits $\mathbf{o}$, the corresponding labels $\mathbf{I}$ and the softmax function $\sigma_s$. +with output logits $\mathbf{o}$, the corresponding labels $\mathbf{I}$ and the softmax function $\sigma_s$ (`torch.nn.functional.softmax`). ### Task 5: Run and test the training script. -Execute the training script with by running `scripts/train.slurm` (locally or using `sbatch`). +Execute the training script by running `scripts/train.slurm` (locally or using `sbatch`). After training you can test your model by changing the `checkpoint_name` variable in `src/sample.py` to the desired model checkpoint and running `scripts/test.slurm`. diff --git a/src/meanIoU.py b/src/meanIoU.py index d3c80ee..f1349c0 100644 --- a/src/meanIoU.py +++ b/src/meanIoU.py @@ -12,14 +12,14 @@ def compute_iou(preds: th.Tensor, target: th.Tensor) -> th.Tensor: """Calculate meanIoU for a given batch. Args: - preds (jnp.ndarray): Predictions from network - target (jnp.ndarray): Labels + preds (th.Tensor): Predictions from network + target (th.Tensor): Labels Returns: - jnp.ndarray: Mean Intersection over Union values + th.Tensor: Mean Intersection over Union values """ assert preds.shape == target.shape - # TODO: Implement meanIoU + # 6. TODO: Implement meanIoU return th.tensor(0.0) @@ -48,7 +48,7 @@ def compute_iou(preds: th.Tensor, target: th.Tensor) -> th.Tensor: batched_labels[batch_index], ) preds = model(imgs) - preds = preds.permute((0, 2, 3, 4, 1)) + preds = preds.permute((0, 3, 4, 2, 1)) preds = th.argmax(preds, dim=-1) ious.append(compute_iou(preds, lbls)) diff --git a/src/sample.py b/src/sample.py index ab3ac41..54be591 100644 --- a/src/sample.py +++ b/src/sample.py @@ -10,6 +10,7 @@ from util import softmax_focal_loss if __name__ == "__main__": + # Change the checkpoint name to match the desired model. checkpoint_name = "./weights/unet_softmaxfl_124.pkl" device = th.device("cuda") if th.cuda.is_available() else th.device("cpu") mean = th.Tensor([206.12558]).to(device) diff --git a/src/train.py b/src/train.py index 76aa211..31bfe54 100644 --- a/src/train.py +++ b/src/train.py @@ -74,6 +74,7 @@ def __init__(self): input_feat = 1 init_feat = 16 out_neurons = 5 + # 3.1 TODO: Define the building blocks of the UNet architecture. # TODO: Initialize downscaling blocks # TODO: Initialize upscaling blocks @@ -86,7 +87,12 @@ def forward(self, x: th.Tensor) -> th.Tensor: Returns: th.Tensor: Segmented output. """ - # TODO: Implement 3D UNet as discussed in the lecture + x = x.permute( + (0, 1, 4, 2, 3) + ) # Permute such that x has shape (batch, channels, depth, height, width) + + # 3.3 TODO: Implement the forward pass of the UNet using the building blocks + # defined in the __init__ function and the upsampling function. return th.tensor(0.0) def __upsize(self, input_: th.Tensor) -> th.Tensor: @@ -98,7 +104,7 @@ def __upsize(self, input_: th.Tensor) -> th.Tensor: Returns: th.Tensor: Upsampled image. """ - # TODO: Upsample the height and width using th.nn.Upsample with nearest mode. + # 3.2 TODO: Upsample the height and width using th.nn.Upsample with nearest mode. return th.tensor(0.0) @@ -170,7 +176,6 @@ def train(): preds, labels_y, th.ones((preds.shape[-1])).to(device) ) ) - # loss = loss_fn(preds, labels_y.type(th.LongTensor).to(device)) loss.backward() opt.step() @@ -192,13 +197,11 @@ def train(): val_out, label_val, th.ones((val_out.shape[-1])).to(device) ) ) - # label_val = val_data["annotation"].to(device) - # val_loss = loss_fn(val_out, label_val.type(th.LongTensor).to(device)) + val_loss_list.append((e, val_loss.item())) writer.write_scalars(e, {"validation_loss": val_loss.item()}) val_out = val_out.cpu() print(f"Validation loss: {val_loss.item()}") - # val_out = val_out.permute((0, 2, 3, 4, 1)) for i in range(len(val_keys)): writer.write_images( e, diff --git a/src/util.py b/src/util.py index e238e8a..447875c 100644 --- a/src/util.py +++ b/src/util.py @@ -137,14 +137,14 @@ def compute_roi(images: Tuple[Image, Image, Image]): rects = [] for pos, size in enumerate(sizes): lines = box_lines(size) - # TODO: Rotate and shift the lines. + # 2.1 TODO: Rotate and shift the lines. rotated = [] shifted = [] rects.append(shifted) # find the intersection. - rects_stacked = np.stack(rects) # Had to rename because of mypy - # TODO: Find the axis maxima and minima + rects_stacked = np.stack(rects) + # 2.2 TODO: Find the axis maxima and minima bbs = [ ( np.zeros_like(rect[0, 0]), @@ -164,7 +164,7 @@ def compute_roi(images: Tuple[Image, Image, Image]): rects_stacked = np.concatenate([rects_stacked, np.expand_dims(roi_bb_lines, 0)]) spacings = [image.GetSpacing() for image in images] - # compute roi coordinates in image space. + # 2.3 TODO: compute roi coordinates in image space. img_coord_rois = [ ( np.zeros_like(roi_bb[0]), # TODO: Implement me @@ -254,14 +254,8 @@ def softmax_focal_loss( gamma: float = 2, ) -> th.Tensor: """Compute a softmax focal loss.""" - # chex.assert_type([logits], float) - # # see also the original sigmoid implementation at: - # # https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/focal_loss.py - # chex.assert_type([logits], float) - # focus = jnp.power(1.0 - jax.nn.softmax(logits, axis=-1), gamma) - # loss = -labels * focus * alpha * jax.nn.log_softmax(logits, axis=-1) - # return jnp.sum(loss, axis=-1) + logits = logits.float() labels = labels.float() - # TODO: Implement softmax focal loss. + # 4. TODO: Implement softmax focal loss. return th.tensor(0.0)