mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
temp commit
This commit is contained in:
@@ -48,7 +48,7 @@ class TensorFlowArchitecture(Architecture):
|
||||
self.network_is_local = network_is_local
|
||||
assert tuning_parameters.agent.tensorflow_support, 'TensorFlow is not supported for this agent'
|
||||
self.sess = tuning_parameters.sess
|
||||
self.inputs = []
|
||||
self.inputs = {}
|
||||
self.outputs = []
|
||||
self.targets = []
|
||||
self.losses = []
|
||||
@@ -106,7 +106,8 @@ class TensorFlowArchitecture(Architecture):
|
||||
# gradients of the outputs w.r.t. the inputs
|
||||
# at the moment, this is only used by ddpg
|
||||
if len(self.outputs) == 1:
|
||||
self.gradients_wrt_inputs = [tf.gradients(self.outputs[0], input_ph) for input_ph in self.inputs]
|
||||
# TODO: convert gradients_with_respect_to_inputs into dictionary?
|
||||
self.gradients_wrt_inputs = [tf.gradients(self.outputs[0], input_ph) for input_ph in self.inputs.values()]
|
||||
self.gradients_weights_ph = tf.placeholder('float32', self.outputs[0].shape, 'output_gradient_weights')
|
||||
self.weighted_gradients = tf.gradients(self.outputs[0], self.trainable_weights, self.gradients_weights_ph)
|
||||
|
||||
@@ -169,9 +170,8 @@ class TensorFlowArchitecture(Architecture):
|
||||
# feed inputs
|
||||
if additional_fetches is None:
|
||||
additional_fetches = []
|
||||
inputs = force_list(inputs)
|
||||
|
||||
feed_dict = dict(zip(self.inputs, inputs))
|
||||
feed_dict = self._feed_dict(inputs)
|
||||
|
||||
# feed targets
|
||||
targets = force_list(targets)
|
||||
@@ -266,6 +266,12 @@ class TensorFlowArchitecture(Architecture):
|
||||
while self.tp.sess.run(self.release_counter) % self.tp.num_threads != 0:
|
||||
time.sleep(0.00001)
|
||||
|
||||
def _feed_dict(self, inputs):
|
||||
return {
|
||||
self.inputs[input_name]: input_value
|
||||
for input_name, input_value in inputs.items()
|
||||
}
|
||||
|
||||
def predict(self, inputs, outputs=None):
|
||||
"""
|
||||
Run a forward pass of the network using the given input
|
||||
@@ -275,8 +281,8 @@ class TensorFlowArchitecture(Architecture):
|
||||
|
||||
WARNING: must only call once per state since each call is assumed by LSTM to be a new time step.
|
||||
"""
|
||||
|
||||
feed_dict = dict(zip(self.inputs, force_list(inputs)))
|
||||
# TODO: rename self.inputs -> self.input_placeholders
|
||||
feed_dict = self._feed_dict(inputs)
|
||||
if outputs is None:
|
||||
outputs = self.outputs
|
||||
|
||||
@@ -290,21 +296,21 @@ class TensorFlowArchitecture(Architecture):
|
||||
|
||||
return squeeze_list(output)
|
||||
|
||||
def train_on_batch(self, inputs, targets, scaler=1., additional_fetches=None):
|
||||
"""
|
||||
Given a batch of examples and targets, runs a forward pass & backward pass and then applies the gradients
|
||||
:param additional_fetches: Optional tensors to fetch during the training process
|
||||
:param inputs: The input for the network
|
||||
:param targets: The targets corresponding to the input batch
|
||||
:param scaler: A scaling factor that allows rescaling the gradients before applying them
|
||||
:return: The loss of the network
|
||||
"""
|
||||
if additional_fetches is None:
|
||||
additional_fetches = []
|
||||
force_list(additional_fetches)
|
||||
loss = self.accumulate_gradients(inputs, targets, additional_fetches=additional_fetches)
|
||||
self.apply_and_reset_gradients(self.accumulated_gradients, scaler)
|
||||
return loss
|
||||
# def train_on_batch(self, inputs, targets, scaler=1., additional_fetches=None):
|
||||
# """
|
||||
# Given a batch of examples and targets, runs a forward pass & backward pass and then applies the gradients
|
||||
# :param additional_fetches: Optional tensors to fetch during the training process
|
||||
# :param inputs: The input for the network
|
||||
# :param targets: The targets corresponding to the input batch
|
||||
# :param scaler: A scaling factor that allows rescaling the gradients before applying them
|
||||
# :return: The loss of the network
|
||||
# """
|
||||
# if additional_fetches is None:
|
||||
# additional_fetches = []
|
||||
# force_list(additional_fetches)
|
||||
# loss = self.accumulate_gradients(inputs, targets, additional_fetches=additional_fetches)
|
||||
# self.apply_and_reset_gradients(self.accumulated_gradients, scaler)
|
||||
# return loss
|
||||
|
||||
def get_weights(self):
|
||||
"""
|
||||
|
||||
@@ -112,7 +112,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
####################
|
||||
|
||||
state_embedding = []
|
||||
for idx, input_type in enumerate(self.tp.agent.input_types):
|
||||
for input_name, input_type in self.tp.agent.input_types.items():
|
||||
# get the class of the input embedder
|
||||
input_embedder = self.get_input_embedder(input_type)
|
||||
self.input_embedders.append(input_embedder)
|
||||
@@ -122,9 +122,9 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
# the existing input_placeholders into the input_embedders.
|
||||
if network_idx == 0:
|
||||
input_placeholder, embedding = input_embedder()
|
||||
self.inputs.append(input_placeholder)
|
||||
self.inputs[input_name] = input_placeholder
|
||||
else:
|
||||
input_placeholder, embedding = input_embedder(self.inputs[idx])
|
||||
input_placeholder, embedding = input_embedder(self.inputs[input_name])
|
||||
|
||||
state_embedding.append(embedding)
|
||||
|
||||
@@ -159,13 +159,15 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
|
||||
# build the head
|
||||
if self.network_is_local:
|
||||
output, target_placeholder, input_placeholder = self.output_heads[-1](head_input)
|
||||
output, target_placeholder, input_placeholders = self.output_heads[-1](head_input)
|
||||
self.targets.extend(target_placeholder)
|
||||
else:
|
||||
output, input_placeholder = self.output_heads[-1](head_input)
|
||||
output, input_placeholders = self.output_heads[-1](head_input)
|
||||
|
||||
self.outputs.extend(output)
|
||||
self.inputs.extend(input_placeholder)
|
||||
# TODO: use head names as well
|
||||
for placeholder_index, input_placeholder in enumerate(input_placeholders):
|
||||
self.inputs['output_{}_{}'.format(head_idx, placeholder_index)] = input_placeholder
|
||||
|
||||
# Losses
|
||||
self.losses = tf.losses.get_losses(self.name)
|
||||
|
||||
@@ -250,7 +250,7 @@ class MeasurementsPredictionHead(Head):
|
||||
name='output')
|
||||
action_stream = tf.reshape(action_stream,
|
||||
(tf.shape(action_stream)[0], self.num_actions, self.multi_step_measurements_size))
|
||||
action_stream = action_stream - tf.reduce_mean(action_stream, reduction_indices=1, keep_dims=True)
|
||||
action_stream = action_stream - tf.reduce_mean(action_stream, reduction_indices=1, keepdims=True)
|
||||
|
||||
# merge to future measurements predictions
|
||||
self.output = tf.add(expectation_stream, action_stream, name='output')
|
||||
@@ -302,7 +302,7 @@ class DNDQHead(Head):
|
||||
square_diff = tf.square(dnd_embeddings - tf.expand_dims(input_layer, 1))
|
||||
distances = tf.reduce_sum(square_diff, axis=2) + [self.l2_norm_added_delta]
|
||||
weights = 1.0 / distances
|
||||
normalised_weights = weights / tf.reduce_sum(weights, axis=1, keep_dims=True)
|
||||
normalised_weights = weights / tf.reduce_sum(weights, axis=1, keepdims=True)
|
||||
return tf.reduce_sum(dnd_values * normalised_weights, axis=1)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user