1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

Moved tf.variable_scope and tf.device calls to framework-specific architecture (#136)

This commit is contained in:
Sina Afrooze
2018-11-22 12:52:22 -08:00
committed by Gal Leibovich
parent 559969d3dd
commit 87a7848b0a
11 changed files with 219 additions and 91 deletions

View File

@@ -20,6 +20,7 @@ from rl_coach.base_parameters import Frameworks, AgentParameters
from rl_coach.logger import failed_imports
from rl_coach.saver import SaverCollection
from rl_coach.spaces import SpacesDefinition
from rl_coach.utils import force_list
try:
import tensorflow as tf
from rl_coach.architectures.tensorflow_components.general_network import GeneralTensorFlowNetwork
@@ -53,52 +54,55 @@ class NetworkWrapper(object):
if self.network_parameters.framework == Frameworks.tensorflow:
if "tensorflow" not in failed_imports:
general_network = GeneralTensorFlowNetwork
general_network = GeneralTensorFlowNetwork.construct
else:
raise Exception('Install tensorflow before using it as framework')
elif self.network_parameters.framework == Frameworks.mxnet:
if "mxnet" not in failed_imports:
general_network = GeneralMxnetNetwork
general_network = GeneralMxnetNetwork.construct
else:
raise Exception('Install mxnet before using it as framework')
else:
raise Exception("{} Framework is not supported"
.format(Frameworks().to_string(self.network_parameters.framework)))
with tf.variable_scope("{}/{}".format(self.ap.full_name_id, name)):
variable_scope = "{}/{}".format(self.ap.full_name_id, name)
# Global network - the main network shared between threads
self.global_network = None
if self.has_global:
# we assign the parameters of this network on the parameters server
with tf.device(replicated_device):
self.global_network = general_network(agent_parameters=agent_parameters,
name='{}/global'.format(name),
global_network=None,
network_is_local=False,
spaces=spaces,
network_is_trainable=True)
# Global network - the main network shared between threads
self.global_network = None
if self.has_global:
# we assign the parameters of this network on the parameters server
self.global_network = general_network(variable_scope=variable_scope,
devices=force_list(replicated_device),
agent_parameters=agent_parameters,
name='{}/global'.format(name),
global_network=None,
network_is_local=False,
spaces=spaces,
network_is_trainable=True)
# Online network - local copy of the main network used for playing
self.online_network = None
with tf.device(worker_device):
self.online_network = general_network(agent_parameters=agent_parameters,
name='{}/online'.format(name),
global_network=self.global_network,
network_is_local=True,
spaces=spaces,
network_is_trainable=True)
# Online network - local copy of the main network used for playing
self.online_network = None
self.online_network = general_network(variable_scope=variable_scope,
devices=force_list(worker_device),
agent_parameters=agent_parameters,
name='{}/online'.format(name),
global_network=self.global_network,
network_is_local=True,
spaces=spaces,
network_is_trainable=True)
# Target network - a local, slow updating network used for stabilizing the learning
self.target_network = None
if self.has_target:
with tf.device(worker_device):
self.target_network = general_network(agent_parameters=agent_parameters,
name='{}/target'.format(name),
global_network=self.global_network,
network_is_local=True,
spaces=spaces,
network_is_trainable=False)
# Target network - a local, slow updating network used for stabilizing the learning
self.target_network = None
if self.has_target:
self.target_network = general_network(variable_scope=variable_scope,
devices=force_list(worker_device),
agent_parameters=agent_parameters,
name='{}/target'.format(name),
global_network=self.global_network,
network_is_local=True,
spaces=spaces,
network_is_trainable=False)
def sync(self):
"""
@@ -198,26 +202,6 @@ class NetworkWrapper(object):
"""
return type(self.online_network).parallel_predict(self.sess, network_input_tuples)
def get_local_variables(self):
"""
Get all the variables that are local to the thread
:return: a list of all the variables that are local to the thread
"""
local_variables = [v for v in tf.local_variables() if self.online_network.name in v.name]
if self.has_target:
local_variables += [v for v in tf.local_variables() if self.target_network.name in v.name]
return local_variables
def get_global_variables(self):
"""
Get all the variables that are shared between threads
:return: a list of all the variables that are shared between threads
"""
global_variables = [v for v in tf.global_variables() if self.global_network.name in v.name]
return global_variables
def set_is_training(self, state: bool):
"""
Set the phase of the network between training and testing