mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
fix for finding the right filter checkpoint to restore + do not update internal filter state when evaluating + fix SharedRunningStats checkpoint filenames (#147)
This commit is contained in:
@@ -762,7 +762,8 @@ class Agent(AgentInterface):
|
|||||||
# informed action
|
# informed action
|
||||||
if self.pre_network_filter is not None:
|
if self.pre_network_filter is not None:
|
||||||
# before choosing an action, first use the pre_network_filter to filter out the current state
|
# before choosing an action, first use the pre_network_filter to filter out the current state
|
||||||
curr_state = self.run_pre_network_filter_for_inference(self.curr_state)
|
update_filter_internal_state = self.phase is not RunPhase.TEST
|
||||||
|
curr_state = self.run_pre_network_filter_for_inference(self.curr_state, update_filter_internal_state)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
curr_state = self.curr_state
|
curr_state = self.curr_state
|
||||||
@@ -772,15 +773,18 @@ class Agent(AgentInterface):
|
|||||||
|
|
||||||
return filtered_action_info
|
return filtered_action_info
|
||||||
|
|
||||||
def run_pre_network_filter_for_inference(self, state: StateType) -> StateType:
|
def run_pre_network_filter_for_inference(self, state: StateType, update_filter_internal_state: bool=True)\
|
||||||
|
-> StateType:
|
||||||
"""
|
"""
|
||||||
Run filters which where defined for being applied right before using the state for inference.
|
Run filters which where defined for being applied right before using the state for inference.
|
||||||
|
|
||||||
:param state: The state to run the filters on
|
:param state: The state to run the filters on
|
||||||
|
:param update_filter_internal_state: Should update the filter's internal state - should not update when evaluating
|
||||||
:return: The filtered state
|
:return: The filtered state
|
||||||
"""
|
"""
|
||||||
dummy_env_response = EnvResponse(next_state=state, reward=0, game_over=False)
|
dummy_env_response = EnvResponse(next_state=state, reward=0, game_over=False)
|
||||||
return self.pre_network_filter.filter(dummy_env_response)[0].next_state
|
return self.pre_network_filter.filter(dummy_env_response,
|
||||||
|
update_internal_state=update_filter_internal_state)[0].next_state
|
||||||
|
|
||||||
def get_state_embedding(self, state: dict) -> np.ndarray:
|
def get_state_embedding(self, state: dict) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -325,7 +325,7 @@ class ClippedPPOAgent(ActorCriticAgent):
|
|||||||
self.update_log()
|
self.update_log()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def run_pre_network_filter_for_inference(self, state: StateType):
|
def run_pre_network_filter_for_inference(self, state: StateType, update_internal_state: bool=False):
|
||||||
dummy_env_response = EnvResponse(next_state=state, reward=0, game_over=False)
|
dummy_env_response = EnvResponse(next_state=state, reward=0, game_over=False)
|
||||||
return self.pre_network_filter.filter(dummy_env_response, update_internal_state=False)[0].next_state
|
return self.pre_network_filter.filter(dummy_env_response, update_internal_state=False)[0].next_state
|
||||||
|
|
||||||
|
|||||||
@@ -466,14 +466,14 @@ class InputFilter(Filter):
|
|||||||
if self.name is not None:
|
if self.name is not None:
|
||||||
checkpoint_prefix = '.'.join([checkpoint_prefix, self.name])
|
checkpoint_prefix = '.'.join([checkpoint_prefix, self.name])
|
||||||
for filter_name, filter in self._reward_filters.items():
|
for filter_name, filter in self._reward_filters.items():
|
||||||
checkpoint_prefix = '.'.join([checkpoint_prefix, 'reward_filters', filter_name])
|
curr_reward_filter_ckpt_prefix = '.'.join([checkpoint_prefix, 'reward_filters', filter_name])
|
||||||
filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix)
|
filter.save_state_to_checkpoint(checkpoint_dir, curr_reward_filter_ckpt_prefix)
|
||||||
|
|
||||||
for observation_name, filters_dict in self._observation_filters.items():
|
for observation_name, filters_dict in self._observation_filters.items():
|
||||||
for filter_name, filter in filters_dict.items():
|
for filter_name, filter in filters_dict.items():
|
||||||
checkpoint_prefix = '.'.join([checkpoint_prefix, 'observation_filters', observation_name,
|
curr_obs_filter_ckpt_prefix = '.'.join([checkpoint_prefix, 'observation_filters', observation_name,
|
||||||
filter_name])
|
filter_name])
|
||||||
filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix)
|
filter.save_state_to_checkpoint(checkpoint_dir, curr_obs_filter_ckpt_prefix)
|
||||||
|
|
||||||
def restore_state_from_checkpoint(self, checkpoint_dir, checkpoint_prefix)->None:
|
def restore_state_from_checkpoint(self, checkpoint_dir, checkpoint_prefix)->None:
|
||||||
"""
|
"""
|
||||||
@@ -486,14 +486,14 @@ class InputFilter(Filter):
|
|||||||
if self.name is not None:
|
if self.name is not None:
|
||||||
checkpoint_prefix = '.'.join([checkpoint_prefix, self.name])
|
checkpoint_prefix = '.'.join([checkpoint_prefix, self.name])
|
||||||
for filter_name, filter in self._reward_filters.items():
|
for filter_name, filter in self._reward_filters.items():
|
||||||
checkpoint_prefix = '.'.join([checkpoint_prefix, 'reward_filters', filter_name])
|
curr_reward_filter_ckpt_prefix = '.'.join([checkpoint_prefix, 'reward_filters', filter_name])
|
||||||
filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)
|
filter.restore_state_from_checkpoint(checkpoint_dir, curr_reward_filter_ckpt_prefix)
|
||||||
|
|
||||||
for observation_name, filters_dict in self._observation_filters.items():
|
for observation_name, filters_dict in self._observation_filters.items():
|
||||||
for filter_name, filter in filters_dict.items():
|
for filter_name, filter in filters_dict.items():
|
||||||
checkpoint_prefix = '.'.join([checkpoint_prefix, 'observation_filters', observation_name,
|
curr_obs_filter_ckpt_prefix = '.'.join([checkpoint_prefix, 'observation_filters', observation_name,
|
||||||
filter_name])
|
filter_name])
|
||||||
filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)
|
filter.restore_state_from_checkpoint(checkpoint_dir, curr_obs_filter_ckpt_prefix)
|
||||||
|
|
||||||
|
|
||||||
class NoInputFilter(InputFilter):
|
class NoInputFilter(InputFilter):
|
||||||
|
|||||||
@@ -109,13 +109,13 @@ class SharedRunningStats(ABC):
|
|||||||
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_latest_checkpoint(self, checkpoint_dir: str) -> str:
|
def get_latest_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str) -> str:
|
||||||
latest_checkpoint_id = -1
|
latest_checkpoint_id = -1
|
||||||
latest_checkpoint = ''
|
latest_checkpoint = ''
|
||||||
# get all checkpoint files
|
# get all checkpoint files
|
||||||
for fname in os.listdir(checkpoint_dir):
|
for fname in os.listdir(checkpoint_dir):
|
||||||
path = os.path.join(checkpoint_dir, fname)
|
path = os.path.join(checkpoint_dir, fname)
|
||||||
if os.path.isdir(path) or fname.split('.')[-1] != 'srs':
|
if os.path.isdir(path) or fname.split('.')[-1] != 'srs' or checkpoint_prefix not in fname:
|
||||||
continue
|
continue
|
||||||
checkpoint_id = int(fname.split('_')[0])
|
checkpoint_id = int(fname.split('_')[0])
|
||||||
if checkpoint_id > latest_checkpoint_id:
|
if checkpoint_id > latest_checkpoint_id:
|
||||||
@@ -189,7 +189,7 @@ class NumpySharedRunningStats(SharedRunningStats):
|
|||||||
pickle.dump(dict_to_save, f, pickle.HIGHEST_PROTOCOL)
|
pickle.dump(dict_to_save, f, pickle.HIGHEST_PROTOCOL)
|
||||||
|
|
||||||
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
||||||
latest_checkpoint_filename = self.get_latest_checkpoint(checkpoint_dir)
|
latest_checkpoint_filename = self.get_latest_checkpoint(checkpoint_dir, checkpoint_prefix)
|
||||||
|
|
||||||
if latest_checkpoint_filename == '':
|
if latest_checkpoint_filename == '':
|
||||||
raise ValueError("Could not find NumpySharedRunningStats checkpoint file. ")
|
raise ValueError("Could not find NumpySharedRunningStats checkpoint file. ")
|
||||||
|
|||||||
Reference in New Issue
Block a user