mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
Fix for NEC not saving the DND when saving a model
This commit is contained in:
@@ -519,7 +519,7 @@ class Agent(object):
|
||||
current_snapshot_period = (int(total_training_time) // self.tp.save_model_sec)
|
||||
if current_snapshot_period > model_snapshots_periods_passed:
|
||||
model_snapshots_periods_passed = current_snapshot_period
|
||||
self.main_network.save_model(model_snapshots_periods_passed)
|
||||
self.save_model(model_snapshots_periods_passed)
|
||||
|
||||
# play and record in replay buffer
|
||||
if self.tp.agent.step_until_collecting_full_episodes:
|
||||
@@ -539,3 +539,5 @@ class Agent(object):
|
||||
self.training_iteration += 1
|
||||
self.post_training_commands()
|
||||
|
||||
def save_model(self, model_id):
|
||||
self.main_network.save_model(model_id)
|
||||
|
||||
@@ -25,6 +25,8 @@ class NECAgent(ValueOptimizationAgent):
|
||||
self.current_episode_state_embeddings = []
|
||||
self.current_episode_actions = []
|
||||
self.training_started = False
|
||||
# if self.tp.checkpoint_restore_dir:
|
||||
# self.load_dnd(self.tp.checkpoint_restore_dir)
|
||||
|
||||
def learn_from_batch(self, batch):
|
||||
if not self.main_network.online_network.output_heads[0].DND.has_enough_entries(self.tp.agent.number_of_knn):
|
||||
@@ -102,3 +104,8 @@ class NECAgent(ValueOptimizationAgent):
|
||||
|
||||
self.current_episode_state_embeddings = []
|
||||
self.current_episode_actions = []
|
||||
|
||||
def save_model(self, model_id):
|
||||
self.main_network.save_model(model_id)
|
||||
with open(os.path.join(self.tp.save_model_dir, str(model_id) + '.dnd'), 'wb') as f:
|
||||
pickle.dump(self.main_network.online_network.output_heads[0].DND, f, pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
Reference in New Issue
Block a user