Skip to content

Commit b74ef93

Browse files
committed
test: run compute tests, format and lint loss/metric updates
Signed-off-by: Rusheel Sharma <rusheelhere@gmail.com>
1 parent e7d6cd3 commit b74ef93

3 files changed

Lines changed: 33 additions & 25 deletions

File tree

monai/losses/focal_loss.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
174174
alpha_arg = None
175175
# Move the warning INSIDE this block
176176
warnings.warn(
177-
"`include_background=False`, scalar `alpha` ignored when using softmax.",
178-
stacklevel=2
177+
"`include_background=False`, scalar `alpha` ignored when using softmax.", stacklevel=2
179178
)
180179
loss = softmax_focal_loss(input, target, self.gamma, alpha_arg)
181180
else:
@@ -184,10 +183,29 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
184183
if mask is not None:
185184
loss = loss * mask
186185

187-
if self.class_weight is not None and target.shape[1] != 1:
188-
cw = torch.as_tensor(self.class_weight).to(loss)
189-
broadcast_dims = [-1] + [1] * len(target.shape[2:])
190-
loss = cw.view(broadcast_dims) * loss
186+
if self.class_weight is not None:
187+
cw = torch.as_tensor(self.class_weight, device=loss.device, dtype=loss.dtype)
188+
num_classes = loss.shape[1]
189+
190+
if cw.ndim > 0:
191+
if num_classes == 1:
192+
raise ValueError("Per-class class_weight is not supported for single-channel outputs.")
193+
if cw.numel() != num_classes:
194+
raise ValueError(
195+
f"The number of class_weight ({cw.numel()}) must match the number of "
196+
f"output channels ({num_classes})."
197+
)
198+
if (cw < 0).any():
199+
raise ValueError("class_weight values must be non-negative.")
200+
else:
201+
if cw < 0:
202+
raise ValueError("class_weight values must be non-negative.")
203+
204+
if cw.ndim == 0:
205+
loss = loss * cw
206+
else:
207+
broadcast_shape = [1, num_classes] + [1] * (loss.ndim - 2)
208+
loss = loss * cw.view(broadcast_shape)
191209

192210
if self.reduction == LossReduction.SUM.value:
193211
loss = loss.sum()
@@ -281,10 +299,10 @@ def sigmoid_focal_loss(
281299
)
282300
broadcast_dims = [1, -1] + [1] * len(target.shape[2:])
283301
alpha_t = alpha_t.view(broadcast_dims)
284-
302+
285303
# Apply per-class weight only to positive samples
286304
alpha_factor = torch.where(target == 1, alpha_t, torch.ones_like(alpha_t))
287305

288306
# This multiplication now works for both Scalar and Tensor cases
289307
loss = alpha_factor * loss
290-
return loss
308+
return loss

monai/losses/unified_focal_loss.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -69,17 +69,13 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
6969
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
7070

7171
# Handle ignore_index:
72-
# Since the wrapper already zeroed out ignore_index in y_true,
73-
# we create the mask where y_true has any valid class (sum across channels > 0)
7472
mask = torch.ones_like(y_true)
7573
if self.ignore_index is not None:
76-
# We identify valid pixels: where at least one channel is 1
74+
# Identify valid pixels: where at least one channel is 1
7775
spatial_mask = (torch.sum(y_true, dim=1, keepdim=True) > 0).float()
7876
mask = spatial_mask.expand_as(y_true)
79-
# Ensure y_pred is also masked so it doesn't create False Positives in ignored areas
8077
y_pred = y_pred * mask
8178

82-
# clip the prediction to avoid NaN
8379
y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
8480
axis = list(range(2, len(y_pred.shape)))
8581

@@ -91,7 +87,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
9187

9288
# Calculate losses separately for each class, enhancing both classes
9389
back_dice = 1 - dice_class[:, 0]
94-
fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma)
90+
fore_dice = torch.pow(1 - dice_class[:, 1], 1 - self.gamma)
9591

9692
# Average class scores
9793
loss = torch.stack([back_dice, fore_dice], dim=-1)
@@ -249,9 +245,9 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
249245
if len(y_pred.shape) != 4 and len(y_pred.shape) != 5:
250246
raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}")
251247

252-
# 2. Transform binary inputs to 2-channel space
248+
# Transform binary inputs to 2-channel space
253249
if y_pred.shape[1] == 1:
254-
y_pred = torch.cat([-y_pred, y_pred], dim=1)
250+
y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
255251

256252
# Move one_hot conversion OUTSIDE the if y_pred.shape[1] == 1 block
257253
if self.to_onehot_y:
@@ -264,20 +260,15 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
264260
else:
265261
y_true = one_hot(y_true, num_classes=self.num_classes)
266262

267-
# 3. NOW check if shapes match (They should both be [B, 2, H, W] now)
263+
# Check if shapes match
264+
if y_true.shape[1] == 1 and y_pred.shape[1] == 2:
265+
y_true = torch.cat([1 - y_true, y_true], dim=1)
268266
if y_true.shape != y_pred.shape:
269267
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
270268

271269
if torch.max(y_true) != self.num_classes - 1:
272270
raise ValueError(f"Please make sure the number of classes is {self.num_classes - 1}")
273271

274-
n_pred_ch = y_pred.shape[1]
275-
if self.to_onehot_y:
276-
if n_pred_ch == 1:
277-
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
278-
else:
279-
y_true = one_hot(y_true, num_classes=n_pred_ch)
280-
281272
asy_focal_loss = self.asy_focal_loss(y_pred, y_true)
282273
asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true)
283274

tests/metrics/test_ignore_index_metrics.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
SurfaceDiceMetric,
2626
SurfaceDistanceMetric,
2727
)
28-
2928
from monai.utils import optional_import
3029

3130
scipy, has_scipy = optional_import("scipy")

0 commit comments

Comments
 (0)