mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
network_imporvements branch merge
This commit is contained in:
@@ -16,7 +16,8 @@ pages:
|
||||
- 'Control Flow' : design/control_flow.md
|
||||
- 'Network' : design/network.md
|
||||
- 'Filters' : design/filters.md
|
||||
|
||||
- API Reference:
|
||||
- 'Agent Parameters' : api_reference/agent_parameters/agent_parameters.md
|
||||
- Algorithms:
|
||||
- 'DQN' : algorithms/value_optimization/dqn.md
|
||||
- 'Double DQN' : algorithms/value_optimization/double_dqn.md
|
||||
|
||||
@@ -25,14 +25,38 @@ from rl_coach.architectures.tensorflow_components.heads.v_head import VHeadParam
|
||||
from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters
|
||||
from rl_coach.base_parameters import AlgorithmParameters, NetworkParameters, \
|
||||
AgentParameters
|
||||
from rl_coach.exploration_policies.categorical import CategoricalParameters
|
||||
from rl_coach.exploration_policies.continuous_entropy import ContinuousEntropyParameters
|
||||
from rl_coach.logger import screen
|
||||
from rl_coach.memories.episodic.single_episode_buffer import SingleEpisodeBufferParameters
|
||||
from rl_coach.spaces import DiscreteActionSpace
|
||||
from rl_coach.spaces import DiscreteActionSpace, BoxActionSpace
|
||||
from rl_coach.utils import last_sample
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
|
||||
|
||||
class ActorCriticAlgorithmParameters(AlgorithmParameters):
|
||||
"""
|
||||
:param policy_gradient_rescaler: (PolicyGradientRescaler)
|
||||
The value that will be used to rescale the policy gradient
|
||||
|
||||
:param apply_gradients_every_x_episodes: (int)
|
||||
The number of episodes to wait before applying the accumulated gradients to the network.
|
||||
The training iterations only accumulate gradients without actually applying them.
|
||||
|
||||
:param beta_entropy: (float)
|
||||
The weight that will be given to the entropy regularization which is used in order to improve exploration.
|
||||
|
||||
:param num_steps_between_gradient_updates: (int)
|
||||
Every num_steps_between_gradient_updates transitions will be considered as a single batch and use for
|
||||
accumulating gradients. This is also the number of steps used for bootstrapping according to the n-step formulation.
|
||||
|
||||
:param gae_lambda: (float)
|
||||
If the policy gradient rescaler was defined as PolicyGradientRescaler.GAE, the generalized advantage estimation
|
||||
scheme will be used, in which case the lambda value controls the decay for the different n-step lengths.
|
||||
|
||||
:param estimate_state_value_using_gae: (bool)
|
||||
If set to True, the state value targets for the V head will be estimated using the GAE scheme.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.policy_gradient_rescaler = PolicyGradientRescaler.A_VALUE
|
||||
@@ -48,9 +72,7 @@ class ActorCriticNetworkParameters(NetworkParameters):
|
||||
super().__init__()
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters()}
|
||||
self.middleware_parameters = FCMiddlewareParameters()
|
||||
self.heads_parameters = [VHeadParameters(), PolicyHeadParameters()]
|
||||
self.loss_weights = [0.5, 1.0]
|
||||
self.rescale_gradient_from_head_by_factor = [1, 1]
|
||||
self.heads_parameters = [VHeadParameters(loss_weight=0.5), PolicyHeadParameters(loss_weight=1.0)]
|
||||
self.optimizer_type = 'Adam'
|
||||
self.clip_gradients = 40.0
|
||||
self.async_training = True
|
||||
@@ -59,8 +81,8 @@ class ActorCriticNetworkParameters(NetworkParameters):
|
||||
class ActorCriticAgentParameters(AgentParameters):
|
||||
def __init__(self):
|
||||
super().__init__(algorithm=ActorCriticAlgorithmParameters(),
|
||||
exploration=None, #TODO this should be different for continuous (ContinuousEntropyExploration)
|
||||
# and discrete (CategoricalExploration) action spaces.
|
||||
exploration={DiscreteActionSpace: CategoricalParameters(),
|
||||
BoxActionSpace: ContinuousEntropyParameters()},
|
||||
memory=SingleEpisodeBufferParameters(),
|
||||
networks={"main": ActorCriticNetworkParameters()})
|
||||
|
||||
|
||||
@@ -157,6 +157,10 @@ class Agent(AgentInterface):
|
||||
if self.ap.task_parameters.seed is not None:
|
||||
random.seed(self.ap.task_parameters.seed)
|
||||
np.random.seed(self.ap.task_parameters.seed)
|
||||
else:
|
||||
# we need to seed the RNG since the different processes are initialized with the same parent seed
|
||||
random.seed()
|
||||
np.random.seed()
|
||||
|
||||
@property
|
||||
def parent(self):
|
||||
@@ -269,6 +273,10 @@ class Agent(AgentInterface):
|
||||
spaces=self.spaces,
|
||||
replicated_device=self.replicated_device,
|
||||
worker_device=self.worker_device)
|
||||
|
||||
if self.ap.visualization.print_networks_summary:
|
||||
print(networks[network_name])
|
||||
|
||||
return networks
|
||||
|
||||
def init_environment_dependent_modules(self) -> None:
|
||||
@@ -278,6 +286,14 @@ class Agent(AgentInterface):
|
||||
:return: None
|
||||
"""
|
||||
# initialize exploration policy
|
||||
if isinstance(self.ap.exploration, dict):
|
||||
if self.spaces.action.__class__ in self.ap.exploration.keys():
|
||||
self.ap.exploration = self.ap.exploration[self.spaces.action.__class__]
|
||||
else:
|
||||
raise ValueError("The exploration parameters were defined as a mapping between action space types and "
|
||||
"exploration types, but the action space used by the environment ({}) was not part of "
|
||||
"the exploration parameters dictionary keys ({})"
|
||||
.format(self.spaces.action.__class__, list(self.ap.exploration.keys())))
|
||||
self.ap.exploration.action_space = self.spaces.action
|
||||
self.exploration_policy = dynamic_import_and_instantiate_module_from_params(self.ap.exploration)
|
||||
|
||||
@@ -543,6 +559,9 @@ class Agent(AgentInterface):
|
||||
"""
|
||||
loss = 0
|
||||
if self._should_train():
|
||||
for network in self.networks.values():
|
||||
network.set_is_training(True)
|
||||
|
||||
for training_step in range(self.ap.algorithm.num_consecutive_training_steps):
|
||||
# TODO: this should be network dependent
|
||||
network_parameters = list(self.ap.network_wrappers.values())[0]
|
||||
@@ -586,9 +605,14 @@ class Agent(AgentInterface):
|
||||
if self.imitation:
|
||||
self.log_to_screen()
|
||||
|
||||
for network in self.networks.values():
|
||||
network.set_is_training(False)
|
||||
|
||||
# run additional commands after the training is done
|
||||
self.post_training_commands()
|
||||
|
||||
|
||||
|
||||
return loss
|
||||
|
||||
def choose_action(self, curr_state):
|
||||
|
||||
@@ -26,6 +26,7 @@ from rl_coach.base_parameters import AgentParameters, AlgorithmParameters, Netwo
|
||||
from rl_coach.exploration_policies.e_greedy import EGreedyParameters
|
||||
from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplayParameters
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayParameters
|
||||
|
||||
|
||||
class BCAlgorithmParameters(AlgorithmParameters):
|
||||
@@ -40,7 +41,6 @@ class BCNetworkParameters(NetworkParameters):
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters()}
|
||||
self.middleware_parameters = FCMiddlewareParameters(scheme=MiddlewareScheme.Medium)
|
||||
self.heads_parameters = [PolicyHeadParameters()]
|
||||
self.loss_weights = [1.0]
|
||||
self.optimizer_type = 'Adam'
|
||||
self.batch_size = 32
|
||||
self.replace_mse_with_huber_loss = False
|
||||
@@ -51,7 +51,7 @@ class BCAgentParameters(AgentParameters):
|
||||
def __init__(self):
|
||||
super().__init__(algorithm=BCAlgorithmParameters(),
|
||||
exploration=EGreedyParameters(),
|
||||
memory=EpisodicExperienceReplayParameters(),
|
||||
memory=ExperienceReplayParameters(),
|
||||
networks={"main": BCNetworkParameters()})
|
||||
|
||||
@property
|
||||
|
||||
@@ -28,8 +28,8 @@ from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayPar
|
||||
class BootstrappedDQNNetworkParameters(DQNNetworkParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.num_output_head_copies = 10
|
||||
self.rescale_gradient_from_head_by_factor = [1.0/self.num_output_head_copies]*self.num_output_head_copies
|
||||
self.heads_parameters[0].num_output_head_copies = 10
|
||||
self.heads_parameters[0].rescale_gradient_from_head_by_factor = 1.0/self.heads_parameters[0].num_output_head_copies
|
||||
|
||||
|
||||
class BootstrappedDQNAgentParameters(AgentParameters):
|
||||
|
||||
@@ -38,7 +38,6 @@ class CILNetworkParameters(NetworkParameters):
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters()}
|
||||
self.middleware_parameters = FCMiddlewareParameters(scheme=MiddlewareScheme.Medium)
|
||||
self.heads_parameters = [RegressionHeadParameters()]
|
||||
self.loss_weights = [1.0]
|
||||
self.optimizer_type = 'Adam'
|
||||
self.batch_size = 32
|
||||
self.replace_mse_with_huber_loss = False
|
||||
|
||||
@@ -31,10 +31,11 @@ from rl_coach.base_parameters import AlgorithmParameters, NetworkParameters, \
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
from rl_coach.core_types import EnvironmentSteps, Batch, EnvResponse, StateType
|
||||
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
|
||||
from rl_coach.exploration_policies.categorical import CategoricalParameters
|
||||
from rl_coach.logger import screen
|
||||
from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplayParameters
|
||||
from rl_coach.schedules import ConstantSchedule
|
||||
from rl_coach.spaces import DiscreteActionSpace
|
||||
from rl_coach.spaces import DiscreteActionSpace, BoxActionSpace
|
||||
|
||||
|
||||
class ClippedPPONetworkParameters(NetworkParameters):
|
||||
@@ -43,8 +44,6 @@ class ClippedPPONetworkParameters(NetworkParameters):
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters(activation_function='tanh')}
|
||||
self.middleware_parameters = FCMiddlewareParameters(activation_function='tanh')
|
||||
self.heads_parameters = [VHeadParameters(), PPOHeadParameters()]
|
||||
self.loss_weights = [1.0, 1.0]
|
||||
self.rescale_gradient_from_head_by_factor = [1, 1]
|
||||
self.batch_size = 64
|
||||
self.optimizer_type = 'Adam'
|
||||
self.clip_gradients = None
|
||||
@@ -79,7 +78,8 @@ class ClippedPPOAlgorithmParameters(AlgorithmParameters):
|
||||
class ClippedPPOAgentParameters(AgentParameters):
|
||||
def __init__(self):
|
||||
super().__init__(algorithm=ClippedPPOAlgorithmParameters(),
|
||||
exploration=AdditiveNoiseParameters(),
|
||||
exploration={DiscreteActionSpace: CategoricalParameters(),
|
||||
BoxActionSpace: AdditiveNoiseParameters()},
|
||||
memory=EpisodicExperienceReplayParameters(),
|
||||
networks={"main": ClippedPPONetworkParameters()})
|
||||
|
||||
@@ -253,6 +253,9 @@ class ClippedPPOAgent(ActorCriticAgent):
|
||||
|
||||
def train(self):
|
||||
if self._should_train(wait_for_full_episode=True):
|
||||
for network in self.networks.values():
|
||||
network.set_is_training(True)
|
||||
|
||||
dataset = self.memory.transitions
|
||||
dataset = self.pre_network_filter.filter(dataset, deep_copy=False)
|
||||
batch = Batch(dataset)
|
||||
@@ -269,6 +272,9 @@ class ClippedPPOAgent(ActorCriticAgent):
|
||||
|
||||
self.train_network(batch, self.ap.algorithm.optimization_epochs)
|
||||
|
||||
for network in self.networks.values():
|
||||
network.set_is_training(False)
|
||||
|
||||
self.post_training_commands()
|
||||
self.training_iteration += 1
|
||||
# should be done in order to update the data that has been accumulated * while not playing *
|
||||
|
||||
@@ -41,8 +41,6 @@ class DDPGCriticNetworkParameters(NetworkParameters):
|
||||
'action': InputEmbedderParameters(scheme=EmbedderScheme.Shallow)}
|
||||
self.middleware_parameters = FCMiddlewareParameters()
|
||||
self.heads_parameters = [VHeadParameters()]
|
||||
self.loss_weights = [1.0]
|
||||
self.rescale_gradient_from_head_by_factor = [1]
|
||||
self.optimizer_type = 'Adam'
|
||||
self.batch_size = 64
|
||||
self.async_training = False
|
||||
@@ -58,8 +56,6 @@ class DDPGActorNetworkParameters(NetworkParameters):
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters(batchnorm=True)}
|
||||
self.middleware_parameters = FCMiddlewareParameters(batchnorm=True)
|
||||
self.heads_parameters = [DDPGActorHeadParameters()]
|
||||
self.loss_weights = [1.0]
|
||||
self.rescale_gradient_from_head_by_factor = [1]
|
||||
self.optimizer_type = 'Adam'
|
||||
self.batch_size = 64
|
||||
self.async_training = False
|
||||
|
||||
@@ -21,7 +21,7 @@ from typing import Union
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.agents.agent import Agent
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Conv2d, Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import Conv2d, Dense
|
||||
from rl_coach.architectures.tensorflow_components.heads.measurements_prediction_head import \
|
||||
MeasurementsPredictionHeadParameters
|
||||
from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters
|
||||
@@ -48,28 +48,27 @@ class DFPNetworkParameters(NetworkParameters):
|
||||
'goal': InputEmbedderParameters(activation_function='leaky_relu')}
|
||||
|
||||
self.input_embedders_parameters['observation'].scheme = [
|
||||
Conv2d([32, 8, 4]),
|
||||
Conv2d([64, 4, 2]),
|
||||
Conv2d([64, 3, 1]),
|
||||
Dense([512]),
|
||||
Conv2d(32, 8, 4),
|
||||
Conv2d(64, 4, 2),
|
||||
Conv2d(64, 3, 1),
|
||||
Dense(512),
|
||||
]
|
||||
|
||||
self.input_embedders_parameters['measurements'].scheme = [
|
||||
Dense([128]),
|
||||
Dense([128]),
|
||||
Dense([128]),
|
||||
Dense(128),
|
||||
Dense(128),
|
||||
Dense(128),
|
||||
]
|
||||
|
||||
self.input_embedders_parameters['goal'].scheme = [
|
||||
Dense([128]),
|
||||
Dense([128]),
|
||||
Dense([128]),
|
||||
Dense(128),
|
||||
Dense(128),
|
||||
Dense(128),
|
||||
]
|
||||
|
||||
self.middleware_parameters = FCMiddlewareParameters(activation_function='leaky_relu',
|
||||
scheme=MiddlewareScheme.Empty)
|
||||
self.heads_parameters = [MeasurementsPredictionHeadParameters(activation_function='leaky_relu')]
|
||||
self.loss_weights = [1.0]
|
||||
self.async_training = False
|
||||
self.batch_size = 64
|
||||
self.adam_optimizer_beta1 = 0.95
|
||||
|
||||
@@ -44,7 +44,6 @@ class DQNNetworkParameters(NetworkParameters):
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters()}
|
||||
self.middleware_parameters = FCMiddlewareParameters(scheme=MiddlewareScheme.Medium)
|
||||
self.heads_parameters = [QHeadParameters()]
|
||||
self.loss_weights = [1.0]
|
||||
self.optimizer_type = 'Adam'
|
||||
self.batch_size = 32
|
||||
self.replace_mse_with_huber_loss = True
|
||||
|
||||
@@ -46,8 +46,6 @@ class HumanNetworkParameters(NetworkParameters):
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters()}
|
||||
self.input_embedders_parameters['observation'].scheme = EmbedderScheme.Medium
|
||||
self.middleware_parameters = FCMiddlewareParameters()
|
||||
self.heads_parameters = [PolicyHeadParameters()]
|
||||
self.loss_weights = [1.0]
|
||||
self.optimizer_type = 'Adam'
|
||||
self.batch_size = 32
|
||||
self.replace_mse_with_huber_loss = False
|
||||
|
||||
@@ -37,7 +37,6 @@ class NStepQNetworkParameters(NetworkParameters):
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters()}
|
||||
self.middleware_parameters = FCMiddlewareParameters()
|
||||
self.heads_parameters = [QHeadParameters()]
|
||||
self.loss_weights = [1.0]
|
||||
self.optimizer_type = 'Adam'
|
||||
self.async_training = True
|
||||
self.shared_optimizer = True
|
||||
|
||||
@@ -37,7 +37,6 @@ class NAFNetworkParameters(NetworkParameters):
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters()}
|
||||
self.middleware_parameters = FCMiddlewareParameters()
|
||||
self.heads_parameters = [NAFHeadParameters()]
|
||||
self.loss_weights = [1.0]
|
||||
self.optimizer_type = 'Adam'
|
||||
self.learning_rate = 0.001
|
||||
self.async_training = True
|
||||
|
||||
@@ -39,8 +39,6 @@ class NECNetworkParameters(NetworkParameters):
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters()}
|
||||
self.middleware_parameters = FCMiddlewareParameters()
|
||||
self.heads_parameters = [DNDQHeadParameters()]
|
||||
self.loss_weights = [1.0]
|
||||
self.rescale_gradient_from_head_by_factor = [1]
|
||||
self.optimizer_type = 'Adam'
|
||||
|
||||
|
||||
|
||||
@@ -26,9 +26,10 @@ from rl_coach.base_parameters import NetworkParameters, AlgorithmParameters, \
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
|
||||
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
|
||||
from rl_coach.exploration_policies.categorical import CategoricalParameters
|
||||
from rl_coach.logger import screen
|
||||
from rl_coach.memories.episodic.single_episode_buffer import SingleEpisodeBufferParameters
|
||||
from rl_coach.spaces import DiscreteActionSpace
|
||||
from rl_coach.spaces import DiscreteActionSpace, BoxActionSpace
|
||||
|
||||
|
||||
class PolicyGradientNetworkParameters(NetworkParameters):
|
||||
@@ -37,7 +38,6 @@ class PolicyGradientNetworkParameters(NetworkParameters):
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters()}
|
||||
self.middleware_parameters = FCMiddlewareParameters()
|
||||
self.heads_parameters = [PolicyHeadParameters()]
|
||||
self.loss_weights = [1.0]
|
||||
self.async_training = True
|
||||
|
||||
|
||||
@@ -53,7 +53,8 @@ class PolicyGradientAlgorithmParameters(AlgorithmParameters):
|
||||
class PolicyGradientsAgentParameters(AgentParameters):
|
||||
def __init__(self):
|
||||
super().__init__(algorithm=PolicyGradientAlgorithmParameters(),
|
||||
exploration=AdditiveNoiseParameters(),
|
||||
exploration={DiscreteActionSpace: CategoricalParameters(),
|
||||
BoxActionSpace: AdditiveNoiseParameters()},
|
||||
memory=SingleEpisodeBufferParameters(),
|
||||
networks={"main": PolicyGradientNetworkParameters()})
|
||||
|
||||
|
||||
@@ -93,6 +93,8 @@ class PolicyOptimizationAgent(Agent):
|
||||
|
||||
total_loss = 0
|
||||
if num_steps_passed_since_last_update > 0:
|
||||
for network in self.networks.values():
|
||||
network.set_is_training(True)
|
||||
|
||||
# we need to update the returns of the episode until now
|
||||
episode.update_returns()
|
||||
@@ -124,6 +126,9 @@ class PolicyOptimizationAgent(Agent):
|
||||
network.apply_gradients_and_sync_networks()
|
||||
self.training_iteration += 1
|
||||
|
||||
for network in self.networks.values():
|
||||
network.set_is_training(False)
|
||||
|
||||
# run additional commands after the training is done
|
||||
self.post_training_commands()
|
||||
|
||||
|
||||
@@ -31,9 +31,10 @@ from rl_coach.architectures.tensorflow_components.embedders.embedder import Inpu
|
||||
|
||||
from rl_coach.core_types import EnvironmentSteps, Batch
|
||||
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
|
||||
from rl_coach.exploration_policies.categorical import CategoricalParameters
|
||||
from rl_coach.logger import screen
|
||||
from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplayParameters
|
||||
from rl_coach.spaces import DiscreteActionSpace
|
||||
from rl_coach.spaces import DiscreteActionSpace, BoxActionSpace
|
||||
from rl_coach.utils import force_list
|
||||
|
||||
|
||||
@@ -43,7 +44,6 @@ class PPOCriticNetworkParameters(NetworkParameters):
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters(activation_function='tanh')}
|
||||
self.middleware_parameters = FCMiddlewareParameters(activation_function='tanh')
|
||||
self.heads_parameters = [VHeadParameters()]
|
||||
self.loss_weights = [1.0]
|
||||
self.async_training = True
|
||||
self.l2_regularization = 0
|
||||
self.create_target_network = True
|
||||
@@ -57,7 +57,6 @@ class PPOActorNetworkParameters(NetworkParameters):
|
||||
self.middleware_parameters = FCMiddlewareParameters(activation_function='tanh')
|
||||
self.heads_parameters = [PPOHeadParameters()]
|
||||
self.optimizer_type = 'Adam'
|
||||
self.loss_weights = [1.0]
|
||||
self.async_training = True
|
||||
self.l2_regularization = 0
|
||||
self.create_target_network = True
|
||||
@@ -84,7 +83,8 @@ class PPOAlgorithmParameters(AlgorithmParameters):
|
||||
class PPOAgentParameters(AgentParameters):
|
||||
def __init__(self):
|
||||
super().__init__(algorithm=PPOAlgorithmParameters(),
|
||||
exploration=AdditiveNoiseParameters(),
|
||||
exploration={DiscreteActionSpace: CategoricalParameters(),
|
||||
BoxActionSpace: AdditiveNoiseParameters()},
|
||||
memory=EpisodicExperienceReplayParameters(),
|
||||
networks={"critic": PPOCriticNetworkParameters(), "actor": PPOActorNetworkParameters()})
|
||||
|
||||
@@ -313,6 +313,9 @@ class PPOAgent(ActorCriticAgent):
|
||||
def train(self):
|
||||
loss = 0
|
||||
if self._should_train(wait_for_full_episode=True):
|
||||
for network in self.networks.values():
|
||||
network.set_is_training(True)
|
||||
|
||||
for training_step in range(self.ap.algorithm.num_consecutive_training_steps):
|
||||
self.networks['actor'].sync()
|
||||
self.networks['critic'].sync()
|
||||
@@ -330,6 +333,9 @@ class PPOAgent(ActorCriticAgent):
|
||||
self.value_loss.add_sample(value_loss)
|
||||
self.policy_loss.add_sample(policy_loss)
|
||||
|
||||
for network in self.networks.values():
|
||||
network.set_is_training(False)
|
||||
|
||||
self.post_training_commands()
|
||||
self.training_iteration += 1
|
||||
self.update_log() # should be done in order to update the data that has been accumulated * while not playing *
|
||||
|
||||
@@ -199,6 +199,16 @@ class NetworkWrapper(object):
|
||||
global_variables = [v for v in tf.global_variables() if self.global_network.name in v.name]
|
||||
return global_variables
|
||||
|
||||
def set_is_training(self, state: bool):
|
||||
"""
|
||||
Set the phase of the network between training and testing
|
||||
:param state: The current state (True = Training, False = Testing)
|
||||
:return: None
|
||||
"""
|
||||
self.online_network.set_is_training(state)
|
||||
if self.has_target:
|
||||
self.target_network.set_is_training(state)
|
||||
|
||||
def set_session(self, sess):
|
||||
self.sess = sess
|
||||
self.online_network.set_session(sess)
|
||||
@@ -207,3 +217,18 @@ class NetworkWrapper(object):
|
||||
if self.target_network:
|
||||
self.target_network.set_session(sess)
|
||||
|
||||
def __str__(self):
|
||||
sub_networks = []
|
||||
if self.global_network:
|
||||
sub_networks.append("global network")
|
||||
if self.online_network:
|
||||
sub_networks.append("online network")
|
||||
if self.target_network:
|
||||
sub_networks.append("target network")
|
||||
|
||||
result = []
|
||||
result.append("Network: {}, Copies: {} ({})".format(self.name, len(sub_networks), ' | '.join(sub_networks)))
|
||||
result.append("-"*len(result[-1]))
|
||||
result.append(str(self.online_network))
|
||||
result.append("")
|
||||
return '\n'.join(result)
|
||||
|
||||
@@ -13,9 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import math
|
||||
import time
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
@@ -27,135 +25,6 @@ from rl_coach.spaces import SpacesDefinition
|
||||
from rl_coach.utils import force_list, squeeze_list
|
||||
|
||||
|
||||
def batchnorm_activation_dropout(input_layer, batchnorm, activation_function, dropout, dropout_rate, layer_idx):
|
||||
layers = [input_layer]
|
||||
|
||||
# batchnorm
|
||||
if batchnorm:
|
||||
layers.append(
|
||||
tf.layers.batch_normalization(layers[-1], name="batchnorm{}".format(layer_idx))
|
||||
)
|
||||
|
||||
# activation
|
||||
if activation_function:
|
||||
layers.append(
|
||||
activation_function(layers[-1], name="activation{}".format(layer_idx))
|
||||
)
|
||||
|
||||
# dropout
|
||||
if dropout:
|
||||
layers.append(
|
||||
tf.layers.dropout(layers[-1], dropout_rate, name="dropout{}".format(layer_idx))
|
||||
)
|
||||
|
||||
# remove the input layer from the layers list
|
||||
del layers[0]
|
||||
|
||||
return layers
|
||||
|
||||
|
||||
class Conv2d(object):
|
||||
def __init__(self, params: List):
|
||||
"""
|
||||
:param params: list of [num_filters, kernel_size, strides]
|
||||
"""
|
||||
self.params = params
|
||||
|
||||
def __call__(self, input_layer, name: str=None):
|
||||
"""
|
||||
returns a tensorflow conv2d layer
|
||||
:param input_layer: previous layer
|
||||
:param name: layer name
|
||||
:return: conv2d layer
|
||||
"""
|
||||
return tf.layers.conv2d(input_layer, filters=self.params[0], kernel_size=self.params[1], strides=self.params[2],
|
||||
data_format='channels_last', name=name)
|
||||
|
||||
|
||||
class Dense(object):
|
||||
def __init__(self, params: Union[List, int]):
|
||||
"""
|
||||
:param params: list of [num_output_neurons]
|
||||
"""
|
||||
self.params = force_list(params)
|
||||
|
||||
def __call__(self, input_layer, name: str=None, kernel_initializer=None, activation=None):
|
||||
"""
|
||||
returns a tensorflow dense layer
|
||||
:param input_layer: previous layer
|
||||
:param name: layer name
|
||||
:return: dense layer
|
||||
"""
|
||||
return tf.layers.dense(input_layer, self.params[0], name=name, kernel_initializer=kernel_initializer,
|
||||
activation=activation)
|
||||
|
||||
|
||||
class NoisyNetDense(object):
|
||||
"""
|
||||
A factorized Noisy Net layer
|
||||
|
||||
https://arxiv.org/abs/1706.10295.
|
||||
"""
|
||||
|
||||
def __init__(self, params: List):
|
||||
"""
|
||||
:param params: list of [num_output_neurons]
|
||||
"""
|
||||
self.params = force_list(params)
|
||||
self.sigma0 = 0.5
|
||||
|
||||
def __call__(self, input_layer, name: str, kernel_initializer=None, activation=None):
|
||||
"""
|
||||
returns a NoisyNet dense layer
|
||||
:param input_layer: previous layer
|
||||
:param name: layer name
|
||||
:param kernel_initializer: initializer for kernels. Default is to use Gaussian noise that preserves stddev.
|
||||
:param activation: the activation function
|
||||
:return: dense layer
|
||||
"""
|
||||
#TODO: noise sampling should be externally controlled. DQN is fine with sampling noise for every
|
||||
# forward (either act or train, both for online and target networks).
|
||||
# A3C, on the other hand, should sample noise only when policy changes (i.e. after every t_max steps)
|
||||
|
||||
num_inputs = input_layer.get_shape()[-1].value
|
||||
num_outputs = self.params[0]
|
||||
|
||||
stddev = 1 / math.sqrt(num_inputs)
|
||||
activation = activation if activation is not None else (lambda x: x)
|
||||
|
||||
if kernel_initializer is None:
|
||||
kernel_mean_initializer = tf.random_uniform_initializer(-stddev, stddev)
|
||||
kernel_stddev_initializer = tf.random_uniform_initializer(-stddev * self.sigma0, stddev * self.sigma0)
|
||||
else:
|
||||
kernel_mean_initializer = kernel_stddev_initializer = kernel_initializer
|
||||
with tf.variable_scope(None, default_name=name):
|
||||
weight_mean = tf.get_variable('weight_mean', shape=(num_inputs, num_outputs),
|
||||
initializer=kernel_mean_initializer)
|
||||
bias_mean = tf.get_variable('bias_mean', shape=(num_outputs,), initializer=tf.zeros_initializer())
|
||||
|
||||
weight_stddev = tf.get_variable('weight_stddev', shape=(num_inputs, num_outputs),
|
||||
initializer=kernel_stddev_initializer)
|
||||
bias_stddev = tf.get_variable('bias_stddev', shape=(num_outputs,),
|
||||
initializer=kernel_stddev_initializer)
|
||||
bias_noise = self.f(tf.random_normal((num_outputs,)))
|
||||
weight_noise = self.factorized_noise(num_inputs, num_outputs)
|
||||
|
||||
bias = bias_mean + bias_stddev * bias_noise
|
||||
weight = weight_mean + weight_stddev * weight_noise
|
||||
return activation(tf.matmul(input_layer, weight) + bias)
|
||||
|
||||
def factorized_noise(self, inputs, outputs):
|
||||
# TODO: use factorized noise only for compute intensive algos (e.g. DQN).
|
||||
# lighter algos (e.g. DQN) should not use it
|
||||
noise1 = self.f(tf.random_normal((inputs, 1)))
|
||||
noise2 = self.f(tf.random_normal((1, outputs)))
|
||||
return tf.matmul(noise1, noise2)
|
||||
|
||||
@staticmethod
|
||||
def f(values):
|
||||
return tf.sqrt(tf.abs(values)) * tf.sign(values)
|
||||
|
||||
|
||||
def variable_summaries(var):
|
||||
"""Attach a lot of summaries to a Tensor (for TensorBoard visualization)."""
|
||||
with tf.name_scope('summaries'):
|
||||
@@ -720,6 +589,14 @@ class TensorFlowArchitecture(Architecture):
|
||||
"""
|
||||
self.sess.run(assign_op, feed_dict={placeholder: value})
|
||||
|
||||
def set_is_training(self, state: bool):
|
||||
"""
|
||||
Set the phase of the network between training and testing
|
||||
:param state: The current state (True = Training, False = Testing)
|
||||
:return: None
|
||||
"""
|
||||
self.set_variable_value(self.assign_is_training, state, self.is_training_placeholder)
|
||||
|
||||
def reset_internal_memory(self):
|
||||
"""
|
||||
Reset any internal memory used by the network. For example, an LSTM internal state
|
||||
@@ -728,4 +605,4 @@ class TensorFlowArchitecture(Architecture):
|
||||
# initialize LSTM hidden states
|
||||
if self.middleware.__class__.__name__ == 'LSTMMiddleware':
|
||||
self.curr_rnn_c_in = self.middleware.c_init
|
||||
self.curr_rnn_h_in = self.middleware.h_init
|
||||
self.curr_rnn_h_in = self.middleware.h_init
|
||||
|
||||
@@ -15,20 +15,23 @@
|
||||
#
|
||||
|
||||
from typing import List, Union
|
||||
import copy
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.architecture import batchnorm_activation_dropout, Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import batchnorm_activation_dropout, Dense, \
|
||||
BatchnormActivationDropout
|
||||
from rl_coach.base_parameters import EmbedderScheme, NetworkComponentParameters
|
||||
|
||||
from rl_coach.core_types import InputEmbedding
|
||||
from rl_coach.utils import force_list
|
||||
|
||||
|
||||
class InputEmbedderParameters(NetworkComponentParameters):
|
||||
def __init__(self, activation_function: str='relu', scheme: Union[List, EmbedderScheme]=EmbedderScheme.Medium,
|
||||
batchnorm: bool=False, dropout=False, name: str='embedder', input_rescaling=None, input_offset=None,
|
||||
input_clipping=None, dense_layer=Dense):
|
||||
input_clipping=None, dense_layer=Dense, is_training=False):
|
||||
super().__init__(dense_layer=dense_layer)
|
||||
self.activation_function = activation_function
|
||||
self.scheme = scheme
|
||||
@@ -44,6 +47,7 @@ class InputEmbedderParameters(NetworkComponentParameters):
|
||||
self.input_offset = input_offset
|
||||
self.input_clipping = input_clipping
|
||||
self.name = name
|
||||
self.is_training = is_training
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
@@ -61,7 +65,8 @@ class InputEmbedder(object):
|
||||
"""
|
||||
def __init__(self, input_size: List[int], activation_function=tf.nn.relu,
|
||||
scheme: EmbedderScheme=None, batchnorm: bool=False, dropout: bool=False,
|
||||
name: str= "embedder", input_rescaling=1.0, input_offset=0.0, input_clipping=None, dense_layer=Dense):
|
||||
name: str= "embedder", input_rescaling=1.0, input_offset=0.0, input_clipping=None, dense_layer=Dense,
|
||||
is_training=False):
|
||||
self.name = name
|
||||
self.input_size = input_size
|
||||
self.activation_function = activation_function
|
||||
@@ -72,11 +77,29 @@ class InputEmbedder(object):
|
||||
self.output = None
|
||||
self.scheme = scheme
|
||||
self.return_type = InputEmbedding
|
||||
self.layers_params = []
|
||||
self.layers = []
|
||||
self.input_rescaling = input_rescaling
|
||||
self.input_offset = input_offset
|
||||
self.input_clipping = input_clipping
|
||||
self.dense_layer = dense_layer
|
||||
self.is_training = is_training
|
||||
|
||||
# layers order is conv -> batchnorm -> activation -> dropout
|
||||
if isinstance(self.scheme, EmbedderScheme):
|
||||
self.layers_params = copy.copy(self.schemes[self.scheme])
|
||||
else:
|
||||
self.layers_params = copy.copy(self.scheme)
|
||||
|
||||
# we allow adding batchnorm, dropout or activation functions after each layer.
|
||||
# The motivation is to simplify the transition between a network with batchnorm and a network without
|
||||
# batchnorm to a single flag (the same applies to activation function and dropout)
|
||||
if self.batchnorm or self.activation_function or self.dropout:
|
||||
for layer_idx in reversed(range(len(self.layers_params))):
|
||||
self.layers_params.insert(layer_idx+1,
|
||||
BatchnormActivationDropout(batchnorm=self.batchnorm,
|
||||
activation_function=self.activation_function,
|
||||
dropout_rate=self.dropout_rate))
|
||||
|
||||
def __call__(self, prev_input_placeholder=None):
|
||||
with tf.variable_scope(self.get_name()):
|
||||
@@ -102,19 +125,11 @@ class InputEmbedder(object):
|
||||
|
||||
self.layers.append(input_layer)
|
||||
|
||||
# layers order is conv -> batchnorm -> activation -> dropout
|
||||
if isinstance(self.scheme, EmbedderScheme):
|
||||
layers_params = self.schemes[self.scheme]
|
||||
else:
|
||||
layers_params = self.scheme
|
||||
for idx, layer_params in enumerate(layers_params):
|
||||
self.layers.append(
|
||||
layer_params(input_layer=self.layers[-1], name='{}_{}'.format(layer_params.__class__.__name__, idx))
|
||||
)
|
||||
|
||||
self.layers.extend(batchnorm_activation_dropout(self.layers[-1], self.batchnorm,
|
||||
self.activation_function, self.dropout,
|
||||
self.dropout_rate, idx))
|
||||
for idx, layer_params in enumerate(self.layers_params):
|
||||
self.layers.extend(force_list(
|
||||
layer_params(input_layer=self.layers[-1], name='{}_{}'.format(layer_params.__class__.__name__, idx),
|
||||
is_training=self.is_training)
|
||||
))
|
||||
|
||||
self.output = tf.contrib.layers.flatten(self.layers[-1])
|
||||
|
||||
@@ -140,4 +155,14 @@ class InputEmbedder(object):
|
||||
"configurations.")
|
||||
|
||||
def get_name(self):
|
||||
return self.name
|
||||
return self.name
|
||||
|
||||
def __str__(self):
|
||||
result = []
|
||||
if self.input_rescaling != 1.0 or self.input_offset != 0.0:
|
||||
result.append('Input Normalization (scale = {}, offset = {})'.format(self.input_rescaling, self.input_offset))
|
||||
result.extend([str(l) for l in self.layers_params])
|
||||
if self.layers_params:
|
||||
return '\n'.join(result)
|
||||
else:
|
||||
return 'No layers'
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import List
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Conv2d, Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import Conv2d, Dense
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedder
|
||||
from rl_coach.base_parameters import EmbedderScheme
|
||||
from rl_coach.core_types import InputImageEmbedding
|
||||
@@ -34,9 +34,9 @@ class ImageEmbedder(InputEmbedder):
|
||||
def __init__(self, input_size: List[int], activation_function=tf.nn.relu,
|
||||
scheme: EmbedderScheme=EmbedderScheme.Medium, batchnorm: bool=False, dropout: bool=False,
|
||||
name: str= "embedder", input_rescaling: float=255.0, input_offset: float=0.0, input_clipping=None,
|
||||
dense_layer=Dense):
|
||||
dense_layer=Dense, is_training=False):
|
||||
super().__init__(input_size, activation_function, scheme, batchnorm, dropout, name, input_rescaling,
|
||||
input_offset, input_clipping, dense_layer=dense_layer)
|
||||
input_offset, input_clipping, dense_layer=dense_layer, is_training=is_training)
|
||||
self.return_type = InputImageEmbedding
|
||||
if len(input_size) != 3 and scheme != EmbedderScheme.Empty:
|
||||
raise ValueError("Image embedders expect the input size to have 3 dimensions. The given size is: {}"
|
||||
@@ -50,28 +50,28 @@ class ImageEmbedder(InputEmbedder):
|
||||
|
||||
EmbedderScheme.Shallow:
|
||||
[
|
||||
Conv2d([32, 3, 1])
|
||||
Conv2d(32, 3, 1)
|
||||
],
|
||||
|
||||
# atari dqn
|
||||
EmbedderScheme.Medium:
|
||||
[
|
||||
Conv2d([32, 8, 4]),
|
||||
Conv2d([64, 4, 2]),
|
||||
Conv2d([64, 3, 1])
|
||||
Conv2d(32, 8, 4),
|
||||
Conv2d(64, 4, 2),
|
||||
Conv2d(64, 3, 1)
|
||||
],
|
||||
|
||||
# carla
|
||||
EmbedderScheme.Deep: \
|
||||
[
|
||||
Conv2d([32, 5, 2]),
|
||||
Conv2d([32, 3, 1]),
|
||||
Conv2d([64, 3, 2]),
|
||||
Conv2d([64, 3, 1]),
|
||||
Conv2d([128, 3, 2]),
|
||||
Conv2d([128, 3, 1]),
|
||||
Conv2d([256, 3, 2]),
|
||||
Conv2d([256, 3, 1])
|
||||
Conv2d(32, 5, 2),
|
||||
Conv2d(32, 3, 1),
|
||||
Conv2d(64, 3, 2),
|
||||
Conv2d(64, 3, 1),
|
||||
Conv2d(128, 3, 2),
|
||||
Conv2d(128, 3, 1),
|
||||
Conv2d(256, 3, 2),
|
||||
Conv2d(256, 3, 1)
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import List
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedder
|
||||
from rl_coach.base_parameters import EmbedderScheme
|
||||
from rl_coach.core_types import InputVectorEmbedding
|
||||
@@ -33,9 +33,10 @@ class VectorEmbedder(InputEmbedder):
|
||||
def __init__(self, input_size: List[int], activation_function=tf.nn.relu,
|
||||
scheme: EmbedderScheme=EmbedderScheme.Medium, batchnorm: bool=False, dropout: bool=False,
|
||||
name: str= "embedder", input_rescaling: float=1.0, input_offset:float=0.0, input_clipping=None,
|
||||
dense_layer=Dense):
|
||||
dense_layer=Dense, is_training=False):
|
||||
super().__init__(input_size, activation_function, scheme, batchnorm, dropout, name,
|
||||
input_rescaling, input_offset, input_clipping, dense_layer=dense_layer)
|
||||
input_rescaling, input_offset, input_clipping, dense_layer=dense_layer,
|
||||
is_training=is_training)
|
||||
|
||||
self.return_type = InputVectorEmbedding
|
||||
if len(self.input_size) != 1 and scheme != EmbedderScheme.Empty:
|
||||
@@ -49,20 +50,20 @@ class VectorEmbedder(InputEmbedder):
|
||||
|
||||
EmbedderScheme.Shallow:
|
||||
[
|
||||
self.dense_layer([128])
|
||||
self.dense_layer(128)
|
||||
],
|
||||
|
||||
# dqn
|
||||
EmbedderScheme.Medium:
|
||||
[
|
||||
self.dense_layer([256])
|
||||
self.dense_layer(256)
|
||||
],
|
||||
|
||||
# carla
|
||||
EmbedderScheme.Deep: \
|
||||
[
|
||||
self.dense_layer([128]),
|
||||
self.dense_layer([128]),
|
||||
self.dense_layer([128])
|
||||
self.dense_layer(128),
|
||||
self.dense_layer(128),
|
||||
self.dense_layer(128)
|
||||
]
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ from rl_coach.architectures.tensorflow_components.middlewares.middleware import
|
||||
from rl_coach.base_parameters import AgentParameters, EmbeddingMergerType
|
||||
from rl_coach.core_types import PredictionType
|
||||
from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace
|
||||
from rl_coach.utils import get_all_subclasses, dynamic_import_and_instantiate_module_from_params
|
||||
from rl_coach.utils import get_all_subclasses, dynamic_import_and_instantiate_module_from_params, indent_string
|
||||
|
||||
|
||||
class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
@@ -80,6 +80,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
return ret_dict
|
||||
|
||||
self.available_return_types = fill_return_types()
|
||||
self.is_training = None
|
||||
|
||||
def predict_with_prediction_type(self, states: Dict[str, np.ndarray],
|
||||
prediction_type: PredictionType) -> Dict[str, np.ndarray]:
|
||||
@@ -161,7 +162,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
module = dynamic_import_and_instantiate_module_from_params(middleware_params_copy)
|
||||
return module
|
||||
|
||||
def get_output_head(self, head_params: HeadParameters, head_idx: int, loss_weight: float=1.):
|
||||
def get_output_head(self, head_params: HeadParameters, head_idx: int):
|
||||
"""
|
||||
Given a head type, creates the head and returns it
|
||||
:param head_params: the parameters of the head to create
|
||||
@@ -176,7 +177,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
head_params_copy.activation_function = self.get_activation_function(head_params_copy.activation_function)
|
||||
return dynamic_import_and_instantiate_module_from_params(head_params_copy, extra_kwargs={
|
||||
'agent_parameters': self.ap, 'spaces': self.spaces, 'network_name': self.network_wrapper_name,
|
||||
'head_idx': head_idx, 'loss_weight': loss_weight, 'is_local': self.network_is_local})
|
||||
'head_idx': head_idx, 'is_local': self.network_is_local})
|
||||
|
||||
def get_model(self):
|
||||
# validate the configuration
|
||||
@@ -189,11 +190,10 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
if self.network_parameters.middleware_parameters is None:
|
||||
raise ValueError("Exactly one middleware type should be defined")
|
||||
|
||||
if len(self.network_parameters.loss_weights) == 0:
|
||||
raise ValueError("At least one loss weight should be defined")
|
||||
|
||||
if len(self.network_parameters.heads_parameters) != len(self.network_parameters.loss_weights):
|
||||
raise ValueError("Number of loss weights should match the number of output types")
|
||||
# ops for defining the training / testing phase
|
||||
self.is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||
self.is_training_placeholder = tf.placeholder("bool")
|
||||
self.assign_is_training = tf.assign(self.is_training, self.is_training_placeholder)
|
||||
|
||||
for network_idx in range(self.num_networks):
|
||||
with tf.variable_scope('network_{}'.format(network_idx)):
|
||||
@@ -245,28 +245,27 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
|
||||
head_count = 0
|
||||
for head_idx in range(self.num_heads_per_network):
|
||||
for head_copy_idx in range(self.network_parameters.num_output_head_copies):
|
||||
if self.network_parameters.use_separate_networks_per_head:
|
||||
# if we use separate networks per head, then the head type corresponds top the network idx
|
||||
head_type_idx = network_idx
|
||||
head_count = network_idx
|
||||
else:
|
||||
# if we use a single network with multiple embedders, then the head type is the current head idx
|
||||
head_type_idx = head_idx
|
||||
|
||||
if self.network_parameters.use_separate_networks_per_head:
|
||||
# if we use separate networks per head, then the head type corresponds to the network idx
|
||||
head_type_idx = network_idx
|
||||
head_count = network_idx
|
||||
else:
|
||||
# if we use a single network with multiple embedders, then the head type is the current head idx
|
||||
head_type_idx = head_idx
|
||||
head_params = self.network_parameters.heads_parameters[head_type_idx]
|
||||
|
||||
for head_copy_idx in range(head_params.num_output_head_copies):
|
||||
# create output head and add it to the output heads list
|
||||
self.output_heads.append(
|
||||
self.get_output_head(self.network_parameters.heads_parameters[head_type_idx],
|
||||
head_idx*self.network_parameters.num_output_head_copies + head_copy_idx,
|
||||
self.network_parameters.loss_weights[head_type_idx])
|
||||
self.get_output_head(head_params,
|
||||
head_idx*head_params.num_output_head_copies + head_copy_idx)
|
||||
)
|
||||
|
||||
# rescale the gradients from the head
|
||||
self.gradients_from_head_rescalers.append(
|
||||
tf.get_variable('gradients_from_head_{}-{}_rescalers'.format(head_idx, head_copy_idx),
|
||||
initializer=float(
|
||||
self.network_parameters.rescale_gradient_from_head_by_factor[head_count]
|
||||
),
|
||||
initializer=float(head_params.rescale_gradient_from_head_by_factor),
|
||||
dtype=tf.float32))
|
||||
|
||||
self.gradients_from_head_rescalers_placeholders.append(
|
||||
@@ -344,4 +343,46 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
else:
|
||||
raise Exception("{} is not a valid optimizer type".format(self.network_parameters.optimizer_type))
|
||||
|
||||
def __str__(self):
|
||||
result = []
|
||||
|
||||
for network in range(self.num_networks):
|
||||
network_structure = []
|
||||
|
||||
# embedder
|
||||
for embedder in self.input_embedders:
|
||||
network_structure.append("Input Embedder: {}".format(embedder.name))
|
||||
network_structure.append(indent_string(str(embedder)))
|
||||
|
||||
if len(self.input_embedders) > 1:
|
||||
network_structure.append("{} ({})".format(self.network_parameters.embedding_merger_type.name,
|
||||
", ".join(["{} embedding".format(e.name) for e in self.input_embedders])))
|
||||
|
||||
# middleware
|
||||
network_structure.append("Middleware:")
|
||||
network_structure.append(indent_string(str(self.middleware)))
|
||||
|
||||
# head
|
||||
if self.network_parameters.use_separate_networks_per_head:
|
||||
heads = range(network, network+1)
|
||||
else:
|
||||
heads = range(0, len(self.output_heads))
|
||||
|
||||
for head_idx in heads:
|
||||
head = self.output_heads[head_idx]
|
||||
head_params = self.network_parameters.heads_parameters[head_idx]
|
||||
if head_params.num_output_head_copies > 1:
|
||||
network_structure.append("Output Head: {} (num copies = {})".format(head.name, head_params.num_output_head_copies))
|
||||
else:
|
||||
network_structure.append("Output Head: {}".format(head.name))
|
||||
network_structure.append(indent_string(str(head)))
|
||||
|
||||
# finalize network
|
||||
if self.num_networks > 1:
|
||||
result.append("Sub-network for head: {}".format(self.output_heads[network].name))
|
||||
result.append(indent_string('\n'.join(network_structure)))
|
||||
else:
|
||||
result.append('\n'.join(network_structure))
|
||||
|
||||
result = '\n'.join(result)
|
||||
return result
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
@@ -25,9 +25,13 @@ from rl_coach.spaces import SpacesDefinition
|
||||
|
||||
|
||||
class CategoricalQHeadParameters(HeadParameters):
|
||||
def __init__(self, activation_function: str ='relu', name: str='categorical_q_head_params', dense_layer=Dense):
|
||||
def __init__(self, activation_function: str ='relu', name: str='categorical_q_head_params',
|
||||
num_output_head_copies: int = 1, rescale_gradient_from_head_by_factor: float = 1.0,
|
||||
loss_weight: float = 1.0, dense_layer=Dense):
|
||||
super().__init__(parameterized_class=CategoricalQHead, activation_function=activation_function, name=name,
|
||||
dense_layer=dense_layer)
|
||||
dense_layer=dense_layer, num_output_head_copies=num_output_head_copies,
|
||||
rescale_gradient_from_head_by_factor=rescale_gradient_from_head_by_factor,
|
||||
loss_weight=loss_weight)
|
||||
|
||||
|
||||
class CategoricalQHead(Head):
|
||||
@@ -54,3 +58,12 @@ class CategoricalQHead(Head):
|
||||
self.target = self.distributions
|
||||
self.loss = tf.nn.softmax_cross_entropy_with_logits(labels=self.target, logits=values_distribution)
|
||||
tf.losses.add_loss(self.loss)
|
||||
|
||||
def __str__(self):
|
||||
result = [
|
||||
"Dense (num outputs = {})".format(self.num_actions * self.num_atoms),
|
||||
"Reshape (output size = {} x {})".format(self.num_actions, self.num_atoms),
|
||||
"Softmax"
|
||||
]
|
||||
return '\n'.join(result)
|
||||
|
||||
|
||||
@@ -16,27 +16,34 @@
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense, batchnorm_activation_dropout
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.core_types import QActionStateValue
|
||||
from rl_coach.spaces import SpacesDefinition, BoxActionSpace, DiscreteActionSpace
|
||||
from rl_coach.utils import force_list
|
||||
|
||||
|
||||
class RegressionHeadParameters(HeadParameters):
|
||||
def __init__(self, activation_function: str ='relu', name: str='q_head_params', dense_layer=Dense):
|
||||
def __init__(self, activation_function: str ='relu', name: str='q_head_params',
|
||||
num_output_head_copies: int = 1, rescale_gradient_from_head_by_factor: float = 1.0,
|
||||
loss_weight: float = 1.0, dense_layer=Dense, scheme=[Dense(256), Dense(256)]):
|
||||
super().__init__(parameterized_class=RegressionHead, activation_function=activation_function, name=name,
|
||||
dense_layer=dense_layer)
|
||||
dense_layer=dense_layer, num_output_head_copies=num_output_head_copies,
|
||||
rescale_gradient_from_head_by_factor=rescale_gradient_from_head_by_factor,
|
||||
loss_weight=loss_weight)
|
||||
|
||||
|
||||
class RegressionHead(Head):
|
||||
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
|
||||
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu',
|
||||
dense_layer=Dense):
|
||||
dense_layer=Dense, scheme=[Dense(256), Dense(256)]):
|
||||
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
|
||||
dense_layer=dense_layer)
|
||||
self.name = 'regression_head'
|
||||
self.scheme = scheme
|
||||
self.layers = []
|
||||
if isinstance(self.spaces.action, BoxActionSpace):
|
||||
self.num_actions = self.spaces.action.shape[0]
|
||||
elif isinstance(self.spaces.action, DiscreteActionSpace):
|
||||
@@ -48,9 +55,18 @@ class RegressionHead(Head):
|
||||
self.loss_type = tf.losses.mean_squared_error
|
||||
|
||||
def _build_module(self, input_layer):
|
||||
self.fc1 = self.dense_layer(256)(input_layer)
|
||||
self.fc2 = self.dense_layer(256)(self.fc1)
|
||||
self.output = self.dense_layer(self.num_actions)(self.fc2, name='output')
|
||||
self.layers.append(input_layer)
|
||||
for idx, layer_params in enumerate(self.scheme):
|
||||
self.layers.extend(force_list(
|
||||
layer_params(input_layer=self.layers[-1], name='{}_{}'.format(layer_params.__class__.__name__, idx))
|
||||
))
|
||||
|
||||
self.layers.append(self.dense_layer(self.num_actions)(self.layers[-1], name='output'))
|
||||
self.output = self.layers[-1]
|
||||
|
||||
def __str__(self):
|
||||
result = []
|
||||
for layer in self.layers:
|
||||
result.append(str(layer))
|
||||
return '\n'.join(result)
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.architecture import batchnorm_activation_dropout, Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import batchnorm_activation_dropout, Dense
|
||||
from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.core_types import ActionProbabilities
|
||||
@@ -25,9 +25,12 @@ from rl_coach.spaces import SpacesDefinition
|
||||
|
||||
class DDPGActorHeadParameters(HeadParameters):
|
||||
def __init__(self, activation_function: str ='tanh', name: str='policy_head_params', batchnorm: bool=True,
|
||||
dense_layer=Dense):
|
||||
num_output_head_copies: int = 1, rescale_gradient_from_head_by_factor: float = 1.0,
|
||||
loss_weight: float = 1.0, dense_layer=Dense):
|
||||
super().__init__(parameterized_class=DDPGActor, activation_function=activation_function, name=name,
|
||||
dense_layer=dense_layer)
|
||||
dense_layer=dense_layer, num_output_head_copies=num_output_head_copies,
|
||||
rescale_gradient_from_head_by_factor=rescale_gradient_from_head_by_factor,
|
||||
loss_weight=loss_weight)
|
||||
self.batchnorm = batchnorm
|
||||
|
||||
|
||||
@@ -56,7 +59,7 @@ class DDPGActor(Head):
|
||||
pre_activation_policy_values_mean = self.dense_layer(self.num_actions)(input_layer, name='fc_mean')
|
||||
policy_values_mean = batchnorm_activation_dropout(pre_activation_policy_values_mean, self.batchnorm,
|
||||
self.activation_function,
|
||||
False, 0, 0)[-1]
|
||||
False, 0, is_training=False, name="BatchnormActivationDropout_0")[-1]
|
||||
self.policy_mean = tf.multiply(policy_values_mean, self.output_scale, name='output_mean')
|
||||
|
||||
if self.is_local:
|
||||
@@ -66,3 +69,9 @@ class DDPGActor(Head):
|
||||
[self.action_penalty * tf.reduce_mean(tf.square(pre_activation_policy_values_mean))]
|
||||
|
||||
self.output = [self.policy_mean]
|
||||
|
||||
def __str__(self):
|
||||
result = [
|
||||
'Dense (num outputs = {})'.format(self.num_actions[0])
|
||||
]
|
||||
return '\n'.join(result)
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
#
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.architectures.tensorflow_components.heads.head import HeadParameters
|
||||
from rl_coach.architectures.tensorflow_components.heads.q_head import QHead
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
@@ -24,9 +24,13 @@ from rl_coach.spaces import SpacesDefinition
|
||||
|
||||
|
||||
class DNDQHeadParameters(HeadParameters):
|
||||
def __init__(self, activation_function: str ='relu', name: str='dnd_q_head_params', dense_layer=Dense):
|
||||
def __init__(self, activation_function: str ='relu', name: str='dnd_q_head_params',
|
||||
num_output_head_copies: int = 1, rescale_gradient_from_head_by_factor: float = 1.0,
|
||||
loss_weight: float = 1.0, dense_layer=Dense):
|
||||
super().__init__(parameterized_class=DNDQHead, activation_function=activation_function, name=name,
|
||||
dense_layer=dense_layer)
|
||||
dense_layer=dense_layer, num_output_head_copies=num_output_head_copies,
|
||||
rescale_gradient_from_head_by_factor=rescale_gradient_from_head_by_factor,
|
||||
loss_weight=loss_weight)
|
||||
|
||||
|
||||
class DNDQHead(QHead):
|
||||
@@ -89,3 +93,9 @@ class DNDQHead(QHead):
|
||||
# DND gradients
|
||||
self.dnd_embeddings_grad = tf.gradients(self.loss[0], self.dnd_embeddings)
|
||||
self.dnd_values_grad = tf.gradients(self.loss[0], self.dnd_values)
|
||||
|
||||
def __str__(self):
|
||||
result = [
|
||||
"DND fetch (num outputs = {})".format(self.num_actions)
|
||||
]
|
||||
return '\n'.join(result)
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.architectures.tensorflow_components.heads.head import HeadParameters
|
||||
from rl_coach.architectures.tensorflow_components.heads.q_head import QHead
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
@@ -24,9 +24,13 @@ from rl_coach.spaces import SpacesDefinition
|
||||
|
||||
|
||||
class DuelingQHeadParameters(HeadParameters):
|
||||
def __init__(self, activation_function: str ='relu', name: str='dueling_q_head_params', dense_layer=Dense):
|
||||
super().__init__(parameterized_class=DuelingQHead, activation_function=activation_function, name=name, dense_layer=dense_layer)
|
||||
|
||||
def __init__(self, activation_function: str ='relu', name: str='dueling_q_head_params',
|
||||
num_output_head_copies: int = 1, rescale_gradient_from_head_by_factor: float = 1.0,
|
||||
loss_weight: float = 1.0, dense_layer=Dense):
|
||||
super().__init__(parameterized_class=DuelingQHead, activation_function=activation_function, name=name,
|
||||
dense_layer=dense_layer, num_output_head_copies=num_output_head_copies,
|
||||
rescale_gradient_from_head_by_factor=rescale_gradient_from_head_by_factor,
|
||||
loss_weight=loss_weight)
|
||||
|
||||
class DuelingQHead(QHead):
|
||||
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
|
||||
@@ -51,3 +55,16 @@ class DuelingQHead(QHead):
|
||||
|
||||
# merge to state-action value function Q
|
||||
self.output = tf.add(self.state_value, self.action_advantage, name='output')
|
||||
|
||||
def __str__(self):
|
||||
result = [
|
||||
"State Value Stream - V",
|
||||
"\tDense (num outputs = 512)",
|
||||
"\tDense (num outputs = 1)",
|
||||
"Action Advantage Stream - A",
|
||||
"\tDense (num outputs = 512)",
|
||||
"\tDense (num outputs = {})".format(self.num_actions),
|
||||
"\tSubtract(A, Mean(A))".format(self.num_actions),
|
||||
"Add (V, A)"
|
||||
]
|
||||
return '\n'.join(result)
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import Type
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.ops.losses.losses_impl import Reduction
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.base_parameters import AgentParameters, Parameters, NetworkComponentParameters
|
||||
from rl_coach.spaces import SpacesDefinition
|
||||
from rl_coach.utils import force_list
|
||||
@@ -35,10 +35,14 @@ def normalized_columns_initializer(std=1.0):
|
||||
|
||||
class HeadParameters(NetworkComponentParameters):
|
||||
def __init__(self, parameterized_class: Type['Head'], activation_function: str = 'relu', name: str= 'head',
|
||||
dense_layer=Dense):
|
||||
num_output_head_copies: int=1, rescale_gradient_from_head_by_factor: float=1.0,
|
||||
loss_weight: float=1.0, dense_layer=Dense):
|
||||
super().__init__(dense_layer=dense_layer)
|
||||
self.activation_function = activation_function
|
||||
self.name = name
|
||||
self.num_output_head_copies = num_output_head_copies
|
||||
self.rescale_gradient_from_head_by_factor = rescale_gradient_from_head_by_factor
|
||||
self.loss_weight = loss_weight
|
||||
self.parameterized_class_name = parameterized_class.__name__
|
||||
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
@@ -26,9 +26,12 @@ from rl_coach.spaces import SpacesDefinition
|
||||
|
||||
class MeasurementsPredictionHeadParameters(HeadParameters):
|
||||
def __init__(self, activation_function: str ='relu', name: str='measurements_prediction_head_params',
|
||||
dense_layer=Dense):
|
||||
super().__init__(parameterized_class=MeasurementsPredictionHead,
|
||||
activation_function=activation_function, name=name, dense_layer=dense_layer)
|
||||
num_output_head_copies: int = 1, rescale_gradient_from_head_by_factor: float = 1.0,
|
||||
loss_weight: float = 1.0, dense_layer=Dense):
|
||||
super().__init__(parameterized_class=MeasurementsPredictionHead, activation_function=activation_function, name=name,
|
||||
dense_layer=dense_layer, num_output_head_copies=num_output_head_copies,
|
||||
rescale_gradient_from_head_by_factor=rescale_gradient_from_head_by_factor,
|
||||
loss_weight=loss_weight)
|
||||
|
||||
|
||||
class MeasurementsPredictionHead(Head):
|
||||
@@ -68,3 +71,17 @@ class MeasurementsPredictionHead(Head):
|
||||
targets_nonan = tf.where(tf.is_nan(self.target), self.output, self.target)
|
||||
self.loss = tf.reduce_sum(tf.reduce_mean(tf.square(targets_nonan - self.output), reduction_indices=0))
|
||||
tf.losses.add_loss(self.loss_weight[0] * self.loss)
|
||||
|
||||
def __str__(self):
|
||||
result = [
|
||||
"State Value Stream - V",
|
||||
"\tDense (num outputs = 256)",
|
||||
"\tDense (num outputs = {})".format(self.multi_step_measurements_size),
|
||||
"Action Advantage Stream - A",
|
||||
"\tDense (num outputs = 256)",
|
||||
"\tDense (num outputs = {})".format(self.num_actions * self.multi_step_measurements_size),
|
||||
"\tReshape (new size = {} x {})".format(self.num_actions, self.multi_step_measurements_size),
|
||||
"\tSubtract(A, Mean(A))".format(self.num_actions),
|
||||
"Add (V, A)"
|
||||
]
|
||||
return '\n'.join(result)
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.core_types import QActionStateValue
|
||||
@@ -25,9 +25,13 @@ from rl_coach.spaces import SpacesDefinition
|
||||
|
||||
|
||||
class NAFHeadParameters(HeadParameters):
|
||||
def __init__(self, activation_function: str ='tanh', name: str='naf_head_params', dense_layer=Dense):
|
||||
def __init__(self, activation_function: str ='tanh', name: str='naf_head_params',
|
||||
num_output_head_copies: int = 1, rescale_gradient_from_head_by_factor: float = 1.0,
|
||||
loss_weight: float = 1.0, dense_layer=Dense):
|
||||
super().__init__(parameterized_class=NAFHead, activation_function=activation_function, name=name,
|
||||
dense_layer=dense_layer)
|
||||
dense_layer=dense_layer, num_output_head_copies=num_output_head_copies,
|
||||
rescale_gradient_from_head_by_factor=rescale_gradient_from_head_by_factor,
|
||||
loss_weight=loss_weight)
|
||||
|
||||
|
||||
class NAFHead(Head):
|
||||
@@ -90,3 +94,21 @@ class NAFHead(Head):
|
||||
self.Q = tf.add(self.V, self.A, name='Q')
|
||||
|
||||
self.output = self.Q
|
||||
|
||||
def __str__(self):
|
||||
result = [
|
||||
"State Value Stream - V",
|
||||
"\tDense (num outputs = 1)",
|
||||
"Action Advantage Stream - A",
|
||||
"\tDense (num outputs = {})".format((self.num_actions * (self.num_actions + 1)) / 2),
|
||||
"\tReshape to lower triangular matrix L (new size = {} x {})".format(self.num_actions, self.num_actions),
|
||||
"\tP = L*L^T",
|
||||
"\tA = -1/2 * (u - mu)^T * P * (u - mu)",
|
||||
"Action Stream - mu",
|
||||
"\tDense (num outputs = {})".format(self.num_actions),
|
||||
"\tActivation (type = {})".format(self.activation_function.__name__),
|
||||
"\tMultiply (factor = {})".format(self.output_scale),
|
||||
"State-Action Value Stream - Q",
|
||||
"\tAdd (V, A)"
|
||||
]
|
||||
return '\n'.join(result)
|
||||
|
||||
@@ -17,20 +17,25 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.architectures.tensorflow_components.heads.head import Head, normalized_columns_initializer, HeadParameters
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.core_types import ActionProbabilities
|
||||
from rl_coach.exploration_policies.continuous_entropy import ContinuousEntropyParameters
|
||||
from rl_coach.spaces import DiscreteActionSpace, BoxActionSpace, CompoundActionSpace
|
||||
from rl_coach.spaces import SpacesDefinition
|
||||
from rl_coach.utils import eps
|
||||
from rl_coach.utils import eps, indent_string
|
||||
|
||||
|
||||
class PolicyHeadParameters(HeadParameters):
|
||||
def __init__(self, activation_function: str ='tanh', name: str='policy_head_params', dense_layer=Dense):
|
||||
def __init__(self, activation_function: str ='tanh', name: str='policy_head_params',
|
||||
num_output_head_copies: int = 1, rescale_gradient_from_head_by_factor: float = 1.0,
|
||||
loss_weight: float = 1.0, dense_layer=Dense):
|
||||
super().__init__(parameterized_class=PolicyHead, activation_function=activation_function, name=name,
|
||||
dense_layer=dense_layer)
|
||||
dense_layer=dense_layer, num_output_head_copies=num_output_head_copies,
|
||||
rescale_gradient_from_head_by_factor=rescale_gradient_from_head_by_factor,
|
||||
loss_weight=loss_weight)
|
||||
|
||||
|
||||
|
||||
class PolicyHead(Head):
|
||||
@@ -112,7 +117,7 @@ class PolicyHead(Head):
|
||||
self.actions.append(tf.placeholder(tf.float32, [None, num_actions], name="actions"))
|
||||
|
||||
# output activation function
|
||||
if np.all(self.spaces.action.max_abs_range < np.inf):
|
||||
if np.all(action_space.max_abs_range < np.inf):
|
||||
# bounded actions
|
||||
self.output_scale = action_space.max_abs_range
|
||||
self.continuous_output_activation = self.activation_function
|
||||
@@ -158,3 +163,45 @@ class PolicyHead(Head):
|
||||
if self.action_penalty and self.action_penalty != 0:
|
||||
self.regularizations += [
|
||||
self.action_penalty * tf.reduce_mean(tf.square(pre_activation_policy_values_mean))]
|
||||
|
||||
def __str__(self):
|
||||
action_spaces = [self.spaces.action]
|
||||
if isinstance(self.spaces.action, CompoundActionSpace):
|
||||
action_spaces = self.spaces.action.sub_action_spaces
|
||||
|
||||
result = []
|
||||
for action_space_idx, action_space in enumerate(action_spaces):
|
||||
action_head_mean_result = []
|
||||
if isinstance(action_space, DiscreteActionSpace):
|
||||
# create a discrete action network (softmax probabilities output)
|
||||
action_head_mean_result.append("Dense (num outputs = {})".format(len(action_space.actions)))
|
||||
action_head_mean_result.append("Softmax")
|
||||
elif isinstance(action_space, BoxActionSpace):
|
||||
# create a continuous action network (bounded mean and stdev outputs)
|
||||
action_head_mean_result.append("Dense (num outputs = {})".format(action_space.shape))
|
||||
if np.all(action_space.max_abs_range < np.inf):
|
||||
# bounded actions
|
||||
action_head_mean_result.append("Activation (type = {})".format(self.activation_function.__name__))
|
||||
action_head_mean_result.append("Multiply (factor = {})".format(action_space.max_abs_range))
|
||||
|
||||
action_head_stdev_result = []
|
||||
if isinstance(self.exploration_policy, ContinuousEntropyParameters):
|
||||
action_head_stdev_result.append("Dense (num outputs = {})".format(action_space.shape))
|
||||
action_head_stdev_result.append("Softplus")
|
||||
|
||||
action_head_result = []
|
||||
if action_head_stdev_result:
|
||||
action_head_result.append("Mean Stream")
|
||||
action_head_result.append(indent_string('\n'.join(action_head_mean_result)))
|
||||
action_head_result.append("Stdev Stream")
|
||||
action_head_result.append(indent_string('\n'.join(action_head_stdev_result)))
|
||||
else:
|
||||
action_head_result.append('\n'.join(action_head_mean_result))
|
||||
|
||||
if len(action_spaces) > 1:
|
||||
result.append("Action head {}".format(action_space_idx))
|
||||
result.append(indent_string('\n'.join(action_head_result)))
|
||||
else:
|
||||
result.append('\n'.join(action_head_result))
|
||||
|
||||
return '\n'.join(result)
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters, normalized_columns_initializer
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.core_types import ActionProbabilities
|
||||
@@ -27,9 +27,13 @@ from rl_coach.utils import eps
|
||||
|
||||
|
||||
class PPOHeadParameters(HeadParameters):
|
||||
def __init__(self, activation_function: str ='tanh', name: str='ppo_head_params', dense_layer=Dense):
|
||||
def __init__(self, activation_function: str ='tanh', name: str='ppo_head_params',
|
||||
num_output_head_copies: int = 1, rescale_gradient_from_head_by_factor: float = 1.0,
|
||||
loss_weight: float = 1.0, dense_layer=Dense):
|
||||
super().__init__(parameterized_class=PPOHead, activation_function=activation_function, name=name,
|
||||
dense_layer=dense_layer)
|
||||
dense_layer=dense_layer, num_output_head_copies=num_output_head_copies,
|
||||
rescale_gradient_from_head_by_factor=rescale_gradient_from_head_by_factor,
|
||||
loss_weight=loss_weight)
|
||||
|
||||
|
||||
class PPOHead(Head):
|
||||
@@ -146,3 +150,15 @@ class PPOHead(Head):
|
||||
self.old_policy_distribution = tf.contrib.distributions.MultivariateNormalDiag(self.old_policy_mean, self.old_policy_std + eps)
|
||||
|
||||
self.output = [self.policy_mean, self.policy_std]
|
||||
|
||||
def __str__(self):
|
||||
action_head_mean_result = []
|
||||
if isinstance(self.spaces.action, DiscreteActionSpace):
|
||||
# create a discrete action network (softmax probabilities output)
|
||||
action_head_mean_result.append("Dense (num outputs = {})".format(len(self.spaces.action.actions)))
|
||||
action_head_mean_result.append("Softmax")
|
||||
elif isinstance(self.spaces.action, BoxActionSpace):
|
||||
# create a continuous action network (bounded mean and stdev outputs)
|
||||
action_head_mean_result.append("Dense (num outputs = {})".format(self.spaces.action.shape))
|
||||
|
||||
return '\n'.join(action_head_mean_result)
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.heads.head import Head, normalized_columns_initializer, HeadParameters
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
@@ -25,9 +25,13 @@ from rl_coach.spaces import SpacesDefinition
|
||||
|
||||
|
||||
class PPOVHeadParameters(HeadParameters):
|
||||
def __init__(self, activation_function: str ='relu', name: str='ppo_v_head_params', dense_layer=Dense):
|
||||
def __init__(self, activation_function: str ='relu', name: str='ppo_v_head_params',
|
||||
num_output_head_copies: int = 1, rescale_gradient_from_head_by_factor: float = 1.0,
|
||||
loss_weight: float = 1.0, dense_layer=Dense):
|
||||
super().__init__(parameterized_class=PPOVHead, activation_function=activation_function, name=name,
|
||||
dense_layer=dense_layer)
|
||||
dense_layer=dense_layer, num_output_head_copies=num_output_head_copies,
|
||||
rescale_gradient_from_head_by_factor=rescale_gradient_from_head_by_factor,
|
||||
loss_weight=loss_weight)
|
||||
|
||||
|
||||
class PPOVHead(Head):
|
||||
@@ -55,3 +59,9 @@ class PPOVHead(Head):
|
||||
self.vf_loss = tf.reduce_mean(tf.maximum(value_loss_1, value_loss_2))
|
||||
self.loss = self.vf_loss
|
||||
tf.losses.add_loss(self.loss)
|
||||
|
||||
def __str__(self):
|
||||
result = [
|
||||
"Dense (num outputs = 1)"
|
||||
]
|
||||
return '\n'.join(result)
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
@@ -25,9 +25,13 @@ from rl_coach.spaces import SpacesDefinition, BoxActionSpace, DiscreteActionSpac
|
||||
|
||||
|
||||
class QHeadParameters(HeadParameters):
|
||||
def __init__(self, activation_function: str ='relu', name: str='q_head_params', dense_layer=Dense):
|
||||
def __init__(self, activation_function: str ='relu', name: str='q_head_params',
|
||||
num_output_head_copies: int = 1, rescale_gradient_from_head_by_factor: float = 1.0,
|
||||
loss_weight: float = 1.0, dense_layer=Dense):
|
||||
super().__init__(parameterized_class=QHead, activation_function=activation_function, name=name,
|
||||
dense_layer=dense_layer)
|
||||
dense_layer=dense_layer, num_output_head_copies=num_output_head_copies,
|
||||
rescale_gradient_from_head_by_factor=rescale_gradient_from_head_by_factor,
|
||||
loss_weight=loss_weight)
|
||||
|
||||
|
||||
class QHead(Head):
|
||||
@@ -51,5 +55,10 @@ class QHead(Head):
|
||||
# Standard Q Network
|
||||
self.output = self.dense_layer(self.num_actions)(input_layer, name='output')
|
||||
|
||||
def __str__(self):
|
||||
result = [
|
||||
"Dense (num outputs = {})".format(self.num_actions)
|
||||
]
|
||||
return '\n'.join(result)
|
||||
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
@@ -26,9 +26,12 @@ from rl_coach.spaces import SpacesDefinition
|
||||
|
||||
class QuantileRegressionQHeadParameters(HeadParameters):
|
||||
def __init__(self, activation_function: str ='relu', name: str='quantile_regression_q_head_params',
|
||||
dense_layer=Dense):
|
||||
super().__init__(parameterized_class=QuantileRegressionQHead, activation_function=activation_function,
|
||||
name=name, dense_layer=dense_layer)
|
||||
num_output_head_copies: int = 1, rescale_gradient_from_head_by_factor: float = 1.0,
|
||||
loss_weight: float = 1.0, dense_layer=Dense):
|
||||
super().__init__(parameterized_class=QuantileRegressionQHead, activation_function=activation_function, name=name,
|
||||
dense_layer=dense_layer, num_output_head_copies=num_output_head_copies,
|
||||
rescale_gradient_from_head_by_factor=rescale_gradient_from_head_by_factor,
|
||||
loss_weight=loss_weight)
|
||||
|
||||
|
||||
class QuantileRegressionQHead(Head):
|
||||
@@ -79,3 +82,11 @@ class QuantileRegressionQHead(Head):
|
||||
quantile_regression_loss = tf.reduce_sum(quantile_huber_loss) / float(self.num_atoms)
|
||||
self.loss = quantile_regression_loss
|
||||
tf.losses.add_loss(self.loss)
|
||||
|
||||
def __str__(self):
|
||||
result = [
|
||||
"Dense (num outputs = {})".format(self.num_actions * self.num_atoms),
|
||||
"Reshape (new size = {} x {})".format(self.num_actions, self.num_atoms)
|
||||
]
|
||||
return '\n'.join(result)
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.architectures.tensorflow_components.heads.head import HeadParameters, Head
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.core_types import QActionStateValue
|
||||
@@ -24,9 +24,13 @@ from rl_coach.spaces import SpacesDefinition
|
||||
|
||||
|
||||
class RainbowQHeadParameters(HeadParameters):
|
||||
def __init__(self, activation_function: str ='relu', name: str='rainbow_q_head_params', dense_layer=Dense):
|
||||
def __init__(self, activation_function: str ='relu', name: str='rainbow_q_head_params',
|
||||
num_output_head_copies: int = 1, rescale_gradient_from_head_by_factor: float = 1.0,
|
||||
loss_weight: float = 1.0, dense_layer=Dense):
|
||||
super().__init__(parameterized_class=RainbowQHead, activation_function=activation_function, name=name,
|
||||
dense_layer=dense_layer)
|
||||
dense_layer=dense_layer, num_output_head_copies=num_output_head_copies,
|
||||
rescale_gradient_from_head_by_factor=rescale_gradient_from_head_by_factor,
|
||||
loss_weight=loss_weight)
|
||||
|
||||
|
||||
class RainbowQHead(Head):
|
||||
@@ -69,3 +73,17 @@ class RainbowQHead(Head):
|
||||
self.loss = tf.nn.softmax_cross_entropy_with_logits(labels=self.target, logits=values_distribution)
|
||||
tf.losses.add_loss(self.loss)
|
||||
|
||||
def __str__(self):
|
||||
result = [
|
||||
"State Value Stream - V",
|
||||
"\tDense (num outputs = 512)",
|
||||
"\tDense (num outputs = {})".format(self.num_atoms),
|
||||
"Action Advantage Stream - A",
|
||||
"\tDense (num outputs = 512)",
|
||||
"\tDense (num outputs = {})".format(self.num_actions * self.num_atoms),
|
||||
"\tReshape (new size = {} x {})".format(self.num_actions, self.num_atoms),
|
||||
"\tSubtract(A, Mean(A))".format(self.num_actions),
|
||||
"Add (V, A)",
|
||||
"Softmax"
|
||||
]
|
||||
return '\n'.join(result)
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.heads.head import Head, normalized_columns_initializer, HeadParameters
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
@@ -25,9 +25,13 @@ from rl_coach.spaces import SpacesDefinition
|
||||
|
||||
|
||||
class VHeadParameters(HeadParameters):
|
||||
def __init__(self, activation_function: str ='relu', name: str='v_head_params', dense_layer=Dense):
|
||||
def __init__(self, activation_function: str ='relu', name: str='v_head_params',
|
||||
num_output_head_copies: int = 1, rescale_gradient_from_head_by_factor: float = 1.0,
|
||||
loss_weight: float = 1.0, dense_layer=Dense):
|
||||
super().__init__(parameterized_class=VHead, activation_function=activation_function, name=name,
|
||||
dense_layer=dense_layer)
|
||||
dense_layer=dense_layer, num_output_head_copies=num_output_head_copies,
|
||||
rescale_gradient_from_head_by_factor=rescale_gradient_from_head_by_factor,
|
||||
loss_weight=loss_weight)
|
||||
|
||||
|
||||
class VHead(Head):
|
||||
@@ -48,3 +52,9 @@ class VHead(Head):
|
||||
# Standard V Network
|
||||
self.output = self.dense_layer(1)(input_layer, name='output',
|
||||
kernel_initializer=normalized_columns_initializer(1.0))
|
||||
|
||||
def __str__(self):
|
||||
result = [
|
||||
"Dense (num outputs = 1)"
|
||||
]
|
||||
return '\n'.join(result)
|
||||
|
||||
167
rl_coach/architectures/tensorflow_components/layers.py
Normal file
167
rl_coach/architectures/tensorflow_components/layers.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import math
|
||||
from typing import List, Union
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.utils import force_list
|
||||
|
||||
|
||||
def batchnorm_activation_dropout(input_layer, batchnorm, activation_function, dropout, dropout_rate, is_training, name):
|
||||
layers = [input_layer]
|
||||
|
||||
# batchnorm
|
||||
if batchnorm:
|
||||
layers.append(
|
||||
tf.layers.batch_normalization(layers[-1], name="{}_batchnorm".format(name), training=is_training)
|
||||
)
|
||||
|
||||
# activation
|
||||
if activation_function:
|
||||
layers.append(
|
||||
activation_function(layers[-1], name="{}_activation".format(name))
|
||||
)
|
||||
|
||||
# dropout
|
||||
if dropout:
|
||||
layers.append(
|
||||
tf.layers.dropout(layers[-1], dropout_rate, name="{}_dropout".format(name), training=is_training)
|
||||
)
|
||||
|
||||
# remove the input layer from the layers list
|
||||
del layers[0]
|
||||
|
||||
return layers
|
||||
|
||||
|
||||
class Conv2d(object):
|
||||
def __init__(self, num_filters: int, kernel_size: int, strides: int):
|
||||
self.num_filters = num_filters
|
||||
self.kernel_size = kernel_size
|
||||
self.strides = strides
|
||||
|
||||
def __call__(self, input_layer, name: str=None, is_training=None):
|
||||
"""
|
||||
returns a tensorflow conv2d layer
|
||||
:param input_layer: previous layer
|
||||
:param name: layer name
|
||||
:return: conv2d layer
|
||||
"""
|
||||
return tf.layers.conv2d(input_layer, filters=self.num_filters, kernel_size=self.kernel_size,
|
||||
strides=self.strides, data_format='channels_last', name=name)
|
||||
|
||||
def __str__(self):
|
||||
return "Convolution (num filters = {}, kernel size = {}, stride = {})"\
|
||||
.format(self.num_filters, self.kernel_size, self.strides)
|
||||
|
||||
|
||||
class BatchnormActivationDropout(object):
|
||||
def __init__(self, batchnorm: bool=False, activation_function=None, dropout_rate: float=0):
|
||||
self.batchnorm = batchnorm
|
||||
self.activation_function = activation_function
|
||||
self.dropout_rate = dropout_rate
|
||||
|
||||
def __call__(self, input_layer, name: str=None, is_training=None):
|
||||
"""
|
||||
returns a list of tensorflow batchnorm, activation and dropout layers
|
||||
:param input_layer: previous layer
|
||||
:param name: layer name
|
||||
:return: batchnorm, activation and dropout layers
|
||||
"""
|
||||
return batchnorm_activation_dropout(input_layer, batchnorm=self.batchnorm,
|
||||
activation_function=self.activation_function,
|
||||
dropout=self.dropout_rate > 0, dropout_rate=self.dropout_rate,
|
||||
is_training=is_training, name=name)
|
||||
|
||||
def __str__(self):
|
||||
result = []
|
||||
if self.batchnorm:
|
||||
result += ["Batch Normalization"]
|
||||
if self.activation_function:
|
||||
result += ["Activation (type = {})".format(self.activation_function.__name__)]
|
||||
if self.dropout_rate > 0:
|
||||
result += ["Dropout (rate = {})".format(self.dropout_rate)]
|
||||
return "\n".join(result)
|
||||
|
||||
|
||||
class Dense(object):
|
||||
def __init__(self, units: int):
|
||||
self.units = units
|
||||
|
||||
def __call__(self, input_layer, name: str=None, kernel_initializer=None, activation=None, is_training=None):
|
||||
"""
|
||||
returns a tensorflow dense layer
|
||||
:param input_layer: previous layer
|
||||
:param name: layer name
|
||||
:return: dense layer
|
||||
"""
|
||||
return tf.layers.dense(input_layer, self.units, name=name, kernel_initializer=kernel_initializer,
|
||||
activation=activation)
|
||||
|
||||
def __str__(self):
|
||||
return "Dense (num outputs = {})".format(self.units)
|
||||
|
||||
|
||||
class NoisyNetDense(object):
|
||||
"""
|
||||
A factorized Noisy Net layer
|
||||
|
||||
https://arxiv.org/abs/1706.10295.
|
||||
"""
|
||||
|
||||
def __init__(self, units: int):
|
||||
self.units = units
|
||||
self.sigma0 = 0.5
|
||||
|
||||
def __call__(self, input_layer, name: str, kernel_initializer=None, activation=None, is_training=None):
|
||||
"""
|
||||
returns a NoisyNet dense layer
|
||||
:param input_layer: previous layer
|
||||
:param name: layer name
|
||||
:param kernel_initializer: initializer for kernels. Default is to use Gaussian noise that preserves stddev.
|
||||
:param activation: the activation function
|
||||
:return: dense layer
|
||||
"""
|
||||
#TODO: noise sampling should be externally controlled. DQN is fine with sampling noise for every
|
||||
# forward (either act or train, both for online and target networks).
|
||||
# A3C, on the other hand, should sample noise only when policy changes (i.e. after every t_max steps)
|
||||
|
||||
num_inputs = input_layer.get_shape()[-1].value
|
||||
num_outputs = self.units
|
||||
|
||||
stddev = 1 / math.sqrt(num_inputs)
|
||||
activation = activation if activation is not None else (lambda x: x)
|
||||
|
||||
if kernel_initializer is None:
|
||||
kernel_mean_initializer = tf.random_uniform_initializer(-stddev, stddev)
|
||||
kernel_stddev_initializer = tf.random_uniform_initializer(-stddev * self.sigma0, stddev * self.sigma0)
|
||||
else:
|
||||
kernel_mean_initializer = kernel_stddev_initializer = kernel_initializer
|
||||
with tf.variable_scope(None, default_name=name):
|
||||
weight_mean = tf.get_variable('weight_mean', shape=(num_inputs, num_outputs),
|
||||
initializer=kernel_mean_initializer)
|
||||
bias_mean = tf.get_variable('bias_mean', shape=(num_outputs,), initializer=tf.zeros_initializer())
|
||||
|
||||
weight_stddev = tf.get_variable('weight_stddev', shape=(num_inputs, num_outputs),
|
||||
initializer=kernel_stddev_initializer)
|
||||
bias_stddev = tf.get_variable('bias_stddev', shape=(num_outputs,),
|
||||
initializer=kernel_stddev_initializer)
|
||||
bias_noise = self.f(tf.random_normal((num_outputs,)))
|
||||
weight_noise = self.factorized_noise(num_inputs, num_outputs)
|
||||
|
||||
bias = bias_mean + bias_stddev * bias_noise
|
||||
weight = weight_mean + weight_stddev * weight_noise
|
||||
return activation(tf.matmul(input_layer, weight) + bias)
|
||||
|
||||
def factorized_noise(self, inputs, outputs):
|
||||
# TODO: use factorized noise only for compute intensive algos (e.g. DQN).
|
||||
# lighter algos (e.g. DQN) should not use it
|
||||
noise1 = self.f(tf.random_normal((inputs, 1)))
|
||||
noise2 = self.f(tf.random_normal((1, outputs)))
|
||||
return tf.matmul(noise1, noise2)
|
||||
|
||||
@staticmethod
|
||||
def f(values):
|
||||
return tf.sqrt(tf.abs(values)) * tf.sign(values)
|
||||
|
||||
def __str__(self):
|
||||
return "Noisy Dense (num outputs = {})".format(self.units)
|
||||
@@ -17,46 +17,41 @@ from typing import Union, List
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.architecture import batchnorm_activation_dropout, Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import batchnorm_activation_dropout, Dense
|
||||
from rl_coach.architectures.tensorflow_components.middlewares.middleware import Middleware, MiddlewareParameters
|
||||
from rl_coach.base_parameters import MiddlewareScheme
|
||||
from rl_coach.core_types import Middleware_FC_Embedding
|
||||
from rl_coach.utils import force_list
|
||||
|
||||
|
||||
class FCMiddlewareParameters(MiddlewareParameters):
|
||||
def __init__(self, activation_function='relu',
|
||||
scheme: Union[List, MiddlewareScheme] = MiddlewareScheme.Medium,
|
||||
batchnorm: bool = False, dropout: bool = False,
|
||||
name="middleware_fc_embedder", dense_layer=Dense):
|
||||
name="middleware_fc_embedder", dense_layer=Dense, is_training=False):
|
||||
super().__init__(parameterized_class=FCMiddleware, activation_function=activation_function,
|
||||
scheme=scheme, batchnorm=batchnorm, dropout=dropout, name=name, dense_layer=dense_layer)
|
||||
scheme=scheme, batchnorm=batchnorm, dropout=dropout, name=name, dense_layer=dense_layer,
|
||||
is_training=is_training)
|
||||
|
||||
|
||||
class FCMiddleware(Middleware):
|
||||
def __init__(self, activation_function=tf.nn.relu,
|
||||
scheme: MiddlewareScheme = MiddlewareScheme.Medium,
|
||||
batchnorm: bool = False, dropout: bool = False,
|
||||
name="middleware_fc_embedder", dense_layer=Dense):
|
||||
name="middleware_fc_embedder", dense_layer=Dense, is_training=False):
|
||||
super().__init__(activation_function=activation_function, batchnorm=batchnorm,
|
||||
dropout=dropout, scheme=scheme, name=name, dense_layer=dense_layer)
|
||||
dropout=dropout, scheme=scheme, name=name, dense_layer=dense_layer, is_training=is_training)
|
||||
self.return_type = Middleware_FC_Embedding
|
||||
self.layers = []
|
||||
|
||||
def _build_module(self):
|
||||
self.layers.append(self.input)
|
||||
|
||||
if isinstance(self.scheme, MiddlewareScheme):
|
||||
layers_params = self.schemes[self.scheme]
|
||||
else:
|
||||
layers_params = self.scheme
|
||||
for idx, layer_params in enumerate(layers_params):
|
||||
self.layers.append(
|
||||
layer_params(self.layers[-1], name='{}_{}'.format(layer_params.__class__.__name__, idx))
|
||||
)
|
||||
|
||||
self.layers.extend(batchnorm_activation_dropout(self.layers[-1], self.batchnorm,
|
||||
self.activation_function, self.dropout,
|
||||
self.dropout_rate, idx))
|
||||
for idx, layer_params in enumerate(self.layers_params):
|
||||
self.layers.extend(force_list(
|
||||
layer_params(self.layers[-1], name='{}_{}'.format(layer_params.__class__.__name__, idx),
|
||||
is_training=self.is_training)
|
||||
))
|
||||
|
||||
self.output = self.layers[-1]
|
||||
|
||||
@@ -69,20 +64,20 @@ class FCMiddleware(Middleware):
|
||||
# ppo
|
||||
MiddlewareScheme.Shallow:
|
||||
[
|
||||
self.dense_layer([64])
|
||||
self.dense_layer(64)
|
||||
],
|
||||
|
||||
# dqn
|
||||
MiddlewareScheme.Medium:
|
||||
[
|
||||
self.dense_layer([512])
|
||||
self.dense_layer(512)
|
||||
],
|
||||
|
||||
MiddlewareScheme.Deep: \
|
||||
[
|
||||
self.dense_layer([128]),
|
||||
self.dense_layer([128]),
|
||||
self.dense_layer([128])
|
||||
self.dense_layer(128),
|
||||
self.dense_layer(128),
|
||||
self.dense_layer(128)
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@@ -18,19 +18,21 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.architecture import batchnorm_activation_dropout, Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import batchnorm_activation_dropout, Dense
|
||||
from rl_coach.architectures.tensorflow_components.middlewares.middleware import Middleware, MiddlewareParameters
|
||||
from rl_coach.base_parameters import MiddlewareScheme
|
||||
from rl_coach.core_types import Middleware_LSTM_Embedding
|
||||
from rl_coach.utils import force_list
|
||||
|
||||
|
||||
class LSTMMiddlewareParameters(MiddlewareParameters):
|
||||
def __init__(self, activation_function='relu', number_of_lstm_cells=256,
|
||||
scheme: MiddlewareScheme = MiddlewareScheme.Medium,
|
||||
batchnorm: bool = False, dropout: bool = False,
|
||||
name="middleware_lstm_embedder", dense_layer=Dense):
|
||||
name="middleware_lstm_embedder", dense_layer=Dense, is_training=False):
|
||||
super().__init__(parameterized_class=LSTMMiddleware, activation_function=activation_function,
|
||||
scheme=scheme, batchnorm=batchnorm, dropout=dropout, name=name, dense_layer=dense_layer)
|
||||
scheme=scheme, batchnorm=batchnorm, dropout=dropout, name=name, dense_layer=dense_layer,
|
||||
is_training=is_training)
|
||||
self.number_of_lstm_cells = number_of_lstm_cells
|
||||
|
||||
|
||||
@@ -38,9 +40,9 @@ class LSTMMiddleware(Middleware):
|
||||
def __init__(self, activation_function=tf.nn.relu, number_of_lstm_cells: int=256,
|
||||
scheme: MiddlewareScheme = MiddlewareScheme.Medium,
|
||||
batchnorm: bool = False, dropout: bool = False,
|
||||
name="middleware_lstm_embedder", dense_layer=Dense):
|
||||
name="middleware_lstm_embedder", dense_layer=Dense, is_training=False):
|
||||
super().__init__(activation_function=activation_function, batchnorm=batchnorm,
|
||||
dropout=dropout, scheme=scheme, name=name, dense_layer=dense_layer)
|
||||
dropout=dropout, scheme=scheme, name=name, dense_layer=dense_layer, is_training=is_training)
|
||||
self.return_type = Middleware_LSTM_Embedding
|
||||
self.number_of_lstm_cells = number_of_lstm_cells
|
||||
self.layers = []
|
||||
@@ -57,19 +59,12 @@ class LSTMMiddleware(Middleware):
|
||||
|
||||
self.layers.append(self.input)
|
||||
|
||||
# optionally insert some dense layers before the LSTM
|
||||
if isinstance(self.scheme, MiddlewareScheme):
|
||||
layers_params = self.schemes[self.scheme]
|
||||
else:
|
||||
layers_params = self.scheme
|
||||
for idx, layer_params in enumerate(layers_params):
|
||||
self.layers.append(
|
||||
tf.layers.dense(self.layers[-1], layer_params[0], name='fc{}'.format(idx))
|
||||
)
|
||||
|
||||
self.layers.extend(batchnorm_activation_dropout(self.layers[-1], self.batchnorm,
|
||||
self.activation_function, self.dropout,
|
||||
self.dropout_rate, idx))
|
||||
# optionally insert some layers before the LSTM
|
||||
for idx, layer_params in enumerate(self.layers_params):
|
||||
self.layers.extend(force_list(
|
||||
layer_params(self.layers[-1], name='fc{}'.format(idx),
|
||||
is_training=self.is_training)
|
||||
))
|
||||
|
||||
# add the LSTM layer
|
||||
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(self.number_of_lstm_cells, state_is_tuple=True)
|
||||
@@ -97,20 +92,20 @@ class LSTMMiddleware(Middleware):
|
||||
# ppo
|
||||
MiddlewareScheme.Shallow:
|
||||
[
|
||||
[64]
|
||||
self.dense_layer(64)
|
||||
],
|
||||
|
||||
# dqn
|
||||
MiddlewareScheme.Medium:
|
||||
[
|
||||
[512]
|
||||
self.dense_layer(512)
|
||||
],
|
||||
|
||||
MiddlewareScheme.Deep: \
|
||||
[
|
||||
[128],
|
||||
[128],
|
||||
[128]
|
||||
self.dense_layer(128),
|
||||
self.dense_layer(128),
|
||||
self.dense_layer(128)
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@@ -13,25 +13,27 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import copy
|
||||
from typing import Type, Union, List
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
from rl_coach.base_parameters import MiddlewareScheme, Parameters, NetworkComponentParameters
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense, BatchnormActivationDropout
|
||||
from rl_coach.base_parameters import MiddlewareScheme, NetworkComponentParameters
|
||||
from rl_coach.core_types import MiddlewareEmbedding
|
||||
|
||||
|
||||
class MiddlewareParameters(NetworkComponentParameters):
|
||||
def __init__(self, parameterized_class: Type['Middleware'],
|
||||
activation_function: str='relu', scheme: Union[List, MiddlewareScheme]=MiddlewareScheme.Medium,
|
||||
batchnorm: bool=False, dropout: bool=False, name='middleware', dense_layer=Dense):
|
||||
batchnorm: bool=False, dropout: bool=False, name='middleware', dense_layer=Dense, is_training=False):
|
||||
super().__init__(dense_layer=dense_layer)
|
||||
self.activation_function = activation_function
|
||||
self.scheme = scheme
|
||||
self.batchnorm = batchnorm
|
||||
self.dropout = dropout
|
||||
self.name = name
|
||||
self.is_training = is_training
|
||||
self.parameterized_class_name = parameterized_class.__name__
|
||||
|
||||
|
||||
@@ -43,7 +45,8 @@ class Middleware(object):
|
||||
"""
|
||||
def __init__(self, activation_function=tf.nn.relu,
|
||||
scheme: MiddlewareScheme = MiddlewareScheme.Medium,
|
||||
batchnorm: bool = False, dropout: bool = False, name="middleware_embedder", dense_layer=Dense):
|
||||
batchnorm: bool = False, dropout: bool = False, name="middleware_embedder", dense_layer=Dense,
|
||||
is_training=False):
|
||||
self.name = name
|
||||
self.input = None
|
||||
self.output = None
|
||||
@@ -54,6 +57,23 @@ class Middleware(object):
|
||||
self.scheme = scheme
|
||||
self.return_type = MiddlewareEmbedding
|
||||
self.dense_layer = dense_layer
|
||||
self.is_training = is_training
|
||||
|
||||
# layers order is conv -> batchnorm -> activation -> dropout
|
||||
if isinstance(self.scheme, MiddlewareScheme):
|
||||
self.layers_params = copy.copy(self.schemes[self.scheme])
|
||||
else:
|
||||
self.layers_params = copy.copy(self.scheme)
|
||||
|
||||
# we allow adding batchnorm, dropout or activation functions after each layer.
|
||||
# The motivation is to simplify the transition between a network with batchnorm and a network without
|
||||
# batchnorm to a single flag (the same applies to activation function and dropout)
|
||||
if self.batchnorm or self.activation_function or self.dropout:
|
||||
for layer_idx in reversed(range(len(self.layers_params))):
|
||||
self.layers_params.insert(layer_idx+1,
|
||||
BatchnormActivationDropout(batchnorm=self.batchnorm,
|
||||
activation_function=self.activation_function,
|
||||
dropout_rate=self.dropout_rate))
|
||||
|
||||
def __call__(self, input_layer):
|
||||
with tf.variable_scope(self.get_name()):
|
||||
@@ -72,3 +92,10 @@ class Middleware(object):
|
||||
def schemes(self):
|
||||
raise NotImplementedError("Inheriting middleware must define schemes matching its allowed default "
|
||||
"configurations.")
|
||||
|
||||
def __str__(self):
|
||||
result = [str(l) for l in self.layers_params]
|
||||
if self.layers_params:
|
||||
return '\n'.join(result)
|
||||
else:
|
||||
return 'No layers'
|
||||
|
||||
@@ -22,7 +22,8 @@ from collections import OrderedDict
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentSteps, GradientClippingMethod
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentSteps, GradientClippingMethod, RunPhase
|
||||
# from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.filters.filter import NoInputFilter
|
||||
|
||||
|
||||
@@ -173,7 +174,71 @@ class PresetValidationParameters(Parameters):
|
||||
|
||||
|
||||
class NetworkParameters(Parameters):
|
||||
def __init__(self):
|
||||
def __init__(self,
|
||||
force_cpu = False,
|
||||
async_training = False,
|
||||
shared_optimizer = True,
|
||||
scale_down_gradients_by_number_of_workers_for_sync_training = True,
|
||||
clip_gradients = None,
|
||||
gradients_clipping_method = GradientClippingMethod.ClipByGlobalNorm,
|
||||
l2_regularization = 0,
|
||||
learning_rate = 0.00025,
|
||||
learning_rate_decay_rate = 0,
|
||||
learning_rate_decay_steps = 0,
|
||||
input_embedders_parameters = {},
|
||||
embedding_merger_type = EmbeddingMergerType.Concat,
|
||||
middleware_parameters = None,
|
||||
heads_parameters = [],
|
||||
use_separate_networks_per_head = False,
|
||||
optimizer_type = 'Adam',
|
||||
optimizer_epsilon = 0.0001,
|
||||
adam_optimizer_beta1 = 0.9,
|
||||
adam_optimizer_beta2 = 0.99,
|
||||
rms_prop_optimizer_decay = 0.9,
|
||||
batch_size = 32,
|
||||
replace_mse_with_huber_loss = False,
|
||||
create_target_network = False,
|
||||
tensorflow_support = True):
|
||||
"""
|
||||
:param force_cpu:
|
||||
Force the neural networks to run on the CPU even if a GPU is available
|
||||
:param async_training:
|
||||
If set to True, asynchronous training will be used, meaning that each workers will progress in its own
|
||||
speed, while not waiting for the rest of the workers to calculate their gradients.
|
||||
:param shared_optimizer:
|
||||
If set to True, a central optimizer which will be shared with all the workers will be used for applying
|
||||
gradients to the network. Otherwise, each worker will have its own optimizer with its own internal
|
||||
parameters that will only be affected by the gradients calculated by that worker
|
||||
:param scale_down_gradients_by_number_of_workers_for_sync_training:
|
||||
If set to True, in synchronous training, the gradients of each worker will be scaled down by the
|
||||
number of workers. This essentially means that the gradients applied to the network are the average
|
||||
of the gradients over all the workers.
|
||||
:param clip_gradients:
|
||||
A value that will be used for clipping the gradients of the network. If set to None, no gradient clipping
|
||||
will be applied. Otherwise, the gradients will be clipped according to the gradients_clipping_method.
|
||||
:param gradients_clipping_method:
|
||||
A gradient clipping method, defined by a GradientClippingMethod enum, and that will be used to clip the
|
||||
gradients of the network. This will only be used if the clip_gradients value is defined as a value other
|
||||
than None.
|
||||
:param l2_regularization:
|
||||
:param learning_rate:
|
||||
:param learning_rate_decay_rate:
|
||||
:param learning_rate_decay_steps:
|
||||
:param input_embedders_parameters:
|
||||
:param embedding_merger_type:
|
||||
:param middleware_parameters:
|
||||
:param heads_parameters:
|
||||
:param use_separate_networks_per_head:
|
||||
:param optimizer_type:
|
||||
:param optimizer_epsilon:
|
||||
:param adam_optimizer_beta1:
|
||||
:param adam_optimizer_beta2:
|
||||
:param rms_prop_optimizer_decay:
|
||||
:param batch_size:
|
||||
:param replace_mse_with_huber_loss:
|
||||
:param create_target_network:
|
||||
:param tensorflow_support:
|
||||
"""
|
||||
super().__init__()
|
||||
self.framework = Frameworks.tensorflow
|
||||
self.sess = None
|
||||
@@ -182,9 +247,6 @@ class NetworkParameters(Parameters):
|
||||
self.force_cpu = False
|
||||
|
||||
# distributed training options
|
||||
self.num_threads = 1
|
||||
self.synchronize_over_num_threads = 1
|
||||
self.distributed = False
|
||||
self.async_training = False
|
||||
self.shared_optimizer = True
|
||||
self.scale_down_gradients_by_number_of_workers_for_sync_training = True
|
||||
@@ -192,7 +254,6 @@ class NetworkParameters(Parameters):
|
||||
# regularization
|
||||
self.clip_gradients = None
|
||||
self.gradients_clipping_method = GradientClippingMethod.ClipByGlobalNorm
|
||||
self.kl_divergence_constraint = None
|
||||
self.l2_regularization = 0
|
||||
|
||||
# learning rate
|
||||
@@ -205,9 +266,6 @@ class NetworkParameters(Parameters):
|
||||
self.embedding_merger_type = EmbeddingMergerType.Concat
|
||||
self.middleware_parameters = None
|
||||
self.heads_parameters = []
|
||||
self.num_output_head_copies = 1
|
||||
self.loss_weights = []
|
||||
self.rescale_gradient_from_head_by_factor = [1]
|
||||
self.use_separate_networks_per_head = False
|
||||
self.optimizer_type = 'Adam'
|
||||
self.optimizer_epsilon = 0.0001
|
||||
@@ -227,35 +285,113 @@ class NetworkComponentParameters(Parameters):
|
||||
self.dense_layer = dense_layer
|
||||
|
||||
|
||||
|
||||
class VisualizationParameters(Parameters):
|
||||
def __init__(self):
|
||||
def __init__(self,
|
||||
print_networks_summary=False,
|
||||
dump_csv=True,
|
||||
dump_signals_to_csv_every_x_episodes=5,
|
||||
dump_gifs=False,
|
||||
dump_mp4=False,
|
||||
video_dump_methods=[],
|
||||
dump_in_episode_signals=False,
|
||||
dump_parameters_documentation=True,
|
||||
render=False,
|
||||
native_rendering=False,
|
||||
max_fps_for_human_control=10,
|
||||
tensorboard=False,
|
||||
add_rendered_image_to_env_response=False):
|
||||
"""
|
||||
:param print_networks_summary:
|
||||
If set to True, a summary of all the networks structure will be printed at the beginning of the experiment
|
||||
:param dump_csv:
|
||||
If set to True, the logger will dump logs to a csv file once in every dump_signals_to_csv_every_x_episodes
|
||||
episodes. The logs can be later used to visualize the training process using Coach Dashboard.
|
||||
:param dump_signals_to_csv_every_x_episodes:
|
||||
Defines the number of episodes between writing new data to the csv log files. Lower values can affect
|
||||
performance, as writing to disk may take time, and it is done synchronously.
|
||||
:param dump_gifs:
|
||||
If set to True, GIF videos of the environment will be stored into the experiment directory according to
|
||||
the filters defined in video_dump_methods.
|
||||
:param dump_mp4:
|
||||
If set to True, MP4 videos of the environment will be stored into the experiment directory according to
|
||||
the filters defined in video_dump_methods.
|
||||
:param dump_in_episode_signals:
|
||||
If set to True, csv files will be dumped for each episode for inspecting different metrics within the
|
||||
episode. This means that for each step in each episode, different metrics such as the reward, the
|
||||
future return, etc. will be saved. Setting this to True may affect performance severely, and therefore
|
||||
this should be used only for debugging purposes.
|
||||
:param dump_parameters_documentation:
|
||||
If set to True, a json file containing all the agent parameters will be saved in the experiment directory.
|
||||
This may be very useful for inspecting the values defined for each parameters and making sure that all
|
||||
the parameters are defined as expected.
|
||||
:param render:
|
||||
If set to True, the environment render function will be called for each step, rendering the image of the
|
||||
environment. This may affect the performance of training, and is highly dependent on the environment.
|
||||
By default, Coach uses PyGame to render the environment image instead of the environment specific rendered.
|
||||
To change this, use the native_rendering flag.
|
||||
:param native_rendering:
|
||||
If set to True, the environment native renderer will be used for rendering the environment image.
|
||||
In some cases this can be slower than rendering using PyGame through Coach, but in other cases the
|
||||
environment opens its native renderer by default, so rendering with PyGame is an unnecessary overhead.
|
||||
:param max_fps_for_human_control:
|
||||
The maximum number of frames per second used while playing the environment as a human. This only has
|
||||
effect while using the --play flag for Coach.
|
||||
:param tensorboard:
|
||||
If set to True, TensorBoard summaries will be stored in the experiment directory. This can later be
|
||||
loaded in TensorBoard in order to visualize the training process.
|
||||
:param video_dump_methods:
|
||||
A list of dump methods that will be used as filters for deciding when to save videos.
|
||||
The filters in the list will be checked one after the other until the first dump method that returns
|
||||
false for should_dump() in the environment class. This list will only be used if dump_mp4 or dump_gif are
|
||||
set to True.
|
||||
:param add_rendered_image_to_env_response:
|
||||
Some environments have a different observation compared to the one displayed while rendering.
|
||||
For some cases it can be useful to pass the rendered image to the agent for visualization purposes.
|
||||
If this flag is set to True, the rendered image will be added to the environment EnvResponse object,
|
||||
which will be passed to the agent and allow using those images.
|
||||
"""
|
||||
super().__init__()
|
||||
# Visualization parameters
|
||||
self.print_summary = True
|
||||
self.dump_csv = True
|
||||
self.dump_gifs = False
|
||||
self.dump_mp4 = False
|
||||
self.dump_signals_to_csv_every_x_episodes = 5
|
||||
self.dump_in_episode_signals = False
|
||||
self.dump_parameters_documentation = True
|
||||
self.render = False
|
||||
self.native_rendering = False
|
||||
self.max_fps_for_human_control = 10
|
||||
self.tensorboard = False
|
||||
self.video_dump_methods = [] # a list of dump methods which will be checked one after the other until the first
|
||||
# dump method that returns false for should_dump()
|
||||
self.add_rendered_image_to_env_response = False
|
||||
self.print_networks_summary = print_networks_summary
|
||||
self.dump_csv = dump_csv
|
||||
self.dump_gifs = dump_gifs
|
||||
self.dump_mp4 = dump_mp4
|
||||
self.dump_signals_to_csv_every_x_episodes = dump_signals_to_csv_every_x_episodes
|
||||
self.dump_in_episode_signals = dump_in_episode_signals
|
||||
self.dump_parameters_documentation = dump_parameters_documentation
|
||||
self.render = render
|
||||
self.native_rendering = native_rendering
|
||||
self.max_fps_for_human_control = max_fps_for_human_control
|
||||
self.tensorboard = tensorboard
|
||||
self.video_dump_methods = video_dump_methods
|
||||
self.add_rendered_image_to_env_response = add_rendered_image_to_env_response
|
||||
|
||||
|
||||
class AgentParameters(Parameters):
|
||||
def __init__(self, algorithm: AlgorithmParameters, exploration: 'ExplorationParameters', memory: 'MemoryParameters',
|
||||
networks: Dict[str, NetworkParameters], visualization: VisualizationParameters=VisualizationParameters()):
|
||||
"""
|
||||
:param algorithm: the algorithmic parameters
|
||||
:param exploration: the exploration policy parameters
|
||||
:param memory: the memory module parameters
|
||||
:param networks: the parameters for the networks of the agent
|
||||
:param visualization: the visualization parameters
|
||||
:param algorithm:
|
||||
A class inheriting AlgorithmParameters.
|
||||
The parameters used for the specific algorithm used by the agent.
|
||||
These parameters can be later referenced in the agent implementation through self.ap.algorithm.
|
||||
:param exploration:
|
||||
Either a class inheriting ExplorationParameters or a dictionary mapping between action
|
||||
space types and their corresponding ExplorationParameters. If a dictionary was used,
|
||||
when the agent will be instantiated, the correct exploration policy parameters will be used
|
||||
according to the real type of the environment action space.
|
||||
These parameters will be used to instantiate the exporation policy.
|
||||
:param memory:
|
||||
A class inheriting MemoryParameters. It defines all the parameters used by the memory module.
|
||||
:param networks:
|
||||
A dictionary mapping between network names and their corresponding network parmeters, defined
|
||||
as a class inheriting NetworkParameters. Each element will be used in order to instantiate
|
||||
a NetworkWrapper class, and all the network wrappers will be stored in the agent under
|
||||
self.network_wrappers. self.network_wrappers is a dict mapping between the network name that
|
||||
was given in the networks dict, and the instantiated network wrapper.
|
||||
:param visualization:
|
||||
A class inheriting VisualizationParameters and defining various parameters that can be
|
||||
used for visualization purposes, such as printing to the screen, rendering, and saving videos.
|
||||
"""
|
||||
super().__init__()
|
||||
self.visualization = visualization
|
||||
@@ -278,13 +414,14 @@ class AgentParameters(Parameters):
|
||||
|
||||
|
||||
class TaskParameters(Parameters):
|
||||
def __init__(self, framework_type: str, evaluate_only: bool=False, use_cpu: bool=False, experiment_path=None,
|
||||
seed=None):
|
||||
def __init__(self, framework_type: str="tensorflow", evaluate_only: bool=False, use_cpu: bool=False,
|
||||
experiment_path="./experiments/test/", seed=None, save_checkpoint_secs=None):
|
||||
"""
|
||||
:param framework_type: deep learning framework type. currently only tensorflow is supported
|
||||
:param evaluate_only: the task will be used only for evaluating the model
|
||||
:param use_cpu: use the cpu for this task
|
||||
:param experiment_path: the path to the directory which will store all the experiment outputs
|
||||
:param save_checkpoint_secs: the number of seconds between each checkpoint saving
|
||||
:param seed: a seed to use for the random numbers generator
|
||||
"""
|
||||
self.framework_type = framework_type
|
||||
@@ -292,6 +429,7 @@ class TaskParameters(Parameters):
|
||||
self.evaluate_only = evaluate_only
|
||||
self.use_cpu = use_cpu
|
||||
self.experiment_path = experiment_path
|
||||
self.save_checkpoint_secs = save_checkpoint_secs
|
||||
self.seed = seed
|
||||
|
||||
|
||||
|
||||
@@ -77,6 +77,7 @@ def get_graph_manager_from_args(args: argparse.Namespace) -> 'GraphManager':
|
||||
graph_manager.visualization_parameters.dump_mp4 = graph_manager.visualization_parameters.dump_mp4 or args.dump_mp4
|
||||
graph_manager.visualization_parameters.render = args.render
|
||||
graph_manager.visualization_parameters.tensorboard = args.tensorboard
|
||||
graph_manager.visualization_parameters.print_networks_summary = args.print_networks_summary
|
||||
|
||||
# update the custom parameters
|
||||
if args.custom_parameter is not None:
|
||||
@@ -291,8 +292,8 @@ def main():
|
||||
"\"visualization.render=False; num_training_iterations=500; optimizer='rmsprop'\"",
|
||||
default=None,
|
||||
type=str)
|
||||
parser.add_argument('--print_parameters',
|
||||
help="(flag) Print tuning_parameters to stdout",
|
||||
parser.add_argument('--print_networks_summary',
|
||||
help="(flag) Print network summary to stdout",
|
||||
action='store_true')
|
||||
parser.add_argument('-tb', '--tensorboard',
|
||||
help="(flag) When using the TensorFlow backend, enable TensorBoard log dumps. ",
|
||||
@@ -336,7 +337,8 @@ def main():
|
||||
evaluate_only=args.evaluate,
|
||||
experiment_path=args.experiment_path,
|
||||
seed=args.seed,
|
||||
use_cpu=args.use_cpu)
|
||||
use_cpu=args.use_cpu,
|
||||
save_checkpoint_secs=args.save_checkpoint_secs)
|
||||
task_parameters.__dict__ = add_items_to_dict(task_parameters.__dict__, args.__dict__)
|
||||
|
||||
start_graph(graph_manager=graph_manager, task_parameters=task_parameters)
|
||||
|
||||
@@ -97,8 +97,8 @@ class CarlaEnvironmentParameters(EnvironmentParameters):
|
||||
LOW = "Low"
|
||||
EPIC = "Epic"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self, level="town1"):
|
||||
super().__init__(level=level)
|
||||
self.frame_skip = 3 # the frame skip affects the fps of the server directly. fps = 30 / frameskip
|
||||
self.server_height = 512
|
||||
self.server_width = 720
|
||||
@@ -106,7 +106,7 @@ class CarlaEnvironmentParameters(EnvironmentParameters):
|
||||
self.camera_width = 180
|
||||
self.experiment_suite = None # an optional CARLA experiment suite to use
|
||||
self.config = None
|
||||
self.level = 'town1'
|
||||
self.level = level
|
||||
self.quality = self.Quality.LOW
|
||||
self.cameras = [CameraTypes.FRONT]
|
||||
self.weather_id = [1]
|
||||
|
||||
@@ -43,8 +43,8 @@ class ObservationType(Enum):
|
||||
|
||||
# Parameters
|
||||
class ControlSuiteEnvironmentParameters(EnvironmentParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self, level=None):
|
||||
super().__init__(level=level)
|
||||
self.observation_type = ObservationType.Measurements
|
||||
self.default_input_filter = ControlSuiteInputFilter
|
||||
self.default_output_filter = ControlSuiteOutputFilter
|
||||
|
||||
@@ -104,8 +104,8 @@ DoomOutputFilter.add_action_filter('to_discrete', FullDiscreteActionSpaceMap())
|
||||
|
||||
|
||||
class DoomEnvironmentParameters(EnvironmentParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self, level=None):
|
||||
super().__init__(level=level)
|
||||
self.default_input_filter = DoomInputFilter
|
||||
self.default_output_filter = DoomOutputFilter
|
||||
self.cameras = [DoomEnvironment.CameraTypes.OBSERVATION]
|
||||
|
||||
@@ -92,9 +92,9 @@ class CustomWrapper(object):
|
||||
|
||||
|
||||
class EnvironmentParameters(Parameters):
|
||||
def __init__(self):
|
||||
def __init__(self, level=None):
|
||||
super().__init__()
|
||||
self.level = None
|
||||
self.level = level
|
||||
self.frame_skip = 4
|
||||
self.seed = None
|
||||
self.human_control = False
|
||||
|
||||
@@ -18,6 +18,7 @@ import gym
|
||||
import numpy as np
|
||||
import scipy.ndimage
|
||||
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.utils import lower_under_to_upper, short_dynamic_import
|
||||
|
||||
try:
|
||||
@@ -40,7 +41,7 @@ except ImportError:
|
||||
failed_imports.append("PyBullet")
|
||||
|
||||
from typing import Dict, Any, Union
|
||||
from rl_coach.core_types import RunPhase
|
||||
from rl_coach.core_types import RunPhase, EnvironmentSteps
|
||||
from rl_coach.environments.environment import Environment, EnvironmentParameters, LevelSelection
|
||||
from rl_coach.spaces import DiscreteActionSpace, BoxActionSpace, ImageObservationSpace, VectorObservationSpace, \
|
||||
StateSpace, RewardSpace
|
||||
@@ -57,10 +58,9 @@ from rl_coach.logger import screen
|
||||
|
||||
|
||||
# Parameters
|
||||
|
||||
class GymEnvironmentParameters(EnvironmentParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self, level=None):
|
||||
super().__init__(level=level)
|
||||
self.random_initialization_steps = 0
|
||||
self.max_over_num_frames = 1
|
||||
self.additional_simulator_parameters = None
|
||||
@@ -70,64 +70,32 @@ class GymEnvironmentParameters(EnvironmentParameters):
|
||||
return 'rl_coach.environments.gym_environment:GymEnvironment'
|
||||
|
||||
|
||||
"""
|
||||
Roboschool Environment Components
|
||||
"""
|
||||
RoboSchoolInputFilters = NoInputFilter()
|
||||
RoboSchoolOutputFilters = NoOutputFilter()
|
||||
|
||||
|
||||
class Roboschool(GymEnvironmentParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Generic parameters for vector environments such as mujoco, roboschool, bullet, etc.
|
||||
class GymVectorEnvironment(GymEnvironmentParameters):
|
||||
def __init__(self, level=None):
|
||||
super().__init__(level=level)
|
||||
self.frame_skip = 1
|
||||
self.default_input_filter = RoboSchoolInputFilters
|
||||
self.default_output_filter = RoboSchoolOutputFilters
|
||||
self.default_input_filter = NoInputFilter()
|
||||
self.default_output_filter = NoOutputFilter()
|
||||
|
||||
|
||||
# Roboschool
|
||||
gym_roboschool_envs = ['inverted_pendulum', 'inverted_pendulum_swingup', 'inverted_double_pendulum', 'reacher',
|
||||
'hopper', 'walker2d', 'half_cheetah', 'ant', 'humanoid', 'humanoid_flagrun',
|
||||
'humanoid_flagrun_harder', 'pong']
|
||||
roboschool_v0 = {e: "{}".format(lower_under_to_upper(e) + '-v0') for e in gym_roboschool_envs}
|
||||
|
||||
"""
|
||||
Mujoco Environment Components
|
||||
"""
|
||||
MujocoInputFilter = NoInputFilter()
|
||||
MujocoOutputFilter = NoOutputFilter()
|
||||
|
||||
|
||||
class Mujoco(GymEnvironmentParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.frame_skip = 1
|
||||
self.default_input_filter = MujocoInputFilter
|
||||
self.default_output_filter = MujocoOutputFilter
|
||||
|
||||
|
||||
# Mujoco
|
||||
gym_mujoco_envs = ['inverted_pendulum', 'inverted_double_pendulum', 'reacher', 'hopper', 'walker2d', 'half_cheetah',
|
||||
'ant', 'swimmer', 'humanoid', 'humanoid_standup', 'pusher', 'thrower', 'striker']
|
||||
|
||||
mujoco_v2 = {e: "{}".format(lower_under_to_upper(e) + '-v2') for e in gym_mujoco_envs}
|
||||
mujoco_v2['walker2d'] = 'Walker2d-v2'
|
||||
|
||||
# Fetch
|
||||
gym_fetch_envs = ['reach', 'slide', 'push', 'pick_and_place']
|
||||
fetch_v1 = {e: "{}".format('Fetch' + lower_under_to_upper(e) + '-v1') for e in gym_fetch_envs}
|
||||
|
||||
"""
|
||||
Bullet Environment Components
|
||||
"""
|
||||
BulletInputFilter = NoInputFilter()
|
||||
BulletOutputFilter = NoOutputFilter()
|
||||
|
||||
|
||||
class Bullet(GymEnvironmentParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.frame_skip = 1
|
||||
self.default_input_filter = BulletInputFilter
|
||||
self.default_output_filter = BulletOutputFilter
|
||||
|
||||
|
||||
"""
|
||||
Atari Environment Components
|
||||
@@ -145,8 +113,8 @@ AtariOutputFilter = NoOutputFilter()
|
||||
|
||||
|
||||
class Atari(GymEnvironmentParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self, level=None):
|
||||
super().__init__(level=level)
|
||||
self.frame_skip = 4
|
||||
self.max_over_num_frames = 2
|
||||
self.random_initialization_steps = 30
|
||||
@@ -167,6 +135,14 @@ atari_deterministic_v4 = {e: "{}".format(lower_under_to_upper(e) + 'Deterministi
|
||||
atari_no_frameskip_v4 = {e: "{}".format(lower_under_to_upper(e) + 'NoFrameskip-v4') for e in gym_atari_envs}
|
||||
|
||||
|
||||
# default atari schedule used in the DeepMind papers
|
||||
atari_schedule = ScheduleParameters()
|
||||
atari_schedule.improve_steps = EnvironmentSteps(50000000)
|
||||
atari_schedule.steps_between_evaluation_periods = EnvironmentSteps(250000)
|
||||
atari_schedule.evaluation_steps = EnvironmentSteps(135000)
|
||||
atari_schedule.heatup_steps = EnvironmentSteps(50000)
|
||||
|
||||
|
||||
class MaxOverFramesAndFrameskipEnvWrapper(gym.Wrapper):
|
||||
def __init__(self, env, frameskip=4, max_over_num_frames=2):
|
||||
super().__init__(env)
|
||||
|
||||
@@ -85,8 +85,8 @@ StarcraftNormalizingOutputFilter.add_action_filter(
|
||||
|
||||
|
||||
class StarCraft2EnvironmentParameters(EnvironmentParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self, level=None):
|
||||
super().__init__(level=level)
|
||||
self.screen_size = 84
|
||||
self.minimap_size = 64
|
||||
self.feature_minimap_maps_to_use = range(7)
|
||||
|
||||
@@ -19,7 +19,7 @@ from typing import List, Dict
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.agents.dqn_agent import DQNAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.architecture import NoisyNetDense
|
||||
from rl_coach.architectures.tensorflow_components.layers import NoisyNetDense
|
||||
from rl_coach.base_parameters import AgentParameters, NetworkParameters
|
||||
from rl_coach.spaces import ActionSpace, BoxActionSpace, DiscreteActionSpace
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, T
|
||||
from rl_coach.environments.environment import Environment
|
||||
from rl_coach.level_manager import LevelManager
|
||||
from rl_coach.logger import screen, Logger
|
||||
from rl_coach.utils import set_cpu
|
||||
from rl_coach.utils import set_cpu, start_shell_command_and_wait
|
||||
|
||||
|
||||
class ScheduleParameters(Parameters):
|
||||
@@ -51,6 +51,27 @@ class HumanPlayScheduleParameters(ScheduleParameters):
|
||||
self.improve_steps = TrainingSteps(10000000000)
|
||||
|
||||
|
||||
class SimpleScheduleWithoutEvaluation(ScheduleParameters):
|
||||
def __init__(self, improve_steps=TrainingSteps(10000000000)):
|
||||
super().__init__()
|
||||
self.heatup_steps = EnvironmentSteps(0)
|
||||
self.evaluation_steps = EnvironmentEpisodes(0)
|
||||
self.steps_between_evaluation_periods = improve_steps
|
||||
self.improve_steps = improve_steps
|
||||
|
||||
|
||||
class SimpleSchedule(ScheduleParameters):
|
||||
def __init__(self,
|
||||
improve_steps=TrainingSteps(10000000000),
|
||||
steps_between_evaluation_periods=EnvironmentEpisodes(50),
|
||||
evaluation_steps=EnvironmentEpisodes(5)):
|
||||
super().__init__()
|
||||
self.heatup_steps = EnvironmentSteps(0)
|
||||
self.evaluation_steps = evaluation_steps
|
||||
self.steps_between_evaluation_periods = steps_between_evaluation_periods
|
||||
self.improve_steps = improve_steps
|
||||
|
||||
|
||||
class GraphManager(object):
|
||||
"""
|
||||
A graph manager is responsible for creating and initializing a graph of agents, including all its internal
|
||||
@@ -78,6 +99,7 @@ class GraphManager(object):
|
||||
|
||||
# timers
|
||||
self.graph_initialization_time = time.time()
|
||||
self.graph_creation_time = None
|
||||
self.heatup_start_time = None
|
||||
self.training_start_time = None
|
||||
self.last_evaluation_start_time = None
|
||||
@@ -94,7 +116,8 @@ class GraphManager(object):
|
||||
self.checkpoint_saver = None
|
||||
self.graph_logger = Logger()
|
||||
|
||||
def create_graph(self, task_parameters: TaskParameters):
|
||||
def create_graph(self, task_parameters: TaskParameters=TaskParameters()):
|
||||
self.graph_creation_time = time.time()
|
||||
self.task_parameters = task_parameters
|
||||
|
||||
if isinstance(task_parameters, DistributedTaskParameters):
|
||||
@@ -129,6 +152,8 @@ class GraphManager(object):
|
||||
|
||||
self.setup_logger()
|
||||
|
||||
return self
|
||||
|
||||
def _create_graph(self, task_parameters: TaskParameters) -> Tuple[List[LevelManager], List[Environment]]:
|
||||
"""
|
||||
Create all the graph modules and the graph scheduler
|
||||
@@ -207,6 +232,29 @@ class GraphManager(object):
|
||||
# restore from checkpoint if given
|
||||
self.restore_checkpoint()
|
||||
|
||||
# tf.train.write_graph(tf.get_default_graph(),
|
||||
# logdir=self.task_parameters.save_checkpoint_dir,
|
||||
# name='graphdef.pb',
|
||||
# as_text=False)
|
||||
# self.save_checkpoint()
|
||||
#
|
||||
# output_nodes = []
|
||||
# for level in self.level_managers:
|
||||
# for agent in level.agents.values():
|
||||
# for network in agent.networks.values():
|
||||
# for output in network.online_network.outputs:
|
||||
# output_nodes.append(output.name.split(":")[0])
|
||||
#
|
||||
# freeze_graph_command = [
|
||||
# "python -m tensorflow.python.tools.freeze_graph",
|
||||
# "--input_graph={}".format(os.path.join(self.task_parameters.save_checkpoint_dir, "graphdef.pb")),
|
||||
# "--input_binary=true",
|
||||
# "--output_node_names='{}'".format(','.join(output_nodes)),
|
||||
# "--input_checkpoint={}".format(os.path.join(self.task_parameters.save_checkpoint_dir, "0_Step-0.ckpt")),
|
||||
# "--output_graph={}".format(os.path.join(self.task_parameters.save_checkpoint_dir, "frozen_graph.pb"))
|
||||
# ]
|
||||
# start_shell_command_and_wait(" ".join(freeze_graph_command))
|
||||
|
||||
def setup_logger(self) -> None:
|
||||
# dump documentation
|
||||
logger_prefix = "{graph_name}".format(graph_name=self.name)
|
||||
@@ -250,6 +298,8 @@ class GraphManager(object):
|
||||
:param steps: the number of steps as a tuple of steps time and steps count
|
||||
:return: None
|
||||
"""
|
||||
self.verify_graph_was_created()
|
||||
|
||||
steps_copy = copy.copy(steps)
|
||||
|
||||
if steps_copy.num_steps > 0:
|
||||
@@ -284,6 +334,8 @@ class GraphManager(object):
|
||||
:param steps: number of training iterations to perform
|
||||
:return: None
|
||||
"""
|
||||
self.verify_graph_was_created()
|
||||
|
||||
# perform several steps of training interleaved with acting
|
||||
count_end = self.total_steps_counters[RunPhase.TRAIN][TrainingSteps] + steps.num_steps
|
||||
while self.total_steps_counters[RunPhase.TRAIN][TrainingSteps] < count_end:
|
||||
@@ -299,6 +351,8 @@ class GraphManager(object):
|
||||
lives available
|
||||
:return: None
|
||||
"""
|
||||
self.verify_graph_was_created()
|
||||
|
||||
self.reset_required = False
|
||||
[environment.reset_internal_state(force_environment_reset) for environment in self.environments]
|
||||
[manager.reset_internal_state() for manager in self.level_managers]
|
||||
@@ -314,6 +368,8 @@ class GraphManager(object):
|
||||
:return: the actual number of steps done, a boolean value that represent if the episode was done when finishing
|
||||
the function call
|
||||
"""
|
||||
self.verify_graph_was_created()
|
||||
|
||||
# perform several steps of playing
|
||||
result = None
|
||||
|
||||
@@ -366,6 +422,8 @@ class GraphManager(object):
|
||||
:param steps: the number of steps as a tuple of steps time and steps count
|
||||
:return: None
|
||||
"""
|
||||
self.verify_graph_was_created()
|
||||
|
||||
# perform several steps of training interleaved with acting
|
||||
if steps.num_steps > 0:
|
||||
self.phase = RunPhase.TRAIN
|
||||
@@ -395,6 +453,8 @@ class GraphManager(object):
|
||||
:param keep_networks_in_sync: sync the network parameters with the global network before each episode
|
||||
:return: None
|
||||
"""
|
||||
self.verify_graph_was_created()
|
||||
|
||||
if steps.num_steps > 0:
|
||||
self.phase = RunPhase.TEST
|
||||
self.last_evaluation_start_time = time.time()
|
||||
@@ -411,6 +471,8 @@ class GraphManager(object):
|
||||
self.phase = RunPhase.UNDEFINED
|
||||
|
||||
def restore_checkpoint(self):
|
||||
self.verify_graph_was_created()
|
||||
|
||||
# TODO: find better way to load checkpoints that were saved with a global network into the online network
|
||||
if hasattr(self.task_parameters, 'checkpoint_restore_dir') and self.task_parameters.checkpoint_restore_dir:
|
||||
import tensorflow as tf
|
||||
@@ -473,6 +535,7 @@ class GraphManager(object):
|
||||
2.2. Evaluate
|
||||
:return: None
|
||||
"""
|
||||
self.verify_graph_was_created()
|
||||
|
||||
# initialize the network parameters from the global network
|
||||
self.sync_graph()
|
||||
@@ -491,6 +554,14 @@ class GraphManager(object):
|
||||
self.train_and_act(self.steps_between_evaluation_periods)
|
||||
self.evaluate(self.evaluation_steps)
|
||||
|
||||
def verify_graph_was_created(self):
|
||||
"""
|
||||
Verifies that the graph was already created, and if not, it creates it with the default task parameters
|
||||
:return: None
|
||||
"""
|
||||
if self.graph_creation_time is None:
|
||||
self.create_graph()
|
||||
|
||||
def __str__(self):
|
||||
result = ""
|
||||
for key, val in self.__dict__.items():
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from rl_coach.agents.actor_critic_agent import ActorCriticAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import SingleLevelSelection, SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4
|
||||
from rl_coach.exploration_policies.categorical import CategoricalParameters
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
|
||||
@@ -29,17 +28,10 @@ agent_params.algorithm.beta_entropy = 0.05
|
||||
agent_params.network_wrappers['main'].middleware_parameters = FCMiddlewareParameters()
|
||||
agent_params.network_wrappers['main'].learning_rate = 0.0001
|
||||
|
||||
agent_params.exploration = CategoricalParameters()
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Atari()
|
||||
env_params.level = SingleLevelSelection(atari_deterministic_v4)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = Atari(level=SingleLevelSelection(atari_deterministic_v4))
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -48,5 +40,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.trace_test_levels = ['breakout', 'pong', 'space_invaders']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from rl_coach.agents.actor_critic_agent import ActorCriticAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.middlewares.lstm_middleware import LSTMMiddlewareParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, MiddlewareScheme, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import SingleLevelSelection, SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4, AtariInputFilter
|
||||
from rl_coach.exploration_policies.categorical import CategoricalParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
|
||||
@@ -29,17 +28,11 @@ agent_params.algorithm.beta_entropy = 0.05
|
||||
agent_params.network_wrappers['main'].learning_rate = 0.0001
|
||||
agent_params.network_wrappers['main'].middleware_parameters = LSTMMiddlewareParameters(scheme=MiddlewareScheme.Medium,
|
||||
number_of_lstm_cells=256)
|
||||
agent_params.exploration = CategoricalParameters()
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Atari()
|
||||
env_params.level = SingleLevelSelection(atari_deterministic_v4)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = True
|
||||
env_params = Atari(level=SingleLevelSelection(atari_deterministic_v4))
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -48,5 +41,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.trace_test_levels = ['breakout', 'pong', 'space_invaders']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from rl_coach.agents.bootstrapped_dqn_agent import BootstrappedDQNAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
|
||||
from rl_coach.core_types import EnvironmentSteps
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
@@ -25,12 +25,7 @@ agent_params.network_wrappers['main'].learning_rate = 0.00025
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Atari()
|
||||
env_params.level = SingleLevelSelection(atari_deterministic_v4)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = Atari(level=SingleLevelSelection(atari_deterministic_v4))
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -39,5 +34,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.trace_test_levels = ['breakout', 'pong', 'space_invaders']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,19 +1,8 @@
|
||||
from rl_coach.agents.categorical_dqn_agent import CategoricalDQNAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4, atari_schedule
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
|
||||
####################
|
||||
# Graph Scheduling #
|
||||
####################
|
||||
schedule_params = ScheduleParameters()
|
||||
schedule_params.improve_steps = EnvironmentSteps(50000000)
|
||||
schedule_params.steps_between_evaluation_periods = EnvironmentSteps(250000)
|
||||
schedule_params.evaluation_steps = EnvironmentSteps(135000)
|
||||
schedule_params.heatup_steps = EnvironmentSteps(50000)
|
||||
|
||||
#########
|
||||
# Agent #
|
||||
@@ -24,12 +13,7 @@ agent_params.network_wrappers['main'].learning_rate = 0.00025
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Atari()
|
||||
env_params.level = SingleLevelSelection(atari_deterministic_v4)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = Atari(level=SingleLevelSelection(atari_deterministic_v4))
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -38,5 +22,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.trace_test_levels = ['breakout', 'pong', 'space_invaders']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=atari_schedule, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,19 +1,8 @@
|
||||
from rl_coach.agents.ddqn_agent import DDQNAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4, atari_schedule
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
|
||||
####################
|
||||
# Graph Scheduling #
|
||||
####################
|
||||
schedule_params = ScheduleParameters()
|
||||
schedule_params.improve_steps = EnvironmentSteps(50000000)
|
||||
schedule_params.steps_between_evaluation_periods = EnvironmentSteps(250000)
|
||||
schedule_params.evaluation_steps = EnvironmentSteps(135000)
|
||||
schedule_params.heatup_steps = EnvironmentSteps(50000)
|
||||
|
||||
#########
|
||||
# Agent #
|
||||
@@ -24,12 +13,7 @@ agent_params.network_wrappers['main'].learning_rate = 0.00025
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Atari()
|
||||
env_params.level = SingleLevelSelection(atari_deterministic_v4)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = Atari(level=SingleLevelSelection(atari_deterministic_v4))
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -38,5 +22,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.trace_test_levels = ['breakout', 'pong', 'space_invaders']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=atari_schedule, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
@@ -1,23 +1,11 @@
|
||||
from rl_coach.agents.ddqn_agent import DDQNAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4, atari_schedule
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.memories.non_episodic.prioritized_experience_replay import PrioritizedExperienceReplayParameters
|
||||
from rl_coach.schedules import LinearSchedule
|
||||
|
||||
####################
|
||||
# Graph Scheduling #
|
||||
####################
|
||||
|
||||
schedule_params = ScheduleParameters()
|
||||
schedule_params.improve_steps = EnvironmentSteps(50000000)
|
||||
schedule_params.steps_between_evaluation_periods = EnvironmentSteps(250000)
|
||||
schedule_params.evaluation_steps = EnvironmentSteps(135000)
|
||||
schedule_params.heatup_steps = EnvironmentSteps(50000)
|
||||
|
||||
#########
|
||||
# Agent #
|
||||
#########
|
||||
@@ -29,12 +17,7 @@ agent_params.memory.beta = LinearSchedule(0.4, 1, 12500000) # 12.5M training it
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Atari()
|
||||
env_params.level = SingleLevelSelection(atari_deterministic_v4)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = Atari(level=SingleLevelSelection(atari_deterministic_v4))
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -43,5 +26,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.trace_test_levels = ['breakout', 'pong', 'space_invaders']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=atari_schedule, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,19 +1,8 @@
|
||||
from rl_coach.agents.dqn_agent import DQNAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4, atari_schedule
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
|
||||
####################
|
||||
# Graph Scheduling #
|
||||
####################
|
||||
schedule_params = ScheduleParameters()
|
||||
schedule_params.improve_steps = EnvironmentSteps(50000000)
|
||||
schedule_params.steps_between_evaluation_periods = EnvironmentSteps(250000)
|
||||
schedule_params.evaluation_steps = EnvironmentSteps(135000)
|
||||
schedule_params.heatup_steps = EnvironmentSteps(50000)
|
||||
|
||||
#########
|
||||
# Agent #
|
||||
@@ -25,12 +14,7 @@ agent_params.network_wrappers['main'].learning_rate = 0.0001
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Atari()
|
||||
env_params.level = SingleLevelSelection(atari_deterministic_v4)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = Atari(level=SingleLevelSelection(atari_deterministic_v4))
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -39,5 +23,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.trace_test_levels = ['breakout', 'pong', 'space_invaders']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=atari_schedule, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,21 +1,11 @@
|
||||
from rl_coach.agents.dqn_agent import DQNAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4, atari_schedule
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.memories.non_episodic.prioritized_experience_replay import PrioritizedExperienceReplayParameters
|
||||
from rl_coach.schedules import LinearSchedule
|
||||
|
||||
####################
|
||||
# Graph Scheduling #
|
||||
####################
|
||||
schedule_params = ScheduleParameters()
|
||||
schedule_params.improve_steps = EnvironmentSteps(50000000)
|
||||
schedule_params.steps_between_evaluation_periods = EnvironmentSteps(250000)
|
||||
schedule_params.evaluation_steps = EnvironmentSteps(135000)
|
||||
schedule_params.heatup_steps = EnvironmentSteps(50000)
|
||||
|
||||
#########
|
||||
# Agent #
|
||||
@@ -28,12 +18,7 @@ agent_params.memory.beta = LinearSchedule(0.4, 1, 12500000) # 12.5M training it
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Atari()
|
||||
env_params.level = SingleLevelSelection(atari_deterministic_v4)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = Atari(level=SingleLevelSelection(atari_deterministic_v4))
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -42,5 +27,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.trace_test_levels = ['breakout', 'pong', 'space_invaders']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=atari_schedule, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -3,20 +3,9 @@ import math
|
||||
from rl_coach.agents.ddqn_agent import DDQNAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.heads.dueling_q_head import DuelingQHeadParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, MiddlewareScheme, PresetValidationParameters
|
||||
from rl_coach.core_types import EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4, atari_schedule
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
|
||||
####################
|
||||
# Graph Scheduling #
|
||||
####################
|
||||
schedule_params = ScheduleParameters()
|
||||
schedule_params.improve_steps = EnvironmentSteps(50000000)
|
||||
schedule_params.steps_between_evaluation_periods = EnvironmentSteps(250000)
|
||||
schedule_params.evaluation_steps = EnvironmentSteps(135000)
|
||||
schedule_params.heatup_steps = EnvironmentSteps(50000)
|
||||
|
||||
#########
|
||||
# Agent #
|
||||
@@ -26,19 +15,14 @@ agent_params = DDQNAgentParameters()
|
||||
# since we are using Adam instead of RMSProp, we adjust the learning rate as well
|
||||
agent_params.network_wrappers['main'].learning_rate = 0.0001
|
||||
agent_params.network_wrappers['main'].middleware_parameters.scheme = MiddlewareScheme.Empty
|
||||
agent_params.network_wrappers['main'].heads_parameters = [DuelingQHeadParameters()]
|
||||
agent_params.network_wrappers['main'].rescale_gradient_from_head_by_factor = [1/math.sqrt(2)]
|
||||
agent_params.network_wrappers['main'].heads_parameters = \
|
||||
[DuelingQHeadParameters(rescale_gradient_from_head_by_factor=1/math.sqrt(2))]
|
||||
agent_params.network_wrappers['main'].clip_gradients = 10
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Atari()
|
||||
env_params.level = SingleLevelSelection(atari_deterministic_v4)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = Atari(level=SingleLevelSelection(atari_deterministic_v4))
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -47,5 +31,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.trace_test_levels = ['breakout', 'pong', 'space_invaders']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=atari_schedule, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,22 +1,13 @@
|
||||
from rl_coach.agents.ddqn_agent import DDQNAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.heads.dueling_q_head import DuelingQHeadParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, MiddlewareScheme, PresetValidationParameters
|
||||
from rl_coach.core_types import EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4
|
||||
from rl_coach.core_types import EnvironmentSteps
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4, atari_schedule
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.memories.non_episodic.prioritized_experience_replay import PrioritizedExperienceReplayParameters
|
||||
from rl_coach.schedules import LinearSchedule, PieceWiseSchedule, ConstantSchedule
|
||||
|
||||
####################
|
||||
# Graph Scheduling #
|
||||
####################
|
||||
schedule_params = ScheduleParameters()
|
||||
schedule_params.improve_steps = EnvironmentSteps(50000000)
|
||||
schedule_params.steps_between_evaluation_periods = EnvironmentSteps(250000)
|
||||
schedule_params.evaluation_steps = EnvironmentSteps(135000)
|
||||
schedule_params.heatup_steps = EnvironmentSteps(50000)
|
||||
|
||||
#########
|
||||
# Agent #
|
||||
@@ -38,12 +29,7 @@ agent_params.memory.beta = LinearSchedule(0.4, 1, 12500000) # 12.5M training it
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Atari()
|
||||
env_params.level = SingleLevelSelection(atari_deterministic_v4)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = Atari(level=SingleLevelSelection(atari_deterministic_v4))
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -52,5 +38,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.trace_test_levels = ['breakout', 'pong', 'space_invaders']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=atari_schedule, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from rl_coach.agents.nec_agent import NECAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import SingleLevelSelection, SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.core_types import EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, AtariInputFilter, atari_deterministic_v4
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
@@ -27,14 +27,9 @@ agent_params.input_filter.remove_reward_filter('clipping')
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Atari()
|
||||
env_params.level = SingleLevelSelection(atari_deterministic_v4)
|
||||
env_params = Atari(level=SingleLevelSelection(atari_deterministic_v4))
|
||||
env_params.random_initialization_steps = 1
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
|
||||
########
|
||||
# Test #
|
||||
########
|
||||
@@ -42,5 +37,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.test_using_a_trace_test = False
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from rl_coach.agents.n_step_q_agent import NStepQAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Conv2d, Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import Conv2d, Dense
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import SingleLevelSelection, SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
@@ -22,19 +22,14 @@ schedule_params.heatup_steps = EnvironmentSteps(0)
|
||||
agent_params = NStepQAgentParameters()
|
||||
|
||||
agent_params.network_wrappers['main'].learning_rate = 0.0001
|
||||
agent_params.network_wrappers['main'].input_embedders_parameters['observation'].scheme = [Conv2d([16, 8, 4]),
|
||||
Conv2d([32, 4, 2])]
|
||||
agent_params.network_wrappers['main'].middleware_parameters.scheme = [Dense([256])]
|
||||
agent_params.network_wrappers['main'].input_embedders_parameters['observation'].scheme = [Conv2d(16, 8, 4),
|
||||
Conv2d(32, 4, 2)]
|
||||
agent_params.network_wrappers['main'].middleware_parameters.scheme = [Dense(256)]
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Atari()
|
||||
env_params.level = SingleLevelSelection(atari_deterministic_v4)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = Atari(level=SingleLevelSelection(atari_deterministic_v4))
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -43,5 +38,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.trace_test_levels = ['breakout', 'pong', 'space_invaders']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,19 +1,8 @@
|
||||
from rl_coach.agents.qr_dqn_agent import QuantileRegressionDQNAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4, atari_schedule
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
|
||||
####################
|
||||
# Graph Scheduling #
|
||||
####################
|
||||
schedule_params = ScheduleParameters()
|
||||
schedule_params.improve_steps = EnvironmentSteps(50000000)
|
||||
schedule_params.steps_between_evaluation_periods = EnvironmentSteps(250000)
|
||||
schedule_params.evaluation_steps = EnvironmentSteps(135000)
|
||||
schedule_params.heatup_steps = EnvironmentSteps(50000)
|
||||
|
||||
#########
|
||||
# Agent #
|
||||
@@ -25,12 +14,7 @@ agent_params.algorithm.huber_loss_interval = 1 # k = 0 for strict quantile loss
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Atari()
|
||||
env_params.level = SingleLevelSelection(atari_deterministic_v4)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = Atari(level=SingleLevelSelection(atari_deterministic_v4))
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -39,5 +23,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.trace_test_levels = ['breakout', 'pong', 'space_invaders']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=atari_schedule, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from rl_coach.agents.rainbow_dqn_agent import RainbowDQNAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
|
||||
from rl_coach.core_types import EnvironmentSteps
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
@@ -30,12 +30,7 @@ agent_params.memory.alpha = 0.5
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Atari()
|
||||
env_params.level = SingleLevelSelection(atari_deterministic_v4)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = Atari(level=SingleLevelSelection(atari_deterministic_v4))
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -44,5 +39,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.trace_test_levels = ['breakout', 'pong', 'space_invaders']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,20 +1,9 @@
|
||||
from rl_coach.agents.bootstrapped_dqn_agent import BootstrappedDQNAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4, atari_schedule
|
||||
from rl_coach.exploration_policies.ucb import UCBParameters
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
|
||||
####################
|
||||
# Graph Scheduling #
|
||||
####################
|
||||
schedule_params = ScheduleParameters()
|
||||
schedule_params.improve_steps = EnvironmentSteps(50000000)
|
||||
schedule_params.steps_between_evaluation_periods = EnvironmentSteps(250000)
|
||||
schedule_params.evaluation_steps = EnvironmentSteps(135000)
|
||||
schedule_params.heatup_steps = EnvironmentSteps(50000)
|
||||
|
||||
#########
|
||||
# Agent #
|
||||
@@ -26,12 +15,7 @@ agent_params.exploration = UCBParameters()
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Atari()
|
||||
env_params.level = SingleLevelSelection(atari_deterministic_v4)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = Atari(level=SingleLevelSelection(atari_deterministic_v4))
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -40,5 +24,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.trace_test_levels = ['breakout', 'pong', 'space_invaders']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=atari_schedule, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
from rl_coach.agents.dqn_agent import DQNAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme, \
|
||||
PresetValidationParameters
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
|
||||
from rl_coach.environments.gym_environment import Mujoco
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.memories.memory import MemoryGranularity
|
||||
@@ -28,7 +27,7 @@ schedule_params.heatup_steps = EnvironmentSteps(0)
|
||||
agent_params = DQNAgentParameters()
|
||||
agent_params.network_wrappers['main'].learning_rate = 0.001
|
||||
agent_params.network_wrappers['main'].batch_size = 128
|
||||
agent_params.network_wrappers['main'].middleware_parameters.scheme = [Dense([256])]
|
||||
agent_params.network_wrappers['main'].middleware_parameters.scheme = [Dense(256)]
|
||||
agent_params.network_wrappers['main'].input_embedders_parameters = {
|
||||
'state': InputEmbedderParameters(scheme=EmbedderScheme.Empty),
|
||||
'desired_goal': InputEmbedderParameters(scheme=EmbedderScheme.Empty)
|
||||
@@ -45,12 +44,10 @@ agent_params.exploration.evaluation_epsilon = 0
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Mujoco()
|
||||
env_params.level = 'rl_coach.environments.toy_problems.bit_flip:BitFlip'
|
||||
env_params = GymVectorEnvironment(level='rl_coach.environments.toy_problems.bit_flip:BitFlip')
|
||||
env_params.additional_simulator_parameters = {'bit_length': bit_length, 'mean_zero': True}
|
||||
# env_params.custom_reward_threshold = -bit_length + 1
|
||||
env_params.custom_reward_threshold = -bit_length + 1
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -61,7 +58,7 @@ preset_validation_params.min_reward_threshold = -7.9
|
||||
preset_validation_params.max_episodes_to_achieve_reward = 10000
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from rl_coach.agents.dqn_agent import DQNAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme, \
|
||||
PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
from rl_coach.environments.gym_environment import Mujoco
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.memories.episodic.episodic_hindsight_experience_replay import \
|
||||
@@ -30,7 +30,7 @@ schedule_params.heatup_steps = EnvironmentSteps(0)
|
||||
agent_params = DQNAgentParameters()
|
||||
agent_params.network_wrappers['main'].learning_rate = 0.001
|
||||
agent_params.network_wrappers['main'].batch_size = 128
|
||||
agent_params.network_wrappers['main'].middleware_parameters.scheme = [Dense([256])]
|
||||
agent_params.network_wrappers['main'].middleware_parameters.scheme = [Dense(256)]
|
||||
agent_params.network_wrappers['main'].input_embedders_parameters = {
|
||||
'state': InputEmbedderParameters(scheme=EmbedderScheme.Empty),
|
||||
'desired_goal': InputEmbedderParameters(scheme=EmbedderScheme.Empty)}
|
||||
@@ -55,13 +55,10 @@ agent_params.memory.goals_space = GoalsSpace(goal_name='state',
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Mujoco()
|
||||
env_params.level = 'rl_coach.environments.toy_problems.bit_flip:BitFlip'
|
||||
env_params = GymVectorEnvironment(level='rl_coach.environments.toy_problems.bit_flip:BitFlip')
|
||||
env_params.additional_simulator_parameters = {'bit_length': bit_length, 'mean_zero': True}
|
||||
env_params.custom_reward_threshold = -bit_length + 1
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
|
||||
# currently no tests for this preset as the max reward can be accidently achieved. will be fixed with trace based tests.
|
||||
|
||||
########
|
||||
@@ -73,7 +70,7 @@ preset_validation_params.min_reward_threshold = -15
|
||||
preset_validation_params.max_episodes_to_achieve_reward = 10000
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
|
||||
|
||||
@@ -2,9 +2,8 @@ import copy
|
||||
|
||||
from rl_coach.agents.ddpg_agent import DDPGAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.carla_environment import CarlaEnvironmentParameters, CameraTypes, CarlaInputFilter
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
|
||||
@@ -49,12 +48,7 @@ agent_params.input_filter.copy_filters_from_one_observation_to_another('forward_
|
||||
# Environment #
|
||||
###############
|
||||
env_params = CarlaEnvironmentParameters()
|
||||
env_params.level = 'town1'
|
||||
env_params.cameras = [CameraTypes.FRONT, CameraTypes.LEFT, CameraTypes.RIGHT]
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params)
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters())
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
import numpy as np
|
||||
import os
|
||||
from logger import screen
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
# make sure you have $CARLA_ROOT/PythonClient in your PYTHONPATH
|
||||
from carla.driving_benchmark.experiment_suites import CoRL2017
|
||||
from rl_coach.logger import screen
|
||||
|
||||
from rl_coach.agents.cil_agent import CILAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Conv2d, Dense
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
from rl_coach.architectures.tensorflow_components.heads.cil_head import RegressionHeadParameters
|
||||
from rl_coach.architectures.tensorflow_components.layers import Conv2d, Dense, BatchnormActivationDropout
|
||||
from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.carla_environment import CarlaEnvironmentParameters, CameraTypes
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.carla_environment import CarlaEnvironmentParameters
|
||||
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
|
||||
from rl_coach.filters.filter import InputFilter
|
||||
from rl_coach.filters.observation.observation_crop_filter import ObservationCropFilter
|
||||
@@ -27,7 +27,6 @@ from rl_coach.schedules import ConstantSchedule
|
||||
from rl_coach.spaces import ImageObservationSpace
|
||||
from rl_coach.utilities.carla_dataset_to_replay_buffer import create_dataset
|
||||
|
||||
|
||||
####################
|
||||
# Graph Scheduling #
|
||||
####################
|
||||
@@ -44,38 +43,64 @@ agent_params = CILAgentParameters()
|
||||
|
||||
# forward camera and measurements input
|
||||
agent_params.network_wrappers['main'].input_embedders_parameters = {
|
||||
'CameraRGB': InputEmbedderParameters(scheme=[Conv2d([32, 5, 2]),
|
||||
Conv2d([32, 3, 1]),
|
||||
Conv2d([64, 3, 2]),
|
||||
Conv2d([64, 3, 1]),
|
||||
Conv2d([128, 3, 2]),
|
||||
Conv2d([128, 3, 1]),
|
||||
Conv2d([256, 3, 1]),
|
||||
Conv2d([256, 3, 1]),
|
||||
Dense([512]),
|
||||
Dense([512])],
|
||||
dropout=True,
|
||||
batchnorm=True),
|
||||
'measurements': InputEmbedderParameters(scheme=[Dense([128]),
|
||||
Dense([128])])
|
||||
'CameraRGB': InputEmbedderParameters(
|
||||
scheme=[
|
||||
Conv2d(32, 5, 2),
|
||||
BatchnormActivationDropout(batchnorm=True, activation_function=tf.tanh),
|
||||
Conv2d(32, 3, 1),
|
||||
BatchnormActivationDropout(batchnorm=True, activation_function=tf.tanh),
|
||||
Conv2d(64, 3, 2),
|
||||
BatchnormActivationDropout(batchnorm=True, activation_function=tf.tanh),
|
||||
Conv2d(64, 3, 1),
|
||||
BatchnormActivationDropout(batchnorm=True, activation_function=tf.tanh),
|
||||
Conv2d(128, 3, 2),
|
||||
BatchnormActivationDropout(batchnorm=True, activation_function=tf.tanh),
|
||||
Conv2d(128, 3, 1),
|
||||
BatchnormActivationDropout(batchnorm=True, activation_function=tf.tanh),
|
||||
Conv2d(256, 3, 1),
|
||||
BatchnormActivationDropout(batchnorm=True, activation_function=tf.tanh),
|
||||
Conv2d(256, 3, 1),
|
||||
BatchnormActivationDropout(batchnorm=True, activation_function=tf.tanh),
|
||||
Dense(512),
|
||||
BatchnormActivationDropout(activation_function=tf.tanh, dropout_rate=0.3),
|
||||
Dense(512),
|
||||
BatchnormActivationDropout(activation_function=tf.tanh, dropout_rate=0.3)
|
||||
],
|
||||
activation_function='none' # we define the activation function for each layer explicitly
|
||||
),
|
||||
'measurements': InputEmbedderParameters(
|
||||
scheme=[
|
||||
Dense(128),
|
||||
BatchnormActivationDropout(activation_function=tf.tanh, dropout_rate=0.5),
|
||||
Dense(128),
|
||||
BatchnormActivationDropout(activation_function=tf.tanh, dropout_rate=0.5)
|
||||
],
|
||||
activation_function='none' # we define the activation function for each layer explicitly
|
||||
)
|
||||
}
|
||||
|
||||
# TODO: batch norm is currently applied to the fc layers which is not desired
|
||||
# TODO: dropout should be configured differenetly per layer [1.0] * 8 + [0.7] * 2 + [0.5] * 2 + [0.5] * 1 + [0.5, 1.] * 5
|
||||
|
||||
# simple fc middleware
|
||||
agent_params.network_wrappers['main'].middleware_parameters = FCMiddlewareParameters(scheme=[Dense([512])])
|
||||
agent_params.network_wrappers['main'].middleware_parameters = \
|
||||
FCMiddlewareParameters(
|
||||
scheme=[
|
||||
Dense(512),
|
||||
BatchnormActivationDropout(activation_function=tf.tanh, dropout_rate=0.5)
|
||||
],
|
||||
activation_function='none'
|
||||
)
|
||||
|
||||
# output branches
|
||||
agent_params.network_wrappers['main'].heads_parameters = [
|
||||
RegressionHeadParameters(),
|
||||
RegressionHeadParameters(),
|
||||
RegressionHeadParameters(),
|
||||
RegressionHeadParameters()
|
||||
RegressionHeadParameters(
|
||||
scheme=[
|
||||
Dense(256),
|
||||
BatchnormActivationDropout(activation_function=tf.tanh, dropout_rate=0.5),
|
||||
Dense(256),
|
||||
BatchnormActivationDropout(activation_function=tf.tanh)
|
||||
],
|
||||
num_output_head_copies=4 # follow lane, left, right, straight
|
||||
)
|
||||
]
|
||||
# agent_params.network_wrappers['main'].num_output_head_copies = 4 # follow lane, left, right, straight
|
||||
agent_params.network_wrappers['main'].rescale_gradient_from_head_by_factor = [1, 1, 1, 1]
|
||||
agent_params.network_wrappers['main'].loss_weights = [1, 1, 1, 1]
|
||||
# TODO: there should be another head predicting the speed which is connected directly to the forward camera embedding
|
||||
|
||||
agent_params.network_wrappers['main'].batch_size = 120
|
||||
@@ -125,7 +150,6 @@ if not os.path.exists(agent_params.memory.load_memory_from_file_path):
|
||||
# Environment #
|
||||
###############
|
||||
env_params = CarlaEnvironmentParameters()
|
||||
env_params.level = 'town1'
|
||||
env_params.cameras = ['CameraRGB']
|
||||
env_params.camera_height = 600
|
||||
env_params.camera_width = 800
|
||||
@@ -134,9 +158,5 @@ env_params.allow_braking = True
|
||||
env_params.quality = CarlaEnvironmentParameters.Quality.EPIC
|
||||
env_params.experiment_suite = CoRL2017('Town01')
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST)]
|
||||
vis_params.dump_mp4 = True
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params)
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters())
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from rl_coach.agents.ddpg_agent import DDPGAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.carla_environment import CarlaEnvironmentParameters
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
|
||||
@@ -29,11 +28,6 @@ agent_params.network_wrappers['critic'].input_embedders_parameters['forward_came
|
||||
# Environment #
|
||||
###############
|
||||
env_params = CarlaEnvironmentParameters()
|
||||
env_params.level = 'town1'
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params)
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters())
|
||||
|
||||
@@ -3,9 +3,8 @@ import math
|
||||
from rl_coach.agents.ddqn_agent import DDQNAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.heads.dueling_q_head import DuelingQHeadParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, MiddlewareScheme
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.carla_environment import CarlaEnvironmentParameters
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod
|
||||
from rl_coach.filters.action.box_discretization import BoxDiscretization
|
||||
from rl_coach.filters.filter import OutputFilter
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
@@ -26,9 +25,9 @@ schedule_params.heatup_steps = EnvironmentSteps(1000)
|
||||
#########
|
||||
agent_params = DDQNAgentParameters()
|
||||
agent_params.network_wrappers['main'].learning_rate = 0.00025
|
||||
agent_params.network_wrappers['main'].heads_parameters = [DuelingQHeadParameters()]
|
||||
agent_params.network_wrappers['main'].heads_parameters = \
|
||||
[DuelingQHeadParameters(rescale_gradient_from_head_by_factor=1/math.sqrt(2))]
|
||||
agent_params.network_wrappers['main'].middleware_parameters.scheme = MiddlewareScheme.Empty
|
||||
agent_params.network_wrappers['main'].rescale_gradient_from_head_by_factor = [1/math.sqrt(2), 1/math.sqrt(2)]
|
||||
agent_params.network_wrappers['main'].clip_gradients = 10
|
||||
agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(4)
|
||||
agent_params.network_wrappers['main'].input_embedders_parameters['forward_camera'] = \
|
||||
@@ -40,11 +39,6 @@ agent_params.output_filter.add_action_filter('discretization', BoxDiscretization
|
||||
# Environment #
|
||||
###############
|
||||
env_params = CarlaEnvironmentParameters()
|
||||
env_params.level = 'town1'
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params)
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters())
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from rl_coach.agents.actor_critic_agent import ActorCriticAgentParameters
|
||||
from rl_coach.agents.policy_optimization_agent import PolicyGradientRescaler
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.environments.gym_environment import MujocoInputFilter, Mujoco
|
||||
from rl_coach.exploration_policies.categorical import CategoricalParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
||||
from rl_coach.filters.filter import InputFilter
|
||||
from rl_coach.filters.reward.reward_rescale_filter import RewardRescaleFilter
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
@@ -33,20 +32,13 @@ agent_params.algorithm.beta_entropy = 0.01
|
||||
agent_params.network_wrappers['main'].optimizer_type = 'Adam'
|
||||
agent_params.network_wrappers['main'].learning_rate = 0.0001
|
||||
|
||||
agent_params.input_filter = MujocoInputFilter()
|
||||
agent_params.input_filter = InputFilter()
|
||||
agent_params.input_filter.add_reward_filter('rescale', RewardRescaleFilter(1/200.))
|
||||
|
||||
agent_params.exploration = CategoricalParameters()
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Mujoco()
|
||||
env_params.level = 'CartPole-v0'
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = GymVectorEnvironment(level='CartPole-v0')
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -58,5 +50,5 @@ preset_validation_params.max_episodes_to_achieve_reward = 300
|
||||
preset_validation_params.num_workers = 8
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from rl_coach.agents.dfp_agent import DFPAgentParameters, HandlingTargetsAfterEpisodeEnd
|
||||
from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.environments.gym_environment import Mujoco
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.schedules import LinearSchedule
|
||||
@@ -37,12 +36,7 @@ agent_params.algorithm.handling_targets_after_episode_end = HandlingTargetsAfter
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Mujoco()
|
||||
env_params.level = 'CartPole-v0'
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = GymVectorEnvironment(level='CartPole-v0')
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -53,5 +47,5 @@ preset_validation_params.min_reward_threshold = 120
|
||||
preset_validation_params.max_episodes_to_achieve_reward = 250
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from rl_coach.agents.dqn_agent import DQNAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.environments.gym_environment import Mujoco
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.memories.memory import MemoryGranularity
|
||||
@@ -41,12 +40,7 @@ agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000)
|
||||
################
|
||||
# Environment #
|
||||
################
|
||||
env_params = Mujoco()
|
||||
env_params.level = 'CartPole-v0'
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = GymVectorEnvironment(level='CartPole-v0')
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -57,5 +51,5 @@ preset_validation_params.min_reward_threshold = 150
|
||||
preset_validation_params.max_episodes_to_achieve_reward = 250
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -3,9 +3,8 @@ import math
|
||||
from rl_coach.agents.ddqn_agent import DDQNAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.heads.dueling_q_head import DuelingQHeadParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.environments.gym_environment import Mujoco
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.memories.memory import MemoryGranularity
|
||||
@@ -34,8 +33,8 @@ agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(1)
|
||||
# NN configuration
|
||||
agent_params.network_wrappers['main'].learning_rate = 0.00025
|
||||
agent_params.network_wrappers['main'].replace_mse_with_huber_loss = False
|
||||
agent_params.network_wrappers['main'].heads_parameters = [DuelingQHeadParameters()]
|
||||
agent_params.network_wrappers['main'].rescale_gradient_from_head_by_factor = [1/math.sqrt(2), 1/math.sqrt(2)]
|
||||
agent_params.network_wrappers['main'].heads_parameters = \
|
||||
[DuelingQHeadParameters(rescale_gradient_from_head_by_factor=1/math.sqrt(2))]
|
||||
|
||||
# ER size
|
||||
agent_params.memory.max_size = (MemoryGranularity.Transitions, 40000)
|
||||
@@ -46,12 +45,7 @@ agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000)
|
||||
################
|
||||
# Environment #
|
||||
################
|
||||
env_params = Mujoco()
|
||||
env_params.level = 'CartPole-v0'
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = GymVectorEnvironment(level='CartPole-v0')
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -62,5 +56,5 @@ preset_validation_params.min_reward_threshold = 150
|
||||
preset_validation_params.max_episodes_to_achieve_reward = 250
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from rl_coach.agents.nec_agent import NECAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.environments.gym_environment import Atari, MujocoInputFilter
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
||||
from rl_coach.filters.filter import InputFilter
|
||||
from rl_coach.filters.reward.reward_rescale_filter import RewardRescaleFilter
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
@@ -30,18 +30,13 @@ agent_params.exploration.epsilon_schedule = LinearSchedule(0.5, 0.1, 1000)
|
||||
agent_params.exploration.evaluation_epsilon = 0
|
||||
agent_params.algorithm.discount = 0.99
|
||||
agent_params.memory.max_size = (MemoryGranularity.Episodes, 200)
|
||||
agent_params.input_filter = MujocoInputFilter()
|
||||
agent_params.input_filter = InputFilter()
|
||||
agent_params.input_filter.add_reward_filter('rescale', RewardRescaleFilter(1/200.))
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Atari()
|
||||
env_params.level = 'CartPole-v0'
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = GymVectorEnvironment(level='CartPole-v0')
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -53,5 +48,5 @@ preset_validation_params.max_episodes_to_achieve_reward = 300
|
||||
preset_validation_params.test_using_a_trace_test = False
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from rl_coach.agents.n_step_q_agent import NStepQAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.environments.gym_environment import MujocoInputFilter, Mujoco
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
||||
from rl_coach.filters.filter import InputFilter
|
||||
from rl_coach.filters.reward.reward_rescale_filter import RewardRescaleFilter
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
@@ -24,18 +24,13 @@ agent_params = NStepQAgentParameters()
|
||||
agent_params.algorithm.discount = 0.99
|
||||
agent_params.network_wrappers['main'].learning_rate = 0.0001
|
||||
agent_params.algorithm.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(100)
|
||||
agent_params.input_filter = MujocoInputFilter()
|
||||
agent_params.input_filter = InputFilter()
|
||||
agent_params.input_filter.add_reward_filter('rescale', RewardRescaleFilter(1/200.))
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Mujoco()
|
||||
env_params.level = 'CartPole-v0'
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = GymVectorEnvironment(level='CartPole-v0')
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -47,5 +42,5 @@ preset_validation_params.max_episodes_to_achieve_reward = 200
|
||||
preset_validation_params.num_workers = 8
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from rl_coach.agents.pal_agent import PALAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.environments.gym_environment import Mujoco
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.memories.memory import MemoryGranularity
|
||||
@@ -41,12 +40,7 @@ agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000)
|
||||
################
|
||||
# Environment #
|
||||
################
|
||||
env_params = Mujoco()
|
||||
env_params.level = 'CartPole-v0'
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = GymVectorEnvironment(level='CartPole-v0')
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -57,5 +51,5 @@ preset_validation_params.min_reward_threshold = 150
|
||||
preset_validation_params.max_episodes_to_achieve_reward = 250
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from rl_coach.agents.policy_gradients_agent import PolicyGradientsAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.environments.gym_environment import MujocoInputFilter, Mujoco
|
||||
from rl_coach.exploration_policies.categorical import CategoricalParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
||||
from rl_coach.filters.filter import InputFilter
|
||||
from rl_coach.filters.reward.reward_rescale_filter import RewardRescaleFilter
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
@@ -29,20 +28,13 @@ agent_params.algorithm.num_steps_between_gradient_updates = 20000
|
||||
agent_params.network_wrappers['main'].optimizer_type = 'Adam'
|
||||
agent_params.network_wrappers['main'].learning_rate = 0.0005
|
||||
|
||||
agent_params.input_filter = MujocoInputFilter()
|
||||
agent_params.input_filter = InputFilter()
|
||||
agent_params.input_filter.add_reward_filter('rescale', RewardRescaleFilter(1/200.))
|
||||
|
||||
agent_params.exploration = CategoricalParameters()
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Mujoco()
|
||||
env_params.level = 'CartPole-v0'
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = GymVectorEnvironment(level='CartPole-v0')
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -53,6 +45,6 @@ preset_validation_params.min_reward_threshold = 130
|
||||
preset_validation_params.max_episodes_to_achieve_reward = 550
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from rl_coach.agents.ddpg_agent import DDPGAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.control_suite_environment import ControlSuiteEnvironmentParameters, control_suite_envs
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import MujocoInputFilter
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.filters.filter import InputFilter
|
||||
from rl_coach.filters.reward.reward_rescale_filter import RewardRescaleFilter
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
@@ -27,23 +27,18 @@ agent_params.network_wrappers['actor'].input_embedders_parameters['measurements'
|
||||
agent_params.network_wrappers['actor'].input_embedders_parameters.pop('observation')
|
||||
agent_params.network_wrappers['critic'].input_embedders_parameters['measurements'] = \
|
||||
agent_params.network_wrappers['critic'].input_embedders_parameters.pop('observation')
|
||||
agent_params.network_wrappers['actor'].input_embedders_parameters['measurements'].scheme = [Dense([300])]
|
||||
agent_params.network_wrappers['actor'].middleware_parameters.scheme = [Dense([200])]
|
||||
agent_params.network_wrappers['critic'].input_embedders_parameters['measurements'].scheme = [Dense([400])]
|
||||
agent_params.network_wrappers['critic'].middleware_parameters.scheme = [Dense([300])]
|
||||
agent_params.network_wrappers['actor'].input_embedders_parameters['measurements'].scheme = [Dense(300)]
|
||||
agent_params.network_wrappers['actor'].middleware_parameters.scheme = [Dense(200)]
|
||||
agent_params.network_wrappers['critic'].input_embedders_parameters['measurements'].scheme = [Dense(400)]
|
||||
agent_params.network_wrappers['critic'].middleware_parameters.scheme = [Dense(300)]
|
||||
agent_params.network_wrappers['critic'].input_embedders_parameters['action'].scheme = EmbedderScheme.Empty
|
||||
agent_params.input_filter = MujocoInputFilter()
|
||||
agent_params.input_filter = InputFilter()
|
||||
agent_params.input_filter.add_reward_filter("rescale", RewardRescaleFilter(1/10.))
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = ControlSuiteEnvironmentParameters()
|
||||
env_params.level = SingleLevelSelection(control_suite_envs)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = ControlSuiteEnvironmentParameters(level=SingleLevelSelection(control_suite_envs))
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -52,5 +47,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.trace_test_levels = ['cartpole:swingup', 'hopper:hop']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
from rl_coach.agents.actor_critic_agent import ActorCriticAgentParameters
|
||||
from rl_coach.agents.policy_optimization_agent import PolicyGradientRescaler
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.doom_environment import DoomEnvironmentParameters
|
||||
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.environments.gym_environment import MujocoInputFilter
|
||||
from rl_coach.exploration_policies.categorical import CategoricalParameters
|
||||
from rl_coach.filters.filter import InputFilter
|
||||
from rl_coach.filters.reward.reward_rescale_filter import RewardRescaleFilter
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
@@ -27,24 +25,18 @@ schedule_params.heatup_steps = EnvironmentSteps(0)
|
||||
agent_params = ActorCriticAgentParameters()
|
||||
agent_params.algorithm.policy_gradient_rescaler = PolicyGradientRescaler.GAE
|
||||
agent_params.network_wrappers['main'].learning_rate = 0.0001
|
||||
agent_params.input_filter = MujocoInputFilter()
|
||||
agent_params.input_filter = InputFilter()
|
||||
agent_params.input_filter.add_reward_filter('rescale', RewardRescaleFilter(1/100.))
|
||||
agent_params.algorithm.num_steps_between_gradient_updates = 30
|
||||
agent_params.algorithm.apply_gradients_every_x_episodes = 1
|
||||
agent_params.algorithm.gae_lambda = 1.0
|
||||
agent_params.algorithm.beta_entropy = 0.01
|
||||
agent_params.network_wrappers['main'].clip_gradients = 40.
|
||||
agent_params.exploration = CategoricalParameters()
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = DoomEnvironmentParameters()
|
||||
env_params.level = 'basic'
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = DoomEnvironmentParameters(level='basic')
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -57,5 +49,5 @@ preset_validation_params.num_workers = 8
|
||||
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -35,8 +35,7 @@ agent_params.memory.load_memory_from_file_path = 'datasets/doom_basic.p'
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = DoomEnvironmentParameters()
|
||||
env_params.level = 'basic'
|
||||
env_params = DoomEnvironmentParameters(level='basic')
|
||||
|
||||
########
|
||||
# Test #
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from rl_coach.agents.dfp_agent import DFPAgentParameters, HandlingTargetsAfterEpisodeEnd
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.doom_environment import DoomEnvironmentParameters
|
||||
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.schedules import LinearSchedule
|
||||
@@ -39,12 +38,8 @@ agent_params.algorithm.handling_targets_after_episode_end = HandlingTargetsAfter
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = DoomEnvironmentParameters()
|
||||
env_params.level = 'basic'
|
||||
env_params = DoomEnvironmentParameters(level='basic')
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -53,5 +48,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.trace_max_env_steps = 2000
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from rl_coach.agents.dqn_agent import DQNAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.doom_environment import DoomEnvironmentParameters
|
||||
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.memories.memory import MemoryGranularity
|
||||
@@ -35,12 +34,8 @@ agent_params.network_wrappers['main'].replace_mse_with_huber_loss = False
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = DoomEnvironmentParameters()
|
||||
env_params.level = 'basic'
|
||||
env_params = DoomEnvironmentParameters(level='basic')
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -52,5 +47,5 @@ preset_validation_params.max_episodes_to_achieve_reward = 400
|
||||
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from rl_coach.agents.ddqn_agent import DDQNAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.heads.dueling_q_head import DuelingQHeadParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.doom_environment import DoomEnvironmentParameters
|
||||
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.memories.memory import MemoryGranularity
|
||||
@@ -36,12 +35,8 @@ agent_params.network_wrappers['main'].heads_parameters = [DuelingQHeadParameters
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = DoomEnvironmentParameters()
|
||||
env_params.level = 'basic'
|
||||
env_params = DoomEnvironmentParameters(level='basic')
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params)
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters())
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from rl_coach.agents.dfp_agent import DFPAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters
|
||||
from rl_coach.core_types import EnvironmentSteps, RunPhase
|
||||
from rl_coach.core_types import EnvironmentSteps
|
||||
from rl_coach.environments.doom_environment import DoomEnvironmentParameters, DoomEnvironment
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.schedules import LinearSchedule
|
||||
@@ -44,13 +43,9 @@ agent_params.network_wrappers['main'].input_embedders_parameters['observation'].
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = DoomEnvironmentParameters()
|
||||
env_params.level = 'BATTLE_COACH_LOCAL'
|
||||
env_params = DoomEnvironmentParameters(level='BATTLE_COACH_LOCAL')
|
||||
env_params.cameras = [DoomEnvironment.CameraTypes.OBSERVATION]
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params)
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters())
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from rl_coach.agents.dfp_agent import DFPAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme, MiddlewareScheme, \
|
||||
PresetValidationParameters
|
||||
from rl_coach.core_types import EnvironmentSteps, RunPhase, EnvironmentEpisodes
|
||||
from rl_coach.core_types import EnvironmentSteps, EnvironmentEpisodes
|
||||
from rl_coach.environments.doom_environment import DoomEnvironmentParameters
|
||||
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.schedules import LinearSchedule
|
||||
@@ -54,11 +53,8 @@ agent_params.algorithm.scale_measurements_targets['GameVariable.HEALTH'] = 30.0
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = DoomEnvironmentParameters()
|
||||
env_params.level = 'HEALTH_GATHERING'
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = DoomEnvironmentParameters(level='HEALTH_GATHERING')
|
||||
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -71,5 +67,5 @@ preset_validation_params.max_episodes_to_achieve_reward = 70
|
||||
preset_validation_params.test_using_a_trace_test = False
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from rl_coach.agents.mmc_agent import MixedMonteCarloAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.doom_environment import DoomEnvironmentParameters
|
||||
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.memories.memory import MemoryGranularity
|
||||
@@ -35,11 +34,7 @@ agent_params.network_wrappers['main'].replace_mse_with_huber_loss = False
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = DoomEnvironmentParameters()
|
||||
env_params.level = 'HEALTH_GATHERING'
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = DoomEnvironmentParameters(level='HEALTH_GATHERING')
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -52,5 +47,5 @@ preset_validation_params.test_using_a_trace_test = False
|
||||
# preset_validation_params.max_episodes_to_achieve_reward = 300
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from rl_coach.agents.dfp_agent import DFPAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme, MiddlewareScheme, \
|
||||
PresetValidationParameters
|
||||
from rl_coach.core_types import EnvironmentSteps, RunPhase, EnvironmentEpisodes
|
||||
from rl_coach.core_types import EnvironmentSteps, EnvironmentEpisodes
|
||||
from rl_coach.environments.doom_environment import DoomEnvironmentParameters
|
||||
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.schedules import LinearSchedule
|
||||
@@ -54,11 +53,8 @@ agent_params.algorithm.scale_measurements_targets['GameVariable.HEALTH'] = 30.0
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = DoomEnvironmentParameters()
|
||||
env_params.level = 'HEALTH_GATHERING_SUPREME_COACH_LOCAL'
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = DoomEnvironmentParameters(level='HEALTH_GATHERING_SUPREME_COACH_LOCAL')
|
||||
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -67,5 +63,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.test_using_a_trace_test = False
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from rl_coach.agents.bootstrapped_dqn_agent import BootstrappedDQNAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters
|
||||
from rl_coach.core_types import EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.gym_environment import Mujoco
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
||||
from rl_coach.filters.filter import NoInputFilter, NoOutputFilter
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
@@ -29,8 +29,8 @@ agent_params.network_wrappers['main'].learning_rate = 0.00025
|
||||
agent_params.memory.max_size = (MemoryGranularity.Transitions, 1000000)
|
||||
agent_params.algorithm.discount = 0.99
|
||||
agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(4)
|
||||
agent_params.network_wrappers['main'].num_output_head_copies = num_output_head_copies
|
||||
agent_params.network_wrappers['main'].rescale_gradient_from_head_by_factor = [1.0/num_output_head_copies]*num_output_head_copies
|
||||
agent_params.network_wrappers['main'].heads_parameters[0].num_output_head_copies = num_output_head_copies
|
||||
agent_params.network_wrappers['main'].heads_parameters[0].rescale_gradient_from_head_by_factor = 1.0/num_output_head_copies
|
||||
agent_params.exploration.bootstrapped_data_sharing_probability = 1.0
|
||||
agent_params.exploration.architecture_num_q_heads = num_output_head_copies
|
||||
agent_params.exploration.epsilon_schedule = ConstantSchedule(0)
|
||||
@@ -40,26 +40,9 @@ agent_params.output_filter = NoOutputFilter()
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Mujoco()
|
||||
env_params.level = 'rl_coach.environments.toy_problems.exploration_chain:ExplorationChain'
|
||||
|
||||
env_params = GymVectorEnvironment(level='rl_coach.environments.toy_problems.exploration_chain:ExplorationChain')
|
||||
env_params.additional_simulator_parameters = {'chain_length': N, 'max_steps': N+7}
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
|
||||
|
||||
########
|
||||
# Test #
|
||||
########
|
||||
|
||||
# currently no test here as bootstrapped_dqn seems to be broken
|
||||
|
||||
# preset_validation_params = PresetValidationParameters()
|
||||
# preset_validation_params.test = True
|
||||
# preset_validation_params.min_reward_threshold = 1600
|
||||
# preset_validation_params.max_episodes_to_achieve_reward = 70
|
||||
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,)
|
||||
# preset_validation_params=preset_validation_params)
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters())
|
||||
|
||||
@@ -38,18 +38,8 @@ agent_params.output_filter = NoOutputFilter()
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = GymEnvironmentParameters()
|
||||
env_params.level = 'rl_coach.environments.toy_problems.exploration_chain:ExplorationChain'
|
||||
env_params = GymEnvironmentParameters(level='rl_coach.environments.toy_problems.exploration_chain:ExplorationChain')
|
||||
env_params.additional_simulator_parameters = {'chain_length': N, 'max_steps': N+7}
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
|
||||
|
||||
# preset_validation_params = PresetValidationParameters()
|
||||
# preset_validation_params.test = True
|
||||
# preset_validation_params.min_reward_threshold = 1600
|
||||
# preset_validation_params.max_episodes_to_achieve_reward = 70
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,)
|
||||
# preset_validation_params=preset_validation_params)
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters())
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from rl_coach.agents.bootstrapped_dqn_agent import BootstrappedDQNAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters
|
||||
from rl_coach.core_types import EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.gym_environment import Mujoco
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
||||
from rl_coach.exploration_policies.ucb import UCBParameters
|
||||
from rl_coach.filters.filter import NoInputFilter, NoOutputFilter
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
@@ -30,8 +30,8 @@ agent_params.network_wrappers['main'].learning_rate = 0.00025
|
||||
agent_params.memory.max_size = (MemoryGranularity.Transitions, 1000000)
|
||||
agent_params.algorithm.discount = 0.99
|
||||
agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(4)
|
||||
agent_params.network_wrappers['main'].num_output_head_copies = num_output_head_copies
|
||||
agent_params.network_wrappers['main'].rescale_gradient_from_head_by_factor = [1.0/num_output_head_copies]*num_output_head_copies
|
||||
agent_params.network_wrappers['main'].heads_parameters[0].num_output_head_copies = num_output_head_copies
|
||||
agent_params.network_wrappers['main'].heads_parameters[0].rescale_gradient_from_head_by_factor = 1.0/num_output_head_copies
|
||||
agent_params.exploration = UCBParameters()
|
||||
agent_params.exploration.bootstrapped_data_sharing_probability = 1.0
|
||||
agent_params.exploration.architecture_num_q_heads = num_output_head_copies
|
||||
@@ -43,12 +43,8 @@ agent_params.output_filter = NoOutputFilter()
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Mujoco()
|
||||
env_params.level = 'rl_coach.environments.toy_problems.exploration_chain:ExplorationChain'
|
||||
|
||||
env_params = GymVectorEnvironment(level='rl_coach.environments.toy_problems.exploration_chain:ExplorationChain')
|
||||
env_params.additional_simulator_parameters = {'chain_length': N, 'max_steps': N+7}
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params)
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters())
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from rl_coach.agents.ddpg_agent import DDPGAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme, PresetValidationParameters
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
from rl_coach.core_types import EnvironmentEpisodes, EnvironmentSteps, TrainingSteps, RunPhase
|
||||
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod, SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Mujoco, MujocoInputFilter, fetch_v1
|
||||
from rl_coach.core_types import EnvironmentEpisodes, EnvironmentSteps, TrainingSteps
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment, fetch_v1
|
||||
from rl_coach.exploration_policies.e_greedy import EGreedyParameters
|
||||
from rl_coach.filters.filter import InputFilter
|
||||
from rl_coach.filters.observation.observation_clipping_filter import ObservationClippingFilter
|
||||
from rl_coach.filters.observation.observation_normalization_filter import ObservationNormalizationFilter
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
@@ -44,7 +45,7 @@ actor_network.input_embedders_parameters = {
|
||||
'observation': InputEmbedderParameters(scheme=EmbedderScheme.Empty),
|
||||
'desired_goal': InputEmbedderParameters(scheme=EmbedderScheme.Empty)
|
||||
}
|
||||
actor_network.middleware_parameters = FCMiddlewareParameters(scheme=[Dense([256]), Dense([256]), Dense([256])])
|
||||
actor_network.middleware_parameters = FCMiddlewareParameters(scheme=[Dense(256), Dense(256), Dense(256)])
|
||||
actor_network.heads_parameters[0].batchnorm = False
|
||||
|
||||
# critic
|
||||
@@ -59,7 +60,7 @@ critic_network.input_embedders_parameters = {
|
||||
'desired_goal': InputEmbedderParameters(scheme=EmbedderScheme.Empty),
|
||||
'observation': InputEmbedderParameters(scheme=EmbedderScheme.Empty)
|
||||
}
|
||||
critic_network.middleware_parameters = FCMiddlewareParameters(scheme=[Dense([256]), Dense([256]), Dense([256])])
|
||||
critic_network.middleware_parameters = FCMiddlewareParameters(scheme=[Dense(256), Dense(256), Dense(256)])
|
||||
|
||||
agent_params.algorithm.discount = 0.98
|
||||
agent_params.algorithm.num_consecutive_playing_steps = EnvironmentEpisodes(1)
|
||||
@@ -90,10 +91,10 @@ agent_params.exploration.evaluation_epsilon = 0
|
||||
agent_params.exploration.continuous_exploration_policy_parameters.noise_percentage_schedule = ConstantSchedule(0.1)
|
||||
agent_params.exploration.continuous_exploration_policy_parameters.evaluation_noise_percentage = 0
|
||||
|
||||
agent_params.input_filter = MujocoInputFilter()
|
||||
agent_params.input_filter = InputFilter()
|
||||
agent_params.input_filter.add_observation_filter('observation', 'clipping', ObservationClippingFilter(-200, 200))
|
||||
|
||||
agent_params.pre_network_filter = MujocoInputFilter()
|
||||
agent_params.pre_network_filter = InputFilter()
|
||||
agent_params.pre_network_filter.add_observation_filter('observation', 'normalize_observation',
|
||||
ObservationNormalizationFilter(name='normalize_observation'))
|
||||
agent_params.pre_network_filter.add_observation_filter('achieved_goal', 'normalize_achieved_goal',
|
||||
@@ -104,27 +105,17 @@ agent_params.pre_network_filter.add_observation_filter('desired_goal', 'normaliz
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Mujoco()
|
||||
env_params.level = SingleLevelSelection(fetch_v1)
|
||||
env_params = GymVectorEnvironment(level=SingleLevelSelection(fetch_v1))
|
||||
env_params.custom_reward_threshold = -49
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
|
||||
|
||||
########
|
||||
# Test #
|
||||
########
|
||||
preset_validation_params = PresetValidationParameters()
|
||||
# preset_validation_params.test = True
|
||||
# preset_validation_params.min_reward_threshold = 200
|
||||
# preset_validation_params.max_episodes_to_achieve_reward = 600
|
||||
# preset_validation_params.reward_test_level = 'inverted_pendulum'
|
||||
preset_validation_params.trace_test_levels = ['slide', 'pick_and_place', 'push', 'reach']
|
||||
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from rl_coach.agents.policy_gradients_agent import PolicyGradientsAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod
|
||||
from rl_coach.environments.gym_environment import Mujoco, MujocoInputFilter
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
||||
from rl_coach.filters.filter import InputFilter
|
||||
from rl_coach.filters.observation.observation_normalization_filter import ObservationNormalizationFilter
|
||||
from rl_coach.filters.reward.reward_rescale_filter import RewardRescaleFilter
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
@@ -25,7 +25,7 @@ agent_params.algorithm.apply_gradients_every_x_episodes = 5
|
||||
agent_params.algorithm.num_steps_between_gradient_updates = 20000
|
||||
agent_params.network_wrappers['main'].learning_rate = 0.0005
|
||||
|
||||
agent_params.input_filter = MujocoInputFilter()
|
||||
agent_params.input_filter = InputFilter()
|
||||
agent_params.input_filter.add_reward_filter('rescale', RewardRescaleFilter(1/20.))
|
||||
agent_params.input_filter.add_observation_filter('observation', 'normalize', ObservationNormalizationFilter())
|
||||
|
||||
@@ -33,14 +33,9 @@ agent_params.input_filter.add_observation_filter('observation', 'normalize', Obs
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Mujoco()
|
||||
env_params.level = "InvertedPendulum-v2"
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = GymVectorEnvironment(level="InvertedPendulum-v2")
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params)
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters())
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from rl_coach.agents.bc_agent import BCAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.gym_environment import Atari
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
@@ -31,14 +30,9 @@ agent_params.memory.load_memory_from_file_path = 'datasets/montezuma_revenge.p'
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Atari()
|
||||
env_params.level = 'MontezumaRevenge-v0'
|
||||
env_params = Atari(level='MontezumaRevenge-v0')
|
||||
env_params.random_initialization_steps = 30
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
|
||||
########
|
||||
# Test #
|
||||
########
|
||||
@@ -46,5 +40,5 @@ preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.test_using_a_trace_test = False
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from rl_coach.agents.actor_critic_agent import ActorCriticAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Mujoco, mujoco_v2, MujocoInputFilter
|
||||
from rl_coach.exploration_policies.continuous_entropy import ContinuousEntropyParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2
|
||||
from rl_coach.filters.filter import InputFilter
|
||||
from rl_coach.filters.observation.observation_normalization_filter import ObservationNormalizationFilter
|
||||
from rl_coach.filters.reward.reward_rescale_filter import RewardRescaleFilter
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
@@ -27,21 +27,15 @@ agent_params.algorithm.num_steps_between_gradient_updates = 10000000
|
||||
agent_params.algorithm.beta_entropy = 0.0001
|
||||
agent_params.network_wrappers['main'].learning_rate = 0.00001
|
||||
|
||||
agent_params.input_filter = MujocoInputFilter()
|
||||
agent_params.input_filter = InputFilter()
|
||||
agent_params.input_filter.add_reward_filter('rescale', RewardRescaleFilter(1/20.))
|
||||
agent_params.input_filter.add_observation_filter('observation', 'normalize', ObservationNormalizationFilter())
|
||||
|
||||
agent_params.exploration = ContinuousEntropyParameters()
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Mujoco()
|
||||
env_params.level = SingleLevelSelection(mujoco_v2)
|
||||
env_params = GymVectorEnvironment(level=SingleLevelSelection(mujoco_v2))
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -55,7 +49,7 @@ preset_validation_params.reward_test_level = 'inverted_pendulum'
|
||||
preset_validation_params.trace_test_levels = ['inverted_pendulum', 'hopper']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from rl_coach.agents.actor_critic_agent import ActorCriticAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.architectures.tensorflow_components.middlewares.lstm_middleware import LSTMMiddlewareParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, MiddlewareScheme, PresetValidationParameters
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Mujoco, mujoco_v2, MujocoInputFilter
|
||||
from rl_coach.exploration_policies.continuous_entropy import ContinuousEntropyParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2
|
||||
from rl_coach.filters.filter import InputFilter
|
||||
from rl_coach.filters.observation.observation_normalization_filter import ObservationNormalizationFilter
|
||||
from rl_coach.filters.reward.reward_rescale_filter import RewardRescaleFilter
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
@@ -30,25 +30,18 @@ agent_params.algorithm.num_steps_between_gradient_updates = 20
|
||||
agent_params.algorithm.beta_entropy = 0.005
|
||||
agent_params.network_wrappers['main'].learning_rate = 0.00002
|
||||
agent_params.network_wrappers['main'].input_embedders_parameters['observation'] = \
|
||||
InputEmbedderParameters(scheme=[Dense([200])])
|
||||
InputEmbedderParameters(scheme=[Dense(200)])
|
||||
agent_params.network_wrappers['main'].middleware_parameters = LSTMMiddlewareParameters(scheme=MiddlewareScheme.Empty,
|
||||
number_of_lstm_cells=128)
|
||||
|
||||
agent_params.input_filter = MujocoInputFilter()
|
||||
agent_params.input_filter = InputFilter()
|
||||
agent_params.input_filter.add_reward_filter('rescale', RewardRescaleFilter(1/20.))
|
||||
agent_params.input_filter.add_observation_filter('observation', 'normalize', ObservationNormalizationFilter())
|
||||
|
||||
agent_params.exploration = ContinuousEntropyParameters()
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Mujoco()
|
||||
env_params.level = SingleLevelSelection(mujoco_v2)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = GymVectorEnvironment(level=SingleLevelSelection(mujoco_v2))
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -62,7 +55,7 @@ preset_validation_params.reward_test_level = 'inverted_pendulum'
|
||||
preset_validation_params.trace_test_levels = ['inverted_pendulum', 'hopper']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
|
||||
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import Mujoco, mujoco_v2, MujocoInputFilter
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2
|
||||
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
|
||||
from rl_coach.filters.filter import InputFilter
|
||||
from rl_coach.filters.observation.observation_normalization_filter import ObservationNormalizationFilter
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
@@ -28,8 +29,8 @@ agent_params = ClippedPPOAgentParameters()
|
||||
|
||||
agent_params.network_wrappers['main'].learning_rate = 0.0003
|
||||
agent_params.network_wrappers['main'].input_embedders_parameters['observation'].activation_function = 'tanh'
|
||||
agent_params.network_wrappers['main'].input_embedders_parameters['observation'].scheme = [Dense([64])]
|
||||
agent_params.network_wrappers['main'].middleware_parameters.scheme = [Dense([64])]
|
||||
agent_params.network_wrappers['main'].input_embedders_parameters['observation'].scheme = [Dense(64)]
|
||||
agent_params.network_wrappers['main'].middleware_parameters.scheme = [Dense(64)]
|
||||
agent_params.network_wrappers['main'].middleware_parameters.activation_function = 'tanh'
|
||||
agent_params.network_wrappers['main'].batch_size = 64
|
||||
agent_params.network_wrappers['main'].optimizer_epsilon = 1e-5
|
||||
@@ -43,21 +44,16 @@ agent_params.algorithm.discount = 0.99
|
||||
agent_params.algorithm.optimization_epochs = 10
|
||||
agent_params.algorithm.estimate_state_value_using_gae = True
|
||||
|
||||
agent_params.input_filter = MujocoInputFilter()
|
||||
agent_params.input_filter = InputFilter()
|
||||
agent_params.exploration = AdditiveNoiseParameters()
|
||||
agent_params.pre_network_filter = MujocoInputFilter()
|
||||
agent_params.pre_network_filter = InputFilter()
|
||||
agent_params.pre_network_filter.add_observation_filter('observation', 'normalize_observation',
|
||||
ObservationNormalizationFilter(name='normalize_observation'))
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = Mujoco()
|
||||
env_params.level = SingleLevelSelection(mujoco_v2)
|
||||
|
||||
vis_params = VisualizationParameters()
|
||||
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
|
||||
vis_params.dump_mp4 = False
|
||||
env_params = GymVectorEnvironment(level=SingleLevelSelection(mujoco_v2))
|
||||
|
||||
########
|
||||
# Test #
|
||||
@@ -70,7 +66,7 @@ preset_validation_params.reward_test_level = 'inverted_pendulum'
|
||||
preset_validation_params.trace_test_levels = ['inverted_pendulum', 'hopper']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=vis_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user