-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathmain.py
More file actions
98 lines (73 loc) · 3.22 KB
/
main.py
File metadata and controls
98 lines (73 loc) · 3.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from functools import reduce
import minerl
import gym
import os
import tensorflow as tf
import yaml
from policy.agent import create_flat_agent
from hierarchy.subtask_agent import ItemAgent
from hierarchy.subtasks_extraction import TrajectoryInformation
from utils.fake_env import FakeEnv
import argparse
from utils.config_validation import Pipeline, Task
from utils.wrappers import wrap_env
from utils.tf_util import config_gpu
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
config_gpu()
def load_trajectories(task, max_trj=300):
data = minerl.data.make(task.environment, task.data_dir)
trajectories = []
for trj_name in data.get_trajectory_names()[:max_trj]:
trajectories.append(TrajectoryInformation(env_name=data.environment, trajectory_name=trj_name))
return trajectories
def run_task(task: Task):
if task.agent_type == 'flat':
env = wrap_env(gym.make(task.environment), task.cfg.wrappers)
agent = create_flat_agent(task, env)
if task.source == 'expert':
for trajectory in load_trajectories(task):
# todo replace 'log' with parameter name
agent.add_demo(wrap_env(FakeEnv(data=trajectory.trajectory_by_subtask['log']), task.cfg.wrappers))
agent.pre_train(task)
agent.save(task.cfg.agent.save_dir)
elif task.source == 'agent':
summary_writer = tf.summary.create_file_writer('train/')
with summary_writer.as_default():
scores_, _ = agent.train(env, task)
env.close()
agent.save(task.cfg.agent.save_dir)
elif task.agent_type == 'hierarchical':
if task.source == 'expert':
env = wrap_env(gym.make(task.environment), task.cfg.wrappers)
trajectories = load_trajectories(task)
unique_subtasks = reduce(lambda x, y: x.union(y),
[set(q) for q in [t.trajectory_by_subtask.keys() for t in trajectories]])
for subtask in unique_subtasks:
if subtask not in ["cobblestone", "iron_ore"]:
continue
agent = create_flat_agent(task, env)
for trj in trajectories:
if not trj.trajectory_by_subtask.get(subtask, None):
continue
agent.add_demo(wrap_env(FakeEnv(data=trj.trajectory_by_subtask[subtask]), task.cfg.wrappers))
agent.pre_train(task)
agent.save(task.cfg.agent.save_dir + subtask + '/')
elif task.source == 'agent':
env = gym.make(task.environment)
item_agent = ItemAgent(task)
item_agent.train(env, task)
def run_pipeline(pipeline: Pipeline):
for task in pipeline.pipeline:
run_task(task)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, action="store", help='yaml file with settings', required=False, default='configs/eval-diamond.yaml')
params = parser.parse_args()
with open(params.config, "r") as f:
config = yaml.safe_load(f)
with tf.device('/gpu'):
# noinspection Pydantic
run_pipeline(Pipeline(**config))
if __name__ == '__main__':
main()