1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

Export graph to ONNX (#61)

Implements the ONNX graph exporting feature. 
Currently does not work for NAF, C51 and A3C_LSTM due to unsupported TF layers in the tf2onnx library.
This commit is contained in:
Itai Caspi
2018-11-06 10:55:21 +02:00
committed by GitHub
parent d75df17d97
commit 811152126c
4 changed files with 107 additions and 27 deletions

View File

@@ -127,11 +127,15 @@ class PPOHead(Head):
self.input = [self.actions, self.old_policy_mean, self.old_policy_std]
self.policy_mean = self.dense_layer(num_actions)(input_layer, name='policy_mean',
kernel_initializer=normalized_columns_initializer(0.01))
if self.is_local:
# for local networks in distributed settings, we need to move variables we create manually to the
# tf.GraphKeys.LOCAL_VARIABLES collection, since the variable scope custom getter which is set in
# Architecture does not apply to them
if self.is_local and isinstance(self.ap.task_parameters, DistributedTaskParameters):
self.policy_logstd = tf.Variable(np.zeros((1, num_actions)), dtype='float32',
collections=[tf.GraphKeys.LOCAL_VARIABLES])
collections=[tf.GraphKeys.LOCAL_VARIABLES], name="policy_log_std")
else:
self.policy_logstd = tf.Variable(np.zeros((1, num_actions)), dtype='float32')
self.policy_logstd = tf.Variable(np.zeros((1, num_actions)), dtype='float32', name="policy_log_std")
self.policy_std = tf.tile(tf.exp(self.policy_logstd), [tf.shape(input_layer)[0], 1], name='policy_std')