From 4c914c057c02739bee25f8d4599a8260f86da296 Mon Sep 17 00:00:00 2001 From: Gal Leibovich Date: Mon, 17 Dec 2018 21:36:27 +0200 Subject: [PATCH] fix for finding the right filter checkpoint to restore + do not update internal filter state when evaluating + fix SharedRunningStats checkpoint filenames (#147) --- rl_coach/agents/agent.py | 10 +++++++--- rl_coach/agents/clipped_ppo_agent.py | 2 +- rl_coach/filters/filter.py | 16 ++++++++-------- rl_coach/utilities/shared_running_stats.py | 6 +++--- 4 files changed, 19 insertions(+), 15 deletions(-) diff --git a/rl_coach/agents/agent.py b/rl_coach/agents/agent.py index 53dc4c3..a6b5297 100644 --- a/rl_coach/agents/agent.py +++ b/rl_coach/agents/agent.py @@ -762,7 +762,8 @@ class Agent(AgentInterface): # informed action if self.pre_network_filter is not None: # before choosing an action, first use the pre_network_filter to filter out the current state - curr_state = self.run_pre_network_filter_for_inference(self.curr_state) + update_filter_internal_state = self.phase is not RunPhase.TEST + curr_state = self.run_pre_network_filter_for_inference(self.curr_state, update_filter_internal_state) else: curr_state = self.curr_state @@ -772,15 +773,18 @@ class Agent(AgentInterface): return filtered_action_info - def run_pre_network_filter_for_inference(self, state: StateType) -> StateType: + def run_pre_network_filter_for_inference(self, state: StateType, update_filter_internal_state: bool=True)\ + -> StateType: """ Run filters which where defined for being applied right before using the state for inference. :param state: The state to run the filters on + :param update_filter_internal_state: Should update the filter's internal state - should not update when evaluating :return: The filtered state """ dummy_env_response = EnvResponse(next_state=state, reward=0, game_over=False) - return self.pre_network_filter.filter(dummy_env_response)[0].next_state + return self.pre_network_filter.filter(dummy_env_response, + update_internal_state=update_filter_internal_state)[0].next_state def get_state_embedding(self, state: dict) -> np.ndarray: """ diff --git a/rl_coach/agents/clipped_ppo_agent.py b/rl_coach/agents/clipped_ppo_agent.py index 9fa8d72..71ccdce 100644 --- a/rl_coach/agents/clipped_ppo_agent.py +++ b/rl_coach/agents/clipped_ppo_agent.py @@ -325,7 +325,7 @@ class ClippedPPOAgent(ActorCriticAgent): self.update_log() return None - def run_pre_network_filter_for_inference(self, state: StateType): + def run_pre_network_filter_for_inference(self, state: StateType, update_internal_state: bool=False): dummy_env_response = EnvResponse(next_state=state, reward=0, game_over=False) return self.pre_network_filter.filter(dummy_env_response, update_internal_state=False)[0].next_state diff --git a/rl_coach/filters/filter.py b/rl_coach/filters/filter.py index 6ad2d55..6881f8d 100644 --- a/rl_coach/filters/filter.py +++ b/rl_coach/filters/filter.py @@ -466,14 +466,14 @@ class InputFilter(Filter): if self.name is not None: checkpoint_prefix = '.'.join([checkpoint_prefix, self.name]) for filter_name, filter in self._reward_filters.items(): - checkpoint_prefix = '.'.join([checkpoint_prefix, 'reward_filters', filter_name]) - filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix) + curr_reward_filter_ckpt_prefix = '.'.join([checkpoint_prefix, 'reward_filters', filter_name]) + filter.save_state_to_checkpoint(checkpoint_dir, curr_reward_filter_ckpt_prefix) for observation_name, filters_dict in self._observation_filters.items(): for filter_name, filter in filters_dict.items(): - checkpoint_prefix = '.'.join([checkpoint_prefix, 'observation_filters', observation_name, + curr_obs_filter_ckpt_prefix = '.'.join([checkpoint_prefix, 'observation_filters', observation_name, filter_name]) - filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix) + filter.save_state_to_checkpoint(checkpoint_dir, curr_obs_filter_ckpt_prefix) def restore_state_from_checkpoint(self, checkpoint_dir, checkpoint_prefix)->None: """ @@ -486,14 +486,14 @@ class InputFilter(Filter): if self.name is not None: checkpoint_prefix = '.'.join([checkpoint_prefix, self.name]) for filter_name, filter in self._reward_filters.items(): - checkpoint_prefix = '.'.join([checkpoint_prefix, 'reward_filters', filter_name]) - filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix) + curr_reward_filter_ckpt_prefix = '.'.join([checkpoint_prefix, 'reward_filters', filter_name]) + filter.restore_state_from_checkpoint(checkpoint_dir, curr_reward_filter_ckpt_prefix) for observation_name, filters_dict in self._observation_filters.items(): for filter_name, filter in filters_dict.items(): - checkpoint_prefix = '.'.join([checkpoint_prefix, 'observation_filters', observation_name, + curr_obs_filter_ckpt_prefix = '.'.join([checkpoint_prefix, 'observation_filters', observation_name, filter_name]) - filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix) + filter.restore_state_from_checkpoint(checkpoint_dir, curr_obs_filter_ckpt_prefix) class NoInputFilter(InputFilter): diff --git a/rl_coach/utilities/shared_running_stats.py b/rl_coach/utilities/shared_running_stats.py index 7f1176f..b78b66b 100644 --- a/rl_coach/utilities/shared_running_stats.py +++ b/rl_coach/utilities/shared_running_stats.py @@ -109,13 +109,13 @@ class SharedRunningStats(ABC): def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str): pass - def get_latest_checkpoint(self, checkpoint_dir: str) -> str: + def get_latest_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str) -> str: latest_checkpoint_id = -1 latest_checkpoint = '' # get all checkpoint files for fname in os.listdir(checkpoint_dir): path = os.path.join(checkpoint_dir, fname) - if os.path.isdir(path) or fname.split('.')[-1] != 'srs': + if os.path.isdir(path) or fname.split('.')[-1] != 'srs' or checkpoint_prefix not in fname: continue checkpoint_id = int(fname.split('_')[0]) if checkpoint_id > latest_checkpoint_id: @@ -189,7 +189,7 @@ class NumpySharedRunningStats(SharedRunningStats): pickle.dump(dict_to_save, f, pickle.HIGHEST_PROTOCOL) def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str): - latest_checkpoint_filename = self.get_latest_checkpoint(checkpoint_dir) + latest_checkpoint_filename = self.get_latest_checkpoint(checkpoint_dir, checkpoint_prefix) if latest_checkpoint_filename == '': raise ValueError("Could not find NumpySharedRunningStats checkpoint file. ")