From fde73ced1397292cb8b821e8ed528e4cca9f8df5 Mon Sep 17 00:00:00 2001 From: Ajay Deshpande Date: Thu, 15 Nov 2018 08:38:58 -0800 Subject: [PATCH] Simulating the act on the trainer. (#65) * Remove the use of daemon threads for Redis subscribe. * Emulate act and observe on trainer side to update internal vars. --- rl_coach/agents/agent.py | 73 +++++++++++++++++-- rl_coach/agents/agent_interface.py | 32 +++++++- rl_coach/agents/clipped_ppo_agent.py | 7 +- rl_coach/agents/ppo_agent.py | 7 +- rl_coach/base_parameters.py | 3 + rl_coach/coach.py | 18 ++--- .../observation_normalization_filter.py | 1 + rl_coach/graph_managers/graph_manager.py | 50 ++++++++++++- rl_coach/level_manager.py | 28 ++++++- rl_coach/memories/backend/memory.py | 2 +- rl_coach/memories/backend/redis.py | 36 ++++----- rl_coach/rollout_worker.py | 12 +-- rl_coach/training_worker.py | 7 ++ 13 files changed, 221 insertions(+), 55 deletions(-) diff --git a/rl_coach/agents/agent.py b/rl_coach/agents/agent.py index b629b69..fac4325 100644 --- a/rl_coach/agents/agent.py +++ b/rl_coach/agents/agent.py @@ -80,9 +80,7 @@ class Agent(AgentInterface): if hasattr(self.ap.memory, 'memory_backend_params'): self.memory_backend = get_memory_backend(self.ap.memory.memory_backend_params) - if self.ap.memory.memory_backend_params.run_type == 'trainer': - self.memory_backend.subscribe(self) - else: + if self.ap.memory.memory_backend_params.run_type != 'trainer': self.memory.set_memory_backend(self.memory_backend) if agent_parameters.memory.load_memory_from_file_path: @@ -583,14 +581,14 @@ class Agent(AgentInterface): "EnvironmentSteps or TrainingSteps. Instead it is {}".format(step_method.__class__)) return should_update - def _should_train(self, wait_for_full_episode=False) -> bool: + def _should_train(self): """ Determine if we should start a training phase according to the number of steps passed since the last training :return: boolean: True if we should start a training phase """ - should_update = self._should_train_helper(wait_for_full_episode=wait_for_full_episode) + should_update = self._should_train_helper() step_method = self.ap.algorithm.num_consecutive_playing_steps @@ -602,8 +600,8 @@ class Agent(AgentInterface): return should_update - def _should_train_helper(self, wait_for_full_episode=False): - + def _should_train_helper(self): + wait_for_full_episode = self.ap.algorithm.act_for_full_episodes step_method = self.ap.algorithm.num_consecutive_playing_steps if step_method.__class__ == EnvironmentEpisodes: @@ -922,5 +920,66 @@ class Agent(AgentInterface): for network in self.networks.values(): network.sync() + # TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create + # an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()] + def emulate_observe_on_trainer(self, transition: Transition) -> bool: + """ + This emulates the observe using the transition obtained from the rollout worker on the training worker + in case of distributed training. + Given a response from the environment, distill the observation from it and store it for later use. + The response should be a dictionary containing the performed action, the new observation and measurements, + the reward, a game over flag and any additional information necessary. + :return: + """ + + # if we are in the first step in the episode, then we don't have a a next state and a reward and thus no + # transition yet, and therefore we don't need to store anything in the memory. + # also we did not reach the goal yet. + if self.current_episode_steps_counter == 0: + # initialize the current state + return transition.game_over + else: + # sum up the total shaped reward + self.total_shaped_reward_in_current_episode += transition.reward + self.total_reward_in_current_episode += transition.reward + self.shaped_reward.add_sample(transition.reward) + self.reward.add_sample(transition.reward) + + # create and store the transition + if self.phase in [RunPhase.TRAIN, RunPhase.HEATUP]: + # for episodic memories we keep the transitions in a local buffer until the episode is ended. + # for regular memories we insert the transitions directly to the memory + self.current_episode_buffer.insert(transition) + if not isinstance(self.memory, EpisodicExperienceReplay) \ + and not self.ap.algorithm.store_transitions_only_when_episodes_are_terminated: + self.call_memory('store', transition) + + if self.ap.visualization.dump_in_episode_signals: + self.update_step_in_episode_log() + + return transition.game_over + + # TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create + # an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()] + def emulate_act_on_trainer(self, transition: Transition) -> ActionInfo: + """ + This emulates the act using the transition obtained from the rollout worker on the training worker + in case of distributed training. + Given the agents current knowledge, decide on the next action to apply to the environment + :return: an action and a dictionary containing any additional info from the action decision process + """ + if self.phase == RunPhase.TRAIN and self.ap.algorithm.num_consecutive_playing_steps.num_steps == 0: + # This agent never plays while training (e.g. behavioral cloning) + return None + + # count steps (only when training or if we are in the evaluation worker) + if self.phase != RunPhase.TEST or self.ap.task_parameters.evaluate_only: + self.total_steps_counter += 1 + self.current_episode_steps_counter += 1 + + self.last_action_info = transition.action + + return self.last_action_info + def get_success_rate(self) -> float: return self.num_successes_across_evaluation_episodes / self.num_evaluation_episodes_completed diff --git a/rl_coach/agents/agent_interface.py b/rl_coach/agents/agent_interface.py index cfbd361..968fa43 100644 --- a/rl_coach/agents/agent_interface.py +++ b/rl_coach/agents/agent_interface.py @@ -18,7 +18,7 @@ from typing import Union, List, Dict import numpy as np -from rl_coach.core_types import EnvResponse, ActionInfo, RunPhase, PredictionType, ActionType +from rl_coach.core_types import EnvResponse, ActionInfo, RunPhase, PredictionType, ActionType, Transition class AgentInterface(object): @@ -123,3 +123,33 @@ class AgentInterface(object): :return: None """ raise NotImplementedError("") + + # TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create + # an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()] + def emulate_observe_on_trainer(self, transition: Transition) -> bool: + """ + This emulates the act using the transition obtained from the rollout worker on the training worker + in case of distributed training. + Gets a response from the environment. + Processes this information for later use. For example, create a transition and store it in memory. + The action info (a class containing any info the agent wants to store regarding its action decision process) is + stored by the agent itself when deciding on the action. + :param env_response: a EnvResponse containing the response from the environment + :return: a done signal which is based on the agent knowledge. This can be different from the done signal from + the environment. For example, an agent can decide to finish the episode each time it gets some + intrinsic reward + """ + raise NotImplementedError("") + + # TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create + # an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()] + def emulate_act_on_trainer(self, transition: Transition) -> ActionInfo: + """ + This emulates the act using the transition obtained from the rollout worker on the training worker + in case of distributed training. + Get a decision of the next action to take. + The action is dependent on the current state which the agent holds from resetting the environment or from + the observe function. + :return: A tuple containing the actual action and additional info on the action + """ + raise NotImplementedError("") diff --git a/rl_coach/agents/clipped_ppo_agent.py b/rl_coach/agents/clipped_ppo_agent.py index 6a6b8f8..9fa8d72 100644 --- a/rl_coach/agents/clipped_ppo_agent.py +++ b/rl_coach/agents/clipped_ppo_agent.py @@ -112,6 +112,7 @@ class ClippedPPOAlgorithmParameters(AlgorithmParameters): self.optimization_epochs = 10 self.normalization_stats = None self.clipping_decay_schedule = ConstantSchedule(1) + self.act_for_full_episodes = True class ClippedPPOAgentParameters(AgentParameters): @@ -294,11 +295,8 @@ class ClippedPPOAgent(ActorCriticAgent): # clean memory self.call_memory('clean') - def _should_train_helper(self, wait_for_full_episode=True): - return super()._should_train_helper(True) - def train(self): - if self._should_train(wait_for_full_episode=True): + if self._should_train(): for network in self.networks.values(): network.set_is_training(True) @@ -334,3 +332,4 @@ class ClippedPPOAgent(ActorCriticAgent): def choose_action(self, curr_state): self.ap.algorithm.clipping_decay_schedule.step() return super().choose_action(curr_state) + diff --git a/rl_coach/agents/ppo_agent.py b/rl_coach/agents/ppo_agent.py index a2dabbb..f45800a 100644 --- a/rl_coach/agents/ppo_agent.py +++ b/rl_coach/agents/ppo_agent.py @@ -121,6 +121,7 @@ class PPOAlgorithmParameters(AlgorithmParameters): self.use_kl_regularization = True self.beta_entropy = 0.01 self.num_consecutive_playing_steps = EnvironmentSteps(5000) + self.act_for_full_episodes = True class PPOAgentParameters(AgentParameters): @@ -354,12 +355,9 @@ class PPOAgent(ActorCriticAgent): # clean memory self.call_memory('clean') - def _should_train_helper(self, wait_for_full_episode=True): - return super()._should_train_helper(True) - def train(self): loss = 0 - if self._should_train(wait_for_full_episode=True): + if self._should_train(): for network in self.networks.values(): network.set_is_training(True) @@ -391,3 +389,4 @@ class PPOAgent(ActorCriticAgent): def get_prediction(self, states): tf_input_state = self.prepare_batch_for_inference(states, "actor") return self.networks['actor'].online_network.predict(tf_input_state) + diff --git a/rl_coach/base_parameters.py b/rl_coach/base_parameters.py index c4a5775..03b27e9 100644 --- a/rl_coach/base_parameters.py +++ b/rl_coach/base_parameters.py @@ -171,6 +171,9 @@ class AlgorithmParameters(Parameters): # Distributed Coach params self.distributed_coach_synchronization_type = None + # Should the workers wait for full episode + self.act_for_full_episodes = False + class PresetValidationParameters(Parameters): def __init__(self, diff --git a/rl_coach/coach.py b/rl_coach/coach.py index 53c1f70..8db1cd0 100644 --- a/rl_coach/coach.py +++ b/rl_coach/coach.py @@ -287,15 +287,15 @@ class CoachLauncher(object): def get_config_args(self, parser: argparse.ArgumentParser) -> argparse.Namespace: """ Returns a Namespace object with all the user-specified configuration options needed to launch. - This implementation uses argparse to take arguments from the CLI, but this can be over-ridden by + This implementation uses argparse to take arguments from the CLI, but this can be over-ridden by another method that gets its configuration from elsewhere. An equivalent method however must return an identically structured Namespace object, which conforms to the structure defined by get_argument_parser. - This method parses the arguments that the user entered, does some basic validation, and + This method parses the arguments that the user entered, does some basic validation, and modification of user-specified values in short form to be more explicit. - :param parser: a parser object which implicitly defines the format of the Namespace that + :param parser: a parser object which implicitly defines the format of the Namespace that is expected to be returned. :return: the parsed arguments as a Namespace """ @@ -333,26 +333,26 @@ class CoachLauncher(object): args.s3_creds_file = coach_config.get('coach', 's3_creds_file') except Error as e: screen.error("Error when reading distributed Coach config file: {}".format(e)) - + if args.image == '': screen.error("Image cannot be empty.") - + data_store_choices = ['s3'] if args.data_store not in data_store_choices: screen.warning("{} data store is unsupported.".format(args.data_store)) screen.error("Supported data stores are {}.".format(data_store_choices)) - + memory_backend_choices = ['redispubsub'] if args.memory_backend not in memory_backend_choices: screen.warning("{} memory backend is not supported.".format(args.memory_backend)) screen.error("Supported memory backends are {}.".format(memory_backend_choices)) - + if args.s3_bucket_name == '': screen.error("S3 bucket name cannot be empty.") - + if args.s3_creds_file == '': args.s3_creds_file = None - + if args.play and args.distributed_coach: screen.error("Playing is not supported in distributed Coach.") diff --git a/rl_coach/filters/observation/observation_normalization_filter.py b/rl_coach/filters/observation/observation_normalization_filter.py index 2cd8ac2..6ecc057 100644 --- a/rl_coach/filters/observation/observation_normalization_filter.py +++ b/rl_coach/filters/observation/observation_normalization_filter.py @@ -70,6 +70,7 @@ class ObservationNormalizationFilter(ObservationFilter): return self.running_observation_stats.normalize(observations) def get_filtered_observation_space(self, input_observation_space: ObservationSpace) -> ObservationSpace: + self.running_observation_stats.create_ops(shape=input_observation_space.shape, clip_values=(self.clip_min, self.clip_max)) return input_observation_space diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index 557393c..2ec5085 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -27,13 +27,14 @@ from rl_coach.base_parameters import iterable_to_items, TaskParameters, Distribu Parameters, PresetValidationParameters from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, TrainingSteps, EnvironmentEpisodes, \ EnvironmentSteps, \ - StepMethod + StepMethod, Transition from rl_coach.environments.environment import Environment from rl_coach.level_manager import LevelManager from rl_coach.logger import screen, Logger from rl_coach.utils import set_cpu, start_shell_command_and_wait from rl_coach.data_stores.data_store_impl import get_data_store from rl_coach.orchestrators.kubernetes_orchestrator import RunType +from rl_coach.memories.backend.memory_impl import get_memory_backend from rl_coach.data_stores.data_store import SyncFiles @@ -398,9 +399,10 @@ class GraphManager(object): [environment.reset_internal_state(force_environment_reset) for environment in self.environments] [manager.reset_internal_state() for manager in self.level_managers] - def act(self, steps: PlayingStepsType) -> None: + def act(self, steps: PlayingStepsType, wait_for_full_episodes=False) -> None: """ Do several steps of acting on the environment + :param wait_for_full_episodes: if set, act for at least `steps`, but make sure that the last episode is complete :param steps: the number of steps as a tuple of steps time and steps count """ self.verify_graph_was_created() @@ -412,7 +414,8 @@ class GraphManager(object): # perform several steps of playing count_end = self.current_step_counter + steps - while self.current_step_counter < count_end: + result = None + while self.current_step_counter < count_end or (wait_for_full_episodes and result is not None and not result.game_over): # reset the environment if the previous episode was terminated if self.reset_required: self.reset_internal_state() @@ -624,5 +627,46 @@ class GraphManager(object): def should_train(self) -> bool: return any([manager.should_train() for manager in self.level_managers]) + # TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create + # an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()] + def emulate_act_on_trainer(self, steps: PlayingStepsType, transition: Transition) -> None: + """ + This emulates the act using the transition obtained from the rollout worker on the training worker + in case of distributed training. + Do several steps of acting on the environment + :param steps: the number of steps as a tuple of steps time and steps count + """ + self.verify_graph_was_created() + + # perform several steps of playing + count_end = self.current_step_counter + steps + while self.current_step_counter < count_end: + # reset the environment if the previous episode was terminated + if self.reset_required: + self.reset_internal_state() + + steps_begin = self.environments[0].total_steps_counter + self.top_level_manager.emulate_step_on_trainer(transition) + steps_end = self.environments[0].total_steps_counter + + # add the diff between the total steps before and after stepping, such that environment initialization steps + # (like in Atari) will not be counted. + # We add at least one step so that even if no steps were made (in case no actions are taken in the training + # phase), the loop will end eventually. + self.current_step_counter[EnvironmentSteps] += max(1, steps_end - steps_begin) + + if transition.game_over: + self.handle_episode_ended() + self.reset_required = True + + def fetch_from_worker(self, num_steps=0): + if hasattr(self, 'memory_backend'): + for transition in self.memory_backend.fetch(num_steps): + self.emulate_act_on_trainer(EnvironmentSteps(1), transition) + + def setup_memory_backend(self) -> None: + if hasattr(self.agent_params.memory, 'memory_backend_params'): + self.memory_backend = get_memory_backend(self.agent_params.memory.memory_backend_params) + def should_stop(self) -> bool: return all([manager.should_stop() for manager in self.level_managers]) diff --git a/rl_coach/level_manager.py b/rl_coach/level_manager.py index 962fa20..3df2766 100644 --- a/rl_coach/level_manager.py +++ b/rl_coach/level_manager.py @@ -17,7 +17,7 @@ import copy from typing import Union, Dict from rl_coach.agents.composite_agent import CompositeAgent -from rl_coach.core_types import EnvResponse, ActionInfo, RunPhase, ActionType, EnvironmentSteps +from rl_coach.core_types import EnvResponse, ActionInfo, RunPhase, ActionType, EnvironmentSteps, Transition from rl_coach.environments.environment import Environment from rl_coach.environments.environment_interface import EnvironmentInterface from rl_coach.spaces import ActionSpace, SpacesDefinition @@ -264,5 +264,31 @@ class LevelManager(EnvironmentInterface): def should_train(self) -> bool: return any([agent._should_train_helper() for agent in self.agents.values()]) + # TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create + # an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()] + def emulate_step_on_trainer(self, transition: Transition) -> None: + """ + This emulates a step using the transition obtained from the rollout worker on the training worker + in case of distributed training. + Run a single step of following the behavioral scheme set for this environment. + :param action: the action to apply to the agents held in this level, before beginning following + the scheme. + :return: None + """ + + if self.reset_required: + self.reset_internal_state() + + acting_agent = list(self.agents.values())[0] + + # for i in range(self.steps_limit.num_steps): + # let the agent observe the result and decide if it wants to terminate the episode + done = acting_agent.emulate_observe_on_trainer(transition) + acting_agent.emulate_act_on_trainer(transition) + + if done: + self.handle_episode_ended() + self.reset_required = True + def should_stop(self) -> bool: return all([agent.get_success_rate() >= self.environment.get_target_success_rate() for agent in self.agents.values()]) diff --git a/rl_coach/memories/backend/memory.py b/rl_coach/memories/backend/memory.py index 4199cf6..da8f8ce 100644 --- a/rl_coach/memories/backend/memory.py +++ b/rl_coach/memories/backend/memory.py @@ -32,5 +32,5 @@ class MemoryBackend(object): def store_episode(self, obj): raise NotImplemented("Not yet implemented") - def subscribe(self, memory): + def fetch(self, num_steps=0): raise NotImplemented("Not yet implemented") diff --git a/rl_coach/memories/backend/redis.py b/rl_coach/memories/backend/redis.py index b9fb5ae..0e405a6 100644 --- a/rl_coach/memories/backend/redis.py +++ b/rl_coach/memories/backend/redis.py @@ -2,13 +2,11 @@ import redis import pickle import uuid -import threading import time from kubernetes import client from rl_coach.memories.backend.memory import MemoryBackend, MemoryBackendParameters from rl_coach.core_types import Transition, Episode -from rl_coach.core_types import RunPhase class RedisPubSubMemoryBackendParameters(MemoryBackendParameters): @@ -131,42 +129,40 @@ class RedisPubSubBackend(MemoryBackend): def sample(self, size): pass + def fetch(self, num_steps=0): + return RedisSub(redis_address=self.params.redis_address, redis_port=self.params.redis_port, channel=self.params.channel).run(num_steps=num_steps) + def subscribe(self, agent): - redis_sub = RedisSub(agent, redis_address=self.params.redis_address, redis_port=self.params.redis_port, channel=self.params.channel) - redis_sub.daemon = True - redis_sub.start() + redis_sub = RedisSub(redis_address=self.params.redis_address, redis_port=self.params.redis_port, channel=self.params.channel) + return redis_sub def get_endpoint(self): return {'redis_address': self.params.redis_address, 'redis_port': self.params.redis_port} -class RedisSub(threading.Thread): - - def __init__(self, agent, redis_address: str = "localhost", redis_port: int=6379, channel: str = "PubsubChannel"): +class RedisSub(object): + def __init__(self, redis_address: str = "localhost", redis_port: int=6379, channel: str = "PubsubChannel"): super().__init__() self.redis_connection = redis.Redis(redis_address, redis_port) self.pubsub = self.redis_connection.pubsub() self.subscriber = None - self.agent = agent self.channel = channel self.subscriber = self.pubsub.subscribe(self.channel) - def run(self): + def run(self, num_steps=0): + steps = 0 for message in self.pubsub.listen(): - if message and 'data' in message and self.agent.phase != RunPhase.TEST or self.agent.ap.task_parameters.evaluate_only: - if self.agent.phase == RunPhase.TEST: - print(self.agent.phase) + if message and 'data' in message: try: obj = pickle.loads(message['data']) if type(obj) == Transition: - self.agent.total_steps_counter += 1 - self.agent.current_episode_steps_counter += 1 - self.agent.call_memory('store', obj) + steps += 1 + yield obj elif type(obj) == Episode: - self.agent.current_episode_buffer = obj - self.agent.total_steps_counter += len(obj.transitions) - self.agent.current_episode_steps_counter += len(obj.transitions) - self.agent.handle_episode_ended() + steps += len(obj.transitions) + yield from obj.transitions except Exception: continue + if num_steps > 0 and steps >= num_steps: + break diff --git a/rl_coach/rollout_worker.py b/rl_coach/rollout_worker.py index 184ccdc..b53d290 100644 --- a/rl_coach/rollout_worker.py +++ b/rl_coach/rollout_worker.py @@ -12,7 +12,7 @@ import os import math from rl_coach.base_parameters import TaskParameters, DistributedCoachSynchronizationType -from rl_coach.core_types import EnvironmentSteps, RunPhase +from rl_coach.core_types import EnvironmentSteps, RunPhase, EnvironmentEpisodes from google.protobuf import text_format from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState from rl_coach.data_stores.data_store import SyncFiles @@ -81,21 +81,23 @@ def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers): task_parameters = TaskParameters() task_parameters.__dict__['checkpoint_restore_dir'] = checkpoint_dir - time.sleep(30) + graph_manager.create_graph(task_parameters) with graph_manager.phase_context(RunPhase.TRAIN): - error_compensation = 100 last_checkpoint = 0 - act_steps = math.ceil((graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps + error_compensation)/num_workers) + act_steps = math.ceil((graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps)/num_workers) for i in range(int(graph_manager.improve_steps.num_steps/act_steps)): if should_stop(checkpoint_dir): break - graph_manager.act(EnvironmentSteps(num_steps=act_steps)) + if type(graph_manager.agent_params.algorithm.num_consecutive_playing_steps) == EnvironmentSteps: + graph_manager.act(EnvironmentSteps(num_steps=act_steps), wait_for_full_episode=graph_manager.agent_params.algorithm.act_for_full_episodes) + elif type(graph_manager.agent_params.algorithm.num_consecutive_playing_steps) == EnvironmentEpisodes: + graph_manager.act(EnvironmentEpisodes(num_steps=act_steps)) new_checkpoint = get_latest_checkpoint(checkpoint_dir) diff --git a/rl_coach/training_worker.py b/rl_coach/training_worker.py index fbb5640..ac2923a 100644 --- a/rl_coach/training_worker.py +++ b/rl_coach/training_worker.py @@ -31,7 +31,14 @@ def training_worker(graph_manager, checkpoint_dir): # evaluation offset eval_offset = 1 + graph_manager.setup_memory_backend() + while(steps < graph_manager.improve_steps.num_steps): + + graph_manager.phase = core_types.RunPhase.TRAIN + graph_manager.fetch_from_worker(num_steps=graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps) + graph_manager.phase = core_types.RunPhase.UNDEFINED + if graph_manager.should_train(): steps += 1