{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Getting Started Guide" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Creating a very simple graph containing a single clipped ppo agent running with the CartPole-v0 Gym environment:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters\n", "from rl_coach.environments.gym_environment import GymVectorEnvironment\n", "from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager\n", "from rl_coach.graph_managers.graph_manager import SimpleSchedule\n", "\n", "graph_manager = BasicRLGraphManager(\n", " agent_params=ClippedPPOAgentParameters(),\n", " env_params=GymVectorEnvironment(level='CartPole-v0'),\n", " schedule_params=SimpleSchedule()\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Running the graph according to the given schedule:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "graph_manager.improve()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Running each phase manually:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from rl_coach.core_types import EnvironmentSteps\n", "\n", "graph_manager = BasicRLGraphManager(\n", " agent_params=ClippedPPOAgentParameters(),\n", " env_params=GymVectorEnvironment(level='CartPole-v0'),\n", " schedule_params=SimpleSchedule()\n", ")\n", "\n", "graph_manager.heatup(EnvironmentSteps(100))\n", "graph_manager.train_and_act(EnvironmentSteps(100))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Changing the default parameters" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters\n", "from rl_coach.environments.gym_environment import GymVectorEnvironment\n", "from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager\n", "from rl_coach.graph_managers.graph_manager import SimpleSchedule\n", "from rl_coach.graph_managers.graph_manager import ScheduleParameters\n", "from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps\n", "\n", "# schedule\n", "schedule_params = ScheduleParameters()\n", "schedule_params.improve_steps = TrainingSteps(10000000)\n", "schedule_params.steps_between_evaluation_periods = EnvironmentSteps(2048)\n", "schedule_params.evaluation_steps = EnvironmentEpisodes(5)\n", "schedule_params.heatup_steps = EnvironmentSteps(0)\n", "\n", "# agent parameters\n", "agent_params = ClippedPPOAgentParameters()\n", "agent_params.algorithm.discount = 1.0\n", "\n", "graph_manager = BasicRLGraphManager(\n", " agent_params=agent_params,\n", " env_params=GymVectorEnvironment(level='CartPole-v0'),\n", " schedule_params=schedule_params\n", ")\n", "\n", "graph_manager.improve()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using a custom gym environment\n", "\n", "We can use a custom gym environment without registering it. \n", "We just need the path to the environment module.\n", "We can also pass custom parameters for the environment __init__" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters\n", "from rl_coach.environments.gym_environment import GymVectorEnvironment\n", "from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager\n", "from rl_coach.graph_managers.graph_manager import SimpleSchedule\n", "from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters\n", "\n", "# define the environment parameters\n", "bit_length = 10\n", "env_params = GymVectorEnvironment(level='rl_coach.environments.toy_problems.bit_flip:BitFlip')\n", "env_params.additional_simulator_parameters = {'bit_length': bit_length, 'mean_zero': True}\n", "\n", "# Clipped PPO\n", "agent_params = ClippedPPOAgentParameters()\n", "agent_params.network_wrappers['main'].input_embedders_parameters = {\n", " 'state': InputEmbedderParameters(scheme=[]),\n", " 'desired_goal': InputEmbedderParameters(scheme=[])\n", "}\n", "\n", "graph_manager = BasicRLGraphManager(\n", " agent_params=agent_params,\n", " env_params=env_params,\n", " schedule_params=SimpleSchedule()\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "graph_manager.improve()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" } }, "nbformat": 4, "nbformat_minor": 2 }