1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Fix for NEC not saving the DND when saving a model

This commit is contained in:
galleibo-intel
2017-11-09 19:13:23 +02:00
parent f5d645d8a6
commit 3c330768f0
4 changed files with 41 additions and 4 deletions

View File

@@ -16,6 +16,7 @@
import numpy as np
from annoy import AnnoyIndex
import os, pickle
class AnnoyDictionary(object):
@@ -171,3 +172,25 @@ class QDND:
if not self.dicts[a].has_enough_entries(k):
return False
return True
def load_dnd(model_dir):
max_id = 0
for f in [s for s in os.listdir(model_dir) if s.endswith('.dnd')]:
if int(f.split('.')[0]) > max_id:
max_id = int(f.split('.')[0])
model_path = str(max_id) + '.dnd'
with open(os.path.join(model_dir, model_path), 'rb') as f:
DND = pickle.load(f)
for a in range(DND.num_actions):
DND.dicts[a].index = AnnoyIndex(512, metric='euclidean')
DND.dicts[a].index.set_seed(1)
for idx, key in zip(range(DND.dicts[a].curr_size), DND.dicts[a].embeddings[:DND.dicts[a].curr_size]):
DND.dicts[a].index.add_item(idx, key)
DND.dicts[a].index.build(50)
return DND