1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00
This commit is contained in:
Gal Leibovich
2019-03-19 18:07:09 +02:00
committed by GitHub
parent 4a8451ff02
commit e3c7e526c7
38 changed files with 1003 additions and 87 deletions

View File

@@ -18,6 +18,7 @@ from typing import Tuple, List
from rl_coach.base_parameters import AgentParameters, VisualizationParameters, TaskParameters, \
PresetValidationParameters
from rl_coach.environments.environment import EnvironmentParameters, Environment
from rl_coach.filters.filter import NoInputFilter, NoOutputFilter
from rl_coach.graph_managers.graph_manager import GraphManager, ScheduleParameters
from rl_coach.level_manager import LevelManager
from rl_coach.utils import short_dynamic_import
@@ -31,17 +32,28 @@ class BasicRLGraphManager(GraphManager):
def __init__(self, agent_params: AgentParameters, env_params: EnvironmentParameters,
schedule_params: ScheduleParameters,
vis_params: VisualizationParameters=VisualizationParameters(),
preset_validation_params: PresetValidationParameters = PresetValidationParameters()):
super().__init__('simple_rl_graph', schedule_params, vis_params)
preset_validation_params: PresetValidationParameters = PresetValidationParameters(),
name='simple_rl_graph'):
super().__init__(name, schedule_params, vis_params)
self.agent_params = agent_params
self.env_params = env_params
self.preset_validation_params = preset_validation_params
self.agent_params.visualization = vis_params
if self.agent_params.input_filter is None:
self.agent_params.input_filter = env_params.default_input_filter()
if env_params is not None:
self.agent_params.input_filter = env_params.default_input_filter()
else:
# In cases where there is no environment (e.g. batch-rl and imitation learning), there is nowhere to get
# a default filter from. So using a default no-filter.
# When there is no environment, the user is expected to define input/output filters (if required) using
# the preset.
self.agent_params.input_filter = NoInputFilter()
if self.agent_params.output_filter is None:
self.agent_params.output_filter = env_params.default_output_filter()
if env_params is not None:
self.agent_params.output_filter = env_params.default_output_filter()
else:
self.agent_params.output_filter = NoOutputFilter()
def _create_graph(self, task_parameters: TaskParameters) -> Tuple[List[LevelManager], List[Environment]]:
# environment loading

View File

