mirror of
https://github.com/gryf/coach.git
synced 2026-03-16 22:53:37 +01:00
fix for finding the right filter checkpoint to restore + do not update internal filter state when evaluating + fix SharedRunningStats checkpoint filenames (#147)
This commit is contained in:
@@ -466,14 +466,14 @@ class InputFilter(Filter):
|
||||
if self.name is not None:
|
||||
checkpoint_prefix = '.'.join([checkpoint_prefix, self.name])
|
||||
for filter_name, filter in self._reward_filters.items():
|
||||
checkpoint_prefix = '.'.join([checkpoint_prefix, 'reward_filters', filter_name])
|
||||
filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix)
|
||||
curr_reward_filter_ckpt_prefix = '.'.join([checkpoint_prefix, 'reward_filters', filter_name])
|
||||
filter.save_state_to_checkpoint(checkpoint_dir, curr_reward_filter_ckpt_prefix)
|
||||
|
||||
for observation_name, filters_dict in self._observation_filters.items():
|
||||
for filter_name, filter in filters_dict.items():
|
||||
checkpoint_prefix = '.'.join([checkpoint_prefix, 'observation_filters', observation_name,
|
||||
curr_obs_filter_ckpt_prefix = '.'.join([checkpoint_prefix, 'observation_filters', observation_name,
|
||||
filter_name])
|
||||
filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix)
|
||||
filter.save_state_to_checkpoint(checkpoint_dir, curr_obs_filter_ckpt_prefix)
|
||||
|
||||
def restore_state_from_checkpoint(self, checkpoint_dir, checkpoint_prefix)->None:
|
||||
"""
|
||||
@@ -486,14 +486,14 @@ class InputFilter(Filter):
|
||||
if self.name is not None:
|
||||
checkpoint_prefix = '.'.join([checkpoint_prefix, self.name])
|
||||
for filter_name, filter in self._reward_filters.items():
|
||||
checkpoint_prefix = '.'.join([checkpoint_prefix, 'reward_filters', filter_name])
|
||||
filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)
|
||||
curr_reward_filter_ckpt_prefix = '.'.join([checkpoint_prefix, 'reward_filters', filter_name])
|
||||
filter.restore_state_from_checkpoint(checkpoint_dir, curr_reward_filter_ckpt_prefix)
|
||||
|
||||
for observation_name, filters_dict in self._observation_filters.items():
|
||||
for filter_name, filter in filters_dict.items():
|
||||
checkpoint_prefix = '.'.join([checkpoint_prefix, 'observation_filters', observation_name,
|
||||
curr_obs_filter_ckpt_prefix = '.'.join([checkpoint_prefix, 'observation_filters', observation_name,
|
||||
filter_name])
|
||||
filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)
|
||||
filter.restore_state_from_checkpoint(checkpoint_dir, curr_obs_filter_ckpt_prefix)
|
||||
|
||||
|
||||
class NoInputFilter(InputFilter):
|
||||
|
||||
Reference in New Issue
Block a user