1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

allow visualizing the observation + bug fixes to coach summary

This commit is contained in:
Itai Caspi
2018-02-13 18:47:24 +02:00
committed by Itai Caspi
parent 5d1a2bc392
commit 55c8c87afc
4 changed files with 14 additions and 2 deletions

View File

@@ -22,6 +22,7 @@ except:
failed_imports.append("matplotlib") failed_imports.append("matplotlib")
import copy import copy
from renderer import Renderer
from configurations import Preset from configurations import Preset
from collections import OrderedDict from collections import OrderedDict
from utils import RunPhase, Signal, is_empty, RunningStat from utils import RunPhase, Signal, is_empty, RunningStat
@@ -101,6 +102,7 @@ class Agent(object):
self.main_network = None self.main_network = None
self.networks = [] self.networks = []
self.last_episode_images = [] self.last_episode_images = []
self.renderer = Renderer()
# signals # signals
self.signals = [] self.signals = []
@@ -232,6 +234,13 @@ class Agent(object):
r, g, b = observation[:, :, 0], observation[:, :, 1], observation[:, :, 2] r, g, b = observation[:, :, 0], observation[:, :, 1], observation[:, :, 2]
observation = 0.2989 * r + 0.5870 * g + 0.1140 * b observation = 0.2989 * r + 0.5870 * g + 0.1140 * b
# Render the processed observation which is how the agent will see it
# Warning: this cannot currently be done in parallel to rendering the environment
if self.tp.visualization.render_observation:
if not self.renderer.is_open:
self.renderer.create_screen(observation.shape[0], observation.shape[1])
self.renderer.render_image(observation)
return observation.astype('uint8') return observation.astype('uint8')
else: else:
if self.tp.env.normalize_observation: if self.tp.env.normalize_observation:

View File

@@ -248,6 +248,7 @@ class VisualizationParameters(Parameters):
video_path = '/home/llt_lab/temp/breakout-videos' video_path = '/home/llt_lab/temp/breakout-videos'
plot_action_values_online = False plot_action_values_online = False
show_saliency_maps_every_num_episodes = 1000000000 show_saliency_maps_every_num_episodes = 1000000000
render_observation = False
print_summary = False print_summary = False
dump_csv = True dump_csv = True
dump_signals_to_csv_every_x_episodes = 5 dump_signals_to_csv_every_x_episodes = 5

View File

@@ -241,7 +241,7 @@ class Logger(BaseLogger):
pil_images[0].save(output_path, save_all=True, append_images=pil_images[1:], duration=1.0 / fps, loop=0) pil_images[0].save(output_path, save_all=True, append_images=pil_images[1:], duration=1.0 / fps, loop=0)
def remove_experiment_dir(self): def remove_experiment_dir(self):
os.removedirs(self.experiments_path) shutil.rmtree(self.experiments_path)
def print_summary(self): def print_summary(self):
screen.separator() screen.separator()
@@ -252,7 +252,7 @@ class Logger(BaseLogger):
screen.separator() screen.separator()
if screen.ask_yes_no("Do you want to discard the experiment results (Warning: this cannot be undone)?", False): if screen.ask_yes_no("Do you want to discard the experiment results (Warning: this cannot be undone)?", False):
self.remove_experiment_dir() self.remove_experiment_dir()
if screen.ask_yes_no("Do you want to specify a different experiment name to save to?", False): elif screen.ask_yes_no("Do you want to specify a different experiment name to save to?", False):
new_name = self.get_experiment_name() new_name = self.get_experiment_name()
new_path = self.get_experiment_path(new_name, create_path=False) new_path = self.get_experiment_path(new_name, create_path=False)
shutil.move(self.experiments_path, new_path) shutil.move(self.experiments_path, new_path)

View File

@@ -41,6 +41,8 @@ class Renderer(object):
:return: None :return: None
""" """
if self.is_open: if self.is_open:
if len(image.shape) == 2:
image = np.stack([image] * 3)
if len(image.shape) == 3: if len(image.shape) == 3:
if image.shape[0] == 3 or image.shape[0] == 1: if image.shape[0] == 3 or image.shape[0] == 1:
image = np.transpose(image, (1, 2, 0)) image = np.transpose(image, (1, 2, 0))