-
Notifications
You must be signed in to change notification settings - Fork 0
/
preprocess.py
101 lines (85 loc) · 3.79 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import cv2
import collections
from typing import Iterable, Optional
import gym
from torchvision import transforms
import numpy as np
class RepeatActionAndMaxFrame(gym.Wrapper):
def __init__(self, env: gym.Env, repeat: int=4,
clip_reward: Optional[bool]=False, no_ops: Optional[int]=0,
fire_first: Optional[bool]=False) -> None:
super().__init__(env)
self.repeat = repeat
self.shape = env.observation_space.low.shape
self.frame_buffer = np.zeros_like((2, self.shape))
self.clip_reward = clip_reward
self.no_ops = no_ops
self.fire_first = fire_first
def step(self, action):
t_reward = 0.0
done = False
for i in range(self.repeat):
obs, reward, done, info = self.env.step(action)
if self.clip_reward:
reward = np.clip(np.array([reward]), -1, 1)[0]
t_reward += reward
idx = i % 2
self.frame_buffer[idx] = obs
if done:
break
max_frame = np.maximum(self.frame_buffer[0], self.frame_buffer[1])
return max_frame, t_reward, done, info
def reset(self):
obs = self.env.reset()
no_ops = np.random.randint(self.no_ops) + 1 if self.no_ops > 0 else 0
for _ in range(no_ops):
_, _, done, _ = self.env.step(0)
if done:
self.env.reset()
if self.fire_first:
assert self.env.unwrapped.get_action_meanings()[1] == "FIRE"
obs, _, _, _ = self.env.step(1)
self.frame_buffer = np.zeros_like((2, self.shape))
self.frame_buffer[0] = obs
return obs
class PreprocessFrame(gym.ObservationWrapper):
def __init__(self, shape: Iterable, env: Optional[gym.Env]=None) -> None:
super().__init__(env)
self.shape = (shape[2], shape[0], shape[1]) # swap channel from last position to first position
self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=self.shape, dtype= np.float32)
def observation(self, observation):
# transform = transforms.Compose([
# transforms.ToTensor(),
# transforms.Grayscale(),
# transforms.Resize(self.shape[1:])
# ])
# new_observation = transform(observation).numpy()
# new_observation = new_observation.reshape(self.shape)
new_observation = cv2.cvtColor(observation, cv2.COLOR_RGB2GRAY)
new_observation = cv2.resize(new_observation, self.shape[1:])
new_observation = np.array(new_observation, dtype=np.uint8).reshape(self.shape) / 255.0
return new_observation
class StackFrames(gym.ObservationWrapper):
def __init__(self, env: gym.Env, repeat: int) -> None:
super().__init__(env)
self.observation_space = gym.spaces.Box(
env.observation_space.low.repeat(repeat, axis=0),
env.observation_space.high.repeat(repeat, axis=0),
dtype=np.float32
)
self.stack = collections.deque(maxlen=repeat)
def reset(self):
self.stack.clear()
observation = self.env.reset()
for _ in range(self.stack.maxlen):
self.stack.append(observation)
return np.array(self.stack).reshape(self.observation_space.low.shape) # can be low shape or high shape
def observation(self, observation):
self.stack.append(observation)
return np.array(self.stack).reshape(self.observation_space.low.shape)
def make_env(env_name, shape=(84, 84, 1), repeat=4, clip_reward=False, no_ops=0, fire_first=False):
env = gym.make(env_name)
env = RepeatActionAndMaxFrame(env, repeat, clip_reward, no_ops, fire_first)
env = PreprocessFrame(shape, env)
env = StackFrames(env, repeat)
return env