1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

pre-release 0.10.0

This commit is contained in:
Gal Novik
2018-08-13 17:11:34 +03:00
parent d44c329bb8
commit 19ca5c24b1
485 changed files with 33292 additions and 16770 deletions

View File

@@ -0,0 +1,112 @@
; Example of settings file for CARLA.
;
; This file can be loaded with the Python client to be sent to the server. It
; defines the parameters to be used when requesting a new episode.
;
; Note that server specific variables are only loaded when launching the
; simulator. Use it with `./CarlaUE4.sh -carla-settings=Path/To/This/File`.
[CARLA/Server]
; If set to false, a mock controller will be used instead of waiting for a real
; client to connect. (Server only)
UseNetworking=false
; Ports to use for the server-client communication. This can be overridden by
; the command-line switch `-world-port=N`, write and read ports will be set to
; N+1 and N+2 respectively. (Server only)
WorldPort=2000
; Time-out in milliseconds for the networking operations. (Server only)
ServerTimeOut=100000000000
; In synchronous mode, CARLA waits every frame until the control from the client
; is received.
SynchronousMode=true
; Send info about every non-player agent in the scene every frame, the
; information is attached to the measurements message. This includes other
; vehicles, pedestrians and traffic signs. Disabled by default to improve
; performance.
SendNonPlayerAgentsInfo=false
[CARLA/QualitySettings]
; Quality level of the graphics, a lower level makes the simulation run
; considerably faster. Available: Low or Epic.
QualityLevel=Low
[CARLA/LevelSettings]
; Path of the vehicle class to be used for the player. Leave empty for default.
; Paths follow the pattern "/Game/Blueprints/Vehicles/Mustang/Mustang.Mustang_C"
PlayerVehicle=
; Number of non-player vehicles to be spawned into the level.
NumberOfVehicles=15
; Number of non-player pedestrians to be spawned into the level.
NumberOfPedestrians=30
; Index of the weather/lighting presets to use. If negative, the default presets
; of the map will be used.
WeatherId=1
; Seeds for the pseudo-random number generators.
SeedVehicles=123456789
SeedPedestrians=123456789
[CARLA/Sensor]
; Names of the sensors to be attached to the player, comma-separated, each of
; them should be defined in its own subsection.
; Uncomment next line to add a camera called FrontCamera to the vehicle
Sensors=FrontCamera
; or uncomment next line to add a camera and a Lidar
; Sensors=FrontCamera,MyLidar
; or uncomment next line to add a regular camera and a depth camera
; Sensors=FrontCamera,FrontCamera/Depth
; Now, every camera we added needs to be defined it in its own subsection.
[CARLA/Sensor/FrontCamera]
; Type of the sensor. The available types are:
; * CAMERA A scene capture camera.
; * LIDAR_RAY_CAST A Lidar implementation based on ray-casting.
SensorType=CAMERA
; Post-processing effect to be applied to this camera. Valid values:
; * None No effects applied.
; * SceneFinal Post-processing present at scene (bloom, fog, etc).
; * Depth Depth map ground-truth only.
; * SemanticSegmentation Semantic segmentation ground-truth only.
PostProcessing=SceneFinal
; Size of the captured image in pixels.
ImageSizeX=360
ImageSizeY=256
; Camera (horizontal) field of view in degrees.
FOV=90
; Position of the camera relative to the car in meters.
PositionX=0.20
PositionY=0
PositionZ=1.30
; Rotation of the camera relative to the car in degrees.
RotationPitch=8
RotationRoll=0
RotationYaw=0
[CARLA/Sensor/FrontCamera/Depth]
; The sensor can be defined in a subsection of FrontCamera so it inherits the
; values in FrontCamera. This adds a camera similar to FrontCamera but generating
; depth map images instead.
PostProcessing=Depth
[CARLA/Sensor/MyLidar]
SensorType=LIDAR_RAY_CAST
; Number of lasers.
Channels=32
; Measure distance in meters.
Range=50.0
; Points generated by all lasers per second.
PointsPerSecond=100000
; Lidar rotation frequency.
RotationFrequency=10
; Upper and lower laser angles, positive values means above horizontal line.
UpperFOVLimit=10
LowerFOVLimit=-30
; Position and rotation relative to the vehicle.
PositionX=0
PositionY=0
PositionZ=1.40
RotationPitch=0
RotationYaw=0
RotationRoll=0

View File

@@ -0,0 +1,19 @@
A custom environment implementation should look like this:
```bash
from coach.filters.input_filter import InputFilter
class CustomFilter(InputFilter):
def __init__(self):
...
def _filter(self, env_response: EnvResponse) -> EnvResponse:
...
def _get_filtered_observation_space(self, input_observation_space: ObservationSpace) -> ObservationSpace:
...
def _get_filtered_reward_space(self, input_reward_space: RewardSpace) -> RewardSpace:
...
def _validate_input_observation_space(self, input_observation_space: ObservationSpace):
...
def _reset(self):
...
```

View File

@@ -0,0 +1,16 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

View File

@@ -0,0 +1,357 @@
import random
import sys
from os import path, environ
from rl_coach.filters.observation.observation_to_uint8_filter import ObservationToUInt8Filter
from rl_coach.filters.observation.observation_rgb_to_y_filter import ObservationRGBToYFilter
try:
if 'CARLA_ROOT' in environ:
sys.path.append(path.join(environ.get('CARLA_ROOT'), 'PythonClient'))
from carla.client import CarlaClient
from carla.settings import CarlaSettings
from carla.tcp import TCPConnectionError
from carla.sensor import Camera
from carla.client import VehicleControl
except ImportError:
from rl_coach.logger import failed_imports
failed_imports.append("CARLA")
import logging
import subprocess
from rl_coach.environments.environment import Environment, EnvironmentParameters, LevelSelection
from rl_coach.spaces import BoxActionSpace, ImageObservationSpace, StateSpace, \
VectorObservationSpace
from rl_coach.utils import get_open_port, force_list
from enum import Enum
import os
import signal
from typing import List, Union
from rl_coach.base_parameters import VisualizationParameters
from rl_coach.filters.filter import InputFilter, NoOutputFilter
from rl_coach.filters.observation.observation_rescale_to_size_filter import ObservationRescaleToSizeFilter
from rl_coach.filters.observation.observation_stacking_filter import ObservationStackingFilter
import numpy as np
# enum of the available levels and their path
class CarlaLevel(Enum):
TOWN1 = "/Game/Maps/Town01"
TOWN2 = "/Game/Maps/Town02"
key_map = {
'BRAKE': (274,), # down arrow
'GAS': (273,), # up arrow
'TURN_LEFT': (276,), # left arrow
'TURN_RIGHT': (275,), # right arrow
'GAS_AND_TURN_LEFT': (273, 276),
'GAS_AND_TURN_RIGHT': (273, 275),
'BRAKE_AND_TURN_LEFT': (274, 276),
'BRAKE_AND_TURN_RIGHT': (274, 275),
}
CarlaInputFilter = InputFilter(is_a_reference_filter=True)
CarlaInputFilter.add_observation_filter('forward_camera', 'rescaling',
ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([128, 180, 3]),
high=255)))
CarlaInputFilter.add_observation_filter('forward_camera', 'to_grayscale', ObservationRGBToYFilter())
CarlaInputFilter.add_observation_filter('forward_camera', 'to_uint8', ObservationToUInt8Filter(0, 255))
CarlaInputFilter.add_observation_filter('forward_camera', 'stacking', ObservationStackingFilter(4))
CarlaOutputFilter = NoOutputFilter()
class CameraTypes(Enum):
FRONT = "forward_camera"
LEFT = "left_camera"
RIGHT = "right_camera"
SEGMENTATION = "segmentation"
DEPTH = "depth"
LIDAR = "lidar"
class CarlaEnvironmentParameters(EnvironmentParameters):
class Quality(Enum):
LOW = "Low"
EPIC = "Epic"
def __init__(self):
super().__init__()
self.frame_skip = 3 # the frame skip affects the fps of the server directly. fps = 30 / frameskip
self.server_height = 512
self.server_width = 720
self.camera_height = 128
self.camera_width = 180
self.config = None #'environments/CarlaSettings.ini' # TODO: remove the config to prevent confusion
self.level = 'town1'
self.quality = self.Quality.LOW
self.cameras = [CameraTypes.FRONT]
self.weather_id = [1]
self.verbose = True
self.episode_max_time = 100000 # miliseconds for each episode
self.allow_braking = False
self.default_input_filter = CarlaInputFilter
self.default_output_filter = CarlaOutputFilter
@property
def path(self):
return 'rl_coach.environments.carla_environment:CarlaEnvironment'
class CarlaEnvironment(Environment):
def __init__(self, level: LevelSelection,
seed: int, frame_skip: int, human_control: bool, custom_reward_threshold: Union[int, float],
visualization_parameters: VisualizationParameters,
server_height: int, server_width: int, camera_height: int, camera_width: int,
verbose: bool, config: str, episode_max_time: int,
allow_braking: bool, quality: CarlaEnvironmentParameters.Quality,
cameras: List[CameraTypes], weather_id: List[int], **kwargs):
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters)
# server configuration
self.server_height = server_height
self.server_width = server_width
self.port = get_open_port()
self.host = 'localhost'
self.map = self.env_id
# client configuration
self.verbose = verbose
self.quality = quality
self.cameras = cameras
self.weather_id = weather_id
self.episode_max_time = episode_max_time
self.allow_braking = allow_braking
self.camera_width = camera_width
self.camera_height = camera_height
# state space
self.state_space = StateSpace({
"measurements": VectorObservationSpace(4, measurements_names=["forward_speed", "x", "y", "z"])
})
for camera in self.cameras:
self.state_space[camera.value] = ImageObservationSpace(
shape=np.array([self.camera_height, self.camera_width, 3]),
high=255)
# setup server settings
self.config = config
if self.config:
# load settings from file
with open(self.config, 'r') as fp:
self.settings = fp.read()
else:
# hard coded settings
self.settings = CarlaSettings()
self.settings.set(
SynchronousMode=True,
SendNonPlayerAgentsInfo=False,
NumberOfVehicles=15,
NumberOfPedestrians=30,
WeatherId=random.choice(force_list(self.weather_id)),
QualityLevel=self.quality.value)
self.settings.randomize_seeds()
self.settings = self._add_cameras(self.settings, self.cameras, self.camera_width, self.camera_height)
# open the server
self.server = self._open_server()
logging.disable(40)
# open the client
self.game = CarlaClient(self.host, self.port, timeout=99999999)
self.game.connect()
scene = self.game.load_settings(self.settings)
# get available start positions
positions = scene.player_start_spots
self.num_pos = len(positions)
self.iterator_start_positions = 0
# action space
self.action_space = BoxActionSpace(shape=2, low=np.array([-1, -1]), high=np.array([1, 1]))
# human control
if self.human_control:
# convert continuous action space to discrete
self.steering_strength = 0.5
self.gas_strength = 1.0
self.brake_strength = 0.5
self.action_space = PartialDiscreteActionSpaceMap(
target_actions=[[0., 0.],
[0., -self.steering_strength],
[0., self.steering_strength],
[self.gas_strength, 0.],
[-self.brake_strength, 0],
[self.gas_strength, -self.steering_strength],
[self.gas_strength, self.steering_strength],
[self.brake_strength, -self.steering_strength],
[self.brake_strength, self.steering_strength]],
target_action_space=self.action_space,
descriptions=['NO-OP', 'TURN_LEFT', 'TURN_RIGHT', 'GAS', 'BRAKE',
'GAS_AND_TURN_LEFT', 'GAS_AND_TURN_RIGHT',
'BRAKE_AND_TURN_LEFT', 'BRAKE_AND_TURN_RIGHT']
)
# map keyboard keys to actions
for idx, action in enumerate(self.action_space.descriptions):
for key in key_map.keys():
if action == key:
self.key_to_action[key_map[key]] = idx
self.num_speedup_steps = 30
# measurements
self.autopilot = None
# env initialization
self.reset_internal_state(True)
# render
if self.is_rendered:
image = self.get_rendered_image()
self.renderer.create_screen(image.shape[1], image.shape[0])
def _add_cameras(self, settings, cameras, camera_width, camera_height):
# add a front facing camera
if CameraTypes.FRONT in cameras:
camera = Camera(CameraTypes.FRONT.value)
camera.set_image_size(camera_width, camera_height)
camera.set_position(0.2, 0, 1.3)
camera.set_rotation(8, 0, 0)
settings.add_sensor(camera)
# add a left facing camera
if CameraTypes.LEFT in cameras:
camera = Camera(CameraTypes.LEFT.value)
camera.set_image_size(camera_width, camera_height)
camera.set_position(0.2, 0, 1.3)
camera.set_rotation(8, -30, 0)
settings.add_sensor(camera)
# add a right facing camera
if CameraTypes.RIGHT in cameras:
camera = Camera(CameraTypes.RIGHT.value)
camera.set_image_size(camera_width, camera_height)
camera.set_position(0.2, 0, 1.3)
camera.set_rotation(8, 30, 0)
settings.add_sensor(camera)
# add a front facing depth camera
if CameraTypes.DEPTH in cameras:
camera = Camera(CameraTypes.DEPTH.value)
camera.set_image_size(camera_width, camera_height)
camera.set_position(0.2, 0, 1.3)
camera.set_rotation(8, 30, 0)
camera.PostProcessing = 'Depth'
settings.add_sensor(camera)
# add a front facing semantic segmentation camera
if CameraTypes.SEGMENTATION in cameras:
camera = Camera(CameraTypes.SEGMENTATION.value)
camera.set_image_size(camera_width, camera_height)
camera.set_position(0.2, 0, 1.3)
camera.set_rotation(8, 30, 0)
camera.PostProcessing = 'SemanticSegmentation'
settings.add_sensor(camera)
return settings
def _open_server(self):
# TODO: get experiment path
log_path = path.join('./logs/', "CARLA_LOG_{}.txt".format(self.port))
with open(log_path, "wb") as out:
cmd = [path.join(environ.get('CARLA_ROOT'), 'CarlaUE4.sh'), self.map,
"-benchmark", "-carla-server", "-fps={}".format(30 / self.frame_skip),
"-world-port={}".format(self.port),
"-windowed -ResX={} -ResY={}".format(self.server_width, self.server_height),
"-carla-no-hud"]
if self.config:
cmd.append("-carla-settings={}".format(self.config))
p = subprocess.Popen(cmd, stdout=out, stderr=out)
return p
def _close_server(self):
os.killpg(os.getpgid(self.server.pid), signal.SIGKILL)
def _update_state(self):
# get measurements and observations
measurements = []
while type(measurements) == list:
measurements, sensor_data = self.game.read_data()
self.state = {}
for camera in self.cameras:
self.state[camera.value] = sensor_data[camera.value].data
self.location = [measurements.player_measurements.transform.location.x,
measurements.player_measurements.transform.location.y,
measurements.player_measurements.transform.location.z]
is_collision = measurements.player_measurements.collision_vehicles != 0 \
or measurements.player_measurements.collision_pedestrians != 0 \
or measurements.player_measurements.collision_other != 0
speed_reward = measurements.player_measurements.forward_speed - 1
if speed_reward > 30.:
speed_reward = 30.
self.reward = speed_reward \
- (measurements.player_measurements.intersection_otherlane * 5) \
- (measurements.player_measurements.intersection_offroad * 5) \
- is_collision * 100 \
- np.abs(self.control.steer) * 10
# update measurements
self.measurements = [measurements.player_measurements.forward_speed] + self.location
self.autopilot = measurements.player_measurements.autopilot_control
# action_p = ['%.2f' % member for member in [self.control.throttle, self.control.steer]]
# screen.success('REWARD: %.2f, ACTIONS: %s' % (self.reward, action_p))
if (measurements.game_timestamp >= self.episode_max_time) or is_collision:
# screen.success('EPISODE IS DONE. GameTime: {}, Collision: {}'.format(str(measurements.game_timestamp),
# str(is_collision)))
self.done = True
self.state['measurements'] = self.measurements
def _take_action(self, action):
self.control = VehicleControl()
self.control.throttle = np.clip(action[0], 0, 1)
self.control.steer = np.clip(action[1], -1, 1)
self.control.brake = np.abs(np.clip(action[0], -1, 0))
if not self.allow_braking:
self.control.brake = 0
self.control.hand_brake = False
self.control.reverse = False
self.game.send_control(self.control)
def _restart_environment_episode(self, force_environment_reset=False):
self.iterator_start_positions += 1
if self.iterator_start_positions >= self.num_pos:
self.iterator_start_positions = 0
try:
self.game.start_episode(self.iterator_start_positions)
except:
self.game.connect()
self.game.start_episode(self.iterator_start_positions)
# start the game with some initial speed
for i in range(self.num_speedup_steps):
self._take_action([1.0, 0])
def get_rendered_image(self) -> np.ndarray:
"""
Return a numpy array containing the image that will be rendered to the screen.
This can be different from the observation. For example, mujoco's observation is a measurements vector.
:return: numpy array containing the image that will be rendered to the screen
"""
image = [self.state[camera.value] for camera in self.cameras]
image = np.vstack(image)
return image

