import os import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) import pytest from rl_coach.environments.gym_environment import GymEnvironment from rl_coach.base_parameters import VisualizationParameters import numpy as np from rl_coach.spaces import DiscreteActionSpace, BoxActionSpace, ImageObservationSpace, VectorObservationSpace @pytest.fixture() def atari_env(): # create a breakout gym environment env = GymEnvironment(level='Breakout-v0', seed=1, frame_skip=4, visualization_parameters=VisualizationParameters()) env.reset_internal_state(True) return env @pytest.fixture() def continuous_env(): # create a breakout gym environment env = GymEnvironment(level='Pendulum-v0', seed=1, frame_skip=1, visualization_parameters=VisualizationParameters()) env.reset_internal_state(True) return env @pytest.mark.unit_test def test_gym_discrete_environment(atari_env): # observation space assert type(atari_env.state_space['observation']) == ImageObservationSpace assert np.all(atari_env.state_space['observation'].shape == [210, 160, 3]) assert np.all(atari_env.last_env_response.next_state['observation'].shape == (210, 160, 3)) # action space assert type(atari_env.action_space) == DiscreteActionSpace assert np.all(atari_env.action_space.high == 3) # make sure that the seed is working properly assert np.sum(atari_env.last_env_response.next_state['observation']) == 4115856 @pytest.mark.unit_test def test_gym_continuous_environment(continuous_env): # observation space assert type(continuous_env.state_space['observation']) == VectorObservationSpace assert np.all(continuous_env.state_space['observation'].shape == [3]) assert np.all(continuous_env.last_env_response.next_state['observation'].shape == (3,)) # action space assert type(continuous_env.action_space) == BoxActionSpace assert np.all(continuous_env.action_space.shape == np.array([1])) # make sure that the seed is working properly assert np.sum(continuous_env.last_env_response.next_state['observation']) == 0.6118565010687202 @pytest.mark.unit_test def test_step(atari_env): result = atari_env.step(0) if __name__ == '__main__': test_gym_continuous_environment(continuous_env())