From 496a516de1724c2fc808db1654208f07249c0d7b Mon Sep 17 00:00:00 2001 From: Zach Dwiel Date: Thu, 4 Oct 2018 16:43:38 -0400 Subject: [PATCH] rename GraphManager.sync_graph -> sync --- rl_coach/graph_managers/graph_manager.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index f0788ff..03da59d 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -362,10 +362,6 @@ class GraphManager(object): # perform several steps of playing count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps - - # The assumption here is that the total_steps_counters are each updated when an event - # takes place (i.e. an episode ends) - # TODO - The counter of frames is not updated correctly. need to fix that. while self.total_steps_counters[self.phase][steps.__class__] < count_end: # reset the environment if the previous episode was terminated if self.reset_required: @@ -398,8 +394,7 @@ class GraphManager(object): if steps.num_steps > 0: with self.phase_context(RunPhase.TRAIN): self.reset_internal_state(force_environment_reset=True) - #TODO - the below while loop should end with full episodes, so to avoid situations where we have partial - # episodes in memory + count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps while self.total_steps_counters[self.phase][steps.__class__] < count_end: # The actual steps being done on the environment are decided by the agents themselves. @@ -408,7 +403,7 @@ class GraphManager(object): self.train() self.occasionally_save_checkpoint() - def sync_graph(self) -> None: + def sync(self) -> None: """ Sync the global network parameters to the graph :return: @@ -428,13 +423,13 @@ class GraphManager(object): with self.phase_context(RunPhase.TEST): # reset all the levels before starting to evaluate self.reset_internal_state(force_environment_reset=True) - self.sync_graph() + self.sync() # act for at least `steps`, though don't interrupt an episode count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps while self.total_steps_counters[self.phase][steps.__class__] < count_end: self.act(EnvironmentEpisodes(1)) - self.sync_graph() + self.sync() def restore_checkpoint(self): self.verify_graph_was_created() @@ -511,7 +506,7 @@ class GraphManager(object): self.verify_graph_was_created() # initialize the network parameters from the global network - self.sync_graph() + self.sync() # heatup self.heatup(self.heatup_steps)