mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
fix for finding the right filter checkpoint to restore + do not update internal filter state when evaluating + fix SharedRunningStats checkpoint filenames (#147)
This commit is contained in:
@@ -762,7 +762,8 @@ class Agent(AgentInterface):
|
||||
# informed action
|
||||
if self.pre_network_filter is not None:
|
||||
# before choosing an action, first use the pre_network_filter to filter out the current state
|
||||
curr_state = self.run_pre_network_filter_for_inference(self.curr_state)
|
||||
update_filter_internal_state = self.phase is not RunPhase.TEST
|
||||
curr_state = self.run_pre_network_filter_for_inference(self.curr_state, update_filter_internal_state)
|
||||
|
||||
else:
|
||||
curr_state = self.curr_state
|
||||
@@ -772,15 +773,18 @@ class Agent(AgentInterface):
|
||||
|
||||
return filtered_action_info
|
||||
|
||||
def run_pre_network_filter_for_inference(self, state: StateType) -> StateType:
|
||||
def run_pre_network_filter_for_inference(self, state: StateType, update_filter_internal_state: bool=True)\
|
||||
-> StateType:
|
||||
"""
|
||||
Run filters which where defined for being applied right before using the state for inference.
|
||||
|
||||
:param state: The state to run the filters on
|
||||
:param update_filter_internal_state: Should update the filter's internal state - should not update when evaluating
|
||||
:return: The filtered state
|
||||
"""
|
||||
dummy_env_response = EnvResponse(next_state=state, reward=0, game_over=False)
|
||||
return self.pre_network_filter.filter(dummy_env_response)[0].next_state
|
||||
return self.pre_network_filter.filter(dummy_env_response,
|
||||
update_internal_state=update_filter_internal_state)[0].next_state
|
||||
|
||||
def get_state_embedding(self, state: dict) -> np.ndarray:
|
||||
"""
|
||||
|
||||
@@ -325,7 +325,7 @@ class ClippedPPOAgent(ActorCriticAgent):
|
||||
self.update_log()
|
||||
return None
|
||||
|
||||
def run_pre_network_filter_for_inference(self, state: StateType):
|
||||
def run_pre_network_filter_for_inference(self, state: StateType, update_internal_state: bool=False):
|
||||
dummy_env_response = EnvResponse(next_state=state, reward=0, game_over=False)
|
||||
return self.pre_network_filter.filter(dummy_env_response, update_internal_state=False)[0].next_state
|
||||
|
||||
|
||||
Reference in New Issue
Block a user