mirror of
https://github.com/gryf/coach.git
synced 2026-02-15 05:25:55 +01:00
Batch RL (#238)
This commit is contained in:
@@ -34,6 +34,9 @@ from rl_coach.spaces import SpacesDefinition, VectorObservationSpace, GoalsSpace
|
||||
from rl_coach.utils import Signal, force_list
|
||||
from rl_coach.utils import dynamic_import_and_instantiate_module_from_params
|
||||
from rl_coach.memories.backend.memory_impl import get_memory_backend
|
||||
from rl_coach.core_types import TimeTypes
|
||||
from rl_coach.off_policy_evaluators.ope_manager import OpeManager
|
||||
from rl_coach.core_types import PickledReplayBuffer, CsvDataset
|
||||
|
||||
|
||||
class Agent(AgentInterface):
|
||||
@@ -49,10 +52,10 @@ class Agent(AgentInterface):
|
||||
and self.ap.memory.shared_memory
|
||||
if self.shared_memory:
|
||||
self.shared_memory_scratchpad = self.ap.task_parameters.shared_memory_scratchpad
|
||||
self.name = agent_parameters.name
|
||||
self.parent = parent
|
||||
self.parent_level_manager = None
|
||||
self.full_name_id = agent_parameters.full_name_id = self.name
|
||||
# TODO this needs to be sorted out. Why the duplicates for the agent's name?
|
||||
self.full_name_id = agent_parameters.full_name_id = self.name = agent_parameters.name
|
||||
|
||||
if type(agent_parameters.task_parameters) == DistributedTaskParameters:
|
||||
screen.log_title("Creating agent - name: {} task id: {} (may take up to 30 seconds due to "
|
||||
@@ -84,9 +87,17 @@ class Agent(AgentInterface):
|
||||
self.memory.set_memory_backend(self.memory_backend)
|
||||
|
||||
if agent_parameters.memory.load_memory_from_file_path:
|
||||
screen.log_title("Loading replay buffer from pickle. Pickle path: {}"
|
||||
.format(agent_parameters.memory.load_memory_from_file_path))
|
||||
self.memory.load(agent_parameters.memory.load_memory_from_file_path)
|
||||
if isinstance(agent_parameters.memory.load_memory_from_file_path, PickledReplayBuffer):
|
||||
screen.log_title("Loading a pickled replay buffer. Pickled file path: {}"
|
||||
.format(agent_parameters.memory.load_memory_from_file_path.filepath))
|
||||
self.memory.load_pickled(agent_parameters.memory.load_memory_from_file_path.filepath)
|
||||
elif isinstance(agent_parameters.memory.load_memory_from_file_path, CsvDataset):
|
||||
screen.log_title("Loading a replay buffer from a CSV file. CSV file path: {}"
|
||||
.format(agent_parameters.memory.load_memory_from_file_path.filepath))
|
||||
self.memory.load_csv(agent_parameters.memory.load_memory_from_file_path)
|
||||
else:
|
||||
raise ValueError('Trying to load a replay buffer using an unsupported method - {}. '
|
||||
.format(agent_parameters.memory.load_memory_from_file_path))
|
||||
|
||||
if self.shared_memory and self.is_chief:
|
||||
self.shared_memory_scratchpad.add(self.memory_lookup_name, self.memory)
|
||||
@@ -147,6 +158,7 @@ class Agent(AgentInterface):
|
||||
self.total_steps_counter = 0
|
||||
self.running_reward = None
|
||||
self.training_iteration = 0
|
||||
self.training_epoch = 0
|
||||
self.last_target_network_update_step = 0
|
||||
self.last_training_phase_step = 0
|
||||
self.current_episode = self.ap.current_episode = 0
|
||||
@@ -184,6 +196,7 @@ class Agent(AgentInterface):
|
||||
self.discounted_return = self.register_signal('Discounted Return')
|
||||
if isinstance(self.in_action_space, GoalsSpace):
|
||||
self.distance_from_goal = self.register_signal('Distance From Goal', dump_one_value_per_step=True)
|
||||
|
||||
# use seed
|
||||
if self.ap.task_parameters.seed is not None:
|
||||
random.seed(self.ap.task_parameters.seed)
|
||||
@@ -193,6 +206,9 @@ class Agent(AgentInterface):
|
||||
random.seed()
|
||||
np.random.seed()
|
||||
|
||||
# batch rl
|
||||
self.ope_manager = OpeManager() if self.ap.is_batch_rl_training else None
|
||||
|
||||
@property
|
||||
def parent(self) -> 'LevelManager':
|
||||
"""
|
||||
@@ -228,6 +244,7 @@ class Agent(AgentInterface):
|
||||
format(graph_name=self.parent_level_manager.parent_graph_manager.name,
|
||||
level_name=self.parent_level_manager.name,
|
||||
agent_full_id='.'.join(self.full_name_id.split('/')))
|
||||
self.agent_logger.set_index_name(self.parent_level_manager.parent_graph_manager.time_metric.value.name)
|
||||
self.agent_logger.set_logger_filenames(self.ap.task_parameters.experiment_path, logger_prefix=logger_prefix,
|
||||
add_timestamp=True, task_id=self.task_id)
|
||||
if self.ap.visualization.dump_in_episode_signals:
|
||||
@@ -387,7 +404,8 @@ class Agent(AgentInterface):
|
||||
elif ending_evaluation:
|
||||
# we write to the next episode, because it could be that the current episode was already written
|
||||
# to disk and then we won't write it again
|
||||
self.agent_logger.set_current_time(self.current_episode + 1)
|
||||
self.agent_logger.set_current_time(self.get_current_time() + 1)
|
||||
|
||||
evaluation_reward = self.accumulated_rewards_across_evaluation_episodes / self.num_evaluation_episodes_completed
|
||||
self.agent_logger.create_signal_value(
|
||||
'Evaluation Reward', evaluation_reward)
|
||||
@@ -471,8 +489,11 @@ class Agent(AgentInterface):
|
||||
:return: None
|
||||
"""
|
||||
# log all the signals to file
|
||||
self.agent_logger.set_current_time(self.current_episode)
|
||||
current_time = self.get_current_time()
|
||||
self.agent_logger.set_current_time(current_time)
|
||||
self.agent_logger.create_signal_value('Training Iter', self.training_iteration)
|
||||
self.agent_logger.create_signal_value('Episode #', self.current_episode)
|
||||
self.agent_logger.create_signal_value('Epoch', self.training_epoch)
|
||||
self.agent_logger.create_signal_value('In Heatup', int(self._phase == RunPhase.HEATUP))
|
||||
self.agent_logger.create_signal_value('ER #Transitions', self.call_memory('num_transitions'))
|
||||
self.agent_logger.create_signal_value('ER #Episodes', self.call_memory('length'))
|
||||
@@ -485,13 +506,17 @@ class Agent(AgentInterface):
|
||||
if self._phase == RunPhase.TRAIN else np.nan)
|
||||
|
||||
self.agent_logger.create_signal_value('Update Target Network', 0, overwrite=False)
|
||||
self.agent_logger.update_wall_clock_time(self.current_episode)
|
||||
self.agent_logger.update_wall_clock_time(current_time)
|
||||
|
||||
# The following signals are created with meaningful values only when an evaluation phase is completed.
|
||||
# Creating with default NaNs for any HEATUP/TRAIN/TEST episode which is not the last in an evaluation phase
|
||||
self.agent_logger.create_signal_value('Evaluation Reward', np.nan, overwrite=False)
|
||||
self.agent_logger.create_signal_value('Shaped Evaluation Reward', np.nan, overwrite=False)
|
||||
self.agent_logger.create_signal_value('Success Rate', np.nan, overwrite=False)
|
||||
self.agent_logger.create_signal_value('Inverse Propensity Score', np.nan, overwrite=False)
|
||||
self.agent_logger.create_signal_value('Direct Method Reward', np.nan, overwrite=False)
|
||||
self.agent_logger.create_signal_value('Doubly Robust', np.nan, overwrite=False)
|
||||
self.agent_logger.create_signal_value('Sequential Doubly Robust', np.nan, overwrite=False)
|
||||
|
||||
for signal in self.episode_signals:
|
||||
self.agent_logger.create_signal_value("{}/Mean".format(signal.name), signal.get_mean())
|
||||
@@ -500,8 +525,7 @@ class Agent(AgentInterface):
|
||||
self.agent_logger.create_signal_value("{}/Min".format(signal.name), signal.get_min())
|
||||
|
||||
# dump
|
||||
if self.current_episode % self.ap.visualization.dump_signals_to_csv_every_x_episodes == 0 \
|
||||
and self.current_episode > 0:
|
||||
if self.current_episode % self.ap.visualization.dump_signals_to_csv_every_x_episodes == 0:
|
||||
self.agent_logger.dump_output_csv()
|
||||
|
||||
def handle_episode_ended(self) -> None:
|
||||
@@ -537,7 +561,8 @@ class Agent(AgentInterface):
|
||||
self.total_reward_in_current_episode >= self.spaces.reward.reward_success_threshold:
|
||||
self.num_successes_across_evaluation_episodes += 1
|
||||
|
||||
if self.ap.visualization.dump_csv:
|
||||
if self.ap.visualization.dump_csv and \
|
||||
self.parent_level_manager.parent_graph_manager.time_metric == TimeTypes.EpisodeNumber:
|
||||
self.update_log()
|
||||
|
||||
if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == "high":
|
||||
@@ -651,18 +676,22 @@ class Agent(AgentInterface):
|
||||
"""
|
||||
loss = 0
|
||||
if self._should_train():
|
||||
self.training_epoch += 1
|
||||
for network in self.networks.values():
|
||||
network.set_is_training(True)
|
||||
|
||||
for training_step in range(self.ap.algorithm.num_consecutive_training_steps):
|
||||
# TODO: this should be network dependent
|
||||
network_parameters = list(self.ap.network_wrappers.values())[0]
|
||||
# TODO: this should be network dependent
|
||||
network_parameters = list(self.ap.network_wrappers.values())[0]
|
||||
|
||||
# we either go sequentially through the entire replay buffer in the batch RL mode,
|
||||
# or sample randomly for the basic RL case.
|
||||
training_schedule = self.call_memory('get_shuffled_data_generator', network_parameters.batch_size) if \
|
||||
self.ap.is_batch_rl_training else [self.call_memory('sample', network_parameters.batch_size) for _ in
|
||||
range(self.ap.algorithm.num_consecutive_training_steps)]
|
||||
|
||||
for batch in training_schedule:
|
||||
# update counters
|
||||
self.training_iteration += 1
|
||||
|
||||
# sample a batch and train on it
|
||||
batch = self.call_memory('sample', network_parameters.batch_size)
|
||||
if self.pre_network_filter is not None:
|
||||
batch = self.pre_network_filter.filter(batch, update_internal_state=False, deep_copy=False)
|
||||
|
||||
@@ -673,6 +702,7 @@ class Agent(AgentInterface):
|
||||
batch = Batch(batch)
|
||||
total_loss, losses, unclipped_grads = self.learn_from_batch(batch)
|
||||
loss += total_loss
|
||||
|
||||
self.unclipped_grads.add_sample(unclipped_grads)
|
||||
|
||||
# TODO: the learning rate decay should be done through the network instead of here
|
||||
@@ -697,6 +727,12 @@ class Agent(AgentInterface):
|
||||
if self.imitation:
|
||||
self.log_to_screen()
|
||||
|
||||
if self.ap.visualization.dump_csv and \
|
||||
self.parent_level_manager.parent_graph_manager.time_metric == TimeTypes.Epoch:
|
||||
# in BatchRL, or imitation learning, the agent never acts, so we have to get the stats out here.
|
||||
# we dump the data out every epoch
|
||||
self.update_log()
|
||||
|
||||
for network in self.networks.values():
|
||||
network.set_is_training(False)
|
||||
|
||||
@@ -1034,3 +1070,12 @@ class Agent(AgentInterface):
|
||||
for network in self.networks.values():
|
||||
savers.update(network.collect_savers(parent_path_suffix))
|
||||
return savers
|
||||
|
||||
def get_current_time(self):
|
||||
pass
|
||||
return {
|
||||
TimeTypes.EpisodeNumber: self.current_episode,
|
||||
TimeTypes.TrainingIteration: self.training_iteration,
|
||||
TimeTypes.EnvironmentSteps: self.total_steps_counter,
|
||||
TimeTypes.WallClockTime: self.agent_logger.get_current_wall_clock_time(),
|
||||
TimeTypes.Epoch: self.training_epoch}[self.parent_level_manager.parent_graph_manager.time_metric]
|
||||
|
||||
@@ -173,3 +173,12 @@ class AgentInterface(object):
|
||||
:return: None
|
||||
"""
|
||||
raise NotImplementedError("")
|
||||
|
||||
def run_off_policy_evaluation(self) -> None:
|
||||
"""
|
||||
Run off-policy evaluation estimators to evaluate the trained policy performance against a dataset.
|
||||
Should only be implemented for off-policy RL algorithms.
|
||||
|
||||
:return: None
|
||||
"""
|
||||
raise NotImplementedError("")
|
||||
|
||||
@@ -64,6 +64,9 @@ class BootstrappedDQNAgent(ValueOptimizationAgent):
|
||||
q_st_plus_1 = result[:self.ap.exploration.architecture_num_q_heads]
|
||||
TD_targets = result[self.ap.exploration.architecture_num_q_heads:]
|
||||
|
||||
# add Q value samples for logging
|
||||
self.q_values.add_sample(TD_targets)
|
||||
|
||||
# initialize with the current prediction so that we will
|
||||
# only update the action that we have actually done in this transition
|
||||
for i in range(self.ap.network_wrappers['main'].batch_size):
|
||||
|
||||
@@ -100,6 +100,9 @@ class CategoricalDQNAgent(ValueOptimizationAgent):
|
||||
(self.networks['main'].online_network, batch.states(network_keys))
|
||||
])
|
||||
|
||||
# add Q value samples for logging
|
||||
self.q_values.add_sample(self.distribution_prediction_to_q_values(TD_targets))
|
||||
|
||||
# select the optimal actions for the next state
|
||||
target_actions = np.argmax(self.distribution_prediction_to_q_values(distributional_q_st_plus_1), axis=1)
|
||||
m = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size))
|
||||
|
||||
@@ -432,3 +432,4 @@ class CompositeAgent(AgentInterface):
|
||||
savers.update(agent.collect_savers(
|
||||
parent_path_suffix="{}.{}".format(parent_path_suffix, self.name)))
|
||||
return savers
|
||||
|
||||
|
||||
@@ -50,6 +50,9 @@ class DDQNAgent(ValueOptimizationAgent):
|
||||
(self.networks['main'].online_network, batch.states(network_keys))
|
||||
])
|
||||
|
||||
# add Q value samples for logging
|
||||
self.q_values.add_sample(TD_targets)
|
||||
|
||||
# initialize with the current prediction so that we will
|
||||
# only update the action that we have actually done in this transition
|
||||
TD_errors = []
|
||||
|
||||
@@ -81,6 +81,9 @@ class DQNAgent(ValueOptimizationAgent):
|
||||
(self.networks['main'].online_network, batch.states(network_keys))
|
||||
])
|
||||
|
||||
# add Q value samples for logging
|
||||
self.q_values.add_sample(TD_targets)
|
||||
|
||||
# only update the action that we have actually done in this transition
|
||||
TD_errors = []
|
||||
for i in range(self.ap.network_wrappers['main'].batch_size):
|
||||
|
||||
@@ -123,6 +123,9 @@ class NStepQAgent(ValueOptimizationAgent, PolicyOptimizationAgent):
|
||||
else:
|
||||
assert True, 'The available values for targets_horizon are: 1-Step, N-Step'
|
||||
|
||||
# add Q value samples for logging
|
||||
self.q_values.add_sample(state_value_head_targets)
|
||||
|
||||
# train
|
||||
result = self.networks['main'].online_network.accumulate_gradients(batch.states(network_keys), [state_value_head_targets])
|
||||
|
||||
|
||||
@@ -88,6 +88,9 @@ class QuantileRegressionDQNAgent(ValueOptimizationAgent):
|
||||
(self.networks['main'].online_network, batch.states(network_keys))
|
||||
])
|
||||
|
||||
# add Q value samples for logging
|
||||
self.q_values.add_sample(self.get_q_values(current_quantiles))
|
||||
|
||||
# get the optimal actions to take for the next states
|
||||
target_actions = np.argmax(self.get_q_values(next_state_quantiles), axis=1)
|
||||
|
||||
|
||||
@@ -95,6 +95,9 @@ class RainbowDQNAgent(CategoricalDQNAgent):
|
||||
(self.networks['main'].online_network, batch.states(network_keys))
|
||||
])
|
||||
|
||||
# add Q value samples for logging
|
||||
self.q_values.add_sample(self.distribution_prediction_to_q_values(TD_targets))
|
||||
|
||||
# only update the action that we have actually done in this transition (using the Double-DQN selected actions)
|
||||
target_actions = ddqn_selected_actions
|
||||
m = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size))
|
||||
|
||||
@@ -13,16 +13,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.agents.agent import Agent
|
||||
from rl_coach.core_types import ActionInfo, StateType
|
||||
from rl_coach.core_types import ActionInfo, StateType, Batch
|
||||
from rl_coach.logger import screen
|
||||
from rl_coach.memories.non_episodic.prioritized_experience_replay import PrioritizedExperienceReplay
|
||||
from rl_coach.spaces import DiscreteActionSpace
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
## This is an abstract agent - there is no learn_from_batch method ##
|
||||
|
||||
@@ -79,10 +80,12 @@ class ValueOptimizationAgent(Agent):
|
||||
# this is for bootstrapped dqn
|
||||
if type(actions_q_values) == list and len(actions_q_values) > 0:
|
||||
actions_q_values = self.exploration_policy.last_action_values
|
||||
actions_q_values = actions_q_values.squeeze()
|
||||
|
||||
# store the q values statistics for logging
|
||||
self.q_values.add_sample(actions_q_values)
|
||||
|
||||
actions_q_values = actions_q_values.squeeze()
|
||||
|
||||
for i, q_value in enumerate(actions_q_values):
|
||||
self.q_value_for_action[i].add_sample(q_value)
|
||||
|
||||
@@ -96,3 +99,74 @@ class ValueOptimizationAgent(Agent):
|
||||
|
||||
def learn_from_batch(self, batch):
|
||||
raise NotImplementedError("ValueOptimizationAgent is an abstract agent. Not to be used directly.")
|
||||
|
||||
def run_off_policy_evaluation(self):
|
||||
"""
|
||||
Run the off-policy evaluation estimators to get a prediction for the performance of the current policy based on
|
||||
an evaluation dataset, which was collected by another policy(ies).
|
||||
:return: None
|
||||
"""
|
||||
assert self.ope_manager
|
||||
dataset_as_episodes = self.call_memory('get_all_complete_episodes_from_to',
|
||||
(self.call_memory('get_last_training_set_episode_id') + 1,
|
||||
self.call_memory('num_complete_episodes')))
|
||||
if len(dataset_as_episodes) == 0:
|
||||
raise ValueError('train_to_eval_ratio is too high causing the evaluation set to be empty. '
|
||||
'Consider decreasing its value.')
|
||||
|
||||
ips, dm, dr, seq_dr = self.ope_manager.evaluate(
|
||||
dataset_as_episodes=dataset_as_episodes,
|
||||
batch_size=self.ap.network_wrappers['main'].batch_size,
|
||||
discount_factor=self.ap.algorithm.discount,
|
||||
reward_model=self.networks['reward_model'].online_network,
|
||||
q_network=self.networks['main'].online_network,
|
||||
network_keys=list(self.ap.network_wrappers['main'].input_embedders_parameters.keys()))
|
||||
|
||||
# get the estimators out to the screen
|
||||
log = OrderedDict()
|
||||
log['Epoch'] = self.training_epoch
|
||||
log['IPS'] = ips
|
||||
log['DM'] = dm
|
||||
log['DR'] = dr
|
||||
log['Sequential-DR'] = seq_dr
|
||||
screen.log_dict(log, prefix='Off-Policy Evaluation')
|
||||
|
||||
# get the estimators out to dashboard
|
||||
self.agent_logger.set_current_time(self.get_current_time() + 1)
|
||||
self.agent_logger.create_signal_value('Inverse Propensity Score', ips)
|
||||
self.agent_logger.create_signal_value('Direct Method Reward', dm)
|
||||
self.agent_logger.create_signal_value('Doubly Robust', dr)
|
||||
self.agent_logger.create_signal_value('Sequential Doubly Robust', seq_dr)
|
||||
|
||||
def improve_reward_model(self, epochs: int):
|
||||
"""
|
||||
Train a reward model to be used by the doubly-robust estimator
|
||||
|
||||
:param epochs: The total number of epochs to use for training a reward model
|
||||
:return: None
|
||||
"""
|
||||
batch_size = self.ap.network_wrappers['reward_model'].batch_size
|
||||
network_keys = self.ap.network_wrappers['reward_model'].input_embedders_parameters.keys()
|
||||
|
||||
# this is fitted from the training dataset
|
||||
for epoch in range(epochs):
|
||||
loss = 0
|
||||
for i, batch in enumerate(self.call_memory('get_shuffled_data_generator', batch_size)):
|
||||
batch = Batch(batch)
|
||||
current_rewards_prediction_for_all_actions = self.networks['reward_model'].online_network.predict(batch.states(network_keys))
|
||||
current_rewards_prediction_for_all_actions[range(batch_size), batch.actions()] = batch.rewards()
|
||||
loss += self.networks['reward_model'].train_and_sync_networks(
|
||||
batch.states(network_keys), current_rewards_prediction_for_all_actions)[0]
|
||||
# print(self.networks['reward_model'].online_network.predict(batch.states(network_keys))[0])
|
||||
|
||||
log = OrderedDict()
|
||||
log['Epoch'] = epoch
|
||||
log['loss'] = loss / int(self.call_memory('num_transitions_in_complete_episodes') / batch_size)
|
||||
screen.log_dict(log, prefix='Training Reward Model')
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user