1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

Simulating the act on the trainer. (#65)

* Remove the use of daemon threads for Redis subscribe.
* Emulate act and observe on trainer side to update internal vars.
This commit is contained in:
Ajay Deshpande
2018-11-15 08:38:58 -08:00
committed by Balaji Subramaniam
parent fe6857eabd
commit fde73ced13
13 changed files with 221 additions and 55 deletions

View File

@@ -80,9 +80,7 @@ class Agent(AgentInterface):
if hasattr(self.ap.memory, 'memory_backend_params'):
self.memory_backend = get_memory_backend(self.ap.memory.memory_backend_params)
if self.ap.memory.memory_backend_params.run_type == 'trainer':
self.memory_backend.subscribe(self)
else:
if self.ap.memory.memory_backend_params.run_type != 'trainer':
self.memory.set_memory_backend(self.memory_backend)
if agent_parameters.memory.load_memory_from_file_path:
@@ -583,14 +581,14 @@ class Agent(AgentInterface):
"EnvironmentSteps or TrainingSteps. Instead it is {}".format(step_method.__class__))
return should_update
def _should_train(self, wait_for_full_episode=False) -> bool:
def _should_train(self):
"""
Determine if we should start a training phase according to the number of steps passed since the last training
:return: boolean: True if we should start a training phase
"""
should_update = self._should_train_helper(wait_for_full_episode=wait_for_full_episode)
should_update = self._should_train_helper()
step_method = self.ap.algorithm.num_consecutive_playing_steps
@@ -602,8 +600,8 @@ class Agent(AgentInterface):
return should_update
def _should_train_helper(self, wait_for_full_episode=False):
def _should_train_helper(self):
wait_for_full_episode = self.ap.algorithm.act_for_full_episodes
step_method = self.ap.algorithm.num_consecutive_playing_steps
if step_method.__class__ == EnvironmentEpisodes:
@@ -922,5 +920,66 @@ class Agent(AgentInterface):
for network in self.networks.values():
network.sync()
# TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
# an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]
def emulate_observe_on_trainer(self, transition: Transition) -> bool:
"""
This emulates the observe using the transition obtained from the rollout worker on the training worker
in case of distributed training.
Given a response from the environment, distill the observation from it and store it for later use.
The response should be a dictionary containing the performed action, the new observation and measurements,
the reward, a game over flag and any additional information necessary.
:return:
"""
# if we are in the first step in the episode, then we don't have a a next state and a reward and thus no
# transition yet, and therefore we don't need to store anything in the memory.
# also we did not reach the goal yet.
if self.current_episode_steps_counter == 0:
# initialize the current state
return transition.game_over
else:
# sum up the total shaped reward
self.total_shaped_reward_in_current_episode += transition.reward
self.total_reward_in_current_episode += transition.reward
self.shaped_reward.add_sample(transition.reward)
self.reward.add_sample(transition.reward)
# create and store the transition
if self.phase in [RunPhase.TRAIN, RunPhase.HEATUP]:
# for episodic memories we keep the transitions in a local buffer until the episode is ended.
# for regular memories we insert the transitions directly to the memory
self.current_episode_buffer.insert(transition)
if not isinstance(self.memory, EpisodicExperienceReplay) \
and not self.ap.algorithm.store_transitions_only_when_episodes_are_terminated:
self.call_memory('store', transition)
if self.ap.visualization.dump_in_episode_signals:
self.update_step_in_episode_log()
return transition.game_over
# TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
# an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]
def emulate_act_on_trainer(self, transition: Transition) -> ActionInfo:
"""
This emulates the act using the transition obtained from the rollout worker on the training worker
in case of distributed training.
Given the agents current knowledge, decide on the next action to apply to the environment
:return: an action and a dictionary containing any additional info from the action decision process
"""
if self.phase == RunPhase.TRAIN and self.ap.algorithm.num_consecutive_playing_steps.num_steps == 0:
# This agent never plays while training (e.g. behavioral cloning)
return None
# count steps (only when training or if we are in the evaluation worker)
if self.phase != RunPhase.TEST or self.ap.task_parameters.evaluate_only:
self.total_steps_counter += 1
self.current_episode_steps_counter += 1
self.last_action_info = transition.action
return self.last_action_info
def get_success_rate(self) -> float:
return self.num_successes_across_evaluation_episodes / self.num_evaluation_episodes_completed

