mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
Fix for NEC not saving the DND when saving a model
This commit is contained in:
@@ -265,13 +265,18 @@ class DNDQHead(Head):
|
||||
self.loss_type = tf.losses.huber_loss
|
||||
else:
|
||||
self.loss_type = tf.losses.mean_squared_error
|
||||
self.tp = tuning_parameters
|
||||
|
||||
def _build_module(self, input_layer):
|
||||
# DND based Q head
|
||||
from memories import differentiable_neural_dictionary
|
||||
self.DND = differentiable_neural_dictionary. QDND(
|
||||
self.DND_size, input_layer.get_shape()[-1], self.num_actions, self.new_value_shift_coefficient,
|
||||
key_error_threshold=self.DND_key_error_threshold)
|
||||
|
||||
if self.tp.checkpoint_restore_dir:
|
||||
self.DND = differentiable_neural_dictionary.load_dnd(self.tp.checkpoint_restore_dir)
|
||||
else:
|
||||
self.DND = differentiable_neural_dictionary.QDND(
|
||||
self.DND_size, input_layer.get_shape()[-1], self.num_actions, self.new_value_shift_coefficient,
|
||||
key_error_threshold=self.DND_key_error_threshold)
|
||||
|
||||
# Retrieve info from DND dictionary
|
||||
self.action = tf.placeholder(tf.int8, [None], name="action")
|
||||
|
||||
Reference in New Issue
Block a user