diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index 65645ff..0622bc4 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -20,6 +20,7 @@ import time from collections import OrderedDict from distutils.dir_util import copy_tree, remove_tree from typing import List, Tuple +import contextlib from rl_coach.base_parameters import iterable_to_items, TaskParameters, DistributedTaskParameters, \ VisualizationParameters, \ @@ -285,6 +286,13 @@ class GraphManager(object): for environment in self.environments: environment.phase = val + @contextlib.contextmanager + def phase_context(self, phase): + old_phase = self.phase + self.phase = phase + yield + self.phase = old_phase + def set_session(self, sess) -> None: """ Set the deep learning framework session for all the modules in the graph @@ -301,20 +309,17 @@ class GraphManager(object): self.verify_graph_was_created() if steps.num_steps > 0: - self.phase = RunPhase.HEATUP - screen.log_title("{}: Starting heatup".format(self.name)) - self.heatup_start_time = time.time() + with self.phase_context(RunPhase.HEATUP): + screen.log_title("{}: Starting heatup".format(self.name)) + self.heatup_start_time = time.time() - # reset all the levels before starting to heatup - self.reset_internal_state(force_environment_reset=True) + # reset all the levels before starting to heatup + self.reset_internal_state(force_environment_reset=True) - # act for at least steps, though don't interrupt an episode - count_end = self.total_steps_counters[self.phase][EnvironmentSteps] + steps.num_steps - while self.total_steps_counters[self.phase][steps.__class__] < count_end: - self.act(steps, continue_until_game_over=True, return_on_game_over=True) - - # training phase - self.phase = RunPhase.UNDEFINED + # act for at least steps, though don't interrupt an episode + count_end = self.total_steps_counters[self.phase][EnvironmentSteps] + steps.num_steps + while self.total_steps_counters[self.phase][steps.__class__] < count_end: + self.act(steps, continue_until_game_over=True, return_on_game_over=True) def handle_episode_ended(self) -> None: """ @@ -333,8 +338,8 @@ class GraphManager(object): """ self.verify_graph_was_created() - [manager.train() for manager in self.level_managers] - + with self.phase_context(RunPhase.TRAIN): + [manager.train() for manager in self.level_managers] def reset_internal_state(self, force_environment_reset=False) -> None: """ @@ -417,18 +422,17 @@ class GraphManager(object): # perform several steps of training interleaved with acting if steps.num_steps > 0: - self.phase = RunPhase.TRAIN - self.reset_internal_state(force_environment_reset=True) - # TODO - the below while loop should end with full episodes, so to avoid situations where we have partial - # episodes in memory - count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps - while self.total_steps_counters[self.phase][steps.__class__] < count_end: - # The actual steps being done on the environment are decided by the agents themselves. - # This is just an high-level controller. - self.act(EnvironmentSteps(1)) - self.train() - self.occasionally_save_checkpoint() - self.phase = RunPhase.UNDEFINED + with self.phase_context(RunPhase.TRAIN): + self.reset_internal_state(force_environment_reset=True) + #TODO - the below while loop should end with full episodes, so to avoid situations where we have partial + # episodes in memory + count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps + while self.total_steps_counters[self.phase][steps.__class__] < count_end: + # The actual steps being done on the environment are decided by the agents themselves. + # This is just an high-level controller. + self.act(EnvironmentSteps(1)) + self.train() + self.occasionally_save_checkpoint() def sync_graph(self) -> None: """ diff --git a/rl_coach/rollout_worker.py b/rl_coach/rollout_worker.py index d2a97db..ea82c2a 100644 --- a/rl_coach/rollout_worker.py +++ b/rl_coach/rollout_worker.py @@ -89,36 +89,33 @@ def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers, polic task_parameters.__dict__['checkpoint_restore_dir'] = checkpoint_dir time.sleep(30) graph_manager.create_graph(task_parameters) - graph_manager.phase = RunPhase.TRAIN + with graph_manager.phase_context(RunPhase.TRAIN): + error_compensation = 100 - error_compensation = 100 + last_checkpoint = 0 - last_checkpoint = 0 + act_steps = math.ceil((graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps + error_compensation)/num_workers) - act_steps = math.ceil((graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps + error_compensation)/num_workers) + for i in range(int(graph_manager.improve_steps.num_steps/act_steps)): - for i in range(int(graph_manager.improve_steps.num_steps/act_steps)): + graph_manager.act(EnvironmentSteps(num_steps=act_steps)) - graph_manager.act(EnvironmentSteps(num_steps=act_steps)) + new_checkpoint = get_latest_checkpoint(checkpoint_dir) - new_checkpoint = get_latest_checkpoint(checkpoint_dir) + if policy_type == 'ON': + while new_checkpoint < last_checkpoint + 1: + if data_store: + data_store.load_from_store() + new_checkpoint = get_latest_checkpoint(checkpoint_dir) - if policy_type == 'ON': - while new_checkpoint < last_checkpoint + 1: - if data_store: - data_store.load_from_store() - new_checkpoint = get_latest_checkpoint(checkpoint_dir) - - graph_manager.restore_checkpoint() - - if policy_type == "OFF": - - if new_checkpoint > last_checkpoint: graph_manager.restore_checkpoint() - last_checkpoint = new_checkpoint + if policy_type == "OFF": - graph_manager.phase = RunPhase.UNDEFINED + if new_checkpoint > last_checkpoint: + graph_manager.restore_checkpoint() + + last_checkpoint = new_checkpoint def main(): diff --git a/rl_coach/tests/graph_managers/test_graph_manager.py b/rl_coach/tests/graph_managers/test_graph_manager.py new file mode 100644 index 0000000..45c7c8e --- /dev/null +++ b/rl_coach/tests/graph_managers/test_graph_manager.py @@ -0,0 +1,47 @@ +import os +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) + +import pytest +from rl_coach.graph_managers.graph_manager import GraphManager, ScheduleParameters +from rl_coach.base_parameters import VisualizationParameters +from rl_coach.core_types import RunPhase + + +@pytest.mark.unit_test +def test_phase_context(): + graph_manager = GraphManager(name='', schedule_params=ScheduleParameters(), vis_params=VisualizationParameters()) + + assert graph_manager.phase == RunPhase.UNDEFINED + with graph_manager.phase_context(RunPhase.TRAIN): + assert graph_manager.phase == RunPhase.TRAIN + assert graph_manager.phase == RunPhase.UNDEFINED + + +@pytest.mark.unit_test +def test_phase_context_nested(): + graph_manager = GraphManager(name='', schedule_params=ScheduleParameters(), vis_params=VisualizationParameters()) + + assert graph_manager.phase == RunPhase.UNDEFINED + with graph_manager.phase_context(RunPhase.TRAIN): + assert graph_manager.phase == RunPhase.TRAIN + with graph_manager.phase_context(RunPhase.TEST): + assert graph_manager.phase == RunPhase.TEST + assert graph_manager.phase == RunPhase.TRAIN + assert graph_manager.phase == RunPhase.UNDEFINED + + +@pytest.mark.unit_test +def test_phase_context_double_nested(): + graph_manager = GraphManager(name='', schedule_params=ScheduleParameters(), vis_params=VisualizationParameters()) + + assert graph_manager.phase == RunPhase.UNDEFINED + with graph_manager.phase_context(RunPhase.TRAIN): + assert graph_manager.phase == RunPhase.TRAIN + with graph_manager.phase_context(RunPhase.TEST): + assert graph_manager.phase == RunPhase.TEST + with graph_manager.phase_context(RunPhase.UNDEFINED): + assert graph_manager.phase == RunPhase.UNDEFINED + assert graph_manager.phase == RunPhase.TEST + assert graph_manager.phase == RunPhase.TRAIN + assert graph_manager.phase == RunPhase.UNDEFINED