-
Notifications
You must be signed in to change notification settings - Fork 1.4k
feat(losses/metrics): implement ignore_index support across dice and … #8757
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
8e5fe51
a1b0a4f
f2caaf8
d075009
941a73b
0f6e05a
a1f6ef4
f01cbc4
af83422
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -51,6 +51,7 @@ def __init__( | |
| smooth_dr: float = 1e-5, | ||
| batch: bool = False, | ||
| soft_label: bool = False, | ||
| ignore_index: int | None = None, | ||
| ) -> None: | ||
| """ | ||
| Args: | ||
|
|
@@ -77,6 +78,7 @@ def __init__( | |
| before any `reduction`. | ||
| soft_label: whether the target contains non-binary values (soft labels) or not. | ||
| If True a soft label formulation of the loss will be used. | ||
| ignore_index: index of the class to ignore during calculation. | ||
|
|
||
| Raises: | ||
| TypeError: When ``other_act`` is not an ``Optional[Callable]``. | ||
|
|
@@ -101,6 +103,7 @@ def __init__( | |
| self.smooth_dr = float(smooth_dr) | ||
| self.batch = batch | ||
| self.soft_label = soft_label | ||
| self.ignore_index = ignore_index | ||
|
|
||
| def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
|
|
@@ -129,8 +132,21 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
| if n_pred_ch == 1: | ||
| warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") | ||
| else: | ||
| original_target = target | ||
| target = one_hot(target, num_classes=n_pred_ch) | ||
|
Comment on lines
+135
to
136
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sanitize ignored labels before
Also applies to: 138-148 🤖 Prompt for AI Agents |
||
|
|
||
| if self.ignore_index is not None: | ||
| mask_src = original_target if self.to_onehot_y and n_pred_ch > 1 else target | ||
|
|
||
| if mask_src.shape[1] == 1: | ||
| mask = (mask_src != self.ignore_index).to(input.dtype) | ||
| else: | ||
| # Fallback for cases where target is already one-hot | ||
| mask = (1.0 - mask_src[:, self.ignore_index : self.ignore_index + 1]).to(input.dtype) | ||
|
|
||
| input = input * mask | ||
| target = target * mask | ||
|
|
||
| if not self.include_background: | ||
| if n_pred_ch == 1: | ||
| warnings.warn("single channel prediction, `include_background=False` ignored.") | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -39,19 +39,22 @@ def __init__( | |
| gamma: float = 0.75, | ||
| epsilon: float = 1e-7, | ||
| reduction: LossReduction | str = LossReduction.MEAN, | ||
| ignore_index: int | None = None, | ||
| ) -> None: | ||
| """ | ||
| Args: | ||
| to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. | ||
| delta : weight of the background. Defaults to 0.7. | ||
| gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. | ||
| epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7. | ||
| ignore_index: class index to ignore from the loss computation. | ||
| """ | ||
| super().__init__(reduction=LossReduction(reduction).value) | ||
| self.to_onehot_y = to_onehot_y | ||
| self.delta = delta | ||
| self.gamma = gamma | ||
| self.epsilon = epsilon | ||
| self.ignore_index = ignore_index | ||
|
|
||
| def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: | ||
| n_pred_ch = y_pred.shape[1] | ||
|
|
@@ -65,22 +68,33 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: | |
| if y_true.shape != y_pred.shape: | ||
| raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") | ||
|
|
||
| # clip the prediction to avoid NaN | ||
| # Handle ignore_index: | ||
| mask = torch.ones_like(y_true) | ||
| if self.ignore_index is not None: | ||
| # Identify valid pixels: where at least one channel is 1 | ||
| spatial_mask = (torch.sum(y_true, dim=1, keepdim=True) > 0).float() | ||
| mask = spatial_mask.expand_as(y_true) | ||
| y_pred = y_pred * mask | ||
|
Comment on lines
+71
to
+77
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
🤖 Prompt for AI Agents |
||
|
|
||
| y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) | ||
| axis = list(range(2, len(y_pred.shape))) | ||
|
|
||
| # Calculate true positives (tp), false negatives (fn) and false positives (fp) | ||
| tp = torch.sum(y_true * y_pred, dim=axis) | ||
| fn = torch.sum(y_true * (1 - y_pred), dim=axis) | ||
| fp = torch.sum((1 - y_true) * y_pred, dim=axis) | ||
| fn = torch.sum(y_true * (1 - y_pred) * mask, dim=axis) | ||
| fp = torch.sum((1 - y_true) * y_pred * mask, dim=axis) | ||
| dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon) | ||
|
|
||
| # Calculate losses separately for each class, enhancing both classes | ||
| back_dice = 1 - dice_class[:, 0] | ||
| fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma) | ||
| fore_dice = torch.pow(1 - dice_class[:, 1], 1 - self.gamma) | ||
|
|
||
| # Average class scores | ||
| loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1)) | ||
| loss = torch.stack([back_dice, fore_dice], dim=-1) | ||
| if self.reduction == LossReduction.MEAN.value: | ||
| return torch.mean(loss) | ||
| if self.reduction == LossReduction.SUM.value: | ||
| return torch.sum(loss) | ||
| return loss | ||
|
|
||
|
|
||
|
|
@@ -103,26 +117,34 @@ def __init__( | |
| gamma: float = 2, | ||
| epsilon: float = 1e-7, | ||
| reduction: LossReduction | str = LossReduction.MEAN, | ||
| ignore_index: int | None = None, | ||
| ): | ||
| """ | ||
| Args: | ||
| to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. | ||
| delta : weight of the background. Defaults to 0.7. | ||
| gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 2. | ||
| epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7. | ||
| ignore_index: class index to ignore from the loss computation. | ||
| """ | ||
| super().__init__(reduction=LossReduction(reduction).value) | ||
| self.to_onehot_y = to_onehot_y | ||
| self.delta = delta | ||
| self.gamma = gamma | ||
| self.epsilon = epsilon | ||
| self.ignore_index = ignore_index | ||
|
|
||
| def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: | ||
| n_pred_ch = y_pred.shape[1] | ||
|
|
||
| if self.to_onehot_y: | ||
| if n_pred_ch == 1: | ||
| warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") | ||
| elif self.ignore_index is not None: | ||
| mask = (y_true != self.ignore_index).float() | ||
| y_true_clean = torch.where(y_true == self.ignore_index, 0, y_true) | ||
| y_true = one_hot(y_true_clean, num_classes=n_pred_ch) | ||
| y_true = y_true * mask | ||
| else: | ||
| y_true = one_hot(y_true, num_classes=n_pred_ch) | ||
|
|
||
|
|
@@ -132,13 +154,24 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: | |
| y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) | ||
| cross_entropy = -y_true * torch.log(y_pred) | ||
|
|
||
| if self.ignore_index is not None: | ||
| spatial_mask = (torch.sum(y_true, dim=1, keepdim=True) > 0).float() | ||
| cross_entropy = cross_entropy * spatial_mask | ||
|
|
||
|
Comment on lines
+157
to
+160
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The default With 🤖 Prompt for AI Agents |
||
| back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0] | ||
| back_ce = (1 - self.delta) * back_ce | ||
|
|
||
| fore_ce = cross_entropy[:, 1] | ||
| fore_ce = self.delta * fore_ce | ||
|
|
||
| loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1)) | ||
| loss = torch.stack([back_ce, fore_ce], dim=1) # [B, 2, H, W] | ||
| if self.reduction == LossReduction.MEAN.value: | ||
| if self.ignore_index is not None: | ||
| # Normalize by the number of non-ignored pixels | ||
| return loss.sum() / spatial_mask.sum().clamp(min=1e-5) | ||
| return loss.mean() | ||
|
Comment on lines
+167
to
+172
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
🤖 Prompt for AI Agents |
||
| if self.reduction == LossReduction.SUM.value: | ||
| return loss.sum() | ||
| return loss | ||
|
|
||
|
|
||
|
|
@@ -162,6 +195,7 @@ def __init__( | |
| gamma: float = 0.5, | ||
| delta: float = 0.7, | ||
| reduction: LossReduction | str = LossReduction.MEAN, | ||
| ignore_index: int | None = None, | ||
| ): | ||
| """ | ||
| Args: | ||
|
|
@@ -170,8 +204,7 @@ def __init__( | |
| weight : weight for each loss function. Defaults to 0.5. | ||
| gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5. | ||
| delta : weight of the background. Defaults to 0.7. | ||
|
|
||
|
|
||
| ignore_index: class index to ignore from the loss computation. | ||
|
|
||
| Example: | ||
| >>> import torch | ||
|
|
@@ -187,10 +220,12 @@ def __init__( | |
| self.gamma = gamma | ||
| self.delta = delta | ||
| self.weight: float = weight | ||
| self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta) | ||
| self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta) | ||
| self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta, ignore_index=ignore_index) | ||
| self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss( | ||
| gamma=self.gamma, delta=self.delta, ignore_index=ignore_index | ||
| ) | ||
| self.ignore_index = ignore_index | ||
|
|
||
| # TODO: Implement this function to support multiple classes segmentation | ||
| def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| Args: | ||
|
|
@@ -207,25 +242,42 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: | |
| ValueError: When num_classes | ||
| ValueError: When the number of classes entered does not match the expected number | ||
| """ | ||
| if y_pred.shape != y_true.shape: | ||
| raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") | ||
|
|
||
| if len(y_pred.shape) != 4 and len(y_pred.shape) != 5: | ||
| raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}") | ||
|
|
||
| # Transform binary inputs to 2-channel space | ||
| if y_pred.shape[1] == 1: | ||
| y_pred = one_hot(y_pred, num_classes=self.num_classes) | ||
| y_true = one_hot(y_true, num_classes=self.num_classes) | ||
|
|
||
| if torch.max(y_true) != self.num_classes - 1: | ||
| raise ValueError(f"Please make sure the number of classes is {self.num_classes - 1}") | ||
| y_pred = torch.cat([1 - y_pred, y_pred], dim=1) | ||
|
|
||
| n_pred_ch = y_pred.shape[1] | ||
| # Move one_hot conversion OUTSIDE the if y_pred.shape[1] == 1 block | ||
| if self.to_onehot_y: | ||
| if n_pred_ch == 1: | ||
| warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") | ||
| if self.ignore_index is not None: | ||
| mask = (y_true != self.ignore_index).float() | ||
| y_true_clean = torch.where(y_true == self.ignore_index, 0, y_true) | ||
| y_true = one_hot(y_true_clean, num_classes=self.num_classes) | ||
| # Keep the channel-wise mask | ||
| y_true = y_true * mask | ||
| else: | ||
| y_true = one_hot(y_true, num_classes=n_pred_ch) | ||
| y_true = one_hot(y_true, num_classes=self.num_classes) | ||
|
|
||
| # Check if shapes match | ||
| if y_true.shape[1] == 1 and y_pred.shape[1] == 2: | ||
| if self.ignore_index is not None: | ||
| # Create mask for valid pixels | ||
| mask = (y_true != self.ignore_index).float() | ||
| # Set ignore_index values to 0 before conversion | ||
| y_true_clean = y_true * mask | ||
| # Convert to 2-channel | ||
| y_true = torch.cat([1 - y_true_clean, y_true_clean], dim=1) | ||
| # Apply mask to both channels so ignored pixels are all zeros | ||
| y_true = y_true * mask | ||
| else: | ||
| y_true = torch.cat([1 - y_true, y_true], dim=1) | ||
|
|
||
| if y_true.shape != y_pred.shape: | ||
| raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") | ||
| if self.ignore_index is None and torch.max(y_true) > self.num_classes - 1: | ||
| raise ValueError(f"Invalid class index found. Maximum class should be {self.num_classes - 1}") | ||
|
|
||
| asy_focal_loss = self.asy_focal_loss(y_pred, y_true) | ||
| asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mask the focal loss, not the logits.
In the sigmoid path,
input = 0andtarget = 0still produce a positive BCE/focal term, so ignored voxels are turned into background loss instead of being excluded.loss.mean()also continues to divide by ignored elements. Apply the mask to the elementwiseloss, then reduce over valid elements only.Also applies to: 183-185, 207-218
🤖 Prompt for AI Agents