1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00
Files
coach/architectures/network_wrapper.py
Roman Dobosz 1b095aeeca Cleanup imports.
Till now, most of the modules were importing all of the module objects
(variables, classes, functions, other imports) into module namespace,
which potentially could (and was) cause of unintentional use of class or
methods, which was indirect imported.

With this patch, all the star imports were substituted with top-level
module, which provides desired class or function.

Besides, all imports where sorted (where possible) in a way pep8[1]
suggests - first are imports from standard library, than goes third
party imports (like numpy, tensorflow etc) and finally coach modules.
All of those sections are separated by one empty line.

[1] https://www.python.org/dev/peps/pep-0008/#imports
2018-04-13 09:58:40 +02:00

188 lines
7.6 KiB
Python

#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import collections
import configurations as conf
import logger
try:
import tensorflow as tf
from architectures.tensorflow_components import general_network as tf_net #import GeneralTensorFlowNetwork
except ImportError:
logger.failed_imports.append("TensorFlow")
try:
from architectures.neon_components import general_network as neon_net
except ImportError:
logger.failed_imports.append("Neon")
class NetworkWrapper(object):
"""
Contains multiple networks and managers syncing and gradient updates
between them.
"""
def __init__(self, tuning_parameters, has_target, has_global, name, replicated_device=None, worker_device=None):
"""
:param tuning_parameters:
:type tuning_parameters: Preset
:param has_target:
:param has_global:
:param name:
:param replicated_device:
:param worker_device:
"""
self.tp = tuning_parameters
self.has_target = has_target
self.has_global = has_global
self.name = name
self.sess = tuning_parameters.sess
if self.tp.framework == conf.Frameworks.TensorFlow:
general_network = tf_net.GeneralTensorFlowNetwork
elif self.tp.framework == conf.Frameworks.Neon:
general_network = neon_net.GeneralNeonNetwork
else:
raise Exception("{} Framework is not supported".format(conf.Frameworks().to_string(self.tp.framework)))
# Global network - the main network shared between threads
self.global_network = None
if self.has_global:
with tf.device(replicated_device):
self.global_network = general_network(tuning_parameters, '{}/global'.format(name),
network_is_local=False)
# Online network - local copy of the main network used for playing
self.online_network = None
with tf.device(worker_device):
self.online_network = general_network(tuning_parameters, '{}/online'.format(name),
self.global_network, network_is_local=True)
# Target network - a local, slow updating network used for stabilizing the learning
self.target_network = None
if self.has_target:
with tf.device(worker_device):
self.target_network = general_network(tuning_parameters, '{}/target'.format(name),
network_is_local=True)
if not self.tp.distributed and self.tp.framework == conf.Frameworks.TensorFlow:
variables_to_restore = tf.global_variables()
variables_to_restore = [v for v in variables_to_restore if '/online' in v.name]
self.model_saver = tf.train.Saver(variables_to_restore)
if self.tp.sess and self.tp.checkpoint_restore_dir:
checkpoint = tf.train.latest_checkpoint(self.tp.checkpoint_restore_dir)
logger.screen.log_title("Loading checkpoint: {}".format(checkpoint))
self.model_saver.restore(self.tp.sess, checkpoint)
self.update_target_network()
def sync(self):
"""
Initializes the weights of the networks to match each other
:return:
"""
self.update_online_network()
self.update_target_network()
def update_target_network(self, rate=1.0):
"""
Copy weights: online network >>> target network
:param rate: the rate of copying the weights - 1 for copying exactly
"""
if self.target_network:
self.target_network.set_weights(self.online_network.get_weights(), rate)
def update_online_network(self, rate=1.0):
"""
Copy weights: global network >>> online network
:param rate: the rate of copying the weights - 1 for copying exactly
"""
if self.global_network:
self.online_network.set_weights(self.global_network.get_weights(), rate)
def apply_gradients_to_global_network(self):
"""
Apply gradients from the online network on the global network
:return:
"""
self.global_network.apply_gradients(self.online_network.accumulated_gradients)
def apply_gradients_to_online_network(self):
"""
Apply gradients from the online network on itself
:return:
"""
self.online_network.apply_gradients(self.online_network.accumulated_gradients)
def train_and_sync_networks(self, inputs, targets, additional_fetches=[]):
"""
A generic training function that enables multi-threading training using a global network if necessary.
:param inputs: The inputs for the network.
:param targets: The targets corresponding to the given inputs
:param additional_fetches: Any additional tensor the user wants to fetch
:return: The loss of the training iteration
"""
result = self.online_network.accumulate_gradients(inputs, targets, additional_fetches=additional_fetches)
self.apply_gradients_and_sync_networks()
return result
def apply_gradients_and_sync_networks(self):
"""
Applies the gradients accumulated in the online network to the global network or to itself and syncs the
networks if necessary
"""
if self.global_network:
self.apply_gradients_to_global_network()
self.online_network.reset_accumulated_gradients()
self.update_online_network()
else:
self.online_network.apply_and_reset_gradients(self.online_network.accumulated_gradients)
def get_local_variables(self):
"""
Get all the variables that are local to the thread
:return: a list of all the variables that are local to the thread
"""
local_variables = [v for v in tf.global_variables() if self.online_network.name in v.name]
if self.has_target:
local_variables += [v for v in tf.global_variables() if self.target_network.name in v.name]
return local_variables
def get_global_variables(self):
"""
Get all the variables that are shared between threads
:return: a list of all the variables that are shared between threads
"""
global_variables = [v for v in tf.global_variables() if self.global_network.name in v.name]
return global_variables
def set_session(self, sess):
self.sess = sess
self.online_network.sess = sess
if self.global_network:
self.global_network.sess = sess
if self.target_network:
self.target_network.sess = sess
def save_model(self, model_id):
saved_model_path = self.model_saver.save(self.tp.sess, os.path.join(self.tp.save_model_dir,
str(model_id) + '.ckpt'))
logger.screen.log_dict(
collections.OrderedDict([
("Saving model", saved_model_path),
]),
prefix="Checkpoint"
)