1
0
mirror of https://github.com/gryf/coach.git synced 2026-03-16 22:53:37 +01:00

fix for finding the right filter checkpoint to restore + do not update internal filter state when evaluating + fix SharedRunningStats checkpoint filenames (#147)

This commit is contained in:
Gal Leibovich
2018-12-17 21:36:27 +02:00
committed by GitHub
parent b4bc8a476c
commit 4c914c057c
4 changed files with 19 additions and 15 deletions

View File

@@ -466,14 +466,14 @@ class InputFilter(Filter):
if self.name is not None:
checkpoint_prefix = '.'.join([checkpoint_prefix, self.name])
for filter_name, filter in self._reward_filters.items():
checkpoint_prefix = '.'.join([checkpoint_prefix, 'reward_filters', filter_name])
filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix)
curr_reward_filter_ckpt_prefix = '.'.join([checkpoint_prefix, 'reward_filters', filter_name])
filter.save_state_to_checkpoint(checkpoint_dir, curr_reward_filter_ckpt_prefix)
for observation_name, filters_dict in self._observation_filters.items():
for filter_name, filter in filters_dict.items():
checkpoint_prefix = '.'.join([checkpoint_prefix, 'observation_filters', observation_name,
curr_obs_filter_ckpt_prefix = '.'.join([checkpoint_prefix, 'observation_filters', observation_name,
filter_name])
filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix)
filter.save_state_to_checkpoint(checkpoint_dir, curr_obs_filter_ckpt_prefix)
def restore_state_from_checkpoint(self, checkpoint_dir, checkpoint_prefix)->None:
"""
@@ -486,14 +486,14 @@ class InputFilter(Filter):
if self.name is not None:
checkpoint_prefix = '.'.join([checkpoint_prefix, self.name])
for filter_name, filter in self._reward_filters.items():
checkpoint_prefix = '.'.join([checkpoint_prefix, 'reward_filters', filter_name])
filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)
curr_reward_filter_ckpt_prefix = '.'.join([checkpoint_prefix, 'reward_filters', filter_name])
filter.restore_state_from_checkpoint(checkpoint_dir, curr_reward_filter_ckpt_prefix)
for observation_name, filters_dict in self._observation_filters.items():
for filter_name, filter in filters_dict.items():
checkpoint_prefix = '.'.join([checkpoint_prefix, 'observation_filters', observation_name,
curr_obs_filter_ckpt_prefix = '.'.join([checkpoint_prefix, 'observation_filters', observation_name,
filter_name])
filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)
filter.restore_state_from_checkpoint(checkpoint_dir, curr_obs_filter_ckpt_prefix)
class NoInputFilter(InputFilter):