mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
committed by
Gal Leibovich
parent
34bc292e60
commit
9e82c06be3
@@ -81,13 +81,14 @@
|
||||
"# Adding module path to sys path if not there, so rl_coach submodules can be imported\n",
|
||||
"import os\n",
|
||||
"import sys\n",
|
||||
"import tensorflow as tf\n",
|
||||
"module_path = os.path.abspath(os.path.join('..'))\n",
|
||||
"resources_path = os.path.abspath(os.path.join('Resources'))\n",
|
||||
"if module_path not in sys.path:\n",
|
||||
" sys.path.append(module_path)\n",
|
||||
"if resources_path not in sys.path:\n",
|
||||
" sys.path.append(resources_path)\n",
|
||||
" \n",
|
||||
" \n",
|
||||
"from rl_coach.coach import CoachInterface"
|
||||
]
|
||||
},
|
||||
@@ -156,6 +157,7 @@
|
||||
"from rl_coach.base_parameters import VisualizationParameters\n",
|
||||
"from rl_coach.core_types import EnvironmentSteps\n",
|
||||
"\n",
|
||||
"tf.reset_default_graph()\n",
|
||||
"coach = CoachInterface(preset='CartPole_ClippedPPO')\n",
|
||||
"\n",
|
||||
"# registering an iteration signal before starting to run\n",
|
||||
@@ -242,6 +244,9 @@
|
||||
"from rl_coach.graph_managers.graph_manager import SimpleSchedule\n",
|
||||
"from rl_coach.architectures.embedder_parameters import InputEmbedderParameters\n",
|
||||
"\n",
|
||||
"# Resetting tensorflow graph as the network has changed.\n",
|
||||
"tf.reset_default_graph()\n",
|
||||
"\n",
|
||||
"# define the environment parameters\n",
|
||||
"bit_length = 10\n",
|
||||
"env_params = GymVectorEnvironment(level='rl_coach.environments.toy_problems.bit_flip:BitFlip')\n",
|
||||
@@ -310,6 +315,9 @@
|
||||
"# Graph Scheduling #\n",
|
||||
"####################\n",
|
||||
"\n",
|
||||
"# Resetting tensorflow graph as the network has changed.\n",
|
||||
"tf.reset_default_graph()\n",
|
||||
"\n",
|
||||
"schedule_params = ScheduleParameters()\n",
|
||||
"schedule_params.improve_steps = TrainingSteps(4000)\n",
|
||||
"schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(10)\n",
|
||||
@@ -405,13 +413,6 @@
|
||||
"# Clearning up\n",
|
||||
"shutil.rmtree(my_checkpoint_dir)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
@@ -54,7 +54,8 @@
|
||||
" sys.path.append(module_path)\n",
|
||||
"\n",
|
||||
"import tensorflow as tf\n",
|
||||
"from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters\n",
|
||||
"from rl_coach.architectures.tensorflow_components.heads.head import Head\n",
|
||||
"from rl_coach.architectures.head_parameters import HeadParameters\n",
|
||||
"from rl_coach.base_parameters import AgentParameters\n",
|
||||
"from rl_coach.core_types import QActionStateValue\n",
|
||||
"from rl_coach.spaces import SpacesDefinition"
|
||||
|
||||
Reference in New Issue
Block a user