1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

Adding target reward and target sucess (#58)

* Adding target reward

* Adding target successs

* Addressing comments

* Using custom_reward_threshold and target_success_rate

* Adding exit message

* Moving success rate to environment

* Making target_success_rate optional
This commit is contained in:
Ajay Deshpande
2018-11-12 15:03:43 -08:00
committed by Balaji Subramaniam
parent 0fe583186e
commit 875d6ef017
17 changed files with 162 additions and 74 deletions

View File

@@ -66,10 +66,10 @@ control_suite_envs = {':'.join(env): ':'.join(env) for env in suite.BENCHMARKING
# Environment
class ControlSuiteEnvironment(Environment):
def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters,
seed: Union[None, int]=None, human_control: bool=False,
target_success_rate: float=1.0, seed: Union[None, int]=None, human_control: bool=False,
observation_type: ObservationType=ObservationType.Measurements,
custom_reward_threshold: Union[int, float]=None, **kwargs):
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters)
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters, target_success_rate)
self.observation_type = observation_type
@@ -126,6 +126,8 @@ class ControlSuiteEnvironment(Environment):
if not self.native_rendering:
self.renderer.create_screen(image.shape[1]*scale, image.shape[0]*scale)
self.target_success_rate = target_success_rate
def _update_state(self):
self.state = {}
@@ -160,3 +162,6 @@ class ControlSuiteEnvironment(Environment):
def get_rendered_image(self):
return self.env.physics.render(camera_id=0)
def get_target_success_rate(self) -> float:
return self.target_success_rate