mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Export graph to ONNX (#61)
Implements the ONNX graph exporting feature. Currently does not work for NAF, C51 and A3C_LSTM due to unsupported TF layers in the tf2onnx library.
This commit is contained in:
@@ -127,11 +127,15 @@ class PPOHead(Head):
|
||||
self.input = [self.actions, self.old_policy_mean, self.old_policy_std]
|
||||
self.policy_mean = self.dense_layer(num_actions)(input_layer, name='policy_mean',
|
||||
kernel_initializer=normalized_columns_initializer(0.01))
|
||||
if self.is_local:
|
||||
|
||||
# for local networks in distributed settings, we need to move variables we create manually to the
|
||||
# tf.GraphKeys.LOCAL_VARIABLES collection, since the variable scope custom getter which is set in
|
||||
# Architecture does not apply to them
|
||||
if self.is_local and isinstance(self.ap.task_parameters, DistributedTaskParameters):
|
||||
self.policy_logstd = tf.Variable(np.zeros((1, num_actions)), dtype='float32',
|
||||
collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||
collections=[tf.GraphKeys.LOCAL_VARIABLES], name="policy_log_std")
|
||||
else:
|
||||
self.policy_logstd = tf.Variable(np.zeros((1, num_actions)), dtype='float32')
|
||||
self.policy_logstd = tf.Variable(np.zeros((1, num_actions)), dtype='float32', name="policy_log_std")
|
||||
|
||||
self.policy_std = tf.tile(tf.exp(self.policy_logstd), [tf.shape(input_layer)[0], 1], name='policy_std')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user