mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Save filters' internal state (#127)
* save filters internal state * moving the restore to be made from within NumpyRunningStats
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
#
|
||||
|
||||
import copy
|
||||
import os
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, List, Union, Tuple
|
||||
@@ -108,8 +109,12 @@ class Agent(AgentInterface):
|
||||
|
||||
# filters
|
||||
self.input_filter = self.ap.input_filter
|
||||
self.input_filter.set_name('input_filter')
|
||||
self.output_filter = self.ap.output_filter
|
||||
self.output_filter.set_name('output_filter')
|
||||
self.pre_network_filter = self.ap.pre_network_filter
|
||||
self.pre_network_filter.set_name('pre_network_filter')
|
||||
|
||||
device = self.replicated_device if self.replicated_device else self.worker_device
|
||||
|
||||
# TODO-REMOVE This is a temporary flow dividing to 3 modes. To be converged to a single flow once distributed tf
|
||||
@@ -923,7 +928,26 @@ class Agent(AgentInterface):
|
||||
:param checkpoint_id: the id of the checkpoint
|
||||
:return: None
|
||||
"""
|
||||
pass
|
||||
checkpoint_dir = os.path.join(self.ap.task_parameters.checkpoint_save_dir,
|
||||
*(self.full_name_id.split('/'))) # adds both level name and agent name
|
||||
self.input_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_id)
|
||||
self.output_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_id)
|
||||
self.pre_network_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_id)
|
||||
|
||||
def restore_checkpoint(self, checkpoint_dir: str) -> None:
|
||||
"""
|
||||
Allows agents to store additional information when saving checkpoints.
|
||||
|
||||
:param checkpoint_id: the id of the checkpoint
|
||||
:return: None
|
||||
"""
|
||||
checkpoint_dir = os.path.join(checkpoint_dir,
|
||||
*(self.full_name_id.split('/'))) # adds both level name and agent name
|
||||
self.input_filter.restore_state_from_checkpoint(checkpoint_dir)
|
||||
self.pre_network_filter.restore_state_from_checkpoint(checkpoint_dir)
|
||||
|
||||
# no output filters currently have an internal state to restore
|
||||
# self.output_filter.restore_state_from_checkpoint(checkpoint_dir)
|
||||
|
||||
def sync(self) -> None:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user