diff --git a/lib/RLTrader.py b/lib/RLTrader.py index 625a3f8..649cb28 100644 --- a/lib/RLTrader.py +++ b/lib/RLTrader.py @@ -213,23 +213,28 @@ def test(self, model_epoch: int = 0, should_render: bool = True): del train_provider - test_env = SubprocVecEnv([make_env(test_provider, i) for i in range(self.n_envs)]) + test_env = DummyVecEnv([make_env(test_provider, i) for i in range(1)]) model_path = path.join('data', 'agents', f'{self.study_name}__{model_epoch}.pkl') model = self.Model.load(model_path, env=test_env) self.logger.info(f'Testing model ({self.study_name}__{model_epoch})') + zero_completed_obs = np.zeros((self.n_envs,) + test_env.observation_space.shape) + zero_completed_obs[0, :] = test_env.reset() + state = None - obs, rewards = test_env.reset(), [] + rewards = [] for _ in range(len(test_provider.data_frame)): - action, state = model.predict(obs, state=state) + action, state = model.predict(zero_completed_obs, state=state) obs, reward, _, __ = test_env.step(action) + zero_completed_obs[0, :] = obs + rewards.append(reward) - if should_render and self.n_envs == 1: + if should_render: test_env.render(mode='human') self.logger.info( diff --git a/optimize.py b/optimize.py index f003f84..5dd190a 100644 --- a/optimize.py +++ b/optimize.py @@ -12,7 +12,7 @@ def optimize_code(params): if __name__ == '__main__': - n_process = multiprocessing.cpu_count() - 4 + n_process = multiprocessing.cpu_count() params = {} processes = []