View File

@@ -0,0 +1,162 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import random
from enum import Enum
from typing import Union
import numpy as np
try:
from dm_control import suite
from dm_control.suite.wrappers import pixels
except ImportError:
from rl_coach.logger import failed_imports
failed_imports.append("DeepMind Control Suite")
from rl_coach.base_parameters import VisualizationParameters
from rl_coach.environments.environment import Environment, EnvironmentParameters, LevelSelection
from rl_coach.filters.filter import NoInputFilter, NoOutputFilter
from rl_coach.spaces import BoxActionSpace, ImageObservationSpace, VectorObservationSpace, StateSpace
class ObservationType(Enum):
Measurements = 1
Image = 2
Image_and_Measurements = 3
# Parameters
class ControlSuiteEnvironmentParameters(EnvironmentParameters):
def __init__(self):
super().__init__()
self.observation_type = ObservationType.Measurements
self.default_input_filter = ControlSuiteInputFilter
self.default_output_filter = ControlSuiteOutputFilter
@property
def path(self):
return 'rl_coach.environments.control_suite_environment:ControlSuiteEnvironment'
"""
ControlSuite Environment Components
"""
ControlSuiteInputFilter = NoInputFilter()
ControlSuiteOutputFilter = NoOutputFilter()
control_suite_envs = {':'.join(env): ':'.join(env) for env in suite.BENCHMARKING}
# Environment
class ControlSuiteEnvironment(Environment):
def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters,
seed: Union[None, int]=None, human_control: bool=False,
observation_type: ObservationType=ObservationType.Measurements,
custom_reward_threshold: Union[int, float]=None, **kwargs):
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters)
self.observation_type = observation_type
# load and initialize environment
domain_name, task_name = self.env_id.split(":")
self.env = suite.load(domain_name=domain_name, task_name=task_name)
if observation_type != ObservationType.Measurements:
self.env = pixels.Wrapper(self.env, pixels_only=observation_type == ObservationType.Image)
# seed
if self.seed is not None:
np.random.seed(self.seed)
random.seed(self.seed)
self.state_space = StateSpace({})
# image observations
if observation_type != ObservationType.Measurements:
self.state_space['pixels'] = ImageObservationSpace(shape=self.env.observation_spec()['pixels'].shape,
high=255)
# measurements observations
if observation_type != ObservationType.Image:
measurements_space_size = 0
measurements_names = []
for observation_space_name, observation_space in self.env.observation_spec().items():
if len(observation_space.shape) == 0:
measurements_space_size += 1
measurements_names.append(observation_space_name)
elif len(observation_space.shape) == 1:
measurements_space_size += observation_space.shape[0]
measurements_names.extend(["{}_{}".format(observation_space_name, i) for i in
range(observation_space.shape[0])])
self.state_space['measurements'] = VectorObservationSpace(shape=measurements_space_size,
measurements_names=measurements_names)
# actions
self.action_space = BoxActionSpace(
shape=self.env.action_spec().shape[0],
low=self.env.action_spec().minimum,
high=self.env.action_spec().maximum
)
# initialize the state by getting a new state from the environment
self.reset_internal_state(True)
# render
if self.is_rendered:
image = self.get_rendered_image()
scale = 1
if self.human_control:
scale = 2
if not self.native_rendering:
self.renderer.create_screen(image.shape[1]*scale, image.shape[0]*scale)
def _update_state(self):
self.state = {}
if self.observation_type != ObservationType.Measurements:
self.pixels = self.last_result.observation['pixels']
self.state['pixels'] = self.pixels
if self.observation_type != ObservationType.Image:
self.measurements = np.array([])
for sub_observation in self.last_result.observation.values():
if isinstance(sub_observation, np.ndarray) and len(sub_observation.shape) == 1:
self.measurements = np.concatenate((self.measurements, sub_observation))
else:
self.measurements = np.concatenate((self.measurements, np.array([sub_observation])))
self.state['measurements'] = self.measurements
self.reward = self.last_result.reward if self.last_result.reward is not None else 0
self.done = self.last_result.last()
def _take_action(self, action):
if type(self.action_space) == BoxActionSpace:
action = self.action_space.clip_action_to_space(action)
self.last_result = self.env.step(action)
def _restart_environment_episode(self, force_environment_reset=False):
self.last_result = self.env.reset()
def _render(self):
pass
def get_rendered_image(self):
return self.env.physics.render(camera_id=0)

View File

@@ -0,0 +1,39 @@
# Lines starting with # are treated as comments (or with whitespaces+#).
# It doesn't matter if you use capital letters or not.
# It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout.
doom_scenario_path = D2_navigation.wad
doom_map = map01
# Rewards
# Each step is good for you!
living_reward = 1
# And death is not!
death_penalty = 0
# Rendering options
screen_resolution = RES_160X120
screen_format = GRAY8
render_hud = false
render_crosshair = false
render_weapon = false
render_decals = false
render_particles = false
window_visible = false
# make episodes finish after 2100 actions (tics)
episode_timeout = 2100
# Available buttons
available_buttons =
{
TURN_LEFT
TURN_RIGHT
MOVE_FORWARD
}
# Game variables that will be in the state
available_game_variables = { HEALTH }
mode = PLAYER

Binary file not shown.

View File

@@ -0,0 +1,44 @@
# Lines starting with # are treated as comments (or with whitespaces+#).
# It doesn't matter if you use capital letters or not.
# It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout.
# modifty these to point to your vizdoom binary and freedoom2.wad
doom_scenario_path = D3_battle.wad
doom_map = map01
# Rewards
living_reward = 0
death_penalty = 0
# Rendering options
screen_resolution = RES_320X240
screen_format = CRCGCB
render_hud = false
render_crosshair = true
render_weapon = true
render_decals = false
render_particles = false
window_visible = false
# make episodes finish after 2100 actions (tics)
episode_timeout = 2100
# Available buttons
available_buttons =
{
MOVE_FORWARD
MOVE_BACKWARD
MOVE_RIGHT
MOVE_LEFT
TURN_LEFT
TURN_RIGHT
ATTACK
SPEED
}
# Game variables that will be in the state
available_game_variables = {AMMO2 HEALTH USER2}
mode = PLAYER
doom_skill = 2

Binary file not shown.

View File

