mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
149 lines
6.9 KiB
Python
149 lines
6.9 KiB
Python
import os
|
|
from os.path import basename
|
|
|
|
import pandas as pd
|
|
|
|
from dashboard_components.globals import x_axis_options, add_directory_csv_files, show_spinner
|
|
from dashboard_components.signals_file import SignalsFile
|
|
from dashboard_components.signals_file_base import SignalsFileBase
|
|
|
|
|
|
class SignalsFilesGroup(SignalsFileBase):
|
|
def __init__(self, csv_paths, plot=None):
|
|
super().__init__(plot)
|
|
self.full_csv_paths = csv_paths
|
|
self.signals_files = []
|
|
if len(csv_paths) == 1 and os.path.isdir(csv_paths[0]):
|
|
self.signals_files = [SignalsFile(str(file), load=False, plot=plot) for file in add_directory_csv_files(csv_paths[0])]
|
|
else:
|
|
for csv_path in csv_paths:
|
|
if os.path.isdir(csv_path):
|
|
self.signals_files.append(SignalsFilesGroup(add_directory_csv_files(csv_path), plot=plot))
|
|
else:
|
|
self.signals_files.append(SignalsFile(str(csv_path), load=False, plot=plot))
|
|
parent_directory_path = os.path.abspath(os.path.join(os.path.dirname(csv_paths[0]), '..'))
|
|
|
|
if len(csv_paths) == 1 and len(os.listdir(parent_directory_path)) == 1:
|
|
# get the parent directory name (since the current directory is the timestamp directory)
|
|
self.dir = os.path.abspath(os.path.join(os.path.dirname(csv_paths[0]), '..'))
|
|
else:
|
|
# get the common directory for all the experiments
|
|
self.dir = os.path.dirname(os.path.commonprefix(csv_paths))
|
|
|
|
self.filename = '{} - Group({})'.format(basename(self.dir), len(self.signals_files))
|
|
|
|
self.load()
|
|
|
|
def load_csv(self):
|
|
corrupted_files_idx = []
|
|
for idx, signal_file in enumerate(self.signals_files):
|
|
signal_file.load_csv()
|
|
if not all(option in signal_file.csv.keys() for option in x_axis_options):
|
|
print("Warning: {} file seems to be corrupted and does contain the necessary columns "
|
|
"and will not be rendered".format(signal_file.filename))
|
|
corrupted_files_idx.append(idx)
|
|
|
|
for file_idx in corrupted_files_idx:
|
|
del self.signals_files[file_idx]
|
|
|
|
# get the stats of all the columns
|
|
if len(self.signals_files) > 1:
|
|
csv_group = pd.concat([signals_file.csv for signals_file in self.signals_files])
|
|
columns_to_remove = [s for s in csv_group.columns if '/Stdev' in s] + \
|
|
[s for s in csv_group.columns if '/Min' in s] + \
|
|
[s for s in csv_group.columns if '/Max' in s]
|
|
for col in columns_to_remove:
|
|
del csv_group[col]
|
|
csv_group = csv_group.groupby(csv_group.index)
|
|
self.csv_mean = csv_group.mean()
|
|
self.csv_mean.columns = [s + '/Mean' for s in self.csv_mean.columns]
|
|
self.csv_stdev = csv_group.std()
|
|
self.csv_stdev.columns = [s + '/Stdev' for s in self.csv_stdev.columns]
|
|
self.csv_min = csv_group.min()
|
|
self.csv_min.columns = [s + '/Min' for s in self.csv_min.columns]
|
|
self.csv_max = csv_group.max()
|
|
self.csv_max.columns = [s + '/Max' for s in self.csv_max.columns]
|
|
|
|
# get the indices from the file with the least number of indices and which is not an evaluation worker
|
|
file_with_min_indices = self.signals_files[0]
|
|
for signals_file in self.signals_files:
|
|
if signals_file.csv.shape[0] < file_with_min_indices.csv.shape[0] and \
|
|
'Training reward' in signals_file.csv.keys():
|
|
file_with_min_indices = signals_file
|
|
self.index_columns = file_with_min_indices.csv[x_axis_options]
|
|
|
|
# concat the stats and the indices columns
|
|
num_rows = file_with_min_indices.csv.shape[0]
|
|
self.csv = pd.concat([self.index_columns, self.csv_mean.head(num_rows), self.csv_stdev.head(num_rows),
|
|
self.csv_min.head(num_rows), self.csv_max.head(num_rows)], axis=1)
|
|
|
|
# remove the stat columns for the indices columns
|
|
columns_to_remove = [s + '/Mean' for s in x_axis_options] + \
|
|
[s + '/Stdev' for s in x_axis_options] + \
|
|
[s + '/Min' for s in x_axis_options] + \
|
|
[s + '/Max' for s in x_axis_options]
|
|
for col in columns_to_remove:
|
|
del self.csv[col]
|
|
else: # This is a group of a single file
|
|
self.csv = self.signals_files[0].csv
|
|
|
|
# # convert wall clock time to minutes - isn't needed because the sub-signals are already scaled
|
|
# self.csv['Wall-Clock Time'] /= 60.
|
|
|
|
# remove NaNs
|
|
self.csv.fillna(value=0, inplace=True) # removing this line will make bollinger bands fail
|
|
for key in self.csv.keys():
|
|
if 'Stdev' in key and 'Evaluation' not in key:
|
|
self.csv[key] = self.csv[key].fillna(value=0)
|
|
|
|
for signal_file in self.signals_files:
|
|
signal_file.update_source_and_signals()
|
|
|
|
def reload_data(self):
|
|
for signal_file in self.signals_files:
|
|
signal_file.reload_data()
|
|
SignalsFileBase.reload_data(self)
|
|
|
|
def update_x_axis_index(self):
|
|
for signal_file in self.signals_files:
|
|
signal_file.update_x_axis_index()
|
|
SignalsFileBase.update_x_axis_index(self)
|
|
|
|
def toggle_y_axis(self, signal_name=None):
|
|
for signal in self.signals.values():
|
|
if signal.selected:
|
|
signal.toggle_axis()
|
|
for signal_file in self.signals_files:
|
|
signal_file.toggle_y_axis(signal.name)
|
|
|
|
def change_averaging_window(self, new_size, force=False, signals=None):
|
|
for signal_file in self.signals_files:
|
|
signal_file.change_averaging_window(new_size, force, signals)
|
|
SignalsFileBase.change_averaging_window(self, new_size, force, signals)
|
|
|
|
def set_signal_selection(self, signal_name, val):
|
|
self.show_files_separately(self.separate_files)
|
|
SignalsFileBase.set_signal_selection(self, signal_name, val)
|
|
|
|
def file_was_modified_on_disk(self):
|
|
for signal_file in self.signals_files:
|
|
if signal_file.file_was_modified_on_disk():
|
|
return True
|
|
return False
|
|
|
|
def show_files_separately(self, val):
|
|
self.separate_files = val
|
|
for signal in self.signals.values():
|
|
if signal.selected:
|
|
if val:
|
|
signal.set_dash("4 4")
|
|
else:
|
|
signal.set_dash("")
|
|
for signal_file in self.signals_files:
|
|
try:
|
|
if val:
|
|
signal_file.set_signal_selection(signal.name, signal.selected)
|
|
else:
|
|
signal_file.set_signal_selection(signal.name, False)
|
|
except:
|
|
pass |