@@ -69,17 +69,13 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
6969 raise ValueError (f"ground truth has different shape ({ y_true .shape } ) from input ({ y_pred .shape } )" )
7070
7171 # Handle ignore_index:
72- # Since the wrapper already zeroed out ignore_index in y_true,
73- # we create the mask where y_true has any valid class (sum across channels > 0)
7472 mask = torch .ones_like (y_true )
7573 if self .ignore_index is not None :
76- # We identify valid pixels: where at least one channel is 1
74+ # Identify valid pixels: where at least one channel is 1
7775 spatial_mask = (torch .sum (y_true , dim = 1 , keepdim = True ) > 0 ).float ()
7876 mask = spatial_mask .expand_as (y_true )
79- # Ensure y_pred is also masked so it doesn't create False Positives in ignored areas
8077 y_pred = y_pred * mask
8178
82- # clip the prediction to avoid NaN
8379 y_pred = torch .clamp (y_pred , self .epsilon , 1.0 - self .epsilon )
8480 axis = list (range (2 , len (y_pred .shape )))
8581
@@ -91,7 +87,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
9187
9288 # Calculate losses separately for each class, enhancing both classes
9389 back_dice = 1 - dice_class [:, 0 ]
94- fore_dice = ( 1 - dice_class [:, 1 ]) * torch .pow (1 - dice_class [:, 1 ], - self .gamma )
90+ fore_dice = torch .pow (1 - dice_class [:, 1 ], 1 - self .gamma )
9591
9692 # Average class scores
9793 loss = torch .stack ([back_dice , fore_dice ], dim = - 1 )
@@ -249,9 +245,9 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
249245 if len (y_pred .shape ) != 4 and len (y_pred .shape ) != 5 :
250246 raise ValueError (f"input shape must be 4 or 5, but got { y_pred .shape } " )
251247
252- # 2. Transform binary inputs to 2-channel space
248+ # Transform binary inputs to 2-channel space
253249 if y_pred .shape [1 ] == 1 :
254- y_pred = torch .cat ([- y_pred , y_pred ], dim = 1 )
250+ y_pred = torch .cat ([1 - y_pred , y_pred ], dim = 1 )
255251
256252 # Move one_hot conversion OUTSIDE the if y_pred.shape[1] == 1 block
257253 if self .to_onehot_y :
@@ -264,20 +260,15 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
264260 else :
265261 y_true = one_hot (y_true , num_classes = self .num_classes )
266262
267- # 3. NOW check if shapes match (They should both be [B, 2, H, W] now)
263+ # Check if shapes match
264+ if y_true .shape [1 ] == 1 and y_pred .shape [1 ] == 2 :
265+ y_true = torch .cat ([1 - y_true , y_true ], dim = 1 )
268266 if y_true .shape != y_pred .shape :
269267 raise ValueError (f"ground truth has different shape ({ y_true .shape } ) from input ({ y_pred .shape } )" )
270268
271269 if torch .max (y_true ) != self .num_classes - 1 :
272270 raise ValueError (f"Please make sure the number of classes is { self .num_classes - 1 } " )
273271
274- n_pred_ch = y_pred .shape [1 ]
275- if self .to_onehot_y :
276- if n_pred_ch == 1 :
277- warnings .warn ("single channel prediction, `to_onehot_y=True` ignored." )
278- else :
279- y_true = one_hot (y_true , num_classes = n_pred_ch )
280-
281272 asy_focal_loss = self .asy_focal_loss (y_pred , y_true )
282273 asy_focal_tversky_loss = self .asy_focal_tversky_loss (y_pred , y_true )
283274
0 commit comments