From e3c7e526c78e2cc039621ef58cf062cce3a65697 Mon Sep 17 00:00:00 2001 From: Gal Leibovich Date: Tue, 19 Mar 2019 18:07:09 +0200 Subject: [PATCH] Batch RL (#238) --- __init__.py | 0 docs_raw/source/usage.rst | 3 +- rl_coach/agents/agent.py | 79 ++++++-- rl_coach/agents/agent_interface.py | 9 + rl_coach/agents/bootstrapped_dqn_agent.py | 3 + rl_coach/agents/categorical_dqn_agent.py | 3 + rl_coach/agents/composite_agent.py | 1 + rl_coach/agents/ddqn_agent.py | 3 + rl_coach/agents/dqn_agent.py | 3 + rl_coach/agents/n_step_q_agent.py | 3 + rl_coach/agents/qr_dqn_agent.py | 3 + rl_coach/agents/rainbow_dqn_agent.py | 3 + rl_coach/agents/value_optimization_agent.py | 82 +++++++- .../tensorflow_components/heads/q_head.py | 5 + .../tensorflow_components/savers.py | 2 +- rl_coach/base_parameters.py | 19 +- rl_coach/core_types.py | 29 ++- .../dashboard_components/experiment_board.py | 5 + rl_coach/dashboard_components/globals.py | 8 +- rl_coach/dashboard_components/signals_file.py | 7 +- .../graph_managers/basic_rl_graph_manager.py | 22 ++- .../graph_managers/batch_rl_graph_manager.py | 180 ++++++++++++++++++ rl_coach/graph_managers/graph_manager.py | 23 ++- rl_coach/level_manager.py | 44 +++-- rl_coach/logger.py | 19 +- .../episodic/episodic_experience_replay.py | 149 ++++++++++++++- .../non_episodic/experience_replay.py | 36 +++- rl_coach/off_policy_evaluators/__init__.py | 15 ++ .../off_policy_evaluators/bandits/__init__.py | 15 ++ .../bandits/doubly_robust.py | 40 ++++ rl_coach/off_policy_evaluators/ope_manager.py | 124 ++++++++++++ rl_coach/off_policy_evaluators/rl/__init__.py | 0 .../rl/sequential_doubly_robust.py | 51 +++++ rl_coach/presets/CARLA_CIL.py | 3 +- rl_coach/presets/CartPole_DQN_BatchRL.py | 91 +++++++++ rl_coach/presets/Doom_Basic_BC.py | 3 +- rl_coach/presets/MontezumaRevenge_BC.py | 3 +- rl_coach/spaces.py | 2 +- 38 files changed, 1003 insertions(+), 87 deletions(-) create mode 100644 __init__.py create mode 100644 rl_coach/graph_managers/batch_rl_graph_manager.py create mode 100644 rl_coach/off_policy_evaluators/__init__.py create mode 100644 rl_coach/off_policy_evaluators/bandits/__init__.py create mode 100644 rl_coach/off_policy_evaluators/bandits/doubly_robust.py create mode 100644 rl_coach/off_policy_evaluators/ope_manager.py create mode 100644 rl_coach/off_policy_evaluators/rl/__init__.py create mode 100644 rl_coach/off_policy_evaluators/rl/sequential_doubly_robust.py create mode 100644 rl_coach/presets/CartPole_DQN_BatchRL.py diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docs_raw/source/usage.rst b/docs_raw/source/usage.rst index c32c07f..8ff520a 100644 --- a/docs_raw/source/usage.rst +++ b/docs_raw/source/usage.rst @@ -113,7 +113,8 @@ In Coach, this can be done in two steps - .. code-block:: python - coach -p Doom_Basic_BC -cp='agent.load_memory_from_file_path=\"/replay_buffer.p\"' + from rl_coach.core_types import PickledReplayBuffer + coach -p Doom_Basic_BC -cp='agent.load_memory_from_file_path=PickledReplayBuffer(\"/replay_buffer.p\"') Visualizations diff --git a/rl_coach/agents/agent.py b/rl_coach/agents/agent.py index e7706c7..28bdd82 100644 --- a/rl_coach/agents/agent.py +++ b/rl_coach/agents/agent.py @@ -34,6 +34,9 @@ from rl_coach.spaces import SpacesDefinition, VectorObservationSpace, GoalsSpace from rl_coach.utils import Signal, force_list from rl_coach.utils import dynamic_import_and_instantiate_module_from_params from rl_coach.memories.backend.memory_impl import get_memory_backend +from rl_coach.core_types import TimeTypes +from rl_coach.off_policy_evaluators.ope_manager import OpeManager +from rl_coach.core_types import PickledReplayBuffer, CsvDataset class Agent(AgentInterface): @@ -49,10 +52,10 @@ class Agent(AgentInterface): and self.ap.memory.shared_memory if self.shared_memory: self.shared_memory_scratchpad = self.ap.task_parameters.shared_memory_scratchpad - self.name = agent_parameters.name self.parent = parent self.parent_level_manager = None - self.full_name_id = agent_parameters.full_name_id = self.name + # TODO this needs to be sorted out. Why the duplicates for the agent's name? + self.full_name_id = agent_parameters.full_name_id = self.name = agent_parameters.name if type(agent_parameters.task_parameters) == DistributedTaskParameters: screen.log_title("Creating agent - name: {} task id: {} (may take up to 30 seconds due to " @@ -84,9 +87,17 @@ class Agent(AgentInterface): self.memory.set_memory_backend(self.memory_backend) if agent_parameters.memory.load_memory_from_file_path: - screen.log_title("Loading replay buffer from pickle. Pickle path: {}" - .format(agent_parameters.memory.load_memory_from_file_path)) - self.memory.load(agent_parameters.memory.load_memory_from_file_path) + if isinstance(agent_parameters.memory.load_memory_from_file_path, PickledReplayBuffer): + screen.log_title("Loading a pickled replay buffer. Pickled file path: {}" + .format(agent_parameters.memory.load_memory_from_file_path.filepath)) + self.memory.load_pickled(agent_parameters.memory.load_memory_from_file_path.filepath) + elif isinstance(agent_parameters.memory.load_memory_from_file_path, CsvDataset): + screen.log_title("Loading a replay buffer from a CSV file. CSV file path: {}" + .format(agent_parameters.memory.load_memory_from_file_path.filepath)) + self.memory.load_csv(agent_parameters.memory.load_memory_from_file_path) + else: + raise ValueError('Trying to load a replay buffer using an unsupported method - {}. ' + .format(agent_parameters.memory.load_memory_from_file_path)) if self.shared_memory and self.is_chief: self.shared_memory_scratchpad.add(self.memory_lookup_name, self.memory) @@ -147,6 +158,7 @@ class Agent(AgentInterface): self.total_steps_counter = 0 self.running_reward = None self.training_iteration = 0 + self.training_epoch = 0 self.last_target_network_update_step = 0 self.last_training_phase_step = 0 self.current_episode = self.ap.current_episode = 0 @@ -184,6 +196,7 @@ class Agent(AgentInterface): self.discounted_return = self.register_signal('Discounted Return') if isinstance(self.in_action_space, GoalsSpace): self.distance_from_goal = self.register_signal('Distance From Goal', dump_one_value_per_step=True) + # use seed if self.ap.task_parameters.seed is not None: random.seed(self.ap.task_parameters.seed) @@ -193,6 +206,9 @@ class Agent(AgentInterface): random.seed() np.random.seed() + # batch rl + self.ope_manager = OpeManager() if self.ap.is_batch_rl_training else None + @property def parent(self) -> 'LevelManager': """ @@ -228,6 +244,7 @@ class Agent(AgentInterface): format(graph_name=self.parent_level_manager.parent_graph_manager.name, level_name=self.parent_level_manager.name, agent_full_id='.'.join(self.full_name_id.split('/'))) + self.agent_logger.set_index_name(self.parent_level_manager.parent_graph_manager.time_metric.value.name) self.agent_logger.set_logger_filenames(self.ap.task_parameters.experiment_path, logger_prefix=logger_prefix, add_timestamp=True, task_id=self.task_id) if self.ap.visualization.dump_in_episode_signals: @@ -387,7 +404,8 @@ class Agent(AgentInterface): elif ending_evaluation: # we write to the next episode, because it could be that the current episode was already written # to disk and then we won't write it again - self.agent_logger.set_current_time(self.current_episode + 1) + self.agent_logger.set_current_time(self.get_current_time() + 1) + evaluation_reward = self.accumulated_rewards_across_evaluation_episodes / self.num_evaluation_episodes_completed self.agent_logger.create_signal_value( 'Evaluation Reward', evaluation_reward) @@ -471,8 +489,11 @@ class Agent(AgentInterface): :return: None """ # log all the signals to file - self.agent_logger.set_current_time(self.current_episode) + current_time = self.get_current_time() + self.agent_logger.set_current_time(current_time) self.agent_logger.create_signal_value('Training Iter', self.training_iteration) + self.agent_logger.create_signal_value('Episode #', self.current_episode) + self.agent_logger.create_signal_value('Epoch', self.training_epoch) self.agent_logger.create_signal_value('In Heatup', int(self._phase == RunPhase.HEATUP)) self.agent_logger.create_signal_value('ER #Transitions', self.call_memory('num_transitions')) self.agent_logger.create_signal_value('ER #Episodes', self.call_memory('length')) @@ -485,13 +506,17 @@ class Agent(AgentInterface): if self._phase == RunPhase.TRAIN else np.nan) self.agent_logger.create_signal_value('Update Target Network', 0, overwrite=False) - self.agent_logger.update_wall_clock_time(self.current_episode) + self.agent_logger.update_wall_clock_time(current_time) # The following signals are created with meaningful values only when an evaluation phase is completed. # Creating with default NaNs for any HEATUP/TRAIN/TEST episode which is not the last in an evaluation phase self.agent_logger.create_signal_value('Evaluation Reward', np.nan, overwrite=False) self.agent_logger.create_signal_value('Shaped Evaluation Reward', np.nan, overwrite=False) self.agent_logger.create_signal_value('Success Rate', np.nan, overwrite=False) + self.agent_logger.create_signal_value('Inverse Propensity Score', np.nan, overwrite=False) + self.agent_logger.create_signal_value('Direct Method Reward', np.nan, overwrite=False) + self.agent_logger.create_signal_value('Doubly Robust', np.nan, overwrite=False) + self.agent_logger.create_signal_value('Sequential Doubly Robust', np.nan, overwrite=False) for signal in self.episode_signals: self.agent_logger.create_signal_value("{}/Mean".format(signal.name), signal.get_mean()) @@ -500,8 +525,7 @@ class Agent(AgentInterface): self.agent_logger.create_signal_value("{}/Min".format(signal.name), signal.get_min()) # dump - if self.current_episode % self.ap.visualization.dump_signals_to_csv_every_x_episodes == 0 \ - and self.current_episode > 0: + if self.current_episode % self.ap.visualization.dump_signals_to_csv_every_x_episodes == 0: self.agent_logger.dump_output_csv() def handle_episode_ended(self) -> None: @@ -537,7 +561,8 @@ class Agent(AgentInterface): self.total_reward_in_current_episode >= self.spaces.reward.reward_success_threshold: self.num_successes_across_evaluation_episodes += 1 - if self.ap.visualization.dump_csv: + if self.ap.visualization.dump_csv and \ + self.parent_level_manager.parent_graph_manager.time_metric == TimeTypes.EpisodeNumber: self.update_log() if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == "high": @@ -651,18 +676,22 @@ class Agent(AgentInterface): """ loss = 0 if self._should_train(): + self.training_epoch += 1 for network in self.networks.values(): network.set_is_training(True) - for training_step in range(self.ap.algorithm.num_consecutive_training_steps): - # TODO: this should be network dependent - network_parameters = list(self.ap.network_wrappers.values())[0] + # TODO: this should be network dependent + network_parameters = list(self.ap.network_wrappers.values())[0] + # we either go sequentially through the entire replay buffer in the batch RL mode, + # or sample randomly for the basic RL case. + training_schedule = self.call_memory('get_shuffled_data_generator', network_parameters.batch_size) if \ + self.ap.is_batch_rl_training else [self.call_memory('sample', network_parameters.batch_size) for _ in + range(self.ap.algorithm.num_consecutive_training_steps)] + + for batch in training_schedule: # update counters self.training_iteration += 1 - - # sample a batch and train on it - batch = self.call_memory('sample', network_parameters.batch_size) if self.pre_network_filter is not None: batch = self.pre_network_filter.filter(batch, update_internal_state=False, deep_copy=False) @@ -673,6 +702,7 @@ class Agent(AgentInterface): batch = Batch(batch) total_loss, losses, unclipped_grads = self.learn_from_batch(batch) loss += total_loss + self.unclipped_grads.add_sample(unclipped_grads) # TODO: the learning rate decay should be done through the network instead of here @@ -697,6 +727,12 @@ class Agent(AgentInterface): if self.imitation: self.log_to_screen() + if self.ap.visualization.dump_csv and \ + self.parent_level_manager.parent_graph_manager.time_metric == TimeTypes.Epoch: + # in BatchRL, or imitation learning, the agent never acts, so we have to get the stats out here. + # we dump the data out every epoch + self.update_log() + for network in self.networks.values(): network.set_is_training(False) @@ -1034,3 +1070,12 @@ class Agent(AgentInterface): for network in self.networks.values(): savers.update(network.collect_savers(parent_path_suffix)) return savers + + def get_current_time(self): + pass + return { + TimeTypes.EpisodeNumber: self.current_episode, + TimeTypes.TrainingIteration: self.training_iteration, + TimeTypes.EnvironmentSteps: self.total_steps_counter, + TimeTypes.WallClockTime: self.agent_logger.get_current_wall_clock_time(), + TimeTypes.Epoch: self.training_epoch}[self.parent_level_manager.parent_graph_manager.time_metric] diff --git a/rl_coach/agents/agent_interface.py b/rl_coach/agents/agent_interface.py index e8aba49..16c32a5 100644 --- a/rl_coach/agents/agent_interface.py +++ b/rl_coach/agents/agent_interface.py @@ -173,3 +173,12 @@ class AgentInterface(object): :return: None """ raise NotImplementedError("") + + def run_off_policy_evaluation(self) -> None: + """ + Run off-policy evaluation estimators to evaluate the trained policy performance against a dataset. + Should only be implemented for off-policy RL algorithms. + + :return: None + """ + raise NotImplementedError("") diff --git a/rl_coach/agents/bootstrapped_dqn_agent.py b/rl_coach/agents/bootstrapped_dqn_agent.py index bbbd242..e8ee7d8 100644 --- a/rl_coach/agents/bootstrapped_dqn_agent.py +++ b/rl_coach/agents/bootstrapped_dqn_agent.py @@ -64,6 +64,9 @@ class BootstrappedDQNAgent(ValueOptimizationAgent): q_st_plus_1 = result[:self.ap.exploration.architecture_num_q_heads] TD_targets = result[self.ap.exploration.architecture_num_q_heads:] + # add Q value samples for logging + self.q_values.add_sample(TD_targets) + # initialize with the current prediction so that we will # only update the action that we have actually done in this transition for i in range(self.ap.network_wrappers['main'].batch_size): diff --git a/rl_coach/agents/categorical_dqn_agent.py b/rl_coach/agents/categorical_dqn_agent.py index 1c4b30e..cfcbe9d 100644 --- a/rl_coach/agents/categorical_dqn_agent.py +++ b/rl_coach/agents/categorical_dqn_agent.py @@ -100,6 +100,9 @@ class CategoricalDQNAgent(ValueOptimizationAgent): (self.networks['main'].online_network, batch.states(network_keys)) ]) + # add Q value samples for logging + self.q_values.add_sample(self.distribution_prediction_to_q_values(TD_targets)) + # select the optimal actions for the next state target_actions = np.argmax(self.distribution_prediction_to_q_values(distributional_q_st_plus_1), axis=1) m = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size)) diff --git a/rl_coach/agents/composite_agent.py b/rl_coach/agents/composite_agent.py index a3cf747..89dfdf4 100644 --- a/rl_coach/agents/composite_agent.py +++ b/rl_coach/agents/composite_agent.py @@ -432,3 +432,4 @@ class CompositeAgent(AgentInterface): savers.update(agent.collect_savers( parent_path_suffix="{}.{}".format(parent_path_suffix, self.name))) return savers + diff --git a/rl_coach/agents/ddqn_agent.py b/rl_coach/agents/ddqn_agent.py index 5268e6d..7021f8e 100644 --- a/rl_coach/agents/ddqn_agent.py +++ b/rl_coach/agents/ddqn_agent.py @@ -50,6 +50,9 @@ class DDQNAgent(ValueOptimizationAgent): (self.networks['main'].online_network, batch.states(network_keys)) ]) + # add Q value samples for logging + self.q_values.add_sample(TD_targets) + # initialize with the current prediction so that we will # only update the action that we have actually done in this transition TD_errors = [] diff --git a/rl_coach/agents/dqn_agent.py b/rl_coach/agents/dqn_agent.py index b234e88..d6c05da 100644 --- a/rl_coach/agents/dqn_agent.py +++ b/rl_coach/agents/dqn_agent.py @@ -81,6 +81,9 @@ class DQNAgent(ValueOptimizationAgent): (self.networks['main'].online_network, batch.states(network_keys)) ]) + # add Q value samples for logging + self.q_values.add_sample(TD_targets) + # only update the action that we have actually done in this transition TD_errors = [] for i in range(self.ap.network_wrappers['main'].batch_size): diff --git a/rl_coach/agents/n_step_q_agent.py b/rl_coach/agents/n_step_q_agent.py index cc44891..21b9239 100644 --- a/rl_coach/agents/n_step_q_agent.py +++ b/rl_coach/agents/n_step_q_agent.py @@ -123,6 +123,9 @@ class NStepQAgent(ValueOptimizationAgent, PolicyOptimizationAgent): else: assert True, 'The available values for targets_horizon are: 1-Step, N-Step' + # add Q value samples for logging + self.q_values.add_sample(state_value_head_targets) + # train result = self.networks['main'].online_network.accumulate_gradients(batch.states(network_keys), [state_value_head_targets]) diff --git a/rl_coach/agents/qr_dqn_agent.py b/rl_coach/agents/qr_dqn_agent.py index 1b26d6d..d5cf3fd 100644 --- a/rl_coach/agents/qr_dqn_agent.py +++ b/rl_coach/agents/qr_dqn_agent.py @@ -88,6 +88,9 @@ class QuantileRegressionDQNAgent(ValueOptimizationAgent): (self.networks['main'].online_network, batch.states(network_keys)) ]) + # add Q value samples for logging + self.q_values.add_sample(self.get_q_values(current_quantiles)) + # get the optimal actions to take for the next states target_actions = np.argmax(self.get_q_values(next_state_quantiles), axis=1) diff --git a/rl_coach/agents/rainbow_dqn_agent.py b/rl_coach/agents/rainbow_dqn_agent.py index d47d3a4..4973670 100644 --- a/rl_coach/agents/rainbow_dqn_agent.py +++ b/rl_coach/agents/rainbow_dqn_agent.py @@ -95,6 +95,9 @@ class RainbowDQNAgent(CategoricalDQNAgent): (self.networks['main'].online_network, batch.states(network_keys)) ]) + # add Q value samples for logging + self.q_values.add_sample(self.distribution_prediction_to_q_values(TD_targets)) + # only update the action that we have actually done in this transition (using the Double-DQN selected actions) target_actions = ddqn_selected_actions m = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size)) diff --git a/rl_coach/agents/value_optimization_agent.py b/rl_coach/agents/value_optimization_agent.py index 9771ae5..917435f 100644 --- a/rl_coach/agents/value_optimization_agent.py +++ b/rl_coach/agents/value_optimization_agent.py @@ -13,16 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +from collections import OrderedDict from typing import Union import numpy as np from rl_coach.agents.agent import Agent -from rl_coach.core_types import ActionInfo, StateType +from rl_coach.core_types import ActionInfo, StateType, Batch +from rl_coach.logger import screen from rl_coach.memories.non_episodic.prioritized_experience_replay import PrioritizedExperienceReplay from rl_coach.spaces import DiscreteActionSpace - +from copy import deepcopy ## This is an abstract agent - there is no learn_from_batch method ## @@ -79,10 +80,12 @@ class ValueOptimizationAgent(Agent): # this is for bootstrapped dqn if type(actions_q_values) == list and len(actions_q_values) > 0: actions_q_values = self.exploration_policy.last_action_values - actions_q_values = actions_q_values.squeeze() # store the q values statistics for logging self.q_values.add_sample(actions_q_values) + + actions_q_values = actions_q_values.squeeze() + for i, q_value in enumerate(actions_q_values): self.q_value_for_action[i].add_sample(q_value) @@ -96,3 +99,74 @@ class ValueOptimizationAgent(Agent): def learn_from_batch(self, batch): raise NotImplementedError("ValueOptimizationAgent is an abstract agent. Not to be used directly.") + + def run_off_policy_evaluation(self): + """ + Run the off-policy evaluation estimators to get a prediction for the performance of the current policy based on + an evaluation dataset, which was collected by another policy(ies). + :return: None + """ + assert self.ope_manager + dataset_as_episodes = self.call_memory('get_all_complete_episodes_from_to', + (self.call_memory('get_last_training_set_episode_id') + 1, + self.call_memory('num_complete_episodes'))) + if len(dataset_as_episodes) == 0: + raise ValueError('train_to_eval_ratio is too high causing the evaluation set to be empty. ' + 'Consider decreasing its value.') + + ips, dm, dr, seq_dr = self.ope_manager.evaluate( + dataset_as_episodes=dataset_as_episodes, + batch_size=self.ap.network_wrappers['main'].batch_size, + discount_factor=self.ap.algorithm.discount, + reward_model=self.networks['reward_model'].online_network, + q_network=self.networks['main'].online_network, + network_keys=list(self.ap.network_wrappers['main'].input_embedders_parameters.keys())) + + # get the estimators out to the screen + log = OrderedDict() + log['Epoch'] = self.training_epoch + log['IPS'] = ips + log['DM'] = dm + log['DR'] = dr + log['Sequential-DR'] = seq_dr + screen.log_dict(log, prefix='Off-Policy Evaluation') + + # get the estimators out to dashboard + self.agent_logger.set_current_time(self.get_current_time() + 1) + self.agent_logger.create_signal_value('Inverse Propensity Score', ips) + self.agent_logger.create_signal_value('Direct Method Reward', dm) + self.agent_logger.create_signal_value('Doubly Robust', dr) + self.agent_logger.create_signal_value('Sequential Doubly Robust', seq_dr) + + def improve_reward_model(self, epochs: int): + """ + Train a reward model to be used by the doubly-robust estimator + + :param epochs: The total number of epochs to use for training a reward model + :return: None + """ + batch_size = self.ap.network_wrappers['reward_model'].batch_size + network_keys = self.ap.network_wrappers['reward_model'].input_embedders_parameters.keys() + + # this is fitted from the training dataset + for epoch in range(epochs): + loss = 0 + for i, batch in enumerate(self.call_memory('get_shuffled_data_generator', batch_size)): + batch = Batch(batch) + current_rewards_prediction_for_all_actions = self.networks['reward_model'].online_network.predict(batch.states(network_keys)) + current_rewards_prediction_for_all_actions[range(batch_size), batch.actions()] = batch.rewards() + loss += self.networks['reward_model'].train_and_sync_networks( + batch.states(network_keys), current_rewards_prediction_for_all_actions)[0] + # print(self.networks['reward_model'].online_network.predict(batch.states(network_keys))[0]) + + log = OrderedDict() + log['Epoch'] = epoch + log['loss'] = loss / int(self.call_memory('num_transitions_in_complete_episodes') / batch_size) + screen.log_dict(log, prefix='Training Reward Model') + + + + + + + diff --git a/rl_coach/architectures/tensorflow_components/heads/q_head.py b/rl_coach/architectures/tensorflow_components/heads/q_head.py index eedec5b..135639c 100644 --- a/rl_coach/architectures/tensorflow_components/heads/q_head.py +++ b/rl_coach/architectures/tensorflow_components/heads/q_head.py @@ -50,6 +50,11 @@ class QHead(Head): # Standard Q Network self.output = self.dense_layer(self.num_actions)(input_layer, name='output') + # TODO add this to other Q heads. e.g. dueling. + temperature = self.ap.network_wrappers[self.network_name].softmax_temperature + temperature_scaled_outputs = self.output / temperature + self.softmax = tf.nn.softmax(temperature_scaled_outputs, name="softmax") + def __str__(self): result = [ "Dense (num outputs = {})".format(self.num_actions) diff --git a/rl_coach/architectures/tensorflow_components/savers.py b/rl_coach/architectures/tensorflow_components/savers.py index 67c0c8b..ea5b1b8 100644 --- a/rl_coach/architectures/tensorflow_components/savers.py +++ b/rl_coach/architectures/tensorflow_components/savers.py @@ -42,7 +42,7 @@ class GlobalVariableSaver(Saver): self._variable_placeholders.append(variable_placeholder) self._variable_update_ops.append(v.assign(variable_placeholder)) - self._saver = tf.train.Saver(self._variables) + self._saver = tf.train.Saver(self._variables, max_to_keep=None) @property def path(self): diff --git a/rl_coach/base_parameters.py b/rl_coach/base_parameters.py index bfab69c..e462e2b 100644 --- a/rl_coach/base_parameters.py +++ b/rl_coach/base_parameters.py @@ -291,7 +291,8 @@ class NetworkParameters(Parameters): batch_size=32, replace_mse_with_huber_loss=False, create_target_network=False, - tensorflow_support=True): + tensorflow_support=True, + softmax_temperature=1): """ :param force_cpu: Force the neural networks to run on the CPU even if a GPU is available @@ -374,6 +375,8 @@ class NetworkParameters(Parameters): online network at will. :param tensorflow_support: A flag which specifies if the network is supported by the TensorFlow framework. + :param softmax_temperature: + If a softmax is present in the network head output, use this temperature """ super().__init__() self.framework = Frameworks.tensorflow @@ -404,17 +407,20 @@ class NetworkParameters(Parameters): self.heads_parameters = heads_parameters self.use_separate_networks_per_head = use_separate_networks_per_head self.optimizer_type = optimizer_type - self.optimizer_epsilon = optimizer_epsilon - self.adam_optimizer_beta1 = adam_optimizer_beta1 - self.adam_optimizer_beta2 = adam_optimizer_beta2 - self.rms_prop_optimizer_decay = rms_prop_optimizer_decay - self.batch_size = batch_size self.replace_mse_with_huber_loss = replace_mse_with_huber_loss self.create_target_network = create_target_network # Framework support self.tensorflow_support = tensorflow_support + # Hyper-Parameter values + self.optimizer_epsilon = optimizer_epsilon + self.adam_optimizer_beta1 = adam_optimizer_beta1 + self.adam_optimizer_beta2 = adam_optimizer_beta2 + self.rms_prop_optimizer_decay = rms_prop_optimizer_decay + self.batch_size = batch_size + self.softmax_temperature = softmax_temperature + class NetworkComponentParameters(Parameters): def __init__(self, dense_layer): @@ -544,6 +550,7 @@ class AgentParameters(Parameters): self.is_a_highest_level_agent = True self.is_a_lowest_level_agent = True self.task_parameters = None + self.is_batch_rl_training = False @property def path(self): diff --git a/rl_coach/core_types.py b/rl_coach/core_types.py index 90374a8..6321e9e 100644 --- a/rl_coach/core_types.py +++ b/rl_coach/core_types.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +from collections import namedtuple import copy from enum import Enum @@ -38,6 +38,17 @@ class GoalTypes(Enum): Measurements = 4 +Record = namedtuple('Record', ['name', 'label']) + + +class TimeTypes(Enum): + EpisodeNumber = Record(name='Episode #', label='Episode #') + TrainingIteration = Record(name='Training Iter', label='Training Iteration') + EnvironmentSteps = Record(name='Total steps', label='Total steps (per worker)') + WallClockTime = Record(name='Wall-Clock Time', label='Wall-Clock Time (minutes)') + Epoch = Record(name='Epoch', label='Epoch #') + + # step methods class StepMethod(object): @@ -249,6 +260,9 @@ class Transition(object): .format(self.info.keys(), new_info.keys())) self.info.update(new_info) + def update_info(self, new_info: Dict[str, Any]) -> None: + self.info.update(new_info) + def __copy__(self): new_transition = type(self)() new_transition.__dict__.update(self.__dict__) @@ -867,3 +881,16 @@ class SelectedPhaseOnlyDumpFilter(object): return True else: return False + + +# TODO move to a NamedTuple, once we move to Python3.6 +# https://stackoverflow.com/questions/34269772/type-hints-in-namedtuple/34269877 +class CsvDataset(object): + def __init__(self, filepath: str, is_episodic: bool = True): + self.filepath = filepath + self.is_episodic = is_episodic + + +class PickledReplayBuffer(object): + def __init__(self, filepath: str): + self.filepath = filepath diff --git a/rl_coach/dashboard_components/experiment_board.py b/rl_coach/dashboard_components/experiment_board.py index e77ffb7..ffb40fd 100644 --- a/rl_coach/dashboard_components/experiment_board.py +++ b/rl_coach/dashboard_components/experiment_board.py @@ -273,6 +273,11 @@ def create_files_signal(files, use_dir_name=False): files_selector.value = filenames[0] selected_file = new_signal_files[0] + # update x axis according to the file's default x-axis (which is the index, and thus the first column) + idx = x_axis_options.index(new_signal_files[0].csv.columns[0]) + change_x_axis(idx) + x_axis_selector.active = idx + def display_files(files): pause_auto_update() diff --git a/rl_coach/dashboard_components/globals.py b/rl_coach/dashboard_components/globals.py index 397b182..495a92f 100644 --- a/rl_coach/dashboard_components/globals.py +++ b/rl_coach/dashboard_components/globals.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +from collections import OrderedDict import os from genericpath import isdir, isfile @@ -26,12 +26,14 @@ import tkinter as tk from tkinter import filedialog import colorsys +from rl_coach.core_types import TimeTypes + patches = {} signals_files = {} selected_file = None x_axis = ['Episode #'] -x_axis_options = ['Episode #', 'Total steps', 'Wall-Clock Time'] -x_axis_labels = ['Episode #', 'Total steps (per worker)', 'Wall-Clock Time (minutes)'] +x_axis_options = [time_type.value.name for time_type in TimeTypes] +x_axis_labels = [time_type.value.label for time_type in TimeTypes] current_color = 0 # spinner diff --git a/rl_coach/dashboard_components/signals_file.py b/rl_coach/dashboard_components/signals_file.py index d0a102c..764c29f 100644 --- a/rl_coach/dashboard_components/signals_file.py +++ b/rl_coach/dashboard_components/signals_file.py @@ -67,7 +67,12 @@ class SignalsFile(SignalsFileBase): for k, v in new_csv.isna().all().items(): if v and k not in x_axis_options: del new_csv[k] - new_csv.fillna(value=0, inplace=True) + + # only fill the missing values that the previous interploation did not deal with (usually the ones before the + # first update to the signal was made). We do so again by interpolation (averaging the previous and the next + # values) + new_csv = ((new_csv.fillna(method='bfill') + new_csv.fillna(method='ffill')) / 2).fillna(method='bfill').fillna( + method='ffill') self.csv = new_csv diff --git a/rl_coach/graph_managers/basic_rl_graph_manager.py b/rl_coach/graph_managers/basic_rl_graph_manager.py index 3eb9604..643fd84 100644 --- a/rl_coach/graph_managers/basic_rl_graph_manager.py +++ b/rl_coach/graph_managers/basic_rl_graph_manager.py @@ -18,6 +18,7 @@ from typing import Tuple, List from rl_coach.base_parameters import AgentParameters, VisualizationParameters, TaskParameters, \ PresetValidationParameters from rl_coach.environments.environment import EnvironmentParameters, Environment +from rl_coach.filters.filter import NoInputFilter, NoOutputFilter from rl_coach.graph_managers.graph_manager import GraphManager, ScheduleParameters from rl_coach.level_manager import LevelManager from rl_coach.utils import short_dynamic_import @@ -31,17 +32,28 @@ class BasicRLGraphManager(GraphManager): def __init__(self, agent_params: AgentParameters, env_params: EnvironmentParameters, schedule_params: ScheduleParameters, vis_params: VisualizationParameters=VisualizationParameters(), - preset_validation_params: PresetValidationParameters = PresetValidationParameters()): - super().__init__('simple_rl_graph', schedule_params, vis_params) + preset_validation_params: PresetValidationParameters = PresetValidationParameters(), + name='simple_rl_graph'): + super().__init__(name, schedule_params, vis_params) self.agent_params = agent_params self.env_params = env_params self.preset_validation_params = preset_validation_params - self.agent_params.visualization = vis_params + if self.agent_params.input_filter is None: - self.agent_params.input_filter = env_params.default_input_filter() + if env_params is not None: + self.agent_params.input_filter = env_params.default_input_filter() + else: + # In cases where there is no environment (e.g. batch-rl and imitation learning), there is nowhere to get + # a default filter from. So using a default no-filter. + # When there is no environment, the user is expected to define input/output filters (if required) using + # the preset. + self.agent_params.input_filter = NoInputFilter() if self.agent_params.output_filter is None: - self.agent_params.output_filter = env_params.default_output_filter() + if env_params is not None: + self.agent_params.output_filter = env_params.default_output_filter() + else: + self.agent_params.output_filter = NoOutputFilter() def _create_graph(self, task_parameters: TaskParameters) -> Tuple[List[LevelManager], List[Environment]]: # environment loading diff --git a/rl_coach/graph_managers/batch_rl_graph_manager.py b/rl_coach/graph_managers/batch_rl_graph_manager.py new file mode 100644 index 0000000..8bd90a0 --- /dev/null +++ b/rl_coach/graph_managers/batch_rl_graph_manager.py @@ -0,0 +1,180 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from copy import deepcopy +from typing import Tuple, List, Union + +from rl_coach.agents.dqn_agent import DQNAgentParameters +from rl_coach.base_parameters import AgentParameters, VisualizationParameters, TaskParameters, \ + PresetValidationParameters +from rl_coach.core_types import RunPhase +from rl_coach.environments.environment import EnvironmentParameters, Environment +from rl_coach.graph_managers.graph_manager import ScheduleParameters +from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager + +from rl_coach.level_manager import LevelManager +from rl_coach.logger import screen +from rl_coach.spaces import SpacesDefinition +from rl_coach.utils import short_dynamic_import + +from rl_coach.memories.episodic import EpisodicExperienceReplayParameters + +from rl_coach.core_types import TimeTypes + + +class BatchRLGraphManager(BasicRLGraphManager): + """ + A batch RL graph manager creates scenario of learning from a dataset without a simulator. + """ + def __init__(self, agent_params: AgentParameters, env_params: Union[EnvironmentParameters, None], + schedule_params: ScheduleParameters, + vis_params: VisualizationParameters = VisualizationParameters(), + preset_validation_params: PresetValidationParameters = PresetValidationParameters(), + name='batch_rl_graph', spaces_definition: SpacesDefinition = None, reward_model_num_epochs: int = 100, + train_to_eval_ratio: float = 0.8): + + super().__init__(agent_params, env_params, schedule_params, vis_params, preset_validation_params, name) + self.is_batch_rl = True + self.time_metric = TimeTypes.Epoch + self.reward_model_num_epochs = reward_model_num_epochs + self.spaces_definition = spaces_definition + + # setting this here to make sure that, by default, train_to_eval_ratio gets a value < 1 + # (its default value in the memory is 1) + self.agent_params.memory.train_to_eval_ratio = train_to_eval_ratio + + def _create_graph(self, task_parameters: TaskParameters) -> Tuple[List[LevelManager], List[Environment]]: + if self.env_params: + # environment loading + self.env_params.seed = task_parameters.seed + self.env_params.experiment_path = task_parameters.experiment_path + env = short_dynamic_import(self.env_params.path)(**self.env_params.__dict__, + visualization_parameters=self.visualization_parameters) + else: + env = None + + # Only DQN variants are supported at this point. + assert(isinstance(self.agent_params, DQNAgentParameters)) + # Only Episodic memories are supported, + # for evaluating the sequential doubly robust estimator + assert(isinstance(self.agent_params.memory, EpisodicExperienceReplayParameters)) + + # agent loading + self.agent_params.task_parameters = task_parameters # TODO: this should probably be passed in a different way + self.agent_params.name = "agent" + self.agent_params.is_batch_rl_training = True + + # user hasn't defined params for the reward model. we will use the same params as used for the 'main' network. + if 'reward_model' not in self.agent_params.network_wrappers: + self.agent_params.network_wrappers['reward_model'] = deepcopy(self.agent_params.network_wrappers['main']) + + agent = short_dynamic_import(self.agent_params.path)(self.agent_params) + + if not env and not self.agent_params.memory.load_memory_from_file_path: + screen.warning("A BatchRLGraph requires setting a dataset to load into the agent's memory or alternatively " + "using an environment to create a (random) dataset from. This agent should only be used for " + "inference. ") + # set level manager + level_manager = LevelManager(agents=agent, environment=env, name="main_level", + spaces_definition=self.spaces_definition) + + if env: + return [level_manager], [env] + else: + return [level_manager], [] + + def improve(self): + """ + The main loop of the run. + Defined in the following steps: + 1. Heatup + 2. Repeat: + 2.1. Repeat: + 2.1.1. Train + 2.1.2. Possibly save checkpoint + 2.2. Evaluate + :return: None + """ + + self.verify_graph_was_created() + + # initialize the network parameters from the global network + self.sync() + + # TODO a bug in heatup where the last episode run is not fed into the ER. e.g. asked for 1024 heatup steps, + # last ran episode ended increased the total to 1040 steps, but the ER will contain only 1014 steps. + # The last episode is not there. Is this a bug in my changes or also on master? + + # Creating a dataset during the heatup phase is useful mainly for tutorial and debug purposes. If we have both + # an environment and a dataset to load from, we will use the environment only for evaluating the policy, + # and will not run heatup. + + # heatup + if self.env_params is not None and not self.agent_params.memory.load_memory_from_file_path: + self.heatup(self.heatup_steps) + + self.improve_reward_model() + + # improve + if self.task_parameters.task_index is not None: + screen.log_title("Starting to improve {} task index {}".format(self.name, self.task_parameters.task_index)) + else: + screen.log_title("Starting to improve {}".format(self.name)) + + # the outer most training loop + improve_steps_end = self.total_steps_counters[RunPhase.TRAIN] + self.improve_steps + while self.total_steps_counters[RunPhase.TRAIN] < improve_steps_end: + # TODO if we have an environment, do we want to use it to have the agent train against, and use the + # collected replay buffer as a dataset? (as oppose to what we currently have, where the dataset is built + # during heatup, and is composed on random actions) + # perform several steps of training + if self.steps_between_evaluation_periods.num_steps > 0: + with self.phase_context(RunPhase.TRAIN): + self.reset_internal_state(force_environment_reset=True) + + steps_between_evaluation_periods_end = self.current_step_counter + self.steps_between_evaluation_periods + while self.current_step_counter < steps_between_evaluation_periods_end: + self.train() + + # the output of batch RL training is always a checkpoint of the trained agent. we always save a checkpoint, + # each epoch, regardless of the user's command line arguments. + self.save_checkpoint() + + # run off-policy evaluation estimators to evaluate the agent's performance against the dataset + self.run_off_policy_evaluation() + + if self.env_params is not None and self.evaluate(self.evaluation_steps): + # if we do have a simulator (although we are in a batch RL setting we might have a simulator, e.g. when + # demonstrating the batch RL use-case using one of the existing Coach environments), + # we might want to evaluate vs. the simulator every now and then. + break + + def improve_reward_model(self): + """ + + :return: + """ + screen.log_title("Training a regression model for estimating MDP rewards") + self.level_managers[0].agents['agent'].improve_reward_model(epochs=self.reward_model_num_epochs) + + def run_off_policy_evaluation(self): + """ + Run off-policy evaluation estimators to evaluate the trained policy performance against the dataset + :return: + """ + self.level_managers[0].agents['agent'].run_off_policy_evaluation() + + + diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index d6a0d7a..d8618d7 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -38,6 +38,8 @@ from rl_coach.data_stores.data_store_impl import get_data_store as data_store_cr from rl_coach.memories.backend.memory_impl import get_memory_backend from rl_coach.data_stores.data_store import SyncFiles +from rl_coach.core_types import TimeTypes + class ScheduleParameters(Parameters): def __init__(self): @@ -119,6 +121,8 @@ class GraphManager(object): self.checkpoint_state_updater = None self.graph_logger = Logger() self.data_store = None + self.is_batch_rl = False + self.time_metric = TimeTypes.EpisodeNumber def create_graph(self, task_parameters: TaskParameters=TaskParameters()): self.graph_creation_time = time.time() @@ -445,16 +449,17 @@ class GraphManager(object): result = self.top_level_manager.step(None) steps_end = self.environments[0].total_steps_counter - # add the diff between the total steps before and after stepping, such that environment initialization steps - # (like in Atari) will not be counted. - # We add at least one step so that even if no steps were made (in case no actions are taken in the training - # phase), the loop will end eventually. - self.current_step_counter[EnvironmentSteps] += max(1, steps_end - steps_begin) - if result.game_over: self.handle_episode_ended() self.reset_required = True + self.current_step_counter[EnvironmentSteps] += (steps_end - steps_begin) + + # if no steps were made (can happen when no actions are taken while in the TRAIN phase, either in batch RL + # or in imitation learning), we force end the loop, so that it will not continue forever. + if (steps_end - steps_begin) == 0: + break + def train_and_act(self, steps: StepMethod) -> None: """ Train the agent by doing several acting steps followed by several training steps continually @@ -472,9 +477,9 @@ class GraphManager(object): while self.current_step_counter < count_end: # The actual number of steps being done on the environment # is decided by the agent, though this inner loop always - # takes at least one step in the environment. Depending on - # internal counters and parameters, it doesn't always train - # or save checkpoints. + # takes at least one step in the environment (at the GraphManager level). + # The agent might also decide to skip acting altogether. + # Depending on internal counters and parameters, it doesn't always train or save checkpoints. self.act(EnvironmentSteps(1)) self.train() self.occasionally_save_checkpoint() diff --git a/rl_coach/level_manager.py b/rl_coach/level_manager.py index 2bbfad2..945ab6c 100644 --- a/rl_coach/level_manager.py +++ b/rl_coach/level_manager.py @@ -40,9 +40,10 @@ class LevelManager(EnvironmentInterface): name: str, agents: Union['Agent', CompositeAgent, Dict[str, Union['Agent', CompositeAgent]]], environment: Union['LevelManager', Environment], - real_environment: Environment=None, - steps_limit: EnvironmentSteps=EnvironmentSteps(1), - should_reset_agent_state_after_time_limit_passes: bool=False + real_environment: Environment = None, + steps_limit: EnvironmentSteps = EnvironmentSteps(1), + should_reset_agent_state_after_time_limit_passes: bool = False, + spaces_definition: SpacesDefinition = None ): """ A level manager controls a single or multiple composite agents and a single environment. @@ -56,6 +57,7 @@ class LevelManager(EnvironmentInterface): :param steps_limit: the number of time steps to run when stepping the internal components :param should_reset_agent_state_after_time_limit_passes: reset the agent after stepping for steps_limit :param name: the level's name + :param spaces_definition: external definition of spaces for when we don't have an environment (e.g. batch-rl) """ super().__init__() @@ -85,9 +87,11 @@ class LevelManager(EnvironmentInterface): if not isinstance(self.steps_limit, EnvironmentSteps): raise ValueError("The num consecutive steps for acting must be defined in terms of environment steps") - self.build() + self.build(spaces_definition) + + # there are cases where we don't have an environment. e.g. in batch-rl or in imitation learning. + self.last_env_response = self.real_environment.last_env_response if self.real_environment else None - self.last_env_response = self.real_environment.last_env_response self.parent_graph_manager = None def handle_episode_ended(self) -> None: @@ -100,13 +104,13 @@ class LevelManager(EnvironmentInterface): def reset_internal_state(self, force_environment_reset: bool = False) -> EnvResponse: """ Reset the environment episode parameters - :param force_enviro nment_reset: in some cases, resetting the environment can be suppressed by the environment + :param force_environment_reset: in some cases, resetting the environment can be suppressed by the environment itself. This flag allows force the reset. :return: the environment response as returned in get_last_env_response """ [agent.reset_internal_state() for agent in self.agents.values()] self.reset_required = False - if self.real_environment.current_episode_steps_counter == 0: + if self.real_environment and self.real_environment.current_episode_steps_counter == 0: self.last_env_response = self.real_environment.last_env_response return self.last_env_response @@ -136,19 +140,27 @@ class LevelManager(EnvironmentInterface): """ return {k: ActionInfo(v) for k, v in self.get_random_action().items()} - def build(self) -> None: + def build(self, spaces_definition: SpacesDefinition = None) -> None: """ Build all the internal components of the level manager (composite agents and environment). + :param spaces_definition: external definition of spaces for when we don't have an environment (e.g. batch-rl) :return: None """ - # TODO: move the spaces definition class to the environment? - action_space = self.environment.action_space - if isinstance(action_space, dict): # TODO: shouldn't be a dict - action_space = list(action_space.values())[0] - spaces = SpacesDefinition(state=self.real_environment.state_space, - goal=self.real_environment.goal_space, # in HRL the agent might want to override this - action=action_space, - reward=self.real_environment.reward_space) + if spaces_definition is None: + # normally the spaces are defined by the environment, and we only gather these here + action_space = self.environment.action_space + + if isinstance(action_space, dict): # TODO: shouldn't be a dict + action_space = list(action_space.values())[0] + + spaces = SpacesDefinition(state=self.real_environment.state_space, + goal=self.real_environment.goal_space, + # in HRL the agent might want to override this + action=action_space, + reward=self.real_environment.reward_space) + else: + spaces = spaces_definition + [agent.set_environment_parameters(spaces) for agent in self.agents.values()] def setup_logger(self) -> None: diff --git a/rl_coach/logger.py b/rl_coach/logger.py index 7a26c07..149b2d3 100644 --- a/rl_coach/logger.py +++ b/rl_coach/logger.py @@ -185,6 +185,9 @@ class BaseLogger(object): self.time = time def create_signal_value(self, signal_name, value, overwrite=True, time=None): + if self.index_name == signal_name: + return False # make sure that we don't create duplicate signals + if self.last_line_idx_written_to_csv != 0: assert signal_name in self.data.columns @@ -227,12 +230,15 @@ class BaseLogger(object): self.last_line_idx_written_to_csv = len(self.data.index) - def update_wall_clock_time(self, index): + def get_current_wall_clock_time(self): if self.start_time: - self.create_signal_value('Wall-Clock Time', time.time() - self.start_time, time=index) + return time.time() - self.start_time else: - self.create_signal_value('Wall-Clock Time', 0, time=index) self.start_time = time.time() + return 0 + + def update_wall_clock_time(self, index): + self.create_signal_value('Wall-Clock Time', self.get_current_wall_clock_time(), time=index) class EpisodeLogger(BaseLogger): @@ -263,10 +269,13 @@ class EpisodeLogger(BaseLogger): class Logger(BaseLogger): - def __init__(self): + def __init__(self, index_name='Episode #'): super().__init__() self.doc_path = '' - self.index_name = 'Episode #' + self.index_name = index_name + + def set_index_name(self, index_name): + self.index_name = index_name def set_logger_filenames(self, _experiments_path, logger_prefix='', task_id=None, add_timestamp=False, filename=''): self.experiments_path = _experiments_path diff --git a/rl_coach/memories/episodic/episodic_experience_replay.py b/rl_coach/memories/episodic/episodic_experience_replay.py index fab05f7..bc1e72a 100644 --- a/rl_coach/memories/episodic/episodic_experience_replay.py +++ b/rl_coach/memories/episodic/episodic_experience_replay.py @@ -1,4 +1,5 @@ # +# # Copyright (c) 2017 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,14 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import ast -from typing import List, Tuple, Union, Dict, Any - +import pandas as pd +from typing import List, Tuple, Union import numpy as np +import random from rl_coach.core_types import Transition, Episode +from rl_coach.logger import screen from rl_coach.memories.memory import Memory, MemoryGranularity, MemoryParameters -from rl_coach.utils import ReaderWriterLock +from rl_coach.utils import ReaderWriterLock, ProgressBar +from rl_coach.core_types import CsvDataset class EpisodicExperienceReplayParameters(MemoryParameters): @@ -28,6 +33,7 @@ class EpisodicExperienceReplayParameters(MemoryParameters): super().__init__() self.max_size = (MemoryGranularity.Transitions, 1000000) self.n_step = -1 + self.train_to_eval_ratio = 1 # for OPE we'll want a value < 1 @property def path(self): @@ -40,7 +46,9 @@ class EpisodicExperienceReplay(Memory): calculations of total return and other values that depend on the sequential behavior of the transitions in the episode. """ - def __init__(self, max_size: Tuple[MemoryGranularity, int]=(MemoryGranularity.Transitions, 1000000), n_step=-1): + + def __init__(self, max_size: Tuple[MemoryGranularity, int] = (MemoryGranularity.Transitions, 1000000), n_step=-1, + train_to_eval_ratio: int = 1): """ :param max_size: the maximum number of transitions or episodes to hold in the memory """ @@ -52,8 +60,11 @@ class EpisodicExperienceReplay(Memory): self._num_transitions = 0 self._num_transitions_in_complete_episodes = 0 self.reader_writer_lock = ReaderWriterLock() + self.last_training_set_episode_id = None # used in batch-rl + self.last_training_set_transition_id = None # used in batch-rl + self.train_to_eval_ratio = train_to_eval_ratio # used in batch-rl - def length(self, lock: bool=False) -> int: + def length(self, lock: bool = False) -> int: """ Get the number of episodes in the ER (even if they are not complete) """ @@ -75,6 +86,9 @@ class EpisodicExperienceReplay(Memory): def num_transitions_in_complete_episodes(self): return self._num_transitions_in_complete_episodes + def get_last_training_set_episode_id(self): + return self.last_training_set_episode_id + def sample(self, size: int, is_consecutive_transitions=False) -> List[Transition]: """ Sample a batch of transitions from the replay buffer. If the requested size is larger than the number @@ -92,7 +106,7 @@ class EpisodicExperienceReplay(Memory): batch = self._buffer[episode_idx].transitions else: transition_idx = np.random.randint(size, self._buffer[episode_idx].length()) - batch = self._buffer[episode_idx].transitions[transition_idx-size:transition_idx] + batch = self._buffer[episode_idx].transitions[transition_idx - size:transition_idx] else: transitions_idx = np.random.randint(self.num_transitions_in_complete_episodes(), size=size) batch = [self.transitions[i] for i in transitions_idx] @@ -105,6 +119,79 @@ class EpisodicExperienceReplay(Memory): return batch + def get_episode_for_transition(self, transition: Transition) -> (int, Episode): + """ + Get the episode from which that transition came from. + :param transition: The transition to lookup the episode for + :return: (Episode number, the episode) or (-1, None) if could not find a matching episode. + """ + + for i, episode in enumerate(self._buffer): + if transition in episode.transitions: + return i, episode + return -1, None + + def shuffle_episodes(self): + """ + Shuffle all the episodes in the replay buffer + :return: + """ + random.shuffle(self._buffer) + self.transitions = [t for e in self._buffer for t in e.transitions] + + def get_shuffled_data_generator(self, size: int) -> List[Transition]: + """ + Get an generator for iterating through the shuffled replay buffer, for processing the data in epochs. + If the requested size is larger than the number of samples available in the replay buffer then the batch will + return empty. The last returned batch may be smaller than the size requested, to accommodate for all the + transitions in the replay buffer. + + :param size: the size of the batch to return + :return: a batch (list) of selected transitions from the replay buffer + """ + self.reader_writer_lock.lock_writing() + if self.last_training_set_transition_id is None: + if self.train_to_eval_ratio < 0 or self.train_to_eval_ratio >= 1: + raise ValueError('train_to_eval_ratio should be in the (0, 1] range.') + + transition = self.transitions[round(self.train_to_eval_ratio * self.num_transitions_in_complete_episodes())] + episode_num, episode = self.get_episode_for_transition(transition) + self.last_training_set_episode_id = episode_num + self.last_training_set_transition_id = \ + len([t for e in self.get_all_complete_episodes_from_to(0, self.last_training_set_episode_id + 1) for t in e]) + + shuffled_transition_indices = list(range(self.last_training_set_transition_id)) + random.shuffle(shuffled_transition_indices) + + # we deliberately drop some of the ending data which is left after dividing to batches of size `size` + # for i in range(math.ceil(len(shuffled_transition_indices) / size)): + for i in range(int(len(shuffled_transition_indices) / size)): + sample_data = [self.transitions[j] for j in shuffled_transition_indices[i * size: (i + 1) * size]] + self.reader_writer_lock.release_writing() + + yield sample_data + + def get_all_complete_episodes_transitions(self) -> List[Transition]: + """ + Get all the transitions from all the complete episodes in the buffer + :return: a list of transitions + """ + return self.transitions[:self.num_transitions_in_complete_episodes()] + + def get_all_complete_episodes(self) -> List[Episode]: + """ + Get all the transitions from all the complete episodes in the buffer + :return: a list of transitions + """ + return self.get_all_complete_episodes_from_to(0, self.num_complete_episodes()) + + def get_all_complete_episodes_from_to(self, start_episode_id, end_episode_id) -> List[Episode]: + """ + Get all the transitions from all the complete episodes in the buffer matching the given episode range + :return: a list of transitions + """ + return self._buffer[start_episode_id:end_episode_id] + def _enforce_max_length(self) -> None: """ Make sure that the size of the replay buffer does not pass the maximum size allowed. @@ -188,7 +275,7 @@ class EpisodicExperienceReplay(Memory): self.reader_writer_lock.release_writing_and_reading() - def store_episode(self, episode: Episode, lock: bool=True) -> None: + def store_episode(self, episode: Episode, lock: bool = True) -> None: """ Store a new episode in the memory. :param episode: the new episode to store @@ -211,7 +298,7 @@ class EpisodicExperienceReplay(Memory): if lock: self.reader_writer_lock.release_writing_and_reading() - def get_episode(self, episode_index: int, lock: bool=True) -> Union[None, Episode]: + def get_episode(self, episode_index: int, lock: bool = True) -> Union[None, Episode]: """ Returns the episode in the given index. If the episode does not exist, returns None instead. :param episode_index: the index of the episode to return @@ -256,7 +343,7 @@ class EpisodicExperienceReplay(Memory): self.reader_writer_lock.release_writing_and_reading() # for API compatibility - def get(self, episode_index: int, lock: bool=True) -> Union[None, Episode]: + def get(self, episode_index: int, lock: bool = True) -> Union[None, Episode]: """ Returns the episode in the given index. If the episode does not exist, returns None instead. :param episode_index: the index of the episode to return @@ -315,3 +402,47 @@ class EpisodicExperienceReplay(Memory): self.reader_writer_lock.release_writing() return mean + + def load_csv(self, csv_dataset: CsvDataset) -> None: + """ + Restore the replay buffer contents from a csv file. + The csv file is assumed to include a list of transitions. + :param csv_dataset: A construct which holds the dataset parameters + """ + df = pd.read_csv(csv_dataset.filepath) + if len(df) > self.max_size[1]: + screen.warning("Warning! The number of transitions to load into the replay buffer ({}) is " + "bigger than the max size of the replay buffer ({}). The excessive transitions will " + "not be stored.".format(len(df), self.max_size[1])) + + episode_ids = df['episode_id'].unique() + progress_bar = ProgressBar(len(episode_ids)) + state_columns = [col for col in df.columns if col.startswith('state_feature')] + + for e_id in episode_ids: + progress_bar.update(e_id) + df_episode_transitions = df[df['episode_id'] == e_id] + episode = Episode() + for (_, current_transition), (_, next_transition) in zip(df_episode_transitions[:-1].iterrows(), + df_episode_transitions[1:].iterrows()): + state = np.array([current_transition[col] for col in state_columns]) + next_state = np.array([next_transition[col] for col in state_columns]) + + episode.insert( + Transition(state={'observation': state}, + action=current_transition['action'], reward=current_transition['reward'], + next_state={'observation': next_state}, game_over=False, + info={'all_action_probabilities': + ast.literal_eval(current_transition['all_action_probabilities'])})) + + # Set the last transition to end the episode + if csv_dataset.is_episodic: + episode.get_last_transition().game_over = True + + self.store_episode(episode) + + # close the progress bar + progress_bar.update(len(episode_ids)) + progress_bar.close() + + self.shuffle_episodes() diff --git a/rl_coach/memories/non_episodic/experience_replay.py b/rl_coach/memories/non_episodic/experience_replay.py index b3a2043..f47d9b6 100644 --- a/rl_coach/memories/non_episodic/experience_replay.py +++ b/rl_coach/memories/non_episodic/experience_replay.py @@ -14,10 +14,10 @@ # limitations under the License. # -from typing import List, Tuple, Union, Dict, Any +from typing import List, Tuple, Union import pickle -import sys -import time +import random +import math import numpy as np @@ -72,7 +72,6 @@ class ExperienceReplay(Memory): Sample a batch of transitions form the replay buffer. If the requested size is larger than the number of samples available in the replay buffer then the batch will return empty. :param size: the size of the batch to sample - :param beta: the beta parameter used for importance sampling :return: a batch (list) of selected transitions from the replay buffer """ self.reader_writer_lock.lock_writing() @@ -92,6 +91,32 @@ class ExperienceReplay(Memory): self.reader_writer_lock.release_writing() return batch + def get_shuffled_data_generator(self, size: int) -> List[Transition]: + """ + Get an generator for iterating through the shuffled replay buffer, for processing the data in epochs. + If the requested size is larger than the number of samples available in the replay buffer then the batch will + return empty. The last returned batch may be smaller than the size requested, to accommodate for all the + transitions in the replay buffer. + + :param size: the size of the batch to return + :return: a batch (list) of selected transitions from the replay buffer + """ + self.reader_writer_lock.lock_writing() + shuffled_transition_indices = list(range(len(self.transitions))) + random.shuffle(shuffled_transition_indices) + + # we deliberately drop some of the ending data which is left after dividing to batches of size `size` + # for i in range(math.ceil(len(shuffled_transition_indices) / size)): + for i in range(int(len(shuffled_transition_indices) / size)): + sample_data = [self.transitions[j] for j in shuffled_transition_indices[i * size: (i + 1) * size]] + self.reader_writer_lock.release_writing() + + yield sample_data + + ## usage example + # for o in random_seq_generator(list(range(10)), 4): + # print(o) + def _enforce_max_length(self) -> None: """ Make sure that the size of the replay buffer does not pass the maximum size allowed. @@ -215,7 +240,7 @@ class ExperienceReplay(Memory): with open(file_path, 'wb') as file: pickle.dump(self.transitions, file) - def load(self, file_path: str) -> None: + def load_pickled(self, file_path: str) -> None: """ Restore the replay buffer contents from a pickle file. The pickle file is assumed to include a list of transitions. @@ -238,3 +263,4 @@ class ExperienceReplay(Memory): progress_bar.update(transition_idx) progress_bar.close() + diff --git a/rl_coach/off_policy_evaluators/__init__.py b/rl_coach/off_policy_evaluators/__init__.py new file mode 100644 index 0000000..9a6e67d --- /dev/null +++ b/rl_coach/off_policy_evaluators/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright (c) 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/rl_coach/off_policy_evaluators/bandits/__init__.py b/rl_coach/off_policy_evaluators/bandits/__init__.py new file mode 100644 index 0000000..9a6e67d --- /dev/null +++ b/rl_coach/off_policy_evaluators/bandits/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright (c) 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/rl_coach/off_policy_evaluators/bandits/doubly_robust.py b/rl_coach/off_policy_evaluators/bandits/doubly_robust.py new file mode 100644 index 0000000..1f1535e --- /dev/null +++ b/rl_coach/off_policy_evaluators/bandits/doubly_robust.py @@ -0,0 +1,40 @@ +# +# Copyright (c) 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import numpy as np + + +class DoublyRobust(object): + + @staticmethod + def evaluate(ope_shared_stats: 'OpeSharedStats') -> tuple: + """ + Run the off-policy evaluator to get a score for the goodness of the new policy, based on the dataset, + which was collected using other policy(ies). + + Papers: + https://arxiv.org/abs/1103.4601 + https://arxiv.org/pdf/1612.01205 (some more clearer explanations) + + :return: the evaluation score + """ + + ips = np.mean(ope_shared_stats.rho_all_dataset * ope_shared_stats.all_rewards) + dm = np.mean(ope_shared_stats.all_v_values_reward_model_based) + dr = np.mean(ope_shared_stats.rho_all_dataset * + (ope_shared_stats.all_rewards - ope_shared_stats.all_reward_model_rewards[ + range(len(ope_shared_stats.all_actions)), ope_shared_stats.all_actions])) + dm + + return ips, dm, dr diff --git a/rl_coach/off_policy_evaluators/ope_manager.py b/rl_coach/off_policy_evaluators/ope_manager.py new file mode 100644 index 0000000..1a64621 --- /dev/null +++ b/rl_coach/off_policy_evaluators/ope_manager.py @@ -0,0 +1,124 @@ +# +# Copyright (c) 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import namedtuple + +import numpy as np +from typing import List + +from rl_coach.architectures.architecture import Architecture +from rl_coach.core_types import Episode, Batch +from rl_coach.off_policy_evaluators.bandits.doubly_robust import DoublyRobust +from rl_coach.off_policy_evaluators.rl.sequential_doubly_robust import SequentialDoublyRobust + +from rl_coach.core_types import Transition + +OpeSharedStats = namedtuple("OpeSharedStats", ['all_reward_model_rewards', 'all_policy_probs', + 'all_v_values_reward_model_based', 'all_rewards', 'all_actions', + 'all_old_policy_probs', 'new_policy_prob', 'rho_all_dataset']) +OpeEstimation = namedtuple("OpeEstimation", ['ips', 'dm', 'dr', 'seq_dr']) + + +class OpeManager(object): + def __init__(self): + self.doubly_robust = DoublyRobust() + self.sequential_doubly_robust = SequentialDoublyRobust() + + @staticmethod + def _prepare_ope_shared_stats(dataset_as_transitions: List[Transition], batch_size: int, + reward_model: Architecture, q_network: Architecture, + network_keys: List) -> OpeSharedStats: + """ + Do the preparations needed for different estimators. + Some of the calcuations are shared, so we centralize all the work here. + + :param dataset_as_transitions: The evaluation dataset in the form of transitions. + :param batch_size: The batch size to use. + :param reward_model: A reward model to be used by DR + :param q_network: The Q network whose its policy we evaluate. + :param network_keys: The network keys used for feeding the neural networks. + :return: + """ + # IPS + all_reward_model_rewards, all_policy_probs, all_old_policy_probs = [], [], [] + all_v_values_reward_model_based, all_v_values_q_model_based, all_rewards, all_actions = [], [], [], [] + + for i in range(int(len(dataset_as_transitions) / batch_size) + 1): + batch = dataset_as_transitions[i * batch_size: (i + 1) * batch_size] + batch_for_inference = Batch(batch) + + all_reward_model_rewards.append(reward_model.predict( + batch_for_inference.states(network_keys))) + + # TODO can we get rid of the 'output_heads[0]', and have some way of a cleaner API? + q_values, sm_values = q_network.predict(batch_for_inference.states(network_keys), + outputs=[q_network.output_heads[0].output, + q_network.output_heads[0].softmax]) + # TODO why is this needed? + q_values = q_values[0] + + all_policy_probs.append(sm_values) + all_v_values_reward_model_based.append(np.sum(all_policy_probs[-1] * all_reward_model_rewards[-1], axis=1)) + all_v_values_q_model_based.append(np.sum(all_policy_probs[-1] * q_values, axis=1)) + all_rewards.append(batch_for_inference.rewards()) + all_actions.append(batch_for_inference.actions()) + all_old_policy_probs.append(batch_for_inference.info('all_action_probabilities') + [range(len(batch_for_inference.actions())), batch_for_inference.actions()]) + + for j, t in enumerate(batch): + t.update_info({ + 'q_value': q_values[j], + 'softmax_policy_prob': all_policy_probs[-1][j], + 'v_value_q_model_based': all_v_values_q_model_based[-1][j], + + }) + + all_reward_model_rewards = np.concatenate(all_reward_model_rewards, axis=0) + all_policy_probs = np.concatenate(all_policy_probs, axis=0) + all_v_values_reward_model_based = np.concatenate(all_v_values_reward_model_based, axis=0) + all_rewards = np.concatenate(all_rewards, axis=0) + all_actions = np.concatenate(all_actions, axis=0) + all_old_policy_probs = np.concatenate(all_old_policy_probs, axis=0) + + # generate model probabilities + new_policy_prob = all_policy_probs[np.arange(all_actions.shape[0]), all_actions] + rho_all_dataset = new_policy_prob / all_old_policy_probs + + return OpeSharedStats(all_reward_model_rewards, all_policy_probs, all_v_values_reward_model_based, + all_rewards, all_actions, all_old_policy_probs, new_policy_prob, rho_all_dataset) + + def evaluate(self, dataset_as_episodes: List[Episode], batch_size: int, discount_factor: float, + reward_model: Architecture, q_network: Architecture, network_keys: List) -> OpeEstimation: + """ + Run all the OPEs and get estimations of the current policy performance based on the evaluation dataset. + + :param dataset_as_episodes: The evaluation dataset. + :param batch_size: Batch size to use for the estimators. + :param discount_factor: The standard RL discount factor. + :param reward_model: A reward model to be used by DR + :param q_network: The Q network whose its policy we evaluate. + :param network_keys: The network keys used for feeding the neural networks. + + :return: An OpeEstimation tuple which groups together all the OPE estimations + """ + # TODO this seems kind of slow, review performance + dataset_as_transitions = [t for e in dataset_as_episodes for t in e.transitions] + ope_shared_stats = self._prepare_ope_shared_stats(dataset_as_transitions, batch_size, reward_model, + q_network, network_keys) + + ips, dm, dr = self.doubly_robust.evaluate(ope_shared_stats) + seq_dr = self.sequential_doubly_robust.evaluate(dataset_as_episodes, discount_factor) + return OpeEstimation(ips, dm, dr, seq_dr) + diff --git a/rl_coach/off_policy_evaluators/rl/__init__.py b/rl_coach/off_policy_evaluators/rl/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/rl_coach/off_policy_evaluators/rl/sequential_doubly_robust.py b/rl_coach/off_policy_evaluators/rl/sequential_doubly_robust.py new file mode 100644 index 0000000..633a747 --- /dev/null +++ b/rl_coach/off_policy_evaluators/rl/sequential_doubly_robust.py @@ -0,0 +1,51 @@ +# +# Copyright (c) 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import List +import numpy as np + +from rl_coach.core_types import Episode + + +class SequentialDoublyRobust(object): + + @staticmethod + def evaluate(dataset_as_episodes: List[Episode], discount_factor: float) -> float: + """ + Run the off-policy evaluator to get a score for the goodness of the new policy, based on the dataset, + which was collected using other policy(ies). + + Paper: https://arxiv.org/pdf/1511.03722.pdf + + :return: the evaluation score + """ + + # Sequential Doubly Robust + per_episode_seq_dr = [] + + for episode in dataset_as_episodes: + episode_seq_dr = 0 + for transition in episode.transitions: + rho = transition.info['softmax_policy_prob'][transition.action] / \ + transition.info['all_action_probabilities'][transition.action] + episode_seq_dr = transition.info['v_value_q_model_based'] + rho * (transition.reward + discount_factor + * episode_seq_dr - + transition.info['q_value'][ + transition.action]) + per_episode_seq_dr.append(episode_seq_dr) + + seq_dr = np.array(per_episode_seq_dr).mean() + + return seq_dr diff --git a/rl_coach/presets/CARLA_CIL.py b/rl_coach/presets/CARLA_CIL.py index 4d5efff..a2151c2 100644 --- a/rl_coach/presets/CARLA_CIL.py +++ b/rl_coach/presets/CARLA_CIL.py @@ -25,6 +25,7 @@ from rl_coach.graph_managers.graph_manager import ScheduleParameters from rl_coach.schedules import ConstantSchedule from rl_coach.spaces import ImageObservationSpace from rl_coach.utilities.carla_dataset_to_replay_buffer import create_dataset +from rl_coach.core_types import PickledReplayBuffer #################### # Graph Scheduling # @@ -130,7 +131,7 @@ agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(0) # use the following command line to download and extract the CARLA dataset: # python rl_coach/utilities/carla_dataset_to_replay_buffer.py -agent_params.memory.load_memory_from_file_path = "./datasets/carla_train_set_replay_buffer.p" +agent_params.memory.load_memory_from_file_path = PickledReplayBuffer("./datasets/carla_train_set_replay_buffer.p") agent_params.memory.state_key_with_the_class_index = 'high_level_command' agent_params.memory.num_classes = 4 diff --git a/rl_coach/presets/CartPole_DQN_BatchRL.py b/rl_coach/presets/CartPole_DQN_BatchRL.py new file mode 100644 index 0000000..ac020b0 --- /dev/null +++ b/rl_coach/presets/CartPole_DQN_BatchRL.py @@ -0,0 +1,91 @@ +from copy import deepcopy + +from rl_coach.agents.ddqn_agent import DDQNAgentParameters +from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters +from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps +from rl_coach.environments.gym_environment import GymVectorEnvironment +from rl_coach.filters.filter import InputFilter +from rl_coach.filters.reward import RewardRescaleFilter +from rl_coach.graph_managers.batch_rl_graph_manager import BatchRLGraphManager +from rl_coach.graph_managers.graph_manager import ScheduleParameters +from rl_coach.memories.memory import MemoryGranularity +from rl_coach.schedules import LinearSchedule +from rl_coach.memories.episodic import EpisodicExperienceReplayParameters + +DATASET_SIZE = 40000 + +#################### +# Graph Scheduling # +#################### + +schedule_params = ScheduleParameters() +schedule_params.improve_steps = TrainingSteps(10000000000) +schedule_params.steps_between_evaluation_periods = TrainingSteps(1) +schedule_params.evaluation_steps = EnvironmentEpisodes(10) +schedule_params.heatup_steps = EnvironmentSteps(DATASET_SIZE) + +######### +# Agent # +######### +# TODO add a preset which uses a dataset to train a BatchRL graph. e.g. save a cartpole dataset in a csv format. +agent_params = DDQNAgentParameters() +agent_params.network_wrappers['main'].batch_size = 1024 + +# DQN params +# agent_params.algorithm.num_steps_between_copying_online_weights_to_target = TrainingSteps(100) + +# For making this become Fitted Q-Iteration we can keep the targets constant for the entire dataset size - +agent_params.algorithm.num_steps_between_copying_online_weights_to_target = TrainingSteps( + DATASET_SIZE / agent_params.network_wrappers['main'].batch_size) + +agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(0) +agent_params.algorithm.discount = 0.98 +# agent_params.algorithm.discount = 1.0 + + +# NN configuration +agent_params.network_wrappers['main'].learning_rate = 0.0001 +agent_params.network_wrappers['main'].replace_mse_with_huber_loss = False +agent_params.network_wrappers['main'].l2_regularization = 0.0001 +agent_params.network_wrappers['main'].softmax_temperature = 0.2 +# agent_params.network_wrappers['main'].learning_rate_decay_rate = 0.95 +# agent_params.network_wrappers['main'].learning_rate_decay_steps = int(DATASET_SIZE / +# agent_params.network_wrappers['main'].batch_size) + +# reward model params +agent_params.network_wrappers['reward_model'] = deepcopy(agent_params.network_wrappers['main']) +agent_params.network_wrappers['reward_model'].learning_rate = 0.0001 +agent_params.network_wrappers['reward_model'].l2_regularization = 0 + +# ER size +agent_params.memory = EpisodicExperienceReplayParameters() +agent_params.memory.max_size = (MemoryGranularity.Transitions, DATASET_SIZE) + + +# E-Greedy schedule +agent_params.exploration.epsilon_schedule = LinearSchedule(0, 0, 10000) +agent_params.exploration.evaluation_epsilon = 0 + + +agent_params.input_filter = InputFilter() +agent_params.input_filter.add_reward_filter('rescale', RewardRescaleFilter(1/200.)) + +################ +# Environment # +################ +env_params = GymVectorEnvironment(level='CartPole-v0') + +######## +# Test # +######## +preset_validation_params = PresetValidationParameters() +preset_validation_params.test = True +preset_validation_params.min_reward_threshold = 150 +preset_validation_params.max_episodes_to_achieve_reward = 250 + +graph_manager = BatchRLGraphManager(agent_params=agent_params, env_params=env_params, + schedule_params=schedule_params, + vis_params=VisualizationParameters(dump_signals_to_csv_every_x_episodes=1), + preset_validation_params=preset_validation_params, + reward_model_num_epochs=50, + train_to_eval_ratio=0.8) diff --git a/rl_coach/presets/Doom_Basic_BC.py b/rl_coach/presets/Doom_Basic_BC.py index bbbd9e0..73ba814 100644 --- a/rl_coach/presets/Doom_Basic_BC.py +++ b/rl_coach/presets/Doom_Basic_BC.py @@ -5,6 +5,7 @@ from rl_coach.environments.doom_environment import DoomEnvironmentParameters from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager from rl_coach.graph_managers.graph_manager import ScheduleParameters from rl_coach.schedules import LinearSchedule +from rl_coach.core_types import PickledReplayBuffer #################### # Graph Scheduling # @@ -29,7 +30,7 @@ agent_params.exploration.evaluation_epsilon = 0 agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(0) agent_params.network_wrappers['main'].replace_mse_with_huber_loss = False agent_params.network_wrappers['main'].batch_size = 120 -agent_params.memory.load_memory_from_file_path = 'datasets/doom_basic.p' +agent_params.memory.load_memory_from_file_path = PickledReplayBuffer('datasets/doom_basic.p') ############### diff --git a/rl_coach/presets/MontezumaRevenge_BC.py b/rl_coach/presets/MontezumaRevenge_BC.py index 951c22b..6aceac5 100644 --- a/rl_coach/presets/MontezumaRevenge_BC.py +++ b/rl_coach/presets/MontezumaRevenge_BC.py @@ -5,6 +5,7 @@ from rl_coach.environments.gym_environment import Atari from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager from rl_coach.graph_managers.graph_manager import ScheduleParameters from rl_coach.memories.memory import MemoryGranularity +from rl_coach.core_types import PickledReplayBuffer #################### # Graph Scheduling # @@ -25,7 +26,7 @@ agent_params.memory.max_size = (MemoryGranularity.Transitions, 1000000) # agent_params.memory.discount = 0.99 agent_params.algorithm.discount = 0.99 agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(0) -agent_params.memory.load_memory_from_file_path = 'datasets/montezuma_revenge.p' +agent_params.memory.load_memory_from_file_path = PickledReplayBuffer('datasets/montezuma_revenge.p') ############### # Environment # diff --git a/rl_coach/spaces.py b/rl_coach/spaces.py index f4ef11e..503598c 100644 --- a/rl_coach/spaces.py +++ b/rl_coach/spaces.py @@ -648,7 +648,7 @@ class SpacesDefinition(object): """ def __init__(self, state: StateSpace, - goal: ObservationSpace, + goal: Union[ObservationSpace, None], action: ActionSpace, reward: RewardSpace): self.state = state