From c2991819b4ad7af366366b9e6274fc36a1060729 Mon Sep 17 00:00:00 2001 From: Ajay Deshpande Date: Fri, 14 Sep 2018 16:17:34 -0700 Subject: [PATCH] Adding right arguments to the agent --- rl_coach/presets/CartPole_DQN_distributed.py | 90 +++++++++----------- rl_coach/rollout_worker.py | 3 + 2 files changed, 44 insertions(+), 49 deletions(-) diff --git a/rl_coach/presets/CartPole_DQN_distributed.py b/rl_coach/presets/CartPole_DQN_distributed.py index d3e8513..f4259c4 100644 --- a/rl_coach/presets/CartPole_DQN_distributed.py +++ b/rl_coach/presets/CartPole_DQN_distributed.py @@ -9,63 +9,55 @@ from rl_coach.memories.memory import MemoryGranularity from rl_coach.schedules import LinearSchedule -def construct_graph(redis_ip='localhost', redis_port=6379): - #################### - # Graph Scheduling # - #################### - schedule_params = ScheduleParameters() - schedule_params.improve_steps = TrainingSteps(10000000000) - schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(10) - schedule_params.evaluation_steps = EnvironmentEpisodes(1) - schedule_params.heatup_steps = EnvironmentSteps(1000) +#################### +# Graph Scheduling # +#################### - ######### - # Agent # - ######### - agent_params = DQNAgentParametersDistributed() +schedule_params = ScheduleParameters() +schedule_params.improve_steps = TrainingSteps(10000000000) +schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(10) +schedule_params.evaluation_steps = EnvironmentEpisodes(1) +schedule_params.heatup_steps = EnvironmentSteps(1000) - # DQN params - agent_params.algorithm.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(100) - agent_params.algorithm.discount = 0.99 - agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(1) +######### +# Agent # +######### +agent_params = DQNAgentParametersDistributed() - # NN configuration - agent_params.network_wrappers['main'].learning_rate = 0.00025 - agent_params.network_wrappers['main'].replace_mse_with_huber_loss = False +# DQN params +agent_params.algorithm.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(100) +agent_params.algorithm.discount = 0.99 +agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(1) - # ER size - agent_params.memory.max_size = (MemoryGranularity.Transitions, 40000) +# NN configuration +agent_params.network_wrappers['main'].learning_rate = 0.00025 +agent_params.network_wrappers['main'].replace_mse_with_huber_loss = False - # E-Greedy schedule - agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000) +# ER size +agent_params.memory.max_size = (MemoryGranularity.Transitions, 40000) - # Redis parameters - agent_params.memory.redis_ip = redis_ip - agent_params.memory.redis_port = redis_port +# E-Greedy schedule +agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000) - ################ - # Environment # - ################ - env_params = Mujoco() - env_params.level = 'CartPole-v0' +################ +# Environment # +################ +env_params = Mujoco() +env_params.level = 'CartPole-v0' - vis_params = VisualizationParameters() - vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()] - vis_params.dump_mp4 = False +vis_params = VisualizationParameters() +vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()] +vis_params.dump_mp4 = False - ######## - # Test # - ######## - preset_validation_params = PresetValidationParameters() - preset_validation_params.test = True - preset_validation_params.min_reward_threshold = 150 - preset_validation_params.max_episodes_to_achieve_reward = 250 +######## +# Test # +######## +preset_validation_params = PresetValidationParameters() +preset_validation_params.test = True +preset_validation_params.min_reward_threshold = 150 +preset_validation_params.max_episodes_to_achieve_reward = 250 - graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params, - schedule_params=schedule_params, vis_params=vis_params, - preset_validation_params=preset_validation_params) - return graph_manager - - -graph_manager = construct_graph() +graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params, + schedule_params=schedule_params, vis_params=vis_params, + preset_validation_params=preset_validation_params) diff --git a/rl_coach/rollout_worker.py b/rl_coach/rollout_worker.py index f7a29f0..f1c5363 100644 --- a/rl_coach/rollout_worker.py +++ b/rl_coach/rollout_worker.py @@ -40,6 +40,9 @@ def main(): graph_manager = short_dynamic_import(expand_preset(args.preset), ignore_module_case=True) + graph_manager.agent_parameters.memory.redis_ip = args.redis_ip + graph_manager.agent_params.memory.redis_port = args.redis_port + rollout_worker( graph_manager=graph_manager, checkpoint_dir=args.checkpoint_dir,