1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

logging screen output to file (#479)

Co-authored-by: Gal Leibovich <gal.leibovich@intel.com>
This commit is contained in:
Guy Jacob
2021-05-06 18:02:27 +03:00
committed by GitHub
parent 9106b69227
commit a1a2e67fbd
8 changed files with 63 additions and 45 deletions

View File

@@ -353,7 +353,7 @@ class Agent(AgentInterface):
worker_device=self.worker_device) worker_device=self.worker_device)
if self.ap.visualization.print_networks_summary: if self.ap.visualization.print_networks_summary:
print(networks[network_name]) screen.print(networks[network_name])
return networks return networks

View File

@@ -174,19 +174,19 @@ def handle_distributed_coach_orchestrator(args):
data_store_params=ds_params_instance) data_store_params=ds_params_instance)
orchestrator = Kubernetes(orchestration_params) orchestrator = Kubernetes(orchestration_params)
if not orchestrator.setup(args.checkpoint_restore_dir): if not orchestrator.setup(args.checkpoint_restore_dir):
print("Could not setup.") screen.print("Could not setup.")
return 1 return 1
if orchestrator.deploy_trainer(): if orchestrator.deploy_trainer():
print("Successfully deployed trainer.") screen.print("Successfully deployed trainer.")
else: else:
print("Could not deploy trainer.") screen.print("Could not deploy trainer.")
return 1 return 1
if orchestrator.deploy_worker(): if orchestrator.deploy_worker():
print("Successfully deployed rollout worker(s).") screen.print("Successfully deployed rollout worker(s).")
else: else:
print("Could not deploy rollout worker(s).") screen.print("Could not deploy rollout worker(s).")
return 1 return 1
if args.dump_worker_logs: if args.dump_worker_logs:

View File

