mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
renamed quick start guide tutorial
This commit is contained in:
196
tutorials/0. Quick Start Guide.ipynb
Normal file
196
tutorials/0. Quick Start Guide.ipynb
Normal file
@@ -0,0 +1,196 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Getting Started Guide"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Creating a very simple graph containing a single clipped ppo agent running with the CartPole-v0 Gym environment:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters\n",
|
||||
"from rl_coach.environments.gym_environment import GymVectorEnvironment\n",
|
||||
"from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager\n",
|
||||
"from rl_coach.graph_managers.graph_manager import SimpleSchedule\n",
|
||||
"\n",
|
||||
"graph_manager = BasicRLGraphManager(\n",
|
||||
" agent_params=ClippedPPOAgentParameters(),\n",
|
||||
" env_params=GymVectorEnvironment(level='CartPole-v0'),\n",
|
||||
" schedule_params=SimpleSchedule()\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Running the graph according to the given schedule:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"graph_manager.improve()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Running each phase manually:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from rl_coach.core_types import EnvironmentSteps\n",
|
||||
"\n",
|
||||
"graph_manager = BasicRLGraphManager(\n",
|
||||
" agent_params=ClippedPPOAgentParameters(),\n",
|
||||
" env_params=GymVectorEnvironment(level='CartPole-v0'),\n",
|
||||
" schedule_params=SimpleSchedule()\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"graph_manager.heatup(EnvironmentSteps(100))\n",
|
||||
"graph_manager.train_and_act(EnvironmentSteps(100))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Changing the default parameters"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters\n",
|
||||
"from rl_coach.environments.gym_environment import GymVectorEnvironment\n",
|
||||
"from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager\n",
|
||||
"from rl_coach.graph_managers.graph_manager import SimpleSchedule\n",
|
||||
"from rl_coach.graph_managers.graph_manager import ScheduleParameters\n",
|
||||
"from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps\n",
|
||||
"\n",
|
||||
"# schedule\n",
|
||||
"schedule_params = ScheduleParameters()\n",
|
||||
"schedule_params.improve_steps = TrainingSteps(10000000)\n",
|
||||
"schedule_params.steps_between_evaluation_periods = EnvironmentSteps(2048)\n",
|
||||
"schedule_params.evaluation_steps = EnvironmentEpisodes(5)\n",
|
||||
"schedule_params.heatup_steps = EnvironmentSteps(0)\n",
|
||||
"\n",
|
||||
"# agent parameters\n",
|
||||
"agent_params = ClippedPPOAgentParameters()\n",
|
||||
"agent_params.algorithm.discount = 1.0\n",
|
||||
"\n",
|
||||
"graph_manager = BasicRLGraphManager(\n",
|
||||
" agent_params=agent_params,\n",
|
||||
" env_params=GymVectorEnvironment(level='CartPole-v0'),\n",
|
||||
" schedule_params=schedule_params\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"graph_manager.improve()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Using a custom gym environment\n",
|
||||
"\n",
|
||||
"We can use a custom gym environment without registering it. \n",
|
||||
"We just need the path to the environment module.\n",
|
||||
"We can also pass custom parameters for the environment __init__"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters\n",
|
||||
"from rl_coach.environments.gym_environment import GymVectorEnvironment\n",
|
||||
"from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager\n",
|
||||
"from rl_coach.graph_managers.graph_manager import SimpleSchedule\n",
|
||||
"from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters\n",
|
||||
"\n",
|
||||
"# define the environment parameters\n",
|
||||
"bit_length = 10\n",
|
||||
"env_params = GymVectorEnvironment(level='rl_coach.environments.toy_problems.bit_flip:BitFlip')\n",
|
||||
"env_params.additional_simulator_parameters = {'bit_length': bit_length, 'mean_zero': True}\n",
|
||||
"\n",
|
||||
"# Clipped PPO\n",
|
||||
"agent_params = ClippedPPOAgentParameters()\n",
|
||||
"agent_params.network_wrappers['main'].input_embedders_parameters = {\n",
|
||||
" 'state': InputEmbedderParameters(scheme=[]),\n",
|
||||
" 'desired_goal': InputEmbedderParameters(scheme=[])\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"graph_manager = BasicRLGraphManager(\n",
|
||||
" agent_params=agent_params,\n",
|
||||
" env_params=env_params,\n",
|
||||
" schedule_params=SimpleSchedule()\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"graph_manager.improve()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.5.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
Reference in New Issue
Block a user