diff --git a/agents/agent.py b/agents/agent.py index 717ab76..274649e 100644 --- a/agents/agent.py +++ b/agents/agent.py @@ -22,6 +22,7 @@ except: failed_imports.append("matplotlib") import copy +from renderer import Renderer from configurations import Preset from collections import OrderedDict from utils import RunPhase, Signal, is_empty, RunningStat @@ -101,6 +102,7 @@ class Agent(object): self.main_network = None self.networks = [] self.last_episode_images = [] + self.renderer = Renderer() # signals self.signals = [] @@ -232,6 +234,13 @@ class Agent(object): r, g, b = observation[:, :, 0], observation[:, :, 1], observation[:, :, 2] observation = 0.2989 * r + 0.5870 * g + 0.1140 * b + # Render the processed observation which is how the agent will see it + # Warning: this cannot currently be done in parallel to rendering the environment + if self.tp.visualization.render_observation: + if not self.renderer.is_open: + self.renderer.create_screen(observation.shape[0], observation.shape[1]) + self.renderer.render_image(observation) + return observation.astype('uint8') else: if self.tp.env.normalize_observation: diff --git a/configurations.py b/configurations.py index 5815c5a..f929642 100644 --- a/configurations.py +++ b/configurations.py @@ -248,6 +248,7 @@ class VisualizationParameters(Parameters): video_path = '/home/llt_lab/temp/breakout-videos' plot_action_values_online = False show_saliency_maps_every_num_episodes = 1000000000 + render_observation = False print_summary = False dump_csv = True dump_signals_to_csv_every_x_episodes = 5 diff --git a/logger.py b/logger.py index b5a1302..ff05d6a 100644 --- a/logger.py +++ b/logger.py @@ -241,7 +241,7 @@ class Logger(BaseLogger): pil_images[0].save(output_path, save_all=True, append_images=pil_images[1:], duration=1.0 / fps, loop=0) def remove_experiment_dir(self): - os.removedirs(self.experiments_path) + shutil.rmtree(self.experiments_path) def print_summary(self): screen.separator() @@ -252,7 +252,7 @@ class Logger(BaseLogger): screen.separator() if screen.ask_yes_no("Do you want to discard the experiment results (Warning: this cannot be undone)?", False): self.remove_experiment_dir() - if screen.ask_yes_no("Do you want to specify a different experiment name to save to?", False): + elif screen.ask_yes_no("Do you want to specify a different experiment name to save to?", False): new_name = self.get_experiment_name() new_path = self.get_experiment_path(new_name, create_path=False) shutil.move(self.experiments_path, new_path) diff --git a/renderer.py b/renderer.py index cddc810..fee19af 100644 --- a/renderer.py +++ b/renderer.py @@ -41,6 +41,8 @@ class Renderer(object): :return: None """ if self.is_open: + if len(image.shape) == 2: + image = np.stack([image] * 3) if len(image.shape) == 3: if image.shape[0] == 3 or image.shape[0] == 1: image = np.transpose(image, (1, 2, 0))