mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
allow visualizing the observation + bug fixes to coach summary
This commit is contained in:
@@ -22,6 +22,7 @@ except:
|
|||||||
failed_imports.append("matplotlib")
|
failed_imports.append("matplotlib")
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
from renderer import Renderer
|
||||||
from configurations import Preset
|
from configurations import Preset
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from utils import RunPhase, Signal, is_empty, RunningStat
|
from utils import RunPhase, Signal, is_empty, RunningStat
|
||||||
@@ -101,6 +102,7 @@ class Agent(object):
|
|||||||
self.main_network = None
|
self.main_network = None
|
||||||
self.networks = []
|
self.networks = []
|
||||||
self.last_episode_images = []
|
self.last_episode_images = []
|
||||||
|
self.renderer = Renderer()
|
||||||
|
|
||||||
# signals
|
# signals
|
||||||
self.signals = []
|
self.signals = []
|
||||||
@@ -232,6 +234,13 @@ class Agent(object):
|
|||||||
r, g, b = observation[:, :, 0], observation[:, :, 1], observation[:, :, 2]
|
r, g, b = observation[:, :, 0], observation[:, :, 1], observation[:, :, 2]
|
||||||
observation = 0.2989 * r + 0.5870 * g + 0.1140 * b
|
observation = 0.2989 * r + 0.5870 * g + 0.1140 * b
|
||||||
|
|
||||||
|
# Render the processed observation which is how the agent will see it
|
||||||
|
# Warning: this cannot currently be done in parallel to rendering the environment
|
||||||
|
if self.tp.visualization.render_observation:
|
||||||
|
if not self.renderer.is_open:
|
||||||
|
self.renderer.create_screen(observation.shape[0], observation.shape[1])
|
||||||
|
self.renderer.render_image(observation)
|
||||||
|
|
||||||
return observation.astype('uint8')
|
return observation.astype('uint8')
|
||||||
else:
|
else:
|
||||||
if self.tp.env.normalize_observation:
|
if self.tp.env.normalize_observation:
|
||||||
|
|||||||
@@ -248,6 +248,7 @@ class VisualizationParameters(Parameters):
|
|||||||
video_path = '/home/llt_lab/temp/breakout-videos'
|
video_path = '/home/llt_lab/temp/breakout-videos'
|
||||||
plot_action_values_online = False
|
plot_action_values_online = False
|
||||||
show_saliency_maps_every_num_episodes = 1000000000
|
show_saliency_maps_every_num_episodes = 1000000000
|
||||||
|
render_observation = False
|
||||||
print_summary = False
|
print_summary = False
|
||||||
dump_csv = True
|
dump_csv = True
|
||||||
dump_signals_to_csv_every_x_episodes = 5
|
dump_signals_to_csv_every_x_episodes = 5
|
||||||
|
|||||||
@@ -241,7 +241,7 @@ class Logger(BaseLogger):
|
|||||||
pil_images[0].save(output_path, save_all=True, append_images=pil_images[1:], duration=1.0 / fps, loop=0)
|
pil_images[0].save(output_path, save_all=True, append_images=pil_images[1:], duration=1.0 / fps, loop=0)
|
||||||
|
|
||||||
def remove_experiment_dir(self):
|
def remove_experiment_dir(self):
|
||||||
os.removedirs(self.experiments_path)
|
shutil.rmtree(self.experiments_path)
|
||||||
|
|
||||||
def print_summary(self):
|
def print_summary(self):
|
||||||
screen.separator()
|
screen.separator()
|
||||||
@@ -252,7 +252,7 @@ class Logger(BaseLogger):
|
|||||||
screen.separator()
|
screen.separator()
|
||||||
if screen.ask_yes_no("Do you want to discard the experiment results (Warning: this cannot be undone)?", False):
|
if screen.ask_yes_no("Do you want to discard the experiment results (Warning: this cannot be undone)?", False):
|
||||||
self.remove_experiment_dir()
|
self.remove_experiment_dir()
|
||||||
if screen.ask_yes_no("Do you want to specify a different experiment name to save to?", False):
|
elif screen.ask_yes_no("Do you want to specify a different experiment name to save to?", False):
|
||||||
new_name = self.get_experiment_name()
|
new_name = self.get_experiment_name()
|
||||||
new_path = self.get_experiment_path(new_name, create_path=False)
|
new_path = self.get_experiment_path(new_name, create_path=False)
|
||||||
shutil.move(self.experiments_path, new_path)
|
shutil.move(self.experiments_path, new_path)
|
||||||
|
|||||||
@@ -41,6 +41,8 @@ class Renderer(object):
|
|||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
if self.is_open:
|
if self.is_open:
|
||||||
|
if len(image.shape) == 2:
|
||||||
|
image = np.stack([image] * 3)
|
||||||
if len(image.shape) == 3:
|
if len(image.shape) == 3:
|
||||||
if image.shape[0] == 3 or image.shape[0] == 1:
|
if image.shape[0] == 3 or image.shape[0] == 1:
|
||||||
image = np.transpose(image, (1, 2, 0))
|
image = np.transpose(image, (1, 2, 0))
|
||||||
|
|||||||
Reference in New Issue
Block a user