@@ -0,0 +1,229 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
try:
import vizdoom
except ImportError:
from rl_coach.logger import failed_imports
failed_imports.append("ViZDoom")
import os
from enum import Enum
from os import path, environ
from typing import Union, List
import numpy as np
from rl_coach.base_parameters import VisualizationParameters
from rl_coach.environments.environment import Environment, EnvironmentParameters, LevelSelection
from rl_coach.filters.action.full_discrete_action_space_map import FullDiscreteActionSpaceMap
from rl_coach.filters.filter import InputFilter, OutputFilter
from rl_coach.filters.observation.observation_rescale_to_size_filter import ObservationRescaleToSizeFilter
from rl_coach.filters.observation.observation_stacking_filter import ObservationStackingFilter
from rl_coach.filters.observation.observation_to_uint8_filter import ObservationToUInt8Filter
from rl_coach.spaces import MultiSelectActionSpace, ImageObservationSpace, \
VectorObservationSpace, StateSpace
from rl_coach.filters.observation.observation_rgb_to_y_filter import ObservationRGBToYFilter
# enum of the available levels and their path
class DoomLevel(Enum):
BASIC = "basic.cfg"
DEFEND = "defend_the_center.cfg"
DEATHMATCH = "deathmatch.cfg"
MY_WAY_HOME = "my_way_home.cfg"
TAKE_COVER = "take_cover.cfg"
HEALTH_GATHERING = "health_gathering.cfg"
HEALTH_GATHERING_SUPREME_COACH_LOCAL = "D2_navigation.cfg" # from https://github.com/IntelVCL/DirectFuturePrediction/tree/master/maps
DEFEND_THE_LINE = "defend_the_line.cfg"
DEADLY_CORRIDOR = "deadly_corridor.cfg"
BATTLE_COACH_LOCAL = "D3_battle.cfg" # from https://github.com/IntelVCL/DirectFuturePrediction/tree/master/maps
key_map = {
'NO-OP': 96, # `
'ATTACK': 13, # enter
'CROUCH': 306, # ctrl
'DROP_SELECTED_ITEM': ord("t"),
'DROP_SELECTED_WEAPON': ord("t"),
'JUMP': 32, # spacebar
'LAND': ord("l"),
'LOOK_DOWN': 274, # down arrow
'LOOK_UP': 273, # up arrow
'MOVE_BACKWARD': ord("s"),
'MOVE_DOWN': ord("s"),
'MOVE_FORWARD': ord("w"),
'MOVE_LEFT': 276,
'MOVE_RIGHT': 275,
'MOVE_UP': ord("w"),
'RELOAD': ord("r"),
'SELECT_NEXT_WEAPON': ord("q"),
'SELECT_PREV_WEAPON': ord("e"),
'SELECT_WEAPON0': ord("0"),
'SELECT_WEAPON1': ord("1"),
'SELECT_WEAPON2': ord("2"),
'SELECT_WEAPON3': ord("3"),
'SELECT_WEAPON4': ord("4"),
'SELECT_WEAPON5': ord("5"),
'SELECT_WEAPON6': ord("6"),
'SELECT_WEAPON7': ord("7"),
'SELECT_WEAPON8': ord("8"),
'SELECT_WEAPON9': ord("9"),
'SPEED': 304, # shift
'STRAFE': 9, # tab
'TURN180': ord("u"),
'TURN_LEFT': ord("a"), # left arrow
'TURN_RIGHT': ord("d"), # right arrow
'USE': ord("f"),
}
DoomInputFilter = InputFilter(is_a_reference_filter=True)
DoomInputFilter.add_observation_filter('observation', 'rescaling',
ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([60, 76, 3]),
high=255)))
DoomInputFilter.add_observation_filter('observation', 'to_grayscale', ObservationRGBToYFilter())
DoomInputFilter.add_observation_filter('observation', 'to_uint8', ObservationToUInt8Filter(0, 255))
DoomInputFilter.add_observation_filter('observation', 'stacking', ObservationStackingFilter(3))
DoomOutputFilter = OutputFilter(is_a_reference_filter=True)
DoomOutputFilter.add_action_filter('to_discrete', FullDiscreteActionSpaceMap())
class DoomEnvironmentParameters(EnvironmentParameters):
def __init__(self):
super().__init__()
self.default_input_filter = DoomInputFilter
self.default_output_filter = DoomOutputFilter
self.cameras = [DoomEnvironment.CameraTypes.OBSERVATION]
@property
def path(self):
return 'rl_coach.environments.doom_environment:DoomEnvironment'
class DoomEnvironment(Environment):
class CameraTypes(Enum):
OBSERVATION = ("observation", "screen_buffer")
DEPTH = ("depth", "depth_buffer")
LABELS = ("labels", "labels_buffer")
MAP = ("map", "automap_buffer")
def __init__(self, level: LevelSelection, seed: int, frame_skip: int, human_control: bool,
custom_reward_threshold: Union[int, float], visualization_parameters: VisualizationParameters,
cameras: List[CameraTypes], **kwargs):
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters)
self.cameras = cameras
# load the emulator with the required level
self.level = DoomLevel[level.upper()]
local_scenarios_path = path.join(os.path.dirname(os.path.realpath(__file__)), 'doom')
self.scenarios_dir = local_scenarios_path if 'COACH_LOCAL' in level \
else path.join(environ.get('VIZDOOM_ROOT'), 'scenarios')
self.game = vizdoom.DoomGame()
self.game.load_config(path.join(self.scenarios_dir, self.level.value))
self.game.set_window_visible(False)
self.game.add_game_args("+vid_forcesurface 1")
self.wait_for_explicit_human_action = True
if self.human_control:
self.game.set_screen_resolution(vizdoom.ScreenResolution.RES_640X480)
elif self.is_rendered:
self.game.set_screen_resolution(vizdoom.ScreenResolution.RES_320X240)
else:
# lower resolution since we actually take only 76x60 and we don't need to render
self.game.set_screen_resolution(vizdoom.ScreenResolution.RES_160X120)
self.game.set_render_hud(False)
self.game.set_render_crosshair(False)
self.game.set_render_decals(False)
self.game.set_render_particles(False)
for camera in self.cameras:
if hasattr(self.game, 'set_{}_enabled'.format(camera.value[1])):
getattr(self.game, 'set_{}_enabled'.format(camera.value[1]))(True)
self.game.init()
# actions
actions_description = ['NO-OP']
actions_description += [str(action).split(".")[1] for action in self.game.get_available_buttons()]
actions_description = actions_description[::-1]
self.action_space = MultiSelectActionSpace(self.game.get_available_buttons_size(),
max_simultaneous_selected_actions=1,
descriptions=actions_description,
allow_no_action_to_be_selected=True)
# human control
if self.human_control:
# TODO: add this to the action space
# map keyboard keys to actions
for idx, action in enumerate(self.action_space.descriptions):
if action in key_map.keys():
self.key_to_action[(key_map[action],)] = idx
# states
self.state_space = StateSpace({
"measurements": VectorObservationSpace(self.game.get_state().game_variables.shape[0],
measurements_names=[str(m) for m in
self.game.get_available_game_variables()])
})
for camera in self.cameras:
self.state_space[camera.value[0]] = ImageObservationSpace(
shape=np.array([self.game.get_screen_height(), self.game.get_screen_width(), 3]),
high=255)
# seed
if seed is not None:
self.game.set_seed(seed)
self.reset_internal_state()
# render
if self.is_rendered:
image = self.get_rendered_image()
self.renderer.create_screen(image.shape[1], image.shape[0])
def _update_state(self):
# extract all data from the current state
state = self.game.get_state()
if state is not None and state.screen_buffer is not None:
self.measurements = state.game_variables
self.state = {'measurements': self.measurements}
for camera in self.cameras:
observation = getattr(state, camera.value[1])
if len(observation.shape) == 3:
self.state[camera.value[0]] = np.transpose(observation, (1, 2, 0))
elif len(observation.shape) == 2:
self.state[camera.value[0]] = np.repeat(np.expand_dims(observation, -1), 3, axis=-1)
self.reward = self.game.get_last_reward()
self.done = self.game.is_episode_finished()
def _take_action(self, action):
self.game.make_action(list(action), self.frame_skip)
def _restart_environment_episode(self, force_environment_reset=False):
self.game.new_episode()
def get_rendered_image(self) -> np.ndarray:
"""
Return a numpy array containing the image that will be rendered to the screen.
This can be different from the observation. For example, mujoco's observation is a measurements vector.
:return: numpy array containing the image that will be rendered to the screen
"""
image = [self.state[camera.value[0]] for camera in self.cameras]
image = np.vstack(image)
return image

View File

