mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Additional cmd line examples (#377)
Adding command line examples to the Quick Start Guide tutorial
This commit is contained in:
@@ -71,10 +71,11 @@
|
||||
"import os\n",
|
||||
"import sys\n",
|
||||
"module_path = os.path.abspath(os.path.join('..'))\n",
|
||||
"resources_path = os.path.abspath(os.path.join('Resources'))\n",
|
||||
"if module_path not in sys.path:\n",
|
||||
" sys.path.append(module_path)\n",
|
||||
"\n",
|
||||
"from rl_coach.coach import CoachInterface"
|
||||
"if resources_path not in sys.path:\n",
|
||||
" sys.path.append(resources_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -83,6 +84,8 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from rl_coach.coach import CoachInterface\n",
|
||||
"\n",
|
||||
"coach = CoachInterface(preset='CartPole_ClippedPPO',\n",
|
||||
" custom_parameter='heatup_steps=EnvironmentSteps(5);improve_steps=TrainingSteps(3)')"
|
||||
]
|
||||
@@ -134,8 +137,13 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Additional functionality\n",
|
||||
"\n",
|
||||
"### Additional functionality"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"`CoachInterface` allows for easy access to functionalities such as multi-threading and saving checkpoints:"
|
||||
]
|
||||
},
|
||||
@@ -153,8 +161,13 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Agent functionality\n",
|
||||
"\n",
|
||||
"### Agent functionality"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"When using `CoachInterface` (single agent with one level of hierarchy) it's also possible to easily use the `Agent` object functionality, such as logging and reading signals and applying the policy the agent has learned on a given state:"
|
||||
]
|
||||
},
|
||||
@@ -197,8 +210,13 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Using GraphManager Directly\n",
|
||||
"\n",
|
||||
"## Using GraphManager Directly"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"It is also possible to invoke coach directly in the python code without defining a preset (which is necessary for `CoachInterface`) by using the `GraphManager` object directly. Using Coach this way won't allow you access functionalities such as multi-threading, but it might be convenient if you don't want to define a preset file.\n",
|
||||
"\n",
|
||||
"Here we show an example of how to do so with a custom environment.\n",
|
||||
@@ -262,6 +280,142 @@
|
||||
"source": [
|
||||
"env_params = GymVectorEnvironment(level='/home/user/my_environment_dir/my_environment_module.py:MyEnvironmentClass')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Advanced functionality - proprietary exploration policy, checkpoint evaluation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Agent modules, such as exploration policy, memory and neural network topology can be replaced with proprietary ones. In this example we'll show how to replace the default exploration policy of the DQN agent with a different one that is defined under the Resources folder. We'll also show how to change the default checkpoint save settings, and how to load a checkpoint for evaluation."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We'll start with the standard definitions of a DQN agent solving the CartPole environment (taken from the Cartpole_DQN preset)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from rl_coach.agents.dqn_agent import DQNAgentParameters\n",
|
||||
"from rl_coach.base_parameters import VisualizationParameters, TaskParameters\n",
|
||||
"from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps\n",
|
||||
"from rl_coach.environments.gym_environment import GymVectorEnvironment\n",
|
||||
"from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager\n",
|
||||
"from rl_coach.graph_managers.graph_manager import ScheduleParameters\n",
|
||||
"from rl_coach.memories.memory import MemoryGranularity\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"####################\n",
|
||||
"# Graph Scheduling #\n",
|
||||
"####################\n",
|
||||
"\n",
|
||||
"schedule_params = ScheduleParameters()\n",
|
||||
"schedule_params.improve_steps = TrainingSteps(4000)\n",
|
||||
"schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(10)\n",
|
||||
"schedule_params.evaluation_steps = EnvironmentEpisodes(1)\n",
|
||||
"schedule_params.heatup_steps = EnvironmentSteps(1000)\n",
|
||||
"\n",
|
||||
"#########\n",
|
||||
"# Agent #\n",
|
||||
"#########\n",
|
||||
"agent_params = DQNAgentParameters()\n",
|
||||
"\n",
|
||||
"# DQN params\n",
|
||||
"agent_params.algorithm.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(100)\n",
|
||||
"agent_params.algorithm.discount = 0.99\n",
|
||||
"agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(1)\n",
|
||||
"\n",
|
||||
"# NN configuration\n",
|
||||
"agent_params.network_wrappers['main'].learning_rate = 0.00025\n",
|
||||
"agent_params.network_wrappers['main'].replace_mse_with_huber_loss = False\n",
|
||||
"\n",
|
||||
"# ER size\n",
|
||||
"agent_params.memory.max_size = (MemoryGranularity.Transitions, 40000)\n",
|
||||
"\n",
|
||||
"################\n",
|
||||
"# Environment #\n",
|
||||
"################\n",
|
||||
"env_params = GymVectorEnvironment(level='CartPole-v0')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Next, we'll override the exploration policy with our own policy defined in `Resources/exploration.py`.\n",
|
||||
"We'll also define the checkpoint save directory and interval in seconds.\n",
|
||||
"\n",
|
||||
"Make sure the first cell at the top of this notebook is run before the following one, such that module_path and resources_path are adding to sys path."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from exploration import MyExplorationParameters\n",
|
||||
"\n",
|
||||
"# Overriding the default DQN Agent exploration policy with my exploration policy\n",
|
||||
"agent_params.exploration = MyExplorationParameters()\n",
|
||||
"\n",
|
||||
"# Creating a graph manager to train a DQN agent to solve CartPole\n",
|
||||
"graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,\n",
|
||||
" schedule_params=schedule_params, vis_params=VisualizationParameters())\n",
|
||||
"\n",
|
||||
"# Resources path was defined at the top of this notebook\n",
|
||||
"my_checkpoint_dir = resources_path + '/checkpoints'\n",
|
||||
"\n",
|
||||
"# Checkpoints will be stored every 5 seconds to the given directory\n",
|
||||
"task_parameters1 = TaskParameters()\n",
|
||||
"task_parameters1.checkpoint_save_dir = my_checkpoint_dir\n",
|
||||
"task_parameters1.checkpoint_save_secs = 5\n",
|
||||
"\n",
|
||||
"graph_manager.create_graph(task_parameters1)\n",
|
||||
"graph_manager.improve()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Last, we'll load the latest checkpoint from the checkpoint directory, and evaluate it."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import tensorflow as tf\n",
|
||||
"import shutil\n",
|
||||
"\n",
|
||||
"# Clearing the previous graph before creating the new one to avoid name conflicts\n",
|
||||
"tf.reset_default_graph()\n",
|
||||
"\n",
|
||||
"# Updating the graph manager's task parameters to restore the latest stored checkpoint from the checkpoints directory\n",
|
||||
"task_parameters2 = TaskParameters()\n",
|
||||
"task_parameters2.checkpoint_restore_path = my_checkpoint_dir\n",
|
||||
"\n",
|
||||
"graph_manager.create_graph(task_parameters2)\n",
|
||||
"graph_manager.evaluate(EnvironmentSteps(5))\n",
|
||||
"\n",
|
||||
"# Clearning up\n",
|
||||
"shutil.rmtree(my_checkpoint_dir)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@@ -273,16 +427,16 @@
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3.0
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.5.2"
|
||||
"version": "3.6.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
"nbformat_minor": 1
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user