1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

Create a dataset using an agent (#306)

Generate a dataset using an agent (allowing to select between this and a random dataset)
This commit is contained in:
Gal Leibovich
2019-05-28 09:34:49 +03:00
committed by GitHub
parent 342b7184bc
commit 9e9c4fd332
26 changed files with 351 additions and 111 deletions

View File

@@ -18,15 +18,17 @@ from typing import Tuple, List, Union
from rl_coach.agents.dqn_agent import DQNAgentParameters
from rl_coach.agents.nec_agent import NECAgentParameters
from rl_coach.architectures.network_wrapper import NetworkWrapper
from rl_coach.base_parameters import AgentParameters, VisualizationParameters, TaskParameters, \
PresetValidationParameters
from rl_coach.core_types import RunPhase
from rl_coach.core_types import RunPhase, TotalStepsCounter, TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
from rl_coach.environments.environment import EnvironmentParameters, Environment
from rl_coach.graph_managers.graph_manager import ScheduleParameters
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
from rl_coach.level_manager import LevelManager
from rl_coach.logger import screen
from rl_coach.schedules import LinearSchedule
from rl_coach.spaces import SpacesDefinition
from rl_coach.utils import short_dynamic_import
@@ -35,26 +37,62 @@ from rl_coach.memories.episodic import EpisodicExperienceReplayParameters
from rl_coach.core_types import TimeTypes
# TODO build a tutorial for batch RL
class BatchRLGraphManager(BasicRLGraphManager):
"""
A batch RL graph manager creates scenario of learning from a dataset without a simulator.
A batch RL graph manager creates a scenario of learning from a dataset without a simulator.
If an environment is given (useful either for research purposes, or for experimenting with a toy problem before
actually working with a real dataset), we can use it in order to collect a dataset to later be used to train the
actual agent. The collected dataset, in this case, can be collected either by randomly acting in the environment
(only running in heatup), or alternatively by training a different agent in the environment and using its collected
data as a dataset. If an experience generating agent parameters are given, we will instantiate this agent and use it
in order to train on the environment and then use this dataset to actually train an agent. Otherwise, we will
collect a random dataset.
:param agent_params: the parameters of the agent to train using batch RL
:param env_params: [optional] environment parameters, for cases where we want to first collect a dataset
:param vis_params: visualization parameters
:param preset_validation_params: preset validation parameters, to be used for testing purposes
:param name: graph name
:param spaces_definition: when working with a dataset, we need to get a description of the actual state and action
spaces of the problem
:param reward_model_num_epochs: the number of epochs to go over the dataset for training a reward model for the
'direct method' and 'doubly robust' OPE methods.
:param train_to_eval_ratio: percentage of the data transitions to be used for training vs. evaluation. i.e. a value
of 0.8 means ~80% of the transitions will be used for training and ~20% will be used for
evaluation using OPE.
:param experience_generating_agent_params: [optional] parameters of an agent to be trained vs. an environment, whose
his collected experience will be used to train the acutal (another) agent
:param experience_generating_schedule_params: [optional] graph scheduling parameters for training the experience
generating agent
"""
def __init__(self, agent_params: AgentParameters, env_params: Union[EnvironmentParameters, None],
def __init__(self, agent_params: AgentParameters,
env_params: Union[EnvironmentParameters, None],
schedule_params: ScheduleParameters,
vis_params: VisualizationParameters = VisualizationParameters(),
preset_validation_params: PresetValidationParameters = PresetValidationParameters(),
name='batch_rl_graph', spaces_definition: SpacesDefinition = None, reward_model_num_epochs: int = 100,
train_to_eval_ratio: float = 0.8):
train_to_eval_ratio: float = 0.8, experience_generating_agent_params: AgentParameters = None,
experience_generating_schedule_params: ScheduleParameters = None):
super().__init__(agent_params, env_params, schedule_params, vis_params, preset_validation_params, name)
self.is_batch_rl = True
self.time_metric = TimeTypes.Epoch
self.reward_model_num_epochs = reward_model_num_epochs
self.spaces_definition = spaces_definition
self.is_collecting_random_dataset = experience_generating_agent_params is None
# setting this here to make sure that, by default, train_to_eval_ratio gets a value < 1
# (its default value in the memory is 1)
self.agent_params.memory.train_to_eval_ratio = train_to_eval_ratio
# (its default value in the memory is 1, so not to affect other non-batch-rl scenarios)
if self.is_collecting_random_dataset:
self.agent_params.memory.train_to_eval_ratio = train_to_eval_ratio
else:
experience_generating_agent_params.memory.train_to_eval_ratio = train_to_eval_ratio
self.experience_generating_agent_params = experience_generating_agent_params
self.experience_generating_agent = None
self.set_schedule_params(experience_generating_schedule_params)
self.schedule_params = schedule_params
def _create_graph(self, task_parameters: TaskParameters) -> Tuple[List[LevelManager], List[Environment]]:
if self.env_params:
@@ -76,22 +114,41 @@ class BatchRLGraphManager(BasicRLGraphManager):
self.agent_params.task_parameters = task_parameters # TODO: this should probably be passed in a different way
self.agent_params.name = "agent"
self.agent_params.is_batch_rl_training = True
self.agent_params.network_wrappers['main'].should_get_softmax_probabilities = True
if 'reward_model' not in self.agent_params.network_wrappers:
# user hasn't defined params for the reward model. we will use the same params as used for the 'main'
# network.
self.agent_params.network_wrappers['reward_model'] = deepcopy(self.agent_params.network_wrappers['main'])
agent = short_dynamic_import(self.agent_params.path)(self.agent_params)
self.agent = short_dynamic_import(self.agent_params.path)(self.agent_params)
agents = {'agent': self.agent}
if not self.is_collecting_random_dataset:
self.experience_generating_agent_params.visualization.dump_csv = False
self.experience_generating_agent_params.task_parameters = task_parameters
self.experience_generating_agent_params.name = "experience_gen_agent"
self.experience_generating_agent_params.network_wrappers['main'].should_get_softmax_probabilities = True
# we need to set these manually as these are usually being set for us only for the default agent
self.experience_generating_agent_params.input_filter = self.agent_params.input_filter
self.experience_generating_agent_params.output_filter = self.agent_params.output_filter
self.experience_generating_agent = short_dynamic_import(
self.experience_generating_agent_params.path)(self.experience_generating_agent_params)
agents['experience_generating_agent'] = self.experience_generating_agent
if not env and not self.agent_params.memory.load_memory_from_file_path:
screen.warning("A BatchRLGraph requires setting a dataset to load into the agent's memory or alternatively "
"using an environment to create a (random) dataset from. This agent should only be used for "
"inference. ")
# set level manager
level_manager = LevelManager(agents=agent, environment=env, name="main_level",
# - although we will be using each agent separately, we have to have both agents initialized together with the
# LevelManager, so to have them both properly initialized
level_manager = LevelManager(agents=agents,
environment=env, name="main_level",
spaces_definition=self.spaces_definition)
if env:
return [level_manager], [env]
else:
@@ -123,12 +180,34 @@ class BatchRLGraphManager(BasicRLGraphManager):
# an environment and a dataset to load from, we will use the environment only for evaluating the policy,
# and will not run heatup.
# heatup
if self.env_params is not None and not self.agent_params.memory.load_memory_from_file_path:
self.heatup(self.heatup_steps)
screen.log_title("Starting to improve an agent collecting experience to use for training the actual agent in a "
"Batch RL fashion")
if self.is_collecting_random_dataset:
# heatup
if self.env_params is not None and not self.agent_params.memory.load_memory_from_file_path:
self.heatup(self.heatup_steps)
else:
# set the experience generating agent to train
self.level_managers[0].agents = {'experience_generating_agent': self.experience_generating_agent}
# collect a dataset using the experience generating agent
super().improve()
# set the acquired experience to the actual agent that we're going to train
self.agent.memory = self.experience_generating_agent.memory
# switch the graph scheduling parameters
self.set_schedule_params(self.schedule_params)
# set the actual agent to train
self.level_managers[0].agents = {'agent': self.agent}
# this agent never actually plays
self.level_managers[0].agents['agent'].ap.algorithm.num_consecutive_playing_steps = EnvironmentSteps(0)
# from this point onwards, the dataset cannot be changed anymore. Allows for performance improvements.
self.level_managers[0].agents['agent'].memory.freeze()
self.level_managers[0].agents['agent'].freeze_memory()
self.initialize_ope_models_and_stats()
@@ -141,15 +220,13 @@ class BatchRLGraphManager(BasicRLGraphManager):
# the outer most training loop
improve_steps_end = self.total_steps_counters[RunPhase.TRAIN] + self.improve_steps
while self.total_steps_counters[RunPhase.TRAIN] < improve_steps_end:
# TODO if we have an environment, do we want to use it to have the agent train against, and use the
# collected replay buffer as a dataset? (as oppose to what we currently have, where the dataset is built
# during heatup, and is composed on random actions)
# perform several steps of training
if self.steps_between_evaluation_periods.num_steps > 0:
with self.phase_context(RunPhase.TRAIN):
self.reset_internal_state(force_environment_reset=True)
steps_between_evaluation_periods_end = self.current_step_counter + self.steps_between_evaluation_periods
steps_between_evaluation_periods_end = self.current_step_counter + \
self.steps_between_evaluation_periods
while self.current_step_counter < steps_between_evaluation_periods_end:
self.train()
@@ -168,8 +245,8 @@ class BatchRLGraphManager(BasicRLGraphManager):
def initialize_ope_models_and_stats(self):
"""
:return:
Improve a reward model of the MDP, to be used for some of the off-policy evaluation (OPE) methods.
e.g. 'direct method' and 'doubly robust'.
"""
agent = self.level_managers[0].agents['agent']
@@ -193,6 +270,3 @@ class BatchRLGraphManager(BasicRLGraphManager):
:return:
"""
self.level_managers[0].agents['agent'].run_off_policy_evaluation()

View File

@@ -95,10 +95,7 @@ class GraphManager(object):
self.level_managers = [] # type: List[LevelManager]
self.top_level_manager = None
self.environments = []
self.heatup_steps = schedule_params.heatup_steps
self.evaluation_steps = schedule_params.evaluation_steps
self.steps_between_evaluation_periods = schedule_params.steps_between_evaluation_periods
self.improve_steps = schedule_params.improve_steps
self.set_schedule_params(schedule_params)
self.visualization_parameters = vis_params
self.name = name
self.task_parameters = None
@@ -759,3 +756,14 @@ class GraphManager(object):
if hasattr(self, 'data_store_params'):
data_store = self.get_data_store(self.data_store_params)
data_store.save_to_store()
def set_schedule_params(self, schedule_params: ScheduleParameters):
"""
Set schedule parameters for the graph.
:param schedule_params: the schedule params to set.
"""
self.heatup_steps = schedule_params.heatup_steps
self.evaluation_steps = schedule_params.evaluation_steps
self.steps_between_evaluation_periods = schedule_params.steps_between_evaluation_periods
self.improve_steps = schedule_params.improve_steps