mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 19:50:17 +01:00
Adding worker logs and plumbed task_parameters to distributed coach (#130)
This commit is contained in:
committed by
Balaji Subramaniam
parent
2b4c9c6774
commit
4a6c404070
@@ -84,7 +84,7 @@ def start_graph(graph_manager: 'GraphManager', task_parameters: 'TaskParameters'
|
|||||||
graph_manager.close()
|
graph_manager.close()
|
||||||
|
|
||||||
|
|
||||||
def handle_distributed_coach_tasks(graph_manager, args):
|
def handle_distributed_coach_tasks(graph_manager, args, task_parameters):
|
||||||
ckpt_inside_container = "/checkpoint"
|
ckpt_inside_container = "/checkpoint"
|
||||||
|
|
||||||
memory_backend_params = None
|
memory_backend_params = None
|
||||||
@@ -100,22 +100,24 @@ def handle_distributed_coach_tasks(graph_manager, args):
|
|||||||
graph_manager.data_store_params = data_store_params
|
graph_manager.data_store_params = data_store_params
|
||||||
|
|
||||||
if args.distributed_coach_run_type == RunType.TRAINER:
|
if args.distributed_coach_run_type == RunType.TRAINER:
|
||||||
|
task_parameters.checkpoint_save_dir = ckpt_inside_container
|
||||||
training_worker(
|
training_worker(
|
||||||
graph_manager=graph_manager,
|
graph_manager=graph_manager,
|
||||||
checkpoint_dir=ckpt_inside_container
|
task_parameters=task_parameters
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER:
|
if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER:
|
||||||
|
task_parameters.checkpoint_restore_dir = ckpt_inside_container
|
||||||
|
|
||||||
data_store = None
|
data_store = None
|
||||||
if args.data_store_params:
|
if args.data_store_params:
|
||||||
data_store = get_data_store(data_store_params)
|
data_store = get_data_store(data_store_params)
|
||||||
wait_for_checkpoint(checkpoint_dir=ckpt_inside_container, data_store=data_store)
|
|
||||||
|
|
||||||
rollout_worker(
|
rollout_worker(
|
||||||
graph_manager=graph_manager,
|
graph_manager=graph_manager,
|
||||||
checkpoint_dir=ckpt_inside_container,
|
|
||||||
data_store=data_store,
|
data_store=data_store,
|
||||||
num_workers=args.num_workers
|
num_workers=args.num_workers,
|
||||||
|
task_parameters=task_parameters
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -124,8 +126,16 @@ def handle_distributed_coach_orchestrator(args):
|
|||||||
RunTypeParameters
|
RunTypeParameters
|
||||||
|
|
||||||
ckpt_inside_container = "/checkpoint"
|
ckpt_inside_container = "/checkpoint"
|
||||||
rollout_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.ROLLOUT_WORKER)] + sys.argv[1:]
|
arg_list = sys.argv[1:]
|
||||||
trainer_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.TRAINER)] + sys.argv[1:]
|
try:
|
||||||
|
i = arg_list.index('--distributed_coach_run_type')
|
||||||
|
arg_list.pop(i)
|
||||||
|
arg_list.pop(i)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
trainer_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.TRAINER)] + arg_list
|
||||||
|
rollout_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.ROLLOUT_WORKER)] + arg_list
|
||||||
|
|
||||||
if '--experiment_name' not in rollout_command:
|
if '--experiment_name' not in rollout_command:
|
||||||
rollout_command = rollout_command + ['--experiment_name', args.experiment_name]
|
rollout_command = rollout_command + ['--experiment_name', args.experiment_name]
|
||||||
@@ -170,6 +180,10 @@ def handle_distributed_coach_orchestrator(args):
|
|||||||
print("Could not deploy rollout worker(s).")
|
print("Could not deploy rollout worker(s).")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if args.dump_worker_logs:
|
||||||
|
screen.log_title("Dumping rollout worker logs in: {}".format(args.experiment_path))
|
||||||
|
orchestrator.worker_logs(path=args.experiment_path)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
orchestrator.trainer_logs()
|
orchestrator.trainer_logs()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
@@ -321,6 +335,9 @@ class CoachLauncher(object):
|
|||||||
if args.list:
|
if args.list:
|
||||||
self.display_all_presets_and_exit()
|
self.display_all_presets_and_exit()
|
||||||
|
|
||||||
|
if args.distributed_coach and not args.checkpoint_save_secs:
|
||||||
|
screen.error("Distributed coach requires --checkpoint_save_secs or -s")
|
||||||
|
|
||||||
# Read args from config file for distributed Coach.
|
# Read args from config file for distributed Coach.
|
||||||
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
|
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
|
||||||
coach_config = ConfigParser({
|
coach_config = ConfigParser({
|
||||||
@@ -550,6 +567,9 @@ class CoachLauncher(object):
|
|||||||
"target success rate as set by the environment or a custom success rate as defined "
|
"target success rate as set by the environment or a custom success rate as defined "
|
||||||
"in the preset. ",
|
"in the preset. ",
|
||||||
action='store_true')
|
action='store_true')
|
||||||
|
parser.add_argument('--dump_worker_logs',
|
||||||
|
help="(flag) Only used in distributed coach. If set, the worker logs are saved in the experiment dir",
|
||||||
|
action='store_true')
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@@ -570,26 +590,6 @@ class CoachLauncher(object):
|
|||||||
atexit.register(logger.summarize_experiment)
|
atexit.register(logger.summarize_experiment)
|
||||||
screen.change_terminal_title(args.experiment_name)
|
screen.change_terminal_title(args.experiment_name)
|
||||||
|
|
||||||
# open dashboard
|
|
||||||
if args.open_dashboard:
|
|
||||||
open_dashboard(args.experiment_path)
|
|
||||||
|
|
||||||
if args.distributed_coach and args.distributed_coach_run_type != RunType.ORCHESTRATOR:
|
|
||||||
handle_distributed_coach_tasks(graph_manager, args)
|
|
||||||
return
|
|
||||||
|
|
||||||
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
|
|
||||||
handle_distributed_coach_orchestrator(args)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Single-threaded runs
|
|
||||||
if args.num_workers == 1:
|
|
||||||
self.start_single_threaded(graph_manager, args)
|
|
||||||
else:
|
|
||||||
self.start_multi_threaded(graph_manager, args)
|
|
||||||
|
|
||||||
def start_single_threaded(self, graph_manager: 'GraphManager', args: argparse.Namespace):
|
|
||||||
# Start the training or evaluation
|
|
||||||
task_parameters = TaskParameters(
|
task_parameters = TaskParameters(
|
||||||
framework_type=args.framework,
|
framework_type=args.framework,
|
||||||
evaluate_only=args.evaluate,
|
evaluate_only=args.evaluate,
|
||||||
@@ -603,6 +603,26 @@ class CoachLauncher(object):
|
|||||||
apply_stop_condition=args.apply_stop_condition
|
apply_stop_condition=args.apply_stop_condition
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# open dashboard
|
||||||
|
if args.open_dashboard:
|
||||||
|
open_dashboard(args.experiment_path)
|
||||||
|
|
||||||
|
if args.distributed_coach and args.distributed_coach_run_type != RunType.ORCHESTRATOR:
|
||||||
|
handle_distributed_coach_tasks(graph_manager, args, task_parameters)
|
||||||
|
return
|
||||||
|
|
||||||
|
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
|
||||||
|
handle_distributed_coach_orchestrator(args)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Single-threaded runs
|
||||||
|
if args.num_workers == 1:
|
||||||
|
self.start_single_threaded(task_parameters, graph_manager, args)
|
||||||
|
else:
|
||||||
|
self.start_multi_threaded(graph_manager, args)
|
||||||
|
|
||||||
|
def start_single_threaded(self, task_parameters, graph_manager: 'GraphManager', args: argparse.Namespace):
|
||||||
|
# Start the training or evaluation
|
||||||
start_graph(graph_manager=graph_manager, task_parameters=task_parameters)
|
start_graph(graph_manager=graph_manager, task_parameters=task_parameters)
|
||||||
|
|
||||||
def start_multi_threaded(self, graph_manager: 'GraphManager', args: argparse.Namespace):
|
def start_multi_threaded(self, graph_manager: 'GraphManager', args: argparse.Namespace):
|
||||||
|
|||||||
@@ -111,6 +111,7 @@ class RedisPubSubBackend(MemoryBackend):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def undeploy(self):
|
def undeploy(self):
|
||||||
|
from kubernetes import client
|
||||||
if self.params.deployed:
|
if self.params.deployed:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@@ -2,9 +2,11 @@ import os
|
|||||||
import uuid
|
import uuid
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
import sys
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List
|
from typing import List
|
||||||
from configparser import ConfigParser, Error
|
from configparser import ConfigParser, Error
|
||||||
|
from multiprocessing import Process
|
||||||
|
|
||||||
from rl_coach.base_parameters import RunType
|
from rl_coach.base_parameters import RunType
|
||||||
from rl_coach.orchestrators.deploy import Deploy, DeployParameters
|
from rl_coach.orchestrators.deploy import Deploy, DeployParameters
|
||||||
@@ -255,8 +257,35 @@ class Kubernetes(Deploy):
|
|||||||
print("Got exception: %s\n while creating Job", e)
|
print("Got exception: %s\n while creating Job", e)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def worker_logs(self):
|
def worker_logs(self, path='./logs'):
|
||||||
pass
|
worker_params = self.params.run_type_params.get(str(RunType.ROLLOUT_WORKER), None)
|
||||||
|
if not worker_params:
|
||||||
|
return
|
||||||
|
|
||||||
|
api_client = k8sclient.CoreV1Api()
|
||||||
|
pods = None
|
||||||
|
try:
|
||||||
|
pods = api_client.list_namespaced_pod(self.params.namespace, label_selector='app={}'.format(
|
||||||
|
worker_params.orchestration_params['job_name']
|
||||||
|
))
|
||||||
|
|
||||||
|
# pod = pods.items[0]
|
||||||
|
except k8sclient.rest.ApiException as e:
|
||||||
|
print("Got exception: %s\n while reading pods", e)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not pods or len(pods.items) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
for pod in pods.items:
|
||||||
|
Process(target=self._tail_log_file, args=(pod.metadata.name, api_client, self.params.namespace, path)).start()
|
||||||
|
|
||||||
|
def _tail_log_file(self, pod_name, api_client, namespace, path):
|
||||||
|
if not os.path.exists(path):
|
||||||
|
os.mkdir(path)
|
||||||
|
|
||||||
|
sys.stdout = open(os.path.join(path, pod_name), 'w')
|
||||||
|
self.tail_log(pod_name, api_client)
|
||||||
|
|
||||||
def trainer_logs(self):
|
def trainer_logs(self):
|
||||||
trainer_params = self.params.run_type_params.get(str(RunType.TRAINER), None)
|
trainer_params = self.params.run_type_params.get(str(RunType.TRAINER), None)
|
||||||
|
|||||||
@@ -68,21 +68,17 @@ def get_latest_checkpoint(checkpoint_dir):
|
|||||||
rel_path = os.path.relpath(ckpt.model_checkpoint_path, checkpoint_dir)
|
rel_path = os.path.relpath(ckpt.model_checkpoint_path, checkpoint_dir)
|
||||||
return int(rel_path.split('_Step')[0])
|
return int(rel_path.split('_Step')[0])
|
||||||
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
def should_stop(checkpoint_dir):
|
def should_stop(checkpoint_dir):
|
||||||
return os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value))
|
return os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value))
|
||||||
|
|
||||||
|
|
||||||
def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers):
|
def rollout_worker(graph_manager, data_store, num_workers, task_parameters):
|
||||||
"""
|
"""
|
||||||
wait for first checkpoint then perform rollouts using the model
|
wait for first checkpoint then perform rollouts using the model
|
||||||
"""
|
"""
|
||||||
wait_for_checkpoint(checkpoint_dir)
|
checkpoint_dir = task_parameters.checkpoint_restore_dir
|
||||||
|
wait_for_checkpoint(checkpoint_dir, data_store)
|
||||||
task_parameters = TaskParameters()
|
|
||||||
task_parameters.__dict__['checkpoint_restore_dir'] = checkpoint_dir
|
|
||||||
|
|
||||||
graph_manager.create_graph(task_parameters)
|
graph_manager.create_graph(task_parameters)
|
||||||
with graph_manager.phase_context(RunPhase.TRAIN):
|
with graph_manager.phase_context(RunPhase.TRAIN):
|
||||||
|
|||||||
@@ -12,14 +12,11 @@ def data_store_ckpt_save(data_store):
|
|||||||
time.sleep(10)
|
time.sleep(10)
|
||||||
|
|
||||||
|
|
||||||
def training_worker(graph_manager, checkpoint_dir):
|
def training_worker(graph_manager, task_parameters):
|
||||||
"""
|
"""
|
||||||
restore a checkpoint then perform rollouts using the restored model
|
restore a checkpoint then perform rollouts using the restored model
|
||||||
"""
|
"""
|
||||||
# initialize graph
|
# initialize graph
|
||||||
task_parameters = TaskParameters()
|
|
||||||
task_parameters.__dict__['checkpoint_save_dir'] = checkpoint_dir
|
|
||||||
task_parameters.__dict__['checkpoint_save_secs'] = 20
|
|
||||||
graph_manager.create_graph(task_parameters)
|
graph_manager.create_graph(task_parameters)
|
||||||
|
|
||||||
# save randomly initialized graph
|
# save randomly initialized graph
|
||||||
|
|||||||
Reference in New Issue
Block a user