1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

network_imporvements branch merge

This commit is contained in:
Shadi Endrawis
2018-10-02 13:41:46 +03:00
parent 72ea933384
commit 51726a5b80
110 changed files with 1639 additions and 1161 deletions

View File

@@ -30,7 +30,7 @@ from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, T
from rl_coach.environments.environment import Environment
from rl_coach.level_manager import LevelManager
from rl_coach.logger import screen, Logger
from rl_coach.utils import set_cpu
from rl_coach.utils import set_cpu, start_shell_command_and_wait
class ScheduleParameters(Parameters):
@@ -51,6 +51,27 @@ class HumanPlayScheduleParameters(ScheduleParameters):
self.improve_steps = TrainingSteps(10000000000)
class SimpleScheduleWithoutEvaluation(ScheduleParameters):
def __init__(self, improve_steps=TrainingSteps(10000000000)):
super().__init__()
self.heatup_steps = EnvironmentSteps(0)
self.evaluation_steps = EnvironmentEpisodes(0)
self.steps_between_evaluation_periods = improve_steps
self.improve_steps = improve_steps
class SimpleSchedule(ScheduleParameters):
def __init__(self,
improve_steps=TrainingSteps(10000000000),
steps_between_evaluation_periods=EnvironmentEpisodes(50),
evaluation_steps=EnvironmentEpisodes(5)):
super().__init__()
self.heatup_steps = EnvironmentSteps(0)
self.evaluation_steps = evaluation_steps
self.steps_between_evaluation_periods = steps_between_evaluation_periods
self.improve_steps = improve_steps
class GraphManager(object):
"""
A graph manager is responsible for creating and initializing a graph of agents, including all its internal
@@ -78,6 +99,7 @@ class GraphManager(object):
# timers
self.graph_initialization_time = time.time()
self.graph_creation_time = None
self.heatup_start_time = None
self.training_start_time = None
self.last_evaluation_start_time = None
@@ -94,7 +116,8 @@ class GraphManager(object):
self.checkpoint_saver = None
self.graph_logger = Logger()
def create_graph(self, task_parameters: TaskParameters):
def create_graph(self, task_parameters: TaskParameters=TaskParameters()):
self.graph_creation_time = time.time()
self.task_parameters = task_parameters
if isinstance(task_parameters, DistributedTaskParameters):
@@ -129,6 +152,8 @@ class GraphManager(object):
self.setup_logger()
return self
def _create_graph(self, task_parameters: TaskParameters) -> Tuple[List[LevelManager], List[Environment]]:
"""
Create all the graph modules and the graph scheduler
@@ -207,6 +232,29 @@ class GraphManager(object):
# restore from checkpoint if given
self.restore_checkpoint()
# tf.train.write_graph(tf.get_default_graph(),
# logdir=self.task_parameters.save_checkpoint_dir,
# name='graphdef.pb',
# as_text=False)
# self.save_checkpoint()
#
# output_nodes = []
# for level in self.level_managers:
# for agent in level.agents.values():
# for network in agent.networks.values():
# for output in network.online_network.outputs:
# output_nodes.append(output.name.split(":")[0])
#
# freeze_graph_command = [
# "python -m tensorflow.python.tools.freeze_graph",
# "--input_graph={}".format(os.path.join(self.task_parameters.save_checkpoint_dir, "graphdef.pb")),
# "--input_binary=true",
# "--output_node_names='{}'".format(','.join(output_nodes)),
# "--input_checkpoint={}".format(os.path.join(self.task_parameters.save_checkpoint_dir, "0_Step-0.ckpt")),
# "--output_graph={}".format(os.path.join(self.task_parameters.save_checkpoint_dir, "frozen_graph.pb"))
# ]
# start_shell_command_and_wait(" ".join(freeze_graph_command))
def setup_logger(self) -> None:
# dump documentation
logger_prefix = "{graph_name}".format(graph_name=self.name)
@@ -250,6 +298,8 @@ class GraphManager(object):
:param steps: the number of steps as a tuple of steps time and steps count
:return: None
"""
self.verify_graph_was_created()
steps_copy = copy.copy(steps)
if steps_copy.num_steps > 0:
@@ -284,6 +334,8 @@ class GraphManager(object):
:param steps: number of training iterations to perform
:return: None
"""
self.verify_graph_was_created()
# perform several steps of training interleaved with acting
count_end = self.total_steps_counters[RunPhase.TRAIN][TrainingSteps] + steps.num_steps
while self.total_steps_counters[RunPhase.TRAIN][TrainingSteps] < count_end:
@@ -299,6 +351,8 @@ class GraphManager(object):
lives available
:return: None
"""
self.verify_graph_was_created()
self.reset_required = False
[environment.reset_internal_state(force_environment_reset) for environment in self.environments]
[manager.reset_internal_state() for manager in self.level_managers]
@@ -314,6 +368,8 @@ class GraphManager(object):
:return: the actual number of steps done, a boolean value that represent if the episode was done when finishing
the function call
"""
self.verify_graph_was_created()
# perform several steps of playing
result = None
@@ -366,6 +422,8 @@ class GraphManager(object):
:param steps: the number of steps as a tuple of steps time and steps count
:return: None
"""
self.verify_graph_was_created()
# perform several steps of training interleaved with acting
if steps.num_steps > 0:
self.phase = RunPhase.TRAIN
@@ -395,6 +453,8 @@ class GraphManager(object):
:param keep_networks_in_sync: sync the network parameters with the global network before each episode
:return: None
"""
self.verify_graph_was_created()
if steps.num_steps > 0:
self.phase = RunPhase.TEST
self.last_evaluation_start_time = time.time()
@@ -411,6 +471,8 @@ class GraphManager(object):
self.phase = RunPhase.UNDEFINED
def restore_checkpoint(self):
self.verify_graph_was_created()
# TODO: find better way to load checkpoints that were saved with a global network into the online network
if hasattr(self.task_parameters, 'checkpoint_restore_dir') and self.task_parameters.checkpoint_restore_dir:
import tensorflow as tf
@@ -473,6 +535,7 @@ class GraphManager(object):
2.2. Evaluate
:return: None
"""
self.verify_graph_was_created()
# initialize the network parameters from the global network
self.sync_graph()
@@ -491,6 +554,14 @@ class GraphManager(object):
self.train_and_act(self.steps_between_evaluation_periods)
self.evaluate(self.evaluation_steps)
def verify_graph_was_created(self):
"""
Verifies that the graph was already created, and if not, it creates it with the default task parameters
:return: None
"""
if self.graph_creation_time is None:
self.create_graph()
def __str__(self):
result = ""
for key, val in self.__dict__.items():