mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Batch RL (#238)
This commit is contained in:
@@ -13,16 +13,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.agents.agent import Agent
|
||||
from rl_coach.core_types import ActionInfo, StateType
|
||||
from rl_coach.core_types import ActionInfo, StateType, Batch
|
||||
from rl_coach.logger import screen
|
||||
from rl_coach.memories.non_episodic.prioritized_experience_replay import PrioritizedExperienceReplay
|
||||
from rl_coach.spaces import DiscreteActionSpace
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
## This is an abstract agent - there is no learn_from_batch method ##
|
||||
|
||||
@@ -79,10 +80,12 @@ class ValueOptimizationAgent(Agent):
|
||||
# this is for bootstrapped dqn
|
||||
if type(actions_q_values) == list and len(actions_q_values) > 0:
|
||||
actions_q_values = self.exploration_policy.last_action_values
|
||||
actions_q_values = actions_q_values.squeeze()
|
||||
|
||||
# store the q values statistics for logging
|
||||
self.q_values.add_sample(actions_q_values)
|
||||
|
||||
actions_q_values = actions_q_values.squeeze()
|
||||
|
||||
for i, q_value in enumerate(actions_q_values):
|
||||
self.q_value_for_action[i].add_sample(q_value)
|
||||
|
||||
@@ -96,3 +99,74 @@ class ValueOptimizationAgent(Agent):
|
||||
|
||||
def learn_from_batch(self, batch):
|
||||
raise NotImplementedError("ValueOptimizationAgent is an abstract agent. Not to be used directly.")
|
||||
|
||||
def run_off_policy_evaluation(self):
|
||||
"""
|
||||
Run the off-policy evaluation estimators to get a prediction for the performance of the current policy based on
|
||||
an evaluation dataset, which was collected by another policy(ies).
|
||||
:return: None
|
||||
"""
|
||||
assert self.ope_manager
|
||||
dataset_as_episodes = self.call_memory('get_all_complete_episodes_from_to',
|
||||
(self.call_memory('get_last_training_set_episode_id') + 1,
|
||||
self.call_memory('num_complete_episodes')))
|
||||
if len(dataset_as_episodes) == 0:
|
||||
raise ValueError('train_to_eval_ratio is too high causing the evaluation set to be empty. '
|
||||
'Consider decreasing its value.')
|
||||
|
||||
ips, dm, dr, seq_dr = self.ope_manager.evaluate(
|
||||
dataset_as_episodes=dataset_as_episodes,
|
||||
batch_size=self.ap.network_wrappers['main'].batch_size,
|
||||
discount_factor=self.ap.algorithm.discount,
|
||||
reward_model=self.networks['reward_model'].online_network,
|
||||
q_network=self.networks['main'].online_network,
|
||||
network_keys=list(self.ap.network_wrappers['main'].input_embedders_parameters.keys()))
|
||||
|
||||
# get the estimators out to the screen
|
||||
log = OrderedDict()
|
||||
log['Epoch'] = self.training_epoch
|
||||
log['IPS'] = ips
|
||||
log['DM'] = dm
|
||||
log['DR'] = dr
|
||||
log['Sequential-DR'] = seq_dr
|
||||
screen.log_dict(log, prefix='Off-Policy Evaluation')
|
||||
|
||||
# get the estimators out to dashboard
|
||||
self.agent_logger.set_current_time(self.get_current_time() + 1)
|
||||
self.agent_logger.create_signal_value('Inverse Propensity Score', ips)
|
||||
self.agent_logger.create_signal_value('Direct Method Reward', dm)
|
||||
self.agent_logger.create_signal_value('Doubly Robust', dr)
|
||||
self.agent_logger.create_signal_value('Sequential Doubly Robust', seq_dr)
|
||||
|
||||
def improve_reward_model(self, epochs: int):
|
||||
"""
|
||||
Train a reward model to be used by the doubly-robust estimator
|
||||
|
||||
:param epochs: The total number of epochs to use for training a reward model
|
||||
:return: None
|
||||
"""
|
||||
batch_size = self.ap.network_wrappers['reward_model'].batch_size
|
||||
network_keys = self.ap.network_wrappers['reward_model'].input_embedders_parameters.keys()
|
||||
|
||||
# this is fitted from the training dataset
|
||||
for epoch in range(epochs):
|
||||
loss = 0
|
||||
for i, batch in enumerate(self.call_memory('get_shuffled_data_generator', batch_size)):
|
||||
batch = Batch(batch)
|
||||
current_rewards_prediction_for_all_actions = self.networks['reward_model'].online_network.predict(batch.states(network_keys))
|
||||
current_rewards_prediction_for_all_actions[range(batch_size), batch.actions()] = batch.rewards()
|
||||
loss += self.networks['reward_model'].train_and_sync_networks(
|
||||
batch.states(network_keys), current_rewards_prediction_for_all_actions)[0]
|
||||
# print(self.networks['reward_model'].online_network.predict(batch.states(network_keys))[0])
|
||||
|
||||
log = OrderedDict()
|
||||
log['Epoch'] = epoch
|
||||
log['loss'] = loss / int(self.call_memory('num_transitions_in_complete_episodes') / batch_size)
|
||||
screen.log_dict(log, prefix='Training Reward Model')
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user