mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
network_imporvements branch merge
This commit is contained in:
@@ -30,7 +30,7 @@ from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, T
|
||||
from rl_coach.environments.environment import Environment
|
||||
from rl_coach.level_manager import LevelManager
|
||||
from rl_coach.logger import screen, Logger
|
||||
from rl_coach.utils import set_cpu
|
||||
from rl_coach.utils import set_cpu, start_shell_command_and_wait
|
||||
|
||||
|
||||
class ScheduleParameters(Parameters):
|
||||
@@ -51,6 +51,27 @@ class HumanPlayScheduleParameters(ScheduleParameters):
|
||||
self.improve_steps = TrainingSteps(10000000000)
|
||||
|
||||
|
||||
class SimpleScheduleWithoutEvaluation(ScheduleParameters):
|
||||
def __init__(self, improve_steps=TrainingSteps(10000000000)):
|
||||
super().__init__()
|
||||
self.heatup_steps = EnvironmentSteps(0)
|
||||
self.evaluation_steps = EnvironmentEpisodes(0)
|
||||
self.steps_between_evaluation_periods = improve_steps
|
||||
self.improve_steps = improve_steps
|
||||
|
||||
|
||||
class SimpleSchedule(ScheduleParameters):
|
||||
def __init__(self,
|
||||
improve_steps=TrainingSteps(10000000000),
|
||||
steps_between_evaluation_periods=EnvironmentEpisodes(50),
|
||||
evaluation_steps=EnvironmentEpisodes(5)):
|
||||
super().__init__()
|
||||
self.heatup_steps = EnvironmentSteps(0)
|
||||
self.evaluation_steps = evaluation_steps
|
||||
self.steps_between_evaluation_periods = steps_between_evaluation_periods
|
||||
self.improve_steps = improve_steps
|
||||
|
||||
|
||||
class GraphManager(object):
|
||||
"""
|
||||
A graph manager is responsible for creating and initializing a graph of agents, including all its internal
|
||||
@@ -78,6 +99,7 @@ class GraphManager(object):
|
||||
|
||||
# timers
|
||||
self.graph_initialization_time = time.time()
|
||||
self.graph_creation_time = None
|
||||
self.heatup_start_time = None
|
||||
self.training_start_time = None
|
||||
self.last_evaluation_start_time = None
|
||||
@@ -94,7 +116,8 @@ class GraphManager(object):
|
||||
self.checkpoint_saver = None
|
||||
self.graph_logger = Logger()
|
||||
|
||||
def create_graph(self, task_parameters: TaskParameters):
|
||||
def create_graph(self, task_parameters: TaskParameters=TaskParameters()):
|
||||
self.graph_creation_time = time.time()
|
||||
self.task_parameters = task_parameters
|
||||
|
||||
if isinstance(task_parameters, DistributedTaskParameters):
|
||||
@@ -129,6 +152,8 @@ class GraphManager(object):
|
||||
|
||||
self.setup_logger()
|
||||
|
||||
return self
|
||||
|
||||
def _create_graph(self, task_parameters: TaskParameters) -> Tuple[List[LevelManager], List[Environment]]:
|
||||
"""
|
||||
Create all the graph modules and the graph scheduler
|
||||
@@ -207,6 +232,29 @@ class GraphManager(object):
|
||||
# restore from checkpoint if given
|
||||
self.restore_checkpoint()
|
||||
|
||||
# tf.train.write_graph(tf.get_default_graph(),
|
||||
# logdir=self.task_parameters.save_checkpoint_dir,
|
||||
# name='graphdef.pb',
|
||||
# as_text=False)
|
||||
# self.save_checkpoint()
|
||||
#
|
||||
# output_nodes = []
|
||||
# for level in self.level_managers:
|
||||
# for agent in level.agents.values():
|
||||
# for network in agent.networks.values():
|
||||
# for output in network.online_network.outputs:
|
||||
# output_nodes.append(output.name.split(":")[0])
|
||||
#
|
||||
# freeze_graph_command = [
|
||||
# "python -m tensorflow.python.tools.freeze_graph",
|
||||
# "--input_graph={}".format(os.path.join(self.task_parameters.save_checkpoint_dir, "graphdef.pb")),
|
||||
# "--input_binary=true",
|
||||
# "--output_node_names='{}'".format(','.join(output_nodes)),
|
||||
# "--input_checkpoint={}".format(os.path.join(self.task_parameters.save_checkpoint_dir, "0_Step-0.ckpt")),
|
||||
# "--output_graph={}".format(os.path.join(self.task_parameters.save_checkpoint_dir, "frozen_graph.pb"))
|
||||
# ]
|
||||
# start_shell_command_and_wait(" ".join(freeze_graph_command))
|
||||
|
||||
def setup_logger(self) -> None:
|
||||
# dump documentation
|
||||
logger_prefix = "{graph_name}".format(graph_name=self.name)
|
||||
@@ -250,6 +298,8 @@ class GraphManager(object):
|
||||
:param steps: the number of steps as a tuple of steps time and steps count
|
||||
:return: None
|
||||
"""
|
||||
self.verify_graph_was_created()
|
||||
|
||||
steps_copy = copy.copy(steps)
|
||||
|
||||
if steps_copy.num_steps > 0:
|
||||
@@ -284,6 +334,8 @@ class GraphManager(object):
|
||||
:param steps: number of training iterations to perform
|
||||
:return: None
|
||||
"""
|
||||
self.verify_graph_was_created()
|
||||
|
||||
# perform several steps of training interleaved with acting
|
||||
count_end = self.total_steps_counters[RunPhase.TRAIN][TrainingSteps] + steps.num_steps
|
||||
while self.total_steps_counters[RunPhase.TRAIN][TrainingSteps] < count_end:
|
||||
@@ -299,6 +351,8 @@ class GraphManager(object):
|
||||
lives available
|
||||
:return: None
|
||||
"""
|
||||
self.verify_graph_was_created()
|
||||
|
||||
self.reset_required = False
|
||||
[environment.reset_internal_state(force_environment_reset) for environment in self.environments]
|
||||
[manager.reset_internal_state() for manager in self.level_managers]
|
||||
@@ -314,6 +368,8 @@ class GraphManager(object):
|
||||
:return: the actual number of steps done, a boolean value that represent if the episode was done when finishing
|
||||
the function call
|
||||
"""
|
||||
self.verify_graph_was_created()
|
||||
|
||||
# perform several steps of playing
|
||||
result = None
|
||||
|
||||
@@ -366,6 +422,8 @@ class GraphManager(object):
|
||||
:param steps: the number of steps as a tuple of steps time and steps count
|
||||
:return: None
|
||||
"""
|
||||
self.verify_graph_was_created()
|
||||
|
||||
# perform several steps of training interleaved with acting
|
||||
if steps.num_steps > 0:
|
||||
self.phase = RunPhase.TRAIN
|
||||
@@ -395,6 +453,8 @@ class GraphManager(object):
|
||||
:param keep_networks_in_sync: sync the network parameters with the global network before each episode
|
||||
:return: None
|
||||
"""
|
||||
self.verify_graph_was_created()
|
||||
|
||||
if steps.num_steps > 0:
|
||||
self.phase = RunPhase.TEST
|
||||
self.last_evaluation_start_time = time.time()
|
||||
@@ -411,6 +471,8 @@ class GraphManager(object):
|
||||
self.phase = RunPhase.UNDEFINED
|
||||
|
||||
def restore_checkpoint(self):
|
||||
self.verify_graph_was_created()
|
||||
|
||||
# TODO: find better way to load checkpoints that were saved with a global network into the online network
|
||||
if hasattr(self.task_parameters, 'checkpoint_restore_dir') and self.task_parameters.checkpoint_restore_dir:
|
||||
import tensorflow as tf
|
||||
@@ -473,6 +535,7 @@ class GraphManager(object):
|
||||
2.2. Evaluate
|
||||
:return: None
|
||||
"""
|
||||
self.verify_graph_was_created()
|
||||
|
||||
# initialize the network parameters from the global network
|
||||
self.sync_graph()
|
||||
@@ -491,6 +554,14 @@ class GraphManager(object):
|
||||
self.train_and_act(self.steps_between_evaluation_periods)
|
||||
self.evaluate(self.evaluation_steps)
|
||||
|
||||
def verify_graph_was_created(self):
|
||||
"""
|
||||
Verifies that the graph was already created, and if not, it creates it with the default task parameters
|
||||
:return: None
|
||||
"""
|
||||
if self.graph_creation_time is None:
|
||||
self.create_graph()
|
||||
|
||||
def __str__(self):
|
||||
result = ""
|
||||
for key, val in self.__dict__.items():
|
||||
|
||||
Reference in New Issue
Block a user