forked from VITA-Group/Structure-LTH
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot_mask.py
More file actions
57 lines (47 loc) · 1.76 KB
/
plot_mask.py
File metadata and controls
57 lines (47 loc) · 1.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import os
import torch
import matplotlib.pyplot as plt
checkpoint_4 = torch.load(os.path.expanduser("~/data/new_mask_4.pth.tar"))
checkpoint_10 = torch.load(os.path.expanduser("~/data/new_mask_10.pth.tar"))
os.makedirs("vis", exist_ok=True)
def process_heatmap_data(data):
masks = []
for i, mask_name in enumerate(data):
mask = data[mask_name]
if i >= 10:
masks.append(mask.view(mask.shape[0], -1))
masks = torch.cat(masks, 0)
print(masks.shape)
return masks
map_imp = process_heatmap_data(checkpoint_10['imp'])
map_refill = process_heatmap_data(checkpoint_4['refill'])
map_regroup = process_heatmap_data(checkpoint_10['regroup'])
for i in range(3):
plt.figure(figsize=(18,2))
plt.rcParams['font.sans-serif'] = 'Times New Roman'
plt.imshow(map_imp[i*512:(i+1)*512,:], cmap='viridis_r')
plt.xticks([],[])
plt.yticks([],[])
#plt.ylabel(y_torch, fontsize=30)
plt.savefig(f'vis/IMP_heatmap_{i}.svg', bbox_inches='tight')
plt.close()
for i in range(3):
plt.figure(figsize=(18,2))
plt.rcParams['font.sans-serif'] = 'Times New Roman'
plt.imshow(map_refill[i*512:(i+1)*512,:], cmap='viridis_r')
#plt.xticks([22,22+56,22+56+160,22+56+160+176],[])
plt.xticks([],[])
plt.yticks([],[])
#plt.ylabel(y_zico, fontsize=30)
plt.savefig(f'vis/refill_heatmap_{i}.svg', bbox_inches='tight')
plt.close()
for i in range(3):
plt.figure(figsize=(18,2))
plt.rcParams['font.sans-serif'] = 'Times New Roman'
plt.imshow(map_regroup[i*512:(i+1)*512,:], cmap='viridis_r')
#plt.xticks([22,22+56,22+56+160,22+56+160+176],[])
plt.xticks([],[])
plt.yticks([],[])
#plt.ylabel(y_nips, fontsize=30)
plt.savefig(f'vis/regroup_heatmap_{i}.svg', bbox_inches='tight')
plt.close()