1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

temp commit

This commit is contained in:
Zach Dwiel
2018-02-16 09:35:58 -05:00
parent 16c5032735
commit 85afb86893
14 changed files with 244 additions and 127 deletions

View File

@@ -30,6 +30,14 @@ from subprocess import Popen
import datetime
import presets
import atexit
import sys
import subprocess
from threading import Thread
try:
from Queue import Queue, Empty
except ImportError:
from queue import Queue, Empty # for Python 3.x
if len(set(failed_imports)) > 0:
screen.warning("Warning: failed to import the following packages - {}".format(', '.join(set(failed_imports))))
@@ -152,6 +160,38 @@ def run_dict_to_json(_run_dict, task_id=''):
return json_path
def enqueue_output(out, queue):
for line in iter(out.readline, b''):
queue.put(line)
out.close()
def merge_streams(processes, output_stream=sys.stdout):
q = Queue()
threads = []
for p in processes:
threads.append(Thread(target=enqueue_output, args=(p.stdout, q)))
threads.append(Thread(target=enqueue_output, args=(p.stderr, q)))
for t in threads:
t.daemon = True
t.start()
while True:
try:
line = q.get_nowait()
except Empty:
# break when all processes are done and q is empty
if all(p.poll() is not None for p in processes):
break
else:
# sys.stdout.write(line)
output_stream.write(line.decode(output_stream.encoding))
output_stream.flush()
print('All processes done')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-p', '--preset',
@@ -252,6 +292,8 @@ if __name__ == "__main__":
if not args.no_summary:
atexit.register(logger.print_summary)
set_cpu()
# Single-threaded runs
if run_dict['num_threads'] == 1:
# set tuning parameters
@@ -285,11 +327,13 @@ if __name__ == "__main__":
set_cpu()
# create a parameter server
Popen(["python3",
"./parallel_actor.py",
"--ps_hosts={}".format(ps_hosts),
"--worker_hosts={}".format(worker_hosts),
"--job_name=ps"])
parameter_server = Popen([
"python3",
"./parallel_actor.py",
"--ps_hosts={}".format(ps_hosts),
"--worker_hosts={}".format(worker_hosts),
"--job_name=ps",
], stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=1)
screen.log_title("*** Distributed Training ***")
time.sleep(1)
@@ -314,13 +358,15 @@ if __name__ == "__main__":
"--job_name=worker",
"--load_json={}".format(json_run_dict_path)]
p = Popen(workers_args)
p = Popen(workers_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=1)
if i != run_dict['num_threads']:
workers.append(p)
else:
evaluation_worker = p
merge_streams(workers + [parameter_server])
# wait for all workers
[w.wait() for w in workers]
evaluation_worker.kill()