diff --git a/rl_coach/base_parameters.py b/rl_coach/base_parameters.py index ff0bdb5..de737eb 100644 --- a/rl_coach/base_parameters.py +++ b/rl_coach/base_parameters.py @@ -22,8 +22,8 @@ from collections import OrderedDict from enum import Enum from typing import Dict, List, Union -from rl_coach.core_types import TrainingSteps, EnvironmentSteps, GradientClippingMethod, RunPhase -# from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod +from rl_coach.core_types import TrainingSteps, EnvironmentSteps, GradientClippingMethod, RunPhase, \ + SelectedPhaseOnlyDumpFilter, MaxDumpFilter from rl_coach.filters.filter import NoInputFilter @@ -285,7 +285,6 @@ class NetworkComponentParameters(Parameters): self.dense_layer = dense_layer - class VisualizationParameters(Parameters): def __init__(self, print_networks_summary=False, @@ -293,7 +292,7 @@ class VisualizationParameters(Parameters): dump_signals_to_csv_every_x_episodes=5, dump_gifs=False, dump_mp4=False, - video_dump_methods=[], + video_dump_methods=None, dump_in_episode_signals=False, dump_parameters_documentation=True, render=False, @@ -352,6 +351,8 @@ class VisualizationParameters(Parameters): which will be passed to the agent and allow using those images. """ super().__init__() + if video_dump_methods is None: + video_dump_methods = [SelectedPhaseOnlyDumpFilter(RunPhase.TEST), MaxDumpFilter()] self.print_networks_summary = print_networks_summary self.dump_csv = dump_csv self.dump_gifs = dump_gifs @@ -363,7 +364,7 @@ class VisualizationParameters(Parameters): self.native_rendering = native_rendering self.max_fps_for_human_control = max_fps_for_human_control self.tensorboard = tensorboard - self.video_dump_methods = video_dump_methods + self.video_dump_filters = video_dump_methods self.add_rendered_image_to_env_response = add_rendered_image_to_env_response diff --git a/rl_coach/core_types.py b/rl_coach/core_types.py index 1e9fafa..8610c4f 100644 --- a/rl_coach/core_types.py +++ b/rl_coach/core_types.py @@ -22,6 +22,8 @@ from typing import List, Union, Dict, Any, Type import numpy as np +from rl_coach.utils import force_list + ActionType = Union[int, float, np.ndarray, List] GoalType = Union[None, np.ndarray] ObservationType = np.ndarray @@ -692,3 +694,79 @@ class Episode(object): def __getitem__(self, sliced): return self.transitions[sliced] + + +""" +Video Dumping Methods +""" + + +class VideoDumpFilter(object): + """ + Method used to decide when to dump videos + """ + def should_dump(self, episode_terminated=False, **kwargs): + raise NotImplementedError("") + + +class AlwaysDumpFilter(VideoDumpFilter): + """ + Dump video for every episode + """ + def __init__(self): + super().__init__() + + def should_dump(self, episode_terminated=False, **kwargs): + return True + + +class MaxDumpFilter(VideoDumpFilter): + """ + Dump video every time a new max total reward has been achieved + """ + def __init__(self): + super().__init__() + self.max_reward_achieved = -np.inf + + def should_dump(self, episode_terminated=False, **kwargs): + # if the episode has not finished yet we want to be prepared for dumping a video + if not episode_terminated: + return True + if kwargs['total_reward_in_current_episode'] > self.max_reward_achieved: + self.max_reward_achieved = kwargs['total_reward_in_current_episode'] + return True + else: + return False + + +class EveryNEpisodesDumpFilter(object): + """ + Dump videos once in every N episodes + """ + def __init__(self, num_episodes_between_dumps: int): + super().__init__() + self.num_episodes_between_dumps = num_episodes_between_dumps + self.last_dumped_episode = 0 + if num_episodes_between_dumps < 1: + raise ValueError("the number of episodes between dumps should be a positive number") + + def should_dump(self, episode_terminated=False, **kwargs): + if kwargs['episode_idx'] >= self.last_dumped_episode + self.num_episodes_between_dumps - 1: + self.last_dumped_episode = kwargs['episode_idx'] + return True + else: + return False + + +class SelectedPhaseOnlyDumpFilter(object): + """ + Dump videos when the phase of the environment matches a predefined phase + """ + def __init__(self, run_phases: Union[RunPhase, List[RunPhase]]): + self.run_phases = force_list(run_phases) + + def should_dump(self, episode_terminated=False, **kwargs): + if kwargs['_phase'] in self.run_phases: + return True + else: + return False diff --git a/rl_coach/environments/environment.py b/rl_coach/environments/environment.py index b055afc..295c168 100644 --- a/rl_coach/environments/environment.py +++ b/rl_coach/environments/environment.py @@ -317,6 +317,13 @@ class Environment(EnvironmentInterface): else: self.renderer.render_image(self.get_rendered_image()) + def handle_episode_ended(self) -> None: + """ + End an episode + :return: None + """ + self.dump_video_of_last_episode_if_needed() + def reset_internal_state(self, force_environment_reset=False) -> EnvResponse: """ Reset the environment and all the variable of the wrapper @@ -324,7 +331,6 @@ class Environment(EnvironmentInterface): :return: A dictionary containing the observation, reward, done flag, action and measurements """ - self.dump_video_of_last_episode_if_needed() self._restart_environment_episode(force_environment_reset) self.last_episode_time = time.time() @@ -392,17 +398,16 @@ class Environment(EnvironmentInterface): self.goal = goal def should_dump_video_of_the_current_episode(self, episode_terminated=False): - if self.visualization_parameters.video_dump_methods: - for video_dump_method in force_list(self.visualization_parameters.video_dump_methods): - if not video_dump_method.should_dump(episode_terminated, **self.__dict__): + if self.visualization_parameters.video_dump_filters: + for video_dump_filter in force_list(self.visualization_parameters.video_dump_filters): + if not video_dump_filter.should_dump(episode_terminated, **self.__dict__): return False return True - return False + return True def dump_video_of_last_episode_if_needed(self): - if self.visualization_parameters.video_dump_methods and self.last_episode_images != []: - if self.should_dump_video_of_the_current_episode(episode_terminated=True): - self.dump_video_of_last_episode() + if self.last_episode_images != [] and self.should_dump_video_of_the_current_episode(episode_terminated=True): + self.dump_video_of_last_episode() def dump_video_of_last_episode(self): frame_skipping = max(1, int(5 / self.frame_skip)) @@ -464,78 +469,3 @@ class Environment(EnvironmentInterface): """ return np.transpose(self.state['observation'], [1, 2, 0]) - -""" -Video Dumping Methods -""" - - -class VideoDumpMethod(object): - """ - Method used to decide when to dump videos - """ - def should_dump(self, episode_terminated=False, **kwargs): - raise NotImplementedError("") - - -class AlwaysDumpMethod(VideoDumpMethod): - """ - Dump video for every episode - """ - def __init__(self): - super().__init__() - - def should_dump(self, episode_terminated=False, **kwargs): - return True - - -class MaxDumpMethod(VideoDumpMethod): - """ - Dump video every time a new max total reward has been achieved - """ - def __init__(self): - super().__init__() - self.max_reward_achieved = -np.inf - - def should_dump(self, episode_terminated=False, **kwargs): - # if the episode has not finished yet we want to be prepared for dumping a video - if not episode_terminated: - return True - if kwargs['total_reward_in_current_episode'] > self.max_reward_achieved: - self.max_reward_achieved = kwargs['total_reward_in_current_episode'] - return True - else: - return False - - -class EveryNEpisodesDumpMethod(object): - """ - Dump videos once in every N episodes - """ - def __init__(self, num_episodes_between_dumps: int): - super().__init__() - self.num_episodes_between_dumps = num_episodes_between_dumps - self.last_dumped_episode = 0 - if num_episodes_between_dumps < 1: - raise ValueError("the number of episodes between dumps should be a positive number") - - def should_dump(self, episode_terminated=False, **kwargs): - if kwargs['episode_idx'] >= self.last_dumped_episode + self.num_episodes_between_dumps - 1: - self.last_dumped_episode = kwargs['episode_idx'] - return True - else: - return False - - -class SelectedPhaseOnlyDumpMethod(object): - """ - Dump videos when the phase of the environment matches a predefined phase - """ - def __init__(self, run_phases: Union[RunPhase, List[RunPhase]]): - self.run_phases = force_list(run_phases) - - def should_dump(self, episode_terminated=False, **kwargs): - if kwargs['_phase'] in self.run_phases: - return True - else: - return False diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index 53525ea..ddbafd4 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -325,8 +325,7 @@ class GraphManager(object): """ self.total_steps_counters[self.phase][EnvironmentEpisodes] += 1 - # TODO: we should disentangle ending the episode from resetting the internal state - # self.reset_internal_state() + [environment.handle_episode_ended() for environment in self.environments] def train(self, steps: TrainingSteps) -> None: """