1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

BCQ variant on top of DDQN (#276)

* kNN based model for predicting which actions to drop
* fix for seeds with batch rl
This commit is contained in:
Gal Leibovich
2019-04-16 17:06:23 +03:00
committed by GitHub
parent bdb9b224a8
commit 4741b0b916
11 changed files with 451 additions and 62 deletions

View File

@@ -45,6 +45,15 @@ class Agent(AgentInterface):
:param agent_parameters: A AgentParameters class instance with all the agent parameters
"""
super().__init__()
# use seed
if agent_parameters.task_parameters.seed is not None:
random.seed(agent_parameters.task_parameters.seed)
np.random.seed(agent_parameters.task_parameters.seed)
else:
# we need to seed the RNG since the different processes are initialized with the same parent seed
random.seed()
np.random.seed()
self.ap = agent_parameters
self.task_id = self.ap.task_parameters.task_index
self.is_chief = self.task_id == 0
@@ -197,15 +206,6 @@ class Agent(AgentInterface):
if isinstance(self.in_action_space, GoalsSpace):
self.distance_from_goal = self.register_signal('Distance From Goal', dump_one_value_per_step=True)
# use seed
if self.ap.task_parameters.seed is not None:
random.seed(self.ap.task_parameters.seed)
np.random.seed(self.ap.task_parameters.seed)
else:
# we need to seed the RNG since the different processes are initialized with the same parent seed
random.seed()
np.random.seed()
# batch rl
self.ope_manager = OpeManager() if self.ap.is_batch_rl_training else None
@@ -688,13 +688,16 @@ class Agent(AgentInterface):
for network in self.networks.values():
network.set_is_training(True)
# TODO: this should be network dependent
network_parameters = list(self.ap.network_wrappers.values())[0]
# At the moment we only support a single batch size for all the networks
networks_parameters = list(self.ap.network_wrappers.values())
assert all(net.batch_size == networks_parameters[0].batch_size for net in networks_parameters)
batch_size = networks_parameters[0].batch_size
# we either go sequentially through the entire replay buffer in the batch RL mode,
# or sample randomly for the basic RL case.
training_schedule = self.call_memory('get_shuffled_data_generator', network_parameters.batch_size) if \
self.ap.is_batch_rl_training else [self.call_memory('sample', network_parameters.batch_size) for _ in
training_schedule = self.call_memory('get_shuffled_data_generator', batch_size) if \
self.ap.is_batch_rl_training else [self.call_memory('sample', batch_size) for _ in
range(self.ap.algorithm.num_consecutive_training_steps)]
for batch in training_schedule:
@@ -713,13 +716,16 @@ class Agent(AgentInterface):
self.unclipped_grads.add_sample(unclipped_grads)
# TODO: the learning rate decay should be done through the network instead of here
# TODO: this only deals with the main network (if exists), need to do the same for other networks
# for instance, for DDPG, the LR signal is currently not shown. Probably should be done through the
# network directly instead of here
# decay learning rate
if network_parameters.learning_rate_decay_rate != 0:
if 'main' in self.ap.network_wrappers and \
self.ap.network_wrappers['main'].learning_rate_decay_rate != 0:
self.curr_learning_rate.add_sample(self.networks['main'].sess.run(
self.networks['main'].online_network.current_learning_rate))
else:
self.curr_learning_rate.add_sample(network_parameters.learning_rate)
self.curr_learning_rate.add_sample(networks_parameters[0].learning_rate)
if any([network.has_target for network in self.networks.values()]) \
and self._should_update_online_weights_to_target():