1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

Release 0.9

Main changes are detailed below:

New features -
* CARLA 0.7 simulator integration
* Human control of the game play
* Recording of human game play and storing / loading the replay buffer
* Behavioral cloning agent and presets
* Golden tests for several presets
* Selecting between deep / shallow image embedders
* Rendering through pygame (with some boost in performance)

API changes -
* Improved environment wrapper API
* Added an evaluate flag to allow convenient evaluation of existing checkpoints
* Improve frameskip definition in Gym

Bug fixes -
* Fixed loading of checkpoints for agents with more than one network
* Fixed the N Step Q learning agent python3 compatibility
This commit is contained in:
Itai Caspi
2017-12-19 19:27:16 +02:00
committed by GitHub
parent 11faf19649
commit 125c7ee38d
41 changed files with 1713 additions and 260 deletions

View File

@@ -0,0 +1,62 @@
[CARLA/Server]
; If set to false, a mock controller will be used instead of waiting for a real
; client to connect.
UseNetworking=true
; Ports to use for the server-client communication. This can be overridden by
; the command-line switch `-world-port=N`, write and read ports will be set to
; N+1 and N+2 respectively.
WorldPort=2000
; Time-out in milliseconds for the networking operations.
ServerTimeOut=10000000000
; In synchronous mode, CARLA waits every frame until the control from the client
; is received.
SynchronousMode=true
; Send info about every non-player agent in the scene every frame, the
; information is attached to the measurements message. This includes other
; vehicles, pedestrians and traffic signs. Disabled by default to improve
; performance.
SendNonPlayerAgentsInfo=false
[CARLA/LevelSettings]
; Path of the vehicle class to be used for the player. Leave empty for default.
; Paths follow the pattern "/Game/Blueprints/Vehicles/Mustang/Mustang.Mustang_C"
PlayerVehicle=
; Number of non-player vehicles to be spawned into the level.
NumberOfVehicles=15
; Number of non-player pedestrians to be spawned into the level.
NumberOfPedestrians=30
; Index of the weather/lighting presets to use. If negative, the default presets
; of the map will be used.
WeatherId=1
; Seeds for the pseudo-random number generators.
SeedVehicles=123456789
SeedPedestrians=123456789
[CARLA/SceneCapture]
; Names of the cameras to be attached to the player, comma-separated, each of
; them should be defined in its own subsection. E.g., Uncomment next line to add
; a camera called MyCamera to the vehicle
Cameras=CameraRGB
; Now, every camera we added needs to be defined it in its own subsection.
[CARLA/SceneCapture/CameraRGB]
; Post-processing effect to be applied. Valid values:
; * None No effects applied.
; * SceneFinal Post-processing present at scene (bloom, fog, etc).
; * Depth Depth map ground-truth only.
; * SemanticSegmentation Semantic segmentation ground-truth only.
PostProcessing=SceneFinal
; Size of the captured image in pixels.
ImageSizeX=360
ImageSizeY=256
; Camera (horizontal) field of view in degrees.
CameraFOV=90
; Position of the camera relative to the car in centimeters.
CameraPositionX=200
CameraPositionY=0
CameraPositionZ=140
; Rotation of the camera relative to the car in degrees.
CameraRotationPitch=0
CameraRotationRoll=0
CameraRotationYaw=0

View File

@@ -15,13 +15,16 @@
#
from logger import *
from utils import Enum
from utils import Enum, get_open_port
from environments.gym_environment_wrapper import *
from environments.doom_environment_wrapper import *
from environments.carla_environment_wrapper import *
class EnvTypes(Enum):
Doom = "DoomEnvironmentWrapper"
Gym = "GymEnvironmentWrapper"
Carla = "CarlaEnvironmentWrapper"
def create_environment(tuning_parameters):

View File

