diff --git a/rl_coach/agents/agent.py b/rl_coach/agents/agent.py index 1e93262..c7b755e 100644 --- a/rl_coach/agents/agent.py +++ b/rl_coach/agents/agent.py @@ -685,7 +685,10 @@ class Agent(AgentInterface): """ loss = 0 if self._should_train(): - self.training_epoch += 1 + if self.ap.is_batch_rl_training: + # when training an agent for generating a dataset in batch-rl, we don't want it to be counted as part of + # the training epochs. we only care for training epochs in batch-rl anyway. + self.training_epoch += 1 for network in self.networks.values(): network.set_is_training(True) @@ -1047,3 +1050,11 @@ class Agent(AgentInterface): TimeTypes.EnvironmentSteps: self.total_steps_counter, TimeTypes.WallClockTime: self.agent_logger.get_current_wall_clock_time(), TimeTypes.Epoch: self.training_epoch}[self.parent_level_manager.parent_graph_manager.time_metric] + + def freeze_memory(self): + """ + Shuffle episodes in the memory and freeze it to make sure that no extra data is being pushed anymore. + :return: None + """ + self.call_memory('shuffle_episodes') + self.call_memory('freeze') diff --git a/rl_coach/agents/categorical_dqn_agent.py b/rl_coach/agents/categorical_dqn_agent.py index 2bc4f5c..34b0ec8 100644 --- a/rl_coach/agents/categorical_dqn_agent.py +++ b/rl_coach/agents/categorical_dqn_agent.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # - from typing import Union import numpy as np @@ -83,13 +82,22 @@ class CategoricalDQNAgent(ValueOptimizationAgent): # prediction's format is (batch,actions,atoms) def get_all_q_values_for_states(self, states: StateType): + q_values = None if self.exploration_policy.requires_action_values(): q_values = self.get_prediction(states, outputs=[self.networks['main'].online_network.output_heads[0].q_values]) - else: - q_values = None + return q_values + def get_all_q_values_for_states_and_softmax_probabilities(self, states: StateType): + actions_q_values, softmax_probabilities = None, None + if self.exploration_policy.requires_action_values(): + outputs = [self.networks['main'].online_network.output_heads[0].q_values, + self.networks['main'].online_network.output_heads[0].softmax] + actions_q_values, softmax_probabilities = self.get_prediction(states, outputs=outputs) + + return actions_q_values, softmax_probabilities + def learn_from_batch(self, batch): network_keys = self.ap.network_wrappers['main'].input_embedders_parameters.keys() diff --git a/rl_coach/agents/dfp_agent.py b/rl_coach/agents/dfp_agent.py index cbce242..83c7412 100644 --- a/rl_coach/agents/dfp_agent.py +++ b/rl_coach/agents/dfp_agent.py @@ -182,7 +182,7 @@ class DFPAgent(Agent): action_values = None # choose action according to the exploration policy and the current phase (evaluating or training the agent) - action = self.exploration_policy.get_action(action_values) + action, _ = self.exploration_policy.get_action(action_values) if action_values is not None: action_values = action_values.squeeze() diff --git a/rl_coach/agents/dqn_agent.py b/rl_coach/agents/dqn_agent.py index c30b385..6689b31 100644 --- a/rl_coach/agents/dqn_agent.py +++ b/rl_coach/agents/dqn_agent.py @@ -49,6 +49,7 @@ class DQNNetworkParameters(NetworkParameters): self.batch_size = 32 self.replace_mse_with_huber_loss = True self.create_target_network = True + self.should_get_softmax_probabilities = False class DQNAgentParameters(AgentParameters): diff --git a/rl_coach/agents/nec_agent.py b/rl_coach/agents/nec_agent.py index 60f5301..3e381b3 100644 --- a/rl_coach/agents/nec_agent.py +++ b/rl_coach/agents/nec_agent.py @@ -16,7 +16,7 @@ import os import pickle -from typing import Union +from typing import Union, List import numpy as np @@ -40,6 +40,7 @@ class NECNetworkParameters(NetworkParameters): self.middleware_parameters = FCMiddlewareParameters() self.heads_parameters = [DNDQHeadParameters()] self.optimizer_type = 'Adam' + self.should_get_softmax_probabilities = False class NECAlgorithmParameters(AlgorithmParameters): @@ -166,11 +167,25 @@ class NECAgent(ValueOptimizationAgent): return super().act() - def get_all_q_values_for_states(self, states: StateType): + def get_all_q_values_for_states(self, states: StateType, additional_outputs: List = None): # we need to store the state embeddings regardless if the action is random or not - return self.get_prediction(states) + return self.get_prediction_and_update_embeddings(states) - def get_prediction(self, states): + def get_all_q_values_for_states_and_softmax_probabilities(self, states: StateType): + # get the actions q values and the state embedding + embedding, actions_q_values, softmax_probabilities = self.networks['main'].online_network.predict( + self.prepare_batch_for_inference(states, 'main'), + outputs=[self.networks['main'].online_network.state_embedding, + self.networks['main'].online_network.output_heads[0].output, + self.networks['main'].online_network.output_heads[0].softmax] + ) + if self.phase != RunPhase.TEST: + # store the state embedding for inserting it to the DND later + self.current_episode_state_embeddings.append(embedding.squeeze()) + actions_q_values = actions_q_values[0][0] + return actions_q_values, softmax_probabilities + + def get_prediction_and_update_embeddings(self, states): # get the actions q values and the state embedding embedding, actions_q_values = self.networks['main'].online_network.predict( self.prepare_batch_for_inference(states, 'main'), diff --git a/rl_coach/agents/policy_optimization_agent.py b/rl_coach/agents/policy_optimization_agent.py index 43dc3db..74c39d9 100644 --- a/rl_coach/agents/policy_optimization_agent.py +++ b/rl_coach/agents/policy_optimization_agent.py @@ -147,7 +147,7 @@ class PolicyOptimizationAgent(Agent): if isinstance(self.spaces.action, DiscreteActionSpace): # DISCRETE action_probabilities = np.array(action_values).squeeze() - action = self.exploration_policy.get_action(action_probabilities) + action, _ = self.exploration_policy.get_action(action_probabilities) action_info = ActionInfo(action=action, all_action_probabilities=action_probabilities) diff --git a/rl_coach/agents/qr_dqn_agent.py b/rl_coach/agents/qr_dqn_agent.py index 1e4042a..4975523 100644 --- a/rl_coach/agents/qr_dqn_agent.py +++ b/rl_coach/agents/qr_dqn_agent.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +from copy import copy from typing import Union import numpy as np @@ -79,6 +79,17 @@ class QuantileRegressionDQNAgent(ValueOptimizationAgent): actions_q_values = None return actions_q_values + # prediction's format is (batch,actions,atoms) + def get_all_q_values_for_states_and_softmax_probabilities(self, states: StateType): + actions_q_values, softmax_probabilities = None, None + if self.exploration_policy.requires_action_values(): + outputs = copy(self.networks['main'].online_network.outputs) + outputs.append(self.networks['main'].online_network.output_heads[0].softmax) + quantile_values, softmax_probabilities = self.get_prediction(states, outputs) + actions_q_values = self.get_q_values(quantile_values) + + return actions_q_values, softmax_probabilities + def learn_from_batch(self, batch): network_keys = self.ap.network_wrappers['main'].input_embedders_parameters.keys() diff --git a/rl_coach/agents/value_optimization_agent.py b/rl_coach/agents/value_optimization_agent.py index 3a3ef8a..dcab620 100644 --- a/rl_coach/agents/value_optimization_agent.py +++ b/rl_coach/agents/value_optimization_agent.py @@ -14,7 +14,7 @@ # limitations under the License. # from collections import OrderedDict -from typing import Union +from typing import Union, List import numpy as np @@ -24,7 +24,8 @@ from rl_coach.filters.filter import NoInputFilter from rl_coach.logger import screen from rl_coach.memories.non_episodic.prioritized_experience_replay import PrioritizedExperienceReplay from rl_coach.spaces import DiscreteActionSpace -from copy import deepcopy +from copy import deepcopy, copy + ## This is an abstract agent - there is no learn_from_batch method ## @@ -35,6 +36,12 @@ class ValueOptimizationAgent(Agent): self.q_values = self.register_signal("Q") self.q_value_for_action = {} + # currently we use softmax action probabilities only in batch-rl, + # but we might want to extend this later at some point. + self.should_get_softmax_probabilities = \ + hasattr(self.ap.network_wrappers['main'], 'should_get_softmax_probabilities') and \ + self.ap.network_wrappers['main'].should_get_softmax_probabilities + def init_environment_dependent_modules(self): super().init_environment_dependent_modules() if isinstance(self.spaces.action, DiscreteActionSpace): @@ -45,12 +52,21 @@ class ValueOptimizationAgent(Agent): # Algorithms for which q_values are calculated from predictions will override this function def get_all_q_values_for_states(self, states: StateType): + actions_q_values = None if self.exploration_policy.requires_action_values(): actions_q_values = self.get_prediction(states) - else: - actions_q_values = None + return actions_q_values + def get_all_q_values_for_states_and_softmax_probabilities(self, states: StateType): + actions_q_values, softmax_probabilities = None, None + if self.exploration_policy.requires_action_values(): + outputs = copy(self.networks['main'].online_network.outputs) + outputs.append(self.networks['main'].online_network.output_heads[0].softmax) + + actions_q_values, softmax_probabilities = self.get_prediction(states, outputs=outputs) + return actions_q_values, softmax_probabilities + def get_prediction(self, states, outputs=None): return self.networks['main'].online_network.predict(self.prepare_batch_for_inference(states, 'main'), outputs=outputs) @@ -72,10 +88,19 @@ class ValueOptimizationAgent(Agent): ).format(policy.__class__.__name__)) def choose_action(self, curr_state): - actions_q_values = self.get_all_q_values_for_states(curr_state) + if self.should_get_softmax_probabilities: + actions_q_values, softmax_probabilities = \ + self.get_all_q_values_for_states_and_softmax_probabilities(curr_state) + else: + actions_q_values = self.get_all_q_values_for_states(curr_state) # choose action according to the exploration policy and the current phase (evaluating or training the agent) - action = self.exploration_policy.get_action(actions_q_values) + action, action_probabilities = self.exploration_policy.get_action(actions_q_values) + if self.should_get_softmax_probabilities and softmax_probabilities is not None: + # override the exploration policy's generated probabilities when an action was taken + # with the agent's actual policy + action_probabilities = softmax_probabilities + self._validate_action(self.exploration_policy, action) if actions_q_values is not None: @@ -87,15 +112,18 @@ class ValueOptimizationAgent(Agent): self.q_values.add_sample(actions_q_values) actions_q_values = actions_q_values.squeeze() + action_probabilities = action_probabilities.squeeze() for i, q_value in enumerate(actions_q_values): self.q_value_for_action[i].add_sample(q_value) action_info = ActionInfo(action=action, action_value=actions_q_values[action], - max_action_value=np.max(actions_q_values)) + max_action_value=np.max(actions_q_values), + all_action_probabilities=action_probabilities) + else: - action_info = ActionInfo(action=action) + action_info = ActionInfo(action=action, all_action_probabilities=action_probabilities) return action_info diff --git a/rl_coach/exploration_policies/additive_noise.py b/rl_coach/exploration_policies/additive_noise.py index 5f89889..e6ccbad 100644 --- a/rl_coach/exploration_policies/additive_noise.py +++ b/rl_coach/exploration_policies/additive_noise.py @@ -17,14 +17,17 @@ from typing import List import numpy as np +import scipy.stats from rl_coach.core_types import RunPhase, ActionType -from rl_coach.exploration_policies.exploration_policy import ExplorationPolicy, ExplorationParameters +from rl_coach.exploration_policies.exploration_policy import ContinuousActionExplorationPolicy, ExplorationParameters from rl_coach.schedules import Schedule, LinearSchedule from rl_coach.spaces import ActionSpace, BoxActionSpace # TODO: consider renaming to gaussian sampling + + class AdditiveNoiseParameters(ExplorationParameters): def __init__(self): super().__init__() @@ -36,7 +39,7 @@ class AdditiveNoiseParameters(ExplorationParameters): return 'rl_coach.exploration_policies.additive_noise:AdditiveNoise' -class AdditiveNoise(ExplorationPolicy): +class AdditiveNoise(ContinuousActionExplorationPolicy): """ AdditiveNoise is an exploration policy intended for continuous action spaces. It takes the action from the agent and adds a Gaussian distributed noise to it. The amount of noise added to the action follows the noise amount that diff --git a/rl_coach/exploration_policies/boltzmann.py b/rl_coach/exploration_policies/boltzmann.py index fd12561..e4f20ce 100644 --- a/rl_coach/exploration_policies/boltzmann.py +++ b/rl_coach/exploration_policies/boltzmann.py @@ -19,7 +19,7 @@ from typing import List import numpy as np from rl_coach.core_types import RunPhase, ActionType -from rl_coach.exploration_policies.exploration_policy import ExplorationPolicy, ExplorationParameters +from rl_coach.exploration_policies.exploration_policy import DiscreteActionExplorationPolicy, ExplorationParameters from rl_coach.schedules import Schedule from rl_coach.spaces import ActionSpace @@ -34,8 +34,7 @@ class BoltzmannParameters(ExplorationParameters): return 'rl_coach.exploration_policies.boltzmann:Boltzmann' - -class Boltzmann(ExplorationPolicy): +class Boltzmann(DiscreteActionExplorationPolicy): """ The Boltzmann exploration policy is intended for discrete action spaces. It assumes that each of the possible actions has some value assigned to it (such as the Q value), and uses a softmax function to convert these values @@ -50,7 +49,7 @@ class Boltzmann(ExplorationPolicy): super().__init__(action_space) self.temperature_schedule = temperature_schedule - def get_action(self, action_values: List[ActionType]) -> ActionType: + def get_action(self, action_values: List[ActionType]) -> (ActionType, List[float]): if self.phase == RunPhase.TRAIN: self.temperature_schedule.step() # softmax calculation @@ -59,7 +58,8 @@ class Boltzmann(ExplorationPolicy): # make sure probs sum to 1 probabilities[-1] = 1 - np.sum(probabilities[:-1]) # choose actions according to the probabilities - return np.random.choice(range(self.action_space.shape), p=probabilities) + action = np.random.choice(range(self.action_space.shape), p=probabilities) + return action, probabilities def get_control_param(self): return self.temperature_schedule.current_value diff --git a/rl_coach/exploration_policies/categorical.py b/rl_coach/exploration_policies/categorical.py index 1bd5ee6..511c608 100644 --- a/rl_coach/exploration_policies/categorical.py +++ b/rl_coach/exploration_policies/categorical.py @@ -19,7 +19,7 @@ from typing import List import numpy as np from rl_coach.core_types import RunPhase, ActionType -from rl_coach.exploration_policies.exploration_policy import ExplorationPolicy, ExplorationParameters +from rl_coach.exploration_policies.exploration_policy import DiscreteActionExplorationPolicy, ExplorationParameters from rl_coach.spaces import ActionSpace @@ -29,7 +29,7 @@ class CategoricalParameters(ExplorationParameters): return 'rl_coach.exploration_policies.categorical:Categorical' -class Categorical(ExplorationPolicy): +class Categorical(DiscreteActionExplorationPolicy): """ Categorical exploration policy is intended for discrete action spaces. It expects the action values to represent a probability distribution over the action, from which a single action will be sampled. @@ -42,13 +42,18 @@ class Categorical(ExplorationPolicy): """ super().__init__(action_space) - def get_action(self, action_values: List[ActionType]) -> ActionType: + def get_action(self, action_values: List[ActionType]) -> (ActionType, List[float]): if self.phase == RunPhase.TRAIN: # choose actions according to the probabilities - return np.random.choice(self.action_space.actions, p=action_values) + action = np.random.choice(self.action_space.actions, p=action_values) + return action, action_values else: # take the action with the highest probability - return np.argmax(action_values) + action = np.argmax(action_values) + one_hot_action_probabilities = np.zeros(len(self.action_space.actions)) + one_hot_action_probabilities[action] = 1 + + return action, one_hot_action_probabilities def get_control_param(self): return 0 diff --git a/rl_coach/exploration_policies/e_greedy.py b/rl_coach/exploration_policies/e_greedy.py index 1cb9072..fde73b3 100644 --- a/rl_coach/exploration_policies/e_greedy.py +++ b/rl_coach/exploration_policies/e_greedy.py @@ -20,8 +20,7 @@ import numpy as np from rl_coach.core_types import RunPhase, ActionType from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters -from rl_coach.exploration_policies.exploration_policy import ExplorationParameters -from rl_coach.exploration_policies.exploration_policy import ExplorationPolicy +from rl_coach.exploration_policies.exploration_policy import ExplorationParameters, ExplorationPolicy from rl_coach.schedules import Schedule, LinearSchedule from rl_coach.spaces import ActionSpace, DiscreteActionSpace, BoxActionSpace from rl_coach.utils import dynamic_import_and_instantiate_module_from_params @@ -82,26 +81,32 @@ class EGreedy(ExplorationPolicy): epsilon = self.evaluation_epsilon if self.phase == RunPhase.TEST else self.epsilon_schedule.current_value return self.current_random_value >= epsilon - def get_action(self, action_values: List[ActionType]) -> ActionType: + def get_action(self, action_values: List[ActionType]) -> (ActionType, List[float]): epsilon = self.evaluation_epsilon if self.phase == RunPhase.TEST else self.epsilon_schedule.current_value if isinstance(self.action_space, DiscreteActionSpace): - top_action = np.argmax(action_values) if self.current_random_value < epsilon: chosen_action = self.action_space.sample() + probabilities = np.full(len(self.action_space.actions), + 1. / (self.action_space.high[0] - self.action_space.low[0] + 1)) else: - chosen_action = top_action + chosen_action = np.argmax(action_values) + + # one-hot probabilities vector + probabilities = np.zeros(len(self.action_space.actions)) + probabilities[chosen_action] = 1 + + self.step_epsilon() + return chosen_action, probabilities + else: if self.current_random_value < epsilon and self.phase == RunPhase.TRAIN: chosen_action = self.action_space.sample() else: chosen_action = self.continuous_exploration_policy.get_action(action_values) - # step the epsilon schedule and generate a new random value for next time - if self.phase == RunPhase.TRAIN: - self.epsilon_schedule.step() - self.current_random_value = np.random.rand() - return chosen_action + self.step_epsilon() + return chosen_action def get_control_param(self): if isinstance(self.action_space, DiscreteActionSpace): @@ -113,3 +118,9 @@ class EGreedy(ExplorationPolicy): super().change_phase(phase) if isinstance(self.action_space, BoxActionSpace): self.continuous_exploration_policy.change_phase(phase) + + def step_epsilon(self): + # step the epsilon schedule and generate a new random value for next time + if self.phase == RunPhase.TRAIN: + self.epsilon_schedule.step() + self.current_random_value = np.random.rand() diff --git a/rl_coach/exploration_policies/exploration_policy.py b/rl_coach/exploration_policies/exploration_policy.py index 16093a0..a345895 100644 --- a/rl_coach/exploration_policies/exploration_policy.py +++ b/rl_coach/exploration_policies/exploration_policy.py @@ -18,7 +18,7 @@ from typing import List from rl_coach.base_parameters import Parameters from rl_coach.core_types import RunPhase, ActionType -from rl_coach.spaces import ActionSpace +from rl_coach.spaces import ActionSpace, DiscreteActionSpace, BoxActionSpace, GoalsSpace class ExplorationParameters(Parameters): @@ -54,14 +54,10 @@ class ExplorationPolicy(object): Given a list of values corresponding to each action, choose one actions according to the exploration policy :param action_values: A list of action values - :return: The chosen action + :return: The chosen action, + The probability of the action (if available, otherwise 1 for absolute certainty in the action) """ - if self.__class__ == ExplorationPolicy: - raise ValueError("The ExplorationPolicy class is an abstract class and should not be used directly. " - "Please set the exploration parameters to point to an inheriting class like EGreedy or " - "AdditiveNoise") - else: - raise ValueError("The get_action function should be overridden in the inheriting exploration class") + raise NotImplementedError() def change_phase(self, phase): """ @@ -82,3 +78,42 @@ class ExplorationPolicy(object): def get_control_param(self): return 0 + + +class DiscreteActionExplorationPolicy(ExplorationPolicy): + """ + A discrete action exploration policy. + """ + def __init__(self, action_space: ActionSpace): + """ + :param action_space: the action space used by the environment + """ + assert isinstance(action_space, DiscreteActionSpace) + super().__init__(action_space) + + def get_action(self, action_values: List[ActionType]) -> (ActionType, List): + """ + Given a list of values corresponding to each action, + choose one actions according to the exploration policy + :param action_values: A list of action values + :return: The chosen action, + The probabilities of actions to select from (if not available a one-hot vector) + """ + if self.__class__ == ExplorationPolicy: + raise ValueError("The ExplorationPolicy class is an abstract class and should not be used directly. " + "Please set the exploration parameters to point to an inheriting class like EGreedy or " + "AdditiveNoise") + else: + raise ValueError("The get_action function should be overridden in the inheriting exploration class") + + +class ContinuousActionExplorationPolicy(ExplorationPolicy): + """ + A continuous action exploration policy. + """ + def __init__(self, action_space: ActionSpace): + """ + :param action_space: the action space used by the environment + """ + assert isinstance(action_space, BoxActionSpace) or isinstance(action_space, GoalsSpace) + super().__init__(action_space) diff --git a/rl_coach/exploration_policies/greedy.py b/rl_coach/exploration_policies/greedy.py index 8abe030..4e809f9 100644 --- a/rl_coach/exploration_policies/greedy.py +++ b/rl_coach/exploration_policies/greedy.py @@ -19,7 +19,7 @@ from typing import List import numpy as np from rl_coach.core_types import ActionType -from rl_coach.exploration_policies.exploration_policy import ExplorationPolicy, ExplorationParameters +from rl_coach.exploration_policies.exploration_policy import ExplorationParameters, ExplorationPolicy from rl_coach.spaces import ActionSpace, DiscreteActionSpace, BoxActionSpace @@ -41,9 +41,12 @@ class Greedy(ExplorationPolicy): """ super().__init__(action_space) - def get_action(self, action_values: List[ActionType]) -> ActionType: + def get_action(self, action_values: List[ActionType]): if type(self.action_space) == DiscreteActionSpace: - return np.argmax(action_values) + action = np.argmax(action_values) + one_hot_action_probabilities = np.zeros(len(self.action_space.actions)) + one_hot_action_probabilities[action] = 1 + return action, one_hot_action_probabilities if type(self.action_space) == BoxActionSpace: return action_values diff --git a/rl_coach/exploration_policies/ou_process.py b/rl_coach/exploration_policies/ou_process.py index 28daae1..2fa7140 100644 --- a/rl_coach/exploration_policies/ou_process.py +++ b/rl_coach/exploration_policies/ou_process.py @@ -19,12 +19,13 @@ from typing import List import numpy as np from rl_coach.core_types import RunPhase, ActionType -from rl_coach.exploration_policies.exploration_policy import ExplorationPolicy, ExplorationParameters +from rl_coach.exploration_policies.exploration_policy import ContinuousActionExplorationPolicy, ExplorationParameters from rl_coach.spaces import ActionSpace, BoxActionSpace, GoalsSpace # Based on on the description in: # https://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab + class OUProcessParameters(ExplorationParameters): def __init__(self): super().__init__() @@ -39,7 +40,7 @@ class OUProcessParameters(ExplorationParameters): # Ornstein-Uhlenbeck process -class OUProcess(ExplorationPolicy): +class OUProcess(ContinuousActionExplorationPolicy): """ OUProcess exploration policy is intended for continuous action spaces, and selects the action according to an Ornstein-Uhlenbeck process. The Ornstein-Uhlenbeck process implements the action as a Gaussian process, where @@ -56,10 +57,6 @@ class OUProcess(ExplorationPolicy): self.state = np.zeros(self.action_space.shape) self.dt = dt - if not (isinstance(action_space, BoxActionSpace) or isinstance(action_space, GoalsSpace)): - raise ValueError("OU process exploration works only for continuous controls." - "The given action space is of type: {}".format(action_space.__class__.__name__)) - def reset(self): self.state = np.zeros(self.action_space.shape) diff --git a/rl_coach/exploration_policies/parameter_noise.py b/rl_coach/exploration_policies/parameter_noise.py index 7854329..377895e 100644 --- a/rl_coach/exploration_policies/parameter_noise.py +++ b/rl_coach/exploration_policies/parameter_noise.py @@ -59,9 +59,13 @@ class ParameterNoise(ExplorationPolicy): self.network_params = network_params self._replace_network_dense_layers() - def get_action(self, action_values: List[ActionType]) -> ActionType: + def get_action(self, action_values: List[ActionType]): if type(self.action_space) == DiscreteActionSpace: - return np.argmax(action_values) + action = np.argmax(action_values) + one_hot_action_probabilities = np.zeros(len(self.action_space.actions)) + one_hot_action_probabilities[action] = 1 + + return action, one_hot_action_probabilities elif type(self.action_space) == BoxActionSpace: action_values_mean = action_values[0].squeeze() action_values_std = action_values[1].squeeze() diff --git a/rl_coach/exploration_policies/truncated_normal.py b/rl_coach/exploration_policies/truncated_normal.py index 396f348..91848ed 100644 --- a/rl_coach/exploration_policies/truncated_normal.py +++ b/rl_coach/exploration_policies/truncated_normal.py @@ -20,7 +20,7 @@ import numpy as np from scipy.stats import truncnorm from rl_coach.core_types import RunPhase, ActionType -from rl_coach.exploration_policies.exploration_policy import ExplorationPolicy, ExplorationParameters +from rl_coach.exploration_policies.exploration_policy import ExplorationParameters, ContinuousActionExplorationPolicy from rl_coach.schedules import Schedule, LinearSchedule from rl_coach.spaces import ActionSpace, BoxActionSpace @@ -38,7 +38,7 @@ class TruncatedNormalParameters(ExplorationParameters): return 'rl_coach.exploration_policies.truncated_normal:TruncatedNormal' -class TruncatedNormal(ExplorationPolicy): +class TruncatedNormal(ContinuousActionExplorationPolicy): """ The TruncatedNormal exploration policy is intended for continuous action spaces. It samples the action from a normal distribution, where the mean action is given by the agent, and the standard deviation can be given in t diff --git a/rl_coach/graph_managers/batch_rl_graph_manager.py b/rl_coach/graph_managers/batch_rl_graph_manager.py index e930fd6..5c9da2e 100644 --- a/rl_coach/graph_managers/batch_rl_graph_manager.py +++ b/rl_coach/graph_managers/batch_rl_graph_manager.py @@ -18,15 +18,17 @@ from typing import Tuple, List, Union from rl_coach.agents.dqn_agent import DQNAgentParameters from rl_coach.agents.nec_agent import NECAgentParameters +from rl_coach.architectures.network_wrapper import NetworkWrapper from rl_coach.base_parameters import AgentParameters, VisualizationParameters, TaskParameters, \ PresetValidationParameters -from rl_coach.core_types import RunPhase +from rl_coach.core_types import RunPhase, TotalStepsCounter, TrainingSteps, EnvironmentEpisodes, EnvironmentSteps from rl_coach.environments.environment import EnvironmentParameters, Environment from rl_coach.graph_managers.graph_manager import ScheduleParameters from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager from rl_coach.level_manager import LevelManager from rl_coach.logger import screen +from rl_coach.schedules import LinearSchedule from rl_coach.spaces import SpacesDefinition from rl_coach.utils import short_dynamic_import @@ -35,26 +37,62 @@ from rl_coach.memories.episodic import EpisodicExperienceReplayParameters from rl_coach.core_types import TimeTypes +# TODO build a tutorial for batch RL class BatchRLGraphManager(BasicRLGraphManager): """ - A batch RL graph manager creates scenario of learning from a dataset without a simulator. + A batch RL graph manager creates a scenario of learning from a dataset without a simulator. + + If an environment is given (useful either for research purposes, or for experimenting with a toy problem before + actually working with a real dataset), we can use it in order to collect a dataset to later be used to train the + actual agent. The collected dataset, in this case, can be collected either by randomly acting in the environment + (only running in heatup), or alternatively by training a different agent in the environment and using its collected + data as a dataset. If an experience generating agent parameters are given, we will instantiate this agent and use it + in order to train on the environment and then use this dataset to actually train an agent. Otherwise, we will + collect a random dataset. + :param agent_params: the parameters of the agent to train using batch RL + :param env_params: [optional] environment parameters, for cases where we want to first collect a dataset + :param vis_params: visualization parameters + :param preset_validation_params: preset validation parameters, to be used for testing purposes + :param name: graph name + :param spaces_definition: when working with a dataset, we need to get a description of the actual state and action + spaces of the problem + :param reward_model_num_epochs: the number of epochs to go over the dataset for training a reward model for the + 'direct method' and 'doubly robust' OPE methods. + :param train_to_eval_ratio: percentage of the data transitions to be used for training vs. evaluation. i.e. a value + of 0.8 means ~80% of the transitions will be used for training and ~20% will be used for + evaluation using OPE. + :param experience_generating_agent_params: [optional] parameters of an agent to be trained vs. an environment, whose + his collected experience will be used to train the acutal (another) agent + :param experience_generating_schedule_params: [optional] graph scheduling parameters for training the experience + generating agent """ - def __init__(self, agent_params: AgentParameters, env_params: Union[EnvironmentParameters, None], + def __init__(self, agent_params: AgentParameters, + env_params: Union[EnvironmentParameters, None], schedule_params: ScheduleParameters, vis_params: VisualizationParameters = VisualizationParameters(), preset_validation_params: PresetValidationParameters = PresetValidationParameters(), name='batch_rl_graph', spaces_definition: SpacesDefinition = None, reward_model_num_epochs: int = 100, - train_to_eval_ratio: float = 0.8): + train_to_eval_ratio: float = 0.8, experience_generating_agent_params: AgentParameters = None, + experience_generating_schedule_params: ScheduleParameters = None): super().__init__(agent_params, env_params, schedule_params, vis_params, preset_validation_params, name) self.is_batch_rl = True self.time_metric = TimeTypes.Epoch self.reward_model_num_epochs = reward_model_num_epochs self.spaces_definition = spaces_definition + self.is_collecting_random_dataset = experience_generating_agent_params is None # setting this here to make sure that, by default, train_to_eval_ratio gets a value < 1 - # (its default value in the memory is 1) - self.agent_params.memory.train_to_eval_ratio = train_to_eval_ratio + # (its default value in the memory is 1, so not to affect other non-batch-rl scenarios) + if self.is_collecting_random_dataset: + self.agent_params.memory.train_to_eval_ratio = train_to_eval_ratio + else: + experience_generating_agent_params.memory.train_to_eval_ratio = train_to_eval_ratio + self.experience_generating_agent_params = experience_generating_agent_params + self.experience_generating_agent = None + + self.set_schedule_params(experience_generating_schedule_params) + self.schedule_params = schedule_params def _create_graph(self, task_parameters: TaskParameters) -> Tuple[List[LevelManager], List[Environment]]: if self.env_params: @@ -76,22 +114,41 @@ class BatchRLGraphManager(BasicRLGraphManager): self.agent_params.task_parameters = task_parameters # TODO: this should probably be passed in a different way self.agent_params.name = "agent" self.agent_params.is_batch_rl_training = True + self.agent_params.network_wrappers['main'].should_get_softmax_probabilities = True if 'reward_model' not in self.agent_params.network_wrappers: # user hasn't defined params for the reward model. we will use the same params as used for the 'main' # network. self.agent_params.network_wrappers['reward_model'] = deepcopy(self.agent_params.network_wrappers['main']) - agent = short_dynamic_import(self.agent_params.path)(self.agent_params) + self.agent = short_dynamic_import(self.agent_params.path)(self.agent_params) + agents = {'agent': self.agent} + + if not self.is_collecting_random_dataset: + self.experience_generating_agent_params.visualization.dump_csv = False + self.experience_generating_agent_params.task_parameters = task_parameters + self.experience_generating_agent_params.name = "experience_gen_agent" + self.experience_generating_agent_params.network_wrappers['main'].should_get_softmax_probabilities = True + + # we need to set these manually as these are usually being set for us only for the default agent + self.experience_generating_agent_params.input_filter = self.agent_params.input_filter + self.experience_generating_agent_params.output_filter = self.agent_params.output_filter + + self.experience_generating_agent = short_dynamic_import( + self.experience_generating_agent_params.path)(self.experience_generating_agent_params) + + agents['experience_generating_agent'] = self.experience_generating_agent if not env and not self.agent_params.memory.load_memory_from_file_path: screen.warning("A BatchRLGraph requires setting a dataset to load into the agent's memory or alternatively " "using an environment to create a (random) dataset from. This agent should only be used for " "inference. ") # set level manager - level_manager = LevelManager(agents=agent, environment=env, name="main_level", + # - although we will be using each agent separately, we have to have both agents initialized together with the + # LevelManager, so to have them both properly initialized + level_manager = LevelManager(agents=agents, + environment=env, name="main_level", spaces_definition=self.spaces_definition) - if env: return [level_manager], [env] else: @@ -123,12 +180,34 @@ class BatchRLGraphManager(BasicRLGraphManager): # an environment and a dataset to load from, we will use the environment only for evaluating the policy, # and will not run heatup. - # heatup - if self.env_params is not None and not self.agent_params.memory.load_memory_from_file_path: - self.heatup(self.heatup_steps) + screen.log_title("Starting to improve an agent collecting experience to use for training the actual agent in a " + "Batch RL fashion") + + if self.is_collecting_random_dataset: + # heatup + if self.env_params is not None and not self.agent_params.memory.load_memory_from_file_path: + self.heatup(self.heatup_steps) + else: + # set the experience generating agent to train + self.level_managers[0].agents = {'experience_generating_agent': self.experience_generating_agent} + + # collect a dataset using the experience generating agent + super().improve() + + # set the acquired experience to the actual agent that we're going to train + self.agent.memory = self.experience_generating_agent.memory + + # switch the graph scheduling parameters + self.set_schedule_params(self.schedule_params) + + # set the actual agent to train + self.level_managers[0].agents = {'agent': self.agent} + + # this agent never actually plays + self.level_managers[0].agents['agent'].ap.algorithm.num_consecutive_playing_steps = EnvironmentSteps(0) # from this point onwards, the dataset cannot be changed anymore. Allows for performance improvements. - self.level_managers[0].agents['agent'].memory.freeze() + self.level_managers[0].agents['agent'].freeze_memory() self.initialize_ope_models_and_stats() @@ -141,15 +220,13 @@ class BatchRLGraphManager(BasicRLGraphManager): # the outer most training loop improve_steps_end = self.total_steps_counters[RunPhase.TRAIN] + self.improve_steps while self.total_steps_counters[RunPhase.TRAIN] < improve_steps_end: - # TODO if we have an environment, do we want to use it to have the agent train against, and use the - # collected replay buffer as a dataset? (as oppose to what we currently have, where the dataset is built - # during heatup, and is composed on random actions) # perform several steps of training if self.steps_between_evaluation_periods.num_steps > 0: with self.phase_context(RunPhase.TRAIN): self.reset_internal_state(force_environment_reset=True) - steps_between_evaluation_periods_end = self.current_step_counter + self.steps_between_evaluation_periods + steps_between_evaluation_periods_end = self.current_step_counter + \ + self.steps_between_evaluation_periods while self.current_step_counter < steps_between_evaluation_periods_end: self.train() @@ -168,8 +245,8 @@ class BatchRLGraphManager(BasicRLGraphManager): def initialize_ope_models_and_stats(self): """ - - :return: + Improve a reward model of the MDP, to be used for some of the off-policy evaluation (OPE) methods. + e.g. 'direct method' and 'doubly robust'. """ agent = self.level_managers[0].agents['agent'] @@ -193,6 +270,3 @@ class BatchRLGraphManager(BasicRLGraphManager): :return: """ self.level_managers[0].agents['agent'].run_off_policy_evaluation() - - - diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index aef08cf..bc29398 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -95,10 +95,7 @@ class GraphManager(object): self.level_managers = [] # type: List[LevelManager] self.top_level_manager = None self.environments = [] - self.heatup_steps = schedule_params.heatup_steps - self.evaluation_steps = schedule_params.evaluation_steps - self.steps_between_evaluation_periods = schedule_params.steps_between_evaluation_periods - self.improve_steps = schedule_params.improve_steps + self.set_schedule_params(schedule_params) self.visualization_parameters = vis_params self.name = name self.task_parameters = None @@ -759,3 +756,14 @@ class GraphManager(object): if hasattr(self, 'data_store_params'): data_store = self.get_data_store(self.data_store_params) data_store.save_to_store() + + def set_schedule_params(self, schedule_params: ScheduleParameters): + """ + Set schedule parameters for the graph. + + :param schedule_params: the schedule params to set. + """ + self.heatup_steps = schedule_params.heatup_steps + self.evaluation_steps = schedule_params.evaluation_steps + self.steps_between_evaluation_periods = schedule_params.steps_between_evaluation_periods + self.improve_steps = schedule_params.improve_steps diff --git a/rl_coach/memories/episodic/episodic_experience_replay.py b/rl_coach/memories/episodic/episodic_experience_replay.py index 836756c..a7dddbd 100644 --- a/rl_coach/memories/episodic/episodic_experience_replay.py +++ b/rl_coach/memories/episodic/episodic_experience_replay.py @@ -452,8 +452,6 @@ class EpisodicExperienceReplay(Memory): progress_bar.update(len(episode_ids)) progress_bar.close() - self.shuffle_episodes() - def freeze(self): """ Freezing the replay buffer does not allow any new transitions to be added to the memory. diff --git a/rl_coach/off_policy_evaluators/rl/weighted_importance_sampling.py b/rl_coach/off_policy_evaluators/rl/weighted_importance_sampling.py index c97ad3c..7d009a8 100644 --- a/rl_coach/off_policy_evaluators/rl/weighted_importance_sampling.py +++ b/rl_coach/off_policy_evaluators/rl/weighted_importance_sampling.py @@ -20,7 +20,7 @@ from rl_coach.core_types import Episode class WeightedImportanceSampling(object): -# TODO rename and add PDIS +# TODO add PDIS @staticmethod def evaluate(evaluation_dataset_as_episodes: List[Episode]) -> float: """ diff --git a/rl_coach/presets/CartPole_DQN_BatchRL_BCQ.py b/rl_coach/presets/CartPole_DQN_BatchRL_BCQ.py index b758464..91a777c 100644 --- a/rl_coach/presets/CartPole_DQN_BatchRL_BCQ.py +++ b/rl_coach/presets/CartPole_DQN_BatchRL_BCQ.py @@ -1,4 +1,5 @@ from copy import deepcopy +from rl_coach.agents.dqn_agent import DQNAgentParameters from rl_coach.architectures.tensorflow_components.layers import Dense from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters @@ -45,10 +46,10 @@ agent_params.network_wrappers['main'].batch_size = 128 agent_params.algorithm.num_steps_between_copying_online_weights_to_target = TrainingSteps( DATASET_SIZE / agent_params.network_wrappers['main'].batch_size) # -# agent_params.algorithm.num_steps_between_copying_online_weights_to_target = TrainingSteps( -# 3) +agent_params.algorithm.num_steps_between_copying_online_weights_to_target = TrainingSteps( + 100) +# agent_params.algorithm.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(100) -agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(0) agent_params.algorithm.discount = 0.98 # can use either a kNN or a NN based model for predicting which actions not to max over in the bellman equation @@ -79,17 +80,49 @@ agent_params.network_wrappers['imitation_model'].middleware_parameters.scheme = # ER size agent_params.memory = EpisodicExperienceReplayParameters() -agent_params.memory.max_size = (MemoryGranularity.Transitions, DATASET_SIZE) - # E-Greedy schedule agent_params.exploration.epsilon_schedule = LinearSchedule(0, 0, 10000) agent_params.exploration.evaluation_epsilon = 0 - +# Input filtering agent_params.input_filter = InputFilter() agent_params.input_filter.add_reward_filter('rescale', RewardRescaleFilter(1/200.)) + + + +# Experience Generating Agent parameters +experience_generating_agent_params = DQNAgentParameters() + +# schedule parameters +experience_generating_schedule_params = ScheduleParameters() +experience_generating_schedule_params.heatup_steps = EnvironmentSteps(1000) +experience_generating_schedule_params.improve_steps = TrainingSteps( + DATASET_SIZE - experience_generating_schedule_params.heatup_steps.num_steps) +experience_generating_schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(10) +experience_generating_schedule_params.evaluation_steps = EnvironmentEpisodes(1) + +# DQN params +experience_generating_agent_params.algorithm.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(100) +experience_generating_agent_params.algorithm.discount = 0.99 +experience_generating_agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(1) + +# NN configuration +experience_generating_agent_params.network_wrappers['main'].learning_rate = 0.00025 +experience_generating_agent_params.network_wrappers['main'].replace_mse_with_huber_loss = False + +# ER size +experience_generating_agent_params.memory = EpisodicExperienceReplayParameters() +experience_generating_agent_params.memory.max_size = \ + (MemoryGranularity.Transitions, + experience_generating_schedule_params.heatup_steps.num_steps + + experience_generating_schedule_params.improve_steps.num_steps) + +# E-Greedy schedule +experience_generating_agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000) + + ################ # Environment # ################ @@ -101,11 +134,14 @@ env_params = GymVectorEnvironment(level='CartPole-v0') preset_validation_params = PresetValidationParameters() preset_validation_params.test = True preset_validation_params.min_reward_threshold = 150 -preset_validation_params.max_episodes_to_achieve_reward = 2000 +preset_validation_params.max_episodes_to_achieve_reward = 50 -graph_manager = BatchRLGraphManager(agent_params=agent_params, env_params=env_params, +graph_manager = BatchRLGraphManager(agent_params=agent_params, + experience_generating_agent_params=experience_generating_agent_params, + experience_generating_schedule_params=experience_generating_schedule_params, + env_params=env_params, schedule_params=schedule_params, vis_params=VisualizationParameters(dump_signals_to_csv_every_x_episodes=1), preset_validation_params=preset_validation_params, reward_model_num_epochs=30, - train_to_eval_ratio=0.8) + train_to_eval_ratio=0.4) diff --git a/rl_coach/tests/exploration_policies/test_additive_noise.py b/rl_coach/tests/exploration_policies/test_additive_noise.py index a32124b..b735515 100644 --- a/rl_coach/tests/exploration_policies/test_additive_noise.py +++ b/rl_coach/tests/exploration_policies/test_additive_noise.py @@ -16,10 +16,6 @@ def test_init(): action_space = DiscreteActionSpace(3) noise_schedule = LinearSchedule(1.0, 1.0, 1000) - # additive noise doesn't work for discrete controls - with pytest.raises(ValueError): - policy = AdditiveNoise(action_space, noise_schedule, 0) - # additive noise requires a bounded range for the actions action_space = BoxActionSpace(np.array([10])) with pytest.raises(ValueError): diff --git a/rl_coach/tests/exploration_policies/test_e_greedy.py b/rl_coach/tests/exploration_policies/test_e_greedy.py index 76d1f36..ded851c 100644 --- a/rl_coach/tests/exploration_policies/test_e_greedy.py +++ b/rl_coach/tests/exploration_policies/test_e_greedy.py @@ -21,14 +21,14 @@ def test_get_action(): # verify that test phase gives greedy actions (evaluation_epsilon = 0) policy.change_phase(RunPhase.TEST) for i in range(100): - best_action = policy.get_action(np.array([10, 20, 30])) + best_action, _ = policy.get_action(np.array([10, 20, 30])) assert best_action == 2 # verify that train phase gives uniform actions (exploration = 1) policy.change_phase(RunPhase.TRAIN) counters = np.array([0, 0, 0]) for i in range(30000): - best_action = policy.get_action(np.array([10, 20, 30])) + best_action, _ = policy.get_action(np.array([10, 20, 30])) counters[best_action] += 1 assert np.all(counters > 9500) # this is noisy so we allow 5% error diff --git a/rl_coach/tests/exploration_policies/test_greedy.py b/rl_coach/tests/exploration_policies/test_greedy.py index ced5efb..5b978b5 100644 --- a/rl_coach/tests/exploration_policies/test_greedy.py +++ b/rl_coach/tests/exploration_policies/test_greedy.py @@ -15,7 +15,7 @@ def test_get_action(): action_space = DiscreteActionSpace(3) policy = Greedy(action_space) - best_action = policy.get_action(np.array([10, 20, 30])) + best_action, _ = policy.get_action(np.array([10, 20, 30])) assert best_action == 2 # continuous control diff --git a/rl_coach/tests/exploration_policies/test_ou_process.py b/rl_coach/tests/exploration_policies/test_ou_process.py index 2918e0c..a31e7c0 100644 --- a/rl_coach/tests/exploration_policies/test_ou_process.py +++ b/rl_coach/tests/exploration_policies/test_ou_process.py @@ -16,10 +16,6 @@ def test_init(): # discrete control action_space = DiscreteActionSpace(3) - # OU process doesn't work for discrete controls - with pytest.raises(ValueError): - policy = OUProcess(action_space, mu=0, theta=0.1, sigma=0.2, dt=0.01) - @pytest.mark.unit_test def test_get_action():