mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Simulating the act on the trainer. (#65)
* Remove the use of daemon threads for Redis subscribe. * Emulate act and observe on trainer side to update internal vars.
This commit is contained in:
committed by
Balaji Subramaniam
parent
fe6857eabd
commit
fde73ced13
@@ -80,9 +80,7 @@ class Agent(AgentInterface):
|
|||||||
if hasattr(self.ap.memory, 'memory_backend_params'):
|
if hasattr(self.ap.memory, 'memory_backend_params'):
|
||||||
self.memory_backend = get_memory_backend(self.ap.memory.memory_backend_params)
|
self.memory_backend = get_memory_backend(self.ap.memory.memory_backend_params)
|
||||||
|
|
||||||
if self.ap.memory.memory_backend_params.run_type == 'trainer':
|
if self.ap.memory.memory_backend_params.run_type != 'trainer':
|
||||||
self.memory_backend.subscribe(self)
|
|
||||||
else:
|
|
||||||
self.memory.set_memory_backend(self.memory_backend)
|
self.memory.set_memory_backend(self.memory_backend)
|
||||||
|
|
||||||
if agent_parameters.memory.load_memory_from_file_path:
|
if agent_parameters.memory.load_memory_from_file_path:
|
||||||
@@ -583,14 +581,14 @@ class Agent(AgentInterface):
|
|||||||
"EnvironmentSteps or TrainingSteps. Instead it is {}".format(step_method.__class__))
|
"EnvironmentSteps or TrainingSteps. Instead it is {}".format(step_method.__class__))
|
||||||
return should_update
|
return should_update
|
||||||
|
|
||||||
def _should_train(self, wait_for_full_episode=False) -> bool:
|
def _should_train(self):
|
||||||
"""
|
"""
|
||||||
Determine if we should start a training phase according to the number of steps passed since the last training
|
Determine if we should start a training phase according to the number of steps passed since the last training
|
||||||
|
|
||||||
:return: boolean: True if we should start a training phase
|
:return: boolean: True if we should start a training phase
|
||||||
"""
|
"""
|
||||||
|
|
||||||
should_update = self._should_train_helper(wait_for_full_episode=wait_for_full_episode)
|
should_update = self._should_train_helper()
|
||||||
|
|
||||||
step_method = self.ap.algorithm.num_consecutive_playing_steps
|
step_method = self.ap.algorithm.num_consecutive_playing_steps
|
||||||
|
|
||||||
@@ -602,8 +600,8 @@ class Agent(AgentInterface):
|
|||||||
|
|
||||||
return should_update
|
return should_update
|
||||||
|
|
||||||
def _should_train_helper(self, wait_for_full_episode=False):
|
def _should_train_helper(self):
|
||||||
|
wait_for_full_episode = self.ap.algorithm.act_for_full_episodes
|
||||||
step_method = self.ap.algorithm.num_consecutive_playing_steps
|
step_method = self.ap.algorithm.num_consecutive_playing_steps
|
||||||
|
|
||||||
if step_method.__class__ == EnvironmentEpisodes:
|
if step_method.__class__ == EnvironmentEpisodes:
|
||||||
@@ -922,5 +920,66 @@ class Agent(AgentInterface):
|
|||||||
for network in self.networks.values():
|
for network in self.networks.values():
|
||||||
network.sync()
|
network.sync()
|
||||||
|
|
||||||
|
# TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
|
||||||
|
# an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]
|
||||||
|
def emulate_observe_on_trainer(self, transition: Transition) -> bool:
|
||||||
|
"""
|
||||||
|
This emulates the observe using the transition obtained from the rollout worker on the training worker
|
||||||
|
in case of distributed training.
|
||||||
|
Given a response from the environment, distill the observation from it and store it for later use.
|
||||||
|
The response should be a dictionary containing the performed action, the new observation and measurements,
|
||||||
|
the reward, a game over flag and any additional information necessary.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
# if we are in the first step in the episode, then we don't have a a next state and a reward and thus no
|
||||||
|
# transition yet, and therefore we don't need to store anything in the memory.
|
||||||
|
# also we did not reach the goal yet.
|
||||||
|
if self.current_episode_steps_counter == 0:
|
||||||
|
# initialize the current state
|
||||||
|
return transition.game_over
|
||||||
|
else:
|
||||||
|
# sum up the total shaped reward
|
||||||
|
self.total_shaped_reward_in_current_episode += transition.reward
|
||||||
|
self.total_reward_in_current_episode += transition.reward
|
||||||
|
self.shaped_reward.add_sample(transition.reward)
|
||||||
|
self.reward.add_sample(transition.reward)
|
||||||
|
|
||||||
|
# create and store the transition
|
||||||
|
if self.phase in [RunPhase.TRAIN, RunPhase.HEATUP]:
|
||||||
|
# for episodic memories we keep the transitions in a local buffer until the episode is ended.
|
||||||
|
# for regular memories we insert the transitions directly to the memory
|
||||||
|
self.current_episode_buffer.insert(transition)
|
||||||
|
if not isinstance(self.memory, EpisodicExperienceReplay) \
|
||||||
|
and not self.ap.algorithm.store_transitions_only_when_episodes_are_terminated:
|
||||||
|
self.call_memory('store', transition)
|
||||||
|
|
||||||
|
if self.ap.visualization.dump_in_episode_signals:
|
||||||
|
self.update_step_in_episode_log()
|
||||||
|
|
||||||
|
return transition.game_over
|
||||||
|
|
||||||
|
# TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
|
||||||
|
# an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]
|
||||||
|
def emulate_act_on_trainer(self, transition: Transition) -> ActionInfo:
|
||||||
|
"""
|
||||||
|
This emulates the act using the transition obtained from the rollout worker on the training worker
|
||||||
|
in case of distributed training.
|
||||||
|
Given the agents current knowledge, decide on the next action to apply to the environment
|
||||||
|
:return: an action and a dictionary containing any additional info from the action decision process
|
||||||
|
"""
|
||||||
|
if self.phase == RunPhase.TRAIN and self.ap.algorithm.num_consecutive_playing_steps.num_steps == 0:
|
||||||
|
# This agent never plays while training (e.g. behavioral cloning)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# count steps (only when training or if we are in the evaluation worker)
|
||||||
|
if self.phase != RunPhase.TEST or self.ap.task_parameters.evaluate_only:
|
||||||
|
self.total_steps_counter += 1
|
||||||
|
self.current_episode_steps_counter += 1
|
||||||
|
|
||||||
|
self.last_action_info = transition.action
|
||||||
|
|
||||||
|
return self.last_action_info
|
||||||
|
|
||||||
def get_success_rate(self) -> float:
|
def get_success_rate(self) -> float:
|
||||||
return self.num_successes_across_evaluation_episodes / self.num_evaluation_episodes_completed
|
return self.num_successes_across_evaluation_episodes / self.num_evaluation_episodes_completed
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from typing import Union, List, Dict
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from rl_coach.core_types import EnvResponse, ActionInfo, RunPhase, PredictionType, ActionType
|
from rl_coach.core_types import EnvResponse, ActionInfo, RunPhase, PredictionType, ActionType, Transition
|
||||||
|
|
||||||
|
|
||||||
class AgentInterface(object):
|
class AgentInterface(object):
|
||||||
@@ -123,3 +123,33 @@ class AgentInterface(object):
|
|||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("")
|
raise NotImplementedError("")
|
||||||
|
|
||||||
|
# TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
|
||||||
|
# an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]
|
||||||
|
def emulate_observe_on_trainer(self, transition: Transition) -> bool:
|
||||||
|
"""
|
||||||
|
This emulates the act using the transition obtained from the rollout worker on the training worker
|
||||||
|
in case of distributed training.
|
||||||
|
Gets a response from the environment.
|
||||||
|
Processes this information for later use. For example, create a transition and store it in memory.
|
||||||
|
The action info (a class containing any info the agent wants to store regarding its action decision process) is
|
||||||
|
stored by the agent itself when deciding on the action.
|
||||||
|
:param env_response: a EnvResponse containing the response from the environment
|
||||||
|
:return: a done signal which is based on the agent knowledge. This can be different from the done signal from
|
||||||
|
the environment. For example, an agent can decide to finish the episode each time it gets some
|
||||||
|
intrinsic reward
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("")
|
||||||
|
|
||||||
|
# TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
|
||||||
|
# an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]
|
||||||
|
def emulate_act_on_trainer(self, transition: Transition) -> ActionInfo:
|
||||||
|
"""
|
||||||
|
This emulates the act using the transition obtained from the rollout worker on the training worker
|
||||||
|
in case of distributed training.
|
||||||
|
Get a decision of the next action to take.
|
||||||
|
The action is dependent on the current state which the agent holds from resetting the environment or from
|
||||||
|
the observe function.
|
||||||
|
:return: A tuple containing the actual action and additional info on the action
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("")
|
||||||
|
|||||||
@@ -112,6 +112,7 @@ class ClippedPPOAlgorithmParameters(AlgorithmParameters):
|
|||||||
self.optimization_epochs = 10
|
self.optimization_epochs = 10
|
||||||
self.normalization_stats = None
|
self.normalization_stats = None
|
||||||
self.clipping_decay_schedule = ConstantSchedule(1)
|
self.clipping_decay_schedule = ConstantSchedule(1)
|
||||||
|
self.act_for_full_episodes = True
|
||||||
|
|
||||||
|
|
||||||
class ClippedPPOAgentParameters(AgentParameters):
|
class ClippedPPOAgentParameters(AgentParameters):
|
||||||
@@ -294,11 +295,8 @@ class ClippedPPOAgent(ActorCriticAgent):
|
|||||||
# clean memory
|
# clean memory
|
||||||
self.call_memory('clean')
|
self.call_memory('clean')
|
||||||
|
|
||||||
def _should_train_helper(self, wait_for_full_episode=True):
|
|
||||||
return super()._should_train_helper(True)
|
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
if self._should_train(wait_for_full_episode=True):
|
if self._should_train():
|
||||||
for network in self.networks.values():
|
for network in self.networks.values():
|
||||||
network.set_is_training(True)
|
network.set_is_training(True)
|
||||||
|
|
||||||
@@ -334,3 +332,4 @@ class ClippedPPOAgent(ActorCriticAgent):
|
|||||||
def choose_action(self, curr_state):
|
def choose_action(self, curr_state):
|
||||||
self.ap.algorithm.clipping_decay_schedule.step()
|
self.ap.algorithm.clipping_decay_schedule.step()
|
||||||
return super().choose_action(curr_state)
|
return super().choose_action(curr_state)
|
||||||
|
|
||||||
|
|||||||
@@ -121,6 +121,7 @@ class PPOAlgorithmParameters(AlgorithmParameters):
|
|||||||
self.use_kl_regularization = True
|
self.use_kl_regularization = True
|
||||||
self.beta_entropy = 0.01
|
self.beta_entropy = 0.01
|
||||||
self.num_consecutive_playing_steps = EnvironmentSteps(5000)
|
self.num_consecutive_playing_steps = EnvironmentSteps(5000)
|
||||||
|
self.act_for_full_episodes = True
|
||||||
|
|
||||||
|
|
||||||
class PPOAgentParameters(AgentParameters):
|
class PPOAgentParameters(AgentParameters):
|
||||||
@@ -354,12 +355,9 @@ class PPOAgent(ActorCriticAgent):
|
|||||||
# clean memory
|
# clean memory
|
||||||
self.call_memory('clean')
|
self.call_memory('clean')
|
||||||
|
|
||||||
def _should_train_helper(self, wait_for_full_episode=True):
|
|
||||||
return super()._should_train_helper(True)
|
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
loss = 0
|
loss = 0
|
||||||
if self._should_train(wait_for_full_episode=True):
|
if self._should_train():
|
||||||
for network in self.networks.values():
|
for network in self.networks.values():
|
||||||
network.set_is_training(True)
|
network.set_is_training(True)
|
||||||
|
|
||||||
@@ -391,3 +389,4 @@ class PPOAgent(ActorCriticAgent):
|
|||||||
def get_prediction(self, states):
|
def get_prediction(self, states):
|
||||||
tf_input_state = self.prepare_batch_for_inference(states, "actor")
|
tf_input_state = self.prepare_batch_for_inference(states, "actor")
|
||||||
return self.networks['actor'].online_network.predict(tf_input_state)
|
return self.networks['actor'].online_network.predict(tf_input_state)
|
||||||
|
|
||||||
|
|||||||
@@ -171,6 +171,9 @@ class AlgorithmParameters(Parameters):
|
|||||||
# Distributed Coach params
|
# Distributed Coach params
|
||||||
self.distributed_coach_synchronization_type = None
|
self.distributed_coach_synchronization_type = None
|
||||||
|
|
||||||
|
# Should the workers wait for full episode
|
||||||
|
self.act_for_full_episodes = False
|
||||||
|
|
||||||
|
|
||||||
class PresetValidationParameters(Parameters):
|
class PresetValidationParameters(Parameters):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
|||||||
@@ -70,6 +70,7 @@ class ObservationNormalizationFilter(ObservationFilter):
|
|||||||
return self.running_observation_stats.normalize(observations)
|
return self.running_observation_stats.normalize(observations)
|
||||||
|
|
||||||
def get_filtered_observation_space(self, input_observation_space: ObservationSpace) -> ObservationSpace:
|
def get_filtered_observation_space(self, input_observation_space: ObservationSpace) -> ObservationSpace:
|
||||||
|
|
||||||
self.running_observation_stats.create_ops(shape=input_observation_space.shape,
|
self.running_observation_stats.create_ops(shape=input_observation_space.shape,
|
||||||
clip_values=(self.clip_min, self.clip_max))
|
clip_values=(self.clip_min, self.clip_max))
|
||||||
return input_observation_space
|
return input_observation_space
|
||||||
|
|||||||
@@ -27,13 +27,14 @@ from rl_coach.base_parameters import iterable_to_items, TaskParameters, Distribu
|
|||||||
Parameters, PresetValidationParameters
|
Parameters, PresetValidationParameters
|
||||||
from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, TrainingSteps, EnvironmentEpisodes, \
|
from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, TrainingSteps, EnvironmentEpisodes, \
|
||||||
EnvironmentSteps, \
|
EnvironmentSteps, \
|
||||||
StepMethod
|
StepMethod, Transition
|
||||||
from rl_coach.environments.environment import Environment
|
from rl_coach.environments.environment import Environment
|
||||||
from rl_coach.level_manager import LevelManager
|
from rl_coach.level_manager import LevelManager
|
||||||
from rl_coach.logger import screen, Logger
|
from rl_coach.logger import screen, Logger
|
||||||
from rl_coach.utils import set_cpu, start_shell_command_and_wait
|
from rl_coach.utils import set_cpu, start_shell_command_and_wait
|
||||||
from rl_coach.data_stores.data_store_impl import get_data_store
|
from rl_coach.data_stores.data_store_impl import get_data_store
|
||||||
from rl_coach.orchestrators.kubernetes_orchestrator import RunType
|
from rl_coach.orchestrators.kubernetes_orchestrator import RunType
|
||||||
|
from rl_coach.memories.backend.memory_impl import get_memory_backend
|
||||||
from rl_coach.data_stores.data_store import SyncFiles
|
from rl_coach.data_stores.data_store import SyncFiles
|
||||||
|
|
||||||
|
|
||||||
@@ -398,9 +399,10 @@ class GraphManager(object):
|
|||||||
[environment.reset_internal_state(force_environment_reset) for environment in self.environments]
|
[environment.reset_internal_state(force_environment_reset) for environment in self.environments]
|
||||||
[manager.reset_internal_state() for manager in self.level_managers]
|
[manager.reset_internal_state() for manager in self.level_managers]
|
||||||
|
|
||||||
def act(self, steps: PlayingStepsType) -> None:
|
def act(self, steps: PlayingStepsType, wait_for_full_episodes=False) -> None:
|
||||||
"""
|
"""
|
||||||
Do several steps of acting on the environment
|
Do several steps of acting on the environment
|
||||||
|
:param wait_for_full_episodes: if set, act for at least `steps`, but make sure that the last episode is complete
|
||||||
:param steps: the number of steps as a tuple of steps time and steps count
|
:param steps: the number of steps as a tuple of steps time and steps count
|
||||||
"""
|
"""
|
||||||
self.verify_graph_was_created()
|
self.verify_graph_was_created()
|
||||||
@@ -412,7 +414,8 @@ class GraphManager(object):
|
|||||||
|
|
||||||
# perform several steps of playing
|
# perform several steps of playing
|
||||||
count_end = self.current_step_counter + steps
|
count_end = self.current_step_counter + steps
|
||||||
while self.current_step_counter < count_end:
|
result = None
|
||||||
|
while self.current_step_counter < count_end or (wait_for_full_episodes and result is not None and not result.game_over):
|
||||||
# reset the environment if the previous episode was terminated
|
# reset the environment if the previous episode was terminated
|
||||||
if self.reset_required:
|
if self.reset_required:
|
||||||
self.reset_internal_state()
|
self.reset_internal_state()
|
||||||
@@ -624,5 +627,46 @@ class GraphManager(object):
|
|||||||
def should_train(self) -> bool:
|
def should_train(self) -> bool:
|
||||||
return any([manager.should_train() for manager in self.level_managers])
|
return any([manager.should_train() for manager in self.level_managers])
|
||||||
|
|
||||||
|
# TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
|
||||||
|
# an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]
|
||||||
|
def emulate_act_on_trainer(self, steps: PlayingStepsType, transition: Transition) -> None:
|
||||||
|
"""
|
||||||
|
This emulates the act using the transition obtained from the rollout worker on the training worker
|
||||||
|
in case of distributed training.
|
||||||
|
Do several steps of acting on the environment
|
||||||
|
:param steps: the number of steps as a tuple of steps time and steps count
|
||||||
|
"""
|
||||||
|
self.verify_graph_was_created()
|
||||||
|
|
||||||
|
# perform several steps of playing
|
||||||
|
count_end = self.current_step_counter + steps
|
||||||
|
while self.current_step_counter < count_end:
|
||||||
|
# reset the environment if the previous episode was terminated
|
||||||
|
if self.reset_required:
|
||||||
|
self.reset_internal_state()
|
||||||
|
|
||||||
|
steps_begin = self.environments[0].total_steps_counter
|
||||||
|
self.top_level_manager.emulate_step_on_trainer(transition)
|
||||||
|
steps_end = self.environments[0].total_steps_counter
|
||||||
|
|
||||||
|
# add the diff between the total steps before and after stepping, such that environment initialization steps
|
||||||
|
# (like in Atari) will not be counted.
|
||||||
|
# We add at least one step so that even if no steps were made (in case no actions are taken in the training
|
||||||
|
# phase), the loop will end eventually.
|
||||||
|
self.current_step_counter[EnvironmentSteps] += max(1, steps_end - steps_begin)
|
||||||
|
|
||||||
|
if transition.game_over:
|
||||||
|
self.handle_episode_ended()
|
||||||
|
self.reset_required = True
|
||||||
|
|
||||||
|
def fetch_from_worker(self, num_steps=0):
|
||||||
|
if hasattr(self, 'memory_backend'):
|
||||||
|
for transition in self.memory_backend.fetch(num_steps):
|
||||||
|
self.emulate_act_on_trainer(EnvironmentSteps(1), transition)
|
||||||
|
|
||||||
|
def setup_memory_backend(self) -> None:
|
||||||
|
if hasattr(self.agent_params.memory, 'memory_backend_params'):
|
||||||
|
self.memory_backend = get_memory_backend(self.agent_params.memory.memory_backend_params)
|
||||||
|
|
||||||
def should_stop(self) -> bool:
|
def should_stop(self) -> bool:
|
||||||
return all([manager.should_stop() for manager in self.level_managers])
|
return all([manager.should_stop() for manager in self.level_managers])
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import copy
|
|||||||
from typing import Union, Dict
|
from typing import Union, Dict
|
||||||
|
|
||||||
from rl_coach.agents.composite_agent import CompositeAgent
|
from rl_coach.agents.composite_agent import CompositeAgent
|
||||||
from rl_coach.core_types import EnvResponse, ActionInfo, RunPhase, ActionType, EnvironmentSteps
|
from rl_coach.core_types import EnvResponse, ActionInfo, RunPhase, ActionType, EnvironmentSteps, Transition
|
||||||
from rl_coach.environments.environment import Environment
|
from rl_coach.environments.environment import Environment
|
||||||
from rl_coach.environments.environment_interface import EnvironmentInterface
|
from rl_coach.environments.environment_interface import EnvironmentInterface
|
||||||
from rl_coach.spaces import ActionSpace, SpacesDefinition
|
from rl_coach.spaces import ActionSpace, SpacesDefinition
|
||||||
@@ -264,5 +264,31 @@ class LevelManager(EnvironmentInterface):
|
|||||||
def should_train(self) -> bool:
|
def should_train(self) -> bool:
|
||||||
return any([agent._should_train_helper() for agent in self.agents.values()])
|
return any([agent._should_train_helper() for agent in self.agents.values()])
|
||||||
|
|
||||||
|
# TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
|
||||||
|
# an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]
|
||||||
|
def emulate_step_on_trainer(self, transition: Transition) -> None:
|
||||||
|
"""
|
||||||
|
This emulates a step using the transition obtained from the rollout worker on the training worker
|
||||||
|
in case of distributed training.
|
||||||
|
Run a single step of following the behavioral scheme set for this environment.
|
||||||
|
:param action: the action to apply to the agents held in this level, before beginning following
|
||||||
|
the scheme.
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.reset_required:
|
||||||
|
self.reset_internal_state()
|
||||||
|
|
||||||
|
acting_agent = list(self.agents.values())[0]
|
||||||
|
|
||||||
|
# for i in range(self.steps_limit.num_steps):
|
||||||
|
# let the agent observe the result and decide if it wants to terminate the episode
|
||||||
|
done = acting_agent.emulate_observe_on_trainer(transition)
|
||||||
|
acting_agent.emulate_act_on_trainer(transition)
|
||||||
|
|
||||||
|
if done:
|
||||||
|
self.handle_episode_ended()
|
||||||
|
self.reset_required = True
|
||||||
|
|
||||||
def should_stop(self) -> bool:
|
def should_stop(self) -> bool:
|
||||||
return all([agent.get_success_rate() >= self.environment.get_target_success_rate() for agent in self.agents.values()])
|
return all([agent.get_success_rate() >= self.environment.get_target_success_rate() for agent in self.agents.values()])
|
||||||
|
|||||||
@@ -32,5 +32,5 @@ class MemoryBackend(object):
|
|||||||
def store_episode(self, obj):
|
def store_episode(self, obj):
|
||||||
raise NotImplemented("Not yet implemented")
|
raise NotImplemented("Not yet implemented")
|
||||||
|
|
||||||
def subscribe(self, memory):
|
def fetch(self, num_steps=0):
|
||||||
raise NotImplemented("Not yet implemented")
|
raise NotImplemented("Not yet implemented")
|
||||||
|
|||||||
@@ -2,13 +2,11 @@
|
|||||||
import redis
|
import redis
|
||||||
import pickle
|
import pickle
|
||||||
import uuid
|
import uuid
|
||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
from kubernetes import client
|
from kubernetes import client
|
||||||
|
|
||||||
from rl_coach.memories.backend.memory import MemoryBackend, MemoryBackendParameters
|
from rl_coach.memories.backend.memory import MemoryBackend, MemoryBackendParameters
|
||||||
from rl_coach.core_types import Transition, Episode
|
from rl_coach.core_types import Transition, Episode
|
||||||
from rl_coach.core_types import RunPhase
|
|
||||||
|
|
||||||
|
|
||||||
class RedisPubSubMemoryBackendParameters(MemoryBackendParameters):
|
class RedisPubSubMemoryBackendParameters(MemoryBackendParameters):
|
||||||
@@ -131,42 +129,40 @@ class RedisPubSubBackend(MemoryBackend):
|
|||||||
def sample(self, size):
|
def sample(self, size):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def fetch(self, num_steps=0):
|
||||||
|
return RedisSub(redis_address=self.params.redis_address, redis_port=self.params.redis_port, channel=self.params.channel).run(num_steps=num_steps)
|
||||||
|
|
||||||
def subscribe(self, agent):
|
def subscribe(self, agent):
|
||||||
redis_sub = RedisSub(agent, redis_address=self.params.redis_address, redis_port=self.params.redis_port, channel=self.params.channel)
|
redis_sub = RedisSub(redis_address=self.params.redis_address, redis_port=self.params.redis_port, channel=self.params.channel)
|
||||||
redis_sub.daemon = True
|
return redis_sub
|
||||||
redis_sub.start()
|
|
||||||
|
|
||||||
def get_endpoint(self):
|
def get_endpoint(self):
|
||||||
return {'redis_address': self.params.redis_address,
|
return {'redis_address': self.params.redis_address,
|
||||||
'redis_port': self.params.redis_port}
|
'redis_port': self.params.redis_port}
|
||||||
|
|
||||||
|
|
||||||
class RedisSub(threading.Thread):
|
class RedisSub(object):
|
||||||
|
def __init__(self, redis_address: str = "localhost", redis_port: int=6379, channel: str = "PubsubChannel"):
|
||||||
def __init__(self, agent, redis_address: str = "localhost", redis_port: int=6379, channel: str = "PubsubChannel"):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.redis_connection = redis.Redis(redis_address, redis_port)
|
self.redis_connection = redis.Redis(redis_address, redis_port)
|
||||||
self.pubsub = self.redis_connection.pubsub()
|
self.pubsub = self.redis_connection.pubsub()
|
||||||
self.subscriber = None
|
self.subscriber = None
|
||||||
self.agent = agent
|
|
||||||
self.channel = channel
|
self.channel = channel
|
||||||
self.subscriber = self.pubsub.subscribe(self.channel)
|
self.subscriber = self.pubsub.subscribe(self.channel)
|
||||||
|
|
||||||
def run(self):
|
def run(self, num_steps=0):
|
||||||
|
steps = 0
|
||||||
for message in self.pubsub.listen():
|
for message in self.pubsub.listen():
|
||||||
if message and 'data' in message and self.agent.phase != RunPhase.TEST or self.agent.ap.task_parameters.evaluate_only:
|
if message and 'data' in message:
|
||||||
if self.agent.phase == RunPhase.TEST:
|
|
||||||
print(self.agent.phase)
|
|
||||||
try:
|
try:
|
||||||
obj = pickle.loads(message['data'])
|
obj = pickle.loads(message['data'])
|
||||||
if type(obj) == Transition:
|
if type(obj) == Transition:
|
||||||
self.agent.total_steps_counter += 1
|
steps += 1
|
||||||
self.agent.current_episode_steps_counter += 1
|
yield obj
|
||||||
self.agent.call_memory('store', obj)
|
|
||||||
elif type(obj) == Episode:
|
elif type(obj) == Episode:
|
||||||
self.agent.current_episode_buffer = obj
|
steps += len(obj.transitions)
|
||||||
self.agent.total_steps_counter += len(obj.transitions)
|
yield from obj.transitions
|
||||||
self.agent.current_episode_steps_counter += len(obj.transitions)
|
|
||||||
self.agent.handle_episode_ended()
|
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
if num_steps > 0 and steps >= num_steps:
|
||||||
|
break
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import os
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
from rl_coach.base_parameters import TaskParameters, DistributedCoachSynchronizationType
|
from rl_coach.base_parameters import TaskParameters, DistributedCoachSynchronizationType
|
||||||
from rl_coach.core_types import EnvironmentSteps, RunPhase
|
from rl_coach.core_types import EnvironmentSteps, RunPhase, EnvironmentEpisodes
|
||||||
from google.protobuf import text_format
|
from google.protobuf import text_format
|
||||||
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
|
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
|
||||||
from rl_coach.data_stores.data_store import SyncFiles
|
from rl_coach.data_stores.data_store import SyncFiles
|
||||||
@@ -81,21 +81,23 @@ def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers):
|
|||||||
|
|
||||||
task_parameters = TaskParameters()
|
task_parameters = TaskParameters()
|
||||||
task_parameters.__dict__['checkpoint_restore_dir'] = checkpoint_dir
|
task_parameters.__dict__['checkpoint_restore_dir'] = checkpoint_dir
|
||||||
time.sleep(30)
|
|
||||||
graph_manager.create_graph(task_parameters)
|
graph_manager.create_graph(task_parameters)
|
||||||
with graph_manager.phase_context(RunPhase.TRAIN):
|
with graph_manager.phase_context(RunPhase.TRAIN):
|
||||||
error_compensation = 100
|
|
||||||
|
|
||||||
last_checkpoint = 0
|
last_checkpoint = 0
|
||||||
|
|
||||||
act_steps = math.ceil((graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps + error_compensation)/num_workers)
|
act_steps = math.ceil((graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps)/num_workers)
|
||||||
|
|
||||||
for i in range(int(graph_manager.improve_steps.num_steps/act_steps)):
|
for i in range(int(graph_manager.improve_steps.num_steps/act_steps)):
|
||||||
|
|
||||||
if should_stop(checkpoint_dir):
|
if should_stop(checkpoint_dir):
|
||||||
break
|
break
|
||||||
|
|
||||||
graph_manager.act(EnvironmentSteps(num_steps=act_steps))
|
if type(graph_manager.agent_params.algorithm.num_consecutive_playing_steps) == EnvironmentSteps:
|
||||||
|
graph_manager.act(EnvironmentSteps(num_steps=act_steps), wait_for_full_episode=graph_manager.agent_params.algorithm.act_for_full_episodes)
|
||||||
|
elif type(graph_manager.agent_params.algorithm.num_consecutive_playing_steps) == EnvironmentEpisodes:
|
||||||
|
graph_manager.act(EnvironmentEpisodes(num_steps=act_steps))
|
||||||
|
|
||||||
new_checkpoint = get_latest_checkpoint(checkpoint_dir)
|
new_checkpoint = get_latest_checkpoint(checkpoint_dir)
|
||||||
|
|
||||||
|
|||||||
@@ -31,7 +31,14 @@ def training_worker(graph_manager, checkpoint_dir):
|
|||||||
# evaluation offset
|
# evaluation offset
|
||||||
eval_offset = 1
|
eval_offset = 1
|
||||||
|
|
||||||
|
graph_manager.setup_memory_backend()
|
||||||
|
|
||||||
while(steps < graph_manager.improve_steps.num_steps):
|
while(steps < graph_manager.improve_steps.num_steps):
|
||||||
|
|
||||||
|
graph_manager.phase = core_types.RunPhase.TRAIN
|
||||||
|
graph_manager.fetch_from_worker(num_steps=graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps)
|
||||||
|
graph_manager.phase = core_types.RunPhase.UNDEFINED
|
||||||
|
|
||||||
if graph_manager.should_train():
|
if graph_manager.should_train():
|
||||||
steps += 1
|
steps += 1
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user