mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
BCQ variant on top of DDQN (#276)
* kNN based model for predicting which actions to drop * fix for seeds with batch rl
This commit is contained in:
@@ -45,6 +45,15 @@ class Agent(AgentInterface):
|
|||||||
:param agent_parameters: A AgentParameters class instance with all the agent parameters
|
:param agent_parameters: A AgentParameters class instance with all the agent parameters
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
# use seed
|
||||||
|
if agent_parameters.task_parameters.seed is not None:
|
||||||
|
random.seed(agent_parameters.task_parameters.seed)
|
||||||
|
np.random.seed(agent_parameters.task_parameters.seed)
|
||||||
|
else:
|
||||||
|
# we need to seed the RNG since the different processes are initialized with the same parent seed
|
||||||
|
random.seed()
|
||||||
|
np.random.seed()
|
||||||
|
|
||||||
self.ap = agent_parameters
|
self.ap = agent_parameters
|
||||||
self.task_id = self.ap.task_parameters.task_index
|
self.task_id = self.ap.task_parameters.task_index
|
||||||
self.is_chief = self.task_id == 0
|
self.is_chief = self.task_id == 0
|
||||||
@@ -197,15 +206,6 @@ class Agent(AgentInterface):
|
|||||||
if isinstance(self.in_action_space, GoalsSpace):
|
if isinstance(self.in_action_space, GoalsSpace):
|
||||||
self.distance_from_goal = self.register_signal('Distance From Goal', dump_one_value_per_step=True)
|
self.distance_from_goal = self.register_signal('Distance From Goal', dump_one_value_per_step=True)
|
||||||
|
|
||||||
# use seed
|
|
||||||
if self.ap.task_parameters.seed is not None:
|
|
||||||
random.seed(self.ap.task_parameters.seed)
|
|
||||||
np.random.seed(self.ap.task_parameters.seed)
|
|
||||||
else:
|
|
||||||
# we need to seed the RNG since the different processes are initialized with the same parent seed
|
|
||||||
random.seed()
|
|
||||||
np.random.seed()
|
|
||||||
|
|
||||||
# batch rl
|
# batch rl
|
||||||
self.ope_manager = OpeManager() if self.ap.is_batch_rl_training else None
|
self.ope_manager = OpeManager() if self.ap.is_batch_rl_training else None
|
||||||
|
|
||||||
@@ -688,13 +688,16 @@ class Agent(AgentInterface):
|
|||||||
for network in self.networks.values():
|
for network in self.networks.values():
|
||||||
network.set_is_training(True)
|
network.set_is_training(True)
|
||||||
|
|
||||||
# TODO: this should be network dependent
|
# At the moment we only support a single batch size for all the networks
|
||||||
network_parameters = list(self.ap.network_wrappers.values())[0]
|
networks_parameters = list(self.ap.network_wrappers.values())
|
||||||
|
assert all(net.batch_size == networks_parameters[0].batch_size for net in networks_parameters)
|
||||||
|
|
||||||
|
batch_size = networks_parameters[0].batch_size
|
||||||
|
|
||||||
# we either go sequentially through the entire replay buffer in the batch RL mode,
|
# we either go sequentially through the entire replay buffer in the batch RL mode,
|
||||||
# or sample randomly for the basic RL case.
|
# or sample randomly for the basic RL case.
|
||||||
training_schedule = self.call_memory('get_shuffled_data_generator', network_parameters.batch_size) if \
|
training_schedule = self.call_memory('get_shuffled_data_generator', batch_size) if \
|
||||||
self.ap.is_batch_rl_training else [self.call_memory('sample', network_parameters.batch_size) for _ in
|
self.ap.is_batch_rl_training else [self.call_memory('sample', batch_size) for _ in
|
||||||
range(self.ap.algorithm.num_consecutive_training_steps)]
|
range(self.ap.algorithm.num_consecutive_training_steps)]
|
||||||
|
|
||||||
for batch in training_schedule:
|
for batch in training_schedule:
|
||||||
@@ -713,13 +716,16 @@ class Agent(AgentInterface):
|
|||||||
|
|
||||||
self.unclipped_grads.add_sample(unclipped_grads)
|
self.unclipped_grads.add_sample(unclipped_grads)
|
||||||
|
|
||||||
# TODO: the learning rate decay should be done through the network instead of here
|
# TODO: this only deals with the main network (if exists), need to do the same for other networks
|
||||||
|
# for instance, for DDPG, the LR signal is currently not shown. Probably should be done through the
|
||||||
|
# network directly instead of here
|
||||||
# decay learning rate
|
# decay learning rate
|
||||||
if network_parameters.learning_rate_decay_rate != 0:
|
if 'main' in self.ap.network_wrappers and \
|
||||||
|
self.ap.network_wrappers['main'].learning_rate_decay_rate != 0:
|
||||||
self.curr_learning_rate.add_sample(self.networks['main'].sess.run(
|
self.curr_learning_rate.add_sample(self.networks['main'].sess.run(
|
||||||
self.networks['main'].online_network.current_learning_rate))
|
self.networks['main'].online_network.current_learning_rate))
|
||||||
else:
|
else:
|
||||||
self.curr_learning_rate.add_sample(network_parameters.learning_rate)
|
self.curr_learning_rate.add_sample(networks_parameters[0].learning_rate)
|
||||||
|
|
||||||
if any([network.has_target for network in self.networks.values()]) \
|
if any([network.has_target for network in self.networks.values()]) \
|
||||||
and self._should_update_online_weights_to_target():
|
and self._should_update_online_weights_to_target():
|
||||||
|
|||||||
@@ -17,9 +17,7 @@
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from rl_coach.agents.dqn_agent import DQNAgent, DQNAgentParameters
|
||||||
from rl_coach.agents.dqn_agent import DQNAgentParameters
|
|
||||||
from rl_coach.agents.value_optimization_agent import ValueOptimizationAgent
|
|
||||||
from rl_coach.core_types import EnvironmentSteps
|
from rl_coach.core_types import EnvironmentSteps
|
||||||
from rl_coach.schedules import LinearSchedule
|
from rl_coach.schedules import LinearSchedule
|
||||||
|
|
||||||
@@ -37,36 +35,10 @@ class DDQNAgentParameters(DQNAgentParameters):
|
|||||||
|
|
||||||
|
|
||||||
# Double DQN - https://arxiv.org/abs/1509.06461
|
# Double DQN - https://arxiv.org/abs/1509.06461
|
||||||
class DDQNAgent(ValueOptimizationAgent):
|
class DDQNAgent(DQNAgent):
|
||||||
def __init__(self, agent_parameters, parent: Union['LevelManager', 'CompositeAgent']=None):
|
def __init__(self, agent_parameters, parent: Union['LevelManager', 'CompositeAgent']=None):
|
||||||
super().__init__(agent_parameters, parent)
|
super().__init__(agent_parameters, parent)
|
||||||
|
|
||||||
def learn_from_batch(self, batch):
|
def select_actions(self, next_states, q_st_plus_1):
|
||||||
network_keys = self.ap.network_wrappers['main'].input_embedders_parameters.keys()
|
return np.argmax(self.networks['main'].online_network.predict(next_states), 1)
|
||||||
|
|
||||||
selected_actions = np.argmax(self.networks['main'].online_network.predict(batch.next_states(network_keys)), 1)
|
|
||||||
q_st_plus_1, TD_targets = self.networks['main'].parallel_prediction([
|
|
||||||
(self.networks['main'].target_network, batch.next_states(network_keys)),
|
|
||||||
(self.networks['main'].online_network, batch.states(network_keys))
|
|
||||||
])
|
|
||||||
|
|
||||||
# add Q value samples for logging
|
|
||||||
self.q_values.add_sample(TD_targets)
|
|
||||||
|
|
||||||
# initialize with the current prediction so that we will
|
|
||||||
# only update the action that we have actually done in this transition
|
|
||||||
TD_errors = []
|
|
||||||
for i in range(batch.size):
|
|
||||||
new_target = batch.rewards()[i] + \
|
|
||||||
(1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * q_st_plus_1[i][selected_actions[i]]
|
|
||||||
TD_errors.append(np.abs(new_target - TD_targets[i, batch.actions()[i]]))
|
|
||||||
TD_targets[i, batch.actions()[i]] = new_target
|
|
||||||
|
|
||||||
# update errors in prioritized replay buffer
|
|
||||||
importance_weights = self.update_transition_priorities_and_get_weights(TD_errors, batch)
|
|
||||||
|
|
||||||
result = self.networks['main'].train_and_sync_networks(batch.states(network_keys), TD_targets,
|
|
||||||
importance_weights=importance_weights)
|
|
||||||
total_loss, losses, unclipped_grads = result[:3]
|
|
||||||
|
|
||||||
return total_loss, losses, unclipped_grads
|
|
||||||
|
|||||||
223
rl_coach/agents/ddqn_bcq_agent.py
Normal file
223
rl_coach/agents/ddqn_bcq_agent.py
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2019 Intel Corporation
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Union, List, Dict
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from rl_coach.agents.dqn_agent import DQNAgentParameters, DQNAlgorithmParameters, DQNAgent
|
||||||
|
from rl_coach.base_parameters import Parameters
|
||||||
|
from rl_coach.core_types import EnvironmentSteps, Batch, StateType
|
||||||
|
from rl_coach.graph_managers.batch_rl_graph_manager import BatchRLGraphManager
|
||||||
|
from rl_coach.logger import screen
|
||||||
|
from rl_coach.memories.non_episodic.differentiable_neural_dictionary import AnnoyDictionary
|
||||||
|
from rl_coach.schedules import LinearSchedule
|
||||||
|
|
||||||
|
|
||||||
|
class NNImitationModelParameters(Parameters):
|
||||||
|
"""
|
||||||
|
A parameters module grouping together parameters related to a neural network based action selection.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.imitation_model_num_epochs = 100
|
||||||
|
self.mask_out_actions_threshold = 0.35
|
||||||
|
|
||||||
|
|
||||||
|
class KNNParameters(Parameters):
|
||||||
|
"""
|
||||||
|
A parameters module grouping together parameters related to a k-Nearest Neighbor based action selection.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.average_dist_coefficient = 1
|
||||||
|
self.knn_size = 50000
|
||||||
|
self.use_state_embedding_instead_of_state = True # useful when the state is too big to be used for kNN
|
||||||
|
|
||||||
|
|
||||||
|
class DDQNBCQAlgorithmParameters(DQNAlgorithmParameters):
|
||||||
|
"""
|
||||||
|
:param action_drop_method_parameters: (Parameters)
|
||||||
|
Defines the mode and related parameters according to which low confidence actions will be filtered out
|
||||||
|
:param num_steps_between_copying_online_weights_to_target (StepMethod)
|
||||||
|
Defines the number of steps between every phase of copying online network's weights to the target network's weights
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.action_drop_method_parameters = KNNParameters()
|
||||||
|
self.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(30000)
|
||||||
|
|
||||||
|
|
||||||
|
class DDQNBCQAgentParameters(DQNAgentParameters):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.algorithm = DDQNBCQAlgorithmParameters()
|
||||||
|
self.exploration.epsilon_schedule = LinearSchedule(1, 0.01, 1000000)
|
||||||
|
self.exploration.evaluation_epsilon = 0.001
|
||||||
|
|
||||||
|
@property
|
||||||
|
def path(self):
|
||||||
|
return 'rl_coach.agents.ddqn_bcq_agent:DDQNBCQAgent'
|
||||||
|
|
||||||
|
|
||||||
|
# Double DQN - https://arxiv.org/abs/1509.06461
|
||||||
|
# (a variant on) BCQ - https://arxiv.org/pdf/1812.02900v2.pdf
|
||||||
|
class DDQNBCQAgent(DQNAgent):
|
||||||
|
def __init__(self, agent_parameters, parent: Union['LevelManager', 'CompositeAgent']=None):
|
||||||
|
super().__init__(agent_parameters, parent)
|
||||||
|
|
||||||
|
if isinstance(self.ap.algorithm.action_drop_method_parameters, KNNParameters):
|
||||||
|
self.knn_trees = [] # will be filled out later, as we don't have the action space size yet
|
||||||
|
self.average_dist = 0
|
||||||
|
|
||||||
|
def to_embedding(states: Union[List[StateType], Dict]):
|
||||||
|
if isinstance(states, list):
|
||||||
|
states = self.prepare_batch_for_inference(states, 'reward_model')
|
||||||
|
if self.ap.algorithm.action_drop_method_parameters.use_state_embedding_instead_of_state:
|
||||||
|
return self.networks['reward_model'].online_network.predict(
|
||||||
|
states,
|
||||||
|
outputs=[self.networks['reward_model'].online_network.state_embedding])
|
||||||
|
else:
|
||||||
|
return states['observation']
|
||||||
|
self.embedding = to_embedding
|
||||||
|
|
||||||
|
elif isinstance(self.ap.algorithm.action_drop_method_parameters, NNImitationModelParameters):
|
||||||
|
if 'imitation_model' not in self.ap.network_wrappers:
|
||||||
|
# user hasn't defined params for the reward model. we will use the same params as used for the 'main'
|
||||||
|
# network.
|
||||||
|
self.ap.network_wrappers['imitation_model'] = deepcopy(self.ap.network_wrappers['reward_model'])
|
||||||
|
else:
|
||||||
|
raise ValueError('Unsupported action drop method {} for DDQNBCQAgent'.format(
|
||||||
|
type(self.ap.algorithm.action_drop_method_parameters)))
|
||||||
|
|
||||||
|
def select_actions(self, next_states, q_st_plus_1):
|
||||||
|
if isinstance(self.ap.algorithm.action_drop_method_parameters, KNNParameters):
|
||||||
|
familiarity = np.array([[distance[0] for distance in
|
||||||
|
knn_tree._get_k_nearest_neighbors_indices(self.embedding(next_states), 1)[0]]
|
||||||
|
for knn_tree in self.knn_trees]).transpose()
|
||||||
|
actions_to_mask_out = familiarity > self.ap.algorithm.action_drop_method_parameters.average_dist_coefficient \
|
||||||
|
* self.average_dist
|
||||||
|
|
||||||
|
elif isinstance(self.ap.algorithm.action_drop_method_parameters, NNImitationModelParameters):
|
||||||
|
familiarity = self.networks['imitation_model'].online_network.predict(next_states)
|
||||||
|
actions_to_mask_out = familiarity < \
|
||||||
|
self.ap.algorithm.action_drop_method_parameters.mask_out_actions_threshold
|
||||||
|
else:
|
||||||
|
raise ValueError('Unsupported action drop method {} for DDQNBCQAgent'.format(
|
||||||
|
type(self.ap.algorithm.action_drop_method_parameters)))
|
||||||
|
|
||||||
|
masked_next_q_values = self.networks['main'].online_network.predict(next_states)
|
||||||
|
masked_next_q_values[actions_to_mask_out] = -np.inf
|
||||||
|
|
||||||
|
# occassionaly there are states in the batch for which our model shows no confidence for either of the actions
|
||||||
|
# in that case, we will just randomly assign q_values to actions, since otherwise argmax will always return
|
||||||
|
# the first action
|
||||||
|
zero_confidence_rows = (masked_next_q_values.max(axis=1) == -np.inf)
|
||||||
|
masked_next_q_values[zero_confidence_rows] = np.random.rand(np.sum(zero_confidence_rows),
|
||||||
|
masked_next_q_values.shape[1])
|
||||||
|
|
||||||
|
return np.argmax(masked_next_q_values, 1)
|
||||||
|
|
||||||
|
def improve_reward_model(self, epochs: int):
|
||||||
|
"""
|
||||||
|
Train both a reward model to be used by the doubly-robust estimator, and some model to be used for BCQ
|
||||||
|
|
||||||
|
:param epochs: The total number of epochs to use for training a reward model
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
|
||||||
|
# we'll be assuming that these gets drawn from the reward model parameters
|
||||||
|
batch_size = self.ap.network_wrappers['reward_model'].batch_size
|
||||||
|
network_keys = self.ap.network_wrappers['reward_model'].input_embedders_parameters.keys()
|
||||||
|
|
||||||
|
# if using a NN to decide which actions to drop, we'll train the NN here
|
||||||
|
if isinstance(self.ap.algorithm.action_drop_method_parameters, NNImitationModelParameters):
|
||||||
|
total_epochs = max(epochs, self.ap.algorithm.action_drop_method_parameters.imitation_model_num_epochs)
|
||||||
|
else:
|
||||||
|
total_epochs = epochs
|
||||||
|
|
||||||
|
for epoch in range(total_epochs):
|
||||||
|
# this is fitted from the training dataset
|
||||||
|
reward_model_loss = 0
|
||||||
|
imitation_model_loss = 0
|
||||||
|
total_transitions_processed = 0
|
||||||
|
for i, batch in enumerate(self.call_memory('get_shuffled_data_generator', batch_size)):
|
||||||
|
batch = Batch(batch)
|
||||||
|
|
||||||
|
# reward model
|
||||||
|
if epoch < epochs:
|
||||||
|
reward_model_loss += self.get_reward_model_loss(batch)
|
||||||
|
|
||||||
|
# imitation model
|
||||||
|
if isinstance(self.ap.algorithm.action_drop_method_parameters, NNImitationModelParameters) and \
|
||||||
|
epoch < self.ap.algorithm.action_drop_method_parameters.imitation_model_num_epochs:
|
||||||
|
target_actions = np.zeros((batch.size, len(self.spaces.action.actions)))
|
||||||
|
target_actions[range(batch.size), batch.actions()] = 1
|
||||||
|
imitation_model_loss += self.networks['imitation_model'].train_and_sync_networks(
|
||||||
|
batch.states(network_keys), target_actions)[0]
|
||||||
|
|
||||||
|
total_transitions_processed += batch.size
|
||||||
|
|
||||||
|
log = OrderedDict()
|
||||||
|
log['Epoch'] = epoch
|
||||||
|
|
||||||
|
if reward_model_loss:
|
||||||
|
log['Reward Model Loss'] = reward_model_loss / total_transitions_processed
|
||||||
|
if imitation_model_loss:
|
||||||
|
log['Imitation Model Loss'] = imitation_model_loss / total_transitions_processed
|
||||||
|
|
||||||
|
screen.log_dict(log, prefix='Training Batch RL Models')
|
||||||
|
|
||||||
|
# if using a kNN based model, we'll initialize and build it here.
|
||||||
|
# initialization cannot be moved to the constructor as we don't have the agent's spaces initialized yet.
|
||||||
|
if isinstance(self.ap.algorithm.action_drop_method_parameters, KNNParameters):
|
||||||
|
knn_size = self.ap.algorithm.action_drop_method_parameters.knn_size
|
||||||
|
if self.ap.algorithm.action_drop_method_parameters.use_state_embedding_instead_of_state:
|
||||||
|
self.knn_trees = [AnnoyDictionary(
|
||||||
|
dict_size=knn_size,
|
||||||
|
key_width=int(self.networks['reward_model'].online_network.state_embedding.shape[-1]),
|
||||||
|
batch_size=knn_size)
|
||||||
|
for _ in range(len(self.spaces.action.actions))]
|
||||||
|
else:
|
||||||
|
self.knn_trees = [AnnoyDictionary(
|
||||||
|
dict_size=knn_size,
|
||||||
|
key_width=self.spaces.state['observation'].shape[0],
|
||||||
|
batch_size=knn_size)
|
||||||
|
for _ in range(len(self.spaces.action.actions))]
|
||||||
|
|
||||||
|
for i, knn_tree in enumerate(self.knn_trees):
|
||||||
|
state_embeddings = self.embedding([transition.state for transition in self.memory.transitions
|
||||||
|
if transition.action == i])
|
||||||
|
knn_tree.add(
|
||||||
|
keys=state_embeddings,
|
||||||
|
values=np.expand_dims(np.zeros(state_embeddings.shape[0]), axis=1))
|
||||||
|
|
||||||
|
for knn_tree in self.knn_trees:
|
||||||
|
knn_tree._rebuild_index()
|
||||||
|
|
||||||
|
self.average_dist = [[dist[0] for dist in knn_tree._get_k_nearest_neighbors_indices(
|
||||||
|
keys=self.embedding([transition.state for transition in self.memory.transitions]),
|
||||||
|
k=1)[0]] for knn_tree in self.knn_trees]
|
||||||
|
self.average_dist = sum([x for l in self.average_dist for x in l]) # flatten and sum
|
||||||
|
self.average_dist /= len(self.memory.transitions)
|
||||||
|
|
||||||
|
def set_session(self, sess) -> None:
|
||||||
|
super().set_session(sess)
|
||||||
|
|
||||||
|
# we check here if we are in batch-rl, since this is the only place where we have a graph_manager to question
|
||||||
|
assert isinstance(self.parent_level_manager.parent_graph_manager, BatchRLGraphManager),\
|
||||||
|
'DDQNBCQ agent can only be used in BatchRL'
|
||||||
@@ -70,6 +70,9 @@ class DQNAgent(ValueOptimizationAgent):
|
|||||||
def __init__(self, agent_parameters, parent: Union['LevelManager', 'CompositeAgent']=None):
|
def __init__(self, agent_parameters, parent: Union['LevelManager', 'CompositeAgent']=None):
|
||||||
super().__init__(agent_parameters, parent)
|
super().__init__(agent_parameters, parent)
|
||||||
|
|
||||||
|
def select_actions(self, next_states, q_st_plus_1):
|
||||||
|
return np.argmax(q_st_plus_1, 1)
|
||||||
|
|
||||||
def learn_from_batch(self, batch):
|
def learn_from_batch(self, batch):
|
||||||
network_keys = self.ap.network_wrappers['main'].input_embedders_parameters.keys()
|
network_keys = self.ap.network_wrappers['main'].input_embedders_parameters.keys()
|
||||||
|
|
||||||
@@ -81,6 +84,8 @@ class DQNAgent(ValueOptimizationAgent):
|
|||||||
(self.networks['main'].online_network, batch.states(network_keys))
|
(self.networks['main'].online_network, batch.states(network_keys))
|
||||||
])
|
])
|
||||||
|
|
||||||
|
selected_actions = self.select_actions(batch.next_states(network_keys), q_st_plus_1)
|
||||||
|
|
||||||
# add Q value samples for logging
|
# add Q value samples for logging
|
||||||
self.q_values.add_sample(TD_targets)
|
self.q_values.add_sample(TD_targets)
|
||||||
|
|
||||||
@@ -88,7 +93,7 @@ class DQNAgent(ValueOptimizationAgent):
|
|||||||
TD_errors = []
|
TD_errors = []
|
||||||
for i in range(batch.size):
|
for i in range(batch.size):
|
||||||
new_target = batch.rewards()[i] +\
|
new_target = batch.rewards()[i] +\
|
||||||
(1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * np.max(q_st_plus_1[i], 0)
|
(1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * q_st_plus_1[i][selected_actions[i]]
|
||||||
TD_errors.append(np.abs(new_target - TD_targets[i, batch.actions()[i]]))
|
TD_errors.append(np.abs(new_target - TD_targets[i, batch.actions()[i]]))
|
||||||
TD_targets[i, batch.actions()[i]] = new_target
|
TD_targets[i, batch.actions()[i]] = new_target
|
||||||
|
|
||||||
|
|||||||
@@ -139,6 +139,15 @@ class ValueOptimizationAgent(Agent):
|
|||||||
self.agent_logger.create_signal_value('Doubly Robust', dr)
|
self.agent_logger.create_signal_value('Doubly Robust', dr)
|
||||||
self.agent_logger.create_signal_value('Sequential Doubly Robust', seq_dr)
|
self.agent_logger.create_signal_value('Sequential Doubly Robust', seq_dr)
|
||||||
|
|
||||||
|
def get_reward_model_loss(self, batch: Batch):
|
||||||
|
network_keys = self.ap.network_wrappers['reward_model'].input_embedders_parameters.keys()
|
||||||
|
current_rewards_prediction_for_all_actions = self.networks['reward_model'].online_network.predict(
|
||||||
|
batch.states(network_keys))
|
||||||
|
current_rewards_prediction_for_all_actions[range(batch.size), batch.actions()] = batch.rewards()
|
||||||
|
|
||||||
|
return self.networks['reward_model'].train_and_sync_networks(
|
||||||
|
batch.states(network_keys), current_rewards_prediction_for_all_actions)[0]
|
||||||
|
|
||||||
def improve_reward_model(self, epochs: int):
|
def improve_reward_model(self, epochs: int):
|
||||||
"""
|
"""
|
||||||
Train a reward model to be used by the doubly-robust estimator
|
Train a reward model to be used by the doubly-robust estimator
|
||||||
@@ -147,7 +156,6 @@ class ValueOptimizationAgent(Agent):
|
|||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
batch_size = self.ap.network_wrappers['reward_model'].batch_size
|
batch_size = self.ap.network_wrappers['reward_model'].batch_size
|
||||||
network_keys = self.ap.network_wrappers['reward_model'].input_embedders_parameters.keys()
|
|
||||||
|
|
||||||
# this is fitted from the training dataset
|
# this is fitted from the training dataset
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
@@ -155,10 +163,7 @@ class ValueOptimizationAgent(Agent):
|
|||||||
total_transitions_processed = 0
|
total_transitions_processed = 0
|
||||||
for i, batch in enumerate(self.call_memory('get_shuffled_data_generator', batch_size)):
|
for i, batch in enumerate(self.call_memory('get_shuffled_data_generator', batch_size)):
|
||||||
batch = Batch(batch)
|
batch = Batch(batch)
|
||||||
current_rewards_prediction_for_all_actions = self.networks['reward_model'].online_network.predict(batch.states(network_keys))
|
loss += self.get_reward_model_loss(batch)
|
||||||
current_rewards_prediction_for_all_actions[range(batch.size), batch.actions()] = batch.rewards()
|
|
||||||
loss += self.networks['reward_model'].train_and_sync_networks(
|
|
||||||
batch.states(network_keys), current_rewards_prediction_for_all_actions)[0]
|
|
||||||
total_transitions_processed += batch.size
|
total_transitions_processed += batch.size
|
||||||
|
|
||||||
log = OrderedDict()
|
log = OrderedDict()
|
||||||
@@ -166,9 +171,3 @@ class ValueOptimizationAgent(Agent):
|
|||||||
log['loss'] = loss / total_transitions_processed
|
log['loss'] = loss / total_transitions_processed
|
||||||
screen.log_dict(log, prefix='Training Reward Model')
|
screen.log_dict(log, prefix='Training Reward Model')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -158,6 +158,16 @@ class QHeadParameters(HeadParameters):
|
|||||||
loss_weight=loss_weight)
|
loss_weight=loss_weight)
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationHeadParameters(HeadParameters):
|
||||||
|
def __init__(self, activation_function: str ='relu', name: str='classification_head_params',
|
||||||
|
num_output_head_copies: int = 1, rescale_gradient_from_head_by_factor: float = 1.0,
|
||||||
|
loss_weight: float = 1.0, dense_layer=None):
|
||||||
|
super().__init__(parameterized_class_name="ClassificationHead", activation_function=activation_function, name=name,
|
||||||
|
dense_layer=dense_layer, num_output_head_copies=num_output_head_copies,
|
||||||
|
rescale_gradient_from_head_by_factor=rescale_gradient_from_head_by_factor,
|
||||||
|
loss_weight=loss_weight)
|
||||||
|
|
||||||
|
|
||||||
class QuantileRegressionQHeadParameters(HeadParameters):
|
class QuantileRegressionQHeadParameters(HeadParameters):
|
||||||
def __init__(self, activation_function: str ='relu', name: str='quantile_regression_q_head_params',
|
def __init__(self, activation_function: str ='relu', name: str='quantile_regression_q_head_params',
|
||||||
num_output_head_copies: int = 1, rescale_gradient_from_head_by_factor: float = 1.0,
|
num_output_head_copies: int = 1, rescale_gradient_from_head_by_factor: float = 1.0,
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from .quantile_regression_q_head import QuantileRegressionQHead
|
|||||||
from .rainbow_q_head import RainbowQHead
|
from .rainbow_q_head import RainbowQHead
|
||||||
from .v_head import VHead
|
from .v_head import VHead
|
||||||
from .acer_policy_head import ACERPolicyHead
|
from .acer_policy_head import ACERPolicyHead
|
||||||
|
from .classification_head import ClassificationHead
|
||||||
from .cil_head import RegressionHead
|
from .cil_head import RegressionHead
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@@ -29,5 +30,6 @@ __all__ = [
|
|||||||
'RainbowQHead',
|
'RainbowQHead',
|
||||||
'VHead',
|
'VHead',
|
||||||
'ACERPolicyHead',
|
'ACERPolicyHead',
|
||||||
|
'ClassificationHead'
|
||||||
'RegressionHead'
|
'RegressionHead'
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -0,0 +1,60 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2019 Intel Corporation
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||||
|
from rl_coach.architectures.tensorflow_components.heads.head import Head
|
||||||
|
from rl_coach.base_parameters import AgentParameters
|
||||||
|
from rl_coach.spaces import SpacesDefinition, BoxActionSpace, DiscreteActionSpace
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationHead(Head):
|
||||||
|
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
|
||||||
|
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu',
|
||||||
|
dense_layer=Dense):
|
||||||
|
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
|
||||||
|
dense_layer=dense_layer)
|
||||||
|
self.name = 'classification_head'
|
||||||
|
if isinstance(self.spaces.action, BoxActionSpace):
|
||||||
|
self.num_actions = 1
|
||||||
|
elif isinstance(self.spaces.action, DiscreteActionSpace):
|
||||||
|
self.num_actions = len(self.spaces.action.actions)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
'ClassificationHead does not support action spaces of type: {class_name}'.format(
|
||||||
|
class_name=self.spaces.action.__class__.__name__,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_module(self, input_layer):
|
||||||
|
# Standard classification Network
|
||||||
|
self.class_values = self.output = self.dense_layer(self.num_actions)(input_layer, name='output')
|
||||||
|
|
||||||
|
self.output = tf.nn.softmax(self.class_values)
|
||||||
|
|
||||||
|
# calculate cross entropy loss
|
||||||
|
self.target = tf.placeholder(tf.float32, shape=(None, self.num_actions), name="target")
|
||||||
|
self.loss = tf.nn.softmax_cross_entropy_with_logits(labels=self.target, logits=self.class_values)
|
||||||
|
tf.losses.add_loss(self.loss)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
result = [
|
||||||
|
"Dense (num outputs = {})".format(self.num_actions)
|
||||||
|
]
|
||||||
|
return '\n'.join(result)
|
||||||
|
|
||||||
|
|
||||||
@@ -259,7 +259,7 @@ class CoachLauncher(object):
|
|||||||
graph_manager.env_params.level = args.level
|
graph_manager.env_params.level = args.level
|
||||||
|
|
||||||
# set the seed for the environment
|
# set the seed for the environment
|
||||||
if args.seed is not None:
|
if args.seed is not None and graph_manager.env_params is not None:
|
||||||
graph_manager.env_params.seed = args.seed
|
graph_manager.env_params.seed = args.seed
|
||||||
|
|
||||||
# visualization
|
# visualization
|
||||||
|
|||||||
@@ -77,8 +77,9 @@ class BatchRLGraphManager(BasicRLGraphManager):
|
|||||||
self.agent_params.name = "agent"
|
self.agent_params.name = "agent"
|
||||||
self.agent_params.is_batch_rl_training = True
|
self.agent_params.is_batch_rl_training = True
|
||||||
|
|
||||||
# user hasn't defined params for the reward model. we will use the same params as used for the 'main' network.
|
|
||||||
if 'reward_model' not in self.agent_params.network_wrappers:
|
if 'reward_model' not in self.agent_params.network_wrappers:
|
||||||
|
# user hasn't defined params for the reward model. we will use the same params as used for the 'main'
|
||||||
|
# network.
|
||||||
self.agent_params.network_wrappers['reward_model'] = deepcopy(self.agent_params.network_wrappers['main'])
|
self.agent_params.network_wrappers['reward_model'] = deepcopy(self.agent_params.network_wrappers['main'])
|
||||||
|
|
||||||
agent = short_dynamic_import(self.agent_params.path)(self.agent_params)
|
agent = short_dynamic_import(self.agent_params.path)(self.agent_params)
|
||||||
|
|||||||
111
rl_coach/presets/CartPole_DQN_BatchRL_BCQ.py
Normal file
111
rl_coach/presets/CartPole_DQN_BatchRL_BCQ.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||||
|
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
|
||||||
|
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
|
||||||
|
from rl_coach.environments.gym_environment import GymVectorEnvironment
|
||||||
|
from rl_coach.filters.filter import InputFilter
|
||||||
|
from rl_coach.filters.reward import RewardRescaleFilter
|
||||||
|
from rl_coach.graph_managers.batch_rl_graph_manager import BatchRLGraphManager
|
||||||
|
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||||
|
from rl_coach.memories.memory import MemoryGranularity
|
||||||
|
from rl_coach.schedules import LinearSchedule
|
||||||
|
from rl_coach.memories.episodic import EpisodicExperienceReplayParameters
|
||||||
|
from rl_coach.architectures.head_parameters import ClassificationHeadParameters
|
||||||
|
from rl_coach.agents.ddqn_bcq_agent import DDQNBCQAgentParameters
|
||||||
|
|
||||||
|
from rl_coach.agents.ddqn_bcq_agent import KNNParameters
|
||||||
|
from rl_coach.agents.ddqn_bcq_agent import NNImitationModelParameters
|
||||||
|
|
||||||
|
DATASET_SIZE = 10000
|
||||||
|
|
||||||
|
####################
|
||||||
|
# Graph Scheduling #
|
||||||
|
####################
|
||||||
|
|
||||||
|
schedule_params = ScheduleParameters()
|
||||||
|
schedule_params.improve_steps = TrainingSteps(10000000000)
|
||||||
|
schedule_params.steps_between_evaluation_periods = TrainingSteps(1)
|
||||||
|
schedule_params.evaluation_steps = EnvironmentEpisodes(10)
|
||||||
|
schedule_params.heatup_steps = EnvironmentSteps(DATASET_SIZE)
|
||||||
|
|
||||||
|
#########
|
||||||
|
# Agent #
|
||||||
|
#########
|
||||||
|
|
||||||
|
# using a set of 'unstable' hyper-params to showcase the value of BCQ. Using the same hyper-params with standard DDQN
|
||||||
|
# will cause Q values to unboundedly increase, and the policy convergence to be unstable.
|
||||||
|
agent_params = DDQNBCQAgentParameters()
|
||||||
|
agent_params.network_wrappers['main'].batch_size = 128
|
||||||
|
# agent_params.network_wrappers['main'].batch_size = 1024
|
||||||
|
|
||||||
|
# DQN params
|
||||||
|
|
||||||
|
# For making this become Fitted Q-Iteration we can keep the targets constant for the entire dataset size -
|
||||||
|
agent_params.algorithm.num_steps_between_copying_online_weights_to_target = TrainingSteps(
|
||||||
|
DATASET_SIZE / agent_params.network_wrappers['main'].batch_size)
|
||||||
|
#
|
||||||
|
# agent_params.algorithm.num_steps_between_copying_online_weights_to_target = TrainingSteps(
|
||||||
|
# 3)
|
||||||
|
|
||||||
|
agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(0)
|
||||||
|
agent_params.algorithm.discount = 0.98
|
||||||
|
|
||||||
|
# can use either a kNN or a NN based model for predicting which actions not to max over in the bellman equation
|
||||||
|
agent_params.algorithm.action_drop_method_parameters = KNNParameters()
|
||||||
|
# agent_params.algorithm.action_drop_method_parameters = NNImitationModelParameters()
|
||||||
|
# agent_params.algorithm.action_drop_method_parameters.imitation_model_num_epochs = 500
|
||||||
|
|
||||||
|
# NN configuration
|
||||||
|
agent_params.network_wrappers['main'].learning_rate = 0.0001
|
||||||
|
agent_params.network_wrappers['main'].replace_mse_with_huber_loss = False
|
||||||
|
agent_params.network_wrappers['main'].l2_regularization = 0.0001
|
||||||
|
agent_params.network_wrappers['main'].softmax_temperature = 0.2
|
||||||
|
|
||||||
|
# reward model params
|
||||||
|
agent_params.network_wrappers['reward_model'] = deepcopy(agent_params.network_wrappers['main'])
|
||||||
|
agent_params.network_wrappers['reward_model'].learning_rate = 0.0001
|
||||||
|
agent_params.network_wrappers['reward_model'].l2_regularization = 0
|
||||||
|
|
||||||
|
agent_params.network_wrappers['imitation_model'] = deepcopy(agent_params.network_wrappers['main'])
|
||||||
|
agent_params.network_wrappers['imitation_model'].learning_rate = 0.0001
|
||||||
|
agent_params.network_wrappers['imitation_model'].l2_regularization = 0
|
||||||
|
|
||||||
|
agent_params.network_wrappers['imitation_model'].heads_parameters = [ClassificationHeadParameters()]
|
||||||
|
agent_params.network_wrappers['imitation_model'].input_embedders_parameters['observation'].scheme = \
|
||||||
|
[Dense(1024), Dense(1024), Dense(512), Dense(512), Dense(256)]
|
||||||
|
agent_params.network_wrappers['imitation_model'].middleware_parameters.scheme = [Dense(128), Dense(64)]
|
||||||
|
|
||||||
|
|
||||||
|
# ER size
|
||||||
|
agent_params.memory = EpisodicExperienceReplayParameters()
|
||||||
|
agent_params.memory.max_size = (MemoryGranularity.Transitions, DATASET_SIZE)
|
||||||
|
|
||||||
|
|
||||||
|
# E-Greedy schedule
|
||||||
|
agent_params.exploration.epsilon_schedule = LinearSchedule(0, 0, 10000)
|
||||||
|
agent_params.exploration.evaluation_epsilon = 0
|
||||||
|
|
||||||
|
|
||||||
|
agent_params.input_filter = InputFilter()
|
||||||
|
agent_params.input_filter.add_reward_filter('rescale', RewardRescaleFilter(1/200.))
|
||||||
|
|
||||||
|
################
|
||||||
|
# Environment #
|
||||||
|
################
|
||||||
|
env_params = GymVectorEnvironment(level='CartPole-v0')
|
||||||
|
|
||||||
|
########
|
||||||
|
# Test #
|
||||||
|
########
|
||||||
|
preset_validation_params = PresetValidationParameters()
|
||||||
|
preset_validation_params.test = True
|
||||||
|
preset_validation_params.min_reward_threshold = 150
|
||||||
|
preset_validation_params.max_episodes_to_achieve_reward = 2000
|
||||||
|
|
||||||
|
graph_manager = BatchRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||||
|
schedule_params=schedule_params,
|
||||||
|
vis_params=VisualizationParameters(dump_signals_to_csv_every_x_episodes=1),
|
||||||
|
preset_validation_params=preset_validation_params,
|
||||||
|
reward_model_num_epochs=30,
|
||||||
|
train_to_eval_ratio=0.8)
|
||||||
Reference in New Issue
Block a user