Skip to content
11 changes: 11 additions & 0 deletions monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
batch: bool = False,
weight: Sequence[float] | float | int | torch.Tensor | None = None,
soft_label: bool = False,
ignore_index: int | None = None,
) -> None:
"""
Args:
Expand Down Expand Up @@ -100,6 +101,7 @@ def __init__(
The value/values should be no less than 0. Defaults to None.
soft_label: whether the target contains non-binary values (soft labels) or not.
If True a soft label formulation of the loss will be used.
ignore_index: class index to ignore from the loss computation.

Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
Expand All @@ -122,6 +124,7 @@ def __init__(
self.smooth_nr = float(smooth_nr)
self.smooth_dr = float(smooth_dr)
self.batch = batch
self.ignore_index = ignore_index
weight = torch.as_tensor(weight) if weight is not None else None
self.register_buffer("class_weight", weight)
self.class_weight: None | torch.Tensor
Expand Down Expand Up @@ -163,6 +166,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if self.other_act is not None:
input = self.other_act(input)

mask: torch.Tensor | None = None
if self.ignore_index is not None:
mask = (target != self.ignore_index).float()

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
Expand All @@ -180,6 +187,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if target.shape != input.shape:
raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})")

if mask is not None:
input = input * mask
target = target * mask

# reducing only spatial dimensions (not batch nor channels)
reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist()
if self.batch:
Expand Down
8 changes: 8 additions & 0 deletions monai/losses/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(
weight: Sequence[float] | float | int | torch.Tensor | None = None,
reduction: LossReduction | str = LossReduction.MEAN,
use_softmax: bool = False,
ignore_index: int | None = None,
) -> None:
"""
Args:
Expand All @@ -99,6 +100,7 @@ def __init__(

use_softmax: whether to use softmax to transform the original logits into probabilities.
If True, softmax is used. If False, sigmoid is used. Defaults to False.
ignore_index: class index to ignore from the loss computation.

Example:
>>> import torch
Expand All @@ -124,6 +126,7 @@ def __init__(
weight = torch.as_tensor(weight) if weight is not None else None
self.register_buffer("class_weight", weight)
self.class_weight: None | torch.Tensor
self.ignore_index = ignore_index

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -161,6 +164,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if target.shape != input.shape:
raise ValueError(f"ground truth has different shape ({target.shape}) from input ({input.shape})")

if self.ignore_index is not None:
mask = (target != self.ignore_index).float()
input = input * mask
target = target * mask

Comment on lines +167 to +171
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Mask the focal loss, not the logits.

In the sigmoid path, input = 0 and target = 0 still produce a positive BCE/focal term, so ignored voxels are turned into background loss instead of being excluded. loss.mean() also continues to divide by ignored elements. Apply the mask to the elementwise loss, then reduce over valid elements only.

Also applies to: 183-185, 207-218

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/losses/focal_loss.py` around lines 167 - 171, The current
implementation applies the ignore mask to logits/targets (input/target) which
still yields a nonzero BCE/focal term for masked entries; instead, compute the
elementwise focal/BCE loss normally (in the sigmoid and other paths) and then
multiply that per-element loss by mask to zero out ignored voxels, and when
reducing use sum(loss * mask) / mask.sum() (or clamp minimal denom) so ignored
elements don't contribute to numerator or denominator; update the forward method
in FocalLoss (and the other masked blocks referenced around the sigmoid branch
and the regions you noted at 183-185 and 207-218) to apply mask to the computed
loss tensor and perform masked reduction rather than masking input/target or
using .mean() over masked values.

loss: torch.Tensor | None = None
input = input.float()
target = target.float()
Expand Down
16 changes: 16 additions & 0 deletions monai/losses/tversky.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
smooth_dr: float = 1e-5,
batch: bool = False,
soft_label: bool = False,
ignore_index: int | None = None,
) -> None:
"""
Args:
Expand All @@ -77,6 +78,7 @@ def __init__(
before any `reduction`.
soft_label: whether the target contains non-binary values (soft labels) or not.
If True a soft label formulation of the loss will be used.
ignore_index: index of the class to ignore during calculation.

Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
Expand All @@ -101,6 +103,7 @@ def __init__(
self.smooth_dr = float(smooth_dr)
self.batch = batch
self.soft_label = soft_label
self.ignore_index = ignore_index

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -129,8 +132,21 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
original_target = target
target = one_hot(target, num_classes=n_pred_ch)
Comment on lines +135 to 136
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Sanitize ignored labels before one_hot.

original_target is saved at Line 135, but Line 136 still one-hot encodes the raw labels. If callers use an out-of-range sentinel such as 255, conversion fails before the mask at Lines 138-148 runs. Replace ignored voxels with a valid class ID before one_hot, then derive the mask from original_target.

Also applies to: 138-148

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/losses/tversky.py` around lines 135 - 136, original_target is preserved
but one_hot is called on the raw target which can contain out-of-range sentinel
values (e.g., 255) causing one_hot to fail; before calling one_hot(target,
num_classes=n_pred_ch) replace ignored/sentinel labels in target with a valid
class id (e.g., 0 or another safe class index) then call one_hot using that
sanitized target, and continue to derive the ignore mask from original_target
(not the sanitized target) when constructing the mask and applying it in the
subsequent loss logic (refer to original_target, target, one_hot, n_pred_ch, and
the mask construction used in the block around lines 138-148).


if self.ignore_index is not None:
mask_src = original_target if self.to_onehot_y and n_pred_ch > 1 else target

if mask_src.shape[1] == 1:
mask = (mask_src != self.ignore_index).to(input.dtype)
else:
# Fallback for cases where target is already one-hot
mask = (1.0 - mask_src[:, self.ignore_index : self.ignore_index + 1]).to(input.dtype)

input = input * mask
target = target * mask

if not self.include_background:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `include_background=False` ignored.")
Expand Down
98 changes: 75 additions & 23 deletions monai/losses/unified_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,22 @@ def __init__(
gamma: float = 0.75,
epsilon: float = 1e-7,
reduction: LossReduction | str = LossReduction.MEAN,
ignore_index: int | None = None,
) -> None:
"""
Args:
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7.
ignore_index: class index to ignore from the loss computation.
"""
super().__init__(reduction=LossReduction(reduction).value)
self.to_onehot_y = to_onehot_y
self.delta = delta
self.gamma = gamma
self.epsilon = epsilon
self.ignore_index = ignore_index

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
n_pred_ch = y_pred.shape[1]
Expand All @@ -65,22 +68,33 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
if y_true.shape != y_pred.shape:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")

# clip the prediction to avoid NaN
# Handle ignore_index:
mask = torch.ones_like(y_true)
if self.ignore_index is not None:
# Identify valid pixels: where at least one channel is 1
spatial_mask = (torch.sum(y_true, dim=1, keepdim=True) > 0).float()
mask = spatial_mask.expand_as(y_true)
y_pred = y_pred * mask
Comment on lines +71 to +77
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

ignore_index is not used to build this mask.

sum(y_true, dim=1) > 0 marks every normal one-hot voxel as valid, regardless of which class should be ignored, and it also keeps sentinel-filled ignored voxels (255 + 255 > 0). The standalone AsymmetricFocalTverskyLoss path therefore still includes ignored regions. Build the mask from the original labels, or from the ignored class channel.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/losses/unified_focal_loss.py` around lines 71 - 77, The current mask
uses (torch.sum(y_true, dim=1) > 0) which does not respect ignore_index; update
the ignore handling to build the mask from the original label indices or from
the explicit ignored-class channel instead of summing one-hot channels.
Concretely: when ignore_index is set, if you have access to the original label
tensor (e.g., labels / y_true_indices) create valid_mask = (labels !=
self.ignore_index).unsqueeze(1) and expand it to y_true shape; otherwise (when
y_true is one-hot) build the mask from the ignored class channel as valid_mask =
1 - y_true[:, self.ignore_index:self.ignore_index+1, ...] and use that to mask
y_pred and other computations (replace spatial_mask usage in this file and in
AsymmetricFocalTverskyLoss code paths).


y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
axis = list(range(2, len(y_pred.shape)))

# Calculate true positives (tp), false negatives (fn) and false positives (fp)
tp = torch.sum(y_true * y_pred, dim=axis)
fn = torch.sum(y_true * (1 - y_pred), dim=axis)
fp = torch.sum((1 - y_true) * y_pred, dim=axis)
fn = torch.sum(y_true * (1 - y_pred) * mask, dim=axis)
fp = torch.sum((1 - y_true) * y_pred * mask, dim=axis)
dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon)

# Calculate losses separately for each class, enhancing both classes
back_dice = 1 - dice_class[:, 0]
fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma)
fore_dice = torch.pow(1 - dice_class[:, 1], 1 - self.gamma)

# Average class scores
loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1))
loss = torch.stack([back_dice, fore_dice], dim=-1)
if self.reduction == LossReduction.MEAN.value:
return torch.mean(loss)
if self.reduction == LossReduction.SUM.value:
return torch.sum(loss)
return loss


Expand All @@ -103,26 +117,34 @@ def __init__(
gamma: float = 2,
epsilon: float = 1e-7,
reduction: LossReduction | str = LossReduction.MEAN,
ignore_index: int | None = None,
):
"""
Args:
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 2.
epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7.
ignore_index: class index to ignore from the loss computation.
"""
super().__init__(reduction=LossReduction(reduction).value)
self.to_onehot_y = to_onehot_y
self.delta = delta
self.gamma = gamma
self.epsilon = epsilon
self.ignore_index = ignore_index

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
n_pred_ch = y_pred.shape[1]

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
elif self.ignore_index is not None:
mask = (y_true != self.ignore_index).float()
y_true_clean = torch.where(y_true == self.ignore_index, 0, y_true)
y_true = one_hot(y_true_clean, num_classes=n_pred_ch)
y_true = y_true * mask
else:
y_true = one_hot(y_true, num_classes=n_pred_ch)

Expand All @@ -132,13 +154,24 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
cross_entropy = -y_true * torch.log(y_pred)

if self.ignore_index is not None:
spatial_mask = (torch.sum(y_true, dim=1, keepdim=True) > 0).float()
cross_entropy = cross_entropy * spatial_mask

Comment on lines +157 to +160
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

The default AsymmetricFocalLoss path still includes ignored voxels.

With to_onehot_y=False, this again uses sum(y_true) > 0, so one-hot targets never drop a specific class and sentinel-filled ignored regions stay valid. Since that is the default mode, ignore_index is ineffective for the standalone loss.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/losses/unified_focal_loss.py` around lines 157 - 160, The ignore_index
handling currently builds spatial_mask from torch.sum(y_true, dim=1) which fails
when self.to_onehot_y is False; update the block in AsymmetricFocalLoss
(unified_focal_loss.py) to branch on self.to_onehot_y: when to_onehot_y is False
compute spatial_mask = (y_true != self.ignore_index).unsqueeze(1).float() (so
label tensors drop ignored voxels), otherwise keep the existing one-hot style
mask (e.g., spatial_mask = (torch.sum(y_true, dim=1, keepdim=True) >
0).float()); then apply that spatial_mask to cross_entropy as before.

back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0]
back_ce = (1 - self.delta) * back_ce

fore_ce = cross_entropy[:, 1]
fore_ce = self.delta * fore_ce

loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1))
loss = torch.stack([back_ce, fore_ce], dim=1) # [B, 2, H, W]
if self.reduction == LossReduction.MEAN.value:
if self.ignore_index is not None:
# Normalize by the number of non-ignored pixels
return loss.sum() / spatial_mask.sum().clamp(min=1e-5)
return loss.mean()
Comment on lines +167 to +172
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

