mirror of
https://github.com/gryf/coach.git
synced 2026-03-31 09:03:33 +02:00
TD3 (#338)
This commit is contained in:
@@ -90,7 +90,7 @@ class DDQNBCQAgent(DQNAgent):
|
||||
if self.ap.algorithm.action_drop_method_parameters.use_state_embedding_instead_of_state:
|
||||
return self.networks['reward_model'].online_network.predict(
|
||||
states,
|
||||
outputs=[self.networks['reward_model'].online_network.state_embedding])
|
||||
outputs=[self.networks['reward_model'].online_network.state_embedding[0]])
|
||||
else:
|
||||
return states['observation']
|
||||
self.embedding = to_embedding
|
||||
@@ -189,7 +189,7 @@ class DDQNBCQAgent(DQNAgent):
|
||||
if self.ap.algorithm.action_drop_method_parameters.use_state_embedding_instead_of_state:
|
||||
self.knn_trees = [AnnoyDictionary(
|
||||
dict_size=knn_size,
|
||||
key_width=int(self.networks['reward_model'].online_network.state_embedding.shape[-1]),
|
||||
key_width=int(self.networks['reward_model'].online_network.state_embedding[0].shape[-1]),
|
||||
batch_size=knn_size)
|
||||
for _ in range(len(self.spaces.action.actions))]
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user