1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

Adding should_train helper and should_train in graph_manager

This commit is contained in:
Ajay Deshpande
2018-10-05 14:22:15 -07:00
committed by zach dwiel
parent a2e57a44f1
commit a7f5442015
7 changed files with 126 additions and 20 deletions

View File

@@ -523,6 +523,8 @@ class Agent(AgentInterface):
Determine if online weights should be copied to the target.
:return: boolean: True if the online weights should be copied to the target.
"""
if hasattr(self.ap.memory, 'memory_backend_params'):
self.total_steps_counter = self.call_memory('num_transitions')
# update the target network of every network that has a target network
step_method = self.ap.algorithm.num_steps_between_copying_online_weights_to_target
if step_method.__class__ == TrainingSteps:
@@ -544,22 +546,35 @@ class Agent(AgentInterface):
:return: boolean: True if we should start a training phase
"""
should_update = self._should_train_helper(wait_for_full_episode)
step_method = self.ap.algorithm.num_consecutive_playing_steps
if should_update:
if step_method.__class__ == EnvironmentEpisodes:
self.last_training_phase_step = self.current_episode
if step_method.__class__ == EnvironmentSteps:
self.last_training_phase_step = self.total_steps_counter
return should_update
def _should_train_helper(self, wait_for_full_episode=False):
if hasattr(self.ap.memory, 'memory_backend_params'):
self.total_steps_counter = self.call_memory('num_transitions')
step_method = self.ap.algorithm.num_consecutive_playing_steps
if step_method.__class__ == EnvironmentEpisodes:
should_update = (self.current_episode - self.last_training_phase_step) >= step_method.num_steps
if should_update:
self.last_training_phase_step = self.current_episode
elif step_method.__class__ == EnvironmentSteps:
should_update = (self.total_steps_counter - self.last_training_phase_step) >= step_method.num_steps
if wait_for_full_episode:
should_update = should_update and self.current_episode_buffer.is_complete
if should_update:
self.last_training_phase_step = self.total_steps_counter
else:
raise ValueError("The num_consecutive_playing_steps parameter should be either "
"EnvironmentSteps or Episodes. Instead it is {}".format(step_method.__class__))
return should_update
def train(self):