1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

waiting for a new checkpoint if it's available

This commit is contained in:
Ajay Deshpande
2018-10-05 19:08:24 -07:00
committed by zach dwiel
parent 5eac0102de
commit 7f00235ed5
7 changed files with 49 additions and 20 deletions

View File

@@ -81,7 +81,7 @@ class Agent(AgentInterface):
self.memory_backend = get_memory_backend(self.ap.memory.memory_backend_params)
if self.ap.memory.memory_backend_params.run_type == 'trainer':
self.memory_backend.subscribe(self.memory)
self.memory_backend.subscribe(self)
else:
self.memory.set_memory_backend(self.memory_backend)
@@ -523,8 +523,7 @@ class Agent(AgentInterface):
Determine if online weights should be copied to the target.
:return: boolean: True if the online weights should be copied to the target.
"""
if hasattr(self.ap.memory, 'memory_backend_params'):
self.total_steps_counter = self.call_memory('num_transitions')
# update the target network of every network that has a target network
step_method = self.ap.algorithm.num_steps_between_copying_online_weights_to_target
if step_method.__class__ == TrainingSteps:
@@ -546,7 +545,7 @@ class Agent(AgentInterface):
:return: boolean: True if we should start a training phase
"""
should_update = self._should_train_helper(wait_for_full_episode)
should_update = self._should_train_helper(wait_for_full_episode=wait_for_full_episode)
step_method = self.ap.algorithm.num_consecutive_playing_steps
@@ -560,10 +559,8 @@ class Agent(AgentInterface):
def _should_train_helper(self, wait_for_full_episode=False):
if hasattr(self.ap.memory, 'memory_backend_params'):
self.total_steps_counter = self.call_memory('num_transitions')
step_method = self.ap.algorithm.num_consecutive_playing_steps
if step_method.__class__ == EnvironmentEpisodes:
should_update = (self.current_episode - self.last_training_phase_step) >= step_method.num_steps