mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Adding target reward and target sucess (#58)
* Adding target reward * Adding target successs * Addressing comments * Using custom_reward_threshold and target_success_rate * Adding exit message * Moving success rate to environment * Making target_success_rate optional
This commit is contained in:
committed by
Balaji Subramaniam
parent
0fe583186e
commit
875d6ef017
@@ -178,11 +178,11 @@ class MaxOverFramesAndFrameskipEnvWrapper(gym.Wrapper):
|
||||
# Environment
|
||||
class GymEnvironment(Environment):
|
||||
def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters,
|
||||
additional_simulator_parameters: Dict[str, Any] = {}, seed: Union[None, int]=None,
|
||||
target_success_rate: float=1.0, additional_simulator_parameters: Dict[str, Any] = {}, seed: Union[None, int]=None,
|
||||
human_control: bool=False, custom_reward_threshold: Union[int, float]=None,
|
||||
random_initialization_steps: int=1, max_over_num_frames: int=1, **kwargs):
|
||||
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold,
|
||||
visualization_parameters)
|
||||
visualization_parameters, target_success_rate)
|
||||
|
||||
self.random_initialization_steps = random_initialization_steps
|
||||
self.max_over_num_frames = max_over_num_frames
|
||||
@@ -221,7 +221,7 @@ class GymEnvironment(Environment):
|
||||
try:
|
||||
self.env = env_class(**self.additional_simulator_parameters)
|
||||
except:
|
||||
screen.error("Failed to instantiate Gym environment class %s with arguments %s" %
|
||||
screen.error("Failed to instantiate Gym environment class %s with arguments %s" %
|
||||
(env_class, self.additional_simulator_parameters), crash=False)
|
||||
raise
|
||||
else:
|
||||
@@ -337,6 +337,8 @@ class GymEnvironment(Environment):
|
||||
self.reward_success_threshold = self.env.spec.reward_threshold
|
||||
self.reward_space = RewardSpace(1, reward_success_threshold=self.reward_success_threshold)
|
||||
|
||||
self.target_success_rate = target_success_rate
|
||||
|
||||
def _wrap_state(self, state):
|
||||
if not isinstance(self.env.observation_space, gym.spaces.Dict):
|
||||
return {'observation': state}
|
||||
@@ -434,3 +436,6 @@ class GymEnvironment(Environment):
|
||||
if self.is_mujoco_env:
|
||||
self._set_mujoco_camera(0)
|
||||
return image
|
||||
|
||||
def get_target_success_rate(self) -> float:
|
||||
return self.target_success_rate
|
||||
|
||||
Reference in New Issue
Block a user