1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

bug-fix for dumping movies (+ small refactoring and rename 'VideoDumpMethod -> 'VideoDumpFilter')

This commit is contained in:
Gal Leibovich
2018-10-21 17:29:10 +03:00
parent 364168490f
commit 5a8da90d32
4 changed files with 98 additions and 90 deletions

View File

@@ -317,6 +317,13 @@ class Environment(EnvironmentInterface):
else:
self.renderer.render_image(self.get_rendered_image())
def handle_episode_ended(self) -> None:
"""
End an episode
:return: None
"""
self.dump_video_of_last_episode_if_needed()
def reset_internal_state(self, force_environment_reset=False) -> EnvResponse:
"""
Reset the environment and all the variable of the wrapper
@@ -324,7 +331,6 @@ class Environment(EnvironmentInterface):
:return: A dictionary containing the observation, reward, done flag, action and measurements
"""
self.dump_video_of_last_episode_if_needed()
self._restart_environment_episode(force_environment_reset)
self.last_episode_time = time.time()
@@ -392,17 +398,16 @@ class Environment(EnvironmentInterface):
self.goal = goal
def should_dump_video_of_the_current_episode(self, episode_terminated=False):
if self.visualization_parameters.video_dump_methods:
for video_dump_method in force_list(self.visualization_parameters.video_dump_methods):
if not video_dump_method.should_dump(episode_terminated, **self.__dict__):
if self.visualization_parameters.video_dump_filters:
for video_dump_filter in force_list(self.visualization_parameters.video_dump_filters):
if not video_dump_filter.should_dump(episode_terminated, **self.__dict__):
return False
return True
return False
return True
def dump_video_of_last_episode_if_needed(self):
if self.visualization_parameters.video_dump_methods and self.last_episode_images != []:
if self.should_dump_video_of_the_current_episode(episode_terminated=True):
self.dump_video_of_last_episode()
if self.last_episode_images != [] and self.should_dump_video_of_the_current_episode(episode_terminated=True):
self.dump_video_of_last_episode()
def dump_video_of_last_episode(self):
frame_skipping = max(1, int(5 / self.frame_skip))
@@ -464,78 +469,3 @@ class Environment(EnvironmentInterface):
"""
return np.transpose(self.state['observation'], [1, 2, 0])
"""
Video Dumping Methods
"""
class VideoDumpMethod(object):
"""
Method used to decide when to dump videos
"""
def should_dump(self, episode_terminated=False, **kwargs):
raise NotImplementedError("")
class AlwaysDumpMethod(VideoDumpMethod):
"""
Dump video for every episode
"""
def __init__(self):
super().__init__()
def should_dump(self, episode_terminated=False, **kwargs):
return True
class MaxDumpMethod(VideoDumpMethod):
"""
Dump video every time a new max total reward has been achieved
"""
def __init__(self):
super().__init__()
self.max_reward_achieved = -np.inf
def should_dump(self, episode_terminated=False, **kwargs):
# if the episode has not finished yet we want to be prepared for dumping a video
if not episode_terminated:
return True
if kwargs['total_reward_in_current_episode'] > self.max_reward_achieved:
self.max_reward_achieved = kwargs['total_reward_in_current_episode']
return True
else:
return False
class EveryNEpisodesDumpMethod(object):
"""
Dump videos once in every N episodes
"""
def __init__(self, num_episodes_between_dumps: int):
super().__init__()
self.num_episodes_between_dumps = num_episodes_between_dumps
self.last_dumped_episode = 0
if num_episodes_between_dumps < 1:
raise ValueError("the number of episodes between dumps should be a positive number")
def should_dump(self, episode_terminated=False, **kwargs):
if kwargs['episode_idx'] >= self.last_dumped_episode + self.num_episodes_between_dumps - 1:
self.last_dumped_episode = kwargs['episode_idx']
return True
else:
return False
class SelectedPhaseOnlyDumpMethod(object):
"""
Dump videos when the phase of the environment matches a predefined phase
"""
def __init__(self, run_phases: Union[RunPhase, List[RunPhase]]):
self.run_phases = force_list(run_phases)
def should_dump(self, episode_terminated=False, **kwargs):
if kwargs['_phase'] in self.run_phases:
return True
else:
return False