mirror of
https://github.com/gryf/coach.git
synced 2026-02-07 16:55:48 +01:00
Cleanup imports.
Till now, most of the modules were importing all of the module objects (variables, classes, functions, other imports) into module namespace, which potentially could (and was) cause of unintentional use of class or methods, which was indirect imported. With this patch, all the star imports were substituted with top-level module, which provides desired class or function. Besides, all imports where sorted (where possible) in a way pep8[1] suggests - first are imports from standard library, than goes third party imports (like numpy, tensorflow etc) and finally coach modules. All of those sections are separated by one empty line. [1] https://www.python.org/dev/peps/pep-0008/#imports
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -13,19 +13,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from architectures.architecture import *
|
||||
from logger import failed_imports
|
||||
try:
|
||||
from architectures.tensorflow_components.general_network import *
|
||||
from architectures.tensorflow_components.architecture import *
|
||||
except ImportError:
|
||||
failed_imports.append("TensorFlow")
|
||||
import logger
|
||||
|
||||
try:
|
||||
from architectures.neon_components.general_network import *
|
||||
from architectures.neon_components.architecture import *
|
||||
from architectures.tensorflow_components import general_network as ts_gn
|
||||
from architectures.tensorflow_components import architecture as ts_arch
|
||||
except ImportError:
|
||||
failed_imports.append("Neon")
|
||||
logger.failed_imports.append("TensorFlow")
|
||||
|
||||
from architectures.network_wrapper import *
|
||||
try:
|
||||
from architectures.neon_components import general_network as neon_gn
|
||||
from architectures.neon_components import architecture as neon_arch
|
||||
except ImportError:
|
||||
logger.failed_imports.append("Neon")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -14,8 +14,6 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from configurations import Preset
|
||||
|
||||
|
||||
class Architecture(object):
|
||||
def __init__(self, tuning_parameters, name=""):
|
||||
@@ -73,4 +71,4 @@ class Architecture(object):
|
||||
pass
|
||||
|
||||
def set_variable_value(self, assign_op, value, placeholder=None):
|
||||
pass
|
||||
pass
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -13,19 +13,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import sys
|
||||
import copy
|
||||
from ngraph.frontends.neon import *
|
||||
import ngraph as ng
|
||||
from architectures.architecture import *
|
||||
import numpy as np
|
||||
from utils import *
|
||||
|
||||
from architectures import architecture
|
||||
import utils
|
||||
|
||||
|
||||
class NeonArchitecture(Architecture):
|
||||
class NeonArchitecture(architecture.Architecture):
|
||||
def __init__(self, tuning_parameters, name="", global_network=None, network_is_local=True):
|
||||
Architecture.__init__(self, tuning_parameters, name)
|
||||
architecture.Architecture.__init__(self, tuning_parameters, name)
|
||||
assert tuning_parameters.agent.neon_support, 'Neon is not supported for this agent'
|
||||
self.clip_error = tuning_parameters.clip_gradients
|
||||
self.total_loss = None
|
||||
@@ -113,8 +110,8 @@ class NeonArchitecture(Architecture):
|
||||
def accumulate_gradients(self, inputs, targets):
|
||||
# Neon doesn't currently allow separating the grads calculation and grad apply operations
|
||||
# so this feature is not currently available. instead we do a full training iteration
|
||||
inputs = force_list(inputs)
|
||||
targets = force_list(targets)
|
||||
inputs = utils.force_list(inputs)
|
||||
targets = utils.force_list(targets)
|
||||
|
||||
for idx, input in enumerate(inputs):
|
||||
inputs[idx] = input.swapaxes(0, -1)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -13,10 +13,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import ngraph.frontends.neon as neon
|
||||
import ngraph as ng
|
||||
from ngraph.util.names import name_scope
|
||||
import ngraph.frontends.neon as neon
|
||||
import ngraph.util.names as ngraph_names
|
||||
|
||||
|
||||
class InputEmbedder(object):
|
||||
@@ -31,7 +30,7 @@ class InputEmbedder(object):
|
||||
self.output = None
|
||||
|
||||
def __call__(self, prev_input_placeholder=None):
|
||||
with name_scope(self.get_name()):
|
||||
with ngraph_names.name_scope(self.get_name()):
|
||||
# create the input axes
|
||||
axes = []
|
||||
if len(self.input_size) == 2:
|
||||
|
||||
@@ -13,15 +13,18 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import ngraph as ng
|
||||
from ngraph.frontends import neon
|
||||
from ngraph.util import names as ngraph_names
|
||||
|
||||
from architectures.neon_components.embedders import *
|
||||
from architectures.neon_components.heads import *
|
||||
from architectures.neon_components.middleware import *
|
||||
from architectures.neon_components.architecture import *
|
||||
from configurations import InputTypes, OutputTypes, MiddlewareTypes
|
||||
from architectures.neon_components import architecture
|
||||
from architectures.neon_components import embedders
|
||||
from architectures.neon_components import middleware
|
||||
from architectures.neon_components import heads
|
||||
import configurations as conf
|
||||
|
||||
|
||||
class GeneralNeonNetwork(NeonArchitecture):
|
||||
class GeneralNeonNetwork(architecture.NeonArchitecture):
|
||||
def __init__(self, tuning_parameters, name="", global_network=None, network_is_local=True):
|
||||
self.global_network = global_network
|
||||
self.network_is_local = network_is_local
|
||||
@@ -34,7 +37,7 @@ class GeneralNeonNetwork(NeonArchitecture):
|
||||
self.activation_function = self.get_activation_function(
|
||||
tuning_parameters.agent.hidden_layers_activation_function)
|
||||
|
||||
NeonArchitecture.__init__(self, tuning_parameters, name, global_network, network_is_local)
|
||||
architecture.NeonArchitecture.__init__(self, tuning_parameters, name, global_network, network_is_local)
|
||||
|
||||
def get_activation_function(self, activation_function_string):
|
||||
activation_functions = {
|
||||
@@ -53,36 +56,36 @@ class GeneralNeonNetwork(NeonArchitecture):
|
||||
# the observation can be either an image or a vector
|
||||
def get_observation_embedding(with_timestep=False):
|
||||
if self.input_height > 1:
|
||||
return ImageEmbedder((self.input_depth, self.input_height, self.input_width), self.batch_size,
|
||||
name="observation")
|
||||
return embedders.ImageEmbedder((self.input_depth, self.input_height, self.input_width), self.batch_size,
|
||||
name="observation")
|
||||
else:
|
||||
return VectorEmbedder((self.input_depth, self.input_width + int(with_timestep)), self.batch_size,
|
||||
name="observation")
|
||||
return embedders.VectorEmbedder((self.input_depth, self.input_width + int(with_timestep)), self.batch_size,
|
||||
name="observation")
|
||||
|
||||
input_mapping = {
|
||||
InputTypes.Observation: get_observation_embedding(),
|
||||
InputTypes.Measurements: VectorEmbedder(self.measurements_size, self.batch_size, name="measurements"),
|
||||
InputTypes.GoalVector: VectorEmbedder(self.measurements_size, self.batch_size, name="goal_vector"),
|
||||
InputTypes.Action: VectorEmbedder((self.num_actions,), self.batch_size, name="action"),
|
||||
InputTypes.TimedObservation: get_observation_embedding(with_timestep=True),
|
||||
conf.InputTypes.Observation: get_observation_embedding(),
|
||||
conf.InputTypes.Measurements: embedders.VectorEmbedder(self.measurements_size, self.batch_size, name="measurements"),
|
||||
conf.InputTypes.GoalVector: embedders.VectorEmbedder(self.measurements_size, self.batch_size, name="goal_vector"),
|
||||
conf.InputTypes.Action: embedders.VectorEmbedder((self.num_actions,), self.batch_size, name="action"),
|
||||
conf.InputTypes.TimedObservation: get_observation_embedding(with_timestep=True),
|
||||
}
|
||||
return input_mapping[embedder_type]
|
||||
|
||||
def get_middleware_embedder(self, middleware_type):
|
||||
return {MiddlewareTypes.LSTM: None, # LSTM over Neon is currently not supported in Coach
|
||||
MiddlewareTypes.FC: FC_Embedder}.get(middleware_type)(self.activation_function)
|
||||
return {conf.MiddlewareTypes.LSTM: None, # LSTM over Neon is currently not supported in Coach
|
||||
conf.MiddlewareTypes.FC: middleware.FC_Embedder}.get(middleware_type)(self.activation_function)
|
||||
|
||||
def get_output_head(self, head_type, head_idx, loss_weight=1.):
|
||||
output_mapping = {
|
||||
OutputTypes.Q: QHead,
|
||||
OutputTypes.DuelingQ: DuelingQHead,
|
||||
OutputTypes.V: None, # Policy Optimization algorithms over Neon are currently not supported in Coach
|
||||
OutputTypes.Pi: None, # Policy Optimization algorithms over Neon are currently not supported in Coach
|
||||
OutputTypes.MeasurementsPrediction: None, # DFP over Neon is currently not supported in Coach
|
||||
OutputTypes.DNDQ: None, # NEC over Neon is currently not supported in Coach
|
||||
OutputTypes.NAF: None, # NAF over Neon is currently not supported in Coach
|
||||
OutputTypes.PPO: None, # PPO over Neon is currently not supported in Coach
|
||||
OutputTypes.PPO_V: None # PPO over Neon is currently not supported in Coach
|
||||
conf.OutputTypes.Q: heads.QHead,
|
||||
conf.OutputTypes.DuelingQ: heads.DuelingQHead,
|
||||
conf.OutputTypes.V: None, # Policy Optimization algorithms over Neon are currently not supported in Coach
|
||||
conf.OutputTypes.Pi: None, # Policy Optimization algorithms over Neon are currently not supported in Coach
|
||||
conf.OutputTypes.MeasurementsPrediction: None, # DFP over Neon is currently not supported in Coach
|
||||
conf.OutputTypes.DNDQ: None, # NEC over Neon is currently not supported in Coach
|
||||
conf.OutputTypes.NAF: None, # NAF over Neon is currently not supported in Coach
|
||||
conf.OutputTypes.PPO: None, # PPO over Neon is currently not supported in Coach
|
||||
conf.OutputTypes.PPO_V: None # PPO over Neon is currently not supported in Coach
|
||||
}
|
||||
return output_mapping[head_type](self.tp, head_idx, loss_weight, self.network_is_local)
|
||||
|
||||
@@ -104,7 +107,7 @@ class GeneralNeonNetwork(NeonArchitecture):
|
||||
done_creating_input_placeholders = False
|
||||
|
||||
for network_idx in range(self.num_networks):
|
||||
with name_scope('network_{}'.format(network_idx)):
|
||||
with ngraph_names.name_scope('network_{}'.format(network_idx)):
|
||||
####################
|
||||
# Input Embeddings #
|
||||
####################
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -13,13 +13,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import ngraph as ng
|
||||
from ngraph.util.names import name_scope
|
||||
import ngraph.frontends.neon as neon
|
||||
import numpy as np
|
||||
from utils import force_list
|
||||
from architectures.neon_components.losses import *
|
||||
from ngraph.frontends import neon
|
||||
from ngraph.util import names as ngraph_names
|
||||
|
||||
import utils
|
||||
from architectures.neon_components import losses
|
||||
|
||||
|
||||
class Head(object):
|
||||
@@ -30,7 +29,7 @@ class Head(object):
|
||||
self.loss = []
|
||||
self.loss_type = []
|
||||
self.regularizations = []
|
||||
self.loss_weight = force_list(loss_weight)
|
||||
self.loss_weight = utils.force_list(loss_weight)
|
||||
self.weights_init = neon.GlorotInit()
|
||||
self.biases_init = neon.ConstantInit()
|
||||
self.target = []
|
||||
@@ -44,15 +43,15 @@ class Head(object):
|
||||
:param input_layer: the input to the graph
|
||||
:return: the output of the last layer and the target placeholder
|
||||
"""
|
||||
with name_scope(self.get_name()):
|
||||
with ngraph_names.name_scope(self.get_name()):
|
||||
self._build_module(input_layer)
|
||||
|
||||
self.output = force_list(self.output)
|
||||
self.target = force_list(self.target)
|
||||
self.input = force_list(self.input)
|
||||
self.loss_type = force_list(self.loss_type)
|
||||
self.loss = force_list(self.loss)
|
||||
self.regularizations = force_list(self.regularizations)
|
||||
self.output = utils.force_list(self.output)
|
||||
self.target = utils.force_list(self.target)
|
||||
self.input = utils.force_list(self.input)
|
||||
self.loss_type = utils.force_list(self.loss_type)
|
||||
self.loss = utils.force_list(self.loss)
|
||||
self.regularizations = utils.force_list(self.regularizations)
|
||||
if self.is_local:
|
||||
self.set_loss()
|
||||
|
||||
@@ -106,7 +105,7 @@ class QHead(Head):
|
||||
if tuning_parameters.agent.replace_mse_with_huber_loss:
|
||||
raise Exception("huber loss is not supported in neon")
|
||||
else:
|
||||
self.loss_type = mean_squared_error
|
||||
self.loss_type = losses.mean_squared_error
|
||||
|
||||
def _build_module(self, input_layer):
|
||||
# Standard Q Network
|
||||
@@ -159,7 +158,7 @@ class MeasurementsPredictionHead(Head):
|
||||
if tuning_parameters.agent.replace_mse_with_huber_loss:
|
||||
raise Exception("huber loss is not supported in neon")
|
||||
else:
|
||||
self.loss_type = mean_squared_error
|
||||
self.loss_type = losses.mean_squared_error
|
||||
|
||||
def _build_module(self, input_layer):
|
||||
# This is almost exactly the same as Dueling Network but we predict the future measurements for each action
|
||||
@@ -167,7 +166,7 @@ class MeasurementsPredictionHead(Head):
|
||||
multistep_measurements_size = self.measurements_size[0] * self.num_predicted_steps_ahead
|
||||
|
||||
# actions expectation tower (expectation stream) - E
|
||||
with name_scope("expectation_stream"):
|
||||
with ngraph_names.name_scope("expectation_stream"):
|
||||
expectation_stream = neon.Sequential([
|
||||
neon.Affine(nout=256, activation=neon.Rectlin(),
|
||||
weight_init=self.weights_init, bias_init=self.biases_init),
|
||||
@@ -176,7 +175,7 @@ class MeasurementsPredictionHead(Head):
|
||||
])(input_layer)
|
||||
|
||||
# action fine differences tower (action stream) - A
|
||||
with name_scope("action_stream"):
|
||||
with ngraph_names.name_scope("action_stream"):
|
||||
action_stream_unnormalized = neon.Sequential([
|
||||
neon.Affine(nout=256, activation=neon.Rectlin(),
|
||||
weight_init=self.weights_init, bias_init=self.biases_init),
|
||||
@@ -191,4 +190,3 @@ class MeasurementsPredictionHead(Head):
|
||||
|
||||
# merge to future measurements predictions
|
||||
self.output = repeated_expectation_stream + action_stream
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -13,15 +13,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import ngraph as ng
|
||||
import ngraph.frontends.neon as neon
|
||||
from ngraph.util.names import name_scope
|
||||
import numpy as np
|
||||
from ngraph.util import names as ngraph_names
|
||||
|
||||
|
||||
def mean_squared_error(targets, outputs, weights=1.0, scope=""):
|
||||
with name_scope(scope):
|
||||
with ngraph_names.name_scope(scope):
|
||||
# TODO: reduce mean over the action axis
|
||||
loss = ng.squared_L2(targets - outputs)
|
||||
weighted_loss = loss * weights
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -13,11 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import ngraph as ng
|
||||
import ngraph.frontends.neon as neon
|
||||
from ngraph.util.names import name_scope
|
||||
import numpy as np
|
||||
from ngraph.util import names as ngraph_names
|
||||
|
||||
|
||||
class MiddlewareEmbedder(object):
|
||||
@@ -30,7 +27,7 @@ class MiddlewareEmbedder(object):
|
||||
self.activation_function = activation_function
|
||||
|
||||
def __call__(self, input_layer):
|
||||
with name_scope(self.get_name()):
|
||||
with ngraph_names.name_scope(self.get_name()):
|
||||
self.input = input_layer
|
||||
self._build_module()
|
||||
|
||||
|
||||
@@ -13,20 +13,21 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
import collections
|
||||
|
||||
from collections import OrderedDict
|
||||
from configurations import Preset, Frameworks
|
||||
from logger import *
|
||||
import configurations as conf
|
||||
import logger
|
||||
try:
|
||||
import tensorflow as tf
|
||||
from architectures.tensorflow_components.general_network import GeneralTensorFlowNetwork
|
||||
from architectures.tensorflow_components import general_network as tf_net #import GeneralTensorFlowNetwork
|
||||
except ImportError:
|
||||
failed_imports.append("TensorFlow")
|
||||
logger.failed_imports.append("TensorFlow")
|
||||
|
||||
try:
|
||||
from architectures.neon_components.general_network import GeneralNeonNetwork
|
||||
from architectures.neon_components import general_network as neon_net
|
||||
except ImportError:
|
||||
failed_imports.append("Neon")
|
||||
logger.failed_imports.append("Neon")
|
||||
|
||||
|
||||
class NetworkWrapper(object):
|
||||
@@ -50,12 +51,12 @@ class NetworkWrapper(object):
|
||||
self.name = name
|
||||
self.sess = tuning_parameters.sess
|
||||
|
||||
if self.tp.framework == Frameworks.TensorFlow:
|
||||
general_network = GeneralTensorFlowNetwork
|
||||
elif self.tp.framework == Frameworks.Neon:
|
||||
general_network = GeneralNeonNetwork
|
||||
if self.tp.framework == conf.Frameworks.TensorFlow:
|
||||
general_network = tf_net.GeneralTensorFlowNetwork
|
||||
elif self.tp.framework == conf.Frameworks.Neon:
|
||||
general_network = neon_net.GeneralNeonNetwork
|
||||
else:
|
||||
raise Exception("{} Framework is not supported".format(Frameworks().to_string(self.tp.framework)))
|
||||
raise Exception("{} Framework is not supported".format(conf.Frameworks().to_string(self.tp.framework)))
|
||||
|
||||
# Global network - the main network shared between threads
|
||||
self.global_network = None
|
||||
@@ -77,13 +78,13 @@ class NetworkWrapper(object):
|
||||
self.target_network = general_network(tuning_parameters, '{}/target'.format(name),
|
||||
network_is_local=True)
|
||||
|
||||
if not self.tp.distributed and self.tp.framework == Frameworks.TensorFlow:
|
||||
if not self.tp.distributed and self.tp.framework == conf.Frameworks.TensorFlow:
|
||||
variables_to_restore = tf.global_variables()
|
||||
variables_to_restore = [v for v in variables_to_restore if '/online' in v.name]
|
||||
self.model_saver = tf.train.Saver(variables_to_restore)
|
||||
if self.tp.sess and self.tp.checkpoint_restore_dir:
|
||||
checkpoint = tf.train.latest_checkpoint(self.tp.checkpoint_restore_dir)
|
||||
screen.log_title("Loading checkpoint: {}".format(checkpoint))
|
||||
logger.screen.log_title("Loading checkpoint: {}".format(checkpoint))
|
||||
self.model_saver.restore(self.tp.sess, checkpoint)
|
||||
self.update_target_network()
|
||||
|
||||
@@ -178,8 +179,8 @@ class NetworkWrapper(object):
|
||||
def save_model(self, model_id):
|
||||
saved_model_path = self.model_saver.save(self.tp.sess, os.path.join(self.tp.save_model_dir,
|
||||
str(model_id) + '.ckpt'))
|
||||
screen.log_dict(
|
||||
OrderedDict([
|
||||
logger.screen.log_dict(
|
||||
collections.OrderedDict([
|
||||
("Saving model", saved_model_path),
|
||||
]),
|
||||
prefix="Checkpoint"
|
||||
|
||||
@@ -15,12 +15,11 @@
|
||||
#
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from architectures.architecture import Architecture
|
||||
from utils import force_list, squeeze_list
|
||||
from configurations import Preset, MiddlewareTypes
|
||||
from architectures import architecture
|
||||
import configurations as conf
|
||||
import utils
|
||||
|
||||
def variable_summaries(var):
|
||||
"""Attach a lot of summaries to a Tensor (for TensorBoard visualization)."""
|
||||
@@ -37,14 +36,14 @@ def variable_summaries(var):
|
||||
tf.summary.scalar('min', tf.reduce_min(var))
|
||||
tf.summary.histogram('histogram', var)
|
||||
|
||||
class TensorFlowArchitecture(Architecture):
|
||||
class TensorFlowArchitecture(architecture.Architecture):
|
||||
def __init__(self, tuning_parameters, name="", global_network=None, network_is_local=True):
|
||||
"""
|
||||
:param tuning_parameters: The parameters used for running the algorithm
|
||||
:type tuning_parameters: Preset
|
||||
:param name: The name of the network
|
||||
"""
|
||||
Architecture.__init__(self, tuning_parameters, name)
|
||||
architecture.Architecture.__init__(self, tuning_parameters, name)
|
||||
self.middleware_embedder = None
|
||||
self.network_is_local = network_is_local
|
||||
assert tuning_parameters.agent.tensorflow_support, 'TensorFlow is not supported for this agent'
|
||||
@@ -174,7 +173,7 @@ class TensorFlowArchitecture(Architecture):
|
||||
feed_dict = self._feed_dict(inputs)
|
||||
|
||||
# feed targets
|
||||
targets = force_list(targets)
|
||||
targets = utils.force_list(targets)
|
||||
for placeholder_idx, target in enumerate(targets):
|
||||
feed_dict[self.targets[placeholder_idx]] = target
|
||||
|
||||
@@ -186,13 +185,13 @@ class TensorFlowArchitecture(Architecture):
|
||||
else:
|
||||
fetches.append(self.tensor_gradients)
|
||||
fetches += [self.total_loss, self.losses]
|
||||
if self.tp.agent.middleware_type == MiddlewareTypes.LSTM:
|
||||
if self.tp.agent.middleware_type == conf.MiddlewareTypes.LSTM:
|
||||
fetches.append(self.middleware_embedder.state_out)
|
||||
additional_fetches_start_idx = len(fetches)
|
||||
fetches += additional_fetches
|
||||
|
||||
# feed the lstm state if necessary
|
||||
if self.tp.agent.middleware_type == MiddlewareTypes.LSTM:
|
||||
if self.tp.agent.middleware_type == conf.MiddlewareTypes.LSTM:
|
||||
# we can't always assume that we are starting from scratch here can we?
|
||||
feed_dict[self.middleware_embedder.c_in] = self.middleware_embedder.c_init
|
||||
feed_dict[self.middleware_embedder.h_in] = self.middleware_embedder.h_init
|
||||
@@ -206,7 +205,7 @@ class TensorFlowArchitecture(Architecture):
|
||||
|
||||
# extract the fetches
|
||||
norm_unclipped_grads, grads, total_loss, losses = result[:4]
|
||||
if self.tp.agent.middleware_type == MiddlewareTypes.LSTM:
|
||||
if self.tp.agent.middleware_type == conf.MiddlewareTypes.LSTM:
|
||||
(self.curr_rnn_c_in, self.curr_rnn_h_in) = result[4]
|
||||
fetched_tensors = []
|
||||
if len(additional_fetches) > 0:
|
||||
@@ -308,7 +307,7 @@ class TensorFlowArchitecture(Architecture):
|
||||
if outputs is None:
|
||||
outputs = self.outputs
|
||||
|
||||
if self.tp.agent.middleware_type == MiddlewareTypes.LSTM:
|
||||
if self.tp.agent.middleware_type == conf.MiddlewareTypes.LSTM:
|
||||
feed_dict[self.middleware_embedder.c_in] = self.curr_rnn_c_in
|
||||
feed_dict[self.middleware_embedder.h_in] = self.curr_rnn_h_in
|
||||
|
||||
@@ -317,7 +316,7 @@ class TensorFlowArchitecture(Architecture):
|
||||
output = self.tp.sess.run(outputs, feed_dict)
|
||||
|
||||
if squeeze_output:
|
||||
output = squeeze_list(output)
|
||||
output = utils.squeeze_list(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -13,8 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from configurations import EmbedderComplexity
|
||||
|
||||
|
||||
|
||||
@@ -13,15 +13,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import tensorflow as tf
|
||||
|
||||
from architectures.tensorflow_components.embedders import *
|
||||
from architectures.tensorflow_components.heads import *
|
||||
from architectures.tensorflow_components.middleware import *
|
||||
from architectures.tensorflow_components.architecture import *
|
||||
from configurations import InputTypes, OutputTypes, MiddlewareTypes
|
||||
from architectures.tensorflow_components import architecture
|
||||
from architectures.tensorflow_components import embedders
|
||||
from architectures.tensorflow_components import middleware
|
||||
from architectures.tensorflow_components import heads
|
||||
import configurations as conf
|
||||
|
||||
|
||||
class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
class GeneralTensorFlowNetwork(architecture.TensorFlowArchitecture):
|
||||
"""
|
||||
A generalized version of all possible networks implemented using tensorflow.
|
||||
"""
|
||||
@@ -37,7 +38,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
self.activation_function = self.get_activation_function(
|
||||
tuning_parameters.agent.hidden_layers_activation_function)
|
||||
|
||||
TensorFlowArchitecture.__init__(self, tuning_parameters, name, global_network, network_is_local)
|
||||
architecture.TensorFlowArchitecture.__init__(self, tuning_parameters, name, global_network, network_is_local)
|
||||
|
||||
def get_activation_function(self, activation_function_string):
|
||||
activation_functions = {
|
||||
@@ -56,37 +57,37 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
# the observation can be either an image or a vector
|
||||
def get_observation_embedding(with_timestep=False):
|
||||
if self.input_height > 1:
|
||||
return ImageEmbedder((self.input_height, self.input_width, self.input_depth), name="observation",
|
||||
input_rescaler=self.tp.agent.input_rescaler)
|
||||
return embedders.ImageEmbedder((self.input_height, self.input_width, self.input_depth), name="observation",
|
||||
input_rescaler=self.tp.agent.input_rescaler)
|
||||
else:
|
||||
return VectorEmbedder((self.input_width + int(with_timestep), self.input_depth), name="observation")
|
||||
return embedders.VectorEmbedder((self.input_width + int(with_timestep), self.input_depth), name="observation")
|
||||
|
||||
input_mapping = {
|
||||
InputTypes.Observation: get_observation_embedding(),
|
||||
InputTypes.Measurements: VectorEmbedder(self.measurements_size, name="measurements"),
|
||||
InputTypes.GoalVector: VectorEmbedder(self.measurements_size, name="goal_vector"),
|
||||
InputTypes.Action: VectorEmbedder((self.num_actions,), name="action"),
|
||||
InputTypes.TimedObservation: get_observation_embedding(with_timestep=True),
|
||||
conf.InputTypes.Observation: get_observation_embedding(),
|
||||
conf.InputTypes.Measurements: embedders.VectorEmbedder(self.measurements_size, name="measurements"),
|
||||
conf.InputTypes.GoalVector: embedders.VectorEmbedder(self.measurements_size, name="goal_vector"),
|
||||
conf.InputTypes.Action: embedders.VectorEmbedder((self.num_actions,), name="action"),
|
||||
conf.InputTypes.TimedObservation: get_observation_embedding(with_timestep=True),
|
||||
}
|
||||
return input_mapping[embedder_type]
|
||||
|
||||
def get_middleware_embedder(self, middleware_type):
|
||||
return {MiddlewareTypes.LSTM: LSTM_Embedder,
|
||||
MiddlewareTypes.FC: FC_Embedder}.get(middleware_type)(self.activation_function)
|
||||
return {conf.MiddlewareTypes.LSTM: middleware.LSTM_Embedder,
|
||||
conf.MiddlewareTypes.FC: middleware.FC_Embedder}.get(middleware_type)(self.activation_function)
|
||||
|
||||
def get_output_head(self, head_type, head_idx, loss_weight=1.):
|
||||
output_mapping = {
|
||||
OutputTypes.Q: QHead,
|
||||
OutputTypes.DuelingQ: DuelingQHead,
|
||||
OutputTypes.V: VHead,
|
||||
OutputTypes.Pi: PolicyHead,
|
||||
OutputTypes.MeasurementsPrediction: MeasurementsPredictionHead,
|
||||
OutputTypes.DNDQ: DNDQHead,
|
||||
OutputTypes.NAF: NAFHead,
|
||||
OutputTypes.PPO: PPOHead,
|
||||
OutputTypes.PPO_V: PPOVHead,
|
||||
OutputTypes.CategoricalQ: CategoricalQHead,
|
||||
OutputTypes.QuantileRegressionQ: QuantileRegressionQHead
|
||||
conf.OutputTypes.Q: heads.QHead,
|
||||
conf.OutputTypes.DuelingQ: heads.DuelingQHead,
|
||||
conf.OutputTypes.V: heads.VHead,
|
||||
conf.OutputTypes.Pi: heads.PolicyHead,
|
||||
conf.OutputTypes.MeasurementsPrediction: heads.MeasurementsPredictionHead,
|
||||
conf.OutputTypes.DNDQ: heads.DNDQHead,
|
||||
conf.OutputTypes.NAF: heads.NAFHead,
|
||||
conf.OutputTypes.PPO: heads.PPOHead,
|
||||
conf.OutputTypes.PPO_V: heads.PPOVHead,
|
||||
conf.OutputTypes.CategoricalQ: heads.CategoricalQHead,
|
||||
conf.OutputTypes.QuantileRegressionQ: heads.QuantileRegressionQHead
|
||||
}
|
||||
return output_mapping[head_type](self.tp, head_idx, loss_weight, self.network_is_local)
|
||||
|
||||
|
||||
@@ -13,10 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
from utils import force_list
|
||||
|
||||
import utils
|
||||
|
||||
|
||||
# Used to initialize weights for policy and value output layers
|
||||
@@ -36,7 +36,7 @@ class Head(object):
|
||||
self.loss = []
|
||||
self.loss_type = []
|
||||
self.regularizations = []
|
||||
self.loss_weight = force_list(loss_weight)
|
||||
self.loss_weight = utils.force_list(loss_weight)
|
||||
self.target = []
|
||||
self.input = []
|
||||
self.is_local = is_local
|
||||
@@ -50,12 +50,12 @@ class Head(object):
|
||||
with tf.variable_scope(self.get_name(), initializer=tf.contrib.layers.xavier_initializer()):
|
||||
self._build_module(input_layer)
|
||||
|
||||
self.output = force_list(self.output)
|
||||
self.target = force_list(self.target)
|
||||
self.input = force_list(self.input)
|
||||
self.loss_type = force_list(self.loss_type)
|
||||
self.loss = force_list(self.loss)
|
||||
self.regularizations = force_list(self.regularizations)
|
||||
self.output = utils.force_list(self.output)
|
||||
self.target = utils.force_list(self.target)
|
||||
self.input = utils.force_list(self.input)
|
||||
self.loss_type = utils.force_list(self.loss_type)
|
||||
self.loss = utils.force_list(self.loss)
|
||||
self.regularizations = utils.force_list(self.regularizations)
|
||||
if self.is_local:
|
||||
self.set_loss()
|
||||
self._post_build()
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
@@ -79,4 +78,4 @@ class SharedRunningStats(object):
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self._shape
|
||||
return self._shape
|
||||
|
||||
Reference in New Issue
Block a user