1
0
mirror of https://github.com/gryf/coach.git synced 2026-02-15 13:35:55 +01:00

Adding initial interface for backend and redis pubsub (#19)

* Adding initial interface for backend and redis pubsub

* Addressing comments, adding super in all memories

* Removing distributed experience replay
This commit is contained in:
Ajay Deshpande
2018-10-03 15:07:48 -07:00
committed by zach dwiel
parent a54ef2757f
commit 6b2de6ba6d
21 changed files with 459 additions and 444 deletions

View File

@@ -33,6 +33,7 @@ from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperi
from rl_coach.spaces import SpacesDefinition, VectorObservationSpace, GoalsSpace, AttentionActionSpace
from rl_coach.utils import Signal, force_list
from rl_coach.utils import dynamic_import_and_instantiate_module_from_params
from rl_coach.memories.backend.memory_impl import get_memory_backend
class Agent(AgentInterface):
@@ -76,6 +77,14 @@ class Agent(AgentInterface):
# modules
self.memory = dynamic_import_and_instantiate_module_from_params(self.ap.memory)
if hasattr(self.ap.memory, 'memory_backend_params'):
self.memory_backend = get_memory_backend(self.ap.memory.memory_backend_params)
if self.ap.memory.memory_backend_params.run_type == 'trainer':
self.memory_backend.subscribe(self.memory)
else:
self.memory.set_memory_backend(self.memory_backend)
if agent_parameters.memory.load_memory_from_file_path:
screen.log_title("Loading replay buffer from pickle. Pickle path: {}"
.format(agent_parameters.memory.load_memory_from_file_path))
@@ -534,6 +543,9 @@ class Agent(AgentInterface):
Determine if we should start a training phase according to the number of steps passed since the last training
:return: boolean: True if we should start a training phase
"""
if hasattr(self.ap.memory, 'memory_backend_params'):
self.total_steps_counter = self.call_memory('num_transitions')
step_method = self.ap.algorithm.num_consecutive_playing_steps
if step_method.__class__ == EnvironmentEpisodes:
should_update = (self.current_episode - self.last_training_phase_step) >= step_method.num_steps

View File

@@ -1,5 +1,5 @@
#
# Copyright (c) 2017 Intel Corporation
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@@ -27,7 +27,6 @@ from rl_coach.architectures.tensorflow_components.embedders.embedder import Inpu
from rl_coach.core_types import EnvironmentSteps
from rl_coach.exploration_policies.e_greedy import EGreedyParameters
from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayParameters
from rl_coach.memories.non_episodic.distributed_experience_replay import DistributedExperienceReplayParameters
from rl_coach.schedules import LinearSchedule
@@ -51,20 +50,6 @@ class DQNNetworkParameters(NetworkParameters):
self.create_target_network = True
class DQNAgentParametersDistributed(AgentParameters):
def __init__(self):
super().__init__(algorithm=DQNAlgorithmParameters(),
exploration=EGreedyParameters(),
memory=DistributedExperienceReplayParameters(),
networks={"main": DQNNetworkParameters()})
self.exploration.epsilon_schedule = LinearSchedule(1, 0.1, 1000000)
self.exploration.evaluation_epsilon = 0.05
@property
def path(self):
return 'rl_coach.agents.dqn_agent:DQNAgent'
class DQNAgentParameters(AgentParameters):
def __init__(self):
super().__init__(algorithm=DQNAlgorithmParameters(),