1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

remove argument keep_networks_in_sync from GraphManager.act, and move this feature into the only place that activated it: GraphManager.train_and_act

This commit is contained in:
Zach Dwiel
2018-10-04 11:59:05 -04:00
committed by zach dwiel
parent b2d864a5bd
commit 5fee48dcfd

View File

@@ -348,11 +348,10 @@ class GraphManager(object):
[environment.reset_internal_state(force_environment_reset) for environment in self.environments] [environment.reset_internal_state(force_environment_reset) for environment in self.environments]
[manager.reset_internal_state() for manager in self.level_managers] [manager.reset_internal_state() for manager in self.level_managers]
def act(self, steps: PlayingStepsType, keep_networks_in_sync=False) -> (int, bool): def act(self, steps: PlayingStepsType) -> (int, bool):
""" """
Do several steps of acting on the environment Do several steps of acting on the environment
:param steps: the number of steps as a tuple of steps time and steps count :param steps: the number of steps as a tuple of steps time and steps count
:param keep_networks_in_sync: sync the network parameters with the global network before each episode
""" """
self.verify_graph_was_created() self.verify_graph_was_created()
@@ -386,8 +385,6 @@ class GraphManager(object):
self.handle_episode_ended() self.handle_episode_ended()
self.reset_required = True self.reset_required = True
if keep_networks_in_sync:
self.sync_graph()
def train_and_act(self, steps: StepMethod) -> None: def train_and_act(self, steps: StepMethod) -> None:
""" """
@@ -436,7 +433,8 @@ class GraphManager(object):
# act for at least `steps`, though don't interrupt an episode # act for at least `steps`, though don't interrupt an episode
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
while self.total_steps_counters[self.phase][steps.__class__] < count_end: while self.total_steps_counters[self.phase][steps.__class__] < count_end:
self.act(EnvironmentEpisodes(1), keep_networks_in_sync=keep_networks_in_sync) self.act(EnvironmentEpisodes(1))
self.sync_graph()
def restore_checkpoint(self): def restore_checkpoint(self):
self.verify_graph_was_created() self.verify_graph_was_created()