1
0
mirror of https://github.com/gryf/coach.git synced 2026-03-31 09:03:33 +02:00
This commit is contained in:
Gal Leibovich
2019-06-16 11:11:21 +03:00
committed by GitHub
parent 8df3c46756
commit 7eb884c5b2
107 changed files with 2200 additions and 495 deletions

View File

@@ -90,7 +90,7 @@ class DDQNBCQAgent(DQNAgent):
if self.ap.algorithm.action_drop_method_parameters.use_state_embedding_instead_of_state:
return self.networks['reward_model'].online_network.predict(
states,
outputs=[self.networks['reward_model'].online_network.state_embedding])
outputs=[self.networks['reward_model'].online_network.state_embedding[0]])
else:
return states['observation']
self.embedding = to_embedding
@@ -189,7 +189,7 @@ class DDQNBCQAgent(DQNAgent):
if self.ap.algorithm.action_drop_method_parameters.use_state_embedding_instead_of_state:
self.knn_trees = [AnnoyDictionary(
dict_size=knn_size,
key_width=int(self.networks['reward_model'].online_network.state_embedding.shape[-1]),
key_width=int(self.networks['reward_model'].online_network.state_embedding[0].shape[-1]),
batch_size=knn_size)
for _ in range(len(self.spaces.action.actions))]
else: