1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

importing heads parameters from the correct file on tutorial #1 (#403)

This commit is contained in:
Pi Esposito
2019-09-24 14:44:49 -03:00
committed by Gal Leibovich
parent 34bc292e60
commit 9e82c06be3
2 changed files with 11 additions and 9 deletions

View File

@@ -81,6 +81,7 @@
"# Adding module path to sys path if not there, so rl_coach submodules can be imported\n", "# Adding module path to sys path if not there, so rl_coach submodules can be imported\n",
"import os\n", "import os\n",
"import sys\n", "import sys\n",
"import tensorflow as tf\n",
"module_path = os.path.abspath(os.path.join('..'))\n", "module_path = os.path.abspath(os.path.join('..'))\n",
"resources_path = os.path.abspath(os.path.join('Resources'))\n", "resources_path = os.path.abspath(os.path.join('Resources'))\n",
"if module_path not in sys.path:\n", "if module_path not in sys.path:\n",
@@ -156,6 +157,7 @@
"from rl_coach.base_parameters import VisualizationParameters\n", "from rl_coach.base_parameters import VisualizationParameters\n",
"from rl_coach.core_types import EnvironmentSteps\n", "from rl_coach.core_types import EnvironmentSteps\n",
"\n", "\n",
"tf.reset_default_graph()\n",
"coach = CoachInterface(preset='CartPole_ClippedPPO')\n", "coach = CoachInterface(preset='CartPole_ClippedPPO')\n",
"\n", "\n",
"# registering an iteration signal before starting to run\n", "# registering an iteration signal before starting to run\n",
@@ -242,6 +244,9 @@
"from rl_coach.graph_managers.graph_manager import SimpleSchedule\n", "from rl_coach.graph_managers.graph_manager import SimpleSchedule\n",
"from rl_coach.architectures.embedder_parameters import InputEmbedderParameters\n", "from rl_coach.architectures.embedder_parameters import InputEmbedderParameters\n",
"\n", "\n",
"# Resetting tensorflow graph as the network has changed.\n",
"tf.reset_default_graph()\n",
"\n",
"# define the environment parameters\n", "# define the environment parameters\n",
"bit_length = 10\n", "bit_length = 10\n",
"env_params = GymVectorEnvironment(level='rl_coach.environments.toy_problems.bit_flip:BitFlip')\n", "env_params = GymVectorEnvironment(level='rl_coach.environments.toy_problems.bit_flip:BitFlip')\n",
@@ -310,6 +315,9 @@
"# Graph Scheduling #\n", "# Graph Scheduling #\n",
"####################\n", "####################\n",
"\n", "\n",
"# Resetting tensorflow graph as the network has changed.\n",
"tf.reset_default_graph()\n",
"\n",
"schedule_params = ScheduleParameters()\n", "schedule_params = ScheduleParameters()\n",
"schedule_params.improve_steps = TrainingSteps(4000)\n", "schedule_params.improve_steps = TrainingSteps(4000)\n",
"schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(10)\n", "schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(10)\n",
@@ -405,13 +413,6 @@
"# Clearning up\n", "# Clearning up\n",
"shutil.rmtree(my_checkpoint_dir)" "shutil.rmtree(my_checkpoint_dir)"
] ]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
} }
], ],
"metadata": { "metadata": {

View File

@@ -54,7 +54,8 @@
" sys.path.append(module_path)\n", " sys.path.append(module_path)\n",
"\n", "\n",
"import tensorflow as tf\n", "import tensorflow as tf\n",
"from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters\n", "from rl_coach.architectures.tensorflow_components.heads.head import Head\n",
"from rl_coach.architectures.head_parameters import HeadParameters\n",
"from rl_coach.base_parameters import AgentParameters\n", "from rl_coach.base_parameters import AgentParameters\n",
"from rl_coach.core_types import QActionStateValue\n", "from rl_coach.core_types import QActionStateValue\n",
"from rl_coach.spaces import SpacesDefinition" "from rl_coach.spaces import SpacesDefinition"