forked from mila-iqia/atari-representation-learning
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_probe.py
More file actions
98 lines (79 loc) · 4.56 KB
/
run_probe.py
File metadata and controls
98 lines (79 loc) · 4.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import sys
from scripts.run_contrastive import train_encoder
from atariari.methods.utils import (get_argparser, probe_only_methods,
train_encoder_methods)
from atariari.methods.encoders import ImpalaCNN, NatureCNN, PPOEncoder
from atariari.methods.majority import majority_baseline
from atariari.benchmark.probe import ProbeTrainer
from atariari.benchmark.episodes import get_episodes
import torch
import wandb
def run_probe(args):
wandb.config.update(vars(args))
tr_eps, val_eps, tr_labels, val_labels, test_eps, test_labels = get_episodes(steps=args.probe_steps,
env_name=args.env_name,
seed=args.seed,
num_processes=args.num_processes,
num_frame_stack=args.num_frame_stack,
downsample=args.downsample,
color=args.color,
entropy_threshold=args.entropy_threshold,
collect_mode=args.probe_collect_mode,
train_mode="probe",
checkpoint_index=args.checkpoint_index,
min_episode_length=args.batch_size)
print("got episodes!")
if args.train_encoder and args.method in train_encoder_methods:
print("Training encoder from scratch")
encoder = train_encoder(args)
encoder.probing = True
encoder.eval()
elif args.method == "pretrained-rl-agent":
encoder = PPOEncoder(args.env_name, args.checkpoint_index)
elif args.method == "majority":
encoder = None
else:
observation_shape = tr_eps[0][0].shape
if args.encoder_type == "Nature":
encoder = NatureCNN(observation_shape[0], args)
elif args.encoder_type == "Impala":
encoder = ImpalaCNN(observation_shape[0], args)
if args.weights_path == "None":
if args.method not in probe_only_methods:
sys.stderr.write("Probing without loading in encoder weights! Are sure you want to do that??")
else:
print("Print loading in encoder weights from probe of type {} from the following path: {}"
.format(args.method, args.weights_path))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
encoder.load_state_dict(torch.load(args.weights_path, map_location=device))
encoder.eval()
torch.set_num_threads(1)
if args.method == 'majority':
test_acc, test_f1score = majority_baseline(tr_labels, test_labels, wandb)
else:
trainer = ProbeTrainer(encoder=encoder,
epochs=args.epochs,
method_name=args.method,
lr=args.probe_lr,
batch_size=args.batch_size,
patience=args.patience,
wandb=wandb,
fully_supervised=(args.method == "supervised"),
save_dir=wandb.run.dir)
trainer.train(tr_eps, val_eps, tr_labels, val_labels)
test_acc, test_f1score = trainer.test(test_eps, test_labels)
print(test_acc, test_f1score)
wandb.log(test_acc)
wandb.log(test_f1score)
if __name__ == "__main__":
parser = get_argparser()
args = parser.parse_args()
if (args.weights_path and args.passing_file) is None:
args.train_encoder = False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tags = [device.type, 'probe', "fs: " + str(args.num_frame_stack) , args.env_name, args.encoder_type, "batch size: " + str(args.batch_size), "pretraining-steps: " + str(args.pretraining_steps), "probe steps: " + str(args.probe_steps), "epochs: " + str(args.epochs)]
if args.wandb_off:
os.environ["WANDB_MODE"] ="dryrun"
wandb.init(project=args.wandb_proj, entity=args.wandb_entity, tags=tags)
run_probe(args)