1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

rename GraphManager.sync_graph -> sync

This commit is contained in:
Zach Dwiel
2018-10-04 16:43:38 -04:00
committed by zach dwiel
parent 5fee48dcfd
commit 496a516de1

View File

@@ -362,10 +362,6 @@ class GraphManager(object):
# perform several steps of playing
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
# The assumption here is that the total_steps_counters are each updated when an event
# takes place (i.e. an episode ends)
# TODO - The counter of frames is not updated correctly. need to fix that.
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
# reset the environment if the previous episode was terminated
if self.reset_required:
@@ -398,8 +394,7 @@ class GraphManager(object):
if steps.num_steps > 0:
with self.phase_context(RunPhase.TRAIN):
self.reset_internal_state(force_environment_reset=True)
#TODO - the below while loop should end with full episodes, so to avoid situations where we have partial
# episodes in memory
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
# The actual steps being done on the environment are decided by the agents themselves.
@@ -408,7 +403,7 @@ class GraphManager(object):
self.train()
self.occasionally_save_checkpoint()
def sync_graph(self) -> None:
def sync(self) -> None:
"""
Sync the global network parameters to the graph
:return:
@@ -428,13 +423,13 @@ class GraphManager(object):
with self.phase_context(RunPhase.TEST):
# reset all the levels before starting to evaluate
self.reset_internal_state(force_environment_reset=True)
self.sync_graph()
self.sync()
# act for at least `steps`, though don't interrupt an episode
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
self.act(EnvironmentEpisodes(1))
self.sync_graph()
self.sync()
def restore_checkpoint(self):
self.verify_graph_was_created()
@@ -511,7 +506,7 @@ class GraphManager(object):
self.verify_graph_was_created()
# initialize the network parameters from the global network
self.sync_graph()
self.sync()
# heatup
self.heatup(self.heatup_steps)