mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
clean up input embeddings setup
This commit is contained in:
committed by
galleibo-intel
parent
1ff0da2165
commit
9ae2905a76
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -100,7 +100,6 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
local_network_in_distributed_training = self.global_network is not None and self.network_is_local
|
||||
|
||||
tuning_parameters.activation_function = self.activation_function
|
||||
done_creating_input_placeholders = False
|
||||
|
||||
for network_idx in range(self.num_networks):
|
||||
with tf.variable_scope('network_{}'.format(network_idx)):
|
||||
@@ -111,18 +110,19 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
state_embedding = []
|
||||
for idx, input_type in enumerate(self.tp.agent.input_types):
|
||||
# get the class of the input embedder
|
||||
self.input_embedders.append(self.get_input_embedder(input_type))
|
||||
input_embedder = self.get_input_embedder(input_type)
|
||||
self.input_embedders.append(input_embedder)
|
||||
|
||||
# in the case each head uses a different network, we still reuse the input placeholders
|
||||
prev_network_input_placeholder = self.inputs[idx] if done_creating_input_placeholders else None
|
||||
|
||||
# create the input embedder instance and store the input placeholder and the embedding
|
||||
input_placeholder, embedding = self.input_embedders[-1](prev_network_input_placeholder)
|
||||
if len(self.inputs) < len(self.tp.agent.input_types):
|
||||
# input placeholders are reused between networks. on the first network, store the placeholders
|
||||
# generated by the input_embedders in self.inputs. on the rest of the networks, pass
|
||||
# the existing input_placeholders into the input_embedders.
|
||||
if network_idx == 0:
|
||||
input_placeholder, embedding = input_embedder()
|
||||
self.inputs.append(input_placeholder)
|
||||
state_embedding.append(embedding)
|
||||
else:
|
||||
input_placeholder, embedding = input_embedder(self.inputs[idx])
|
||||
|
||||
done_creating_input_placeholders = True
|
||||
state_embedding.append(embedding)
|
||||
|
||||
##############
|
||||
# Middleware #
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -41,6 +41,14 @@ class MiddlewareEmbedder(object):
|
||||
|
||||
class LSTM_Embedder(MiddlewareEmbedder):
|
||||
def _build_module(self):
|
||||
"""
|
||||
self.state_in: tuple of placeholders containing the initial state
|
||||
self.state_out: tuple of output state
|
||||
|
||||
todo: it appears that the shape of the output is batch, feature
|
||||
the code here seems to be slicing off the first element in the batch
|
||||
which would definitely be wrong. need to double check the shape
|
||||
"""
|
||||
|
||||
middleware = tf.layers.dense(self.input, 512, activation=self.activation_function)
|
||||
lstm_cell = tf.contrib.rnn.BasicLSTMCell(256, state_is_tuple=True)
|
||||
|
||||
Reference in New Issue
Block a user