-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_local.py
More file actions
128 lines (101 loc) · 3.68 KB
/
test_local.py
File metadata and controls
128 lines (101 loc) · 3.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
""" """
import os
import sys
from pathlib import Path
sys.path.insert(0, "/home/wenh06/Jupyter/wenhao/workspace/torch_ecg/")
sys.path.insert(0, "/home/wenh06/Jupyter/wenhao/workspace/bib_lookup/")
tmp_data_dir = Path("/home/wenh06/Jupyter/wenhao/data/CinC2023/")
import numpy as np
import torch
from torch_ecg.utils.misc import dict_to_str, str2bool
from cfg import _BASE_DIR, ModelCfg, TrainCfg
from evaluate_model import evaluate_model
from run_model import run_model
# from train_model import train_challenge_model
from team_code import train_challenge_model
from utils.misc import func_indicator
# set_entry_test_flag(True)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if ModelCfg.torch_dtype == torch.float64:
torch.set_default_tensor_type(torch.DoubleTensor)
DTYPE = np.float64
else:
DTYPE = np.float32
TASK = "classification" # "classification" or "regression", etc.
trunc_data_folder = {limit: tmp_data_dir / f"trunc_subset_{limit}" for limit in [12, 24, 48, 72]}
@func_indicator("testing challenge entry")
def test_entry():
# data_folder = str(tmp_data_dir / "training_subset") # subset
data_folder = tmp_data_dir / "training" # full set
train_challenge_model(str(data_folder), str(TrainCfg.model_dir), verbose=2)
output_dir = _BASE_DIR / "tmp" / "output"
output_dir.mkdir(parents=True, exist_ok=True)
print("run model for the original data")
run_model(
str(TrainCfg.model_dir),
str(data_folder),
str(output_dir),
allow_failures=False,
verbose=2,
)
print("evaluate model for the original data")
(
challenge_score,
auroc_outcomes,
auprc_outcomes,
accuracy_outcomes,
f_measure_outcomes,
mse_cpcs,
mae_cpcs,
) = evaluate_model(str(data_folder), str(output_dir))
eval_res = {
"challenge_score": challenge_score,
"auroc_outcomes": auroc_outcomes,
"auprc_outcomes": auprc_outcomes,
"accuracy_outcomes": accuracy_outcomes,
"f_measure_outcomes": f_measure_outcomes,
"mse_cpcs": mse_cpcs,
"mae_cpcs": mae_cpcs,
}
print(f"original data evaluation results: {dict_to_str(eval_res)}")
for limit in [12, 24, 48, 72]:
print(f"run model for the {limit}h data")
run_model(
str(TrainCfg.model_dir),
str(trunc_data_folder[limit]),
str(output_dir),
allow_failures=False,
verbose=2,
)
print(f"evaluate model for the {limit}h data")
(
challenge_score,
auroc_outcomes,
auprc_outcomes,
accuracy_outcomes,
f_measure_outcomes,
mse_cpcs,
mae_cpcs,
) = evaluate_model(str(trunc_data_folder[limit]), str(output_dir))
eval_res = {
"challenge_score": challenge_score,
"auroc_outcomes": auroc_outcomes,
"auprc_outcomes": auprc_outcomes,
"accuracy_outcomes": accuracy_outcomes,
"f_measure_outcomes": f_measure_outcomes,
"mse_cpcs": mse_cpcs,
"mae_cpcs": mae_cpcs,
}
print(f"{limit}h data evaluation results: {dict_to_str(eval_res)}")
print("entry test passed")
if __name__ == "__main__":
TEST_FLAG = os.environ.get("CINC2023_REVENGER_TEST", False)
TEST_FLAG = str2bool(TEST_FLAG)
if not TEST_FLAG:
raise RuntimeError(
"please set CINC2023_REVENGER_TEST to true (1, y, yes, true, etc.) to run the test, e.g."
"\n CINC2023_REVENGER_TEST=1 python test_local.py "
)
# set_entry_test_flag(True)
test_entry()
# set_entry_test_flag(False)