@@ -0,0 +1,230 @@
import sys
from os import path, environ
try:
sys.path.append(path.join(environ.get('CARLA_ROOT'), 'PythonClient'))
from carla.client import CarlaClient
from carla.settings import CarlaSettings
from carla.tcp import TCPConnectionError
from carla.sensor import Camera
from carla.client import VehicleControl
except ImportError:
from logger import failed_imports
failed_imports.append("CARLA")
import numpy as np
import time
import logging
import subprocess
import signal
from environments.environment_wrapper import EnvironmentWrapper
from utils import *
from logger import screen, logger
from PIL import Image
# enum of the available levels and their path
class CarlaLevel(Enum):
TOWN1 = "/Game/Maps/Town01"
TOWN2 = "/Game/Maps/Town02"
key_map = {
'BRAKE': (274,), # down arrow
'GAS': (273,), # up arrow
'TURN_LEFT': (276,), # left arrow
'TURN_RIGHT': (275,), # right arrow
'GAS_AND_TURN_LEFT': (273, 276),
'GAS_AND_TURN_RIGHT': (273, 275),
'BRAKE_AND_TURN_LEFT': (274, 276),
'BRAKE_AND_TURN_RIGHT': (274, 275),
}
class CarlaEnvironmentWrapper(EnvironmentWrapper):
def __init__(self, tuning_parameters):
EnvironmentWrapper.__init__(self, tuning_parameters)
self.tp = tuning_parameters
# server configuration
self.server_height = self.tp.env.server_height
self.server_width = self.tp.env.server_width
self.port = get_open_port()
self.host = 'localhost'
self.map = CarlaLevel().get(self.tp.env.level)
# client configuration
self.verbose = self.tp.env.verbose
self.depth = self.tp.env.depth
self.stereo = self.tp.env.stereo
self.semantic_segmentation = self.tp.env.semantic_segmentation
self.height = self.server_height * (1 + int(self.depth) + int(self.semantic_segmentation))
self.width = self.server_width * (1 + int(self.stereo))
self.size = (self.width, self.height)
self.config = self.tp.env.config
if self.config:
# load settings from file
with open(self.config, 'r') as fp:
self.settings = fp.read()
else:
# hard coded settings
self.settings = CarlaSettings()
self.settings.set(
SynchronousMode=True,
SendNonPlayerAgentsInfo=False,
NumberOfVehicles=15,
NumberOfPedestrians=30,
WeatherId=1)
self.settings.randomize_seeds()
# add cameras
camera = Camera('CameraRGB')
camera.set_image_size(self.width, self.height)
camera.set_position(200, 0, 140)
camera.set_rotation(0, 0, 0)
self.settings.add_sensor(camera)
# open the server
self.server = self._open_server()
logging.disable(40)
# open the client
self.game = CarlaClient(self.host, self.port, timeout=99999999)
self.game.connect()
scene = self.game.load_settings(self.settings)
# get available start positions
positions = scene.player_start_spots
self.num_pos = len(positions)
self.iterator_start_positions = 0
# action space
self.discrete_controls = False
self.action_space_size = 2
self.action_space_high = [1, 1]
self.action_space_low = [-1, -1]
self.action_space_abs_range = np.maximum(np.abs(self.action_space_low), np.abs(self.action_space_high))
self.steering_strength = 0.5
self.gas_strength = 1.0
self.brake_strength = 0.5
self.actions = {0: [0., 0.],
1: [0., -self.steering_strength],
2: [0., self.steering_strength],
3: [self.gas_strength, 0.],
4: [-self.brake_strength, 0],
5: [self.gas_strength, -self.steering_strength],
6: [self.gas_strength, self.steering_strength],
7: [self.brake_strength, -self.steering_strength],
8: [self.brake_strength, self.steering_strength]}
self.actions_description = ['NO-OP', 'TURN_LEFT', 'TURN_RIGHT', 'GAS', 'BRAKE',
'GAS_AND_TURN_LEFT', 'GAS_AND_TURN_RIGHT',
'BRAKE_AND_TURN_LEFT', 'BRAKE_AND_TURN_RIGHT']
for idx, action in enumerate(self.actions_description):
for key in key_map.keys():
if action == key:
self.key_to_action[key_map[key]] = idx
self.num_speedup_steps = 30
# measurements
self.measurements_size = (1,)
self.autopilot = None
# env initialization
self.reset(True)
# render
if self.is_rendered:
image = self.get_rendered_image()
self.renderer.create_screen(image.shape[1], image.shape[0])
def _open_server(self):
log_path = path.join(logger.experiments_path, "CARLA_LOG_{}.txt".format(self.port))
with open(log_path, "wb") as out:
cmd = [path.join(environ.get('CARLA_ROOT'), 'CarlaUE4.sh'), self.map,
"-benchmark", "-carla-server", "-fps=10", "-world-port={}".format(self.port),
"-windowed -ResX={} -ResY={}".format(self.server_width, self.server_height),
"-carla-no-hud"]
if self.config:
cmd.append("-carla-settings={}".format(self.config))
p = subprocess.Popen(cmd, stdout=out, stderr=out)
return p
def _close_server(self):
os.killpg(os.getpgid(self.server.pid), signal.SIGKILL)
def _update_state(self):
# get measurements and observations
measurements = []
while type(measurements) == list:
measurements, sensor_data = self.game.read_data()
self.observation = sensor_data['CameraRGB'].data
self.location = (measurements.player_measurements.transform.location.x,
measurements.player_measurements.transform.location.y,
measurements.player_measurements.transform.location.z)
is_collision = measurements.player_measurements.collision_vehicles != 0 \
or measurements.player_measurements.collision_pedestrians != 0 \
or measurements.player_measurements.collision_other != 0
speed_reward = measurements.player_measurements.forward_speed - 1
if speed_reward > 30.:
speed_reward = 30.
self.reward = speed_reward \
- (measurements.player_measurements.intersection_otherlane * 5) \
- (measurements.player_measurements.intersection_offroad * 5) \
- is_collision * 100 \
- np.abs(self.control.steer) * 10
# update measurements
self.measurements = [measurements.player_measurements.forward_speed]
self.autopilot = measurements.player_measurements.autopilot_control
# action_p = ['%.2f' % member for member in [self.control.throttle, self.control.steer]]
# screen.success('REWARD: %.2f, ACTIONS: %s' % (self.reward, action_p))
if (measurements.game_timestamp >= self.tp.env.episode_max_time) or is_collision:
# screen.success('EPISODE IS DONE. GameTime: {}, Collision: {}'.format(str(measurements.game_timestamp),
# str(is_collision)))
self.done = True
def _take_action(self, action_idx):
if type(action_idx) == int:
action = self.actions[action_idx]
else:
action = action_idx
self.last_action_idx = action
self.control = VehicleControl()
self.control.throttle = np.clip(action[0], 0, 1)
self.control.steer = np.clip(action[1], -1, 1)
self.control.brake = np.abs(np.clip(action[0], -1, 0))
if not self.tp.env.allow_braking:
self.control.brake = 0
self.control.hand_brake = False
self.control.reverse = False
self.game.send_control(self.control)
def _restart_environment_episode(self, force_environment_reset=False):
self.iterator_start_positions += 1
if self.iterator_start_positions >= self.num_pos:
self.iterator_start_positions = 0
try:
self.game.start_episode(self.iterator_start_positions)
except:
self.game.connect()
self.game.start_episode(self.iterator_start_positions)
# start the game with some initial speed
observation = None
for i in range(self.num_speedup_steps):
observation = self.step([1.0, 0])['observation']
self.observation = observation
return observation

