mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
Till now, most of the modules were importing all of the module objects (variables, classes, functions, other imports) into module namespace, which potentially could (and was) cause of unintentional use of class or methods, which was indirect imported. With this patch, all the star imports were substituted with top-level module, which provides desired class or function. Besides, all imports where sorted (where possible) in a way pep8[1] suggests - first are imports from standard library, than goes third party imports (like numpy, tensorflow etc) and finally coach modules. All of those sections are separated by one empty line. [1] https://www.python.org/dev/peps/pep-0008/#imports
108 lines
4.2 KiB
Python
108 lines
4.2 KiB
Python
import argparse
|
|
import os
|
|
|
|
import matplotlib
|
|
import matplotlib.pyplot as plt
|
|
|
|
from dashboard import SignalsFile
|
|
|
|
|
|
class FigureMaker(object):
|
|
def __init__(self, path, cols, smoothness, signal_to_plot, x_axis):
|
|
self.experiments_path = path
|
|
self.environments = self.list_environments()
|
|
self.cols = cols
|
|
self.rows = int((len(self.environments) + cols - 1) / cols)
|
|
self.smoothness = smoothness
|
|
self.signal_to_plot = signal_to_plot
|
|
self.x_axis = x_axis
|
|
|
|
params = {
|
|
'axes.labelsize': 8,
|
|
'font.size': 10,
|
|
'legend.fontsize': 14,
|
|
'xtick.labelsize': 8,
|
|
'ytick.labelsize': 8,
|
|
'text.usetex': False,
|
|
'figure.figsize': [16, 30]
|
|
}
|
|
matplotlib.rcParams.update(params)
|
|
|
|
def list_environments(self):
|
|
environments = sorted([e.name for e in os.scandir(args.path) if e.is_dir()])
|
|
filtered_environments = self.filter_environments(environments)
|
|
return filtered_environments
|
|
|
|
def filter_environments(self, environments):
|
|
filtered_environments = []
|
|
for idx, environment in enumerate(environments):
|
|
path = os.path.join(args.path, environment)
|
|
experiments = [e.name for e in os.scandir(path) if e.is_dir()]
|
|
|
|
# take only the last updated experiment directory
|
|
last_experiment_dir = max([os.path.join(path, root) for root in experiments], key=os.path.getctime)
|
|
|
|
# make sure there is a csv file inside it
|
|
for file_path in os.listdir(last_experiment_dir):
|
|
full_file_path = os.path.join(last_experiment_dir, file_path)
|
|
if os.path.isfile(full_file_path) and file_path.endswith('.csv'):
|
|
filtered_environments.append((environment, full_file_path))
|
|
|
|
return filtered_environments
|
|
|
|
def plot_figures(self):
|
|
for idx, (environment, full_file_path) in enumerate(self.environments):
|
|
print(environment)
|
|
axis = plt.subplot(self.rows, self.cols, idx + 1)
|
|
signals = SignalsFile(full_file_path)
|
|
signals.change_averaging_window(self.smoothness, force=True, signals=[self.signal_to_plot])
|
|
steps = signals.bokeh_source.data[self.x_axis]
|
|
rewards = signals.bokeh_source.data[self.signal_to_plot]
|
|
|
|
yloc = plt.MaxNLocator(4)
|
|
axis.yaxis.set_major_locator(yloc)
|
|
axis.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
|
|
plt.title(environment, fontsize=10, y=1.08)
|
|
plt.plot(steps, rewards, linewidth=0.8)
|
|
plt.subplots_adjust(hspace=2.0, wspace=0.4)
|
|
|
|
def save_pdf(self, name):
|
|
plt.savefig(name + ".pdf", bbox_inches='tight')
|
|
|
|
def show_figures(self):
|
|
plt.show()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('-p', '--path',
|
|
help="(string) Root directory of the experiments",
|
|
default=None,
|
|
type=str)
|
|
parser.add_argument('-c', '--cols',
|
|
help="(int) Number of plot columns",
|
|
default=6,
|
|
type=int)
|
|
parser.add_argument('-s', '--smoothness',
|
|
help="(int) Number of consequent episodes to average over",
|
|
default=200,
|
|
type=int)
|
|
parser.add_argument('-sig', '--signal',
|
|
help="(str) The name of the signal to plot",
|
|
default='Evaluation Reward',
|
|
type=str)
|
|
parser.add_argument('-x', '--x_axis',
|
|
help="(str) The meaning of the x axis",
|
|
default='Total steps',
|
|
type=str)
|
|
parser.add_argument('-pdf', '--pdf',
|
|
help="(str) A name of a pdf to save to",
|
|
default='atari',
|
|
type=str)
|
|
args = parser.parse_args()
|
|
|
|
maker = FigureMaker(args.path, cols=args.cols, smoothness=args.smoothness, signal_to_plot=args.signal, x_axis=args.x_axis)
|
|
maker.plot_figures()
|
|
maker.save_pdf(args.pdf)
|
|
maker.show_figures()
|