1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

parameter noise exploration - using Noisy Nets

This commit is contained in:
Gal Leibovich
2018-08-27 18:19:01 +03:00
parent 658b437079
commit 1aa2ab0590
49 changed files with 536 additions and 433 deletions

View File

@@ -24,11 +24,12 @@ from rl_coach.architectures.tensorflow_components.heads.policy_head import Polic
from rl_coach.architectures.tensorflow_components.heads.v_head import VHeadParameters from rl_coach.architectures.tensorflow_components.heads.v_head import VHeadParameters
from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters
from rl_coach.base_parameters import AlgorithmParameters, NetworkParameters, \ from rl_coach.base_parameters import AlgorithmParameters, NetworkParameters, \
AgentParameters, InputEmbedderParameters AgentParameters
from rl_coach.logger import screen from rl_coach.logger import screen
from rl_coach.memories.episodic.single_episode_buffer import SingleEpisodeBufferParameters from rl_coach.memories.episodic.single_episode_buffer import SingleEpisodeBufferParameters
from rl_coach.spaces import DiscreteActionSpace from rl_coach.spaces import DiscreteActionSpace
from rl_coach.utils import last_sample from rl_coach.utils import last_sample
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
class ActorCriticAlgorithmParameters(AlgorithmParameters): class ActorCriticAlgorithmParameters(AlgorithmParameters):

View File

@@ -21,10 +21,11 @@ import numpy as np
from rl_coach.agents.imitation_agent import ImitationAgent from rl_coach.agents.imitation_agent import ImitationAgent
from rl_coach.architectures.tensorflow_components.heads.policy_head import PolicyHeadParameters from rl_coach.architectures.tensorflow_components.heads.policy_head import PolicyHeadParameters
from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters
from rl_coach.base_parameters import AgentParameters, AlgorithmParameters, NetworkParameters, InputEmbedderParameters, \ from rl_coach.base_parameters import AgentParameters, AlgorithmParameters, NetworkParameters, \
MiddlewareScheme MiddlewareScheme
from rl_coach.exploration_policies.e_greedy import EGreedyParameters from rl_coach.exploration_policies.e_greedy import EGreedyParameters
from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplayParameters from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplayParameters
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
class BCAlgorithmParameters(AlgorithmParameters): class BCAlgorithmParameters(AlgorithmParameters):

View File

@@ -18,9 +18,11 @@ from typing import Union
import numpy as np import numpy as np
from rl_coach.agents.dqn_agent import DQNAgentParameters, DQNNetworkParameters from rl_coach.agents.dqn_agent import DQNNetworkParameters, DQNAlgorithmParameters
from rl_coach.agents.value_optimization_agent import ValueOptimizationAgent from rl_coach.agents.value_optimization_agent import ValueOptimizationAgent
from rl_coach.base_parameters import AgentParameters
from rl_coach.exploration_policies.bootstrapped import BootstrappedParameters from rl_coach.exploration_policies.bootstrapped import BootstrappedParameters
from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayParameters
class BootstrappedDQNNetworkParameters(DQNNetworkParameters): class BootstrappedDQNNetworkParameters(DQNNetworkParameters):
@@ -30,11 +32,12 @@ class BootstrappedDQNNetworkParameters(DQNNetworkParameters):
self.rescale_gradient_from_head_by_factor = [1.0/self.num_output_head_copies]*self.num_output_head_copies self.rescale_gradient_from_head_by_factor = [1.0/self.num_output_head_copies]*self.num_output_head_copies
class BootstrappedDQNAgentParameters(DQNAgentParameters): class BootstrappedDQNAgentParameters(AgentParameters):
def __init__(self): def __init__(self):
super().__init__() super().__init__(algorithm=DQNAlgorithmParameters(),
self.network_wrappers = {"main": BootstrappedDQNNetworkParameters()} exploration=BootstrappedParameters(),
self.exploration = BootstrappedParameters() memory=ExperienceReplayParameters(),
networks={"main": BootstrappedDQNNetworkParameters()})
@property @property
def path(self): def path(self):

View File

@@ -18,7 +18,7 @@ from typing import Union
import numpy as np import numpy as np
from rl_coach.agents.dqn_agent import DQNNetworkParameters, DQNAlgorithmParameters from rl_coach.agents.dqn_agent import DQNNetworkParameters, DQNAlgorithmParameters, DQNAgentParameters
from rl_coach.agents.value_optimization_agent import ValueOptimizationAgent from rl_coach.agents.value_optimization_agent import ValueOptimizationAgent
from rl_coach.architectures.tensorflow_components.heads.categorical_q_head import CategoricalQHeadParameters from rl_coach.architectures.tensorflow_components.heads.categorical_q_head import CategoricalQHeadParameters
from rl_coach.base_parameters import AgentParameters from rl_coach.base_parameters import AgentParameters
@@ -49,12 +49,12 @@ class CategoricalDQNExplorationParameters(EGreedyParameters):
self.evaluation_epsilon = 0.001 self.evaluation_epsilon = 0.001
class CategoricalDQNAgentParameters(AgentParameters): class CategoricalDQNAgentParameters(DQNAgentParameters):
def __init__(self): def __init__(self):
super().__init__(algorithm=CategoricalDQNAlgorithmParameters(), super().__init__()
exploration=CategoricalDQNExplorationParameters(), self.algorithm = CategoricalDQNAlgorithmParameters()
memory=ExperienceReplayParameters(), self.exploration = CategoricalDQNExplorationParameters()
networks={"main": CategoricalDQNNetworkParameters()}) self.network_wrappers = {"main": CategoricalDQNNetworkParameters()}
@property @property
def path(self): def path(self):

View File

@@ -27,7 +27,8 @@ from rl_coach.architectures.tensorflow_components.heads.ppo_head import PPOHeadP
from rl_coach.architectures.tensorflow_components.heads.v_head import VHeadParameters from rl_coach.architectures.tensorflow_components.heads.v_head import VHeadParameters
from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters
from rl_coach.base_parameters import AlgorithmParameters, NetworkParameters, \ from rl_coach.base_parameters import AlgorithmParameters, NetworkParameters, \
AgentParameters, InputEmbedderParameters AgentParameters
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
from rl_coach.core_types import EnvironmentSteps, Batch, EnvResponse, StateType from rl_coach.core_types import EnvironmentSteps, Batch, EnvResponse, StateType
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
from rl_coach.logger import screen from rl_coach.logger import screen

View File

@@ -22,11 +22,12 @@ import numpy as np
from rl_coach.agents.actor_critic_agent import ActorCriticAgent from rl_coach.agents.actor_critic_agent import ActorCriticAgent
from rl_coach.agents.agent import Agent from rl_coach.agents.agent import Agent
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
from rl_coach.architectures.tensorflow_components.heads.ddpg_actor_head import DDPGActorHeadParameters from rl_coach.architectures.tensorflow_components.heads.ddpg_actor_head import DDPGActorHeadParameters
from rl_coach.architectures.tensorflow_components.heads.v_head import VHeadParameters from rl_coach.architectures.tensorflow_components.heads.v_head import VHeadParameters
from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters
from rl_coach.base_parameters import NetworkParameters, AlgorithmParameters, \ from rl_coach.base_parameters import NetworkParameters, AlgorithmParameters, \
AgentParameters, InputEmbedderParameters, EmbedderScheme AgentParameters, EmbedderScheme
from rl_coach.core_types import ActionInfo, EnvironmentSteps from rl_coach.core_types import ActionInfo, EnvironmentSteps
from rl_coach.exploration_policies.ou_process import OUProcessParameters from rl_coach.exploration_policies.ou_process import OUProcessParameters
from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplayParameters from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplayParameters

View File

@@ -26,8 +26,9 @@ from rl_coach.architectures.tensorflow_components.heads.measurements_prediction_
MeasurementsPredictionHeadParameters MeasurementsPredictionHeadParameters
from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters
from rl_coach.base_parameters import AlgorithmParameters, AgentParameters, NetworkParameters, \ from rl_coach.base_parameters import AlgorithmParameters, AgentParameters, NetworkParameters, \
InputEmbedderParameters, MiddlewareScheme MiddlewareScheme
from rl_coach.core_types import ActionInfo, EnvironmentSteps, RunPhase from rl_coach.core_types import ActionInfo, EnvironmentSteps, RunPhase
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
from rl_coach.exploration_policies.e_greedy import EGreedyParameters from rl_coach.exploration_policies.e_greedy import EGreedyParameters
from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplayParameters from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplayParameters
from rl_coach.memories.memory import MemoryGranularity from rl_coach.memories.memory import MemoryGranularity

View File

@@ -22,7 +22,8 @@ from rl_coach.agents.value_optimization_agent import ValueOptimizationAgent
from rl_coach.architectures.tensorflow_components.heads.q_head import QHeadParameters from rl_coach.architectures.tensorflow_components.heads.q_head import QHeadParameters
from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters
from rl_coach.base_parameters import AlgorithmParameters, NetworkParameters, AgentParameters, \ from rl_coach.base_parameters import AlgorithmParameters, NetworkParameters, AgentParameters, \
InputEmbedderParameters, MiddlewareScheme MiddlewareScheme
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
from rl_coach.core_types import EnvironmentSteps from rl_coach.core_types import EnvironmentSteps
from rl_coach.exploration_policies.e_greedy import EGreedyParameters from rl_coach.exploration_policies.e_greedy import EGreedyParameters
from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayParameters from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayParameters

View File

@@ -25,8 +25,9 @@ from rl_coach.agents.agent import Agent
from rl_coach.agents.bc_agent import BCNetworkParameters from rl_coach.agents.bc_agent import BCNetworkParameters
from rl_coach.architectures.tensorflow_components.heads.policy_head import PolicyHeadParameters from rl_coach.architectures.tensorflow_components.heads.policy_head import PolicyHeadParameters
from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters
from rl_coach.base_parameters import AlgorithmParameters, NetworkParameters, InputEmbedderParameters, EmbedderScheme, \ from rl_coach.base_parameters import AlgorithmParameters, NetworkParameters, EmbedderScheme, \
AgentParameters AgentParameters
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
from rl_coach.core_types import ActionInfo from rl_coach.core_types import ActionInfo
from rl_coach.exploration_policies.e_greedy import EGreedyParameters from rl_coach.exploration_policies.e_greedy import EGreedyParameters
from rl_coach.logger import screen from rl_coach.logger import screen

View File

