{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "In this tutorial we'll add the DeepMind Control Suite environment to Coach, and create a preset that trains the DDPG agent on the new environment." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Setup\n", "First, follow the installation instructions here: https://github.com/deepmind/dm_control#installation-and-requirements. \n", "\n", "\n", "Make sure your ```LD_LIBRARY_PATH``` contains the path to the GLEW and LGFW libraries (https://github.com/openai/mujoco-py/issues/110).\n", "\n", "\n", "In addition, Mujoco rendering might need to be disabled (https://github.com/deepmind/dm_control/issues/20)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "#os.environ['DISABLE_MUJOCO_RENDERING'] = '1'\n", "\n", "import sys\n", "module_path = os.path.abspath(os.path.join('..'))\n", "if module_path not in sys.path:\n", " sys.path.append(module_path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# The Environment Wrapper\n", "\n", "To integrate an environment with Coach, we need to implement an environment wrapper which is placed under the environments folder. In our case, we'll implement the ```control_suite_environment.py``` file.\n", "\n", "\n", "We'll start with some helper classes - ```ObservationType``` and ```ControlSuiteEnvironmentParameters```." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from enum import Enum\n", "from dm_control import suite\n", "from rl_coach.environments.environment import Environment, EnvironmentParameters, LevelSelection\n", "from rl_coach.filters.filter import NoInputFilter, NoOutputFilter\n", "\n", "\n", "\n", "class ObservationType(Enum):\n", " Measurements = 1\n", " Image = 2\n", " Image_and_Measurements = 3\n", "\n", "\n", "# Parameters\n", "class ControlSuiteEnvironmentParameters(EnvironmentParameters):\n", " def __init__(self):\n", " super().__init__()\n", " self.observation_type = ObservationType.Measurements\n", " self.default_input_filter = ControlSuiteInputFilter\n", " self.default_output_filter = ControlSuiteOutputFilter\n", "\n", " @property\n", " def path(self):\n", " return 'environments.control_suite_environment:ControlSuiteEnvironment'\n", "\n", "\n", "\"\"\"\n", "ControlSuite Environment Components\n", "\"\"\"\n", "ControlSuiteInputFilter = NoInputFilter()\n", "ControlSuiteOutputFilter = NoOutputFilter()\n", "\n", "control_suite_envs = {':'.join(env): ':'.join(env) for env in suite.BENCHMARKING}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's define the control suite's environment wrapper class.\n", "\n", "In the ```__init__``` function we'll load and initialize the environment, and the internal state and action space members which will make sure the states and actions are within their allowed limits." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import random\n", "from typing import Union\n", "from rl_coach.base_parameters import VisualizationParameters\n", "from rl_coach.spaces import BoxActionSpace, ImageObservationSpace, VectorObservationSpace, StateSpace\n", "from dm_control.suite.wrappers import pixels\n", "\n", "\n", "# Environment\n", "class ControlSuiteEnvironment(Environment):\n", " def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters,\n", " seed: Union[None, int]=None, human_control: bool=False,\n", " observation_type: ObservationType=ObservationType.Measurements,\n", " custom_reward_threshold: Union[int, float]=None, **kwargs):\n", " super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters)\n", "\n", " self.observation_type = observation_type\n", "\n", " # load and initialize environment\n", " domain_name, task_name = self.env_id.split(\":\")\n", " self.env = suite.load(domain_name=domain_name, task_name=task_name)\n", "\n", " if observation_type != ObservationType.Measurements:\n", " self.env = pixels.Wrapper(self.env, pixels_only=observation_type == ObservationType.Image)\n", "\n", " # seed\n", " if self.seed is not None:\n", " np.random.seed(self.seed)\n", " random.seed(self.seed)\n", "\n", " self.state_space = StateSpace({})\n", "\n", " # image observations\n", " if observation_type != ObservationType.Measurements:\n", " self.state_space['pixels'] = ImageObservationSpace(shape=self.env.observation_spec()['pixels'].shape,\n", " high=255)\n", "\n", " # measurements observations\n", " if observation_type != ObservationType.Image:\n", " measurements_space_size = 0\n", " measurements_names = []\n", " for observation_space_name, observation_space in self.env.observation_spec().items():\n", " if len(observation_space.shape) == 0:\n", " measurements_space_size += 1\n", " measurements_names.append(observation_space_name)\n", " elif len(observation_space.shape) == 1:\n", " measurements_space_size += observation_space.shape[0]\n", " measurements_names.extend([\"{}_{}\".format(observation_space_name, i) for i in\n", " range(observation_space.shape[0])])\n", " self.state_space['measurements'] = VectorObservationSpace(shape=measurements_space_size,\n", " measurements_names=measurements_names)\n", "\n", " # actions\n", " self.action_space = BoxActionSpace(\n", " shape=self.env.action_spec().shape[0],\n", " low=self.env.action_spec().minimum,\n", " high=self.env.action_spec().maximum\n", " )\n", "\n", " # initialize the state by getting a new state from the environment\n", " self.reset_internal_state(True)\n", "\n", " # render\n", " if self.is_rendered:\n", " image = self.get_rendered_image()\n", " scale = 1\n", " if self.human_control:\n", " scale = 2\n", " if not self.native_rendering:\n", " self.renderer.create_screen(image.shape[1]*scale, image.shape[0]*scale)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following functions cover the API expected from a new environment wrapper:\n", "\n", "1. ```_update_state``` - update the internal state of the wrapper (to be queried by the agent)\n", "2. ```_take_action``` - take an action on the environment \n", "3. ```_restart_environment_episode``` - restart the environment on a new episode \n", "4. ```get_rendered_image``` - get a rendered image of the environment in its current state" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class ControlSuiteEnvironment(Environment):\n", " def _update_state(self):\n", " self.state = {}\n", "\n", " if self.observation_type != ObservationType.Measurements:\n", " self.pixels = self.last_result.observation['pixels']\n", " self.state['pixels'] = self.pixels\n", "\n", " if self.observation_type != ObservationType.Image:\n", " self.measurements = np.array([])\n", " for sub_observation in self.last_result.observation.values():\n", " if isinstance(sub_observation, np.ndarray) and len(sub_observation.shape) == 1:\n", " self.measurements = np.concatenate((self.measurements, sub_observation))\n", " else:\n", " self.measurements = np.concatenate((self.measurements, np.array([sub_observation])))\n", " self.state['measurements'] = self.measurements\n", "\n", " self.reward = self.last_result.reward if self.last_result.reward is not None else 0\n", "\n", " self.done = self.last_result.last()\n", "\n", " def _take_action(self, action):\n", " if type(self.action_space) == BoxActionSpace:\n", " action = self.action_space.clip_action_to_space(action)\n", "\n", " self.last_result = self.env.step(action)\n", "\n", " def _restart_environment_episode(self, force_environment_reset=False):\n", " self.last_result = self.env.reset()\n", "\n", " def get_rendered_image(self):\n", " return self.env.physics.render(camera_id=0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# The Preset\n", "The new preset will be defined in a new file - ```presets\\ControlSuite_DDPG.py```. \n", "\n", "First - let's define the agent parameters" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from rl_coach.agents.ddpg_agent import DDPGAgentParameters\n", "from rl_coach.architectures.tensorflow_components.architecture import Dense\n", "from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme\n", "from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase\n", "from rl_coach.environments.gym_environment import MujocoInputFilter\n", "from rl_coach.filters.reward.reward_rescale_filter import RewardRescaleFilter\n", "\n", "\n", "agent_params = DDPGAgentParameters()\n", "agent_params.network_wrappers['actor'].input_embedders_parameters['measurements'] = \\\n", " agent_params.network_wrappers['actor'].input_embedders_parameters.pop('observation')\n", "agent_params.network_wrappers['critic'].input_embedders_parameters['measurements'] = \\\n", " agent_params.network_wrappers['critic'].input_embedders_parameters.pop('observation')\n", "agent_params.network_wrappers['actor'].input_embedders_parameters['measurements'].scheme = [Dense([300])]\n", "agent_params.network_wrappers['actor'].middleware_parameters.scheme = [Dense([200])]\n", "agent_params.network_wrappers['critic'].input_embedders_parameters['measurements'].scheme = [Dense([400])]\n", "agent_params.network_wrappers['critic'].middleware_parameters.scheme = [Dense([300])]\n", "agent_params.network_wrappers['critic'].input_embedders_parameters['action'].scheme = EmbedderScheme.Empty\n", "agent_params.input_filter = MujocoInputFilter()\n", "agent_params.input_filter.add_reward_filter(\"rescale\", RewardRescaleFilter(1/10.))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's define the environment parameters" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from rl_coach.environments.control_suite_environment import ControlSuiteEnvironmentParameters, control_suite_envs\n", "from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection\n", "\n", "env_params = ControlSuiteEnvironmentParameters()\n", "env_params.level = SingleLevelSelection(control_suite_envs)\n", "\n", "vis_params = VisualizationParameters()\n", "vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]\n", "vis_params.dump_mp4 = False" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The schedule parameters will define the number of heatup steps, periodice evaluation steps, training steps between evaluations." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from rl_coach.graph_managers.graph_manager import ScheduleParameters\n", "\n", "\n", "schedule_params = ScheduleParameters()\n", "schedule_params.improve_steps = TrainingSteps(10000000000)\n", "schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(20)\n", "schedule_params.evaluation_steps = EnvironmentEpisodes(1)\n", "schedule_params.heatup_steps = EnvironmentSteps(1000)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we'll create and run the graph manager" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager\n", "from rl_coach.base_parameters import TaskParameters, Frameworks\n", "\n", "\n", "graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,\n", " schedule_params=schedule_params, vis_params=vis_params)\n", "\n", "graph_manager.env_params.level.select('walker:walk')\n", "graph_manager.visualization_parameters.render = True\n", "\n", "\n", "log_path = '../experiments/control_suite_walker_ddpg'\n", "if not os.path.exists(log_path):\n", " os.makedirs(log_path)\n", " \n", "task_parameters = TaskParameters(framework_type=Frameworks.tensorflow, \n", " evaluate_only=False,\n", " experiment_path=log_path)\n", "\n", "task_parameters.__dict__['checkpoint_save_secs'] = None\n", "\n", "\n", "graph_manager.create_graph(task_parameters)\n", "\n", "# let the adventure begin\n", "graph_manager.improve()\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" } }, "nbformat": 4, "nbformat_minor": 2 }