1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

Fix for NEC not saving the DND when saving a model

This commit is contained in:
galleibo-intel
2017-11-09 19:13:23 +02:00
parent f5d645d8a6
commit 3c330768f0
4 changed files with 41 additions and 4 deletions

View File

@@ -265,13 +265,18 @@ class DNDQHead(Head):
self.loss_type = tf.losses.huber_loss
else:
self.loss_type = tf.losses.mean_squared_error
self.tp = tuning_parameters
def _build_module(self, input_layer):
# DND based Q head
from memories import differentiable_neural_dictionary
self.DND = differentiable_neural_dictionary. QDND(
self.DND_size, input_layer.get_shape()[-1], self.num_actions, self.new_value_shift_coefficient,
key_error_threshold=self.DND_key_error_threshold)
if self.tp.checkpoint_restore_dir:
self.DND = differentiable_neural_dictionary.load_dnd(self.tp.checkpoint_restore_dir)
else:
self.DND = differentiable_neural_dictionary.QDND(
self.DND_size, input_layer.get_shape()[-1], self.num_actions, self.new_value_shift_coefficient,
key_error_threshold=self.DND_key_error_threshold)
# Retrieve info from DND dictionary
self.action = tf.placeholder(tf.int8, [None], name="action")