diff --git a/monai/losses/dice.py b/monai/losses/dice.py index cd76ec1323..8d6afeaa06 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -67,6 +67,7 @@ def __init__( batch: bool = False, weight: Sequence[float] | float | int | torch.Tensor | None = None, soft_label: bool = False, + ignore_index: int | None = None, ) -> None: """ Args: @@ -100,6 +101,7 @@ def __init__( The value/values should be no less than 0. Defaults to None. soft_label: whether the target contains non-binary values (soft labels) or not. If True a soft label formulation of the loss will be used. + ignore_index: class index to ignore from the loss computation. Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. @@ -122,6 +124,7 @@ def __init__( self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) self.batch = batch + self.ignore_index = ignore_index weight = torch.as_tensor(weight) if weight is not None else None self.register_buffer("class_weight", weight) self.class_weight: None | torch.Tensor @@ -163,6 +166,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if self.other_act is not None: input = self.other_act(input) + mask: torch.Tensor | None = None + if self.ignore_index is not None: + mask = (target != self.ignore_index).float() + if self.to_onehot_y: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") @@ -180,6 +187,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if target.shape != input.shape: raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})") + if mask is not None: + input = input * mask + target = target * mask + # reducing only spatial dimensions (not batch nor channels) reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist() if self.batch: diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index caa237fca8..62e98c2c61 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -73,6 +73,7 @@ def __init__( weight: Sequence[float] | float | int | torch.Tensor | None = None, reduction: LossReduction | str = LossReduction.MEAN, use_softmax: bool = False, + ignore_index: int | None = None, ) -> None: """ Args: @@ -99,6 +100,7 @@ def __init__( use_softmax: whether to use softmax to transform the original logits into probabilities. If True, softmax is used. If False, sigmoid is used. Defaults to False. + ignore_index: class index to ignore from the loss computation. Example: >>> import torch @@ -124,6 +126,7 @@ def __init__( weight = torch.as_tensor(weight) if weight is not None else None self.register_buffer("class_weight", weight) self.class_weight: None | torch.Tensor + self.ignore_index = ignore_index def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -161,6 +164,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if target.shape != input.shape: raise ValueError(f"ground truth has different shape ({target.shape}) from input ({input.shape})") + if self.ignore_index is not None: + mask = (target != self.ignore_index).float() + input = input * mask + target = target * mask + loss: torch.Tensor | None = None input = input.float() target = target.float() diff --git a/monai/losses/tversky.py b/monai/losses/tversky.py index 154f34c526..f2c15954c0 100644 --- a/monai/losses/tversky.py +++ b/monai/losses/tversky.py @@ -51,6 +51,7 @@ def __init__( smooth_dr: float = 1e-5, batch: bool = False, soft_label: bool = False, + ignore_index: int | None = None, ) -> None: """ Args: @@ -77,6 +78,7 @@ def __init__( before any `reduction`. soft_label: whether the target contains non-binary values (soft labels) or not. If True a soft label formulation of the loss will be used. + ignore_index: index of the class to ignore during calculation. Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. @@ -101,6 +103,7 @@ def __init__( self.smooth_dr = float(smooth_dr) self.batch = batch self.soft_label = soft_label + self.ignore_index = ignore_index def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -129,8 +132,21 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") else: + original_target = target target = one_hot(target, num_classes=n_pred_ch) + if self.ignore_index is not None: + mask_src = original_target if self.to_onehot_y and n_pred_ch > 1 else target + + if mask_src.shape[1] == 1: + mask = (mask_src != self.ignore_index).to(input.dtype) + else: + # Fallback for cases where target is already one-hot + mask = (1.0 - mask_src[:, self.ignore_index : self.ignore_index + 1]).to(input.dtype) + + input = input * mask + target = target * mask + if not self.include_background: if n_pred_ch == 1: warnings.warn("single channel prediction, `include_background=False` ignored.") diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index 745513fec0..45d2d4e778 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -39,6 +39,7 @@ def __init__( gamma: float = 0.75, epsilon: float = 1e-7, reduction: LossReduction | str = LossReduction.MEAN, + ignore_index: int | None = None, ) -> None: """ Args: @@ -46,12 +47,14 @@ def __init__( delta : weight of the background. Defaults to 0.7. gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7. + ignore_index: class index to ignore from the loss computation. """ super().__init__(reduction=LossReduction(reduction).value) self.to_onehot_y = to_onehot_y self.delta = delta self.gamma = gamma self.epsilon = epsilon + self.ignore_index = ignore_index def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: n_pred_ch = y_pred.shape[1] @@ -65,22 +68,33 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: if y_true.shape != y_pred.shape: raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") - # clip the prediction to avoid NaN + # Handle ignore_index: + mask = torch.ones_like(y_true) + if self.ignore_index is not None: + # Identify valid pixels: where at least one channel is 1 + spatial_mask = (torch.sum(y_true, dim=1, keepdim=True) > 0).float() + mask = spatial_mask.expand_as(y_true) + y_pred = y_pred * mask + y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) axis = list(range(2, len(y_pred.shape))) # Calculate true positives (tp), false negatives (fn) and false positives (fp) tp = torch.sum(y_true * y_pred, dim=axis) - fn = torch.sum(y_true * (1 - y_pred), dim=axis) - fp = torch.sum((1 - y_true) * y_pred, dim=axis) + fn = torch.sum(y_true * (1 - y_pred) * mask, dim=axis) + fp = torch.sum((1 - y_true) * y_pred * mask, dim=axis) dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon) # Calculate losses separately for each class, enhancing both classes back_dice = 1 - dice_class[:, 0] - fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma) + fore_dice = torch.pow(1 - dice_class[:, 1], 1 - self.gamma) # Average class scores - loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1)) + loss = torch.stack([back_dice, fore_dice], dim=-1) + if self.reduction == LossReduction.MEAN.value: + return torch.mean(loss) + if self.reduction == LossReduction.SUM.value: + return torch.sum(loss) return loss @@ -103,6 +117,7 @@ def __init__( gamma: float = 2, epsilon: float = 1e-7, reduction: LossReduction | str = LossReduction.MEAN, + ignore_index: int | None = None, ): """ Args: @@ -110,12 +125,14 @@ def __init__( delta : weight of the background. Defaults to 0.7. gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 2. epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7. + ignore_index: class index to ignore from the loss computation. """ super().__init__(reduction=LossReduction(reduction).value) self.to_onehot_y = to_onehot_y self.delta = delta self.gamma = gamma self.epsilon = epsilon + self.ignore_index = ignore_index def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: n_pred_ch = y_pred.shape[1] @@ -123,6 +140,11 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: if self.to_onehot_y: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + elif self.ignore_index is not None: + mask = (y_true != self.ignore_index).float() + y_true_clean = torch.where(y_true == self.ignore_index, 0, y_true) + y_true = one_hot(y_true_clean, num_classes=n_pred_ch) + y_true = y_true * mask else: y_true = one_hot(y_true, num_classes=n_pred_ch) @@ -132,13 +154,24 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) cross_entropy = -y_true * torch.log(y_pred) + if self.ignore_index is not None: + spatial_mask = (torch.sum(y_true, dim=1, keepdim=True) > 0).float() + cross_entropy = cross_entropy * spatial_mask + back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0] back_ce = (1 - self.delta) * back_ce fore_ce = cross_entropy[:, 1] fore_ce = self.delta * fore_ce - loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1)) + loss = torch.stack([back_ce, fore_ce], dim=1) # [B, 2, H, W] + if self.reduction == LossReduction.MEAN.value: + if self.ignore_index is not None: + # Normalize by the number of non-ignored pixels + return loss.sum() / spatial_mask.sum().clamp(min=1e-5) + return loss.mean() + if self.reduction == LossReduction.SUM.value: + return loss.sum() return loss @@ -162,6 +195,7 @@ def __init__( gamma: float = 0.5, delta: float = 0.7, reduction: LossReduction | str = LossReduction.MEAN, + ignore_index: int | None = None, ): """ Args: @@ -170,8 +204,7 @@ def __init__( weight : weight for each loss function. Defaults to 0.5. gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5. delta : weight of the background. Defaults to 0.7. - - + ignore_index: class index to ignore from the loss computation. Example: >>> import torch @@ -187,10 +220,12 @@ def __init__( self.gamma = gamma self.delta = delta self.weight: float = weight - self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta) - self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta) + self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta, ignore_index=ignore_index) + self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss( + gamma=self.gamma, delta=self.delta, ignore_index=ignore_index + ) + self.ignore_index = ignore_index - # TODO: Implement this function to support multiple classes segmentation def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: """ Args: @@ -207,25 +242,42 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: ValueError: When num_classes ValueError: When the number of classes entered does not match the expected number """ - if y_pred.shape != y_true.shape: - raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") - if len(y_pred.shape) != 4 and len(y_pred.shape) != 5: raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}") + # Transform binary inputs to 2-channel space if y_pred.shape[1] == 1: - y_pred = one_hot(y_pred, num_classes=self.num_classes) - y_true = one_hot(y_true, num_classes=self.num_classes) - - if torch.max(y_true) != self.num_classes - 1: - raise ValueError(f"Please make sure the number of classes is {self.num_classes - 1}") + y_pred = torch.cat([1 - y_pred, y_pred], dim=1) - n_pred_ch = y_pred.shape[1] + # Move one_hot conversion OUTSIDE the if y_pred.shape[1] == 1 block if self.to_onehot_y: - if n_pred_ch == 1: - warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + if self.ignore_index is not None: + mask = (y_true != self.ignore_index).float() + y_true_clean = torch.where(y_true == self.ignore_index, 0, y_true) + y_true = one_hot(y_true_clean, num_classes=self.num_classes) + # Keep the channel-wise mask + y_true = y_true * mask else: - y_true = one_hot(y_true, num_classes=n_pred_ch) + y_true = one_hot(y_true, num_classes=self.num_classes) + + # Check if shapes match + if y_true.shape[1] == 1 and y_pred.shape[1] == 2: + if self.ignore_index is not None: + # Create mask for valid pixels + mask = (y_true != self.ignore_index).float() + # Set ignore_index values to 0 before conversion + y_true_clean = y_true * mask + # Convert to 2-channel + y_true = torch.cat([1 - y_true_clean, y_true_clean], dim=1) + # Apply mask to both channels so ignored pixels are all zeros + y_true = y_true * mask + else: + y_true = torch.cat([1 - y_true, y_true], dim=1) + + if y_true.shape != y_pred.shape: + raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") + if self.ignore_index is None and torch.max(y_true) > self.num_classes - 1: + raise ValueError(f"Invalid class index found. Maximum class should be {self.num_classes - 1}") asy_focal_loss = self.asy_focal_loss(y_pred, y_true) asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true) diff --git a/monai/metrics/confusion_matrix.py b/monai/metrics/confusion_matrix.py index 26ec823081..51c671c9d3 100644 --- a/monai/metrics/confusion_matrix.py +++ b/monai/metrics/confusion_matrix.py @@ -69,6 +69,7 @@ def __init__( compute_sample: bool = False, reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False, + ignore_index: int | None = None, ) -> None: super().__init__() self.include_background = include_background @@ -76,6 +77,7 @@ def __init__( self.compute_sample = compute_sample self.reduction = reduction self.get_not_nans = get_not_nans + self.ignore_index = ignore_index def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] """ @@ -96,7 +98,9 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor warnings.warn("As for classification task, compute_sample should be False.") self.compute_sample = False - return get_confusion_matrix(y_pred=y_pred, y=y, include_background=self.include_background) + return get_confusion_matrix( + y_pred=y_pred, y=y, include_background=self.include_background, ignore_index=self.ignore_index + ) def aggregate( self, compute_sample: bool = False, reduction: MetricReduction | str | None = None @@ -131,7 +135,9 @@ def aggregate( return results -def get_confusion_matrix(y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True) -> torch.Tensor: +def get_confusion_matrix( + y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, ignore_index: int | None = None +) -> torch.Tensor: """ Compute confusion matrix. A tensor with the shape [BC4] will be returned. Where, the third dimension represents the number of true positive, false positive, true negative and false negative values for @@ -145,6 +151,9 @@ def get_confusion_matrix(y_pred: torch.Tensor, y: torch.Tensor, include_backgrou The values should be binarized. include_background: whether to include metric computation on the first channel of the predicted output. Defaults to True. + ignore_index: index of the class to ignore during calculation. + If ignore_index < number of classes, that class channel is excluded + else ignored regions are inferred from spatial locations where all label channels are zero. Raises: ValueError: when `y_pred` and `y` have different shapes. @@ -158,17 +167,42 @@ def get_confusion_matrix(y_pred: torch.Tensor, y: torch.Tensor, include_backgrou # get confusion matrix related metric batch_size, n_class = y_pred.shape[:2] + + # Create spatial mask if ignore_index is provided + mask = None + if ignore_index is not None: + if ignore_index >= n_class: + # If ignore_index is outside channel range (e.g. 255), we assume it's a spatial mask + mask = y.sum(dim=1, keepdim=True) > 0 + else: + # If ignore_index is a valid channel, exclude that specific channel + mask = 1.0 - y[:, ignore_index : ignore_index + 1] + # convert to [BNS], where S is the number of pixels for one sample. - # As for classification tasks, S equals to 1. y_pred = y_pred.reshape(batch_size, n_class, -1) y = y.reshape(batch_size, n_class, -1) + + if mask is not None: + mask = mask.reshape(batch_size, 1, -1) + y_pred = y_pred * mask + y = y * mask + tp = (y_pred + y) == 2 tn = (y_pred + y) == 0 + if mask is not None: + # When masking, TN must only count locations where the mask is 1 + tn = tn * mask.bool() + tp = tp.sum(dim=[2]).float() tn = tn.sum(dim=[2]).float() p = y.sum(dim=[2]).float() - n = y.shape[-1] - p + + if mask is not None: + # n is total valid pixels (per sample) minus the positives for that class + n = mask.reshape(batch_size, -1).sum(dim=1, keepdim=True) - p + else: + n = y.shape[-1] - p fn = p - tp fp = n - tn diff --git a/monai/metrics/generalized_dice.py b/monai/metrics/generalized_dice.py index 05eb94af48..b905474c2f 100644 --- a/monai/metrics/generalized_dice.py +++ b/monai/metrics/generalized_dice.py @@ -13,7 +13,7 @@ import torch -from monai.metrics.utils import do_metric_reduction, ignore_background +from monai.metrics.utils import do_metric_reduction from monai.utils import MetricReduction, Weight, deprecated_arg, look_up_option from .metric import CumulativeIterationMetric @@ -41,6 +41,7 @@ class GeneralizedDiceScore(CumulativeIterationMetric): Old versions computed `mean` when `mean_batch` was provided due to bug in reduction. weight_type: {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform ground truth volume into a weight factor. Defaults to ``"square"``. + ignore_index: class index to ignore from the metric computation. Raises: ValueError: When the `reduction` is not one of MetricReduction enum. @@ -51,11 +52,13 @@ def __init__( include_background: bool = True, reduction: MetricReduction | str = MetricReduction.MEAN, weight_type: Weight | str = Weight.SQUARE, + ignore_index: int | None = None, ) -> None: super().__init__() self.include_background = include_background self.reduction = look_up_option(reduction, MetricReduction) self.weight_type = look_up_option(weight_type, Weight) + self.ignore_index = ignore_index self.sum_over_classes = self.reduction in { MetricReduction.SUM, MetricReduction.MEAN, @@ -71,6 +74,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor y_pred (torch.Tensor): Binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format, where N is the batch dimension, C is the channel dimension, and the remaining are the spatial dimensions. y (torch.Tensor): Binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`. + ignore_index: class index to ignore from the metric computation. Returns: torch.Tensor: Generalized Dice Score averaged across batch and class @@ -84,6 +88,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor include_background=self.include_background, weight_type=self.weight_type, sum_over_classes=self.sum_over_classes, + ignore_index=self.ignore_index, ) @deprecated_arg( @@ -118,6 +123,7 @@ def compute_generalized_dice( include_background: bool = True, weight_type: Weight | str = Weight.SQUARE, sum_over_classes: bool = False, + ignore_index: int | None = None, ) -> torch.Tensor: """ Computes the Generalized Dice Score and returns a tensor with its per image values. @@ -132,6 +138,7 @@ def compute_generalized_dice( weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform ground truth volume into a weight factor. Defaults to ``"square"``. sum_over_labels (bool): Whether to sum the numerator and denominator across all labels before the final computation. + ignore_index: class index to ignore from the metric computation. Returns: torch.Tensor: Per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes]. @@ -147,52 +154,79 @@ def compute_generalized_dice( if y.shape != y_pred.shape: raise ValueError(f"y_pred - {y_pred.shape} - and y - {y.shape} - should have the same shapes.") - # Ignore background, if needed + # Apply ignore_index masking + if ignore_index is not None: + mask = (y != ignore_index).all(dim=1, keepdim=True).float() + y_pred = y_pred * mask + y = y * mask + + n_channels = y_pred.shape[1] + channels_to_use = list(range(n_channels)) + if not include_background: - y_pred, y = ignore_background(y_pred=y_pred, y=y) + channels_to_use.pop(0) + + if ignore_index is not None: + # If background was 0 and we ignore class 2, we need the correct absolute index + if ignore_index in channels_to_use: + channels_to_use.remove(ignore_index) + + if not channels_to_use: + return torch.zeros(y_pred.shape[0], 1, device=y_pred.device) # Reducing only spatial dimensions (not batch nor channels), compute the intersection and non-weighted denominator reduce_axis = list(range(2, y_pred.dim())) - intersection = torch.sum(y * y_pred, dim=reduce_axis) - y_o = torch.sum(y, dim=reduce_axis) - y_pred_o = torch.sum(y_pred, dim=reduce_axis) + y_o_full = torch.sum(y, dim=reduce_axis) # shape: (B, C) + intersection = torch.sum(y[:, channels_to_use, ...] * y_pred[:, channels_to_use, ...], dim=reduce_axis) + y_o = torch.sum(y[:, channels_to_use, ...], dim=reduce_axis) + y_pred_o = torch.sum(y_pred[:, channels_to_use, ...], dim=reduce_axis) + denominator = y_o + y_pred_o # Set the class weights weight_type = look_up_option(weight_type, Weight) + y_o_float = y_o_full.float() + if weight_type == Weight.SIMPLE: - w = torch.reciprocal(y_o.float()) + w_full = torch.reciprocal(y_o_float) elif weight_type == Weight.SQUARE: - w = torch.reciprocal(y_o.float() * y_o.float()) + w_full = torch.reciprocal(y_o_float * y_o_float) else: - w = torch.ones_like(y_o.float()) + w_full = torch.ones_like(y_o_float) # Replace infinite values for non-appearing classes by the maximum weight - for b in w: - infs = torch.isinf(b) - b[infs] = 0 - b[infs] = torch.max(b) + for b_idx in range(w_full.shape[0]): + batch_w = w_full[b_idx] + infs = torch.isinf(batch_w) + if infs.any(): + batch_w[infs] = 0 + max_w = torch.max(batch_w) + batch_w[infs] = max_w if max_w > 0 else 1.0 + + w = w_full[:, channels_to_use] - # Compute the weighted numerator and denominator, summing along the class axis when sum_over_classes is True if sum_over_classes: - numer = 2.0 * (intersection * w).sum(dim=1, keepdim=True) - denom = (denominator * w).sum(dim=1, keepdim=True) - y_pred_o = y_pred_o.sum(dim=-1, keepdim=True) + intersection = (intersection * w).sum(dim=1, keepdim=True) + denominator = (denominator * w).sum(dim=1, keepdim=True) + numer = 2.0 * intersection + denom = denominator else: numer = 2.0 * (intersection * w) denom = denominator * w - y_pred_o = y_pred_o # Compute the score - generalized_dice_score = numer / denom + generalized_dice_score = numer / (denom + 1e-6) - # Handle zero division. Where denom == 0 and the prediction volume is 0, score is 1. - # Where denom == 0 but the prediction volume is not 0, score is 0 + # Handle zero division denom_zeros = denom == 0 - generalized_dice_score[denom_zeros] = torch.where( - (y_pred_o == 0)[denom_zeros], - torch.tensor(1.0, device=generalized_dice_score.device), - torch.tensor(0.0, device=generalized_dice_score.device), - ) + if denom_zeros.any(): + if sum_over_classes: + generalized_dice_score[denom_zeros] = 1.0 + else: + generalized_dice_score[denom_zeros] = torch.where( + (y_pred_o * w)[denom_zeros] == 0, + torch.ones_like(generalized_dice_score[denom_zeros]), + torch.zeros_like(generalized_dice_score[denom_zeros]), + ) return generalized_dice_score diff --git a/monai/metrics/hausdorff_distance.py b/monai/metrics/hausdorff_distance.py index 1b83c93e5b..85cd589f03 100644 --- a/monai/metrics/hausdorff_distance.py +++ b/monai/metrics/hausdorff_distance.py @@ -51,6 +51,7 @@ class HausdorffDistanceMetric(CumulativeIterationMetric): ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric. + ignore_index: index of the class to ignore during calculation. Defaults to ``None``. """ @@ -62,6 +63,7 @@ def __init__( directed: bool = False, reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False, + ignore_index: int | None = None, ) -> None: super().__init__() self.include_background = include_background @@ -70,6 +72,7 @@ def __init__( self.directed = directed self.reduction = reduction self.get_not_nans = get_not_nans + self.ignore_index = ignore_index def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) -> torch.Tensor: # type: ignore[override] """ @@ -97,6 +100,12 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) if dims < 3: raise ValueError("y_pred should have at least three dimensions.") + mask = None + if self.ignore_index is not None: + mask = (y != self.ignore_index).all(dim=1, keepdim=True).float() + y_pred = y_pred * mask + y = y * mask + # compute (BxC) for each channel for each batch return compute_hausdorff_distance( y_pred=y_pred, @@ -106,6 +115,8 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) percentile=self.percentile, directed=self.directed, spacing=kwargs.get("spacing"), + ignore_index=self.ignore_index, + mask=mask, ) def aggregate( @@ -137,6 +148,8 @@ def compute_hausdorff_distance( percentile: float | None = None, directed: bool = False, spacing: int | float | np.ndarray | Sequence[int | float | np.ndarray | Sequence[int | float]] | None = None, + mask: torch.Tensor | None = None, + ignore_index: int | None = None, ) -> torch.Tensor: """ Compute the Hausdorff distance. @@ -162,6 +175,7 @@ def compute_hausdorff_distance( If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch, else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used for all images in batch. Defaults to ``None``. + ignore_index: index of the class to ignore during calculation. Defaults to ``None``. """ if not include_background: @@ -179,17 +193,35 @@ def compute_hausdorff_distance( spacing_list = prepare_spacing(spacing=spacing, batch_size=batch_size, img_dim=img_dim) for b, c in np.ndindex(batch_size, n_class): + yp = y_pred[b, c] + yt = y[b, c] + + if ignore_index is not None: + valid_mask = y[b].sum(dim=0) > 0 + yp = yp * valid_mask + yt = yt * valid_mask + + # if everything is ignored, define distance as 0 + if not valid_mask.any(): + hd[b, c] = torch.tensor(0.0, device=y_pred.device) + continue + _, distances, _ = get_edge_surface_distance( - y_pred[b, c], - y[b, c], + yp, + yt, distance_metric=distance_metric, spacing=spacing_list[b], symmetric=not directed, - class_index=c, + mask=mask[b, 0] if mask is not None else None, ) + + if len(distances) == 0: + hd[b, c] = torch.tensor(0.0, device=y_pred.device) + continue + percentile_distances = [_compute_percentile_hausdorff_distance(d, percentile) for d in distances] - max_distance = torch.max(torch.stack(percentile_distances)) - hd[b, c] = max_distance + + hd[b, c] = torch.max(torch.stack(percentile_distances)) return hd diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index fedd94fb93..d3553a2002 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -106,6 +106,7 @@ def __init__( ignore_empty: bool = True, num_classes: int | None = None, return_with_label: bool | list[str] = False, + ignore_index: int | None = None, ) -> None: super().__init__() self.include_background = include_background @@ -114,6 +115,7 @@ def __init__( self.ignore_empty = ignore_empty self.num_classes = num_classes self.return_with_label = return_with_label + self.ignore_index = ignore_index self.dice_helper = DiceHelper( include_background=self.include_background, reduction=MetricReduction.NONE, @@ -121,6 +123,7 @@ def __init__( apply_argmax=False, ignore_empty=self.ignore_empty, num_classes=self.num_classes, + ignore_index=self.ignore_index, ) def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] @@ -175,6 +178,7 @@ def compute_dice( include_background: bool = True, ignore_empty: bool = True, num_classes: int | None = None, + ignore_index: int | None = None, ) -> torch.Tensor: """ Computes Dice score metric for a batch of predictions. This performs the same computation as @@ -192,6 +196,7 @@ def compute_dice( num_classes: number of input channels (always including the background). When this is ``None``, ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are single-channel class indices and the number of classes is not automatically inferred from data. + ignore_index: index of the class to ignore during calculation. Returns: Dice scores per batch and per class, (shape: [batch_size, num_classes]). @@ -204,6 +209,7 @@ def compute_dice( apply_argmax=False, ignore_empty=ignore_empty, num_classes=num_classes, + ignore_index=ignore_index, )(y_pred=y_pred, y=y) @@ -262,6 +268,7 @@ def __init__( num_classes: int | None = None, sigmoid: bool | None = None, softmax: bool | None = None, + ignore_index: int | None = None, ) -> None: # handling deprecated arguments if sigmoid is not None: @@ -277,8 +284,9 @@ def __init__( self.activate = activate self.ignore_empty = ignore_empty self.num_classes = num_classes + self.ignore_index = ignore_index - def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: """ Compute the dice metric for binary inputs which have only spatial dimensions. This method is called separately for each batch item and for each channel of those items. @@ -286,7 +294,12 @@ def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor Args: y_pred: input predictions with shape HW[D]. y: ground truth with shape HW[D]. + mask: binary mask where 0 indicates voxels to ignore. """ + if mask is not None: + y_pred = y_pred * mask + y = y * mask + y_o = torch.sum(y) if y_o > 0: return (2.0 * torch.sum(torch.masked_select(y, y_pred))) / (y_o + torch.sum(y_pred)) @@ -322,6 +335,11 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl y_pred = torch.sigmoid(y_pred) y_pred = y_pred > 0.5 + # Create global mask for ignored voxels if ignore_index is set + mask = None + if self.ignore_index is not None: + mask = y != self.ignore_index + first_ch = 0 if self.include_background else 1 data = [] for b in range(y_pred.shape[0]): @@ -329,7 +347,11 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]: x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool() x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c] - c_list.append(self.compute_channel(x_pred, x)) + + # Extract the spatial mask for the current batch item + b_mask = mask[b, 0] if mask is not None else None + + c_list.append(self.compute_channel(x_pred, x, mask=b_mask)) data.append(torch.stack(c_list)) data = torch.stack(data, dim=0).contiguous() # type: ignore diff --git a/monai/metrics/meaniou.py b/monai/metrics/meaniou.py index 65c53f7aa5..069a8a3845 100644 --- a/monai/metrics/meaniou.py +++ b/monai/metrics/meaniou.py @@ -54,12 +54,14 @@ def __init__( reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False, ignore_empty: bool = True, + ignore_index: int | None = None, ) -> None: super().__init__() self.include_background = include_background self.reduction = reduction self.get_not_nans = get_not_nans self.ignore_empty = ignore_empty + self.ignore_index = ignore_index def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] """ @@ -78,7 +80,11 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor raise ValueError(f"y_pred should have at least 3 dimensions (batch, channel, spatial), got {dims}.") # compute IoU (BxC) for each channel for each batch return compute_iou( - y_pred=y_pred, y=y, include_background=self.include_background, ignore_empty=self.ignore_empty + y_pred=y_pred, + y=y, + include_background=self.include_background, + ignore_empty=self.ignore_empty, + ignore_index=self.ignore_index, ) def aggregate( @@ -103,7 +109,11 @@ def aggregate( def compute_iou( - y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, ignore_empty: bool = True + y_pred: torch.Tensor, + y: torch.Tensor, + include_background: bool = True, + ignore_empty: bool = True, + ignore_index: int | None = None, ) -> torch.Tensor: """Computes Intersection over Union (IoU) score metric from a batch of predictions. @@ -133,6 +143,13 @@ def compute_iou( if y.shape != y_pred.shape: raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.") + if ignore_index is not None: + mask = (y != ignore_index).float() + if mask.shape != y_pred.shape: + mask = mask.expand_as(y_pred) + y_pred = y_pred * mask + y = torch.where(y == ignore_index, torch.tensor(0, device=y.device), y) + # reducing only spatial dimensions (not batch nor channels) n_len = len(y_pred.shape) reduce_axis = list(range(2, n_len)) diff --git a/monai/metrics/surface_dice.py b/monai/metrics/surface_dice.py index b20b47a1a5..949b93d34d 100644 --- a/monai/metrics/surface_dice.py +++ b/monai/metrics/surface_dice.py @@ -57,6 +57,7 @@ class SurfaceDiceMetric(CumulativeIterationMetric): If set to ``True``, the function `aggregate` will return both the aggregated NSD and the `not_nans` count. If set to ``False``, `aggregate` will only return the aggregated NSD. use_subvoxels: Whether to use subvoxel distances. Defaults to ``False``. + ignore_index: class index to ignore from the metric computation. """ def __init__( @@ -67,6 +68,7 @@ def __init__( reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False, use_subvoxels: bool = False, + ignore_index: int | None = None, ) -> None: super().__init__() self.class_thresholds = class_thresholds @@ -75,6 +77,7 @@ def __init__( self.reduction = reduction self.get_not_nans = get_not_nans self.use_subvoxels = use_subvoxels + self.ignore_index = ignore_index def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) -> torch.Tensor: # type: ignore[override] r""" @@ -94,6 +97,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used for all images in batch. Defaults to ``None``. use_subvoxels: Whether to use subvoxel distances. Defaults to ``False``. + ignore_index: class index to ignore from the metric computation. Returns: @@ -108,6 +112,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) distance_metric=self.distance_metric, spacing=kwargs.get("spacing"), use_subvoxels=self.use_subvoxels, + ignore_index=self.ignore_index, ) def aggregate( @@ -142,6 +147,7 @@ def compute_surface_dice( distance_metric: str = "euclidean", spacing: int | float | np.ndarray | Sequence[int | float | np.ndarray | Sequence[int | float]] | None = None, use_subvoxels: bool = False, + ignore_index: int | None = None, ) -> torch.Tensor: r""" This function computes the (Normalized) Surface Dice (NSD) between the two tensors `y_pred` (referred to as @@ -199,6 +205,7 @@ def compute_surface_dice( else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used for all images in batch. Defaults to ``None``. use_subvoxels: Whether to use subvoxel distances. Defaults to ``False``. + ignore_index: class index to ignore from the metric computation. Raises: ValueError: If `y_pred` and/or `y` are not PyTorch tensors. @@ -213,6 +220,11 @@ def compute_surface_dice( Pytorch Tensor of shape [B,C], containing the NSD values :math:`\operatorname {NSD}_{b,c}` for each batch index :math:`b` and class :math:`c`. """ + if ignore_index is not None: + mask = (y != ignore_index).all(dim=1, keepdim=True).float() + + y_pred = y_pred * mask + y = y * mask if not include_background: y_pred, y = ignore_background(y_pred=y_pred, y=y) @@ -255,6 +267,7 @@ def compute_surface_dice( use_subvoxels=use_subvoxels, symmetric=True, class_index=c, + mask=mask[b, 0] if ignore_index is not None else None, ) boundary_correct: int | torch.Tensor | float boundary_complete: int | torch.Tensor | float @@ -264,7 +277,16 @@ def compute_surface_dice( distances_gt_pred <= class_thresholds[c] ) else: - areas_pred, areas_gt = areas # type: ignore + # Handle areas being returned as a single item or a tuple + if isinstance(areas, (list, tuple)): + if len(areas) == 2: + areas_pred, areas_gt = areas + elif len(areas) == 1: + areas_pred = areas_gt = areas[0] + else: + areas_pred = areas_gt = torch.tensor([], device=y_pred.device) + else: + areas_pred = areas_gt = areas areas_gt, areas_pred = areas_gt[edges_gt], areas_pred[edges_pred] boundary_complete = areas_gt.sum() + areas_pred.sum() gt_true = areas_gt[distances_gt_pred <= class_thresholds[c]].sum() if len(areas_gt) > 0 else 0.0 diff --git a/monai/metrics/surface_distance.py b/monai/metrics/surface_distance.py index 3cb336d6a0..ef68c5c2c5 100644 --- a/monai/metrics/surface_distance.py +++ b/monai/metrics/surface_distance.py @@ -46,6 +46,7 @@ class SurfaceDistanceMetric(CumulativeIterationMetric): ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric. + ignore_index: class index to ignore from the metric computation. """ @@ -56,6 +57,7 @@ def __init__( distance_metric: str = "euclidean", reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False, + ignore_index: int | None = None, ) -> None: super().__init__() self.include_background = include_background @@ -63,6 +65,7 @@ def __init__( self.symmetric = symmetric self.reduction = reduction self.get_not_nans = get_not_nans + self.ignore_index = ignore_index def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) -> torch.Tensor: # type: ignore[override] """ @@ -89,6 +92,13 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) if y_pred.dim() < 3: raise ValueError("y_pred should have at least three dimensions.") + mask = None + + if self.ignore_index is not None: + mask = (y != self.ignore_index).all(dim=1, keepdim=True).float() + y_pred = y_pred * mask + y = y * mask + # compute (BxC) for each channel for each batch return compute_average_surface_distance( y_pred=y_pred, @@ -97,6 +107,8 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) symmetric=self.symmetric, distance_metric=self.distance_metric, spacing=kwargs.get("spacing"), + mask=mask, + ignore_index=self.ignore_index, ) def aggregate( @@ -127,6 +139,8 @@ def compute_average_surface_distance( symmetric: bool = False, distance_metric: str = "euclidean", spacing: int | float | np.ndarray | Sequence[int | float | np.ndarray | Sequence[int | float]] | None = None, + mask: torch.Tensor | None = None, + ignore_index: int | None = None, ) -> torch.Tensor: """ This function is used to compute the Average Surface Distance from `y_pred` to `y` @@ -154,10 +168,12 @@ def compute_average_surface_distance( If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch, else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used for all images in batch. Defaults to ``None``. + ignore_index: class index to ignore from the metric computation. """ if not include_background: - y_pred, y = ignore_background(y_pred=y_pred, y=y) + if ignore_index != 0: + y_pred, y = ignore_background(y_pred=y_pred, y=y) y_pred = convert_data_type(y_pred, output_type=torch.Tensor, dtype=torch.float)[0] y = convert_data_type(y, output_type=torch.Tensor, dtype=torch.float)[0] @@ -172,15 +188,27 @@ def compute_average_surface_distance( spacing_list = prepare_spacing(spacing=spacing, batch_size=batch_size, img_dim=img_dim) for b, c in np.ndindex(batch_size, n_class): + yp = y_pred[b, c] + yt = y[b, c] + + if ignore_index is not None: + valid_mask = y[b].sum(dim=0) > 0 + yp = yp * valid_mask + yt = yt * valid_mask + _, distances, _ = get_edge_surface_distance( - y_pred[b, c], - y[b, c], + yp, + yt, distance_metric=distance_metric, spacing=spacing_list[b], symmetric=symmetric, class_index=c, + mask=mask[b, 0] if mask is not None else None, ) + surface_distance = torch.cat(distances) - asd[b, c] = torch.tensor(np.nan) if surface_distance.shape == (0,) else surface_distance.mean() + asd[b, c] = ( + torch.tensor(float("nan"), device=asd.device) if surface_distance.numel() == 0 else surface_distance.mean() + ) return convert_data_type(asd, output_type=torch.Tensor, device=y_pred.device, dtype=torch.float)[0] diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index a451b1a770..ac55437aa3 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -41,6 +41,7 @@ __all__ = [ "ignore_background", + "ignore_index_mask", "do_metric_reduction", "get_mask_edges", "get_surface_distance", @@ -68,6 +69,27 @@ def ignore_background(y_pred: NdarrayTensor, y: NdarrayTensor) -> tuple[NdarrayT return y_pred, y +def ignore_index_mask( + y_pred: torch.Tensor, y: torch.Tensor, ignore_index: int | None = None +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Masks out the specified ignore_index from both predictions and ground truth. + This is a helper for #8667 to allow 'Ignore Class' functionality in metrics. + """ + if ignore_index is None: + return y_pred, y + + # Create a spatial mask (B, 1, H, W, [D]) + # Elements are 0 where target == ignore_index, else 1 + mask = (y != ignore_index).float() + + # Apply mask to zero out the ignored regions + y_pred = y_pred * mask + y = y * mask + + return y_pred, y + + def do_metric_reduction( f: torch.Tensor, reduction: MetricReduction | str = MetricReduction.MEAN ) -> tuple[torch.Tensor | Any, torch.Tensor]: @@ -143,6 +165,7 @@ def get_mask_edges( crop: bool = True, spacing: Sequence | None = None, always_return_as_numpy: bool = False, + ignore_index: int | None = None, ) -> tuple[NdarrayTensor, NdarrayTensor]: """ Compute edges from binary segmentation masks. This @@ -244,6 +267,7 @@ def get_surface_distance( seg_gt: NdarrayOrTensor, distance_metric: str = "euclidean", spacing: int | float | np.ndarray | Sequence[int | float] | None = None, + mask: NdarrayOrTensor | None = None, ) -> NdarrayOrTensor: """ This function is used to compute the surface distances from `seg_pred` to `seg_gt`. @@ -262,6 +286,7 @@ def get_surface_distance( (1) If a single number, isotropic spacing with that value is used. (2) If a sequence of numbers, the length of the sequence must be equal to the image dimensions. (3) If ``None``, spacing of unity is used. Defaults to ``None``. + mask: optional boolean mask. Pixels where mask is False will be ignored in the distance computation. Note: If seg_pred or seg_gt is all 0, may result in nan/inf distance. @@ -275,14 +300,17 @@ def get_surface_distance( dis = np.inf * lib.ones_like(seg_gt, dtype=lib.float32) dis = dis[seg_gt] return convert_to_dst_type(dis, seg_pred, dtype=dis.dtype)[0] + if distance_metric == "euclidean": dis = monai_distance_transform_edt((~seg_gt)[None, ...], sampling=spacing)[0] # type: ignore elif distance_metric in {"chessboard", "taxicab"}: dis = distance_transform_cdt(convert_to_numpy(~seg_gt), metric=distance_metric) else: raise ValueError(f"distance_metric {distance_metric} is not implemented.") + dis = convert_to_dst_type(dis, seg_pred, dtype=lib.float32)[0] - return dis[seg_pred] # type: ignore + out = dis[seg_pred.bool()] + return out if out is not None else dis.new_empty((0,)) def get_edge_surface_distance( @@ -293,6 +321,8 @@ def get_edge_surface_distance( use_subvoxels: bool = False, symmetric: bool = False, class_index: int = -1, + mask: torch.Tensor | None = None, + ignore_index: int | None = None, ) -> tuple[ tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor], @@ -312,6 +342,7 @@ def get_edge_surface_distance( This will return the areas of the edges. symmetric: whether to compute the surface distance from `y_pred` to `y` and from `y` to `y_pred`. class_index: The class-index used for context when warning about empty ground truth or prediction. + mask: optional boolean mask indicating valid pixels. Returns: (edges_pred, edges_gt), (distances_pred_to_gt, [distances_gt_to_pred]), (areas_pred, areas_gt) | tuple() @@ -320,19 +351,18 @@ def get_edge_surface_distance( edges_spacing = None if use_subvoxels: edges_spacing = spacing if spacing is not None else ([1] * len(y_pred.shape)) - (edges_pred, edges_gt, *areas) = get_mask_edges( - y_pred, y, crop=True, spacing=edges_spacing, always_return_as_numpy=False - ) - if not edges_gt.any(): - warnings.warn( - f"the ground truth of class {class_index if class_index != -1 else 'Unknown'} is all 0," - " this may result in nan/inf distance." - ) - if not edges_pred.any(): - warnings.warn( - f"the prediction of class {class_index if class_index != -1 else 'Unknown'} is all 0," - " this may result in nan/inf distance." - ) + + edge_results = get_mask_edges(y_pred, y, crop=True, spacing=edges_spacing, always_return_as_numpy=False) + edges_pred, edges_gt = edge_results[0], edge_results[1] + + if mask is not None: + if len(edge_results) > 2 and isinstance(edge_results[2], tuple): + slices = edge_results[2] + mask = mask[slices] + mask = mask.to(edges_pred.device).bool() + edges_pred = edges_pred & mask + edges_gt = edges_gt & mask + distances: tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor] if symmetric: distances = ( @@ -341,7 +371,25 @@ def get_edge_surface_distance( ) # type: ignore else: distances = (get_surface_distance(edges_pred, edges_gt, distance_metric, spacing),) # type: ignore - return convert_to_tensor(((edges_pred, edges_gt), distances, tuple(areas)), device=y_pred.device) # type: ignore[no-any-return] + + distances = tuple(d if d is not None else edges_pred.new_empty((0,)) for d in distances) + + areas = edge_results[2:] if use_subvoxels else () + + # Ensure areas is always a tuple of 2 when use_subvoxels=True + if use_subvoxels and isinstance(areas, (list, tuple)): + if len(areas) == 1: + areas = (areas[0], areas[0]) + elif len(areas) != 2: + # Unexpected length, create empty tensors + areas = (torch.tensor([], device=y_pred.device), torch.tensor([], device=y_pred.device)) + + out = convert_to_tensor(((edges_pred, edges_gt), distances, tuple(areas)), device=y_pred.device) # type: ignore[no-any-return] + + if out is None: + out = torch.empty((0,), device=y_pred.device) + + return out def is_binary_tensor(input: torch.Tensor, name: str) -> None: diff --git a/tests/losses/test_ignore_index_losses.py b/tests/losses/test_ignore_index_losses.py new file mode 100644 index 0000000000..b07ba5c98d --- /dev/null +++ b/tests/losses/test_ignore_index_losses.py @@ -0,0 +1,68 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.losses import AsymmetricUnifiedFocalLoss, DiceLoss, FocalLoss, TverskyLoss + +# Defining test cases: (LossClass, args) +TEST_CASES = [ + (DiceLoss, {"sigmoid": True}), + (FocalLoss, {"use_softmax": False}), + (TverskyLoss, {"sigmoid": True}), + (AsymmetricUnifiedFocalLoss, {}), +] + + +class TestIgnoreIndexLosses(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_loss_ignore_consistency(self, loss_class, kwargs): + ignore_index = 255 + loss_func = loss_class(ignore_index=ignore_index, **kwargs) + + # Create two inputs that are identical EXCEPT in the area designated as 'ignored' + # Input shape: [Batch, Channel, H, W] + input_base = torch.randn(1, 1, 4, 4) + input_alt = input_base.clone() + input_alt[0, 0, 2:, :] += 5.0 # Significant difference in the bottom half + + # Target: Top half is valid (0,1), Bottom half is ignored (255) + target = torch.tensor( + [[[[1, 0, 1, 0], [0, 1, 0, 1], [255, 255, 255, 255], [255, 255, 255, 255]]]], dtype=torch.float + ) + + # Execute + loss_base = loss_func(input_base, target) + loss_alt = loss_func(input_alt, target) + + # ASSERTION: The losses must be identical because the difference + # occurred only in the ignored region. + torch.testing.assert_close(loss_base, loss_alt, atol=1e-5, rtol=1e-5) + + @parameterized.expand(TEST_CASES) + def test_no_ignore_behavior(self, loss_class, kwargs): + # Ensure that when ignore_index is None, the loss functions normally + loss_func = loss_class(ignore_index=None, **kwargs) + input_data = torch.randn(1, 1, 4, 4) + target = torch.randint(0, 2, (1, 1, 4, 4)).float() + + output = loss_func(input_data, target) + self.assertFalse(torch.isnan(output)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/metrics/test_ignore_index_metrics.py b/tests/metrics/test_ignore_index_metrics.py new file mode 100644 index 0000000000..6aecba8b74 --- /dev/null +++ b/tests/metrics/test_ignore_index_metrics.py @@ -0,0 +1,87 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.metrics import ( + ConfusionMatrixMetric, + DiceMetric, + GeneralizedDiceScore, + HausdorffDistanceMetric, + MeanIoU, + SurfaceDiceMetric, + SurfaceDistanceMetric, +) +from monai.utils import optional_import + +scipy, has_scipy = optional_import("scipy") + +# Test cases for metrics with their specific required arguments +TEST_METRICS = [ + (DiceMetric, {"include_background": True, "reduction": "mean"}), + (MeanIoU, {"include_background": True, "reduction": "mean"}), + (GeneralizedDiceScore, {"include_background": True}), + (ConfusionMatrixMetric, {"metric_name": "accuracy"}), +] + +# Metrics that require SciPy (Hausdorff and Surface metrics) +SCIPY_METRICS = [ + (HausdorffDistanceMetric, {"include_background": True}), + (SurfaceDistanceMetric, {"include_background": True}), + (SurfaceDiceMetric, {"class_thresholds": [0.5, 0.5], "include_background": True}), +] + + +@unittest.skipUnless(has_scipy, "Scipy required for surface metrics") +class TestIgnoreIndexMetrics(unittest.TestCase): + @parameterized.expand(TEST_METRICS + SCIPY_METRICS) + def test_metric_ignore_consistency(self, metric_class, kwargs): + # Initialize metric with ignore_index + metric = metric_class(ignore_index=255, **kwargs) + + # Batch size 1, 2 Classes, 4x4 Image + # y_pred1 and y_pred2 differ ONLY in the bottom half (the ignore zone) + y_pred1 = torch.zeros((1, 2, 4, 4)) + y_pred1[:, 1, 0:2, :] = 1.0 # Top half prediction + + y_pred2 = y_pred1.clone() + y_pred2[:, 1, 2:4, :] = 1.0 # Bottom half prediction (different!) + + # Target: Top half is valid (0/1), Bottom half is 255 + y = torch.zeros((1, 2, 4, 4)) + y[:, 1, 0:2, 0:2] = 1.0 + y[:, :, 2:4, :] = 255 + + # Run metric for both predictions + metric.reset() + metric(y_pred=y_pred1, y=y) + res1 = metric.aggregate() + if isinstance(res1, list): + res1 = res1[0] + + metric.reset() + metric(y_pred=y_pred2, y=y) + res2 = metric.aggregate() + if isinstance(res2, list): + res2 = res2[0] + + # The result must be identical because the spatial difference + # is hidden by the ignore_index + torch.testing.assert_close(res1, res2, msg=f"Failed for {metric_class.__name__}") + + +if __name__ == "__main__": + unittest.main()