{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Goal-Based Data Collection\n", "A practical approach to robot reinforcement learning is to first collect a large batch of real or simulated robot interaction data, \n", "using some data collection policy, and then learn from this data to perform various tasks, using offline learning algorithms.\n", "\n", "In this notebook, we will demonstrate how to collect diverse dataset for a simple robotics manipulation task\n", "using the algorithms detailed in the following paper:\n", "[Efficient Self-Supervised Data Collection for Offline Robot Learning](https://arxiv.org/abs/2105.04607).\n", "\n", "The implementation is based on the Robosuite simulator, which should be installed before running this notebook. Follow the instructions in the Coach readme [here](https://github.com/IntelLabs/coach#robosuite).\n", "\n", "Presets with predefined parameters for all three algorithms shown in the paper can be found here:\n", "\n", "* Random Agent: ```presets/RoboSuite_CubeExp_Random.py```\n", "\n", "* Intrinsic Reward Agent: ```presets/RoboSuite_CubeExp_TD3_Intrinsic_Reward.py```\n", "\n", "* Goal-Based Agent: ```presets/RoboSuite_CubeExp_TD3_Goal_Based.py```\n", "\n", "You can run those presets using the command line:\n", "\n", "`coach -p RoboSuite_CubeExp_TD3_Goal_Based`\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Preliminaries\n", "First, get the required imports and other general settings we need for this notebook.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "from rl_coach.agents.td3_exp_agent import TD3GoalBasedAgentParameters\n", "from rl_coach.architectures.embedder_parameters import InputEmbedderParameters\n", "from rl_coach.architectures.layers import Dense, Conv2d, BatchnormActivationDropout, Flatten\n", "from rl_coach.base_parameters import EmbedderScheme\n", "from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps\n", "from rl_coach.environments.robosuite_environment import RobosuiteGoalBasedExpEnvironmentParameters, \\\n", " OptionalObservations\n", "from rl_coach.filters.filter import NoInputFilter, NoOutputFilter\n", "from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager\n", "from rl_coach.graph_managers.graph_manager import ScheduleParameters\n", "from rl_coach.architectures.head_parameters import RNDHeadParameters\n", "from rl_coach.schedules import LinearSchedule\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then, we define the training schedule for the agent. `improve_steps` dictates the number of samples in the final data-set.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "####################\n", "# Graph Scheduling #\n", "####################\n", "\n", "schedule_params = ScheduleParameters()\n", "schedule_params.improve_steps = TrainingSteps(300000)\n", "schedule_params.steps_between_evaluation_periods = TrainingSteps(300000)\n", "schedule_params.evaluation_steps = EnvironmentEpisodes(0)\n", "schedule_params.heatup_steps = EnvironmentSteps(1000)\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this example, we will be using the goal-based algorithm for data-collection. Therefore, we populate\n", "the `TD3GoalBasedAgentParameters` class with our desired algorithm specific parameters.\n", "\n", "The goal-based data collected is based on TD3, using this class you can change the TD3 specific parameters as well.\n", "\n", "A detailed description of the goal-based and TD3 algorithm specific parameters can be found in \n", "```agents/td3_exp_agent.py``` and ```agents/td3_agent.py``` respectively.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "#########\n", "# Agent #\n", "#########\n", "\n", "agent_params = TD3GoalBasedAgentParameters()\n", "agent_params.algorithm.use_non_zero_discount_for_terminal_states = False\n", "agent_params.algorithm.identity_goal_sample_rate = 0.04\n", "agent_params.exploration.noise_schedule = LinearSchedule(1.5, 0.5, 300000)\n", "\n", "agent_params.algorithm.rnd_sample_size = 2000\n", "agent_params.algorithm.rnd_batch_size = 500\n", "agent_params.algorithm.rnd_optimization_epochs = 4\n", "agent_params.algorithm.td3_training_ratio = 1.0\n", "agent_params.algorithm.identity_goal_sample_rate = 0.0\n", "agent_params.algorithm.env_obs_key = 'camera'\n", "agent_params.algorithm.agent_obs_key = 'obs-goal'\n", "agent_params.algorithm.replay_buffer_save_steps = 25000\n", "agent_params.algorithm.replay_buffer_save_path = './Resources'\n", "\n", "agent_params.input_filter = NoInputFilter()\n", "agent_params.output_filter = NoOutputFilter()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we'll define the networks' architecture and parameters as they appear in the paper.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# Camera observation pre-processing network scheme\n", "camera_obs_scheme = [\n", " Conv2d(32, 8, 4),\n", " BatchnormActivationDropout(activation_function='relu'),\n", " Conv2d(64, 4, 2),\n", " BatchnormActivationDropout(activation_function='relu'),\n", " Conv2d(64, 3, 1),\n", " BatchnormActivationDropout(activation_function='relu'),\n", " Flatten(),\n", " Dense(256),\n", " BatchnormActivationDropout(activation_function='relu')\n", "]\n", "\n", "# Actor\n", "actor_network = agent_params.network_wrappers['actor']\n", "actor_network.input_embedders_parameters = {\n", " 'measurements': InputEmbedderParameters(scheme=EmbedderScheme.Empty),\n", " agent_params.algorithm.agent_obs_key: InputEmbedderParameters(scheme=camera_obs_scheme, activation_function='none')\n", "}\n", "\n", "actor_network.middleware_parameters.scheme = [Dense(300), Dense(200)]\n", "actor_network.learning_rate = 1e-4\n", "\n", "# Critic\n", "critic_network = agent_params.network_wrappers['critic']\n", "critic_network.input_embedders_parameters = {\n", " 'action': InputEmbedderParameters(scheme=EmbedderScheme.Empty),\n", " 'measurements': InputEmbedderParameters(scheme=EmbedderScheme.Empty),\n", " agent_params.algorithm.agent_obs_key: InputEmbedderParameters(scheme=camera_obs_scheme, activation_function='none')\n", "}\n", "\n", "critic_network.middleware_parameters.scheme = [Dense(400), Dense(300)]\n", "critic_network.learning_rate = 1e-4\n", "\n", "# RND\n", "agent_params.network_wrappers['predictor'].input_embedders_parameters = \\\n", " {agent_params.algorithm.env_obs_key: InputEmbedderParameters(scheme=EmbedderScheme.Empty,\n", " input_rescaling={'image': 1.0},\n", " flatten=False)}\n", "agent_params.network_wrappers['constant'].input_embedders_parameters = \\\n", " {agent_params.algorithm.env_obs_key: InputEmbedderParameters(scheme=EmbedderScheme.Empty,\n", " input_rescaling={'image': 1.0},\n", " flatten=False)}\n", "agent_params.network_wrappers['predictor'].heads_parameters = [RNDHeadParameters(is_predictor=True)]\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The last thing we need to define is the environment parameters for the manipulation task.\n", "This environment is a 7DoF Franka Panda robotic arm with a closed gripper and cartesian\n", "position control of the end-effector. The robot is positioned on a table, and a cube object with colored sides is placed in\n", "front of it.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "###############\n", "# Environment #\n", "###############\n", "env_params = RobosuiteGoalBasedExpEnvironmentParameters(level='CubeExp')\n", "env_params.robot = 'Panda'\n", "env_params.custom_controller_config_fpath = '../rl_coach/environments/robosuite/osc_pose.json'\n", "env_params.base_parameters.optional_observations = OptionalObservations.CAMERA\n", "env_params.base_parameters.render_camera = 'frontview'\n", "env_params.base_parameters.camera_names = 'agentview'\n", "env_params.base_parameters.camera_depths = False\n", "env_params.base_parameters.horizon = 200\n", "env_params.base_parameters.ignore_done = False\n", "env_params.base_parameters.use_object_obs = True\n", "env_params.frame_skip = 1\n", "env_params.base_parameters.control_freq = 2\n", "env_params.base_parameters.camera_heights = 84\n", "env_params.base_parameters.camera_widths = 84\n", "env_params.extra_parameters = {'hard_reset': False}\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we create the graph manager and call `graph_manager.improve()` in order to start the data collection.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params, schedule_params=schedule_params)\n", "graph_manager.improve()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Once the data collection is complete, the data-set will saved to path specified by `agent_params.algorithm.replay_buffer_save_path`.\n", "\n", "At this point, the data can be used to learn any downstream task you define on that environment.\n", "\n", "The script below shows a visualization of the data-set. The dots represent a position of the cube on the table as seen in the data-set, and the color corresponds to the color of the face at the top. The number at the top signifies that number of dots a plot contains for a certain color.\n", "\n", "First we load the data-set from disk. Note that this can take several minutes to complete." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import joblib\n", "\n", "print('Loading data-set (this can take several minutes)...')\n", "rb_path = os.path.join('./Resources', 'RB_TD3GoalBasedAgent.joblib.bz2')\n", "episodes = joblib.load(rb_path)\n", "print('Done')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can run the visualization script:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "import os\n", "import numpy as np\n", "from collections import OrderedDict\n", "from enum import IntEnum\n", "from pylab import subplot\n", "from gym.envs.robotics.rotations import quat2euler, mat2euler, quat2mat\n", "import matplotlib.pyplot as plt\n", "\n", "\n", "class CubeColor(IntEnum):\n", " YELLOW = 0\n", " CYAN = 1\n", " WHITE = 2\n", " RED = 3\n", " GREEN = 4\n", " BLUE = 5\n", " UNKNOWN = 6\n", "\n", "\n", "x_range = [-0.3, 0.3]\n", "y_range = [-0.3, 0.3]\n", "\n", "COLOR_MAP = OrderedDict([\n", " (int(CubeColor.YELLOW), 'yellow'),\n", " (int(CubeColor.CYAN), 'cyan'),\n", " (int(CubeColor.WHITE), 'white'),\n", " (int(CubeColor.RED), 'red'),\n", " (int(CubeColor.GREEN), 'green'),\n", " (int(CubeColor.BLUE), 'blue'),\n", " (int(CubeColor.UNKNOWN), 'black'),\n", "])\n", "\n", "# Mapping between (subset of) euler angles to top face color, based on the initial cube rotation\n", "COLOR_ROTATION_MAP = OrderedDict([\n", " (CubeColor.YELLOW, (0, 2, [np.array([0, 0]),\n", " np.array([np.pi, np.pi]), np.array([-np.pi, -np.pi]),\n", " np.array([-np.pi, np.pi]), np.array([np.pi, -np.pi])])),\n", " (CubeColor.CYAN, (0, 2, [np.array([0, np.pi]), np.array([0, -np.pi]),\n", " np.array([np.pi, 0]), np.array([-np.pi, 0])])),\n", " (CubeColor.WHITE, (1, 2, [np.array([-np.pi / 2])])),\n", " (CubeColor.RED, (1, 2, [np.array([np.pi / 2])])),\n", " (CubeColor.GREEN, (0, 2, [np.array([np.pi / 2, 0])])),\n", " (CubeColor.BLUE, (0, 2, [np.array([-np.pi / 2, 0])])),\n", "])\n", "\n", "\n", "def get_cube_top_color(cube_quat, atol):\n", " euler = mat2euler(quat2mat(cube_quat))\n", " for color, (start_dim, end_dim, xy_rotations) in COLOR_ROTATION_MAP.items():\n", " if any(list(np.allclose(euler[start_dim:end_dim], xy_rotation, atol=atol) for xy_rotation in xy_rotations)):\n", " return color\n", " return CubeColor.UNKNOWN\n", "\n", "\n", "def pos2cord(x, y):\n", " x = max(min(x, x_range[1]), x_range[0])\n", " y = max(min(y, y_range[1]), y_range[0])\n", " x = int(((x - x_range[0])/(x_range[1] - x_range[0]))*99)\n", " y = int(((y - y_range[0])/(y_range[1] - y_range[0]))*99)\n", " return x, y\n", "\n", "\n", "pos_idx = 25\n", "quat_idx = 28\n", "positions = []\n", "colors = []\n", "print('Extracting cube positions and colors...')\n", "for episode in episodes:\n", " for transition in episode:\n", " x, y = transition.state['measurements'][pos_idx:pos_idx+2]\n", " positions.append([x, y])\n", " angle = quat2euler(transition.state['measurements'][quat_idx:quat_idx+4])\n", " colors.append(int(get_cube_top_color(transition.state['measurements'][quat_idx:quat_idx+4], np.pi / 4)))\n", "\n", " x_cord, y_cord = pos2cord(x, y)\n", "\n", " x, y = episode[-1].next_state['measurements'][pos_idx:pos_idx+2]\n", " positions.append([x, y])\n", " colors.append(int(get_cube_top_color(episode[-1].next_state['measurements'][quat_idx:quat_idx+4], np.pi / 4)))\n", "\n", " x_cord, y_cord = pos2cord(x, y)\n", "print('Done')\n", "\n", "\n", "fig = plt.figure(figsize=(15.0, 5.0))\n", "axes = []\n", "for j in range(6):\n", " axes.append(subplot(1, 6, j + 1))\n", " xy = np.array(positions)[np.array(colors) == list(COLOR_MAP.keys())[j]]\n", " axes[-1].scatter(xy[:, 1], xy[:, 0], c=COLOR_MAP[j], alpha=0.01, edgecolors='black')\n", " plt.xlim(y_range)\n", " plt.ylim(x_range)\n", " plt.xticks([])\n", " plt.yticks([])\n", " axes[-1].set_aspect('equal', adjustable='box')\n", " title = 'N=' + str(xy.shape[0])\n", " plt.title(title)\n", "\n", "for ax in axes:\n", " ax.invert_yaxis()\n", "\n", "plt.show()\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" }, "pycharm": { "stem_cell": { "cell_type": "raw", "metadata": { "collapsed": false }, "source": [] } } }, "nbformat": 4, "nbformat_minor": 1 }