diff --git a/ScaFFold/utils/data_types.py b/ScaFFold/utils/data_types.py index 90186db..b555811 100644 --- a/ScaFFold/utils/data_types.py +++ b/ScaFFold/utils/data_types.py @@ -13,9 +13,13 @@ # SPDX-License-Identifier: (Apache-2.0) import numpy as np +import torch DEFAULT_NP_DTYPE = np.float64 # Masks are values 0 <= x <= n_categories MASK_DTYPE = np.uint16 # Volumes/img are 0 <= x <= 1 VOLUME_DTYPE = np.float32 + +# Shared AMP dtype selection for torch.autocast. +AMP_DTYPE = torch.bfloat16 diff --git a/ScaFFold/utils/evaluate.py b/ScaFFold/utils/evaluate.py index 56198cc..62d0fdf 100644 --- a/ScaFFold/utils/evaluate.py +++ b/ScaFFold/utils/evaluate.py @@ -17,6 +17,7 @@ from distconv import DCTensor from tqdm import tqdm +from ScaFFold.utils.data_types import AMP_DTYPE, VOLUME_DTYPE from ScaFFold.utils.dice_score import ( SpatialAllReduce, compute_sharded_dice, @@ -29,13 +30,16 @@ def evaluate( net, dataloader, device, amp, primary, criterion, n_categories, parallel_strategy ): - def foreground_dice_mean(dice_scores): if dice_scores.size(1) > 1: return dice_scores[:, 1:].mean().item() return dice_scores.mean().item() net.eval() + autocast_device_type = device.type if device.type != "mps" else "cpu" + autocast_kwargs = {"device_type": autocast_device_type, "enabled": amp} + if amp: + autocast_kwargs["dtype"] = AMP_DTYPE num_val_batches = len(dataloader) total_dice_score = 0.0 processed_batches = 0 @@ -47,7 +51,7 @@ def foreground_dice_mean(dice_scores): f"[eval] ps.shard_dim={parallel_strategy.shard_dim} num_shards={parallel_strategy.num_shards}" ) - with torch.autocast(device.type if device.type != "mps" else "cpu", enabled=amp): + with torch.autocast(**autocast_kwargs): val_loss_epoch = 0.0 for batch in tqdm( dataloader, @@ -85,44 +89,39 @@ def foreground_dice_mean(dice_scores): if local_preds.size(0) == 0 or local_labels.size(0) == 0: continue - # --- 1. Sharded CE Loss --- - with torch.autocast( - device.type if device.type != "mps" else "cpu", enabled=False - ): + # Calculate CE and Dice loss in single precision for numerical stability. + with torch.autocast(device_type=autocast_device_type, enabled=False): + # Compute global CE loss from sharded CE loss local_ce_sum = F.cross_entropy( local_preds.float(), local_labels, reduction="sum" ) - global_ce_sum = SpatialAllReduce.apply(local_ce_sum, spatial_mesh) - - # Divide by the actual global voxel count to handle uneven shards. - local_voxel_count = torch.tensor( - float(local_labels.numel()), - device=local_labels.device, - dtype=torch.float32, - ) - global_total_voxels = SpatialAllReduce.apply( - local_voxel_count, spatial_mesh - ) - CE_loss = global_ce_sum / global_total_voxels - - # --- 2. Format Predictions & Labels (Strictly Multiclass) --- - mask_pred_probs = F.softmax(local_preds, dim=1).float() - mask_true_onehot = ( - F.one_hot(local_labels, n_categories + 1).permute(0, 4, 1, 2, 3).float() - ) + global_ce_sum = SpatialAllReduce.apply(local_ce_sum, spatial_mesh) + local_voxel_count = torch.tensor( + float(local_labels.numel()), + device=local_labels.device, + dtype=VOLUME_DTYPE, + ) + global_total_voxels = SpatialAllReduce.apply( + local_voxel_count, spatial_mesh + ) + CE_loss = global_ce_sum / global_total_voxels + + # Compute global dice loss from sharded dice loss + mask_pred_probs = F.softmax(local_preds.float(), dim=1) + mask_true_onehot = ( + F.one_hot(local_labels, n_categories + 1) + .permute(0, 4, 1, 2, 3) + .float() + ) + dice_score_probs = compute_sharded_dice( + mask_pred_probs, mask_true_onehot, spatial_mesh + ) + batch_dice_score = foreground_dice_mean(dice_score_probs) - # Dice loss uses probabilities - dice_score_probs = compute_sharded_dice( - mask_pred_probs, mask_true_onehot, spatial_mesh - ) - # Eval metric (excluding background class 0) - # dice_score_probs shape is [Batch, Channels]. - batch_dice_score = foreground_dice_mean(dice_score_probs) - - # --- Combine and Accumulate --- - loss = CE_loss + (1.0 - batch_dice_score) - val_loss_epoch += loss.item() - total_dice_score += batch_dice_score.item() + # Sum global CE Loss and Dice loss + loss = CE_loss + (1.0 - batch_dice_score) + val_loss_epoch += loss.item() + total_dice_score += batch_dice_score processed_batches += 1 net.train() diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 27482ab..b680746 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -29,7 +29,11 @@ from ScaFFold.utils.checkpointing import CheckpointManager from ScaFFold.utils.data_loading import FractalDataset, SpatialShardSpec -from ScaFFold.utils.dice_score import SpatialAllReduce, compute_sharded_dice +from ScaFFold.utils.data_types import AMP_DTYPE, VOLUME_DTYPE +from ScaFFold.utils.dice_score import ( + SpatialAllReduce, + compute_sharded_dice, +) from ScaFFold.utils.distributed import get_local_rank, get_world_rank, get_world_size # Local @@ -48,6 +52,9 @@ def __init__(self, model, config, device, log): self.config = config self.device = device self.log = log + self.amp_device_type = self.device.type if self.device.type != "mps" else "cpu" + self.amp_dtype = AMP_DTYPE + self.use_grad_scaler = False self.world_size = get_world_size(required=self.config.dist) self.world_rank = get_world_rank(required=self.config.dist) self.local_rank = get_local_rank(required=self.config.dist) @@ -194,7 +201,11 @@ def setup_training_components(self): ) # Set up gradient scaler for AMP (Automatic Mixed Precision) - self.grad_scaler = torch.amp.GradScaler("cuda", enabled=self.config.torch_amp) + # bfloat does not need grad scaler + self.use_grad_scaler = ( + self.config.torch_amp and self.amp_dtype != torch.bfloat16 + ) + self.grad_scaler = torch.amp.GradScaler("cuda", enabled=self.use_grad_scaler) # Set up loss function self.criterion = ( @@ -204,15 +215,24 @@ def setup_training_components(self): ) self.log.info( - f"Optimizer: {self.optimizer}, Scheduler: {self.scheduler}, Gradient Scaler Enabled: {self.config.torch_amp}" + f"Optimizer: {self.optimizer}, Scheduler: {self.scheduler}, AMP dtype: {self.amp_dtype}, Gradient Scaler Enabled: {self.use_grad_scaler}" ) + def _autocast_kwargs(self, enabled=None): + if enabled is None: + enabled = self.config.torch_amp + + kwargs = {"device_type": self.amp_device_type, "enabled": enabled} + if enabled: + kwargs["dtype"] = self.amp_dtype + return kwargs + @staticmethod def _foreground_dice_mean(dice_scores): """Match optimization to the reported validation metric by excluding background.""" if dice_scores.size(1) > 1: - return dice_scores[:, 1:].mean().item() - return dice_scores.mean().item() + return dice_scores[:, 1:].mean() + return dice_scores.mean() class PyTorchTrainer(BaseTrainer): @@ -399,10 +419,7 @@ def warmup(self): true_masks_dc = DCTensor.from_shard(true_masks, self.ps) self._get_memsize(images_dc, "Sharded image", self.config.verbose) - with torch.autocast( - self.device.type if self.device.type != "mps" else "cpu", - enabled=self.config.torch_amp, - ): + with torch.autocast(**self._autocast_kwargs()): # Forward on DCTensor self.log.debug(" warmup: running forward pass") masks_pred_dc = self.model(images_dc) @@ -428,42 +445,41 @@ def warmup(self): f" warmup: Calculating sharded loss. Mem: {current_mem:.2f} GB." ) - # 1. Sharded Cross Entropy - with torch.autocast( - self.device.type if self.device.type != "mps" else "cpu", - enabled=False, - ): + # Calculate CE and Dice loss in single precision for numerical stability. + with torch.autocast(**self._autocast_kwargs(enabled=False)): + # Compute global CE loss from sharded CE loss local_ce_sum = F.cross_entropy( local_preds.float(), local_labels, reduction="sum" ) + global_ce_sum = SpatialAllReduce.apply( + local_ce_sum, self.spatial_mesh + ) + local_voxel_count = torch.tensor( + float(local_labels.numel()), + device=local_labels.device, + dtype=VOLUME_DTYPE, + ) + global_total_voxels = SpatialAllReduce.apply( + local_voxel_count, self.spatial_mesh + ) + loss_ce = global_ce_sum / global_total_voxels - # Pass the spatial_mesh directly - global_ce_sum = SpatialAllReduce.apply(local_ce_sum, self.spatial_mesh) - - local_voxel_count = torch.tensor( - float(local_labels.numel()), - device=local_labels.device, - dtype=torch.float32, - ) - global_total_voxels = SpatialAllReduce.apply( - local_voxel_count, self.spatial_mesh - ) - loss_ce = global_ce_sum / global_total_voxels - - # 2. Sharded Dice Loss - local_preds_softmax = F.softmax(local_preds, dim=1).float() - local_labels_one_hot = ( - F.one_hot(local_labels, num_classes=self.config.n_categories + 1) - .permute(0, 4, 1, 2, 3) - .float() - ) - dice_scores = compute_sharded_dice( - local_preds_softmax, local_labels_one_hot, self.spatial_mesh - ) - batch_dice_score = self._foreground_dice_mean(dice_scores) + # Compute global dice loss from sharded dice loss + local_preds_softmax = F.softmax(local_preds.float(), dim=1) + local_labels_one_hot = ( + F.one_hot( + local_labels, num_classes=self.config.n_categories + 1 + ) + .permute(0, 4, 1, 2, 3) + .float() + ) + dice_scores = compute_sharded_dice( + local_preds_softmax, local_labels_one_hot, self.spatial_mesh + ) + batch_dice_score = self._foreground_dice_mean(dice_scores) - # 3. Combine Loss - loss = loss_ce + (1.0 - batch_dice_score) + # Sum global CE Loss and Dice loss + loss = loss_ce + (1.0 - batch_dice_score) self.log.debug( " warmup: loss calculation complete. Proceeding to backward pass" @@ -592,10 +608,7 @@ def train(self): images_dc, "Sharded image", self.config.verbose ) - with torch.autocast( - self.device.type if self.device.type != "mps" else "cpu", - enabled=self.config.torch_amp, - ): + with torch.autocast(**self._autocast_kwargs()): # Predict on this batch torch.cuda.reset_peak_memory_stats() gather_and_print_mem(self.log, "pre_forward") @@ -627,56 +640,51 @@ def train(self): f"Calculating sharded loss. Mem: {current_mem:.2f} GB." ) - # 1. Sharded Cross Entropy - with torch.autocast( - self.device.type - if self.device.type != "mps" - else "cpu", - enabled=False, - ): + # Calculate CE and Dice loss in single precision for numerical stability. + with torch.autocast(**self._autocast_kwargs(enabled=False)): + # Compute global CE loss from sharded CE loss local_ce_sum = F.cross_entropy( local_preds.float(), local_labels, reduction="sum", ) - - # Pass the spatial_mesh directly - global_ce_sum = SpatialAllReduce.apply( - local_ce_sum, self.spatial_mesh - ) - - local_voxel_count = torch.tensor( - float(local_labels.numel()), - device=local_labels.device, - dtype=torch.float32, - ) - global_total_voxels = SpatialAllReduce.apply( - local_voxel_count, self.spatial_mesh - ) - loss_ce = global_ce_sum / global_total_voxels - - # 2. Sharded Dice Loss - local_preds_softmax = F.softmax(local_preds, dim=1).float() - local_labels_one_hot = ( - F.one_hot( - local_labels, - num_classes=self.config.n_categories + 1, + global_ce_sum = SpatialAllReduce.apply( + local_ce_sum, self.spatial_mesh ) - .permute(0, 4, 1, 2, 3) - .float() - ) + local_voxel_count = torch.tensor( + float(local_labels.numel()), + device=local_labels.device, + dtype=VOLUME_DTYPE, + ) + global_total_voxels = SpatialAllReduce.apply( + local_voxel_count, self.spatial_mesh + ) + loss_ce = global_ce_sum / global_total_voxels - # Compute sharded dice using new function - dice_scores = compute_sharded_dice( - local_preds_softmax, - local_labels_one_hot, - self.spatial_mesh, - ) - batch_dice_score = self._foreground_dice_mean(dice_scores) + # Compute global dice loss from sharded dice loss + local_preds_softmax = F.softmax( + local_preds.float(), dim=1 + ) + local_labels_one_hot = ( + F.one_hot( + local_labels, + num_classes=self.config.n_categories + 1, + ) + .permute(0, 4, 1, 2, 3) + .float() + ) + dice_scores = compute_sharded_dice( + local_preds_softmax, + local_labels_one_hot, + self.spatial_mesh, + ) + batch_dice_score = self._foreground_dice_mean( + dice_scores + ) - # 3. Combine Loss - loss = loss_ce + (1.0 - batch_dice_score) - train_dice_total += batch_dice_score + # Sum global CE Loss and Dice loss + loss = loss_ce + (1.0 - batch_dice_score) + train_dice_total += batch_dice_score end_code_region("calculate_loss") @@ -748,7 +756,7 @@ def train(self): # # Write out data for this epoch to train stats csv # - train_dice = float(train_dice_total / len(self.train_loader)) + train_dice = float(train_dice_total.item() / len(self.train_loader)) self.log.info( f" epoch {epoch} \ | train_dice_loss {train_dice:.6f} (type {type(train_dice)}) \