-
Notifications
You must be signed in to change notification settings - Fork 0
/
crossing.py
155 lines (125 loc) · 4.71 KB
/
crossing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from gym_minigrid.minigrid import *
from gym_minigrid.register import register
import itertools as itt
class CrossingEnv(MiniGridEnv):
"""
Environment with wall or lava obstacles, sparse reward.
"""
def __init__(self, size=9, num_crossings=1, obstacle_type=Lava, seed=None):
self.num_crossings = num_crossings
self.obstacle_type = obstacle_type
super().__init__(
grid_size=size,
max_steps=4*size*size,
# Set this to True for maximum speed
see_through_walls=False,
seed=None
)
def _gen_grid(self, width, height):
assert width % 2 == 1 and height % 2 == 1 # odd size
# Create an empty grid
self.grid = Grid(width, height)
# Generate the surrounding walls
self.grid.wall_rect(0, 0, width, height)
# Place the agent in the top-left corner
self.agent_pos = (1, 1)
self.agent_dir = 0
# Place a goal square in the bottom-right corner
self.grid.set(width - 2, height - 2, Goal())
# Place obstacles (lava or walls)
v, h = object(), object() # singleton `vertical` and `horizontal` objects
# Lava rivers or walls specified by direction and position in grid
rivers = [(v, i) for i in range(2, height - 2, 2)]
rivers += [(h, j) for j in range(2, width - 2, 2)]
self.np_random.shuffle(rivers)
rivers = rivers[:self.num_crossings] # sample random rivers
rivers_v = sorted([pos for direction, pos in rivers if direction is v])
rivers_h = sorted([pos for direction, pos in rivers if direction is h])
obstacle_pos = itt.chain(
itt.product(range(1, width - 1), rivers_h),
itt.product(rivers_v, range(1, height - 1)),
)
for i, j in obstacle_pos:
self.grid.set(i, j, self.obstacle_type())
# Sample path to goal
path = [h] * len(rivers_v) + [v] * len(rivers_h)
self.np_random.shuffle(path)
# Create openings
limits_v = [0] + rivers_v + [height - 1]
limits_h = [0] + rivers_h + [width - 1]
room_i, room_j = 0, 0
for direction in path:
if direction is h:
i = limits_v[room_i + 1]
j = self.np_random.choice(
range(limits_h[room_j] + 1, limits_h[room_j + 1]))
room_i += 1
elif direction is v:
i = self.np_random.choice(
range(limits_v[room_i] + 1, limits_v[room_i + 1]))
j = limits_h[room_j + 1]
room_j += 1
else:
assert False
self.grid.set(i, j, None)
self.mission = (
"avoid the lava and get to the green goal square"
if self.obstacle_type == Lava
else "find the opening and get to the green goal square"
)
class LavaCrossingEnv(CrossingEnv):
def __init__(self):
super().__init__(size=9, num_crossings=1)
class LavaCrossingS9N2Env(CrossingEnv):
def __init__(self):
super().__init__(size=9, num_crossings=2)
class LavaCrossingS9N3Env(CrossingEnv):
def __init__(self):
super().__init__(size=9, num_crossings=3)
class LavaCrossingS11N5Env(CrossingEnv):
def __init__(self):
super().__init__(size=11, num_crossings=5)
register(
id='MiniGrid-LavaCrossingS9N1-v0',
entry_point='gym_minigrid.envs:LavaCrossingEnv'
)
register(
id='MiniGrid-LavaCrossingS9N2-v0',
entry_point='gym_minigrid.envs:LavaCrossingS9N2Env'
)
register(
id='MiniGrid-LavaCrossingS9N3-v0',
entry_point='gym_minigrid.envs:LavaCrossingS9N3Env'
)
register(
id='MiniGrid-LavaCrossingS11N5-v0',
entry_point='gym_minigrid.envs:LavaCrossingS11N5Env'
)
class SimpleCrossingEnv(CrossingEnv):
def __init__(self):
super().__init__(size=9, num_crossings=1, obstacle_type=Wall)
class SimpleCrossingS9N2Env(CrossingEnv):
def __init__(self):
super().__init__(size=9, num_crossings=2, obstacle_type=Wall)
class SimpleCrossingS9N3Env(CrossingEnv):
def __init__(self):
super().__init__(size=9, num_crossings=3, obstacle_type=Wall)
class SimpleCrossingS11N5Env(CrossingEnv):
def __init__(self):
super().__init__(size=11, num_crossings=5, obstacle_type=Wall)
register(
id='MiniGrid-SimpleCrossingS9N1-v0',
entry_point='gym_minigrid.envs:SimpleCrossingEnv'
)
register(
id='MiniGrid-SimpleCrossingS9N2-v0',
entry_point='gym_minigrid.envs:SimpleCrossingS9N2Env'
)
register(
id='MiniGrid-SimpleCrossingS9N3-v0',
entry_point='gym_minigrid.envs:SimpleCrossingS9N3Env'
)
register(
id='MiniGrid-SimpleCrossingS11N5-v0',
entry_point='gym_minigrid.envs:SimpleCrossingS11N5Env'
)