mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
Moved tf.variable_scope and tf.device calls to framework-specific architecture (#136)
This commit is contained in:
committed by
Gal Leibovich
parent
559969d3dd
commit
87a7848b0a
@@ -15,7 +15,7 @@
|
||||
#
|
||||
|
||||
import copy
|
||||
from typing import Dict
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
@@ -25,8 +25,9 @@ from rl_coach.architectures.head_parameters import HeadParameters
|
||||
from rl_coach.architectures.middleware_parameters import MiddlewareParameters
|
||||
from rl_coach.architectures.tensorflow_components.architecture import TensorFlowArchitecture
|
||||
from rl_coach.architectures.tensorflow_components import utils
|
||||
from rl_coach.base_parameters import AgentParameters, EmbeddingMergerType
|
||||
from rl_coach.base_parameters import AgentParameters, Device, DeviceType, EmbeddingMergerType
|
||||
from rl_coach.core_types import PredictionType
|
||||
from rl_coach.logger import screen
|
||||
from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace, TensorObservationSpace
|
||||
from rl_coach.utils import get_all_subclasses, dynamic_import_and_instantiate_module_from_params, indent_string
|
||||
|
||||
@@ -35,6 +36,62 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
"""
|
||||
A generalized version of all possible networks implemented using tensorflow.
|
||||
"""
|
||||
# dictionary of variable-scope name to variable-scope object to prevent tensorflow from
|
||||
# creating a new auxiliary variable scope even when name is properly specified
|
||||
variable_scopes_dict = dict()
|
||||
|
||||
@staticmethod
|
||||
def construct(variable_scope: str, devices: List[str], *args, **kwargs) -> 'GeneralTensorFlowNetwork':
|
||||
"""
|
||||
Construct a network class using the provided variable scope and on requested devices
|
||||
:param variable_scope: string specifying variable scope under which to create network variables
|
||||
:param devices: list of devices (can be list of Device objects, or string for TF distributed)
|
||||
:param args: all other arguments for class initializer
|
||||
:param kwargs: all other keyword arguments for class initializer
|
||||
:return: a GeneralTensorFlowNetwork object
|
||||
"""
|
||||
if len(devices) > 1:
|
||||
screen.warning("Tensorflow implementation only support a single device. Using {}".format(devices[0]))
|
||||
|
||||
def construct_on_device():
|
||||
with tf.device(GeneralTensorFlowNetwork._tf_device(devices[0])):
|
||||
return GeneralTensorFlowNetwork(*args, **kwargs)
|
||||
|
||||
# If variable_scope is in our dictionary, then this is not the first time that this variable_scope
|
||||
# is being used with construct(). So to avoid TF adding an incrementing number to the end of the
|
||||
# variable_scope to uniquify it, we have to both pass the previous variable_scope object to the new
|
||||
# variable_scope() call and also recover the name space using name_scope
|
||||
if variable_scope in GeneralTensorFlowNetwork.variable_scopes_dict:
|
||||
variable_scope = GeneralTensorFlowNetwork.variable_scopes_dict[variable_scope]
|
||||
with tf.variable_scope(variable_scope, auxiliary_name_scope=False) as vs:
|
||||
with tf.name_scope(vs.original_name_scope):
|
||||
return construct_on_device()
|
||||
else:
|
||||
with tf.variable_scope(variable_scope, auxiliary_name_scope=True) as vs:
|
||||
# Add variable_scope object to dictionary for next call to construct
|
||||
GeneralTensorFlowNetwork.variable_scopes_dict[variable_scope] = vs
|
||||
return construct_on_device()
|
||||
|
||||
@staticmethod
|
||||
def _tf_device(device: Union[str, Device]) -> str:
|
||||
"""
|
||||
Convert device to tensorflow-specific device representation
|
||||
:param device: either a specific string (used in distributed mode) which is returned without
|
||||
any change or a Device type
|
||||
:return: tensorflow-specific string for device
|
||||
"""
|
||||
if isinstance(device, str):
|
||||
return device
|
||||
elif isinstance(device, Device):
|
||||
if device.device_type == DeviceType.CPU:
|
||||
return "/cpu:0"
|
||||
elif device.device_type == DeviceType.GPU:
|
||||
return "/device:GPU:{}".format(device.index)
|
||||
else:
|
||||
raise ValueError("Invalid device_type: {}".format(device.device_type))
|
||||
else:
|
||||
raise ValueError("Invalid device instance type: {}".format(type(device)))
|
||||
|
||||
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, name: str,
|
||||
global_network=None, network_is_local: bool=True, network_is_trainable: bool=False):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user