mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Remove double call to reset_internal_state() on gym environments (#364)
This commit is contained in:
@@ -392,9 +392,6 @@ class GymEnvironment(Environment):
|
|||||||
else:
|
else:
|
||||||
screen.error("Error: Environment {} does not support human control.".format(self.env), crash=True)
|
screen.error("Error: Environment {} does not support human control.".format(self.env), crash=True)
|
||||||
|
|
||||||
# initialize the state by getting a new state from the environment
|
|
||||||
self.reset_internal_state(True)
|
|
||||||
|
|
||||||
# render
|
# render
|
||||||
if self.is_rendered:
|
if self.is_rendered:
|
||||||
image = self.get_rendered_image()
|
image = self.get_rendered_image()
|
||||||
@@ -405,7 +402,6 @@ class GymEnvironment(Environment):
|
|||||||
self.renderer.create_screen(image.shape[1]*scale, image.shape[0]*scale)
|
self.renderer.create_screen(image.shape[1]*scale, image.shape[0]*scale)
|
||||||
|
|
||||||
# the info is only updated after the first step
|
# the info is only updated after the first step
|
||||||
self.state = self.step(self.action_space.default_action).next_state
|
|
||||||
self.state_space['measurements'] = VectorObservationSpace(shape=len(self.info.keys()))
|
self.state_space['measurements'] = VectorObservationSpace(shape=len(self.info.keys()))
|
||||||
|
|
||||||
if self.env.spec and custom_reward_threshold is None:
|
if self.env.spec and custom_reward_threshold is None:
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ def atari_env():
|
|||||||
seed=1,
|
seed=1,
|
||||||
frame_skip=4,
|
frame_skip=4,
|
||||||
visualization_parameters=VisualizationParameters())
|
visualization_parameters=VisualizationParameters())
|
||||||
|
env.reset_internal_state(True)
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
@@ -26,6 +27,7 @@ def continuous_env():
|
|||||||
seed=1,
|
seed=1,
|
||||||
frame_skip=1,
|
frame_skip=1,
|
||||||
visualization_parameters=VisualizationParameters())
|
visualization_parameters=VisualizationParameters())
|
||||||
|
env.reset_internal_state(True)
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
@@ -56,7 +58,7 @@ def test_gym_continuous_environment(continuous_env):
|
|||||||
assert np.all(continuous_env.action_space.shape == np.array([1]))
|
assert np.all(continuous_env.action_space.shape == np.array([1]))
|
||||||
|
|
||||||
# make sure that the seed is working properly
|
# make sure that the seed is working properly
|
||||||
assert np.sum(continuous_env.last_env_response.next_state['observation']) == 1.2661630859028832
|
assert np.sum(continuous_env.last_env_response.next_state['observation']) == 0.6118565010687202
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit_test
|
@pytest.mark.unit_test
|
||||||
|
|||||||
Reference in New Issue
Block a user