1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00
Files
coach/rl_coach/environments/environment.py
shadiendrawis 0896f43097 Robosuite exploration (#478)
* Add Robosuite parameters for all env types + initialize env flow

* Init flow done

* Rest of Environment API complete for RobosuiteEnvironment

* RobosuiteEnvironment changes

* Observation stacking filter
* Add proper frame_skip in addition to control_freq
* Hardcode Coach rendering to 'frontview' camera

* Robosuite_Lift_DDPG preset + Robosuite env updates

* Move observation stacking filter from env to preset
* Pre-process observation - concatenate depth map (if exists)
  to image and object state (if exists) to robot state
* Preset parameters based on Surreal DDPG parameters, taken from:
  https://github.com/SurrealAI/surreal/blob/master/surreal/main/ddpg_configs.py

* RobosuiteEnvironment fixes - working now with PyGame rendering

* Preset minor modifications

* ObservationStackingFilter - option to concat non-vector observations

* Consider frame skip when setting horizon in robosuite env

* Robosuite lift preset - update heatup length and training interval

* Robosuite env - change control_freq to 10 to match Surreal usage

* Robosuite clipped PPO preset

* Distribute multiple workers (-n #) over multiple GPUs

* Clipped PPO memory optimization from @shadiendrawis

* Fixes to evaluation only workers

* RoboSuite_ClippedPPO: Update training interval

* Undo last commit (update training interval)

* Fix "doube-negative" if conditions

* multi-agent single-trainer clipped ppo training with cartpole

* cleanups (not done yet) + ~tuned hyper-params for mast

* Switch to Robosuite v1 APIs

* Change presets to IK controller

* more cleanups + enabling evaluation worker + better logging

* RoboSuite_Lift_ClippedPPO updates

* Fix major bug in obs normalization filter setup

* Reduce coupling between Robosuite API and Coach environment

* Now only non task-specific parameters are explicitly defined
  in Coach
* Removed a bunch of enums of Robosuite elements, using simple
  strings instead
* With this change new environments/robots/controllers in Robosuite
  can be used immediately in Coach

* MAST: better logging of actor-trainer interaction + bug fixes + performance improvements.

Still missing: fixed pubsub for obs normalization running stats + logging for trainer signals

* lstm support for ppo

* setting JOINT VELOCITY action space by default + fix for EveryNEpisodes video dump filter + new TaskIDDumpFilter + allowing or between video dump filters

* Separate Robosuite clipped PPO preset for the non-MAST case

* Add flatten layer to architectures and use it in Robosuite presets

This is required for embedders that mix conv and dense

TODO: Add MXNet implementation

* publishing running_stats together with the published policy + hyper-param for when to publish a policy + cleanups

* bug-fix for memory leak in MAST

* Bugfix: Return value in TF BatchnormActivationDropout.to_tf_instance

* Explicit activations in embedder scheme so there's no ReLU after flatten

* Add clipped PPO heads with configurable dense layers at the beginning

* This is a workaround needed to mimic Surreal-PPO, where the CNN and
  LSTM are shared between actor and critic but the FC layers are not
  shared
* Added a "SchemeBuilder" class, currently only used for the new heads
  but we can change Middleware and Embedder implementations to use it
  as well

* Video dump setting fix in basic preset

* logging screen output to file

* coach to start the redis-server for a MAST run

* trainer drops off-policy data + old policy in ClippedPPO updates only after policy was published + logging free memory stats + actors check for a new policy only at the beginning of a new episode + fixed a bug where the trainer was logging "Training Reward = 0", causing dashboard to incorrectly display the signal

* Add missing set_internal_state function in TFSharedRunningStats

* Robosuite preset - use SingleLevelSelect instead of hard-coded level

* policy ID published directly on Redis

* Small fix when writing to log file

* Major bugfix in Robosuite presets - pass dense sizes to heads

* RoboSuite_Lift_ClippedPPO hyper-params update

* add horizon and value bootstrap to GAE calculation, fix A3C with LSTM

* adam hyper-params from mujoco

* updated MAST preset with IK_POSE_POS controller

* configurable initialization for policy stdev + custom extra noise per actor + logging of policy stdev to dashboard

* values loss weighting of 0.5

* minor fixes + presets

* bug-fix for MAST  where the old policy in the trainer had kept updating every training iter while it should only update after every policy publish

* bug-fix: reset_internal_state was not called by the trainer

* bug-fixes in the lstm flow + some hyper-param adjustments for CartPole_ClippedPPO_LSTM -> training and sometimes reaches 200

* adding back the horizon hyper-param - a messy commit

* another bug-fix missing from prev commit

* set control_freq=2 to match action_scale 0.125

* ClippedPPO with MAST cleanups and some preps for TD3 with MAST

* TD3 presets. RoboSuite_Lift_TD3 seems to work well with multi-process runs (-n 8)

* setting termination on collision to be on by default

* bug-fix following prev-prev commit

* initial cube exploration environment with TD3 commit

* bug fix + minor refactoring

* several parameter changes and RND debugging

* Robosuite Gym wrapper + Rename TD3_Random* -> Random*

* algorithm update

* Add RoboSuite v1 env + presets (to eventually replace non-v1 ones)

* Remove grasping presets, keep only V1 exp. presets (w/o V1 tag)

* Keep just robosuite V1 env as the 'robosuite_environment' module

* Exclude Robosuite and MAST presets from integration tests

* Exclude LSTM and MAST presets from golden tests

* Fix mistakenly removed import

* Revert debug changes in ReaderWriterLock

* Try another way to exclude LSTM/MAST golden tests

* Remove debug prints

* Remove PreDense heads, unused in the end

* Missed removing an instance of PreDense head

* Remove MAST, not required for this PR

* Undo unused concat option in ObservationStackingFilter

* Remove LSTM updates, not required in this PR

* Update README.md

* code changes for the exploration flow to work with robosuite master branch

* code cleanup + documentation

* jupyter tutorial for the goal-based exploration + scatter plot

* typo fix

* Update README.md

* seprate parameter for the obs-goal observation + small fixes

* code clarity fixes

* adjustment in tutorial 5

* Update tutorial

* Update tutorial

Co-authored-by: Guy Jacob <guy.jacob@intel.com>
Co-authored-by: Gal Leibovich <gal.leibovich@intel.com>
Co-authored-by: shadi.endrawis <sendrawi@aipg-ra-skx-03.ra.intel.com>
2021-06-01 00:34:19 +03:00

505 lines
18 KiB
Python

#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import operator
import time
from collections import OrderedDict
from typing import Union, List, Tuple, Dict
import numpy as np
from rl_coach import logger
from rl_coach.base_parameters import Parameters
from rl_coach.base_parameters import VisualizationParameters
from rl_coach.core_types import GoalType, ActionType, EnvResponse, RunPhase
from rl_coach.environments.environment_interface import EnvironmentInterface
from rl_coach.logger import screen
from rl_coach.renderer import Renderer
from rl_coach.spaces import ActionSpace, ObservationSpace, DiscreteActionSpace, RewardSpace, StateSpace
from rl_coach.utils import squeeze_list, force_list
class LevelSelection(object):
def __init__(self, level: str):
self.selected_level = level
def select(self, level: str):
self.selected_level = level
def __str__(self):
if self.selected_level is None:
logger.screen.error("No level has been selected. Please select a level using the -lvl command line flag, "
"or change the level in the preset.", crash=True)
return self.selected_level
class SingleLevelSelection(LevelSelection):
def __init__(self, levels: Union[str, List[str], Dict[str, str]], force_lower=True):
super().__init__(None)
self.levels = levels
if isinstance(levels, list):
self.levels = {level: level for level in levels}
if isinstance(levels, str):
self.levels = {levels: levels}
self.force_lower = force_lower
def __str__(self):
if self.selected_level is None:
logger.screen.error("No level has been selected. Please select a level using the -lvl command line flag, "
"or change the level in the preset. \nThe available levels are: \n{}"
.format(', '.join(sorted(self.levels.keys()))), crash=True)
selected_level = self.selected_level.lower() if self.force_lower else self.selected_level
if selected_level not in self.levels.keys():
logger.screen.error("The selected level ({}) is not part of the available levels ({})"
.format(selected_level, ', '.join(self.levels.keys())), crash=True)
return self.levels[selected_level]
# class SingleLevelPerPhase(LevelSelection):
# def __init__(self, levels: Dict[RunPhase, str]):
# super().__init__(None)
# self.levels = levels
#
# def __str__(self):
# super().__str__()
# if self.selected_level not in self.levels.keys():
# logger.screen.error("The selected level ({}) is not part of the available levels ({})"
# .format(self.selected_level, self.levels.keys()), crash=True)
# return self.levels[self.selected_level]
class CustomWrapper(object):
def __init__(self, environment):
super().__init__()
self.environment = environment
def __getattr__(self, attr):
if attr in self.__dict__:
return self.__dict__[attr]
else:
return getattr(self.environment, attr, False)
class EnvironmentParameters(Parameters):
def __init__(self, level=None):
super().__init__()
self.level = level
self.frame_skip = 4
self.seed = None
self.human_control = False
self.custom_reward_threshold = None
self.default_input_filter = None
self.default_output_filter = None
self.experiment_path = None
# Set target reward and target_success if present
self.target_success_rate = 1.0
@property
def path(self):
return 'rl_coach.environments.environment:Environment'
class Environment(EnvironmentInterface):
def __init__(self, level: LevelSelection, seed: int, frame_skip: int, human_control: bool,
custom_reward_threshold: Union[int, float], visualization_parameters: VisualizationParameters,
target_success_rate: float=1.0, **kwargs):
"""
:param level: The environment level. Each environment can have multiple levels
:param seed: a seed for the random number generator of the environment
:param frame_skip: number of frames to skip (while repeating the same action) between each two agent directives
:param human_control: human should control the environment
:param visualization_parameters: a blob of parameters used for visualization of the environment
:param **kwargs: as the class is instantiated by EnvironmentParameters, this is used to support having
additional arguments which will be ignored by this class, but might be used by others
"""
super().__init__()
# env initialization
self.game = []
self.state = {}
self.observation = None
self.goal = None
self.reward = 0
self.done = False
self.info = {}
self._last_env_response = None
self.last_action = 0
self.episode_idx = 0
self.total_steps_counter = 0
self.current_episode_steps_counter = 0
self.last_episode_time = time.time()
self.key_to_action = {}
self.last_episode_images = []
# rewards
self.total_reward_in_current_episode = 0
self.max_reward_achieved = -np.inf
self.reward_success_threshold = custom_reward_threshold
# spaces
self.state_space = self._state_space = None
self.goal_space = self._goal_space = None
self.action_space = self._action_space = None
self.reward_space = RewardSpace(1, reward_success_threshold=self.reward_success_threshold) # TODO: add a getter and setter
self.env_id = str(level)
self.seed = seed
self.frame_skip = frame_skip
# human interaction and visualization
self.human_control = human_control
self.wait_for_explicit_human_action = False
self.is_rendered = visualization_parameters.render or self.human_control
self.native_rendering = visualization_parameters.native_rendering and not self.human_control
self.visualization_parameters = visualization_parameters
if not self.native_rendering:
self.renderer = Renderer()
# Set target reward and target_success if present
self.target_success_rate = target_success_rate
@property
def action_space(self) -> Union[List[ActionSpace], ActionSpace]:
"""
Get the action space of the environment
:return: the action space
"""
return self._action_space
@action_space.setter
def action_space(self, val: Union[List[ActionSpace], ActionSpace]):
"""
Set the action space of the environment
:return: None
"""
self._action_space = val
@property
def state_space(self) -> Union[List[StateSpace], StateSpace]:
"""
Get the state space of the environment
:return: the observation space
"""
return self._state_space
@state_space.setter
def state_space(self, val: Union[List[StateSpace], StateSpace]):
"""
Set the state space of the environment
:return: None
"""
self._state_space = val
@property
def goal_space(self) -> Union[List[ObservationSpace], ObservationSpace]:
"""
Get the state space of the environment
:return: the observation space
"""
return self._goal_space
@goal_space.setter
def goal_space(self, val: Union[List[ObservationSpace], ObservationSpace]):
"""
Set the goal space of the environment
:return: None
"""
self._goal_space = val
def get_action_from_user(self) -> ActionType:
"""
Get an action from the user keyboard
:return: action index
"""
if self.wait_for_explicit_human_action:
while len(self.renderer.pressed_keys) == 0:
self.renderer.get_events()
if self.key_to_action == {}:
# the keys are the numbers on the keyboard corresponding to the action index
if len(self.renderer.pressed_keys) > 0:
action_idx = self.renderer.pressed_keys[0] - ord("1")
if 0 <= action_idx < self.action_space.shape[0]:
return action_idx
else:
# the keys are mapped through the environment to more intuitive keyboard keys
# key = tuple(self.renderer.pressed_keys)
# for key in self.renderer.pressed_keys:
for env_keys in self.key_to_action.keys():
if set(env_keys) == set(self.renderer.pressed_keys):
return self.action_space.actions[self.key_to_action[env_keys]]
# return the default action 0 so that the environment will continue running
return self.action_space.default_action
@property
def last_env_response(self) -> Union[List[EnvResponse], EnvResponse]:
"""
Get the last environment response
:return: a dictionary that contains the state, reward, etc.
"""
return squeeze_list(self._last_env_response)
@last_env_response.setter
def last_env_response(self, val: Union[List[EnvResponse], EnvResponse]):
"""
Set the last environment response
:param val: the last environment response
"""
self._last_env_response = force_list(val)
def step(self, action: ActionType) -> EnvResponse:
"""
Make a single step in the environment using the given action
:param action: an action to use for stepping the environment. Should follow the definition of the action space.
:return: the environment response as returned in get_last_env_response
"""
action = self.action_space.clip_action_to_space(action)
if self.action_space and not self.action_space.contains(action):
raise ValueError("The given action does not match the action space definition. "
"Action = {}, action space definition = {}".format(action, self.action_space))
# store the last agent action done and allow passing None actions to repeat the previously done action
if action is None:
action = self.last_action
self.last_action = action
if self.visualization_parameters.add_rendered_image_to_env_response:
current_rendered_image = self.get_rendered_image()
self.current_episode_steps_counter += 1
if self.phase != RunPhase.UNDEFINED:
self.total_steps_counter += 1
# act
self._take_action(action)
# observe
self._update_state()
if self.is_rendered:
self.render()
self.total_reward_in_current_episode += self.reward
if self.visualization_parameters.add_rendered_image_to_env_response:
self.info['image'] = current_rendered_image
self.last_env_response = \
EnvResponse(
reward=self.reward,
next_state=self.state,
goal=self.goal,
game_over=self.done,
info=self.info
)
# store observations for video / gif dumping
if self.should_dump_video_of_the_current_episode(episode_terminated=False) and \
(self.visualization_parameters.dump_mp4 or self.visualization_parameters.dump_gifs):
self.last_episode_images.append(self.get_rendered_image())
return self.last_env_response
def render(self) -> None:
"""
Call the environment function for rendering to the screen
:return: None
"""
if self.native_rendering:
self._render()
else:
self.renderer.render_image(self.get_rendered_image())
def handle_episode_ended(self) -> None:
"""
End an episode
:return: None
"""
self.dump_video_of_last_episode_if_needed()
def reset_internal_state(self, force_environment_reset=False) -> EnvResponse:
"""
Reset the environment and all the variable of the wrapper
:param force_environment_reset: forces environment reset even when the game did not end
:return: A dictionary containing the observation, reward, done flag, action and measurements
"""
self._restart_environment_episode(force_environment_reset)
self.last_episode_time = time.time()
if self.current_episode_steps_counter > 0 and self.phase != RunPhase.UNDEFINED:
self.episode_idx += 1
self.done = False
self.total_reward_in_current_episode = self.reward = 0.0
self.last_action = 0
self.current_episode_steps_counter = 0
self.last_episode_images = []
self._update_state()
# render before the preprocessing of the observation, so that the image will be in its original quality
if self.is_rendered:
self.render()
self.last_env_response = \
EnvResponse(
reward=self.reward,
next_state=self.state,
goal=self.goal,
game_over=self.done,
info=self.info
)
return self.last_env_response
def get_random_action(self) -> ActionType:
"""
Returns an action picked uniformly from the available actions
:return: a numpy array with a random action
"""
return self.action_space.sample()
def get_available_keys(self) -> List[Tuple[str, ActionType]]:
"""
Return a list of tuples mapping between action names and the keyboard key that triggers them
:return: a list of tuples mapping between action names and the keyboard key that triggers them
"""
available_keys = []
if self.key_to_action != {}:
for key, idx in sorted(self.key_to_action.items(), key=operator.itemgetter(1)):
if key != ():
key_names = [self.renderer.get_key_names([k])[0] for k in key]
available_keys.append((self.action_space.descriptions[idx], ' + '.join(key_names)))
elif type(self.action_space) == DiscreteActionSpace:
for action in range(self.action_space.shape):
available_keys.append(("Action {}".format(action + 1), action + 1))
return available_keys
def get_goal(self) -> GoalType:
"""
Get the current goal that the agents needs to achieve in the environment
:return: The goal
"""
return self.goal
def set_goal(self, goal: GoalType) -> None:
"""
Set the current goal that the agent needs to achieve in the environment
:param goal: the goal that needs to be achieved
:return: None
"""
self.goal = goal
def should_dump_video_of_the_current_episode(self, episode_terminated=False):
if self.visualization_parameters.video_dump_filters:
for video_dump_filter in force_list(self.visualization_parameters.video_dump_filters):
if not video_dump_filter.should_dump(episode_terminated, **self.__dict__):
return False
return True
return True
def dump_video_of_last_episode_if_needed(self):
if self.last_episode_images != [] and self.should_dump_video_of_the_current_episode(episode_terminated=True):
self.dump_video_of_last_episode()
def dump_video_of_last_episode(self):
frame_skipping = max(1, int(5 / self.frame_skip))
file_name = 'episode-{}_score-{}'.format(self.episode_idx, self.total_reward_in_current_episode)
fps = 10
if self.visualization_parameters.dump_gifs:
logger.create_gif(self.last_episode_images[::frame_skipping], name=file_name, fps=fps)
if self.visualization_parameters.dump_mp4:
logger.create_mp4(self.last_episode_images[::frame_skipping], name=file_name, fps=fps)
# The following functions define the interaction with the environment.
# Any new environment that inherits the Environment class should use these signatures.
# Some of these functions are optional - please read their description for more details.
def _take_action(self, action_idx: ActionType) -> None:
"""
An environment dependent function that sends an action to the simulator.
:param action_idx: the action to perform on the environment
:return: None
"""
raise NotImplementedError("")
def _update_state(self) -> None:
"""
Updates the state from the environment.
Should update self.observation, self.reward, self.done, self.measurements and self.info
:return: None
"""
raise NotImplementedError("")
def _restart_environment_episode(self, force_environment_reset=False) -> None:
"""
Restarts the simulator episode
:param force_environment_reset: Force the environment to reset even if the episode is not done yet.
:return: None
"""
raise NotImplementedError("")
def _render(self) -> None:
"""
Renders the environment using the native simulator renderer
:return: None
"""
pass
def get_rendered_image(self) -> np.ndarray:
"""
Return a numpy array containing the image that will be rendered to the screen.
This can be different from the observation. For example, mujoco's observation is a measurements vector.
:return: numpy array containing the image that will be rendered to the screen
"""
return np.transpose(self.state['observation'], [1, 2, 0])
def get_target_success_rate(self) -> float:
return self.target_success_rate
def close(self) -> None:
"""
Clean up steps.
:return: None
"""
pass