mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 19:50:17 +01:00
restructure looping mechanism inGraphManager
This commit is contained in:
@@ -596,6 +596,12 @@ class TotalStepsCounter(object):
|
|||||||
"""
|
"""
|
||||||
self.counters[key] = item
|
self.counters[key] = item
|
||||||
|
|
||||||
|
def __add__(self, other: Type[StepMethod]) -> Type[StepMethod]:
|
||||||
|
return other.__class__(self.counters[other.__class__] + other.num_steps)
|
||||||
|
|
||||||
|
def __lt__(self, other: Type[StepMethod]):
|
||||||
|
return self.counters[other.__class__] < other.num_steps
|
||||||
|
|
||||||
|
|
||||||
class GradientClippingMethod(Enum):
|
class GradientClippingMethod(Enum):
|
||||||
ClipByGlobalNorm = 0
|
ClipByGlobalNorm = 0
|
||||||
|
|||||||
@@ -279,6 +279,10 @@ class GraphManager(object):
|
|||||||
for environment in self.environments:
|
for environment in self.environments:
|
||||||
environment.phase = val
|
environment.phase = val
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_step_counter(self) -> TotalStepsCounter:
|
||||||
|
return self.total_steps_counters[self.phase]
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def phase_context(self, phase):
|
def phase_context(self, phase):
|
||||||
old_phase = self.phase
|
old_phase = self.phase
|
||||||
@@ -309,8 +313,8 @@ class GraphManager(object):
|
|||||||
self.reset_internal_state(force_environment_reset=True)
|
self.reset_internal_state(force_environment_reset=True)
|
||||||
|
|
||||||
# act for at least steps, though don't interrupt an episode
|
# act for at least steps, though don't interrupt an episode
|
||||||
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
|
count_end = self.current_step_counter + steps
|
||||||
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
|
while self.current_step_counter < count_end:
|
||||||
self.act(EnvironmentEpisodes(1))
|
self.act(EnvironmentEpisodes(1))
|
||||||
|
|
||||||
def handle_episode_ended(self) -> None:
|
def handle_episode_ended(self) -> None:
|
||||||
@@ -318,7 +322,7 @@ class GraphManager(object):
|
|||||||
End an episode and reset all the episodic parameters
|
End an episode and reset all the episodic parameters
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
self.total_steps_counters[self.phase][EnvironmentEpisodes] += 1
|
self.current_step_counter[EnvironmentEpisodes] += 1
|
||||||
|
|
||||||
[environment.handle_episode_ended() for environment in self.environments]
|
[environment.handle_episode_ended() for environment in self.environments]
|
||||||
|
|
||||||
@@ -331,6 +335,7 @@ class GraphManager(object):
|
|||||||
self.verify_graph_was_created()
|
self.verify_graph_was_created()
|
||||||
|
|
||||||
with self.phase_context(RunPhase.TRAIN):
|
with self.phase_context(RunPhase.TRAIN):
|
||||||
|
self.current_step_counter[TrainingSteps] += 1
|
||||||
[manager.train() for manager in self.level_managers]
|
[manager.train() for manager in self.level_managers]
|
||||||
|
|
||||||
def reset_internal_state(self, force_environment_reset=False) -> None:
|
def reset_internal_state(self, force_environment_reset=False) -> None:
|
||||||
@@ -361,8 +366,8 @@ class GraphManager(object):
|
|||||||
data_store.load_from_store()
|
data_store.load_from_store()
|
||||||
|
|
||||||
# perform several steps of playing
|
# perform several steps of playing
|
||||||
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
|
count_end = self.current_step_counter + steps
|
||||||
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
|
while self.current_step_counter < count_end:
|
||||||
# reset the environment if the previous episode was terminated
|
# reset the environment if the previous episode was terminated
|
||||||
if self.reset_required:
|
if self.reset_required:
|
||||||
self.reset_internal_state()
|
self.reset_internal_state()
|
||||||
@@ -375,11 +380,10 @@ class GraphManager(object):
|
|||||||
# (like in Atari) will not be counted.
|
# (like in Atari) will not be counted.
|
||||||
# We add at least one step so that even if no steps were made (in case no actions are taken in the training
|
# We add at least one step so that even if no steps were made (in case no actions are taken in the training
|
||||||
# phase), the loop will end eventually.
|
# phase), the loop will end eventually.
|
||||||
self.total_steps_counters[self.phase][EnvironmentSteps] += max(1, steps_end - steps_begin)
|
self.current_step_counter[EnvironmentSteps] += max(1, steps_end - steps_begin)
|
||||||
|
|
||||||
if result.game_over:
|
if result.game_over:
|
||||||
self.handle_episode_ended()
|
self.handle_episode_ended()
|
||||||
|
|
||||||
self.reset_required = True
|
self.reset_required = True
|
||||||
|
|
||||||
def train_and_act(self, steps: StepMethod) -> None:
|
def train_and_act(self, steps: StepMethod) -> None:
|
||||||
@@ -395,8 +399,8 @@ class GraphManager(object):
|
|||||||
with self.phase_context(RunPhase.TRAIN):
|
with self.phase_context(RunPhase.TRAIN):
|
||||||
self.reset_internal_state(force_environment_reset=True)
|
self.reset_internal_state(force_environment_reset=True)
|
||||||
|
|
||||||
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
|
count_end = self.current_step_counter + steps
|
||||||
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
|
while self.current_step_counter < count_end:
|
||||||
# The actual steps being done on the environment are decided by the agents themselves.
|
# The actual steps being done on the environment are decided by the agents themselves.
|
||||||
# This is just an high-level controller.
|
# This is just an high-level controller.
|
||||||
self.act(EnvironmentSteps(1))
|
self.act(EnvironmentSteps(1))
|
||||||
@@ -426,8 +430,8 @@ class GraphManager(object):
|
|||||||
self.sync()
|
self.sync()
|
||||||
|
|
||||||
# act for at least `steps`, though don't interrupt an episode
|
# act for at least `steps`, though don't interrupt an episode
|
||||||
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
|
count_end = self.current_step_counter + steps
|
||||||
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
|
while self.current_step_counter < count_end:
|
||||||
self.act(EnvironmentEpisodes(1))
|
self.act(EnvironmentEpisodes(1))
|
||||||
self.sync()
|
self.sync()
|
||||||
|
|
||||||
@@ -457,8 +461,8 @@ class GraphManager(object):
|
|||||||
else:
|
else:
|
||||||
screen.log_title("Starting to improve {}".format(self.name))
|
screen.log_title("Starting to improve {}".format(self.name))
|
||||||
|
|
||||||
count_end = self.improve_steps.num_steps
|
count_end = self.total_steps_counters[RunPhase.TRAIN] + self.improve_steps
|
||||||
while self.total_steps_counters[RunPhase.TRAIN][self.improve_steps.__class__] < count_end:
|
while self.total_steps_counters[RunPhase.TRAIN] < count_end:
|
||||||
self.train_and_act(self.steps_between_evaluation_periods)
|
self.train_and_act(self.steps_between_evaluation_periods)
|
||||||
self.evaluate(self.evaluation_steps)
|
self.evaluate(self.evaluation_steps)
|
||||||
|
|
||||||
|
|||||||
27
rl_coach/tests/test_core_types.py
Normal file
27
rl_coach/tests/test_core_types.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
from rl_coach.core_types import TotalStepsCounter, EnvironmentSteps, EnvironmentEpisodes
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit_test
|
||||||
|
def test_add_total_steps_counter():
|
||||||
|
counter = TotalStepsCounter()
|
||||||
|
steps = counter + EnvironmentSteps(10)
|
||||||
|
assert steps.num_steps == 10
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit_test
|
||||||
|
def test_add_total_steps_counter_non_zero():
|
||||||
|
counter = TotalStepsCounter()
|
||||||
|
counter[EnvironmentSteps] += 10
|
||||||
|
steps = counter + EnvironmentSteps(10)
|
||||||
|
assert steps.num_steps == 20
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit_test
|
||||||
|
def test_total_steps_counter_less_than():
|
||||||
|
counter = TotalStepsCounter()
|
||||||
|
steps = counter + EnvironmentSteps(0)
|
||||||
|
assert not (counter < steps)
|
||||||
|
steps = counter + EnvironmentSteps(1)
|
||||||
|
assert counter < steps
|
||||||
Reference in New Issue
Block a user