@@ -22,8 +22,9 @@ from rl_coach.agents.policy_optimization_agent import PolicyOptimizationAgent
from rl_coach.agents.value_optimization_agent import ValueOptimizationAgent from rl_coach.agents.value_optimization_agent import ValueOptimizationAgent
from rl_coach.architectures.tensorflow_components.heads.q_head import QHeadParameters from rl_coach.architectures.tensorflow_components.heads.q_head import QHeadParameters
from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters
from rl_coach.base_parameters import AlgorithmParameters, AgentParameters, NetworkParameters, \ from rl_coach.base_parameters import AlgorithmParameters, AgentParameters, NetworkParameters
InputEmbedderParameters from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
from rl_coach.core_types import EnvironmentSteps from rl_coach.core_types import EnvironmentSteps
from rl_coach.exploration_policies.e_greedy import EGreedyParameters from rl_coach.exploration_policies.e_greedy import EGreedyParameters
from rl_coach.memories.episodic.single_episode_buffer import SingleEpisodeBufferParameters from rl_coach.memories.episodic.single_episode_buffer import SingleEpisodeBufferParameters

View File

@@ -22,7 +22,9 @@ from rl_coach.agents.value_optimization_agent import ValueOptimizationAgent
from rl_coach.architectures.tensorflow_components.heads.naf_head import NAFHeadParameters from rl_coach.architectures.tensorflow_components.heads.naf_head import NAFHeadParameters
from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters
from rl_coach.base_parameters import AlgorithmParameters, AgentParameters, \ from rl_coach.base_parameters import AlgorithmParameters, AgentParameters, \
NetworkParameters, InputEmbedderParameters NetworkParameters
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
from rl_coach.core_types import ActionInfo, EnvironmentSteps from rl_coach.core_types import ActionInfo, EnvironmentSteps
from rl_coach.exploration_policies.ou_process import OUProcessParameters from rl_coach.exploration_policies.ou_process import OUProcessParameters
from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplayParameters from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplayParameters

View File

@@ -23,8 +23,9 @@ import numpy as np
from rl_coach.agents.value_optimization_agent import ValueOptimizationAgent from rl_coach.agents.value_optimization_agent import ValueOptimizationAgent
from rl_coach.architectures.tensorflow_components.heads.dnd_q_head import DNDQHeadParameters from rl_coach.architectures.tensorflow_components.heads.dnd_q_head import DNDQHeadParameters
from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters
from rl_coach.base_parameters import AlgorithmParameters, NetworkParameters, AgentParameters, \ from rl_coach.base_parameters import AlgorithmParameters, NetworkParameters, AgentParameters
InputEmbedderParameters from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
from rl_coach.core_types import RunPhase, EnvironmentSteps, Episode, StateType from rl_coach.core_types import RunPhase, EnvironmentSteps, Episode, StateType
from rl_coach.exploration_policies.e_greedy import EGreedyParameters from rl_coach.exploration_policies.e_greedy import EGreedyParameters
from rl_coach.logger import screen from rl_coach.logger import screen

View File

@@ -22,7 +22,9 @@ from rl_coach.agents.policy_optimization_agent import PolicyOptimizationAgent, P
from rl_coach.architectures.tensorflow_components.heads.policy_head import PolicyHeadParameters from rl_coach.architectures.tensorflow_components.heads.policy_head import PolicyHeadParameters
from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters
from rl_coach.base_parameters import NetworkParameters, AlgorithmParameters, \ from rl_coach.base_parameters import NetworkParameters, AlgorithmParameters, \
AgentParameters, InputEmbedderParameters AgentParameters
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
from rl_coach.logger import screen from rl_coach.logger import screen
from rl_coach.memories.episodic.single_episode_buffer import SingleEpisodeBufferParameters from rl_coach.memories.episodic.single_episode_buffer import SingleEpisodeBufferParameters

View File

@@ -26,7 +26,9 @@ from rl_coach.architectures.tensorflow_components.heads.ppo_head import PPOHeadP
from rl_coach.architectures.tensorflow_components.heads.v_head import VHeadParameters from rl_coach.architectures.tensorflow_components.heads.v_head import VHeadParameters
from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters
from rl_coach.base_parameters import AlgorithmParameters, NetworkParameters, \ from rl_coach.base_parameters import AlgorithmParameters, NetworkParameters, \
AgentParameters, InputEmbedderParameters, DistributedTaskParameters AgentParameters, DistributedTaskParameters
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
from rl_coach.core_types import EnvironmentSteps, Batch from rl_coach.core_types import EnvironmentSteps, Batch
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
from rl_coach.logger import screen from rl_coach.logger import screen

View File

@@ -13,9 +13,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import math
import time import time
from typing import List from typing import List, Union
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
@@ -73,20 +73,87 @@ class Conv2d(object):
class Dense(object): class Dense(object):
def __init__(self, params: List): def __init__(self, params: Union[List, int]):
""" """
:param params: list of [num_output_neurons] :param params: list of [num_output_neurons]
""" """
self.params = params self.params = force_list(params)
def __call__(self, input_layer, name: str): def __call__(self, input_layer, name: str, kernel_initializer=None, activation=None):
""" """
returns a tensorflow dense layer returns a tensorflow dense layer
:param input_layer: previous layer :param input_layer: previous layer
:param name: layer name :param name: layer name
:return: dense layer :return: dense layer
""" """
return tf.layers.dense(input_layer, self.params[0], name=name) return tf.layers.dense(input_layer, self.params[0], name=name, kernel_initializer=kernel_initializer,
activation=activation)
class NoisyNetDense(object):
"""
A factorized Noisy Net layer
https://arxiv.org/abs/1706.10295.
"""
def __init__(self, params: List):
"""
:param params: list of [num_output_neurons]
"""
self.params = force_list(params)
self.sigma0 = 0.5
def __call__(self, input_layer, name: str, kernel_initializer=None, activation=None):
"""
returns a NoisyNet dense layer
:param input_layer: previous layer
:param name: layer name
:param kernel_initializer: initializer for kernels. Default is to use Gaussian noise that preserves stddev.
:param activation: the activation function
:return: dense layer
"""
#TODO: noise sampling should be externally controlled. DQN is fine with sampling noise for every
# forward (either act or train, both for online and target networks).
# A3C, on the other hand, should sample noise only when policy changes (i.e. after every t_max steps)
num_inputs = input_layer.get_shape()[-1].value
num_outputs = self.params[0]
stddev = 1 / math.sqrt(num_inputs)
activation = activation if activation is not None else (lambda x: x)
if kernel_initializer is None:
kernel_mean_initializer = tf.random_uniform_initializer(-stddev, stddev)
kernel_stddev_initializer = tf.random_uniform_initializer(-stddev * self.sigma0, stddev * self.sigma0)
else:
kernel_mean_initializer = kernel_stddev_initializer = kernel_initializer
with tf.variable_scope(None, default_name=name):
weight_mean = tf.get_variable('weight_mean', shape=(num_inputs, num_outputs),
initializer=kernel_mean_initializer)
bias_mean = tf.get_variable('bias_mean', shape=(num_outputs,), initializer=tf.zeros_initializer())
weight_stddev = tf.get_variable('weight_stddev', shape=(num_inputs, num_outputs),
initializer=kernel_stddev_initializer)
bias_stddev = tf.get_variable('bias_stddev', shape=(num_outputs,),
initializer=kernel_stddev_initializer)
bias_noise = self.f(tf.random_normal((num_outputs,)))
weight_noise = self.factorized_noise(num_inputs, num_outputs)
bias = bias_mean + bias_stddev * bias_noise
weight = weight_mean + weight_stddev * weight_noise
return activation(tf.matmul(input_layer, weight) + bias)
def factorized_noise(self, inputs, outputs):
# TODO: use factorized noise only for compute intensive algos (e.g. DQN).
# lighter algos (e.g. DQN) should not use it
noise1 = self.f(tf.random_normal((inputs, 1)))
noise2 = self.f(tf.random_normal((1, outputs)))
return tf.matmul(noise1, noise2)
@staticmethod
def f(values):
return tf.sqrt(tf.abs(values)) * tf.sign(values)
def variable_summaries(var): def variable_summaries(var):

View File

@@ -19,11 +19,40 @@ from typing import List, Union
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from rl_coach.architectures.tensorflow_components.architecture import batchnorm_activation_dropout from rl_coach.architectures.tensorflow_components.architecture import batchnorm_activation_dropout, Dense
from rl_coach.base_parameters import EmbedderScheme from rl_coach.base_parameters import EmbedderScheme, NetworkComponentParameters
from rl_coach.core_types import InputEmbedding from rl_coach.core_types import InputEmbedding
class InputEmbedderParameters(NetworkComponentParameters):
def __init__(self, activation_function: str='relu', scheme: Union[List, EmbedderScheme]=EmbedderScheme.Medium,
batchnorm: bool=False, dropout=False, name: str='embedder', input_rescaling=None, input_offset=None,
input_clipping=None, dense_layer=Dense):
super().__init__(dense_layer=dense_layer)
self.activation_function = activation_function
self.scheme = scheme
self.batchnorm = batchnorm
self.dropout = dropout
if input_rescaling is None:
input_rescaling = {'image': 255.0, 'vector': 1.0}
if input_offset is None:
input_offset = {'image': 0.0, 'vector': 0.0}
self.input_rescaling = input_rescaling
self.input_offset = input_offset
self.input_clipping = input_clipping
self.name = name
@property
def path(self):
return {
"image": 'image_embedder:ImageEmbedder',
"vector": 'vector_embedder:VectorEmbedder'
}
class InputEmbedder(object): class InputEmbedder(object):
""" """
An input embedder is the first part of the network, which takes the input from the state and produces a vector An input embedder is the first part of the network, which takes the input from the state and produces a vector
@@ -32,7 +61,7 @@ class InputEmbedder(object):
""" """
def __init__(self, input_size: List[int], activation_function=tf.nn.relu, def __init__(self, input_size: List[int], activation_function=tf.nn.relu,
scheme: EmbedderScheme=None, batchnorm: bool=False, dropout: bool=False, scheme: EmbedderScheme=None, batchnorm: bool=False, dropout: bool=False,
name: str= "embedder", input_rescaling=1.0, input_offset=0.0, input_clipping=None): name: str= "embedder", input_rescaling=1.0, input_offset=0.0, input_clipping=None, dense_layer=Dense):
self.name = name self.name = name
self.input_size = input_size self.input_size = input_size
self.activation_function = activation_function self.activation_function = activation_function
@@ -47,6 +76,7 @@ class InputEmbedder(object):
self.input_rescaling = input_rescaling self.input_rescaling = input_rescaling
self.input_offset = input_offset self.input_offset = input_offset
self.input_clipping = input_clipping self.input_clipping = input_clipping
self.dense_layer = dense_layer
def __call__(self, prev_input_placeholder=None): def __call__(self, prev_input_placeholder=None):
with tf.variable_scope(self.get_name()): with tf.variable_scope(self.get_name()):

View File

