1
0
mirror of https://github.com/gryf/coach.git synced 2026-02-15 05:25:55 +01:00

Save filters' internal state (#127)

* save filters internal state

* moving the restore to be made from within NumpyRunningStats
This commit is contained in:
Gal Leibovich
2018-11-20 17:21:48 +02:00
committed by GitHub
parent 67eb9e4c28
commit a112ee69f6
13 changed files with 173 additions and 14 deletions

View File

@@ -15,6 +15,7 @@
#
import copy
import os
import random
from collections import OrderedDict
from typing import Dict, List, Union, Tuple
@@ -108,8 +109,12 @@ class Agent(AgentInterface):
# filters
self.input_filter = self.ap.input_filter
self.input_filter.set_name('input_filter')
self.output_filter = self.ap.output_filter
self.output_filter.set_name('output_filter')
self.pre_network_filter = self.ap.pre_network_filter
self.pre_network_filter.set_name('pre_network_filter')
device = self.replicated_device if self.replicated_device else self.worker_device
# TODO-REMOVE This is a temporary flow dividing to 3 modes. To be converged to a single flow once distributed tf
@@ -923,7 +928,26 @@ class Agent(AgentInterface):
:param checkpoint_id: the id of the checkpoint
:return: None
"""
pass
checkpoint_dir = os.path.join(self.ap.task_parameters.checkpoint_save_dir,
*(self.full_name_id.split('/'))) # adds both level name and agent name
self.input_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_id)
self.output_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_id)
self.pre_network_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_id)
def restore_checkpoint(self, checkpoint_dir: str) -> None:
"""
Allows agents to store additional information when saving checkpoints.
:param checkpoint_id: the id of the checkpoint
:return: None
"""
checkpoint_dir = os.path.join(checkpoint_dir,
*(self.full_name_id.split('/'))) # adds both level name and agent name
self.input_filter.restore_state_from_checkpoint(checkpoint_dir)
self.pre_network_filter.restore_state_from_checkpoint(checkpoint_dir)
# no output filters currently have an internal state to restore
# self.output_filter.restore_state_from_checkpoint(checkpoint_dir)
def sync(self) -> None:
"""

View File

@@ -392,6 +392,9 @@ class CompositeAgent(AgentInterface):
def save_checkpoint(self, checkpoint_id: int) -> None:
[agent.save_checkpoint(checkpoint_id) for agent in self.agents.values()]
def restore_checkpoint(self, checkpoint_dir: str) -> None:
[agent.restore_checkpoint(checkpoint_dir) for agent in self.agents.values()]
def set_incoming_directive(self, action: ActionType) -> None:
self.incoming_action = action
if isinstance(self.decision_policy, SingleDecider) and isinstance(self.in_action_space, AgentSelection):

View File

@@ -204,5 +204,6 @@ class NECAgent(ValueOptimizationAgent):
actions, discounted_rewards)
def save_checkpoint(self, checkpoint_id):
super().save_checkpoint(checkpoint_id)
with open(os.path.join(self.ap.task_parameters.checkpoint_save_dir, str(checkpoint_id) + '.dnd'), 'wb') as f:
pickle.dump(self.networks['main'].online_network.output_heads[0].DND, f, pickle.HIGHEST_PROTOCOL)