-
Notifications
You must be signed in to change notification settings - Fork 0
/
stack_husk.py
118 lines (111 loc) · 3.76 KB
/
stack_husk.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from craftground import craftground
from craftground.wrappers.action import ActionWrapper, Action
from craftground.wrappers.fast_reset import FastResetWrapper
from craftground.wrappers.time_limit import TimeLimitWrapper
from craftground.wrappers.vision import VisionWrapper
from stable_baselines3 import A2C, DQN
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import (
VecVideoRecorder,
DummyVecEnv,
VecFrameStack,
)
from wandb.integration.sb3 import WandbCallback
import wandb
from avoid_damage import AvoidDamageWrapper
def main():
run = wandb.init(
# set the wandb project where this run will be logged
project="craftground-sb3",
entity="jourhyang123",
# track hyperparameters and run metadata
group="escape-husk-stacked",
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
monitor_gym=True, # auto-upload the videos of agents playing the game
save_code=True, # optional save_code=True, # optional
)
env = craftground.make(
port=8022,
initialInventoryCommands=[],
initialPosition=None, # nullable
initialMobsCommands=[
"minecraft:husk ~ ~ ~5 {HandItems:[{Count:1,id:iron_shovel},{}]}",
"minecraft:husk ~ ~ ~-5 {HandItems:[{Count:1,id:iron_shovel},{}]}",
"minecraft:husk ~5 ~ ~5 {HandItems:[{Count:1,id:iron_shovel},{}]}",
"minecraft:husk ~10 ~ ~10 {HandItems:[{Count:1,id:iron_shovel},{}]}",
"minecraft:husk ~10 ~ ~-10 {HandItems:[{Count:1,id:iron_shovel},{}]}",
# player looks at south (positive Z) when spawn
],
imageSizeX=114,
imageSizeY=64,
visibleSizeX=114,
visibleSizeY=64,
seed=12345, # nullable
allowMobSpawn=False,
alwaysDay=True,
alwaysNight=False,
initialWeather="clear", # nullable
isHardCore=False,
isWorldFlat=True, # superflat world
obs_keys=["sound_subtitles"],
initialExtraCommands=[],
isHudHidden=False,
render_action=True,
render_distance=2,
simulation_distance=5,
)
env = FastResetWrapper(
TimeLimitWrapper(
ActionWrapper(
AvoidDamageWrapper(VisionWrapper(env, x_dim=114, y_dim=64)),
enabled_actions=[
Action.FORWARD,
Action.BACKWARD,
Action.STRAFE_LEFT,
Action.STRAFE_RIGHT,
Action.TURN_LEFT,
Action.TURN_RIGHT,
],
),
max_timesteps=400,
)
)
env = Monitor(env)
env = DummyVecEnv([lambda: env])
env = VecVideoRecorder(
env,
f"videos/{run.id}",
record_video_trigger=lambda x: x % 4000 == 0,
video_length=400,
)
env = VecFrameStack(env, n_stack=4)
model = DQN(
"CnnPolicy",
env,
verbose=1,
device="mps",
tensorboard_log=f"runs/{run.id}",
optimize_memory_usage=True,
replay_buffer_kwargs={"handle_timeout_termination": False},
)
model.learn(
total_timesteps=400000,
callback=WandbCallback(
gradient_save_freq=100,
model_save_path=f"models/{run.id}",
verbose=2,
),
)
model.save("dqn_stack_craftground")
run.finish()
# vec_env = model.get_env()
# obs = vec_env.reset()
# for i in range(1000):
# action, _state = model.predict(obs, deterministic=True)
# obs, reward, done, info = vec_env.step(action)
# # vec_env.render("human")
# # VecEnv resets automatically
# # if done:
# # obs = vec_env.reset()
if __name__ == "__main__":
main()