1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

fix more agents

This commit is contained in:
Zach Dwiel
2018-02-16 20:06:51 -05:00
parent 98f57a0d87
commit 8248caf35e
6 changed files with 52 additions and 42 deletions

View File

@@ -296,16 +296,16 @@ class TensorFlowArchitecture(Architecture):
return feed_dict
def predict(self, inputs, outputs=None):
def predict(self, inputs, outputs=None, squeeze_output=True):
"""
Run a forward pass of the network using the given input
:param inputs: The input for the network
:param outputs: The output for the network, defaults to self.outputs
:param squeeze_output: call squeeze_list on output
:return: The network output
WARNING: must only call once per state since each call is assumed by LSTM to be a new time step.
"""
# TODO: rename self.inputs -> self.input_placeholders
feed_dict = self._feed_dict(inputs)
if outputs is None:
outputs = self.outputs
@@ -318,7 +318,10 @@ class TensorFlowArchitecture(Architecture):
else:
output = self.tp.sess.run(outputs, feed_dict)
return squeeze_list(output)
if squeeze_output:
output = squeeze_list(output)
return output
# def train_on_batch(self, inputs, targets, scaler=1., additional_fetches=None):
# """