1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00
Files
coach/tutorials/2. Adding an Environment.ipynb

387 lines
15 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this tutorial we'll add the DeepMind Control Suite environment to Coach, and create a preset that trains the DDPG agent on the new environment."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Setup\n",
"First, follow the installation instructions here: https://github.com/deepmind/dm_control#installation-and-requirements. \n",
"\n",
"\n",
"Make sure your ```LD_LIBRARY_PATH``` contains the path to the GLEW and LGFW libraries (https://github.com/openai/mujoco-py/issues/110).\n",
"\n",
"\n",
"In addition, Mujoco rendering might need to be disabled (https://github.com/deepmind/dm_control/issues/20)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"#os.environ['DISABLE_MUJOCO_RENDERING'] = '1'\n",
"\n",
"import sys\n",
"module_path = os.path.abspath(os.path.join('..'))\n",
"if module_path not in sys.path:\n",
" sys.path.append(module_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# The Environment Wrapper\n",
"\n",
"To integrate an environment with Coach, we need to implement an environment wrapper which is placed under the environments folder. In our case, we'll implement the ```control_suite_environment.py``` file.\n",
"\n",
"\n",
"We'll start with some helper classes - ```ObservationType``` and ```ControlSuiteEnvironmentParameters```."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from enum import Enum\n",
"from dm_control import suite\n",
"from rl_coach.environments.environment import Environment, EnvironmentParameters, LevelSelection\n",
"from rl_coach.filters.filter import NoInputFilter, NoOutputFilter\n",
"\n",
"\n",
"\n",
"class ObservationType(Enum):\n",
" Measurements = 1\n",
" Image = 2\n",
" Image_and_Measurements = 3\n",
"\n",
"\n",
"# Parameters\n",
"class ControlSuiteEnvironmentParameters(EnvironmentParameters):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.observation_type = ObservationType.Measurements\n",
" self.default_input_filter = ControlSuiteInputFilter\n",
" self.default_output_filter = ControlSuiteOutputFilter\n",
"\n",
" @property\n",
" def path(self):\n",
" return 'environments.control_suite_environment:ControlSuiteEnvironment'\n",
"\n",
"\n",
"\"\"\"\n",
"ControlSuite Environment Components\n",
"\"\"\"\n",
"ControlSuiteInputFilter = NoInputFilter()\n",
"ControlSuiteOutputFilter = NoOutputFilter()\n",
"\n",
"control_suite_envs = {':'.join(env): ':'.join(env) for env in suite.BENCHMARKING}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's define the control suite's environment wrapper class.\n",
"\n",
"In the ```__init__``` function we'll load and initialize the environment, and the internal state and action space members which will make sure the states and actions are within their allowed limits."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import random\n",
"from typing import Union\n",
"from rl_coach.base_parameters import VisualizationParameters\n",
"from rl_coach.spaces import BoxActionSpace, ImageObservationSpace, VectorObservationSpace, StateSpace\n",
"from dm_control.suite.wrappers import pixels\n",
"\n",
"\n",
"# Environment\n",
"class ControlSuiteEnvironment(Environment):\n",
" def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters,\n",
" seed: Union[None, int]=None, human_control: bool=False,\n",
" observation_type: ObservationType=ObservationType.Measurements,\n",
" custom_reward_threshold: Union[int, float]=None, **kwargs):\n",
" super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters)\n",
"\n",
" self.observation_type = observation_type\n",
"\n",
" # load and initialize environment\n",
" domain_name, task_name = self.env_id.split(\":\")\n",
" self.env = suite.load(domain_name=domain_name, task_name=task_name)\n",
"\n",
" if observation_type != ObservationType.Measurements:\n",
" self.env = pixels.Wrapper(self.env, pixels_only=observation_type == ObservationType.Image)\n",
"\n",
" # seed\n",
" if self.seed is not None:\n",
" np.random.seed(self.seed)\n",
" random.seed(self.seed)\n",
"\n",
" self.state_space = StateSpace({})\n",
"\n",
" # image observations\n",
" if observation_type != ObservationType.Measurements:\n",
" self.state_space['pixels'] = ImageObservationSpace(shape=self.env.observation_spec()['pixels'].shape,\n",
" high=255)\n",
"\n",
" # measurements observations\n",
" if observation_type != ObservationType.Image:\n",
" measurements_space_size = 0\n",
" measurements_names = []\n",
" for observation_space_name, observation_space in self.env.observation_spec().items():\n",
" if len(observation_space.shape) == 0:\n",
" measurements_space_size += 1\n",
" measurements_names.append(observation_space_name)\n",
" elif len(observation_space.shape) == 1:\n",
" measurements_space_size += observation_space.shape[0]\n",
" measurements_names.extend([\"{}_{}\".format(observation_space_name, i) for i in\n",
" range(observation_space.shape[0])])\n",
" self.state_space['measurements'] = VectorObservationSpace(shape=measurements_space_size,\n",
" measurements_names=measurements_names)\n",
"\n",
" # actions\n",
" self.action_space = BoxActionSpace(\n",
" shape=self.env.action_spec().shape[0],\n",
" low=self.env.action_spec().minimum,\n",
" high=self.env.action_spec().maximum\n",
" )\n",
"\n",
" # initialize the state by getting a new state from the environment\n",
" self.reset_internal_state(True)\n",
"\n",
" # render\n",
" if self.is_rendered:\n",
" image = self.get_rendered_image()\n",
" scale = 1\n",
" if self.human_control:\n",
" scale = 2\n",
" if not self.native_rendering:\n",
" self.renderer.create_screen(image.shape[1]*scale, image.shape[0]*scale)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The following functions cover the API expected from a new environment wrapper:\n",
"\n",
"1. ```_update_state``` - update the internal state of the wrapper (to be queried by the agent)\n",
"2. ```_take_action``` - take an action on the environment \n",
"3. ```_restart_environment_episode``` - restart the environment on a new episode \n",
"4. ```get_rendered_image``` - get a rendered image of the environment in its current state"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class ControlSuiteEnvironment(Environment):\n",
" def _update_state(self):\n",
" self.state = {}\n",
"\n",
" if self.observation_type != ObservationType.Measurements:\n",
" self.pixels = self.last_result.observation['pixels']\n",
" self.state['pixels'] = self.pixels\n",
"\n",
" if self.observation_type != ObservationType.Image:\n",
" self.measurements = np.array([])\n",
" for sub_observation in self.last_result.observation.values():\n",
" if isinstance(sub_observation, np.ndarray) and len(sub_observation.shape) == 1:\n",
" self.measurements = np.concatenate((self.measurements, sub_observation))\n",
" else:\n",
" self.measurements = np.concatenate((self.measurements, np.array([sub_observation])))\n",
" self.state['measurements'] = self.measurements\n",
"\n",
" self.reward = self.last_result.reward if self.last_result.reward is not None else 0\n",
"\n",
" self.done = self.last_result.last()\n",
"\n",
" def _take_action(self, action):\n",
" if type(self.action_space) == BoxActionSpace:\n",
" action = self.action_space.clip_action_to_space(action)\n",
"\n",
" self.last_result = self.env.step(action)\n",
"\n",
" def _restart_environment_episode(self, force_environment_reset=False):\n",
" self.last_result = self.env.reset()\n",
"\n",
" def get_rendered_image(self):\n",
" return self.env.physics.render(camera_id=0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# The Preset\n",
"The new preset will be defined in a new file - ```presets\\ControlSuite_DDPG.py```. \n",
"\n",
"First - let's define the agent parameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from rl_coach.agents.ddpg_agent import DDPGAgentParameters\n",
"from rl_coach.architectures.tensorflow_components.architecture import Dense\n",
"from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme\n",
"from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase\n",
"from rl_coach.environments.gym_environment import MujocoInputFilter\n",
"from rl_coach.filters.reward.reward_rescale_filter import RewardRescaleFilter\n",
"\n",
"\n",
"agent_params = DDPGAgentParameters()\n",
"agent_params.network_wrappers['actor'].input_embedders_parameters['measurements'] = \\\n",
" agent_params.network_wrappers['actor'].input_embedders_parameters.pop('observation')\n",
"agent_params.network_wrappers['critic'].input_embedders_parameters['measurements'] = \\\n",
" agent_params.network_wrappers['critic'].input_embedders_parameters.pop('observation')\n",
"agent_params.network_wrappers['actor'].input_embedders_parameters['measurements'].scheme = [Dense([300])]\n",
"agent_params.network_wrappers['actor'].middleware_parameters.scheme = [Dense([200])]\n",
"agent_params.network_wrappers['critic'].input_embedders_parameters['measurements'].scheme = [Dense([400])]\n",
"agent_params.network_wrappers['critic'].middleware_parameters.scheme = [Dense([300])]\n",
"agent_params.network_wrappers['critic'].input_embedders_parameters['action'].scheme = EmbedderScheme.Empty\n",
"agent_params.input_filter = MujocoInputFilter()\n",
"agent_params.input_filter.add_reward_filter(\"rescale\", RewardRescaleFilter(1/10.))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's define the environment parameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from rl_coach.environments.control_suite_environment import ControlSuiteEnvironmentParameters, control_suite_envs\n",
"from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection\n",
"\n",
"env_params = ControlSuiteEnvironmentParameters()\n",
"env_params.level = SingleLevelSelection(control_suite_envs)\n",
"\n",
"vis_params = VisualizationParameters()\n",
"vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]\n",
"vis_params.dump_mp4 = False"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The schedule parameters will define the number of heatup steps, periodice evaluation steps, training steps between evaluations."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from rl_coach.graph_managers.graph_manager import ScheduleParameters\n",
"\n",
"\n",
"schedule_params = ScheduleParameters()\n",
"schedule_params.improve_steps = TrainingSteps(10000000000)\n",
"schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(20)\n",
"schedule_params.evaluation_steps = EnvironmentEpisodes(1)\n",
"schedule_params.heatup_steps = EnvironmentSteps(1000)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, we'll create and run the graph manager"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager\n",
"from rl_coach.base_parameters import TaskParameters, Frameworks\n",
"\n",
"\n",
"graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,\n",
" schedule_params=schedule_params, vis_params=vis_params)\n",
"\n",
"graph_manager.env_params.level.select('walker:walk')\n",
"graph_manager.visualization_parameters.render = True\n",
"\n",
"\n",
"log_path = '../experiments/control_suite_walker_ddpg'\n",
"if not os.path.exists(log_path):\n",
" os.makedirs(log_path)\n",
" \n",
"task_parameters = TaskParameters(framework_type=\"tensorflow\", \n",
" evaluate_only=False,\n",
" experiment_path=log_path)\n",
"\n",
"task_parameters.__dict__['save_checkpoint_secs'] = None\n",
"\n",
"\n",
"graph_manager.create_graph(task_parameters)\n",
"\n",
"# let the adventure begin\n",
"graph_manager.improve()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}