1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00
This commit is contained in:
Gal Leibovich
2019-03-19 18:07:09 +02:00
committed by GitHub
parent 4a8451ff02
commit e3c7e526c7
38 changed files with 1003 additions and 87 deletions

View File

@@ -13,16 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from collections import OrderedDict
from typing import Union
import numpy as np
from rl_coach.agents.agent import Agent
from rl_coach.core_types import ActionInfo, StateType
from rl_coach.core_types import ActionInfo, StateType, Batch
from rl_coach.logger import screen
from rl_coach.memories.non_episodic.prioritized_experience_replay import PrioritizedExperienceReplay
from rl_coach.spaces import DiscreteActionSpace
from copy import deepcopy
## This is an abstract agent - there is no learn_from_batch method ##
@@ -79,10 +80,12 @@ class ValueOptimizationAgent(Agent):
# this is for bootstrapped dqn
if type(actions_q_values) == list and len(actions_q_values) > 0:
actions_q_values = self.exploration_policy.last_action_values
actions_q_values = actions_q_values.squeeze()
# store the q values statistics for logging
self.q_values.add_sample(actions_q_values)
actions_q_values = actions_q_values.squeeze()
for i, q_value in enumerate(actions_q_values):
self.q_value_for_action[i].add_sample(q_value)
@@ -96,3 +99,74 @@ class ValueOptimizationAgent(Agent):
def learn_from_batch(self, batch):
raise NotImplementedError("ValueOptimizationAgent is an abstract agent. Not to be used directly.")
def run_off_policy_evaluation(self):
"""
Run the off-policy evaluation estimators to get a prediction for the performance of the current policy based on
an evaluation dataset, which was collected by another policy(ies).
:return: None
"""
assert self.ope_manager
dataset_as_episodes = self.call_memory('get_all_complete_episodes_from_to',
(self.call_memory('get_last_training_set_episode_id') + 1,
self.call_memory('num_complete_episodes')))
if len(dataset_as_episodes) == 0:
raise ValueError('train_to_eval_ratio is too high causing the evaluation set to be empty. '
'Consider decreasing its value.')
ips, dm, dr, seq_dr = self.ope_manager.evaluate(
dataset_as_episodes=dataset_as_episodes,
batch_size=self.ap.network_wrappers['main'].batch_size,
discount_factor=self.ap.algorithm.discount,
reward_model=self.networks['reward_model'].online_network,
q_network=self.networks['main'].online_network,
network_keys=list(self.ap.network_wrappers['main'].input_embedders_parameters.keys()))
# get the estimators out to the screen
log = OrderedDict()
log['Epoch'] = self.training_epoch
log['IPS'] = ips
log['DM'] = dm
log['DR'] = dr
log['Sequential-DR'] = seq_dr
screen.log_dict(log, prefix='Off-Policy Evaluation')
# get the estimators out to dashboard
self.agent_logger.set_current_time(self.get_current_time() + 1)
self.agent_logger.create_signal_value('Inverse Propensity Score', ips)
self.agent_logger.create_signal_value('Direct Method Reward', dm)
self.agent_logger.create_signal_value('Doubly Robust', dr)
self.agent_logger.create_signal_value('Sequential Doubly Robust', seq_dr)
def improve_reward_model(self, epochs: int):
"""
Train a reward model to be used by the doubly-robust estimator
:param epochs: The total number of epochs to use for training a reward model
:return: None
"""
batch_size = self.ap.network_wrappers['reward_model'].batch_size
network_keys = self.ap.network_wrappers['reward_model'].input_embedders_parameters.keys()
# this is fitted from the training dataset
for epoch in range(epochs):
loss = 0
for i, batch in enumerate(self.call_memory('get_shuffled_data_generator', batch_size)):
batch = Batch(batch)
current_rewards_prediction_for_all_actions = self.networks['reward_model'].online_network.predict(batch.states(network_keys))
current_rewards_prediction_for_all_actions[range(batch_size), batch.actions()] = batch.rewards()
loss += self.networks['reward_model'].train_and_sync_networks(
batch.states(network_keys), current_rewards_prediction_for_all_actions)[0]
# print(self.networks['reward_model'].online_network.predict(batch.states(network_keys))[0])
log = OrderedDict()
log['Epoch'] = epoch
log['loss'] = loss / int(self.call_memory('num_transitions_in_complete_episodes') / batch_size)
screen.log_dict(log, prefix='Training Reward Model')