forked from xiyuanzh/UniMTS
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate.py
More file actions
115 lines (89 loc) · 5.33 KB
/
evaluate.py
File metadata and controls
115 lines (89 loc) · 5.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import numpy as np
import torch
import argparse
import os
import numpy as np
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
import wandb
import datetime
from torch.utils.data import DataLoader, TensorDataset
from huggingface_hub import hf_hub_download
import zipfile
from data import load, load_multiple
from utils import compute_metrics_np
from contrastive import ContrastiveModule
def main(args):
repo_id = "xiyuanz/UniMTS"
checkpoint_file = "checkpoint/UniMTS.pth"
config_file = "config.json"
data_file = "UniMTS_data.zip"
if not os.path.exists("checkpoint"):
hf_hub_download(repo_id=repo_id, filename=checkpoint_file, local_dir="./")
hf_hub_download(repo_id=repo_id, filename=config_file, local_dir="./")
if not os.path.exists("UniMTS_data"):
hf_hub_download(repo_id=repo_id, filename=data_file, local_dir="./")
with zipfile.ZipFile("UniMTS_data.zip", 'r') as zip_ref:
zip_ref.extractall("./")
# load real data
dataset_list = ['Opp_g','UCIHAR','MotionSense','w-HAR','Shoaib','har70plus','realworld','TNDA-HAR','PAMAP',\
'USCHAD','Mhealth','Harth','ut-complex','Wharf','WISDM','DSADS','UTD-MHAD','MMAct']
real_inputs_list, real_masks_list, real_labels_list, label_list_list, all_text_list, _ = load_multiple(dataset_list, args.padding_size, args.data_path)
test_real_dataloader_list = []
for real_inputs, real_masks, real_labels in zip(real_inputs_list, real_masks_list, real_labels_list):
real_dataset = TensorDataset(real_inputs, real_masks, real_labels)
test_real_dataloader_list.append(DataLoader(real_dataset, batch_size=args.batch_size, shuffle=False))
date = datetime.datetime.now().strftime("%d-%m-%y_%H:%M")
wandb.init(
project='UniMTS',
name=f"{args.run_tag}_{args.stage}_" + f"{date}"
)
model = ContrastiveModule(args).cuda()
model.model.load_state_dict(torch.load(f'{args.checkpoint}'))
model.eval()
with torch.no_grad():
for ds, real_labels, test_real_dataloader, label_list, all_text in zip(dataset_list, real_labels_list, test_real_dataloader_list, label_list_list, all_text_list):
pred_whole, logits_whole = [], []
for input, mask, label in test_real_dataloader:
input = input.cuda()
mask = mask.cuda()
label = label.cuda()
if not args.gyro:
b, t, c = input.shape
indices = np.array([range(i, i+3) for i in range(0, c, 6)]).flatten()
input = input[:,:,indices]
b, t, c = input.shape
if args.stft:
input_stft = input.permute(0,2,1).reshape(b * c,t)
input_stft = torch.abs(torch.stft(input_stft, n_fft = 25, hop_length = 28, onesided = False, center = True, return_complex = True))
input_stft = input_stft.reshape(b, c, input_stft.shape[-2], input_stft.shape[-1]).reshape(b, c, t).permute(0,2,1)
input = torch.cat((input, input_stft), dim=-1)
input = input.reshape(b, t, 22, -1).permute(0, 3, 1, 2).unsqueeze(-1)
logits_per_imu, logits_per_text = model(input, all_text)
logits_whole.append(logits_per_imu)
pred = torch.argmax(logits_per_imu, dim=-1).detach().cpu().numpy()
pred_whole.append(pred)
pred = np.concatenate(pred_whole)
acc = accuracy_score(real_labels, pred)
prec = precision_score(real_labels, pred, average='macro')
rec = recall_score(real_labels, pred, average='macro')
f1 = f1_score(real_labels, pred, average='macro')
print(f"{ds} acc: {acc}, {ds} prec: {prec}, {ds} rec: {rec}, {ds} f1: {f1}")
wandb.log({f"{ds} acc": acc, f"{ds} prec": prec, f"{ds} rec": rec, f"{ds} f1": f1})
logits_whole = torch.cat(logits_whole)
r_at_1, r_at_2, r_at_3, r_at_4, r_at_5, mrr_score = compute_metrics_np(logits_whole.detach().cpu().numpy(), real_labels.numpy())
print(f"{ds} R@1: {r_at_1}, R@2: {r_at_2}, R@3: {r_at_3}, R@4: {r_at_4}, R@5: {r_at_5}, MRR: {mrr_score}")
wandb.log({f"{ds} R@1": r_at_1, f"{ds} R@2": r_at_2, f"{ds} R@3": r_at_3, f"{ds} R@4": r_at_4, f"{ds} R@5": r_at_5, f"{ds} MRR": mrr_score})
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Unified Pre-trained Motion Time Series Model')
# data
parser.add_argument('--padding_size', type=int, default='200', help='padding size (default: 200)')
parser.add_argument('--data_path', type=str, default='./UniMTS_data/', help='/path/to/data/')
# training
parser.add_argument('--run_tag', type=str, default='exp0', help='logging tag')
parser.add_argument('--stage', type=str, default='evaluation', help='training or evaluation stage')
parser.add_argument('--gyro', type=int, default=0, help='using gyro or not')
parser.add_argument('--stft', type=int, default=0, help='using stft or not')
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
parser.add_argument('--checkpoint', type=str, default='./checkpoint/UniMTS.pth', help='/path/to/checkpoint/')
args = parser.parse_args()
main(args)