mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
simplify GraphManager.act by removing arguments: continue_until_game_over and return_on_game_over
This commit is contained in:
@@ -319,7 +319,7 @@ class GraphManager(object):
|
|||||||
# act for at least steps, though don't interrupt an episode
|
# act for at least steps, though don't interrupt an episode
|
||||||
count_end = self.total_steps_counters[self.phase][EnvironmentSteps] + steps.num_steps
|
count_end = self.total_steps_counters[self.phase][EnvironmentSteps] + steps.num_steps
|
||||||
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
|
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
|
||||||
self.act(steps, continue_until_game_over=True, return_on_game_over=True)
|
self.act(EnvironmentEpisodes(1))
|
||||||
|
|
||||||
def handle_episode_ended(self) -> None:
|
def handle_episode_ended(self) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -356,8 +356,7 @@ class GraphManager(object):
|
|||||||
[environment.reset_internal_state(force_environment_reset) for environment in self.environments]
|
[environment.reset_internal_state(force_environment_reset) for environment in self.environments]
|
||||||
[manager.reset_internal_state() for manager in self.level_managers]
|
[manager.reset_internal_state() for manager in self.level_managers]
|
||||||
|
|
||||||
def act(self, steps: PlayingStepsType, return_on_game_over: bool=False, continue_until_game_over=False,
|
def act(self, steps: PlayingStepsType, keep_networks_in_sync=False) -> (int, bool):
|
||||||
keep_networks_in_sync=False) -> (int, bool):
|
|
||||||
"""
|
"""
|
||||||
Do several steps of acting on the environment
|
Do several steps of acting on the environment
|
||||||
:param steps: the number of steps as a tuple of steps time and steps count
|
:param steps: the number of steps as a tuple of steps time and steps count
|
||||||
@@ -380,7 +379,7 @@ class GraphManager(object):
|
|||||||
# The assumption here is that the total_steps_counters are each updated when an event
|
# The assumption here is that the total_steps_counters are each updated when an event
|
||||||
# takes place (i.e. an episode ends)
|
# takes place (i.e. an episode ends)
|
||||||
# TODO - The counter of frames is not updated correctly. need to fix that.
|
# TODO - The counter of frames is not updated correctly. need to fix that.
|
||||||
while self.total_steps_counters[self.phase][steps.__class__] < count_end or continue_until_game_over:
|
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
|
||||||
# reset the environment if the previous episode was terminated
|
# reset the environment if the previous episode was terminated
|
||||||
if self.reset_required:
|
if self.reset_required:
|
||||||
self.reset_internal_state()
|
self.reset_internal_state()
|
||||||
@@ -402,14 +401,11 @@ class GraphManager(object):
|
|||||||
self.total_steps_counters[self.phase][EnvironmentSteps] += max(1, steps_end - steps_begin)
|
self.total_steps_counters[self.phase][EnvironmentSteps] += max(1, steps_end - steps_begin)
|
||||||
|
|
||||||
if result.game_over:
|
if result.game_over:
|
||||||
continue_until_game_over = False
|
|
||||||
self.handle_episode_ended()
|
self.handle_episode_ended()
|
||||||
# TODO: why not just reset right now?
|
# TODO: why not just reset right now?
|
||||||
self.reset_required = True
|
self.reset_required = True
|
||||||
if keep_networks_in_sync:
|
if keep_networks_in_sync:
|
||||||
self.sync_graph()
|
self.sync_graph()
|
||||||
if return_on_game_over:
|
|
||||||
return
|
|
||||||
|
|
||||||
def train_and_act(self, steps: StepMethod) -> None:
|
def train_and_act(self, steps: StepMethod) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -460,8 +456,7 @@ class GraphManager(object):
|
|||||||
# act for at least `steps`, though don't interrupt an episode
|
# act for at least `steps`, though don't interrupt an episode
|
||||||
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
|
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
|
||||||
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
|
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
|
||||||
self.act(steps, continue_until_game_over=True, return_on_game_over=True,
|
self.act(EnvironmentEpisodes(1), keep_networks_in_sync=keep_networks_in_sync)
|
||||||
keep_networks_in_sync=keep_networks_in_sync)
|
|
||||||
|
|
||||||
self.phase = RunPhase.UNDEFINED
|
self.phase = RunPhase.UNDEFINED
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user