1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Distiller's AMC induced changes (#359)

* override episode rewards with the last transition reward

* EWMA normalization filter

* allowing control over when the pre_network filter runs
This commit is contained in:
Gal Leibovich
2019-08-05 10:24:58 +03:00
committed by GitHub
parent 7df67dafa3
commit c1d1fae342
10 changed files with 137 additions and 30 deletions

View File

@@ -20,6 +20,8 @@ import pickle
import redis
import numpy as np
from rl_coach.utils import get_latest_checkpoint
class SharedRunningStatsSubscribe(threading.Thread):
def __init__(self, shared_running_stats):
@@ -109,27 +111,13 @@ class SharedRunningStats(ABC):
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
pass
def get_latest_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str) -> str:
latest_checkpoint_id = -1
latest_checkpoint = ''
# get all checkpoint files
for fname in os.listdir(checkpoint_dir):
path = os.path.join(checkpoint_dir, fname)
if os.path.isdir(path) or fname.split('.')[-1] != 'srs' or checkpoint_prefix not in fname:
continue
checkpoint_id = int(fname.split('_')[0])
if checkpoint_id > latest_checkpoint_id:
latest_checkpoint = fname
latest_checkpoint_id = checkpoint_id
return latest_checkpoint
class NumpySharedRunningStats(SharedRunningStats):
def __init__(self, name, epsilon=1e-2, pubsub_params=None):
super().__init__(name=name, pubsub_params=pubsub_params)
self._count = epsilon
self.epsilon = epsilon
self.checkpoint_file_extension = 'srs'
def set_params(self, shape=[1], clip_values=None):
self._shape = shape
@@ -185,11 +173,12 @@ class NumpySharedRunningStats(SharedRunningStats):
'_sum': self._sum,
'_sum_squares': self._sum_squares}
with open(os.path.join(checkpoint_dir, str(checkpoint_prefix) + '.srs'), 'wb') as f:
with open(os.path.join(checkpoint_dir, str(checkpoint_prefix) + '.' + self.checkpoint_file_extension), 'wb') as f:
pickle.dump(dict_to_save, f, pickle.HIGHEST_PROTOCOL)
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
latest_checkpoint_filename = self.get_latest_checkpoint(checkpoint_dir, checkpoint_prefix)
latest_checkpoint_filename = get_latest_checkpoint(checkpoint_dir, checkpoint_prefix,
self.checkpoint_file_extension)
if latest_checkpoint_filename == '':
raise ValueError("Could not find NumpySharedRunningStats checkpoint file. ")