1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00
Files
coach/tutorials/0. Quick Start Guide.ipynb

197 lines
5.6 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Getting Started Guide"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Creating a very simple graph containing a single clipped ppo agent running with the CartPole-v0 Gym environment:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters\n",
"from rl_coach.environments.gym_environment import GymVectorEnvironment\n",
"from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager\n",
"from rl_coach.graph_managers.graph_manager import SimpleSchedule\n",
"\n",
"graph_manager = BasicRLGraphManager(\n",
" agent_params=ClippedPPOAgentParameters(),\n",
" env_params=GymVectorEnvironment(level='CartPole-v0'),\n",
" schedule_params=SimpleSchedule()\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Running the graph according to the given schedule:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"graph_manager.improve()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Running each phase manually:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from rl_coach.core_types import EnvironmentSteps\n",
"\n",
"graph_manager = BasicRLGraphManager(\n",
" agent_params=ClippedPPOAgentParameters(),\n",
" env_params=GymVectorEnvironment(level='CartPole-v0'),\n",
" schedule_params=SimpleSchedule()\n",
")\n",
"\n",
"graph_manager.heatup(EnvironmentSteps(100))\n",
"graph_manager.train_and_act(EnvironmentSteps(100))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Changing the default parameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters\n",
"from rl_coach.environments.gym_environment import GymVectorEnvironment\n",
"from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager\n",
"from rl_coach.graph_managers.graph_manager import SimpleSchedule\n",
"from rl_coach.graph_managers.graph_manager import ScheduleParameters\n",
"from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps\n",
"\n",
"# schedule\n",
"schedule_params = ScheduleParameters()\n",
"schedule_params.improve_steps = TrainingSteps(10000000)\n",
"schedule_params.steps_between_evaluation_periods = EnvironmentSteps(2048)\n",
"schedule_params.evaluation_steps = EnvironmentEpisodes(5)\n",
"schedule_params.heatup_steps = EnvironmentSteps(0)\n",
"\n",
"# agent parameters\n",
"agent_params = ClippedPPOAgentParameters()\n",
"agent_params.algorithm.discount = 1.0\n",
"\n",
"graph_manager = BasicRLGraphManager(\n",
" agent_params=agent_params,\n",
" env_params=GymVectorEnvironment(level='CartPole-v0'),\n",
" schedule_params=schedule_params\n",
")\n",
"\n",
"graph_manager.improve()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using a custom gym environment\n",
"\n",
"We can use a custom gym environment without registering it. \n",
"We just need the path to the environment module.\n",
"We can also pass custom parameters for the environment __init__"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters\n",
"from rl_coach.environments.gym_environment import GymVectorEnvironment\n",
"from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager\n",
"from rl_coach.graph_managers.graph_manager import SimpleSchedule\n",
"from rl_coach.architectures.embedder_parameters import InputEmbedderParameters\n",
"\n",
"# define the environment parameters\n",
"bit_length = 10\n",
"env_params = GymVectorEnvironment(level='rl_coach.environments.toy_problems.bit_flip:BitFlip')\n",
"env_params.additional_simulator_parameters = {'bit_length': bit_length, 'mean_zero': True}\n",
"\n",
"# Clipped PPO\n",
"agent_params = ClippedPPOAgentParameters()\n",
"agent_params.network_wrappers['main'].input_embedders_parameters = {\n",
" 'state': InputEmbedderParameters(scheme=[]),\n",
" 'desired_goal': InputEmbedderParameters(scheme=[])\n",
"}\n",
"\n",
"graph_manager = BasicRLGraphManager(\n",
" agent_params=agent_params,\n",
" env_params=env_params,\n",
" schedule_params=SimpleSchedule()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"graph_manager.improve()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}