# # Copyright (c) 2017 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ To run Coach Dashboard, run the following command: python3 dashboard.py """ from utils import * import os import datetime import sys import wx import random import pandas as pd from pandas.io.common import EmptyDataError import numpy as np from bokeh.palettes import Dark2 from bokeh.layouts import row, column, widgetbox, Spacer from bokeh.models import ColumnDataSource, Range1d, LinearAxis, HoverTool, WheelZoomTool, PanTool from bokeh.models.widgets import RadioButtonGroup, MultiSelect, Button, Select, Slider, Div, CheckboxGroup from bokeh.models.glyphs import Patch from bokeh.plotting import figure, show, curdoc from utils import force_list from utils import squeeze_list from itertools import cycle from os import listdir from os.path import isfile, join, isdir, basename from enum import Enum class DialogApp(wx.App): def getFileDialog(self): with wx.FileDialog(None, "Open CSV file", wildcard="CSV files (*.csv)|*.csv", style=wx.FD_OPEN | wx.FD_FILE_MUST_EXIST | wx.FD_CHANGE_DIR | wx.FD_MULTIPLE) as fileDialog: if fileDialog.ShowModal() == wx.ID_CANCEL: return None # the user changed their mind else: # Proceed loading the file chosen by the user return fileDialog.GetPaths() def getDirDialog(self): with wx.DirDialog (None, "Choose input directory", "", style=wx.FD_OPEN | wx.FD_FILE_MUST_EXIST | wx.FD_CHANGE_DIR) as dirDialog: if dirDialog.ShowModal() == wx.ID_CANCEL: return None # the user changed their mind else: # Proceed loading the dir chosen by the user return dirDialog.GetPath() class Signal: def __init__(self, name, parent): self.name = name self.full_name = "{}/{}".format(parent.filename, self.name) self.selected = False self.color = random.choice(Dark2[8]) self.line = None self.bands = None self.bokeh_source = parent.bokeh_source self.min_val = 0 self.max_val = 0 self.axis = 'default' self.sub_signals = [] for name in self.bokeh_source.data.keys(): if (len(name.split('/')) == 1 and name == self.name) or '/'.join(name.split('/')[:-1]) == self.name: self.sub_signals.append(name) if len(self.sub_signals) > 1: self.mean_signal = squeeze_list([name for name in self.sub_signals if 'Mean' in name.split('/')[-1]]) self.stdev_signal = squeeze_list([name for name in self.sub_signals if 'Stdev' in name.split('/')[-1]]) self.min_signal = squeeze_list([name for name in self.sub_signals if 'Min' in name.split('/')[-1]]) self.max_signal = squeeze_list([name for name in self.sub_signals if 'Max' in name.split('/')[-1]]) else: self.mean_signal = squeeze_list(self.name) self.stdev_signal = None self.min_signal = None self.max_signal = None self.has_bollinger_bands = False if self.mean_signal and self.stdev_signal and self.min_signal and self.max_signal: self.has_bollinger_bands = True self.show_bollinger_bands = False self.bollinger_bands_source = None self.update_range() def set_selected(self, val): global current_color if self.selected != val: self.selected = val if self.line: self.color = Dark2[8][current_color] current_color = (current_color + 1) % len(Dark2[8]) self.line.glyph.line_color = self.color self.line.visible = self.selected if self.bands: self.bands.glyph.fill_color = self.color self.bands.visible = self.selected and self.show_bollinger_bands elif self.selected: # lazy plotting - plot only when selected for the first time show_spinner() self.color = Dark2[8][current_color] current_color = (current_color + 1) % len(Dark2[8]) if self.has_bollinger_bands: self.set_bands_source() self.create_bands() self.line = plot.line('index', self.mean_signal, source=self.bokeh_source, line_color=self.color, line_width=2) self.line.visible = True hide_spinner() def set_dash(self, dash): self.line.glyph.line_dash = dash def create_bands(self): self.bands = plot.patch(x='band_x', y='band_y', source=self.bollinger_bands_source, color=self.color, fill_alpha=0.4, alpha=0.1, line_width=0) self.bands.visible = self.show_bollinger_bands # self.min_line = plot.line('index', self.min_signal, source=self.bokeh_source, # line_color=self.color, line_width=3, line_dash="4 4") # self.max_line = plot.line('index', self.max_signal, source=self.bokeh_source, # line_color=self.color, line_width=3, line_dash="4 4") # self.min_line.visible = self.show_bollinger_bands # self.max_line.visible = self.show_bollinger_bands def set_bands_source(self): x_ticks = self.bokeh_source.data['index'] mean_values = self.bokeh_source.data[self.mean_signal] stdev_values = self.bokeh_source.data[self.stdev_signal] band_x = np.append(x_ticks, x_ticks[::-1]) band_y = np.append(mean_values - stdev_values, mean_values[::-1] + stdev_values[::-1]) source_data = {'band_x': band_x, 'band_y': band_y} if self.bollinger_bands_source: self.bollinger_bands_source.data = source_data else: self.bollinger_bands_source = ColumnDataSource(source_data) def change_bollinger_bands_state(self, new_state): self.show_bollinger_bands = new_state if self.bands and self.selected: self.bands.visible = new_state # self.min_line.visible = new_state # self.max_line.visible = new_state def update_range(self): self.min_val = np.min(self.bokeh_source.data[self.mean_signal]) self.max_val = np.max(self.bokeh_source.data[self.mean_signal]) def set_axis(self, axis): self.axis = axis self.line.y_range_name = axis def toggle_axis(self): if self.axis == 'default': self.set_axis('secondary') else: self.set_axis('default') class SignalsFileBase: def __init__(self): self.full_csv_path = "" self.dir = "" self.filename = "" self.signals_averaging_window = 1 self.show_bollinger_bands = False self.csv = None self.bokeh_source = None self.bokeh_source_orig = None self.last_modified = None self.signals = {} self.separate_files = False def load_csv(self): pass def update_source_and_signals(self): # create bokeh data sources self.bokeh_source_orig = ColumnDataSource(self.csv) self.bokeh_source_orig.data['index'] = self.bokeh_source_orig.data[x_axis] if self.bokeh_source is None: self.bokeh_source = ColumnDataSource(self.csv) else: # self.bokeh_source.data = self.bokeh_source_orig.data # smooth the data if necessary self.change_averaging_window(self.signals_averaging_window, force=True) # create all the signals if len(self.signals.keys()) == 0: self.signals = {} unique_signal_names = [] for name in self.csv.columns: if len(name.split('/')) == 1: unique_signal_names.append(name) else: unique_signal_names.append('/'.join(name.split('/')[:-1])) unique_signal_names = list(set(unique_signal_names)) for signal_name in unique_signal_names: self.signals[signal_name] = Signal(signal_name, self) def load(self): self.load_csv() self.update_source_and_signals() def reload_data(self, signals): # this function is a workaround to reload the data of all the signals # if the data doesn't change, bokeh does not refreshes the line self.change_averaging_window(self.signals_averaging_window + 1, force=True) self.change_averaging_window(self.signals_averaging_window - 1, force=True) def change_averaging_window(self, new_size, force=False, signals=None): if force or self.signals_averaging_window != new_size: self.signals_averaging_window = new_size win = np.ones(new_size) / new_size temp_data = self.bokeh_source_orig.data.copy() for col in self.bokeh_source.data.keys(): if col == 'index' or col in x_axis_options \ or (signals and not any(col in signal for signal in signals)): temp_data[col] = temp_data[col][:-new_size] continue temp_data[col] = np.convolve(self.bokeh_source_orig.data[col], win, mode='same')[:-new_size] self.bokeh_source.data = temp_data # smooth bollinger bands for signal in self.signals.values(): if signal.has_bollinger_bands: signal.set_bands_source() def hide_all_signals(self): for signal_name in self.signals.keys(): self.set_signal_selection(signal_name, False) def set_signal_selection(self, signal_name, val): self.signals[signal_name].set_selected(val) def change_bollinger_bands_state(self, new_state): self.show_bollinger_bands = new_state for signal in self.signals.values(): signal.change_bollinger_bands_state(new_state) def file_was_modified_on_disk(self): pass def get_range_of_selected_signals_on_axis(self, axis): max_val = -float('inf') min_val = float('inf') for signal in self.signals.values(): if signal.selected and signal.axis == axis: max_val = max(max_val, signal.max_val) min_val = min(min_val, signal.min_val) return min_val, max_val def get_selected_signals(self): signals = [] for signal in self.signals.values(): if signal.selected: signals.append(signal) return signals def show_files_separately(self, val): pass class SignalsFile(SignalsFileBase): def __init__(self, csv_path, load=True): SignalsFileBase.__init__(self) self.full_csv_path = csv_path self.dir, self.filename, _ = break_file_path(csv_path) if load: self.load() def load_csv(self): # load csv and fix sparse data. # csv can be in the middle of being written so we use try - except self.csv = None while self.csv is None: try: self.csv = pd.read_csv(self.full_csv_path) break except EmptyDataError: self.csv = None continue self.csv = self.csv.interpolate() self.csv.fillna(value=0, inplace=True) self.last_modified = os.path.getmtime(self.full_csv_path) def file_was_modified_on_disk(self): return self.last_modified != os.path.getmtime(self.full_csv_path) class SignalsFilesGroup(SignalsFileBase): def __init__(self, csv_paths): SignalsFileBase.__init__(self) self.full_csv_paths = csv_paths self.signals_files = [] if len(csv_paths) == 1 and os.path.isdir(csv_paths[0]): self.signals_files = [SignalsFile(str(file), load=False) for file in add_directory_csv_files(csv_paths[0])] else: for csv_path in csv_paths: if os.path.isdir(csv_path): self.signals_files.append(SignalsFilesGroup(add_directory_csv_files(csv_path))) else: self.signals_files.append(SignalsFile(str(csv_path), load=False)) self.dir = os.path.dirname(os.path.commonprefix(csv_paths)) self.filename = '{} - Group({})'.format(basename(self.dir), len(self.signals_files)) self.load() def load_csv(self): corrupted_files_idx = [] for idx, signal_file in enumerate(self.signals_files): signal_file.load_csv() if not all(option in signal_file.csv.keys() for option in x_axis_options): print("Warning: {} file seems to be corrupted and does contain the necessary columns " "and will not be rendered".format(signal_file.filename)) corrupted_files_idx.append(idx) for file_idx in corrupted_files_idx: del self.signals_files[file_idx] # get the stats of all the columns csv_group = pd.concat([signals_file.csv for signals_file in self.signals_files]) columns_to_remove = [s for s in csv_group.columns if '/Stdev' in s] + \ [s for s in csv_group.columns if '/Min' in s] + \ [s for s in csv_group.columns if '/Max' in s] for col in columns_to_remove: del csv_group[col] csv_group = csv_group.groupby(csv_group.index) self.csv_mean = csv_group.mean() self.csv_mean.columns = [s + '/Mean' for s in self.csv_mean.columns] self.csv_stdev = csv_group.std() self.csv_stdev.columns = [s + '/Stdev' for s in self.csv_stdev.columns] self.csv_min = csv_group.min() self.csv_min.columns = [s + '/Min' for s in self.csv_min.columns] self.csv_max = csv_group.max() self.csv_max.columns = [s + '/Max' for s in self.csv_max.columns] # get the indices from the file with the least number of indices and which is not an evaluation worker file_with_min_indices = self.signals_files[0] for signals_file in self.signals_files: if signals_file.csv.shape[0] < file_with_min_indices.csv.shape[0] and \ 'Training reward' in signals_file.csv.keys(): file_with_min_indices = signals_file self.index_columns = file_with_min_indices.csv[x_axis_options] # concat the stats and the indices columns num_rows = file_with_min_indices.csv.shape[0] self.csv = pd.concat([self.index_columns, self.csv_mean.head(num_rows), self.csv_stdev.head(num_rows), self.csv_min.head(num_rows), self.csv_max.head(num_rows)], axis=1) # remove the stat columns for the indices columns columns_to_remove = [s + '/Mean' for s in x_axis_options] + \ [s + '/Stdev' for s in x_axis_options] + \ [s + '/Min' for s in x_axis_options] + \ [s + '/Max' for s in x_axis_options] for col in columns_to_remove: del self.csv[col] # remove NaNs # self.csv.fillna(value=0, inplace=True) for key in self.csv.keys(): if 'Stdev' in key and 'Evaluation' not in key: self.csv[key] = self.csv[key].fillna(value=0) for signal_file in self.signals_files: signal_file.update_source_and_signals() def change_averaging_window(self, new_size, force=False, signals=None): for signal_file in self.signals_files: signal_file.change_averaging_window(new_size, force, signals) SignalsFileBase.change_averaging_window(self, new_size, force, signals) def set_signal_selection(self, signal_name, val): self.show_files_separately(self.separate_files) SignalsFileBase.set_signal_selection(self, signal_name, val) def file_was_modified_on_disk(self): for signal_file in self.signals_files: if signal_file.file_was_modified_on_disk(): return True return False def show_files_separately(self, val): self.separate_files = val for signal in self.signals.values(): if signal.selected: if val: signal.set_dash("4 4") else: signal.set_dash("") for signal_file in self.signals_files: try: if val: signal_file.set_signal_selection(signal.name, signal.selected) else: signal_file.set_signal_selection(signal.name, False) except: pass class RunType(Enum): SINGLE_FOLDER_SINGLE_FILE = 1 SINGLE_FOLDER_MULTIPLE_FILES = 2 MULTIPLE_FOLDERS_SINGLE_FILES = 3 MULTIPLE_FOLDERS_MULTIPLE_FILES = 4 UNKNOWN = 0 class FolderType(Enum): SINGLE_FILE = 1 MULTIPLE_FILES = 2 MULTIPLE_FOLDERS = 3 EMPTY = 4 dialog = DialogApp() # read data patches = {} signals_files = {} selected_file = None x_axis = 'Episode #' x_axis_options = ['Episode #', 'Total steps', 'Wall-Clock Time'] current_color = 0 # spinner root_dir = os.path.dirname(os.path.abspath(__file__)) with open(os.path.join(root_dir, 'spinner.css'), 'r') as f: spinner_style = """""".format(f.read()) spinner_html = """