1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

Numpy shared running stats (#97)

This commit is contained in:
Gal Leibovich
2018-11-18 14:46:40 +02:00
committed by GitHub
parent e1fa6e9681
commit 6caf721d1c
10 changed files with 226 additions and 114 deletions

View File

@@ -25,7 +25,7 @@ from six.moves import range
from rl_coach.agents.agent_interface import AgentInterface
from rl_coach.architectures.network_wrapper import NetworkWrapper
from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters
from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters, Frameworks
from rl_coach.core_types import RunPhase, PredictionType, EnvironmentEpisodes, ActionType, Batch, Episode, StateType
from rl_coach.core_types import Transition, ActionInfo, TrainingSteps, EnvironmentSteps, EnvResponse
from rl_coach.logger import screen, Logger, EpisodeLogger
@@ -110,14 +110,27 @@ class Agent(AgentInterface):
self.output_filter = self.ap.output_filter
self.pre_network_filter = self.ap.pre_network_filter
device = self.replicated_device if self.replicated_device else self.worker_device
# TODO-REMOVE This is a temporary flow dividing to 3 modes. To be converged to a single flow once distributed tf
# is removed, and Redis is used for sharing data between local workers.
# Filters MoW will be split between different configurations
# 1. Distributed coach synchrnization type (=distributed across multiple nodes) - Redis based data sharing + numpy arithmetic backend
# 2. Distributed TF (=distributed on a single node, using distributed TF) - TF for both data sharing and arithmetic backend
# 3. Single worker (=both TF and Mxnet) - no data sharing needed + numpy arithmetic backend
if hasattr(self.ap.memory, 'memory_backend_params') and self.ap.algorithm.distributed_coach_synchronization_type:
self.input_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params)
self.output_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params)
self.pre_network_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params)
self.input_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params, mode='numpy')
self.output_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params, mode='numpy')
self.pre_network_filter.set_device(device, memory_backend_params=self.ap.memory.memory_backend_params, mode='numpy')
elif (type(agent_parameters.task_parameters) == DistributedTaskParameters and
agent_parameters.task_parameters.framework_type == Frameworks.tensorflow):
self.input_filter.set_device(device, mode='tf')
self.output_filter.set_device(device, mode='tf')
self.pre_network_filter.set_device(device, mode='tf')
else:
self.input_filter.set_device(device)
self.output_filter.set_device(device)
self.pre_network_filter.set_device(device)
self.input_filter.set_device(device, mode='numpy')
self.output_filter.set_device(device, mode='numpy')
self.pre_network_filter.set_device(device, mode='numpy')
# initialize all internal variables
self._phase = RunPhase.HEATUP