mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
temp commit
This commit is contained in:
58
coach.py
58
coach.py
@@ -30,6 +30,14 @@ from subprocess import Popen
|
||||
import datetime
|
||||
import presets
|
||||
import atexit
|
||||
import sys
|
||||
import subprocess
|
||||
from threading import Thread
|
||||
|
||||
try:
|
||||
from Queue import Queue, Empty
|
||||
except ImportError:
|
||||
from queue import Queue, Empty # for Python 3.x
|
||||
|
||||
if len(set(failed_imports)) > 0:
|
||||
screen.warning("Warning: failed to import the following packages - {}".format(', '.join(set(failed_imports))))
|
||||
@@ -152,6 +160,38 @@ def run_dict_to_json(_run_dict, task_id=''):
|
||||
return json_path
|
||||
|
||||
|
||||
def enqueue_output(out, queue):
|
||||
for line in iter(out.readline, b''):
|
||||
queue.put(line)
|
||||
out.close()
|
||||
|
||||
|
||||
def merge_streams(processes, output_stream=sys.stdout):
|
||||
q = Queue()
|
||||
threads = []
|
||||
for p in processes:
|
||||
threads.append(Thread(target=enqueue_output, args=(p.stdout, q)))
|
||||
threads.append(Thread(target=enqueue_output, args=(p.stderr, q)))
|
||||
|
||||
for t in threads:
|
||||
t.daemon = True
|
||||
t.start()
|
||||
|
||||
while True:
|
||||
try:
|
||||
line = q.get_nowait()
|
||||
except Empty:
|
||||
# break when all processes are done and q is empty
|
||||
if all(p.poll() is not None for p in processes):
|
||||
break
|
||||
else:
|
||||
# sys.stdout.write(line)
|
||||
output_stream.write(line.decode(output_stream.encoding))
|
||||
output_stream.flush()
|
||||
|
||||
print('All processes done')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-p', '--preset',
|
||||
@@ -252,6 +292,8 @@ if __name__ == "__main__":
|
||||
if not args.no_summary:
|
||||
atexit.register(logger.print_summary)
|
||||
|
||||
set_cpu()
|
||||
|
||||
# Single-threaded runs
|
||||
if run_dict['num_threads'] == 1:
|
||||
# set tuning parameters
|
||||
@@ -285,11 +327,13 @@ if __name__ == "__main__":
|
||||
set_cpu()
|
||||
|
||||
# create a parameter server
|
||||
Popen(["python3",
|
||||
"./parallel_actor.py",
|
||||
"--ps_hosts={}".format(ps_hosts),
|
||||
"--worker_hosts={}".format(worker_hosts),
|
||||
"--job_name=ps"])
|
||||
parameter_server = Popen([
|
||||
"python3",
|
||||
"./parallel_actor.py",
|
||||
"--ps_hosts={}".format(ps_hosts),
|
||||
"--worker_hosts={}".format(worker_hosts),
|
||||
"--job_name=ps",
|
||||
], stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=1)
|
||||
|
||||
screen.log_title("*** Distributed Training ***")
|
||||
time.sleep(1)
|
||||
@@ -314,13 +358,15 @@ if __name__ == "__main__":
|
||||
"--job_name=worker",
|
||||
"--load_json={}".format(json_run_dict_path)]
|
||||
|
||||
p = Popen(workers_args)
|
||||
p = Popen(workers_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=1)
|
||||
|
||||
if i != run_dict['num_threads']:
|
||||
workers.append(p)
|
||||
else:
|
||||
evaluation_worker = p
|
||||
|
||||
merge_streams(workers + [parameter_server])
|
||||
|
||||
# wait for all workers
|
||||
[w.wait() for w in workers]
|
||||
evaluation_worker.kill()
|
||||
|
||||
Reference in New Issue
Block a user