1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

clean up input embeddings setup

This commit is contained in:
Zach Dwiel
2017-11-03 13:51:02 -07:00
committed by galleibo-intel
parent 1ff0da2165
commit 9ae2905a76
2 changed files with 20 additions and 12 deletions

View File

@@ -1,5 +1,5 @@
#
# Copyright (c) 2017 Intel Corporation
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -100,7 +100,6 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
local_network_in_distributed_training = self.global_network is not None and self.network_is_local
tuning_parameters.activation_function = self.activation_function
done_creating_input_placeholders = False
for network_idx in range(self.num_networks):
with tf.variable_scope('network_{}'.format(network_idx)):
@@ -111,18 +110,19 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
state_embedding = []
for idx, input_type in enumerate(self.tp.agent.input_types):
# get the class of the input embedder
self.input_embedders.append(self.get_input_embedder(input_type))
input_embedder = self.get_input_embedder(input_type)
self.input_embedders.append(input_embedder)
# in the case each head uses a different network, we still reuse the input placeholders
prev_network_input_placeholder = self.inputs[idx] if done_creating_input_placeholders else None
# create the input embedder instance and store the input placeholder and the embedding
input_placeholder, embedding = self.input_embedders[-1](prev_network_input_placeholder)
if len(self.inputs) < len(self.tp.agent.input_types):
# input placeholders are reused between networks. on the first network, store the placeholders
# generated by the input_embedders in self.inputs. on the rest of the networks, pass
# the existing input_placeholders into the input_embedders.
if network_idx == 0:
input_placeholder, embedding = input_embedder()
self.inputs.append(input_placeholder)
state_embedding.append(embedding)
else:
input_placeholder, embedding = input_embedder(self.inputs[idx])
done_creating_input_placeholders = True
state_embedding.append(embedding)
##############
# Middleware #

View File

@@ -1,5 +1,5 @@
#
# Copyright (c) 2017 Intel Corporation
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -41,6 +41,14 @@ class MiddlewareEmbedder(object):
class LSTM_Embedder(MiddlewareEmbedder):
def _build_module(self):
"""
self.state_in: tuple of placeholders containing the initial state
self.state_out: tuple of output state
todo: it appears that the shape of the output is batch, feature
the code here seems to be slicing off the first element in the batch
which would definitely be wrong. need to double check the shape
"""
middleware = tf.layers.dense(self.input, 512, activation=self.activation_function)
lstm_cell = tf.contrib.rnn.BasicLSTMCell(256, state_is_tuple=True)