mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
committed by
Gal Leibovich
parent
34bc292e60
commit
9e82c06be3
@@ -81,6 +81,7 @@
|
|||||||
"# Adding module path to sys path if not there, so rl_coach submodules can be imported\n",
|
"# Adding module path to sys path if not there, so rl_coach submodules can be imported\n",
|
||||||
"import os\n",
|
"import os\n",
|
||||||
"import sys\n",
|
"import sys\n",
|
||||||
|
"import tensorflow as tf\n",
|
||||||
"module_path = os.path.abspath(os.path.join('..'))\n",
|
"module_path = os.path.abspath(os.path.join('..'))\n",
|
||||||
"resources_path = os.path.abspath(os.path.join('Resources'))\n",
|
"resources_path = os.path.abspath(os.path.join('Resources'))\n",
|
||||||
"if module_path not in sys.path:\n",
|
"if module_path not in sys.path:\n",
|
||||||
@@ -156,6 +157,7 @@
|
|||||||
"from rl_coach.base_parameters import VisualizationParameters\n",
|
"from rl_coach.base_parameters import VisualizationParameters\n",
|
||||||
"from rl_coach.core_types import EnvironmentSteps\n",
|
"from rl_coach.core_types import EnvironmentSteps\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"tf.reset_default_graph()\n",
|
||||||
"coach = CoachInterface(preset='CartPole_ClippedPPO')\n",
|
"coach = CoachInterface(preset='CartPole_ClippedPPO')\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# registering an iteration signal before starting to run\n",
|
"# registering an iteration signal before starting to run\n",
|
||||||
@@ -242,6 +244,9 @@
|
|||||||
"from rl_coach.graph_managers.graph_manager import SimpleSchedule\n",
|
"from rl_coach.graph_managers.graph_manager import SimpleSchedule\n",
|
||||||
"from rl_coach.architectures.embedder_parameters import InputEmbedderParameters\n",
|
"from rl_coach.architectures.embedder_parameters import InputEmbedderParameters\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"# Resetting tensorflow graph as the network has changed.\n",
|
||||||
|
"tf.reset_default_graph()\n",
|
||||||
|
"\n",
|
||||||
"# define the environment parameters\n",
|
"# define the environment parameters\n",
|
||||||
"bit_length = 10\n",
|
"bit_length = 10\n",
|
||||||
"env_params = GymVectorEnvironment(level='rl_coach.environments.toy_problems.bit_flip:BitFlip')\n",
|
"env_params = GymVectorEnvironment(level='rl_coach.environments.toy_problems.bit_flip:BitFlip')\n",
|
||||||
@@ -310,6 +315,9 @@
|
|||||||
"# Graph Scheduling #\n",
|
"# Graph Scheduling #\n",
|
||||||
"####################\n",
|
"####################\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"# Resetting tensorflow graph as the network has changed.\n",
|
||||||
|
"tf.reset_default_graph()\n",
|
||||||
|
"\n",
|
||||||
"schedule_params = ScheduleParameters()\n",
|
"schedule_params = ScheduleParameters()\n",
|
||||||
"schedule_params.improve_steps = TrainingSteps(4000)\n",
|
"schedule_params.improve_steps = TrainingSteps(4000)\n",
|
||||||
"schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(10)\n",
|
"schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(10)\n",
|
||||||
@@ -405,13 +413,6 @@
|
|||||||
"# Clearning up\n",
|
"# Clearning up\n",
|
||||||
"shutil.rmtree(my_checkpoint_dir)"
|
"shutil.rmtree(my_checkpoint_dir)"
|
||||||
]
|
]
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": []
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
|||||||
@@ -54,7 +54,8 @@
|
|||||||
" sys.path.append(module_path)\n",
|
" sys.path.append(module_path)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"import tensorflow as tf\n",
|
"import tensorflow as tf\n",
|
||||||
"from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters\n",
|
"from rl_coach.architectures.tensorflow_components.heads.head import Head\n",
|
||||||
|
"from rl_coach.architectures.head_parameters import HeadParameters\n",
|
||||||
"from rl_coach.base_parameters import AgentParameters\n",
|
"from rl_coach.base_parameters import AgentParameters\n",
|
||||||
"from rl_coach.core_types import QActionStateValue\n",
|
"from rl_coach.core_types import QActionStateValue\n",
|
||||||
"from rl_coach.spaces import SpacesDefinition"
|
"from rl_coach.spaces import SpacesDefinition"
|
||||||
|
|||||||
Reference in New Issue
Block a user