mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Batch RL (#238)
This commit is contained in:
0
__init__.py
Normal file
0
__init__.py
Normal file
@@ -113,7 +113,8 @@ In Coach, this can be done in two steps -
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
coach -p Doom_Basic_BC -cp='agent.load_memory_from_file_path=\"<experiment dir>/replay_buffer.p\"'
|
||||
from rl_coach.core_types import PickledReplayBuffer
|
||||
coach -p Doom_Basic_BC -cp='agent.load_memory_from_file_path=PickledReplayBuffer(\"<experiment dir>/replay_buffer.p\"')
|
||||
|
||||
|
||||
Visualizations
|
||||
|
||||
@@ -34,6 +34,9 @@ from rl_coach.spaces import SpacesDefinition, VectorObservationSpace, GoalsSpace
|
||||
from rl_coach.utils import Signal, force_list
|
||||
from rl_coach.utils import dynamic_import_and_instantiate_module_from_params
|
||||
from rl_coach.memories.backend.memory_impl import get_memory_backend
|
||||
from rl_coach.core_types import TimeTypes
|
||||
from rl_coach.off_policy_evaluators.ope_manager import OpeManager
|
||||
from rl_coach.core_types import PickledReplayBuffer, CsvDataset
|
||||
|
||||
|
||||
class Agent(AgentInterface):
|
||||
@@ -49,10 +52,10 @@ class Agent(AgentInterface):
|
||||
and self.ap.memory.shared_memory
|
||||
if self.shared_memory:
|
||||
self.shared_memory_scratchpad = self.ap.task_parameters.shared_memory_scratchpad
|
||||
self.name = agent_parameters.name
|
||||
self.parent = parent
|
||||
self.parent_level_manager = None
|
||||
self.full_name_id = agent_parameters.full_name_id = self.name
|
||||
# TODO this needs to be sorted out. Why the duplicates for the agent's name?
|
||||
self.full_name_id = agent_parameters.full_name_id = self.name = agent_parameters.name
|
||||
|
||||
if type(agent_parameters.task_parameters) == DistributedTaskParameters:
|
||||
screen.log_title("Creating agent - name: {} task id: {} (may take up to 30 seconds due to "
|
||||
@@ -84,9 +87,17 @@ class Agent(AgentInterface):
|
||||
self.memory.set_memory_backend(self.memory_backend)
|
||||
|
||||
if agent_parameters.memory.load_memory_from_file_path:
|
||||
screen.log_title("Loading replay buffer from pickle. Pickle path: {}"
|
||||
.format(agent_parameters.memory.load_memory_from_file_path))
|
||||
self.memory.load(agent_parameters.memory.load_memory_from_file_path)
|
||||
if isinstance(agent_parameters.memory.load_memory_from_file_path, PickledReplayBuffer):
|
||||
screen.log_title("Loading a pickled replay buffer. Pickled file path: {}"
|
||||
.format(agent_parameters.memory.load_memory_from_file_path.filepath))
|
||||
self.memory.load_pickled(agent_parameters.memory.load_memory_from_file_path.filepath)
|
||||
elif isinstance(agent_parameters.memory.load_memory_from_file_path, CsvDataset):
|
||||
screen.log_title("Loading a replay buffer from a CSV file. CSV file path: {}"
|
||||
.format(agent_parameters.memory.load_memory_from_file_path.filepath))
|
||||
self.memory.load_csv(agent_parameters.memory.load_memory_from_file_path)
|
||||
else:
|
||||
raise ValueError('Trying to load a replay buffer using an unsupported method - {}. '
|
||||
.format(agent_parameters.memory.load_memory_from_file_path))
|
||||
|
||||
if self.shared_memory and self.is_chief:
|
||||
self.shared_memory_scratchpad.add(self.memory_lookup_name, self.memory)
|
||||
@@ -147,6 +158,7 @@ class Agent(AgentInterface):
|
||||
self.total_steps_counter = 0
|
||||
self.running_reward = None
|
||||
self.training_iteration = 0
|
||||
self.training_epoch = 0
|
||||
self.last_target_network_update_step = 0
|
||||
self.last_training_phase_step = 0
|
||||
self.current_episode = self.ap.current_episode = 0
|
||||
@@ -184,6 +196,7 @@ class Agent(AgentInterface):
|
||||
self.discounted_return = self.register_signal('Discounted Return')
|
||||
if isinstance(self.in_action_space, GoalsSpace):
|
||||
self.distance_from_goal = self.register_signal('Distance From Goal', dump_one_value_per_step=True)
|
||||
|
||||
# use seed
|
||||
if self.ap.task_parameters.seed is not None:
|
||||
random.seed(self.ap.task_parameters.seed)
|
||||
@@ -193,6 +206,9 @@ class Agent(AgentInterface):
|
||||
random.seed()
|
||||
np.random.seed()
|
||||
|
||||
# batch rl
|
||||
self.ope_manager = OpeManager() if self.ap.is_batch_rl_training else None
|
||||
|
||||
@property
|
||||
def parent(self) -> 'LevelManager':
|
||||
"""
|
||||
@@ -228,6 +244,7 @@ class Agent(AgentInterface):
|
||||
format(graph_name=self.parent_level_manager.parent_graph_manager.name,
|
||||
level_name=self.parent_level_manager.name,
|
||||
agent_full_id='.'.join(self.full_name_id.split('/')))
|
||||
self.agent_logger.set_index_name(self.parent_level_manager.parent_graph_manager.time_metric.value.name)
|
||||
self.agent_logger.set_logger_filenames(self.ap.task_parameters.experiment_path, logger_prefix=logger_prefix,
|
||||
add_timestamp=True, task_id=self.task_id)
|
||||
if self.ap.visualization.dump_in_episode_signals:
|
||||
@@ -387,7 +404,8 @@ class Agent(AgentInterface):
|
||||
elif ending_evaluation:
|
||||
# we write to the next episode, because it could be that the current episode was already written
|
||||
# to disk and then we won't write it again
|
||||
self.agent_logger.set_current_time(self.current_episode + 1)
|
||||
self.agent_logger.set_current_time(self.get_current_time() + 1)
|
||||
|
||||
evaluation_reward = self.accumulated_rewards_across_evaluation_episodes / self.num_evaluation_episodes_completed
|
||||
self.agent_logger.create_signal_value(
|
||||
'Evaluation Reward', evaluation_reward)
|
||||
@@ -471,8 +489,11 @@ class Agent(AgentInterface):
|
||||
:return: None
|
||||
"""
|
||||
# log all the signals to file
|
||||
self.agent_logger.set_current_time(self.current_episode)
|
||||
current_time = self.get_current_time()
|
||||
self.agent_logger.set_current_time(current_time)
|
||||
self.agent_logger.create_signal_value('Training Iter', self.training_iteration)
|
||||
self.agent_logger.create_signal_value('Episode #', self.current_episode)
|
||||
self.agent_logger.create_signal_value('Epoch', self.training_epoch)
|
||||
self.agent_logger.create_signal_value('In Heatup', int(self._phase == RunPhase.HEATUP))
|
||||
self.agent_logger.create_signal_value('ER #Transitions', self.call_memory('num_transitions'))
|
||||
self.agent_logger.create_signal_value('ER #Episodes', self.call_memory('length'))
|
||||
@@ -485,13 +506,17 @@ class Agent(AgentInterface):
|
||||
if self._phase == RunPhase.TRAIN else np.nan)
|
||||
|
||||
self.agent_logger.create_signal_value('Update Target Network', 0, overwrite=False)
|
||||
self.agent_logger.update_wall_clock_time(self.current_episode)
|
||||
self.agent_logger.update_wall_clock_time(current_time)
|
||||
|
||||
# The following signals are created with meaningful values only when an evaluation phase is completed.
|
||||
# Creating with default NaNs for any HEATUP/TRAIN/TEST episode which is not the last in an evaluation phase
|
||||
self.agent_logger.create_signal_value('Evaluation Reward', np.nan, overwrite=False)
|
||||
self.agent_logger.create_signal_value('Shaped Evaluation Reward', np.nan, overwrite=False)
|
||||
self.agent_logger.create_signal_value('Success Rate', np.nan, overwrite=False)
|
||||
self.agent_logger.create_signal_value('Inverse Propensity Score', np.nan, overwrite=False)
|
||||
self.agent_logger.create_signal_value('Direct Method Reward', np.nan, overwrite=False)
|
||||
self.agent_logger.create_signal_value('Doubly Robust', np.nan, overwrite=False)
|
||||
self.agent_logger.create_signal_value('Sequential Doubly Robust', np.nan, overwrite=False)
|
||||
|
||||
for signal in self.episode_signals:
|
||||
self.agent_logger.create_signal_value("{}/Mean".format(signal.name), signal.get_mean())
|
||||
@@ -500,8 +525,7 @@ class Agent(AgentInterface):
|
||||
self.agent_logger.create_signal_value("{}/Min".format(signal.name), signal.get_min())
|
||||
|
||||
# dump
|
||||
if self.current_episode % self.ap.visualization.dump_signals_to_csv_every_x_episodes == 0 \
|
||||
and self.current_episode > 0:
|
||||
if self.current_episode % self.ap.visualization.dump_signals_to_csv_every_x_episodes == 0:
|
||||
self.agent_logger.dump_output_csv()
|
||||
|
||||
def handle_episode_ended(self) -> None:
|
||||
@@ -537,7 +561,8 @@ class Agent(AgentInterface):
|
||||
self.total_reward_in_current_episode >= self.spaces.reward.reward_success_threshold:
|
||||
self.num_successes_across_evaluation_episodes += 1
|
||||
|
||||
if self.ap.visualization.dump_csv:
|
||||
if self.ap.visualization.dump_csv and \
|
||||
self.parent_level_manager.parent_graph_manager.time_metric == TimeTypes.EpisodeNumber:
|
||||
self.update_log()
|
||||
|
||||
if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == "high":
|
||||
@@ -651,18 +676,22 @@ class Agent(AgentInterface):
|
||||
"""
|
||||
loss = 0
|
||||
if self._should_train():
|
||||
self.training_epoch += 1
|
||||
for network in self.networks.values():
|
||||
network.set_is_training(True)
|
||||
|
||||
for training_step in range(self.ap.algorithm.num_consecutive_training_steps):
|
||||
# TODO: this should be network dependent
|
||||
network_parameters = list(self.ap.network_wrappers.values())[0]
|
||||
# TODO: this should be network dependent
|
||||
network_parameters = list(self.ap.network_wrappers.values())[0]
|
||||
|
||||
# we either go sequentially through the entire replay buffer in the batch RL mode,
|
||||
# or sample randomly for the basic RL case.
|
||||
training_schedule = self.call_memory('get_shuffled_data_generator', network_parameters.batch_size) if \
|
||||
self.ap.is_batch_rl_training else [self.call_memory('sample', network_parameters.batch_size) for _ in
|
||||
range(self.ap.algorithm.num_consecutive_training_steps)]
|
||||
|
||||
for batch in training_schedule:
|
||||
# update counters
|
||||
self.training_iteration += 1
|
||||
|
||||
# sample a batch and train on it
|
||||
batch = self.call_memory('sample', network_parameters.batch_size)
|
||||
if self.pre_network_filter is not None:
|
||||
batch = self.pre_network_filter.filter(batch, update_internal_state=False, deep_copy=False)
|
||||
|
||||
@@ -673,6 +702,7 @@ class Agent(AgentInterface):
|
||||
batch = Batch(batch)
|
||||
total_loss, losses, unclipped_grads = self.learn_from_batch(batch)
|
||||
loss += total_loss
|
||||
|
||||
self.unclipped_grads.add_sample(unclipped_grads)
|
||||
|
||||
# TODO: the learning rate decay should be done through the network instead of here
|
||||
@@ -697,6 +727,12 @@ class Agent(AgentInterface):
|
||||
if self.imitation:
|
||||
self.log_to_screen()
|
||||
|
||||
if self.ap.visualization.dump_csv and \
|
||||
self.parent_level_manager.parent_graph_manager.time_metric == TimeTypes.Epoch:
|
||||
# in BatchRL, or imitation learning, the agent never acts, so we have to get the stats out here.
|
||||
# we dump the data out every epoch
|
||||
self.update_log()
|
||||
|
||||
for network in self.networks.values():
|
||||
network.set_is_training(False)
|
||||
|
||||
@@ -1034,3 +1070,12 @@ class Agent(AgentInterface):
|
||||
for network in self.networks.values():
|
||||
savers.update(network.collect_savers(parent_path_suffix))
|
||||
return savers
|
||||
|
||||
def get_current_time(self):
|
||||
pass
|
||||
return {
|
||||
TimeTypes.EpisodeNumber: self.current_episode,
|
||||
TimeTypes.TrainingIteration: self.training_iteration,
|
||||
TimeTypes.EnvironmentSteps: self.total_steps_counter,
|
||||
TimeTypes.WallClockTime: self.agent_logger.get_current_wall_clock_time(),
|
||||
TimeTypes.Epoch: self.training_epoch}[self.parent_level_manager.parent_graph_manager.time_metric]
|
||||
|
||||
@@ -173,3 +173,12 @@ class AgentInterface(object):
|
||||
:return: None
|
||||
"""
|
||||
raise NotImplementedError("")
|
||||
|
||||
def run_off_policy_evaluation(self) -> None:
|
||||
"""
|
||||
Run off-policy evaluation estimators to evaluate the trained policy performance against a dataset.
|
||||
Should only be implemented for off-policy RL algorithms.
|
||||
|
||||
:return: None
|
||||
"""
|
||||
raise NotImplementedError("")
|
||||
|
||||
@@ -64,6 +64,9 @@ class BootstrappedDQNAgent(ValueOptimizationAgent):
|
||||
q_st_plus_1 = result[:self.ap.exploration.architecture_num_q_heads]
|
||||
TD_targets = result[self.ap.exploration.architecture_num_q_heads:]
|
||||
|
||||
# add Q value samples for logging
|
||||
self.q_values.add_sample(TD_targets)
|
||||
|
||||
# initialize with the current prediction so that we will
|
||||
# only update the action that we have actually done in this transition
|
||||
for i in range(self.ap.network_wrappers['main'].batch_size):
|
||||
|
||||
@@ -100,6 +100,9 @@ class CategoricalDQNAgent(ValueOptimizationAgent):
|
||||
(self.networks['main'].online_network, batch.states(network_keys))
|
||||
])
|
||||
|
||||
# add Q value samples for logging
|
||||
self.q_values.add_sample(self.distribution_prediction_to_q_values(TD_targets))
|
||||
|
||||
# select the optimal actions for the next state
|
||||
target_actions = np.argmax(self.distribution_prediction_to_q_values(distributional_q_st_plus_1), axis=1)
|
||||
m = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size))
|
||||
|
||||
@@ -432,3 +432,4 @@ class CompositeAgent(AgentInterface):
|
||||
savers.update(agent.collect_savers(
|
||||
parent_path_suffix="{}.{}".format(parent_path_suffix, self.name)))
|
||||
return savers
|
||||
|
||||
|
||||
@@ -50,6 +50,9 @@ class DDQNAgent(ValueOptimizationAgent):
|
||||
(self.networks['main'].online_network, batch.states(network_keys))
|
||||
])
|
||||
|
||||
# add Q value samples for logging
|
||||
self.q_values.add_sample(TD_targets)
|
||||
|
||||
# initialize with the current prediction so that we will
|
||||
# only update the action that we have actually done in this transition
|
||||
TD_errors = []
|
||||
|
||||
@@ -81,6 +81,9 @@ class DQNAgent(ValueOptimizationAgent):
|
||||
(self.networks['main'].online_network, batch.states(network_keys))
|
||||
])
|
||||
|
||||
# add Q value samples for logging
|
||||
self.q_values.add_sample(TD_targets)
|
||||
|
||||
# only update the action that we have actually done in this transition
|
||||
TD_errors = []
|
||||
for i in range(self.ap.network_wrappers['main'].batch_size):
|
||||
|
||||
@@ -123,6 +123,9 @@ class NStepQAgent(ValueOptimizationAgent, PolicyOptimizationAgent):
|
||||
else:
|
||||
assert True, 'The available values for targets_horizon are: 1-Step, N-Step'
|
||||
|
||||
# add Q value samples for logging
|
||||
self.q_values.add_sample(state_value_head_targets)
|
||||
|
||||
# train
|
||||
result = self.networks['main'].online_network.accumulate_gradients(batch.states(network_keys), [state_value_head_targets])
|
||||
|
||||
|
||||
@@ -88,6 +88,9 @@ class QuantileRegressionDQNAgent(ValueOptimizationAgent):
|
||||
(self.networks['main'].online_network, batch.states(network_keys))
|
||||
])
|
||||
|
||||
# add Q value samples for logging
|
||||
self.q_values.add_sample(self.get_q_values(current_quantiles))
|
||||
|
||||
# get the optimal actions to take for the next states
|
||||
target_actions = np.argmax(self.get_q_values(next_state_quantiles), axis=1)
|
||||
|
||||
|
||||
@@ -95,6 +95,9 @@ class RainbowDQNAgent(CategoricalDQNAgent):
|
||||
(self.networks['main'].online_network, batch.states(network_keys))
|
||||
])
|
||||
|
||||
# add Q value samples for logging
|
||||
self.q_values.add_sample(self.distribution_prediction_to_q_values(TD_targets))
|
||||
|
||||
# only update the action that we have actually done in this transition (using the Double-DQN selected actions)
|
||||
target_actions = ddqn_selected_actions
|
||||
m = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size))
|
||||
|
||||
@@ -13,16 +13,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.agents.agent import Agent
|
||||
from rl_coach.core_types import ActionInfo, StateType
|
||||
from rl_coach.core_types import ActionInfo, StateType, Batch
|
||||
from rl_coach.logger import screen
|
||||
from rl_coach.memories.non_episodic.prioritized_experience_replay import PrioritizedExperienceReplay
|
||||
from rl_coach.spaces import DiscreteActionSpace
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
## This is an abstract agent - there is no learn_from_batch method ##
|
||||
|
||||
@@ -79,10 +80,12 @@ class ValueOptimizationAgent(Agent):
|
||||
# this is for bootstrapped dqn
|
||||
if type(actions_q_values) == list and len(actions_q_values) > 0:
|
||||
actions_q_values = self.exploration_policy.last_action_values
|
||||
actions_q_values = actions_q_values.squeeze()
|
||||
|
||||
# store the q values statistics for logging
|
||||
self.q_values.add_sample(actions_q_values)
|
||||
|
||||
actions_q_values = actions_q_values.squeeze()
|
||||
|
||||
for i, q_value in enumerate(actions_q_values):
|
||||
self.q_value_for_action[i].add_sample(q_value)
|
||||
|
||||
@@ -96,3 +99,74 @@ class ValueOptimizationAgent(Agent):
|
||||
|
||||
def learn_from_batch(self, batch):
|
||||
raise NotImplementedError("ValueOptimizationAgent is an abstract agent. Not to be used directly.")
|
||||
|
||||
def run_off_policy_evaluation(self):
|
||||
"""
|
||||
Run the off-policy evaluation estimators to get a prediction for the performance of the current policy based on
|
||||
an evaluation dataset, which was collected by another policy(ies).
|
||||
:return: None
|
||||
"""
|
||||
assert self.ope_manager
|
||||
dataset_as_episodes = self.call_memory('get_all_complete_episodes_from_to',
|
||||
(self.call_memory('get_last_training_set_episode_id') + 1,
|
||||
self.call_memory('num_complete_episodes')))
|
||||
if len(dataset_as_episodes) == 0:
|
||||
raise ValueError('train_to_eval_ratio is too high causing the evaluation set to be empty. '
|
||||
'Consider decreasing its value.')
|
||||
|
||||
ips, dm, dr, seq_dr = self.ope_manager.evaluate(
|
||||
dataset_as_episodes=dataset_as_episodes,
|
||||
batch_size=self.ap.network_wrappers['main'].batch_size,
|
||||
discount_factor=self.ap.algorithm.discount,
|
||||
reward_model=self.networks['reward_model'].online_network,
|
||||
q_network=self.networks['main'].online_network,
|
||||
network_keys=list(self.ap.network_wrappers['main'].input_embedders_parameters.keys()))
|
||||
|
||||
# get the estimators out to the screen
|
||||
log = OrderedDict()
|
||||
log['Epoch'] = self.training_epoch
|
||||
log['IPS'] = ips
|
||||
log['DM'] = dm
|
||||
log['DR'] = dr
|
||||
log['Sequential-DR'] = seq_dr
|
||||
screen.log_dict(log, prefix='Off-Policy Evaluation')
|
||||
|
||||
# get the estimators out to dashboard
|
||||
self.agent_logger.set_current_time(self.get_current_time() + 1)
|
||||
self.agent_logger.create_signal_value('Inverse Propensity Score', ips)
|
||||
self.agent_logger.create_signal_value('Direct Method Reward', dm)
|
||||
self.agent_logger.create_signal_value('Doubly Robust', dr)
|
||||
self.agent_logger.create_signal_value('Sequential Doubly Robust', seq_dr)
|
||||
|
||||
def improve_reward_model(self, epochs: int):
|
||||
"""
|
||||
Train a reward model to be used by the doubly-robust estimator
|
||||
|
||||
:param epochs: The total number of epochs to use for training a reward model
|
||||
:return: None
|
||||
"""
|
||||
batch_size = self.ap.network_wrappers['reward_model'].batch_size
|
||||
network_keys = self.ap.network_wrappers['reward_model'].input_embedders_parameters.keys()
|
||||
|
||||
# this is fitted from the training dataset
|
||||
for epoch in range(epochs):
|
||||
loss = 0
|
||||
for i, batch in enumerate(self.call_memory('get_shuffled_data_generator', batch_size)):
|
||||
batch = Batch(batch)
|
||||
current_rewards_prediction_for_all_actions = self.networks['reward_model'].online_network.predict(batch.states(network_keys))
|
||||
current_rewards_prediction_for_all_actions[range(batch_size), batch.actions()] = batch.rewards()
|
||||
loss += self.networks['reward_model'].train_and_sync_networks(
|
||||
batch.states(network_keys), current_rewards_prediction_for_all_actions)[0]
|
||||
# print(self.networks['reward_model'].online_network.predict(batch.states(network_keys))[0])
|
||||
|
||||
log = OrderedDict()
|
||||
log['Epoch'] = epoch
|
||||
log['loss'] = loss / int(self.call_memory('num_transitions_in_complete_episodes') / batch_size)
|
||||
screen.log_dict(log, prefix='Training Reward Model')
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -50,6 +50,11 @@ class QHead(Head):
|
||||
# Standard Q Network
|
||||
self.output = self.dense_layer(self.num_actions)(input_layer, name='output')
|
||||
|
||||
# TODO add this to other Q heads. e.g. dueling.
|
||||
temperature = self.ap.network_wrappers[self.network_name].softmax_temperature
|
||||
temperature_scaled_outputs = self.output / temperature
|
||||
self.softmax = tf.nn.softmax(temperature_scaled_outputs, name="softmax")
|
||||
|
||||
def __str__(self):
|
||||
result = [
|
||||
"Dense (num outputs = {})".format(self.num_actions)
|
||||
|
||||
@@ -42,7 +42,7 @@ class GlobalVariableSaver(Saver):
|
||||
self._variable_placeholders.append(variable_placeholder)
|
||||
self._variable_update_ops.append(v.assign(variable_placeholder))
|
||||
|
||||
self._saver = tf.train.Saver(self._variables)
|
||||
self._saver = tf.train.Saver(self._variables, max_to_keep=None)
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
|
||||
@@ -291,7 +291,8 @@ class NetworkParameters(Parameters):
|
||||
batch_size=32,
|
||||
replace_mse_with_huber_loss=False,
|
||||
create_target_network=False,
|
||||
tensorflow_support=True):
|
||||
tensorflow_support=True,
|
||||
softmax_temperature=1):
|
||||
"""
|
||||
:param force_cpu:
|
||||
Force the neural networks to run on the CPU even if a GPU is available
|
||||
@@ -374,6 +375,8 @@ class NetworkParameters(Parameters):
|
||||
online network at will.
|
||||
:param tensorflow_support:
|
||||
A flag which specifies if the network is supported by the TensorFlow framework.
|
||||
:param softmax_temperature:
|
||||
If a softmax is present in the network head output, use this temperature
|
||||
"""
|
||||
super().__init__()
|
||||
self.framework = Frameworks.tensorflow
|
||||
@@ -404,17 +407,20 @@ class NetworkParameters(Parameters):
|
||||
self.heads_parameters = heads_parameters
|
||||
self.use_separate_networks_per_head = use_separate_networks_per_head
|
||||
self.optimizer_type = optimizer_type
|
||||
self.optimizer_epsilon = optimizer_epsilon
|
||||
self.adam_optimizer_beta1 = adam_optimizer_beta1
|
||||
self.adam_optimizer_beta2 = adam_optimizer_beta2
|
||||
self.rms_prop_optimizer_decay = rms_prop_optimizer_decay
|
||||
self.batch_size = batch_size
|
||||
self.replace_mse_with_huber_loss = replace_mse_with_huber_loss
|
||||
self.create_target_network = create_target_network
|
||||
|
||||
# Framework support
|
||||
self.tensorflow_support = tensorflow_support
|
||||
|
||||
# Hyper-Parameter values
|
||||
self.optimizer_epsilon = optimizer_epsilon
|
||||
self.adam_optimizer_beta1 = adam_optimizer_beta1
|
||||
self.adam_optimizer_beta2 = adam_optimizer_beta2
|
||||
self.rms_prop_optimizer_decay = rms_prop_optimizer_decay
|
||||
self.batch_size = batch_size
|
||||
self.softmax_temperature = softmax_temperature
|
||||
|
||||
|
||||
class NetworkComponentParameters(Parameters):
|
||||
def __init__(self, dense_layer):
|
||||
@@ -544,6 +550,7 @@ class AgentParameters(Parameters):
|
||||
self.is_a_highest_level_agent = True
|
||||
self.is_a_lowest_level_agent = True
|
||||
self.task_parameters = None
|
||||
self.is_batch_rl_training = False
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
import copy
|
||||
from enum import Enum
|
||||
@@ -38,6 +38,17 @@ class GoalTypes(Enum):
|
||||
Measurements = 4
|
||||
|
||||
|
||||
Record = namedtuple('Record', ['name', 'label'])
|
||||
|
||||
|
||||
class TimeTypes(Enum):
|
||||
EpisodeNumber = Record(name='Episode #', label='Episode #')
|
||||
TrainingIteration = Record(name='Training Iter', label='Training Iteration')
|
||||
EnvironmentSteps = Record(name='Total steps', label='Total steps (per worker)')
|
||||
WallClockTime = Record(name='Wall-Clock Time', label='Wall-Clock Time (minutes)')
|
||||
Epoch = Record(name='Epoch', label='Epoch #')
|
||||
|
||||
|
||||
# step methods
|
||||
|
||||
class StepMethod(object):
|
||||
@@ -249,6 +260,9 @@ class Transition(object):
|
||||
.format(self.info.keys(), new_info.keys()))
|
||||
self.info.update(new_info)
|
||||
|
||||
def update_info(self, new_info: Dict[str, Any]) -> None:
|
||||
self.info.update(new_info)
|
||||
|
||||
def __copy__(self):
|
||||
new_transition = type(self)()
|
||||
new_transition.__dict__.update(self.__dict__)
|
||||
@@ -867,3 +881,16 @@ class SelectedPhaseOnlyDumpFilter(object):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
# TODO move to a NamedTuple, once we move to Python3.6
|
||||
# https://stackoverflow.com/questions/34269772/type-hints-in-namedtuple/34269877
|
||||
class CsvDataset(object):
|
||||
def __init__(self, filepath: str, is_episodic: bool = True):
|
||||
self.filepath = filepath
|
||||
self.is_episodic = is_episodic
|
||||
|
||||
|
||||
class PickledReplayBuffer(object):
|
||||
def __init__(self, filepath: str):
|
||||
self.filepath = filepath
|
||||
|
||||
@@ -273,6 +273,11 @@ def create_files_signal(files, use_dir_name=False):
|
||||
files_selector.value = filenames[0]
|
||||
selected_file = new_signal_files[0]
|
||||
|
||||
# update x axis according to the file's default x-axis (which is the index, and thus the first column)
|
||||
idx = x_axis_options.index(new_signal_files[0].csv.columns[0])
|
||||
change_x_axis(idx)
|
||||
x_axis_selector.active = idx
|
||||
|
||||
|
||||
def display_files(files):
|
||||
pause_auto_update()
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import os
|
||||
from genericpath import isdir, isfile
|
||||
@@ -26,12 +26,14 @@ import tkinter as tk
|
||||
from tkinter import filedialog
|
||||
import colorsys
|
||||
|
||||
from rl_coach.core_types import TimeTypes
|
||||
|
||||
patches = {}
|
||||
signals_files = {}
|
||||
selected_file = None
|
||||
x_axis = ['Episode #']
|
||||
x_axis_options = ['Episode #', 'Total steps', 'Wall-Clock Time']
|
||||
x_axis_labels = ['Episode #', 'Total steps (per worker)', 'Wall-Clock Time (minutes)']
|
||||
x_axis_options = [time_type.value.name for time_type in TimeTypes]
|
||||
x_axis_labels = [time_type.value.label for time_type in TimeTypes]
|
||||
current_color = 0
|
||||
|
||||
# spinner
|
||||
|
||||
@@ -67,7 +67,12 @@ class SignalsFile(SignalsFileBase):
|
||||
for k, v in new_csv.isna().all().items():
|
||||
if v and k not in x_axis_options:
|
||||
del new_csv[k]
|
||||
new_csv.fillna(value=0, inplace=True)
|
||||
|
||||
# only fill the missing values that the previous interploation did not deal with (usually the ones before the
|
||||
# first update to the signal was made). We do so again by interpolation (averaging the previous and the next
|
||||
# values)
|
||||
new_csv = ((new_csv.fillna(method='bfill') + new_csv.fillna(method='ffill')) / 2).fillna(method='bfill').fillna(
|
||||
method='ffill')
|
||||
|
||||
self.csv = new_csv
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from typing import Tuple, List
|
||||
from rl_coach.base_parameters import AgentParameters, VisualizationParameters, TaskParameters, \
|
||||
PresetValidationParameters
|
||||
from rl_coach.environments.environment import EnvironmentParameters, Environment
|
||||
from rl_coach.filters.filter import NoInputFilter, NoOutputFilter
|
||||
from rl_coach.graph_managers.graph_manager import GraphManager, ScheduleParameters
|
||||
from rl_coach.level_manager import LevelManager
|
||||
from rl_coach.utils import short_dynamic_import
|
||||
@@ -31,17 +32,28 @@ class BasicRLGraphManager(GraphManager):
|
||||
def __init__(self, agent_params: AgentParameters, env_params: EnvironmentParameters,
|
||||
schedule_params: ScheduleParameters,
|
||||
vis_params: VisualizationParameters=VisualizationParameters(),
|
||||
preset_validation_params: PresetValidationParameters = PresetValidationParameters()):
|
||||
super().__init__('simple_rl_graph', schedule_params, vis_params)
|
||||
preset_validation_params: PresetValidationParameters = PresetValidationParameters(),
|
||||
name='simple_rl_graph'):
|
||||
super().__init__(name, schedule_params, vis_params)
|
||||
self.agent_params = agent_params
|
||||
self.env_params = env_params
|
||||
self.preset_validation_params = preset_validation_params
|
||||
|
||||
self.agent_params.visualization = vis_params
|
||||
|
||||
if self.agent_params.input_filter is None:
|
||||
self.agent_params.input_filter = env_params.default_input_filter()
|
||||
if env_params is not None:
|
||||
self.agent_params.input_filter = env_params.default_input_filter()
|
||||
else:
|
||||
# In cases where there is no environment (e.g. batch-rl and imitation learning), there is nowhere to get
|
||||
# a default filter from. So using a default no-filter.
|
||||
# When there is no environment, the user is expected to define input/output filters (if required) using
|
||||
# the preset.
|
||||
self.agent_params.input_filter = NoInputFilter()
|
||||
if self.agent_params.output_filter is None:
|
||||
self.agent_params.output_filter = env_params.default_output_filter()
|
||||
if env_params is not None:
|
||||
self.agent_params.output_filter = env_params.default_output_filter()
|
||||
else:
|
||||
self.agent_params.output_filter = NoOutputFilter()
|
||||
|
||||
def _create_graph(self, task_parameters: TaskParameters) -> Tuple[List[LevelManager], List[Environment]]:
|
||||
# environment loading
|
||||
|
||||
180
rl_coach/graph_managers/batch_rl_graph_manager.py
Normal file
180
rl_coach/graph_managers/batch_rl_graph_manager.py
Normal file
@@ -0,0 +1,180 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from copy import deepcopy
|
||||
from typing import Tuple, List, Union
|
||||
|
||||
from rl_coach.agents.dqn_agent import DQNAgentParameters
|
||||
from rl_coach.base_parameters import AgentParameters, VisualizationParameters, TaskParameters, \
|
||||
PresetValidationParameters
|
||||
from rl_coach.core_types import RunPhase
|
||||
from rl_coach.environments.environment import EnvironmentParameters, Environment
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
|
||||
from rl_coach.level_manager import LevelManager
|
||||
from rl_coach.logger import screen
|
||||
from rl_coach.spaces import SpacesDefinition
|
||||
from rl_coach.utils import short_dynamic_import
|
||||
|
||||
from rl_coach.memories.episodic import EpisodicExperienceReplayParameters
|
||||
|
||||
from rl_coach.core_types import TimeTypes
|
||||
|
||||
|
||||
class BatchRLGraphManager(BasicRLGraphManager):
|
||||
"""
|
||||
A batch RL graph manager creates scenario of learning from a dataset without a simulator.
|
||||
"""
|
||||
def __init__(self, agent_params: AgentParameters, env_params: Union[EnvironmentParameters, None],
|
||||
schedule_params: ScheduleParameters,
|
||||
vis_params: VisualizationParameters = VisualizationParameters(),
|
||||
preset_validation_params: PresetValidationParameters = PresetValidationParameters(),
|
||||
name='batch_rl_graph', spaces_definition: SpacesDefinition = None, reward_model_num_epochs: int = 100,
|
||||
train_to_eval_ratio: float = 0.8):
|
||||
|
||||
super().__init__(agent_params, env_params, schedule_params, vis_params, preset_validation_params, name)
|
||||
self.is_batch_rl = True
|
||||
self.time_metric = TimeTypes.Epoch
|
||||
self.reward_model_num_epochs = reward_model_num_epochs
|
||||
self.spaces_definition = spaces_definition
|
||||
|
||||
# setting this here to make sure that, by default, train_to_eval_ratio gets a value < 1
|
||||
# (its default value in the memory is 1)
|
||||
self.agent_params.memory.train_to_eval_ratio = train_to_eval_ratio
|
||||
|
||||
def _create_graph(self, task_parameters: TaskParameters) -> Tuple[List[LevelManager], List[Environment]]:
|
||||
if self.env_params:
|
||||
# environment loading
|
||||
self.env_params.seed = task_parameters.seed
|
||||
self.env_params.experiment_path = task_parameters.experiment_path
|
||||
env = short_dynamic_import(self.env_params.path)(**self.env_params.__dict__,
|
||||
visualization_parameters=self.visualization_parameters)
|
||||
else:
|
||||
env = None
|
||||
|
||||
# Only DQN variants are supported at this point.
|
||||
assert(isinstance(self.agent_params, DQNAgentParameters))
|
||||
# Only Episodic memories are supported,
|
||||
# for evaluating the sequential doubly robust estimator
|
||||
assert(isinstance(self.agent_params.memory, EpisodicExperienceReplayParameters))
|
||||
|
||||
# agent loading
|
||||
self.agent_params.task_parameters = task_parameters # TODO: this should probably be passed in a different way
|
||||
self.agent_params.name = "agent"
|
||||
self.agent_params.is_batch_rl_training = True
|
||||
|
||||
# user hasn't defined params for the reward model. we will use the same params as used for the 'main' network.
|
||||
if 'reward_model' not in self.agent_params.network_wrappers:
|
||||
self.agent_params.network_wrappers['reward_model'] = deepcopy(self.agent_params.network_wrappers['main'])
|
||||
|
||||
agent = short_dynamic_import(self.agent_params.path)(self.agent_params)
|
||||
|
||||
if not env and not self.agent_params.memory.load_memory_from_file_path:
|
||||
screen.warning("A BatchRLGraph requires setting a dataset to load into the agent's memory or alternatively "
|
||||
"using an environment to create a (random) dataset from. This agent should only be used for "
|
||||
"inference. ")
|
||||
# set level manager
|
||||
level_manager = LevelManager(agents=agent, environment=env, name="main_level",
|
||||
spaces_definition=self.spaces_definition)
|
||||
|
||||
if env:
|
||||
return [level_manager], [env]
|
||||
else:
|
||||
return [level_manager], []
|
||||
|
||||
def improve(self):
|
||||
"""
|
||||
The main loop of the run.
|
||||
Defined in the following steps:
|
||||
1. Heatup
|
||||
2. Repeat:
|
||||
2.1. Repeat:
|
||||
2.1.1. Train
|
||||
2.1.2. Possibly save checkpoint
|
||||
2.2. Evaluate
|
||||
:return: None
|
||||
"""
|
||||
|
||||
self.verify_graph_was_created()
|
||||
|
||||
# initialize the network parameters from the global network
|
||||
self.sync()
|
||||
|
||||
# TODO a bug in heatup where the last episode run is not fed into the ER. e.g. asked for 1024 heatup steps,
|
||||
# last ran episode ended increased the total to 1040 steps, but the ER will contain only 1014 steps.
|
||||
# The last episode is not there. Is this a bug in my changes or also on master?
|
||||
|
||||
# Creating a dataset during the heatup phase is useful mainly for tutorial and debug purposes. If we have both
|
||||
# an environment and a dataset to load from, we will use the environment only for evaluating the policy,
|
||||
# and will not run heatup.
|
||||
|
||||
# heatup
|
||||
if self.env_params is not None and not self.agent_params.memory.load_memory_from_file_path:
|
||||
self.heatup(self.heatup_steps)
|
||||
|
||||
self.improve_reward_model()
|
||||
|
||||
# improve
|
||||
if self.task_parameters.task_index is not None:
|
||||
screen.log_title("Starting to improve {} task index {}".format(self.name, self.task_parameters.task_index))
|
||||
else:
|
||||
screen.log_title("Starting to improve {}".format(self.name))
|
||||
|
||||
# the outer most training loop
|
||||
improve_steps_end = self.total_steps_counters[RunPhase.TRAIN] + self.improve_steps
|
||||
while self.total_steps_counters[RunPhase.TRAIN] < improve_steps_end:
|
||||
# TODO if we have an environment, do we want to use it to have the agent train against, and use the
|
||||
# collected replay buffer as a dataset? (as oppose to what we currently have, where the dataset is built
|
||||
# during heatup, and is composed on random actions)
|
||||
# perform several steps of training
|
||||
if self.steps_between_evaluation_periods.num_steps > 0:
|
||||
with self.phase_context(RunPhase.TRAIN):
|
||||
self.reset_internal_state(force_environment_reset=True)
|
||||
|
||||
steps_between_evaluation_periods_end = self.current_step_counter + self.steps_between_evaluation_periods
|
||||
while self.current_step_counter < steps_between_evaluation_periods_end:
|
||||
self.train()
|
||||
|
||||
# the output of batch RL training is always a checkpoint of the trained agent. we always save a checkpoint,
|
||||
# each epoch, regardless of the user's command line arguments.
|
||||
self.save_checkpoint()
|
||||
|
||||
# run off-policy evaluation estimators to evaluate the agent's performance against the dataset
|
||||
self.run_off_policy_evaluation()
|
||||
|
||||
if self.env_params is not None and self.evaluate(self.evaluation_steps):
|
||||
# if we do have a simulator (although we are in a batch RL setting we might have a simulator, e.g. when
|
||||
# demonstrating the batch RL use-case using one of the existing Coach environments),
|
||||
# we might want to evaluate vs. the simulator every now and then.
|
||||
break
|
||||
|
||||
def improve_reward_model(self):
|
||||
"""
|
||||
|
||||
:return:
|
||||
"""
|
||||
screen.log_title("Training a regression model for estimating MDP rewards")
|
||||
self.level_managers[0].agents['agent'].improve_reward_model(epochs=self.reward_model_num_epochs)
|
||||
|
||||
def run_off_policy_evaluation(self):
|
||||
"""
|
||||
Run off-policy evaluation estimators to evaluate the trained policy performance against the dataset
|
||||
:return:
|
||||
"""
|
||||
self.level_managers[0].agents['agent'].run_off_policy_evaluation()
|
||||
|
||||
|
||||
|
||||
@@ -38,6 +38,8 @@ from rl_coach.data_stores.data_store_impl import get_data_store as data_store_cr
|
||||
from rl_coach.memories.backend.memory_impl import get_memory_backend
|
||||
from rl_coach.data_stores.data_store import SyncFiles
|
||||
|
||||
from rl_coach.core_types import TimeTypes
|
||||
|
||||
|
||||
class ScheduleParameters(Parameters):
|
||||
def __init__(self):
|
||||
@@ -119,6 +121,8 @@ class GraphManager(object):
|
||||
self.checkpoint_state_updater = None
|
||||
self.graph_logger = Logger()
|
||||
self.data_store = None
|
||||
self.is_batch_rl = False
|
||||
self.time_metric = TimeTypes.EpisodeNumber
|
||||
|
||||
def create_graph(self, task_parameters: TaskParameters=TaskParameters()):
|
||||
self.graph_creation_time = time.time()
|
||||
@@ -445,16 +449,17 @@ class GraphManager(object):
|
||||
result = self.top_level_manager.step(None)
|
||||
steps_end = self.environments[0].total_steps_counter
|
||||
|
||||
# add the diff between the total steps before and after stepping, such that environment initialization steps
|
||||
# (like in Atari) will not be counted.
|
||||
# We add at least one step so that even if no steps were made (in case no actions are taken in the training
|
||||
# phase), the loop will end eventually.
|
||||
self.current_step_counter[EnvironmentSteps] += max(1, steps_end - steps_begin)
|
||||
|
||||
if result.game_over:
|
||||
self.handle_episode_ended()
|
||||
self.reset_required = True
|
||||
|
||||
self.current_step_counter[EnvironmentSteps] += (steps_end - steps_begin)
|
||||
|
||||
# if no steps were made (can happen when no actions are taken while in the TRAIN phase, either in batch RL
|
||||
# or in imitation learning), we force end the loop, so that it will not continue forever.
|
||||
if (steps_end - steps_begin) == 0:
|
||||
break
|
||||
|
||||
def train_and_act(self, steps: StepMethod) -> None:
|
||||
"""
|
||||
Train the agent by doing several acting steps followed by several training steps continually
|
||||
@@ -472,9 +477,9 @@ class GraphManager(object):
|
||||
while self.current_step_counter < count_end:
|
||||
# The actual number of steps being done on the environment
|
||||
# is decided by the agent, though this inner loop always
|
||||
# takes at least one step in the environment. Depending on
|
||||
# internal counters and parameters, it doesn't always train
|
||||
# or save checkpoints.
|
||||
# takes at least one step in the environment (at the GraphManager level).
|
||||
# The agent might also decide to skip acting altogether.
|
||||
# Depending on internal counters and parameters, it doesn't always train or save checkpoints.
|
||||
self.act(EnvironmentSteps(1))
|
||||
self.train()
|
||||
self.occasionally_save_checkpoint()
|
||||
|
||||
@@ -40,9 +40,10 @@ class LevelManager(EnvironmentInterface):
|
||||
name: str,
|
||||
agents: Union['Agent', CompositeAgent, Dict[str, Union['Agent', CompositeAgent]]],
|
||||
environment: Union['LevelManager', Environment],
|
||||
real_environment: Environment=None,
|
||||
steps_limit: EnvironmentSteps=EnvironmentSteps(1),
|
||||
should_reset_agent_state_after_time_limit_passes: bool=False
|
||||
real_environment: Environment = None,
|
||||
steps_limit: EnvironmentSteps = EnvironmentSteps(1),
|
||||
should_reset_agent_state_after_time_limit_passes: bool = False,
|
||||
spaces_definition: SpacesDefinition = None
|
||||
):
|
||||
"""
|
||||
A level manager controls a single or multiple composite agents and a single environment.
|
||||
@@ -56,6 +57,7 @@ class LevelManager(EnvironmentInterface):
|
||||
:param steps_limit: the number of time steps to run when stepping the internal components
|
||||
:param should_reset_agent_state_after_time_limit_passes: reset the agent after stepping for steps_limit
|
||||
:param name: the level's name
|
||||
:param spaces_definition: external definition of spaces for when we don't have an environment (e.g. batch-rl)
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@@ -85,9 +87,11 @@ class LevelManager(EnvironmentInterface):
|
||||
|
||||
if not isinstance(self.steps_limit, EnvironmentSteps):
|
||||
raise ValueError("The num consecutive steps for acting must be defined in terms of environment steps")
|
||||
self.build()
|
||||
self.build(spaces_definition)
|
||||
|
||||
# there are cases where we don't have an environment. e.g. in batch-rl or in imitation learning.
|
||||
self.last_env_response = self.real_environment.last_env_response if self.real_environment else None
|
||||
|
||||
self.last_env_response = self.real_environment.last_env_response
|
||||
self.parent_graph_manager = None
|
||||
|
||||
def handle_episode_ended(self) -> None:
|
||||
@@ -100,13 +104,13 @@ class LevelManager(EnvironmentInterface):
|
||||
def reset_internal_state(self, force_environment_reset: bool = False) -> EnvResponse:
|
||||
"""
|
||||
Reset the environment episode parameters
|
||||
:param force_enviro nment_reset: in some cases, resetting the environment can be suppressed by the environment
|
||||
:param force_environment_reset: in some cases, resetting the environment can be suppressed by the environment
|
||||
itself. This flag allows force the reset.
|
||||
:return: the environment response as returned in get_last_env_response
|
||||
"""
|
||||
[agent.reset_internal_state() for agent in self.agents.values()]
|
||||
self.reset_required = False
|
||||
if self.real_environment.current_episode_steps_counter == 0:
|
||||
if self.real_environment and self.real_environment.current_episode_steps_counter == 0:
|
||||
self.last_env_response = self.real_environment.last_env_response
|
||||
return self.last_env_response
|
||||
|
||||
@@ -136,19 +140,27 @@ class LevelManager(EnvironmentInterface):
|
||||
"""
|
||||
return {k: ActionInfo(v) for k, v in self.get_random_action().items()}
|
||||
|
||||
def build(self) -> None:
|
||||
def build(self, spaces_definition: SpacesDefinition = None) -> None:
|
||||
"""
|
||||
Build all the internal components of the level manager (composite agents and environment).
|
||||
:param spaces_definition: external definition of spaces for when we don't have an environment (e.g. batch-rl)
|
||||
:return: None
|
||||
"""
|
||||
# TODO: move the spaces definition class to the environment?
|
||||
action_space = self.environment.action_space
|
||||
if isinstance(action_space, dict): # TODO: shouldn't be a dict
|
||||
action_space = list(action_space.values())[0]
|
||||
spaces = SpacesDefinition(state=self.real_environment.state_space,
|
||||
goal=self.real_environment.goal_space, # in HRL the agent might want to override this
|
||||
action=action_space,
|
||||
reward=self.real_environment.reward_space)
|
||||
if spaces_definition is None:
|
||||
# normally the spaces are defined by the environment, and we only gather these here
|
||||
action_space = self.environment.action_space
|
||||
|
||||
if isinstance(action_space, dict): # TODO: shouldn't be a dict
|
||||
action_space = list(action_space.values())[0]
|
||||
|
||||
spaces = SpacesDefinition(state=self.real_environment.state_space,
|
||||
goal=self.real_environment.goal_space,
|
||||
# in HRL the agent might want to override this
|
||||
action=action_space,
|
||||
reward=self.real_environment.reward_space)
|
||||
else:
|
||||
spaces = spaces_definition
|
||||
|
||||
[agent.set_environment_parameters(spaces) for agent in self.agents.values()]
|
||||
|
||||
def setup_logger(self) -> None:
|
||||
|
||||
@@ -185,6 +185,9 @@ class BaseLogger(object):
|
||||
self.time = time
|
||||
|
||||
def create_signal_value(self, signal_name, value, overwrite=True, time=None):
|
||||
if self.index_name == signal_name:
|
||||
return False # make sure that we don't create duplicate signals
|
||||
|
||||
if self.last_line_idx_written_to_csv != 0:
|
||||
assert signal_name in self.data.columns
|
||||
|
||||
@@ -227,12 +230,15 @@ class BaseLogger(object):
|
||||
|
||||
self.last_line_idx_written_to_csv = len(self.data.index)
|
||||
|
||||
def update_wall_clock_time(self, index):
|
||||
def get_current_wall_clock_time(self):
|
||||
if self.start_time:
|
||||
self.create_signal_value('Wall-Clock Time', time.time() - self.start_time, time=index)
|
||||
return time.time() - self.start_time
|
||||
else:
|
||||
self.create_signal_value('Wall-Clock Time', 0, time=index)
|
||||
self.start_time = time.time()
|
||||
return 0
|
||||
|
||||
def update_wall_clock_time(self, index):
|
||||
self.create_signal_value('Wall-Clock Time', self.get_current_wall_clock_time(), time=index)
|
||||
|
||||
|
||||
class EpisodeLogger(BaseLogger):
|
||||
@@ -263,10 +269,13 @@ class EpisodeLogger(BaseLogger):
|
||||
|
||||
|
||||
class Logger(BaseLogger):
|
||||
def __init__(self):
|
||||
def __init__(self, index_name='Episode #'):
|
||||
super().__init__()
|
||||
self.doc_path = ''
|
||||
self.index_name = 'Episode #'
|
||||
self.index_name = index_name
|
||||
|
||||
def set_index_name(self, index_name):
|
||||
self.index_name = index_name
|
||||
|
||||
def set_logger_filenames(self, _experiments_path, logger_prefix='', task_id=None, add_timestamp=False, filename=''):
|
||||
self.experiments_path = _experiments_path
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -13,14 +14,18 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import ast
|
||||
|
||||
from typing import List, Tuple, Union, Dict, Any
|
||||
|
||||
import pandas as pd
|
||||
from typing import List, Tuple, Union
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
from rl_coach.core_types import Transition, Episode
|
||||
from rl_coach.logger import screen
|
||||
from rl_coach.memories.memory import Memory, MemoryGranularity, MemoryParameters
|
||||
from rl_coach.utils import ReaderWriterLock
|
||||
from rl_coach.utils import ReaderWriterLock, ProgressBar
|
||||
from rl_coach.core_types import CsvDataset
|
||||
|
||||
|
||||
class EpisodicExperienceReplayParameters(MemoryParameters):
|
||||
@@ -28,6 +33,7 @@ class EpisodicExperienceReplayParameters(MemoryParameters):
|
||||
super().__init__()
|
||||
self.max_size = (MemoryGranularity.Transitions, 1000000)
|
||||
self.n_step = -1
|
||||
self.train_to_eval_ratio = 1 # for OPE we'll want a value < 1
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
@@ -40,7 +46,9 @@ class EpisodicExperienceReplay(Memory):
|
||||
calculations of total return and other values that depend on the sequential behavior of the transitions
|
||||
in the episode.
|
||||
"""
|
||||
def __init__(self, max_size: Tuple[MemoryGranularity, int]=(MemoryGranularity.Transitions, 1000000), n_step=-1):
|
||||
|
||||
def __init__(self, max_size: Tuple[MemoryGranularity, int] = (MemoryGranularity.Transitions, 1000000), n_step=-1,
|
||||
train_to_eval_ratio: int = 1):
|
||||
"""
|
||||
:param max_size: the maximum number of transitions or episodes to hold in the memory
|
||||
"""
|
||||
@@ -52,8 +60,11 @@ class EpisodicExperienceReplay(Memory):
|
||||
self._num_transitions = 0
|
||||
self._num_transitions_in_complete_episodes = 0
|
||||
self.reader_writer_lock = ReaderWriterLock()
|
||||
self.last_training_set_episode_id = None # used in batch-rl
|
||||
self.last_training_set_transition_id = None # used in batch-rl
|
||||
self.train_to_eval_ratio = train_to_eval_ratio # used in batch-rl
|
||||
|
||||
def length(self, lock: bool=False) -> int:
|
||||
def length(self, lock: bool = False) -> int:
|
||||
"""
|
||||
Get the number of episodes in the ER (even if they are not complete)
|
||||
"""
|
||||
@@ -75,6 +86,9 @@ class EpisodicExperienceReplay(Memory):
|
||||
def num_transitions_in_complete_episodes(self):
|
||||
return self._num_transitions_in_complete_episodes
|
||||
|
||||
def get_last_training_set_episode_id(self):
|
||||
return self.last_training_set_episode_id
|
||||
|
||||
def sample(self, size: int, is_consecutive_transitions=False) -> List[Transition]:
|
||||
"""
|
||||
Sample a batch of transitions from the replay buffer. If the requested size is larger than the number
|
||||
@@ -92,7 +106,7 @@ class EpisodicExperienceReplay(Memory):
|
||||
batch = self._buffer[episode_idx].transitions
|
||||
else:
|
||||
transition_idx = np.random.randint(size, self._buffer[episode_idx].length())
|
||||
batch = self._buffer[episode_idx].transitions[transition_idx-size:transition_idx]
|
||||
batch = self._buffer[episode_idx].transitions[transition_idx - size:transition_idx]
|
||||
else:
|
||||
transitions_idx = np.random.randint(self.num_transitions_in_complete_episodes(), size=size)
|
||||
batch = [self.transitions[i] for i in transitions_idx]
|
||||
@@ -105,6 +119,79 @@ class EpisodicExperienceReplay(Memory):
|
||||
|
||||
return batch
|
||||
|
||||
def get_episode_for_transition(self, transition: Transition) -> (int, Episode):
|
||||
"""
|
||||
Get the episode from which that transition came from.
|
||||
:param transition: The transition to lookup the episode for
|
||||
:return: (Episode number, the episode) or (-1, None) if could not find a matching episode.
|
||||
"""
|
||||
|
||||
for i, episode in enumerate(self._buffer):
|
||||
if transition in episode.transitions:
|
||||
return i, episode
|
||||
return -1, None
|
||||
|
||||
def shuffle_episodes(self):
|
||||
"""
|
||||
Shuffle all the episodes in the replay buffer
|
||||
:return:
|
||||
"""
|
||||
random.shuffle(self._buffer)
|
||||
self.transitions = [t for e in self._buffer for t in e.transitions]
|
||||
|
||||
def get_shuffled_data_generator(self, size: int) -> List[Transition]:
|
||||
"""
|
||||
Get an generator for iterating through the shuffled replay buffer, for processing the data in epochs.
|
||||
If the requested size is larger than the number of samples available in the replay buffer then the batch will
|
||||
return empty. The last returned batch may be smaller than the size requested, to accommodate for all the
|
||||
transitions in the replay buffer.
|
||||
|
||||
:param size: the size of the batch to return
|
||||
:return: a batch (list) of selected transitions from the replay buffer
|
||||
"""
|
||||
self.reader_writer_lock.lock_writing()
|
||||
if self.last_training_set_transition_id is None:
|
||||
if self.train_to_eval_ratio < 0 or self.train_to_eval_ratio >= 1:
|
||||
raise ValueError('train_to_eval_ratio should be in the (0, 1] range.')
|
||||
|
||||
transition = self.transitions[round(self.train_to_eval_ratio * self.num_transitions_in_complete_episodes())]
|
||||
episode_num, episode = self.get_episode_for_transition(transition)
|
||||
self.last_training_set_episode_id = episode_num
|
||||
self.last_training_set_transition_id = \
|
||||
len([t for e in self.get_all_complete_episodes_from_to(0, self.last_training_set_episode_id + 1) for t in e])
|
||||
|
||||
shuffled_transition_indices = list(range(self.last_training_set_transition_id))
|
||||
random.shuffle(shuffled_transition_indices)
|
||||
|
||||
# we deliberately drop some of the ending data which is left after dividing to batches of size `size`
|
||||
# for i in range(math.ceil(len(shuffled_transition_indices) / size)):
|
||||
for i in range(int(len(shuffled_transition_indices) / size)):
|
||||
sample_data = [self.transitions[j] for j in shuffled_transition_indices[i * size: (i + 1) * size]]
|
||||
self.reader_writer_lock.release_writing()
|
||||
|
||||
yield sample_data
|
||||
|
||||
def get_all_complete_episodes_transitions(self) -> List[Transition]:
|
||||
"""
|
||||
Get all the transitions from all the complete episodes in the buffer
|
||||
:return: a list of transitions
|
||||
"""
|
||||
return self.transitions[:self.num_transitions_in_complete_episodes()]
|
||||
|
||||
def get_all_complete_episodes(self) -> List[Episode]:
|
||||
"""
|
||||
Get all the transitions from all the complete episodes in the buffer
|
||||
:return: a list of transitions
|
||||
"""
|
||||
return self.get_all_complete_episodes_from_to(0, self.num_complete_episodes())
|
||||
|
||||
def get_all_complete_episodes_from_to(self, start_episode_id, end_episode_id) -> List[Episode]:
|
||||
"""
|
||||
Get all the transitions from all the complete episodes in the buffer matching the given episode range
|
||||
:return: a list of transitions
|
||||
"""
|
||||
return self._buffer[start_episode_id:end_episode_id]
|
||||
|
||||
def _enforce_max_length(self) -> None:
|
||||
"""
|
||||
Make sure that the size of the replay buffer does not pass the maximum size allowed.
|
||||
@@ -188,7 +275,7 @@ class EpisodicExperienceReplay(Memory):
|
||||
|
||||
self.reader_writer_lock.release_writing_and_reading()
|
||||
|
||||
def store_episode(self, episode: Episode, lock: bool=True) -> None:
|
||||
def store_episode(self, episode: Episode, lock: bool = True) -> None:
|
||||
"""
|
||||
Store a new episode in the memory.
|
||||
:param episode: the new episode to store
|
||||
@@ -211,7 +298,7 @@ class EpisodicExperienceReplay(Memory):
|
||||
if lock:
|
||||
self.reader_writer_lock.release_writing_and_reading()
|
||||
|
||||
def get_episode(self, episode_index: int, lock: bool=True) -> Union[None, Episode]:
|
||||
def get_episode(self, episode_index: int, lock: bool = True) -> Union[None, Episode]:
|
||||
"""
|
||||
Returns the episode in the given index. If the episode does not exist, returns None instead.
|
||||
:param episode_index: the index of the episode to return
|
||||
@@ -256,7 +343,7 @@ class EpisodicExperienceReplay(Memory):
|
||||
self.reader_writer_lock.release_writing_and_reading()
|
||||
|
||||
# for API compatibility
|
||||
def get(self, episode_index: int, lock: bool=True) -> Union[None, Episode]:
|
||||
def get(self, episode_index: int, lock: bool = True) -> Union[None, Episode]:
|
||||
"""
|
||||
Returns the episode in the given index. If the episode does not exist, returns None instead.
|
||||
:param episode_index: the index of the episode to return
|
||||
@@ -315,3 +402,47 @@ class EpisodicExperienceReplay(Memory):
|
||||
|
||||
self.reader_writer_lock.release_writing()
|
||||
return mean
|
||||
|
||||
def load_csv(self, csv_dataset: CsvDataset) -> None:
|
||||
"""
|
||||
Restore the replay buffer contents from a csv file.
|
||||
The csv file is assumed to include a list of transitions.
|
||||
:param csv_dataset: A construct which holds the dataset parameters
|
||||
"""
|
||||
df = pd.read_csv(csv_dataset.filepath)
|
||||
if len(df) > self.max_size[1]:
|
||||
screen.warning("Warning! The number of transitions to load into the replay buffer ({}) is "
|
||||
"bigger than the max size of the replay buffer ({}). The excessive transitions will "
|
||||
"not be stored.".format(len(df), self.max_size[1]))
|
||||
|
||||
episode_ids = df['episode_id'].unique()
|
||||
progress_bar = ProgressBar(len(episode_ids))
|
||||
state_columns = [col for col in df.columns if col.startswith('state_feature')]
|
||||
|
||||
for e_id in episode_ids:
|
||||
progress_bar.update(e_id)
|
||||
df_episode_transitions = df[df['episode_id'] == e_id]
|
||||
episode = Episode()
|
||||
for (_, current_transition), (_, next_transition) in zip(df_episode_transitions[:-1].iterrows(),
|
||||
df_episode_transitions[1:].iterrows()):
|
||||
state = np.array([current_transition[col] for col in state_columns])
|
||||
next_state = np.array([next_transition[col] for col in state_columns])
|
||||
|
||||
episode.insert(
|
||||
Transition(state={'observation': state},
|
||||
action=current_transition['action'], reward=current_transition['reward'],
|
||||
next_state={'observation': next_state}, game_over=False,
|
||||
info={'all_action_probabilities':
|
||||
ast.literal_eval(current_transition['all_action_probabilities'])}))
|
||||
|
||||
# Set the last transition to end the episode
|
||||
if csv_dataset.is_episodic:
|
||||
episode.get_last_transition().game_over = True
|
||||
|
||||
self.store_episode(episode)
|
||||
|
||||
# close the progress bar
|
||||
progress_bar.update(len(episode_ids))
|
||||
progress_bar.close()
|
||||
|
||||
self.shuffle_episodes()
|
||||
|
||||
@@ -14,10 +14,10 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import List, Tuple, Union, Dict, Any
|
||||
from typing import List, Tuple, Union
|
||||
import pickle
|
||||
import sys
|
||||
import time
|
||||
import random
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -72,7 +72,6 @@ class ExperienceReplay(Memory):
|
||||
Sample a batch of transitions form the replay buffer. If the requested size is larger than the number
|
||||
of samples available in the replay buffer then the batch will return empty.
|
||||
:param size: the size of the batch to sample
|
||||
:param beta: the beta parameter used for importance sampling
|
||||
:return: a batch (list) of selected transitions from the replay buffer
|
||||
"""
|
||||
self.reader_writer_lock.lock_writing()
|
||||
@@ -92,6 +91,32 @@ class ExperienceReplay(Memory):
|
||||
self.reader_writer_lock.release_writing()
|
||||
return batch
|
||||
|
||||
def get_shuffled_data_generator(self, size: int) -> List[Transition]:
|
||||
"""
|
||||
Get an generator for iterating through the shuffled replay buffer, for processing the data in epochs.
|
||||
If the requested size is larger than the number of samples available in the replay buffer then the batch will
|
||||
return empty. The last returned batch may be smaller than the size requested, to accommodate for all the
|
||||
transitions in the replay buffer.
|
||||
|
||||
:param size: the size of the batch to return
|
||||
:return: a batch (list) of selected transitions from the replay buffer
|
||||
"""
|
||||
self.reader_writer_lock.lock_writing()
|
||||
shuffled_transition_indices = list(range(len(self.transitions)))
|
||||
random.shuffle(shuffled_transition_indices)
|
||||
|
||||
# we deliberately drop some of the ending data which is left after dividing to batches of size `size`
|
||||
# for i in range(math.ceil(len(shuffled_transition_indices) / size)):
|
||||
for i in range(int(len(shuffled_transition_indices) / size)):
|
||||
sample_data = [self.transitions[j] for j in shuffled_transition_indices[i * size: (i + 1) * size]]
|
||||
self.reader_writer_lock.release_writing()
|
||||
|
||||
yield sample_data
|
||||
|
||||
## usage example
|
||||
# for o in random_seq_generator(list(range(10)), 4):
|
||||
# print(o)
|
||||
|
||||
def _enforce_max_length(self) -> None:
|
||||
"""
|
||||
Make sure that the size of the replay buffer does not pass the maximum size allowed.
|
||||
@@ -215,7 +240,7 @@ class ExperienceReplay(Memory):
|
||||
with open(file_path, 'wb') as file:
|
||||
pickle.dump(self.transitions, file)
|
||||
|
||||
def load(self, file_path: str) -> None:
|
||||
def load_pickled(self, file_path: str) -> None:
|
||||
"""
|
||||
Restore the replay buffer contents from a pickle file.
|
||||
The pickle file is assumed to include a list of transitions.
|
||||
@@ -238,3 +263,4 @@ class ExperienceReplay(Memory):
|
||||
progress_bar.update(transition_idx)
|
||||
|
||||
progress_bar.close()
|
||||
|
||||
|
||||
15
rl_coach/off_policy_evaluators/__init__.py
Normal file
15
rl_coach/off_policy_evaluators/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
#
|
||||
# Copyright (c) 2019 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
15
rl_coach/off_policy_evaluators/bandits/__init__.py
Normal file
15
rl_coach/off_policy_evaluators/bandits/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
#
|
||||
# Copyright (c) 2019 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
40
rl_coach/off_policy_evaluators/bandits/doubly_robust.py
Normal file
40
rl_coach/off_policy_evaluators/bandits/doubly_robust.py
Normal file
@@ -0,0 +1,40 @@
|
||||
#
|
||||
# Copyright (c) 2019 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DoublyRobust(object):
|
||||
|
||||
@staticmethod
|
||||
def evaluate(ope_shared_stats: 'OpeSharedStats') -> tuple:
|
||||
"""
|
||||
Run the off-policy evaluator to get a score for the goodness of the new policy, based on the dataset,
|
||||
which was collected using other policy(ies).
|
||||
|
||||
Papers:
|
||||
https://arxiv.org/abs/1103.4601
|
||||
https://arxiv.org/pdf/1612.01205 (some more clearer explanations)
|
||||
|
||||
:return: the evaluation score
|
||||
"""
|
||||
|
||||
ips = np.mean(ope_shared_stats.rho_all_dataset * ope_shared_stats.all_rewards)
|
||||
dm = np.mean(ope_shared_stats.all_v_values_reward_model_based)
|
||||
dr = np.mean(ope_shared_stats.rho_all_dataset *
|
||||
(ope_shared_stats.all_rewards - ope_shared_stats.all_reward_model_rewards[
|
||||
range(len(ope_shared_stats.all_actions)), ope_shared_stats.all_actions])) + dm
|
||||
|
||||
return ips, dm, dr
|
||||
124
rl_coach/off_policy_evaluators/ope_manager.py
Normal file
124
rl_coach/off_policy_evaluators/ope_manager.py
Normal file
@@ -0,0 +1,124 @@
|
||||
#
|
||||
# Copyright (c) 2019 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from collections import namedtuple
|
||||
|
||||
import numpy as np
|
||||
from typing import List
|
||||
|
||||
from rl_coach.architectures.architecture import Architecture
|
||||
from rl_coach.core_types import Episode, Batch
|
||||
from rl_coach.off_policy_evaluators.bandits.doubly_robust import DoublyRobust
|
||||
from rl_coach.off_policy_evaluators.rl.sequential_doubly_robust import SequentialDoublyRobust
|
||||
|
||||
from rl_coach.core_types import Transition
|
||||
|
||||
OpeSharedStats = namedtuple("OpeSharedStats", ['all_reward_model_rewards', 'all_policy_probs',
|
||||
'all_v_values_reward_model_based', 'all_rewards', 'all_actions',
|
||||
'all_old_policy_probs', 'new_policy_prob', 'rho_all_dataset'])
|
||||
OpeEstimation = namedtuple("OpeEstimation", ['ips', 'dm', 'dr', 'seq_dr'])
|
||||
|
||||
|
||||
class OpeManager(object):
|
||||
def __init__(self):
|
||||
self.doubly_robust = DoublyRobust()
|
||||
self.sequential_doubly_robust = SequentialDoublyRobust()
|
||||
|
||||
@staticmethod
|
||||
def _prepare_ope_shared_stats(dataset_as_transitions: List[Transition], batch_size: int,
|
||||
reward_model: Architecture, q_network: Architecture,
|
||||
network_keys: List) -> OpeSharedStats:
|
||||
"""
|
||||
Do the preparations needed for different estimators.
|
||||
Some of the calcuations are shared, so we centralize all the work here.
|
||||
|
||||
:param dataset_as_transitions: The evaluation dataset in the form of transitions.
|
||||
:param batch_size: The batch size to use.
|
||||
:param reward_model: A reward model to be used by DR
|
||||
:param q_network: The Q network whose its policy we evaluate.
|
||||
:param network_keys: The network keys used for feeding the neural networks.
|
||||
:return:
|
||||
"""
|
||||
# IPS
|
||||
all_reward_model_rewards, all_policy_probs, all_old_policy_probs = [], [], []
|
||||
all_v_values_reward_model_based, all_v_values_q_model_based, all_rewards, all_actions = [], [], [], []
|
||||
|
||||
for i in range(int(len(dataset_as_transitions) / batch_size) + 1):
|
||||
batch = dataset_as_transitions[i * batch_size: (i + 1) * batch_size]
|
||||
batch_for_inference = Batch(batch)
|
||||
|
||||
all_reward_model_rewards.append(reward_model.predict(
|
||||
batch_for_inference.states(network_keys)))
|
||||
|
||||
# TODO can we get rid of the 'output_heads[0]', and have some way of a cleaner API?
|
||||
q_values, sm_values = q_network.predict(batch_for_inference.states(network_keys),
|
||||
outputs=[q_network.output_heads[0].output,
|
||||
q_network.output_heads[0].softmax])
|
||||
# TODO why is this needed?
|
||||
q_values = q_values[0]
|
||||
|
||||
all_policy_probs.append(sm_values)
|
||||
all_v_values_reward_model_based.append(np.sum(all_policy_probs[-1] * all_reward_model_rewards[-1], axis=1))
|
||||
all_v_values_q_model_based.append(np.sum(all_policy_probs[-1] * q_values, axis=1))
|
||||
all_rewards.append(batch_for_inference.rewards())
|
||||
all_actions.append(batch_for_inference.actions())
|
||||
all_old_policy_probs.append(batch_for_inference.info('all_action_probabilities')
|
||||
[range(len(batch_for_inference.actions())), batch_for_inference.actions()])
|
||||
|
||||
for j, t in enumerate(batch):
|
||||
t.update_info({
|
||||
'q_value': q_values[j],
|
||||
'softmax_policy_prob': all_policy_probs[-1][j],
|
||||
'v_value_q_model_based': all_v_values_q_model_based[-1][j],
|
||||
|
||||
})
|
||||
|
||||
all_reward_model_rewards = np.concatenate(all_reward_model_rewards, axis=0)
|
||||
all_policy_probs = np.concatenate(all_policy_probs, axis=0)
|
||||
all_v_values_reward_model_based = np.concatenate(all_v_values_reward_model_based, axis=0)
|
||||
all_rewards = np.concatenate(all_rewards, axis=0)
|
||||
all_actions = np.concatenate(all_actions, axis=0)
|
||||
all_old_policy_probs = np.concatenate(all_old_policy_probs, axis=0)
|
||||
|
||||
# generate model probabilities
|
||||
new_policy_prob = all_policy_probs[np.arange(all_actions.shape[0]), all_actions]
|
||||
rho_all_dataset = new_policy_prob / all_old_policy_probs
|
||||
|
||||
return OpeSharedStats(all_reward_model_rewards, all_policy_probs, all_v_values_reward_model_based,
|
||||
all_rewards, all_actions, all_old_policy_probs, new_policy_prob, rho_all_dataset)
|
||||
|
||||
def evaluate(self, dataset_as_episodes: List[Episode], batch_size: int, discount_factor: float,
|
||||
reward_model: Architecture, q_network: Architecture, network_keys: List) -> OpeEstimation:
|
||||
"""
|
||||
Run all the OPEs and get estimations of the current policy performance based on the evaluation dataset.
|
||||
|
||||
:param dataset_as_episodes: The evaluation dataset.
|
||||
:param batch_size: Batch size to use for the estimators.
|
||||
:param discount_factor: The standard RL discount factor.
|
||||
:param reward_model: A reward model to be used by DR
|
||||
:param q_network: The Q network whose its policy we evaluate.
|
||||
:param network_keys: The network keys used for feeding the neural networks.
|
||||
|
||||
:return: An OpeEstimation tuple which groups together all the OPE estimations
|
||||
"""
|
||||
# TODO this seems kind of slow, review performance
|
||||
dataset_as_transitions = [t for e in dataset_as_episodes for t in e.transitions]
|
||||
ope_shared_stats = self._prepare_ope_shared_stats(dataset_as_transitions, batch_size, reward_model,
|
||||
q_network, network_keys)
|
||||
|
||||
ips, dm, dr = self.doubly_robust.evaluate(ope_shared_stats)
|
||||
seq_dr = self.sequential_doubly_robust.evaluate(dataset_as_episodes, discount_factor)
|
||||
return OpeEstimation(ips, dm, dr, seq_dr)
|
||||
|
||||
0
rl_coach/off_policy_evaluators/rl/__init__.py
Normal file
0
rl_coach/off_policy_evaluators/rl/__init__.py
Normal file
@@ -0,0 +1,51 @@
|
||||
#
|
||||
# Copyright (c) 2019 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import List
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.core_types import Episode
|
||||
|
||||
|
||||
class SequentialDoublyRobust(object):
|
||||
|
||||
@staticmethod
|
||||
def evaluate(dataset_as_episodes: List[Episode], discount_factor: float) -> float:
|
||||
"""
|
||||
Run the off-policy evaluator to get a score for the goodness of the new policy, based on the dataset,
|
||||
which was collected using other policy(ies).
|
||||
|
||||
Paper: https://arxiv.org/pdf/1511.03722.pdf
|
||||
|
||||
:return: the evaluation score
|
||||
"""
|
||||
|
||||
# Sequential Doubly Robust
|
||||
per_episode_seq_dr = []
|
||||
|
||||
for episode in dataset_as_episodes:
|
||||
episode_seq_dr = 0
|
||||
for transition in episode.transitions:
|
||||
rho = transition.info['softmax_policy_prob'][transition.action] / \
|
||||
transition.info['all_action_probabilities'][transition.action]
|
||||
episode_seq_dr = transition.info['v_value_q_model_based'] + rho * (transition.reward + discount_factor
|
||||
* episode_seq_dr -
|
||||
transition.info['q_value'][
|
||||
transition.action])
|
||||
per_episode_seq_dr.append(episode_seq_dr)
|
||||
|
||||
seq_dr = np.array(per_episode_seq_dr).mean()
|
||||
|
||||
return seq_dr
|
||||
@@ -25,6 +25,7 @@ from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.schedules import ConstantSchedule
|
||||
from rl_coach.spaces import ImageObservationSpace
|
||||
from rl_coach.utilities.carla_dataset_to_replay_buffer import create_dataset
|
||||
from rl_coach.core_types import PickledReplayBuffer
|
||||
|
||||
####################
|
||||
# Graph Scheduling #
|
||||
@@ -130,7 +131,7 @@ agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(0)
|
||||
|
||||
# use the following command line to download and extract the CARLA dataset:
|
||||
# python rl_coach/utilities/carla_dataset_to_replay_buffer.py
|
||||
agent_params.memory.load_memory_from_file_path = "./datasets/carla_train_set_replay_buffer.p"
|
||||
agent_params.memory.load_memory_from_file_path = PickledReplayBuffer("./datasets/carla_train_set_replay_buffer.p")
|
||||
agent_params.memory.state_key_with_the_class_index = 'high_level_command'
|
||||
agent_params.memory.num_classes = 4
|
||||
|
||||
|
||||
91
rl_coach/presets/CartPole_DQN_BatchRL.py
Normal file
91
rl_coach/presets/CartPole_DQN_BatchRL.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from copy import deepcopy
|
||||
|
||||
from rl_coach.agents.ddqn_agent import DDQNAgentParameters
|
||||
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
||||
from rl_coach.filters.filter import InputFilter
|
||||
from rl_coach.filters.reward import RewardRescaleFilter
|
||||
from rl_coach.graph_managers.batch_rl_graph_manager import BatchRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.memories.memory import MemoryGranularity
|
||||
from rl_coach.schedules import LinearSchedule
|
||||
from rl_coach.memories.episodic import EpisodicExperienceReplayParameters
|
||||
|
||||
DATASET_SIZE = 40000
|
||||
|
||||
####################
|
||||
# Graph Scheduling #
|
||||
####################
|
||||
|
||||
schedule_params = ScheduleParameters()
|
||||
schedule_params.improve_steps = TrainingSteps(10000000000)
|
||||
schedule_params.steps_between_evaluation_periods = TrainingSteps(1)
|
||||
schedule_params.evaluation_steps = EnvironmentEpisodes(10)
|
||||
schedule_params.heatup_steps = EnvironmentSteps(DATASET_SIZE)
|
||||
|
||||
#########
|
||||
# Agent #
|
||||
#########
|
||||
# TODO add a preset which uses a dataset to train a BatchRL graph. e.g. save a cartpole dataset in a csv format.
|
||||
agent_params = DDQNAgentParameters()
|
||||
agent_params.network_wrappers['main'].batch_size = 1024
|
||||
|
||||
# DQN params
|
||||
# agent_params.algorithm.num_steps_between_copying_online_weights_to_target = TrainingSteps(100)
|
||||
|
||||
# For making this become Fitted Q-Iteration we can keep the targets constant for the entire dataset size -
|
||||
agent_params.algorithm.num_steps_between_copying_online_weights_to_target = TrainingSteps(
|
||||
DATASET_SIZE / agent_params.network_wrappers['main'].batch_size)
|
||||
|
||||
agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(0)
|
||||
agent_params.algorithm.discount = 0.98
|
||||
# agent_params.algorithm.discount = 1.0
|
||||
|
||||
|
||||
# NN configuration
|
||||
agent_params.network_wrappers['main'].learning_rate = 0.0001
|
||||
agent_params.network_wrappers['main'].replace_mse_with_huber_loss = False
|
||||
agent_params.network_wrappers['main'].l2_regularization = 0.0001
|
||||
agent_params.network_wrappers['main'].softmax_temperature = 0.2
|
||||
# agent_params.network_wrappers['main'].learning_rate_decay_rate = 0.95
|
||||
# agent_params.network_wrappers['main'].learning_rate_decay_steps = int(DATASET_SIZE /
|
||||
# agent_params.network_wrappers['main'].batch_size)
|
||||
|
||||
# reward model params
|
||||
agent_params.network_wrappers['reward_model'] = deepcopy(agent_params.network_wrappers['main'])
|
||||
agent_params.network_wrappers['reward_model'].learning_rate = 0.0001
|
||||
agent_params.network_wrappers['reward_model'].l2_regularization = 0
|
||||
|
||||
# ER size
|
||||
agent_params.memory = EpisodicExperienceReplayParameters()
|
||||
agent_params.memory.max_size = (MemoryGranularity.Transitions, DATASET_SIZE)
|
||||
|
||||
|
||||
# E-Greedy schedule
|
||||
agent_params.exploration.epsilon_schedule = LinearSchedule(0, 0, 10000)
|
||||
agent_params.exploration.evaluation_epsilon = 0
|
||||
|
||||
|
||||
agent_params.input_filter = InputFilter()
|
||||
agent_params.input_filter.add_reward_filter('rescale', RewardRescaleFilter(1/200.))
|
||||
|
||||
################
|
||||
# Environment #
|
||||
################
|
||||
env_params = GymVectorEnvironment(level='CartPole-v0')
|
||||
|
||||
########
|
||||
# Test #
|
||||
########
|
||||
preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.test = True
|
||||
preset_validation_params.min_reward_threshold = 150
|
||||
preset_validation_params.max_episodes_to_achieve_reward = 250
|
||||
|
||||
graph_manager = BatchRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params,
|
||||
vis_params=VisualizationParameters(dump_signals_to_csv_every_x_episodes=1),
|
||||
preset_validation_params=preset_validation_params,
|
||||
reward_model_num_epochs=50,
|
||||
train_to_eval_ratio=0.8)
|
||||
@@ -5,6 +5,7 @@ from rl_coach.environments.doom_environment import DoomEnvironmentParameters
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.schedules import LinearSchedule
|
||||
from rl_coach.core_types import PickledReplayBuffer
|
||||
|
||||
####################
|
||||
# Graph Scheduling #
|
||||
@@ -29,7 +30,7 @@ agent_params.exploration.evaluation_epsilon = 0
|
||||
agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(0)
|
||||
agent_params.network_wrappers['main'].replace_mse_with_huber_loss = False
|
||||
agent_params.network_wrappers['main'].batch_size = 120
|
||||
agent_params.memory.load_memory_from_file_path = 'datasets/doom_basic.p'
|
||||
agent_params.memory.load_memory_from_file_path = PickledReplayBuffer('datasets/doom_basic.p')
|
||||
|
||||
|
||||
###############
|
||||
|
||||
@@ -5,6 +5,7 @@ from rl_coach.environments.gym_environment import Atari
|
||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
from rl_coach.memories.memory import MemoryGranularity
|
||||
from rl_coach.core_types import PickledReplayBuffer
|
||||
|
||||
####################
|
||||
# Graph Scheduling #
|
||||
@@ -25,7 +26,7 @@ agent_params.memory.max_size = (MemoryGranularity.Transitions, 1000000)
|
||||
# agent_params.memory.discount = 0.99
|
||||
agent_params.algorithm.discount = 0.99
|
||||
agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(0)
|
||||
agent_params.memory.load_memory_from_file_path = 'datasets/montezuma_revenge.p'
|
||||
agent_params.memory.load_memory_from_file_path = PickledReplayBuffer('datasets/montezuma_revenge.p')
|
||||
|
||||
###############
|
||||
# Environment #
|
||||
|
||||
@@ -648,7 +648,7 @@ class SpacesDefinition(object):
|
||||
"""
|
||||
def __init__(self,
|
||||
state: StateSpace,
|
||||
goal: ObservationSpace,
|
||||
goal: Union[ObservationSpace, None],
|
||||
action: ActionSpace,
|
||||
reward: RewardSpace):
|
||||
self.state = state
|
||||
|
||||
Reference in New Issue
Block a user