1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

Moved tf.variable_scope and tf.device calls to framework-specific architecture (#136)

This commit is contained in:
Sina Afrooze
2018-11-22 12:52:22 -08:00
committed by Gal Leibovich
parent 559969d3dd
commit 87a7848b0a
11 changed files with 219 additions and 91 deletions

View File

@@ -26,7 +26,7 @@ from six.moves import range
from rl_coach.agents.agent_interface import AgentInterface
from rl_coach.architectures.network_wrapper import NetworkWrapper
from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters, Frameworks
from rl_coach.base_parameters import AgentParameters, Device, DeviceType, DistributedTaskParameters, Frameworks
from rl_coach.core_types import RunPhase, PredictionType, EnvironmentEpisodes, ActionType, Batch, Episode, StateType
from rl_coach.core_types import Transition, ActionInfo, TrainingSteps, EnvironmentSteps, EnvResponse
from rl_coach.logger import screen, Logger, EpisodeLogger
@@ -98,14 +98,18 @@ class Agent(AgentInterface):
self.has_global = True
self.replicated_device = agent_parameters.task_parameters.device
self.worker_device = "/job:worker/task:{}".format(self.task_id)
if agent_parameters.task_parameters.use_cpu:
self.worker_device += "/cpu:0"
else:
self.worker_device += "/device:GPU:0"
else:
self.has_global = False
self.replicated_device = None
self.worker_device = ""
if agent_parameters.task_parameters.use_cpu:
self.worker_device += "/cpu:0"
else:
self.worker_device += "/device:GPU:0"
if agent_parameters.task_parameters.use_cpu:
self.worker_device = Device(DeviceType.CPU)
else:
self.worker_device = [Device(DeviceType.GPU, i)
for i in range(agent_parameters.task_parameters.num_gpu)]
# filters
self.input_filter = self.ap.input_filter