mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Moved tf.variable_scope and tf.device calls to framework-specific architecture (#136)
This commit is contained in:
committed by
Gal Leibovich
parent
559969d3dd
commit
87a7848b0a
@@ -20,6 +20,7 @@ from rl_coach.base_parameters import Frameworks, AgentParameters
|
||||
from rl_coach.logger import failed_imports
|
||||
from rl_coach.saver import SaverCollection
|
||||
from rl_coach.spaces import SpacesDefinition
|
||||
from rl_coach.utils import force_list
|
||||
try:
|
||||
import tensorflow as tf
|
||||
from rl_coach.architectures.tensorflow_components.general_network import GeneralTensorFlowNetwork
|
||||
@@ -53,52 +54,55 @@ class NetworkWrapper(object):
|
||||
|
||||
if self.network_parameters.framework == Frameworks.tensorflow:
|
||||
if "tensorflow" not in failed_imports:
|
||||
general_network = GeneralTensorFlowNetwork
|
||||
general_network = GeneralTensorFlowNetwork.construct
|
||||
else:
|
||||
raise Exception('Install tensorflow before using it as framework')
|
||||
elif self.network_parameters.framework == Frameworks.mxnet:
|
||||
if "mxnet" not in failed_imports:
|
||||
general_network = GeneralMxnetNetwork
|
||||
general_network = GeneralMxnetNetwork.construct
|
||||
else:
|
||||
raise Exception('Install mxnet before using it as framework')
|
||||
else:
|
||||
raise Exception("{} Framework is not supported"
|
||||
.format(Frameworks().to_string(self.network_parameters.framework)))
|
||||
|
||||
with tf.variable_scope("{}/{}".format(self.ap.full_name_id, name)):
|
||||
variable_scope = "{}/{}".format(self.ap.full_name_id, name)
|
||||
|
||||
# Global network - the main network shared between threads
|
||||
self.global_network = None
|
||||
if self.has_global:
|
||||
# we assign the parameters of this network on the parameters server
|
||||
with tf.device(replicated_device):
|
||||
self.global_network = general_network(agent_parameters=agent_parameters,
|
||||
name='{}/global'.format(name),
|
||||
global_network=None,
|
||||
network_is_local=False,
|
||||
spaces=spaces,
|
||||
network_is_trainable=True)
|
||||
# Global network - the main network shared between threads
|
||||
self.global_network = None
|
||||
if self.has_global:
|
||||
# we assign the parameters of this network on the parameters server
|
||||
self.global_network = general_network(variable_scope=variable_scope,
|
||||
devices=force_list(replicated_device),
|
||||
agent_parameters=agent_parameters,
|
||||
name='{}/global'.format(name),
|
||||
global_network=None,
|
||||
network_is_local=False,
|
||||
spaces=spaces,
|
||||
network_is_trainable=True)
|
||||
|
||||
# Online network - local copy of the main network used for playing
|
||||
self.online_network = None
|
||||
with tf.device(worker_device):
|
||||
self.online_network = general_network(agent_parameters=agent_parameters,
|
||||
name='{}/online'.format(name),
|
||||
global_network=self.global_network,
|
||||
network_is_local=True,
|
||||
spaces=spaces,
|
||||
network_is_trainable=True)
|
||||
# Online network - local copy of the main network used for playing
|
||||
self.online_network = None
|
||||
self.online_network = general_network(variable_scope=variable_scope,
|
||||
devices=force_list(worker_device),
|
||||
agent_parameters=agent_parameters,
|
||||
name='{}/online'.format(name),
|
||||
global_network=self.global_network,
|
||||
network_is_local=True,
|
||||
spaces=spaces,
|
||||
network_is_trainable=True)
|
||||
|
||||
# Target network - a local, slow updating network used for stabilizing the learning
|
||||
self.target_network = None
|
||||
if self.has_target:
|
||||
with tf.device(worker_device):
|
||||
self.target_network = general_network(agent_parameters=agent_parameters,
|
||||
name='{}/target'.format(name),
|
||||
global_network=self.global_network,
|
||||
network_is_local=True,
|
||||
spaces=spaces,
|
||||
network_is_trainable=False)
|
||||
# Target network - a local, slow updating network used for stabilizing the learning
|
||||
self.target_network = None
|
||||
if self.has_target:
|
||||
self.target_network = general_network(variable_scope=variable_scope,
|
||||
devices=force_list(worker_device),
|
||||
agent_parameters=agent_parameters,
|
||||
name='{}/target'.format(name),
|
||||
global_network=self.global_network,
|
||||
network_is_local=True,
|
||||
spaces=spaces,
|
||||
network_is_trainable=False)
|
||||
|
||||
def sync(self):
|
||||
"""
|
||||
@@ -198,26 +202,6 @@ class NetworkWrapper(object):
|
||||
"""
|
||||
return type(self.online_network).parallel_predict(self.sess, network_input_tuples)
|
||||
|
||||
def get_local_variables(self):
|
||||
"""
|
||||
Get all the variables that are local to the thread
|
||||
|
||||
:return: a list of all the variables that are local to the thread
|
||||
"""
|
||||
local_variables = [v for v in tf.local_variables() if self.online_network.name in v.name]
|
||||
if self.has_target:
|
||||
local_variables += [v for v in tf.local_variables() if self.target_network.name in v.name]
|
||||
return local_variables
|
||||
|
||||
def get_global_variables(self):
|
||||
"""
|
||||
Get all the variables that are shared between threads
|
||||
|
||||
:return: a list of all the variables that are shared between threads
|
||||
"""
|
||||
global_variables = [v for v in tf.global_variables() if self.global_network.name in v.name]
|
||||
return global_variables
|
||||
|
||||
def set_is_training(self, state: bool):
|
||||
"""
|
||||
Set the phase of the network between training and testing
|
||||
|
||||
Reference in New Issue
Block a user