forked from zhouyh310/SleepHGNN
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconfig.py
More file actions
52 lines (41 loc) · 1.28 KB
/
config.py
File metadata and controls
52 lines (41 loc) · 1.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from dataclasses import dataclass, field
from omegaconf import MISSING
from typing import List, Any
@dataclass(frozen=True)
class ConstFields:
n_subjects: int = 10
n_node_types: int = 4 # [EEG, EOG, EMG, ECG]
n_relation_types: int = 16
@dataclass
class BaseConfig:
const: ConstFields = field(default_factory=ConstFields)
task_name: str = MISSING
hydra: Any = field(default_factory=lambda: {
'run': {
'dir': r'outputs/${task_name}'
}
})
data_root: str = './data'
feature_dirname: str = 'psd'
label_dirname: str = 'label'
adj_mat_dirname: str = 'nmi_adj_mat/threshold_0.1'
output_root: str = r'outputs/${task_name}/.task'
criterion_root: str = r'${output_root}/criterion'
plot_root: str = r'${output_root}/plot'
shuffle: bool = True
k_fold: int = 10
max_epochs: int = 200
batch_size: int = 2048
lr: float = 5e-4
l2_decay: float = 1e-3
n_HGTs: int = 3
n_heads: int = 8
emb_dim: int = 128
lin_dims: List[int] = field(default_factory=lambda: [512, 128])
lin_dropout: float = 0.2
@dataclass
class MyConfig(BaseConfig):
'''
Name the task and override the other settings (here, or in command line) that you want to modify.
'''
task_name: str = MISSING