mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
rename GraphManager.sync_graph -> sync
This commit is contained in:
@@ -362,10 +362,6 @@ class GraphManager(object):
|
|||||||
|
|
||||||
# perform several steps of playing
|
# perform several steps of playing
|
||||||
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
|
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
|
||||||
|
|
||||||
# The assumption here is that the total_steps_counters are each updated when an event
|
|
||||||
# takes place (i.e. an episode ends)
|
|
||||||
# TODO - The counter of frames is not updated correctly. need to fix that.
|
|
||||||
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
|
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
|
||||||
# reset the environment if the previous episode was terminated
|
# reset the environment if the previous episode was terminated
|
||||||
if self.reset_required:
|
if self.reset_required:
|
||||||
@@ -398,8 +394,7 @@ class GraphManager(object):
|
|||||||
if steps.num_steps > 0:
|
if steps.num_steps > 0:
|
||||||
with self.phase_context(RunPhase.TRAIN):
|
with self.phase_context(RunPhase.TRAIN):
|
||||||
self.reset_internal_state(force_environment_reset=True)
|
self.reset_internal_state(force_environment_reset=True)
|
||||||
#TODO - the below while loop should end with full episodes, so to avoid situations where we have partial
|
|
||||||
# episodes in memory
|
|
||||||
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
|
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
|
||||||
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
|
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
|
||||||
# The actual steps being done on the environment are decided by the agents themselves.
|
# The actual steps being done on the environment are decided by the agents themselves.
|
||||||
@@ -408,7 +403,7 @@ class GraphManager(object):
|
|||||||
self.train()
|
self.train()
|
||||||
self.occasionally_save_checkpoint()
|
self.occasionally_save_checkpoint()
|
||||||
|
|
||||||
def sync_graph(self) -> None:
|
def sync(self) -> None:
|
||||||
"""
|
"""
|
||||||
Sync the global network parameters to the graph
|
Sync the global network parameters to the graph
|
||||||
:return:
|
:return:
|
||||||
@@ -428,13 +423,13 @@ class GraphManager(object):
|
|||||||
with self.phase_context(RunPhase.TEST):
|
with self.phase_context(RunPhase.TEST):
|
||||||
# reset all the levels before starting to evaluate
|
# reset all the levels before starting to evaluate
|
||||||
self.reset_internal_state(force_environment_reset=True)
|
self.reset_internal_state(force_environment_reset=True)
|
||||||
self.sync_graph()
|
self.sync()
|
||||||
|
|
||||||
# act for at least `steps`, though don't interrupt an episode
|
# act for at least `steps`, though don't interrupt an episode
|
||||||
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
|
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
|
||||||
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
|
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
|
||||||
self.act(EnvironmentEpisodes(1))
|
self.act(EnvironmentEpisodes(1))
|
||||||
self.sync_graph()
|
self.sync()
|
||||||
|
|
||||||
def restore_checkpoint(self):
|
def restore_checkpoint(self):
|
||||||
self.verify_graph_was_created()
|
self.verify_graph_was_created()
|
||||||
@@ -511,7 +506,7 @@ class GraphManager(object):
|
|||||||
self.verify_graph_was_created()
|
self.verify_graph_was_created()
|
||||||
|
|
||||||
# initialize the network parameters from the global network
|
# initialize the network parameters from the global network
|
||||||
self.sync_graph()
|
self.sync()
|
||||||
|
|
||||||
# heatup
|
# heatup
|
||||||
self.heatup(self.heatup_steps)
|
self.heatup(self.heatup_steps)
|
||||||
|
|||||||
Reference in New Issue
Block a user