1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

renamed quick start guide tutorial

This commit is contained in:
Gal Novik
2018-10-03 18:15:29 +03:00
parent f7990d4003
commit 5c4f9d58dd

View File

@@ -0,0 +1,196 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Getting Started Guide"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Creating a very simple graph containing a single clipped ppo agent running with the CartPole-v0 Gym environment:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters\n",
"from rl_coach.environments.gym_environment import GymVectorEnvironment\n",
"from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager\n",
"from rl_coach.graph_managers.graph_manager import SimpleSchedule\n",
"\n",
"graph_manager = BasicRLGraphManager(\n",
" agent_params=ClippedPPOAgentParameters(),\n",
" env_params=GymVectorEnvironment(level='CartPole-v0'),\n",
" schedule_params=SimpleSchedule()\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Running the graph according to the given schedule:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"graph_manager.improve()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Running each phase manually:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from rl_coach.core_types import EnvironmentSteps\n",
"\n",
"graph_manager = BasicRLGraphManager(\n",
" agent_params=ClippedPPOAgentParameters(),\n",
" env_params=GymVectorEnvironment(level='CartPole-v0'),\n",
" schedule_params=SimpleSchedule()\n",
")\n",
"\n",
"graph_manager.heatup(EnvironmentSteps(100))\n",
"graph_manager.train_and_act(EnvironmentSteps(100))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Changing the default parameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters\n",
"from rl_coach.environments.gym_environment import GymVectorEnvironment\n",
"from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager\n",
"from rl_coach.graph_managers.graph_manager import SimpleSchedule\n",
"from rl_coach.graph_managers.graph_manager import ScheduleParameters\n",
"from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps\n",
"\n",
"# schedule\n",
"schedule_params = ScheduleParameters()\n",
"schedule_params.improve_steps = TrainingSteps(10000000)\n",
"schedule_params.steps_between_evaluation_periods = EnvironmentSteps(2048)\n",
"schedule_params.evaluation_steps = EnvironmentEpisodes(5)\n",
"schedule_params.heatup_steps = EnvironmentSteps(0)\n",
"\n",
"# agent parameters\n",
"agent_params = ClippedPPOAgentParameters()\n",
"agent_params.algorithm.discount = 1.0\n",
"\n",
"graph_manager = BasicRLGraphManager(\n",
" agent_params=agent_params,\n",
" env_params=GymVectorEnvironment(level='CartPole-v0'),\n",
" schedule_params=schedule_params\n",
")\n",
"\n",
"graph_manager.improve()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using a custom gym environment\n",
"\n",
"We can use a custom gym environment without registering it. \n",
"We just need the path to the environment module.\n",
"We can also pass custom parameters for the environment __init__"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters\n",
"from rl_coach.environments.gym_environment import GymVectorEnvironment\n",
"from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager\n",
"from rl_coach.graph_managers.graph_manager import SimpleSchedule\n",
"from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters\n",
"\n",
"# define the environment parameters\n",
"bit_length = 10\n",
"env_params = GymVectorEnvironment(level='rl_coach.environments.toy_problems.bit_flip:BitFlip')\n",
"env_params.additional_simulator_parameters = {'bit_length': bit_length, 'mean_zero': True}\n",
"\n",
"# Clipped PPO\n",
"agent_params = ClippedPPOAgentParameters()\n",
"agent_params.network_wrappers['main'].input_embedders_parameters = {\n",
" 'state': InputEmbedderParameters(scheme=[]),\n",
" 'desired_goal': InputEmbedderParameters(scheme=[])\n",
"}\n",
"\n",
"graph_manager = BasicRLGraphManager(\n",
" agent_params=agent_params,\n",
" env_params=env_params,\n",
" schedule_params=SimpleSchedule()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"graph_manager.improve()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}