@@ -18,7 +18,7 @@ from typing import List
import tensorflow as tf import tensorflow as tf
from rl_coach.architectures.tensorflow_components.architecture import Conv2d from rl_coach.architectures.tensorflow_components.architecture import Conv2d, Dense
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedder from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedder
from rl_coach.base_parameters import EmbedderScheme from rl_coach.base_parameters import EmbedderScheme
from rl_coach.core_types import InputImageEmbedding from rl_coach.core_types import InputImageEmbedding
@@ -30,45 +30,49 @@ class ImageEmbedder(InputEmbedder):
The embedder is intended for image like inputs, where the channels are expected to be the last axis. The embedder is intended for image like inputs, where the channels are expected to be the last axis.
The embedder also allows custom rescaling of the input prior to the neural network. The embedder also allows custom rescaling of the input prior to the neural network.
""" """
schemes = {
EmbedderScheme.Empty:
[],
EmbedderScheme.Shallow:
[
Conv2d([32, 3, 1])
],
# atari dqn
EmbedderScheme.Medium:
[
Conv2d([32, 8, 4]),
Conv2d([64, 4, 2]),
Conv2d([64, 3, 1])
],
# carla
EmbedderScheme.Deep: \
[
Conv2d([32, 5, 2]),
Conv2d([32, 3, 1]),
Conv2d([64, 3, 2]),
Conv2d([64, 3, 1]),
Conv2d([128, 3, 2]),
Conv2d([128, 3, 1]),
Conv2d([256, 3, 2]),
Conv2d([256, 3, 1])
]
}
def __init__(self, input_size: List[int], activation_function=tf.nn.relu, def __init__(self, input_size: List[int], activation_function=tf.nn.relu,
scheme: EmbedderScheme=EmbedderScheme.Medium, batchnorm: bool=False, dropout: bool=False, scheme: EmbedderScheme=EmbedderScheme.Medium, batchnorm: bool=False, dropout: bool=False,
name: str= "embedder", input_rescaling: float=255.0, input_offset: float=0.0, input_clipping=None): name: str= "embedder", input_rescaling: float=255.0, input_offset: float=0.0, input_clipping=None,
dense_layer=Dense):
super().__init__(input_size, activation_function, scheme, batchnorm, dropout, name, input_rescaling, super().__init__(input_size, activation_function, scheme, batchnorm, dropout, name, input_rescaling,
input_offset, input_clipping) input_offset, input_clipping, dense_layer=dense_layer)
self.return_type = InputImageEmbedding self.return_type = InputImageEmbedding
if len(input_size) != 3 and scheme != EmbedderScheme.Empty: if len(input_size) != 3 and scheme != EmbedderScheme.Empty:
raise ValueError("Image embedders expect the input size to have 3 dimensions. The given size is: {}" raise ValueError("Image embedders expect the input size to have 3 dimensions. The given size is: {}"
.format(input_size)) .format(input_size))
@property
def schemes(self):
return {
EmbedderScheme.Empty:
[],
EmbedderScheme.Shallow:
[
Conv2d([32, 3, 1])
],
# atari dqn
EmbedderScheme.Medium:
[
Conv2d([32, 8, 4]),
Conv2d([64, 4, 2]),
Conv2d([64, 3, 1])
],
# carla
EmbedderScheme.Deep: \
[
Conv2d([32, 5, 2]),
Conv2d([32, 3, 1]),
Conv2d([64, 3, 2]),
Conv2d([64, 3, 1]),
Conv2d([128, 3, 2]),
Conv2d([128, 3, 1]),
Conv2d([256, 3, 2]),
Conv2d([256, 3, 1])
]
}

View File

@@ -29,36 +29,40 @@ class VectorEmbedder(InputEmbedder):
An input embedder that is intended for inputs that can be represented as vectors. An input embedder that is intended for inputs that can be represented as vectors.
The embedder flattens the input, applies several dense layers to it and returns the output. The embedder flattens the input, applies several dense layers to it and returns the output.
""" """
schemes = {
EmbedderScheme.Empty:
[],
EmbedderScheme.Shallow:
[
Dense([128])
],
# dqn
EmbedderScheme.Medium:
[
Dense([256])
],
# carla
EmbedderScheme.Deep: \
[
Dense([128]),
Dense([128]),
Dense([128])
]
}
def __init__(self, input_size: List[int], activation_function=tf.nn.relu, def __init__(self, input_size: List[int], activation_function=tf.nn.relu,
scheme: EmbedderScheme=EmbedderScheme.Medium, batchnorm: bool=False, dropout: bool=False, scheme: EmbedderScheme=EmbedderScheme.Medium, batchnorm: bool=False, dropout: bool=False,
name: str= "embedder", input_rescaling: float=1.0, input_offset:float=0.0, input_clipping=None): name: str= "embedder", input_rescaling: float=1.0, input_offset:float=0.0, input_clipping=None,
dense_layer=Dense):
super().__init__(input_size, activation_function, scheme, batchnorm, dropout, name, super().__init__(input_size, activation_function, scheme, batchnorm, dropout, name,
input_rescaling, input_offset, input_clipping) input_rescaling, input_offset, input_clipping, dense_layer=dense_layer)
self.return_type = InputVectorEmbedding self.return_type = InputVectorEmbedding
if len(self.input_size) != 1 and scheme != EmbedderScheme.Empty: if len(self.input_size) != 1 and scheme != EmbedderScheme.Empty:
raise ValueError("The input size of a vector embedder must contain only a single dimension") raise ValueError("The input size of a vector embedder must contain only a single dimension")
@property
def schemes(self):
return {
EmbedderScheme.Empty:
[],
EmbedderScheme.Shallow:
[
self.dense_layer([128])
],
# dqn
EmbedderScheme.Medium:
[
self.dense_layer([256])
],
# carla
EmbedderScheme.Deep: \
[
self.dense_layer([128]),
self.dense_layer([128]),
self.dense_layer([128])
]
}

View File

@@ -20,10 +20,11 @@ from typing import Dict
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
from rl_coach.architectures.tensorflow_components.architecture import TensorFlowArchitecture from rl_coach.architectures.tensorflow_components.architecture import TensorFlowArchitecture
from rl_coach.architectures.tensorflow_components.heads.head import HeadParameters from rl_coach.architectures.tensorflow_components.heads.head import HeadParameters
from rl_coach.architectures.tensorflow_components.middlewares.middleware import MiddlewareParameters from rl_coach.architectures.tensorflow_components.middlewares.middleware import MiddlewareParameters
from rl_coach.base_parameters import AgentParameters, InputEmbedderParameters, EmbeddingMergerType from rl_coach.base_parameters import AgentParameters, EmbeddingMergerType
from rl_coach.core_types import PredictionType from rl_coach.core_types import PredictionType
from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace
from rl_coach.utils import get_all_subclasses, dynamic_import_and_instantiate_module_from_params from rl_coach.utils import get_all_subclasses, dynamic_import_and_instantiate_module_from_params

View File

@@ -16,6 +16,8 @@
import tensorflow as tf import tensorflow as tf
from rl_coach.architectures.tensorflow_components.architecture import Dense
from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters
from rl_coach.base_parameters import AgentParameters from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import QActionStateValue from rl_coach.core_types import QActionStateValue
@@ -23,14 +25,17 @@ from rl_coach.spaces import SpacesDefinition
class CategoricalQHeadParameters(HeadParameters): class CategoricalQHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='relu', name: str='categorical_q_head_params'): def __init__(self, activation_function: str ='relu', name: str='categorical_q_head_params', dense_layer=Dense):
super().__init__(parameterized_class=CategoricalQHead, activation_function=activation_function, name=name) super().__init__(parameterized_class=CategoricalQHead, activation_function=activation_function, name=name,
dense_layer=dense_layer)
class CategoricalQHead(Head): class CategoricalQHead(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str, def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str ='relu'): head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str ='relu',
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function) dense_layer=Dense):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
dense_layer=dense_layer)
self.name = 'categorical_dqn_head' self.name = 'categorical_dqn_head'
self.num_actions = len(self.spaces.action.actions) self.num_actions = len(self.spaces.action.actions)
self.num_atoms = agent_parameters.algorithm.atoms self.num_atoms = agent_parameters.algorithm.atoms
@@ -40,7 +45,7 @@ class CategoricalQHead(Head):
self.actions = tf.placeholder(tf.int32, [None], name="actions") self.actions = tf.placeholder(tf.int32, [None], name="actions")
self.input = [self.actions] self.input = [self.actions]
values_distribution = tf.layers.dense(input_layer, self.num_actions * self.num_atoms, name='output') values_distribution = self.dense_layer(self.num_actions * self.num_atoms)(input_layer, name='output')
values_distribution = tf.reshape(values_distribution, (tf.shape(values_distribution)[0], self.num_actions, values_distribution = tf.reshape(values_distribution, (tf.shape(values_distribution)[0], self.num_actions,
self.num_atoms)) self.num_atoms))
# softmax on atoms dimension # softmax on atoms dimension

View File

@@ -16,7 +16,7 @@
import tensorflow as tf import tensorflow as tf
from rl_coach.architectures.tensorflow_components.architecture import batchnorm_activation_dropout from rl_coach.architectures.tensorflow_components.architecture import batchnorm_activation_dropout, Dense
from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters
from rl_coach.base_parameters import AgentParameters from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import ActionProbabilities from rl_coach.core_types import ActionProbabilities
@@ -24,16 +24,19 @@ from rl_coach.spaces import SpacesDefinition
class DDPGActorHeadParameters(HeadParameters): class DDPGActorHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='tanh', name: str='policy_head_params', batchnorm: bool=True): def __init__(self, activation_function: str ='tanh', name: str='policy_head_params', batchnorm: bool=True,
super().__init__(parameterized_class=DDPGActor, activation_function=activation_function, name=name) dense_layer=Dense):
super().__init__(parameterized_class=DDPGActor, activation_function=activation_function, name=name,
dense_layer=dense_layer)
self.batchnorm = batchnorm self.batchnorm = batchnorm
class DDPGActor(Head): class DDPGActor(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str, def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='tanh', head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='tanh',
batchnorm: bool=True): batchnorm: bool=True, dense_layer=Dense):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function) super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
dense_layer=dense_layer)
self.name = 'ddpg_actor_head' self.name = 'ddpg_actor_head'
self.return_type = ActionProbabilities self.return_type = ActionProbabilities
@@ -50,7 +53,7 @@ class DDPGActor(Head):
def _build_module(self, input_layer): def _build_module(self, input_layer):
# mean # mean
pre_activation_policy_values_mean = tf.layers.dense(input_layer, self.num_actions, name='fc_mean') pre_activation_policy_values_mean = self.dense_layer(self.num_actions)(input_layer, name='fc_mean')
policy_values_mean = batchnorm_activation_dropout(pre_activation_policy_values_mean, self.batchnorm, policy_values_mean = batchnorm_activation_dropout(pre_activation_policy_values_mean, self.batchnorm,
self.activation_function, self.activation_function,
False, 0, 0)[-1] False, 0, 0)[-1]

View File

