1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

clean up input embeddings setup

This commit is contained in:
Zach Dwiel
2017-11-03 13:51:02 -07:00
committed by galleibo-intel
parent 1ff0da2165
commit 9ae2905a76
2 changed files with 20 additions and 12 deletions

View File

@@ -1,5 +1,5 @@
#
# Copyright (c) 2017 Intel Corporation
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -100,7 +100,6 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
local_network_in_distributed_training = self.global_network is not None and self.network_is_local
tuning_parameters.activation_function = self.activation_function
done_creating_input_placeholders = False
for network_idx in range(self.num_networks):
with tf.variable_scope('network_{}'.format(network_idx)):
@@ -111,18 +110,19 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
state_embedding = []
for idx, input_type in enumerate(self.tp.agent.input_types):
# get the class of the input embedder
self.input_embedders.append(self.get_input_embedder(input_type))
input_embedder = self.get_input_embedder(input_type)
self.input_embedders.append(input_embedder)
# in the case each head uses a different network, we still reuse the input placeholders
prev_network_input_placeholder = self.inputs[idx] if done_creating_input_placeholders else None
# create the input embedder instance and store the input placeholder and the embedding
input_placeholder, embedding = self.input_embedders[-1](prev_network_input_placeholder)
if len(self.inputs) < len(self.tp.agent.input_types):
# input placeholders are reused between networks. on the first network, store the placeholders
# generated by the input_embedders in self.inputs. on the rest of the networks, pass
# the existing input_placeholders into the input_embedders.
if network_idx == 0:
input_placeholder, embedding = input_embedder()
self.inputs.append(input_placeholder)
state_embedding.append(embedding)
else:
input_placeholder, embedding = input_embedder(self.inputs[idx])
done_creating_input_placeholders = True
state_embedding.append(embedding)
##############
# Middleware #