mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
logging screen output to file (#479)
Co-authored-by: Gal Leibovich <gal.leibovich@intel.com>
This commit is contained in:
@@ -353,7 +353,7 @@ class Agent(AgentInterface):
|
||||
worker_device=self.worker_device)
|
||||
|
||||
if self.ap.visualization.print_networks_summary:
|
||||
print(networks[network_name])
|
||||
screen.print(networks[network_name])
|
||||
|
||||
return networks
|
||||
|
||||
|
||||
@@ -174,19 +174,19 @@ def handle_distributed_coach_orchestrator(args):
|
||||
data_store_params=ds_params_instance)
|
||||
orchestrator = Kubernetes(orchestration_params)
|
||||
if not orchestrator.setup(args.checkpoint_restore_dir):
|
||||
print("Could not setup.")
|
||||
screen.print("Could not setup.")
|
||||
return 1
|
||||
|
||||
if orchestrator.deploy_trainer():
|
||||
print("Successfully deployed trainer.")
|
||||
screen.print("Successfully deployed trainer.")
|
||||
else:
|
||||
print("Could not deploy trainer.")
|
||||
screen.print("Could not deploy trainer.")
|
||||
return 1
|
||||
|
||||
if orchestrator.deploy_worker():
|
||||
print("Successfully deployed rollout worker(s).")
|
||||
screen.print("Successfully deployed rollout worker(s).")
|
||||
else:
|
||||
print("Could not deploy rollout worker(s).")
|
||||
screen.print("Could not deploy rollout worker(s).")
|
||||
return 1
|
||||
|
||||
if args.dump_worker_logs:
|
||||
|
||||
@@ -19,6 +19,7 @@ import uuid
|
||||
|
||||
from rl_coach.data_stores.data_store import DataStoreParameters
|
||||
from rl_coach.data_stores.checkpoint_data_store import CheckpointDataStore
|
||||
from rl_coach.logger import screen
|
||||
|
||||
|
||||
class NFSDataStoreParameters(DataStoreParameters):
|
||||
@@ -151,7 +152,7 @@ class NFSDataStore(CheckpointDataStore):
|
||||
k8s_apps_v1_api_client.create_namespaced_deployment(self.params.namespace, deployment)
|
||||
self.params.name = name
|
||||
except k8sclient.rest.ApiException as e:
|
||||
print("Got exception: %s\n while creating nfs-server", e)
|
||||
screen.print("Got exception: %s\n while creating nfs-server", e)
|
||||
return False
|
||||
|
||||
k8s_core_v1_api_client = k8sclient.CoreV1Api()
|
||||
@@ -178,7 +179,7 @@ class NFSDataStore(CheckpointDataStore):
|
||||
self.params.svc_name = svc_name
|
||||
self.params.server = svc_response.spec.cluster_ip
|
||||
except k8sclient.rest.ApiException as e:
|
||||
print("Got exception: %s\n while creating a service for nfs-server", e)
|
||||
screen.print("Got exception: %s\n while creating a service for nfs-server", e)
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -212,7 +213,7 @@ class NFSDataStore(CheckpointDataStore):
|
||||
k8s_api_client.create_persistent_volume(persistent_volume)
|
||||
self.params.pv_name = pv_name
|
||||
except k8sclient.rest.ApiException as e:
|
||||
print("Got exception: %s\n while creating the NFS PV", e)
|
||||
screen.print("Got exception: %s\n while creating the NFS PV", e)
|
||||
return False
|
||||
|
||||
pvc_name = "nfs-ckpt-pvc-{}".format(uuid.uuid4())
|
||||
@@ -238,7 +239,7 @@ class NFSDataStore(CheckpointDataStore):
|
||||
k8s_api_client.create_namespaced_persistent_volume_claim(self.params.namespace, persistent_volume_claim)
|
||||
self.params.pvc_name = pvc_name
|
||||
except k8sclient.rest.ApiException as e:
|
||||
print("Got exception: %s\n while creating the NFS PVC", e)
|
||||
screen.print("Got exception: %s\n while creating the NFS PVC", e)
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -252,14 +253,14 @@ class NFSDataStore(CheckpointDataStore):
|
||||
try:
|
||||
k8s_apps_v1_api_client.delete_namespaced_deployment(self.params.name, self.params.namespace, del_options)
|
||||
except k8sclient.rest.ApiException as e:
|
||||
print("Got exception: %s\n while deleting nfs-server", e)
|
||||
screen.print("Got exception: %s\n while deleting nfs-server", e)
|
||||
return False
|
||||
|
||||
k8s_core_v1_api_client = k8sclient.CoreV1Api()
|
||||
try:
|
||||
k8s_core_v1_api_client.delete_namespaced_service(self.params.svc_name, self.params.namespace, del_options)
|
||||
except k8sclient.rest.ApiException as e:
|
||||
print("Got exception: %s\n while deleting the service for nfs-server", e)
|
||||
screen.print("Got exception: %s\n while deleting the service for nfs-server", e)
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -276,13 +277,13 @@ class NFSDataStore(CheckpointDataStore):
|
||||
try:
|
||||
k8s_api_client.delete_persistent_volume(self.params.pv_name, del_options)
|
||||
except k8sclient.rest.ApiException as e:
|
||||
print("Got exception: %s\n while deleting NFS PV", e)
|
||||
screen.print("Got exception: %s\n while deleting NFS PV", e)
|
||||
return False
|
||||
|
||||
try:
|
||||
k8s_api_client.delete_namespaced_persistent_volume_claim(self.params.pvc_name, self.params.namespace, del_options)
|
||||
except k8sclient.rest.ApiException as e:
|
||||
print("Got exception: %s\n while deleting NFS PVC", e)
|
||||
screen.print("Got exception: %s\n while deleting NFS PVC", e)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@@ -22,6 +22,7 @@ from minio.error import S3Error
|
||||
from configparser import ConfigParser, Error
|
||||
from rl_coach.checkpoint import CheckpointStateFile
|
||||
from rl_coach.data_stores.data_store import SyncFiles
|
||||
from rl_coach.logger import screen
|
||||
|
||||
import os
|
||||
import time
|
||||
@@ -62,7 +63,7 @@ class S3DataStore(CheckpointDataStore):
|
||||
access_key = config.get('default', 'aws_access_key_id')
|
||||
secret_key = config.get('default', 'aws_secret_access_key')
|
||||
except Error as e:
|
||||
print("Error when reading S3 credentials file: %s", e)
|
||||
screen.print("Error when reading S3 credentials file: %s", e)
|
||||
else:
|
||||
access_key = os.environ.get('ACCESS_KEY_ID')
|
||||
secret_key = os.environ.get('SECRET_ACCESS_KEY')
|
||||
@@ -135,7 +136,7 @@ class S3DataStore(CheckpointDataStore):
|
||||
self.mc.fput_object(self.params.bucket_name, filename, os.path.join(self.params.expt_dir, 'gifs', filename))
|
||||
|
||||
except S3Error as e:
|
||||
print("Got exception: %s\n while saving to S3", e)
|
||||
screen.print("Got exception: %s\n while saving to S3", e)
|
||||
|
||||
def load_from_store(self):
|
||||
"""
|
||||
@@ -191,7 +192,7 @@ class S3DataStore(CheckpointDataStore):
|
||||
self.mc.fget_object(obj.bucket_name, obj.object_name, filename)
|
||||
|
||||
except S3Error as e:
|
||||
print("Got exception: %s\n while loading from S3", e)
|
||||
screen.print("Got exception: %s\n while loading from S3", e)
|
||||
|
||||
def setup_checkpoint_dir(self, crd=None):
|
||||
if crd:
|
||||
|
||||
@@ -60,6 +60,20 @@ class ScreenLogger(object):
|
||||
def __init__(self, name, use_colors=True):
|
||||
self.name = name
|
||||
self.set_use_colors(use_colors)
|
||||
self.log_file = None
|
||||
|
||||
def print(self, *text: str) -> None:
|
||||
"""
|
||||
Prints to console and as well as to log.txt
|
||||
:param text: The text to print
|
||||
:return: None
|
||||
"""
|
||||
if not self.log_file:
|
||||
self.log_file = open(os.path.join(experiment_path, "log.txt"), "a")
|
||||
self.log_file.write(",".join([t for t in text]))
|
||||
self.log_file.write("\n")
|
||||
self.log_file.flush()
|
||||
print(*text, flush=True)
|
||||
|
||||
def set_use_colors(self, use_colors):
|
||||
self._use_colors = use_colors
|
||||
@@ -79,12 +93,12 @@ class ScreenLogger(object):
|
||||
self._suffix = ""
|
||||
|
||||
def separator(self):
|
||||
print("")
|
||||
print("--------------------------------")
|
||||
print("")
|
||||
self.print("")
|
||||
self.print("--------------------------------")
|
||||
self.print("")
|
||||
|
||||
def log(self, data):
|
||||
print(data)
|
||||
self.print(data)
|
||||
|
||||
def log_dict(self, data, prefix=""):
|
||||
timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S.%f') + ' '
|
||||
@@ -93,25 +107,25 @@ class ScreenLogger(object):
|
||||
str += "{}{}{} - ".format(Colors.PURPLE, prefix, Colors.END)
|
||||
for k, v in data.items():
|
||||
str += "{}{}: {}{} ".format(Colors.BLUE, k, Colors.END, v)
|
||||
print(str)
|
||||
self.print(str)
|
||||
else:
|
||||
logentries = [timestamp]
|
||||
for k, v in data.items():
|
||||
logentries.append("{}={}".format(k, v))
|
||||
logline = "{}> {}".format(prefix, ", ".join(logentries))
|
||||
print(logline)
|
||||
self.print(logline)
|
||||
|
||||
def log_title(self, title):
|
||||
print("{}{}{}".format(self._prefix_title, title, self._suffix))
|
||||
self.print("{}{}{}".format(self._prefix_title, title, self._suffix))
|
||||
|
||||
def success(self, text):
|
||||
print("{}{}{}".format(self._prefix_success, text, self._suffix))
|
||||
self.print("{}{}{}".format(self._prefix_success, text, self._suffix))
|
||||
|
||||
def warning(self, text):
|
||||
print("{}{}{}".format(self._prefix_warning, text, self._suffix))
|
||||
self.print("{}{}{}".format(self._prefix_warning, text, self._suffix))
|
||||
|
||||
def error(self, text, crash=True):
|
||||
print("{}{}{}".format(self._prefix_error, text, self._suffix))
|
||||
self.print("{}{}{}".format(self._prefix_error, text, self._suffix))
|
||||
if crash:
|
||||
exit(1)
|
||||
|
||||
@@ -167,9 +181,9 @@ class ScreenLogger(object):
|
||||
:return: None
|
||||
"""
|
||||
if self._use_colors:
|
||||
print("\x1b]2;{}\x07".format(title))
|
||||
self.print("\x1b]2;{}\x07".format(title))
|
||||
else:
|
||||
print("Title: %s" % title)
|
||||
self.print("Title: %s" % title)
|
||||
|
||||
|
||||
class BaseLogger(object):
|
||||
@@ -441,4 +455,4 @@ def get_experiment_path(experiment_name, initial_experiment_path=None, create_pa
|
||||
|
||||
|
||||
global screen
|
||||
screen = ScreenLogger("")
|
||||
screen = ScreenLogger(experiment_path)
|
||||
|
||||
@@ -22,6 +22,7 @@ import time
|
||||
|
||||
from rl_coach.memories.backend.memory import MemoryBackend, MemoryBackendParameters
|
||||
from rl_coach.core_types import Transition, Episode, EnvironmentSteps, EnvironmentEpisodes
|
||||
from rl_coach.logger import screen
|
||||
|
||||
|
||||
class RedisPubSubMemoryBackendParameters(MemoryBackendParameters):
|
||||
@@ -115,10 +116,10 @@ class RedisPubSubBackend(MemoryBackend):
|
||||
config.load_kube_config()
|
||||
api_client = client.AppsV1Api()
|
||||
try:
|
||||
print(self.params.orchestrator_params)
|
||||
screen.print(self.params.orchestrator_params)
|
||||
api_client.create_namespaced_deployment(self.params.orchestrator_params['namespace'], deployment)
|
||||
except client.rest.ApiException as e:
|
||||
print("Got exception: %s\n while creating redis-server", e)
|
||||
screen.print("Got exception: %s\n while creating redis-server", e)
|
||||
return False
|
||||
|
||||
core_v1_api = client.CoreV1Api()
|
||||
@@ -147,7 +148,7 @@ class RedisPubSubBackend(MemoryBackend):
|
||||
self.params.redis_port = 6379
|
||||
return True
|
||||
except client.rest.ApiException as e:
|
||||
print("Got exception: %s\n while creating a service for redis-server", e)
|
||||
screen.print("Got exception: %s\n while creating a service for redis-server", e)
|
||||
return False
|
||||
|
||||
def undeploy(self):
|
||||
@@ -164,13 +165,13 @@ class RedisPubSubBackend(MemoryBackend):
|
||||
try:
|
||||
api_client.delete_namespaced_deployment(self.redis_server_name, self.params.orchestrator_params['namespace'], delete_options)
|
||||
except client.rest.ApiException as e:
|
||||
print("Got exception: %s\n while deleting redis-server", e)
|
||||
screen.print("Got exception: %s\n while deleting redis-server", e)
|
||||
|
||||
api_client = client.CoreV1Api()
|
||||
try:
|
||||
api_client.delete_namespaced_service(self.redis_service_name, self.params.orchestrator_params['namespace'], delete_options)
|
||||
except client.rest.ApiException as e:
|
||||
print("Got exception: %s\n while deleting redis-server", e)
|
||||
screen.print("Got exception: %s\n while deleting redis-server", e)
|
||||
|
||||
def sample(self, size):
|
||||
pass
|
||||
|
||||
@@ -32,6 +32,7 @@ from rl_coach.memories.backend.memory import MemoryBackendParameters
|
||||
from rl_coach.memories.backend.memory_impl import get_memory_backend
|
||||
from rl_coach.data_stores.data_store import DataStoreParameters
|
||||
from rl_coach.data_stores.data_store_impl import get_data_store
|
||||
from rl_coach.logger import screen
|
||||
|
||||
|
||||
class RunTypeParameters():
|
||||
@@ -113,7 +114,7 @@ class Kubernetes(Deploy):
|
||||
self.s3_access_key = s3config.get('default', 'aws_access_key_id')
|
||||
self.s3_secret_key = s3config.get('default', 'aws_secret_access_key')
|
||||
except Error as e:
|
||||
print("Error when reading S3 credentials file: %s", e)
|
||||
screen.print("Error when reading S3 credentials file: %s", e)
|
||||
else:
|
||||
self.s3_access_key = os.environ.get('ACCESS_KEY_ID')
|
||||
self.s3_secret_key = os.environ.get('SECRET_ACCESS_KEY')
|
||||
@@ -244,7 +245,7 @@ class Kubernetes(Deploy):
|
||||
trainer_params.orchestration_params['job_name'] = name
|
||||
return True
|
||||
except k8sclient.rest.ApiException as e:
|
||||
print("Got exception: %s\n while creating job", e)
|
||||
screen.print("Got exception: %s\n while creating job", e)
|
||||
return False
|
||||
|
||||
def deploy_worker(self):
|
||||
@@ -357,7 +358,7 @@ class Kubernetes(Deploy):
|
||||
worker_params.orchestration_params['job_name'] = name
|
||||
return True
|
||||
except k8sclient.rest.ApiException as e:
|
||||
print("Got exception: %s\n while creating Job", e)
|
||||
screen.print("Got exception: %s\n while creating Job", e)
|
||||
return False
|
||||
|
||||
def worker_logs(self, path='./logs'):
|
||||
@@ -377,7 +378,7 @@ class Kubernetes(Deploy):
|
||||
|
||||
# pod = pods.items[0]
|
||||
except k8sclient.rest.ApiException as e:
|
||||
print("Got exception: %s\n while reading pods", e)
|
||||
screen.print("Got exception: %s\n while reading pods", e)
|
||||
return
|
||||
|
||||
if not pods or len(pods.items) == 0:
|
||||
@@ -410,7 +411,7 @@ class Kubernetes(Deploy):
|
||||
|
||||
pod = pods.items[0]
|
||||
except k8sclient.rest.ApiException as e:
|
||||
print("Got exception: %s\n while reading pods", e)
|
||||
screen.print("Got exception: %s\n while reading pods", e)
|
||||
return
|
||||
|
||||
if not pod:
|
||||
@@ -427,7 +428,7 @@ class Kubernetes(Deploy):
|
||||
pod_name, self.params.namespace, follow=True,
|
||||
_preload_content=False
|
||||
):
|
||||
print(line.decode('utf-8'), flush=True, end='')
|
||||
screen.print(line.decode('utf-8'), flush=True, end='')
|
||||
except k8sclient.rest.ApiException as e:
|
||||
pass
|
||||
|
||||
@@ -469,12 +470,12 @@ class Kubernetes(Deploy):
|
||||
try:
|
||||
api_client.delete_namespaced_job(trainer_params.orchestration_params['job_name'], self.params.namespace, delete_options)
|
||||
except k8sclient.rest.ApiException as e:
|
||||
print("Got exception: %s\n while deleting trainer", e)
|
||||
screen.print("Got exception: %s\n while deleting trainer", e)
|
||||
worker_params = self.params.run_type_params.get(str(RunType.ROLLOUT_WORKER), None)
|
||||
if worker_params:
|
||||
try:
|
||||
api_client.delete_namespaced_job(worker_params.orchestration_params['job_name'], self.params.namespace, delete_options)
|
||||
except k8sclient.rest.ApiException as e:
|
||||
print("Got exception: %s\n while deleting workers", e)
|
||||
screen.print("Got exception: %s\n while deleting workers", e)
|
||||
self.memory_backend.undeploy()
|
||||
self.data_store.undeploy()
|
||||
|
||||
@@ -458,7 +458,7 @@ class Timer(object):
|
||||
self.start = time.time()
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
print(self.prefix, time.time() - self.start)
|
||||
screen.print(self.prefix, time.time() - self.start)
|
||||
|
||||
|
||||
class ReaderWriterLock(object):
|
||||
@@ -516,7 +516,7 @@ class ProgressBar(object):
|
||||
sys.stdout.flush()
|
||||
|
||||
def close(self):
|
||||
print("")
|
||||
screen.print("")
|
||||
|
||||
|
||||
def start_shell_command_and_wait(command):
|
||||
|
||||
Reference in New Issue
Block a user