-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconv_rl_agent.py
70 lines (60 loc) · 2.06 KB
/
conv_rl_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import json
from typing import Dict
import sys
from argparse import Namespace
from agents.conv_rl_agent.agent import Agent
from agents.lux.config import EnvConfig
from agents.lux.kit import GameState, process_obs, to_json, from_json, process_action, obs_to_game_state
import os
os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide"
### DO NOT REMOVE THE FOLLOWING CODE ###
# store potentially multiple dictionaries as kaggle imports code directly
agent_dict = dict()
agent_prev_obs = dict()
def agent_fn(observation, configurations):
"""
agent definition for kaggle submission.
"""
global agent_dict
step = observation.step
player = observation.player
remainingOverageTime = observation.remainingOverageTime
if step == 0:
env_cfg = EnvConfig.from_dict(configurations["env_cfg"])
agent_dict[player] = Agent(player, env_cfg)
agent_prev_obs[player] = dict()
agent = agent_dict[player]
agent = agent_dict[player]
obs = process_obs(
player, agent_prev_obs[player], step, json.loads(observation.obs))
agent_prev_obs[player] = obs
agent.step = step
if obs["real_env_steps"] < 0:
actions = agent.early_setup(step, obs, remainingOverageTime)
else:
actions = agent.act(step, obs, remainingOverageTime)
return process_action(actions)
if __name__ == "__main__":
def read_input():
"""
Reads input from stdin
"""
try:
return input()
except EOFError as eof:
raise SystemExit(eof)
step = 0
player_id = 0
configurations = None
i = 0
while True:
inputs = read_input()
obs = json.loads(inputs)
observation = Namespace(**dict(step=obs["step"], obs=json.dumps(
obs["obs"]), remainingOverageTime=obs["remainingOverageTime"], player=obs["player"], info=obs["info"]))
if i == 0:
configurations = obs["info"]["env_cfg"]
i += 1
actions = agent_fn(observation, dict(env_cfg=configurations))
# send actions to engine
print(json.dumps(actions))