View File

@@ -25,6 +25,7 @@ import numpy as np
from environments.environment_wrapper import EnvironmentWrapper
from os import path, environ
from utils import *
from logger import *
# enum of the available levels and their path
@@ -39,6 +40,43 @@ class DoomLevel(Enum):
DEFEND_THE_LINE = "defend_the_line.cfg"
DEADLY_CORRIDOR = "deadly_corridor.cfg"
key_map = {
'NO-OP': 96, # `
'ATTACK': 13, # enter
'CROUCH': 306, # ctrl
'DROP_SELECTED_ITEM': ord("t"),
'DROP_SELECTED_WEAPON': ord("t"),
'JUMP': 32, # spacebar
'LAND': ord("l"),
'LOOK_DOWN': 274, # down arrow
'LOOK_UP': 273, # up arrow
'MOVE_BACKWARD': ord("s"),
'MOVE_DOWN': ord("s"),
'MOVE_FORWARD': ord("w"),
'MOVE_LEFT': 276,
'MOVE_RIGHT': 275,
'MOVE_UP': ord("w"),
'RELOAD': ord("r"),
'SELECT_NEXT_WEAPON': ord("q"),
'SELECT_PREV_WEAPON': ord("e"),
'SELECT_WEAPON0': ord("0"),
'SELECT_WEAPON1': ord("1"),
'SELECT_WEAPON2': ord("2"),
'SELECT_WEAPON3': ord("3"),
'SELECT_WEAPON4': ord("4"),
'SELECT_WEAPON5': ord("5"),
'SELECT_WEAPON6': ord("6"),
'SELECT_WEAPON7': ord("7"),
'SELECT_WEAPON8': ord("8"),
'SELECT_WEAPON9': ord("9"),
'SPEED': 304, # shift
'STRAFE': 9, # tab
'TURN180': ord("u"),
'TURN_LEFT': ord("a"), # left arrow
'TURN_RIGHT': ord("d"), # right arrow
'USE': ord("f"),
}
class DoomEnvironmentWrapper(EnvironmentWrapper):
def __init__(self, tuning_parameters):
@@ -49,26 +87,42 @@ class DoomEnvironmentWrapper(EnvironmentWrapper):
self.scenarios_dir = path.join(environ.get('VIZDOOM_ROOT'), 'scenarios')
self.game = vizdoom.DoomGame()
self.game.load_config(path.join(self.scenarios_dir, self.level))
self.game.set_window_visible(self.is_rendered)
self.game.set_window_visible(False)
self.game.add_game_args("+vid_forcesurface 1")
if self.is_rendered:
self.wait_for_explicit_human_action = True
if self.human_control:
self.game.set_screen_resolution(vizdoom.ScreenResolution.RES_640X480)
self.renderer.create_screen(640, 480)
elif self.is_rendered:
self.game.set_screen_resolution(vizdoom.ScreenResolution.RES_320X240)
self.renderer.create_screen(320, 240)
else:
# lower resolution since we actually take only 76x60 and we don't need to render
self.game.set_screen_resolution(vizdoom.ScreenResolution.RES_160X120)
self.game.set_render_hud(False)
self.game.set_render_crosshair(False)
self.game.set_render_decals(False)
self.game.set_render_particles(False)
self.game.init()
# action space
self.action_space_abs_range = 0
self.actions = {}
self.action_space_size = self.game.get_available_buttons_size()
for action_idx in range(self.action_space_size):
self.actions[action_idx] = [0] * self.action_space_size
self.actions[action_idx][action_idx] = 1
self.actions_description = [str(action) for action in self.game.get_available_buttons()]
self.action_space_size = self.game.get_available_buttons_size() + 1
self.action_vector_size = self.action_space_size - 1
self.actions[0] = [0] * self.action_vector_size
for action_idx in range(self.action_vector_size):
self.actions[action_idx + 1] = [0] * self.action_vector_size
self.actions[action_idx + 1][action_idx] = 1
self.actions_description = ['NO-OP']
self.actions_description += [str(action).split(".")[1] for action in self.game.get_available_buttons()]
for idx, action in enumerate(self.actions_description):
if action in key_map.keys():
self.key_to_action[(key_map[action],)] = idx
# measurement
self.measurements_size = self.game.get_state().game_variables.shape
self.width = self.game.get_screen_width()
@@ -77,27 +131,17 @@ class DoomEnvironmentWrapper(EnvironmentWrapper):
self.game.set_seed(self.tp.seed)
self.reset()
def _update_observation_and_measurements(self):
def _update_state(self):
# extract all data from the current state
state = self.game.get_state()
if state is not None and state.screen_buffer is not None:
self.observation = self._preprocess_observation(state.screen_buffer)
self.observation = state.screen_buffer
self.measurements = state.game_variables
self.reward = self.game.get_last_reward()
self.done = self.game.is_episode_finished()
def step(self, action_idx):
self.reward = 0
for frame in range(self.tp.env.frame_skip):
self.reward += self.game.make_action(self._idx_to_action(action_idx))
self._update_observation_and_measurements()
if self.done:
break
return {'observation': self.observation,
'reward': self.reward,
'done': self.done,
'action': action_idx,
'measurements': self.measurements}
def _take_action(self, action_idx):
self.game.make_action(self._idx_to_action(action_idx), self.frame_skip)
def _preprocess_observation(self, observation):
if observation is None:
@@ -108,3 +152,5 @@ class DoomEnvironmentWrapper(EnvironmentWrapper):
def _restart_environment_episode(self, force_environment_reset=False):
self.game.new_episode()