@@ -15,6 +15,7 @@
# #
import tensorflow as tf import tensorflow as tf
from rl_coach.architectures.tensorflow_components.architecture import Dense
from rl_coach.architectures.tensorflow_components.heads.head import HeadParameters from rl_coach.architectures.tensorflow_components.heads.head import HeadParameters
from rl_coach.architectures.tensorflow_components.heads.q_head import QHead from rl_coach.architectures.tensorflow_components.heads.q_head import QHead
from rl_coach.base_parameters import AgentParameters from rl_coach.base_parameters import AgentParameters
@@ -23,14 +24,17 @@ from rl_coach.spaces import SpacesDefinition
class DNDQHeadParameters(HeadParameters): class DNDQHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='relu', name: str='dnd_q_head_params'): def __init__(self, activation_function: str ='relu', name: str='dnd_q_head_params', dense_layer=Dense):
super().__init__(parameterized_class=DNDQHead, activation_function=activation_function, name=name) super().__init__(parameterized_class=DNDQHead, activation_function=activation_function, name=name,
dense_layer=dense_layer)
class DNDQHead(QHead): class DNDQHead(QHead):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str, def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu'): head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu',
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function) dense_layer=Dense):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
dense_layer=dense_layer)
self.name = 'dnd_q_values_head' self.name = 'dnd_q_values_head'
self.DND_size = agent_parameters.algorithm.dnd_size self.DND_size = agent_parameters.algorithm.dnd_size
self.DND_key_error_threshold = agent_parameters.algorithm.DND_key_error_threshold self.DND_key_error_threshold = agent_parameters.algorithm.DND_key_error_threshold

View File

@@ -16,6 +16,7 @@
import tensorflow as tf import tensorflow as tf
from rl_coach.architectures.tensorflow_components.architecture import Dense
from rl_coach.architectures.tensorflow_components.heads.head import HeadParameters from rl_coach.architectures.tensorflow_components.heads.head import HeadParameters
from rl_coach.architectures.tensorflow_components.heads.q_head import QHead from rl_coach.architectures.tensorflow_components.heads.q_head import QHead
from rl_coach.base_parameters import AgentParameters from rl_coach.base_parameters import AgentParameters
@@ -23,27 +24,29 @@ from rl_coach.spaces import SpacesDefinition
class DuelingQHeadParameters(HeadParameters): class DuelingQHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='relu', name: str='dueling_q_head_params'): def __init__(self, activation_function: str ='relu', name: str='dueling_q_head_params', dense_layer=Dense):
super().__init__(parameterized_class=DuelingQHead, activation_function=activation_function, name=name) super().__init__(parameterized_class=DuelingQHead, activation_function=activation_function, name=name, dense_layer=dense_layer)
class DuelingQHead(QHead): class DuelingQHead(QHead):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str, def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu'): head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu',
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function) dense_layer=Dense):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
dense_layer=dense_layer)
self.name = 'dueling_q_values_head' self.name = 'dueling_q_values_head'
def _build_module(self, input_layer): def _build_module(self, input_layer):
# state value tower - V # state value tower - V
with tf.variable_scope("state_value"): with tf.variable_scope("state_value"):
state_value = tf.layers.dense(input_layer, 512, activation=self.activation_function, name='fc1') state_value = self.dense_layer(512)(input_layer, activation=self.activation_function, name='fc1')
state_value = tf.layers.dense(state_value, 1, name='fc2') state_value = self.dense_layer(1)(state_value, name='fc2')
# state_value = tf.expand_dims(state_value, axis=-1) # state_value = tf.expand_dims(state_value, axis=-1)
# action advantage tower - A # action advantage tower - A
with tf.variable_scope("action_advantage"): with tf.variable_scope("action_advantage"):
action_advantage = tf.layers.dense(input_layer, 512, activation=self.activation_function, name='fc1') action_advantage = self.dense_layer(512)(input_layer, activation=self.activation_function, name='fc1')
action_advantage = tf.layers.dense(action_advantage, self.num_actions, name='fc2') action_advantage = self.dense_layer(self.num_actions)(action_advantage, name='fc2')
action_advantage = action_advantage - tf.reduce_mean(action_advantage) action_advantage = action_advantage - tf.reduce_mean(action_advantage)
# merge to state-action value function Q # merge to state-action value function Q

View File

@@ -18,8 +18,8 @@ from typing import Type
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python.ops.losses.losses_impl import Reduction from tensorflow.python.ops.losses.losses_impl import Reduction
from rl_coach.architectures.tensorflow_components.architecture import Dense
from rl_coach.base_parameters import AgentParameters, Parameters from rl_coach.base_parameters import AgentParameters, Parameters, NetworkComponentParameters
from rl_coach.spaces import SpacesDefinition from rl_coach.spaces import SpacesDefinition
from rl_coach.utils import force_list from rl_coach.utils import force_list
@@ -33,9 +33,10 @@ def normalized_columns_initializer(std=1.0):
return _initializer return _initializer
class HeadParameters(Parameters): class HeadParameters(NetworkComponentParameters):
def __init__(self, parameterized_class: Type['Head'], activation_function: str = 'relu', name: str= 'head'): def __init__(self, parameterized_class: Type['Head'], activation_function: str = 'relu', name: str= 'head',
super().__init__() dense_layer=Dense):
super().__init__(dense_layer=dense_layer)
self.activation_function = activation_function self.activation_function = activation_function
self.name = name self.name = name
self.parameterized_class_name = parameterized_class.__name__ self.parameterized_class_name = parameterized_class.__name__
@@ -48,7 +49,8 @@ class Head(object):
an assigned loss function. The heads are algorithm dependent. an assigned loss function. The heads are algorithm dependent.
""" """
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str, def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int=0, loss_weight: float=1., is_local: bool=True, activation_function: str='relu'): head_idx: int=0, loss_weight: float=1., is_local: bool=True, activation_function: str='relu',
dense_layer=Dense):
self.head_idx = head_idx self.head_idx = head_idx
self.network_name = network_name self.network_name = network_name
self.network_parameters = agent_parameters.network_wrappers[self.network_name] self.network_parameters = agent_parameters.network_wrappers[self.network_name]
@@ -66,6 +68,7 @@ class Head(object):
self.spaces = spaces self.spaces = spaces
self.return_type = None self.return_type = None
self.activation_function = activation_function self.activation_function = activation_function
self.dense_layer = dense_layer
def __call__(self, input_layer): def __call__(self, input_layer):
""" """

View File

@@ -16,6 +16,8 @@
import tensorflow as tf import tensorflow as tf
from rl_coach.architectures.tensorflow_components.architecture import Dense
from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters
from rl_coach.base_parameters import AgentParameters from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import Measurements from rl_coach.core_types import Measurements
@@ -23,15 +25,18 @@ from rl_coach.spaces import SpacesDefinition
class MeasurementsPredictionHeadParameters(HeadParameters): class MeasurementsPredictionHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='relu', name: str='measurements_prediction_head_params'): def __init__(self, activation_function: str ='relu', name: str='measurements_prediction_head_params',
dense_layer=Dense):
super().__init__(parameterized_class=MeasurementsPredictionHead, super().__init__(parameterized_class=MeasurementsPredictionHead,
activation_function=activation_function, name=name) activation_function=activation_function, name=name, dense_layer=dense_layer)
class MeasurementsPredictionHead(Head): class MeasurementsPredictionHead(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str, def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu'): head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu',
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function) dense_layer=Dense):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
dense_layer=dense_layer)
self.name = 'future_measurements_head' self.name = 'future_measurements_head'
self.num_actions = len(self.spaces.action.actions) self.num_actions = len(self.spaces.action.actions)
self.num_measurements = self.spaces.state['measurements'].shape[0] self.num_measurements = self.spaces.state['measurements'].shape[0]
@@ -43,15 +48,15 @@ class MeasurementsPredictionHead(Head):
# This is almost exactly the same as Dueling Network but we predict the future measurements for each action # This is almost exactly the same as Dueling Network but we predict the future measurements for each action
# actions expectation tower (expectation stream) - E # actions expectation tower (expectation stream) - E
with tf.variable_scope("expectation_stream"): with tf.variable_scope("expectation_stream"):
expectation_stream = tf.layers.dense(input_layer, 256, activation=self.activation_function, name='fc1') expectation_stream = self.dense_layer(256)(input_layer, activation=self.activation_function, name='fc1')
expectation_stream = tf.layers.dense(expectation_stream, self.multi_step_measurements_size, name='output') expectation_stream = self.dense_layer(self.multi_step_measurements_size)(expectation_stream, name='output')
expectation_stream = tf.expand_dims(expectation_stream, axis=1) expectation_stream = tf.expand_dims(expectation_stream, axis=1)
# action fine differences tower (action stream) - A # action fine differences tower (action stream) - A
with tf.variable_scope("action_stream"): with tf.variable_scope("action_stream"):
action_stream = tf.layers.dense(input_layer, 256, activation=self.activation_function, name='fc1') action_stream = self.dense_layer(256)(input_layer, activation=self.activation_function, name='fc1')
action_stream = tf.layers.dense(action_stream, self.num_actions * self.multi_step_measurements_size, action_stream = self.dense_layer(self.num_actions * self.multi_step_measurements_size)(action_stream,
name='output') name='output')
action_stream = tf.reshape(action_stream, action_stream = tf.reshape(action_stream,
(tf.shape(action_stream)[0], self.num_actions, self.multi_step_measurements_size)) (tf.shape(action_stream)[0], self.num_actions, self.multi_step_measurements_size))
action_stream = action_stream - tf.reduce_mean(action_stream, reduction_indices=1, keepdims=True) action_stream = action_stream - tf.reduce_mean(action_stream, reduction_indices=1, keepdims=True)

View File

@@ -16,6 +16,7 @@
import tensorflow as tf import tensorflow as tf
from rl_coach.architectures.tensorflow_components.architecture import Dense
from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters
from rl_coach.base_parameters import AgentParameters from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import QActionStateValue from rl_coach.core_types import QActionStateValue
@@ -24,14 +25,17 @@ from rl_coach.spaces import SpacesDefinition
class NAFHeadParameters(HeadParameters): class NAFHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='tanh', name: str='naf_head_params'): def __init__(self, activation_function: str ='tanh', name: str='naf_head_params', dense_layer=Dense):
super().__init__(parameterized_class=NAFHead, activation_function=activation_function, name=name) super().__init__(parameterized_class=NAFHead, activation_function=activation_function, name=name,
dense_layer=dense_layer)
class NAFHead(Head): class NAFHead(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str, def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True,activation_function: str='relu'): head_idx: int = 0, loss_weight: float = 1., is_local: bool = True,activation_function: str='relu',
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function) dense_layer=Dense):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
dense_layer=dense_layer)
if not isinstance(self.spaces.action, BoxActionSpace): if not isinstance(self.spaces.action, BoxActionSpace):
raise ValueError("NAF works only for continuous action spaces (BoxActionSpace)") raise ValueError("NAF works only for continuous action spaces (BoxActionSpace)")
@@ -50,15 +54,15 @@ class NAFHead(Head):
self.input = self.action self.input = self.action
# V Head # V Head
self.V = tf.layers.dense(input_layer, 1, name='V') self.V = self.dense_layer(1)(input_layer, name='V')
# mu Head # mu Head
mu_unscaled = tf.layers.dense(input_layer, self.num_actions, activation=self.activation_function, name='mu_unscaled') mu_unscaled = self.dense_layer(self.num_actions)(input_layer, activation=self.activation_function, name='mu_unscaled')
self.mu = tf.multiply(mu_unscaled, self.output_scale, name='mu') self.mu = tf.multiply(mu_unscaled, self.output_scale, name='mu')
# A Head # A Head
# l_vector is a vector that includes a lower-triangular matrix values # l_vector is a vector that includes a lower-triangular matrix values
self.l_vector = tf.layers.dense(input_layer, (self.num_actions * (self.num_actions + 1)) / 2, name='l_vector') self.l_vector = self.dense_layer((self.num_actions * (self.num_actions + 1)) / 2)(input_layer, name='l_vector')
# Convert l to a lower triangular matrix and exponentiate its diagonal # Convert l to a lower triangular matrix and exponentiate its diagonal

