mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
* Add Robosuite parameters for all env types + initialize env flow * Init flow done * Rest of Environment API complete for RobosuiteEnvironment * RobosuiteEnvironment changes * Observation stacking filter * Add proper frame_skip in addition to control_freq * Hardcode Coach rendering to 'frontview' camera * Robosuite_Lift_DDPG preset + Robosuite env updates * Move observation stacking filter from env to preset * Pre-process observation - concatenate depth map (if exists) to image and object state (if exists) to robot state * Preset parameters based on Surreal DDPG parameters, taken from: https://github.com/SurrealAI/surreal/blob/master/surreal/main/ddpg_configs.py * RobosuiteEnvironment fixes - working now with PyGame rendering * Preset minor modifications * ObservationStackingFilter - option to concat non-vector observations * Consider frame skip when setting horizon in robosuite env * Robosuite lift preset - update heatup length and training interval * Robosuite env - change control_freq to 10 to match Surreal usage * Robosuite clipped PPO preset * Distribute multiple workers (-n #) over multiple GPUs * Clipped PPO memory optimization from @shadiendrawis * Fixes to evaluation only workers * RoboSuite_ClippedPPO: Update training interval * Undo last commit (update training interval) * Fix "doube-negative" if conditions * multi-agent single-trainer clipped ppo training with cartpole * cleanups (not done yet) + ~tuned hyper-params for mast * Switch to Robosuite v1 APIs * Change presets to IK controller * more cleanups + enabling evaluation worker + better logging * RoboSuite_Lift_ClippedPPO updates * Fix major bug in obs normalization filter setup * Reduce coupling between Robosuite API and Coach environment * Now only non task-specific parameters are explicitly defined in Coach * Removed a bunch of enums of Robosuite elements, using simple strings instead * With this change new environments/robots/controllers in Robosuite can be used immediately in Coach * MAST: better logging of actor-trainer interaction + bug fixes + performance improvements. Still missing: fixed pubsub for obs normalization running stats + logging for trainer signals * lstm support for ppo * setting JOINT VELOCITY action space by default + fix for EveryNEpisodes video dump filter + new TaskIDDumpFilter + allowing or between video dump filters * Separate Robosuite clipped PPO preset for the non-MAST case * Add flatten layer to architectures and use it in Robosuite presets This is required for embedders that mix conv and dense TODO: Add MXNet implementation * publishing running_stats together with the published policy + hyper-param for when to publish a policy + cleanups * bug-fix for memory leak in MAST * Bugfix: Return value in TF BatchnormActivationDropout.to_tf_instance * Explicit activations in embedder scheme so there's no ReLU after flatten * Add clipped PPO heads with configurable dense layers at the beginning * This is a workaround needed to mimic Surreal-PPO, where the CNN and LSTM are shared between actor and critic but the FC layers are not shared * Added a "SchemeBuilder" class, currently only used for the new heads but we can change Middleware and Embedder implementations to use it as well * Video dump setting fix in basic preset * logging screen output to file * coach to start the redis-server for a MAST run * trainer drops off-policy data + old policy in ClippedPPO updates only after policy was published + logging free memory stats + actors check for a new policy only at the beginning of a new episode + fixed a bug where the trainer was logging "Training Reward = 0", causing dashboard to incorrectly display the signal * Add missing set_internal_state function in TFSharedRunningStats * Robosuite preset - use SingleLevelSelect instead of hard-coded level * policy ID published directly on Redis * Small fix when writing to log file * Major bugfix in Robosuite presets - pass dense sizes to heads * RoboSuite_Lift_ClippedPPO hyper-params update * add horizon and value bootstrap to GAE calculation, fix A3C with LSTM * adam hyper-params from mujoco * updated MAST preset with IK_POSE_POS controller * configurable initialization for policy stdev + custom extra noise per actor + logging of policy stdev to dashboard * values loss weighting of 0.5 * minor fixes + presets * bug-fix for MAST where the old policy in the trainer had kept updating every training iter while it should only update after every policy publish * bug-fix: reset_internal_state was not called by the trainer * bug-fixes in the lstm flow + some hyper-param adjustments for CartPole_ClippedPPO_LSTM -> training and sometimes reaches 200 * adding back the horizon hyper-param - a messy commit * another bug-fix missing from prev commit * set control_freq=2 to match action_scale 0.125 * ClippedPPO with MAST cleanups and some preps for TD3 with MAST * TD3 presets. RoboSuite_Lift_TD3 seems to work well with multi-process runs (-n 8) * setting termination on collision to be on by default * bug-fix following prev-prev commit * initial cube exploration environment with TD3 commit * bug fix + minor refactoring * several parameter changes and RND debugging * Robosuite Gym wrapper + Rename TD3_Random* -> Random* * algorithm update * Add RoboSuite v1 env + presets (to eventually replace non-v1 ones) * Remove grasping presets, keep only V1 exp. presets (w/o V1 tag) * Keep just robosuite V1 env as the 'robosuite_environment' module * Exclude Robosuite and MAST presets from integration tests * Exclude LSTM and MAST presets from golden tests * Fix mistakenly removed import * Revert debug changes in ReaderWriterLock * Try another way to exclude LSTM/MAST golden tests * Remove debug prints * Remove PreDense heads, unused in the end * Missed removing an instance of PreDense head * Remove MAST, not required for this PR * Undo unused concat option in ObservationStackingFilter * Remove LSTM updates, not required in this PR * Update README.md * code changes for the exploration flow to work with robosuite master branch * code cleanup + documentation * jupyter tutorial for the goal-based exploration + scatter plot * typo fix * Update README.md * seprate parameter for the obs-goal observation + small fixes * code clarity fixes * adjustment in tutorial 5 * Update tutorial * Update tutorial Co-authored-by: Guy Jacob <guy.jacob@intel.com> Co-authored-by: Gal Leibovich <gal.leibovich@intel.com> Co-authored-by: shadi.endrawis <sendrawi@aipg-ra-skx-03.ra.intel.com>
449 lines
16 KiB
Plaintext
449 lines
16 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Goal-Based Data Collection\n",
|
|
"A practical approach to robot reinforcement learning is to first collect a large batch of real or simulated robot interaction data, \n",
|
|
"using some data collection policy, and then learn from this data to perform various tasks, using offline learning algorithms.\n",
|
|
"\n",
|
|
"In this notebook, we will demonstrate how to collect diverse dataset for a simple robotics manipulation task\n",
|
|
"using the algorithms detailed in the following paper:\n",
|
|
"[Efficient Self-Supervised Data Collection for Offline Robot Learning](https://arxiv.org/abs/2105.04607).\n",
|
|
"\n",
|
|
"The implementation is based on the Robosuite simulator, which should be installed before running this notebook. Follow the instructions in the Coach readme [here](https://github.com/IntelLabs/coach#robosuite).\n",
|
|
"\n",
|
|
"Presets with predefined parameters for all three algorithms shown in the paper can be found here:\n",
|
|
"\n",
|
|
"* Random Agent: ```presets/RoboSuite_CubeExp_Random.py```\n",
|
|
"\n",
|
|
"* Intrinsic Reward Agent: ```presets/RoboSuite_CubeExp_TD3_Intrinsic_Reward.py```\n",
|
|
"\n",
|
|
"* Goal-Based Agent: ```presets/RoboSuite_CubeExp_TD3_Goal_Based.py```\n",
|
|
"\n",
|
|
"You can run those presets using the command line:\n",
|
|
"\n",
|
|
"`coach -p RoboSuite_CubeExp_TD3_Goal_Based`\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Preliminaries\n",
|
|
"First, get the required imports and other general settings we need for this notebook.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from rl_coach.agents.td3_exp_agent import TD3GoalBasedAgentParameters\n",
|
|
"from rl_coach.architectures.embedder_parameters import InputEmbedderParameters\n",
|
|
"from rl_coach.architectures.layers import Dense, Conv2d, BatchnormActivationDropout, Flatten\n",
|
|
"from rl_coach.base_parameters import EmbedderScheme\n",
|
|
"from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps\n",
|
|
"from rl_coach.environments.robosuite_environment import RobosuiteGoalBasedExpEnvironmentParameters, \\\n",
|
|
" OptionalObservations\n",
|
|
"from rl_coach.filters.filter import NoInputFilter, NoOutputFilter\n",
|
|
"from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager\n",
|
|
"from rl_coach.graph_managers.graph_manager import ScheduleParameters\n",
|
|
"from rl_coach.architectures.head_parameters import RNDHeadParameters\n",
|
|
"from rl_coach.schedules import LinearSchedule\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Then, we define the training schedule for the agent. `improve_steps` dictates the number of samples in the final data-set.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"####################\n",
|
|
"# Graph Scheduling #\n",
|
|
"####################\n",
|
|
"\n",
|
|
"schedule_params = ScheduleParameters()\n",
|
|
"schedule_params.improve_steps = TrainingSteps(300000)\n",
|
|
"schedule_params.steps_between_evaluation_periods = TrainingSteps(300000)\n",
|
|
"schedule_params.evaluation_steps = EnvironmentEpisodes(0)\n",
|
|
"schedule_params.heatup_steps = EnvironmentSteps(1000)\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"In this example, we will be using the goal-based algorithm for data-collection. Therefore, we populate\n",
|
|
"the `TD3GoalBasedAgentParameters` class with our desired algorithm specific parameters.\n",
|
|
"\n",
|
|
"The goal-based data collected is based on TD3, using this class you can change the TD3 specific parameters as well.\n",
|
|
"\n",
|
|
"A detailed description of the goal-based and TD3 algorithm specific parameters can be found in \n",
|
|
"```agents/td3_exp_agent.py``` and ```agents/td3_agent.py``` respectively.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"#########\n",
|
|
"# Agent #\n",
|
|
"#########\n",
|
|
"\n",
|
|
"agent_params = TD3GoalBasedAgentParameters()\n",
|
|
"agent_params.algorithm.use_non_zero_discount_for_terminal_states = False\n",
|
|
"agent_params.algorithm.identity_goal_sample_rate = 0.04\n",
|
|
"agent_params.exploration.noise_schedule = LinearSchedule(1.5, 0.5, 300000)\n",
|
|
"\n",
|
|
"agent_params.algorithm.rnd_sample_size = 2000\n",
|
|
"agent_params.algorithm.rnd_batch_size = 500\n",
|
|
"agent_params.algorithm.rnd_optimization_epochs = 4\n",
|
|
"agent_params.algorithm.td3_training_ratio = 1.0\n",
|
|
"agent_params.algorithm.identity_goal_sample_rate = 0.0\n",
|
|
"agent_params.algorithm.env_obs_key = 'camera'\n",
|
|
"agent_params.algorithm.agent_obs_key = 'obs-goal'\n",
|
|
"agent_params.algorithm.replay_buffer_save_steps = 25000\n",
|
|
"agent_params.algorithm.replay_buffer_save_path = './Resources'\n",
|
|
"\n",
|
|
"agent_params.input_filter = NoInputFilter()\n",
|
|
"agent_params.output_filter = NoOutputFilter()\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Next, we'll define the networks' architecture and parameters as they appear in the paper.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Camera observation pre-processing network scheme\n",
|
|
"camera_obs_scheme = [\n",
|
|
" Conv2d(32, 8, 4),\n",
|
|
" BatchnormActivationDropout(activation_function='relu'),\n",
|
|
" Conv2d(64, 4, 2),\n",
|
|
" BatchnormActivationDropout(activation_function='relu'),\n",
|
|
" Conv2d(64, 3, 1),\n",
|
|
" BatchnormActivationDropout(activation_function='relu'),\n",
|
|
" Flatten(),\n",
|
|
" Dense(256),\n",
|
|
" BatchnormActivationDropout(activation_function='relu')\n",
|
|
"]\n",
|
|
"\n",
|
|
"# Actor\n",
|
|
"actor_network = agent_params.network_wrappers['actor']\n",
|
|
"actor_network.input_embedders_parameters = {\n",
|
|
" 'measurements': InputEmbedderParameters(scheme=EmbedderScheme.Empty),\n",
|
|
" agent_params.algorithm.agent_obs_key: InputEmbedderParameters(scheme=camera_obs_scheme, activation_function='none')\n",
|
|
"}\n",
|
|
"\n",
|
|
"actor_network.middleware_parameters.scheme = [Dense(300), Dense(200)]\n",
|
|
"actor_network.learning_rate = 1e-4\n",
|
|
"\n",
|
|
"# Critic\n",
|
|
"critic_network = agent_params.network_wrappers['critic']\n",
|
|
"critic_network.input_embedders_parameters = {\n",
|
|
" 'action': InputEmbedderParameters(scheme=EmbedderScheme.Empty),\n",
|
|
" 'measurements': InputEmbedderParameters(scheme=EmbedderScheme.Empty),\n",
|
|
" agent_params.algorithm.agent_obs_key: InputEmbedderParameters(scheme=camera_obs_scheme, activation_function='none')\n",
|
|
"}\n",
|
|
"\n",
|
|
"critic_network.middleware_parameters.scheme = [Dense(400), Dense(300)]\n",
|
|
"critic_network.learning_rate = 1e-4\n",
|
|
"\n",
|
|
"# RND\n",
|
|
"agent_params.network_wrappers['predictor'].input_embedders_parameters = \\\n",
|
|
" {agent_params.algorithm.env_obs_key: InputEmbedderParameters(scheme=EmbedderScheme.Empty,\n",
|
|
" input_rescaling={'image': 1.0},\n",
|
|
" flatten=False)}\n",
|
|
"agent_params.network_wrappers['constant'].input_embedders_parameters = \\\n",
|
|
" {agent_params.algorithm.env_obs_key: InputEmbedderParameters(scheme=EmbedderScheme.Empty,\n",
|
|
" input_rescaling={'image': 1.0},\n",
|
|
" flatten=False)}\n",
|
|
"agent_params.network_wrappers['predictor'].heads_parameters = [RNDHeadParameters(is_predictor=True)]\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The last thing we need to define is the environment parameters for the manipulation task.\n",
|
|
"This environment is a 7DoF Franka Panda robotic arm with a closed gripper and cartesian\n",
|
|
"position control of the end-effector. The robot is positioned on a table, and a cube object with colored sides is placed in\n",
|
|
"front of it.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"###############\n",
|
|
"# Environment #\n",
|
|
"###############\n",
|
|
"env_params = RobosuiteGoalBasedExpEnvironmentParameters(level='CubeExp')\n",
|
|
"env_params.robot = 'Panda'\n",
|
|
"env_params.custom_controller_config_fpath = '../rl_coach/environments/robosuite/osc_pose.json'\n",
|
|
"env_params.base_parameters.optional_observations = OptionalObservations.CAMERA\n",
|
|
"env_params.base_parameters.render_camera = 'frontview'\n",
|
|
"env_params.base_parameters.camera_names = 'agentview'\n",
|
|
"env_params.base_parameters.camera_depths = False\n",
|
|
"env_params.base_parameters.horizon = 200\n",
|
|
"env_params.base_parameters.ignore_done = False\n",
|
|
"env_params.base_parameters.use_object_obs = True\n",
|
|
"env_params.frame_skip = 1\n",
|
|
"env_params.base_parameters.control_freq = 2\n",
|
|
"env_params.base_parameters.camera_heights = 84\n",
|
|
"env_params.base_parameters.camera_widths = 84\n",
|
|
"env_params.extra_parameters = {'hard_reset': False}\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Finally, we create the graph manager and call `graph_manager.improve()` in order to start the data collection.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params, schedule_params=schedule_params)\n",
|
|
"graph_manager.improve()\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Once the data collection is complete, the data-set will saved to path specified by `agent_params.algorithm.replay_buffer_save_path`.\n",
|
|
"\n",
|
|
"At this point, the data can be used to learn any downstream task you define on that environment.\n",
|
|
"\n",
|
|
"The script below shows a visualization of the data-set. The dots represent a position of the cube on the table as seen in the data-set, and the color corresponds to the color of the face at the top. The number at the top signifies that number of dots a plot contains for a certain color.\n",
|
|
"\n",
|
|
"First we load the data-set from disk. Note that this can take several minutes to complete."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import joblib\n",
|
|
"\n",
|
|
"print('Loading data-set (this can take several minutes)...')\n",
|
|
"rb_path = os.path.join('./Resources', 'RB_TD3GoalBasedAgent.joblib.bz2')\n",
|
|
"episodes = joblib.load(rb_path)\n",
|
|
"print('Done')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Now we can run the visualization script:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"import numpy as np\n",
|
|
"from collections import OrderedDict\n",
|
|
"from enum import IntEnum\n",
|
|
"from pylab import subplot\n",
|
|
"from gym.envs.robotics.rotations import quat2euler, mat2euler, quat2mat\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"\n",
|
|
"\n",
|
|
"class CubeColor(IntEnum):\n",
|
|
" YELLOW = 0\n",
|
|
" CYAN = 1\n",
|
|
" WHITE = 2\n",
|
|
" RED = 3\n",
|
|
" GREEN = 4\n",
|
|
" BLUE = 5\n",
|
|
" UNKNOWN = 6\n",
|
|
"\n",
|
|
"\n",
|
|
"x_range = [-0.3, 0.3]\n",
|
|
"y_range = [-0.3, 0.3]\n",
|
|
"\n",
|
|
"COLOR_MAP = OrderedDict([\n",
|
|
" (int(CubeColor.YELLOW), 'yellow'),\n",
|
|
" (int(CubeColor.CYAN), 'cyan'),\n",
|
|
" (int(CubeColor.WHITE), 'white'),\n",
|
|
" (int(CubeColor.RED), 'red'),\n",
|
|
" (int(CubeColor.GREEN), 'green'),\n",
|
|
" (int(CubeColor.BLUE), 'blue'),\n",
|
|
" (int(CubeColor.UNKNOWN), 'black'),\n",
|
|
"])\n",
|
|
"\n",
|
|
"# Mapping between (subset of) euler angles to top face color, based on the initial cube rotation\n",
|
|
"COLOR_ROTATION_MAP = OrderedDict([\n",
|
|
" (CubeColor.YELLOW, (0, 2, [np.array([0, 0]),\n",
|
|
" np.array([np.pi, np.pi]), np.array([-np.pi, -np.pi]),\n",
|
|
" np.array([-np.pi, np.pi]), np.array([np.pi, -np.pi])])),\n",
|
|
" (CubeColor.CYAN, (0, 2, [np.array([0, np.pi]), np.array([0, -np.pi]),\n",
|
|
" np.array([np.pi, 0]), np.array([-np.pi, 0])])),\n",
|
|
" (CubeColor.WHITE, (1, 2, [np.array([-np.pi / 2])])),\n",
|
|
" (CubeColor.RED, (1, 2, [np.array([np.pi / 2])])),\n",
|
|
" (CubeColor.GREEN, (0, 2, [np.array([np.pi / 2, 0])])),\n",
|
|
" (CubeColor.BLUE, (0, 2, [np.array([-np.pi / 2, 0])])),\n",
|
|
"])\n",
|
|
"\n",
|
|
"\n",
|
|
"def get_cube_top_color(cube_quat, atol):\n",
|
|
" euler = mat2euler(quat2mat(cube_quat))\n",
|
|
" for color, (start_dim, end_dim, xy_rotations) in COLOR_ROTATION_MAP.items():\n",
|
|
" if any(list(np.allclose(euler[start_dim:end_dim], xy_rotation, atol=atol) for xy_rotation in xy_rotations)):\n",
|
|
" return color\n",
|
|
" return CubeColor.UNKNOWN\n",
|
|
"\n",
|
|
"\n",
|
|
"def pos2cord(x, y):\n",
|
|
" x = max(min(x, x_range[1]), x_range[0])\n",
|
|
" y = max(min(y, y_range[1]), y_range[0])\n",
|
|
" x = int(((x - x_range[0])/(x_range[1] - x_range[0]))*99)\n",
|
|
" y = int(((y - y_range[0])/(y_range[1] - y_range[0]))*99)\n",
|
|
" return x, y\n",
|
|
"\n",
|
|
"\n",
|
|
"pos_idx = 25\n",
|
|
"quat_idx = 28\n",
|
|
"positions = []\n",
|
|
"colors = []\n",
|
|
"print('Extracting cube positions and colors...')\n",
|
|
"for episode in episodes:\n",
|
|
" for transition in episode:\n",
|
|
" x, y = transition.state['measurements'][pos_idx:pos_idx+2]\n",
|
|
" positions.append([x, y])\n",
|
|
" angle = quat2euler(transition.state['measurements'][quat_idx:quat_idx+4])\n",
|
|
" colors.append(int(get_cube_top_color(transition.state['measurements'][quat_idx:quat_idx+4], np.pi / 4)))\n",
|
|
"\n",
|
|
" x_cord, y_cord = pos2cord(x, y)\n",
|
|
"\n",
|
|
" x, y = episode[-1].next_state['measurements'][pos_idx:pos_idx+2]\n",
|
|
" positions.append([x, y])\n",
|
|
" colors.append(int(get_cube_top_color(episode[-1].next_state['measurements'][quat_idx:quat_idx+4], np.pi / 4)))\n",
|
|
"\n",
|
|
" x_cord, y_cord = pos2cord(x, y)\n",
|
|
"print('Done')\n",
|
|
"\n",
|
|
"\n",
|
|
"fig = plt.figure(figsize=(15.0, 5.0))\n",
|
|
"axes = []\n",
|
|
"for j in range(6):\n",
|
|
" axes.append(subplot(1, 6, j + 1))\n",
|
|
" xy = np.array(positions)[np.array(colors) == list(COLOR_MAP.keys())[j]]\n",
|
|
" axes[-1].scatter(xy[:, 1], xy[:, 0], c=COLOR_MAP[j], alpha=0.01, edgecolors='black')\n",
|
|
" plt.xlim(y_range)\n",
|
|
" plt.ylim(x_range)\n",
|
|
" plt.xticks([])\n",
|
|
" plt.yticks([])\n",
|
|
" axes[-1].set_aspect('equal', adjustable='box')\n",
|
|
" title = 'N=' + str(xy.shape[0])\n",
|
|
" plt.title(title)\n",
|
|
"\n",
|
|
"for ax in axes:\n",
|
|
" ax.invert_yaxis()\n",
|
|
"\n",
|
|
"plt.show()\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.6.9"
|
|
},
|
|
"pycharm": {
|
|
"stem_cell": {
|
|
"cell_type": "raw",
|
|
"metadata": {
|
|
"collapsed": false
|
|
},
|
|
"source": []
|
|
}
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 1
|
|
}
|