diff --git a/agents/agent.py b/agents/agent.py index 3e83c4b..ed9eabc 100644 --- a/agents/agent.py +++ b/agents/agent.py @@ -519,7 +519,7 @@ class Agent(object): current_snapshot_period = (int(total_training_time) // self.tp.save_model_sec) if current_snapshot_period > model_snapshots_periods_passed: model_snapshots_periods_passed = current_snapshot_period - self.main_network.save_model(model_snapshots_periods_passed) + self.save_model(model_snapshots_periods_passed) # play and record in replay buffer if self.tp.agent.step_until_collecting_full_episodes: @@ -539,3 +539,5 @@ class Agent(object): self.training_iteration += 1 self.post_training_commands() + def save_model(self, model_id): + self.main_network.save_model(model_id) diff --git a/agents/nec_agent.py b/agents/nec_agent.py index 4e724c9..e8ac535 100644 --- a/agents/nec_agent.py +++ b/agents/nec_agent.py @@ -25,6 +25,8 @@ class NECAgent(ValueOptimizationAgent): self.current_episode_state_embeddings = [] self.current_episode_actions = [] self.training_started = False + # if self.tp.checkpoint_restore_dir: + # self.load_dnd(self.tp.checkpoint_restore_dir) def learn_from_batch(self, batch): if not self.main_network.online_network.output_heads[0].DND.has_enough_entries(self.tp.agent.number_of_knn): @@ -102,3 +104,8 @@ class NECAgent(ValueOptimizationAgent): self.current_episode_state_embeddings = [] self.current_episode_actions = [] + + def save_model(self, model_id): + self.main_network.save_model(model_id) + with open(os.path.join(self.tp.save_model_dir, str(model_id) + '.dnd'), 'wb') as f: + pickle.dump(self.main_network.online_network.output_heads[0].DND, f, pickle.HIGHEST_PROTOCOL) diff --git a/architectures/tensorflow_components/heads.py b/architectures/tensorflow_components/heads.py index 1ebc10c..708fd1b 100644 --- a/architectures/tensorflow_components/heads.py +++ b/architectures/tensorflow_components/heads.py @@ -265,13 +265,18 @@ class DNDQHead(Head): self.loss_type = tf.losses.huber_loss else: self.loss_type = tf.losses.mean_squared_error + self.tp = tuning_parameters def _build_module(self, input_layer): # DND based Q head from memories import differentiable_neural_dictionary - self.DND = differentiable_neural_dictionary. QDND( - self.DND_size, input_layer.get_shape()[-1], self.num_actions, self.new_value_shift_coefficient, - key_error_threshold=self.DND_key_error_threshold) + + if self.tp.checkpoint_restore_dir: + self.DND = differentiable_neural_dictionary.load_dnd(self.tp.checkpoint_restore_dir) + else: + self.DND = differentiable_neural_dictionary.QDND( + self.DND_size, input_layer.get_shape()[-1], self.num_actions, self.new_value_shift_coefficient, + key_error_threshold=self.DND_key_error_threshold) # Retrieve info from DND dictionary self.action = tf.placeholder(tf.int8, [None], name="action") diff --git a/memories/differentiable_neural_dictionary.py b/memories/differentiable_neural_dictionary.py index 4de4c9c..4904151 100644 --- a/memories/differentiable_neural_dictionary.py +++ b/memories/differentiable_neural_dictionary.py @@ -16,6 +16,7 @@ import numpy as np from annoy import AnnoyIndex +import os, pickle class AnnoyDictionary(object): @@ -171,3 +172,25 @@ class QDND: if not self.dicts[a].has_enough_entries(k): return False return True + + +def load_dnd(model_dir): + max_id = 0 + + for f in [s for s in os.listdir(model_dir) if s.endswith('.dnd')]: + if int(f.split('.')[0]) > max_id: + max_id = int(f.split('.')[0]) + + model_path = str(max_id) + '.dnd' + with open(os.path.join(model_dir, model_path), 'rb') as f: + DND = pickle.load(f) + + for a in range(DND.num_actions): + DND.dicts[a].index = AnnoyIndex(512, metric='euclidean') + DND.dicts[a].index.set_seed(1) + + for idx, key in zip(range(DND.dicts[a].curr_size), DND.dicts[a].embeddings[:DND.dicts[a].curr_size]): + DND.dicts[a].index.add_item(idx, key) + + DND.dicts[a].index.build(50) + return DND