From abaa58b559316d17698c74ba12d3d97226df2947 Mon Sep 17 00:00:00 2001 From: Gal Novik Date: Mon, 17 Sep 2018 15:59:00 +0300 Subject: [PATCH] human agent will exit when human control not supported by environment; jupyter notebooks fixes --- rl_coach/environments/gym_environment.py | 2 ++ tutorials/1. Implementing an Algorithm.ipynb | 2 +- tutorials/2. Adding an Environment.ipynb | 4 ++-- tutorials/3. Implementing a Hierarchical RL Graph.ipynb | 8 +++++--- 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/rl_coach/environments/gym_environment.py b/rl_coach/environments/gym_environment.py index 2a6db97..f21939d 100644 --- a/rl_coach/environments/gym_environment.py +++ b/rl_coach/environments/gym_environment.py @@ -329,6 +329,8 @@ class GymEnvironment(Environment): self.key_to_action = {} if hasattr(self.env.unwrapped, 'get_keys_to_action'): self.key_to_action = self.env.unwrapped.get_keys_to_action() + else: + screen.error("Error: Environment {} does not support human control.".format(self.env), crash=True) # initialize the state by getting a new state from the environment self.reset_internal_state(True) diff --git a/tutorials/1. Implementing an Algorithm.ipynb b/tutorials/1. Implementing an Algorithm.ipynb index 3cc90ca..7e7866d 100644 --- a/tutorials/1. Implementing an Algorithm.ipynb +++ b/tutorials/1. Implementing an Algorithm.ipynb @@ -348,7 +348,7 @@ "metadata": {}, "source": [ "# Running the Preset\n", - "(this is normally done from command line by running ```python coach.py -p atari_categorical_dqn ... ```)" + "(this is normally done from command line by running ```coach -p Atari_C51 ... ```)" ] }, { diff --git a/tutorials/2. Adding an Environment.ipynb b/tutorials/2. Adding an Environment.ipynb index 95a1582..5bec481 100644 --- a/tutorials/2. Adding an Environment.ipynb +++ b/tutorials/2. Adding an Environment.ipynb @@ -28,7 +28,7 @@ "outputs": [], "source": [ "import os\n", - "os.environ['DISABLE_MUJOCO_RENDERING'] = '1'\n", + "#os.environ['DISABLE_MUJOCO_RENDERING'] = '1'\n", "\n", "import sys\n", "module_path = os.path.abspath(os.path.join('..'))\n", @@ -334,7 +334,7 @@ " schedule_params=schedule_params, vis_params=vis_params)\n", "\n", "graph_manager.env_params.level.select('walker:walk')\n", - "#graph_manager.visualization_parameters.render = True\n", + "graph_manager.visualization_parameters.render = True\n", "\n", "\n", "log_path = '../experiments/control_suite_walker_ddpg'\n", diff --git a/tutorials/3. Implementing a Hierarchical RL Graph.ipynb b/tutorials/3. Implementing a Hierarchical RL Graph.ipynb index d2dff6d..32e832b 100644 --- a/tutorials/3. Implementing a Hierarchical RL Graph.ipynb +++ b/tutorials/3. Implementing a Hierarchical RL Graph.ipynb @@ -163,7 +163,8 @@ "outputs": [], "source": [ "from rl_coach.architectures.tensorflow_components.architecture import Dense\n", - "from rl_coach.base_parameters import VisualizationParameters, EmbeddingMergerType, EmbedderScheme, InputEmbedderParameters\n", + "from rl_coach.base_parameters import VisualizationParameters, EmbeddingMergerType, EmbedderScheme\n", + "from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters\n", "from rl_coach.memories.episodic.episodic_hindsight_experience_replay import HindsightGoalSelectionMethod, \\\n", " EpisodicHindsightExperienceReplayParameters\n", "from rl_coach.memories.episodic.episodic_hrl_hindsight_experience_replay import \\\n", @@ -344,7 +345,8 @@ "source": [ "graph_manager = HRLGraphManager(agents_params=agents_params, env_params=env_params,\n", " schedule_params=schedule_params, vis_params=vis_params,\n", - " consecutive_steps_to_run_each_level=EnvironmentSteps(40))" + " consecutive_steps_to_run_each_level=EnvironmentSteps(40))\n", + "graph_manager.visualization_parameters.render = True" ] }, { @@ -360,7 +362,7 @@ "metadata": {}, "outputs": [], "source": [ - "from base_parameters import TaskParameters, Frameworks\n", + "from rl_coach.base_parameters import TaskParameters, Frameworks\n", "\n", "log_path = '../experiments/pendulum_hac'\n", "if not os.path.exists(log_path):\n",