diff --git a/main.py b/main.py index 69e2297..b17ca5e 100644 --- a/main.py +++ b/main.py @@ -295,7 +295,7 @@ def main(args): max_accuracy = checkpoint['max_accuracy'] if args.eval: - test_stats = evaluate(data_loader_val, model, device, num_tasks, distributed=True, amp=args.amp) + test_stats = evaluate(data_loader_val, model, device, num_tasks, distributed=args.distributed, amp=args.amp) print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.2f}%") return @@ -316,7 +316,7 @@ def main(args): lr_scheduler.step(epoch) - test_stats = evaluate(data_loader_val, model, device, num_tasks, distributed=True, amp=args.amp) + test_stats = evaluate(data_loader_val, model, device, num_tasks, distributed=args.distributed, amp=args.amp) print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.2f}%") max_accuracy = max(max_accuracy, test_stats["acc1"]) print(f'Max accuracy: {max_accuracy:.2f}%') diff --git a/models/crossvit.py b/models/crossvit.py index 0d2c856..9aa3cd1 100644 --- a/models/crossvit.py +++ b/models/crossvit.py @@ -211,6 +211,7 @@ def __init__(self, img_size=(224, 224), patch_size=(8, 16), in_chans=3, num_clas super().__init__() self.num_classes = num_classes + self.embed_dim = embed_dim if not isinstance(img_size, list): img_size = to_2tuple(img_size) self.img_size = img_size @@ -281,7 +282,7 @@ def get_classifier(self): def reset_classifier(self, num_classes, global_pool=''): self.num_classes = num_classes - self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.ModuleList([nn.Linear(self.embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in range(self.num_branches)]) def forward_features(self, x): B, C, H, W = x.shape