@@ -0,0 +1,540 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import operator
import time
from collections import OrderedDict
from typing import Union, List, Tuple, Dict
import numpy as np
from rl_coach.base_parameters import Parameters
from rl_coach.base_parameters import VisualizationParameters
from rl_coach.core_types import GoalType, ActionType, EnvResponse, RunPhase
from rl_coach.renderer import Renderer
from rl_coach.spaces import ActionSpace, ObservationSpace, DiscreteActionSpace, RewardSpace, StateSpace
from rl_coach.utils import squeeze_list, force_list
from rl_coach import logger
from rl_coach.environments.environment_interface import EnvironmentInterface
from rl_coach.logger import screen
class LevelSelection(object):
def __init__(self, level: str):
self.selected_level = level
def select(self, level: str):
self.selected_level = level
def __str__(self):
if self.selected_level is None:
logger.screen.error("No level has been selected. Please select a level using the -lvl command line flag, "
"or change the level in the preset.", crash=True)
return self.selected_level
class SingleLevelSelection(LevelSelection):
def __init__(self, levels: Union[str, List[str], Dict[str, str]]):
super().__init__(None)
self.levels = levels
if isinstance(levels, list):
self.levels = {level: level for level in levels}
if isinstance(levels, str):
self.levels = {levels: levels}
def __str__(self):
if self.selected_level is None:
logger.screen.error("No level has been selected. Please select a level using the -lvl command line flag, "
"or change the level in the preset. \nThe available levels are: \n{}"
.format(', '.join(self.levels.keys())), crash=True)
if self.selected_level not in self.levels.keys():
logger.screen.error("The selected level ({}) is not part of the available levels ({})"
.format(self.selected_level, ', '.join(self.levels.keys())), crash=True)
return self.levels[self.selected_level]
# class SingleLevelPerPhase(LevelSelection):
# def __init__(self, levels: Dict[RunPhase, str]):
# super().__init__(None)
# self.levels = levels
#
# def __str__(self):
# super().__str__()
# if self.selected_level not in self.levels.keys():
# logger.screen.error("The selected level ({}) is not part of the available levels ({})"
# .format(self.selected_level, self.levels.keys()), crash=True)
# return self.levels[self.selected_level]
class CustomWrapper(object):
def __init__(self, environment):
super().__init__()
self.environment = environment
def __getattr__(self, attr):
if attr in self.__dict__:
return self.__dict__[attr]
else:
return getattr(self.environment, attr, False)
class EnvironmentParameters(Parameters):
def __init__(self):
super().__init__()
self.level = None
self.frame_skip = 4
self.seed = None
self.human_control = False
self.custom_reward_threshold = None
self.default_input_filter = None
self.default_output_filter = None
@property
def path(self):
return 'rl_coach.environments.environment:Environment'
class Environment(EnvironmentInterface):
def __init__(self, level: LevelSelection, seed: int, frame_skip: int, human_control: bool,
custom_reward_threshold: Union[int, float], visualization_parameters: VisualizationParameters,
**kwargs):
"""
:param level: The environment level. Each environment can have multiple levels
:param seed: a seed for the random number generator of the environment
:param frame_skip: number of frames to skip (while repeating the same action) between each two agent directives
:param human_control: human should control the environment
:param visualization_parameters: a blob of parameters used for visualization of the environment
:param **kwargs: as the class is instantiated by EnvironmentParameters, this is used to support having
additional arguments which will be ignored by this class, but might be used by others
"""
super().__init__()
# env initialization
self.game = []
self.state = {}
self.observation = None
self.goal = None
self.reward = 0
self.done = False
self.info = {}
self._last_env_response = None
self.last_action = 0
self.episode_idx = 0
self.total_steps_counter = 0
self.current_episode_steps_counter = 0
self.last_episode_time = time.time()
self.key_to_action = {}
self.last_episode_images = []
# rewards
self.total_reward_in_current_episode = 0
self.max_reward_achieved = -np.inf
self.reward_success_threshold = custom_reward_threshold
# spaces
self.state_space = self._state_space = None
self.goal_space = self._goal_space = None
self.action_space = self._action_space = None
self.reward_space = RewardSpace(1, reward_success_threshold=self.reward_success_threshold) # TODO: add a getter and setter
self.env_id = str(level)
self.seed = seed
self.frame_skip = frame_skip
# human interaction and visualization
self.human_control = human_control
self.wait_for_explicit_human_action = False
self.is_rendered = visualization_parameters.render or self.human_control
self.native_rendering = visualization_parameters.native_rendering or self.human_control
self.visualization_parameters = visualization_parameters
if not self.native_rendering:
self.renderer = Renderer()
@property
def action_space(self) -> Union[List[ActionSpace], ActionSpace]:
"""
Get the action space of the environment
:return: the action space
"""
return self._action_space
@action_space.setter
def action_space(self, val: Union[List[ActionSpace], ActionSpace]):
"""
Set the action space of the environment
:return: None
"""
self._action_space = val
@property
def state_space(self) -> Union[List[StateSpace], StateSpace]:
"""
Get the state space of the environment
:return: the observation space
"""
return self._state_space
@state_space.setter
def state_space(self, val: Union[List[StateSpace], StateSpace]):
"""
Set the state space of the environment
:return: None
"""
self._state_space = val
@property
def goal_space(self) -> Union[List[ObservationSpace], ObservationSpace]:
"""
Get the state space of the environment
:return: the observation space
"""
return self._goal_space
@goal_space.setter
def goal_space(self, val: Union[List[ObservationSpace], ObservationSpace]):
"""
Set the goal space of the environment
:return: None
"""
self._goal_space = val
def get_action_from_user(self) -> ActionType:
"""
Get an action from the user keyboard
:return: action index
"""
if self.wait_for_explicit_human_action:
while len(self.renderer.pressed_keys) == 0:
self.renderer.get_events()
if self.key_to_action == {}:
# the keys are the numbers on the keyboard corresponding to the action index
if len(self.renderer.pressed_keys) > 0:
action_idx = self.renderer.pressed_keys[0] - ord("1")
if 0 <= action_idx < self.action_space.shape[0]:
return action_idx
else:
# the keys are mapped through the environment to more intuitive keyboard keys
# key = tuple(self.renderer.pressed_keys)
# for key in self.renderer.pressed_keys:
for env_keys in self.key_to_action.keys():
if set(env_keys) == set(self.renderer.pressed_keys):
return self.action_space.actions[self.key_to_action[env_keys]]
# return the default action 0 so that the environment will continue running
return self.action_space.default_action
@property
def last_env_response(self) -> Union[List[EnvResponse], EnvResponse]:
"""
Get the last environment response
:return: a dictionary that contains the state, reward, etc.
"""
return squeeze_list(self._last_env_response)
@last_env_response.setter
def last_env_response(self, val: Union[List[EnvResponse], EnvResponse]):
"""
Set the last environment response
:param val: the last environment response
"""
self._last_env_response = force_list(val)
def step(self, action: ActionType) -> EnvResponse:
"""
Make a single step in the environment using the given action
:param action: an action to use for stepping the environment. Should follow the definition of the action space.
:return: the environment response as returned in get_last_env_response
"""
action = self.action_space.clip_action_to_space(action)
if self.action_space and not self.action_space.val_matches_space_definition(action):
raise ValueError("The given action does not match the action space definition. "
"Action = {}, action space definition = {}".format(action, self.action_space))
# store the last agent action done and allow passing None actions to repeat the previously done action
if action is None:
action = self.last_action
self.last_action = action
if self.visualization_parameters.add_rendered_image_to_env_response:
current_rendered_image = self.get_rendered_image()
self.current_episode_steps_counter += 1
if self.phase != RunPhase.UNDEFINED:
self.total_steps_counter += 1
# act
self._take_action(action)
# observe
self._update_state()
if self.is_rendered:
self.render()
self.total_reward_in_current_episode += self.reward
if self.visualization_parameters.add_rendered_image_to_env_response:
self.info['image'] = current_rendered_image
self.last_env_response = \
EnvResponse(
reward=self.reward,
next_state=self.state,
goal=self.goal,
game_over=self.done,
info=self.info
)
# store observations for video / gif dumping
if self.should_dump_video_of_the_current_episode(episode_terminated=False) and \
(self.visualization_parameters.dump_mp4 or self.visualization_parameters.dump_gifs):
self.last_episode_images.append(self.get_rendered_image())
return self.last_env_response
def render(self) -> None:
"""
Call the environment function for rendering to the screen
"""
if self.native_rendering:
self._render()
else:
self.renderer.render_image(self.get_rendered_image())
def reset_internal_state(self, force_environment_reset=False) -> EnvResponse:
"""
Reset the environment and all the variable of the wrapper
:param force_environment_reset: forces environment reset even when the game did not end
:return: A dictionary containing the observation, reward, done flag, action and measurements
"""
self.dump_video_of_last_episode_if_needed()
self._restart_environment_episode(force_environment_reset)
self.last_episode_time = time.time()
if self.current_episode_steps_counter > 0 and self.phase != RunPhase.UNDEFINED:
self.episode_idx += 1
self.done = False
self.total_reward_in_current_episode = self.reward = 0.0
self.last_action = 0
self.current_episode_steps_counter = 0
self.last_episode_images = []
self._update_state()
# render before the preprocessing of the observation, so that the image will be in its original quality
if self.is_rendered:
self.render()
self.last_env_response = \
EnvResponse(
reward=self.reward,
next_state=self.state,
goal=self.goal,
game_over=self.done,
info=self.info
)
return self.last_env_response
def get_random_action(self) -> ActionType:
"""
Returns an action picked uniformly from the available actions
:return: a numpy array with a random action
"""
return self.action_space.sample()
def get_available_keys(self) -> List[Tuple[str, ActionType]]:
"""
Return a list of tuples mapping between action names and the keyboard key that triggers them
:return: a list of tuples mapping between action names and the keyboard key that triggers them
"""
available_keys = []
if self.key_to_action != {}:
for key, idx in sorted(self.key_to_action.items(), key=operator.itemgetter(1)):
if key != ():
key_names = [self.renderer.get_key_names([k])[0] for k in key]
available_keys.append((self.action_space.descriptions[idx], ' + '.join(key_names)))
elif type(self.action_space) == DiscreteActionSpace:
for action in range(self.action_space.shape):
available_keys.append(("Action {}".format(action + 1), action + 1))
return available_keys
def get_goal(self) -> GoalType:
"""
Get the current goal that the agents needs to achieve in the environment
:return: The goal
"""
return self.goal
def set_goal(self, goal: GoalType) -> None:
"""
Set the current goal that the agent needs to achieve in the environment
:param goal: the goal that needs to be achieved
:return: None
"""
self.goal = goal
def should_dump_video_of_the_current_episode(self, episode_terminated=False):
if self.visualization_parameters.video_dump_methods:
for video_dump_method in force_list(self.visualization_parameters.video_dump_methods):
if not video_dump_method.should_dump(episode_terminated, **self.__dict__):
return False
return True
return False
def dump_video_of_last_episode_if_needed(self):
if self.visualization_parameters.video_dump_methods and self.last_episode_images != []:
if self.should_dump_video_of_the_current_episode(episode_terminated=True):
self.dump_video_of_last_episode()
def dump_video_of_last_episode(self):
frame_skipping = max(1, int(5 / self.frame_skip))
file_name = 'episode-{}_score-{}'.format(self.episode_idx, self.total_reward_in_current_episode)
fps = 10
if self.visualization_parameters.dump_gifs:
logger.create_gif(self.last_episode_images[::frame_skipping], name=file_name, fps=fps)
if self.visualization_parameters.dump_mp4:
logger.create_mp4(self.last_episode_images[::frame_skipping], name=file_name, fps=fps)
def log_to_screen(self):
# log to screen
log = OrderedDict()
log["Episode"] = self.episode_idx
log["Total reward"] = np.round(self.total_reward_in_current_episode, 2)
log["Steps"] = self.total_steps_counter
screen.log_dict(log, prefix=self.phase.value)
# The following functions define the interaction with the environment.
# Any new environment that inherits the Environment class should use these signatures.
# Some of these functions are optional - please read their description for more details.
def _take_action(self, action_idx: ActionType) -> None:
"""
An environment dependent function that sends an action to the simulator.
:param action_idx: the action to perform on the environment
:return: None
"""
raise NotImplementedError("")
def _update_state(self) -> None:
"""
Updates the state from the environment.
Should update self.observation, self.reward, self.done, self.measurements and self.info
:return: None
"""
raise NotImplementedError("")
def _restart_environment_episode(self, force_environment_reset=False) -> None:
"""
Restarts the simulator episode
:param force_environment_reset: Force the environment to reset even if the episode is not done yet.
:return: None
"""
raise NotImplementedError("")
def _render(self) -> None:
"""
Renders the environment using the native simulator renderer
:return: None
"""
pass
def get_rendered_image(self) -> np.ndarray:
"""
Return a numpy array containing the image that will be rendered to the screen.
This can be different from the observation. For example, mujoco's observation is a measurements vector.
:return: numpy array containing the image that will be rendered to the screen
"""
return np.transpose(self.state['observation'], [1, 2, 0])
"""
Video Dumping Methods
"""
class VideoDumpMethod(object):
"""
Method used to decide when to dump videos
"""
def should_dump(self, episode_terminated=False, **kwargs):
raise NotImplementedError("")
class AlwaysDumpMethod(VideoDumpMethod):
"""
Dump video for every episode
"""
def __init__(self):
super().__init__()
def should_dump(self, episode_terminated=False, **kwargs):
return True
class MaxDumpMethod(VideoDumpMethod):
"""
Dump video every time a new max total reward has been achieved
"""
def __init__(self):
super().__init__()
self.max_reward_achieved = -np.inf
def should_dump(self, episode_terminated=False, **kwargs):
# if the episode has not finished yet we want to be prepared for dumping a video
if not episode_terminated:
return True
if kwargs['total_reward_in_current_episode'] > self.max_reward_achieved:
self.max_reward_achieved = kwargs['total_reward_in_current_episode']
return True
else:
return False
class EveryNEpisodesDumpMethod(object):
"""
Dump videos once in every N episodes
"""
def __init__(self, num_episodes_between_dumps: int):
super().__init__()
self.num_episodes_between_dumps = num_episodes_between_dumps
self.last_dumped_episode = 0
if num_episodes_between_dumps < 1:
raise ValueError("the number of episodes between dumps should be a positive number")
def should_dump(self, episode_terminated=False, **kwargs):
if kwargs['episode_idx'] >= self.last_dumped_episode + self.num_episodes_between_dumps - 1:
self.last_dumped_episode = kwargs['episode_idx']
return True
else:
return False
class SelectedPhaseOnlyDumpMethod(object):
"""
Dump videos when the phase of the environment matches a predefined phase
"""
def __init__(self, run_phases: Union[RunPhase, List[RunPhase]]):
self.run_phases = force_list(run_phases)
def should_dump(self, episode_terminated=False, **kwargs):
if kwargs['_phase'] in self.run_phases:
return True
else:
return False

View File

