mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
439 lines
17 KiB
Plaintext
439 lines
17 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Getting Started Guide"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Table of Contents\n",
|
|
"- [Using Coach from the Command Line](#Using-Coach-from-the-Command-Line)\n",
|
|
"- [Using Coach as a Library](#Using-Coach-as-a-Library)\n",
|
|
" - [Preset based - using `CoachInterface`](#Preset-based---using-CoachInterface)\n",
|
|
" - [Training a preset](#Training-a-preset)\n",
|
|
" - [Running each training or inference iteration manually](#Running-each-training-or-inference-iteration-manually)\n",
|
|
" - [Non-preset - using `GraphManager` directly](#Non-preset---using-GraphManager-directly)\n",
|
|
" - [Training an agent with a custom Gym environment](#Training-an-agent-with-a-custom-Gym-environment)\n",
|
|
" - [Advanced functionality - proprietary exploration policy, checkpoint evaluation](#Advanced-functionality---proprietary-exploration-policy,-checkpoint-evaluation)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Using Coach from the Command Line"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"When running Coach from the command line, we use a Preset module to define the experiment parameters.\n",
|
|
"As its name implies, a preset is a predefined set of parameters to run some agent on some environment.\n",
|
|
"Coach has many predefined presets that follow the algorithms definitions in the published papers, and allows training some of the existing algorithms with essentially no coding at all. This presets can easily be run from the command line. For example:\n",
|
|
"\n",
|
|
"`coach -p CartPole_DQN`\n",
|
|
"\n",
|
|
"You can find all the predefined presets under the `presets` directory, or by listing them using the following command:\n",
|
|
"\n",
|
|
"`coach -l`\n",
|
|
"\n",
|
|
"Coach can also be used with an externally defined preset by passing the absolute path to the module and the name of the graph manager object which is defined in the preset: \n",
|
|
"\n",
|
|
"`coach -p /home/my_user/my_agent_dir/my_preset.py:graph_manager`\n",
|
|
"\n",
|
|
"Some presets are generic for multiple environment levels, and therefore require defining the specific level through the command line:\n",
|
|
"\n",
|
|
"`coach -p Atari_DQN -lvl breakout`\n",
|
|
"\n",
|
|
"There are plenty of other command line arguments you can use in order to customize the experiment. A full documentation of the available arguments can be found using the following command:\n",
|
|
"\n",
|
|
"`coach -h`"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Using Coach as a Library"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Alternatively, Coach can be used a library directly from python. As described above, Coach uses the presets mechanism to define the experiments. A preset is essentially a python module which instantiates a `GraphManager` object. The graph manager is a container that holds the agents and the environments, and has some additional parameters for running the experiment, such as visualization parameters. The graph manager acts as the scheduler which orchestrates the experiment.\n",
|
|
"\n",
|
|
"**Note: Each one of the examples in this section is independent, so notebook kernels need to be restarted before running it. Make sure you run the next cell before running any of the examples.**"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Adding module path to sys path if not there, so rl_coach submodules can be imported\n",
|
|
"import os\n",
|
|
"import sys\n",
|
|
"module_path = os.path.abspath(os.path.join('..'))\n",
|
|
"resources_path = os.path.abspath(os.path.join('Resources'))\n",
|
|
"if module_path not in sys.path:\n",
|
|
" sys.path.append(module_path)\n",
|
|
"if resources_path not in sys.path:\n",
|
|
" sys.path.append(resources_path)\n",
|
|
" \n",
|
|
"from rl_coach.coach import CoachInterface"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Preset based - using `CoachInterface`\n",
|
|
"\n",
|
|
"The basic method to run Coach directly from python is through a `CoachInterface` object, which uses the same arguments as the command line invocation but allowes for more flexibility and additional control of the training/inference process.\n",
|
|
"\n",
|
|
"Let's start with some examples."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"#### Training a preset\n",
|
|
"In this example, we'll create a very simple graph containing a Clipped PPO agent running with the CartPole-v0 Gym environment. `CoachInterface` has a few useful parameters such as `custom_parameter` that enables overriding preset settings, and other optional parameters enabling control over the training process. We'll override the preset's schedule parameters, train with a single rollout worker, and save checkpoints every 10 seconds:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"coach = CoachInterface(preset='CartPole_ClippedPPO',\n",
|
|
" # The optional custom_parameter enables overriding preset settings\n",
|
|
" custom_parameter='heatup_steps=EnvironmentSteps(5);improve_steps=TrainingSteps(3)',\n",
|
|
" # Other optional parameters enable easy access to advanced functionalities\n",
|
|
" num_workers=1, checkpoint_save_secs=10)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"coach.run()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"#### Running each training or inference iteration manually"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The graph manager (which was instantiated in the preset) can be accessed from the `CoachInterface` object. The graph manager simplifies the scheduling process by encapsulating the calls to each of the training phases. Sometimes, it can be beneficial to have a more fine grained control over the scheduling process. This can be easily done by calling the individual phase functions directly:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from rl_coach.environments.gym_environment import GymEnvironment, GymVectorEnvironment\n",
|
|
"from rl_coach.base_parameters import VisualizationParameters\n",
|
|
"from rl_coach.core_types import EnvironmentSteps\n",
|
|
"\n",
|
|
"coach = CoachInterface(preset='CartPole_ClippedPPO')\n",
|
|
"\n",
|
|
"# registering an iteration signal before starting to run\n",
|
|
"coach.graph_manager.log_signal('iteration', -1)\n",
|
|
"\n",
|
|
"coach.graph_manager.heatup(EnvironmentSteps(100))\n",
|
|
"\n",
|
|
"# training\n",
|
|
"for it in range(10):\n",
|
|
" # logging the iteration signal during training\n",
|
|
" coach.graph_manager.log_signal('iteration', it)\n",
|
|
" # using the graph manager to train and act a given number of steps\n",
|
|
" coach.graph_manager.train_and_act(EnvironmentSteps(100))\n",
|
|
" # reading signals during training\n",
|
|
" training_reward = coach.graph_manager.get_signal_value('Training Reward')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Sometimes we may want to track the agent's decisions, log or maybe even modify them.\n",
|
|
"We can access the agent itself through the `CoachInterface` as follows. \n",
|
|
"\n",
|
|
"Note that we also need an instance of the environment to do so. In this case we use instantiate a `GymEnvironment` object with the CartPole `GymVectorEnvironment`:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# inference\n",
|
|
"env_params = GymVectorEnvironment(level='CartPole-v0')\n",
|
|
"env = GymEnvironment(**env_params.__dict__, visualization_parameters=VisualizationParameters())\n",
|
|
"\n",
|
|
"response = env.reset_internal_state()\n",
|
|
"for _ in range(10):\n",
|
|
" action_info = coach.graph_manager.get_agent().choose_action(response.next_state)\n",
|
|
" print(\"State:{}, Action:{}\".format(response.next_state,action_info.action))\n",
|
|
" response = env.step(action_info.action)\n",
|
|
" print(\"Reward:{}\".format(response.reward))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Non-preset - using `GraphManager` directly"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"It is also possible to invoke coach directly in the python code without defining a preset (which is necessary for `CoachInterface`) by using the `GraphManager` object directly. Using Coach this way won't allow you access functionalities such as multi-threading, but it might be convenient if you don't want to define a preset file."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"#### Training an agent with a custom Gym environment\n",
|
|
"\n",
|
|
"Here we show an example of how to use the `GraphManager` to train an agent on a custom Gym environment.\n",
|
|
"\n",
|
|
"We first construct a `GymEnvironmentParameters` object describing the environment parameters. For Gym environments with vector observations, we can use the more specific `GymVectorEnvironment` object. \n",
|
|
"\n",
|
|
"The path to the custom environment is defined in the `level` parameter and it can be the absolute path to its class (e.g. `'/home/user/my_environment_dir/my_environment_module.py:MyEnvironmentClass'`) or the relative path to the module as in this example. In any case, we can use the custom gym environment without registering it.\n",
|
|
"\n",
|
|
"Custom parameters for the environment's `__init__` function can be passed as `additional_simulator_parameters`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters\n",
|
|
"from rl_coach.environments.gym_environment import GymVectorEnvironment\n",
|
|
"from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager\n",
|
|
"from rl_coach.graph_managers.graph_manager import SimpleSchedule\n",
|
|
"from rl_coach.architectures.embedder_parameters import InputEmbedderParameters\n",
|
|
"\n",
|
|
"# define the environment parameters\n",
|
|
"bit_length = 10\n",
|
|
"env_params = GymVectorEnvironment(level='rl_coach.environments.toy_problems.bit_flip:BitFlip')\n",
|
|
"env_params.additional_simulator_parameters = {'bit_length': bit_length, 'mean_zero': True}\n",
|
|
"\n",
|
|
"# Clipped PPO\n",
|
|
"agent_params = ClippedPPOAgentParameters()\n",
|
|
"agent_params.network_wrappers['main'].input_embedders_parameters = {\n",
|
|
" 'state': InputEmbedderParameters(scheme=[]),\n",
|
|
" 'desired_goal': InputEmbedderParameters(scheme=[])\n",
|
|
"}\n",
|
|
"\n",
|
|
"graph_manager = BasicRLGraphManager(\n",
|
|
" agent_params=agent_params,\n",
|
|
" env_params=env_params,\n",
|
|
" schedule_params=SimpleSchedule()\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"graph_manager.improve()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"#### Advanced functionality - proprietary exploration policy, checkpoint evaluation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Agent modules, such as exploration policy, memory and neural network topology can be replaced with proprietary ones. In this example we'll show how to replace the default exploration policy of the DQN agent with a different one that is defined under the Resources folder. We'll also show how to change the default checkpoint save settings, and how to load a checkpoint for evaluation."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We'll start with the standard definitions of a DQN agent solving the CartPole environment (taken from the Cartpole_DQN preset)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from rl_coach.agents.dqn_agent import DQNAgentParameters\n",
|
|
"from rl_coach.base_parameters import VisualizationParameters, TaskParameters\n",
|
|
"from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps\n",
|
|
"from rl_coach.environments.gym_environment import GymVectorEnvironment\n",
|
|
"from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager\n",
|
|
"from rl_coach.graph_managers.graph_manager import ScheduleParameters\n",
|
|
"from rl_coach.memories.memory import MemoryGranularity\n",
|
|
"\n",
|
|
"\n",
|
|
"####################\n",
|
|
"# Graph Scheduling #\n",
|
|
"####################\n",
|
|
"\n",
|
|
"schedule_params = ScheduleParameters()\n",
|
|
"schedule_params.improve_steps = TrainingSteps(4000)\n",
|
|
"schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(10)\n",
|
|
"schedule_params.evaluation_steps = EnvironmentEpisodes(1)\n",
|
|
"schedule_params.heatup_steps = EnvironmentSteps(1000)\n",
|
|
"\n",
|
|
"#########\n",
|
|
"# Agent #\n",
|
|
"#########\n",
|
|
"agent_params = DQNAgentParameters()\n",
|
|
"\n",
|
|
"# DQN params\n",
|
|
"agent_params.algorithm.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(100)\n",
|
|
"agent_params.algorithm.discount = 0.99\n",
|
|
"agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(1)\n",
|
|
"\n",
|
|
"# NN configuration\n",
|
|
"agent_params.network_wrappers['main'].learning_rate = 0.00025\n",
|
|
"agent_params.network_wrappers['main'].replace_mse_with_huber_loss = False\n",
|
|
"\n",
|
|
"# ER size\n",
|
|
"agent_params.memory.max_size = (MemoryGranularity.Transitions, 40000)\n",
|
|
"\n",
|
|
"################\n",
|
|
"# Environment #\n",
|
|
"################\n",
|
|
"env_params = GymVectorEnvironment(level='CartPole-v0')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Next, we'll override the exploration policy with our own policy defined in `Resources/exploration.py`.\n",
|
|
"We'll also define the checkpoint save directory and interval in seconds.\n",
|
|
"\n",
|
|
"Make sure the first cell at the top of this notebook is run before the following one, such that module_path and resources_path are adding to sys path."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from exploration import MyExplorationParameters\n",
|
|
"\n",
|
|
"# Overriding the default DQN Agent exploration policy with my exploration policy\n",
|
|
"agent_params.exploration = MyExplorationParameters()\n",
|
|
"\n",
|
|
"# Creating a graph manager to train a DQN agent to solve CartPole\n",
|
|
"graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,\n",
|
|
" schedule_params=schedule_params, vis_params=VisualizationParameters())\n",
|
|
"\n",
|
|
"# Resources path was defined at the top of this notebook\n",
|
|
"my_checkpoint_dir = resources_path + '/checkpoints'\n",
|
|
"\n",
|
|
"# Checkpoints will be stored every 5 seconds to the given directory\n",
|
|
"task_parameters1 = TaskParameters()\n",
|
|
"task_parameters1.checkpoint_save_dir = my_checkpoint_dir\n",
|
|
"task_parameters1.checkpoint_save_secs = 5\n",
|
|
"\n",
|
|
"graph_manager.create_graph(task_parameters1)\n",
|
|
"graph_manager.improve()\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Last, we'll load the latest checkpoint from the checkpoint directory, and evaluate it."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import tensorflow as tf\n",
|
|
"import shutil\n",
|
|
"\n",
|
|
"# Clearing the previous graph before creating the new one to avoid name conflicts\n",
|
|
"tf.reset_default_graph()\n",
|
|
"\n",
|
|
"# Updating the graph manager's task parameters to restore the latest stored checkpoint from the checkpoints directory\n",
|
|
"task_parameters2 = TaskParameters()\n",
|
|
"task_parameters2.checkpoint_restore_path = my_checkpoint_dir\n",
|
|
"\n",
|
|
"graph_manager.create_graph(task_parameters2)\n",
|
|
"graph_manager.evaluate(EnvironmentSteps(5))\n",
|
|
"\n",
|
|
"# Clearning up\n",
|
|
"shutil.rmtree(my_checkpoint_dir)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.6.4"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 1
|
|
}
|