1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

bug-fix for dumping movies (+ small refactoring and rename 'VideoDumpMethod -> 'VideoDumpFilter')

This commit is contained in:
Gal Leibovich
2018-10-21 17:29:10 +03:00
parent 364168490f
commit 5a8da90d32
4 changed files with 98 additions and 90 deletions

View File

@@ -22,8 +22,8 @@ from collections import OrderedDict
from enum import Enum
from typing import Dict, List, Union
from rl_coach.core_types import TrainingSteps, EnvironmentSteps, GradientClippingMethod, RunPhase
# from rl_coach.environments.environment import SelectedPhaseOnlyDumpMethod, MaxDumpMethod
from rl_coach.core_types import TrainingSteps, EnvironmentSteps, GradientClippingMethod, RunPhase, \
SelectedPhaseOnlyDumpFilter, MaxDumpFilter
from rl_coach.filters.filter import NoInputFilter
@@ -285,7 +285,6 @@ class NetworkComponentParameters(Parameters):
self.dense_layer = dense_layer
class VisualizationParameters(Parameters):
def __init__(self,
print_networks_summary=False,
@@ -293,7 +292,7 @@ class VisualizationParameters(Parameters):
dump_signals_to_csv_every_x_episodes=5,
dump_gifs=False,
dump_mp4=False,
video_dump_methods=[],
video_dump_methods=None,
dump_in_episode_signals=False,
dump_parameters_documentation=True,
render=False,
@@ -352,6 +351,8 @@ class VisualizationParameters(Parameters):
which will be passed to the agent and allow using those images.
"""
super().__init__()
if video_dump_methods is None:
video_dump_methods = [SelectedPhaseOnlyDumpFilter(RunPhase.TEST), MaxDumpFilter()]
self.print_networks_summary = print_networks_summary
self.dump_csv = dump_csv
self.dump_gifs = dump_gifs
@@ -363,7 +364,7 @@ class VisualizationParameters(Parameters):
self.native_rendering = native_rendering
self.max_fps_for_human_control = max_fps_for_human_control
self.tensorboard = tensorboard
self.video_dump_methods = video_dump_methods
self.video_dump_filters = video_dump_methods
self.add_rendered_image_to_env_response = add_rendered_image_to_env_response