1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

Remove double call to reset_internal_state() on gym environments (#364)

This commit is contained in:
Gal Leibovich
2019-07-02 13:43:23 +03:00
committed by GitHub
parent a576ab5659
commit 587b74e04a
2 changed files with 4 additions and 6 deletions

View File

@@ -392,9 +392,6 @@ class GymEnvironment(Environment):
else: else:
screen.error("Error: Environment {} does not support human control.".format(self.env), crash=True) screen.error("Error: Environment {} does not support human control.".format(self.env), crash=True)
# initialize the state by getting a new state from the environment
self.reset_internal_state(True)
# render # render
if self.is_rendered: if self.is_rendered:
image = self.get_rendered_image() image = self.get_rendered_image()
@@ -405,7 +402,6 @@ class GymEnvironment(Environment):
self.renderer.create_screen(image.shape[1]*scale, image.shape[0]*scale) self.renderer.create_screen(image.shape[1]*scale, image.shape[0]*scale)
# the info is only updated after the first step # the info is only updated after the first step
self.state = self.step(self.action_space.default_action).next_state
self.state_space['measurements'] = VectorObservationSpace(shape=len(self.info.keys())) self.state_space['measurements'] = VectorObservationSpace(shape=len(self.info.keys()))
if self.env.spec and custom_reward_threshold is None: if self.env.spec and custom_reward_threshold is None:

View File

@@ -16,6 +16,7 @@ def atari_env():
seed=1, seed=1,
frame_skip=4, frame_skip=4,
visualization_parameters=VisualizationParameters()) visualization_parameters=VisualizationParameters())
env.reset_internal_state(True)
return env return env
@@ -26,6 +27,7 @@ def continuous_env():
seed=1, seed=1,
frame_skip=1, frame_skip=1,
visualization_parameters=VisualizationParameters()) visualization_parameters=VisualizationParameters())
env.reset_internal_state(True)
return env return env
@@ -56,7 +58,7 @@ def test_gym_continuous_environment(continuous_env):
assert np.all(continuous_env.action_space.shape == np.array([1])) assert np.all(continuous_env.action_space.shape == np.array([1]))
# make sure that the seed is working properly # make sure that the seed is working properly
assert np.sum(continuous_env.last_env_response.next_state['observation']) == 1.2661630859028832 assert np.sum(continuous_env.last_env_response.next_state['observation']) == 0.6118565010687202
@pytest.mark.unit_test @pytest.mark.unit_test
@@ -64,4 +66,4 @@ def test_step(atari_env):
result = atari_env.step(0) result = atari_env.step(0)
if __name__ == '__main__': if __name__ == '__main__':
test_gym_continuous_environment(continuous_env()) test_gym_continuous_environment(continuous_env())