mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
temp commit
This commit is contained in:
@@ -112,7 +112,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
####################
|
||||
|
||||
state_embedding = []
|
||||
for idx, input_type in enumerate(self.tp.agent.input_types):
|
||||
for input_name, input_type in self.tp.agent.input_types.items():
|
||||
# get the class of the input embedder
|
||||
input_embedder = self.get_input_embedder(input_type)
|
||||
self.input_embedders.append(input_embedder)
|
||||
@@ -122,9 +122,9 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
# the existing input_placeholders into the input_embedders.
|
||||
if network_idx == 0:
|
||||
input_placeholder, embedding = input_embedder()
|
||||
self.inputs.append(input_placeholder)
|
||||
self.inputs[input_name] = input_placeholder
|
||||
else:
|
||||
input_placeholder, embedding = input_embedder(self.inputs[idx])
|
||||
input_placeholder, embedding = input_embedder(self.inputs[input_name])
|
||||
|
||||
state_embedding.append(embedding)
|
||||
|
||||
@@ -159,13 +159,15 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
|
||||
# build the head
|
||||
if self.network_is_local:
|
||||
output, target_placeholder, input_placeholder = self.output_heads[-1](head_input)
|
||||
output, target_placeholder, input_placeholders = self.output_heads[-1](head_input)
|
||||
self.targets.extend(target_placeholder)
|
||||
else:
|
||||
output, input_placeholder = self.output_heads[-1](head_input)
|
||||
output, input_placeholders = self.output_heads[-1](head_input)
|
||||
|
||||
self.outputs.extend(output)
|
||||
self.inputs.extend(input_placeholder)
|
||||
# TODO: use head names as well
|
||||
for placeholder_index, input_placeholder in enumerate(input_placeholders):
|
||||
self.inputs['output_{}_{}'.format(head_idx, placeholder_index)] = input_placeholder
|
||||
|
||||
# Losses
|
||||
self.losses = tf.losses.get_losses(self.name)
|
||||
|
||||
Reference in New Issue
Block a user