mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
restructure looping mechanism inGraphManager
This commit is contained in:
@@ -279,6 +279,10 @@ class GraphManager(object):
|
||||
for environment in self.environments:
|
||||
environment.phase = val
|
||||
|
||||
@property
|
||||
def current_step_counter(self) -> TotalStepsCounter:
|
||||
return self.total_steps_counters[self.phase]
|
||||
|
||||
@contextlib.contextmanager
|
||||
def phase_context(self, phase):
|
||||
old_phase = self.phase
|
||||
@@ -309,8 +313,8 @@ class GraphManager(object):
|
||||
self.reset_internal_state(force_environment_reset=True)
|
||||
|
||||
# act for at least steps, though don't interrupt an episode
|
||||
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
|
||||
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
|
||||
count_end = self.current_step_counter + steps
|
||||
while self.current_step_counter < count_end:
|
||||
self.act(EnvironmentEpisodes(1))
|
||||
|
||||
def handle_episode_ended(self) -> None:
|
||||
@@ -318,7 +322,7 @@ class GraphManager(object):
|
||||
End an episode and reset all the episodic parameters
|
||||
:return: None
|
||||
"""
|
||||
self.total_steps_counters[self.phase][EnvironmentEpisodes] += 1
|
||||
self.current_step_counter[EnvironmentEpisodes] += 1
|
||||
|
||||
[environment.handle_episode_ended() for environment in self.environments]
|
||||
|
||||
@@ -331,6 +335,7 @@ class GraphManager(object):
|
||||
self.verify_graph_was_created()
|
||||
|
||||
with self.phase_context(RunPhase.TRAIN):
|
||||
self.current_step_counter[TrainingSteps] += 1
|
||||
[manager.train() for manager in self.level_managers]
|
||||
|
||||
def reset_internal_state(self, force_environment_reset=False) -> None:
|
||||
@@ -361,8 +366,8 @@ class GraphManager(object):
|
||||
data_store.load_from_store()
|
||||
|
||||
# perform several steps of playing
|
||||
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
|
||||
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
|
||||
count_end = self.current_step_counter + steps
|
||||
while self.current_step_counter < count_end:
|
||||
# reset the environment if the previous episode was terminated
|
||||
if self.reset_required:
|
||||
self.reset_internal_state()
|
||||
@@ -375,11 +380,10 @@ class GraphManager(object):
|
||||
# (like in Atari) will not be counted.
|
||||
# We add at least one step so that even if no steps were made (in case no actions are taken in the training
|
||||
# phase), the loop will end eventually.
|
||||
self.total_steps_counters[self.phase][EnvironmentSteps] += max(1, steps_end - steps_begin)
|
||||
self.current_step_counter[EnvironmentSteps] += max(1, steps_end - steps_begin)
|
||||
|
||||
if result.game_over:
|
||||
self.handle_episode_ended()
|
||||
|
||||
self.reset_required = True
|
||||
|
||||
def train_and_act(self, steps: StepMethod) -> None:
|
||||
@@ -395,8 +399,8 @@ class GraphManager(object):
|
||||
with self.phase_context(RunPhase.TRAIN):
|
||||
self.reset_internal_state(force_environment_reset=True)
|
||||
|
||||
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
|
||||
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
|
||||
count_end = self.current_step_counter + steps
|
||||
while self.current_step_counter < count_end:
|
||||
# The actual steps being done on the environment are decided by the agents themselves.
|
||||
# This is just an high-level controller.
|
||||
self.act(EnvironmentSteps(1))
|
||||
@@ -426,8 +430,8 @@ class GraphManager(object):
|
||||
self.sync()
|
||||
|
||||
# act for at least `steps`, though don't interrupt an episode
|
||||
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
|
||||
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
|
||||
count_end = self.current_step_counter + steps
|
||||
while self.current_step_counter < count_end:
|
||||
self.act(EnvironmentEpisodes(1))
|
||||
self.sync()
|
||||
|
||||
@@ -457,8 +461,8 @@ class GraphManager(object):
|
||||
else:
|
||||
screen.log_title("Starting to improve {}".format(self.name))
|
||||
|
||||
count_end = self.improve_steps.num_steps
|
||||
while self.total_steps_counters[RunPhase.TRAIN][self.improve_steps.__class__] < count_end:
|
||||
count_end = self.total_steps_counters[RunPhase.TRAIN] + self.improve_steps
|
||||
while self.total_steps_counters[RunPhase.TRAIN] < count_end:
|
||||
self.train_and_act(self.steps_between_evaluation_periods)
|
||||
self.evaluate(self.evaluation_steps)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user