mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
bug-fix for dumping movies (+ small refactoring and rename 'VideoDumpMethod -> 'VideoDumpFilter')
This commit is contained in:
@@ -22,8 +22,8 @@ from collections import OrderedDict
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentSteps, GradientClippingMethod, RunPhase
|
||||
# from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
|
||||
from rl_coach.core_types import TrainingSteps, EnvironmentSteps, GradientClippingMethod, RunPhase, \
|
||||
SelectedPhaseOnlyDumpFilter, MaxDumpFilter
|
||||
from rl_coach.filters.filter import NoInputFilter
|
||||
|
||||
|
||||
@@ -285,7 +285,6 @@ class NetworkComponentParameters(Parameters):
|
||||
self.dense_layer = dense_layer
|
||||
|
||||
|
||||
|
||||
class VisualizationParameters(Parameters):
|
||||
def __init__(self,
|
||||
print_networks_summary=False,
|
||||
@@ -293,7 +292,7 @@ class VisualizationParameters(Parameters):
|
||||
dump_signals_to_csv_every_x_episodes=5,
|
||||
dump_gifs=False,
|
||||
dump_mp4=False,
|
||||
video_dump_methods=[],
|
||||
video_dump_methods=None,
|
||||
dump_in_episode_signals=False,
|
||||
dump_parameters_documentation=True,
|
||||
render=False,
|
||||
@@ -352,6 +351,8 @@ class VisualizationParameters(Parameters):
|
||||
which will be passed to the agent and allow using those images.
|
||||
"""
|
||||
super().__init__()
|
||||
if video_dump_methods is None:
|
||||
video_dump_methods = [SelectedPhaseOnlyDumpFilter(RunPhase.TEST), MaxDumpFilter()]
|
||||
self.print_networks_summary = print_networks_summary
|
||||
self.dump_csv = dump_csv
|
||||
self.dump_gifs = dump_gifs
|
||||
@@ -363,7 +364,7 @@ class VisualizationParameters(Parameters):
|
||||
self.native_rendering = native_rendering
|
||||
self.max_fps_for_human_control = max_fps_for_human_control
|
||||
self.tensorboard = tensorboard
|
||||
self.video_dump_methods = video_dump_methods
|
||||
self.video_dump_filters = video_dump_methods
|
||||
self.add_rendered_image_to_env_response = add_rendered_image_to_env_response
|
||||
|
||||
|
||||
|
||||
@@ -22,6 +22,8 @@ from typing import List, Union, Dict, Any, Type
|
||||
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.utils import force_list
|
||||
|
||||
ActionType = Union[int, float, np.ndarray, List]
|
||||
GoalType = Union[None, np.ndarray]
|
||||
ObservationType = np.ndarray
|
||||
@@ -692,3 +694,79 @@ class Episode(object):
|
||||
|
||||
def __getitem__(self, sliced):
|
||||
return self.transitions[sliced]
|
||||
|
||||
|
||||
"""
|
||||
Video Dumping Methods
|
||||
"""
|
||||
|
||||
|
||||
class VideoDumpFilter(object):
|
||||
"""
|
||||
Method used to decide when to dump videos
|
||||
"""
|
||||
def should_dump(self, episode_terminated=False, **kwargs):
|
||||
raise NotImplementedError("")
|
||||
|
||||
|
||||
class AlwaysDumpFilter(VideoDumpFilter):
|
||||
"""
|
||||
Dump video for every episode
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def should_dump(self, episode_terminated=False, **kwargs):
|
||||
return True
|
||||
|
||||
|
||||
class MaxDumpFilter(VideoDumpFilter):
|
||||
"""
|
||||
Dump video every time a new max total reward has been achieved
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max_reward_achieved = -np.inf
|
||||
|
||||
def should_dump(self, episode_terminated=False, **kwargs):
|
||||
# if the episode has not finished yet we want to be prepared for dumping a video
|
||||
if not episode_terminated:
|
||||
return True
|
||||
if kwargs['total_reward_in_current_episode'] > self.max_reward_achieved:
|
||||
self.max_reward_achieved = kwargs['total_reward_in_current_episode']
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class EveryNEpisodesDumpFilter(object):
|
||||
"""
|
||||
Dump videos once in every N episodes
|
||||
"""
|
||||
def __init__(self, num_episodes_between_dumps: int):
|
||||
super().__init__()
|
||||
self.num_episodes_between_dumps = num_episodes_between_dumps
|
||||
self.last_dumped_episode = 0
|
||||
if num_episodes_between_dumps < 1:
|
||||
raise ValueError("the number of episodes between dumps should be a positive number")
|
||||
|
||||
def should_dump(self, episode_terminated=False, **kwargs):
|
||||
if kwargs['episode_idx'] >= self.last_dumped_episode + self.num_episodes_between_dumps - 1:
|
||||
self.last_dumped_episode = kwargs['episode_idx']
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class SelectedPhaseOnlyDumpFilter(object):
|
||||
"""
|
||||
Dump videos when the phase of the environment matches a predefined phase
|
||||
"""
|
||||
def __init__(self, run_phases: Union[RunPhase, List[RunPhase]]):
|
||||
self.run_phases = force_list(run_phases)
|
||||
|
||||
def should_dump(self, episode_terminated=False, **kwargs):
|
||||
if kwargs['_phase'] in self.run_phases:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
@@ -317,6 +317,13 @@ class Environment(EnvironmentInterface):
|
||||
else:
|
||||
self.renderer.render_image(self.get_rendered_image())
|
||||
|
||||
def handle_episode_ended(self) -> None:
|
||||
"""
|
||||
End an episode
|
||||
:return: None
|
||||
"""
|
||||
self.dump_video_of_last_episode_if_needed()
|
||||
|
||||
def reset_internal_state(self, force_environment_reset=False) -> EnvResponse:
|
||||
"""
|
||||
Reset the environment and all the variable of the wrapper
|
||||
@@ -324,7 +331,6 @@ class Environment(EnvironmentInterface):
|
||||
:return: A dictionary containing the observation, reward, done flag, action and measurements
|
||||
"""
|
||||
|
||||
self.dump_video_of_last_episode_if_needed()
|
||||
self._restart_environment_episode(force_environment_reset)
|
||||
self.last_episode_time = time.time()
|
||||
|
||||
@@ -392,16 +398,15 @@ class Environment(EnvironmentInterface):
|
||||
self.goal = goal
|
||||
|
||||
def should_dump_video_of_the_current_episode(self, episode_terminated=False):
|
||||
if self.visualization_parameters.video_dump_methods:
|
||||
for video_dump_method in force_list(self.visualization_parameters.video_dump_methods):
|
||||
if not video_dump_method.should_dump(episode_terminated, **self.__dict__):
|
||||
if self.visualization_parameters.video_dump_filters:
|
||||
for video_dump_filter in force_list(self.visualization_parameters.video_dump_filters):
|
||||
if not video_dump_filter.should_dump(episode_terminated, **self.__dict__):
|
||||
return False
|
||||
return True
|
||||
return False
|
||||
return True
|
||||
|
||||
def dump_video_of_last_episode_if_needed(self):
|
||||
if self.visualization_parameters.video_dump_methods and self.last_episode_images != []:
|
||||
if self.should_dump_video_of_the_current_episode(episode_terminated=True):
|
||||
if self.last_episode_images != [] and self.should_dump_video_of_the_current_episode(episode_terminated=True):
|
||||
self.dump_video_of_last_episode()
|
||||
|
||||
def dump_video_of_last_episode(self):
|
||||
@@ -464,78 +469,3 @@ class Environment(EnvironmentInterface):
|
||||
"""
|
||||
return np.transpose(self.state['observation'], [1, 2, 0])
|
||||
|
||||
|
||||
"""
|
||||
Video Dumping Methods
|
||||
"""
|
||||
|
||||
|
||||
class VideoDumpMethod(object):
|
||||
"""
|
||||
Method used to decide when to dump videos
|
||||
"""
|
||||
def should_dump(self, episode_terminated=False, **kwargs):
|
||||
raise NotImplementedError("")
|
||||
|
||||
|
||||
class AlwaysDumpMethod(VideoDumpMethod):
|
||||
"""
|
||||
Dump video for every episode
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def should_dump(self, episode_terminated=False, **kwargs):
|
||||
return True
|
||||
|
||||
|
||||
class MaxDumpMethod(VideoDumpMethod):
|
||||
"""
|
||||
Dump video every time a new max total reward has been achieved
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max_reward_achieved = -np.inf
|
||||
|
||||
def should_dump(self, episode_terminated=False, **kwargs):
|
||||
# if the episode has not finished yet we want to be prepared for dumping a video
|
||||
if not episode_terminated:
|
||||
return True
|
||||
if kwargs['total_reward_in_current_episode'] > self.max_reward_achieved:
|
||||
self.max_reward_achieved = kwargs['total_reward_in_current_episode']
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class EveryNEpisodesDumpMethod(object):
|
||||
"""
|
||||
Dump videos once in every N episodes
|
||||
"""
|
||||
def __init__(self, num_episodes_between_dumps: int):
|
||||
super().__init__()
|
||||
self.num_episodes_between_dumps = num_episodes_between_dumps
|
||||
self.last_dumped_episode = 0
|
||||
if num_episodes_between_dumps < 1:
|
||||
raise ValueError("the number of episodes between dumps should be a positive number")
|
||||
|
||||
def should_dump(self, episode_terminated=False, **kwargs):
|
||||
if kwargs['episode_idx'] >= self.last_dumped_episode + self.num_episodes_between_dumps - 1:
|
||||
self.last_dumped_episode = kwargs['episode_idx']
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class SelectedPhaseOnlyDumpMethod(object):
|
||||
"""
|
||||
Dump videos when the phase of the environment matches a predefined phase
|
||||
"""
|
||||
def __init__(self, run_phases: Union[RunPhase, List[RunPhase]]):
|
||||
self.run_phases = force_list(run_phases)
|
||||
|
||||
def should_dump(self, episode_terminated=False, **kwargs):
|
||||
if kwargs['_phase'] in self.run_phases:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
@@ -325,8 +325,7 @@ class GraphManager(object):
|
||||
"""
|
||||
self.total_steps_counters[self.phase][EnvironmentEpisodes] += 1
|
||||
|
||||
# TODO: we should disentangle ending the episode from resetting the internal state
|
||||
# self.reset_internal_state()
|
||||
[environment.handle_episode_ended() for environment in self.environments]
|
||||
|
||||
def train(self, steps: TrainingSteps) -> None:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user