@@ -19,6 +19,7 @@ import uuid
from rl_coach.data_stores.data_store import DataStoreParameters from rl_coach.data_stores.data_store import DataStoreParameters
from rl_coach.data_stores.checkpoint_data_store import CheckpointDataStore from rl_coach.data_stores.checkpoint_data_store import CheckpointDataStore
from rl_coach.logger import screen
class NFSDataStoreParameters(DataStoreParameters): class NFSDataStoreParameters(DataStoreParameters):
@@ -151,7 +152,7 @@ class NFSDataStore(CheckpointDataStore):
k8s_apps_v1_api_client.create_namespaced_deployment(self.params.namespace, deployment) k8s_apps_v1_api_client.create_namespaced_deployment(self.params.namespace, deployment)
self.params.name = name self.params.name = name
except k8sclient.rest.ApiException as e: except k8sclient.rest.ApiException as e:
print("Got exception: %s\n while creating nfs-server", e) screen.print("Got exception: %s\n while creating nfs-server", e)
return False return False
k8s_core_v1_api_client = k8sclient.CoreV1Api() k8s_core_v1_api_client = k8sclient.CoreV1Api()
@@ -178,7 +179,7 @@ class NFSDataStore(CheckpointDataStore):
self.params.svc_name = svc_name self.params.svc_name = svc_name
self.params.server = svc_response.spec.cluster_ip self.params.server = svc_response.spec.cluster_ip
except k8sclient.rest.ApiException as e: except k8sclient.rest.ApiException as e:
print("Got exception: %s\n while creating a service for nfs-server", e) screen.print("Got exception: %s\n while creating a service for nfs-server", e)
return False return False
return True return True
@@ -212,7 +213,7 @@ class NFSDataStore(CheckpointDataStore):
k8s_api_client.create_persistent_volume(persistent_volume) k8s_api_client.create_persistent_volume(persistent_volume)
self.params.pv_name = pv_name self.params.pv_name = pv_name
except k8sclient.rest.ApiException as e: except k8sclient.rest.ApiException as e:
print("Got exception: %s\n while creating the NFS PV", e) screen.print("Got exception: %s\n while creating the NFS PV", e)
return False return False
pvc_name = "nfs-ckpt-pvc-{}".format(uuid.uuid4()) pvc_name = "nfs-ckpt-pvc-{}".format(uuid.uuid4())
@@ -238,7 +239,7 @@ class NFSDataStore(CheckpointDataStore):
k8s_api_client.create_namespaced_persistent_volume_claim(self.params.namespace, persistent_volume_claim) k8s_api_client.create_namespaced_persistent_volume_claim(self.params.namespace, persistent_volume_claim)
self.params.pvc_name = pvc_name self.params.pvc_name = pvc_name
except k8sclient.rest.ApiException as e: except k8sclient.rest.ApiException as e:
print("Got exception: %s\n while creating the NFS PVC", e) screen.print("Got exception: %s\n while creating the NFS PVC", e)
return False return False
return True return True
@@ -252,14 +253,14 @@ class NFSDataStore(CheckpointDataStore):
try: try:
k8s_apps_v1_api_client.delete_namespaced_deployment(self.params.name, self.params.namespace, del_options) k8s_apps_v1_api_client.delete_namespaced_deployment(self.params.name, self.params.namespace, del_options)
except k8sclient.rest.ApiException as e: except k8sclient.rest.ApiException as e:
print("Got exception: %s\n while deleting nfs-server", e) screen.print("Got exception: %s\n while deleting nfs-server", e)
return False return False
k8s_core_v1_api_client = k8sclient.CoreV1Api() k8s_core_v1_api_client = k8sclient.CoreV1Api()
try: try:
k8s_core_v1_api_client.delete_namespaced_service(self.params.svc_name, self.params.namespace, del_options) k8s_core_v1_api_client.delete_namespaced_service(self.params.svc_name, self.params.namespace, del_options)
except k8sclient.rest.ApiException as e: except k8sclient.rest.ApiException as e:
print("Got exception: %s\n while deleting the service for nfs-server", e) screen.print("Got exception: %s\n while deleting the service for nfs-server", e)
return False return False
return True return True
@@ -276,13 +277,13 @@ class NFSDataStore(CheckpointDataStore):
try: try:
k8s_api_client.delete_persistent_volume(self.params.pv_name, del_options) k8s_api_client.delete_persistent_volume(self.params.pv_name, del_options)
except k8sclient.rest.ApiException as e: except k8sclient.rest.ApiException as e:
print("Got exception: %s\n while deleting NFS PV", e) screen.print("Got exception: %s\n while deleting NFS PV", e)
return False return False
try: try:
k8s_api_client.delete_namespaced_persistent_volume_claim(self.params.pvc_name, self.params.namespace, del_options) k8s_api_client.delete_namespaced_persistent_volume_claim(self.params.pvc_name, self.params.namespace, del_options)
except k8sclient.rest.ApiException as e: except k8sclient.rest.ApiException as e:
print("Got exception: %s\n while deleting NFS PVC", e) screen.print("Got exception: %s\n while deleting NFS PVC", e)
return False return False
return True return True

View File

@@ -22,6 +22,7 @@ from minio.error import S3Error
from configparser import ConfigParser, Error from configparser import ConfigParser, Error
from rl_coach.checkpoint import CheckpointStateFile from rl_coach.checkpoint import CheckpointStateFile
from rl_coach.data_stores.data_store import SyncFiles from rl_coach.data_stores.data_store import SyncFiles
from rl_coach.logger import screen
import os import os
import time import time
@@ -62,7 +63,7 @@ class S3DataStore(CheckpointDataStore):
access_key = config.get('default', 'aws_access_key_id') access_key = config.get('default', 'aws_access_key_id')
secret_key = config.get('default', 'aws_secret_access_key') secret_key = config.get('default', 'aws_secret_access_key')
except Error as e: except Error as e:
print("Error when reading S3 credentials file: %s", e) screen.print("Error when reading S3 credentials file: %s", e)
else: else:
access_key = os.environ.get('ACCESS_KEY_ID') access_key = os.environ.get('ACCESS_KEY_ID')
secret_key = os.environ.get('SECRET_ACCESS_KEY') secret_key = os.environ.get('SECRET_ACCESS_KEY')
@@ -135,7 +136,7 @@ class S3DataStore(CheckpointDataStore):
self.mc.fput_object(self.params.bucket_name, filename, os.path.join(self.params.expt_dir, 'gifs', filename)) self.mc.fput_object(self.params.bucket_name, filename, os.path.join(self.params.expt_dir, 'gifs', filename))
except S3Error as e: except S3Error as e:
print("Got exception: %s\n while saving to S3", e) screen.print("Got exception: %s\n while saving to S3", e)
def load_from_store(self): def load_from_store(self):
""" """
@@ -191,7 +192,7 @@ class S3DataStore(CheckpointDataStore):
self.mc.fget_object(obj.bucket_name, obj.object_name, filename) self.mc.fget_object(obj.bucket_name, obj.object_name, filename)
except S3Error as e: except S3Error as e:
print("Got exception: %s\n while loading from S3", e) screen.print("Got exception: %s\n while loading from S3", e)
def setup_checkpoint_dir(self, crd=None): def setup_checkpoint_dir(self, crd=None):
if crd: if crd:

