mirror of
https://github.com/gryf/coach.git
synced 2026-02-23 10:35:46 +01:00
Batch RL Tutorial (#372)
This commit is contained in:
@@ -64,15 +64,14 @@ class RewardNormalizationFilter(RewardFilter):
|
||||
|
||||
def filter(self, reward: RewardType, update_internal_state: bool=True) -> RewardType:
|
||||
if update_internal_state:
|
||||
if not isinstance(reward, np.ndarray) or len(reward.shape) < 2:
|
||||
reward = np.array([[reward]])
|
||||
self.running_rewards_stats.push(reward)
|
||||
|
||||
reward = (reward - self.running_rewards_stats.mean) / \
|
||||
(self.running_rewards_stats.std + 1e-15)
|
||||
reward = np.clip(reward, self.clip_min, self.clip_max)
|
||||
|
||||
return reward
|
||||
return self.running_rewards_stats.normalize(reward).squeeze()
|
||||
|
||||
def get_filtered_reward_space(self, input_reward_space: RewardSpace) -> RewardSpace:
|
||||
self.running_rewards_stats.set_params(shape=(1,), clip_values=(self.clip_min, self.clip_max))
|
||||
return input_reward_space
|
||||
|
||||
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
||||
|
||||
Reference in New Issue
Block a user