From 201a2237a1f11a933ca7142468283026859deda2 Mon Sep 17 00:00:00 2001 From: Zach Dwiel Date: Fri, 5 Oct 2018 11:36:42 -0400 Subject: [PATCH] restructure looping mechanism inGraphManager --- rl_coach/core_types.py | 6 +++++ rl_coach/graph_managers/graph_manager.py | 30 ++++++++++++++---------- rl_coach/tests/test_core_types.py | 27 +++++++++++++++++++++ 3 files changed, 50 insertions(+), 13 deletions(-) create mode 100644 rl_coach/tests/test_core_types.py diff --git a/rl_coach/core_types.py b/rl_coach/core_types.py index 8610c4f..33c7f67 100644 --- a/rl_coach/core_types.py +++ b/rl_coach/core_types.py @@ -596,6 +596,12 @@ class TotalStepsCounter(object): """ self.counters[key] = item + def __add__(self, other: Type[StepMethod]) -> Type[StepMethod]: + return other.__class__(self.counters[other.__class__] + other.num_steps) + + def __lt__(self, other: Type[StepMethod]): + return self.counters[other.__class__] < other.num_steps + class GradientClippingMethod(Enum): ClipByGlobalNorm = 0 diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index d942ede..b7ed951 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -279,6 +279,10 @@ class GraphManager(object): for environment in self.environments: environment.phase = val + @property + def current_step_counter(self) -> TotalStepsCounter: + return self.total_steps_counters[self.phase] + @contextlib.contextmanager def phase_context(self, phase): old_phase = self.phase @@ -309,8 +313,8 @@ class GraphManager(object): self.reset_internal_state(force_environment_reset=True) # act for at least steps, though don't interrupt an episode - count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps - while self.total_steps_counters[self.phase][steps.__class__] < count_end: + count_end = self.current_step_counter + steps + while self.current_step_counter < count_end: self.act(EnvironmentEpisodes(1)) def handle_episode_ended(self) -> None: @@ -318,7 +322,7 @@ class GraphManager(object): End an episode and reset all the episodic parameters :return: None """ - self.total_steps_counters[self.phase][EnvironmentEpisodes] += 1 + self.current_step_counter[EnvironmentEpisodes] += 1 [environment.handle_episode_ended() for environment in self.environments] @@ -331,6 +335,7 @@ class GraphManager(object): self.verify_graph_was_created() with self.phase_context(RunPhase.TRAIN): + self.current_step_counter[TrainingSteps] += 1 [manager.train() for manager in self.level_managers] def reset_internal_state(self, force_environment_reset=False) -> None: @@ -361,8 +366,8 @@ class GraphManager(object): data_store.load_from_store() # perform several steps of playing - count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps - while self.total_steps_counters[self.phase][steps.__class__] < count_end: + count_end = self.current_step_counter + steps + while self.current_step_counter < count_end: # reset the environment if the previous episode was terminated if self.reset_required: self.reset_internal_state() @@ -375,11 +380,10 @@ class GraphManager(object): # (like in Atari) will not be counted. # We add at least one step so that even if no steps were made (in case no actions are taken in the training # phase), the loop will end eventually. - self.total_steps_counters[self.phase][EnvironmentSteps] += max(1, steps_end - steps_begin) + self.current_step_counter[EnvironmentSteps] += max(1, steps_end - steps_begin) if result.game_over: self.handle_episode_ended() - self.reset_required = True def train_and_act(self, steps: StepMethod) -> None: @@ -395,8 +399,8 @@ class GraphManager(object): with self.phase_context(RunPhase.TRAIN): self.reset_internal_state(force_environment_reset=True) - count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps - while self.total_steps_counters[self.phase][steps.__class__] < count_end: + count_end = self.current_step_counter + steps + while self.current_step_counter < count_end: # The actual steps being done on the environment are decided by the agents themselves. # This is just an high-level controller. self.act(EnvironmentSteps(1)) @@ -426,8 +430,8 @@ class GraphManager(object): self.sync() # act for at least `steps`, though don't interrupt an episode - count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps - while self.total_steps_counters[self.phase][steps.__class__] < count_end: + count_end = self.current_step_counter + steps + while self.current_step_counter < count_end: self.act(EnvironmentEpisodes(1)) self.sync() @@ -457,8 +461,8 @@ class GraphManager(object): else: screen.log_title("Starting to improve {}".format(self.name)) - count_end = self.improve_steps.num_steps - while self.total_steps_counters[RunPhase.TRAIN][self.improve_steps.__class__] < count_end: + count_end = self.total_steps_counters[RunPhase.TRAIN] + self.improve_steps + while self.total_steps_counters[RunPhase.TRAIN] < count_end: self.train_and_act(self.steps_between_evaluation_periods) self.evaluate(self.evaluation_steps) diff --git a/rl_coach/tests/test_core_types.py b/rl_coach/tests/test_core_types.py new file mode 100644 index 0000000..58df0ad --- /dev/null +++ b/rl_coach/tests/test_core_types.py @@ -0,0 +1,27 @@ +from rl_coach.core_types import TotalStepsCounter, EnvironmentSteps, EnvironmentEpisodes + +import pytest + + +@pytest.mark.unit_test +def test_add_total_steps_counter(): + counter = TotalStepsCounter() + steps = counter + EnvironmentSteps(10) + assert steps.num_steps == 10 + + +@pytest.mark.unit_test +def test_add_total_steps_counter_non_zero(): + counter = TotalStepsCounter() + counter[EnvironmentSteps] += 10 + steps = counter + EnvironmentSteps(10) + assert steps.num_steps == 20 + + +@pytest.mark.unit_test +def test_total_steps_counter_less_than(): + counter = TotalStepsCounter() + steps = counter + EnvironmentSteps(0) + assert not (counter < steps) + steps = counter + EnvironmentSteps(1) + assert counter < steps