mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Fix for NEC not saving the DND when saving a model
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
|
||||
import numpy as np
|
||||
from annoy import AnnoyIndex
|
||||
import os, pickle
|
||||
|
||||
|
||||
class AnnoyDictionary(object):
|
||||
@@ -171,3 +172,25 @@ class QDND:
|
||||
if not self.dicts[a].has_enough_entries(k):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def load_dnd(model_dir):
|
||||
max_id = 0
|
||||
|
||||
for f in [s for s in os.listdir(model_dir) if s.endswith('.dnd')]:
|
||||
if int(f.split('.')[0]) > max_id:
|
||||
max_id = int(f.split('.')[0])
|
||||
|
||||
model_path = str(max_id) + '.dnd'
|
||||
with open(os.path.join(model_dir, model_path), 'rb') as f:
|
||||
DND = pickle.load(f)
|
||||
|
||||
for a in range(DND.num_actions):
|
||||
DND.dicts[a].index = AnnoyIndex(512, metric='euclidean')
|
||||
DND.dicts[a].index.set_seed(1)
|
||||
|
||||
for idx, key in zip(range(DND.dicts[a].curr_size), DND.dicts[a].embeddings[:DND.dicts[a].curr_size]):
|
||||
DND.dicts[a].index.add_item(idx, key)
|
||||
|
||||
DND.dicts[a].index.build(50)
|
||||
return DND
|
||||
|
||||
Reference in New Issue
Block a user