@@ -0,0 +1,149 @@
########################################################################################################################
####### Currently we are ignoring more complex cases including EnvironmentGroups - DO NOT USE THIS FILE ****************
########################################################################################################################
# #
# # Copyright (c) 2017 Intel Corporation
# #
# # Licensed under the Apache License, Version 2.0 (the "License");
# # you may not use this file except in compliance with the License.
# # You may obtain a copy of the License at
# #
# # http://www.apache.org/licenses/LICENSE-2.0
# #
# # Unless required by applicable law or agreed to in writing, software
# # distributed under the License is distributed on an "AS IS" BASIS,
# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# # See the License for the specific language governing permissions and
# # limitations under the License.
# #
#
# from typing import Union, List, Dict
# import numpy as np
# from environments import create_environment
# from environments.environment import Environment
# from environments.environment_interface import EnvironmentInterface, ActionType, ActionSpace
# from core_types import GoalType, Transition
#
#
# class EnvironmentGroup(EnvironmentInterface):
# """
# An EnvironmentGroup is a group of different environments.
# In the simple case, it will contain a single environment. But it can also contain multiple environments,
# where the agent can then act on them as a batch, such that the prediction of the action is more efficient.
# """
# def __init__(self, environments_parameters: List[Environment]):
# self.environments_parameters = environments_parameters
# self.environments = []
# self.action_space = []
# self.outgoing_control = []
# self._last_env_response = []
#
# @property
# def action_space(self) -> Union[List[ActionSpace], ActionSpace]:
# """
# Get the action space of the environment
# :return: the action space
# """
# return self.action_space
#
# @action_space.setter
# def action_space(self, val: Union[List[ActionSpace], ActionSpace]):
# """
# Set the action space of the environment
# :return: None
# """
# self.action_space = val
#
# @property
# def phase(self) -> RunPhase:
# """
# Get the phase of the environments group
# :return: the current phase
# """
# return self.phase
#
# @phase.setter
# def phase(self, val: RunPhase):
# """
# Change the phase of each one of the environments in the group
# :param val: the new phase
# :return: None
# """
# self.phase = val
# call_method_for_all(self.environments, 'phase', val)
#
# def _create_environments(self):
# """
# Create the environments using the given parameters and update the environments list
# :return: None
# """
# for environment_parameters in self.environments_parameters:
# environment = create_environment(environment_parameters)
# self.action_space = self.action_space.append(environment.action_space)
# self.environments.append(environment)
#
# @property
# def last_env_response(self) -> Union[List[Transition], Transition]:
# """
# Get the last environment response
# :return: a dictionary that contains the state, reward, etc.
# """
# return squeeze_list(self._last_env_response)
#
# @last_env_response.setter
# def last_env_response(self, val: Union[List[Transition], Transition]):
# """
# Set the last environment response
# :param val: the last environment response
# """
# self._last_env_response = force_list(val)
#
# def step(self, actions: Union[List[ActionType], ActionType]) -> List[Transition]:
# """
# Act in all the environments in the group.
# :param actions: can be either a single action if there is a single environment in the group, or a list of
# actions in case there are multiple environments in the group. Each action can be an action index
# or a numpy array representing a continuous action for example.
# :return: The responses from all the environments in the group
# """
#
# actions = force_list(actions)
# if len(actions) != len(self.environments):
# raise ValueError("The number of actions does not match the number of environments in the group")
#
# result = []
# for environment, action in zip(self.environments, actions):
# result.append(environment.step(action))
#
# self.last_env_response = result
#
# return result
#
# def reset(self, force_environment_reset: bool=False) -> List[Transition]:
# """
# Reset all the environments in the group
# :param force_environment_reset: force the reset of each one of the environments
# :return: a list of the environments responses
# """
# return call_method_for_all(self.environments, 'reset', force_environment_reset)
#
# def get_random_action(self) -> List[ActionType]:
# """
# Get a list of random action that can be applied on the environments in the group
# :return: a list of random actions
# """
# return call_method_for_all(self.environments, 'get_random_action')
#
# def set_goal(self, goal: GoalType) -> None:
# """
# Set the goal of each one of the environments in the group to be the given goal
# :param goal: a goal vector
# :return: None
# """
# # TODO: maybe enable setting multiple goals?
# call_method_for_all(self.environments, 'set_goal', goal)

View File

@@ -0,0 +1,76 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Union, Dict
from rl_coach.spaces import ActionSpace
from rl_coach.core_types import ActionType, EnvResponse, RunPhase
class EnvironmentInterface(object):
def __init__(self):
self._phase = RunPhase.UNDEFINED
@property
def phase(self) -> RunPhase:
"""
Get the phase of the environment
:return: the current phase
"""
return self._phase
@phase.setter
def phase(self, val: RunPhase):
"""
Change the phase of the environment
:param val: the new phase
:return: None
"""
self._phase = val
@property
def action_space(self) -> Union[Dict[str, ActionSpace], ActionSpace]:
"""
Get the action space of the environment (or of each of the agents wrapped in this environment.
i.e. in the LevelManager case")
:return: the action space
"""
raise NotImplementedError("")
def get_random_action(self) -> ActionType:
"""
Get a random action from the environment action space
:return: An action that follows the definition of the action space.
"""
raise NotImplementedError("")
def step(self, action: ActionType) -> Union[None, EnvResponse]:
"""
Make a single step in the environment using the given action
:param action: an action to use for stepping the environment. Should follow the definition of the action space.
:return: the environment response as returned in get_last_env_response or None for LevelManager
"""
raise NotImplementedError("")
def reset_internal_state(self, force_environment_reset: bool=False) -> Union[None, EnvResponse]:
"""
Reset the environment episode
:param force_environment_reset: in some cases, resetting the environment can be suppressed by the environment
itself. This flag allows force the reset.
:return: the environment response as returned in get_last_env_response or None for LevelManager
"""
raise NotImplementedError("")

View File

