1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 19:50:17 +01:00

restructure looping mechanism inGraphManager

This commit is contained in:
Zach Dwiel
2018-10-05 11:36:42 -04:00
committed by zach dwiel
parent 52560a2aae
commit 201a2237a1
3 changed files with 50 additions and 13 deletions

View File

@@ -596,6 +596,12 @@ class TotalStepsCounter(object):
""" """
self.counters[key] = item self.counters[key] = item
def __add__(self, other: Type[StepMethod]) -> Type[StepMethod]:
return other.__class__(self.counters[other.__class__] + other.num_steps)
def __lt__(self, other: Type[StepMethod]):
return self.counters[other.__class__] < other.num_steps
class GradientClippingMethod(Enum): class GradientClippingMethod(Enum):
ClipByGlobalNorm = 0 ClipByGlobalNorm = 0

View File

@@ -279,6 +279,10 @@ class GraphManager(object):
for environment in self.environments: for environment in self.environments:
environment.phase = val environment.phase = val
@property
def current_step_counter(self) -> TotalStepsCounter:
return self.total_steps_counters[self.phase]
@contextlib.contextmanager @contextlib.contextmanager
def phase_context(self, phase): def phase_context(self, phase):
old_phase = self.phase old_phase = self.phase
@@ -309,8 +313,8 @@ class GraphManager(object):
self.reset_internal_state(force_environment_reset=True) self.reset_internal_state(force_environment_reset=True)
# act for at least steps, though don't interrupt an episode # act for at least steps, though don't interrupt an episode
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps count_end = self.current_step_counter + steps
while self.total_steps_counters[self.phase][steps.__class__] < count_end: while self.current_step_counter < count_end:
self.act(EnvironmentEpisodes(1)) self.act(EnvironmentEpisodes(1))
def handle_episode_ended(self) -> None: def handle_episode_ended(self) -> None:
@@ -318,7 +322,7 @@ class GraphManager(object):
End an episode and reset all the episodic parameters End an episode and reset all the episodic parameters
:return: None :return: None
""" """
self.total_steps_counters[self.phase][EnvironmentEpisodes] += 1 self.current_step_counter[EnvironmentEpisodes] += 1
[environment.handle_episode_ended() for environment in self.environments] [environment.handle_episode_ended() for environment in self.environments]
@@ -331,6 +335,7 @@ class GraphManager(object):
self.verify_graph_was_created() self.verify_graph_was_created()
with self.phase_context(RunPhase.TRAIN): with self.phase_context(RunPhase.TRAIN):
self.current_step_counter[TrainingSteps] += 1
[manager.train() for manager in self.level_managers] [manager.train() for manager in self.level_managers]
def reset_internal_state(self, force_environment_reset=False) -> None: def reset_internal_state(self, force_environment_reset=False) -> None:
@@ -361,8 +366,8 @@ class GraphManager(object):
data_store.load_from_store() data_store.load_from_store()
# perform several steps of playing # perform several steps of playing
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps count_end = self.current_step_counter + steps
while self.total_steps_counters[self.phase][steps.__class__] < count_end: while self.current_step_counter < count_end:
# reset the environment if the previous episode was terminated # reset the environment if the previous episode was terminated
if self.reset_required: if self.reset_required:
self.reset_internal_state() self.reset_internal_state()
@@ -375,11 +380,10 @@ class GraphManager(object):
# (like in Atari) will not be counted. # (like in Atari) will not be counted.
# We add at least one step so that even if no steps were made (in case no actions are taken in the training # We add at least one step so that even if no steps were made (in case no actions are taken in the training
# phase), the loop will end eventually. # phase), the loop will end eventually.
self.total_steps_counters[self.phase][EnvironmentSteps] += max(1, steps_end - steps_begin) self.current_step_counter[EnvironmentSteps] += max(1, steps_end - steps_begin)
if result.game_over: if result.game_over:
self.handle_episode_ended() self.handle_episode_ended()
self.reset_required = True self.reset_required = True
def train_and_act(self, steps: StepMethod) -> None: def train_and_act(self, steps: StepMethod) -> None:
@@ -395,8 +399,8 @@ class GraphManager(object):
with self.phase_context(RunPhase.TRAIN): with self.phase_context(RunPhase.TRAIN):
self.reset_internal_state(force_environment_reset=True) self.reset_internal_state(force_environment_reset=True)
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps count_end = self.current_step_counter + steps
while self.total_steps_counters[self.phase][steps.__class__] < count_end: while self.current_step_counter < count_end:
# The actual steps being done on the environment are decided by the agents themselves. # The actual steps being done on the environment are decided by the agents themselves.
# This is just an high-level controller. # This is just an high-level controller.
self.act(EnvironmentSteps(1)) self.act(EnvironmentSteps(1))
@@ -426,8 +430,8 @@ class GraphManager(object):
self.sync() self.sync()
# act for at least `steps`, though don't interrupt an episode # act for at least `steps`, though don't interrupt an episode
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps count_end = self.current_step_counter + steps
while self.total_steps_counters[self.phase][steps.__class__] < count_end: while self.current_step_counter < count_end:
self.act(EnvironmentEpisodes(1)) self.act(EnvironmentEpisodes(1))
self.sync() self.sync()
@@ -457,8 +461,8 @@ class GraphManager(object):
else: else:
screen.log_title("Starting to improve {}".format(self.name)) screen.log_title("Starting to improve {}".format(self.name))
count_end = self.improve_steps.num_steps count_end = self.total_steps_counters[RunPhase.TRAIN] + self.improve_steps
while self.total_steps_counters[RunPhase.TRAIN][self.improve_steps.__class__] < count_end: while self.total_steps_counters[RunPhase.TRAIN] < count_end:
self.train_and_act(self.steps_between_evaluation_periods) self.train_and_act(self.steps_between_evaluation_periods)
self.evaluate(self.evaluation_steps) self.evaluate(self.evaluation_steps)

View File

@@ -0,0 +1,27 @@
from rl_coach.core_types import TotalStepsCounter, EnvironmentSteps, EnvironmentEpisodes
import pytest
@pytest.mark.unit_test
def test_add_total_steps_counter():
counter = TotalStepsCounter()
steps = counter + EnvironmentSteps(10)
assert steps.num_steps == 10
@pytest.mark.unit_test
def test_add_total_steps_counter_non_zero():
counter = TotalStepsCounter()
counter[EnvironmentSteps] += 10
steps = counter + EnvironmentSteps(10)
assert steps.num_steps == 20
@pytest.mark.unit_test
def test_total_steps_counter_less_than():
counter = TotalStepsCounter()
steps = counter + EnvironmentSteps(0)
assert not (counter < steps)
steps = counter + EnvironmentSteps(1)
assert counter < steps