View File

@@ -18,7 +18,7 @@ from typing import Union, List, Dict
import numpy as np
from rl_coach.core_types import EnvResponse, ActionInfo, RunPhase, PredictionType, ActionType
from rl_coach.core_types import EnvResponse, ActionInfo, RunPhase, PredictionType, ActionType, Transition
class AgentInterface(object):
@@ -123,3 +123,33 @@ class AgentInterface(object):
:return: None
"""
raise NotImplementedError("")
# TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
# an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]
def emulate_observe_on_trainer(self, transition: Transition) -> bool:
"""
This emulates the act using the transition obtained from the rollout worker on the training worker
in case of distributed training.
Gets a response from the environment.
Processes this information for later use. For example, create a transition and store it in memory.
The action info (a class containing any info the agent wants to store regarding its action decision process) is
stored by the agent itself when deciding on the action.
:param env_response: a EnvResponse containing the response from the environment
:return: a done signal which is based on the agent knowledge. This can be different from the done signal from
the environment. For example, an agent can decide to finish the episode each time it gets some
intrinsic reward
"""
raise NotImplementedError("")
# TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
# an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]
def emulate_act_on_trainer(self, transition: Transition) -> ActionInfo:
"""
This emulates the act using the transition obtained from the rollout worker on the training worker
in case of distributed training.
Get a decision of the next action to take.
The action is dependent on the current state which the agent holds from resetting the environment or from
the observe function.
:return: A tuple containing the actual action and additional info on the action
"""
raise NotImplementedError("")

View File

@@ -112,6 +112,7 @@ class ClippedPPOAlgorithmParameters(AlgorithmParameters):
self.optimization_epochs = 10
self.normalization_stats = None
self.clipping_decay_schedule = ConstantSchedule(1)
self.act_for_full_episodes = True
class ClippedPPOAgentParameters(AgentParameters):
@@ -294,11 +295,8 @@ class ClippedPPOAgent(ActorCriticAgent):
# clean memory
self.call_memory('clean')
def _should_train_helper(self, wait_for_full_episode=True):
return super()._should_train_helper(True)
def train(self):
if self._should_train(wait_for_full_episode=True):
if self._should_train():
for network in self.networks.values():
network.set_is_training(True)
@@ -334,3 +332,4 @@ class ClippedPPOAgent(ActorCriticAgent):
def choose_action(self, curr_state):
self.ap.algorithm.clipping_decay_schedule.step()
return super().choose_action(curr_state)

View File

@@ -121,6 +121,7 @@ class PPOAlgorithmParameters(AlgorithmParameters):
self.use_kl_regularization = True
self.beta_entropy = 0.01
self.num_consecutive_playing_steps = EnvironmentSteps(5000)
self.act_for_full_episodes = True
class PPOAgentParameters(AgentParameters):
@@ -354,12 +355,9 @@ class PPOAgent(ActorCriticAgent):
# clean memory
self.call_memory('clean')
def _should_train_helper(self, wait_for_full_episode=True):
return super()._should_train_helper(True)
def train(self):
loss = 0
if self._should_train(wait_for_full_episode=True):
if self._should_train():
for network in self.networks.values():
network.set_is_training(True)
@@ -391,3 +389,4 @@ class PPOAgent(ActorCriticAgent):
def get_prediction(self, states):
tf_input_state = self.prepare_batch_for_inference(states, "actor")
return self.networks['actor'].online_network.predict(tf_input_state)

View File

@@ -171,6 +171,9 @@ class AlgorithmParameters(Parameters):
# Distributed Coach params
self.distributed_coach_synchronization_type = None
# Should the workers wait for full episode
self.act_for_full_episodes = False
class PresetValidationParameters(Parameters):
def __init__(self,

View File

@@ -287,15 +287,15 @@ class CoachLauncher(object):
def get_config_args(self, parser: argparse.ArgumentParser) -> argparse.Namespace:
"""
Returns a Namespace object with all the user-specified configuration options needed to launch.
This implementation uses argparse to take arguments from the CLI, but this can be over-ridden by
This implementation uses argparse to take arguments from the CLI, but this can be over-ridden by
another method that gets its configuration from elsewhere. An equivalent method however must
return an identically structured Namespace object, which conforms to the structure defined by
get_argument_parser.
This method parses the arguments that the user entered, does some basic validation, and
This method parses the arguments that the user entered, does some basic validation, and
modification of user-specified values in short form to be more explicit.
:param parser: a parser object which implicitly defines the format of the Namespace that
:param parser: a parser object which implicitly defines the format of the Namespace that
is expected to be returned.
:return: the parsed arguments as a Namespace
"""
@@ -333,26 +333,26 @@ class CoachLauncher(object):
args.s3_creds_file = coach_config.get('coach', 's3_creds_file')
except Error as e:
screen.error("Error when reading distributed Coach config file: {}".format(e))
if args.image == '':
screen.error("Image cannot be empty.")
data_store_choices = ['s3']
if args.data_store not in data_store_choices:
screen.warning("{} data store is unsupported.".format(args.data_store))
screen.error("Supported data stores are {}.".format(data_store_choices))
memory_backend_choices = ['redispubsub']
if args.memory_backend not in memory_backend_choices:
screen.warning("{} memory backend is not supported.".format(args.memory_backend))
screen.error("Supported memory backends are {}.".format(memory_backend_choices))
if args.s3_bucket_name == '':
screen.error("S3 bucket name cannot be empty.")
if args.s3_creds_file == '':
args.s3_creds_file = None
if args.play and args.distributed_coach:
screen.error("Playing is not supported in distributed Coach.")

View File

@@ -70,6 +70,7 @@ class ObservationNormalizationFilter(ObservationFilter):
return self.running_observation_stats.normalize(observations)
def get_filtered_observation_space(self, input_observation_space: ObservationSpace) -> ObservationSpace:
self.running_observation_stats.create_ops(shape=input_observation_space.shape,
clip_values=(self.clip_min, self.clip_max))
return input_observation_space

View File

@@ -27,13 +27,14 @@ from rl_coach.base_parameters import iterable_to_items, TaskParameters, Distribu
Parameters, PresetValidationParameters
from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, TrainingSteps, EnvironmentEpisodes, \
EnvironmentSteps, \
StepMethod
StepMethod, Transition
from rl_coach.environments.environment import Environment
from rl_coach.level_manager import LevelManager
from rl_coach.logger import screen, Logger
from rl_coach.utils import set_cpu, start_shell_command_and_wait
from rl_coach.data_stores.data_store_impl import get_data_store
from rl_coach.orchestrators.kubernetes_orchestrator import RunType
from rl_coach.memories.backend.memory_impl import get_memory_backend
from rl_coach.data_stores.data_store import SyncFiles
@@ -398,9 +399,10 @@ class GraphManager(object):
[environment.reset_internal_state(force_environment_reset) for environment in self.environments]
[manager.reset_internal_state() for manager in self.level_managers]
def act(self, steps: PlayingStepsType) -> None:
def act(self, steps: PlayingStepsType, wait_for_full_episodes=False) -> None:
"""
Do several steps of acting on the environment
:param wait_for_full_episodes: if set, act for at least `steps`, but make sure that the last episode is complete
:param steps: the number of steps as a tuple of steps time and steps count
"""
self.verify_graph_was_created()
@@ -412,7 +414,8 @@ class GraphManager(object):
# perform several steps of playing
count_end = self.current_step_counter + steps
while self.current_step_counter < count_end:
result = None
while self.current_step_counter < count_end or (wait_for_full_episodes and result is not None and not result.game_over):
# reset the environment if the previous episode was terminated
if self.reset_required:
self.reset_internal_state()
@@ -624,5 +627,46 @@ class GraphManager(object):
def should_train(self) -> bool:
return any([manager.should_train() for manager in self.level_managers])
# TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
# an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]
def emulate_act_on_trainer(self, steps: PlayingStepsType, transition: Transition) -> None:
"""
This emulates the act using the transition obtained from the rollout worker on the training worker
in case of distributed training.
Do several steps of acting on the environment
:param steps: the number of steps as a tuple of steps time and steps count
"""
self.verify_graph_was_created()
# perform several steps of playing
count_end = self.current_step_counter + steps
while self.current_step_counter < count_end:
# reset the environment if the previous episode was terminated
if self.reset_required:
self.reset_internal_state()
steps_begin = self.environments[0].total_steps_counter
self.top_level_manager.emulate_step_on_trainer(transition)
steps_end = self.environments[0].total_steps_counter
# add the diff between the total steps before and after stepping, such that environment initialization steps
# (like in Atari) will not be counted.
# We add at least one step so that even if no steps were made (in case no actions are taken in the training
# phase), the loop will end eventually.
self.current_step_counter[EnvironmentSteps] += max(1, steps_end - steps_begin)
if transition.game_over:
self.handle_episode_ended()
self.reset_required = True
def fetch_from_worker(self, num_steps=0):
if hasattr(self, 'memory_backend'):
for transition in self.memory_backend.fetch(num_steps):
self.emulate_act_on_trainer(EnvironmentSteps(1), transition)
def setup_memory_backend(self) -> None:
if hasattr(self.agent_params.memory, 'memory_backend_params'):
self.memory_backend = get_memory_backend(self.agent_params.memory.memory_backend_params)
def should_stop(self) -> bool:
return all([manager.should_stop() for manager in self.level_managers])

View File

@@ -17,7 +17,7 @@ import copy
from typing import Union, Dict
from rl_coach.agents.composite_agent import CompositeAgent
from rl_coach.core_types import EnvResponse, ActionInfo, RunPhase, ActionType, EnvironmentSteps
from rl_coach.core_types import EnvResponse, ActionInfo, RunPhase, ActionType, EnvironmentSteps, Transition
from rl_coach.environments.environment import Environment
from rl_coach.environments.environment_interface import EnvironmentInterface
from rl_coach.spaces import ActionSpace, SpacesDefinition
@@ -264,5 +264,31 @@ class LevelManager(EnvironmentInterface):
def should_train(self) -> bool:
return any([agent._should_train_helper() for agent in self.agents.values()])
# TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
# an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]
def emulate_step_on_trainer(self, transition: Transition) -> None:
"""
This emulates a step using the transition obtained from the rollout worker on the training worker
in case of distributed training.
Run a single step of following the behavioral scheme set for this environment.
:param action: the action to apply to the agents held in this level, before beginning following
the scheme.
:return: None
"""
if self.reset_required:
self.reset_internal_state()
acting_agent = list(self.agents.values())[0]
# for i in range(self.steps_limit.num_steps):
# let the agent observe the result and decide if it wants to terminate the episode
done = acting_agent.emulate_observe_on_trainer(transition)
acting_agent.emulate_act_on_trainer(transition)
if done:
self.handle_episode_ended()
self.reset_required = True
def should_stop(self) -> bool:
return all([agent.get_success_rate() >= self.environment.get_target_success_rate() for agent in self.agents.values()])

View File

@@ -32,5 +32,5 @@ class MemoryBackend(object):
def store_episode(self, obj):
raise NotImplemented("Not yet implemented")
def subscribe(self, memory):
def fetch(self, num_steps=0):
raise NotImplemented("Not yet implemented")

View File

@@ -2,13 +2,11 @@
import redis
import pickle
import uuid
import threading
import time
from kubernetes import client
from rl_coach.memories.backend.memory import MemoryBackend, MemoryBackendParameters
from rl_coach.core_types import Transition, Episode
from rl_coach.core_types import RunPhase
class RedisPubSubMemoryBackendParameters(MemoryBackendParameters):
@@ -131,42 +129,40 @@ class RedisPubSubBackend(MemoryBackend):
def sample(self, size):
pass
def fetch(self, num_steps=0):
return RedisSub(redis_address=self.params.redis_address, redis_port=self.params.redis_port, channel=self.params.channel).run(num_steps=num_steps)
def subscribe(self, agent):
redis_sub = RedisSub(agent, redis_address=self.params.redis_address, redis_port=self.params.redis_port, channel=self.params.channel)
redis_sub.daemon = True
redis_sub.start()
redis_sub = RedisSub(redis_address=self.params.redis_address, redis_port=self.params.redis_port, channel=self.params.channel)
return redis_sub
def get_endpoint(self):
return {'redis_address': self.params.redis_address,
'redis_port': self.params.redis_port}
class RedisSub(threading.Thread):
def __init__(self, agent, redis_address: str = "localhost", redis_port: int=6379, channel: str = "PubsubChannel"):
class RedisSub(object):
def __init__(self, redis_address: str = "localhost", redis_port: int=6379, channel: str = "PubsubChannel"):
super().__init__()
self.redis_connection = redis.Redis(redis_address, redis_port)
self.pubsub = self.redis_connection.pubsub()
self.subscriber = None
self.agent = agent
self.channel = channel
self.subscriber = self.pubsub.subscribe(self.channel)
def run(self):
def run(self, num_steps=0):
steps = 0
for message in self.pubsub.listen():
if message and 'data' in message and self.agent.phase != RunPhase.TEST or self.agent.ap.task_parameters.evaluate_only:
if self.agent.phase == RunPhase.TEST:
print(self.agent.phase)
if message and 'data' in message:
try:
obj = pickle.loads(message['data'])
if type(obj) == Transition:
self.agent.total_steps_counter += 1
self.agent.current_episode_steps_counter += 1
self.agent.call_memory('store', obj)
steps += 1
yield obj
elif type(obj) == Episode:
self.agent.current_episode_buffer = obj
self.agent.total_steps_counter += len(obj.transitions)
self.agent.current_episode_steps_counter += len(obj.transitions)
self.agent.handle_episode_ended()
steps += len(obj.transitions)
yield from obj.transitions
except Exception:
continue
if num_steps > 0 and steps >= num_steps:
break

View File

@@ -12,7 +12,7 @@ import os
import math
from rl_coach.base_parameters import TaskParameters, DistributedCoachSynchronizationType
from rl_coach.core_types import EnvironmentSteps, RunPhase
from rl_coach.core_types import EnvironmentSteps, RunPhase, EnvironmentEpisodes
from google.protobuf import text_format
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
from rl_coach.data_stores.data_store import SyncFiles
@@ -81,21 +81,23 @@ def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers):
task_parameters = TaskParameters()
task_parameters.__dict__['checkpoint_restore_dir'] = checkpoint_dir
time.sleep(30)
graph_manager.create_graph(task_parameters)
with graph_manager.phase_context(RunPhase.TRAIN):
error_compensation = 100
last_checkpoint = 0
act_steps = math.ceil((graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps + error_compensation)/num_workers)
act_steps = math.ceil((graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps)/num_workers)
for i in range(int(graph_manager.improve_steps.num_steps/act_steps)):
if should_stop(checkpoint_dir):
break
graph_manager.act(EnvironmentSteps(num_steps=act_steps))
if type(graph_manager.agent_params.algorithm.num_consecutive_playing_steps) == EnvironmentSteps:
graph_manager.act(EnvironmentSteps(num_steps=act_steps), wait_for_full_episode=graph_manager.agent_params.algorithm.act_for_full_episodes)
elif type(graph_manager.agent_params.algorithm.num_consecutive_playing_steps) == EnvironmentEpisodes:
graph_manager.act(EnvironmentEpisodes(num_steps=act_steps))
new_checkpoint = get_latest_checkpoint(checkpoint_dir)

View File

@@ -31,7 +31,14 @@ def training_worker(graph_manager, checkpoint_dir):
# evaluation offset
eval_offset = 1
graph_manager.setup_memory_backend()
while(steps < graph_manager.improve_steps.num_steps):
graph_manager.phase = core_types.RunPhase.TRAIN
graph_manager.fetch_from_worker(num_steps=graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps)
graph_manager.phase = core_types.RunPhase.UNDEFINED
if graph_manager.should_train():
steps += 1