mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Distiller's AMC induced changes (#359)
* override episode rewards with the last transition reward * EWMA normalization filter * allowing control over when the pre_network filter runs
This commit is contained in:
@@ -20,6 +20,8 @@ import pickle
|
||||
import redis
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.utils import get_latest_checkpoint
|
||||
|
||||
|
||||
class SharedRunningStatsSubscribe(threading.Thread):
|
||||
def __init__(self, shared_running_stats):
|
||||
@@ -109,27 +111,13 @@ class SharedRunningStats(ABC):
|
||||
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
||||
pass
|
||||
|
||||
def get_latest_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str) -> str:
|
||||
latest_checkpoint_id = -1
|
||||
latest_checkpoint = ''
|
||||
# get all checkpoint files
|
||||
for fname in os.listdir(checkpoint_dir):
|
||||
path = os.path.join(checkpoint_dir, fname)
|
||||
if os.path.isdir(path) or fname.split('.')[-1] != 'srs' or checkpoint_prefix not in fname:
|
||||
continue
|
||||
checkpoint_id = int(fname.split('_')[0])
|
||||
if checkpoint_id > latest_checkpoint_id:
|
||||
latest_checkpoint = fname
|
||||
latest_checkpoint_id = checkpoint_id
|
||||
|
||||
return latest_checkpoint
|
||||
|
||||
|
||||
class NumpySharedRunningStats(SharedRunningStats):
|
||||
def __init__(self, name, epsilon=1e-2, pubsub_params=None):
|
||||
super().__init__(name=name, pubsub_params=pubsub_params)
|
||||
self._count = epsilon
|
||||
self.epsilon = epsilon
|
||||
self.checkpoint_file_extension = 'srs'
|
||||
|
||||
def set_params(self, shape=[1], clip_values=None):
|
||||
self._shape = shape
|
||||
@@ -185,11 +173,12 @@ class NumpySharedRunningStats(SharedRunningStats):
|
||||
'_sum': self._sum,
|
||||
'_sum_squares': self._sum_squares}
|
||||
|
||||
with open(os.path.join(checkpoint_dir, str(checkpoint_prefix) + '.srs'), 'wb') as f:
|
||||
with open(os.path.join(checkpoint_dir, str(checkpoint_prefix) + '.' + self.checkpoint_file_extension), 'wb') as f:
|
||||
pickle.dump(dict_to_save, f, pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
||||
latest_checkpoint_filename = self.get_latest_checkpoint(checkpoint_dir, checkpoint_prefix)
|
||||
latest_checkpoint_filename = get_latest_checkpoint(checkpoint_dir, checkpoint_prefix,
|
||||
self.checkpoint_file_extension)
|
||||
|
||||
if latest_checkpoint_filename == '':
|
||||
raise ValueError("Could not find NumpySharedRunningStats checkpoint file. ")
|
||||
|
||||
Reference in New Issue
Block a user