mirror of
https://github.com/gryf/coach.git
synced 2026-03-17 07:13:37 +01:00
network_imporvements branch merge
This commit is contained in:
@@ -1,10 +1,9 @@
|
||||
from rl_coach.agents.actor_critic_agent import ActorCriticAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import SingleLevelSelection, SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4
|
||||
from rl_coach.exploration_policies.categorical import CategoricalParameters
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
|
||||
@@ -29,17 +28,10 @@ agent_params.algorithm.beta_entropy = 0.05
|
||||
agent_params.network_wrappers['main'].middleware_parameters = FCMiddlewareParameters()
|
||||
agent_params.network_wrappers['main'].learning_rate = 0.0001
|
||||
|
||||
agent_params.exploration = CategoricalParameters()
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Atari()
|
||||
env_params.level = SingleLevelSelection(atari_deterministic_v4)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = Atari(level=SingleLevelSelection(atari_deterministic_v4))
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -48,5 +40,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.trace_test_levels = ['breakout', 'pong', 'space_invaders']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from rl_coach.agents.actor_critic_agent import ActorCriticAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.middlewares.lstm_middleware import LSTMMiddlewareParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, MiddlewareScheme, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import SingleLevelSelection, SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4, AtariInputFilter
|
||||
from rl_coach.exploration_policies.categorical import CategoricalParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
|
||||
@@ -29,17 +28,11 @@ agent_params.algorithm.beta_entropy = 0.05
|
||||
agent_params.network_wrappers['main'].learning_rate = 0.0001
|
||||
agent_params.network_wrappers['main'].middleware_parameters = LSTMMiddlewareParameters(scheme=MiddlewareScheme.Medium,
|
||||
number_of_lstm_cells=256)
|
||||
agent_params.exploration = CategoricalParameters()
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Atari()
|
||||
env_params.level = SingleLevelSelection(atari_deterministic_v4)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = True
|
||||
env_params = Atari(level=SingleLevelSelection(atari_deterministic_v4))
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -48,5 +41,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.trace_test_levels = ['breakout', 'pong', 'space_invaders']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from rl_coach.agents.bootstrapped_dqn_agent import BootstrappedDQNAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
|
||||
from rl_coach.core_types import EnvironmentSteps
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
@@ -25,12 +25,7 @@ agent_params.network_wrappers['main'].learning_rate = 0.00025
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Atari()
|
||||
env_params.level = SingleLevelSelection(atari_deterministic_v4)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = Atari(level=SingleLevelSelection(atari_deterministic_v4))
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -39,5 +34,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.trace_test_levels = ['breakout', 'pong', 'space_invaders']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,19 +1,8 @@
|
||||
from rl_coach.agents.categorical_dqn_agent import CategoricalDQNAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4, atari_schedule
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
|
||||
####################
|
||||
# Graph Scheduling #
|
||||
####################
|
||||
schedule_params = ScheduleParameters()
|
||||
schedule_params.improve_steps = EnvironmentSteps(50000000)
|
||||
schedule_params.steps_between_evaluation_periods = EnvironmentSteps(250000)
|
||||
schedule_params.evaluation_steps = EnvironmentSteps(135000)
|
||||
schedule_params.heatup_steps = EnvironmentSteps(50000)
|
||||
|
||||
#########
|
||||
# Agent #
|
||||
@@ -24,12 +13,7 @@ agent_params.network_wrappers['main'].learning_rate = 0.00025
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Atari()
|
||||
env_params.level = SingleLevelSelection(atari_deterministic_v4)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = Atari(level=SingleLevelSelection(atari_deterministic_v4))
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -38,5 +22,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.trace_test_levels = ['breakout', 'pong', 'space_invaders']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=atari_schedule, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,19 +1,8 @@
|
||||
from rl_coach.agents.ddqn_agent import DDQNAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4, atari_schedule
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
|
||||
####################
|
||||
# Graph Scheduling #
|
||||
####################
|
||||
schedule_params = ScheduleParameters()
|
||||
schedule_params.improve_steps = EnvironmentSteps(50000000)
|
||||
schedule_params.steps_between_evaluation_periods = EnvironmentSteps(250000)
|
||||
schedule_params.evaluation_steps = EnvironmentSteps(135000)
|
||||
schedule_params.heatup_steps = EnvironmentSteps(50000)
|
||||
|
||||
#########
|
||||
# Agent #
|
||||
@@ -24,12 +13,7 @@ agent_params.network_wrappers['main'].learning_rate = 0.00025
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Atari()
|
||||
env_params.level = SingleLevelSelection(atari_deterministic_v4)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = Atari(level=SingleLevelSelection(atari_deterministic_v4))
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -38,5 +22,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.trace_test_levels = ['breakout', 'pong', 'space_invaders']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=atari_schedule, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
@@ -1,23 +1,11 @@
|
||||
from rl_coach.agents.ddqn_agent import DDQNAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4, atari_schedule
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.memories.non_episodic.prioritized_experience_replay import PrioritizedExperienceReplayParameters
|
||||
from rl_coach.schedules import LinearSchedule
|
||||
|
||||
####################
|
||||
# Graph Scheduling #
|
||||
####################
|
||||
|
||||
schedule_params = ScheduleParameters()
|
||||
schedule_params.improve_steps = EnvironmentSteps(50000000)
|
||||
schedule_params.steps_between_evaluation_periods = EnvironmentSteps(250000)
|
||||
schedule_params.evaluation_steps = EnvironmentSteps(135000)
|
||||
schedule_params.heatup_steps = EnvironmentSteps(50000)
|
||||
|
||||
#########
|
||||
# Agent #
|
||||
#########
|
||||
@@ -29,12 +17,7 @@ agent_params.memory.beta = LinearSchedule(0.4, 1, 12500000) # 12.5M training it
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Atari()
|
||||
env_params.level = SingleLevelSelection(atari_deterministic_v4)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = Atari(level=SingleLevelSelection(atari_deterministic_v4))
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -43,5 +26,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.trace_test_levels = ['breakout', 'pong', 'space_invaders']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=atari_schedule, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,19 +1,8 @@
|
||||
from rl_coach.agents.dqn_agent import DQNAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4, atari_schedule
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
|
||||
####################
|
||||
# Graph Scheduling #
|
||||
####################
|
||||
schedule_params = ScheduleParameters()
|
||||
schedule_params.improve_steps = EnvironmentSteps(50000000)
|
||||
schedule_params.steps_between_evaluation_periods = EnvironmentSteps(250000)
|
||||
schedule_params.evaluation_steps = EnvironmentSteps(135000)
|
||||
schedule_params.heatup_steps = EnvironmentSteps(50000)
|
||||
|
||||
#########
|
||||
# Agent #
|
||||
@@ -25,12 +14,7 @@ agent_params.network_wrappers['main'].learning_rate = 0.0001
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Atari()
|
||||
env_params.level = SingleLevelSelection(atari_deterministic_v4)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = Atari(level=SingleLevelSelection(atari_deterministic_v4))
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -39,5 +23,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.trace_test_levels = ['breakout', 'pong', 'space_invaders']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=atari_schedule, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,21 +1,11 @@
|
||||
from rl_coach.agents.dqn_agent import DQNAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4, atari_schedule
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.memories.non_episodic.prioritized_experience_replay import PrioritizedExperienceReplayParameters
|
||||
from rl_coach.schedules import LinearSchedule
|
||||
|
||||
####################
|
||||
# Graph Scheduling #
|
||||
####################
|
||||
schedule_params = ScheduleParameters()
|
||||
schedule_params.improve_steps = EnvironmentSteps(50000000)
|
||||
schedule_params.steps_between_evaluation_periods = EnvironmentSteps(250000)
|
||||
schedule_params.evaluation_steps = EnvironmentSteps(135000)
|
||||
schedule_params.heatup_steps = EnvironmentSteps(50000)
|
||||
|
||||
#########
|
||||
# Agent #
|
||||
@@ -28,12 +18,7 @@ agent_params.memory.beta = LinearSchedule(0.4, 1, 12500000) # 12.5M training it
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Atari()
|
||||
env_params.level = SingleLevelSelection(atari_deterministic_v4)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = Atari(level=SingleLevelSelection(atari_deterministic_v4))
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -42,5 +27,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.trace_test_levels = ['breakout', 'pong', 'space_invaders']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=atari_schedule, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -3,20 +3,9 @@ import math
|
||||
from rl_coach.agents.ddqn_agent import DDQNAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.heads.dueling_q_head import DuelingQHeadParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, MiddlewareScheme, PresetValidationParameters
|
||||
from rl_coach.core_types import EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4, atari_schedule
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
|
||||
####################
|
||||
# Graph Scheduling #
|
||||
####################
|
||||
schedule_params = ScheduleParameters()
|
||||
schedule_params.improve_steps = EnvironmentSteps(50000000)
|
||||
schedule_params.steps_between_evaluation_periods = EnvironmentSteps(250000)
|
||||
schedule_params.evaluation_steps = EnvironmentSteps(135000)
|
||||
schedule_params.heatup_steps = EnvironmentSteps(50000)
|
||||
|
||||
#########
|
||||
# Agent #
|
||||
@@ -26,19 +15,14 @@ agent_params = DDQNAgentParameters()
|
||||
# since we are using Adam instead of RMSProp, we adjust the learning rate as well
|
||||
agent_params.network_wrappers['main'].learning_rate = 0.0001
|
||||
agent_params.network_wrappers['main'].middleware_parameters.scheme = MiddlewareScheme.Empty
|
||||
agent_params.network_wrappers['main'].heads_parameters = [DuelingQHeadParameters()]
|
||||
agent_params.network_wrappers['main'].rescale_gradient_from_head_by_factor = [1/math.sqrt(2)]
|
||||
agent_params.network_wrappers['main'].heads_parameters = \
|
||||
[DuelingQHeadParameters(rescale_gradient_from_head_by_factor=1/math.sqrt(2))]
|
||||
agent_params.network_wrappers['main'].clip_gradients = 10
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Atari()
|
||||
env_params.level = SingleLevelSelection(atari_deterministic_v4)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = Atari(level=SingleLevelSelection(atari_deterministic_v4))
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -47,5 +31,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.trace_test_levels = ['breakout', 'pong', 'space_invaders']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=atari_schedule, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,22 +1,13 @@
|
||||
from rl_coach.agents.ddqn_agent import DDQNAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.heads.dueling_q_head import DuelingQHeadParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, MiddlewareScheme, PresetValidationParameters
|
||||
from rl_coach.core_types import EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4
|
||||
from rl_coach.core_types import EnvironmentSteps
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4, atari_schedule
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.memories.non_episodic.prioritized_experience_replay import PrioritizedExperienceReplayParameters
|
||||
from rl_coach.schedules import LinearSchedule, PieceWiseSchedule, ConstantSchedule
|
||||
|
||||
####################
|
||||
# Graph Scheduling #
|
||||
####################
|
||||
schedule_params = ScheduleParameters()
|
||||
schedule_params.improve_steps = EnvironmentSteps(50000000)
|
||||
schedule_params.steps_between_evaluation_periods = EnvironmentSteps(250000)
|
||||
schedule_params.evaluation_steps = EnvironmentSteps(135000)
|
||||
schedule_params.heatup_steps = EnvironmentSteps(50000)
|
||||
|
||||
#########
|
||||
# Agent #
|
||||
@@ -38,12 +29,7 @@ agent_params.memory.beta = LinearSchedule(0.4, 1, 12500000) # 12.5M training it
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Atari()
|
||||
env_params.level = SingleLevelSelection(atari_deterministic_v4)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = Atari(level=SingleLevelSelection(atari_deterministic_v4))
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -52,5 +38,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.trace_test_levels = ['breakout', 'pong', 'space_invaders']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=atari_schedule, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from rl_coach.agents.nec_agent import NECAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import SingleLevelSelection, SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.core_types import EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, AtariInputFilter, atari_deterministic_v4
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
@@ -27,14 +27,9 @@ agent_params.input_filter.remove_reward_filter('clipping')
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Atari()
|
||||
env_params.level = SingleLevelSelection(atari_deterministic_v4)
|
||||
env_params = Atari(level=SingleLevelSelection(atari_deterministic_v4))
|
||||
env_params.random_initialization_steps = 1
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
|
||||
########
|
||||
# Test #
|
||||
########
|
||||
@@ -42,5 +37,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.test_using_a_trace_test = False
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from rl_coach.agents.n_step_q_agent import NStepQAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Conv2d, Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import Conv2d, Dense
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import SingleLevelSelection, SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
@@ -22,19 +22,14 @@ schedule_params.heatup_steps = EnvironmentSteps(0)
|
||||
agent_params = NStepQAgentParameters()
|
||||
|
||||
agent_params.network_wrappers['main'].learning_rate = 0.0001
|
||||
agent_params.network_wrappers['main'].input_embedders_parameters['observation'].scheme = [Conv2d([16, 8, 4]),
|
||||
Conv2d([32, 4, 2])]
|
||||
agent_params.network_wrappers['main'].middleware_parameters.scheme = [Dense([256])]
|
||||
agent_params.network_wrappers['main'].input_embedders_parameters['observation'].scheme = [Conv2d(16, 8, 4),
|
||||
Conv2d(32, 4, 2)]
|
||||
agent_params.network_wrappers['main'].middleware_parameters.scheme = [Dense(256)]
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Atari()
|
||||
env_params.level = SingleLevelSelection(atari_deterministic_v4)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = Atari(level=SingleLevelSelection(atari_deterministic_v4))
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -43,5 +38,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.trace_test_levels = ['breakout', 'pong', 'space_invaders']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,19 +1,8 @@
|
||||
from rl_coach.agents.qr_dqn_agent import QuantileRegressionDQNAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4, atari_schedule
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
|
||||
####################
|
||||
# Graph Scheduling #
|
||||
####################
|
||||
schedule_params = ScheduleParameters()
|
||||
schedule_params.improve_steps = EnvironmentSteps(50000000)
|
||||
schedule_params.steps_between_evaluation_periods = EnvironmentSteps(250000)
|
||||
schedule_params.evaluation_steps = EnvironmentSteps(135000)
|
||||
schedule_params.heatup_steps = EnvironmentSteps(50000)
|
||||
|
||||
#########
|
||||
# Agent #
|
||||
@@ -25,12 +14,7 @@ agent_params.algorithm.huber_loss_interval = 1 # k = 0 for strict quantile loss
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Atari()
|
||||
env_params.level = SingleLevelSelection(atari_deterministic_v4)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = Atari(level=SingleLevelSelection(atari_deterministic_v4))
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -39,5 +23,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.trace_test_levels = ['breakout', 'pong', 'space_invaders']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=atari_schedule, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from rl_coach.agents.rainbow_dqn_agent import RainbowDQNAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
|
||||
from rl_coach.core_types import EnvironmentSteps
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
@@ -30,12 +30,7 @@ agent_params.memory.alpha = 0.5
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Atari()
|
||||
env_params.level = SingleLevelSelection(atari_deterministic_v4)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = Atari(level=SingleLevelSelection(atari_deterministic_v4))
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -44,5 +39,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.trace_test_levels = ['breakout', 'pong', 'space_invaders']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,20 +1,9 @@
|
||||
from rl_coach.agents.bootstrapped_dqn_agent import BootstrappedDQNAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4, atari_schedule
|
||||
from rl_coach.exploration_policies.ucb import UCBParameters
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
|
||||
####################
|
||||
# Graph Scheduling #
|
||||
####################
|
||||
schedule_params = ScheduleParameters()
|
||||
schedule_params.improve_steps = EnvironmentSteps(50000000)
|
||||
schedule_params.steps_between_evaluation_periods = EnvironmentSteps(250000)
|
||||
schedule_params.evaluation_steps = EnvironmentSteps(135000)
|
||||
schedule_params.heatup_steps = EnvironmentSteps(50000)
|
||||
|
||||
#########
|
||||
# Agent #
|
||||
@@ -26,12 +15,7 @@ agent_params.exploration = UCBParameters()
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Atari()
|
||||
env_params.level = SingleLevelSelection(atari_deterministic_v4)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = Atari(level=SingleLevelSelection(atari_deterministic_v4))
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -40,5 +24,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.trace_test_levels = ['breakout', 'pong', 'space_invaders']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=atari_schedule, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
from rl_coach.agents.dqn_agent import DQNAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme, \
|
||||
PresetValidationParameters
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
|
||||
from rl_coach.environments.gym_environment import Mujoco
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.memories.memory import MemoryGranularity
|
||||
@@ -28,7 +27,7 @@ schedule_params.heatup_steps = EnvironmentSteps(0)
|
||||
agent_params = DQNAgentParameters()
|
||||
agent_params.network_wrappers['main'].learning_rate = 0.001
|
||||
agent_params.network_wrappers['main'].batch_size = 128
|
||||
agent_params.network_wrappers['main'].middleware_parameters.scheme = [Dense([256])]
|
||||
agent_params.network_wrappers['main'].middleware_parameters.scheme = [Dense(256)]
|
||||
agent_params.network_wrappers['main'].input_embedders_parameters = {
|
||||
'state': InputEmbedderParameters(scheme=EmbedderScheme.Empty),
|
||||
'desired_goal': InputEmbedderParameters(scheme=EmbedderScheme.Empty)
|
||||
@@ -45,12 +44,10 @@ agent_params.exploration.evaluation_epsilon = 0
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Mujoco()
|
||||
env_params.level = 'rl_coach.environments.toy_problems.bit_flip:BitFlip'
|
||||
env_params = GymVectorEnvironment(level='rl_coach.environments.toy_problems.bit_flip:BitFlip')
|
||||
env_params.additional_simulator_parameters = {'bit_length': bit_length, 'mean_zero': True}
|
||||
# env_params.custom_reward_threshold = -bit_length + 1
|
||||
env_params.custom_reward_threshold = -bit_length + 1
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -61,7 +58,7 @@ preset_validation_params.min_reward_threshold = -7.9
|
||||
preset_validation_params.max_episodes_to_achieve_reward = 10000
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from rl_coach.agents.dqn_agent import DQNAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme, \
|
||||
PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
from rl_coach.environments.gym_environment import Mujoco
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.memories.episodic.episodic_hindsight_experience_replay import \
|
||||
@@ -30,7 +30,7 @@ schedule_params.heatup_steps = EnvironmentSteps(0)
|
||||
agent_params = DQNAgentParameters()
|
||||
agent_params.network_wrappers['main'].learning_rate = 0.001
|
||||
agent_params.network_wrappers['main'].batch_size = 128
|
||||
agent_params.network_wrappers['main'].middleware_parameters.scheme = [Dense([256])]
|
||||
agent_params.network_wrappers['main'].middleware_parameters.scheme = [Dense(256)]
|
||||
agent_params.network_wrappers['main'].input_embedders_parameters = {
|
||||
'state': InputEmbedderParameters(scheme=EmbedderScheme.Empty),
|
||||
'desired_goal': InputEmbedderParameters(scheme=EmbedderScheme.Empty)}
|
||||
@@ -55,13 +55,10 @@ agent_params.memory.goals_space = GoalsSpace(goal_name='state',
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Mujoco()
|
||||
env_params.level = 'rl_coach.environments.toy_problems.bit_flip:BitFlip'
|
||||
env_params = GymVectorEnvironment(level='rl_coach.environments.toy_problems.bit_flip:BitFlip')
|
||||
env_params.additional_simulator_parameters = {'bit_length': bit_length, 'mean_zero': True}
|
||||
env_params.custom_reward_threshold = -bit_length + 1
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
|
||||
# currently no tests for this preset as the max reward can be accidently achieved. will be fixed with trace based tests.
|
||||
|
||||
########
|
||||
@@ -73,7 +70,7 @@ preset_validation_params.min_reward_threshold = -15
|
||||
preset_validation_params.max_episodes_to_achieve_reward = 10000
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
|
||||
|
||||
@@ -2,9 +2,8 @@ import copy
|
||||
|
||||
from rl_coach.agents.ddpg_agent import DDPGAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.carla_environment import CarlaEnvironmentParameters, CameraTypes, CarlaInputFilter
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
|
||||
@@ -49,12 +48,7 @@ agent_params.input_filter.copy_filters_from_one_observation_to_another('forward_
|
||||
# Environment #
|
||||
###############
|
||||
env_params = CarlaEnvironmentParameters()
|
||||
env_params.level = 'town1'
|
||||
env_params.cameras = [CameraTypes.FRONT, CameraTypes.LEFT, CameraTypes.RIGHT]
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params)
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters())
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
import numpy as np
|
||||
import os
|
||||
from logger import screen
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
# make sure you have $CARLA_ROOT/PythonClient in your PYTHONPATH
|
||||
from carla.driving_benchmark.experiment_suites import CoRL2017
|
||||
from rl_coach.logger import screen
|
||||
|
||||
from rl_coach.agents.cil_agent import CILAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Conv2d, Dense
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
from rl_coach.architectures.tensorflow_components.heads.cil_head import RegressionHeadParameters
|
||||
from rl_coach.architectures.tensorflow_components.layers import Conv2d, Dense, BatchnormActivationDropout
|
||||
from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.carla_environment import CarlaEnvironmentParameters, CameraTypes
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.carla_environment import CarlaEnvironmentParameters
|
||||
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
|
||||
from rl_coach.filters.filter import InputFilter
|
||||
from rl_coach.filters.observation.observation_crop_filter import ObservationCropFilter
|
||||
@@ -27,7 +27,6 @@ from rl_coach.schedules import ConstantSchedule
|
||||
from rl_coach.spaces import ImageObservationSpace
|
||||
from rl_coach.utilities.carla_dataset_to_replay_buffer import create_dataset
|
||||
|
||||
|
||||
####################
|
||||
# Graph Scheduling #
|
||||
####################
|
||||
@@ -44,38 +43,64 @@ agent_params = CILAgentParameters()
|
||||
|
||||
# forward camera and measurements input
|
||||
agent_params.network_wrappers['main'].input_embedders_parameters = {
|
||||
'CameraRGB': InputEmbedderParameters(scheme=[Conv2d([32, 5, 2]),
|
||||
Conv2d([32, 3, 1]),
|
||||
Conv2d([64, 3, 2]),
|
||||
Conv2d([64, 3, 1]),
|
||||
Conv2d([128, 3, 2]),
|
||||
Conv2d([128, 3, 1]),
|
||||
Conv2d([256, 3, 1]),
|
||||
Conv2d([256, 3, 1]),
|
||||
Dense([512]),
|
||||
Dense([512])],
|
||||
dropout=True,
|
||||
batchnorm=True),
|
||||
'measurements': InputEmbedderParameters(scheme=[Dense([128]),
|
||||
Dense([128])])
|
||||
'CameraRGB': InputEmbedderParameters(
|
||||
scheme=[
|
||||
Conv2d(32, 5, 2),
|
||||
BatchnormActivationDropout(batchnorm=True, activation_function=tf.tanh),
|
||||
Conv2d(32, 3, 1),
|
||||
BatchnormActivationDropout(batchnorm=True, activation_function=tf.tanh),
|
||||
Conv2d(64, 3, 2),
|
||||
BatchnormActivationDropout(batchnorm=True, activation_function=tf.tanh),
|
||||
Conv2d(64, 3, 1),
|
||||
BatchnormActivationDropout(batchnorm=True, activation_function=tf.tanh),
|
||||
Conv2d(128, 3, 2),
|
||||
BatchnormActivationDropout(batchnorm=True, activation_function=tf.tanh),
|
||||
Conv2d(128, 3, 1),
|
||||
BatchnormActivationDropout(batchnorm=True, activation_function=tf.tanh),
|
||||
Conv2d(256, 3, 1),
|
||||
BatchnormActivationDropout(batchnorm=True, activation_function=tf.tanh),
|
||||
Conv2d(256, 3, 1),
|
||||
BatchnormActivationDropout(batchnorm=True, activation_function=tf.tanh),
|
||||
Dense(512),
|
||||
BatchnormActivationDropout(activation_function=tf.tanh, dropout_rate=0.3),
|
||||
Dense(512),
|
||||
BatchnormActivationDropout(activation_function=tf.tanh, dropout_rate=0.3)
|
||||
],
|
||||
activation_function='none' # we define the activation function for each layer explicitly
|
||||
),
|
||||
'measurements': InputEmbedderParameters(
|
||||
scheme=[
|
||||
Dense(128),
|
||||
BatchnormActivationDropout(activation_function=tf.tanh, dropout_rate=0.5),
|
||||
Dense(128),
|
||||
BatchnormActivationDropout(activation_function=tf.tanh, dropout_rate=0.5)
|
||||
],
|
||||
activation_function='none' # we define the activation function for each layer explicitly
|
||||
)
|
||||
}
|
||||
|
||||
# TODO: batch norm is currently applied to the fc layers which is not desired
|
||||
# TODO: dropout should be configured differenetly per layer [1.0] * 8 + [0.7] * 2 + [0.5] * 2 + [0.5] * 1 + [0.5, 1.] * 5
|
||||
|
||||
# simple fc middleware
|
||||
agent_params.network_wrappers['main'].middleware_parameters = FCMiddlewareParameters(scheme=[Dense([512])])
|
||||
agent_params.network_wrappers['main'].middleware_parameters = \
|
||||
FCMiddlewareParameters(
|
||||
scheme=[
|
||||
Dense(512),
|
||||
BatchnormActivationDropout(activation_function=tf.tanh, dropout_rate=0.5)
|
||||
],
|
||||
activation_function='none'
|
||||
)
|
||||
|
||||
# output branches
|
||||
agent_params.network_wrappers['main'].heads_parameters = [
|
||||
RegressionHeadParameters(),
|
||||
RegressionHeadParameters(),
|
||||
RegressionHeadParameters(),
|
||||
RegressionHeadParameters()
|
||||
RegressionHeadParameters(
|
||||
scheme=[
|
||||
Dense(256),
|
||||
BatchnormActivationDropout(activation_function=tf.tanh, dropout_rate=0.5),
|
||||
Dense(256),
|
||||
BatchnormActivationDropout(activation_function=tf.tanh)
|
||||
],
|
||||
num_output_head_copies=4 # follow lane, left, right, straight
|
||||
)
|
||||
]
|
||||
# agent_params.network_wrappers['main'].num_output_head_copies = 4 # follow lane, left, right, straight
|
||||
agent_params.network_wrappers['main'].rescale_gradient_from_head_by_factor = [1, 1, 1, 1]
|
||||
agent_params.network_wrappers['main'].loss_weights = [1, 1, 1, 1]
|
||||
# TODO: there should be another head predicting the speed which is connected directly to the forward camera embedding
|
||||
|
||||
agent_params.network_wrappers['main'].batch_size = 120
|
||||
@@ -125,7 +150,6 @@ if not os.path.exists(agent_params.memory.load_memory_from_file_path):
|
||||
# Environment #
|
||||
###############
|
||||
env_params = CarlaEnvironmentParameters()
|
||||
env_params.level = 'town1'
|
||||
env_params.cameras = ['CameraRGB']
|
||||
env_params.camera_height = 600
|
||||
env_params.camera_width = 800
|
||||
@@ -134,9 +158,5 @@ env_params.allow_braking = True
|
||||
env_params.quality = CarlaEnvironmentParameters.Quality.EPIC
|
||||
env_params.experiment_suite = CoRL2017('Town01')
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST)]
|
||||
vis_params.dump_mp4 = True
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params)
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters())
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from rl_coach.agents.ddpg_agent import DDPGAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.carla_environment import CarlaEnvironmentParameters
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
|
||||
@@ -29,11 +28,6 @@ agent_params.network_wrappers['critic'].input_embedders_parameters['forward_came
|
||||
# Environment #
|
||||
###############
|
||||
env_params = CarlaEnvironmentParameters()
|
||||
env_params.level = 'town1'
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params)
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters())
|
||||
|
||||
@@ -3,9 +3,8 @@ import math
|
||||
from rl_coach.agents.ddqn_agent import DDQNAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.heads.dueling_q_head import DuelingQHeadParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, MiddlewareScheme
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.carla_environment import CarlaEnvironmentParameters
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod
|
||||
from rl_coach.filters.action.box_discretization import BoxDiscretization
|
||||
from rl_coach.filters.filter import OutputFilter
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
@@ -26,9 +25,9 @@ schedule_params.heatup_steps = EnvironmentSteps(1000)
|
||||
#########
|
||||
agent_params = DDQNAgentParameters()
|
||||
agent_params.network_wrappers['main'].learning_rate = 0.00025
|
||||
agent_params.network_wrappers['main'].heads_parameters = [DuelingQHeadParameters()]
|
||||
agent_params.network_wrappers['main'].heads_parameters = \
|
||||
[DuelingQHeadParameters(rescale_gradient_from_head_by_factor=1/math.sqrt(2))]
|
||||
agent_params.network_wrappers['main'].middleware_parameters.scheme = MiddlewareScheme.Empty
|
||||
agent_params.network_wrappers['main'].rescale_gradient_from_head_by_factor = [1/math.sqrt(2), 1/math.sqrt(2)]
|
||||
agent_params.network_wrappers['main'].clip_gradients = 10
|
||||
agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(4)
|
||||
agent_params.network_wrappers['main'].input_embedders_parameters['forward_camera'] = \
|
||||
@@ -40,11 +39,6 @@ agent_params.output_filter.add_action_filter('discretization', BoxDiscretization
|
||||
# Environment #
|
||||
###############
|
||||
env_params = CarlaEnvironmentParameters()
|
||||
env_params.level = 'town1'
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params)
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters())
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from rl_coach.agents.actor_critic_agent import ActorCriticAgentParameters
|
||||
from rl_coach.agents.policy_optimization_agent import PolicyGradientRescaler
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.environments.gym_environment import MujocoInputFilter, Mujoco
|
||||
from rl_coach.exploration_policies.categorical import CategoricalParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
||||
from rl_coach.filters.filter import InputFilter
|
||||
from rl_coach.filters.reward.reward_rescale_filter import RewardRescaleFilter
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
@@ -33,20 +32,13 @@ agent_params.algorithm.beta_entropy = 0.01
|
||||
agent_params.network_wrappers['main'].optimizer_type = 'Adam'
|
||||
agent_params.network_wrappers['main'].learning_rate = 0.0001
|
||||
|
||||
agent_params.input_filter = MujocoInputFilter()
|
||||
agent_params.input_filter = InputFilter()
|
||||
agent_params.input_filter.add_reward_filter('rescale', RewardRescaleFilter(1/200.))
|
||||
|
||||
agent_params.exploration = CategoricalParameters()
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Mujoco()
|
||||
env_params.level = 'CartPole-v0'
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = GymVectorEnvironment(level='CartPole-v0')
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -58,5 +50,5 @@ preset_validation_params.max_episodes_to_achieve_reward = 300
|
||||
preset_validation_params.num_workers = 8
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from rl_coach.agents.dfp_agent import DFPAgentParameters, HandlingTargetsAfterEpisodeEnd
|
||||
from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.environments.gym_environment import Mujoco
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.schedules import LinearSchedule
|
||||
@@ -37,12 +36,7 @@ agent_params.algorithm.handling_targets_after_episode_end = HandlingTargetsAfter
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Mujoco()
|
||||
env_params.level = 'CartPole-v0'
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = GymVectorEnvironment(level='CartPole-v0')
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -53,5 +47,5 @@ preset_validation_params.min_reward_threshold = 120
|
||||
preset_validation_params.max_episodes_to_achieve_reward = 250
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from rl_coach.agents.dqn_agent import DQNAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.environments.gym_environment import Mujoco
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.memories.memory import MemoryGranularity
|
||||
@@ -41,12 +40,7 @@ agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000)
|
||||
################
|
||||
# Environment #
|
||||
################
|
||||
env_params = Mujoco()
|
||||
env_params.level = 'CartPole-v0'
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = GymVectorEnvironment(level='CartPole-v0')
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -57,5 +51,5 @@ preset_validation_params.min_reward_threshold = 150
|
||||
preset_validation_params.max_episodes_to_achieve_reward = 250
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -3,9 +3,8 @@ import math
|
||||
from rl_coach.agents.ddqn_agent import DDQNAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.heads.dueling_q_head import DuelingQHeadParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.environments.gym_environment import Mujoco
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.memories.memory import MemoryGranularity
|
||||
@@ -34,8 +33,8 @@ agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(1)
|
||||
# NN configuration
|
||||
agent_params.network_wrappers['main'].learning_rate = 0.00025
|
||||
agent_params.network_wrappers['main'].replace_mse_with_huber_loss = False
|
||||
agent_params.network_wrappers['main'].heads_parameters = [DuelingQHeadParameters()]
|
||||
agent_params.network_wrappers['main'].rescale_gradient_from_head_by_factor = [1/math.sqrt(2), 1/math.sqrt(2)]
|
||||
agent_params.network_wrappers['main'].heads_parameters = \
|
||||
[DuelingQHeadParameters(rescale_gradient_from_head_by_factor=1/math.sqrt(2))]
|
||||
|
||||
# ER size
|
||||
agent_params.memory.max_size = (MemoryGranularity.Transitions, 40000)
|
||||
@@ -46,12 +45,7 @@ agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000)
|
||||
################
|
||||
# Environment #
|
||||
################
|
||||
env_params = Mujoco()
|
||||
env_params.level = 'CartPole-v0'
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = GymVectorEnvironment(level='CartPole-v0')
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -62,5 +56,5 @@ preset_validation_params.min_reward_threshold = 150
|
||||
preset_validation_params.max_episodes_to_achieve_reward = 250
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from rl_coach.agents.nec_agent import NECAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.environments.gym_environment import Atari, MujocoInputFilter
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
||||
from rl_coach.filters.filter import InputFilter
|
||||
from rl_coach.filters.reward.reward_rescale_filter import RewardRescaleFilter
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
@@ -30,18 +30,13 @@ agent_params.exploration.epsilon_schedule = LinearSchedule(0.5, 0.1, 1000)
|
||||
agent_params.exploration.evaluation_epsilon = 0
|
||||
agent_params.algorithm.discount = 0.99
|
||||
agent_params.memory.max_size = (MemoryGranularity.Episodes, 200)
|
||||
agent_params.input_filter = MujocoInputFilter()
|
||||
agent_params.input_filter = InputFilter()
|
||||
agent_params.input_filter.add_reward_filter('rescale', RewardRescaleFilter(1/200.))
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Atari()
|
||||
env_params.level = 'CartPole-v0'
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = GymVectorEnvironment(level='CartPole-v0')
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -53,5 +48,5 @@ preset_validation_params.max_episodes_to_achieve_reward = 300
|
||||
preset_validation_params.test_using_a_trace_test = False
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from rl_coach.agents.n_step_q_agent import NStepQAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.environments.gym_environment import MujocoInputFilter, Mujoco
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
||||
from rl_coach.filters.filter import InputFilter
|
||||
from rl_coach.filters.reward.reward_rescale_filter import RewardRescaleFilter
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
@@ -24,18 +24,13 @@ agent_params = NStepQAgentParameters()
|
||||
agent_params.algorithm.discount = 0.99
|
||||
agent_params.network_wrappers['main'].learning_rate = 0.0001
|
||||
agent_params.algorithm.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(100)
|
||||
agent_params.input_filter = MujocoInputFilter()
|
||||
agent_params.input_filter = InputFilter()
|
||||
agent_params.input_filter.add_reward_filter('rescale', RewardRescaleFilter(1/200.))
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Mujoco()
|
||||
env_params.level = 'CartPole-v0'
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = GymVectorEnvironment(level='CartPole-v0')
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -47,5 +42,5 @@ preset_validation_params.max_episodes_to_achieve_reward = 200
|
||||
preset_validation_params.num_workers = 8
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from rl_coach.agents.pal_agent import PALAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.environments.gym_environment import Mujoco
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.memories.memory import MemoryGranularity
|
||||
@@ -41,12 +40,7 @@ agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000)
|
||||
################
|
||||
# Environment #
|
||||
################
|
||||
env_params = Mujoco()
|
||||
env_params.level = 'CartPole-v0'
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = GymVectorEnvironment(level='CartPole-v0')
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -57,5 +51,5 @@ preset_validation_params.min_reward_threshold = 150
|
||||
preset_validation_params.max_episodes_to_achieve_reward = 250
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from rl_coach.agents.policy_gradients_agent import PolicyGradientsAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.environments.gym_environment import MujocoInputFilter, Mujoco
|
||||
from rl_coach.exploration_policies.categorical import CategoricalParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
||||
from rl_coach.filters.filter import InputFilter
|
||||
from rl_coach.filters.reward.reward_rescale_filter import RewardRescaleFilter
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
@@ -29,20 +28,13 @@ agent_params.algorithm.num_steps_between_gradient_updates = 20000
|
||||
agent_params.network_wrappers['main'].optimizer_type = 'Adam'
|
||||
agent_params.network_wrappers['main'].learning_rate = 0.0005
|
||||
|
||||
agent_params.input_filter = MujocoInputFilter()
|
||||
agent_params.input_filter = InputFilter()
|
||||
agent_params.input_filter.add_reward_filter('rescale', RewardRescaleFilter(1/200.))
|
||||
|
||||
agent_params.exploration = CategoricalParameters()
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Mujoco()
|
||||
env_params.level = 'CartPole-v0'
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = GymVectorEnvironment(level='CartPole-v0')
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -53,6 +45,6 @@ preset_validation_params.min_reward_threshold = 130
|
||||
preset_validation_params.max_episodes_to_achieve_reward = 550
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from rl_coach.agents.ddpg_agent import DDPGAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.control_suite_environment import ControlSuiteEnvironmentParameters, control_suite_envs
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import MujocoInputFilter
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.filters.filter import InputFilter
|
||||
from rl_coach.filters.reward.reward_rescale_filter import RewardRescaleFilter
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
@@ -27,23 +27,18 @@ agent_params.network_wrappers['actor'].input_embedders_parameters['measurements'
|
||||
agent_params.network_wrappers['actor'].input_embedders_parameters.pop('observation')
|
||||
agent_params.network_wrappers['critic'].input_embedders_parameters['measurements'] = \
|
||||
agent_params.network_wrappers['critic'].input_embedders_parameters.pop('observation')
|
||||
agent_params.network_wrappers['actor'].input_embedders_parameters['measurements'].scheme = [Dense([300])]
|
||||
agent_params.network_wrappers['actor'].middleware_parameters.scheme = [Dense([200])]
|
||||
agent_params.network_wrappers['critic'].input_embedders_parameters['measurements'].scheme = [Dense([400])]
|
||||
agent_params.network_wrappers['critic'].middleware_parameters.scheme = [Dense([300])]
|
||||
agent_params.network_wrappers['actor'].input_embedders_parameters['measurements'].scheme = [Dense(300)]
|
||||
agent_params.network_wrappers['actor'].middleware_parameters.scheme = [Dense(200)]
|
||||
agent_params.network_wrappers['critic'].input_embedders_parameters['measurements'].scheme = [Dense(400)]
|
||||
agent_params.network_wrappers['critic'].middleware_parameters.scheme = [Dense(300)]
|
||||
agent_params.network_wrappers['critic'].input_embedders_parameters['action'].scheme = EmbedderScheme.Empty
|
||||
agent_params.input_filter = MujocoInputFilter()
|
||||
agent_params.input_filter = InputFilter()
|
||||
agent_params.input_filter.add_reward_filter("rescale", RewardRescaleFilter(1/10.))
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = ControlSuiteEnvironmentParameters()
|
||||
env_params.level = SingleLevelSelection(control_suite_envs)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = ControlSuiteEnvironmentParameters(level=SingleLevelSelection(control_suite_envs))
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -52,5 +47,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.trace_test_levels = ['cartpole:swingup', 'hopper:hop']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
from rl_coach.agents.actor_critic_agent import ActorCriticAgentParameters
|
||||
from rl_coach.agents.policy_optimization_agent import PolicyGradientRescaler
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.doom_environment import DoomEnvironmentParameters
|
||||
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.environments.gym_environment import MujocoInputFilter
|
||||
from rl_coach.exploration_policies.categorical import CategoricalParameters
|
||||
from rl_coach.filters.filter import InputFilter
|
||||
from rl_coach.filters.reward.reward_rescale_filter import RewardRescaleFilter
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
@@ -27,24 +25,18 @@ schedule_params.heatup_steps = EnvironmentSteps(0)
|
||||
agent_params = ActorCriticAgentParameters()
|
||||
agent_params.algorithm.policy_gradient_rescaler = PolicyGradientRescaler.GAE
|
||||
agent_params.network_wrappers['main'].learning_rate = 0.0001
|
||||
agent_params.input_filter = MujocoInputFilter()
|
||||
agent_params.input_filter = InputFilter()
|
||||
agent_params.input_filter.add_reward_filter('rescale', RewardRescaleFilter(1/100.))
|
||||
agent_params.algorithm.num_steps_between_gradient_updates = 30
|
||||
agent_params.algorithm.apply_gradients_every_x_episodes = 1
|
||||
agent_params.algorithm.gae_lambda = 1.0
|
||||
agent_params.algorithm.beta_entropy = 0.01
|
||||
agent_params.network_wrappers['main'].clip_gradients = 40.
|
||||
agent_params.exploration = CategoricalParameters()
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = DoomEnvironmentParameters()
|
||||
env_params.level = 'basic'
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = DoomEnvironmentParameters(level='basic')
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -57,5 +49,5 @@ preset_validation_params.num_workers = 8
|
||||
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -35,8 +35,7 @@ agent_params.memory.load_memory_from_file_path = 'datasets/doom_basic.p'
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = DoomEnvironmentParameters()
|
||||
env_params.level = 'basic'
|
||||
env_params = DoomEnvironmentParameters(level='basic')
|
||||
|
||||
########
|
||||
# Test #
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from rl_coach.agents.dfp_agent import DFPAgentParameters, HandlingTargetsAfterEpisodeEnd
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.doom_environment import DoomEnvironmentParameters
|
||||
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.schedules import LinearSchedule
|
||||
@@ -39,12 +38,8 @@ agent_params.algorithm.handling_targets_after_episode_end = HandlingTargetsAfter
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = DoomEnvironmentParameters()
|
||||
env_params.level = 'basic'
|
||||
env_params = DoomEnvironmentParameters(level='basic')
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -53,5 +48,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.trace_max_env_steps = 2000
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from rl_coach.agents.dqn_agent import DQNAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.doom_environment import DoomEnvironmentParameters
|
||||
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.memories.memory import MemoryGranularity
|
||||
@@ -35,12 +34,8 @@ agent_params.network_wrappers['main'].replace_mse_with_huber_loss = False
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = DoomEnvironmentParameters()
|
||||
env_params.level = 'basic'
|
||||
env_params = DoomEnvironmentParameters(level='basic')
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -52,5 +47,5 @@ preset_validation_params.max_episodes_to_achieve_reward = 400
|
||||
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from rl_coach.agents.ddqn_agent import DDQNAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.heads.dueling_q_head import DuelingQHeadParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.doom_environment import DoomEnvironmentParameters
|
||||
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.memories.memory import MemoryGranularity
|
||||
@@ -36,12 +35,8 @@ agent_params.network_wrappers['main'].heads_parameters = [DuelingQHeadParameters
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = DoomEnvironmentParameters()
|
||||
env_params.level = 'basic'
|
||||
env_params = DoomEnvironmentParameters(level='basic')
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params)
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters())
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from rl_coach.agents.dfp_agent import DFPAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters
|
||||
from rl_coach.core_types import EnvironmentSteps, RunPhase
|
||||
from rl_coach.core_types import EnvironmentSteps
|
||||
from rl_coach.environments.doom_environment import DoomEnvironmentParameters, DoomEnvironment
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.schedules import LinearSchedule
|
||||
@@ -44,13 +43,9 @@ agent_params.network_wrappers['main'].input_embedders_parameters['observation'].
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = DoomEnvironmentParameters()
|
||||
env_params.level = 'BATTLE_COACH_LOCAL'
|
||||
env_params = DoomEnvironmentParameters(level='BATTLE_COACH_LOCAL')
|
||||
env_params.cameras = [DoomEnvironment.CameraTypes.OBSERVATION]
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params)
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters())
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from rl_coach.agents.dfp_agent import DFPAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme, MiddlewareScheme, \
|
||||
PresetValidationParameters
|
||||
from rl_coach.core_types import EnvironmentSteps, RunPhase, EnvironmentEpisodes
|
||||
from rl_coach.core_types import EnvironmentSteps, EnvironmentEpisodes
|
||||
from rl_coach.environments.doom_environment import DoomEnvironmentParameters
|
||||
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.schedules import LinearSchedule
|
||||
@@ -54,11 +53,8 @@ agent_params.algorithm.scale_measurements_targets['GameVariable.HEALTH'] = 30.0
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = DoomEnvironmentParameters()
|
||||
env_params.level = 'HEALTH_GATHERING'
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = DoomEnvironmentParameters(level='HEALTH_GATHERING')
|
||||
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -71,5 +67,5 @@ preset_validation_params.max_episodes_to_achieve_reward = 70
|
||||
preset_validation_params.test_using_a_trace_test = False
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from rl_coach.agents.mmc_agent import MixedMonteCarloAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.doom_environment import DoomEnvironmentParameters
|
||||
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.memories.memory import MemoryGranularity
|
||||
@@ -35,11 +34,7 @@ agent_params.network_wrappers['main'].replace_mse_with_huber_loss = False
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = DoomEnvironmentParameters()
|
||||
env_params.level = 'HEALTH_GATHERING'
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = DoomEnvironmentParameters(level='HEALTH_GATHERING')
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -52,5 +47,5 @@ preset_validation_params.test_using_a_trace_test = False
|
||||
# preset_validation_params.max_episodes_to_achieve_reward = 300
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from rl_coach.agents.dfp_agent import DFPAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme, MiddlewareScheme, \
|
||||
PresetValidationParameters
|
||||
from rl_coach.core_types import EnvironmentSteps, RunPhase, EnvironmentEpisodes
|
||||
from rl_coach.core_types import EnvironmentSteps, EnvironmentEpisodes
|
||||
from rl_coach.environments.doom_environment import DoomEnvironmentParameters
|
||||
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.schedules import LinearSchedule
|
||||
@@ -54,11 +53,8 @@ agent_params.algorithm.scale_measurements_targets['GameVariable.HEALTH'] = 30.0
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = DoomEnvironmentParameters()
|
||||
env_params.level = 'HEALTH_GATHERING_SUPREME_COACH_LOCAL'
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = DoomEnvironmentParameters(level='HEALTH_GATHERING_SUPREME_COACH_LOCAL')
|
||||
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -67,5 +63,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.test_using_a_trace_test = False
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from rl_coach.agents.bootstrapped_dqn_agent import BootstrappedDQNAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters
|
||||
from rl_coach.core_types import EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.gym_environment import Mujoco
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
||||
from rl_coach.filters.filter import NoInputFilter, NoOutputFilter
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
@@ -29,8 +29,8 @@ agent_params.network_wrappers['main'].learning_rate = 0.00025
|
||||
agent_params.memory.max_size = (MemoryGranularity.Transitions, 1000000)
|
||||
agent_params.algorithm.discount = 0.99
|
||||
agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(4)
|
||||
agent_params.network_wrappers['main'].num_output_head_copies = num_output_head_copies
|
||||
agent_params.network_wrappers['main'].rescale_gradient_from_head_by_factor = [1.0/num_output_head_copies]*num_output_head_copies
|
||||
agent_params.network_wrappers['main'].heads_parameters[0].num_output_head_copies = num_output_head_copies
|
||||
agent_params.network_wrappers['main'].heads_parameters[0].rescale_gradient_from_head_by_factor = 1.0/num_output_head_copies
|
||||
agent_params.exploration.bootstrapped_data_sharing_probability = 1.0
|
||||
agent_params.exploration.architecture_num_q_heads = num_output_head_copies
|
||||
agent_params.exploration.epsilon_schedule = ConstantSchedule(0)
|
||||
@@ -40,26 +40,9 @@ agent_params.output_filter = NoOutputFilter()
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Mujoco()
|
||||
env_params.level = 'rl_coach.environments.toy_problems.exploration_chain:ExplorationChain'
|
||||
|
||||
env_params = GymVectorEnvironment(level='rl_coach.environments.toy_problems.exploration_chain:ExplorationChain')
|
||||
env_params.additional_simulator_parameters = {'chain_length': N, 'max_steps': N+7}
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
|
||||
|
||||
########
|
||||
# Test #
|
||||
########
|
||||
|
||||
# currently no test here as bootstrapped_dqn seems to be broken
|
||||
|
||||
# preset_validation_params = PresetValidationParameters()
|
||||
# preset_validation_params.test = True
|
||||
# preset_validation_params.min_reward_threshold = 1600
|
||||
# preset_validation_params.max_episodes_to_achieve_reward = 70
|
||||
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,)
|
||||
# preset_validation_params=preset_validation_params)
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters())
|
||||
|
||||
@@ -38,18 +38,8 @@ agent_params.output_filter = NoOutputFilter()
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = GymEnvironmentParameters()
|
||||
env_params.level = 'rl_coach.environments.toy_problems.exploration_chain:ExplorationChain'
|
||||
env_params = GymEnvironmentParameters(level='rl_coach.environments.toy_problems.exploration_chain:ExplorationChain')
|
||||
env_params.additional_simulator_parameters = {'chain_length': N, 'max_steps': N+7}
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
|
||||
|
||||
# preset_validation_params = PresetValidationParameters()
|
||||
# preset_validation_params.test = True
|
||||
# preset_validation_params.min_reward_threshold = 1600
|
||||
# preset_validation_params.max_episodes_to_achieve_reward = 70
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,)
|
||||
# preset_validation_params=preset_validation_params)
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters())
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from rl_coach.agents.bootstrapped_dqn_agent import BootstrappedDQNAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters
|
||||
from rl_coach.core_types import EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.gym_environment import Mujoco
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
||||
from rl_coach.exploration_policies.ucb import UCBParameters
|
||||
from rl_coach.filters.filter import NoInputFilter, NoOutputFilter
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
@@ -30,8 +30,8 @@ agent_params.network_wrappers['main'].learning_rate = 0.00025
|
||||
agent_params.memory.max_size = (MemoryGranularity.Transitions, 1000000)
|
||||
agent_params.algorithm.discount = 0.99
|
||||
agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(4)
|
||||
agent_params.network_wrappers['main'].num_output_head_copies = num_output_head_copies
|
||||
agent_params.network_wrappers['main'].rescale_gradient_from_head_by_factor = [1.0/num_output_head_copies]*num_output_head_copies
|
||||
agent_params.network_wrappers['main'].heads_parameters[0].num_output_head_copies = num_output_head_copies
|
||||
agent_params.network_wrappers['main'].heads_parameters[0].rescale_gradient_from_head_by_factor = 1.0/num_output_head_copies
|
||||
agent_params.exploration = UCBParameters()
|
||||
agent_params.exploration.bootstrapped_data_sharing_probability = 1.0
|
||||
agent_params.exploration.architecture_num_q_heads = num_output_head_copies
|
||||
@@ -43,12 +43,8 @@ agent_params.output_filter = NoOutputFilter()
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Mujoco()
|
||||
env_params.level = 'rl_coach.environments.toy_problems.exploration_chain:ExplorationChain'
|
||||
|
||||
env_params = GymVectorEnvironment(level='rl_coach.environments.toy_problems.exploration_chain:ExplorationChain')
|
||||
env_params.additional_simulator_parameters = {'chain_length': N, 'max_steps': N+7}
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params)
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters())
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from rl_coach.agents.ddpg_agent import DDPGAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme, PresetValidationParameters
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
from rl_coach.core_types import EnvironmentEpisodes, EnvironmentSteps, TrainingSteps, RunPhase
|
||||
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod, SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Mujoco, MujocoInputFilter, fetch_v1
|
||||
from rl_coach.core_types import EnvironmentEpisodes, EnvironmentSteps, TrainingSteps
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment, fetch_v1
|
||||
from rl_coach.exploration_policies.e_greedy import EGreedyParameters
|
||||
from rl_coach.filters.filter import InputFilter
|
||||
from rl_coach.filters.observation.observation_clipping_filter import ObservationClippingFilter
|
||||
from rl_coach.filters.observation.observation_normalization_filter import ObservationNormalizationFilter
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
@@ -44,7 +45,7 @@ actor_network.input_embedders_parameters = {
|
||||
'observation': InputEmbedderParameters(scheme=EmbedderScheme.Empty),
|
||||
'desired_goal': InputEmbedderParameters(scheme=EmbedderScheme.Empty)
|
||||
}
|
||||
actor_network.middleware_parameters = FCMiddlewareParameters(scheme=[Dense([256]), Dense([256]), Dense([256])])
|
||||
actor_network.middleware_parameters = FCMiddlewareParameters(scheme=[Dense(256), Dense(256), Dense(256)])
|
||||
actor_network.heads_parameters[0].batchnorm = False
|
||||
|
||||
# critic
|
||||
@@ -59,7 +60,7 @@ critic_network.input_embedders_parameters = {
|
||||
'desired_goal': InputEmbedderParameters(scheme=EmbedderScheme.Empty),
|
||||
'observation': InputEmbedderParameters(scheme=EmbedderScheme.Empty)
|
||||
}
|
||||
critic_network.middleware_parameters = FCMiddlewareParameters(scheme=[Dense([256]), Dense([256]), Dense([256])])
|
||||
critic_network.middleware_parameters = FCMiddlewareParameters(scheme=[Dense(256), Dense(256), Dense(256)])
|
||||
|
||||
agent_params.algorithm.discount = 0.98
|
||||
agent_params.algorithm.num_consecutive_playing_steps = EnvironmentEpisodes(1)
|
||||
@@ -90,10 +91,10 @@ agent_params.exploration.evaluation_epsilon = 0
|
||||
agent_params.exploration.continuous_exploration_policy_parameters.noise_percentage_schedule = ConstantSchedule(0.1)
|
||||
agent_params.exploration.continuous_exploration_policy_parameters.evaluation_noise_percentage = 0
|
||||
|
||||
agent_params.input_filter = MujocoInputFilter()
|
||||
agent_params.input_filter = InputFilter()
|
||||
agent_params.input_filter.add_observation_filter('observation', 'clipping', ObservationClippingFilter(-200, 200))
|
||||
|
||||
agent_params.pre_network_filter = MujocoInputFilter()
|
||||
agent_params.pre_network_filter = InputFilter()
|
||||
agent_params.pre_network_filter.add_observation_filter('observation', 'normalize_observation',
|
||||
ObservationNormalizationFilter(name='normalize_observation'))
|
||||
agent_params.pre_network_filter.add_observation_filter('achieved_goal', 'normalize_achieved_goal',
|
||||
@@ -104,27 +105,17 @@ agent_params.pre_network_filter.add_observation_filter('desired_goal', 'normaliz
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Mujoco()
|
||||
env_params.level = SingleLevelSelection(fetch_v1)
|
||||
env_params = GymVectorEnvironment(level=SingleLevelSelection(fetch_v1))
|
||||
env_params.custom_reward_threshold = -49
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
|
||||
|
||||
########
|
||||
# Test #
|
||||
########
|
||||
preset_validation_params = PresetValidationParameters()
|
||||
# preset_validation_params.test = True
|
||||
# preset_validation_params.min_reward_threshold = 200
|
||||
# preset_validation_params.max_episodes_to_achieve_reward = 600
|
||||
# preset_validation_params.reward_test_level = 'inverted_pendulum'
|
||||
preset_validation_params.trace_test_levels = ['slide', 'pick_and_place', 'push', 'reach']
|
||||
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from rl_coach.agents.policy_gradients_agent import PolicyGradientsAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod
|
||||
from rl_coach.environments.gym_environment import Mujoco, MujocoInputFilter
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
||||
from rl_coach.filters.filter import InputFilter
|
||||
from rl_coach.filters.observation.observation_normalization_filter import ObservationNormalizationFilter
|
||||
from rl_coach.filters.reward.reward_rescale_filter import RewardRescaleFilter
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
@@ -25,7 +25,7 @@ agent_params.algorithm.apply_gradients_every_x_episodes = 5
|
||||
agent_params.algorithm.num_steps_between_gradient_updates = 20000
|
||||
agent_params.network_wrappers['main'].learning_rate = 0.0005
|
||||
|
||||
agent_params.input_filter = MujocoInputFilter()
|
||||
agent_params.input_filter = InputFilter()
|
||||
agent_params.input_filter.add_reward_filter('rescale', RewardRescaleFilter(1/20.))
|
||||
agent_params.input_filter.add_observation_filter('observation', 'normalize', ObservationNormalizationFilter())
|
||||
|
||||
@@ -33,14 +33,9 @@ agent_params.input_filter.add_observation_filter('observation', 'normalize', Obs
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Mujoco()
|
||||
env_params.level = "InvertedPendulum-v2"
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = GymVectorEnvironment(level="InvertedPendulum-v2")
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params)
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters())
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from rl_coach.agents.bc_agent import BCAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.gym_environment import Atari
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
@@ -31,14 +30,9 @@ agent_params.memory.load_memory_from_file_path = 'datasets/montezuma_revenge.p'
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Atari()
|
||||
env_params.level = 'MontezumaRevenge-v0'
|
||||
env_params = Atari(level='MontezumaRevenge-v0')
|
||||
env_params.random_initialization_steps = 30
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
|
||||
########
|
||||
# Test #
|
||||
########
|
||||
@@ -46,5 +40,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.test_using_a_trace_test = False
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from rl_coach.agents.actor_critic_agent import ActorCriticAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Mujoco, mujoco_v2, MujocoInputFilter
|
||||
from rl_coach.exploration_policies.continuous_entropy import ContinuousEntropyParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2
|
||||
from rl_coach.filters.filter import InputFilter
|
||||
from rl_coach.filters.observation.observation_normalization_filter import ObservationNormalizationFilter
|
||||
from rl_coach.filters.reward.reward_rescale_filter import RewardRescaleFilter
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
@@ -27,21 +27,15 @@ agent_params.algorithm.num_steps_between_gradient_updates = 10000000
|
||||
agent_params.algorithm.beta_entropy = 0.0001
|
||||
agent_params.network_wrappers['main'].learning_rate = 0.00001
|
||||
|
||||
agent_params.input_filter = MujocoInputFilter()
|
||||
agent_params.input_filter = InputFilter()
|
||||
agent_params.input_filter.add_reward_filter('rescale', RewardRescaleFilter(1/20.))
|
||||
agent_params.input_filter.add_observation_filter('observation', 'normalize', ObservationNormalizationFilter())
|
||||
|
||||
agent_params.exploration = ContinuousEntropyParameters()
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Mujoco()
|
||||
env_params.level = SingleLevelSelection(mujoco_v2)
|
||||
env_params = GymVectorEnvironment(level=SingleLevelSelection(mujoco_v2))
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -55,7 +49,7 @@ preset_validation_params.reward_test_level = 'inverted_pendulum'
|
||||
preset_validation_params.trace_test_levels = ['inverted_pendulum', 'hopper']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from rl_coach.agents.actor_critic_agent import ActorCriticAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.architectures.tensorflow_components.middlewares.lstm_middleware import LSTMMiddlewareParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, MiddlewareScheme, PresetValidationParameters
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Mujoco, mujoco_v2, MujocoInputFilter
|
||||
from rl_coach.exploration_policies.continuous_entropy import ContinuousEntropyParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2
|
||||
from rl_coach.filters.filter import InputFilter
|
||||
from rl_coach.filters.observation.observation_normalization_filter import ObservationNormalizationFilter
|
||||
from rl_coach.filters.reward.reward_rescale_filter import RewardRescaleFilter
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
@@ -30,25 +30,18 @@ agent_params.algorithm.num_steps_between_gradient_updates = 20
|
||||
agent_params.algorithm.beta_entropy = 0.005
|
||||
agent_params.network_wrappers['main'].learning_rate = 0.00002
|
||||
agent_params.network_wrappers['main'].input_embedders_parameters['observation'] = \
|
||||
InputEmbedderParameters(scheme=[Dense([200])])
|
||||
InputEmbedderParameters(scheme=[Dense(200)])
|
||||
agent_params.network_wrappers['main'].middleware_parameters = LSTMMiddlewareParameters(scheme=MiddlewareScheme.Empty,
|
||||
number_of_lstm_cells=128)
|
||||
|
||||
agent_params.input_filter = MujocoInputFilter()
|
||||
agent_params.input_filter = InputFilter()
|
||||
agent_params.input_filter.add_reward_filter('rescale', RewardRescaleFilter(1/20.))
|
||||
agent_params.input_filter.add_observation_filter('observation', 'normalize', ObservationNormalizationFilter())
|
||||
|
||||
agent_params.exploration = ContinuousEntropyParameters()
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Mujoco()
|
||||
env_params.level = SingleLevelSelection(mujoco_v2)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = GymVectorEnvironment(level=SingleLevelSelection(mujoco_v2))
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -62,7 +55,7 @@ preset_validation_params.reward_test_level = 'inverted_pendulum'
|
||||
preset_validation_params.trace_test_levels = ['inverted_pendulum', 'hopper']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Mujoco, mujoco_v2, MujocoInputFilter
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2
|
||||
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
|
||||
from rl_coach.filters.filter import InputFilter
|
||||
from rl_coach.filters.observation.observation_normalization_filter import ObservationNormalizationFilter
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
@@ -28,8 +29,8 @@ agent_params = ClippedPPOAgentParameters()
|
||||
|
||||
agent_params.network_wrappers['main'].learning_rate = 0.0003
|
||||
agent_params.network_wrappers['main'].input_embedders_parameters['observation'].activation_function = 'tanh'
|
||||
agent_params.network_wrappers['main'].input_embedders_parameters['observation'].scheme = [Dense([64])]
|
||||
agent_params.network_wrappers['main'].middleware_parameters.scheme = [Dense([64])]
|
||||
agent_params.network_wrappers['main'].input_embedders_parameters['observation'].scheme = [Dense(64)]
|
||||
agent_params.network_wrappers['main'].middleware_parameters.scheme = [Dense(64)]
|
||||
agent_params.network_wrappers['main'].middleware_parameters.activation_function = 'tanh'
|
||||
agent_params.network_wrappers['main'].batch_size = 64
|
||||
agent_params.network_wrappers['main'].optimizer_epsilon = 1e-5
|
||||
@@ -43,21 +44,16 @@ agent_params.algorithm.discount = 0.99
|
||||
agent_params.algorithm.optimization_epochs = 10
|
||||
agent_params.algorithm.estimate_state_value_using_gae = True
|
||||
|
||||
agent_params.input_filter = MujocoInputFilter()
|
||||
agent_params.input_filter = InputFilter()
|
||||
agent_params.exploration = AdditiveNoiseParameters()
|
||||
agent_params.pre_network_filter = MujocoInputFilter()
|
||||
agent_params.pre_network_filter = InputFilter()
|
||||
agent_params.pre_network_filter.add_observation_filter('observation', 'normalize_observation',
|
||||
ObservationNormalizationFilter(name='normalize_observation'))
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Mujoco()
|
||||
env_params.level = SingleLevelSelection(mujoco_v2)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = GymVectorEnvironment(level=SingleLevelSelection(mujoco_v2))
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -70,7 +66,7 @@ preset_validation_params.reward_test_level = 'inverted_pendulum'
|
||||
preset_validation_params.trace_test_levels = ['inverted_pendulum', 'hopper']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from rl_coach.agents.ddpg_agent import DDPGAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters, EmbedderScheme
|
||||
from rl_coach.core_types import EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Mujoco, mujoco_v2
|
||||
from rl_coach.core_types import EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
|
||||
@@ -21,21 +21,16 @@ schedule_params.heatup_steps = EnvironmentSteps(1000)
|
||||
# Agent #
|
||||
#########
|
||||
agent_params = DDPGAgentParameters()
|
||||
agent_params.network_wrappers['actor'].input_embedders_parameters['observation'].scheme = [Dense([400])]
|
||||
agent_params.network_wrappers['actor'].middleware_parameters.scheme = [Dense([300])]
|
||||
agent_params.network_wrappers['critic'].input_embedders_parameters['observation'].scheme = [Dense([400])]
|
||||
agent_params.network_wrappers['critic'].middleware_parameters.scheme = [Dense([300])]
|
||||
agent_params.network_wrappers['actor'].input_embedders_parameters['observation'].scheme = [Dense(400)]
|
||||
agent_params.network_wrappers['actor'].middleware_parameters.scheme = [Dense(300)]
|
||||
agent_params.network_wrappers['critic'].input_embedders_parameters['observation'].scheme = [Dense(400)]
|
||||
agent_params.network_wrappers['critic'].middleware_parameters.scheme = [Dense(300)]
|
||||
agent_params.network_wrappers['critic'].input_embedders_parameters['action'].scheme = EmbedderScheme.Empty
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Mujoco()
|
||||
env_params.level = SingleLevelSelection(mujoco_v2)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = GymVectorEnvironment(level=SingleLevelSelection(mujoco_v2))
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -48,5 +43,5 @@ preset_validation_params.reward_test_level = 'inverted_pendulum'
|
||||
preset_validation_params.trace_test_levels = ['inverted_pendulum', 'hopper']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from rl_coach.agents.naf_agent import NAFAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase, GradientClippingMethod
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Mujoco, mujoco_v2
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, GradientClippingMethod
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
|
||||
@@ -20,20 +20,15 @@ schedule_params.heatup_steps = EnvironmentSteps(1000)
|
||||
# Agent #
|
||||
#########
|
||||
agent_params = NAFAgentParameters()
|
||||
agent_params.network_wrappers['main'].input_embedders_parameters['observation'].scheme = [Dense([200])]
|
||||
agent_params.network_wrappers['main'].middleware_parameters.scheme = [Dense([200])]
|
||||
agent_params.network_wrappers['main'].input_embedders_parameters['observation'].scheme = [Dense(200)]
|
||||
agent_params.network_wrappers['main'].middleware_parameters.scheme = [Dense(200)]
|
||||
agent_params.network_wrappers['main'].clip_gradients = 1000
|
||||
agent_params.network_wrappers['main'].gradients_clipping_method = GradientClippingMethod.ClipByValue
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Mujoco()
|
||||
env_params.level = SingleLevelSelection(mujoco_v2)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = GymVectorEnvironment(level=SingleLevelSelection(mujoco_v2))
|
||||
|
||||
|
||||
# this preset is currently broken - no test
|
||||
@@ -51,5 +46,5 @@ preset_validation_params.trace_test_levels = ['inverted_pendulum', 'hopper']
|
||||
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from rl_coach.agents.ppo_agent import PPOAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Mujoco, mujoco_v2, MujocoInputFilter
|
||||
from rl_coach.exploration_policies.continuous_entropy import ContinuousEntropyParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2
|
||||
from rl_coach.filters.filter import InputFilter
|
||||
from rl_coach.filters.observation.observation_normalization_filter import ObservationNormalizationFilter
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
@@ -25,28 +25,18 @@ agent_params = PPOAgentParameters()
|
||||
agent_params.network_wrappers['actor'].learning_rate = 0.001
|
||||
agent_params.network_wrappers['critic'].learning_rate = 0.001
|
||||
|
||||
agent_params.network_wrappers['actor'].input_embedders_parameters['observation'].scheme = [Dense([64])]
|
||||
agent_params.network_wrappers['actor'].middleware_parameters.scheme = [Dense([64])]
|
||||
agent_params.network_wrappers['critic'].input_embedders_parameters['observation'].scheme = [Dense([64])]
|
||||
agent_params.network_wrappers['critic'].middleware_parameters.scheme = [Dense([64])]
|
||||
agent_params.network_wrappers['actor'].input_embedders_parameters['observation'].scheme = [Dense(64)]
|
||||
agent_params.network_wrappers['actor'].middleware_parameters.scheme = [Dense(64)]
|
||||
agent_params.network_wrappers['critic'].input_embedders_parameters['observation'].scheme = [Dense(64)]
|
||||
agent_params.network_wrappers['critic'].middleware_parameters.scheme = [Dense(64)]
|
||||
|
||||
agent_params.input_filter = MujocoInputFilter()
|
||||
agent_params.input_filter = InputFilter()
|
||||
agent_params.input_filter.add_observation_filter('observation', 'normalize', ObservationNormalizationFilter())
|
||||
|
||||
agent_params.exploration = ContinuousEntropyParameters()
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Mujoco()
|
||||
env_params.level = SingleLevelSelection(mujoco_v2)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
|
||||
|
||||
# this preset is currently broken
|
||||
env_params = GymVectorEnvironment(level=SingleLevelSelection(mujoco_v2))
|
||||
|
||||
|
||||
########
|
||||
@@ -60,7 +50,7 @@ preset_validation_params.reward_test_level = 'inverted_pendulum'
|
||||
preset_validation_params.trace_test_levels = ['inverted_pendulum', 'hopper']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.agents.hac_ddpg_agent import HACDDPGAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
from rl_coach.base_parameters import VisualizationParameters, EmbeddingMergerType, EmbedderScheme
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
|
||||
from rl_coach.core_types import EnvironmentEpisodes, EnvironmentSteps, RunPhase, TrainingSteps
|
||||
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod
|
||||
from rl_coach.environments.gym_environment import Mujoco
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.base_parameters import VisualizationParameters, EmbeddingMergerType, EmbedderScheme
|
||||
from rl_coach.core_types import EnvironmentEpisodes, EnvironmentSteps, TrainingSteps
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
||||
from rl_coach.exploration_policies.e_greedy import EGreedyParameters
|
||||
from rl_coach.exploration_policies.ou_process import OUProcessParameters
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
@@ -66,7 +64,7 @@ top_agent_params.exploration.theta = 0.1
|
||||
top_actor = top_agent_params.network_wrappers['actor']
|
||||
top_actor.input_embedders_parameters = {'observation': InputEmbedderParameters(scheme=EmbedderScheme.Empty),
|
||||
'desired_goal': InputEmbedderParameters(scheme=EmbedderScheme.Empty)}
|
||||
top_actor.middleware_parameters.scheme = [Dense([64])] * 3
|
||||
top_actor.middleware_parameters.scheme = [Dense(64)] * 3
|
||||
top_actor.learning_rate = 0.001
|
||||
top_actor.batch_size = 4096
|
||||
|
||||
@@ -76,7 +74,7 @@ top_critic.input_embedders_parameters = {'observation': InputEmbedderParameters(
|
||||
'action': InputEmbedderParameters(scheme=EmbedderScheme.Empty),
|
||||
'desired_goal': InputEmbedderParameters(scheme=EmbedderScheme.Empty)}
|
||||
top_critic.embedding_merger_type = EmbeddingMergerType.Concat
|
||||
top_critic.middleware_parameters.scheme = [Dense([64])] * 3
|
||||
top_critic.middleware_parameters.scheme = [Dense(64)] * 3
|
||||
top_critic.learning_rate = 0.001
|
||||
top_critic.batch_size = 4096
|
||||
|
||||
@@ -107,7 +105,7 @@ bottom_agent_params.exploration.continuous_exploration_policy_parameters.theta =
|
||||
bottom_actor = bottom_agent_params.network_wrappers['actor']
|
||||
bottom_actor.input_embedders_parameters = {'observation': InputEmbedderParameters(scheme=EmbedderScheme.Empty),
|
||||
'desired_goal': InputEmbedderParameters(scheme=EmbedderScheme.Empty)}
|
||||
bottom_actor.middleware_parameters.scheme = [Dense([64])] * 3
|
||||
bottom_actor.middleware_parameters.scheme = [Dense(64)] * 3
|
||||
bottom_actor.learning_rate = 0.001
|
||||
bottom_actor.batch_size = 4096
|
||||
|
||||
@@ -117,7 +115,7 @@ bottom_critic.input_embedders_parameters = {'observation': InputEmbedderParamete
|
||||
'action': InputEmbedderParameters(scheme=EmbedderScheme.Empty),
|
||||
'desired_goal': InputEmbedderParameters(scheme=EmbedderScheme.Empty)}
|
||||
bottom_critic.embedding_merger_type = EmbeddingMergerType.Concat
|
||||
bottom_critic.middleware_parameters.scheme = [Dense([64])] * 3
|
||||
bottom_critic.middleware_parameters.scheme = [Dense(64)] * 3
|
||||
bottom_critic.learning_rate = 0.001
|
||||
bottom_critic.batch_size = 4096
|
||||
|
||||
@@ -128,8 +126,7 @@ agents_params = [top_agent_params, bottom_agent_params]
|
||||
###############
|
||||
time_limit = 1000
|
||||
|
||||
env_params = Mujoco()
|
||||
env_params.level = "rl_coach.environments.mujoco.pendulum_with_goals:PendulumWithGoals"
|
||||
env_params = GymVectorEnvironment(level="rl_coach.environments.mujoco.pendulum_with_goals:PendulumWithGoals")
|
||||
env_params.additional_simulator_parameters = {"time_limit": time_limit,
|
||||
"random_goals_instead_of_standing_goal": False,
|
||||
"polar_coordinates": polar_coordinates,
|
||||
@@ -138,8 +135,6 @@ env_params.frame_skip = 10
|
||||
env_params.custom_reward_threshold = -time_limit + 1
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST)]
|
||||
vis_params.dump_mp4 = False
|
||||
vis_params.native_rendering = False
|
||||
|
||||
graph_manager = HACGraphManager(agents_params=agents_params, env_params=env_params,
|
||||
|
||||
29
rl_coach/presets/README.md
Normal file
29
rl_coach/presets/README.md
Normal file
@@ -0,0 +1,29 @@
|
||||
# Defining Presets
|
||||
|
||||
In Coach, we use a Preset mechanism in order to define reproducible experiments.
|
||||
A Preset defines all the parameters of an experiment in a single file, and can be executed from the command
|
||||
line using the file name.
|
||||
Presets can be very simple by using the default parameters of the algorithm and environment.
|
||||
They can also be explicit and define all the parameters in order to avoid hidden logic.
|
||||
The outcome of a preset is a GraphManager.
|
||||
|
||||
|
||||
Let's start with the simplest preset possible.
|
||||
We will define a preset for training the CartPole environment using Clipped PPO.
|
||||
The 3 minimal things we need to define in each preset are the agent, the environment and a schedule.
|
||||
|
||||
```
|
||||
from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import SimpleSchedule
|
||||
|
||||
graph_manager = BasicRLGraphManager(
|
||||
agent_params=ClippedPPOAgentParameters(),
|
||||
env_params=GymVectorEnvironment(level='CartPole-v0'),
|
||||
schedule_params=SimpleSchedule()
|
||||
)
|
||||
```
|
||||
|
||||
Most presets in Coach are much more explicit than this. The motivation behind this is to be as transparent as
|
||||
possible regarding all the changes needed relative to the basic parameters defined in the algorithm paper.
|
||||
@@ -1,10 +1,8 @@
|
||||
from rl_coach.agents.actor_critic_agent import ActorCriticAgentParameters
|
||||
from rl_coach.agents.policy_optimization_agent import PolicyGradientRescaler
|
||||
from rl_coach.base_parameters import VisualizationParameters
|
||||
from rl_coach.core_types import RunPhase
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod
|
||||
from rl_coach.base_parameters import VisualizationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.starcraft2_environment import StarCraft2EnvironmentParameters
|
||||
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
@@ -51,15 +49,9 @@ agent_params.network_wrappers['main'].adam_optimizer_beta2 = 0.999
|
||||
# Environment #
|
||||
###############
|
||||
|
||||
env_params = StarCraft2EnvironmentParameters()
|
||||
env_params.level = 'CollectMineralShards'
|
||||
env_params = StarCraft2EnvironmentParameters(level='CollectMineralShards')
|
||||
env_params.feature_screen_maps_to_use = [5]
|
||||
env_params.feature_minimap_maps_to_use = [5]
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
# vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST),MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = True
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params)
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters())
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
from rl_coach.agents.ddqn_agent import DDQNAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
from rl_coach.architectures.tensorflow_components.heads.dueling_q_head import DuelingQHeadParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters
|
||||
from rl_coach.core_types import RunPhase
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod
|
||||
from rl_coach.environments.starcraft2_environment import StarCraft2EnvironmentParameters
|
||||
from rl_coach.filters.action.box_discretization import BoxDiscretization
|
||||
from rl_coach.filters.filter import OutputFilter
|
||||
@@ -51,15 +49,10 @@ agent_params.output_filter = \
|
||||
# Environment #
|
||||
###############
|
||||
|
||||
env_params = StarCraft2EnvironmentParameters()
|
||||
env_params.level = 'CollectMineralShards'
|
||||
env_params = StarCraft2EnvironmentParameters(level='CollectMineralShards')
|
||||
env_params.feature_screen_maps_to_use = [5]
|
||||
env_params.feature_minimap_maps_to_use = [5]
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
# vis_params.dump_in_episode_signals = True
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params)
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters())
|
||||
|
||||
Reference in New Issue
Block a user