From 587b74e04a944d135a50e79dbe2950f1a3fa4c7a Mon Sep 17 00:00:00 2001 From: Gal Leibovich Date: Tue, 2 Jul 2019 13:43:23 +0300 Subject: [PATCH] Remove double call to reset_internal_state() on gym environments (#364) --- rl_coach/environments/gym_environment.py | 4 ---- rl_coach/tests/environments/test_gym_environment.py | 6 ++++-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/rl_coach/environments/gym_environment.py b/rl_coach/environments/gym_environment.py index 0052c40..b1ebafb 100644 --- a/rl_coach/environments/gym_environment.py +++ b/rl_coach/environments/gym_environment.py @@ -392,9 +392,6 @@ class GymEnvironment(Environment): else: screen.error("Error: Environment {} does not support human control.".format(self.env), crash=True) - # initialize the state by getting a new state from the environment - self.reset_internal_state(True) - # render if self.is_rendered: image = self.get_rendered_image() @@ -405,7 +402,6 @@ class GymEnvironment(Environment): self.renderer.create_screen(image.shape[1]*scale, image.shape[0]*scale) # the info is only updated after the first step - self.state = self.step(self.action_space.default_action).next_state self.state_space['measurements'] = VectorObservationSpace(shape=len(self.info.keys())) if self.env.spec and custom_reward_threshold is None: diff --git a/rl_coach/tests/environments/test_gym_environment.py b/rl_coach/tests/environments/test_gym_environment.py index 589467e..d777fc5 100644 --- a/rl_coach/tests/environments/test_gym_environment.py +++ b/rl_coach/tests/environments/test_gym_environment.py @@ -16,6 +16,7 @@ def atari_env(): seed=1, frame_skip=4, visualization_parameters=VisualizationParameters()) + env.reset_internal_state(True) return env @@ -26,6 +27,7 @@ def continuous_env(): seed=1, frame_skip=1, visualization_parameters=VisualizationParameters()) + env.reset_internal_state(True) return env @@ -56,7 +58,7 @@ def test_gym_continuous_environment(continuous_env): assert np.all(continuous_env.action_space.shape == np.array([1])) # make sure that the seed is working properly - assert np.sum(continuous_env.last_env_response.next_state['observation']) == 1.2661630859028832 + assert np.sum(continuous_env.last_env_response.next_state['observation']) == 0.6118565010687202 @pytest.mark.unit_test @@ -64,4 +66,4 @@ def test_step(atari_env): result = atari_env.step(0) if __name__ == '__main__': - test_gym_continuous_environment(continuous_env()) \ No newline at end of file + test_gym_continuous_environment(continuous_env())