1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

Adding initial interface for backend and redis pubsub (#19)

* Adding initial interface for backend and redis pubsub

* Addressing comments, adding super in all memories

* Removing distributed experience replay
This commit is contained in:
Ajay Deshpande
2018-10-03 15:07:48 -07:00
committed by zach dwiel
parent a54ef2757f
commit 6b2de6ba6d
21 changed files with 459 additions and 444 deletions

View File

@@ -33,6 +33,7 @@ from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperi
from rl_coach.spaces import SpacesDefinition, VectorObservationSpace, GoalsSpace, AttentionActionSpace
from rl_coach.utils import Signal, force_list
from rl_coach.utils import dynamic_import_and_instantiate_module_from_params
from rl_coach.memories.backend.memory_impl import get_memory_backend
class Agent(AgentInterface):
@@ -76,6 +77,14 @@ class Agent(AgentInterface):
# modules
self.memory = dynamic_import_and_instantiate_module_from_params(self.ap.memory)
if hasattr(self.ap.memory, 'memory_backend_params'):
self.memory_backend = get_memory_backend(self.ap.memory.memory_backend_params)
if self.ap.memory.memory_backend_params.run_type == 'trainer':
self.memory_backend.subscribe(self.memory)
else:
self.memory.set_memory_backend(self.memory_backend)
if agent_parameters.memory.load_memory_from_file_path:
screen.log_title("Loading replay buffer from pickle. Pickle path: {}"
.format(agent_parameters.memory.load_memory_from_file_path))
@@ -534,6 +543,9 @@ class Agent(AgentInterface):
Determine if we should start a training phase according to the number of steps passed since the last training
:return: boolean: True if we should start a training phase
"""
if hasattr(self.ap.memory, 'memory_backend_params'):
self.total_steps_counter = self.call_memory('num_transitions')
step_method = self.ap.algorithm.num_consecutive_playing_steps
if step_method.__class__ == EnvironmentEpisodes:
should_update = (self.current_episode - self.last_training_phase_step) >= step_method.num_steps