View File

@@ -17,6 +17,9 @@
import numpy as np
from utils import *
from configurations import Preset
from renderer import Renderer
import operator
import time
class EnvironmentWrapper(object):
@@ -31,13 +34,19 @@ class EnvironmentWrapper(object):
self.observation = []
self.reward = 0
self.done = False
self.default_action = 0
self.last_action_idx = 0
self.episode_idx = 0
self.last_episode_time = time.time()
self.measurements = []
self.info = []
self.action_space_low = 0
self.action_space_high = 0
self.action_space_abs_range = 0
self.actions_description = {}
self.discrete_controls = True
self.action_space_size = 0
self.key_to_action = {}
self.width = 1
self.height = 1
self.is_state_type_image = True
@@ -50,17 +59,11 @@ class EnvironmentWrapper(object):
self.is_rendered = self.tp.visualization.render
self.seed = self.tp.seed
self.frame_skip = self.tp.env.frame_skip
def _update_observation_and_measurements(self):
# extract all the available measurments (ovservation, depthmap, lives, ammo etc.)
pass
def _restart_environment_episode(self, force_environment_reset=False):
"""
:param force_environment_reset: Force the environment to reset even if the episode is not done yet.
:return:
"""
pass
self.human_control = self.tp.env.human_control
self.wait_for_explicit_human_action = False
self.is_rendered = self.is_rendered or self.human_control
self.game_is_open = True
self.renderer = Renderer()
def _idx_to_action(self, action_idx):
"""
@@ -71,13 +74,43 @@ class EnvironmentWrapper(object):
"""
return self.actions[action_idx]
def _preprocess_observation(self, observation):
def _action_to_idx(self, action):
"""
Do initial observation preprocessing such as cropping, rgb2gray, rescale etc.
:param observation: a raw observation from the environment
:return: the preprocessed observation
Convert an environment action to one of the available actions of the wrapper.
For example, if the available actions are 4,5,6 then this function will map 4->0, 5->1, 6->2
:param action: the environment action
:return: an action index between 0 and self.action_space_size - 1, or -1 if the action does not exist
"""
pass
for key, val in self.actions.items():
if val == action:
return key
return -1
def get_action_from_user(self):
"""
Get an action from the user keyboard
:return: action index
"""
if self.wait_for_explicit_human_action:
while len(self.renderer.pressed_keys) == 0:
self.renderer.get_events()
if self.key_to_action == {}:
# the keys are the numbers on the keyboard corresponding to the action index
if len(self.renderer.pressed_keys) > 0:
action_idx = self.renderer.pressed_keys[0] - ord("1")
if 0 <= action_idx < self.action_space_size:
return action_idx
else:
# the keys are mapped through the environment to more intuitive keyboard keys
# key = tuple(self.renderer.pressed_keys)
# for key in self.renderer.pressed_keys:
for env_keys in self.key_to_action.keys():
if set(env_keys) == set(self.renderer.pressed_keys):
return self.key_to_action[env_keys]
# return the default action 0 so that the environment will continue running
return self.default_action
def step(self, action_idx):
"""
@@ -85,13 +118,29 @@ class EnvironmentWrapper(object):
:param action_idx: the action to perform on the environment
:return: A dictionary containing the observation, reward, done flag, action and measurements
"""
pass
self.last_action_idx = action_idx
self._take_action(action_idx)
self._update_state()
if self.is_rendered:
self.render()
self.observation = self._preprocess_observation(self.observation)
return {'observation': self.observation,
'reward': self.reward,
'done': self.done,
'action': self.last_action_idx,
'measurements': self.measurements,
'info': self.info}
def render(self):
"""
Call the environment function for rendering to the screen
"""
pass
self.renderer.render_image(self.get_rendered_image())
def reset(self, force_environment_reset=False):
"""
@@ -100,15 +149,25 @@ class EnvironmentWrapper(object):
:return: A dictionary containing the observation, reward, done flag, action and measurements
"""
self._restart_environment_episode(force_environment_reset)
self.last_episode_time = time.time()
self.done = False
self.episode_idx += 1
self.reward = 0.0
self.last_action_idx = 0
self._update_observation_and_measurements()
self._update_state()
# render before the preprocessing of the observation, so that the image will be in its original quality
if self.is_rendered:
self.render()
self.observation = self._preprocess_observation(self.observation)
return {'observation': self.observation,
'reward': self.reward,
'done': self.done,
'action': self.last_action_idx,
'measurements': self.measurements}
'measurements': self.measurements,
'info': self.info}
def get_random_action(self):
"""
@@ -129,10 +188,62 @@ class EnvironmentWrapper(object):
"""
self.phase = phase
def get_available_keys(self):
"""
Return a list of tuples mapping between action names and the keyboard key that triggers them
:return: a list of tuples mapping between action names and the keyboard key that triggers them
"""
available_keys = []
if self.key_to_action != {}:
for key, idx in sorted(self.key_to_action.items(), key=operator.itemgetter(1)):
if key != ():
key_names = [self.renderer.get_key_names([k])[0] for k in key]
available_keys.append((self.actions_description[idx], ' + '.join(key_names)))
elif self.discrete_controls:
for action in range(self.action_space_size):
available_keys.append(("Action {}".format(action + 1), action + 1))
return available_keys
# The following functions define the interaction with the environment.
# Any new environment that inherits the EnvironmentWrapper class should use these signatures.
# Some of these functions are optional - please read their description for more details.
def _take_action(self, action_idx):
"""
An environment dependent function that sends an action to the simulator.
:param action_idx: the action to perform on the environment
:return: None
"""
pass
def _preprocess_observation(self, observation):
"""
Do initial observation preprocessing such as cropping, rgb2gray, rescale etc.
Implementing this function is optional.
:param observation: a raw observation from the environment
:return: the preprocessed observation
"""
return observation
def _update_state(self):
"""
Updates the state from the environment.
Should update self.observation, self.reward, self.done, self.measurements and self.info
:return: None
"""
pass
def _restart_environment_episode(self, force_environment_reset=False):
"""
:param force_environment_reset: Force the environment to reset even if the episode is not done yet.
:return:
"""
pass
def get_rendered_image(self):
"""
Return a numpy array containing the image that will be rendered to the screen.
This can be different from the observation. For example, mujoco's observation is a measurements vector.
:return: numpy array containing the image that will be rendered to the screen
"""
return self.observation
return self.observation

