mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
clean up input embeddings setup
This commit is contained in:
committed by
galleibo-intel
parent
1ff0da2165
commit
9ae2905a76
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -41,6 +41,14 @@ class MiddlewareEmbedder(object):
|
||||
|
||||
class LSTM_Embedder(MiddlewareEmbedder):
|
||||
def _build_module(self):
|
||||
"""
|
||||
self.state_in: tuple of placeholders containing the initial state
|
||||
self.state_out: tuple of output state
|
||||
|
||||
todo: it appears that the shape of the output is batch, feature
|
||||
the code here seems to be slicing off the first element in the batch
|
||||
which would definitely be wrong. need to double check the shape
|
||||
"""
|
||||
|
||||
middleware = tf.layers.dense(self.input, 512, activation=self.activation_function)
|
||||
lstm_cell = tf.contrib.rnn.BasicLSTMCell(256, state_is_tuple=True)
|
||||
|
||||
Reference in New Issue
Block a user