@@ -0,0 +1,180 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from copy import deepcopy
from typing import Tuple, List, Union
from rl_coach.agents.dqn_agent import DQNAgentParameters
from rl_coach.base_parameters import AgentParameters, VisualizationParameters, TaskParameters, \
PresetValidationParameters
from rl_coach.core_types import RunPhase
from rl_coach.environments.environment import EnvironmentParameters, Environment
from rl_coach.graph_managers.graph_manager import ScheduleParameters
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
from rl_coach.level_manager import LevelManager
from rl_coach.logger import screen
from rl_coach.spaces import SpacesDefinition
from rl_coach.utils import short_dynamic_import
from rl_coach.memories.episodic import EpisodicExperienceReplayParameters
from rl_coach.core_types import TimeTypes
class BatchRLGraphManager(BasicRLGraphManager):
"""
A batch RL graph manager creates scenario of learning from a dataset without a simulator.
"""
def __init__(self, agent_params: AgentParameters, env_params: Union[EnvironmentParameters, None],
schedule_params: ScheduleParameters,
vis_params: VisualizationParameters = VisualizationParameters(),
preset_validation_params: PresetValidationParameters = PresetValidationParameters(),
name='batch_rl_graph', spaces_definition: SpacesDefinition = None, reward_model_num_epochs: int = 100,
train_to_eval_ratio: float = 0.8):
super().__init__(agent_params, env_params, schedule_params, vis_params, preset_validation_params, name)
self.is_batch_rl = True
self.time_metric = TimeTypes.Epoch
self.reward_model_num_epochs = reward_model_num_epochs
self.spaces_definition = spaces_definition
# setting this here to make sure that, by default, train_to_eval_ratio gets a value < 1
# (its default value in the memory is 1)
self.agent_params.memory.train_to_eval_ratio = train_to_eval_ratio
def _create_graph(self, task_parameters: TaskParameters) -> Tuple[List[LevelManager], List[Environment]]:
if self.env_params:
# environment loading
self.env_params.seed = task_parameters.seed
self.env_params.experiment_path = task_parameters.experiment_path
env = short_dynamic_import(self.env_params.path)(**self.env_params.__dict__,
visualization_parameters=self.visualization_parameters)
else:
env = None
# Only DQN variants are supported at this point.
assert(isinstance(self.agent_params, DQNAgentParameters))
# Only Episodic memories are supported,
# for evaluating the sequential doubly robust estimator
assert(isinstance(self.agent_params.memory, EpisodicExperienceReplayParameters))
# agent loading
self.agent_params.task_parameters = task_parameters # TODO: this should probably be passed in a different way
self.agent_params.name = "agent"
self.agent_params.is_batch_rl_training = True
# user hasn't defined params for the reward model. we will use the same params as used for the 'main' network.
if 'reward_model' not in self.agent_params.network_wrappers:
self.agent_params.network_wrappers['reward_model'] = deepcopy(self.agent_params.network_wrappers['main'])
agent = short_dynamic_import(self.agent_params.path)(self.agent_params)
if not env and not self.agent_params.memory.load_memory_from_file_path:
screen.warning("A BatchRLGraph requires setting a dataset to load into the agent's memory or alternatively "
"using an environment to create a (random) dataset from. This agent should only be used for "
"inference. ")
# set level manager
level_manager = LevelManager(agents=agent, environment=env, name="main_level",
spaces_definition=self.spaces_definition)
if env:
return [level_manager], [env]
else:
return [level_manager], []
def improve(self):
"""
The main loop of the run.
Defined in the following steps:
1. Heatup
2. Repeat:
2.1. Repeat:
2.1.1. Train
2.1.2. Possibly save checkpoint
2.2. Evaluate
:return: None
"""
self.verify_graph_was_created()
# initialize the network parameters from the global network
self.sync()
# TODO a bug in heatup where the last episode run is not fed into the ER. e.g. asked for 1024 heatup steps,
# last ran episode ended increased the total to 1040 steps, but the ER will contain only 1014 steps.
# The last episode is not there. Is this a bug in my changes or also on master?
# Creating a dataset during the heatup phase is useful mainly for tutorial and debug purposes. If we have both
# an environment and a dataset to load from, we will use the environment only for evaluating the policy,
# and will not run heatup.
# heatup
if self.env_params is not None and not self.agent_params.memory.load_memory_from_file_path:
self.heatup(self.heatup_steps)
self.improve_reward_model()
# improve
if self.task_parameters.task_index is not None:
screen.log_title("Starting to improve {} task index {}".format(self.name, self.task_parameters.task_index))
else:
screen.log_title("Starting to improve {}".format(self.name))
# the outer most training loop
improve_steps_end = self.total_steps_counters[RunPhase.TRAIN] + self.improve_steps
while self.total_steps_counters[RunPhase.TRAIN] < improve_steps_end:
# TODO if we have an environment, do we want to use it to have the agent train against, and use the
# collected replay buffer as a dataset? (as oppose to what we currently have, where the dataset is built
# during heatup, and is composed on random actions)
# perform several steps of training
if self.steps_between_evaluation_periods.num_steps > 0:
with self.phase_context(RunPhase.TRAIN):
self.reset_internal_state(force_environment_reset=True)
steps_between_evaluation_periods_end = self.current_step_counter + self.steps_between_evaluation_periods
while self.current_step_counter < steps_between_evaluation_periods_end:
self.train()
# the output of batch RL training is always a checkpoint of the trained agent. we always save a checkpoint,
# each epoch, regardless of the user's command line arguments.
self.save_checkpoint()
# run off-policy evaluation estimators to evaluate the agent's performance against the dataset
self.run_off_policy_evaluation()
if self.env_params is not None and self.evaluate(self.evaluation_steps):
# if we do have a simulator (although we are in a batch RL setting we might have a simulator, e.g. when
# demonstrating the batch RL use-case using one of the existing Coach environments),
# we might want to evaluate vs. the simulator every now and then.
break
def improve_reward_model(self):
"""
:return:
"""
screen.log_title("Training a regression model for estimating MDP rewards")
self.level_managers[0].agents['agent'].improve_reward_model(epochs=self.reward_model_num_epochs)
def run_off_policy_evaluation(self):
"""
Run off-policy evaluation estimators to evaluate the trained policy performance against the dataset
:return:
"""
self.level_managers[0].agents['agent'].run_off_policy_evaluation()

View File

@@ -38,6 +38,8 @@ from rl_coach.data_stores.data_store_impl import get_data_store as data_store_cr
from rl_coach.memories.backend.memory_impl import get_memory_backend
from rl_coach.data_stores.data_store import SyncFiles
from rl_coach.core_types import TimeTypes
class ScheduleParameters(Parameters):
def __init__(self):
@@ -119,6 +121,8 @@ class GraphManager(object):
self.checkpoint_state_updater = None
self.graph_logger = Logger()
self.data_store = None
self.is_batch_rl = False
self.time_metric = TimeTypes.EpisodeNumber
def create_graph(self, task_parameters: TaskParameters=TaskParameters()):
self.graph_creation_time = time.time()
@@ -445,16 +449,17 @@ class GraphManager(object):
result = self.top_level_manager.step(None)
steps_end = self.environments[0].total_steps_counter
# add the diff between the total steps before and after stepping, such that environment initialization steps
# (like in Atari) will not be counted.
# We add at least one step so that even if no steps were made (in case no actions are taken in the training
# phase), the loop will end eventually.
self.current_step_counter[EnvironmentSteps] += max(1, steps_end - steps_begin)
if result.game_over:
self.handle_episode_ended()
self.reset_required = True
self.current_step_counter[EnvironmentSteps] += (steps_end - steps_begin)
# if no steps were made (can happen when no actions are taken while in the TRAIN phase, either in batch RL
# or in imitation learning), we force end the loop, so that it will not continue forever.
if (steps_end - steps_begin) == 0:
break
def train_and_act(self, steps: StepMethod) -> None:
"""
Train the agent by doing several acting steps followed by several training steps continually
@@ -472,9 +477,9 @@ class GraphManager(object):
while self.current_step_counter < count_end:
# The actual number of steps being done on the environment
# is decided by the agent, though this inner loop always
# takes at least one step in the environment. Depending on
# internal counters and parameters, it doesn't always train
# or save checkpoints.
# takes at least one step in the environment (at the GraphManager level).
# The agent might also decide to skip acting altogether.
# Depending on internal counters and parameters, it doesn't always train or save checkpoints.
self.act(EnvironmentSteps(1))
self.train()
self.occasionally_save_checkpoint()