mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
387 lines
15 KiB
Plaintext
387 lines
15 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"In this tutorial we'll add the DeepMind Control Suite environment to Coach, and create a preset that trains the DDPG agent on the new environment."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Setup\n",
|
|
"First, follow the installation instructions here: https://github.com/deepmind/dm_control#installation-and-requirements. \n",
|
|
"\n",
|
|
"\n",
|
|
"Make sure your ```LD_LIBRARY_PATH``` contains the path to the GLEW and LGFW libraries (https://github.com/openai/mujoco-py/issues/110).\n",
|
|
"\n",
|
|
"\n",
|
|
"In addition, Mujoco rendering might need to be disabled (https://github.com/deepmind/dm_control/issues/20)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"os.environ['DISABLE_MUJOCO_RENDERING'] = '1'\n",
|
|
"\n",
|
|
"import sys\n",
|
|
"module_path = os.path.abspath(os.path.join('..'))\n",
|
|
"if module_path not in sys.path:\n",
|
|
" sys.path.append(module_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# The Environment Wrapper\n",
|
|
"\n",
|
|
"To integrate an environment with Coach, we need to implement an environment wrapper which is placed under the environments folder. In our case, we'll implement the ```control_suite_environment.py``` file.\n",
|
|
"\n",
|
|
"\n",
|
|
"We'll start with some helper classes - ```ObservationType``` and ```ControlSuiteEnvironmentParameters```."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from enum import Enum\n",
|
|
"from dm_control import suite\n",
|
|
"from rl_coach.environments.environment import Environment, EnvironmentParameters, LevelSelection\n",
|
|
"from rl_coach.filters.filter import NoInputFilter, NoOutputFilter\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"class ObservationType(Enum):\n",
|
|
" Measurements = 1\n",
|
|
" Image = 2\n",
|
|
" Image_and_Measurements = 3\n",
|
|
"\n",
|
|
"\n",
|
|
"# Parameters\n",
|
|
"class ControlSuiteEnvironmentParameters(EnvironmentParameters):\n",
|
|
" def __init__(self):\n",
|
|
" super().__init__()\n",
|
|
" self.observation_type = ObservationType.Measurements\n",
|
|
" self.default_input_filter = ControlSuiteInputFilter\n",
|
|
" self.default_output_filter = ControlSuiteOutputFilter\n",
|
|
"\n",
|
|
" @property\n",
|
|
" def path(self):\n",
|
|
" return 'environments.control_suite_environment:ControlSuiteEnvironment'\n",
|
|
"\n",
|
|
"\n",
|
|
"\"\"\"\n",
|
|
"ControlSuite Environment Components\n",
|
|
"\"\"\"\n",
|
|
"ControlSuiteInputFilter = NoInputFilter()\n",
|
|
"ControlSuiteOutputFilter = NoOutputFilter()\n",
|
|
"\n",
|
|
"control_suite_envs = {':'.join(env): ':'.join(env) for env in suite.BENCHMARKING}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Now let's define the control suite's environment wrapper class.\n",
|
|
"\n",
|
|
"In the ```__init__``` function we'll load and initialize the environment, and the internal state and action space members which will make sure the states and actions are within their allowed limits."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"import random\n",
|
|
"from typing import Union\n",
|
|
"from rl_coach.base_parameters import VisualizationParameters\n",
|
|
"from rl_coach.spaces import BoxActionSpace, ImageObservationSpace, VectorObservationSpace, StateSpace\n",
|
|
"from dm_control.suite.wrappers import pixels\n",
|
|
"\n",
|
|
"\n",
|
|
"# Environment\n",
|
|
"class ControlSuiteEnvironment(Environment):\n",
|
|
" def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters,\n",
|
|
" seed: Union[None, int]=None, human_control: bool=False,\n",
|
|
" observation_type: ObservationType=ObservationType.Measurements,\n",
|
|
" custom_reward_threshold: Union[int, float]=None, **kwargs):\n",
|
|
" super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters)\n",
|
|
"\n",
|
|
" self.observation_type = observation_type\n",
|
|
"\n",
|
|
" # load and initialize environment\n",
|
|
" domain_name, task_name = self.env_id.split(\":\")\n",
|
|
" self.env = suite.load(domain_name=domain_name, task_name=task_name)\n",
|
|
"\n",
|
|
" if observation_type != ObservationType.Measurements:\n",
|
|
" self.env = pixels.Wrapper(self.env, pixels_only=observation_type == ObservationType.Image)\n",
|
|
"\n",
|
|
" # seed\n",
|
|
" if self.seed is not None:\n",
|
|
" np.random.seed(self.seed)\n",
|
|
" random.seed(self.seed)\n",
|
|
"\n",
|
|
" self.state_space = StateSpace({})\n",
|
|
"\n",
|
|
" # image observations\n",
|
|
" if observation_type != ObservationType.Measurements:\n",
|
|
" self.state_space['pixels'] = ImageObservationSpace(shape=self.env.observation_spec()['pixels'].shape,\n",
|
|
" high=255)\n",
|
|
"\n",
|
|
" # measurements observations\n",
|
|
" if observation_type != ObservationType.Image:\n",
|
|
" measurements_space_size = 0\n",
|
|
" measurements_names = []\n",
|
|
" for observation_space_name, observation_space in self.env.observation_spec().items():\n",
|
|
" if len(observation_space.shape) == 0:\n",
|
|
" measurements_space_size += 1\n",
|
|
" measurements_names.append(observation_space_name)\n",
|
|
" elif len(observation_space.shape) == 1:\n",
|
|
" measurements_space_size += observation_space.shape[0]\n",
|
|
" measurements_names.extend([\"{}_{}\".format(observation_space_name, i) for i in\n",
|
|
" range(observation_space.shape[0])])\n",
|
|
" self.state_space['measurements'] = VectorObservationSpace(shape=measurements_space_size,\n",
|
|
" measurements_names=measurements_names)\n",
|
|
"\n",
|
|
" # actions\n",
|
|
" self.action_space = BoxActionSpace(\n",
|
|
" shape=self.env.action_spec().shape[0],\n",
|
|
" low=self.env.action_spec().minimum,\n",
|
|
" high=self.env.action_spec().maximum\n",
|
|
" )\n",
|
|
"\n",
|
|
" # initialize the state by getting a new state from the environment\n",
|
|
" self.reset_internal_state(True)\n",
|
|
"\n",
|
|
" # render\n",
|
|
" if self.is_rendered:\n",
|
|
" image = self.get_rendered_image()\n",
|
|
" scale = 1\n",
|
|
" if self.human_control:\n",
|
|
" scale = 2\n",
|
|
" if not self.native_rendering:\n",
|
|
" self.renderer.create_screen(image.shape[1]*scale, image.shape[0]*scale)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The following functions cover the API expected from a new environment wrapper:\n",
|
|
"\n",
|
|
"1. ```_update_state``` - update the internal state of the wrapper (to be queried by the agent)\n",
|
|
"2. ```_take_action``` - take an action on the environment \n",
|
|
"3. ```_restart_environment_episode``` - restart the environment on a new episode \n",
|
|
"4. ```get_rendered_image``` - get a rendered image of the environment in its current state"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class ControlSuiteEnvironment(Environment):\n",
|
|
" def _update_state(self):\n",
|
|
" self.state = {}\n",
|
|
"\n",
|
|
" if self.observation_type != ObservationType.Measurements:\n",
|
|
" self.pixels = self.last_result.observation['pixels']\n",
|
|
" self.state['pixels'] = self.pixels\n",
|
|
"\n",
|
|
" if self.observation_type != ObservationType.Image:\n",
|
|
" self.measurements = np.array([])\n",
|
|
" for sub_observation in self.last_result.observation.values():\n",
|
|
" if isinstance(sub_observation, np.ndarray) and len(sub_observation.shape) == 1:\n",
|
|
" self.measurements = np.concatenate((self.measurements, sub_observation))\n",
|
|
" else:\n",
|
|
" self.measurements = np.concatenate((self.measurements, np.array([sub_observation])))\n",
|
|
" self.state['measurements'] = self.measurements\n",
|
|
"\n",
|
|
" self.reward = self.last_result.reward if self.last_result.reward is not None else 0\n",
|
|
"\n",
|
|
" self.done = self.last_result.last()\n",
|
|
"\n",
|
|
" def _take_action(self, action):\n",
|
|
" if type(self.action_space) == BoxActionSpace:\n",
|
|
" action = self.action_space.clip_action_to_space(action)\n",
|
|
"\n",
|
|
" self.last_result = self.env.step(action)\n",
|
|
"\n",
|
|
" def _restart_environment_episode(self, force_environment_reset=False):\n",
|
|
" self.last_result = self.env.reset()\n",
|
|
"\n",
|
|
" def get_rendered_image(self):\n",
|
|
" return self.env.physics.render(camera_id=0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# The Preset\n",
|
|
"The new preset will be defined in a new file - ```presets\\ControlSuite_DDPG.py```. \n",
|
|
"\n",
|
|
"First - let's define the agent parameters"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from rl_coach.agents.ddpg_agent import DDPGAgentParameters\n",
|
|
"from rl_coach.architectures.tensorflow_components.architecture import Dense\n",
|
|
"from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme\n",
|
|
"from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase\n",
|
|
"from rl_coach.environments.gym_environment import MujocoInputFilter\n",
|
|
"from rl_coach.filters.reward.reward_rescale_filter import RewardRescaleFilter\n",
|
|
"\n",
|
|
"\n",
|
|
"agent_params = DDPGAgentParameters()\n",
|
|
"agent_params.network_wrappers['actor'].input_embedders_parameters['measurements'] = \\\n",
|
|
" agent_params.network_wrappers['actor'].input_embedders_parameters.pop('observation')\n",
|
|
"agent_params.network_wrappers['critic'].input_embedders_parameters['measurements'] = \\\n",
|
|
" agent_params.network_wrappers['critic'].input_embedders_parameters.pop('observation')\n",
|
|
"agent_params.network_wrappers['actor'].input_embedders_parameters['measurements'].scheme = [Dense([300])]\n",
|
|
"agent_params.network_wrappers['actor'].middleware_parameters.scheme = [Dense([200])]\n",
|
|
"agent_params.network_wrappers['critic'].input_embedders_parameters['measurements'].scheme = [Dense([400])]\n",
|
|
"agent_params.network_wrappers['critic'].middleware_parameters.scheme = [Dense([300])]\n",
|
|
"agent_params.network_wrappers['critic'].input_embedders_parameters['action'].scheme = EmbedderScheme.Empty\n",
|
|
"agent_params.input_filter = MujocoInputFilter()\n",
|
|
"agent_params.input_filter.add_reward_filter(\"rescale\", RewardRescaleFilter(1/10.))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Now let's define the environment parameters"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from rl_coach.environments.control_suite_environment import ControlSuiteEnvironmentParameters, control_suite_envs\n",
|
|
"from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection\n",
|
|
"\n",
|
|
"env_params = ControlSuiteEnvironmentParameters()\n",
|
|
"env_params.level = SingleLevelSelection(control_suite_envs)\n",
|
|
"\n",
|
|
"vis_params = VisualizationParameters()\n",
|
|
"vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]\n",
|
|
"vis_params.dump_mp4 = False"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The schedule parameters will define the number of heatup steps, periodice evaluation steps, training steps between evaluations."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from rl_coach.graph_managers.graph_manager import ScheduleParameters\n",
|
|
"\n",
|
|
"\n",
|
|
"schedule_params = ScheduleParameters()\n",
|
|
"schedule_params.improve_steps = TrainingSteps(10000000000)\n",
|
|
"schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(20)\n",
|
|
"schedule_params.evaluation_steps = EnvironmentEpisodes(1)\n",
|
|
"schedule_params.heatup_steps = EnvironmentSteps(1000)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Finally, we'll create and run the graph manager"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager\n",
|
|
"from rl_coach.base_parameters import TaskParameters, Frameworks\n",
|
|
"\n",
|
|
"\n",
|
|
"graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,\n",
|
|
" schedule_params=schedule_params, vis_params=vis_params)\n",
|
|
"\n",
|
|
"graph_manager.env_params.level.select('walker:walk')\n",
|
|
"#graph_manager.visualization_parameters.render = True\n",
|
|
"\n",
|
|
"\n",
|
|
"log_path = '../experiments/control_suite_walker_ddpg'\n",
|
|
"if not os.path.exists(log_path):\n",
|
|
" os.makedirs(log_path)\n",
|
|
" \n",
|
|
"task_parameters = TaskParameters(framework_type=\"tensorflow\", \n",
|
|
" evaluate_only=False,\n",
|
|
" experiment_path=log_path)\n",
|
|
"\n",
|
|
"task_parameters.__dict__['save_checkpoint_secs'] = None\n",
|
|
"\n",
|
|
"\n",
|
|
"graph_manager.create_graph(task_parameters)\n",
|
|
"\n",
|
|
"# let the adventure begin\n",
|
|
"graph_manager.improve()\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.5.2"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|