mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Corrected MXNet's PPO Head for Continuous Action Spaces (#84)
* Changes required for Continuous PPO Head with MXNet. Used in MountainCarContinuous_ClippedPPO. * Simplified changes for continuous ppo. * Cleaned up to avoid duplicate code, and simplified covariance creation.
This commit is contained in:
committed by
Scott Leishman
parent
fde73ced13
commit
3358e04a6a
@@ -299,8 +299,7 @@ class MxnetArchitecture(Architecture):
|
|||||||
assert outputs is None, "outputs must be None"
|
assert outputs is None, "outputs must be None"
|
||||||
|
|
||||||
output = self._predict(inputs)
|
output = self._predict(inputs)
|
||||||
|
output = list(o.asnumpy() for o in output)
|
||||||
output = tuple(o.asnumpy() for o in output)
|
|
||||||
if squeeze_output:
|
if squeeze_output:
|
||||||
output = squeeze_list(output)
|
output = squeeze_list(output)
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -412,7 +412,10 @@ class SingleModel(HybridBlock):
|
|||||||
# Head
|
# Head
|
||||||
outputs = tuple()
|
outputs = tuple()
|
||||||
for head in self._output_heads:
|
for head in self._output_heads:
|
||||||
outputs += (head(state_embedding),)
|
out = head(state_embedding)
|
||||||
|
if not isinstance(out, tuple):
|
||||||
|
out = (out,)
|
||||||
|
outputs += out
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class MultivariateNormalDist:
|
|||||||
sigma: nd_sym_type,
|
sigma: nd_sym_type,
|
||||||
F: ModuleType=mx.nd) -> None:
|
F: ModuleType=mx.nd) -> None:
|
||||||
"""
|
"""
|
||||||
Distribution object for Multivariate Normal. Works with batches.
|
Distribution object for Multivariate Normal. Works with batches.
|
||||||
Optionally works with batches and time steps, but be consistent in usage: i.e. if using time_step,
|
Optionally works with batches and time steps, but be consistent in usage: i.e. if using time_step,
|
||||||
mean, sigma and data for log_prob must all include a time_step dimension.
|
mean, sigma and data for log_prob must all include a time_step dimension.
|
||||||
|
|
||||||
@@ -264,12 +264,12 @@ class ContinuousPPOHead(nn.HybridBlock):
|
|||||||
# but since we assume the action probability variables are independent,
|
# but since we assume the action probability variables are independent,
|
||||||
# only the diagonal entries of the covariance matrix are specified.
|
# only the diagonal entries of the covariance matrix are specified.
|
||||||
self.log_std = self.params.get('log_std',
|
self.log_std = self.params.get('log_std',
|
||||||
shape=num_actions,
|
shape=(num_actions,),
|
||||||
init=mx.init.Zero(),
|
init=mx.init.Zero(),
|
||||||
allow_deferred_init=True)
|
allow_deferred_init=True)
|
||||||
# todo: is_local?
|
# todo: is_local?
|
||||||
|
|
||||||
def hybrid_forward(self, F: ModuleType, x: nd_sym_type, log_std: nd_sym_type) -> List[nd_sym_type]:
|
def hybrid_forward(self, F: ModuleType, x: nd_sym_type, log_std: nd_sym_type) -> Tuple[nd_sym_type, nd_sym_type]:
|
||||||
"""
|
"""
|
||||||
Used for forward pass through head network.
|
Used for forward pass through head network.
|
||||||
|
|
||||||
@@ -282,8 +282,8 @@ class ContinuousPPOHead(nn.HybridBlock):
|
|||||||
of shape (batch_size, time_step, action_mean).
|
of shape (batch_size, time_step, action_mean).
|
||||||
"""
|
"""
|
||||||
policy_means = self.dense(x)
|
policy_means = self.dense(x)
|
||||||
policy_std = log_std.exp()
|
policy_std = log_std.exp().expand_dims(0).broadcast_like(policy_means)
|
||||||
return [policy_means, policy_std]
|
return policy_means, policy_std
|
||||||
|
|
||||||
|
|
||||||
class ClippedPPOLossDiscrete(HeadLoss):
|
class ClippedPPOLossDiscrete(HeadLoss):
|
||||||
@@ -490,8 +490,8 @@ class ClippedPPOLossContinuous(HeadLoss):
|
|||||||
of shape (batch_size, num_actions) or
|
of shape (batch_size, num_actions) or
|
||||||
of shape (batch_size, time_step, num_actions).
|
of shape (batch_size, time_step, num_actions).
|
||||||
:param actions: true actions taken during rollout,
|
:param actions: true actions taken during rollout,
|
||||||
of shape (batch_size) or
|
of shape (batch_size, num_actions) or
|
||||||
of shape (batch_size, time_step).
|
of shape (batch_size, time_step, num_actions).
|
||||||
:param old_policy_means: action means for previous policy,
|
:param old_policy_means: action means for previous policy,
|
||||||
of shape (batch_size, num_actions) or
|
of shape (batch_size, num_actions) or
|
||||||
of shape (batch_size, time_step, num_actions).
|
of shape (batch_size, time_step, num_actions).
|
||||||
@@ -500,20 +500,24 @@ class ClippedPPOLossContinuous(HeadLoss):
|
|||||||
of shape (batch_size, time_step, num_actions).
|
of shape (batch_size, time_step, num_actions).
|
||||||
:param clip_param_rescaler: scales epsilon to use for likelihood ratio clipping.
|
:param clip_param_rescaler: scales epsilon to use for likelihood ratio clipping.
|
||||||
:param advantages: change in state value after taking action (a.k.a advantage)
|
:param advantages: change in state value after taking action (a.k.a advantage)
|
||||||
of shape (batch_size) or
|
of shape (batch_size,) or
|
||||||
of shape (batch_size, time_step).
|
of shape (batch_size, time_step).
|
||||||
:param kl_coefficient: loss coefficient applied kl divergence loss (also see high_kl_penalty_coefficient).
|
:param kl_coefficient: loss coefficient applied kl divergence loss (also see high_kl_penalty_coefficient).
|
||||||
:return: loss, of shape (batch_size).
|
:return: loss, of shape (batch_size).
|
||||||
"""
|
"""
|
||||||
old_var = old_policy_stds ** 2
|
|
||||||
# sets diagonal in (batch size and time step) covariance matrices
|
def diagonal_covariance(stds, size):
|
||||||
old_covar = mx.nd.eye(N=self.num_actions) * (old_var + eps).broadcast_like(old_policy_means).expand_dims(-2)
|
vars = stds ** 2
|
||||||
|
# sets diagonal in (batch size and time step) covariance matrices
|
||||||
|
vars_tiled = vars.expand_dims(2).tile((1, 1, size))
|
||||||
|
covars = F.broadcast_mul(vars_tiled, F.eye(size))
|
||||||
|
return covars
|
||||||
|
|
||||||
|
old_covar = diagonal_covariance(stds=old_policy_stds, size=self.num_actions)
|
||||||
old_policy_dist = MultivariateNormalDist(self.num_actions, old_policy_means, old_covar, F=F)
|
old_policy_dist = MultivariateNormalDist(self.num_actions, old_policy_means, old_covar, F=F)
|
||||||
action_probs_wrt_old_policy = old_policy_dist.log_prob(actions)
|
action_probs_wrt_old_policy = old_policy_dist.log_prob(actions)
|
||||||
|
|
||||||
new_var = new_policy_stds ** 2
|
new_covar = diagonal_covariance(stds=new_policy_stds, size=self.num_actions)
|
||||||
# sets diagonal in (batch size and time step) covariance matrices
|
|
||||||
new_covar = mx.nd.eye(N=self.num_actions) * (new_var + eps).broadcast_like(new_policy_means).expand_dims(-2)
|
|
||||||
new_policy_dist = MultivariateNormalDist(self.num_actions, new_policy_means, new_covar, F=F)
|
new_policy_dist = MultivariateNormalDist(self.num_actions, new_policy_means, new_covar, F=F)
|
||||||
action_probs_wrt_new_policy = new_policy_dist.log_prob(actions)
|
action_probs_wrt_new_policy = new_policy_dist.log_prob(actions)
|
||||||
|
|
||||||
@@ -607,7 +611,7 @@ class PPOHead(Head):
|
|||||||
if isinstance(self.spaces.action, DiscreteActionSpace):
|
if isinstance(self.spaces.action, DiscreteActionSpace):
|
||||||
self.net = DiscretePPOHead(num_actions=len(self.spaces.action.actions))
|
self.net = DiscretePPOHead(num_actions=len(self.spaces.action.actions))
|
||||||
elif isinstance(self.spaces.action, BoxActionSpace):
|
elif isinstance(self.spaces.action, BoxActionSpace):
|
||||||
self.net = ContinuousPPOHead(num_actions=len(self.spaces.action.actions))
|
self.net = ContinuousPPOHead(num_actions=self.spaces.action.shape[0])
|
||||||
else:
|
else:
|
||||||
raise ValueError("Only discrete or continuous action spaces are supported for PPO.")
|
raise ValueError("Only discrete or continuous action spaces are supported for PPO.")
|
||||||
|
|
||||||
@@ -635,7 +639,7 @@ class PPOHead(Head):
|
|||||||
self.kl_cutoff, self.high_kl_penalty_coefficient,
|
self.kl_cutoff, self.high_kl_penalty_coefficient,
|
||||||
self.loss_weight)
|
self.loss_weight)
|
||||||
elif isinstance(self.spaces.action, BoxActionSpace):
|
elif isinstance(self.spaces.action, BoxActionSpace):
|
||||||
loss = ClippedPPOLossContinuous(len(self.spaces.action.actions),
|
loss = ClippedPPOLossContinuous(self.spaces.action.shape[0],
|
||||||
self.clip_likelihood_ratio_using_epsilon,
|
self.clip_likelihood_ratio_using_epsilon,
|
||||||
self.beta,
|
self.beta,
|
||||||
self.use_kl_regularization, self.initial_kl_coefficient,
|
self.use_kl_regularization, self.initial_kl_coefficient,
|
||||||
|
|||||||
Reference in New Issue
Block a user