1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

Moved tf.variable_scope and tf.device calls to framework-specific architecture (#136)

This commit is contained in:
Sina Afrooze
2018-11-22 12:52:22 -08:00
committed by Gal Leibovich
parent 559969d3dd
commit 87a7848b0a
11 changed files with 219 additions and 91 deletions

View File

@@ -15,7 +15,7 @@
#
import copy
from typing import Dict
from typing import Dict, List, Union
import numpy as np
import tensorflow as tf
@@ -25,8 +25,9 @@ from rl_coach.architectures.head_parameters import HeadParameters
from rl_coach.architectures.middleware_parameters import MiddlewareParameters
from rl_coach.architectures.tensorflow_components.architecture import TensorFlowArchitecture
from rl_coach.architectures.tensorflow_components import utils
from rl_coach.base_parameters import AgentParameters, EmbeddingMergerType
from rl_coach.base_parameters import AgentParameters, Device, DeviceType, EmbeddingMergerType
from rl_coach.core_types import PredictionType
from rl_coach.logger import screen
from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace, TensorObservationSpace
from rl_coach.utils import get_all_subclasses, dynamic_import_and_instantiate_module_from_params, indent_string
@@ -35,6 +36,62 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
"""
A generalized version of all possible networks implemented using tensorflow.
"""
# dictionary of variable-scope name to variable-scope object to prevent tensorflow from
# creating a new auxiliary variable scope even when name is properly specified
variable_scopes_dict = dict()
@staticmethod
def construct(variable_scope: str, devices: List[str], *args, **kwargs) -> 'GeneralTensorFlowNetwork':
"""
Construct a network class using the provided variable scope and on requested devices
:param variable_scope: string specifying variable scope under which to create network variables
:param devices: list of devices (can be list of Device objects, or string for TF distributed)
:param args: all other arguments for class initializer
:param kwargs: all other keyword arguments for class initializer
:return: a GeneralTensorFlowNetwork object
"""
if len(devices) > 1:
screen.warning("Tensorflow implementation only support a single device. Using {}".format(devices[0]))
def construct_on_device():
with tf.device(GeneralTensorFlowNetwork._tf_device(devices[0])):
return GeneralTensorFlowNetwork(*args, **kwargs)
# If variable_scope is in our dictionary, then this is not the first time that this variable_scope
# is being used with construct(). So to avoid TF adding an incrementing number to the end of the
# variable_scope to uniquify it, we have to both pass the previous variable_scope object to the new
# variable_scope() call and also recover the name space using name_scope
if variable_scope in GeneralTensorFlowNetwork.variable_scopes_dict:
variable_scope = GeneralTensorFlowNetwork.variable_scopes_dict[variable_scope]
with tf.variable_scope(variable_scope, auxiliary_name_scope=False) as vs:
with tf.name_scope(vs.original_name_scope):
return construct_on_device()
else:
with tf.variable_scope(variable_scope, auxiliary_name_scope=True) as vs:
# Add variable_scope object to dictionary for next call to construct
GeneralTensorFlowNetwork.variable_scopes_dict[variable_scope] = vs
return construct_on_device()
@staticmethod
def _tf_device(device: Union[str, Device]) -> str:
"""
Convert device to tensorflow-specific device representation
:param device: either a specific string (used in distributed mode) which is returned without
any change or a Device type
:return: tensorflow-specific string for device
"""
if isinstance(device, str):
return device
elif isinstance(device, Device):
if device.device_type == DeviceType.CPU:
return "/cpu:0"
elif device.device_type == DeviceType.GPU:
return "/device:GPU:{}".format(device.index)
else:
raise ValueError("Invalid device_type: {}".format(device.device_type))
else:
raise ValueError("Invalid device instance type: {}".format(type(device)))
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, name: str,
global_network=None, network_is_local: bool=True, network_is_trainable: bool=False):
"""