mean reduction is scaled by pixels, not valid elements.

loss.sum() / spatial_mask.sum() is 2x larger than loss.mean() when nothing is ignored because loss has two channels but the denominator counts only pixels. Divide by the expanded mask sum, or average loss after masking it.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/losses/unified_focal_loss.py` around lines 167 - 172, The
mean-reduction branch incorrectly divides loss.sum() by spatial_mask.sum(),
undercounting channels (loss has shape [B,2,H,W]); update the normalization to
account for both channel elements by either applying the spatial mask to loss
(e.g., expand spatial_mask to match loss with
spatial_mask.unsqueeze(1).expand_as(loss)), summing the masked loss and dividing
by the masked-element count, or divide loss.sum() by (spatial_mask.sum() *
loss.size(1)). Make the change in the block handling self.reduction ==
LossReduction.MEAN.value and self.ignore_index is not None (affecting variables
loss, spatial_mask, back_ce, fore_ce).

if self.reduction == LossReduction.SUM.value:
return loss.sum()
return loss


Expand All @@ -162,6 +195,7 @@ def __init__(
gamma: float = 0.5,
delta: float = 0.7,
reduction: LossReduction | str = LossReduction.MEAN,
ignore_index: int | None = None,
):
"""
Args:
Expand All @@ -170,8 +204,7 @@ def __init__(
weight : weight for each loss function. Defaults to 0.5.
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5.
delta : weight of the background. Defaults to 0.7.


ignore_index: class index to ignore from the loss computation.

Example:
>>> import torch
Expand All @@ -187,10 +220,12 @@ def __init__(
self.gamma = gamma
self.delta = delta
self.weight: float = weight
self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta)
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta)
self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta, ignore_index=ignore_index)
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
gamma=self.gamma, delta=self.delta, ignore_index=ignore_index
)
self.ignore_index = ignore_index

# TODO: Implement this function to support multiple classes segmentation
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
"""
Args:
Expand All @@ -207,25 +242,42 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
ValueError: When num_classes
ValueError: When the number of classes entered does not match the expected number
"""
if y_pred.shape != y_true.shape:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")

if len(y_pred.shape) != 4 and len(y_pred.shape) != 5:
raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}")

# Transform binary inputs to 2-channel space
if y_pred.shape[1] == 1:
y_pred = one_hot(y_pred, num_classes=self.num_classes)
y_true = one_hot(y_true, num_classes=self.num_classes)

if torch.max(y_true) != self.num_classes - 1:
raise ValueError(f"Please make sure the number of classes is {self.num_classes - 1}")
y_pred = torch.cat([1 - y_pred, y_pred], dim=1)

n_pred_ch = y_pred.shape[1]
# Move one_hot conversion OUTSIDE the if y_pred.shape[1] == 1 block
if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
if self.ignore_index is not None:
mask = (y_true != self.ignore_index).float()
y_true_clean = torch.where(y_true == self.ignore_index, 0, y_true)
y_true = one_hot(y_true_clean, num_classes=self.num_classes)
# Keep the channel-wise mask
y_true = y_true * mask
else:
y_true = one_hot(y_true, num_classes=n_pred_ch)
y_true = one_hot(y_true, num_classes=self.num_classes)

# Check if shapes match
if y_true.shape[1] == 1 and y_pred.shape[1] == 2:
if self.ignore_index is not None:
# Create mask for valid pixels
mask = (y_true != self.ignore_index).float()
# Set ignore_index values to 0 before conversion
y_true_clean = y_true * mask
# Convert to 2-channel
y_true = torch.cat([1 - y_true_clean, y_true_clean], dim=1)
# Apply mask to both channels so ignored pixels are all zeros
y_true = y_true * mask
else:
y_true = torch.cat([1 - y_true, y_true], dim=1)

if y_true.shape != y_pred.shape:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
if self.ignore_index is None and torch.max(y_true) > self.num_classes - 1:
raise ValueError(f"Invalid class index found. Maximum class should be {self.num_classes - 1}")

asy_focal_loss = self.asy_focal_loss(y_pred, y_true)
asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true)
Expand Down
Loading
Loading