-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsb3_minerl_envs.py
69 lines (50 loc) · 1.9 KB
/
sb3_minerl_envs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import gym
from gym.envs.registration import register
from gym_wrappers import (
ObservationToInfos,
DictToMultiDiscreteActionSpace,
HiddenStateObservationSpace,
ObservationToCPU,
)
def sb3_minerl_findcave_env(minerl_agent):
env = gym.make("MineRLBasaltFindCave-v0")
# Make env compatible with SB3
sb3_env = ObservationToInfos(env)
sb3_env = DictToMultiDiscreteActionSpace(sb3_env, minerl_agent)
sb3_env = HiddenStateObservationSpace(sb3_env, minerl_agent)
sb3_env = ObservationToCPU(sb3_env)
return sb3_env
def sb3_minerl_makewaterfall_env(minerl_agent):
env = gym.make("MineRLBasaltMakeWaterfall-v0")
# Make env compatible with SB3
sb3_env = ObservationToInfos(env)
sb3_env = DictToMultiDiscreteActionSpace(sb3_env, minerl_agent)
sb3_env = HiddenStateObservationSpace(sb3_env, minerl_agent)
sb3_env = ObservationToCPU(sb3_env)
return sb3_env
def sb3_minerl_buildvillagehouse_env(minerl_agent):
env = gym.make("MineRLBasaltBuildVillageHouse-v0")
# Make env compatible with SB3
sb3_env = ObservationToInfos(env)
sb3_env = DictToMultiDiscreteActionSpace(sb3_env, minerl_agent)
sb3_env = HiddenStateObservationSpace(sb3_env, minerl_agent)
sb3_env = ObservationToCPU(sb3_env)
return sb3_env
def sb3_minerl_createvillageanimalpen_env(minerl_agent):
env = gym.make("MineRLBasaltCreateVillageAnimalPen-v0")
# Make env compatible with SB3
sb3_env = ObservationToInfos(env)
sb3_env = DictToMultiDiscreteActionSpace(sb3_env, minerl_agent)
sb3_env = HiddenStateObservationSpace(sb3_env, minerl_agent)
sb3_env = ObservationToCPU(sb3_env)
return sb3_env
for env_name in [
"FindCave",
"MakeWaterfall",
"BuildVillageHouse",
"CreateVillageAnimalPen",
]:
register(
"MineRLBasalt" + env_name + "SB3" + "-v0",
entry_point=f"sb3_minerl_envs:sb3_minerl_{env_name.lower()}_env",
)