From 9ae2905a768f658d6ee6a15b63fff62b7340d686 Mon Sep 17 00:00:00 2001 From: Zach Dwiel Date: Fri, 3 Nov 2017 13:51:02 -0700 Subject: [PATCH] clean up input embeddings setup --- .../tensorflow_components/general_network.py | 22 +++++++++---------- .../tensorflow_components/middleware.py | 10 ++++++++- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/architectures/tensorflow_components/general_network.py b/architectures/tensorflow_components/general_network.py index 9b20082..842ad66 100644 --- a/architectures/tensorflow_components/general_network.py +++ b/architectures/tensorflow_components/general_network.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2017 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -100,7 +100,6 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture): local_network_in_distributed_training = self.global_network is not None and self.network_is_local tuning_parameters.activation_function = self.activation_function - done_creating_input_placeholders = False for network_idx in range(self.num_networks): with tf.variable_scope('network_{}'.format(network_idx)): @@ -111,18 +110,19 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture): state_embedding = [] for idx, input_type in enumerate(self.tp.agent.input_types): # get the class of the input embedder - self.input_embedders.append(self.get_input_embedder(input_type)) + input_embedder = self.get_input_embedder(input_type) + self.input_embedders.append(input_embedder) - # in the case each head uses a different network, we still reuse the input placeholders - prev_network_input_placeholder = self.inputs[idx] if done_creating_input_placeholders else None - - # create the input embedder instance and store the input placeholder and the embedding - input_placeholder, embedding = self.input_embedders[-1](prev_network_input_placeholder) - if len(self.inputs) < len(self.tp.agent.input_types): + # input placeholders are reused between networks. on the first network, store the placeholders + # generated by the input_embedders in self.inputs. on the rest of the networks, pass + # the existing input_placeholders into the input_embedders. + if network_idx == 0: + input_placeholder, embedding = input_embedder() self.inputs.append(input_placeholder) - state_embedding.append(embedding) + else: + input_placeholder, embedding = input_embedder(self.inputs[idx]) - done_creating_input_placeholders = True + state_embedding.append(embedding) ############## # Middleware # diff --git a/architectures/tensorflow_components/middleware.py b/architectures/tensorflow_components/middleware.py index 318953c..e491bf0 100644 --- a/architectures/tensorflow_components/middleware.py +++ b/architectures/tensorflow_components/middleware.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2017 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -41,6 +41,14 @@ class MiddlewareEmbedder(object): class LSTM_Embedder(MiddlewareEmbedder): def _build_module(self): + """ + self.state_in: tuple of placeholders containing the initial state + self.state_out: tuple of output state + + todo: it appears that the shape of the output is batch, feature + the code here seems to be slicing off the first element in the batch + which would definitely be wrong. need to double check the shape + """ middleware = tf.layers.dense(self.input, 512, activation=self.activation_function) lstm_cell = tf.contrib.rnn.BasicLSTMCell(256, state_is_tuple=True)