diff --git a/architectures/tensorflow_components/architecture.py b/architectures/tensorflow_components/architecture.py index 41f7216..3474ff4 100644 --- a/architectures/tensorflow_components/architecture.py +++ b/architectures/tensorflow_components/architecture.py @@ -321,22 +321,6 @@ class TensorFlowArchitecture(Architecture): return output - # def train_on_batch(self, inputs, targets, scaler=1., additional_fetches=None): - # """ - # Given a batch of examples and targets, runs a forward pass & backward pass and then applies the gradients - # :param additional_fetches: Optional tensors to fetch during the training process - # :param inputs: The input for the network - # :param targets: The targets corresponding to the input batch - # :param scaler: A scaling factor that allows rescaling the gradients before applying them - # :return: The loss of the network - # """ - # if additional_fetches is None: - # additional_fetches = [] - # force_list(additional_fetches) - # loss = self.accumulate_gradients(inputs, targets, additional_fetches=additional_fetches) - # self.apply_and_reset_gradients(self.accumulated_gradients, scaler) - # return loss - def get_weights(self): """ :return: a list of tensors containing the network weights for each layer