@@ -0,0 +1,454 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import gym
import numpy as np
import scipy.ndimage
from rl_coach.utils import lower_under_to_upper, short_dynamic_import
try:
import roboschool
from OpenGL import GL
except ImportError:
from rl_coach.logger import failed_imports
failed_imports.append("RoboSchool")
try:
from rl_coach.gym_extensions.continuous import mujoco
except:
from rl_coach.logger import failed_imports
failed_imports.append("GymExtensions")
try:
import pybullet_envs
except ImportError:
from rl_coach.logger import failed_imports
failed_imports.append("PyBullet")
from typing import Dict, Any, Union
from rl_coach.core_types import RunPhase
from rl_coach.environments.environment import Environment, EnvironmentParameters, LevelSelection
from rl_coach.spaces import DiscreteActionSpace, BoxActionSpace, ImageObservationSpace, VectorObservationSpace, \
StateSpace, RewardSpace
from rl_coach.filters.filter import NoInputFilter, NoOutputFilter
from rl_coach.filters.reward.reward_clipping_filter import RewardClippingFilter
from rl_coach.filters.observation.observation_rescale_to_size_filter import ObservationRescaleToSizeFilter
from rl_coach.filters.observation.observation_stacking_filter import ObservationStackingFilter
from rl_coach.filters.observation.observation_rgb_to_y_filter import ObservationRGBToYFilter
from rl_coach.filters.observation.observation_to_uint8_filter import ObservationToUInt8Filter
from rl_coach.filters.filter import InputFilter
import random
from rl_coach.base_parameters import VisualizationParameters
from rl_coach.logger import screen
# Parameters
class GymEnvironmentParameters(EnvironmentParameters):
def __init__(self):
super().__init__()
self.random_initialization_steps = 0
self.max_over_num_frames = 1
self.additional_simulator_parameters = None
@property
def path(self):
return 'rl_coach.environments.gym_environment:GymEnvironment'
"""
Roboschool Environment Components
"""
RoboSchoolInputFilters = NoInputFilter()
RoboSchoolOutputFilters = NoOutputFilter()
class Roboschool(GymEnvironmentParameters):
def __init__(self):
super().__init__()
self.frame_skip = 1
self.default_input_filter = RoboSchoolInputFilters
self.default_output_filter = RoboSchoolOutputFilters
gym_roboschool_envs = ['inverted_pendulum', 'inverted_pendulum_swingup', 'inverted_double_pendulum', 'reacher',
'hopper', 'walker2d', 'half_cheetah', 'ant', 'humanoid', 'humanoid_flagrun',
'humanoid_flagrun_harder', 'pong']
roboschool_v0 = {e: "{}".format(lower_under_to_upper(e) + '-v0') for e in gym_roboschool_envs}
"""
Mujoco Environment Components
"""
MujocoInputFilter = NoInputFilter()
MujocoOutputFilter = NoOutputFilter()
class Mujoco(GymEnvironmentParameters):
def __init__(self):
super().__init__()
self.frame_skip = 1
self.default_input_filter = MujocoInputFilter
self.default_output_filter = MujocoOutputFilter
gym_mujoco_envs = ['inverted_pendulum', 'inverted_double_pendulum', 'reacher', 'hopper', 'walker2d', 'half_cheetah',
'ant', 'swimmer', 'humanoid', 'humanoid_standup', 'pusher', 'thrower', 'striker']
mujoco_v2 = {e: "{}".format(lower_under_to_upper(e) + '-v2') for e in gym_mujoco_envs}
mujoco_v2['walker2d'] = 'Walker2d-v2'
gym_fetch_envs = ['reach', 'slide', 'push', 'pick_and_place']
fetch_v1 = {e: "{}".format('Fetch' + lower_under_to_upper(e) + '-v1') for e in gym_fetch_envs}
"""
Bullet Environment Components
"""
BulletInputFilter = NoInputFilter()
BulletOutputFilter = NoOutputFilter()
class Bullet(GymEnvironmentParameters):
def __init__(self):
super().__init__()
self.frame_skip = 1
self.default_input_filter = BulletInputFilter
self.default_output_filter = BulletOutputFilter
"""
Atari Environment Components
"""
AtariInputFilter = InputFilter(is_a_reference_filter=True)
AtariInputFilter.add_reward_filter('clipping', RewardClippingFilter(-1.0, 1.0))
AtariInputFilter.add_observation_filter('observation', 'rescaling',
ObservationRescaleToSizeFilter(ImageObservationSpace(np.array([84, 84, 3]),
high=255)))
AtariInputFilter.add_observation_filter('observation', 'to_grayscale', ObservationRGBToYFilter())
AtariInputFilter.add_observation_filter('observation', 'to_uint8', ObservationToUInt8Filter(0, 255))
AtariInputFilter.add_observation_filter('observation', 'stacking', ObservationStackingFilter(4))
AtariOutputFilter = NoOutputFilter()
class Atari(GymEnvironmentParameters):
def __init__(self):
super().__init__()
self.frame_skip = 4
self.max_over_num_frames = 2
self.random_initialization_steps = 30
self.default_input_filter = AtariInputFilter
self.default_output_filter = AtariOutputFilter
gym_atari_envs = ['air_raid', 'alien', 'amidar', 'assault', 'asterix', 'asteroids', 'atlantis',
'bank_heist', 'battle_zone', 'beam_rider', 'berzerk', 'bowling', 'boxing', 'breakout', 'carnival',
'centipede', 'chopper_command', 'crazy_climber', 'demon_attack', 'double_dunk',
'elevator_action', 'enduro', 'fishing_derby', 'freeway', 'frostbite', 'gopher', 'gravitar',
'hero', 'ice_hockey', 'jamesbond', 'journey_escape', 'kangaroo', 'krull', 'kung_fu_master',
'montezuma_revenge', 'ms_pacman', 'name_this_game', 'phoenix', 'pitfall', 'pong', 'pooyan',
'private_eye', 'qbert', 'riverraid', 'road_runner', 'robotank', 'seaquest', 'skiing',
'solaris', 'space_invaders', 'star_gunner', 'tennis', 'time_pilot', 'tutankham', 'up_n_down',
'venture', 'video_pinball', 'wizard_of_wor', 'yars_revenge', 'zaxxon']
atari_deterministic_v4 = {e: "{}".format(lower_under_to_upper(e) + 'Deterministic-v4') for e in gym_atari_envs}
atari_no_frameskip_v4 = {e: "{}".format(lower_under_to_upper(e) + 'NoFrameskip-v4') for e in gym_atari_envs}
class MaxOverFramesAndFrameskipEnvWrapper(gym.Wrapper):
def __init__(self, env, frameskip=4, max_over_num_frames=2):
super().__init__(env)
self.max_over_num_frames = max_over_num_frames
self.observations_stack = []
self.frameskip = frameskip
self.first_frame_to_max_over = self.frameskip - self.max_over_num_frames
def reset(self):
return self.env.reset()
def step(self, action):
total_reward = 0.0
done = None
info = None
self.observations_stack = []
for i in range(self.frameskip):
observation, reward, done, info = self.env.step(action)
if i >= self.first_frame_to_max_over:
self.observations_stack.append(observation)
total_reward += reward
if done:
# deal with last state in episode
if not self.observations_stack:
self.observations_stack.append(observation)
break
max_over_frames_observation = np.max(self.observations_stack, axis=0)
return max_over_frames_observation, total_reward, done, info
# Environment
class GymEnvironment(Environment):
def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters,
additional_simulator_parameters: Dict[str, Any] = None, seed: Union[None, int]=None,
human_control: bool=False, custom_reward_threshold: Union[int, float]=None,
random_initialization_steps: int=1, max_over_num_frames: int=1, **kwargs):
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold,
visualization_parameters)
self.random_initialization_steps = random_initialization_steps
self.max_over_num_frames = max_over_num_frames
self.additional_simulator_parameters = additional_simulator_parameters
# hide warnings
gym.logger.set_level(40)
"""
load and initialize environment
environment ids can be defined in 3 ways:
1. Native gym environments like BreakoutDeterministic-v0 for example
2. Custom gym environments written and installed as python packages.
This environments should have a python module with a class inheriting gym.Env, implementing the
relevant functions (_reset, _step, _render) and defining the observation and action space
For example: my_environment_package:MyEnvironmentClass will run an environment defined in the
MyEnvironmentClass class
3. Custom gym environments written as an independent module which is not installed.
This environments should have a python module with a class inheriting gym.Env, implementing the
relevant functions (_reset, _step, _render) and defining the observation and action space.
For example: path_to_my_environment.sub_directory.my_module:MyEnvironmentClass will run an
environment defined in the MyEnvironmentClass class which is located in the module in the relative path
path_to_my_environment.sub_directory.my_module
"""
if ':' in self.env_id:
# custom environments
if '/' in self.env_id or '.' in self.env_id:
# environment in a an absolute path module written as a unix path or in a relative path module
# written as a python import path
env_class = short_dynamic_import(self.env_id)
else:
# environment in a python package
env_class = gym.envs.registration.load(self.env_id)
# instantiate the environment
if self.additional_simulator_parameters:
self.env = env_class(**self.additional_simulator_parameters)
else:
self.env = env_class()
else:
self.env = gym.make(self.env_id)
# for classic control we want to use the native renderer because otherwise we will get 2 renderer windows
environment_to_always_use_with_native_rendering = ['classic_control', 'mujoco', 'robotics']
self.native_rendering = self.native_rendering or \
any([env in str(self.env.unwrapped.__class__)
for env in environment_to_always_use_with_native_rendering])
if self.native_rendering:
if hasattr(self, 'renderer'):
self.renderer.close()
# seed
if self.seed is not None:
self.env.seed(self.seed)
np.random.seed(self.seed)
random.seed(self.seed)
# frame skip and max between consecutive frames
self.is_robotics_env = 'robotics' in str(self.env.unwrapped.__class__)
self.is_mujoco_env = 'mujoco' in str(self.env.unwrapped.__class__)
self.is_atari_env = 'Atari' in str(self.env.unwrapped.__class__)
self.timelimit_env_wrapper = self.env
if self.is_atari_env:
self.env.unwrapped.frameskip = 1 # this accesses the atari env that is wrapped with a timelimit wrapper env
if self.env_id == "SpaceInvadersDeterministic-v4" and self.frame_skip == 4:
screen.warning("Warning: The frame-skip for Space Invaders was automatically updated from 4 to 3. "
"This is following the DQN paper where it was noticed that a frame-skip of 3 makes the "
"laser rays disappear. To force frame-skip of 4, please use SpaceInvadersNoFrameskip-v4.")
self.frame_skip = 3
self.env = MaxOverFramesAndFrameskipEnvWrapper(self.env,
frameskip=self.frame_skip,
max_over_num_frames=self.max_over_num_frames)
else:
self.env.unwrapped.frameskip = self.frame_skip
self.state_space = StateSpace({})
# observations
if not isinstance(self.env.observation_space, gym.spaces.dict_space.Dict):
state_space = {'observation': self.env.observation_space}
else:
state_space = self.env.observation_space.spaces
for observation_space_name, observation_space in state_space.items():
if len(observation_space.shape) == 3 and observation_space.shape[-1] == 3:
# we assume gym has image observations which are RGB and where their values are within 0-255
self.state_space[observation_space_name] = ImageObservationSpace(
shape=np.array(observation_space.shape),
high=255,
channels_axis=-1
)
else:
self.state_space[observation_space_name] = VectorObservationSpace(
shape=observation_space.shape[0],
low=observation_space.low,
high=observation_space.high
)
if 'desired_goal' in state_space.keys():
self.goal_space = self.state_space['desired_goal']
# actions
if type(self.env.action_space) == gym.spaces.box.Box:
self.action_space = BoxActionSpace(
shape=self.env.action_space.shape,
low=self.env.action_space.low,
high=self.env.action_space.high
)
elif type(self.env.action_space) == gym.spaces.discrete.Discrete:
actions_description = []
if hasattr(self.env.unwrapped, 'get_action_meanings'):
actions_description = self.env.unwrapped.get_action_meanings()
self.action_space = DiscreteActionSpace(
num_actions=self.env.action_space.n,
descriptions=actions_description
)
if self.human_control:
# TODO: add this to the action space
# map keyboard keys to actions
self.key_to_action = {}
if hasattr(self.env.unwrapped, 'get_keys_to_action'):
self.key_to_action = self.env.unwrapped.get_keys_to_action()
# initialize the state by getting a new state from the environment
self.reset_internal_state(True)
# render
if self.is_rendered:
image = self.get_rendered_image()
scale = 1
if self.human_control:
scale = 2
if not self.native_rendering:
self.renderer.create_screen(image.shape[1]*scale, image.shape[0]*scale)
# measurements
if self.env.spec is not None:
self.timestep_limit = self.env.spec.timestep_limit
else:
self.timestep_limit = None
# the info is only updated after the first step
self.state = self.step(self.action_space.default_action).next_state
self.state_space['measurements'] = VectorObservationSpace(shape=len(self.info.keys()))
if self.env.spec and custom_reward_threshold is None:
self.reward_success_threshold = self.env.spec.reward_threshold
self.reward_space = RewardSpace(1, reward_success_threshold=self.reward_success_threshold)
def _wrap_state(self, state):
if not isinstance(self.env.observation_space, gym.spaces.Dict):
return {'observation': state}
return state
def _update_state(self):
if self.is_atari_env and hasattr(self, 'current_ale_lives') \
and self.current_ale_lives != self.env.unwrapped.ale.lives():
if self.phase == RunPhase.TRAIN or self.phase == RunPhase.HEATUP:
# signal termination for life loss
self.done = True
elif self.phase == RunPhase.TEST and not self.done:
# the episode is not terminated in evaluation, but we need to press fire again
self._press_fire()
self._update_ale_lives()
# TODO: update the measurements
if self.state and "desired_goal" in self.state.keys():
self.goal = self.state['desired_goal']
def _take_action(self, action):
if type(self.action_space) == BoxActionSpace:
action = self.action_space.clip_action_to_space(action)
self.state, self.reward, self.done, self.info = self.env.step(action)
self.state = self._wrap_state(self.state)
def _random_noop(self):
# simulate a random initial environment state by stepping for a random number of times between 0 and 30
step_count = 0
random_initialization_steps = random.randint(0, self.random_initialization_steps)
while self.action_space is not None and (self.state is None or step_count < random_initialization_steps):
step_count += 1
self.step(self.action_space.default_action)
def _press_fire(self):
fire_action = 1
if self.is_atari_env and self.env.unwrapped.get_action_meanings()[fire_action] == 'FIRE':
self.current_ale_lives = self.env.unwrapped.ale.lives()
self.step(fire_action)
if self.done:
self.reset_internal_state()
def _update_ale_lives(self):
if self.is_atari_env:
self.current_ale_lives = self.env.unwrapped.ale.lives()
def _restart_environment_episode(self, force_environment_reset=False):
# prevent reset of environment if there are ale lives left
if (self.is_atari_env and self.env.unwrapped.ale.lives() > 0) \
and not force_environment_reset and not self.timelimit_env_wrapper._past_limit():
self.step(self.action_space.default_action)
else:
self.state = self.env.reset()
self.state = self._wrap_state(self.state)
self._update_ale_lives()
if self.is_atari_env:
self._random_noop()
self._press_fire()
# initialize the number of lives
self._update_ale_lives()
def _set_mujoco_camera(self, camera_idx: int):
"""
This function can be used to set the camera for rendering the mujoco simulator
:param camera_idx: The index of the camera to use. Should be defined in the model
:return: None
"""
if self.env.unwrapped.viewer.cam.fixedcamid != camera_idx and self.env.unwrapped.viewer._ncam > camera_idx:
from mujoco_py.generated import const
self.env.unwrapped.viewer.cam.type = const.CAMERA_FIXED
self.env.unwrapped.viewer.cam.fixedcamid = camera_idx
def _get_robotics_image(self):
self.env.render()
image = self.env.unwrapped._get_viewer().read_pixels(1600, 900, depth=False)[::-1, :, :]
image = scipy.misc.imresize(image, (270, 480, 3))
return image
def _render(self):
self.env.render(mode='human')
# required for setting up a fixed camera for mujoco
if self.is_mujoco_env:
self._set_mujoco_camera(0)
def get_rendered_image(self):
if self.is_robotics_env:
# necessary for fetch since the rendered image is cropped to an irrelevant part of the simulator
image = self._get_robotics_image()
else:
image = self.env.render(mode='rgb_array')
# required for setting up a fixed camera for mujoco
if self.is_mujoco_env:
self._set_mujoco_camera(0)
return image

View File

View File

@@ -0,0 +1,38 @@
# Copyright 2017 The dm_control Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Functions to manage the common assets for domains."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from dm_control.utils import resources
_SUITE_DIR = os.path.dirname(os.path.dirname(__file__))
_FILENAMES = [
"common/materials.xml",
"common/skybox.xml",
"common/visual.xml",
]
ASSETS = {filename: resources.GetResource(os.path.join(_SUITE_DIR, filename))
for filename in _FILENAMES}
def read_model(model_filename):
"""Reads a model XML file and returns its contents as a string."""
return resources.GetResource(os.path.join(_SUITE_DIR, model_filename))

View File

@@ -0,0 +1,22 @@
<!--
Common textures, colors and materials to be used throughout this suite. Some
materials such as xxx_highlight are activated on occurence of certain events,
for example receiving a positive reward.
-->
<mujoco>
<asset>
<texture name="grid" type="2d" builtin="checker" rgb1=".1 .2 .3" rgb2=".2 .3 .4" width="300" height="300" mark="edge" markrgb=".2 .3 .4"/>
<material name="grid" texture="grid" texrepeat="1 1" texuniform="true" reflectance=".2"/>
<material name="self" rgba=".7 .5 .3 1"/>
<material name="self_default" rgba=".7 .5 .3 1"/>
<material name="self_highlight" rgba="0 .5 .3 1"/>
<material name="effector" rgba=".7 .4 .2 1"/>
<material name="effector_default" rgba=".7 .4 .2 1"/>
<material name="effector_highlight" rgba="0 .5 .3 1"/>
<material name="decoration" rgba=".3 .5 .7 1"/>
<material name="eye" rgba="0 .2 1 1"/>
<material name="target" rgba=".6 .3 .3 1"/>
<material name="target_default" rgba=".6 .3 .3 1"/>
<material name="target_highlight" rgba=".6 .3 .3 .4"/>
</asset>
</mujoco>

View File

@@ -0,0 +1,6 @@
<mujoco>
<asset>
<texture name="skybox" type="skybox" builtin="gradient" rgb1=".4 .6 .8" rgb2="0 0 0"
width="800" height="800" mark="random" markrgb="1 1 1"/>
</asset>
</mujoco>

View File

@@ -0,0 +1,7 @@
<mujoco>
<visual>
<headlight ambient=".4 .4 .4" diffuse=".8 .8 .8" specular="0.1 0.1 0.1"/>
<map znear=".01"/>
<quality shadowsize="2048"/>
</visual>
</mujoco>

