1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Fix for NEC not saving the DND when saving a model

This commit is contained in:
galleibo-intel
2017-11-09 19:13:23 +02:00
parent f5d645d8a6
commit 3c330768f0
4 changed files with 41 additions and 4 deletions

View File

@@ -519,7 +519,7 @@ class Agent(object):
current_snapshot_period = (int(total_training_time) // self.tp.save_model_sec) current_snapshot_period = (int(total_training_time) // self.tp.save_model_sec)
if current_snapshot_period > model_snapshots_periods_passed: if current_snapshot_period > model_snapshots_periods_passed:
model_snapshots_periods_passed = current_snapshot_period model_snapshots_periods_passed = current_snapshot_period
self.main_network.save_model(model_snapshots_periods_passed) self.save_model(model_snapshots_periods_passed)
# play and record in replay buffer # play and record in replay buffer
if self.tp.agent.step_until_collecting_full_episodes: if self.tp.agent.step_until_collecting_full_episodes:
@@ -539,3 +539,5 @@ class Agent(object):
self.training_iteration += 1 self.training_iteration += 1
self.post_training_commands() self.post_training_commands()
def save_model(self, model_id):
self.main_network.save_model(model_id)

View File

@@ -25,6 +25,8 @@ class NECAgent(ValueOptimizationAgent):
self.current_episode_state_embeddings = [] self.current_episode_state_embeddings = []
self.current_episode_actions = [] self.current_episode_actions = []
self.training_started = False self.training_started = False
# if self.tp.checkpoint_restore_dir:
# self.load_dnd(self.tp.checkpoint_restore_dir)
def learn_from_batch(self, batch): def learn_from_batch(self, batch):
if not self.main_network.online_network.output_heads[0].DND.has_enough_entries(self.tp.agent.number_of_knn): if not self.main_network.online_network.output_heads[0].DND.has_enough_entries(self.tp.agent.number_of_knn):
@@ -102,3 +104,8 @@ class NECAgent(ValueOptimizationAgent):
self.current_episode_state_embeddings = [] self.current_episode_state_embeddings = []
self.current_episode_actions = [] self.current_episode_actions = []
def save_model(self, model_id):
self.main_network.save_model(model_id)
with open(os.path.join(self.tp.save_model_dir, str(model_id) + '.dnd'), 'wb') as f:
pickle.dump(self.main_network.online_network.output_heads[0].DND, f, pickle.HIGHEST_PROTOCOL)

View File

@@ -265,13 +265,18 @@ class DNDQHead(Head):
self.loss_type = tf.losses.huber_loss self.loss_type = tf.losses.huber_loss
else: else:
self.loss_type = tf.losses.mean_squared_error self.loss_type = tf.losses.mean_squared_error
self.tp = tuning_parameters
def _build_module(self, input_layer): def _build_module(self, input_layer):
# DND based Q head # DND based Q head
from memories import differentiable_neural_dictionary from memories import differentiable_neural_dictionary
self.DND = differentiable_neural_dictionary. QDND(
self.DND_size, input_layer.get_shape()[-1], self.num_actions, self.new_value_shift_coefficient, if self.tp.checkpoint_restore_dir:
key_error_threshold=self.DND_key_error_threshold) self.DND = differentiable_neural_dictionary.load_dnd(self.tp.checkpoint_restore_dir)
else:
self.DND = differentiable_neural_dictionary.QDND(
self.DND_size, input_layer.get_shape()[-1], self.num_actions, self.new_value_shift_coefficient,
key_error_threshold=self.DND_key_error_threshold)
# Retrieve info from DND dictionary # Retrieve info from DND dictionary
self.action = tf.placeholder(tf.int8, [None], name="action") self.action = tf.placeholder(tf.int8, [None], name="action")

View File

@@ -16,6 +16,7 @@
import numpy as np import numpy as np
from annoy import AnnoyIndex from annoy import AnnoyIndex
import os, pickle
class AnnoyDictionary(object): class AnnoyDictionary(object):
@@ -171,3 +172,25 @@ class QDND:
if not self.dicts[a].has_enough_entries(k): if not self.dicts[a].has_enough_entries(k):
return False return False
return True return True
def load_dnd(model_dir):
max_id = 0
for f in [s for s in os.listdir(model_dir) if s.endswith('.dnd')]:
if int(f.split('.')[0]) > max_id:
max_id = int(f.split('.')[0])
model_path = str(max_id) + '.dnd'
with open(os.path.join(model_dir, model_path), 'rb') as f:
DND = pickle.load(f)
for a in range(DND.num_actions):
DND.dicts[a].index = AnnoyIndex(512, metric='euclidean')
DND.dicts[a].index.set_seed(1)
for idx, key in zip(range(DND.dicts[a].curr_size), DND.dicts[a].embeddings[:DND.dicts[a].curr_size]):
DND.dicts[a].index.add_item(idx, key)
DND.dicts[a].index.build(50)
return DND