mirror of
https://github.com/gryf/coach.git
synced 2026-02-12 11:45:45 +01:00
SAC algorithm (#282)
* SAC algorithm * SAC - updates to agent (learn_from_batch), sac_head and sac_q_head to fix problem in gradient calculation. Now SAC agents is able to train. gym_environment - fixing an error in access to gym.spaces * Soft Actor Critic - code cleanup * code cleanup * V-head initialization fix * SAC benchmarks * SAC Documentation * typo fix * documentation fixes * documentation and version update * README typo
This commit is contained in:
321
rl_coach/agents/soft_actor_critic_agent.py
Normal file
321
rl_coach/agents/soft_actor_critic_agent.py
Normal file
@@ -0,0 +1,321 @@
|
||||
#
|
||||
# Copyright (c) 2019 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Union
|
||||
import copy
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
|
||||
from rl_coach.agents.agent import Agent
|
||||
from rl_coach.agents.policy_optimization_agent import PolicyOptimizationAgent
|
||||
|
||||
from rl_coach.architectures.head_parameters import SACQHeadParameters,SACPolicyHeadParameters,VHeadParameters
|
||||
from rl_coach.architectures.middleware_parameters import FCMiddlewareParameters
|
||||
from rl_coach.base_parameters import AlgorithmParameters, NetworkParameters, AgentParameters, EmbedderScheme, MiddlewareScheme
|
||||
from rl_coach.core_types import ActionInfo, EnvironmentSteps, RunPhase
|
||||
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
|
||||
from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayParameters
|
||||
from rl_coach.architectures.embedder_parameters import InputEmbedderParameters
|
||||
from rl_coach.spaces import BoxActionSpace
|
||||
|
||||
|
||||
# There are 3 networks in SAC implementation. All have the same topology but parameters are not shared.
|
||||
# The networks are:
|
||||
# 1. State Value Network - SACValueNetwork
|
||||
# 2. Soft Q Value Network - SACCriticNetwork
|
||||
# 3. Policy Network - SACPolicyNetwork - currently supporting only Gaussian Policy
|
||||
|
||||
|
||||
# 1. State Value Network - SACValueNetwork
|
||||
# this is the state value network in SAC.
|
||||
# The network is trained to predict (regression) the state value in the max-entropy settings
|
||||
# The objective to be minimized is given in equation (5) in the paper:
|
||||
#
|
||||
# J(psi)= E_(s~D)[0.5*(V_psi(s)-y(s))^2]
|
||||
# where y(s) = E_(a~pi)[Q_theta(s,a)-log(pi(a|s))]
|
||||
|
||||
|
||||
# Default parameters for value network:
|
||||
# topology :
|
||||
# input embedder : EmbedderScheme.Medium (Dense(256)) , relu activation
|
||||
# middleware : EmbedderScheme.Medium (Dense(256)) , relu activation
|
||||
|
||||
|
||||
class SACValueNetworkParameters(NetworkParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters(activation_function='relu')}
|
||||
self.middleware_parameters = FCMiddlewareParameters(activation_function='relu')
|
||||
self.heads_parameters = [VHeadParameters(initializer='xavier')]
|
||||
self.rescale_gradient_from_head_by_factor = [1]
|
||||
self.optimizer_type = 'Adam'
|
||||
self.batch_size = 256
|
||||
self.async_training = False
|
||||
self.learning_rate = 0.0003 # 3e-4 see appendix D in the paper
|
||||
self.create_target_network = True # tau is set in SoftActorCriticAlgorithmParameters.rate_for_copying_weights_to_target
|
||||
|
||||
|
||||
# 2. Soft Q Value Network - SACCriticNetwork
|
||||
# the whole network is built in the SACQHeadParameters. we use empty input embedder and middleware
|
||||
class SACCriticNetworkParameters(NetworkParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters(scheme=EmbedderScheme.Empty)}
|
||||
self.middleware_parameters = FCMiddlewareParameters(scheme=MiddlewareScheme.Empty)
|
||||
self.heads_parameters = [SACQHeadParameters()] # SACQHeadParameters includes the topology of the head
|
||||
self.rescale_gradient_from_head_by_factor = [1]
|
||||
self.optimizer_type = 'Adam'
|
||||
self.batch_size = 256
|
||||
self.async_training = False
|
||||
self.learning_rate = 0.0003
|
||||
self.create_target_network = False
|
||||
|
||||
|
||||
# 3. policy Network
|
||||
# Default parameters for policy network:
|
||||
# topology :
|
||||
# input embedder : EmbedderScheme.Medium (Dense(256)) , relu activation
|
||||
# middleware : EmbedderScheme = [Dense(256)] , relu activation --> scheme should be overridden in preset
|
||||
class SACPolicyNetworkParameters(NetworkParameters):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters(activation_function='relu')}
|
||||
self.middleware_parameters = FCMiddlewareParameters(activation_function='relu')
|
||||
self.heads_parameters = [SACPolicyHeadParameters()]
|
||||
self.rescale_gradient_from_head_by_factor = [1]
|
||||
self.optimizer_type = 'Adam'
|
||||
self.batch_size = 256
|
||||
self.async_training = False
|
||||
self.learning_rate = 0.0003
|
||||
self.create_target_network = False
|
||||
self.l2_regularization = 0 # weight decay regularization. not used in the original paper
|
||||
|
||||
|
||||
# Algorithm Parameters
|
||||
|
||||
class SoftActorCriticAlgorithmParameters(AlgorithmParameters):
|
||||
"""
|
||||
:param num_steps_between_copying_online_weights_to_target: (StepMethod)
|
||||
The number of steps between copying the online network weights to the target network weights.
|
||||
|
||||
:param rate_for_copying_weights_to_target: (float)
|
||||
When copying the online network weights to the target network weights, a soft update will be used, which
|
||||
weight the new online network weights by rate_for_copying_weights_to_target. (Tau as defined in the paper)
|
||||
|
||||
:param use_deterministic_for_evaluation: (bool)
|
||||
If True, during the evaluation phase, action are chosen deterministically according to the policy mean
|
||||
and not sampled from the policy distribution.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(1)
|
||||
self.rate_for_copying_weights_to_target = 0.005
|
||||
self.use_deterministic_for_evaluation = True # evaluate agent using deterministic policy (i.e. take the mean value)
|
||||
|
||||
|
||||
class SoftActorCriticAgentParameters(AgentParameters):
|
||||
def __init__(self):
|
||||
super().__init__(algorithm=SoftActorCriticAlgorithmParameters(),
|
||||
exploration=AdditiveNoiseParameters(),
|
||||
memory=ExperienceReplayParameters(), # SAC doesnt use episodic related data
|
||||
# network wrappers:
|
||||
networks=OrderedDict([("policy", SACPolicyNetworkParameters()),
|
||||
("q", SACCriticNetworkParameters()),
|
||||
("v", SACValueNetworkParameters())]))
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return 'rl_coach.agents.soft_actor_critic_agent:SoftActorCriticAgent'
|
||||
|
||||
|
||||
# Soft Actor Critic - https://arxiv.org/abs/1801.01290
|
||||
class SoftActorCriticAgent(PolicyOptimizationAgent):
|
||||
def __init__(self, agent_parameters, parent: Union['LevelManager', 'CompositeAgent']=None):
|
||||
super().__init__(agent_parameters, parent)
|
||||
self.last_gradient_update_step_idx = 0
|
||||
|
||||
# register signals to track (in learn_from_batch)
|
||||
self.policy_means = self.register_signal('Policy_mu_avg')
|
||||
self.policy_logsig = self.register_signal('Policy_logsig')
|
||||
self.policy_logprob_sampled = self.register_signal('Policy_logp_sampled')
|
||||
self.policy_grads = self.register_signal('Policy_grads_sumabs')
|
||||
|
||||
self.q1_values = self.register_signal("Q1")
|
||||
self.TD_err1 = self.register_signal("TD err1")
|
||||
self.q2_values = self.register_signal("Q2")
|
||||
self.TD_err2 = self.register_signal("TD err2")
|
||||
self.v_tgt_ns = self.register_signal('V_tgt_ns')
|
||||
self.v_onl_ys = self.register_signal('V_onl_ys')
|
||||
self.action_signal = self.register_signal("actions")
|
||||
|
||||
def learn_from_batch(self, batch):
|
||||
#########################################
|
||||
# need to update the following networks:
|
||||
# 1. actor (policy)
|
||||
# 2. state value (v)
|
||||
# 3. critic (q1 and q2)
|
||||
# 4. target network - probably already handled by V
|
||||
|
||||
#########################################
|
||||
# define the networks to be used
|
||||
|
||||
# State Value Network
|
||||
value_network = self.networks['v']
|
||||
value_network_keys = self.ap.network_wrappers['v'].input_embedders_parameters.keys()
|
||||
|
||||
# Critic Network
|
||||
q_network = self.networks['q'].online_network
|
||||
q_head = q_network.output_heads[0]
|
||||
q_network_keys = self.ap.network_wrappers['q'].input_embedders_parameters.keys()
|
||||
|
||||
# Actor (policy) Network
|
||||
policy_network = self.networks['policy'].online_network
|
||||
policy_network_keys = self.ap.network_wrappers['policy'].input_embedders_parameters.keys()
|
||||
|
||||
##########################################
|
||||
# 1. updating the actor - according to (13) in the paper
|
||||
policy_inputs = copy.copy(batch.states(policy_network_keys))
|
||||
policy_results = policy_network.predict(policy_inputs)
|
||||
|
||||
policy_mu, policy_std, sampled_raw_actions, sampled_actions, sampled_actions_logprob, \
|
||||
sampled_actions_logprob_mean = policy_results
|
||||
|
||||
self.policy_means.add_sample(policy_mu)
|
||||
self.policy_logsig.add_sample(policy_std)
|
||||
self.policy_logprob_sampled.add_sample(sampled_actions_logprob_mean)
|
||||
|
||||
# get the state-action values for the replayed states and their corresponding actions from the policy
|
||||
q_inputs = copy.copy(batch.states(q_network_keys))
|
||||
q_inputs['output_0_0'] = sampled_actions
|
||||
log_target = q_network.predict(q_inputs)[0].squeeze()
|
||||
|
||||
# log internal q values
|
||||
q1_vals, q2_vals = q_network.predict(q_inputs, outputs=[q_head.q1_output, q_head.q2_output])
|
||||
self.q1_values.add_sample(q1_vals)
|
||||
self.q2_values.add_sample(q2_vals)
|
||||
|
||||
# calculate the gradients according to (13)
|
||||
# get the gradients of log_prob w.r.t the weights (parameters) - indicated as phi in the paper
|
||||
initial_feed_dict = {policy_network.gradients_weights_ph[5]: np.array(1.0)}
|
||||
dlogp_dphi = policy_network.predict(policy_inputs,
|
||||
outputs=policy_network.weighted_gradients[5],
|
||||
initial_feed_dict=initial_feed_dict)
|
||||
|
||||
# calculate dq_da
|
||||
dq_da = q_network.predict(q_inputs,
|
||||
outputs=q_network.gradients_wrt_inputs[1]['output_0_0'])
|
||||
|
||||
# calculate da_dphi
|
||||
initial_feed_dict = {policy_network.gradients_weights_ph[3]: dq_da}
|
||||
dq_dphi = policy_network.predict(policy_inputs,
|
||||
outputs=policy_network.weighted_gradients[3],
|
||||
initial_feed_dict=initial_feed_dict)
|
||||
|
||||
# now given dlogp_dphi, dq_dphi we need to calculate the policy gradients according to (13)
|
||||
policy_grads = [dlogp_dphi[l] - dq_dphi[l] for l in range(len(dlogp_dphi))]
|
||||
|
||||
# apply the gradients to policy networks
|
||||
policy_network.apply_gradients(policy_grads)
|
||||
grads_sumabs = np.sum([np.sum(np.abs(policy_grads[l])) for l in range(len(policy_grads))])
|
||||
self.policy_grads.add_sample(grads_sumabs)
|
||||
|
||||
##########################################
|
||||
# 2. updating the state value online network weights
|
||||
# done by calculating the targets for the v head according to (5) in the paper
|
||||
# value_targets = log_targets-sampled_actions_logprob
|
||||
value_inputs = copy.copy(batch.states(value_network_keys))
|
||||
value_targets = log_target - sampled_actions_logprob
|
||||
|
||||
self.v_onl_ys.add_sample(value_targets)
|
||||
|
||||
# call value_network apply gradients with this target
|
||||
value_loss = value_network.online_network.train_on_batch(value_inputs, value_targets[:,None])[0]
|
||||
|
||||
##########################################
|
||||
# 3. updating the critic (q networks)
|
||||
# updating q networks according to (7) in the paper
|
||||
|
||||
# define the input to the q network: state has been already updated previously, but now we need
|
||||
# the actions from the batch (and not those sampled by the policy)
|
||||
q_inputs['output_0_0'] = batch.actions(len(batch.actions().shape) == 1)
|
||||
|
||||
# define the targets : scale_reward * reward + (1-terminal)*discount*v_target_next_state
|
||||
# define v_target_next_state
|
||||
value_inputs = copy.copy(batch.next_states(value_network_keys))
|
||||
v_target_next_state = value_network.target_network.predict(value_inputs)
|
||||
self.v_tgt_ns.add_sample(v_target_next_state)
|
||||
# Note: reward is assumed to be rescaled by RewardRescaleFilter in the preset parameters
|
||||
TD_targets = batch.rewards(expand_dims=True) + \
|
||||
(1.0 - batch.game_overs(expand_dims=True)) * self.ap.algorithm.discount * v_target_next_state
|
||||
|
||||
# call critic network update
|
||||
result = q_network.train_on_batch(q_inputs, TD_targets, additional_fetches=[q_head.q1_loss, q_head.q2_loss])
|
||||
total_loss, losses, unclipped_grads = result[:3]
|
||||
q1_loss, q2_loss = result[3]
|
||||
self.TD_err1.add_sample(q1_loss)
|
||||
self.TD_err2.add_sample(q2_loss)
|
||||
|
||||
##########################################
|
||||
# 4. updating the value target network
|
||||
# I just need to set the parameter rate_for_copying_weights_to_target in the agent parameters to be 1-tau
|
||||
# where tau is the hyper parameter as defined in sac original implementation
|
||||
|
||||
return total_loss, losses, unclipped_grads
|
||||
|
||||
def get_prediction(self, states):
|
||||
"""
|
||||
get the mean and stdev of the policy distribution given 'states'
|
||||
:param states: the states for which we need to sample actions from the policy
|
||||
:return: mean and stdev
|
||||
"""
|
||||
tf_input_state = self.prepare_batch_for_inference(states, 'policy')
|
||||
return self.networks['policy'].online_network.predict(tf_input_state)
|
||||
|
||||
def train(self):
|
||||
# since the algorithm works with experience replay buffer (non-episodic),
|
||||
# we cant use the policy optimization train method. we need Agent.train
|
||||
# note that since in Agent.train there is no apply_gradients, we need to do it in learn from batch
|
||||
return Agent.train(self)
|
||||
|
||||
def choose_action(self, curr_state):
|
||||
"""
|
||||
choose_action - chooses the most likely action
|
||||
if 'deterministic' - take the mean of the policy which is the prediction of the policy network.
|
||||
else - use the exploration policy
|
||||
:param curr_state:
|
||||
:return: action wrapped in ActionInfo
|
||||
"""
|
||||
if not isinstance(self.spaces.action, BoxActionSpace):
|
||||
raise ValueError("SAC works only for continuous control problems")
|
||||
# convert to batch so we can run it through the network
|
||||
tf_input_state = self.prepare_batch_for_inference(curr_state, 'policy')
|
||||
# use the online network for prediction
|
||||
policy_network = self.networks['policy'].online_network
|
||||
policy_head = policy_network.output_heads[0]
|
||||
result = policy_network.predict(tf_input_state,
|
||||
outputs=[policy_head.policy_mean, policy_head.actions])
|
||||
action_mean, action_sample = result
|
||||
|
||||
# if using deterministic policy, take the mean values. else, use exploration policy to sample from the pdf
|
||||
if self.phase == RunPhase.TEST and self.ap.algorithm.use_deterministic_for_evaluation:
|
||||
action = action_mean[0]
|
||||
else:
|
||||
action = action_sample[0]
|
||||
|
||||
self.action_signal.add_sample(action)
|
||||
|
||||
action_info = ActionInfo(action=action)
|
||||
return action_info
|
||||
@@ -36,7 +36,6 @@ class HeadParameters(NetworkComponentParameters):
|
||||
return 'rl_coach.architectures.tensorflow_components.heads:' + self.parameterized_class_name
|
||||
|
||||
|
||||
|
||||
class PPOHeadParameters(HeadParameters):
|
||||
def __init__(self, activation_function: str ='tanh', name: str='ppo_head_params',
|
||||
num_output_head_copies: int = 1, rescale_gradient_from_head_by_factor: float = 1.0,
|
||||
@@ -50,11 +49,12 @@ class PPOHeadParameters(HeadParameters):
|
||||
class VHeadParameters(HeadParameters):
|
||||
def __init__(self, activation_function: str ='relu', name: str='v_head_params',
|
||||
num_output_head_copies: int = 1, rescale_gradient_from_head_by_factor: float = 1.0,
|
||||
loss_weight: float = 1.0, dense_layer=None):
|
||||
loss_weight: float = 1.0, dense_layer=None, initializer='normalized_columns'):
|
||||
super().__init__(parameterized_class_name="VHead", activation_function=activation_function, name=name,
|
||||
dense_layer=dense_layer, num_output_head_copies=num_output_head_copies,
|
||||
rescale_gradient_from_head_by_factor=rescale_gradient_from_head_by_factor,
|
||||
loss_weight=loss_weight)
|
||||
self.initializer = initializer
|
||||
|
||||
|
||||
class CategoricalQHeadParameters(HeadParameters):
|
||||
@@ -196,3 +196,17 @@ class ACERPolicyHeadParameters(HeadParameters):
|
||||
dense_layer=dense_layer, num_output_head_copies=num_output_head_copies,
|
||||
rescale_gradient_from_head_by_factor=rescale_gradient_from_head_by_factor,
|
||||
loss_weight=loss_weight)
|
||||
|
||||
|
||||
class SACPolicyHeadParameters(HeadParameters):
|
||||
def __init__(self, activation_function: str ='relu', name: str='sac_policy_head_params', dense_layer=None):
|
||||
super().__init__(parameterized_class_name='SACPolicyHead', activation_function=activation_function, name=name,
|
||||
dense_layer=dense_layer)
|
||||
|
||||
|
||||
class SACQHeadParameters(HeadParameters):
|
||||
def __init__(self, activation_function: str ='relu', name: str='sac_q_head_params', dense_layer=None,
|
||||
layers_sizes: tuple = (256, 256)):
|
||||
super().__init__(parameterized_class_name='SACQHead', activation_function=activation_function, name=name,
|
||||
dense_layer=dense_layer)
|
||||
self.network_layers_sizes = layers_sizes
|
||||
|
||||
@@ -12,6 +12,8 @@ from .quantile_regression_q_head import QuantileRegressionQHead
|
||||
from .rainbow_q_head import RainbowQHead
|
||||
from .v_head import VHead
|
||||
from .acer_policy_head import ACERPolicyHead
|
||||
from .sac_head import SACPolicyHead
|
||||
from .sac_q_head import SACQHead
|
||||
from .classification_head import ClassificationHead
|
||||
from .cil_head import RegressionHead
|
||||
|
||||
@@ -30,6 +32,8 @@ __all__ = [
|
||||
'RainbowQHead',
|
||||
'VHead',
|
||||
'ACERPolicyHead',
|
||||
'ClassificationHead'
|
||||
'SACPolicyHead',
|
||||
'SACQHead',
|
||||
'ClassificationHead',
|
||||
'RegressionHead'
|
||||
]
|
||||
|
||||
107
rl_coach/architectures/tensorflow_components/heads/sac_head.py
Normal file
107
rl_coach/architectures/tensorflow_components/heads/sac_head.py
Normal file
@@ -0,0 +1,107 @@
|
||||
#
|
||||
# Copyright (c) 2019 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.architectures.tensorflow_components.heads.head import Head
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.core_types import ActionProbabilities
|
||||
from rl_coach.spaces import SpacesDefinition
|
||||
from rl_coach.utils import eps
|
||||
|
||||
LOG_SIG_CAP_MAX = 2
|
||||
LOG_SIG_CAP_MIN = -20
|
||||
|
||||
|
||||
class SACPolicyHead(Head):
|
||||
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
|
||||
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu',
|
||||
squash: bool = True, dense_layer=Dense):
|
||||
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
|
||||
dense_layer=dense_layer)
|
||||
self.name = 'sac_policy_head'
|
||||
self.return_type = ActionProbabilities
|
||||
self.num_actions = self.spaces.action.shape # continuous actions
|
||||
self.squash = squash # squashing using tanh
|
||||
|
||||
def _build_module(self, input_layer):
|
||||
self.given_raw_actions = tf.placeholder(tf.float32, [None, self.num_actions], name="actions")
|
||||
self.input = [self.given_raw_actions]
|
||||
self.output = []
|
||||
|
||||
# build the network
|
||||
self._build_continuous_net(input_layer, self.spaces.action)
|
||||
|
||||
def _squash_correction(self,actions):
|
||||
'''
|
||||
correct squash operation (in case of bounded actions) according to appendix C in the paper.
|
||||
NOTE : this correction assume the squash is done with tanh.
|
||||
:param actions: unbounded actions
|
||||
:return: the correction to be applied to the log_prob of the actions, assuming tanh squash
|
||||
'''
|
||||
if not self.squash:
|
||||
return 0
|
||||
return tf.reduce_sum(tf.log(1 - tf.tanh(actions) ** 2 + eps), axis=1)
|
||||
|
||||
def _build_continuous_net(self, input_layer, action_space):
|
||||
num_actions = action_space.shape[0]
|
||||
|
||||
self.policy_mu_and_logsig = self.dense_layer(2*num_actions)(input_layer, name='policy_mu_logsig')
|
||||
self.policy_mean = tf.identity(self.policy_mu_and_logsig[..., :num_actions], name='policy_mean')
|
||||
self.policy_log_std = tf.clip_by_value(self.policy_mu_and_logsig[..., num_actions:],
|
||||
LOG_SIG_CAP_MIN, LOG_SIG_CAP_MAX,name='policy_log_std')
|
||||
|
||||
self.output.append(self.policy_mean) # output[0]
|
||||
self.output.append(self.policy_log_std) # output[1]
|
||||
|
||||
# define the distributions for the policy
|
||||
# Tensorflow's multivariate normal distribution supports reparameterization
|
||||
tfd = tf.contrib.distributions
|
||||
self.policy_distribution = tfd.MultivariateNormalDiag(loc=self.policy_mean,
|
||||
scale_diag=tf.exp(self.policy_log_std))
|
||||
|
||||
# define network outputs
|
||||
# note that tensorflow supports reparametrization.
|
||||
# i.e. policy_action_sample is a tensor through which gradients can flow
|
||||
self.raw_actions = self.policy_distribution.sample()
|
||||
|
||||
if self.squash:
|
||||
self.actions = tf.tanh(self.raw_actions)
|
||||
# correct log_prob in case of squash (see appendix C in the paper)
|
||||
squash_correction = self._squash_correction(self.raw_actions)
|
||||
else:
|
||||
self.actions = self.raw_actions
|
||||
squash_correction = 0
|
||||
|
||||
# policy_action_logprob is a tensor through which gradients can flow
|
||||
self.sampled_actions_logprob = self.policy_distribution.log_prob(self.raw_actions) - squash_correction
|
||||
self.sampled_actions_logprob_mean = tf.reduce_mean(self.sampled_actions_logprob)
|
||||
|
||||
self.output.append(self.raw_actions) # output[2] : sampled raw action (before squash)
|
||||
self.output.append(self.actions) # output[3] : squashed (if needed) version of sampled raw_actions
|
||||
self.output.append(self.sampled_actions_logprob) # output[4]: log prob of sampled action (squash corrected)
|
||||
self.output.append(self.sampled_actions_logprob_mean) # output[5]: mean of log prob of sampled actions (squash corrected)
|
||||
|
||||
def __str__(self):
|
||||
result = [
|
||||
"policy head:"
|
||||
"\t\tDense (num outputs = 256)",
|
||||
"\t\tDense (num outputs = 256)",
|
||||
"\t\tDense (num outputs = {0})".format(2*self.num_actions),
|
||||
"policy_mu = output[:num_actions], policy_std = output[num_actions:]"
|
||||
]
|
||||
return '\n'.join(result)
|
||||
116
rl_coach/architectures/tensorflow_components/heads/sac_q_head.py
Normal file
116
rl_coach/architectures/tensorflow_components/heads/sac_q_head.py
Normal file
@@ -0,0 +1,116 @@
|
||||
#
|
||||
# Copyright (c) 2019 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.architectures.tensorflow_components.heads.head import Head
|
||||
from rl_coach.base_parameters import AgentParameters
|
||||
from rl_coach.core_types import QActionStateValue
|
||||
from rl_coach.spaces import SpacesDefinition, BoxActionSpace
|
||||
|
||||
|
||||
class SACQHead(Head):
|
||||
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
|
||||
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu',
|
||||
dense_layer=Dense):
|
||||
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
|
||||
dense_layer=dense_layer)
|
||||
self.name = 'q_values_head'
|
||||
if isinstance(self.spaces.action, BoxActionSpace):
|
||||
self.num_actions = self.spaces.action.shape # continuous actions
|
||||
else:
|
||||
raise ValueError(
|
||||
'SACQHead does not support action spaces of type: {class_name}'.format(
|
||||
class_name=self.spaces.action.__class__.__name__,
|
||||
)
|
||||
)
|
||||
self.return_type = QActionStateValue
|
||||
# extract the topology from the SACQHeadParameters
|
||||
self.network_layers_sizes = agent_parameters.network_wrappers['q'].heads_parameters[0].network_layers_sizes
|
||||
|
||||
def _build_module(self, input_layer):
|
||||
# SAC Q network is basically 2 networks running in parallel on the same input (state , action)
|
||||
# state is the observation fed through the input_layer, action is fed through placeholder to the header
|
||||
# each is calculating q value : q1(s,a) and q2(s,a)
|
||||
# the output of the head is min(q1,q2)
|
||||
self.actions = tf.placeholder(tf.float32, [None, self.num_actions], name="actions")
|
||||
self.target = tf.placeholder(tf.float32, [None, 1], name="q_targets")
|
||||
self.input = [self.actions]
|
||||
self.output = []
|
||||
# Note (1) : in the author's implementation of sac (in rllab) they summarize the embedding of observation and
|
||||
# action (broadcasting the bias) in the first layer of the network.
|
||||
|
||||
# build q1 network head
|
||||
with tf.variable_scope("q1_head"):
|
||||
layer_size = self.network_layers_sizes[0]
|
||||
qi_obs_emb = self.dense_layer(layer_size)(input_layer, activation=self.activation_function)
|
||||
qi_act_emb = self.dense_layer(layer_size)(self.actions, activation=self.activation_function)
|
||||
qi_output = qi_obs_emb + qi_act_emb # merging the inputs by summarizing them (see Note (1))
|
||||
for layer_size in self.network_layers_sizes[1:]:
|
||||
qi_output = self.dense_layer(layer_size)(qi_output, activation=self.activation_function)
|
||||
# the output layer
|
||||
self.q1_output = self.dense_layer(1)(qi_output, name='q1_output')
|
||||
|
||||
# build q2 network head
|
||||
with tf.variable_scope("q2_head"):
|
||||
layer_size = self.network_layers_sizes[0]
|
||||
qi_obs_emb = self.dense_layer(layer_size)(input_layer, activation=self.activation_function)
|
||||
qi_act_emb = self.dense_layer(layer_size)(self.actions, activation=self.activation_function)
|
||||
qi_output = qi_obs_emb + qi_act_emb # merging the inputs by summarizing them (see Note (1))
|
||||
for layer_size in self.network_layers_sizes[1:]:
|
||||
qi_output = self.dense_layer(layer_size)(qi_output, activation=self.activation_function)
|
||||
# the output layer
|
||||
self.q2_output = self.dense_layer(1)(qi_output, name='q2_output')
|
||||
|
||||
# take the minimum as the network's output. this is the log_target (in the original implementation)
|
||||
self.q_output = tf.minimum(self.q1_output, self.q2_output, name='q_output')
|
||||
# the policy gradients
|
||||
# self.q_output_mean = tf.reduce_mean(self.q1_output) # option 1: use q1
|
||||
self.q_output_mean = tf.reduce_mean(self.q_output) # option 2: use min(q1,q2)
|
||||
|
||||
self.output.append(self.q_output)
|
||||
self.output.append(self.q_output_mean)
|
||||
|
||||
# defining the loss
|
||||
self.q1_loss = 0.5*tf.reduce_mean(tf.square(self.q1_output - self.target))
|
||||
self.q2_loss = 0.5*tf.reduce_mean(tf.square(self.q2_output - self.target))
|
||||
# eventually both losses are depends on different parameters so we can sum them up
|
||||
self.loss = self.q1_loss+self.q2_loss
|
||||
tf.losses.add_loss(self.loss)
|
||||
|
||||
def __str__(self):
|
||||
result = [
|
||||
"q1 output"
|
||||
"\t\tDense (num outputs = 256)",
|
||||
"\t\tDense (num outputs = 256)",
|
||||
"\t\tDense (num outputs = 1)",
|
||||
"q2 output"
|
||||
"\t\tDense (num outputs = 256)",
|
||||
"\t\tDense (num outputs = 256)",
|
||||
"\t\tDense (num outputs = 1)",
|
||||
"min(Q1,Q2)"
|
||||
]
|
||||
return '\n'.join(result)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ from rl_coach.spaces import SpacesDefinition
|
||||
class VHead(Head):
|
||||
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
|
||||
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu',
|
||||
dense_layer=Dense):
|
||||
dense_layer=Dense, initializer='normalized_columns'):
|
||||
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
|
||||
dense_layer=dense_layer)
|
||||
self.name = 'v_values_head'
|
||||
@@ -37,10 +37,15 @@ class VHead(Head):
|
||||
else:
|
||||
self.loss_type = tf.losses.mean_squared_error
|
||||
|
||||
self.initializer = initializer
|
||||
|
||||
def _build_module(self, input_layer):
|
||||
# Standard V Network
|
||||
self.output = self.dense_layer(1)(input_layer, name='output',
|
||||
kernel_initializer=normalized_columns_initializer(1.0))
|
||||
if self.initializer == 'normalized_columns':
|
||||
self.output = self.dense_layer(1)(input_layer, name='output',
|
||||
kernel_initializer=normalized_columns_initializer(1.0))
|
||||
elif self.initializer == 'xavier' or self.initializer is None:
|
||||
self.output = self.dense_layer(1)(input_layer, name='output')
|
||||
|
||||
def __str__(self):
|
||||
result = [
|
||||
|
||||
71
rl_coach/presets/Mujoco_SAC.py
Normal file
71
rl_coach/presets/Mujoco_SAC.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from rl_coach.agents.soft_actor_critic_agent import SoftActorCriticAgentParameters
|
||||
from rl_coach.architectures.layers import Dense
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.filters.filter import InputFilter
|
||||
from rl_coach.filters.reward.reward_rescale_filter import RewardRescaleFilter
|
||||
from rl_coach.environments.environment import SingleLevelSelection
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
|
||||
|
||||
####################
|
||||
# Graph Scheduling #
|
||||
####################
|
||||
|
||||
# see graph_manager.py for possible schedule parameters
|
||||
schedule_params = ScheduleParameters()
|
||||
schedule_params.improve_steps = EnvironmentSteps(3000000)
|
||||
schedule_params.steps_between_evaluation_periods = EnvironmentSteps(1000)
|
||||
schedule_params.evaluation_steps = EnvironmentEpisodes(1)
|
||||
schedule_params.heatup_steps = EnvironmentSteps(10000)
|
||||
|
||||
|
||||
#########
|
||||
# Agent #
|
||||
#########
|
||||
agent_params = SoftActorCriticAgentParameters()
|
||||
# override default parameters:
|
||||
# value (v) networks parameters
|
||||
agent_params.network_wrappers['v'].batch_size = 256
|
||||
agent_params.network_wrappers['v'].learning_rate = 0.0003
|
||||
agent_params.network_wrappers['v'].middleware_parameters.scheme = [Dense(256)]
|
||||
|
||||
# critic (q) network parameters
|
||||
agent_params.network_wrappers['q'].heads_parameters[0].network_layers_sizes = (256, 256)
|
||||
agent_params.network_wrappers['q'].batch_size = 256
|
||||
agent_params.network_wrappers['q'].learning_rate = 0.0003
|
||||
|
||||
# actor (policy) network parameters
|
||||
agent_params.network_wrappers['policy'].batch_size = 256
|
||||
agent_params.network_wrappers['policy'].learning_rate = 0.0003
|
||||
agent_params.network_wrappers['policy'].middleware_parameters.scheme = [Dense(256)]
|
||||
|
||||
# Input Filter
|
||||
# SAC requires reward scaling for Mujoco environments.
|
||||
# according to the paper:
|
||||
# Hopper, Walker-2d, HalfCheetah, Ant - requires scaling of 5
|
||||
# Humanoid - requires scaling of 20
|
||||
|
||||
agent_params.input_filter = InputFilter()
|
||||
agent_params.input_filter.add_reward_filter('rescale', RewardRescaleFilter(5))
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
###############
|
||||
env_params = GymVectorEnvironment(level=SingleLevelSelection(mujoco_v2))
|
||||
|
||||
########
|
||||
# Test #
|
||||
########
|
||||
preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.test = True
|
||||
preset_validation_params.min_reward_threshold = 400
|
||||
preset_validation_params.max_episodes_to_achieve_reward = 2200
|
||||
preset_validation_params.reward_test_level = 'inverted_pendulum'
|
||||
preset_validation_params.trace_test_levels = ['inverted_pendulum', 'hopper']
|
||||
|
||||
graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params, vis_params=VisualizationParameters(),
|
||||
preset_validation_params=preset_validation_params)
|
||||
Reference in New Issue
Block a user