View File

@@ -17,6 +17,7 @@
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from rl_coach.architectures.tensorflow_components.architecture import Dense
from rl_coach.architectures.tensorflow_components.heads.head import Head, normalized_columns_initializer, HeadParameters from rl_coach.architectures.tensorflow_components.heads.head import Head, normalized_columns_initializer, HeadParameters
from rl_coach.base_parameters import AgentParameters from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import ActionProbabilities from rl_coach.core_types import ActionProbabilities
@@ -27,14 +28,17 @@ from rl_coach.utils import eps
class PolicyHeadParameters(HeadParameters): class PolicyHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='tanh', name: str='policy_head_params'): def __init__(self, activation_function: str ='tanh', name: str='policy_head_params', dense_layer=Dense):
super().__init__(parameterized_class=PolicyHead, activation_function=activation_function, name=name) super().__init__(parameterized_class=PolicyHead, activation_function=activation_function, name=name,
dense_layer=dense_layer)
class PolicyHead(Head): class PolicyHead(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str, def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='tanh'): head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='tanh',
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function) dense_layer=Dense):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
dense_layer=dense_layer)
self.name = 'policy_values_head' self.name = 'policy_values_head'
self.return_type = ActionProbabilities self.return_type = ActionProbabilities
self.beta = None self.beta = None
@@ -90,7 +94,7 @@ class PolicyHead(Head):
num_actions = len(action_space.actions) num_actions = len(action_space.actions)
self.actions.append(tf.placeholder(tf.int32, [None], name="actions")) self.actions.append(tf.placeholder(tf.int32, [None], name="actions"))
policy_values = tf.layers.dense(input_layer, num_actions, name='fc') policy_values = self.dense_layer(num_actions)(input_layer, name='fc')
self.policy_probs = tf.nn.softmax(policy_values, name="policy") self.policy_probs = tf.nn.softmax(policy_values, name="policy")
# define the distributions for the policy and the old policy # define the distributions for the policy and the old policy
@@ -114,7 +118,7 @@ class PolicyHead(Head):
self.continuous_output_activation = None self.continuous_output_activation = None
# mean # mean
pre_activation_policy_values_mean = tf.layers.dense(input_layer, num_actions, name='fc_mean') pre_activation_policy_values_mean = self.dense_layer(num_actions)(input_layer, name='fc_mean')
policy_values_mean = self.continuous_output_activation(pre_activation_policy_values_mean) policy_values_mean = self.continuous_output_activation(pre_activation_policy_values_mean)
self.policy_mean = tf.multiply(policy_values_mean, self.output_scale, name='output_mean') self.policy_mean = tf.multiply(policy_values_mean, self.output_scale, name='output_mean')
@@ -123,8 +127,9 @@ class PolicyHead(Head):
# standard deviation # standard deviation
if isinstance(self.exploration_policy, ContinuousEntropyParameters): if isinstance(self.exploration_policy, ContinuousEntropyParameters):
# the stdev is an output of the network and uses a softplus activation as defined in A3C # the stdev is an output of the network and uses a softplus activation as defined in A3C
policy_values_std = tf.layers.dense(input_layer, num_actions, policy_values_std = self.dense_layer(num_actions)(input_layer,
kernel_initializer=normalized_columns_initializer(0.01), name='fc_std') kernel_initializer=normalized_columns_initializer(0.01),
name='fc_std')
self.policy_std = tf.nn.softplus(policy_values_std, name='output_variance') + eps self.policy_std = tf.nn.softplus(policy_values_std, name='output_variance') + eps
self.output.append(self.policy_std) self.output.append(self.policy_std)

View File

@@ -17,6 +17,7 @@
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from rl_coach.architectures.tensorflow_components.architecture import Dense
from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters, normalized_columns_initializer from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters, normalized_columns_initializer
from rl_coach.base_parameters import AgentParameters from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import ActionProbabilities from rl_coach.core_types import ActionProbabilities
@@ -26,14 +27,17 @@ from rl_coach.utils import eps
class PPOHeadParameters(HeadParameters): class PPOHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='tanh', name: str='ppo_head_params'): def __init__(self, activation_function: str ='tanh', name: str='ppo_head_params', dense_layer=Dense):
super().__init__(parameterized_class=PPOHead, activation_function=activation_function, name=name) super().__init__(parameterized_class=PPOHead, activation_function=activation_function, name=name,
dense_layer=dense_layer)
class PPOHead(Head): class PPOHead(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str, def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='tanh'): head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='tanh',
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function) dense_layer=Dense):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
dense_layer=dense_layer)
self.name = 'ppo_head' self.name = 'ppo_head'
self.return_type = ActionProbabilities self.return_type = ActionProbabilities
@@ -110,7 +114,7 @@ class PPOHead(Head):
# Policy Head # Policy Head
self.input = [self.actions, self.old_policy_mean] self.input = [self.actions, self.old_policy_mean]
policy_values = tf.layers.dense(input_layer, num_actions, name='policy_fc') policy_values = self.dense_layer(num_actions)(input_layer, name='policy_fc')
self.policy_mean = tf.nn.softmax(policy_values, name="policy") self.policy_mean = tf.nn.softmax(policy_values, name="policy")
# define the distributions for the policy and the old policy # define the distributions for the policy and the old policy
@@ -127,7 +131,7 @@ class PPOHead(Head):
self.old_policy_std = tf.placeholder(tf.float32, [None, num_actions], "old_policy_std") self.old_policy_std = tf.placeholder(tf.float32, [None, num_actions], "old_policy_std")
self.input = [self.actions, self.old_policy_mean, self.old_policy_std] self.input = [self.actions, self.old_policy_mean, self.old_policy_std]
self.policy_mean = tf.layers.dense(input_layer, num_actions, name='policy_mean', self.policy_mean = self.dense_layer(num_actions)(input_layer, name='policy_mean',
kernel_initializer=normalized_columns_initializer(0.01)) kernel_initializer=normalized_columns_initializer(0.01))
if self.is_local: if self.is_local:
self.policy_logstd = tf.Variable(np.zeros((1, num_actions)), dtype='float32', self.policy_logstd = tf.Variable(np.zeros((1, num_actions)), dtype='float32',

View File

@@ -16,6 +16,8 @@
import tensorflow as tf import tensorflow as tf
from rl_coach.architectures.tensorflow_components.architecture import Dense
from rl_coach.architectures.tensorflow_components.heads.head import Head, normalized_columns_initializer, HeadParameters from rl_coach.architectures.tensorflow_components.heads.head import Head, normalized_columns_initializer, HeadParameters
from rl_coach.base_parameters import AgentParameters from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import ActionProbabilities from rl_coach.core_types import ActionProbabilities
@@ -23,14 +25,17 @@ from rl_coach.spaces import SpacesDefinition
class PPOVHeadParameters(HeadParameters): class PPOVHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='relu', name: str='ppo_v_head_params'): def __init__(self, activation_function: str ='relu', name: str='ppo_v_head_params', dense_layer=Dense):
super().__init__(parameterized_class=PPOVHead, activation_function=activation_function, name=name) super().__init__(parameterized_class=PPOVHead, activation_function=activation_function, name=name,
dense_layer=dense_layer)
class PPOVHead(Head): class PPOVHead(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str, def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu'): head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu',
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function) dense_layer=Dense):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
dense_layer=dense_layer)
self.name = 'ppo_v_head' self.name = 'ppo_v_head'
self.clip_likelihood_ratio_using_epsilon = agent_parameters.algorithm.clip_likelihood_ratio_using_epsilon self.clip_likelihood_ratio_using_epsilon = agent_parameters.algorithm.clip_likelihood_ratio_using_epsilon
self.return_type = ActionProbabilities self.return_type = ActionProbabilities
@@ -38,7 +43,7 @@ class PPOVHead(Head):
def _build_module(self, input_layer): def _build_module(self, input_layer):
self.old_policy_value = tf.placeholder(tf.float32, [None], "old_policy_values") self.old_policy_value = tf.placeholder(tf.float32, [None], "old_policy_values")
self.input = [self.old_policy_value] self.input = [self.old_policy_value]
self.output = tf.layers.dense(input_layer, 1, name='output', self.output = self.dense_layer(1)(input_layer, name='output',
kernel_initializer=normalized_columns_initializer(1.0)) kernel_initializer=normalized_columns_initializer(1.0))
self.target = self.total_return = tf.placeholder(tf.float32, [None], name="total_return") self.target = self.total_return = tf.placeholder(tf.float32, [None], name="total_return")

View File

@@ -16,6 +16,8 @@
import tensorflow as tf import tensorflow as tf
from rl_coach.architectures.tensorflow_components.architecture import Dense
from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters
from rl_coach.base_parameters import AgentParameters from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import QActionStateValue from rl_coach.core_types import QActionStateValue
@@ -23,14 +25,17 @@ from rl_coach.spaces import SpacesDefinition, BoxActionSpace, DiscreteActionSpac
class QHeadParameters(HeadParameters): class QHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='relu', name: str='q_head_params'): def __init__(self, activation_function: str ='relu', name: str='q_head_params', dense_layer=Dense):
super().__init__(parameterized_class=QHead, activation_function=activation_function, name=name) super().__init__(parameterized_class=QHead, activation_function=activation_function, name=name,
dense_layer=dense_layer)
class QHead(Head): class QHead(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str, def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu'): head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu',
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function) dense_layer=Dense):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
dense_layer=dense_layer)
self.name = 'q_values_head' self.name = 'q_values_head'
if isinstance(self.spaces.action, BoxActionSpace): if isinstance(self.spaces.action, BoxActionSpace):
self.num_actions = 1 self.num_actions = 1
@@ -44,7 +49,7 @@ class QHead(Head):
def _build_module(self, input_layer): def _build_module(self, input_layer):
# Standard Q Network # Standard Q Network
self.output = tf.layers.dense(input_layer, self.num_actions, name='output') self.output = self.dense_layer(self.num_actions)(input_layer, name='output')

View File

