From fb1039fcb514cb10278a77eb1356aaa8b3dfbe1a Mon Sep 17 00:00:00 2001 From: Ajay Deshpande Date: Mon, 8 Oct 2018 17:49:40 -0700 Subject: [PATCH] Checkpoint and evaluation optimizations --- .../orchestrators/kubernetes_orchestrator.py | 1 + rl_coach/orchestrators/start_training.py | 13 +++-- rl_coach/rollout_worker.py | 49 ++++++++++++------- rl_coach/training_worker.py | 24 +++++++-- 4 files changed, 61 insertions(+), 26 deletions(-) diff --git a/rl_coach/orchestrators/kubernetes_orchestrator.py b/rl_coach/orchestrators/kubernetes_orchestrator.py index 229a26e..8f5f405 100644 --- a/rl_coach/orchestrators/kubernetes_orchestrator.py +++ b/rl_coach/orchestrators/kubernetes_orchestrator.py @@ -179,6 +179,7 @@ class Kubernetes(Deploy): worker_params.command += ['--memory-backend-params', json.dumps(self.params.memory_backend_parameters.__dict__)] worker_params.command += ['--data-store-params', json.dumps(self.params.data_store_params.__dict__)] + worker_params.command += ['--num-workers', worker_params.num_replicas] name = "{}-{}".format(worker_params.run_type, uuid.uuid4()) diff --git a/rl_coach/orchestrators/start_training.py b/rl_coach/orchestrators/start_training.py index aea8237..6f93a7d 100644 --- a/rl_coach/orchestrators/start_training.py +++ b/rl_coach/orchestrators/start_training.py @@ -8,9 +8,10 @@ from rl_coach.data_stores.nfs_data_store import NFSDataStoreParameters def main(preset: str, image: str='ajaysudh/testing:coach', num_workers: int=1, nfs_server: str=None, nfs_path: str=None, - memory_backend: str=None, data_store: str=None, s3_end_point: str=None, s3_bucket_name: str=None): - rollout_command = ['python3', 'rl_coach/rollout_worker.py', '-p', preset] - training_command = ['python3', 'rl_coach/training_worker.py', '-p', preset] + memory_backend: str=None, data_store: str=None, s3_end_point: str=None, s3_bucket_name: str=None, + policy_type: str="OFF"): + rollout_command = ['python3', 'rl_coach/rollout_worker.py', '-p', preset, '--policy-type', policy_type] + training_command = ['python3', 'rl_coach/training_worker.py', '-p', preset, '--policy-type', policy_type] memory_backend_params = None if memory_backend == "redispubsub": @@ -95,6 +96,10 @@ if __name__ == '__main__': type=int, required=False, default=1) + parser.add_argument('--policy-type', + help="(string) The type of policy: OFF/ON", + type=str, + default='OFF') # parser.add_argument('--checkpoint_dir', # help='(string) Path to a folder containing a checkpoint to write the model to.', @@ -104,4 +109,4 @@ if __name__ == '__main__': main(preset=args.preset, image=args.image, nfs_server=args.nfs_server, nfs_path=args.nfs_path, memory_backend=args.memory_backend, data_store=args.data_store, s3_end_point=args.s3_end_point, - s3_bucket_name=args.s3_bucket_name, num_workers=args.num_workers) + s3_bucket_name=args.s3_bucket_name, num_workers=args.num_workers, policy_type=args.policy_type) diff --git a/rl_coach/rollout_worker.py b/rl_coach/rollout_worker.py index 744c92e..d2a97db 100644 --- a/rl_coach/rollout_worker.py +++ b/rl_coach/rollout_worker.py @@ -11,6 +11,7 @@ import argparse import time import os import json +import math from threading import Thread @@ -69,20 +70,16 @@ def data_store_ckpt_load(data_store): time.sleep(10) -def check_for_new_checkpoint(checkpoint_dir, last_checkpoint): +def get_latest_checkpoint(checkpoint_dir): if os.path.exists(os.path.join(checkpoint_dir, 'checkpoint')): ckpt = CheckpointState() contents = open(os.path.join(checkpoint_dir, 'checkpoint'), 'r').read() text_format.Merge(contents, ckpt) rel_path = os.path.relpath(ckpt.model_checkpoint_path, checkpoint_dir) - current_checkpoint = int(rel_path.split('_Step')[0]) - if current_checkpoint > last_checkpoint: - last_checkpoint = current_checkpoint - - return last_checkpoint + return int(rel_path.split('_Step')[0]) -def rollout_worker(graph_manager, checkpoint_dir, data_store): +def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers, policy_type): """ wait for first checkpoint then perform rollouts using the model """ @@ -98,22 +95,28 @@ def rollout_worker(graph_manager, checkpoint_dir, data_store): last_checkpoint = 0 - act_steps = graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps + error_compensation - - print(act_steps, graph_manager.improve_steps.num_steps) + act_steps = math.ceil((graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps + error_compensation)/num_workers) for i in range(int(graph_manager.improve_steps.num_steps/act_steps)): graph_manager.act(EnvironmentSteps(num_steps=act_steps)) - new_checkpoint = last_checkpoint + 1 - while last_checkpoint < new_checkpoint: - if data_store: - data_store.load_from_store() - last_checkpoint = check_for_new_checkpoint(checkpoint_dir, last_checkpoint) + new_checkpoint = get_latest_checkpoint(checkpoint_dir) + + if policy_type == 'ON': + while new_checkpoint < last_checkpoint + 1: + if data_store: + data_store.load_from_store() + new_checkpoint = get_latest_checkpoint(checkpoint_dir) + + graph_manager.restore_checkpoint() + + if policy_type == "OFF": + + if new_checkpoint > last_checkpoint: + graph_manager.restore_checkpoint() last_checkpoint = new_checkpoint - graph_manager.restore_checkpoint() graph_manager.phase = RunPhase.UNDEFINED @@ -134,6 +137,14 @@ def main(): parser.add_argument('--data-store-params', help="(string) JSON string of the data store params", type=str) + parser.add_argument('--num-workers', + help="(int) The number of workers started in this pool", + type=int, + default=1) + parser.add_argument('--policy-type', + help="(string) The type of policy: OFF/ON", + type=str, + default='OFF') args = parser.parse_args() @@ -142,9 +153,7 @@ def main(): data_store = None if args.memory_backend_params: args.memory_backend_params = json.loads(args.memory_backend_params) - print(args.memory_backend_params) args.memory_backend_params['run_type'] = 'worker' - print(construct_memory_params(args.memory_backend_params)) graph_manager.agent_params.memory.register_var('memory_backend_params', construct_memory_params(args.memory_backend_params)) if args.data_store_params: @@ -159,7 +168,9 @@ def main(): rollout_worker( graph_manager=graph_manager, checkpoint_dir=args.checkpoint_dir, - data_store=data_store + data_store=data_store, + num_workers=args.num_workers, + policy_type=args.policy_type ) if __name__ == '__main__': diff --git a/rl_coach/training_worker.py b/rl_coach/training_worker.py index ba15df1..4a6e769 100644 --- a/rl_coach/training_worker.py +++ b/rl_coach/training_worker.py @@ -18,13 +18,14 @@ def data_store_ckpt_save(data_store): data_store.save_to_store() time.sleep(10) -def training_worker(graph_manager, checkpoint_dir): +def training_worker(graph_manager, checkpoint_dir, policy_type): """ restore a checkpoint then perform rollouts using the restored model """ # initialize graph task_parameters = TaskParameters() task_parameters.__dict__['save_checkpoint_dir'] = checkpoint_dir + task_parameters.__dict__['save_checkpoint_secs'] = 60 graph_manager.create_graph(task_parameters) # save randomly initialized graph @@ -32,14 +33,26 @@ def training_worker(graph_manager, checkpoint_dir): # training loop steps = 0 + + # evaluation offset + eval_offset = 1 + while(steps < graph_manager.improve_steps.num_steps): if graph_manager.should_train(): steps += 1 + graph_manager.phase = core_types.RunPhase.TRAIN graph_manager.train(core_types.TrainingSteps(1)) graph_manager.phase = core_types.RunPhase.UNDEFINED - graph_manager.evaluate(graph_manager.evaluation_steps) - graph_manager.save_checkpoint() + + if steps * graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps > graph_manager.steps_between_evaluation_periods.num_steps * eval_offset: + graph_manager.evaluate(graph_manager.evaluation_steps) + eval_offset += 1 + + if policy_type == 'ON': + graph_manager.save_checkpoint() + else: + graph_manager.occasionally_save_checkpoint() def main(): @@ -58,6 +71,10 @@ def main(): parser.add_argument('--data-store-params', help="(string) JSON string of the data store params", type=str) + parser.add_argument('--policy-type', + help="(string) The type of policy: OFF/ON", + type=str, + default='OFF') args = parser.parse_args() graph_manager = short_dynamic_import(expand_preset(args.preset), ignore_module_case=True) @@ -78,6 +95,7 @@ def main(): training_worker( graph_manager=graph_manager, checkpoint_dir=args.checkpoint_dir, + policy_type=args.policy_type )