1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

temp commit

This commit is contained in:
Zach Dwiel
2018-02-16 09:35:58 -05:00
parent 16c5032735
commit 85afb86893
14 changed files with 244 additions and 127 deletions

View File

@@ -112,7 +112,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
####################
state_embedding = []
for idx, input_type in enumerate(self.tp.agent.input_types):
for input_name, input_type in self.tp.agent.input_types.items():
# get the class of the input embedder
input_embedder = self.get_input_embedder(input_type)
self.input_embedders.append(input_embedder)
@@ -122,9 +122,9 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
# the existing input_placeholders into the input_embedders.
if network_idx == 0:
input_placeholder, embedding = input_embedder()
self.inputs.append(input_placeholder)
self.inputs[input_name] = input_placeholder
else:
input_placeholder, embedding = input_embedder(self.inputs[idx])
input_placeholder, embedding = input_embedder(self.inputs[input_name])
state_embedding.append(embedding)
@@ -159,13 +159,15 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
# build the head
if self.network_is_local:
output, target_placeholder, input_placeholder = self.output_heads[-1](head_input)
output, target_placeholder, input_placeholders = self.output_heads[-1](head_input)
self.targets.extend(target_placeholder)
else:
output, input_placeholder = self.output_heads[-1](head_input)
output, input_placeholders = self.output_heads[-1](head_input)
self.outputs.extend(output)
self.inputs.extend(input_placeholder)
# TODO: use head names as well
for placeholder_index, input_placeholder in enumerate(input_placeholders):
self.inputs['output_{}_{}'.format(head_idx, placeholder_index)] = input_placeholder
# Losses
self.losses = tf.losses.get_losses(self.name)