mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
Numpy shared running stats (#97)
This commit is contained in:
@@ -25,7 +25,7 @@ from six.moves import range
|
||||
|
||||
from rl_coach.agents.agent_interface import AgentInterface
|
||||
from rl_coach.architectures.network_wrapper import NetworkWrapper
|
||||
from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters
|
||||
from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters, Frameworks
|
||||
from rl_coach.core_types import RunPhase, PredictionType, EnvironmentEpisodes, ActionType, Batch, Episode, StateType
|
||||
from rl_coach.core_types import Transition, ActionInfo, TrainingSteps, EnvironmentSteps, EnvResponse
|
||||
from rl_coach.logger import screen, Logger, EpisodeLogger
|
||||
@@ -110,14 +110,27 @@ class Agent(AgentInterface):
|
||||
self.output_filter = self.ap.output_filter
|
||||
self.pre_network_filter = self.ap.pre_network_filter
|
||||
device = self.replicated_device if self.replicated_device else self.worker_device
|
||||
|
||||
# TODO-REMOVE This is a temporary flow dividing to 3 modes. To be converged to a single flow once distributed tf
|
||||
# is removed, and Redis is used for sharing data between local workers.
|
||||
# Filters MoW will be split between different configurations
|
||||
# 1. Distributed coach synchrnization type (=distributed across multiple nodes) - Redis based data sharing + numpy arithmetic backend
|
||||
# 2. Distributed TF (=distributed on a single node, using distributed TF) - TF for both data sharing and arithmetic backend
|
||||
# 3. Single worker (=both TF and Mxnet) - no data sharing needed + numpy arithmetic backend
|
||||
|
||||
if hasattr(self.ap.memory, 'memory_backend_params') and self.ap.algorithm.distributed_coach_synchronization_type:
|
||||
self.input_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params)
|
||||
self.output_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params)
|
||||
self.pre_network_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params)
|
||||
self.input_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params, mode='numpy')
|
||||
self.output_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params, mode='numpy')
|
||||
self.pre_network_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params, mode='numpy')
|
||||
elif (type(agent_parameters.task_parameters) == DistributedTaskParameters and
|
||||
agent_parameters.task_parameters.framework_type == Frameworks.tensorflow):
|
||||
self.input_filter.set_device(device, mode='tf')
|
||||
self.output_filter.set_device(device, mode='tf')
|
||||
self.pre_network_filter.set_device(device, mode='tf')
|
||||
else:
|
||||
self.input_filter.set_device(device)
|
||||
self.output_filter.set_device(device)
|
||||
self.pre_network_filter.set_device(device)
|
||||
self.input_filter.set_device(device, mode='numpy')
|
||||
self.output_filter.set_device(device, mode='numpy')
|
||||
self.pre_network_filter.set_device(device, mode='numpy')
|
||||
|
||||
# initialize all internal variables
|
||||
self._phase = RunPhase.HEATUP
|
||||
|
||||
Reference in New Issue
Block a user