1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

restructure looping mechanism inGraphManager

This commit is contained in:
Zach Dwiel
2018-10-05 11:36:42 -04:00
committed by zach dwiel
parent 52560a2aae
commit 201a2237a1
3 changed files with 50 additions and 13 deletions

View File

@@ -596,6 +596,12 @@ class TotalStepsCounter(object):
"""
self.counters[key] = item
def __add__(self, other: Type[StepMethod]) -> Type[StepMethod]:
return other.__class__(self.counters[other.__class__] + other.num_steps)
def __lt__(self, other: Type[StepMethod]):
return self.counters[other.__class__] < other.num_steps
class GradientClippingMethod(Enum):
ClipByGlobalNorm = 0

View File

@@ -279,6 +279,10 @@ class GraphManager(object):
for environment in self.environments:
environment.phase = val
@property
def current_step_counter(self) -> TotalStepsCounter:
return self.total_steps_counters[self.phase]
@contextlib.contextmanager
def phase_context(self, phase):
old_phase = self.phase
@@ -309,8 +313,8 @@ class GraphManager(object):
self.reset_internal_state(force_environment_reset=True)
# act for at least steps, though don't interrupt an episode
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
count_end = self.current_step_counter + steps
while self.current_step_counter < count_end:
self.act(EnvironmentEpisodes(1))
def handle_episode_ended(self) -> None:
@@ -318,7 +322,7 @@ class GraphManager(object):
End an episode and reset all the episodic parameters
:return: None
"""
self.total_steps_counters[self.phase][EnvironmentEpisodes] += 1
self.current_step_counter[EnvironmentEpisodes] += 1
[environment.handle_episode_ended() for environment in self.environments]
@@ -331,6 +335,7 @@ class GraphManager(object):
self.verify_graph_was_created()
with self.phase_context(RunPhase.TRAIN):
self.current_step_counter[TrainingSteps] += 1
[manager.train() for manager in self.level_managers]
def reset_internal_state(self, force_environment_reset=False) -> None:
@@ -361,8 +366,8 @@ class GraphManager(object):
data_store.load_from_store()
# perform several steps of playing
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
count_end = self.current_step_counter + steps
while self.current_step_counter < count_end:
# reset the environment if the previous episode was terminated
if self.reset_required:
self.reset_internal_state()
@@ -375,11 +380,10 @@ class GraphManager(object):
# (like in Atari) will not be counted.
# We add at least one step so that even if no steps were made (in case no actions are taken in the training
# phase), the loop will end eventually.
self.total_steps_counters[self.phase][EnvironmentSteps] += max(1, steps_end - steps_begin)
self.current_step_counter[EnvironmentSteps] += max(1, steps_end - steps_begin)
if result.game_over:
self.handle_episode_ended()
self.reset_required = True
def train_and_act(self, steps: StepMethod) -> None:
@@ -395,8 +399,8 @@ class GraphManager(object):
with self.phase_context(RunPhase.TRAIN):
self.reset_internal_state(force_environment_reset=True)
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
count_end = self.current_step_counter + steps
while self.current_step_counter < count_end:
# The actual steps being done on the environment are decided by the agents themselves.
# This is just an high-level controller.
self.act(EnvironmentSteps(1))
@@ -426,8 +430,8 @@ class GraphManager(object):
self.sync()
# act for at least `steps`, though don't interrupt an episode
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
count_end = self.current_step_counter + steps
while self.current_step_counter < count_end:
self.act(EnvironmentEpisodes(1))
self.sync()
@@ -457,8 +461,8 @@ class GraphManager(object):
else:
screen.log_title("Starting to improve {}".format(self.name))
count_end = self.improve_steps.num_steps
while self.total_steps_counters[RunPhase.TRAIN][self.improve_steps.__class__] < count_end:
count_end = self.total_steps_counters[RunPhase.TRAIN] + self.improve_steps
while self.total_steps_counters[RunPhase.TRAIN] < count_end:
self.train_and_act(self.steps_between_evaluation_periods)
self.evaluate(self.evaluation_steps)

View File

@@ -0,0 +1,27 @@
from rl_coach.core_types import TotalStepsCounter, EnvironmentSteps, EnvironmentEpisodes
import pytest
@pytest.mark.unit_test
def test_add_total_steps_counter():
counter = TotalStepsCounter()
steps = counter + EnvironmentSteps(10)
assert steps.num_steps == 10
@pytest.mark.unit_test
def test_add_total_steps_counter_non_zero():
counter = TotalStepsCounter()
counter[EnvironmentSteps] += 10
steps = counter + EnvironmentSteps(10)
assert steps.num_steps == 20
@pytest.mark.unit_test
def test_total_steps_counter_less_than():
counter = TotalStepsCounter()
steps = counter + EnvironmentSteps(0)
assert not (counter < steps)
steps = counter + EnvironmentSteps(1)
assert counter < steps