mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
restructure looping mechanism inGraphManager
This commit is contained in:
@@ -596,6 +596,12 @@ class TotalStepsCounter(object):
|
||||
"""
|
||||
self.counters[key] = item
|
||||
|
||||
def __add__(self, other: Type[StepMethod]) -> Type[StepMethod]:
|
||||
return other.__class__(self.counters[other.__class__] + other.num_steps)
|
||||
|
||||
def __lt__(self, other: Type[StepMethod]):
|
||||
return self.counters[other.__class__] < other.num_steps
|
||||
|
||||
|
||||
class GradientClippingMethod(Enum):
|
||||
ClipByGlobalNorm = 0
|
||||
|
||||
@@ -279,6 +279,10 @@ class GraphManager(object):
|
||||
for environment in self.environments:
|
||||
environment.phase = val
|
||||
|
||||
@property
|
||||
def current_step_counter(self) -> TotalStepsCounter:
|
||||
return self.total_steps_counters[self.phase]
|
||||
|
||||
@contextlib.contextmanager
|
||||
def phase_context(self, phase):
|
||||
old_phase = self.phase
|
||||
@@ -309,8 +313,8 @@ class GraphManager(object):
|
||||
self.reset_internal_state(force_environment_reset=True)
|
||||
|
||||
# act for at least steps, though don't interrupt an episode
|
||||
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
|
||||
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
|
||||
count_end = self.current_step_counter + steps
|
||||
while self.current_step_counter < count_end:
|
||||
self.act(EnvironmentEpisodes(1))
|
||||
|
||||
def handle_episode_ended(self) -> None:
|
||||
@@ -318,7 +322,7 @@ class GraphManager(object):
|
||||
End an episode and reset all the episodic parameters
|
||||
:return: None
|
||||
"""
|
||||
self.total_steps_counters[self.phase][EnvironmentEpisodes] += 1
|
||||
self.current_step_counter[EnvironmentEpisodes] += 1
|
||||
|
||||
[environment.handle_episode_ended() for environment in self.environments]
|
||||
|
||||
@@ -331,6 +335,7 @@ class GraphManager(object):
|
||||
self.verify_graph_was_created()
|
||||
|
||||
with self.phase_context(RunPhase.TRAIN):
|
||||
self.current_step_counter[TrainingSteps] += 1
|
||||
[manager.train() for manager in self.level_managers]
|
||||
|
||||
def reset_internal_state(self, force_environment_reset=False) -> None:
|
||||
@@ -361,8 +366,8 @@ class GraphManager(object):
|
||||
data_store.load_from_store()
|
||||
|
||||
# perform several steps of playing
|
||||
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
|
||||
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
|
||||
count_end = self.current_step_counter + steps
|
||||
while self.current_step_counter < count_end:
|
||||
# reset the environment if the previous episode was terminated
|
||||
if self.reset_required:
|
||||
self.reset_internal_state()
|
||||
@@ -375,11 +380,10 @@ class GraphManager(object):
|
||||
# (like in Atari) will not be counted.
|
||||
# We add at least one step so that even if no steps were made (in case no actions are taken in the training
|
||||
# phase), the loop will end eventually.
|
||||
self.total_steps_counters[self.phase][EnvironmentSteps] += max(1, steps_end - steps_begin)
|
||||
self.current_step_counter[EnvironmentSteps] += max(1, steps_end - steps_begin)
|
||||
|
||||
if result.game_over:
|
||||
self.handle_episode_ended()
|
||||
|
||||
self.reset_required = True
|
||||
|
||||
def train_and_act(self, steps: StepMethod) -> None:
|
||||
@@ -395,8 +399,8 @@ class GraphManager(object):
|
||||
with self.phase_context(RunPhase.TRAIN):
|
||||
self.reset_internal_state(force_environment_reset=True)
|
||||
|
||||
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
|
||||
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
|
||||
count_end = self.current_step_counter + steps
|
||||
while self.current_step_counter < count_end:
|
||||
# The actual steps being done on the environment are decided by the agents themselves.
|
||||
# This is just an high-level controller.
|
||||
self.act(EnvironmentSteps(1))
|
||||
@@ -426,8 +430,8 @@ class GraphManager(object):
|
||||
self.sync()
|
||||
|
||||
# act for at least `steps`, though don't interrupt an episode
|
||||
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
|
||||
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
|
||||
count_end = self.current_step_counter + steps
|
||||
while self.current_step_counter < count_end:
|
||||
self.act(EnvironmentEpisodes(1))
|
||||
self.sync()
|
||||
|
||||
@@ -457,8 +461,8 @@ class GraphManager(object):
|
||||
else:
|
||||
screen.log_title("Starting to improve {}".format(self.name))
|
||||
|
||||
count_end = self.improve_steps.num_steps
|
||||
while self.total_steps_counters[RunPhase.TRAIN][self.improve_steps.__class__] < count_end:
|
||||
count_end = self.total_steps_counters[RunPhase.TRAIN] + self.improve_steps
|
||||
while self.total_steps_counters[RunPhase.TRAIN] < count_end:
|
||||
self.train_and_act(self.steps_between_evaluation_periods)
|
||||
self.evaluate(self.evaluation_steps)
|
||||
|
||||
|
||||
27
rl_coach/tests/test_core_types.py
Normal file
27
rl_coach/tests/test_core_types.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from rl_coach.core_types import TotalStepsCounter, EnvironmentSteps, EnvironmentEpisodes
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_add_total_steps_counter():
|
||||
counter = TotalStepsCounter()
|
||||
steps = counter + EnvironmentSteps(10)
|
||||
assert steps.num_steps == 10
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_add_total_steps_counter_non_zero():
|
||||
counter = TotalStepsCounter()
|
||||
counter[EnvironmentSteps] += 10
|
||||
steps = counter + EnvironmentSteps(10)
|
||||
assert steps.num_steps == 20
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_total_steps_counter_less_than():
|
||||
counter = TotalStepsCounter()
|
||||
steps = counter + EnvironmentSteps(0)
|
||||
assert not (counter < steps)
|
||||
steps = counter + EnvironmentSteps(1)
|
||||
assert counter < steps
|
||||
Reference in New Issue
Block a user