mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
BCQ variant on top of DDQN (#276)
* kNN based model for predicting which actions to drop * fix for seeds with batch rl
This commit is contained in:
@@ -45,6 +45,15 @@ class Agent(AgentInterface):
|
||||
:param agent_parameters: A AgentParameters class instance with all the agent parameters
|
||||
"""
|
||||
super().__init__()
|
||||
# use seed
|
||||
if agent_parameters.task_parameters.seed is not None:
|
||||
random.seed(agent_parameters.task_parameters.seed)
|
||||
np.random.seed(agent_parameters.task_parameters.seed)
|
||||
else:
|
||||
# we need to seed the RNG since the different processes are initialized with the same parent seed
|
||||
random.seed()
|
||||
np.random.seed()
|
||||
|
||||
self.ap = agent_parameters
|
||||
self.task_id = self.ap.task_parameters.task_index
|
||||
self.is_chief = self.task_id == 0
|
||||
@@ -197,15 +206,6 @@ class Agent(AgentInterface):
|
||||
if isinstance(self.in_action_space, GoalsSpace):
|
||||
self.distance_from_goal = self.register_signal('Distance From Goal', dump_one_value_per_step=True)
|
||||
|
||||
# use seed
|
||||
if self.ap.task_parameters.seed is not None:
|
||||
random.seed(self.ap.task_parameters.seed)
|
||||
np.random.seed(self.ap.task_parameters.seed)
|
||||
else:
|
||||
# we need to seed the RNG since the different processes are initialized with the same parent seed
|
||||
random.seed()
|
||||
np.random.seed()
|
||||
|
||||
# batch rl
|
||||
self.ope_manager = OpeManager() if self.ap.is_batch_rl_training else None
|
||||
|
||||
@@ -688,13 +688,16 @@ class Agent(AgentInterface):
|
||||
for network in self.networks.values():
|
||||
network.set_is_training(True)
|
||||
|
||||
# TODO: this should be network dependent
|
||||
network_parameters = list(self.ap.network_wrappers.values())[0]
|
||||
# At the moment we only support a single batch size for all the networks
|
||||
networks_parameters = list(self.ap.network_wrappers.values())
|
||||
assert all(net.batch_size == networks_parameters[0].batch_size for net in networks_parameters)
|
||||
|
||||
batch_size = networks_parameters[0].batch_size
|
||||
|
||||
# we either go sequentially through the entire replay buffer in the batch RL mode,
|
||||
# or sample randomly for the basic RL case.
|
||||
training_schedule = self.call_memory('get_shuffled_data_generator', network_parameters.batch_size) if \
|
||||
self.ap.is_batch_rl_training else [self.call_memory('sample', network_parameters.batch_size) for _ in
|
||||
training_schedule = self.call_memory('get_shuffled_data_generator', batch_size) if \
|
||||
self.ap.is_batch_rl_training else [self.call_memory('sample', batch_size) for _ in
|
||||
range(self.ap.algorithm.num_consecutive_training_steps)]
|
||||
|
||||
for batch in training_schedule:
|
||||
@@ -713,13 +716,16 @@ class Agent(AgentInterface):
|
||||
|
||||
self.unclipped_grads.add_sample(unclipped_grads)
|
||||
|
||||
# TODO: the learning rate decay should be done through the network instead of here
|
||||
# TODO: this only deals with the main network (if exists), need to do the same for other networks
|
||||
# for instance, for DDPG, the LR signal is currently not shown. Probably should be done through the
|
||||
# network directly instead of here
|
||||
# decay learning rate
|
||||
if network_parameters.learning_rate_decay_rate != 0:
|
||||
if 'main' in self.ap.network_wrappers and \
|
||||
self.ap.network_wrappers['main'].learning_rate_decay_rate != 0:
|
||||
self.curr_learning_rate.add_sample(self.networks['main'].sess.run(
|
||||
self.networks['main'].online_network.current_learning_rate))
|
||||
else:
|
||||
self.curr_learning_rate.add_sample(network_parameters.learning_rate)
|
||||
self.curr_learning_rate.add_sample(networks_parameters[0].learning_rate)
|
||||
|
||||
if any([network.has_target for network in self.networks.values()]) \
|
||||
and self._should_update_online_weights_to_target():
|
||||
|
||||
Reference in New Issue
Block a user