1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00
Files
coach/logger.py

210 lines
6.2 KiB
Python

#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pandas import *
import os
from pprint import pprint
import threading
import time
import datetime
from six.moves import input
from PIL import Image
global failed_imports
failed_imports = []
class Colors(object):
PURPLE = '\033[95m'
CYAN = '\033[96m'
DARKCYAN = '\033[36m'
BLUE = '\033[94m'
GREEN = '\033[92m'
YELLOW = '\033[93m'
RED = '\033[91m'
WHITE = '\033[37m'
BG_RED = '\033[41m'
BG_GREEN = '\033[42m'
BG_YELLOW = '\033[43m'
BG_BLUE = '\033[44m'
BG_PURPLE = '\033[45m'
BG_CYAN = '\033[30;46m'
BG_WHITE = '\x1b[30;47m'
BG_RESET = '\033[49m'
BOLD = '\033[1m'
UNDERLINE_ON = '\033[4m'
UNDERLINE_OFF = '\033[24m'
END = '\033[0m'
# prints to screen with a prefix identifying the origin of the print
class ScreenLogger(object):
def __init__(self, name):
self.name = name
def separator(self):
print("")
print("--------------------------------")
print("")
def log(self, data):
print(self.name + ": " + data)
def log_dict(self, dict, prefix=""):
str = "{}{}{} - ".format(Colors.PURPLE, prefix, Colors.END)
for k, v in dict.items():
str += "{}{}: {}{} ".format(Colors.BLUE, k, Colors.END, v)
print(str)
def log_title(self, title):
print("{}{}{}".format(Colors.BG_CYAN, title, Colors.END))
def success(self, text):
print("{}{}{}".format(Colors.GREEN, text, Colors.END))
def warning(self, text):
print("{}{}{}".format(Colors.YELLOW, text, Colors.END))
def error(self, text):
print("{}{}{}".format(Colors.RED, text, Colors.END))
def ask_input(self, title):
return input("{}{}{}".format(Colors.BG_CYAN, title, Colors.END))
class BaseLogger(object):
def __init__(self):
pass
def set_current_time(self, time):
pass
def set_dump_dir(self, path, task_id):
pass
def create_signal_value(self, signal_name, value, overwrite=True, time=None):
pass
def change_signal_value(self, signal_name, time, value):
pass
def signal_value_exists(self, time, signal_name):
pass
def get_signal_value(self, time, signal_name):
pass
def dump_output_csv(self):
pass
def update_wall_clock_time(self, episode):
pass
class Logger(BaseLogger):
def __init__(self):
BaseLogger.__init__(self)
self.data = DataFrame()
self.csv_path = ''
self.doc_path = ''
self.aggregated_data_across_threads = None
self.start_time = None
self.time = None
self.experiments_path = ""
def set_current_time(self, time):
self.time = time
def two_digits(self, num):
return '%02d' % num
def set_dump_dir(self, experiments_path, task_id=None, add_timestamp=False, filename='worker'):
self.experiments_path = experiments_path
# set file names
if task_id is not None:
filename += "_{}".format(task_id)
# add timestamp
self.time_started = datetime.datetime.now()
if add_timestamp:
t = self.time_started.time()
d = self.time_started.date()
filename += '_{}_{}_{}-{}_{}'.format(self.two_digits(d.day), self.two_digits(d.month),
d.year, self.two_digits(t.hour), self.two_digits(t.minute))
# add an index to the file in case there is already an experiment running with the same timestamp
path_exists = True
idx = 0
while path_exists:
self.csv_path = os.path.join(experiments_path, '{}_{}.csv'.format(filename, idx))
self.doc_path = os.path.join(experiments_path, '{}_{}.doc'.format(filename, idx))
path_exists = os.path.exists(self.csv_path) or os.path.exists(self.doc_path)
idx += 1
def create_signal_value(self, signal_name, value, overwrite=True, time=None):
if not time:
time = self.time
# create only if it doesn't already exist
if overwrite or not self.signal_value_exists(time, signal_name):
self.data.loc[time, signal_name] = value
return True
return False
def change_signal_value(self, signal_name, time, value):
# change only if it already exists
if self.signal_value_exists(time, signal_name):
self.data.loc[time, signal_name] = value
return True
return False
def signal_value_exists(self, time, signal_name):
try:
self.get_signal_value(time, signal_name)
except:
return False
return True
def get_signal_value(self, time, signal_name):
return self.data.loc[time, signal_name]
def dump_output_csv(self):
self.data.index.name = "Episode #"
if len(self.data.index) == 1:
self.start_time = time.time()
self.data.to_csv(self.csv_path)
def update_wall_clock_time(self, episode):
if self.start_time:
self.create_signal_value('Wall-Clock Time', time.time() - self.start_time, time=episode)
def create_gif(self, images, fps=10, name="Gif"):
output_file = '{}_{}.gif'.format(datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S'), name)
output_dir = os.path.join(self.experiments_path, 'gifs')
if not os.path.exists(output_dir):
os.makedirs(output_dir)
output_path = os.path.join(output_dir, output_file)
pil_images = [Image.fromarray(image) for image in images]
pil_images[0].save(output_path, save_all=True, append_images=pil_images[1:], duration=1.0 / fps, loop=0)
global logger
logger = Logger()
global screen
screen = ScreenLogger("")