From 97a1414ec182c09609ebe141ff6acc350cc352e5 Mon Sep 17 00:00:00 2001 From: abhrac Date: Wed, 12 Jan 2022 15:02:10 +0000 Subject: [PATCH] honoring distributed flag + fixing reset_classifier --- main.py | 4 ++-- models/crossvit.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 69e2297..b17ca5e 100644 --- a/main.py +++ b/main.py @@ -295,7 +295,7 @@ def main(args): max_accuracy = checkpoint['max_accuracy'] if args.eval: - test_stats = evaluate(data_loader_val, model, device, num_tasks, distributed=True, amp=args.amp) + test_stats = evaluate(data_loader_val, model, device, num_tasks, distributed=args.distributed, amp=args.amp) print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.2f}%") return @@ -316,7 +316,7 @@ def main(args): lr_scheduler.step(epoch) - test_stats = evaluate(data_loader_val, model, device, num_tasks, distributed=True, amp=args.amp) + test_stats = evaluate(data_loader_val, model, device, num_tasks, distributed=args.distributed, amp=args.amp) print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.2f}%") max_accuracy = max(max_accuracy, test_stats["acc1"]) print(f'Max accuracy: {max_accuracy:.2f}%') diff --git a/models/crossvit.py b/models/crossvit.py index 0d2c856..9aa3cd1 100644 --- a/models/crossvit.py +++ b/models/crossvit.py @@ -211,6 +211,7 @@ def __init__(self, img_size=(224, 224), patch_size=(8, 16), in_chans=3, num_clas super().__init__() self.num_classes = num_classes + self.embed_dim = embed_dim if not isinstance(img_size, list): img_size = to_2tuple(img_size) self.img_size = img_size @@ -281,7 +282,7 @@ def get_classifier(self): def reset_classifier(self, num_classes, global_pool=''): self.num_classes = num_classes - self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.ModuleList([nn.Linear(self.embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in range(self.num_branches)]) def forward_features(self, x): B, C, H, W = x.shape