mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
Moved tf.variable_scope and tf.device calls to framework-specific architecture (#136)
This commit is contained in:
committed by
Gal Leibovich
parent
559969d3dd
commit
87a7848b0a
@@ -26,7 +26,7 @@ from six.moves import range
|
||||
|
||||
from rl_coach.agents.agent_interface import AgentInterface
|
||||
from rl_coach.architectures.network_wrapper import NetworkWrapper
|
||||
from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters, Frameworks
|
||||
from rl_coach.base_parameters import AgentParameters, Device, DeviceType, DistributedTaskParameters, Frameworks
|
||||
from rl_coach.core_types import RunPhase, PredictionType, EnvironmentEpisodes, ActionType, Batch, Episode, StateType
|
||||
from rl_coach.core_types import Transition, ActionInfo, TrainingSteps, EnvironmentSteps, EnvResponse
|
||||
from rl_coach.logger import screen, Logger, EpisodeLogger
|
||||
@@ -98,14 +98,18 @@ class Agent(AgentInterface):
|
||||
self.has_global = True
|
||||
self.replicated_device = agent_parameters.task_parameters.device
|
||||
self.worker_device = "/job:worker/task:{}".format(self.task_id)
|
||||
if agent_parameters.task_parameters.use_cpu:
|
||||
self.worker_device += "/cpu:0"
|
||||
else:
|
||||
self.worker_device += "/device:GPU:0"
|
||||
else:
|
||||
self.has_global = False
|
||||
self.replicated_device = None
|
||||
self.worker_device = ""
|
||||
if agent_parameters.task_parameters.use_cpu:
|
||||
self.worker_device += "/cpu:0"
|
||||
else:
|
||||
self.worker_device += "/device:GPU:0"
|
||||
if agent_parameters.task_parameters.use_cpu:
|
||||
self.worker_device = Device(DeviceType.CPU)
|
||||
else:
|
||||
self.worker_device = [Device(DeviceType.GPU, i)
|
||||
for i in range(agent_parameters.task_parameters.num_gpu)]
|
||||
|
||||
# filters
|
||||
self.input_filter = self.ap.input_filter
|
||||
|
||||
Reference in New Issue
Block a user