mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
fix more agents
This commit is contained in:
@@ -296,16 +296,16 @@ class TensorFlowArchitecture(Architecture):
|
||||
|
||||
return feed_dict
|
||||
|
||||
def predict(self, inputs, outputs=None):
|
||||
def predict(self, inputs, outputs=None, squeeze_output=True):
|
||||
"""
|
||||
Run a forward pass of the network using the given input
|
||||
:param inputs: The input for the network
|
||||
:param outputs: The output for the network, defaults to self.outputs
|
||||
:param squeeze_output: call squeeze_list on output
|
||||
:return: The network output
|
||||
|
||||
WARNING: must only call once per state since each call is assumed by LSTM to be a new time step.
|
||||
"""
|
||||
# TODO: rename self.inputs -> self.input_placeholders
|
||||
feed_dict = self._feed_dict(inputs)
|
||||
if outputs is None:
|
||||
outputs = self.outputs
|
||||
@@ -318,7 +318,10 @@ class TensorFlowArchitecture(Architecture):
|
||||
else:
|
||||
output = self.tp.sess.run(outputs, feed_dict)
|
||||
|
||||
return squeeze_list(output)
|
||||
if squeeze_output:
|
||||
output = squeeze_list(output)
|
||||
|
||||
return output
|
||||
|
||||
# def train_on_batch(self, inputs, targets, scaler=1., additional_fetches=None):
|
||||
# """
|
||||
|
||||
Reference in New Issue
Block a user