1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

removing some of the presets from the trace tests + more robust replay buffer loading

This commit is contained in:
itaicaspi-intel
2018-09-12 15:25:13 +03:00
parent 171fe97a3a
commit a16d724963
10 changed files with 59 additions and 20 deletions

View File

@@ -167,6 +167,7 @@ class PresetValidationParameters(Parameters):
self.max_episodes_to_achieve_reward = 1 self.max_episodes_to_achieve_reward = 1
self.num_workers = 1 self.num_workers = 1
self.reward_test_level = None self.reward_test_level = None
self.test_using_a_trace_test = True
self.trace_test_levels = None self.trace_test_levels = None
self.trace_max_env_steps = 5000 self.trace_max_env_steps = 5000

View File

@@ -16,6 +16,8 @@
from typing import List, Tuple, Union, Dict, Any from typing import List, Tuple, Union, Dict, Any
import pickle import pickle
import sys
import time
import numpy as np import numpy as np
@@ -235,5 +237,17 @@ class ExperienceReplay(Memory):
:param file_path: The path to a pickle file to restore :param file_path: The path to a pickle file to restore
""" """
with open(file_path, 'rb') as file: with open(file_path, 'rb') as file:
self.transitions = pickle.load(file) transitions = pickle.load(file)
self._num_transitions = len(self.transitions) num_transitions = len(transitions)
start_time = time.time()
for transition_idx, transition in enumerate(transitions):
self.store(transition)
# print progress
if transition_idx % 100 == 0:
percentage = int((100 * transition_idx) / num_transitions)
sys.stdout.write("\rProgress: ({}/{})".format(transition_idx, num_transitions))
sys.stdout.write(' Time (sec): {}'.format(round(time.time() - start_time, 2)))
sys.stdout.write(' {}%|{}{}| '.format(percentage, '#' * int(percentage / 10),
' ' * (10 - int(percentage / 10))))
sys.stdout.flush()

View File

@@ -39,7 +39,7 @@ vis_params.dump_mp4 = False
# Test # # Test #
######## ########
preset_validation_params = PresetValidationParameters() preset_validation_params = PresetValidationParameters()
preset_validation_params.trace_test_levels = ['breakout', 'pong', 'space_invaders'] preset_validation_params.test_using_a_trace_test = False
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params, graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
schedule_params=schedule_params, vis_params=vis_params, schedule_params=schedule_params, vis_params=vis_params,

View File

@@ -50,6 +50,7 @@ preset_validation_params = PresetValidationParameters()
preset_validation_params.test = True preset_validation_params.test = True
preset_validation_params.min_reward_threshold = 150 preset_validation_params.min_reward_threshold = 150
preset_validation_params.max_episodes_to_achieve_reward = 300 preset_validation_params.max_episodes_to_achieve_reward = 300
preset_validation_params.test_using_a_trace_test = False
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params, graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
schedule_params=schedule_params, vis_params=vis_params, schedule_params=schedule_params, vis_params=vis_params,

View File

@@ -1,5 +1,5 @@
from rl_coach.agents.bc_agent import BCAgentParameters from rl_coach.agents.bc_agent import BCAgentParameters
from rl_coach.base_parameters import VisualizationParameters from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
from rl_coach.environments.doom_environment import DoomEnvironmentParameters from rl_coach.environments.doom_environment import DoomEnvironmentParameters
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
@@ -38,5 +38,12 @@ agent_params.memory.load_memory_from_file_path = 'datasets/doom_basic.p'
env_params = DoomEnvironmentParameters() env_params = DoomEnvironmentParameters()
env_params.level = 'basic' env_params.level = 'basic'
########
# Test #
########
preset_validation_params = PresetValidationParameters()
preset_validation_params.test_using_a_trace_test = False
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params, graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
schedule_params=schedule_params, vis_params=VisualizationParameters()) schedule_params=schedule_params, vis_params=VisualizationParameters(),
preset_validation_params=preset_validation_params)

View File

@@ -68,6 +68,7 @@ preset_validation_params.test = True
# reward threshold was set to 1000 since otherwise the test takes about an hour # reward threshold was set to 1000 since otherwise the test takes about an hour
preset_validation_params.min_reward_threshold = 1000 preset_validation_params.min_reward_threshold = 1000
preset_validation_params.max_episodes_to_achieve_reward = 70 preset_validation_params.max_episodes_to_achieve_reward = 70
preset_validation_params.test_using_a_trace_test = False
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params, graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
schedule_params=schedule_params, vis_params=vis_params, schedule_params=schedule_params, vis_params=vis_params,

View File

@@ -45,7 +45,7 @@ vis_params.dump_mp4 = False
# Test # # Test #
######## ########
preset_validation_params = PresetValidationParameters() preset_validation_params = PresetValidationParameters()
preset_validation_params.test_using_a_trace_test = False
# disabling this test for now, as it takes too long to converge # disabling this test for now, as it takes too long to converge
# preset_validation_params.test = True # preset_validation_params.test = True
# preset_validation_params.min_reward_threshold = 1000 # preset_validation_params.min_reward_threshold = 1000

View File

@@ -1,5 +1,6 @@
from rl_coach.agents.dfp_agent import DFPAgentParameters from rl_coach.agents.dfp_agent import DFPAgentParameters
from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme, MiddlewareScheme from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme, MiddlewareScheme, \
PresetValidationParameters
from rl_coach.core_types import EnvironmentSteps, RunPhase, EnvironmentEpisodes from rl_coach.core_types import EnvironmentSteps, RunPhase, EnvironmentEpisodes
from rl_coach.environments.doom_environment import DoomEnvironmentParameters from rl_coach.environments.doom_environment import DoomEnvironmentParameters
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
@@ -59,6 +60,12 @@ vis_params = VisualizationParameters()
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()] vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
vis_params.dump_mp4 = False vis_params.dump_mp4 = False
########
# Test #
########
preset_validation_params = PresetValidationParameters()
preset_validation_params.test_using_a_trace_test = False
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params, graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
schedule_params=schedule_params, vis_params=vis_params, schedule_params=schedule_params, vis_params=vis_params,
) preset_validation_params=preset_validation_params)

View File

@@ -1,5 +1,5 @@
from rl_coach.agents.bc_agent import BCAgentParameters from rl_coach.agents.bc_agent import BCAgentParameters
from rl_coach.base_parameters import VisualizationParameters from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod
from rl_coach.environments.gym_environment import Atari from rl_coach.environments.gym_environment import Atari
@@ -39,5 +39,12 @@ vis_params = VisualizationParameters()
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()] vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
vis_params.dump_mp4 = False vis_params.dump_mp4 = False
########
# Test #
########
preset_validation_params = PresetValidationParameters()
preset_validation_params.test_using_a_trace_test = False
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params, graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
schedule_params=schedule_params, vis_params=vis_params) schedule_params=schedule_params, vis_params=vis_params,
preset_validation_params=preset_validation_params)

View File

@@ -224,6 +224,7 @@ def main():
preset_validation_params = preset.graph_manager.preset_validation_params preset_validation_params = preset.graph_manager.preset_validation_params
num_env_steps = preset_validation_params.trace_max_env_steps num_env_steps = preset_validation_params.trace_max_env_steps
if preset_validation_params.test_using_a_trace_test:
if preset_validation_params.trace_test_levels: if preset_validation_params.trace_test_levels:
for level in preset_validation_params.trace_test_levels: for level in preset_validation_params.trace_test_levels:
test_count += 1 test_count += 1