From ca53695751a8234f24c62fc9078daf8b04903e4d Mon Sep 17 00:00:00 2001 From: DhruvrajSinhZala24 Date: Tue, 24 Mar 2026 11:56:32 +0530 Subject: [PATCH] Fix eval metrics and wandb config for transformers (#192) --- .../eval.py | 53 ++++++++++----- .../train.py | 67 ++++++++++++------- 2 files changed, 76 insertions(+), 44 deletions(-) diff --git a/DeepLense_Classification_Transformers_Archil_Srivastava/eval.py b/DeepLense_Classification_Transformers_Archil_Srivastava/eval.py index a3f12b32..602c62f6 100644 --- a/DeepLense_Classification_Transformers_Archil_Srivastava/eval.py +++ b/DeepLense_Classification_Transformers_Archil_Srivastava/eval.py @@ -40,8 +40,7 @@ def evaluate(model, data_loader, loss_fn, device): """ model.eval() # Switch on evaluation model - # Initialize lists for different metrics - loss, accuracy, class_auroc, micro_auroc, macro_auroc = [], [], [], [], [] + # Collect predictions and labels to compute metrics in one pass logits, y = [], [] # Iterate over batches and accumulate metrics @@ -55,24 +54,30 @@ def evaluate(model, data_loader, loss_fn, device): # Concatenate all results logits, y = torch.cat(logits), torch.cat(y) - loss.append(loss_fn(logits, y)) - accuracy.append(accuracy_fn(logits, y, num_classes=NUM_CLASSES)) - class_auroc.append(auroc_fn(logits, y, num_classes=NUM_CLASSES, average=None)) - macro_auroc.append(auroc_fn(logits, y, num_classes=NUM_CLASSES, average="macro")) + loss = loss_fn(logits, y).item() + accuracy = accuracy_fn(logits, y, num_classes=NUM_CLASSES).item() + class_auroc = auroc_fn( + logits, y, num_classes=NUM_CLASSES, average=None + ).cpu() + micro_auroc = auroc_fn( + logits, y, num_classes=NUM_CLASSES, average="micro" + ).item() + macro_auroc = auroc_fn( + logits, y, num_classes=NUM_CLASSES, average="macro" + ).item() result = { "ground_truth": y, "logits": logits, - "loss": np.mean(loss), - "accuracy": np.mean(accuracy), - "micro_auroc": np.mean(micro_auroc), - "macro_auroc": np.mean(macro_auroc), + "loss": loss, + "accuracy": accuracy, + "micro_auroc": micro_auroc, + "macro_auroc": macro_auroc, } # Class-wise AUROC - class_auroc = class_auroc[0] for i, label in enumerate(LABELS): - result[f"{label}_auroc"] = class_auroc[i] + result[f"{label}_auroc"] = class_auroc[i].item() return result @@ -84,6 +89,12 @@ def evaluate(model, data_loader, loss_fn, device): # Wandb-specific params parser.add_argument("--runid", type=str, help="ID of train run") parser.add_argument("--project", type=str, default="ml4sci_deeplense_final") + parser.add_argument( + "--entity", + type=str, + default=os.environ.get("WANDB_ENTITY"), + help="Weights & Biases entity; defaults to WANDB_ENTITY env var or your logged-in user.", + ) # Device to run on parser.add_argument( @@ -92,9 +103,15 @@ def evaluate(model, data_loader, loss_fn, device): run_config = parser.parse_args() # Start wandb run - with wandb.init( - entity="_archil", project=run_config.project, id=run_config.runid, resume="must" - ): + wandb_kwargs = { + "project": run_config.project, + "id": run_config.runid, + "resume": "must", + } + if run_config.entity: + wandb_kwargs["entity"] = run_config.entity + + with wandb.init(**wandb_kwargs): # Get best device on machine device = get_device(run_config.device) @@ -169,9 +186,9 @@ def evaluate(model, data_loader, loss_fn, device): roc_auc = dict() for idx, cls in enumerate(LABELS): class_truth = (metrics["ground_truth"].numpy() == idx).astype(int) - class_pred = torch.nn.functional.softmax(metrics["logits"]).numpy()[ - ..., idx - ] + class_pred = torch.nn.functional.softmax( + metrics["logits"], dim=-1 + ).numpy()[..., idx] fpr[idx], tpr[idx], _ = roc_curve(class_truth, class_pred) _ = axes[0].plot( fpr[idx], diff --git a/DeepLense_Classification_Transformers_Archil_Srivastava/train.py b/DeepLense_Classification_Transformers_Archil_Srivastava/train.py index a5a6303c..173be126 100644 --- a/DeepLense_Classification_Transformers_Archil_Srivastava/train.py +++ b/DeepLense_Classification_Transformers_Archil_Srivastava/train.py @@ -4,7 +4,6 @@ from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts from torch.nn import CrossEntropyLoss from tqdm import tqdm -import numpy as np import argparse import os import wandb @@ -12,7 +11,6 @@ from data import LensDataset, WrapperDataset, get_transforms from constants import * from models import get_timm_model -from models.transformers import get_transformer_model from utils import get_device, set_seed from eval import evaluate @@ -54,10 +52,18 @@ def train_step(model, images, labels, optimizer, scheduler, criterion, device="c loss.backward() # Backward pass optimizer.step() # Optimize weights step if scheduler is not None: - scheduler.step(loss) # Modify learning rate if scheduler is set + scheduler.step() # Modify learning rate if scheduler is set return loss +def get_image_size(dataset_name): + if dataset_name == "Model_I": + return 150 + if dataset_name in ["Model_II", "Model_III"]: + return 64 + raise ValueError(f"Unsupported dataset: {dataset_name}") + + def train( model, train_loader, @@ -106,8 +112,9 @@ def train( # Alert wandb to log this training wandb.watch(model, criterion, log="all", log_freq=log_interval) - best_val_auroc, best_val_metrics = 0.0, dict() + best_val_auroc, best_val_metrics = float("-inf"), dict() batch_num = 0 + best_model_path = os.path.join(wandb.run.dir, "best_model.pt") for epoch in range(1, epochs + 1): for batch_data in tqdm(train_loader, desc=f"Epoch {epoch}"): batch_num += 1 @@ -127,7 +134,8 @@ def train( log_dict = { "epoch": epoch, "batch_num": batch_num, - "train/loss": loss, + "train/loss": loss.item(), + "train/lr": optimizer.param_groups[0]["lr"], "val/loss": val_metrics["loss"], "val/accuracy": val_metrics["accuracy"], "val/micro_auroc": val_metrics["micro_auroc"], @@ -166,12 +174,11 @@ def train( ] best_val_metrics = val_metrics # Save best model so far in disk - torch.save( - model.state_dict(), os.path.join(wandb.run.dir, "best_model.pt") - ) + torch.save(model.state_dict(), best_model_path) # Sync best model at a lesser frequency (i.e. at the end of each epoch) - wandb.save(os.path.join(wandb.run.dir, "best_model.pt")) + if os.path.exists(best_model_path): + wandb.save(best_model_path) return best_val_metrics @@ -181,7 +188,7 @@ def train( # W&B related parameters parser.add_argument( "--dataset", - choices=["Model_I", "Model_II", "Model_III", "Model_IV"], + choices=["Model_I", "Model_II", "Model_III"], default="Model_I", help="which data model", ) @@ -208,6 +215,12 @@ def train( parser.add_argument( "--device", choices=["cpu", "mps", "cuda", "tpu", "best"], default="best" ) + parser.add_argument( + "--entity", + type=str, + default=os.environ.get("WANDB_ENTITY"), + help="Weights & Biases entity; defaults to WANDB_ENTITY env var or your logged-in user.", + ) # Augmentations parser.add_argument("--random_zoom", type=float, default=1) @@ -233,25 +246,22 @@ def train( group = f"{group}-complex" # Start wandb run - with wandb.init( - entity="_archil", - project=run_config.project, - config=run_config, - group=group, - job_type=f"{run_config.dataset}", - ): + wandb_kwargs = { + "project": run_config.project, + "config": run_config, + "group": group, + "job_type": f"{run_config.dataset}", + } + if run_config.entity: + wandb_kwargs["entity"] = run_config.entity + + with wandb.init(**wandb_kwargs): # Set random seed - if run_config.seed: + if run_config.seed is not None: set_seed(run_config.seed) # Select image size based on dataset - if run_config.dataset == "Model_I": - IMAGE_SIZE = 150 - elif run_config.dataset in ["Model_II", "Model_III"]: - IMAGE_SIZE = 64 - else: - IMAGE_SIZE = None - raise ValueError("Dataset not found") + IMAGE_SIZE = get_image_size(run_config.dataset) # Select best device on the machine device = get_device(run_config.device) @@ -280,7 +290,12 @@ def train( # 90%-10% Train-validation split train_size = int(len(train_dataset) * 0.9) val_size = len(train_dataset) - train_size - train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size]) + split_generator = None + if run_config.seed is not None: + split_generator = torch.Generator().manual_seed(run_config.seed) + train_dataset, val_dataset = random_split( + train_dataset, [train_size, val_size], generator=split_generator + ) # Initialize train and validation datasets train_dataset = WrapperDataset(