From a16d7249633a4069aaa9c218abbcf1e8226c5ab4 Mon Sep 17 00:00:00 2001 From: itaicaspi-intel Date: Wed, 12 Sep 2018 15:25:13 +0300 Subject: [PATCH] removing some of the presets from the trace tests + more robust replay buffer loading --- rl_coach/base_parameters.py | 1 + .../non_episodic/experience_replay.py | 18 ++++++++++++++-- rl_coach/presets/Atari_NEC.py | 2 +- rl_coach/presets/CartPole_NEC.py | 1 + rl_coach/presets/Doom_Basic_BC.py | 11 ++++++++-- rl_coach/presets/Doom_Health_DFP.py | 1 + rl_coach/presets/Doom_Health_MMC.py | 2 +- rl_coach/presets/Doom_Health_Supreme_DFP.py | 11 ++++++++-- rl_coach/presets/MontezumaRevenge_BC.py | 11 ++++++++-- rl_coach/tests/trace_tests.py | 21 ++++++++++--------- 10 files changed, 59 insertions(+), 20 deletions(-) diff --git a/rl_coach/base_parameters.py b/rl_coach/base_parameters.py index dd70969..0ee4359 100644 --- a/rl_coach/base_parameters.py +++ b/rl_coach/base_parameters.py @@ -167,6 +167,7 @@ class PresetValidationParameters(Parameters): self.max_episodes_to_achieve_reward = 1 self.num_workers = 1 self.reward_test_level = None + self.test_using_a_trace_test = True self.trace_test_levels = None self.trace_max_env_steps = 5000 diff --git a/rl_coach/memories/non_episodic/experience_replay.py b/rl_coach/memories/non_episodic/experience_replay.py index 92a0ae4..415706b 100644 --- a/rl_coach/memories/non_episodic/experience_replay.py +++ b/rl_coach/memories/non_episodic/experience_replay.py @@ -16,6 +16,8 @@ from typing import List, Tuple, Union, Dict, Any import pickle +import sys +import time import numpy as np @@ -235,5 +237,17 @@ class ExperienceReplay(Memory): :param file_path: The path to a pickle file to restore """ with open(file_path, 'rb') as file: - self.transitions = pickle.load(file) - self._num_transitions = len(self.transitions) + transitions = pickle.load(file) + num_transitions = len(transitions) + start_time = time.time() + for transition_idx, transition in enumerate(transitions): + self.store(transition) + + # print progress + if transition_idx % 100 == 0: + percentage = int((100 * transition_idx) / num_transitions) + sys.stdout.write("\rProgress: ({}/{})".format(transition_idx, num_transitions)) + sys.stdout.write(' Time (sec): {}'.format(round(time.time() - start_time, 2))) + sys.stdout.write(' {}%|{}{}| '.format(percentage, '#' * int(percentage / 10), + ' ' * (10 - int(percentage / 10)))) + sys.stdout.flush() diff --git a/rl_coach/presets/Atari_NEC.py b/rl_coach/presets/Atari_NEC.py index 158e408..9265805 100644 --- a/rl_coach/presets/Atari_NEC.py +++ b/rl_coach/presets/Atari_NEC.py @@ -39,7 +39,7 @@ vis_params.dump_mp4 = False # Test # ######## preset_validation_params = PresetValidationParameters() -preset_validation_params.trace_test_levels = ['breakout', 'pong', 'space_invaders'] +preset_validation_params.test_using_a_trace_test = False graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params, schedule_params=schedule_params, vis_params=vis_params, diff --git a/rl_coach/presets/CartPole_NEC.py b/rl_coach/presets/CartPole_NEC.py index 7a961aa..cf00bd1 100644 --- a/rl_coach/presets/CartPole_NEC.py +++ b/rl_coach/presets/CartPole_NEC.py @@ -50,6 +50,7 @@ preset_validation_params = PresetValidationParameters() preset_validation_params.test = True preset_validation_params.min_reward_threshold = 150 preset_validation_params.max_episodes_to_achieve_reward = 300 +preset_validation_params.test_using_a_trace_test = False graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params, schedule_params=schedule_params, vis_params=vis_params, diff --git a/rl_coach/presets/Doom_Basic_BC.py b/rl_coach/presets/Doom_Basic_BC.py index cf5e16c..0cf30ab 100644 --- a/rl_coach/presets/Doom_Basic_BC.py +++ b/rl_coach/presets/Doom_Basic_BC.py @@ -1,5 +1,5 @@ from rl_coach.agents.bc_agent import BCAgentParameters -from rl_coach.base_parameters import VisualizationParameters +from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps from rl_coach.environments.doom_environment import DoomEnvironmentParameters from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager @@ -38,5 +38,12 @@ agent_params.memory.load_memory_from_file_path = 'datasets/doom_basic.p' env_params = DoomEnvironmentParameters() env_params.level = 'basic' +######## +# Test # +######## +preset_validation_params = PresetValidationParameters() +preset_validation_params.test_using_a_trace_test = False + graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params, - schedule_params=schedule_params, vis_params=VisualizationParameters()) + schedule_params=schedule_params, vis_params=VisualizationParameters(), + preset_validation_params=preset_validation_params) diff --git a/rl_coach/presets/Doom_Health_DFP.py b/rl_coach/presets/Doom_Health_DFP.py index 259d958..863b9f8 100644 --- a/rl_coach/presets/Doom_Health_DFP.py +++ b/rl_coach/presets/Doom_Health_DFP.py @@ -68,6 +68,7 @@ preset_validation_params.test = True # reward threshold was set to 1000 since otherwise the test takes about an hour preset_validation_params.min_reward_threshold = 1000 preset_validation_params.max_episodes_to_achieve_reward = 70 +preset_validation_params.test_using_a_trace_test = False graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params, schedule_params=schedule_params, vis_params=vis_params, diff --git a/rl_coach/presets/Doom_Health_MMC.py b/rl_coach/presets/Doom_Health_MMC.py index 1a8210e..c146b2f 100644 --- a/rl_coach/presets/Doom_Health_MMC.py +++ b/rl_coach/presets/Doom_Health_MMC.py @@ -45,7 +45,7 @@ vis_params.dump_mp4 = False # Test # ######## preset_validation_params = PresetValidationParameters() - +preset_validation_params.test_using_a_trace_test = False # disabling this test for now, as it takes too long to converge # preset_validation_params.test = True # preset_validation_params.min_reward_threshold = 1000 diff --git a/rl_coach/presets/Doom_Health_Supreme_DFP.py b/rl_coach/presets/Doom_Health_Supreme_DFP.py index 1bb3af9..4316add 100644 --- a/rl_coach/presets/Doom_Health_Supreme_DFP.py +++ b/rl_coach/presets/Doom_Health_Supreme_DFP.py @@ -1,5 +1,6 @@ from rl_coach.agents.dfp_agent import DFPAgentParameters -from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme, MiddlewareScheme +from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme, MiddlewareScheme, \ + PresetValidationParameters from rl_coach.core_types import EnvironmentSteps, RunPhase, EnvironmentEpisodes from rl_coach.environments.doom_environment import DoomEnvironmentParameters from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod @@ -59,6 +60,12 @@ vis_params = VisualizationParameters() vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()] vis_params.dump_mp4 = False +######## +# Test # +######## +preset_validation_params = PresetValidationParameters() +preset_validation_params.test_using_a_trace_test = False + graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params, schedule_params=schedule_params, vis_params=vis_params, - ) + preset_validation_params=preset_validation_params) diff --git a/rl_coach/presets/MontezumaRevenge_BC.py b/rl_coach/presets/MontezumaRevenge_BC.py index 0e026ca..5fbd13a 100644 --- a/rl_coach/presets/MontezumaRevenge_BC.py +++ b/rl_coach/presets/MontezumaRevenge_BC.py @@ -1,5 +1,5 @@ from rl_coach.agents.bc_agent import BCAgentParameters -from rl_coach.base_parameters import VisualizationParameters +from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod from rl_coach.environments.gym_environment import Atari @@ -39,5 +39,12 @@ vis_params = VisualizationParameters() vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()] vis_params.dump_mp4 = False +######## +# Test # +######## +preset_validation_params = PresetValidationParameters() +preset_validation_params.test_using_a_trace_test = False + graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params, - schedule_params=schedule_params, vis_params=vis_params) + schedule_params=schedule_params, vis_params=vis_params, + preset_validation_params=preset_validation_params) diff --git a/rl_coach/tests/trace_tests.py b/rl_coach/tests/trace_tests.py index 7a5c32a..176cfe7 100644 --- a/rl_coach/tests/trace_tests.py +++ b/rl_coach/tests/trace_tests.py @@ -224,21 +224,22 @@ def main(): preset_validation_params = preset.graph_manager.preset_validation_params num_env_steps = preset_validation_params.trace_max_env_steps - if preset_validation_params.trace_test_levels: - for level in preset_validation_params.trace_test_levels: + if preset_validation_params.test_using_a_trace_test: + if preset_validation_params.trace_test_levels: + for level in preset_validation_params.trace_test_levels: + test_count += 1 + test_path, log_file, p = run_trace_based_test(preset_name, num_env_steps, level) + processes.append((test_path, log_file, p)) + test_passed = wait_and_check(args, processes) + if test_passed is not None and not test_passed: + fail_count += 1 + else: test_count += 1 - test_path, log_file, p = run_trace_based_test(preset_name, num_env_steps, level) + test_path, log_file, p = run_trace_based_test(preset_name, num_env_steps) processes.append((test_path, log_file, p)) test_passed = wait_and_check(args, processes) if test_passed is not None and not test_passed: fail_count += 1 - else: - test_count += 1 - test_path, log_file, p = run_trace_based_test(preset_name, num_env_steps) - processes.append((test_path, log_file, p)) - test_passed = wait_and_check(args, processes) - if test_passed is not None and not test_passed: - fail_count += 1 while len(processes) > 0: test_passed = wait_and_check(args, processes, force=True)