mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
209 lines
9.5 KiB
Python
209 lines
9.5 KiB
Python
#
|
|
# Copyright (c) 2017 Intel Corporation
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
|
|
import os
|
|
from multiprocessing import Process, Manager
|
|
from os.path import basename
|
|
|
|
import pandas as pd
|
|
from rl_coach.dashboard_components.globals import x_axis_options, add_directory_csv_files, show_spinner, x_axis
|
|
from rl_coach.dashboard_components.signals_file_base import SignalsFileBase
|
|
|
|
from rl_coach.dashboard_components.signals_file import SignalsFile
|
|
|
|
|
|
class SignalsFilesGroup(SignalsFileBase):
|
|
def __init__(self, csv_paths, plot=None):
|
|
super().__init__(plot)
|
|
self.full_csv_paths = csv_paths
|
|
self.signals_files = []
|
|
if len(csv_paths) == 1 and os.path.isdir(csv_paths[0]):
|
|
self.signals_files = [SignalsFile(str(file), load=False, plot=plot) for file in add_directory_csv_files(csv_paths[0])]
|
|
else:
|
|
for csv_path in csv_paths:
|
|
if os.path.isdir(csv_path):
|
|
self.signals_files.append(SignalsFilesGroup(add_directory_csv_files(csv_path), plot=plot))
|
|
else:
|
|
self.signals_files.append(SignalsFile(str(csv_path), load=False, plot=plot))
|
|
parent_directory_path = os.path.abspath(os.path.join(os.path.dirname(csv_paths[0]), '..'))
|
|
|
|
if len(os.listdir(parent_directory_path)) == 1:
|
|
# get the parent directory name (since the current directory is the timestamp directory)
|
|
self.dir = parent_directory_path
|
|
else:
|
|
# get the common directory for all the experiments
|
|
self.dir = os.path.dirname('/'.join(os.path.commonprefix(csv_paths).split('/')[:-1]) + '/')
|
|
|
|
self.filename = '{} - Group({})'.format(basename(self.dir), len(self.signals_files))
|
|
|
|
self.signal_files_need_update = False
|
|
|
|
self.load()
|
|
|
|
def load_csv(self):
|
|
global x_axis
|
|
# load the csv's for all workers
|
|
processes = []
|
|
results = Manager().dict()
|
|
corrupted_files_idx = []
|
|
for idx, signal_file in enumerate(self.signals_files):
|
|
if not isinstance(signal_file, SignalsFilesGroup):
|
|
processes.append(Process(target=signal_file.load_csv, args=(idx, results)))
|
|
processes[-1].start()
|
|
[p.join() for p in processes]
|
|
|
|
# load csv's for SignalsFilesGroup serially for now. TODO: we should later parallelize this as well.
|
|
for idx, signal_file in enumerate(self.signals_files):
|
|
if isinstance(signal_file, SignalsFilesGroup):
|
|
signal_file.load_csv()
|
|
|
|
for idx, signal_file in enumerate(self.signals_files):
|
|
if len(list(results.keys())) > 0:
|
|
signal_file.csv, signal_file.last_modified = results[idx]
|
|
if not all(option in signal_file.csv.keys() for option in x_axis_options):
|
|
print("Warning: {} file seems to be corrupted and does contain the necessary columns "
|
|
"and will not be rendered".format(signal_file.filename))
|
|
corrupted_files_idx.append(idx)
|
|
|
|
# remove corrupted worker files
|
|
for file_idx in corrupted_files_idx:
|
|
del self.signals_files[file_idx]
|
|
|
|
# get the stats of all the columns
|
|
if len(self.signals_files) > 1:
|
|
transformed_signals_files = []
|
|
subsampling = None
|
|
for idx in range(len(self.signals_files)):
|
|
transformed_signals_files.append(self.signals_files[idx].csv.copy(deep=True))
|
|
|
|
# change the index to be the currently selected x axis
|
|
transformed_signals_files[-1].index = transformed_signals_files[-1][x_axis[0]]
|
|
|
|
# remove all duplicate index rows
|
|
transformed_signals_files[-1] = transformed_signals_files[-1][~transformed_signals_files[-1].index.duplicated()]
|
|
|
|
# fill up missing row indices. we are going to take the mean over the group and we want to make sure
|
|
# the entire group has some value for every possible index.
|
|
num_rows = int(transformed_signals_files[-1].index.values[-1])
|
|
transformed_signals_files[-1] = transformed_signals_files[-1].reindex(range(num_rows))
|
|
transformed_signals_files[-1].interpolate(inplace=True)
|
|
|
|
# sub sample the csv to max of 5000 indices (do the same subsampling to all files)
|
|
if subsampling is None:
|
|
subsampling = max(1, num_rows // 5000)
|
|
transformed_signals_files[-1] = transformed_signals_files[-1].iloc[::subsampling, :]
|
|
|
|
csv_group = pd.concat([signals_file for signals_file in transformed_signals_files])
|
|
columns_to_remove = [s for s in csv_group.columns if '/Stdev' in s] + \
|
|
[s for s in csv_group.columns if '/Min' in s] + \
|
|
[s for s in csv_group.columns if '/Max' in s]
|
|
for col in columns_to_remove:
|
|
del csv_group[col]
|
|
csv_group = csv_group.groupby(csv_group.index)
|
|
self.csv_mean = csv_group.mean()
|
|
self.csv_mean.columns = [s + '/Mean' for s in self.csv_mean.columns]
|
|
self.csv_stdev = csv_group.std()
|
|
self.csv_stdev.columns = [s + '/Stdev' for s in self.csv_stdev.columns]
|
|
self.csv_min = csv_group.min()
|
|
self.csv_min.columns = [s + '/Min' for s in self.csv_min.columns]
|
|
self.csv_max = csv_group.max()
|
|
self.csv_max.columns = [s + '/Max' for s in self.csv_max.columns]
|
|
|
|
# get the indices from the file with the least number of indices and which is not an evaluation worker
|
|
file_with_min_indices = transformed_signals_files[0]
|
|
for signals_file in transformed_signals_files:
|
|
if signals_file.shape[0] < file_with_min_indices.shape[0] and \
|
|
'Training reward' in signals_file.keys():
|
|
file_with_min_indices = signals_file
|
|
self.index_columns = file_with_min_indices[x_axis_options]
|
|
|
|
# concat the stats and the indices columns
|
|
num_rows = file_with_min_indices.shape[0]
|
|
self.csv = pd.concat([self.index_columns, self.csv_mean.head(num_rows), self.csv_stdev.head(num_rows),
|
|
self.csv_min.head(num_rows), self.csv_max.head(num_rows)], axis=1)
|
|
|
|
# remove the stat columns for the indices columns
|
|
columns_to_remove = [s + '/Mean' for s in x_axis_options] + \
|
|
[s + '/Stdev' for s in x_axis_options] + \
|
|
[s + '/Min' for s in x_axis_options] + \
|
|
[s + '/Max' for s in x_axis_options]
|
|
for col in columns_to_remove:
|
|
if col in self.csv.keys():
|
|
del self.csv[col]
|
|
else: # This is a group of a single file
|
|
self.csv = self.signals_files[0].csv
|
|
|
|
# remove NaNs
|
|
self.csv.fillna(value=0, inplace=True) # removing this line will make bollinger bands fail
|
|
for key in self.csv.keys():
|
|
if 'Stdev' in key and 'Evaluation' not in key:
|
|
self.csv[key] = self.csv[key].fillna(value=0)
|
|
|
|
self.signal_files_need_update = True
|
|
|
|
def reload_data(self):
|
|
SignalsFileBase.reload_data(self)
|
|
|
|
def update_x_axis_index(self):
|
|
SignalsFileBase.update_x_axis_index(self)
|
|
|
|
# update the x axis for the bollinger bands
|
|
for signal in self.signals.values():
|
|
if signal.has_bollinger_bands:
|
|
signal.set_bands_source()
|
|
|
|
def toggle_y_axis(self, signal_name=None):
|
|
for signal in self.signals.values():
|
|
if signal.selected:
|
|
signal.toggle_axis()
|
|
|
|
def change_averaging_window(self, new_size, force=False, signals=None):
|
|
SignalsFileBase.change_averaging_window(self, new_size, force, signals)
|
|
|
|
def set_signal_selection(self, signal_name, val):
|
|
self.show_files_separately(self.separate_files)
|
|
SignalsFileBase.set_signal_selection(self, signal_name, val)
|
|
|
|
def file_was_modified_on_disk(self):
|
|
for signal_file in self.signals_files:
|
|
if signal_file.file_was_modified_on_disk():
|
|
return True
|
|
return False
|
|
|
|
def show_files_separately(self, val):
|
|
self.separate_files = val
|
|
|
|
# lazy updating of the signals of each of the workers
|
|
if self.separate_files and self.signal_files_need_update:
|
|
for signal_file in self.signals_files:
|
|
signal_file.update_source_and_signals()
|
|
self.signal_files_need_update = False
|
|
|
|
for signal in self.signals.values():
|
|
if signal.selected:
|
|
if val:
|
|
signal.set_dash("4 4")
|
|
else:
|
|
signal.set_dash("")
|
|
for signal_file in self.signals_files:
|
|
try:
|
|
if val:
|
|
signal_file.set_signal_selection(signal.name, signal.selected)
|
|
else:
|
|
signal_file.set_signal_selection(signal.name, False)
|
|
except:
|
|
pass |