mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
Release 0.9
Main changes are detailed below: New features - * CARLA 0.7 simulator integration * Human control of the game play * Recording of human game play and storing / loading the replay buffer * Behavioral cloning agent and presets * Golden tests for several presets * Selecting between deep / shallow image embedders * Rendering through pygame (with some boost in performance) API changes - * Improved environment wrapper API * Added an evaluate flag to allow convenient evaluation of existing checkpoints * Improve frameskip definition in Gym Bug fixes - * Fixed loading of checkpoints for agents with more than one network * Fixed the N Step Q learning agent python3 compatibility
This commit is contained in:
62
environments/CarlaSettings.ini
Normal file
62
environments/CarlaSettings.ini
Normal file
@@ -0,0 +1,62 @@
|
||||
[CARLA/Server]
|
||||
; If set to false, a mock controller will be used instead of waiting for a real
|
||||
; client to connect.
|
||||
UseNetworking=true
|
||||
; Ports to use for the server-client communication. This can be overridden by
|
||||
; the command-line switch `-world-port=N`, write and read ports will be set to
|
||||
; N+1 and N+2 respectively.
|
||||
WorldPort=2000
|
||||
; Time-out in milliseconds for the networking operations.
|
||||
ServerTimeOut=10000000000
|
||||
; In synchronous mode, CARLA waits every frame until the control from the client
|
||||
; is received.
|
||||
SynchronousMode=true
|
||||
; Send info about every non-player agent in the scene every frame, the
|
||||
; information is attached to the measurements message. This includes other
|
||||
; vehicles, pedestrians and traffic signs. Disabled by default to improve
|
||||
; performance.
|
||||
SendNonPlayerAgentsInfo=false
|
||||
|
||||
[CARLA/LevelSettings]
|
||||
; Path of the vehicle class to be used for the player. Leave empty for default.
|
||||
; Paths follow the pattern "/Game/Blueprints/Vehicles/Mustang/Mustang.Mustang_C"
|
||||
PlayerVehicle=
|
||||
; Number of non-player vehicles to be spawned into the level.
|
||||
NumberOfVehicles=15
|
||||
; Number of non-player pedestrians to be spawned into the level.
|
||||
NumberOfPedestrians=30
|
||||
; Index of the weather/lighting presets to use. If negative, the default presets
|
||||
; of the map will be used.
|
||||
WeatherId=1
|
||||
; Seeds for the pseudo-random number generators.
|
||||
SeedVehicles=123456789
|
||||
SeedPedestrians=123456789
|
||||
|
||||
[CARLA/SceneCapture]
|
||||
; Names of the cameras to be attached to the player, comma-separated, each of
|
||||
; them should be defined in its own subsection. E.g., Uncomment next line to add
|
||||
; a camera called MyCamera to the vehicle
|
||||
|
||||
Cameras=CameraRGB
|
||||
|
||||
; Now, every camera we added needs to be defined it in its own subsection.
|
||||
[CARLA/SceneCapture/CameraRGB]
|
||||
; Post-processing effect to be applied. Valid values:
|
||||
; * None No effects applied.
|
||||
; * SceneFinal Post-processing present at scene (bloom, fog, etc).
|
||||
; * Depth Depth map ground-truth only.
|
||||
; * SemanticSegmentation Semantic segmentation ground-truth only.
|
||||
PostProcessing=SceneFinal
|
||||
; Size of the captured image in pixels.
|
||||
ImageSizeX=360
|
||||
ImageSizeY=256
|
||||
; Camera (horizontal) field of view in degrees.
|
||||
CameraFOV=90
|
||||
; Position of the camera relative to the car in centimeters.
|
||||
CameraPositionX=200
|
||||
CameraPositionY=0
|
||||
CameraPositionZ=140
|
||||
; Rotation of the camera relative to the car in degrees.
|
||||
CameraRotationPitch=0
|
||||
CameraRotationRoll=0
|
||||
CameraRotationYaw=0
|
||||
@@ -15,13 +15,16 @@
|
||||
#
|
||||
|
||||
from logger import *
|
||||
from utils import Enum
|
||||
from utils import Enum, get_open_port
|
||||
from environments.gym_environment_wrapper import *
|
||||
from environments.doom_environment_wrapper import *
|
||||
from environments.carla_environment_wrapper import *
|
||||
|
||||
|
||||
class EnvTypes(Enum):
|
||||
Doom = "DoomEnvironmentWrapper"
|
||||
Gym = "GymEnvironmentWrapper"
|
||||
Carla = "CarlaEnvironmentWrapper"
|
||||
|
||||
|
||||
def create_environment(tuning_parameters):
|
||||
|
||||
230
environments/carla_environment_wrapper.py
Normal file
230
environments/carla_environment_wrapper.py
Normal file
@@ -0,0 +1,230 @@
|
||||
import sys
|
||||
from os import path, environ
|
||||
|
||||
try:
|
||||
sys.path.append(path.join(environ.get('CARLA_ROOT'), 'PythonClient'))
|
||||
from carla.client import CarlaClient
|
||||
from carla.settings import CarlaSettings
|
||||
from carla.tcp import TCPConnectionError
|
||||
from carla.sensor import Camera
|
||||
from carla.client import VehicleControl
|
||||
except ImportError:
|
||||
from logger import failed_imports
|
||||
failed_imports.append("CARLA")
|
||||
|
||||
import numpy as np
|
||||
import time
|
||||
import logging
|
||||
import subprocess
|
||||
import signal
|
||||
from environments.environment_wrapper import EnvironmentWrapper
|
||||
from utils import *
|
||||
from logger import screen, logger
|
||||
from PIL import Image
|
||||
|
||||
|
||||
# enum of the available levels and their path
|
||||
class CarlaLevel(Enum):
|
||||
TOWN1 = "/Game/Maps/Town01"
|
||||
TOWN2 = "/Game/Maps/Town02"
|
||||
|
||||
key_map = {
|
||||
'BRAKE': (274,), # down arrow
|
||||
'GAS': (273,), # up arrow
|
||||
'TURN_LEFT': (276,), # left arrow
|
||||
'TURN_RIGHT': (275,), # right arrow
|
||||
'GAS_AND_TURN_LEFT': (273, 276),
|
||||
'GAS_AND_TURN_RIGHT': (273, 275),
|
||||
'BRAKE_AND_TURN_LEFT': (274, 276),
|
||||
'BRAKE_AND_TURN_RIGHT': (274, 275),
|
||||
}
|
||||
|
||||
|
||||
class CarlaEnvironmentWrapper(EnvironmentWrapper):
|
||||
def __init__(self, tuning_parameters):
|
||||
EnvironmentWrapper.__init__(self, tuning_parameters)
|
||||
|
||||
self.tp = tuning_parameters
|
||||
|
||||
# server configuration
|
||||
self.server_height = self.tp.env.server_height
|
||||
self.server_width = self.tp.env.server_width
|
||||
self.port = get_open_port()
|
||||
self.host = 'localhost'
|
||||
self.map = CarlaLevel().get(self.tp.env.level)
|
||||
|
||||
# client configuration
|
||||
self.verbose = self.tp.env.verbose
|
||||
self.depth = self.tp.env.depth
|
||||
self.stereo = self.tp.env.stereo
|
||||
self.semantic_segmentation = self.tp.env.semantic_segmentation
|
||||
self.height = self.server_height * (1 + int(self.depth) + int(self.semantic_segmentation))
|
||||
self.width = self.server_width * (1 + int(self.stereo))
|
||||
self.size = (self.width, self.height)
|
||||
|
||||
self.config = self.tp.env.config
|
||||
if self.config:
|
||||
# load settings from file
|
||||
with open(self.config, 'r') as fp:
|
||||
self.settings = fp.read()
|
||||
else:
|
||||
# hard coded settings
|
||||
self.settings = CarlaSettings()
|
||||
self.settings.set(
|
||||
SynchronousMode=True,
|
||||
SendNonPlayerAgentsInfo=False,
|
||||
NumberOfVehicles=15,
|
||||
NumberOfPedestrians=30,
|
||||
WeatherId=1)
|
||||
self.settings.randomize_seeds()
|
||||
|
||||
# add cameras
|
||||
camera = Camera('CameraRGB')
|
||||
camera.set_image_size(self.width, self.height)
|
||||
camera.set_position(200, 0, 140)
|
||||
camera.set_rotation(0, 0, 0)
|
||||
self.settings.add_sensor(camera)
|
||||
|
||||
# open the server
|
||||
self.server = self._open_server()
|
||||
|
||||
logging.disable(40)
|
||||
|
||||
# open the client
|
||||
self.game = CarlaClient(self.host, self.port, timeout=99999999)
|
||||
self.game.connect()
|
||||
scene = self.game.load_settings(self.settings)
|
||||
|
||||
# get available start positions
|
||||
positions = scene.player_start_spots
|
||||
self.num_pos = len(positions)
|
||||
self.iterator_start_positions = 0
|
||||
|
||||
# action space
|
||||
self.discrete_controls = False
|
||||
self.action_space_size = 2
|
||||
self.action_space_high = [1, 1]
|
||||
self.action_space_low = [-1, -1]
|
||||
self.action_space_abs_range = np.maximum(np.abs(self.action_space_low), np.abs(self.action_space_high))
|
||||
self.steering_strength = 0.5
|
||||
self.gas_strength = 1.0
|
||||
self.brake_strength = 0.5
|
||||
self.actions = {0: [0., 0.],
|
||||
1: [0., -self.steering_strength],
|
||||
2: [0., self.steering_strength],
|
||||
3: [self.gas_strength, 0.],
|
||||
4: [-self.brake_strength, 0],
|
||||
5: [self.gas_strength, -self.steering_strength],
|
||||
6: [self.gas_strength, self.steering_strength],
|
||||
7: [self.brake_strength, -self.steering_strength],
|
||||
8: [self.brake_strength, self.steering_strength]}
|
||||
self.actions_description = ['NO-OP', 'TURN_LEFT', 'TURN_RIGHT', 'GAS', 'BRAKE',
|
||||
'GAS_AND_TURN_LEFT', 'GAS_AND_TURN_RIGHT',
|
||||
'BRAKE_AND_TURN_LEFT', 'BRAKE_AND_TURN_RIGHT']
|
||||
for idx, action in enumerate(self.actions_description):
|
||||
for key in key_map.keys():
|
||||
if action == key:
|
||||
self.key_to_action[key_map[key]] = idx
|
||||
self.num_speedup_steps = 30
|
||||
|
||||
# measurements
|
||||
self.measurements_size = (1,)
|
||||
self.autopilot = None
|
||||
|
||||
# env initialization
|
||||
self.reset(True)
|
||||
|
||||
# render
|
||||
if self.is_rendered:
|
||||
image = self.get_rendered_image()
|
||||
self.renderer.create_screen(image.shape[1], image.shape[0])
|
||||
|
||||
def _open_server(self):
|
||||
log_path = path.join(logger.experiments_path, "CARLA_LOG_{}.txt".format(self.port))
|
||||
with open(log_path, "wb") as out:
|
||||
cmd = [path.join(environ.get('CARLA_ROOT'), 'CarlaUE4.sh'), self.map,
|
||||
"-benchmark", "-carla-server", "-fps=10", "-world-port={}".format(self.port),
|
||||
"-windowed -ResX={} -ResY={}".format(self.server_width, self.server_height),
|
||||
"-carla-no-hud"]
|
||||
if self.config:
|
||||
cmd.append("-carla-settings={}".format(self.config))
|
||||
p = subprocess.Popen(cmd, stdout=out, stderr=out)
|
||||
|
||||
return p
|
||||
|
||||
def _close_server(self):
|
||||
os.killpg(os.getpgid(self.server.pid), signal.SIGKILL)
|
||||
|
||||
def _update_state(self):
|
||||
# get measurements and observations
|
||||
measurements = []
|
||||
while type(measurements) == list:
|
||||
measurements, sensor_data = self.game.read_data()
|
||||
self.observation = sensor_data['CameraRGB'].data
|
||||
|
||||
self.location = (measurements.player_measurements.transform.location.x,
|
||||
measurements.player_measurements.transform.location.y,
|
||||
measurements.player_measurements.transform.location.z)
|
||||
|
||||
is_collision = measurements.player_measurements.collision_vehicles != 0 \
|
||||
or measurements.player_measurements.collision_pedestrians != 0 \
|
||||
or measurements.player_measurements.collision_other != 0
|
||||
|
||||
speed_reward = measurements.player_measurements.forward_speed - 1
|
||||
if speed_reward > 30.:
|
||||
speed_reward = 30.
|
||||
self.reward = speed_reward \
|
||||
- (measurements.player_measurements.intersection_otherlane * 5) \
|
||||
- (measurements.player_measurements.intersection_offroad * 5) \
|
||||
- is_collision * 100 \
|
||||
- np.abs(self.control.steer) * 10
|
||||
|
||||
# update measurements
|
||||
self.measurements = [measurements.player_measurements.forward_speed]
|
||||
self.autopilot = measurements.player_measurements.autopilot_control
|
||||
|
||||
# action_p = ['%.2f' % member for member in [self.control.throttle, self.control.steer]]
|
||||
# screen.success('REWARD: %.2f, ACTIONS: %s' % (self.reward, action_p))
|
||||
|
||||
if (measurements.game_timestamp >= self.tp.env.episode_max_time) or is_collision:
|
||||
# screen.success('EPISODE IS DONE. GameTime: {}, Collision: {}'.format(str(measurements.game_timestamp),
|
||||
# str(is_collision)))
|
||||
self.done = True
|
||||
|
||||
def _take_action(self, action_idx):
|
||||
if type(action_idx) == int:
|
||||
action = self.actions[action_idx]
|
||||
else:
|
||||
action = action_idx
|
||||
self.last_action_idx = action
|
||||
|
||||
self.control = VehicleControl()
|
||||
self.control.throttle = np.clip(action[0], 0, 1)
|
||||
self.control.steer = np.clip(action[1], -1, 1)
|
||||
self.control.brake = np.abs(np.clip(action[0], -1, 0))
|
||||
if not self.tp.env.allow_braking:
|
||||
self.control.brake = 0
|
||||
self.control.hand_brake = False
|
||||
self.control.reverse = False
|
||||
|
||||
self.game.send_control(self.control)
|
||||
|
||||
def _restart_environment_episode(self, force_environment_reset=False):
|
||||
self.iterator_start_positions += 1
|
||||
if self.iterator_start_positions >= self.num_pos:
|
||||
self.iterator_start_positions = 0
|
||||
|
||||
try:
|
||||
self.game.start_episode(self.iterator_start_positions)
|
||||
except:
|
||||
self.game.connect()
|
||||
self.game.start_episode(self.iterator_start_positions)
|
||||
|
||||
# start the game with some initial speed
|
||||
observation = None
|
||||
for i in range(self.num_speedup_steps):
|
||||
observation = self.step([1.0, 0])['observation']
|
||||
self.observation = observation
|
||||
|
||||
return observation
|
||||
|
||||
@@ -25,6 +25,7 @@ import numpy as np
|
||||
from environments.environment_wrapper import EnvironmentWrapper
|
||||
from os import path, environ
|
||||
from utils import *
|
||||
from logger import *
|
||||
|
||||
|
||||
# enum of the available levels and their path
|
||||
@@ -39,6 +40,43 @@ class DoomLevel(Enum):
|
||||
DEFEND_THE_LINE = "defend_the_line.cfg"
|
||||
DEADLY_CORRIDOR = "deadly_corridor.cfg"
|
||||
|
||||
key_map = {
|
||||
'NO-OP': 96, # `
|
||||
'ATTACK': 13, # enter
|
||||
'CROUCH': 306, # ctrl
|
||||
'DROP_SELECTED_ITEM': ord("t"),
|
||||
'DROP_SELECTED_WEAPON': ord("t"),
|
||||
'JUMP': 32, # spacebar
|
||||
'LAND': ord("l"),
|
||||
'LOOK_DOWN': 274, # down arrow
|
||||
'LOOK_UP': 273, # up arrow
|
||||
'MOVE_BACKWARD': ord("s"),
|
||||
'MOVE_DOWN': ord("s"),
|
||||
'MOVE_FORWARD': ord("w"),
|
||||
'MOVE_LEFT': 276,
|
||||
'MOVE_RIGHT': 275,
|
||||
'MOVE_UP': ord("w"),
|
||||
'RELOAD': ord("r"),
|
||||
'SELECT_NEXT_WEAPON': ord("q"),
|
||||
'SELECT_PREV_WEAPON': ord("e"),
|
||||
'SELECT_WEAPON0': ord("0"),
|
||||
'SELECT_WEAPON1': ord("1"),
|
||||
'SELECT_WEAPON2': ord("2"),
|
||||
'SELECT_WEAPON3': ord("3"),
|
||||
'SELECT_WEAPON4': ord("4"),
|
||||
'SELECT_WEAPON5': ord("5"),
|
||||
'SELECT_WEAPON6': ord("6"),
|
||||
'SELECT_WEAPON7': ord("7"),
|
||||
'SELECT_WEAPON8': ord("8"),
|
||||
'SELECT_WEAPON9': ord("9"),
|
||||
'SPEED': 304, # shift
|
||||
'STRAFE': 9, # tab
|
||||
'TURN180': ord("u"),
|
||||
'TURN_LEFT': ord("a"), # left arrow
|
||||
'TURN_RIGHT': ord("d"), # right arrow
|
||||
'USE': ord("f"),
|
||||
}
|
||||
|
||||
|
||||
class DoomEnvironmentWrapper(EnvironmentWrapper):
|
||||
def __init__(self, tuning_parameters):
|
||||
@@ -49,26 +87,42 @@ class DoomEnvironmentWrapper(EnvironmentWrapper):
|
||||
self.scenarios_dir = path.join(environ.get('VIZDOOM_ROOT'), 'scenarios')
|
||||
self.game = vizdoom.DoomGame()
|
||||
self.game.load_config(path.join(self.scenarios_dir, self.level))
|
||||
self.game.set_window_visible(self.is_rendered)
|
||||
self.game.set_window_visible(False)
|
||||
self.game.add_game_args("+vid_forcesurface 1")
|
||||
if self.is_rendered:
|
||||
|
||||
self.wait_for_explicit_human_action = True
|
||||
if self.human_control:
|
||||
self.game.set_screen_resolution(vizdoom.ScreenResolution.RES_640X480)
|
||||
self.renderer.create_screen(640, 480)
|
||||
elif self.is_rendered:
|
||||
self.game.set_screen_resolution(vizdoom.ScreenResolution.RES_320X240)
|
||||
self.renderer.create_screen(320, 240)
|
||||
else:
|
||||
# lower resolution since we actually take only 76x60 and we don't need to render
|
||||
self.game.set_screen_resolution(vizdoom.ScreenResolution.RES_160X120)
|
||||
|
||||
self.game.set_render_hud(False)
|
||||
self.game.set_render_crosshair(False)
|
||||
self.game.set_render_decals(False)
|
||||
self.game.set_render_particles(False)
|
||||
self.game.init()
|
||||
|
||||
# action space
|
||||
self.action_space_abs_range = 0
|
||||
self.actions = {}
|
||||
self.action_space_size = self.game.get_available_buttons_size()
|
||||
for action_idx in range(self.action_space_size):
|
||||
self.actions[action_idx] = [0] * self.action_space_size
|
||||
self.actions[action_idx][action_idx] = 1
|
||||
self.actions_description = [str(action) for action in self.game.get_available_buttons()]
|
||||
self.action_space_size = self.game.get_available_buttons_size() + 1
|
||||
self.action_vector_size = self.action_space_size - 1
|
||||
self.actions[0] = [0] * self.action_vector_size
|
||||
for action_idx in range(self.action_vector_size):
|
||||
self.actions[action_idx + 1] = [0] * self.action_vector_size
|
||||
self.actions[action_idx + 1][action_idx] = 1
|
||||
self.actions_description = ['NO-OP']
|
||||
self.actions_description += [str(action).split(".")[1] for action in self.game.get_available_buttons()]
|
||||
for idx, action in enumerate(self.actions_description):
|
||||
if action in key_map.keys():
|
||||
self.key_to_action[(key_map[action],)] = idx
|
||||
|
||||
# measurement
|
||||
self.measurements_size = self.game.get_state().game_variables.shape
|
||||
|
||||
self.width = self.game.get_screen_width()
|
||||
@@ -77,27 +131,17 @@ class DoomEnvironmentWrapper(EnvironmentWrapper):
|
||||
self.game.set_seed(self.tp.seed)
|
||||
self.reset()
|
||||
|
||||
def _update_observation_and_measurements(self):
|
||||
def _update_state(self):
|
||||
# extract all data from the current state
|
||||
state = self.game.get_state()
|
||||
if state is not None and state.screen_buffer is not None:
|
||||
self.observation = self._preprocess_observation(state.screen_buffer)
|
||||
self.observation = state.screen_buffer
|
||||
self.measurements = state.game_variables
|
||||
self.reward = self.game.get_last_reward()
|
||||
self.done = self.game.is_episode_finished()
|
||||
|
||||
def step(self, action_idx):
|
||||
self.reward = 0
|
||||
for frame in range(self.tp.env.frame_skip):
|
||||
self.reward += self.game.make_action(self._idx_to_action(action_idx))
|
||||
self._update_observation_and_measurements()
|
||||
if self.done:
|
||||
break
|
||||
|
||||
return {'observation': self.observation,
|
||||
'reward': self.reward,
|
||||
'done': self.done,
|
||||
'action': action_idx,
|
||||
'measurements': self.measurements}
|
||||
def _take_action(self, action_idx):
|
||||
self.game.make_action(self._idx_to_action(action_idx), self.frame_skip)
|
||||
|
||||
def _preprocess_observation(self, observation):
|
||||
if observation is None:
|
||||
@@ -108,3 +152,5 @@ class DoomEnvironmentWrapper(EnvironmentWrapper):
|
||||
|
||||
def _restart_environment_episode(self, force_environment_reset=False):
|
||||
self.game.new_episode()
|
||||
|
||||
|
||||
|
||||
@@ -17,6 +17,9 @@
|
||||
import numpy as np
|
||||
from utils import *
|
||||
from configurations import Preset
|
||||
from renderer import Renderer
|
||||
import operator
|
||||
import time
|
||||
|
||||
|
||||
class EnvironmentWrapper(object):
|
||||
@@ -31,13 +34,19 @@ class EnvironmentWrapper(object):
|
||||
self.observation = []
|
||||
self.reward = 0
|
||||
self.done = False
|
||||
self.default_action = 0
|
||||
self.last_action_idx = 0
|
||||
self.episode_idx = 0
|
||||
self.last_episode_time = time.time()
|
||||
self.measurements = []
|
||||
self.info = []
|
||||
self.action_space_low = 0
|
||||
self.action_space_high = 0
|
||||
self.action_space_abs_range = 0
|
||||
self.actions_description = {}
|
||||
self.discrete_controls = True
|
||||
self.action_space_size = 0
|
||||
self.key_to_action = {}
|
||||
self.width = 1
|
||||
self.height = 1
|
||||
self.is_state_type_image = True
|
||||
@@ -50,17 +59,11 @@ class EnvironmentWrapper(object):
|
||||
self.is_rendered = self.tp.visualization.render
|
||||
self.seed = self.tp.seed
|
||||
self.frame_skip = self.tp.env.frame_skip
|
||||
|
||||
def _update_observation_and_measurements(self):
|
||||
# extract all the available measurments (ovservation, depthmap, lives, ammo etc.)
|
||||
pass
|
||||
|
||||
def _restart_environment_episode(self, force_environment_reset=False):
|
||||
"""
|
||||
:param force_environment_reset: Force the environment to reset even if the episode is not done yet.
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
self.human_control = self.tp.env.human_control
|
||||
self.wait_for_explicit_human_action = False
|
||||
self.is_rendered = self.is_rendered or self.human_control
|
||||
self.game_is_open = True
|
||||
self.renderer = Renderer()
|
||||
|
||||
def _idx_to_action(self, action_idx):
|
||||
"""
|
||||
@@ -71,13 +74,43 @@ class EnvironmentWrapper(object):
|
||||
"""
|
||||
return self.actions[action_idx]
|
||||
|
||||
def _preprocess_observation(self, observation):
|
||||
def _action_to_idx(self, action):
|
||||
"""
|
||||
Do initial observation preprocessing such as cropping, rgb2gray, rescale etc.
|
||||
:param observation: a raw observation from the environment
|
||||
:return: the preprocessed observation
|
||||
Convert an environment action to one of the available actions of the wrapper.
|
||||
For example, if the available actions are 4,5,6 then this function will map 4->0, 5->1, 6->2
|
||||
:param action: the environment action
|
||||
:return: an action index between 0 and self.action_space_size - 1, or -1 if the action does not exist
|
||||
"""
|
||||
pass
|
||||
for key, val in self.actions.items():
|
||||
if val == action:
|
||||
return key
|
||||
return -1
|
||||
|
||||
def get_action_from_user(self):
|
||||
"""
|
||||
Get an action from the user keyboard
|
||||
:return: action index
|
||||
"""
|
||||
if self.wait_for_explicit_human_action:
|
||||
while len(self.renderer.pressed_keys) == 0:
|
||||
self.renderer.get_events()
|
||||
|
||||
if self.key_to_action == {}:
|
||||
# the keys are the numbers on the keyboard corresponding to the action index
|
||||
if len(self.renderer.pressed_keys) > 0:
|
||||
action_idx = self.renderer.pressed_keys[0] - ord("1")
|
||||
if 0 <= action_idx < self.action_space_size:
|
||||
return action_idx
|
||||
else:
|
||||
# the keys are mapped through the environment to more intuitive keyboard keys
|
||||
# key = tuple(self.renderer.pressed_keys)
|
||||
# for key in self.renderer.pressed_keys:
|
||||
for env_keys in self.key_to_action.keys():
|
||||
if set(env_keys) == set(self.renderer.pressed_keys):
|
||||
return self.key_to_action[env_keys]
|
||||
|
||||
# return the default action 0 so that the environment will continue running
|
||||
return self.default_action
|
||||
|
||||
def step(self, action_idx):
|
||||
"""
|
||||
@@ -85,13 +118,29 @@ class EnvironmentWrapper(object):
|
||||
:param action_idx: the action to perform on the environment
|
||||
:return: A dictionary containing the observation, reward, done flag, action and measurements
|
||||
"""
|
||||
pass
|
||||
self.last_action_idx = action_idx
|
||||
|
||||
self._take_action(action_idx)
|
||||
|
||||
self._update_state()
|
||||
|
||||
if self.is_rendered:
|
||||
self.render()
|
||||
|
||||
self.observation = self._preprocess_observation(self.observation)
|
||||
|
||||
return {'observation': self.observation,
|
||||
'reward': self.reward,
|
||||
'done': self.done,
|
||||
'action': self.last_action_idx,
|
||||
'measurements': self.measurements,
|
||||
'info': self.info}
|
||||
|
||||
def render(self):
|
||||
"""
|
||||
Call the environment function for rendering to the screen
|
||||
"""
|
||||
pass
|
||||
self.renderer.render_image(self.get_rendered_image())
|
||||
|
||||
def reset(self, force_environment_reset=False):
|
||||
"""
|
||||
@@ -100,15 +149,25 @@ class EnvironmentWrapper(object):
|
||||
:return: A dictionary containing the observation, reward, done flag, action and measurements
|
||||
"""
|
||||
self._restart_environment_episode(force_environment_reset)
|
||||
self.last_episode_time = time.time()
|
||||
self.done = False
|
||||
self.episode_idx += 1
|
||||
self.reward = 0.0
|
||||
self.last_action_idx = 0
|
||||
self._update_observation_and_measurements()
|
||||
self._update_state()
|
||||
|
||||
# render before the preprocessing of the observation, so that the image will be in its original quality
|
||||
if self.is_rendered:
|
||||
self.render()
|
||||
|
||||
self.observation = self._preprocess_observation(self.observation)
|
||||
|
||||
return {'observation': self.observation,
|
||||
'reward': self.reward,
|
||||
'done': self.done,
|
||||
'action': self.last_action_idx,
|
||||
'measurements': self.measurements}
|
||||
'measurements': self.measurements,
|
||||
'info': self.info}
|
||||
|
||||
def get_random_action(self):
|
||||
"""
|
||||
@@ -129,10 +188,62 @@ class EnvironmentWrapper(object):
|
||||
"""
|
||||
self.phase = phase
|
||||
|
||||
def get_available_keys(self):
|
||||
"""
|
||||
Return a list of tuples mapping between action names and the keyboard key that triggers them
|
||||
:return: a list of tuples mapping between action names and the keyboard key that triggers them
|
||||
"""
|
||||
available_keys = []
|
||||
if self.key_to_action != {}:
|
||||
for key, idx in sorted(self.key_to_action.items(), key=operator.itemgetter(1)):
|
||||
if key != ():
|
||||
key_names = [self.renderer.get_key_names([k])[0] for k in key]
|
||||
available_keys.append((self.actions_description[idx], ' + '.join(key_names)))
|
||||
elif self.discrete_controls:
|
||||
for action in range(self.action_space_size):
|
||||
available_keys.append(("Action {}".format(action + 1), action + 1))
|
||||
return available_keys
|
||||
|
||||
# The following functions define the interaction with the environment.
|
||||
# Any new environment that inherits the EnvironmentWrapper class should use these signatures.
|
||||
# Some of these functions are optional - please read their description for more details.
|
||||
|
||||
def _take_action(self, action_idx):
|
||||
"""
|
||||
An environment dependent function that sends an action to the simulator.
|
||||
:param action_idx: the action to perform on the environment
|
||||
:return: None
|
||||
"""
|
||||
pass
|
||||
|
||||
def _preprocess_observation(self, observation):
|
||||
"""
|
||||
Do initial observation preprocessing such as cropping, rgb2gray, rescale etc.
|
||||
Implementing this function is optional.
|
||||
:param observation: a raw observation from the environment
|
||||
:return: the preprocessed observation
|
||||
"""
|
||||
return observation
|
||||
|
||||
def _update_state(self):
|
||||
"""
|
||||
Updates the state from the environment.
|
||||
Should update self.observation, self.reward, self.done, self.measurements and self.info
|
||||
:return: None
|
||||
"""
|
||||
pass
|
||||
|
||||
def _restart_environment_episode(self, force_environment_reset=False):
|
||||
"""
|
||||
:param force_environment_reset: Force the environment to reset even if the episode is not done yet.
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_rendered_image(self):
|
||||
"""
|
||||
Return a numpy array containing the image that will be rendered to the screen.
|
||||
This can be different from the observation. For example, mujoco's observation is a measurements vector.
|
||||
:return: numpy array containing the image that will be rendered to the screen
|
||||
"""
|
||||
return self.observation
|
||||
return self.observation
|
||||
@@ -15,8 +15,10 @@
|
||||
#
|
||||
|
||||
import sys
|
||||
from logger import *
|
||||
import gym
|
||||
import numpy as np
|
||||
import time
|
||||
try:
|
||||
import roboschool
|
||||
from OpenGL import GL
|
||||
@@ -40,8 +42,6 @@ from gym import wrappers
|
||||
from utils import force_list, RunPhase
|
||||
from environments.environment_wrapper import EnvironmentWrapper
|
||||
|
||||
i = 0
|
||||
|
||||
|
||||
class GymEnvironmentWrapper(EnvironmentWrapper):
|
||||
def __init__(self, tuning_parameters):
|
||||
@@ -53,29 +53,30 @@ class GymEnvironmentWrapper(EnvironmentWrapper):
|
||||
self.env.seed(self.seed)
|
||||
|
||||
# self.env_spec = gym.spec(self.env_id)
|
||||
self.env.frameskip = self.frame_skip
|
||||
self.discrete_controls = type(self.env.action_space) != gym.spaces.box.Box
|
||||
|
||||
# pybullet requires rendering before resetting the environment, but other gym environments (Pendulum) will crash
|
||||
try:
|
||||
if self.is_rendered:
|
||||
self.render()
|
||||
except:
|
||||
pass
|
||||
|
||||
o = self.reset(True)['observation']
|
||||
self.observation = self.reset(True)['observation']
|
||||
|
||||
# render
|
||||
if self.is_rendered:
|
||||
self.render()
|
||||
image = self.get_rendered_image()
|
||||
scale = 1
|
||||
if self.human_control:
|
||||
scale = 2
|
||||
self.renderer.create_screen(image.shape[1]*scale, image.shape[0]*scale)
|
||||
|
||||
self.is_state_type_image = len(o.shape) > 1
|
||||
self.is_state_type_image = len(self.observation.shape) > 1
|
||||
if self.is_state_type_image:
|
||||
self.width = o.shape[1]
|
||||
self.height = o.shape[0]
|
||||
self.width = self.observation.shape[1]
|
||||
self.height = self.observation.shape[0]
|
||||
else:
|
||||
self.width = o.shape[0]
|
||||
self.width = self.observation.shape[0]
|
||||
|
||||
# action space
|
||||
self.actions_description = {}
|
||||
if hasattr(self.env.unwrapped, 'get_action_meanings'):
|
||||
self.actions_description = self.env.unwrapped.get_action_meanings()
|
||||
if self.discrete_controls:
|
||||
self.action_space_size = self.env.action_space.n
|
||||
self.action_space_abs_range = 0
|
||||
@@ -85,34 +86,31 @@ class GymEnvironmentWrapper(EnvironmentWrapper):
|
||||
self.action_space_low = self.env.action_space.low
|
||||
self.action_space_abs_range = np.maximum(np.abs(self.action_space_low), np.abs(self.action_space_high))
|
||||
self.actions = {i: i for i in range(self.action_space_size)}
|
||||
self.key_to_action = {}
|
||||
if hasattr(self.env.unwrapped, 'get_keys_to_action'):
|
||||
self.key_to_action = self.env.unwrapped.get_keys_to_action()
|
||||
|
||||
# measurements
|
||||
self.timestep_limit = self.env.spec.timestep_limit
|
||||
self.current_ale_lives = 0
|
||||
self.measurements_size = len(self.step(0)['info'].keys())
|
||||
|
||||
# env intialization
|
||||
self.observation = o
|
||||
self.reward = 0
|
||||
self.done = False
|
||||
self.last_action = self.actions[0]
|
||||
|
||||
def render(self):
|
||||
self.env.render()
|
||||
|
||||
def step(self, action_idx):
|
||||
def _update_state(self):
|
||||
if hasattr(self.env.env, 'ale'):
|
||||
if self.phase == RunPhase.TRAIN and hasattr(self, 'current_ale_lives'):
|
||||
# signal termination for life loss
|
||||
if self.current_ale_lives != self.env.env.ale.lives():
|
||||
self.done = True
|
||||
self.current_ale_lives = self.env.env.ale.lives()
|
||||
|
||||
def _take_action(self, action_idx):
|
||||
if action_idx is None:
|
||||
action_idx = self.last_action_idx
|
||||
|
||||
self.last_action_idx = action_idx
|
||||
|
||||
if self.discrete_controls:
|
||||
action = self.actions[action_idx]
|
||||
else:
|
||||
action = action_idx
|
||||
|
||||
if hasattr(self.env.env, 'ale'):
|
||||
prev_ale_lives = self.env.env.ale.lives()
|
||||
|
||||
# pendulum-v0 for example expects a list
|
||||
if not self.discrete_controls:
|
||||
# catching cases where the action for continuous control is a number instead of a list the
|
||||
@@ -128,42 +126,26 @@ class GymEnvironmentWrapper(EnvironmentWrapper):
|
||||
|
||||
self.observation, self.reward, self.done, self.info = self.env.step(action)
|
||||
|
||||
if hasattr(self.env.env, 'ale') and self.phase == RunPhase.TRAIN:
|
||||
# signal termination for breakout life loss
|
||||
if prev_ale_lives != self.env.env.ale.lives():
|
||||
self.done = True
|
||||
|
||||
def _preprocess_observation(self, observation):
|
||||
if any(env in self.env_id for env in ["Breakout", "Pong"]):
|
||||
# crop image
|
||||
self.observation = self.observation[34:195, :, :]
|
||||
|
||||
if self.is_rendered:
|
||||
self.render()
|
||||
|
||||
return {'observation': self.observation,
|
||||
'reward': self.reward,
|
||||
'done': self.done,
|
||||
'action': self.last_action_idx,
|
||||
'info': self.info}
|
||||
observation = observation[34:195, :, :]
|
||||
return observation
|
||||
|
||||
def _restart_environment_episode(self, force_environment_reset=False):
|
||||
# prevent reset of environment if there are ale lives left
|
||||
if "Breakout" in self.env_id and self.env.env.ale.lives() > 0 and not force_environment_reset:
|
||||
if (hasattr(self.env.env, 'ale') and self.env.env.ale.lives() > 0) \
|
||||
and not force_environment_reset and not self.env._past_limit():
|
||||
return self.observation
|
||||
|
||||
if self.seed:
|
||||
self.env.seed(self.seed)
|
||||
observation = self.env.reset()
|
||||
while observation is None:
|
||||
observation = self.step(0)['observation']
|
||||
|
||||
if "Breakout" in self.env_id:
|
||||
# crop image
|
||||
observation = observation[34:195, :, :]
|
||||
self.observation = self.env.reset()
|
||||
while self.observation is None:
|
||||
self.step(0)
|
||||
|
||||
self.observation = observation
|
||||
|
||||
return observation
|
||||
return self.observation
|
||||
|
||||
def get_rendered_image(self):
|
||||
return self.env.render(mode='rgb_array')
|
||||
|
||||
Reference in New Issue
Block a user