From e4abdfc67ce21163acb2ea4adaede9bce82113df Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 23 Apr 2026 14:24:24 -0700 Subject: [PATCH 1/9] bf16 and more fp32 sections for dice --- ScaFFold/utils/data_types.py | 4 ++ ScaFFold/utils/evaluate.py | 30 ++++++----- ScaFFold/utils/trainer.py | 97 ++++++++++++++++++++---------------- 3 files changed, 75 insertions(+), 56 deletions(-) diff --git a/ScaFFold/utils/data_types.py b/ScaFFold/utils/data_types.py index 90186db..b555811 100644 --- a/ScaFFold/utils/data_types.py +++ b/ScaFFold/utils/data_types.py @@ -13,9 +13,13 @@ # SPDX-License-Identifier: (Apache-2.0) import numpy as np +import torch DEFAULT_NP_DTYPE = np.float64 # Masks are values 0 <= x <= n_categories MASK_DTYPE = np.uint16 # Volumes/img are 0 <= x <= 1 VOLUME_DTYPE = np.float32 + +# Shared AMP dtype selection for torch.autocast. +AMP_DTYPE = torch.bfloat16 diff --git a/ScaFFold/utils/evaluate.py b/ScaFFold/utils/evaluate.py index 372aa70..2bf4e01 100644 --- a/ScaFFold/utils/evaluate.py +++ b/ScaFFold/utils/evaluate.py @@ -17,6 +17,7 @@ from distconv import DCTensor from tqdm import tqdm +from ScaFFold.utils.data_types import AMP_DTYPE from ScaFFold.utils.dice_score import ( SpatialAllReduce, compute_sharded_dice, @@ -30,6 +31,10 @@ def evaluate( net, dataloader, device, amp, primary, criterion, n_categories, parallel_strategy ): net.eval() + autocast_device_type = device.type if device.type != "mps" else "cpu" + autocast_kwargs = {"device_type": autocast_device_type, "enabled": amp} + if amp: + autocast_kwargs["dtype"] = AMP_DTYPE num_val_batches = len(dataloader) total_dice_score = 0.0 processed_batches = 0 @@ -41,7 +46,7 @@ def evaluate( f"[eval] ps.shard_dim={parallel_strategy.shard_dim} num_shards={parallel_strategy.num_shards}" ) - with torch.autocast(device.type if device.type != "mps" else "cpu", enabled=amp): + with torch.autocast(**autocast_kwargs): val_loss_epoch = 0.0 for batch in tqdm( dataloader, @@ -80,9 +85,7 @@ def evaluate( continue # --- 1. Sharded CE Loss --- - with torch.autocast( - device.type if device.type != "mps" else "cpu", enabled=False - ): + with torch.autocast(device_type=autocast_device_type, enabled=False): local_ce_sum = F.cross_entropy( local_preds.float(), local_labels, reduction="sum" ) @@ -100,15 +103,18 @@ def evaluate( CE_loss = global_ce_sum / global_total_voxels # --- 2. Format Predictions & Labels (Strictly Multiclass) --- - mask_pred_probs = F.softmax(local_preds, dim=1).float() - mask_true_onehot = ( - F.one_hot(local_labels, n_categories + 1).permute(0, 4, 1, 2, 3).float() - ) + with torch.autocast(device_type=autocast_device_type, enabled=False): + mask_pred_probs = F.softmax(local_preds.float(), dim=1) + mask_true_onehot = ( + F.one_hot(local_labels, n_categories + 1) + .permute(0, 4, 1, 2, 3) + .float() + ) - # Dice loss uses probabilities - dice_score_probs = compute_sharded_dice( - mask_pred_probs, mask_true_onehot, spatial_mesh - ) + # Dice loss uses probabilities + dice_score_probs = compute_sharded_dice( + mask_pred_probs, mask_true_onehot, spatial_mesh + ) dice_loss_curr = 1.0 - dice_score_probs.mean() # Eval metric (excluding background class 0) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index a5d4355..4a6b052 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -28,8 +28,12 @@ from tqdm import tqdm from ScaFFold.utils.checkpointing import CheckpointManager +from ScaFFold.utils.data_types import AMP_DTYPE from ScaFFold.utils.data_loading import FractalDataset, SpatialShardSpec -from ScaFFold.utils.dice_score import SpatialAllReduce, compute_sharded_dice +from ScaFFold.utils.dice_score import ( + SpatialAllReduce, + compute_sharded_dice, +) from ScaFFold.utils.distributed import get_local_rank, get_world_rank, get_world_size # Local @@ -48,6 +52,9 @@ def __init__(self, model, config, device, log): self.config = config self.device = device self.log = log + self.amp_device_type = self.device.type if self.device.type != "mps" else "cpu" + self.amp_dtype = AMP_DTYPE + self.use_grad_scaler = False self.world_size = get_world_size(required=self.config.dist) self.world_rank = get_world_rank(required=self.config.dist) self.local_rank = get_local_rank(required=self.config.dist) @@ -194,7 +201,8 @@ def setup_training_components(self): ) # Set up gradient scaler for AMP (Automatic Mixed Precision) - self.grad_scaler = torch.amp.GradScaler("cuda", enabled=self.config.torch_amp) + self.use_grad_scaler = self.config.torch_amp and self.amp_dtype == torch.float16 + self.grad_scaler = torch.amp.GradScaler("cuda", enabled=self.use_grad_scaler) # Set up loss function self.criterion = ( @@ -204,9 +212,18 @@ def setup_training_components(self): ) self.log.info( - f"Optimizer: {self.optimizer}, Scheduler: {self.scheduler}, Gradient Scaler Enabled: {self.config.torch_amp}" + f"Optimizer: {self.optimizer}, Scheduler: {self.scheduler}, AMP dtype: {self.amp_dtype}, Gradient Scaler Enabled: {self.use_grad_scaler}" ) + def _autocast_kwargs(self, enabled=None): + if enabled is None: + enabled = self.config.torch_amp + + kwargs = {"device_type": self.amp_device_type, "enabled": enabled} + if enabled: + kwargs["dtype"] = self.amp_dtype + return kwargs + class PyTorchTrainer(BaseTrainer): """ @@ -392,10 +409,7 @@ def warmup(self): true_masks_dc = DCTensor.from_shard(true_masks, self.ps) self._get_memsize(images_dc, "Sharded image", self.config.verbose) - with torch.autocast( - self.device.type if self.device.type != "mps" else "cpu", - enabled=self.config.torch_amp, - ): + with torch.autocast(**self._autocast_kwargs()): # Forward on DCTensor self.log.debug(" warmup: running forward pass") masks_pred_dc = self.model(images_dc) @@ -422,10 +436,7 @@ def warmup(self): ) # 1. Sharded Cross Entropy - with torch.autocast( - self.device.type if self.device.type != "mps" else "cpu", - enabled=False, - ): + with torch.autocast(**self._autocast_kwargs(enabled=False)): local_ce_sum = F.cross_entropy( local_preds.float(), local_labels, reduction="sum" ) @@ -444,15 +455,18 @@ def warmup(self): loss_ce = global_ce_sum / global_total_voxels # 2. Sharded Dice Loss - local_preds_softmax = F.softmax(local_preds, dim=1).float() - local_labels_one_hot = ( - F.one_hot(local_labels, num_classes=self.config.n_categories + 1) - .permute(0, 4, 1, 2, 3) - .float() - ) - dice_scores = compute_sharded_dice( - local_preds_softmax, local_labels_one_hot, self.spatial_mesh - ) + with torch.autocast(**self._autocast_kwargs(enabled=False)): + local_preds_softmax = F.softmax(local_preds.float(), dim=1) + local_labels_one_hot = ( + F.one_hot( + local_labels, num_classes=self.config.n_categories + 1 + ) + .permute(0, 4, 1, 2, 3) + .float() + ) + dice_scores = compute_sharded_dice( + local_preds_softmax, local_labels_one_hot, self.spatial_mesh + ) loss_dice = 1.0 - dice_scores.mean() # 3. Combine Loss @@ -585,10 +599,7 @@ def train(self): images_dc, "Sharded image", self.config.verbose ) - with torch.autocast( - self.device.type if self.device.type != "mps" else "cpu", - enabled=self.config.torch_amp, - ): + with torch.autocast(**self._autocast_kwargs()): # Predict on this batch torch.cuda.reset_peak_memory_stats() gather_and_print_mem(self.log, "pre_forward") @@ -621,12 +632,7 @@ def train(self): ) # 1. Sharded Cross Entropy - with torch.autocast( - self.device.type - if self.device.type != "mps" - else "cpu", - enabled=False, - ): + with torch.autocast(**self._autocast_kwargs(enabled=False)): local_ce_sum = F.cross_entropy( local_preds.float(), local_labels, @@ -649,22 +655,25 @@ def train(self): loss_ce = global_ce_sum / global_total_voxels # 2. Sharded Dice Loss - local_preds_softmax = F.softmax(local_preds, dim=1).float() - local_labels_one_hot = ( - F.one_hot( - local_labels, - num_classes=self.config.n_categories + 1, + with torch.autocast(**self._autocast_kwargs(enabled=False)): + local_preds_softmax = F.softmax( + local_preds.float(), dim=1 + ) + local_labels_one_hot = ( + F.one_hot( + local_labels, + num_classes=self.config.n_categories + 1, + ) + .permute(0, 4, 1, 2, 3) + .float() ) - .permute(0, 4, 1, 2, 3) - .float() - ) - # Compute sharded dice using new function - dice_scores = compute_sharded_dice( - local_preds_softmax, - local_labels_one_hot, - self.spatial_mesh, - ) + # Compute sharded dice using new function + dice_scores = compute_sharded_dice( + local_preds_softmax, + local_labels_one_hot, + self.spatial_mesh, + ) loss_dice = 1.0 - dice_scores.mean() # 3. Combine Loss From e4a4233cc9bd08a89eb08990e385157c85ddf054 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 23 Apr 2026 14:29:16 -0700 Subject: [PATCH 2/9] Refactor --- ScaFFold/utils/evaluate.py | 27 ++++++++------- ScaFFold/utils/trainer.py | 67 ++++++++++++++++++-------------------- 2 files changed, 45 insertions(+), 49 deletions(-) diff --git a/ScaFFold/utils/evaluate.py b/ScaFFold/utils/evaluate.py index 2bf4e01..fe45881 100644 --- a/ScaFFold/utils/evaluate.py +++ b/ScaFFold/utils/evaluate.py @@ -84,11 +84,23 @@ def evaluate( if local_preds.size(0) == 0 or local_labels.size(0) == 0: continue - # --- 1. Sharded CE Loss --- with torch.autocast(device_type=autocast_device_type, enabled=False): + # --- 1. Sharded CE Loss --- local_ce_sum = F.cross_entropy( local_preds.float(), local_labels, reduction="sum" ) + # --- 2. Sharded Dice Loss --- + mask_pred_probs = F.softmax(local_preds.float(), dim=1) + mask_true_onehot = ( + F.one_hot(local_labels, n_categories + 1) + .permute(0, 4, 1, 2, 3) + .float() + ) + + # Dice loss uses probabilities + dice_score_probs = compute_sharded_dice( + mask_pred_probs, mask_true_onehot, spatial_mesh + ) global_ce_sum = SpatialAllReduce.apply(local_ce_sum, spatial_mesh) # Divide by the actual global voxel count to handle uneven shards. @@ -102,19 +114,6 @@ def evaluate( ) CE_loss = global_ce_sum / global_total_voxels - # --- 2. Format Predictions & Labels (Strictly Multiclass) --- - with torch.autocast(device_type=autocast_device_type, enabled=False): - mask_pred_probs = F.softmax(local_preds.float(), dim=1) - mask_true_onehot = ( - F.one_hot(local_labels, n_categories + 1) - .permute(0, 4, 1, 2, 3) - .float() - ) - - # Dice loss uses probabilities - dice_score_probs = compute_sharded_dice( - mask_pred_probs, mask_true_onehot, spatial_mesh - ) dice_loss_curr = 1.0 - dice_score_probs.mean() # Eval metric (excluding background class 0) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 4a6b052..2fdb863 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -435,11 +435,23 @@ def warmup(self): f" warmup: Calculating sharded loss. Mem: {current_mem:.2f} GB." ) - # 1. Sharded Cross Entropy with torch.autocast(**self._autocast_kwargs(enabled=False)): + # 1. Sharded Cross Entropy local_ce_sum = F.cross_entropy( local_preds.float(), local_labels, reduction="sum" ) + # 2. Sharded Dice Loss + local_preds_softmax = F.softmax(local_preds.float(), dim=1) + local_labels_one_hot = ( + F.one_hot( + local_labels, num_classes=self.config.n_categories + 1 + ) + .permute(0, 4, 1, 2, 3) + .float() + ) + dice_scores = compute_sharded_dice( + local_preds_softmax, local_labels_one_hot, self.spatial_mesh + ) # Pass the spatial_mesh directly global_ce_sum = SpatialAllReduce.apply(local_ce_sum, self.spatial_mesh) @@ -454,19 +466,6 @@ def warmup(self): ) loss_ce = global_ce_sum / global_total_voxels - # 2. Sharded Dice Loss - with torch.autocast(**self._autocast_kwargs(enabled=False)): - local_preds_softmax = F.softmax(local_preds.float(), dim=1) - local_labels_one_hot = ( - F.one_hot( - local_labels, num_classes=self.config.n_categories + 1 - ) - .permute(0, 4, 1, 2, 3) - .float() - ) - dice_scores = compute_sharded_dice( - local_preds_softmax, local_labels_one_hot, self.spatial_mesh - ) loss_dice = 1.0 - dice_scores.mean() # 3. Combine Loss @@ -631,31 +630,14 @@ def train(self): f"Calculating sharded loss. Mem: {current_mem:.2f} GB." ) - # 1. Sharded Cross Entropy with torch.autocast(**self._autocast_kwargs(enabled=False)): + # 1. Sharded Cross Entropy local_ce_sum = F.cross_entropy( local_preds.float(), local_labels, reduction="sum", ) - - # Pass the spatial_mesh directly - global_ce_sum = SpatialAllReduce.apply( - local_ce_sum, self.spatial_mesh - ) - - local_voxel_count = torch.tensor( - float(local_labels.numel()), - device=local_labels.device, - dtype=torch.float32, - ) - global_total_voxels = SpatialAllReduce.apply( - local_voxel_count, self.spatial_mesh - ) - loss_ce = global_ce_sum / global_total_voxels - - # 2. Sharded Dice Loss - with torch.autocast(**self._autocast_kwargs(enabled=False)): + # Sharded Dice Loss local_preds_softmax = F.softmax( local_preds.float(), dim=1 ) @@ -667,13 +649,28 @@ def train(self): .permute(0, 4, 1, 2, 3) .float() ) - - # Compute sharded dice using new function + # Compute sharded dice dice_scores = compute_sharded_dice( local_preds_softmax, local_labels_one_hot, self.spatial_mesh, ) + + # Pass the spatial_mesh directly + global_ce_sum = SpatialAllReduce.apply( + local_ce_sum, self.spatial_mesh + ) + + local_voxel_count = torch.tensor( + float(local_labels.numel()), + device=local_labels.device, + dtype=torch.float32, + ) + global_total_voxels = SpatialAllReduce.apply( + local_voxel_count, self.spatial_mesh + ) + loss_ce = global_ce_sum / global_total_voxels + loss_dice = 1.0 - dice_scores.mean() # 3. Combine Loss From 2b6caa704600ad716fc13a0b26ff604bf9441ff6 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 23 Apr 2026 14:32:37 -0700 Subject: [PATCH 3/9] ruff --- ScaFFold/utils/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 2fdb863..fc5d2d9 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -28,8 +28,8 @@ from tqdm import tqdm from ScaFFold.utils.checkpointing import CheckpointManager -from ScaFFold.utils.data_types import AMP_DTYPE from ScaFFold.utils.data_loading import FractalDataset, SpatialShardSpec +from ScaFFold.utils.data_types import AMP_DTYPE from ScaFFold.utils.dice_score import ( SpatialAllReduce, compute_sharded_dice, From 235a913678586aafb62f11af7faab2ad6b7e8915 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 23 Apr 2026 15:11:29 -0700 Subject: [PATCH 4/9] fix merge artifact --- ScaFFold/utils/evaluate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ScaFFold/utils/evaluate.py b/ScaFFold/utils/evaluate.py index 09362ab..13ae7f4 100644 --- a/ScaFFold/utils/evaluate.py +++ b/ScaFFold/utils/evaluate.py @@ -127,7 +127,7 @@ def foreground_dice_mean(dice_scores): # --- Combine and Accumulate --- loss = CE_loss + (1.0 - batch_dice_score) val_loss_epoch += loss.item() - total_dice_score += batch_dice_score.item() + total_dice_score += batch_dice_score processed_batches += 1 net.train() From 49eccfa6741afd6a254d9f13afa0a9fe3b54fb0c Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 23 Apr 2026 19:06:51 -0700 Subject: [PATCH 5/9] Update trainer.py --- ScaFFold/utils/trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index f753433..ff4c4ea 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -228,8 +228,8 @@ def _autocast_kwargs(self, enabled=None): def _foreground_dice_mean(dice_scores): """Match optimization to the reported validation metric by excluding background.""" if dice_scores.size(1) > 1: - return dice_scores[:, 1:].mean().item() - return dice_scores.mean().item() + return dice_scores[:, 1:].mean() + return dice_scores.mean() class PyTorchTrainer(BaseTrainer): @@ -684,7 +684,7 @@ def train(self): # 3. Combine Loss loss = loss_ce + (1.0 - batch_dice_score) - train_dice_total += batch_dice_score + train_dice_total += batch_dice_score.item() end_code_region("calculate_loss") From 14e3bef879bfe49afaac725fdddfa5a286c9a209 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 30 Apr 2026 14:20:49 -0700 Subject: [PATCH 6/9] Update trainer.py --- ScaFFold/utils/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index ff4c4ea..3df51dc 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -201,7 +201,8 @@ def setup_training_components(self): ) # Set up gradient scaler for AMP (Automatic Mixed Precision) - self.use_grad_scaler = self.config.torch_amp and self.amp_dtype == torch.float16 + # bfloat does not need grad scaler + self.use_grad_scaler = self.config.torch_amp and self.amp_dtype != torch.bfloat16 self.grad_scaler = torch.amp.GradScaler("cuda", enabled=self.use_grad_scaler) # Set up loss function From 535c3aa55a478fec2fd5058e54ee0df6746b799e Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 30 Apr 2026 14:26:13 -0700 Subject: [PATCH 7/9] Update trainer.py --- ScaFFold/utils/trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 3df51dc..4b8c9ba 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -202,7 +202,9 @@ def setup_training_components(self): # Set up gradient scaler for AMP (Automatic Mixed Precision) # bfloat does not need grad scaler - self.use_grad_scaler = self.config.torch_amp and self.amp_dtype != torch.bfloat16 + self.use_grad_scaler = ( + self.config.torch_amp and self.amp_dtype != torch.bfloat16 + ) self.grad_scaler = torch.amp.GradScaler("cuda", enabled=self.use_grad_scaler) # Set up loss function From 69bb3bada7a8cb713962d0dd79da6d33b0dbdf51 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 30 Apr 2026 14:54:41 -0700 Subject: [PATCH 8/9] Refactor --- ScaFFold/utils/evaluate.py | 46 +++++++++----------- ScaFFold/utils/trainer.py | 87 ++++++++++++++++++-------------------- 2 files changed, 62 insertions(+), 71 deletions(-) diff --git a/ScaFFold/utils/evaluate.py b/ScaFFold/utils/evaluate.py index 13ae7f4..62d0fdf 100644 --- a/ScaFFold/utils/evaluate.py +++ b/ScaFFold/utils/evaluate.py @@ -17,7 +17,7 @@ from distconv import DCTensor from tqdm import tqdm -from ScaFFold.utils.data_types import AMP_DTYPE +from ScaFFold.utils.data_types import AMP_DTYPE, VOLUME_DTYPE from ScaFFold.utils.dice_score import ( SpatialAllReduce, compute_sharded_dice, @@ -30,7 +30,6 @@ def evaluate( net, dataloader, device, amp, primary, criterion, n_categories, parallel_strategy ): - def foreground_dice_mean(dice_scores): if dice_scores.size(1) > 1: return dice_scores[:, 1:].mean().item() @@ -90,44 +89,39 @@ def foreground_dice_mean(dice_scores): if local_preds.size(0) == 0 or local_labels.size(0) == 0: continue + # Calculate CE and Dice loss in single precision for numerical stability. with torch.autocast(device_type=autocast_device_type, enabled=False): - # --- 1. Sharded CE Loss --- + # Compute global CE loss from sharded CE loss local_ce_sum = F.cross_entropy( local_preds.float(), local_labels, reduction="sum" ) - # --- 2. Sharded Dice Loss --- + global_ce_sum = SpatialAllReduce.apply(local_ce_sum, spatial_mesh) + local_voxel_count = torch.tensor( + float(local_labels.numel()), + device=local_labels.device, + dtype=VOLUME_DTYPE, + ) + global_total_voxels = SpatialAllReduce.apply( + local_voxel_count, spatial_mesh + ) + CE_loss = global_ce_sum / global_total_voxels + + # Compute global dice loss from sharded dice loss mask_pred_probs = F.softmax(local_preds.float(), dim=1) mask_true_onehot = ( F.one_hot(local_labels, n_categories + 1) .permute(0, 4, 1, 2, 3) .float() ) - - # Dice loss uses probabilities dice_score_probs = compute_sharded_dice( mask_pred_probs, mask_true_onehot, spatial_mesh ) - global_ce_sum = SpatialAllReduce.apply(local_ce_sum, spatial_mesh) - - # Divide by the actual global voxel count to handle uneven shards. - local_voxel_count = torch.tensor( - float(local_labels.numel()), - device=local_labels.device, - dtype=torch.float32, - ) - global_total_voxels = SpatialAllReduce.apply( - local_voxel_count, spatial_mesh - ) - CE_loss = global_ce_sum / global_total_voxels - - # Eval metric (excluding background class 0) - # dice_score_probs shape is [Batch, Channels]. - batch_dice_score = foreground_dice_mean(dice_score_probs) + batch_dice_score = foreground_dice_mean(dice_score_probs) - # --- Combine and Accumulate --- - loss = CE_loss + (1.0 - batch_dice_score) - val_loss_epoch += loss.item() - total_dice_score += batch_dice_score + # Sum global CE Loss and Dice loss + loss = CE_loss + (1.0 - batch_dice_score) + val_loss_epoch += loss.item() + total_dice_score += batch_dice_score processed_batches += 1 net.train() diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 4b8c9ba..da1a982 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -29,7 +29,7 @@ from ScaFFold.utils.checkpointing import CheckpointManager from ScaFFold.utils.data_loading import FractalDataset, SpatialShardSpec -from ScaFFold.utils.data_types import AMP_DTYPE +from ScaFFold.utils.data_types import AMP_DTYPE, VOLUME_DTYPE from ScaFFold.utils.dice_score import ( SpatialAllReduce, compute_sharded_dice, @@ -445,12 +445,26 @@ def warmup(self): f" warmup: Calculating sharded loss. Mem: {current_mem:.2f} GB." ) + # Calculate CE and Dice loss in single precision for numerical stability. with torch.autocast(**self._autocast_kwargs(enabled=False)): - # 1. Sharded Cross Entropy + # Compute global CE loss from sharded CE loss local_ce_sum = F.cross_entropy( local_preds.float(), local_labels, reduction="sum" ) - # 2. Sharded Dice Loss + global_ce_sum = SpatialAllReduce.apply( + local_ce_sum, self.spatial_mesh + ) + local_voxel_count = torch.tensor( + float(local_labels.numel()), + device=local_labels.device, + dtype=VOLUME_DTYPE, + ) + global_total_voxels = SpatialAllReduce.apply( + local_voxel_count, self.spatial_mesh + ) + loss_ce = global_ce_sum / global_total_voxels + + # Compute global dice loss from sharded dice loss local_preds_softmax = F.softmax(local_preds.float(), dim=1) local_labels_one_hot = ( F.one_hot( @@ -462,25 +476,10 @@ def warmup(self): dice_scores = compute_sharded_dice( local_preds_softmax, local_labels_one_hot, self.spatial_mesh ) + batch_dice_score = self._foreground_dice_mean(dice_scores) - # Pass the spatial_mesh directly - global_ce_sum = SpatialAllReduce.apply(local_ce_sum, self.spatial_mesh) - - local_voxel_count = torch.tensor( - float(local_labels.numel()), - device=local_labels.device, - dtype=torch.float32, - ) - global_total_voxels = SpatialAllReduce.apply( - local_voxel_count, self.spatial_mesh - ) - loss_ce = global_ce_sum / global_total_voxels - - # Dice score - batch_dice_score = self._foreground_dice_mean(dice_scores) - - # 3. Combine Loss - loss = loss_ce + (1.0 - batch_dice_score) + # Sum global CE Loss and Dice loss + loss = loss_ce + (1.0 - batch_dice_score) self.log.debug( " warmup: loss calculation complete. Proceeding to backward pass" @@ -641,14 +640,28 @@ def train(self): f"Calculating sharded loss. Mem: {current_mem:.2f} GB." ) + # Calculate CE and Dice loss in single precision for numerical stability. with torch.autocast(**self._autocast_kwargs(enabled=False)): - # 1. Sharded Cross Entropy + # Compute global CE loss from sharded CE loss local_ce_sum = F.cross_entropy( local_preds.float(), local_labels, reduction="sum", ) - # Sharded Dice Loss + global_ce_sum = SpatialAllReduce.apply( + local_ce_sum, self.spatial_mesh + ) + local_voxel_count = torch.tensor( + float(local_labels.numel()), + device=local_labels.device, + dtype=VOLUME_DTYPE, + ) + global_total_voxels = SpatialAllReduce.apply( + local_voxel_count, self.spatial_mesh + ) + loss_ce = global_ce_sum / global_total_voxels + + # Compute global dice loss from sharded dice loss local_preds_softmax = F.softmax( local_preds.float(), dim=1 ) @@ -660,34 +673,18 @@ def train(self): .permute(0, 4, 1, 2, 3) .float() ) - # Compute sharded dice dice_scores = compute_sharded_dice( local_preds_softmax, local_labels_one_hot, self.spatial_mesh, ) + batch_dice_score = self._foreground_dice_mean( + dice_scores + ) - # Pass the spatial_mesh directly - global_ce_sum = SpatialAllReduce.apply( - local_ce_sum, self.spatial_mesh - ) - - local_voxel_count = torch.tensor( - float(local_labels.numel()), - device=local_labels.device, - dtype=torch.float32, - ) - global_total_voxels = SpatialAllReduce.apply( - local_voxel_count, self.spatial_mesh - ) - loss_ce = global_ce_sum / global_total_voxels - - # Dice score - batch_dice_score = self._foreground_dice_mean(dice_scores) - - # 3. Combine Loss - loss = loss_ce + (1.0 - batch_dice_score) - train_dice_total += batch_dice_score.item() + # Sum global CE Loss and Dice loss + loss = loss_ce + (1.0 - batch_dice_score) + train_dice_total += batch_dice_score.item() end_code_region("calculate_loss") From 34c5281a0e845a86f62d5777dd1e3e7af4eb8a20 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 30 Apr 2026 14:56:26 -0700 Subject: [PATCH 9/9] mv .item() --- ScaFFold/utils/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index da1a982..b680746 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -684,7 +684,7 @@ def train(self): # Sum global CE Loss and Dice loss loss = loss_ce + (1.0 - batch_dice_score) - train_dice_total += batch_dice_score.item() + train_dice_total += batch_dice_score end_code_region("calculate_loss") @@ -756,7 +756,7 @@ def train(self): # # Write out data for this epoch to train stats csv # - train_dice = float(train_dice_total / len(self.train_loader)) + train_dice = float(train_dice_total.item() / len(self.train_loader)) self.log.info( f" epoch {epoch} \ | train_dice_loss {train_dice:.6f} (type {type(train_dice)}) \