1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Cleanup and refactoring (#171)

This commit is contained in:
Zach Dwiel
2019-01-15 03:04:53 -05:00
committed by Gal Leibovich
parent cd812b0d25
commit fedb4cbd7c
7 changed files with 45 additions and 33 deletions

View File

@@ -609,35 +609,35 @@ class Agent(AgentInterface):
:return: boolean: True if we should start a training phase
"""
should_update = self._should_train_helper()
should_update = self._should_update()
step_method = self.ap.algorithm.num_consecutive_playing_steps
steps = self.ap.algorithm.num_consecutive_playing_steps
if should_update:
if step_method.__class__ == EnvironmentEpisodes:
if steps.__class__ == EnvironmentEpisodes:
self.last_training_phase_step = self.current_episode
if step_method.__class__ == EnvironmentSteps:
if steps.__class__ == EnvironmentSteps:
self.last_training_phase_step = self.total_steps_counter
return should_update
def _should_train_helper(self):
def _should_update(self):
wait_for_full_episode = self.ap.algorithm.act_for_full_episodes
step_method = self.ap.algorithm.num_consecutive_playing_steps
steps = self.ap.algorithm.num_consecutive_playing_steps
if step_method.__class__ == EnvironmentEpisodes:
should_update = (self.current_episode - self.last_training_phase_step) >= step_method.num_steps
if steps.__class__ == EnvironmentEpisodes:
should_update = (self.current_episode - self.last_training_phase_step) >= steps.num_steps
should_update = should_update and self.call_memory('length') > 0
elif step_method.__class__ == EnvironmentSteps:
should_update = (self.total_steps_counter - self.last_training_phase_step) >= step_method.num_steps
elif steps.__class__ == EnvironmentSteps:
should_update = (self.total_steps_counter - self.last_training_phase_step) >= steps.num_steps
should_update = should_update and self.call_memory('num_transitions') > 0
if wait_for_full_episode:
should_update = should_update and self.current_episode_buffer.is_complete
else:
raise ValueError("The num_consecutive_playing_steps parameter should be either "
"EnvironmentSteps or Episodes. Instead it is {}".format(step_method.__class__))
"EnvironmentSteps or Episodes. Instead it is {}".format(steps.__class__))
return should_update