1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Save filters' internal state (#127)

* save filters internal state

* moving the restore to be made from within NumpyRunningStats
This commit is contained in:
Gal Leibovich
2018-11-20 17:21:48 +02:00
committed by GitHub
parent 67eb9e4c28
commit a112ee69f6
13 changed files with 173 additions and 14 deletions

View File

@@ -15,6 +15,7 @@
#
import copy
import os
import random
from collections import OrderedDict
from typing import Dict, List, Union, Tuple
@@ -108,8 +109,12 @@ class Agent(AgentInterface):
# filters
self.input_filter = self.ap.input_filter
self.input_filter.set_name('input_filter')
self.output_filter = self.ap.output_filter
self.output_filter.set_name('output_filter')
self.pre_network_filter = self.ap.pre_network_filter
self.pre_network_filter.set_name('pre_network_filter')
device = self.replicated_device if self.replicated_device else self.worker_device
# TODO-REMOVE This is a temporary flow dividing to 3 modes. To be converged to a single flow once distributed tf
@@ -923,7 +928,26 @@ class Agent(AgentInterface):
:param checkpoint_id: the id of the checkpoint
:return: None
"""
pass
checkpoint_dir = os.path.join(self.ap.task_parameters.checkpoint_save_dir,
*(self.full_name_id.split('/'))) # adds both level name and agent name
self.input_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_id)
self.output_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_id)
self.pre_network_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_id)
def restore_checkpoint(self, checkpoint_dir: str) -> None:
"""
Allows agents to store additional information when saving checkpoints.
:param checkpoint_id: the id of the checkpoint
:return: None
"""
checkpoint_dir = os.path.join(checkpoint_dir,
*(self.full_name_id.split('/'))) # adds both level name and agent name
self.input_filter.restore_state_from_checkpoint(checkpoint_dir)
self.pre_network_filter.restore_state_from_checkpoint(checkpoint_dir)
# no output filters currently have an internal state to restore
# self.output_filter.restore_state_from_checkpoint(checkpoint_dir)
def sync(self) -> None:
"""