diff --git a/rl_coach/agents/agent.py b/rl_coach/agents/agent.py index 5d12e0b..2790a17 100644 --- a/rl_coach/agents/agent.py +++ b/rl_coach/agents/agent.py @@ -353,7 +353,7 @@ class Agent(AgentInterface): worker_device=self.worker_device) if self.ap.visualization.print_networks_summary: - print(networks[network_name]) + screen.print(networks[network_name]) return networks diff --git a/rl_coach/coach.py b/rl_coach/coach.py index e675c03..d428134 100644 --- a/rl_coach/coach.py +++ b/rl_coach/coach.py @@ -174,19 +174,19 @@ def handle_distributed_coach_orchestrator(args): data_store_params=ds_params_instance) orchestrator = Kubernetes(orchestration_params) if not orchestrator.setup(args.checkpoint_restore_dir): - print("Could not setup.") + screen.print("Could not setup.") return 1 if orchestrator.deploy_trainer(): - print("Successfully deployed trainer.") + screen.print("Successfully deployed trainer.") else: - print("Could not deploy trainer.") + screen.print("Could not deploy trainer.") return 1 if orchestrator.deploy_worker(): - print("Successfully deployed rollout worker(s).") + screen.print("Successfully deployed rollout worker(s).") else: - print("Could not deploy rollout worker(s).") + screen.print("Could not deploy rollout worker(s).") return 1 if args.dump_worker_logs: diff --git a/rl_coach/data_stores/nfs_data_store.py b/rl_coach/data_stores/nfs_data_store.py index cc72e01..1ff39ff 100644 --- a/rl_coach/data_stores/nfs_data_store.py +++ b/rl_coach/data_stores/nfs_data_store.py @@ -19,6 +19,7 @@ import uuid from rl_coach.data_stores.data_store import DataStoreParameters from rl_coach.data_stores.checkpoint_data_store import CheckpointDataStore +from rl_coach.logger import screen class NFSDataStoreParameters(DataStoreParameters): @@ -151,7 +152,7 @@ class NFSDataStore(CheckpointDataStore): k8s_apps_v1_api_client.create_namespaced_deployment(self.params.namespace, deployment) self.params.name = name except k8sclient.rest.ApiException as e: - print("Got exception: %s\n while creating nfs-server", e) + screen.print("Got exception: %s\n while creating nfs-server", e) return False k8s_core_v1_api_client = k8sclient.CoreV1Api() @@ -178,7 +179,7 @@ class NFSDataStore(CheckpointDataStore): self.params.svc_name = svc_name self.params.server = svc_response.spec.cluster_ip except k8sclient.rest.ApiException as e: - print("Got exception: %s\n while creating a service for nfs-server", e) + screen.print("Got exception: %s\n while creating a service for nfs-server", e) return False return True @@ -212,7 +213,7 @@ class NFSDataStore(CheckpointDataStore): k8s_api_client.create_persistent_volume(persistent_volume) self.params.pv_name = pv_name except k8sclient.rest.ApiException as e: - print("Got exception: %s\n while creating the NFS PV", e) + screen.print("Got exception: %s\n while creating the NFS PV", e) return False pvc_name = "nfs-ckpt-pvc-{}".format(uuid.uuid4()) @@ -238,7 +239,7 @@ class NFSDataStore(CheckpointDataStore): k8s_api_client.create_namespaced_persistent_volume_claim(self.params.namespace, persistent_volume_claim) self.params.pvc_name = pvc_name except k8sclient.rest.ApiException as e: - print("Got exception: %s\n while creating the NFS PVC", e) + screen.print("Got exception: %s\n while creating the NFS PVC", e) return False return True @@ -252,14 +253,14 @@ class NFSDataStore(CheckpointDataStore): try: k8s_apps_v1_api_client.delete_namespaced_deployment(self.params.name, self.params.namespace, del_options) except k8sclient.rest.ApiException as e: - print("Got exception: %s\n while deleting nfs-server", e) + screen.print("Got exception: %s\n while deleting nfs-server", e) return False k8s_core_v1_api_client = k8sclient.CoreV1Api() try: k8s_core_v1_api_client.delete_namespaced_service(self.params.svc_name, self.params.namespace, del_options) except k8sclient.rest.ApiException as e: - print("Got exception: %s\n while deleting the service for nfs-server", e) + screen.print("Got exception: %s\n while deleting the service for nfs-server", e) return False return True @@ -276,13 +277,13 @@ class NFSDataStore(CheckpointDataStore): try: k8s_api_client.delete_persistent_volume(self.params.pv_name, del_options) except k8sclient.rest.ApiException as e: - print("Got exception: %s\n while deleting NFS PV", e) + screen.print("Got exception: %s\n while deleting NFS PV", e) return False try: k8s_api_client.delete_namespaced_persistent_volume_claim(self.params.pvc_name, self.params.namespace, del_options) except k8sclient.rest.ApiException as e: - print("Got exception: %s\n while deleting NFS PVC", e) + screen.print("Got exception: %s\n while deleting NFS PVC", e) return False return True diff --git a/rl_coach/data_stores/s3_data_store.py b/rl_coach/data_stores/s3_data_store.py index 43b189a..b2567f2 100644 --- a/rl_coach/data_stores/s3_data_store.py +++ b/rl_coach/data_stores/s3_data_store.py @@ -22,6 +22,7 @@ from minio.error import S3Error from configparser import ConfigParser, Error from rl_coach.checkpoint import CheckpointStateFile from rl_coach.data_stores.data_store import SyncFiles +from rl_coach.logger import screen import os import time @@ -62,7 +63,7 @@ class S3DataStore(CheckpointDataStore): access_key = config.get('default', 'aws_access_key_id') secret_key = config.get('default', 'aws_secret_access_key') except Error as e: - print("Error when reading S3 credentials file: %s", e) + screen.print("Error when reading S3 credentials file: %s", e) else: access_key = os.environ.get('ACCESS_KEY_ID') secret_key = os.environ.get('SECRET_ACCESS_KEY') @@ -135,7 +136,7 @@ class S3DataStore(CheckpointDataStore): self.mc.fput_object(self.params.bucket_name, filename, os.path.join(self.params.expt_dir, 'gifs', filename)) except S3Error as e: - print("Got exception: %s\n while saving to S3", e) + screen.print("Got exception: %s\n while saving to S3", e) def load_from_store(self): """ @@ -191,7 +192,7 @@ class S3DataStore(CheckpointDataStore): self.mc.fget_object(obj.bucket_name, obj.object_name, filename) except S3Error as e: - print("Got exception: %s\n while loading from S3", e) + screen.print("Got exception: %s\n while loading from S3", e) def setup_checkpoint_dir(self, crd=None): if crd: diff --git a/rl_coach/logger.py b/rl_coach/logger.py index 37cc856..6831f70 100644 --- a/rl_coach/logger.py +++ b/rl_coach/logger.py @@ -60,6 +60,20 @@ class ScreenLogger(object): def __init__(self, name, use_colors=True): self.name = name self.set_use_colors(use_colors) + self.log_file = None + + def print(self, *text: str) -> None: + """ + Prints to console and as well as to log.txt + :param text: The text to print + :return: None + """ + if not self.log_file: + self.log_file = open(os.path.join(experiment_path, "log.txt"), "a") + self.log_file.write(",".join([t for t in text])) + self.log_file.write("\n") + self.log_file.flush() + print(*text, flush=True) def set_use_colors(self, use_colors): self._use_colors = use_colors @@ -79,12 +93,12 @@ class ScreenLogger(object): self._suffix = "" def separator(self): - print("") - print("--------------------------------") - print("") + self.print("") + self.print("--------------------------------") + self.print("") def log(self, data): - print(data) + self.print(data) def log_dict(self, data, prefix=""): timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S.%f') + ' ' @@ -93,25 +107,25 @@ class ScreenLogger(object): str += "{}{}{} - ".format(Colors.PURPLE, prefix, Colors.END) for k, v in data.items(): str += "{}{}: {}{} ".format(Colors.BLUE, k, Colors.END, v) - print(str) + self.print(str) else: logentries = [timestamp] for k, v in data.items(): logentries.append("{}={}".format(k, v)) logline = "{}> {}".format(prefix, ", ".join(logentries)) - print(logline) + self.print(logline) def log_title(self, title): - print("{}{}{}".format(self._prefix_title, title, self._suffix)) + self.print("{}{}{}".format(self._prefix_title, title, self._suffix)) def success(self, text): - print("{}{}{}".format(self._prefix_success, text, self._suffix)) + self.print("{}{}{}".format(self._prefix_success, text, self._suffix)) def warning(self, text): - print("{}{}{}".format(self._prefix_warning, text, self._suffix)) + self.print("{}{}{}".format(self._prefix_warning, text, self._suffix)) def error(self, text, crash=True): - print("{}{}{}".format(self._prefix_error, text, self._suffix)) + self.print("{}{}{}".format(self._prefix_error, text, self._suffix)) if crash: exit(1) @@ -167,9 +181,9 @@ class ScreenLogger(object): :return: None """ if self._use_colors: - print("\x1b]2;{}\x07".format(title)) + self.print("\x1b]2;{}\x07".format(title)) else: - print("Title: %s" % title) + self.print("Title: %s" % title) class BaseLogger(object): @@ -441,4 +455,4 @@ def get_experiment_path(experiment_name, initial_experiment_path=None, create_pa global screen -screen = ScreenLogger("") +screen = ScreenLogger(experiment_path) diff --git a/rl_coach/memories/backend/redis.py b/rl_coach/memories/backend/redis.py index 5face78..bc2c7eb 100644 --- a/rl_coach/memories/backend/redis.py +++ b/rl_coach/memories/backend/redis.py @@ -22,6 +22,7 @@ import time from rl_coach.memories.backend.memory import MemoryBackend, MemoryBackendParameters from rl_coach.core_types import Transition, Episode, EnvironmentSteps, EnvironmentEpisodes +from rl_coach.logger import screen class RedisPubSubMemoryBackendParameters(MemoryBackendParameters): @@ -115,10 +116,10 @@ class RedisPubSubBackend(MemoryBackend): config.load_kube_config() api_client = client.AppsV1Api() try: - print(self.params.orchestrator_params) + screen.print(self.params.orchestrator_params) api_client.create_namespaced_deployment(self.params.orchestrator_params['namespace'], deployment) except client.rest.ApiException as e: - print("Got exception: %s\n while creating redis-server", e) + screen.print("Got exception: %s\n while creating redis-server", e) return False core_v1_api = client.CoreV1Api() @@ -147,7 +148,7 @@ class RedisPubSubBackend(MemoryBackend): self.params.redis_port = 6379 return True except client.rest.ApiException as e: - print("Got exception: %s\n while creating a service for redis-server", e) + screen.print("Got exception: %s\n while creating a service for redis-server", e) return False def undeploy(self): @@ -164,13 +165,13 @@ class RedisPubSubBackend(MemoryBackend): try: api_client.delete_namespaced_deployment(self.redis_server_name, self.params.orchestrator_params['namespace'], delete_options) except client.rest.ApiException as e: - print("Got exception: %s\n while deleting redis-server", e) + screen.print("Got exception: %s\n while deleting redis-server", e) api_client = client.CoreV1Api() try: api_client.delete_namespaced_service(self.redis_service_name, self.params.orchestrator_params['namespace'], delete_options) except client.rest.ApiException as e: - print("Got exception: %s\n while deleting redis-server", e) + screen.print("Got exception: %s\n while deleting redis-server", e) def sample(self, size): pass diff --git a/rl_coach/orchestrators/kubernetes_orchestrator.py b/rl_coach/orchestrators/kubernetes_orchestrator.py index caf6a71..c7300af 100644 --- a/rl_coach/orchestrators/kubernetes_orchestrator.py +++ b/rl_coach/orchestrators/kubernetes_orchestrator.py @@ -32,6 +32,7 @@ from rl_coach.memories.backend.memory import MemoryBackendParameters from rl_coach.memories.backend.memory_impl import get_memory_backend from rl_coach.data_stores.data_store import DataStoreParameters from rl_coach.data_stores.data_store_impl import get_data_store +from rl_coach.logger import screen class RunTypeParameters(): @@ -113,7 +114,7 @@ class Kubernetes(Deploy): self.s3_access_key = s3config.get('default', 'aws_access_key_id') self.s3_secret_key = s3config.get('default', 'aws_secret_access_key') except Error as e: - print("Error when reading S3 credentials file: %s", e) + screen.print("Error when reading S3 credentials file: %s", e) else: self.s3_access_key = os.environ.get('ACCESS_KEY_ID') self.s3_secret_key = os.environ.get('SECRET_ACCESS_KEY') @@ -244,7 +245,7 @@ class Kubernetes(Deploy): trainer_params.orchestration_params['job_name'] = name return True except k8sclient.rest.ApiException as e: - print("Got exception: %s\n while creating job", e) + screen.print("Got exception: %s\n while creating job", e) return False def deploy_worker(self): @@ -357,7 +358,7 @@ class Kubernetes(Deploy): worker_params.orchestration_params['job_name'] = name return True except k8sclient.rest.ApiException as e: - print("Got exception: %s\n while creating Job", e) + screen.print("Got exception: %s\n while creating Job", e) return False def worker_logs(self, path='./logs'): @@ -377,7 +378,7 @@ class Kubernetes(Deploy): # pod = pods.items[0] except k8sclient.rest.ApiException as e: - print("Got exception: %s\n while reading pods", e) + screen.print("Got exception: %s\n while reading pods", e) return if not pods or len(pods.items) == 0: @@ -410,7 +411,7 @@ class Kubernetes(Deploy): pod = pods.items[0] except k8sclient.rest.ApiException as e: - print("Got exception: %s\n while reading pods", e) + screen.print("Got exception: %s\n while reading pods", e) return if not pod: @@ -427,7 +428,7 @@ class Kubernetes(Deploy): pod_name, self.params.namespace, follow=True, _preload_content=False ): - print(line.decode('utf-8'), flush=True, end='') + screen.print(line.decode('utf-8'), flush=True, end='') except k8sclient.rest.ApiException as e: pass @@ -469,12 +470,12 @@ class Kubernetes(Deploy): try: api_client.delete_namespaced_job(trainer_params.orchestration_params['job_name'], self.params.namespace, delete_options) except k8sclient.rest.ApiException as e: - print("Got exception: %s\n while deleting trainer", e) + screen.print("Got exception: %s\n while deleting trainer", e) worker_params = self.params.run_type_params.get(str(RunType.ROLLOUT_WORKER), None) if worker_params: try: api_client.delete_namespaced_job(worker_params.orchestration_params['job_name'], self.params.namespace, delete_options) except k8sclient.rest.ApiException as e: - print("Got exception: %s\n while deleting workers", e) + screen.print("Got exception: %s\n while deleting workers", e) self.memory_backend.undeploy() self.data_store.undeploy() diff --git a/rl_coach/utils.py b/rl_coach/utils.py index 72714fa..6c3c453 100644 --- a/rl_coach/utils.py +++ b/rl_coach/utils.py @@ -458,7 +458,7 @@ class Timer(object): self.start = time.time() def __exit__(self, type, value, traceback): - print(self.prefix, time.time() - self.start) + screen.print(self.prefix, time.time() - self.start) class ReaderWriterLock(object): @@ -516,7 +516,7 @@ class ProgressBar(object): sys.stdout.flush() def close(self): - print("") + screen.print("") def start_shell_command_and_wait(command):