mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
Fix for NEC not saving the DND when saving a model
This commit is contained in:
@@ -519,7 +519,7 @@ class Agent(object):
|
|||||||
current_snapshot_period = (int(total_training_time) // self.tp.save_model_sec)
|
current_snapshot_period = (int(total_training_time) // self.tp.save_model_sec)
|
||||||
if current_snapshot_period > model_snapshots_periods_passed:
|
if current_snapshot_period > model_snapshots_periods_passed:
|
||||||
model_snapshots_periods_passed = current_snapshot_period
|
model_snapshots_periods_passed = current_snapshot_period
|
||||||
self.main_network.save_model(model_snapshots_periods_passed)
|
self.save_model(model_snapshots_periods_passed)
|
||||||
|
|
||||||
# play and record in replay buffer
|
# play and record in replay buffer
|
||||||
if self.tp.agent.step_until_collecting_full_episodes:
|
if self.tp.agent.step_until_collecting_full_episodes:
|
||||||
@@ -539,3 +539,5 @@ class Agent(object):
|
|||||||
self.training_iteration += 1
|
self.training_iteration += 1
|
||||||
self.post_training_commands()
|
self.post_training_commands()
|
||||||
|
|
||||||
|
def save_model(self, model_id):
|
||||||
|
self.main_network.save_model(model_id)
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ class NECAgent(ValueOptimizationAgent):
|
|||||||
self.current_episode_state_embeddings = []
|
self.current_episode_state_embeddings = []
|
||||||
self.current_episode_actions = []
|
self.current_episode_actions = []
|
||||||
self.training_started = False
|
self.training_started = False
|
||||||
|
# if self.tp.checkpoint_restore_dir:
|
||||||
|
# self.load_dnd(self.tp.checkpoint_restore_dir)
|
||||||
|
|
||||||
def learn_from_batch(self, batch):
|
def learn_from_batch(self, batch):
|
||||||
if not self.main_network.online_network.output_heads[0].DND.has_enough_entries(self.tp.agent.number_of_knn):
|
if not self.main_network.online_network.output_heads[0].DND.has_enough_entries(self.tp.agent.number_of_knn):
|
||||||
@@ -102,3 +104,8 @@ class NECAgent(ValueOptimizationAgent):
|
|||||||
|
|
||||||
self.current_episode_state_embeddings = []
|
self.current_episode_state_embeddings = []
|
||||||
self.current_episode_actions = []
|
self.current_episode_actions = []
|
||||||
|
|
||||||
|
def save_model(self, model_id):
|
||||||
|
self.main_network.save_model(model_id)
|
||||||
|
with open(os.path.join(self.tp.save_model_dir, str(model_id) + '.dnd'), 'wb') as f:
|
||||||
|
pickle.dump(self.main_network.online_network.output_heads[0].DND, f, pickle.HIGHEST_PROTOCOL)
|
||||||
|
|||||||
@@ -265,13 +265,18 @@ class DNDQHead(Head):
|
|||||||
self.loss_type = tf.losses.huber_loss
|
self.loss_type = tf.losses.huber_loss
|
||||||
else:
|
else:
|
||||||
self.loss_type = tf.losses.mean_squared_error
|
self.loss_type = tf.losses.mean_squared_error
|
||||||
|
self.tp = tuning_parameters
|
||||||
|
|
||||||
def _build_module(self, input_layer):
|
def _build_module(self, input_layer):
|
||||||
# DND based Q head
|
# DND based Q head
|
||||||
from memories import differentiable_neural_dictionary
|
from memories import differentiable_neural_dictionary
|
||||||
self.DND = differentiable_neural_dictionary. QDND(
|
|
||||||
self.DND_size, input_layer.get_shape()[-1], self.num_actions, self.new_value_shift_coefficient,
|
if self.tp.checkpoint_restore_dir:
|
||||||
key_error_threshold=self.DND_key_error_threshold)
|
self.DND = differentiable_neural_dictionary.load_dnd(self.tp.checkpoint_restore_dir)
|
||||||
|
else:
|
||||||
|
self.DND = differentiable_neural_dictionary.QDND(
|
||||||
|
self.DND_size, input_layer.get_shape()[-1], self.num_actions, self.new_value_shift_coefficient,
|
||||||
|
key_error_threshold=self.DND_key_error_threshold)
|
||||||
|
|
||||||
# Retrieve info from DND dictionary
|
# Retrieve info from DND dictionary
|
||||||
self.action = tf.placeholder(tf.int8, [None], name="action")
|
self.action = tf.placeholder(tf.int8, [None], name="action")
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from annoy import AnnoyIndex
|
from annoy import AnnoyIndex
|
||||||
|
import os, pickle
|
||||||
|
|
||||||
|
|
||||||
class AnnoyDictionary(object):
|
class AnnoyDictionary(object):
|
||||||
@@ -171,3 +172,25 @@ class QDND:
|
|||||||
if not self.dicts[a].has_enough_entries(k):
|
if not self.dicts[a].has_enough_entries(k):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def load_dnd(model_dir):
|
||||||
|
max_id = 0
|
||||||
|
|
||||||
|
for f in [s for s in os.listdir(model_dir) if s.endswith('.dnd')]:
|
||||||
|
if int(f.split('.')[0]) > max_id:
|
||||||
|
max_id = int(f.split('.')[0])
|
||||||
|
|
||||||
|
model_path = str(max_id) + '.dnd'
|
||||||
|
with open(os.path.join(model_dir, model_path), 'rb') as f:
|
||||||
|
DND = pickle.load(f)
|
||||||
|
|
||||||
|
for a in range(DND.num_actions):
|
||||||
|
DND.dicts[a].index = AnnoyIndex(512, metric='euclidean')
|
||||||
|
DND.dicts[a].index.set_seed(1)
|
||||||
|
|
||||||
|
for idx, key in zip(range(DND.dicts[a].curr_size), DND.dicts[a].embeddings[:DND.dicts[a].curr_size]):
|
||||||
|
DND.dicts[a].index.add_item(idx, key)
|
||||||
|
|
||||||
|
DND.dicts[a].index.build(50)
|
||||||
|
return DND
|
||||||
|
|||||||
Reference in New Issue
Block a user