View File

@@ -15,8 +15,10 @@
#
import sys
from logger import *
import gym
import numpy as np
import time
try:
import roboschool
from OpenGL import GL
@@ -40,8 +42,6 @@ from gym import wrappers
from utils import force_list, RunPhase
from environments.environment_wrapper import EnvironmentWrapper
i = 0
class GymEnvironmentWrapper(EnvironmentWrapper):
def __init__(self, tuning_parameters):
@@ -53,29 +53,30 @@ class GymEnvironmentWrapper(EnvironmentWrapper):
self.env.seed(self.seed)
# self.env_spec = gym.spec(self.env_id)
self.env.frameskip = self.frame_skip
self.discrete_controls = type(self.env.action_space) != gym.spaces.box.Box
# pybullet requires rendering before resetting the environment, but other gym environments (Pendulum) will crash
try:
if self.is_rendered:
self.render()
except:
pass
o = self.reset(True)['observation']
self.observation = self.reset(True)['observation']
# render
if self.is_rendered:
self.render()
image = self.get_rendered_image()
scale = 1
if self.human_control:
scale = 2
self.renderer.create_screen(image.shape[1]*scale, image.shape[0]*scale)
self.is_state_type_image = len(o.shape) > 1
self.is_state_type_image = len(self.observation.shape) > 1
if self.is_state_type_image:
self.width = o.shape[1]
self.height = o.shape[0]
self.width = self.observation.shape[1]
self.height = self.observation.shape[0]
else:
self.width = o.shape[0]
self.width = self.observation.shape[0]
# action space
self.actions_description = {}
if hasattr(self.env.unwrapped, 'get_action_meanings'):
self.actions_description = self.env.unwrapped.get_action_meanings()
if self.discrete_controls:
self.action_space_size = self.env.action_space.n
self.action_space_abs_range = 0
@@ -85,34 +86,31 @@ class GymEnvironmentWrapper(EnvironmentWrapper):
self.action_space_low = self.env.action_space.low
self.action_space_abs_range = np.maximum(np.abs(self.action_space_low), np.abs(self.action_space_high))
self.actions = {i: i for i in range(self.action_space_size)}
self.key_to_action = {}
if hasattr(self.env.unwrapped, 'get_keys_to_action'):
self.key_to_action = self.env.unwrapped.get_keys_to_action()
# measurements
self.timestep_limit = self.env.spec.timestep_limit
self.current_ale_lives = 0
self.measurements_size = len(self.step(0)['info'].keys())
# env intialization
self.observation = o
self.reward = 0
self.done = False
self.last_action = self.actions[0]
def render(self):
self.env.render()
def step(self, action_idx):
def _update_state(self):
if hasattr(self.env.env, 'ale'):
if self.phase == RunPhase.TRAIN and hasattr(self, 'current_ale_lives'):
# signal termination for life loss
if self.current_ale_lives != self.env.env.ale.lives():
self.done = True
self.current_ale_lives = self.env.env.ale.lives()
def _take_action(self, action_idx):
if action_idx is None:
action_idx = self.last_action_idx
self.last_action_idx = action_idx
if self.discrete_controls:
action = self.actions[action_idx]
else:
action = action_idx
if hasattr(self.env.env, 'ale'):
prev_ale_lives = self.env.env.ale.lives()
# pendulum-v0 for example expects a list
if not self.discrete_controls:
# catching cases where the action for continuous control is a number instead of a list the
@@ -128,42 +126,26 @@ class GymEnvironmentWrapper(EnvironmentWrapper):
self.observation, self.reward, self.done, self.info = self.env.step(action)
if hasattr(self.env.env, 'ale') and self.phase == RunPhase.TRAIN:
# signal termination for breakout life loss
if prev_ale_lives != self.env.env.ale.lives():
self.done = True
def _preprocess_observation(self, observation):
if any(env in self.env_id for env in ["Breakout", "Pong"]):
# crop image
self.observation = self.observation[34:195, :, :]
if self.is_rendered:
self.render()
return {'observation': self.observation,
'reward': self.reward,
'done': self.done,
'action': self.last_action_idx,
'info': self.info}
observation = observation[34:195, :, :]
return observation
def _restart_environment_episode(self, force_environment_reset=False):
# prevent reset of environment if there are ale lives left
if "Breakout" in self.env_id and self.env.env.ale.lives() > 0 and not force_environment_reset:
if (hasattr(self.env.env, 'ale') and self.env.env.ale.lives() > 0) \
and not force_environment_reset and not self.env._past_limit():
return self.observation
if self.seed:
self.env.seed(self.seed)
observation = self.env.reset()
while observation is None:
observation = self.step(0)['observation']
if "Breakout" in self.env_id:
# crop image
observation = observation[34:195, :, :]
self.observation = self.env.reset()
while self.observation is None:
self.step(0)
self.observation = observation
return observation
return self.observation
def get_rendered_image(self):
return self.env.render(mode='rgb_array')