1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

temp commit

This commit is contained in:
Zach Dwiel
2018-02-16 09:35:58 -05:00
parent 16c5032735
commit 85afb86893
14 changed files with 244 additions and 127 deletions

View File

@@ -48,7 +48,7 @@ class TensorFlowArchitecture(Architecture):
self.network_is_local = network_is_local
assert tuning_parameters.agent.tensorflow_support, 'TensorFlow is not supported for this agent'
self.sess = tuning_parameters.sess
self.inputs = []
self.inputs = {}
self.outputs = []
self.targets = []
self.losses = []
@@ -106,7 +106,8 @@ class TensorFlowArchitecture(Architecture):
# gradients of the outputs w.r.t. the inputs
# at the moment, this is only used by ddpg
if len(self.outputs) == 1:
self.gradients_wrt_inputs = [tf.gradients(self.outputs[0], input_ph) for input_ph in self.inputs]
# TODO: convert gradients_with_respect_to_inputs into dictionary?
self.gradients_wrt_inputs = [tf.gradients(self.outputs[0], input_ph) for input_ph in self.inputs.values()]
self.gradients_weights_ph = tf.placeholder('float32', self.outputs[0].shape, 'output_gradient_weights')
self.weighted_gradients = tf.gradients(self.outputs[0], self.trainable_weights, self.gradients_weights_ph)
@@ -169,9 +170,8 @@ class TensorFlowArchitecture(Architecture):
# feed inputs
if additional_fetches is None:
additional_fetches = []
inputs = force_list(inputs)
feed_dict = dict(zip(self.inputs, inputs))
feed_dict = self._feed_dict(inputs)
# feed targets
targets = force_list(targets)
@@ -266,6 +266,12 @@ class TensorFlowArchitecture(Architecture):
while self.tp.sess.run(self.release_counter) % self.tp.num_threads != 0:
time.sleep(0.00001)
def _feed_dict(self, inputs):
return {
self.inputs[input_name]: input_value
for input_name, input_value in inputs.items()
}
def predict(self, inputs, outputs=None):
"""
Run a forward pass of the network using the given input
@@ -275,8 +281,8 @@ class TensorFlowArchitecture(Architecture):
WARNING: must only call once per state since each call is assumed by LSTM to be a new time step.
"""
feed_dict = dict(zip(self.inputs, force_list(inputs)))
# TODO: rename self.inputs -> self.input_placeholders
feed_dict = self._feed_dict(inputs)
if outputs is None:
outputs = self.outputs
@@ -290,21 +296,21 @@ class TensorFlowArchitecture(Architecture):
return squeeze_list(output)
def train_on_batch(self, inputs, targets, scaler=1., additional_fetches=None):
"""
Given a batch of examples and targets, runs a forward pass & backward pass and then applies the gradients
:param additional_fetches: Optional tensors to fetch during the training process
:param inputs: The input for the network
:param targets: The targets corresponding to the input batch
:param scaler: A scaling factor that allows rescaling the gradients before applying them
:return: The loss of the network
"""
if additional_fetches is None:
additional_fetches = []
force_list(additional_fetches)
loss = self.accumulate_gradients(inputs, targets, additional_fetches=additional_fetches)
self.apply_and_reset_gradients(self.accumulated_gradients, scaler)
return loss
# def train_on_batch(self, inputs, targets, scaler=1., additional_fetches=None):
# """
# Given a batch of examples and targets, runs a forward pass & backward pass and then applies the gradients
# :param additional_fetches: Optional tensors to fetch during the training process
# :param inputs: The input for the network
# :param targets: The targets corresponding to the input batch
# :param scaler: A scaling factor that allows rescaling the gradients before applying them
# :return: The loss of the network
# """
# if additional_fetches is None:
# additional_fetches = []
# force_list(additional_fetches)
# loss = self.accumulate_gradients(inputs, targets, additional_fetches=additional_fetches)
# self.apply_and_reset_gradients(self.accumulated_gradients, scaler)
# return loss
def get_weights(self):
"""