View File

@@ -60,6 +60,20 @@ class ScreenLogger(object):
def __init__(self, name, use_colors=True): def __init__(self, name, use_colors=True):
self.name = name self.name = name
self.set_use_colors(use_colors) self.set_use_colors(use_colors)
self.log_file = None
def print(self, *text: str) -> None:
"""
Prints to console and as well as to log.txt
:param text: The text to print
:return: None
"""
if not self.log_file:
self.log_file = open(os.path.join(experiment_path, "log.txt"), "a")
self.log_file.write(",".join([t for t in text]))
self.log_file.write("\n")
self.log_file.flush()
print(*text, flush=True)
def set_use_colors(self, use_colors): def set_use_colors(self, use_colors):
self._use_colors = use_colors self._use_colors = use_colors
@@ -79,12 +93,12 @@ class ScreenLogger(object):
self._suffix = "" self._suffix = ""
def separator(self): def separator(self):
print("") self.print("")
print("--------------------------------") self.print("--------------------------------")
print("") self.print("")
def log(self, data): def log(self, data):
print(data) self.print(data)
def log_dict(self, data, prefix=""): def log_dict(self, data, prefix=""):
timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S.%f') + ' ' timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S.%f') + ' '
@@ -93,25 +107,25 @@ class ScreenLogger(object):
str += "{}{}{} - ".format(Colors.PURPLE, prefix, Colors.END) str += "{}{}{} - ".format(Colors.PURPLE, prefix, Colors.END)
for k, v in data.items(): for k, v in data.items():
str += "{}{}: {}{} ".format(Colors.BLUE, k, Colors.END, v) str += "{}{}: {}{} ".format(Colors.BLUE, k, Colors.END, v)
print(str) self.print(str)
else: else:
logentries = [timestamp] logentries = [timestamp]
for k, v in data.items(): for k, v in data.items():
logentries.append("{}={}".format(k, v)) logentries.append("{}={}".format(k, v))
logline = "{}> {}".format(prefix, ", ".join(logentries)) logline = "{}> {}".format(prefix, ", ".join(logentries))
print(logline) self.print(logline)
def log_title(self, title): def log_title(self, title):
print("{}{}{}".format(self._prefix_title, title, self._suffix)) self.print("{}{}{}".format(self._prefix_title, title, self._suffix))
def success(self, text): def success(self, text):
print("{}{}{}".format(self._prefix_success, text, self._suffix)) self.print("{}{}{}".format(self._prefix_success, text, self._suffix))
def warning(self, text): def warning(self, text):
print("{}{}{}".format(self._prefix_warning, text, self._suffix)) self.print("{}{}{}".format(self._prefix_warning, text, self._suffix))
def error(self, text, crash=True): def error(self, text, crash=True):
print("{}{}{}".format(self._prefix_error, text, self._suffix)) self.print("{}{}{}".format(self._prefix_error, text, self._suffix))
if crash: if crash:
exit(1) exit(1)
@@ -167,9 +181,9 @@ class ScreenLogger(object):
:return: None :return: None
""" """
if self._use_colors: if self._use_colors:
print("\x1b]2;{}\x07".format(title)) self.print("\x1b]2;{}\x07".format(title))
else: else:
print("Title: %s" % title) self.print("Title: %s" % title)
class BaseLogger(object): class BaseLogger(object):
@@ -441,4 +455,4 @@ def get_experiment_path(experiment_name, initial_experiment_path=None, create_pa
global screen global screen
screen = ScreenLogger("") screen = ScreenLogger(experiment_path)

View File

@@ -22,6 +22,7 @@ import time
from rl_coach.memories.backend.memory import MemoryBackend, MemoryBackendParameters from rl_coach.memories.backend.memory import MemoryBackend, MemoryBackendParameters
from rl_coach.core_types import Transition, Episode, EnvironmentSteps, EnvironmentEpisodes from rl_coach.core_types import Transition, Episode, EnvironmentSteps, EnvironmentEpisodes
from rl_coach.logger import screen
class RedisPubSubMemoryBackendParameters(MemoryBackendParameters): class RedisPubSubMemoryBackendParameters(MemoryBackendParameters):
@@ -115,10 +116,10 @@ class RedisPubSubBackend(MemoryBackend):
config.load_kube_config() config.load_kube_config()
api_client = client.AppsV1Api() api_client = client.AppsV1Api()
try: try:
print(self.params.orchestrator_params) screen.print(self.params.orchestrator_params)
api_client.create_namespaced_deployment(self.params.orchestrator_params['namespace'], deployment) api_client.create_namespaced_deployment(self.params.orchestrator_params['namespace'], deployment)
except client.rest.ApiException as e: except client.rest.ApiException as e:
print("Got exception: %s\n while creating redis-server", e) screen.print("Got exception: %s\n while creating redis-server", e)
return False return False
core_v1_api = client.CoreV1Api() core_v1_api = client.CoreV1Api()
@@ -147,7 +148,7 @@ class RedisPubSubBackend(MemoryBackend):
self.params.redis_port = 6379 self.params.redis_port = 6379
return True return True
except client.rest.ApiException as e: except client.rest.ApiException as e:
print("Got exception: %s\n while creating a service for redis-server", e) screen.print("Got exception: %s\n while creating a service for redis-server", e)
return False return False
def undeploy(self): def undeploy(self):
@@ -164,13 +165,13 @@ class RedisPubSubBackend(MemoryBackend):
try: try:
api_client.delete_namespaced_deployment(self.redis_server_name, self.params.orchestrator_params['namespace'], delete_options) api_client.delete_namespaced_deployment(self.redis_server_name, self.params.orchestrator_params['namespace'], delete_options)
except client.rest.ApiException as e: except client.rest.ApiException as e:
print("Got exception: %s\n while deleting redis-server", e) screen.print("Got exception: %s\n while deleting redis-server", e)
api_client = client.CoreV1Api() api_client = client.CoreV1Api()
try: try:
api_client.delete_namespaced_service(self.redis_service_name, self.params.orchestrator_params['namespace'], delete_options) api_client.delete_namespaced_service(self.redis_service_name, self.params.orchestrator_params['namespace'], delete_options)
except client.rest.ApiException as e: except client.rest.ApiException as e:
print("Got exception: %s\n while deleting redis-server", e) screen.print("Got exception: %s\n while deleting redis-server", e)
def sample(self, size): def sample(self, size):
pass pass

View File

@@ -32,6 +32,7 @@ from rl_coach.memories.backend.memory import MemoryBackendParameters
from rl_coach.memories.backend.memory_impl import get_memory_backend from rl_coach.memories.backend.memory_impl import get_memory_backend
from rl_coach.data_stores.data_store import DataStoreParameters from rl_coach.data_stores.data_store import DataStoreParameters
from rl_coach.data_stores.data_store_impl import get_data_store from rl_coach.data_stores.data_store_impl import get_data_store
from rl_coach.logger import screen
class RunTypeParameters(): class RunTypeParameters():
@@ -113,7 +114,7 @@ class Kubernetes(Deploy):
self.s3_access_key = s3config.get('default', 'aws_access_key_id') self.s3_access_key = s3config.get('default', 'aws_access_key_id')
self.s3_secret_key = s3config.get('default', 'aws_secret_access_key') self.s3_secret_key = s3config.get('default', 'aws_secret_access_key')
except Error as e: except Error as e:
print("Error when reading S3 credentials file: %s", e) screen.print("Error when reading S3 credentials file: %s", e)
else: else:
self.s3_access_key = os.environ.get('ACCESS_KEY_ID') self.s3_access_key = os.environ.get('ACCESS_KEY_ID')
self.s3_secret_key = os.environ.get('SECRET_ACCESS_KEY') self.s3_secret_key = os.environ.get('SECRET_ACCESS_KEY')
@@ -244,7 +245,7 @@ class Kubernetes(Deploy):
trainer_params.orchestration_params['job_name'] = name trainer_params.orchestration_params['job_name'] = name
return True return True
except k8sclient.rest.ApiException as e: except k8sclient.rest.ApiException as e:
print("Got exception: %s\n while creating job", e) screen.print("Got exception: %s\n while creating job", e)
return False return False
def deploy_worker(self): def deploy_worker(self):
@@ -357,7 +358,7 @@ class Kubernetes(Deploy):
worker_params.orchestration_params['job_name'] = name worker_params.orchestration_params['job_name'] = name
return True return True
except k8sclient.rest.ApiException as e: except k8sclient.rest.ApiException as e:
print("Got exception: %s\n while creating Job", e) screen.print("Got exception: %s\n while creating Job", e)
return False return False
def worker_logs(self, path='./logs'): def worker_logs(self, path='./logs'):
@@ -377,7 +378,7 @@ class Kubernetes(Deploy):
# pod = pods.items[0] # pod = pods.items[0]
except k8sclient.rest.ApiException as e: except k8sclient.rest.ApiException as e:
print("Got exception: %s\n while reading pods", e) screen.print("Got exception: %s\n while reading pods", e)
return return
if not pods or len(pods.items) == 0: if not pods or len(pods.items) == 0:
@@ -410,7 +411,7 @@ class Kubernetes(Deploy):
pod = pods.items[0] pod = pods.items[0]
except k8sclient.rest.ApiException as e: except k8sclient.rest.ApiException as e:
print("Got exception: %s\n while reading pods", e) screen.print("Got exception: %s\n while reading pods", e)
return return
if not pod: if not pod:
@@ -427,7 +428,7 @@ class Kubernetes(Deploy):
pod_name, self.params.namespace, follow=True, pod_name, self.params.namespace, follow=True,
_preload_content=False _preload_content=False
): ):
print(line.decode('utf-8'), flush=True, end='') screen.print(line.decode('utf-8'), flush=True, end='')
except k8sclient.rest.ApiException as e: except k8sclient.rest.ApiException as e:
pass pass
@@ -469,12 +470,12 @@ class Kubernetes(Deploy):
try: try:
api_client.delete_namespaced_job(trainer_params.orchestration_params['job_name'], self.params.namespace, delete_options) api_client.delete_namespaced_job(trainer_params.orchestration_params['job_name'], self.params.namespace, delete_options)
except k8sclient.rest.ApiException as e: except k8sclient.rest.ApiException as e:
print("Got exception: %s\n while deleting trainer", e) screen.print("Got exception: %s\n while deleting trainer", e)
worker_params = self.params.run_type_params.get(str(RunType.ROLLOUT_WORKER), None) worker_params = self.params.run_type_params.get(str(RunType.ROLLOUT_WORKER), None)
if worker_params: if worker_params:
try: try:
api_client.delete_namespaced_job(worker_params.orchestration_params['job_name'], self.params.namespace, delete_options) api_client.delete_namespaced_job(worker_params.orchestration_params['job_name'], self.params.namespace, delete_options)
except k8sclient.rest.ApiException as e: except k8sclient.rest.ApiException as e:
print("Got exception: %s\n while deleting workers", e) screen.print("Got exception: %s\n while deleting workers", e)
self.memory_backend.undeploy() self.memory_backend.undeploy()
self.data_store.undeploy() self.data_store.undeploy()

View File

@@ -458,7 +458,7 @@ class Timer(object):
self.start = time.time() self.start = time.time()
def __exit__(self, type, value, traceback): def __exit__(self, type, value, traceback):
print(self.prefix, time.time() - self.start) screen.print(self.prefix, time.time() - self.start)
class ReaderWriterLock(object): class ReaderWriterLock(object):
@@ -516,7 +516,7 @@ class ProgressBar(object):
sys.stdout.flush() sys.stdout.flush()
def close(self): def close(self):
print("") screen.print("")
def start_shell_command_and_wait(command): def start_shell_command_and_wait(command):