diff --git a/DeepLense_Classification_Transformers_Archil_Srivastava/eval.py b/DeepLense_Classification_Transformers_Archil_Srivastava/eval.py index a3f12b32..b82e58a8 100644 --- a/DeepLense_Classification_Transformers_Archil_Srivastava/eval.py +++ b/DeepLense_Classification_Transformers_Archil_Srivastava/eval.py @@ -58,6 +58,7 @@ def evaluate(model, data_loader, loss_fn, device): loss.append(loss_fn(logits, y)) accuracy.append(accuracy_fn(logits, y, num_classes=NUM_CLASSES)) class_auroc.append(auroc_fn(logits, y, num_classes=NUM_CLASSES, average=None)) + micro_auroc.append(auroc_fn(logits, y, num_classes=NUM_CLASSES, average="micro")) macro_auroc.append(auroc_fn(logits, y, num_classes=NUM_CLASSES, average="macro")) result = { @@ -84,6 +85,7 @@ def evaluate(model, data_loader, loss_fn, device): # Wandb-specific params parser.add_argument("--runid", type=str, help="ID of train run") parser.add_argument("--project", type=str, default="ml4sci_deeplense_final") + parser.add_argument("--entity", type=str, default=os.environ.get("WANDB_ENTITY", None)) # Device to run on parser.add_argument( @@ -93,7 +95,7 @@ def evaluate(model, data_loader, loss_fn, device): # Start wandb run with wandb.init( - entity="_archil", project=run_config.project, id=run_config.runid, resume="must" + entity=run_config.entity, project=run_config.project, id=run_config.runid, resume="must" ): # Get best device on machine device = get_device(run_config.device) @@ -169,7 +171,7 @@ def evaluate(model, data_loader, loss_fn, device): roc_auc = dict() for idx, cls in enumerate(LABELS): class_truth = (metrics["ground_truth"].numpy() == idx).astype(int) - class_pred = torch.nn.functional.softmax(metrics["logits"]).numpy()[ + class_pred = torch.nn.functional.softmax(metrics["logits"], dim=-1).numpy()[ ..., idx ] fpr[idx], tpr[idx], _ = roc_curve(class_truth, class_pred) diff --git a/DeepLense_Classification_Transformers_Archil_Srivastava/train.py b/DeepLense_Classification_Transformers_Archil_Srivastava/train.py index a5a6303c..2ef4fcb9 100644 --- a/DeepLense_Classification_Transformers_Archil_Srivastava/train.py +++ b/DeepLense_Classification_Transformers_Archil_Srivastava/train.py @@ -187,6 +187,7 @@ def train( ) parser.add_argument("--log_interval", type=int, default=100) parser.add_argument("--project", type=str, default="ml4sci_deeplense_final") + parser.add_argument("--entity", type=str, default=os.environ.get("WANDB_ENTITY", None)) # Timm-Specific parameters parser.add_argument("--model_name", type=str, default="vit_base_patch16_224") @@ -234,7 +235,7 @@ def train( # Start wandb run with wandb.init( - entity="_archil", + entity=run_config.entity, project=run_config.project, config=run_config, group=group,