mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
clean up input embeddings setup
This commit is contained in:
committed by
galleibo-intel
parent
1ff0da2165
commit
9ae2905a76
@@ -1,5 +1,5 @@
|
|||||||
#
|
#
|
||||||
# Copyright (c) 2017 Intel Corporation
|
# Copyright (c) 2017 Intel Corporation
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -100,7 +100,6 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
|||||||
local_network_in_distributed_training = self.global_network is not None and self.network_is_local
|
local_network_in_distributed_training = self.global_network is not None and self.network_is_local
|
||||||
|
|
||||||
tuning_parameters.activation_function = self.activation_function
|
tuning_parameters.activation_function = self.activation_function
|
||||||
done_creating_input_placeholders = False
|
|
||||||
|
|
||||||
for network_idx in range(self.num_networks):
|
for network_idx in range(self.num_networks):
|
||||||
with tf.variable_scope('network_{}'.format(network_idx)):
|
with tf.variable_scope('network_{}'.format(network_idx)):
|
||||||
@@ -111,18 +110,19 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
|||||||
state_embedding = []
|
state_embedding = []
|
||||||
for idx, input_type in enumerate(self.tp.agent.input_types):
|
for idx, input_type in enumerate(self.tp.agent.input_types):
|
||||||
# get the class of the input embedder
|
# get the class of the input embedder
|
||||||
self.input_embedders.append(self.get_input_embedder(input_type))
|
input_embedder = self.get_input_embedder(input_type)
|
||||||
|
self.input_embedders.append(input_embedder)
|
||||||
|
|
||||||
# in the case each head uses a different network, we still reuse the input placeholders
|
# input placeholders are reused between networks. on the first network, store the placeholders
|
||||||
prev_network_input_placeholder = self.inputs[idx] if done_creating_input_placeholders else None
|
# generated by the input_embedders in self.inputs. on the rest of the networks, pass
|
||||||
|
# the existing input_placeholders into the input_embedders.
|
||||||
# create the input embedder instance and store the input placeholder and the embedding
|
if network_idx == 0:
|
||||||
input_placeholder, embedding = self.input_embedders[-1](prev_network_input_placeholder)
|
input_placeholder, embedding = input_embedder()
|
||||||
if len(self.inputs) < len(self.tp.agent.input_types):
|
|
||||||
self.inputs.append(input_placeholder)
|
self.inputs.append(input_placeholder)
|
||||||
state_embedding.append(embedding)
|
else:
|
||||||
|
input_placeholder, embedding = input_embedder(self.inputs[idx])
|
||||||
|
|
||||||
done_creating_input_placeholders = True
|
state_embedding.append(embedding)
|
||||||
|
|
||||||
##############
|
##############
|
||||||
# Middleware #
|
# Middleware #
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
#
|
#
|
||||||
# Copyright (c) 2017 Intel Corporation
|
# Copyright (c) 2017 Intel Corporation
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -41,6 +41,14 @@ class MiddlewareEmbedder(object):
|
|||||||
|
|
||||||
class LSTM_Embedder(MiddlewareEmbedder):
|
class LSTM_Embedder(MiddlewareEmbedder):
|
||||||
def _build_module(self):
|
def _build_module(self):
|
||||||
|
"""
|
||||||
|
self.state_in: tuple of placeholders containing the initial state
|
||||||
|
self.state_out: tuple of output state
|
||||||
|
|
||||||
|
todo: it appears that the shape of the output is batch, feature
|
||||||
|
the code here seems to be slicing off the first element in the batch
|
||||||
|
which would definitely be wrong. need to double check the shape
|
||||||
|
"""
|
||||||
|
|
||||||
middleware = tf.layers.dense(self.input, 512, activation=self.activation_function)
|
middleware = tf.layers.dense(self.input, 512, activation=self.activation_function)
|
||||||
lstm_cell = tf.contrib.rnn.BasicLSTMCell(256, state_is_tuple=True)
|
lstm_cell = tf.contrib.rnn.BasicLSTMCell(256, state_is_tuple=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user