mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
Fix for NEC not saving the DND when saving a model
This commit is contained in:
@@ -519,7 +519,7 @@ class Agent(object):
|
||||
current_snapshot_period = (int(total_training_time) // self.tp.save_model_sec)
|
||||
if current_snapshot_period > model_snapshots_periods_passed:
|
||||
model_snapshots_periods_passed = current_snapshot_period
|
||||
self.main_network.save_model(model_snapshots_periods_passed)
|
||||
self.save_model(model_snapshots_periods_passed)
|
||||
|
||||
# play and record in replay buffer
|
||||
if self.tp.agent.step_until_collecting_full_episodes:
|
||||
@@ -539,3 +539,5 @@ class Agent(object):
|
||||
self.training_iteration += 1
|
||||
self.post_training_commands()
|
||||
|
||||
def save_model(self, model_id):
|
||||
self.main_network.save_model(model_id)
|
||||
|
||||
@@ -25,6 +25,8 @@ class NECAgent(ValueOptimizationAgent):
|
||||
self.current_episode_state_embeddings = []
|
||||
self.current_episode_actions = []
|
||||
self.training_started = False
|
||||
# if self.tp.checkpoint_restore_dir:
|
||||
# self.load_dnd(self.tp.checkpoint_restore_dir)
|
||||
|
||||
def learn_from_batch(self, batch):
|
||||
if not self.main_network.online_network.output_heads[0].DND.has_enough_entries(self.tp.agent.number_of_knn):
|
||||
@@ -102,3 +104,8 @@ class NECAgent(ValueOptimizationAgent):
|
||||
|
||||
self.current_episode_state_embeddings = []
|
||||
self.current_episode_actions = []
|
||||
|
||||
def save_model(self, model_id):
|
||||
self.main_network.save_model(model_id)
|
||||
with open(os.path.join(self.tp.save_model_dir, str(model_id) + '.dnd'), 'wb') as f:
|
||||
pickle.dump(self.main_network.online_network.output_heads[0].DND, f, pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
@@ -265,11 +265,16 @@ class DNDQHead(Head):
|
||||
self.loss_type = tf.losses.huber_loss
|
||||
else:
|
||||
self.loss_type = tf.losses.mean_squared_error
|
||||
self.tp = tuning_parameters
|
||||
|
||||
def _build_module(self, input_layer):
|
||||
# DND based Q head
|
||||
from memories import differentiable_neural_dictionary
|
||||
self.DND = differentiable_neural_dictionary. QDND(
|
||||
|
||||
if self.tp.checkpoint_restore_dir:
|
||||
self.DND = differentiable_neural_dictionary.load_dnd(self.tp.checkpoint_restore_dir)
|
||||
else:
|
||||
self.DND = differentiable_neural_dictionary.QDND(
|
||||
self.DND_size, input_layer.get_shape()[-1], self.num_actions, self.new_value_shift_coefficient,
|
||||
key_error_threshold=self.DND_key_error_threshold)
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
import numpy as np
|
||||
from annoy import AnnoyIndex
|
||||
import os, pickle
|
||||
|
||||
|
||||
class AnnoyDictionary(object):
|
||||
@@ -171,3 +172,25 @@ class QDND:
|
||||
if not self.dicts[a].has_enough_entries(k):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def load_dnd(model_dir):
|
||||
max_id = 0
|
||||
|
||||
for f in [s for s in os.listdir(model_dir) if s.endswith('.dnd')]:
|
||||
if int(f.split('.')[0]) > max_id:
|
||||
max_id = int(f.split('.')[0])
|
||||
|
||||
model_path = str(max_id) + '.dnd'
|
||||
with open(os.path.join(model_dir, model_path), 'rb') as f:
|
||||
DND = pickle.load(f)
|
||||
|
||||
for a in range(DND.num_actions):
|
||||
DND.dicts[a].index = AnnoyIndex(512, metric='euclidean')
|
||||
DND.dicts[a].index.set_seed(1)
|
||||
|
||||
for idx, key in zip(range(DND.dicts[a].curr_size), DND.dicts[a].embeddings[:DND.dicts[a].curr_size]):
|
||||
DND.dicts[a].index.add_item(idx, key)
|
||||
|
||||
DND.dicts[a].index.build(50)
|
||||
return DND
|
||||
|
||||
Reference in New Issue
Block a user