-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathutils.py
More file actions
54 lines (44 loc) · 1.89 KB
/
utils.py
File metadata and controls
54 lines (44 loc) · 1.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from mcts import MCTS
import numpy as np
from kaggle_environments import make, evaluate
import torch
class KaggleAgent:
def __init__(self, model, game, args):
self.model = model
self.game = game
self.args = args
if self.args['search']:
self.mcts = MCTS(model, game, args)
def run(self, obs, conf):
player = obs['mark'] if obs['mark'] == 1 else -1
state = np.array(obs['board']).reshape(self.game.row_count, self.game.column_count)
state[state==2] = -1
state = self.game.change_perspective(state, player)
if self.args['search']:
policy = self.mcts.search(state)
else:
policy, _ = self.model(
torch.tensor(self.game.get_encoded_state(state), device=self.model.device).unsqueeze(0)
)
valid_moves = self.game.get_valid_moves(state)
policy *= valid_moves
policy /= np.sum(policy)
if self.args['temperature'] == 0:
action = int(np.argmax(policy))
elif self.args['temperature'] == float('inf'):
action = np.random.choice([r for r in range(self.game.action_size) if policy[r] > 0])
else:
policy = policy ** (1 / self.args['temperature'])
policy /= np.sum(policy)
action = np.random.choice(self.game.action_size, p=policy)
return action
def evaluateKaggle(gameName, players, num_iterations=1):
if num_iterations == 1:
env = make(gameName, debug=True)
env.run(players)
return env.render(mode="ipython")
results = np.array(evaluate(gameName, players, num_episodes=num_iterations))[:, 0]
print(f"""
Player 1 | Wins: {np.sum(results == 1)} | Draws: {np.sum(results == 0)} | Losses: {np.sum(results == -1)}
Player 2 | Wins: {np.sum(results == -1)} | Draws: {np.sum(results == 0)} | Losses: {np.sum(results == 1)}
""")