Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 35 additions & 18 deletions DeepLense_Classification_Transformers_Archil_Srivastava/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ def evaluate(model, data_loader, loss_fn, device):
"""
model.eval() # Switch on evaluation model

# Initialize lists for different metrics
loss, accuracy, class_auroc, micro_auroc, macro_auroc = [], [], [], [], []
# Collect predictions and labels to compute metrics in one pass
logits, y = [], []

# Iterate over batches and accumulate metrics
Expand All @@ -55,24 +54,30 @@ def evaluate(model, data_loader, loss_fn, device):

# Concatenate all results
logits, y = torch.cat(logits), torch.cat(y)
loss.append(loss_fn(logits, y))
accuracy.append(accuracy_fn(logits, y, num_classes=NUM_CLASSES))
class_auroc.append(auroc_fn(logits, y, num_classes=NUM_CLASSES, average=None))
macro_auroc.append(auroc_fn(logits, y, num_classes=NUM_CLASSES, average="macro"))
loss = loss_fn(logits, y).item()
accuracy = accuracy_fn(logits, y, num_classes=NUM_CLASSES).item()
class_auroc = auroc_fn(
logits, y, num_classes=NUM_CLASSES, average=None
).cpu()
micro_auroc = auroc_fn(
logits, y, num_classes=NUM_CLASSES, average="micro"
).item()
macro_auroc = auroc_fn(
logits, y, num_classes=NUM_CLASSES, average="macro"
).item()

result = {
"ground_truth": y,
"logits": logits,
"loss": np.mean(loss),
"accuracy": np.mean(accuracy),
"micro_auroc": np.mean(micro_auroc),
"macro_auroc": np.mean(macro_auroc),
"loss": loss,
"accuracy": accuracy,
"micro_auroc": micro_auroc,
"macro_auroc": macro_auroc,
}

# Class-wise AUROC
class_auroc = class_auroc[0]
for i, label in enumerate(LABELS):
result[f"{label}_auroc"] = class_auroc[i]
result[f"{label}_auroc"] = class_auroc[i].item()

return result

Expand All @@ -84,6 +89,12 @@ def evaluate(model, data_loader, loss_fn, device):
# Wandb-specific params
parser.add_argument("--runid", type=str, help="ID of train run")
parser.add_argument("--project", type=str, default="ml4sci_deeplense_final")
parser.add_argument(
"--entity",
type=str,
default=os.environ.get("WANDB_ENTITY"),
help="Weights & Biases entity; defaults to WANDB_ENTITY env var or your logged-in user.",
)

# Device to run on
parser.add_argument(
Expand All @@ -92,9 +103,15 @@ def evaluate(model, data_loader, loss_fn, device):
run_config = parser.parse_args()

# Start wandb run
with wandb.init(
entity="_archil", project=run_config.project, id=run_config.runid, resume="must"
):
wandb_kwargs = {
"project": run_config.project,
"id": run_config.runid,
"resume": "must",
}
if run_config.entity:
wandb_kwargs["entity"] = run_config.entity

with wandb.init(**wandb_kwargs):
# Get best device on machine
device = get_device(run_config.device)

Expand Down Expand Up @@ -169,9 +186,9 @@ def evaluate(model, data_loader, loss_fn, device):
roc_auc = dict()
for idx, cls in enumerate(LABELS):
class_truth = (metrics["ground_truth"].numpy() == idx).astype(int)
class_pred = torch.nn.functional.softmax(metrics["logits"]).numpy()[
..., idx
]
class_pred = torch.nn.functional.softmax(
metrics["logits"], dim=-1
).numpy()[..., idx]
fpr[idx], tpr[idx], _ = roc_curve(class_truth, class_pred)
_ = axes[0].plot(
fpr[idx],
Expand Down
67 changes: 41 additions & 26 deletions DeepLense_Classification_Transformers_Archil_Srivastava/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,13 @@
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
import numpy as np
import argparse
import os
import wandb

from data import LensDataset, WrapperDataset, get_transforms
from constants import *
from models import get_timm_model
from models.transformers import get_transformer_model
from utils import get_device, set_seed
from eval import evaluate

Expand Down Expand Up @@ -54,10 +52,18 @@ def train_step(model, images, labels, optimizer, scheduler, criterion, device="c
loss.backward() # Backward pass
optimizer.step() # Optimize weights step
if scheduler is not None:
scheduler.step(loss) # Modify learning rate if scheduler is set
scheduler.step() # Modify learning rate if scheduler is set
return loss


def get_image_size(dataset_name):
if dataset_name == "Model_I":
return 150
if dataset_name in ["Model_II", "Model_III"]:
return 64
raise ValueError(f"Unsupported dataset: {dataset_name}")


def train(
model,
train_loader,
Expand Down Expand Up @@ -106,8 +112,9 @@ def train(
# Alert wandb to log this training
wandb.watch(model, criterion, log="all", log_freq=log_interval)

best_val_auroc, best_val_metrics = 0.0, dict()
best_val_auroc, best_val_metrics = float("-inf"), dict()
batch_num = 0
best_model_path = os.path.join(wandb.run.dir, "best_model.pt")
for epoch in range(1, epochs + 1):
for batch_data in tqdm(train_loader, desc=f"Epoch {epoch}"):
batch_num += 1
Expand All @@ -127,7 +134,8 @@ def train(
log_dict = {
"epoch": epoch,
"batch_num": batch_num,
"train/loss": loss,
"train/loss": loss.item(),
"train/lr": optimizer.param_groups[0]["lr"],
"val/loss": val_metrics["loss"],
"val/accuracy": val_metrics["accuracy"],
"val/micro_auroc": val_metrics["micro_auroc"],
Expand Down Expand Up @@ -166,12 +174,11 @@ def train(
]
best_val_metrics = val_metrics
# Save best model so far in disk
torch.save(
model.state_dict(), os.path.join(wandb.run.dir, "best_model.pt")
)
torch.save(model.state_dict(), best_model_path)

# Sync best model at a lesser frequency (i.e. at the end of each epoch)
wandb.save(os.path.join(wandb.run.dir, "best_model.pt"))
if os.path.exists(best_model_path):
wandb.save(best_model_path)

return best_val_metrics

Expand All @@ -181,7 +188,7 @@ def train(
# W&B related parameters
parser.add_argument(
"--dataset",
choices=["Model_I", "Model_II", "Model_III", "Model_IV"],
choices=["Model_I", "Model_II", "Model_III"],
default="Model_I",
help="which data model",
)
Expand All @@ -208,6 +215,12 @@ def train(
parser.add_argument(
"--device", choices=["cpu", "mps", "cuda", "tpu", "best"], default="best"
)
parser.add_argument(
"--entity",
type=str,
default=os.environ.get("WANDB_ENTITY"),
help="Weights & Biases entity; defaults to WANDB_ENTITY env var or your logged-in user.",
)

# Augmentations
parser.add_argument("--random_zoom", type=float, default=1)
Expand All @@ -233,25 +246,22 @@ def train(
group = f"{group}-complex"

# Start wandb run
with wandb.init(
entity="_archil",
project=run_config.project,
config=run_config,
group=group,
job_type=f"{run_config.dataset}",
):
wandb_kwargs = {
"project": run_config.project,
"config": run_config,
"group": group,
"job_type": f"{run_config.dataset}",
}
if run_config.entity:
wandb_kwargs["entity"] = run_config.entity

with wandb.init(**wandb_kwargs):
# Set random seed
if run_config.seed:
if run_config.seed is not None:
set_seed(run_config.seed)

# Select image size based on dataset
if run_config.dataset == "Model_I":
IMAGE_SIZE = 150
elif run_config.dataset in ["Model_II", "Model_III"]:
IMAGE_SIZE = 64
else:
IMAGE_SIZE = None
raise ValueError("Dataset not found")
IMAGE_SIZE = get_image_size(run_config.dataset)

# Select best device on the machine
device = get_device(run_config.device)
Expand Down Expand Up @@ -280,7 +290,12 @@ def train(
# 90%-10% Train-validation split
train_size = int(len(train_dataset) * 0.9)
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])
split_generator = None
if run_config.seed is not None:
split_generator = torch.Generator().manual_seed(run_config.seed)
train_dataset, val_dataset = random_split(
train_dataset, [train_size, val_size], generator=split_generator
)

# Initialize train and validation datasets
train_dataset = WrapperDataset(
Expand Down