mirror of
https://github.com/gryf/coach.git
synced 2026-03-31 00:53:32 +02:00
network_imporvements branch merge
This commit is contained in:
@@ -1,13 +1,11 @@
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.agents.hac_ddpg_agent import HACDDPGAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
from rl_coach.base_parameters import VisualizationParameters, EmbeddingMergerType, EmbedderScheme
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
|
||||
from rl_coach.core_types import EnvironmentEpisodes, EnvironmentSteps, RunPhase, TrainingSteps
|
||||
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod
|
||||
from rl_coach.environments.gym_environment import Mujoco
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.base_parameters import VisualizationParameters, EmbeddingMergerType, EmbedderScheme
|
||||
from rl_coach.core_types import EnvironmentEpisodes, EnvironmentSteps, TrainingSteps
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
||||
from rl_coach.exploration_policies.e_greedy import EGreedyParameters
|
||||
from rl_coach.exploration_policies.ou_process import OUProcessParameters
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
@@ -66,7 +64,7 @@ top_agent_params.exploration.theta = 0.1
|
||||
top_actor = top_agent_params.network_wrappers['actor']
|
||||
top_actor.input_embedders_parameters = {'observation': InputEmbedderParameters(scheme=EmbedderScheme.Empty),
|
||||
'desired_goal': InputEmbedderParameters(scheme=EmbedderScheme.Empty)}
|
||||
top_actor.middleware_parameters.scheme = [Dense([64])] * 3
|
||||
top_actor.middleware_parameters.scheme = [Dense(64)] * 3
|
||||
top_actor.learning_rate = 0.001
|
||||
top_actor.batch_size = 4096
|
||||
|
||||
@@ -76,7 +74,7 @@ top_critic.input_embedders_parameters = {'observation': InputEmbedderParameters(
|
||||
'action': InputEmbedderParameters(scheme=EmbedderScheme.Empty),
|
||||
'desired_goal': InputEmbedderParameters(scheme=EmbedderScheme.Empty)}
|
||||
top_critic.embedding_merger_type = EmbeddingMergerType.Concat
|
||||
top_critic.middleware_parameters.scheme = [Dense([64])] * 3
|
||||
top_critic.middleware_parameters.scheme = [Dense(64)] * 3
|
||||
top_critic.learning_rate = 0.001
|
||||
top_critic.batch_size = 4096
|
||||
|
||||
@@ -107,7 +105,7 @@ bottom_agent_params.exploration.continuous_exploration_policy_parameters.theta =
|
||||
bottom_actor = bottom_agent_params.network_wrappers['actor']
|
||||
bottom_actor.input_embedders_parameters = {'observation': InputEmbedderParameters(scheme=EmbedderScheme.Empty),
|
||||
'desired_goal': InputEmbedderParameters(scheme=EmbedderScheme.Empty)}
|
||||
bottom_actor.middleware_parameters.scheme = [Dense([64])] * 3
|
||||
bottom_actor.middleware_parameters.scheme = [Dense(64)] * 3
|
||||
bottom_actor.learning_rate = 0.001
|
||||
bottom_actor.batch_size = 4096
|
||||
|
||||
@@ -117,7 +115,7 @@ bottom_critic.input_embedders_parameters = {'observation': InputEmbedderParamete
|
||||
'action': InputEmbedderParameters(scheme=EmbedderScheme.Empty),
|
||||
'desired_goal': InputEmbedderParameters(scheme=EmbedderScheme.Empty)}
|
||||
bottom_critic.embedding_merger_type = EmbeddingMergerType.Concat
|
||||
bottom_critic.middleware_parameters.scheme = [Dense([64])] * 3
|
||||
bottom_critic.middleware_parameters.scheme = [Dense(64)] * 3
|
||||
bottom_critic.learning_rate = 0.001
|
||||
bottom_critic.batch_size = 4096
|
||||
|
||||
@@ -128,8 +126,7 @@ agents_params = [top_agent_params, bottom_agent_params]
|
||||
###############
|
||||
time_limit = 1000
|
||||
|
||||
env_params = Mujoco()
|
||||
env_params.level = "rl_coach.environments.mujoco.pendulum_with_goals:PendulumWithGoals"
|
||||
env_params = GymVectorEnvironment(level="rl_coach.environments.mujoco.pendulum_with_goals:PendulumWithGoals")
|
||||
env_params.additional_simulator_parameters = {"time_limit": time_limit,
|
||||
"random_goals_instead_of_standing_goal": False,
|
||||
"polar_coordinates": polar_coordinates,
|
||||
@@ -138,8 +135,6 @@ env_params.frame_skip = 10
|
||||
env_params.custom_reward_threshold = -time_limit + 1
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST)]
|
||||
vis_params.dump_mp4 = False
|
||||
vis_params.native_rendering = False
|
||||
|
||||
graph_manager = HACGraphManager(agents_params=agents_params, env_params=env_params,
|
||||
|
||||
Reference in New Issue
Block a user