diff --git a/lib/RLTrader.py b/lib/RLTrader.py index 97b3ca9..36d3974 100644 --- a/lib/RLTrader.py +++ b/lib/RLTrader.py @@ -223,10 +223,11 @@ def test(self, model_epoch: int = 0, should_render: bool = True): self.logger.info(f'Testing model ({self.study_name}__{model_epoch})') state = None - obs, done, rewards = test_env.reset(), [False], [] - while not all(done): + obs, rewards = test_env.reset(), [] + + for _ in range(len(test_provider.data_frame)): action, state = model.predict(obs, state=state) - obs, reward, done, _ = test_env.step(action) + obs, reward, _, __ = test_env.step(action) rewards.append(reward)