mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
64 lines
2.3 KiB
Python
64 lines
2.3 KiB
Python
import os
|
|
from os.path import basename
|
|
|
|
import pandas as pd
|
|
from pandas.errors import EmptyDataError
|
|
|
|
from rl_coach.dashboard_components.signals_file_base import SignalsFileBase
|
|
from rl_coach.dashboard_components.globals import x_axis_options
|
|
from rl_coach.utils import break_file_path
|
|
|
|
|
|
class SignalsFile(SignalsFileBase):
|
|
def __init__(self, csv_path, load=True, plot=None, use_dir_name=False):
|
|
super().__init__(plot)
|
|
self.use_dir_name = use_dir_name
|
|
self.full_csv_path = csv_path
|
|
self.dir, self.filename, _ = break_file_path(csv_path)
|
|
|
|
if use_dir_name:
|
|
parent_directory_path = os.path.abspath(os.path.join(os.path.dirname(csv_path), '..'))
|
|
if len(os.listdir(parent_directory_path)) == 1:
|
|
# get the parent directory name (since the current directory is the timestamp directory)
|
|
self.dir = parent_directory_path
|
|
self.filename = basename(self.dir)
|
|
else:
|
|
# get the common directory for all the experiments
|
|
self.dir = os.path.dirname(csv_path)
|
|
self.filename = "{}/{}".format(basename(parent_directory_path), basename(self.dir))
|
|
|
|
if load:
|
|
self.load()
|
|
# this helps set the correct x axis
|
|
self.change_averaging_window(1, force=True)
|
|
|
|
def load_csv(self, idx=None, result=None):
|
|
# load csv and fix sparse data.
|
|
# csv can be in the middle of being written so we use try - except
|
|
new_csv = None
|
|
while new_csv is None:
|
|
try:
|
|
new_csv = pd.read_csv(self.full_csv_path)
|
|
break
|
|
except EmptyDataError:
|
|
new_csv = None
|
|
continue
|
|
|
|
new_csv['Wall-Clock Time'] /= 60.
|
|
new_csv = new_csv.interpolate()
|
|
# remove signals which don't contain any values
|
|
for k, v in new_csv.isna().all().items():
|
|
if v and k not in x_axis_options:
|
|
del new_csv[k]
|
|
new_csv.fillna(value=0, inplace=True)
|
|
|
|
self.csv = new_csv
|
|
|
|
self.last_modified = os.path.getmtime(self.full_csv_path)
|
|
|
|
if idx is not None:
|
|
result[idx] = (self.csv, self.last_modified)
|
|
|
|
def file_was_modified_on_disk(self):
|
|
return self.last_modified != os.path.getmtime(self.full_csv_path)
|