View File

@@ -0,0 +1,185 @@
import numpy as np
import gym
import os
from gym import spaces
from gym.envs.registration import EnvSpec
from mujoco_py import load_model_from_path, MjSim , MjViewer, MjRenderContextOffscreen
class PendulumWithGoals(gym.Env):
metadata = {
'render.modes': ['human', 'rgb_array'], 'video.frames_per_second': 30
}
def __init__(self, goal_reaching_thresholds=np.array([0.075, 0.075, 0.75]),
goal_not_reached_penalty=-1, goal_reached_reward=0, terminate_on_goal_reaching=True,
time_limit=1000, frameskip=1, random_goals_instead_of_standing_goal=False,
polar_coordinates: bool=False):
super().__init__()
dir = os.path.dirname(__file__)
model = load_model_from_path(dir + "/pendulum_with_goals.xml")
self.sim = MjSim(model)
self.viewer = None
self.rgb_viewer = None
self.frameskip = frameskip
self.goal = None
self.goal_reaching_thresholds = goal_reaching_thresholds
self.goal_not_reached_penalty = goal_not_reached_penalty
self.goal_reached_reward = goal_reached_reward
self.terminate_on_goal_reaching = terminate_on_goal_reaching
self.time_limit = time_limit
self.current_episode_steps_counter = 0
self.random_goals_instead_of_standing_goal = random_goals_instead_of_standing_goal
self.polar_coordinates = polar_coordinates
# spaces definition
self.action_space = spaces.Box(low=-self.sim.model.actuator_ctrlrange[:, 1],
high=self.sim.model.actuator_ctrlrange[:, 1],
dtype=np.float32)
if self.polar_coordinates:
self.observation_space = spaces.Dict({
"observation": spaces.Box(low=np.array([-np.pi, -15]),
high=np.array([np.pi, 15]),
dtype=np.float32),
"desired_goal": spaces.Box(low=np.array([-np.pi, -15]),
high=np.array([np.pi, 15]),
dtype=np.float32),
"achieved_goal": spaces.Box(low=np.array([-np.pi, -15]),
high=np.array([np.pi, 15]),
dtype=np.float32)
})
else:
self.observation_space = spaces.Dict({
"observation": spaces.Box(low=np.array([-1, -1, -15]),
high=np.array([1, 1, 15]),
dtype=np.float32),
"desired_goal": spaces.Box(low=np.array([-1, -1, -15]),
high=np.array([1, 1, 15]),
dtype=np.float32),
"achieved_goal": spaces.Box(low=np.array([-1, -1, -15]),
high=np.array([1, 1, 15]),
dtype=np.float32)
})
self.spec = EnvSpec('PendulumWithGoals-v0')
self.spec.reward_threshold = self.goal_not_reached_penalty * self.time_limit
self.reset()
def _goal_reached(self):
observation = self._get_obs()
if np.any(np.abs(observation['achieved_goal'] - observation['desired_goal']) > self.goal_reaching_thresholds):
return False
else:
return True
def _terminate(self):
if (self._goal_reached() and self.terminate_on_goal_reaching) or \
self.current_episode_steps_counter >= self.time_limit:
return True
else:
return False
def _reward(self):
if self._goal_reached():
return self.goal_reached_reward
else:
return self.goal_not_reached_penalty
def step(self, action):
self.sim.data.ctrl[:] = action
for _ in range(self.frameskip):
self.sim.step()
self.current_episode_steps_counter += 1
state = self._get_obs()
# visualize the angular velocities
state_velocity = np.copy(state['observation'][-1] / 20)
goal_velocity = self.goal[-1] / 20
self.sim.model.site_size[2] = np.array([0.01, 0.01, state_velocity])
self.sim.data.mocap_pos[2] = np.array([0.85, 0, 0.75 + state_velocity])
self.sim.model.site_size[3] = np.array([0.01, 0.01, goal_velocity])
self.sim.data.mocap_pos[3] = np.array([1.15, 0, 0.75 + goal_velocity])
return state, self._reward(), self._terminate(), {}
def _get_obs(self):
"""
y
^
|____
| /
| /
|~/
|/
--------> x
"""
# observation
angle = self.sim.data.qpos
angular_velocity = self.sim.data.qvel
if self.polar_coordinates:
observation = np.concatenate([angle - np.pi, angular_velocity])
else:
x = np.sin(angle)
y = np.cos(angle) # qpos is the angle relative to a standing pole
observation = np.concatenate([x, y, angular_velocity])
return {
"observation": observation,
"desired_goal": self.goal,
"achieved_goal": observation
}
def reset(self):
self.current_episode_steps_counter = 0
# set initial state
angle = np.random.uniform(np.pi / 4, 7 * np.pi / 4)
angular_velocity = np.random.uniform(-0.05, 0.05)
self.sim.data.qpos[0] = angle
self.sim.data.qvel[0] = angular_velocity
self.sim.step()
# goal
if self.random_goals_instead_of_standing_goal:
angle_target = np.random.uniform(-np.pi / 8, np.pi / 8)
angular_velocity_target = np.random.uniform(-0.2, 0.2)
else:
angle_target = 0
angular_velocity_target = 0
# convert target values to goal
x_target = np.sin(angle_target)
y_target = np.cos(angle_target)
if self.polar_coordinates:
self.goal = np.array([angle_target - np.pi, angular_velocity_target])
else:
self.goal = np.array([x_target, y_target, angular_velocity_target])
# visualize the goal
self.sim.data.mocap_pos[0] = [x_target, 0, y_target]
return self._get_obs()
def render(self, mode='human', close=False):
if mode == 'human':
if self.viewer is None:
self.viewer = MjViewer(self.sim)
self.viewer.render()
elif mode == 'rgb_array':
if self.rgb_viewer is None:
self.rgb_viewer = MjRenderContextOffscreen(self.sim, 0)
self.rgb_viewer.render(500, 500)
# window size used for old mujoco-py:
data = self.rgb_viewer.read_pixels(500, 500, depth=False)
# original image is upside-down, so flip it
return data[::-1, :, :]

View File

@@ -0,0 +1,42 @@
<mujoco model="pendulum_with_goals">
<include file="./common/visual.xml"/>
<include file="./common/skybox.xml"/>
<include file="./common/materials.xml"/>
<option timestep="0.002">
<flag contact="disable" energy="enable"/>
</option>
<worldbody>
<light name="light" pos="0 0 2"/>
<geom name="floor" size="2 2 .2" type="plane" material="grid"/>
<camera name="fixed" pos="0 -1.5 2" xyaxes='1 0 0 0 1 1'/>
<camera name="lookat" mode="targetbodycom" target="pole" pos="0 -2 1"/>
<body name="pole" pos="0 0 .6">
<joint name="hinge" type="hinge" axis="0 1 0" damping="0.1"/>
<geom name="base" material="decoration" type="cylinder" fromto="0 -.03 0 0 .03 0" size="0.021" mass="0"/>
<geom name="pole" material="self" type="capsule" fromto="0 0 0 0 0 0.5" size="0.02" mass="0"/>
<geom name="mass" material="effector" type="sphere" pos="0 0 0.5" size="0.05" mass="1"/>
</body>
<body name="end_goal" pos="0 0 0" mocap="true">
<site type="sphere" size="0.05" rgba="1 1 0 1" />
</body>
<!--<body name="sub_goal" pos="0 0 0" mocap="true">-->
<!--<site type="sphere" size="0.05" rgba="1 0 1 1" />-->
<!--</body>-->
<body name="current_velo" pos="0.0 0 0.0" mocap="true">
<site type="box" size="0.01 0.01 0.1" rgba="1 1 1 1" />
</body>
<body name="subgoal_velo" pos="0.0 0 0.0" mocap="true">
<site type="box" size="0.01 0.01 0.1" rgba="1 0 1 1" />
</body>
<body name="zero_velo" pos="1.0 0 0.75" mocap="true">
<site type="box" size="0.3 0.01 0.01" rgba="1 0 0 1" />
</body>
</worldbody>
<actuator>
<motor name="torque" joint="hinge" gear="1" ctrlrange="-2 2" ctrllimited="true"/>
</actuator>
</mujoco>

View File

