mirror of
https://github.com/gryf/coach.git
synced 2026-02-15 13:35:55 +01:00
Adding initial interface for backend and redis pubsub (#19)
* Adding initial interface for backend and redis pubsub * Addressing comments, adding super in all memories * Removing distributed experience replay
This commit is contained in:
committed by
zach dwiel
parent
a54ef2757f
commit
6b2de6ba6d
@@ -33,6 +33,7 @@ from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperi
|
||||
from rl_coach.spaces import SpacesDefinition, VectorObservationSpace, GoalsSpace, AttentionActionSpace
|
||||
from rl_coach.utils import Signal, force_list
|
||||
from rl_coach.utils import dynamic_import_and_instantiate_module_from_params
|
||||
from rl_coach.memories.backend.memory_impl import get_memory_backend
|
||||
|
||||
|
||||
class Agent(AgentInterface):
|
||||
@@ -76,6 +77,14 @@ class Agent(AgentInterface):
|
||||
# modules
|
||||
self.memory = dynamic_import_and_instantiate_module_from_params(self.ap.memory)
|
||||
|
||||
if hasattr(self.ap.memory, 'memory_backend_params'):
|
||||
self.memory_backend = get_memory_backend(self.ap.memory.memory_backend_params)
|
||||
|
||||
if self.ap.memory.memory_backend_params.run_type == 'trainer':
|
||||
self.memory_backend.subscribe(self.memory)
|
||||
else:
|
||||
self.memory.set_memory_backend(self.memory_backend)
|
||||
|
||||
if agent_parameters.memory.load_memory_from_file_path:
|
||||
screen.log_title("Loading replay buffer from pickle. Pickle path: {}"
|
||||
.format(agent_parameters.memory.load_memory_from_file_path))
|
||||
@@ -534,6 +543,9 @@ class Agent(AgentInterface):
|
||||
Determine if we should start a training phase according to the number of steps passed since the last training
|
||||
:return: boolean: True if we should start a training phase
|
||||
"""
|
||||
|
||||
if hasattr(self.ap.memory, 'memory_backend_params'):
|
||||
self.total_steps_counter = self.call_memory('num_transitions')
|
||||
step_method = self.ap.algorithm.num_consecutive_playing_steps
|
||||
if step_method.__class__ == EnvironmentEpisodes:
|
||||
should_update = (self.current_episode - self.last_training_phase_step) >= step_method.num_steps
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -27,7 +27,6 @@ from rl_coach.architectures.tensorflow_components.embedders.embedder import Inpu
|
||||
from rl_coach.core_types import EnvironmentSteps
|
||||
from rl_coach.exploration_policies.e_greedy import EGreedyParameters
|
||||
from rl_coach.memories.non_episodic.experience_replay import ExperienceReplayParameters
|
||||
from rl_coach.memories.non_episodic.distributed_experience_replay import DistributedExperienceReplayParameters
|
||||
from rl_coach.schedules import LinearSchedule
|
||||
|
||||
|
||||
@@ -51,20 +50,6 @@ class DQNNetworkParameters(NetworkParameters):
|
||||
self.create_target_network = True
|
||||
|
||||
|
||||
class DQNAgentParametersDistributed(AgentParameters):
|
||||
def __init__(self):
|
||||
super().__init__(algorithm=DQNAlgorithmParameters(),
|
||||
exploration=EGreedyParameters(),
|
||||
memory=DistributedExperienceReplayParameters(),
|
||||
networks={"main": DQNNetworkParameters()})
|
||||
self.exploration.epsilon_schedule = LinearSchedule(1, 0.1, 1000000)
|
||||
self.exploration.evaluation_epsilon = 0.05
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return 'rl_coach.agents.dqn_agent:DQNAgent'
|
||||
|
||||
|
||||
class DQNAgentParameters(AgentParameters):
|
||||
def __init__(self):
|
||||
super().__init__(algorithm=DQNAlgorithmParameters(),
|
||||
|
||||
Reference in New Issue
Block a user