mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Batch RL (#238)
This commit is contained in:
@@ -34,6 +34,9 @@ from rl_coach.spaces import SpacesDefinition, VectorObservationSpace, GoalsSpace
|
||||
from rl_coach.utils import Signal, force_list
|
||||
from rl_coach.utils import dynamic_import_and_instantiate_module_from_params
|
||||
from rl_coach.memories.backend.memory_impl import get_memory_backend
|
||||
from rl_coach.core_types import TimeTypes
|
||||
from rl_coach.off_policy_evaluators.ope_manager import OpeManager
|
||||
from rl_coach.core_types import PickledReplayBuffer, CsvDataset
|
||||
|
||||
|
||||
class Agent(AgentInterface):
|
||||
@@ -49,10 +52,10 @@ class Agent(AgentInterface):
|
||||
and self.ap.memory.shared_memory
|
||||
if self.shared_memory:
|
||||
self.shared_memory_scratchpad = self.ap.task_parameters.shared_memory_scratchpad
|
||||
self.name = agent_parameters.name
|
||||
self.parent = parent
|
||||
self.parent_level_manager = None
|
||||
self.full_name_id = agent_parameters.full_name_id = self.name
|
||||
# TODO this needs to be sorted out. Why the duplicates for the agent's name?
|
||||
self.full_name_id = agent_parameters.full_name_id = self.name = agent_parameters.name
|
||||
|
||||
if type(agent_parameters.task_parameters) == DistributedTaskParameters:
|
||||
screen.log_title("Creating agent - name: {} task id: {} (may take up to 30 seconds due to "
|
||||
@@ -84,9 +87,17 @@ class Agent(AgentInterface):
|
||||
self.memory.set_memory_backend(self.memory_backend)
|
||||
|
||||
if agent_parameters.memory.load_memory_from_file_path:
|
||||
screen.log_title("Loading replay buffer from pickle. Pickle path: {}"
|
||||
.format(agent_parameters.memory.load_memory_from_file_path))
|
||||
self.memory.load(agent_parameters.memory.load_memory_from_file_path)
|
||||
if isinstance(agent_parameters.memory.load_memory_from_file_path, PickledReplayBuffer):
|
||||
screen.log_title("Loading a pickled replay buffer. Pickled file path: {}"
|
||||
.format(agent_parameters.memory.load_memory_from_file_path.filepath))
|
||||
self.memory.load_pickled(agent_parameters.memory.load_memory_from_file_path.filepath)
|
||||
elif isinstance(agent_parameters.memory.load_memory_from_file_path, CsvDataset):
|
||||
screen.log_title("Loading a replay buffer from a CSV file. CSV file path: {}"
|
||||
.format(agent_parameters.memory.load_memory_from_file_path.filepath))
|
||||
self.memory.load_csv(agent_parameters.memory.load_memory_from_file_path)
|
||||
else:
|
||||
raise ValueError('Trying to load a replay buffer using an unsupported method - {}. '
|
||||
.format(agent_parameters.memory.load_memory_from_file_path))
|
||||
|
||||
if self.shared_memory and self.is_chief:
|
||||
self.shared_memory_scratchpad.add(self.memory_lookup_name, self.memory)
|
||||
@@ -147,6 +158,7 @@ class Agent(AgentInterface):
|
||||
self.total_steps_counter = 0
|
||||
self.running_reward = None
|
||||
self.training_iteration = 0
|
||||
self.training_epoch = 0
|
||||
self.last_target_network_update_step = 0
|
||||
self.last_training_phase_step = 0
|
||||
self.current_episode = self.ap.current_episode = 0
|
||||
@@ -184,6 +196,7 @@ class Agent(AgentInterface):
|
||||
self.discounted_return = self.register_signal('Discounted Return')
|
||||
if isinstance(self.in_action_space, GoalsSpace):
|
||||
self.distance_from_goal = self.register_signal('Distance From Goal', dump_one_value_per_step=True)
|
||||
|
||||
# use seed
|
||||
if self.ap.task_parameters.seed is not None:
|
||||
random.seed(self.ap.task_parameters.seed)
|
||||
@@ -193,6 +206,9 @@ class Agent(AgentInterface):
|
||||
random.seed()
|
||||
np.random.seed()
|
||||
|
||||
# batch rl
|
||||
self.ope_manager = OpeManager() if self.ap.is_batch_rl_training else None
|
||||
|
||||
@property
|
||||
def parent(self) -> 'LevelManager':
|
||||
"""
|
||||
@@ -228,6 +244,7 @@ class Agent(AgentInterface):
|
||||
format(graph_name=self.parent_level_manager.parent_graph_manager.name,
|
||||
level_name=self.parent_level_manager.name,
|
||||
agent_full_id='.'.join(self.full_name_id.split('/')))
|
||||
self.agent_logger.set_index_name(self.parent_level_manager.parent_graph_manager.time_metric.value.name)
|
||||
self.agent_logger.set_logger_filenames(self.ap.task_parameters.experiment_path, logger_prefix=logger_prefix,
|
||||
add_timestamp=True, task_id=self.task_id)
|
||||
if self.ap.visualization.dump_in_episode_signals:
|
||||
@@ -387,7 +404,8 @@ class Agent(AgentInterface):
|
||||
elif ending_evaluation:
|
||||
# we write to the next episode, because it could be that the current episode was already written
|
||||
# to disk and then we won't write it again
|
||||
self.agent_logger.set_current_time(self.current_episode + 1)
|
||||
self.agent_logger.set_current_time(self.get_current_time() + 1)
|
||||
|
||||
evaluation_reward = self.accumulated_rewards_across_evaluation_episodes / self.num_evaluation_episodes_completed
|
||||
self.agent_logger.create_signal_value(
|
||||
'Evaluation Reward', evaluation_reward)
|
||||
@@ -471,8 +489,11 @@ class Agent(AgentInterface):
|
||||
:return: None
|
||||
"""
|
||||
# log all the signals to file
|
||||
self.agent_logger.set_current_time(self.current_episode)
|
||||
current_time = self.get_current_time()
|
||||
self.agent_logger.set_current_time(current_time)
|
||||
self.agent_logger.create_signal_value('Training Iter', self.training_iteration)
|
||||
self.agent_logger.create_signal_value('Episode #', self.current_episode)
|
||||
self.agent_logger.create_signal_value('Epoch', self.training_epoch)
|
||||
self.agent_logger.create_signal_value('In Heatup', int(self._phase == RunPhase.HEATUP))
|
||||
self.agent_logger.create_signal_value('ER #Transitions', self.call_memory('num_transitions'))
|
||||
self.agent_logger.create_signal_value('ER #Episodes', self.call_memory('length'))
|
||||
@@ -485,13 +506,17 @@ class Agent(AgentInterface):
|
||||
if self._phase == RunPhase.TRAIN else np.nan)
|
||||
|
||||
self.agent_logger.create_signal_value('Update Target Network', 0, overwrite=False)
|
||||
self.agent_logger.update_wall_clock_time(self.current_episode)
|
||||
self.agent_logger.update_wall_clock_time(current_time)
|
||||
|
||||
# The following signals are created with meaningful values only when an evaluation phase is completed.
|
||||
# Creating with default NaNs for any HEATUP/TRAIN/TEST episode which is not the last in an evaluation phase
|
||||
self.agent_logger.create_signal_value('Evaluation Reward', np.nan, overwrite=False)
|
||||
self.agent_logger.create_signal_value('Shaped Evaluation Reward', np.nan, overwrite=False)
|
||||
self.agent_logger.create_signal_value('Success Rate', np.nan, overwrite=False)
|
||||
self.agent_logger.create_signal_value('Inverse Propensity Score', np.nan, overwrite=False)
|
||||
self.agent_logger.create_signal_value('Direct Method Reward', np.nan, overwrite=False)
|
||||
self.agent_logger.create_signal_value('Doubly Robust', np.nan, overwrite=False)
|
||||
self.agent_logger.create_signal_value('Sequential Doubly Robust', np.nan, overwrite=False)
|
||||
|
||||
for signal in self.episode_signals:
|
||||
self.agent_logger.create_signal_value("{}/Mean".format(signal.name), signal.get_mean())
|
||||
@@ -500,8 +525,7 @@ class Agent(AgentInterface):
|
||||
self.agent_logger.create_signal_value("{}/Min".format(signal.name), signal.get_min())
|
||||
|
||||
# dump
|
||||
if self.current_episode % self.ap.visualization.dump_signals_to_csv_every_x_episodes == 0 \
|
||||
and self.current_episode > 0:
|
||||
if self.current_episode % self.ap.visualization.dump_signals_to_csv_every_x_episodes == 0:
|
||||
self.agent_logger.dump_output_csv()
|
||||
|
||||
def handle_episode_ended(self) -> None:
|
||||
@@ -537,7 +561,8 @@ class Agent(AgentInterface):
|
||||
self.total_reward_in_current_episode >= self.spaces.reward.reward_success_threshold:
|
||||
self.num_successes_across_evaluation_episodes += 1
|
||||
|
||||
if self.ap.visualization.dump_csv:
|
||||
if self.ap.visualization.dump_csv and \
|
||||
self.parent_level_manager.parent_graph_manager.time_metric == TimeTypes.EpisodeNumber:
|
||||
self.update_log()
|
||||
|
||||
if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == "high":
|
||||
@@ -651,18 +676,22 @@ class Agent(AgentInterface):
|
||||
"""
|
||||
loss = 0
|
||||
if self._should_train():
|
||||
self.training_epoch += 1
|
||||
for network in self.networks.values():
|
||||
network.set_is_training(True)
|
||||
|
||||
for training_step in range(self.ap.algorithm.num_consecutive_training_steps):
|
||||
# TODO: this should be network dependent
|
||||
network_parameters = list(self.ap.network_wrappers.values())[0]
|
||||
# TODO: this should be network dependent
|
||||
network_parameters = list(self.ap.network_wrappers.values())[0]
|
||||
|
||||
# we either go sequentially through the entire replay buffer in the batch RL mode,
|
||||
# or sample randomly for the basic RL case.
|
||||
training_schedule = self.call_memory('get_shuffled_data_generator', network_parameters.batch_size) if \
|
||||
self.ap.is_batch_rl_training else [self.call_memory('sample', network_parameters.batch_size) for _ in
|
||||
range(self.ap.algorithm.num_consecutive_training_steps)]
|
||||
|
||||
for batch in training_schedule:
|
||||
# update counters
|
||||
self.training_iteration += 1
|
||||
|
||||
# sample a batch and train on it
|
||||
batch = self.call_memory('sample', network_parameters.batch_size)
|
||||
if self.pre_network_filter is not None:
|
||||
batch = self.pre_network_filter.filter(batch, update_internal_state=False, deep_copy=False)
|
||||
|
||||
@@ -673,6 +702,7 @@ class Agent(AgentInterface):
|
||||
batch = Batch(batch)
|
||||
total_loss, losses, unclipped_grads = self.learn_from_batch(batch)
|
||||
loss += total_loss
|
||||
|
||||
self.unclipped_grads.add_sample(unclipped_grads)
|
||||
|
||||
# TODO: the learning rate decay should be done through the network instead of here
|
||||
@@ -697,6 +727,12 @@ class Agent(AgentInterface):
|
||||
if self.imitation:
|
||||
self.log_to_screen()
|
||||
|
||||
if self.ap.visualization.dump_csv and \
|
||||
self.parent_level_manager.parent_graph_manager.time_metric == TimeTypes.Epoch:
|
||||
# in BatchRL, or imitation learning, the agent never acts, so we have to get the stats out here.
|
||||
# we dump the data out every epoch
|
||||
self.update_log()
|
||||
|
||||
for network in self.networks.values():
|
||||
network.set_is_training(False)
|
||||
|
||||
@@ -1034,3 +1070,12 @@ class Agent(AgentInterface):
|
||||
for network in self.networks.values():
|
||||
savers.update(network.collect_savers(parent_path_suffix))
|
||||
return savers
|
||||
|
||||
def get_current_time(self):
|
||||
pass
|
||||
return {
|
||||
TimeTypes.EpisodeNumber: self.current_episode,
|
||||
TimeTypes.TrainingIteration: self.training_iteration,
|
||||
TimeTypes.EnvironmentSteps: self.total_steps_counter,
|
||||
TimeTypes.WallClockTime: self.agent_logger.get_current_wall_clock_time(),
|
||||
TimeTypes.Epoch: self.training_epoch}[self.parent_level_manager.parent_graph_manager.time_metric]
|
||||
|
||||
Reference in New Issue
Block a user