@@ -16,6 +16,8 @@
import tensorflow as tf import tensorflow as tf
from rl_coach.architectures.tensorflow_components.architecture import Dense
from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters from rl_coach.architectures.tensorflow_components.heads.head import Head, HeadParameters
from rl_coach.base_parameters import AgentParameters from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import QActionStateValue from rl_coach.core_types import QActionStateValue
@@ -23,15 +25,18 @@ from rl_coach.spaces import SpacesDefinition
class QuantileRegressionQHeadParameters(HeadParameters): class QuantileRegressionQHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='relu', name: str='quantile_regression_q_head_params'): def __init__(self, activation_function: str ='relu', name: str='quantile_regression_q_head_params',
dense_layer=Dense):
super().__init__(parameterized_class=QuantileRegressionQHead, activation_function=activation_function, super().__init__(parameterized_class=QuantileRegressionQHead, activation_function=activation_function,
name=name) name=name, dense_layer=dense_layer)
class QuantileRegressionQHead(Head): class QuantileRegressionQHead(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str, def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu'): head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu',
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function) dense_layer=Dense):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
dense_layer=dense_layer)
self.name = 'quantile_regression_dqn_head' self.name = 'quantile_regression_dqn_head'
self.num_actions = len(self.spaces.action.actions) self.num_actions = len(self.spaces.action.actions)
self.num_atoms = agent_parameters.algorithm.atoms # we use atom / quantile interchangeably self.num_atoms = agent_parameters.algorithm.atoms # we use atom / quantile interchangeably
@@ -44,7 +49,7 @@ class QuantileRegressionQHead(Head):
self.input = [self.actions, self.quantile_midpoints] self.input = [self.actions, self.quantile_midpoints]
# the output of the head is the N unordered quantile locations {theta_1, ..., theta_N} # the output of the head is the N unordered quantile locations {theta_1, ..., theta_N}
quantiles_locations = tf.layers.dense(input_layer, self.num_actions * self.num_atoms, name='output') quantiles_locations = self.dense_layer(self.num_actions * self.num_atoms)(input_layer, name='output')
quantiles_locations = tf.reshape(quantiles_locations, (tf.shape(quantiles_locations)[0], self.num_actions, self.num_atoms)) quantiles_locations = tf.reshape(quantiles_locations, (tf.shape(quantiles_locations)[0], self.num_actions, self.num_atoms))
self.output = quantiles_locations self.output = quantiles_locations

View File

@@ -16,6 +16,8 @@
import tensorflow as tf import tensorflow as tf
from rl_coach.architectures.tensorflow_components.architecture import Dense
from rl_coach.architectures.tensorflow_components.heads.head import Head, normalized_columns_initializer, HeadParameters from rl_coach.architectures.tensorflow_components.heads.head import Head, normalized_columns_initializer, HeadParameters
from rl_coach.base_parameters import AgentParameters from rl_coach.base_parameters import AgentParameters
from rl_coach.core_types import VStateValue from rl_coach.core_types import VStateValue
@@ -23,14 +25,17 @@ from rl_coach.spaces import SpacesDefinition
class VHeadParameters(HeadParameters): class VHeadParameters(HeadParameters):
def __init__(self, activation_function: str ='relu', name: str='v_head_params'): def __init__(self, activation_function: str ='relu', name: str='v_head_params', dense_layer=Dense):
super().__init__(parameterized_class=VHead, activation_function=activation_function, name=name) super().__init__(parameterized_class=VHead, activation_function=activation_function, name=name,
dense_layer=dense_layer)
class VHead(Head): class VHead(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str, def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu'): head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu',
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function) dense_layer=Dense):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
dense_layer=dense_layer)
self.name = 'v_values_head' self.name = 'v_values_head'
self.return_type = VStateValue self.return_type = VStateValue
@@ -41,5 +46,5 @@ class VHead(Head):
def _build_module(self, input_layer): def _build_module(self, input_layer):
# Standard V Network # Standard V Network
self.output = tf.layers.dense(input_layer, 1, name='output', self.output = self.dense_layer(1)(input_layer, name='output',
kernel_initializer=normalized_columns_initializer(1.0)) kernel_initializer=normalized_columns_initializer(1.0))

View File

@@ -27,42 +27,18 @@ class FCMiddlewareParameters(MiddlewareParameters):
def __init__(self, activation_function='relu', def __init__(self, activation_function='relu',
scheme: Union[List, MiddlewareScheme] = MiddlewareScheme.Medium, scheme: Union[List, MiddlewareScheme] = MiddlewareScheme.Medium,
batchnorm: bool = False, dropout: bool = False, batchnorm: bool = False, dropout: bool = False,
name="middleware_fc_embedder"): name="middleware_fc_embedder", dense_layer=Dense):
super().__init__(parameterized_class=FCMiddleware, activation_function=activation_function, super().__init__(parameterized_class=FCMiddleware, activation_function=activation_function,
scheme=scheme, batchnorm=batchnorm, dropout=dropout, name=name) scheme=scheme, batchnorm=batchnorm, dropout=dropout, name=name, dense_layer=dense_layer)
class FCMiddleware(Middleware): class FCMiddleware(Middleware):
schemes = {
MiddlewareScheme.Empty:
[],
# ppo
MiddlewareScheme.Shallow:
[
Dense([64])
],
# dqn
MiddlewareScheme.Medium:
[
Dense([512])
],
MiddlewareScheme.Deep: \
[
Dense([128]),
Dense([128]),
Dense([128])
]
}
def __init__(self, activation_function=tf.nn.relu, def __init__(self, activation_function=tf.nn.relu,
scheme: MiddlewareScheme = MiddlewareScheme.Medium, scheme: MiddlewareScheme = MiddlewareScheme.Medium,
batchnorm: bool = False, dropout: bool = False, batchnorm: bool = False, dropout: bool = False,
name="middleware_fc_embedder"): name="middleware_fc_embedder", dense_layer=Dense):
super().__init__(activation_function=activation_function, batchnorm=batchnorm, super().__init__(activation_function=activation_function, batchnorm=batchnorm,
dropout=dropout, scheme=scheme, name=name) dropout=dropout, scheme=scheme, name=name, dense_layer=dense_layer)
self.return_type = Middleware_FC_Embedding self.return_type = Middleware_FC_Embedding
self.layers = [] self.layers = []
@@ -70,7 +46,7 @@ class FCMiddleware(Middleware):
self.layers.append(self.input) self.layers.append(self.input)
if isinstance(self.scheme, MiddlewareScheme): if isinstance(self.scheme, MiddlewareScheme):
layers_params = FCMiddleware.schemes[self.scheme] layers_params = self.schemes[self.scheme]
else: else:
layers_params = self.scheme layers_params = self.scheme
for idx, layer_params in enumerate(layers_params): for idx, layer_params in enumerate(layers_params):
@@ -84,3 +60,29 @@ class FCMiddleware(Middleware):
self.output = self.layers[-1] self.output = self.layers[-1]
@property
def schemes(self):
return {
MiddlewareScheme.Empty:
[],
# ppo
MiddlewareScheme.Shallow:
[
self.dense_layer([64])
],
# dqn
MiddlewareScheme.Medium:
[
self.dense_layer([512])
],
MiddlewareScheme.Deep: \
[
self.dense_layer([128]),
self.dense_layer([128]),
self.dense_layer([128])
]
}

View File

@@ -18,7 +18,7 @@
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from rl_coach.architectures.tensorflow_components.architecture import batchnorm_activation_dropout from rl_coach.architectures.tensorflow_components.architecture import batchnorm_activation_dropout, Dense
from rl_coach.architectures.tensorflow_components.middlewares.middleware import Middleware, MiddlewareParameters from rl_coach.architectures.tensorflow_components.middlewares.middleware import Middleware, MiddlewareParameters
from rl_coach.base_parameters import MiddlewareScheme from rl_coach.base_parameters import MiddlewareScheme
from rl_coach.core_types import Middleware_LSTM_Embedding from rl_coach.core_types import Middleware_LSTM_Embedding
@@ -28,43 +28,19 @@ class LSTMMiddlewareParameters(MiddlewareParameters):
def __init__(self, activation_function='relu', number_of_lstm_cells=256, def __init__(self, activation_function='relu', number_of_lstm_cells=256,
scheme: MiddlewareScheme = MiddlewareScheme.Medium, scheme: MiddlewareScheme = MiddlewareScheme.Medium,
batchnorm: bool = False, dropout: bool = False, batchnorm: bool = False, dropout: bool = False,
name="middleware_lstm_embedder"): name="middleware_lstm_embedder", dense_layer=Dense):
super().__init__(parameterized_class=LSTMMiddleware, activation_function=activation_function, super().__init__(parameterized_class=LSTMMiddleware, activation_function=activation_function,
scheme=scheme, batchnorm=batchnorm, dropout=dropout, name=name) scheme=scheme, batchnorm=batchnorm, dropout=dropout, name=name, dense_layer=dense_layer)
self.number_of_lstm_cells = number_of_lstm_cells self.number_of_lstm_cells = number_of_lstm_cells
class LSTMMiddleware(Middleware): class LSTMMiddleware(Middleware):
schemes = {
MiddlewareScheme.Empty:
[],
# ppo
MiddlewareScheme.Shallow:
[
[64]
],
# dqn
MiddlewareScheme.Medium:
[
[512]
],
MiddlewareScheme.Deep: \
[
[128],
[128],
[128]
]
}
def __init__(self, activation_function=tf.nn.relu, number_of_lstm_cells: int=256, def __init__(self, activation_function=tf.nn.relu, number_of_lstm_cells: int=256,
scheme: MiddlewareScheme = MiddlewareScheme.Medium, scheme: MiddlewareScheme = MiddlewareScheme.Medium,
batchnorm: bool = False, dropout: bool = False, batchnorm: bool = False, dropout: bool = False,
name="middleware_lstm_embedder"): name="middleware_lstm_embedder", dense_layer=Dense):
super().__init__(activation_function=activation_function, batchnorm=batchnorm, super().__init__(activation_function=activation_function, batchnorm=batchnorm,
dropout=dropout, scheme=scheme, name=name) dropout=dropout, scheme=scheme, name=name, dense_layer=dense_layer)
self.return_type = Middleware_LSTM_Embedding self.return_type = Middleware_LSTM_Embedding
self.number_of_lstm_cells = number_of_lstm_cells self.number_of_lstm_cells = number_of_lstm_cells
self.layers = [] self.layers = []
@@ -83,7 +59,7 @@ class LSTMMiddleware(Middleware):
# optionally insert some dense layers before the LSTM # optionally insert some dense layers before the LSTM
if isinstance(self.scheme, MiddlewareScheme): if isinstance(self.scheme, MiddlewareScheme):
layers_params = LSTMMiddleware.schemes[self.scheme] layers_params = self.schemes[self.scheme]
else: else:
layers_params = self.scheme layers_params = self.scheme
for idx, layer_params in enumerate(layers_params): for idx, layer_params in enumerate(layers_params):
@@ -111,3 +87,30 @@ class LSTMMiddleware(Middleware):
lstm_c, lstm_h = lstm_state lstm_c, lstm_h = lstm_state
self.state_out = (lstm_c[:1, :], lstm_h[:1, :]) self.state_out = (lstm_c[:1, :], lstm_h[:1, :])
self.output = tf.reshape(lstm_outputs, [-1, self.number_of_lstm_cells]) self.output = tf.reshape(lstm_outputs, [-1, self.number_of_lstm_cells])
@property
def schemes(self):
return {
MiddlewareScheme.Empty:
[],
# ppo
MiddlewareScheme.Shallow:
[
[64]
],
# dqn
MiddlewareScheme.Medium:
[
[512]
],
MiddlewareScheme.Deep: \
[
[128],
[128],
[128]
]
}

