diff --git a/src/train.py b/src/train.py index 8fc25fa..491ed1f 100755 --- a/src/train.py +++ b/src/train.py @@ -102,7 +102,7 @@ def run_episode(env, policy, scaler, animate=False): observes.append(obs) action = policy.sample(obs).reshape((1, -1)).astype(np.float64) actions.append(action) - obs, reward, done, _ = env.step(action) + obs, reward, done, _ = env.step(np.squeeze(action)) if not isinstance(reward, float): reward = np.asscalar(reward) rewards.append(reward)