batchnorm fixes + disabling batchnorm in DDPG (#353)
Co-authored-by: James Casbon <casbon+gh@gmail.com>
@@ -38,6 +38,7 @@ The environments that were used for testing include:
|
||||
|**[Clipped PPO](clipped_ppo)** |  |Mujoco | |
|
||||
|**[DDPG](ddpg)** |  |Mujoco | |
|
||||
|**[SAC](sac)** |  |Mujoco | |
|
||||
|**[TD3](td3)** |  |Mujoco | |
|
||||
|**[NEC](nec)** |  |Atari | |
|
||||
|**[HER](ddpg_her)** |  |Fetch | |
|
||||
|**[DFP](dfp)** |  |Doom | Doom Battle was not verified |
|
||||
|
||||
|
Before Width: | Height: | Size: 135 KiB After Width: | Height: | Size: 109 KiB |
|
Before Width: | Height: | Size: 89 KiB After Width: | Height: | Size: 109 KiB |
|
Before Width: | Height: | Size: 111 KiB After Width: | Height: | Size: 124 KiB |
|
Before Width: | Height: | Size: 113 KiB After Width: | Height: | Size: 82 KiB |
|
Before Width: | Height: | Size: 104 KiB After Width: | Height: | Size: 98 KiB |
|
Before Width: | Height: | Size: 127 KiB After Width: | Height: | Size: 135 KiB |
|
Before Width: | Height: | Size: 70 KiB After Width: | Height: | Size: 98 KiB |
|
Before Width: | Height: | Size: 82 KiB After Width: | Height: | Size: 117 KiB |
|
Before Width: | Height: | Size: 119 KiB After Width: | Height: | Size: 115 KiB |
@@ -34,9 +34,9 @@ from rl_coach.spaces import BoxActionSpace, GoalsSpace
|
||||
|
||||
|
||||
class DDPGCriticNetworkParameters(NetworkParameters):
|
||||
def __init__(self):
|
||||
def __init__(self, use_batchnorm=False):
|
||||
super().__init__()
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters(batchnorm=True),
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters(batchnorm=use_batchnorm),
|
||||
'action': InputEmbedderParameters(scheme=EmbedderScheme.Shallow)}
|
||||
self.middleware_parameters = FCMiddlewareParameters()
|
||||
self.heads_parameters = [DDPGVHeadParameters()]
|
||||
@@ -53,11 +53,11 @@ class DDPGCriticNetworkParameters(NetworkParameters):
|
||||
|
||||
|
||||
class DDPGActorNetworkParameters(NetworkParameters):
|
||||
def __init__(self):
|
||||
def __init__(self, use_batchnorm=False):
|
||||
super().__init__()
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters(batchnorm=True)}
|
||||
self.middleware_parameters = FCMiddlewareParameters(batchnorm=True)
|
||||
self.heads_parameters = [DDPGActorHeadParameters()]
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters(batchnorm=use_batchnorm)}
|
||||
self.middleware_parameters = FCMiddlewareParameters(batchnorm=use_batchnorm)
|
||||
self.heads_parameters = [DDPGActorHeadParameters(batchnorm=use_batchnorm)]
|
||||
self.optimizer_type = 'Adam'
|
||||
self.batch_size = 64
|
||||
self.adam_optimizer_beta2 = 0.999
|
||||
@@ -109,12 +109,12 @@ class DDPGAlgorithmParameters(AlgorithmParameters):
|
||||
|
||||
|
||||
class DDPGAgentParameters(AgentParameters):
|
||||
def __init__(self):
|
||||
def __init__(self, use_batchnorm=False):
|
||||
super().__init__(algorithm=DDPGAlgorithmParameters(),
|
||||
exploration=OUProcessParameters(),
|
||||
memory=EpisodicExperienceReplayParameters(),
|
||||
networks=OrderedDict([("actor", DDPGActorNetworkParameters()),
|
||||
("critic", DDPGCriticNetworkParameters())]))
|
||||
networks=OrderedDict([("actor", DDPGActorNetworkParameters(use_batchnorm=use_batchnorm)),
|
||||
("critic", DDPGCriticNetworkParameters(use_batchnorm=use_batchnorm))]))
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
@@ -170,7 +170,9 @@ class DDPGAgent(ActorCriticAgent):
|
||||
# train the critic
|
||||
critic_inputs = copy.copy(batch.states(critic_keys))
|
||||
critic_inputs['action'] = batch.actions(len(batch.actions().shape) == 1)
|
||||
result = critic.train_and_sync_networks(critic_inputs, TD_targets)
|
||||
|
||||
# also need the inputs for when applying gradients so batchnorm's update of running mean and stddev will work
|
||||
result = critic.train_and_sync_networks(critic_inputs, TD_targets, use_inputs_for_apply_gradients=True)
|
||||
total_loss, losses, unclipped_grads = result[:3]
|
||||
|
||||
# apply the gradients from the critic to the actor
|
||||
@@ -179,11 +181,12 @@ class DDPGAgent(ActorCriticAgent):
|
||||
outputs=actor.online_network.weighted_gradients[0],
|
||||
initial_feed_dict=initial_feed_dict)
|
||||
|
||||
# also need the inputs for when applying gradients so batchnorm's update of running mean and stddev will work
|
||||
if actor.has_global:
|
||||
actor.apply_gradients_to_global_network(gradients)
|
||||
actor.apply_gradients_to_global_network(gradients, additional_inputs=copy.copy(batch.states(critic_keys)))
|
||||
actor.update_online_network()
|
||||
else:
|
||||
actor.apply_gradients_to_online_network(gradients)
|
||||
actor.apply_gradients_to_online_network(gradients, additional_inputs=copy.copy(batch.states(critic_keys)))
|
||||
|
||||
return total_loss, losses, unclipped_grads
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ from rl_coach.base_parameters import NetworkComponentParameters
|
||||
class HeadParameters(NetworkComponentParameters):
|
||||
def __init__(self, parameterized_class_name: str, activation_function: str = 'relu', name: str= 'head',
|
||||
num_output_head_copies: int=1, rescale_gradient_from_head_by_factor: float=1.0,
|
||||
loss_weight: float=1.0, dense_layer=None):
|
||||
loss_weight: float=1.0, dense_layer=None, is_training=False):
|
||||
super().__init__(dense_layer=dense_layer)
|
||||
self.activation_function = activation_function
|
||||
self.name = name
|
||||
@@ -30,6 +30,7 @@ class HeadParameters(NetworkComponentParameters):
|
||||
self.rescale_gradient_from_head_by_factor = rescale_gradient_from_head_by_factor
|
||||
self.loss_weight = loss_weight
|
||||
self.parameterized_class_name = parameterized_class_name
|
||||
self.is_training = is_training
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
|
||||
@@ -124,31 +124,37 @@ class NetworkWrapper(object):
|
||||
if self.global_network:
|
||||
self.online_network.set_weights(self.global_network.get_weights(), rate)
|
||||
|
||||
def apply_gradients_to_global_network(self, gradients=None):
|
||||
def apply_gradients_to_global_network(self, gradients=None, additional_inputs=None):
|
||||
"""
|
||||
Apply gradients from the online network on the global network
|
||||
|
||||
:param gradients: optional gradients that will be used instead of teh accumulated gradients
|
||||
:param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's
|
||||
update ops also requires the inputs)
|
||||
:return:
|
||||
"""
|
||||
if gradients is None:
|
||||
gradients = self.online_network.accumulated_gradients
|
||||
if self.network_parameters.shared_optimizer:
|
||||
self.global_network.apply_gradients(gradients)
|
||||
self.global_network.apply_gradients(gradients, additional_inputs=additional_inputs)
|
||||
else:
|
||||
self.online_network.apply_gradients(gradients)
|
||||
self.online_network.apply_gradients(gradients, additional_inputs=additional_inputs)
|
||||
|
||||
def apply_gradients_to_online_network(self, gradients=None):
|
||||
def apply_gradients_to_online_network(self, gradients=None, additional_inputs=None):
|
||||
"""
|
||||
Apply gradients from the online network on itself
|
||||
:param gradients: optional gradients that will be used instead of teh accumulated gradients
|
||||
:param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's
|
||||
update ops also requires the inputs)
|
||||
|
||||
:return:
|
||||
"""
|
||||
if gradients is None:
|
||||
gradients = self.online_network.accumulated_gradients
|
||||
self.online_network.apply_gradients(gradients)
|
||||
self.online_network.apply_gradients(gradients, additional_inputs=additional_inputs)
|
||||
|
||||
def train_and_sync_networks(self, inputs, targets, additional_fetches=[], importance_weights=None):
|
||||
def train_and_sync_networks(self, inputs, targets, additional_fetches=[], importance_weights=None,
|
||||
use_inputs_for_apply_gradients=False):
|
||||
"""
|
||||
A generic training function that enables multi-threading training using a global network if necessary.
|
||||
|
||||
@@ -157,14 +163,20 @@ class NetworkWrapper(object):
|
||||
:param additional_fetches: Any additional tensor the user wants to fetch
|
||||
:param importance_weights: A coefficient for each sample in the batch, which will be used to rescale the loss
|
||||
error of this sample. If it is not given, the samples losses won't be scaled
|
||||
:param use_inputs_for_apply_gradients: Add the inputs also for when applying gradients
|
||||
(e.g. for incorporating batchnorm update ops)
|
||||
:return: The loss of the training iteration
|
||||
"""
|
||||
result = self.online_network.accumulate_gradients(inputs, targets, additional_fetches=additional_fetches,
|
||||
importance_weights=importance_weights, no_accumulation=True)
|
||||
self.apply_gradients_and_sync_networks(reset_gradients=False)
|
||||
if use_inputs_for_apply_gradients:
|
||||
self.apply_gradients_and_sync_networks(reset_gradients=False, additional_inputs=inputs)
|
||||
else:
|
||||
self.apply_gradients_and_sync_networks(reset_gradients=False)
|
||||
|
||||
return result
|
||||
|
||||
def apply_gradients_and_sync_networks(self, reset_gradients=True):
|
||||
def apply_gradients_and_sync_networks(self, reset_gradients=True, additional_inputs=None):
|
||||
"""
|
||||
Applies the gradients accumulated in the online network to the global network or to itself and syncs the
|
||||
networks if necessary
|
||||
@@ -173,17 +185,22 @@ class NetworkWrapper(object):
|
||||
the network. this is useful when the accumulated gradients are overwritten instead
|
||||
if accumulated by the accumulate_gradients function. this allows reducing time
|
||||
complexity for this function by around 10%
|
||||
:param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's
|
||||
update ops also requires the inputs)
|
||||
|
||||
"""
|
||||
if self.global_network:
|
||||
self.apply_gradients_to_global_network()
|
||||
self.apply_gradients_to_global_network(additional_inputs=additional_inputs)
|
||||
if reset_gradients:
|
||||
self.online_network.reset_accumulated_gradients()
|
||||
self.update_online_network()
|
||||
else:
|
||||
if reset_gradients:
|
||||
self.online_network.apply_and_reset_gradients(self.online_network.accumulated_gradients)
|
||||
self.online_network.apply_and_reset_gradients(self.online_network.accumulated_gradients,
|
||||
additional_inputs=additional_inputs)
|
||||
else:
|
||||
self.online_network.apply_gradients(self.online_network.accumulated_gradients)
|
||||
self.online_network.apply_gradients(self.online_network.accumulated_gradients,
|
||||
additional_inputs=additional_inputs)
|
||||
|
||||
def parallel_prediction(self, network_input_tuples: List[Tuple]):
|
||||
"""
|
||||
|
||||
@@ -270,8 +270,11 @@ class TensorFlowArchitecture(Architecture):
|
||||
elif self.network_is_trainable:
|
||||
# not any of the above but is trainable? -> create an operation for applying the gradients to
|
||||
# this network weights
|
||||
self.update_weights_from_batch_gradients = self.optimizer.apply_gradients(
|
||||
zip(self.weights_placeholders, self.weights), global_step=self.global_step)
|
||||
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.full_name)
|
||||
|
||||
with tf.control_dependencies(update_ops):
|
||||
self.update_weights_from_batch_gradients = self.optimizer.apply_gradients(
|
||||
zip(self.weights_placeholders, self.weights), global_step=self.global_step)
|
||||
|
||||
def set_session(self, sess):
|
||||
self.sess = sess
|
||||
@@ -414,13 +417,16 @@ class TensorFlowArchitecture(Architecture):
|
||||
|
||||
return feed_dict
|
||||
|
||||
def apply_and_reset_gradients(self, gradients, scaler=1.):
|
||||
def apply_and_reset_gradients(self, gradients, scaler=1., additional_inputs=None):
|
||||
"""
|
||||
Applies the given gradients to the network weights and resets the accumulation placeholder
|
||||
:param gradients: The gradients to use for the update
|
||||
:param scaler: A scaling factor that allows rescaling the gradients before applying them
|
||||
:param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's
|
||||
update ops also requires the inputs)
|
||||
|
||||
"""
|
||||
self.apply_gradients(gradients, scaler)
|
||||
self.apply_gradients(gradients, scaler, additional_inputs=additional_inputs)
|
||||
self.reset_accumulated_gradients()
|
||||
|
||||
def wait_for_all_workers_to_lock(self, lock: str, include_only_training_workers: bool=False):
|
||||
@@ -460,13 +466,16 @@ class TensorFlowArchitecture(Architecture):
|
||||
self.wait_for_all_workers_to_lock('release', include_only_training_workers=include_only_training_workers)
|
||||
self.sess.run(self.release_init)
|
||||
|
||||
def apply_gradients(self, gradients, scaler=1.):
|
||||
def apply_gradients(self, gradients, scaler=1., additional_inputs=None):
|
||||
"""
|
||||
Applies the given gradients to the network weights
|
||||
:param gradients: The gradients to use for the update
|
||||
:param scaler: A scaling factor that allows rescaling the gradients before applying them.
|
||||
The gradients will be MULTIPLIED by this factor
|
||||
:param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's
|
||||
update ops also requires the inputs)
|
||||
"""
|
||||
|
||||
if self.network_parameters.async_training or not isinstance(self.ap.task_parameters, DistributedTaskParameters):
|
||||
if hasattr(self, 'global_step') and not self.network_is_local:
|
||||
self.sess.run(self.inc_step)
|
||||
@@ -503,6 +512,8 @@ class TensorFlowArchitecture(Architecture):
|
||||
# async distributed training / distributed training with independent optimizer
|
||||
# / non-distributed training - just apply the gradients
|
||||
feed_dict = dict(zip(self.weights_placeholders, gradients))
|
||||
if additional_inputs is not None:
|
||||
feed_dict = {**feed_dict, **self.create_feed_dict(additional_inputs)}
|
||||
self.sess.run(self.update_weights_from_batch_gradients, feed_dict=feed_dict)
|
||||
|
||||
# release barrier
|
||||
|
||||
@@ -185,6 +185,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
|
||||
embedder_path = embedder_params.path(emb_type)
|
||||
embedder_params_copy = copy.copy(embedder_params)
|
||||
embedder_params_copy.is_training = self.is_training
|
||||
embedder_params_copy.activation_function = utils.get_activation_function(embedder_params.activation_function)
|
||||
embedder_params_copy.input_rescaling = embedder_params_copy.input_rescaling[emb_type]
|
||||
embedder_params_copy.input_offset = embedder_params_copy.input_offset[emb_type]
|
||||
@@ -204,6 +205,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
middleware_path = middleware_params.path
|
||||
middleware_params_copy = copy.copy(middleware_params)
|
||||
middleware_params_copy.activation_function = utils.get_activation_function(middleware_params.activation_function)
|
||||
middleware_params_copy.is_training = self.is_training
|
||||
module = dynamic_import_and_instantiate_module_from_params(middleware_params_copy, path=middleware_path)
|
||||
return module
|
||||
|
||||
@@ -218,6 +220,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
head_path = head_params.path
|
||||
head_params_copy = copy.copy(head_params)
|
||||
head_params_copy.activation_function = utils.get_activation_function(head_params_copy.activation_function)
|
||||
head_params_copy.is_training = self.is_training
|
||||
return dynamic_import_and_instantiate_module_from_params(head_params_copy, path=head_path, extra_kwargs={
|
||||
'agent_parameters': self.ap, 'spaces': self.spaces, 'network_name': self.network_wrapper_name,
|
||||
'head_idx': head_idx, 'is_local': self.network_is_local})
|
||||
@@ -339,7 +342,11 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
head_count += 1
|
||||
|
||||
# model weights
|
||||
self.weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.full_name)
|
||||
if not self.distributed_training or self.network_is_global:
|
||||
self.weights = [var for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.full_name) if
|
||||
'global_step' not in var.name]
|
||||
else:
|
||||
self.weights = [var for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.full_name)]
|
||||
|
||||
# Losses
|
||||
self.losses = tf.losses.get_losses(self.full_name)
|
||||
|
||||
@@ -26,9 +26,9 @@ from rl_coach.spaces import SpacesDefinition
|
||||
class DDPGActor(Head):
|
||||
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
|
||||
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='tanh',
|
||||
batchnorm: bool=True, dense_layer=Dense):
|
||||
batchnorm: bool=True, dense_layer=Dense, is_training=False):
|
||||
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
|
||||
dense_layer=dense_layer)
|
||||
dense_layer=dense_layer, is_training=is_training)
|
||||
self.name = 'ddpg_actor_head'
|
||||
self.return_type = ActionProbabilities
|
||||
|
||||
@@ -50,7 +50,7 @@ class DDPGActor(Head):
|
||||
batchnorm=self.batchnorm,
|
||||
activation_function=self.activation_function,
|
||||
dropout_rate=0,
|
||||
is_training=False,
|
||||
is_training=self.is_training,
|
||||
name="BatchnormActivationDropout_0")[-1]
|
||||
self.policy_mean = tf.multiply(policy_values_mean, self.output_scale, name='output_mean')
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ class Head(object):
|
||||
"""
|
||||
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
|
||||
head_idx: int=0, loss_weight: float=1., is_local: bool=True, activation_function: str='relu',
|
||||
dense_layer=Dense):
|
||||
dense_layer=Dense, is_training=False):
|
||||
self.head_idx = head_idx
|
||||
self.network_name = network_name
|
||||
self.network_parameters = agent_parameters.network_wrappers[self.network_name]
|
||||
@@ -64,6 +64,7 @@ class Head(object):
|
||||
self.dense_layer = Dense
|
||||
else:
|
||||
self.dense_layer = convert_layer_class(self.dense_layer)
|
||||
self.is_training = is_training
|
||||
|
||||
def __call__(self, input_layer):
|
||||
"""
|
||||
|
||||
@@ -26,6 +26,9 @@ from rl_coach.architectures.tensorflow_components import utils
|
||||
def batchnorm_activation_dropout(input_layer, batchnorm, activation_function, dropout_rate, is_training, name):
|
||||
layers = [input_layer]
|
||||
|
||||
# Rationale: passing a bool here will mean that batchnorm and or activation will never activate
|
||||
assert not isinstance(is_training, bool)
|
||||
|
||||
# batchnorm
|
||||
if batchnorm:
|
||||
layers.append(
|
||||
|
||||
@@ -17,7 +17,7 @@ from typing import Union, List
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.layers import batchnorm_activation_dropout, Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.architectures.tensorflow_components.middlewares.middleware import Middleware
|
||||
from rl_coach.base_parameters import MiddlewareScheme
|
||||
from rl_coach.core_types import Middleware_FC_Embedding
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.layers import batchnorm_activation_dropout, Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.architectures.tensorflow_components.middlewares.middleware import Middleware
|
||||
from rl_coach.base_parameters import MiddlewareScheme
|
||||
from rl_coach.core_types import Middleware_LSTM_Embedding
|
||||
|
||||
@@ -25,17 +25,20 @@ def test_embedder(reset):
|
||||
with pytest.raises(ValueError):
|
||||
embedder = ImageEmbedder(np.array([10, 100, 100, 100]), name="test")
|
||||
|
||||
# creating a simple image embedder
|
||||
embedder = ImageEmbedder(np.array([100, 100, 10]), name="test")
|
||||
|
||||
# make sure the ops where not created yet
|
||||
assert len(tf.get_default_graph().get_operations()) == 0
|
||||
is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||
pre_ops = len(tf.get_default_graph().get_operations())
|
||||
# creating a simple image embedder
|
||||
embedder = ImageEmbedder(np.array([100, 100, 10]), name="test", is_training=is_training)
|
||||
|
||||
# make sure the only the is_training op is creates
|
||||
assert len(tf.get_default_graph().get_operations()) == pre_ops
|
||||
|
||||
# call the embedder
|
||||
input_ph, output_ph = embedder()
|
||||
|
||||
# make sure that now the ops were created
|
||||
assert len(tf.get_default_graph().get_operations()) > 0
|
||||
assert len(tf.get_default_graph().get_operations()) > pre_ops
|
||||
|
||||
# try feeding a batch of one example
|
||||
input = np.random.rand(1, 100, 100, 10)
|
||||
@@ -55,7 +58,9 @@ def test_embedder(reset):
|
||||
@pytest.mark.unit_test
|
||||
def test_complex_embedder(reset):
|
||||
# creating a deep vector embedder
|
||||
embedder = ImageEmbedder(np.array([100, 100, 10]), name="test", scheme=EmbedderScheme.Deep)
|
||||
is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||
embedder = ImageEmbedder(np.array([100, 100, 10]), name="test", scheme=EmbedderScheme.Deep,
|
||||
is_training=is_training)
|
||||
|
||||
# call the embedder
|
||||
embedder()
|
||||
@@ -71,8 +76,9 @@ def test_complex_embedder(reset):
|
||||
@pytest.mark.unit_test
|
||||
def test_activation_function(reset):
|
||||
# creating a deep image embedder with relu
|
||||
is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||
embedder = ImageEmbedder(np.array([100, 100, 10]), name="relu", scheme=EmbedderScheme.Deep,
|
||||
activation_function=tf.nn.relu)
|
||||
activation_function=tf.nn.relu, is_training=is_training)
|
||||
|
||||
# call the embedder
|
||||
embedder()
|
||||
@@ -86,7 +92,7 @@ def test_activation_function(reset):
|
||||
|
||||
# creating a deep image embedder with tanh
|
||||
embedder_tanh = ImageEmbedder(np.array([100, 100, 10]), name="tanh", scheme=EmbedderScheme.Deep,
|
||||
activation_function=tf.nn.tanh)
|
||||
activation_function=tf.nn.tanh, is_training=is_training)
|
||||
|
||||
# call the embedder
|
||||
embedder_tanh()
|
||||
|
||||
@@ -22,16 +22,19 @@ def test_embedder(reset):
|
||||
embedder = VectorEmbedder(np.array([10, 10]), name="test")
|
||||
|
||||
# creating a simple vector embedder
|
||||
embedder = VectorEmbedder(np.array([10]), name="test")
|
||||
is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||
pre_ops = len(tf.get_default_graph().get_operations())
|
||||
|
||||
embedder = VectorEmbedder(np.array([10]), name="test", is_training=is_training)
|
||||
|
||||
# make sure the ops where not created yet
|
||||
assert len(tf.get_default_graph().get_operations()) == 0
|
||||
assert len(tf.get_default_graph().get_operations()) == pre_ops
|
||||
|
||||
# call the embedder
|
||||
input_ph, output_ph = embedder()
|
||||
|
||||
# make sure that now the ops were created
|
||||
assert len(tf.get_default_graph().get_operations()) > 0
|
||||
assert len(tf.get_default_graph().get_operations()) > pre_ops
|
||||
|
||||
# try feeding a batch of one example
|
||||
input = np.random.rand(1, 10)
|
||||
@@ -51,7 +54,8 @@ def test_embedder(reset):
|
||||
@pytest.mark.unit_test
|
||||
def test_complex_embedder(reset):
|
||||
# creating a deep vector embedder
|
||||
embedder = VectorEmbedder(np.array([10]), name="test", scheme=EmbedderScheme.Deep)
|
||||
is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||
embedder = VectorEmbedder(np.array([10]), name="test", scheme=EmbedderScheme.Deep, is_training=is_training)
|
||||
|
||||
# call the embedder
|
||||
embedder()
|
||||
@@ -67,8 +71,9 @@ def test_complex_embedder(reset):
|
||||
@pytest.mark.unit_test
|
||||
def test_activation_function(reset):
|
||||
# creating a deep vector embedder with relu
|
||||
is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||
embedder = VectorEmbedder(np.array([10]), name="relu", scheme=EmbedderScheme.Deep,
|
||||
activation_function=tf.nn.relu)
|
||||
activation_function=tf.nn.relu, is_training=is_training)
|
||||
|
||||
# call the embedder
|
||||
embedder()
|
||||
@@ -82,7 +87,7 @@ def test_activation_function(reset):
|
||||
|
||||
# creating a deep vector embedder with tanh
|
||||
embedder_tanh = VectorEmbedder(np.array([10]), name="tanh", scheme=EmbedderScheme.Deep,
|
||||
activation_function=tf.nn.tanh)
|
||||
activation_function=tf.nn.tanh, is_training=is_training)
|
||||
|
||||
# call the embedder
|
||||
embedder_tanh()
|
||||
|
||||