View File

@@ -17,16 +17,16 @@ from typing import Type, Union, List
import tensorflow as tf import tensorflow as tf
from rl_coach.base_parameters import MiddlewareScheme, Parameters from rl_coach.architectures.tensorflow_components.architecture import Dense
from rl_coach.base_parameters import MiddlewareScheme, Parameters, NetworkComponentParameters
from rl_coach.core_types import MiddlewareEmbedding from rl_coach.core_types import MiddlewareEmbedding
class MiddlewareParameters(Parameters): class MiddlewareParameters(NetworkComponentParameters):
def __init__(self, parameterized_class: Type['Middleware'], def __init__(self, parameterized_class: Type['Middleware'],
activation_function: str='relu', scheme: Union[List, MiddlewareScheme]=MiddlewareScheme.Medium, activation_function: str='relu', scheme: Union[List, MiddlewareScheme]=MiddlewareScheme.Medium,
batchnorm: bool=False, dropout: bool=False, batchnorm: bool=False, dropout: bool=False, name='middleware', dense_layer=Dense):
name='middleware'): super().__init__(dense_layer=dense_layer)
super().__init__()
self.activation_function = activation_function self.activation_function = activation_function
self.scheme = scheme self.scheme = scheme
self.batchnorm = batchnorm self.batchnorm = batchnorm
@@ -43,7 +43,7 @@ class Middleware(object):
""" """
def __init__(self, activation_function=tf.nn.relu, def __init__(self, activation_function=tf.nn.relu,
scheme: MiddlewareScheme = MiddlewareScheme.Medium, scheme: MiddlewareScheme = MiddlewareScheme.Medium,
batchnorm: bool = False, dropout: bool = False, name="middleware_embedder"): batchnorm: bool = False, dropout: bool = False, name="middleware_embedder", dense_layer=Dense):
self.name = name self.name = name
self.input = None self.input = None
self.output = None self.output = None
@@ -53,6 +53,7 @@ class Middleware(object):
self.dropout_rate = 0 self.dropout_rate = 0
self.scheme = scheme self.scheme = scheme
self.return_type = MiddlewareEmbedding self.return_type = MiddlewareEmbedding
self.dense_layer = dense_layer
def __call__(self, input_layer): def __call__(self, input_layer):
with tf.variable_scope(self.get_name()): with tf.variable_scope(self.get_name()):
@@ -66,3 +67,8 @@ class Middleware(object):
def get_name(self): def get_name(self):
return self.name return self.name
@property
def schemes(self):
raise NotImplementedError("Inheriting middleware must define schemes matching its allowed default "
"configurations.")

View File

@@ -199,7 +199,7 @@ class NetworkParameters(Parameters):
self.learning_rate_decay_steps = 0 self.learning_rate_decay_steps = 0
# structure # structure
self.input_embedders_parameters = [] self.input_embedders_parameters = {}
self.embedding_merger_type = EmbeddingMergerType.Concat self.embedding_merger_type = EmbeddingMergerType.Concat
self.middleware_parameters = None self.middleware_parameters = None
self.heads_parameters = [] self.heads_parameters = []
@@ -220,32 +220,9 @@ class NetworkParameters(Parameters):
self.tensorflow_support = True self.tensorflow_support = True
class InputEmbedderParameters(Parameters): class NetworkComponentParameters(Parameters):
def __init__(self, activation_function: str='relu', scheme: Union[List, EmbedderScheme]=EmbedderScheme.Medium, def __init__(self, dense_layer):
batchnorm: bool=False, dropout=False, name: str='embedder', input_rescaling=None, input_offset=None, self.dense_layer = dense_layer
input_clipping=None):
super().__init__()
self.activation_function = activation_function
self.scheme = scheme
self.batchnorm = batchnorm
self.dropout = dropout
if input_rescaling is None:
input_rescaling = {'image': 255.0, 'vector': 1.0}
if input_offset is None:
input_offset = {'image': 0.0, 'vector': 0.0}
self.input_rescaling = input_rescaling
self.input_offset = input_offset
self.input_clipping = input_clipping
self.name = name
@property
def path(self):
return {
"image": 'image_embedder:ImageEmbedder',
"vector": 'vector_embedder:VectorEmbedder'
}
class VisualizationParameters(Parameters): class VisualizationParameters(Parameters):
@@ -287,7 +264,7 @@ class AgentParameters(Parameters):
self.input_filter = None self.input_filter = None
self.output_filter = None self.output_filter = None
self.pre_network_filter = NoInputFilter() self.pre_network_filter = NoInputFilter()
self.full_name_id = None # TODO: do we really want to hold this parameters here? self.full_name_id = None # TODO: do we really want to hold this parameter here?
self.name = None self.name = None
self.is_a_highest_level_agent = True self.is_a_highest_level_agent = True
self.is_a_lowest_level_agent = True self.is_a_lowest_level_agent = True

View File

@@ -18,6 +18,7 @@ import random
import sys import sys
from os import path, environ from os import path, environ
from rl_coach.filters.action.partial_discrete_action_space_map import PartialDiscreteActionSpaceMap
from rl_coach.filters.observation.observation_rgb_to_y_filter import ObservationRGBToYFilter from rl_coach.filters.observation.observation_rgb_to_y_filter import ObservationRGBToYFilter
from rl_coach.filters.observation.observation_to_uint8_filter import ObservationToUInt8Filter from rl_coach.filters.observation.observation_to_uint8_filter import ObservationToUInt8Filter
@@ -208,7 +209,6 @@ class CarlaEnvironment(Environment):
[self.gas_strength, self.steering_strength], [self.gas_strength, self.steering_strength],
[self.brake_strength, -self.steering_strength], [self.brake_strength, -self.steering_strength],
[self.brake_strength, self.steering_strength]], [self.brake_strength, self.steering_strength]],
target_action_space=self.action_space,
descriptions=['NO-OP', 'TURN_LEFT', 'TURN_RIGHT', 'GAS', 'BRAKE', descriptions=['NO-OP', 'TURN_LEFT', 'TURN_RIGHT', 'GAS', 'BRAKE',
'GAS_AND_TURN_LEFT', 'GAS_AND_TURN_RIGHT', 'GAS_AND_TURN_LEFT', 'GAS_AND_TURN_RIGHT',
'BRAKE_AND_TURN_LEFT', 'BRAKE_AND_TURN_RIGHT'] 'BRAKE_AND_TURN_LEFT', 'BRAKE_AND_TURN_RIGHT']

View File

@@ -1,149 +0,0 @@
########################################################################################################################
####### Currently we are ignoring more complex cases including EnvironmentGroups - DO NOT USE THIS FILE ****************
########################################################################################################################
# #
# # Copyright (c) 2017 Intel Corporation
# #
# # Licensed under the Apache License, Version 2.0 (the "License");
# # you may not use this file except in compliance with the License.
# # You may obtain a copy of the License at
# #
# # http://www.apache.org/licenses/LICENSE-2.0
# #
# # Unless required by applicable law or agreed to in writing, software
# # distributed under the License is distributed on an "AS IS" BASIS,
# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# # See the License for the specific language governing permissions and
# # limitations under the License.
# #
#
# from typing import Union, List, Dict
# import numpy as np
# from environments import create_environment
# from environments.environment import Environment
# from environments.environment_interface import EnvironmentInterface, ActionType, ActionSpace
# from core_types import GoalType, Transition
#
#
# class EnvironmentGroup(EnvironmentInterface):
# """
# An EnvironmentGroup is a group of different environments.
# In the simple case, it will contain a single environment. But it can also contain multiple environments,
# where the agent can then act on them as a batch, such that the prediction of the action is more efficient.
# """
# def __init__(self, environments_parameters: List[Environment]):
# self.environments_parameters = environments_parameters
# self.environments = []
# self.action_space = []
# self.outgoing_control = []
# self._last_env_response = []
#
# @property
# def action_space(self) -> Union[List[ActionSpace], ActionSpace]:
# """
# Get the action space of the environment
# :return: the action space
# """
# return self.action_space
#
# @action_space.setter
# def action_space(self, val: Union[List[ActionSpace], ActionSpace]):
# """
# Set the action space of the environment
# :return: None
# """
# self.action_space = val
#
# @property
# def phase(self) -> RunPhase:
# """
# Get the phase of the environments group
# :return: the current phase
# """
# return self.phase
#
# @phase.setter
# def phase(self, val: RunPhase):
# """
# Change the phase of each one of the environments in the group
# :param val: the new phase
# :return: None
# """
# self.phase = val
# call_method_for_all(self.environments, 'phase', val)
#
# def _create_environments(self):
# """
# Create the environments using the given parameters and update the environments list
# :return: None
# """
# for environment_parameters in self.environments_parameters:
# environment = create_environment(environment_parameters)
# self.action_space = self.action_space.append(environment.action_space)
# self.environments.append(environment)
#
# @property
# def last_env_response(self) -> Union[List[Transition], Transition]:
# """
# Get the last environment response
# :return: a dictionary that contains the state, reward, etc.
# """
# return squeeze_list(self._last_env_response)
#
# @last_env_response.setter
# def last_env_response(self, val: Union[List[Transition], Transition]):
# """
# Set the last environment response
# :param val: the last environment response
# """
# self._last_env_response = force_list(val)
#
# def step(self, actions: Union[List[ActionType], ActionType]) -> List[Transition]:
# """
# Act in all the environments in the group.
# :param actions: can be either a single action if there is a single environment in the group, or a list of
# actions in case there are multiple environments in the group. Each action can be an action index
# or a numpy array representing a continuous action for example.
# :return: The responses from all the environments in the group
# """
#
# actions = force_list(actions)
# if len(actions) != len(self.environments):
# raise ValueError("The number of actions does not match the number of environments in the group")
#
# result = []
# for environment, action in zip(self.environments, actions):
# result.append(environment.step(action))
#
# self.last_env_response = result
#
# return result
#
# def reset(self, force_environment_reset: bool=False) -> List[Transition]:
# """
# Reset all the environments in the group
# :param force_environment_reset: force the reset of each one of the environments
# :return: a list of the environments responses
# """
# return call_method_for_all(self.environments, 'reset', force_environment_reset)
#
# def get_random_action(self) -> List[ActionType]:
# """
# Get a list of random action that can be applied on the environments in the group
# :return: a list of random actions
# """
# return call_method_for_all(self.environments, 'get_random_action')
#
# def set_goal(self, goal: GoalType) -> None:
# """
# Set the goal of each one of the environments in the group to be the given goal
# :param goal: a goal vector
# :return: None
# """
# # TODO: maybe enable setting multiple goals?
# call_method_for_all(self.environments, 'set_goal', goal)

