mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
rename GraphManager.sync_graph -> sync
This commit is contained in:
@@ -362,10 +362,6 @@ class GraphManager(object):
|
||||
|
||||
# perform several steps of playing
|
||||
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
|
||||
|
||||
# The assumption here is that the total_steps_counters are each updated when an event
|
||||
# takes place (i.e. an episode ends)
|
||||
# TODO - The counter of frames is not updated correctly. need to fix that.
|
||||
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
|
||||
# reset the environment if the previous episode was terminated
|
||||
if self.reset_required:
|
||||
@@ -398,8 +394,7 @@ class GraphManager(object):
|
||||
if steps.num_steps > 0:
|
||||
with self.phase_context(RunPhase.TRAIN):
|
||||
self.reset_internal_state(force_environment_reset=True)
|
||||
#TODO - the below while loop should end with full episodes, so to avoid situations where we have partial
|
||||
# episodes in memory
|
||||
|
||||
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
|
||||
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
|
||||
# The actual steps being done on the environment are decided by the agents themselves.
|
||||
@@ -408,7 +403,7 @@ class GraphManager(object):
|
||||
self.train()
|
||||
self.occasionally_save_checkpoint()
|
||||
|
||||
def sync_graph(self) -> None:
|
||||
def sync(self) -> None:
|
||||
"""
|
||||
Sync the global network parameters to the graph
|
||||
:return:
|
||||
@@ -428,13 +423,13 @@ class GraphManager(object):
|
||||
with self.phase_context(RunPhase.TEST):
|
||||
# reset all the levels before starting to evaluate
|
||||
self.reset_internal_state(force_environment_reset=True)
|
||||
self.sync_graph()
|
||||
self.sync()
|
||||
|
||||
# act for at least `steps`, though don't interrupt an episode
|
||||
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
|
||||
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
|
||||
self.act(EnvironmentEpisodes(1))
|
||||
self.sync_graph()
|
||||
self.sync()
|
||||
|
||||
def restore_checkpoint(self):
|
||||
self.verify_graph_was_created()
|
||||
@@ -511,7 +506,7 @@ class GraphManager(object):
|
||||
self.verify_graph_was_created()
|
||||
|
||||
# initialize the network parameters from the global network
|
||||
self.sync_graph()
|
||||
self.sync()
|
||||
|
||||
# heatup
|
||||
self.heatup(self.heatup_steps)
|
||||
|
||||
Reference in New Issue
Block a user