1
0
mirror of https://github.com/gryf/coach.git synced 2026-02-15 05:25:55 +01:00

Fix for NEC not saving the DND when saving a model

This commit is contained in:
galleibo-intel
2017-11-09 19:13:23 +02:00
parent f5d645d8a6
commit 3c330768f0
4 changed files with 41 additions and 4 deletions

View File

@@ -25,6 +25,8 @@ class NECAgent(ValueOptimizationAgent):
self.current_episode_state_embeddings = []
self.current_episode_actions = []
self.training_started = False
# if self.tp.checkpoint_restore_dir:
# self.load_dnd(self.tp.checkpoint_restore_dir)
def learn_from_batch(self, batch):
if not self.main_network.online_network.output_heads[0].DND.has_enough_entries(self.tp.agent.number_of_knn):
@@ -102,3 +104,8 @@ class NECAgent(ValueOptimizationAgent):
self.current_episode_state_embeddings = []
self.current_episode_actions = []
def save_model(self, model_id):
self.main_network.save_model(model_id)
with open(os.path.join(self.tp.save_model_dir, str(model_id) + '.dnd'), 'wb') as f:
pickle.dump(self.main_network.online_network.output_heads[0].DND, f, pickle.HIGHEST_PROTOCOL)