View File

@@ -0,0 +1,79 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List, Dict
import numpy as np
from rl_coach.agents.dqn_agent import DQNAgentParameters
from rl_coach.architectures.tensorflow_components.architecture import NoisyNetDense
from rl_coach.base_parameters import AgentParameters, NetworkParameters
from rl_coach.spaces import ActionSpace, BoxActionSpace, DiscreteActionSpace
from rl_coach.core_types import ActionType
from rl_coach.exploration_policies.exploration_policy import ExplorationPolicy, ExplorationParameters
class ParameterNoiseParameters(ExplorationParameters):
def __init__(self, agent_params: AgentParameters):
super().__init__()
if not isinstance(agent_params, DQNAgentParameters):
raise ValueError("Currently only DQN variants are supported for using an exploration type of "
"ParameterNoise.")
self.network_params = agent_params.network_wrappers
@property
def path(self):
return 'rl_coach.exploration_policies.parameter_noise:ParameterNoise'
class ParameterNoise(ExplorationPolicy):
def __init__(self, network_params: Dict[str, NetworkParameters], action_space: ActionSpace):
"""
:param action_space: the action space used by the environment
:param alpha0:
"""
super().__init__(action_space)
self.network_params = network_params
self._replace_network_dense_layers()
def get_action(self, action_values: List[ActionType]) -> ActionType:
if type(self.action_space) == DiscreteActionSpace:
return np.argmax(action_values)
elif type(self.action_space) == BoxActionSpace:
action_values_mean = action_values[0].squeeze()
action_values_std = action_values[1].squeeze()
return np.random.normal(action_values_mean, action_values_std)
else:
raise ValueError("ActionSpace type {} is not supported for ParameterNoise.".format(type(self.action_space)))
def get_control_param(self):
return 0
def _replace_network_dense_layers(self):
# replace the dense type for all the networks components (embedders, mw, heads) with a NoisyNetDense
# NOTE: we are changing network params in a non-params class (an already instantiated class), this could have
# been prone to a bug, but since the networks are created very late in the game
# (after agent.init_environment_dependent()_modules is called) - then we are fine.
for network_wrapper_params in self.network_params.values():
for component_params in list(network_wrapper_params.input_embedders_parameters.values()) + \
[network_wrapper_params.middleware_parameters] + \
network_wrapper_params.heads_parameters:
component_params.dense_layer = NoisyNetDense

View File

@@ -1,8 +1,10 @@
from rl_coach.agents.dqn_agent import DQNAgentParameters from rl_coach.agents.dqn_agent import DQNAgentParameters
from rl_coach.architectures.tensorflow_components.architecture import Dense from rl_coach.architectures.tensorflow_components.architecture import Dense
from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme, InputEmbedderParameters, \ from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme, \
PresetValidationParameters PresetValidationParameters
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
from rl_coach.environments.gym_environment import Mujoco from rl_coach.environments.gym_environment import Mujoco
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
from rl_coach.graph_managers.graph_manager import ScheduleParameters from rl_coach.graph_managers.graph_manager import ScheduleParameters

View File

@@ -1,8 +1,9 @@
from rl_coach.agents.dqn_agent import DQNAgentParameters from rl_coach.agents.dqn_agent import DQNAgentParameters
from rl_coach.architectures.tensorflow_components.architecture import Dense from rl_coach.architectures.tensorflow_components.architecture import Dense
from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme, InputEmbedderParameters, \ from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme, \
PresetValidationParameters PresetValidationParameters
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
from rl_coach.environments.gym_environment import Mujoco from rl_coach.environments.gym_environment import Mujoco
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
from rl_coach.graph_managers.graph_manager import ScheduleParameters from rl_coach.graph_managers.graph_manager import ScheduleParameters

View File

@@ -1,6 +1,6 @@
from rl_coach.agents.ddpg_agent import DDPGAgentParameters from rl_coach.agents.ddpg_agent import DDPGAgentParameters
from rl_coach.architectures.tensorflow_components.architecture import Dense from rl_coach.architectures.tensorflow_components.architecture import Dense
from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme, PresetValidationParameters
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
from rl_coach.environments.control_suite_environment import ControlSuiteEnvironmentParameters, control_suite_envs from rl_coach.environments.control_suite_environment import ControlSuiteEnvironmentParameters, control_suite_envs
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection

View File

@@ -1,7 +1,7 @@
from rl_coach.agents.dfp_agent import DFPAgentParameters from rl_coach.agents.dfp_agent import DFPAgentParameters
from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme, MiddlewareScheme, \ from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme, MiddlewareScheme, \
PresetValidationParameters PresetValidationParameters
from rl_coach.core_types import EnvironmentSteps, RunPhase from rl_coach.core_types import EnvironmentSteps, RunPhase, EnvironmentEpisodes
from rl_coach.environments.doom_environment import DoomEnvironmentParameters from rl_coach.environments.doom_environment import DoomEnvironmentParameters
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager

View File

@@ -1,7 +1,6 @@
from rl_coach.agents.dfp_agent import DFPAgentParameters from rl_coach.agents.dfp_agent import DFPAgentParameters
from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme, MiddlewareScheme, \ from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme, MiddlewareScheme
PresetValidationParameters from rl_coach.core_types import EnvironmentSteps, RunPhase, EnvironmentEpisodes
from rl_coach.core_types import EnvironmentSteps, RunPhase
from rl_coach.environments.doom_environment import DoomEnvironmentParameters from rl_coach.environments.doom_environment import DoomEnvironmentParameters
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
@@ -62,4 +61,4 @@ vis_params.dump_mp4 = False
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params, graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
schedule_params=schedule_params, vis_params=vis_params, schedule_params=schedule_params, vis_params=vis_params,
preset_validation_params=preset_validation_params) )

View File

@@ -1,9 +1,9 @@
from rl_coach.agents.ddpg_agent import DDPGAgentParameters from rl_coach.agents.ddpg_agent import DDPGAgentParameters
from rl_coach.architectures.tensorflow_components.architecture import Dense from rl_coach.architectures.tensorflow_components.architecture import Dense
from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters
from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme, InputEmbedderParameters, \ from rl_coach.base_parameters import VisualizationParameters, EmbedderScheme, PresetValidationParameters
PresetValidationParameters from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase from rl_coach.core_types import EnvironmentEpisodes, EnvironmentSteps, TrainingSteps, RunPhase
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod, SingleLevelSelection from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod, SingleLevelSelection
from rl_coach.environments.gym_environment import Mujoco, MujocoInputFilter, fetch_v1 from rl_coach.environments.gym_environment import Mujoco, MujocoInputFilter, fetch_v1
from rl_coach.exploration_policies.e_greedy import EGreedyParameters from rl_coach.exploration_policies.e_greedy import EGreedyParameters

View File

@@ -1,8 +1,8 @@
from rl_coach.agents.actor_critic_agent import ActorCriticAgentParameters from rl_coach.agents.actor_critic_agent import ActorCriticAgentParameters
from rl_coach.architectures.tensorflow_components.architecture import Dense from rl_coach.architectures.tensorflow_components.architecture import Dense
from rl_coach.architectures.tensorflow_components.middlewares.lstm_middleware import LSTMMiddlewareParameters from rl_coach.architectures.tensorflow_components.middlewares.lstm_middleware import LSTMMiddlewareParameters
from rl_coach.base_parameters import VisualizationParameters, InputEmbedderParameters, MiddlewareScheme, \ from rl_coach.base_parameters import VisualizationParameters, MiddlewareScheme, PresetValidationParameters
PresetValidationParameters from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod, SingleLevelSelection
from rl_coach.environments.gym_environment import Mujoco, mujoco_v2, MujocoInputFilter from rl_coach.environments.gym_environment import Mujoco, mujoco_v2, MujocoInputFilter

View File

@@ -2,8 +2,9 @@ import numpy as np
from rl_coach.agents.hac_ddpg_agent import HACDDPGAgentParameters from rl_coach.agents.hac_ddpg_agent import HACDDPGAgentParameters
from rl_coach.architectures.tensorflow_components.architecture import Dense from rl_coach.architectures.tensorflow_components.architecture import Dense
from rl_coach.base_parameters import VisualizationParameters, EmbeddingMergerType, EmbedderScheme, \ from rl_coach.base_parameters import VisualizationParameters, EmbeddingMergerType, EmbedderScheme
InputEmbedderParameters from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
from rl_coach.core_types import EnvironmentEpisodes, EnvironmentSteps, RunPhase, TrainingSteps from rl_coach.core_types import EnvironmentEpisodes, EnvironmentSteps, RunPhase, TrainingSteps
from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod
from rl_coach.environments.gym_environment import Mujoco from rl_coach.environments.gym_environment import Mujoco

View File

@@ -1,8 +1,9 @@
from rl_coach.agents.actor_critic_agent import ActorCriticAgentParameters from rl_coach.agents.actor_critic_agent import ActorCriticAgentParameters
from rl_coach.agents.policy_optimization_agent import PolicyGradientRescaler from rl_coach.agents.policy_optimization_agent import PolicyGradientRescaler
from rl_coach.base_parameters import VisualizationParameters, InputEmbedderParameters from rl_coach.base_parameters import VisualizationParameters
from rl_coach.core_types import RunPhase from rl_coach.core_types import RunPhase
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod
from rl_coach.environments.starcraft2_environment import StarCraft2EnvironmentParameters from rl_coach.environments.starcraft2_environment import StarCraft2EnvironmentParameters
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters

View File

@@ -2,8 +2,9 @@ from collections import OrderedDict
from rl_coach.agents.ddqn_agent import DDQNAgentParameters from rl_coach.agents.ddqn_agent import DDQNAgentParameters
from rl_coach.architectures.tensorflow_components.heads.dueling_q_head import DuelingQHeadParameters from rl_coach.architectures.tensorflow_components.heads.dueling_q_head import DuelingQHeadParameters
from rl_coach.base_parameters import VisualizationParameters, InputEmbedderParameters from rl_coach.base_parameters import VisualizationParameters
from rl_coach.core_types import RunPhase from rl_coach.core_types import RunPhase
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod from rl_coach.environments.environment import MaxDumpMethod, SelectedPhaseOnlyDumpMethod
from rl_coach.environments.starcraft2_environment import StarCraft2EnvironmentParameters from rl_coach.environments.starcraft2_environment import StarCraft2EnvironmentParameters