@@ -0,0 +1,245 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from enum import Enum
from typing import Union, List
import numpy as np
from rl_coach.filters.observation.observation_move_axis_filter import ObservationMoveAxisFilter
try:
from pysc2 import maps
from pysc2.env import sc2_env
from pysc2.env import available_actions_printer
from pysc2.lib import actions
from pysc2.lib import features
from pysc2.env import environment
from absl import app
from absl import flags
except ImportError:
from rl_coach.logger import failed_imports
failed_imports.append("PySc2")
from rl_coach.environments.environment import Environment, EnvironmentParameters, LevelSelection
from rl_coach.base_parameters import VisualizationParameters
from rl_coach.spaces import BoxActionSpace, VectorObservationSpace, PlanarMapsObservationSpace, StateSpace, CompoundActionSpace, \
DiscreteActionSpace
from rl_coach.filters.filter import InputFilter, OutputFilter
from rl_coach.filters.observation.observation_rescale_to_size_filter import ObservationRescaleToSizeFilter
from rl_coach.filters.action.linear_box_to_box_map import LinearBoxToBoxMap
from rl_coach.filters.observation.observation_to_uint8_filter import ObservationToUInt8Filter
FLAGS = flags.FLAGS
FLAGS(['coach.py'])
SCREEN_SIZE = 84 # will also impact the action space size
# Starcraft Constants
_NOOP = actions.FUNCTIONS.no_op.id
_MOVE_SCREEN = actions.FUNCTIONS.Move_screen.id
_SELECT_ARMY = actions.FUNCTIONS.select_army.id
_PLAYER_RELATIVE = features.SCREEN_FEATURES.player_relative.index
_NOT_QUEUED = [0]
_SELECT_ALL = [0]
class StarcraftObservationType(Enum):
Features = 0
RGB = 1
StarcraftInputFilter = InputFilter(is_a_reference_filter=True)
StarcraftInputFilter.add_observation_filter('screen', 'move_axis', ObservationMoveAxisFilter(0, -1))
StarcraftInputFilter.add_observation_filter('screen', 'rescaling',
ObservationRescaleToSizeFilter(
PlanarMapsObservationSpace(np.array([84, 84, 1]),
low=0, high=255, channels_axis=-1)))
StarcraftInputFilter.add_observation_filter('screen', 'to_uint8', ObservationToUInt8Filter(0, 255))
StarcraftInputFilter.add_observation_filter('minimap', 'move_axis', ObservationMoveAxisFilter(0, -1))
StarcraftInputFilter.add_observation_filter('minimap', 'rescaling',
ObservationRescaleToSizeFilter(
PlanarMapsObservationSpace(np.array([64, 64, 1]),
low=0, high=255, channels_axis=-1)))
StarcraftInputFilter.add_observation_filter('minimap', 'to_uint8', ObservationToUInt8Filter(0, 255))
StarcraftNormalizingOutputFilter = OutputFilter(is_a_reference_filter=True)
StarcraftNormalizingOutputFilter.add_action_filter(
'normalization', LinearBoxToBoxMap(input_space_low=-SCREEN_SIZE / 2, input_space_high=SCREEN_SIZE / 2 - 1))
class StarCraft2EnvironmentParameters(EnvironmentParameters):
def __init__(self):
super().__init__()
self.screen_size = 84
self.minimap_size = 64
self.feature_minimap_maps_to_use = range(7)
self.feature_screen_maps_to_use = range(17)
self.observation_type = StarcraftObservationType.Features
self.disable_fog = False
self.auto_select_all_army = True
self.default_input_filter = StarcraftInputFilter
self.default_output_filter = StarcraftNormalizingOutputFilter
self.use_full_action_space = False
@property
def path(self):
return 'rl_coach.environments.starcraft2_environment:StarCraft2Environment'
# Environment
class StarCraft2Environment(Environment):
def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters,
seed: Union[None, int]=None, human_control: bool=False,
custom_reward_threshold: Union[int, float]=None,
screen_size: int=84, minimap_size: int=64,
feature_minimap_maps_to_use: List=range(7), feature_screen_maps_to_use: List=range(17),
observation_type: StarcraftObservationType=StarcraftObservationType.Features,
disable_fog: bool=False, auto_select_all_army: bool=True,
use_full_action_space: bool=False, **kwargs):
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters)
self.screen_size = screen_size
self.minimap_size = minimap_size
self.feature_minimap_maps_to_use = feature_minimap_maps_to_use
self.feature_screen_maps_to_use = feature_screen_maps_to_use
self.observation_type = observation_type
self.features_screen_size = None
self.feature_minimap_size = None
self.rgb_screen_size = None
self.rgb_minimap_size = None
if self.observation_type == StarcraftObservationType.Features:
self.features_screen_size = screen_size
self.feature_minimap_size = minimap_size
elif self.observation_type == StarcraftObservationType.RGB:
self.rgb_screen_size = screen_size
self.rgb_minimap_size = minimap_size
self.disable_fog = disable_fog
self.auto_select_all_army = auto_select_all_army
self.use_full_action_space = use_full_action_space
# step_mul is the equivalent to frame skipping. Not sure if it repeats actions in between or not though.
self.env = sc2_env.SC2Env(map_name=self.env_id, step_mul=frame_skip,
visualize=self.is_rendered,
agent_interface_format=sc2_env.AgentInterfaceFormat(
feature_dimensions=sc2_env.Dimensions(
screen=self.features_screen_size,
minimap=self.feature_minimap_size
)
# rgb_dimensions=sc2_env.Dimensions(
# screen=self.rgb_screen_size,
# minimap=self.rgb_screen_size
# )
),
# feature_screen_size=self.features_screen_size,
# feature_minimap_size=self.feature_minimap_size,
# rgb_screen_size=self.rgb_screen_size,
# rgb_minimap_size=self.rgb_screen_size,
disable_fog=disable_fog,
random_seed=self.seed
)
# print all the available actions
# self.env = available_actions_printer.AvailableActionsPrinter(self.env)
self.reset_internal_state(True)
"""
feature_screen: [height_map, visibility_map, creep, power, player_id, player_relative, unit_type, selected,
unit_hit_points, unit_hit_points_ratio, unit_energy, unit_energy_ratio, unit_shields,
unit_shields_ratio, unit_density, unit_density_aa, effects]
feature_minimap: [height_map, visibility_map, creep, camera, player_id, player_relative, selecte
d]
player: [player_id, minerals, vespene, food_cap, food_army, food_workers, idle_worker_dount,
army_count, warp_gate_count, larva_count]
"""
self.screen_shape = np.array(self.env.observation_spec()[0]['feature_screen'])
self.screen_shape[0] = len(self.feature_screen_maps_to_use)
self.minimap_shape = np.array(self.env.observation_spec()[0]['feature_minimap'])
self.minimap_shape[0] = len(self.feature_minimap_maps_to_use)
self.state_space = StateSpace({
"screen": PlanarMapsObservationSpace(shape=self.screen_shape, low=0, high=255, channels_axis=0),
"minimap": PlanarMapsObservationSpace(shape=self.minimap_shape, low=0, high=255, channels_axis=0),
"measurements": VectorObservationSpace(self.env.observation_spec()[0]["player"][0])
})
if self.use_full_action_space:
action_identifiers = list(self.env.action_spec()[0].functions)
num_action_identifiers = len(action_identifiers)
action_arguments = [(arg.name, arg.sizes) for arg in self.env.action_spec()[0].types]
sub_action_spaces = [DiscreteActionSpace(num_action_identifiers)]
for argument in action_arguments:
for dimension in argument[1]:
sub_action_spaces.append(DiscreteActionSpace(dimension))
self.action_space = CompoundActionSpace(sub_action_spaces)
else:
self.action_space = BoxActionSpace(2, 0, self.screen_size - 1, ["X-Axis, Y-Axis"],
default_action=np.array([self.screen_size/2, self.screen_size/2]))
def _update_state(self):
timestep = 0
self.screen = self.last_result[timestep].observation.feature_screen
# extract only the requested segmentation maps from the observation
self.screen = np.take(self.screen, self.feature_screen_maps_to_use, axis=0)
self.minimap = self.last_result[timestep].observation.feature_minimap
self.measurements = self.last_result[timestep].observation.player
self.reward = self.last_result[timestep].reward
self.done = self.last_result[timestep].step_type == environment.StepType.LAST
self.state = {
'screen': self.screen,
'minimap': self.minimap,
'measurements': self.measurements
}
def _take_action(self, action):
if self.use_full_action_space:
action_identifier = action[0]
action_arguments = action[1:]
action = actions.FunctionCall(action_identifier, action_arguments)
else:
coord = np.array(action[0:2])
noop = False
coord = coord.round()
coord = np.clip(coord, 0, SCREEN_SIZE - 1)
self.last_action_idx = coord
if noop:
action = actions.FunctionCall(_NOOP, [])
else:
action = actions.FunctionCall(_MOVE_SCREEN, [_NOT_QUEUED, coord])
self.last_result = self.env.step(actions=[action])
def _restart_environment_episode(self, force_environment_reset=False):
# reset the environment
self.last_result = self.env.reset()
# select all the units on the screen
if self.auto_select_all_army:
self.env.step(actions=[actions.FunctionCall(_SELECT_ARMY, [_SELECT_ALL])])
def get_rendered_image(self):
screen = np.squeeze(np.tile(np.expand_dims(self.screen, -1), (1, 1, 3)))
screen = screen / np.max(screen) * 255
return screen.astype('uint8')
def dump_video_of_last_episode(self):
from rl_coach.logger import experiment_path
self.env._run_config.replay_dir = experiment_path
self.env.save_replay('replays')
super().dump_video_of_last_episode()

View File

@@ -0,0 +1,82 @@
import numpy as np
import gym
from gym import spaces
import random
class BitFlip(gym.Env):
metadata = {
'render.modes': ['human', 'rgb_array'], 'video.frames_per_second': 30
}
def __init__(self, bit_length=16, max_steps=None, mean_zero=False):
super(BitFlip, self).__init__()
if bit_length < 1:
raise ValueError('bit_length must be >= 1, found {}'.format(bit_length))
self.bit_length = bit_length
self.mean_zero = mean_zero
if max_steps is None:
# default to bit_length
self.max_steps = bit_length
elif max_steps == 0:
self.max_steps = None
else:
self.max_steps = max_steps
# spaces documentation: https://gym.openai.com/docs/
self.action_space = spaces.Discrete(bit_length)
self.observation_space = spaces.Dict({
'state': spaces.Box(low=0, high=1, shape=(bit_length, )),
'desired_goal': spaces.Box(low=0, high=1, shape=(bit_length, )),
'achieved_goal': spaces.Box(low=0, high=1, shape=(bit_length, ))
})
self.reset()
def _terminate(self):
return (self.state == self.goal).all() or self.steps >= self.max_steps
def _reward(self):
return -1 if (self.state != self.goal).any() else 0
def step(self, action):
# action is an int in the range [0, self.bit_length)
self.state[action] = int(not self.state[action])
self.steps += 1
return (self._get_obs(), self._reward(), self._terminate(), {})
def reset(self):
self.steps = 0
self.state = np.array([random.choice([1, 0]) for _ in range(self.bit_length)])
# make sure goal is not the initial state
self.goal = self.state
while (self.goal == self.state).all():
self.goal = np.array([random.choice([1, 0]) for _ in range(self.bit_length)])
return self._get_obs()
def _mean_zero(self, x):
if self.mean_zero:
return (x - 0.5) / 0.5
else:
return x
def _get_obs(self):
return {
'state': self._mean_zero(self.state),
'desired_goal': self._mean_zero(self.goal),
'achieved_goal': self._mean_zero(self.state)
}
def render(self, mode='human', close=False):
observation = np.zeros((20, 20 * self.bit_length, 3))
for bit_idx, (state_bit, goal_bit) in enumerate(zip(self.state, self.goal)):
# green if the bit matches
observation[:, bit_idx * 20:(bit_idx + 1) * 20, 1] = (state_bit == goal_bit) * 255
# red if the bit doesn't match
observation[:, bit_idx * 20:(bit_idx + 1) * 20, 0] = (state_bit != goal_bit) * 255
return observation

View File

@@ -0,0 +1,126 @@
import numpy as np
import gym
from gym import spaces
from enum import Enum
class ExplorationChain(gym.Env):
metadata = {
'render.modes': ['human', 'rgb_array'], 'video.frames_per_second': 30
}
class ObservationType(Enum):
OneHot = 0
Therm = 1
def __init__(self, chain_length=16, start_state=1, max_steps=None, observation_type=ObservationType.Therm,
left_state_reward=1/1000, right_state_reward=1, simple_render=True):
super().__init__()
if chain_length <= 3:
raise ValueError('Chain length must be > 3, found {}'.format(chain_length))
if not 0 <= start_state < chain_length:
raise ValueError('The start state should be within the chain bounds, found {}'.format(start_state))
self.chain_length = chain_length
self.start_state = start_state
self.max_steps = max_steps
self.observation_type = observation_type
self.left_state_reward = left_state_reward
self.right_state_reward = right_state_reward
self.simple_render = simple_render
# spaces documentation: https://gym.openai.com/docs/
self.action_space = spaces.Discrete(2) # 0 -> Go left, 1 -> Go right
self.observation_space = spaces.Box(0, 1, shape=(chain_length,))#spaces.MultiBinary(chain_length)
self.reset()
def _terminate(self):
return self.steps >= self.max_steps
def _reward(self):
if self.state == 0:
return self.left_state_reward
elif self.state == self.chain_length - 1:
return self.right_state_reward
else:
return 0
def step(self, action):
# action is 0 or 1
if action == 0:
if 0 < self.state:
self.state -= 1
elif action == 1:
if self.state < self.chain_length - 1:
self.state += 1
else:
raise ValueError("An invalid action was given. The available actions are - 0 or 1, found {}".format(action))
self.steps += 1
return self._get_obs(), self._reward(), self._terminate(), {}
def reset(self):
self.steps = 0
self.state = self.start_state
return self._get_obs()
def _get_obs(self):
self.observation = np.zeros((self.chain_length,))
if self.observation_type == self.ObservationType.OneHot:
self.observation[self.state] = 1
elif self.observation_type == self.ObservationType.Therm:
self.observation[:(self.state+1)] = 1
return self.observation
def render(self, mode='human', close=False):
if self.simple_render:
observation = np.zeros((20, 20*self.chain_length))
observation[:, self.state*20:(self.state+1)*20] = 255
return observation
else:
# lazy loading of networkx and matplotlib to allow using the environment without installing them if
# necessary
import networkx as nx
from networkx.drawing.nx_agraph import graphviz_layout
import matplotlib.pyplot as plt
if not hasattr(self, 'G'):
self.states = list(range(self.chain_length))
self.G = nx.DiGraph(directed=True)
for i, origin_state in enumerate(self.states):
if i < self.chain_length - 1:
self.G.add_edge(origin_state,
origin_state + 1,
weight=0.5)
if i > 0:
self.G.add_edge(origin_state,
origin_state - 1,
weight=0.5, )
if i == 0 or i < self.chain_length - 1:
self.G.add_edge(origin_state,
origin_state,
weight=0.5, )
fig = plt.gcf()
if np.all(fig.get_size_inches() != [10, 2]):
fig.set_size_inches(5, 1)
color = ['y']*(len(self.G))
color[self.state] = 'r'
options = {
'node_color': color,
'node_size': 50,
'width': 1,
'arrowstyle': '-|>',
'arrowsize': 5,
'font_size': 6
}
pos = graphviz_layout(self.G, prog='dot', args='-Grankdir=LR')
nx.draw_networkx(self.G, pos, arrows=True, **options)
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
return data