1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 19:50:17 +01:00

Changed run_multiple_seeds to support mxnet. And fix other bugs. (#122)

This commit is contained in:
Thom Lane
2018-11-24 22:33:09 -08:00
committed by Gal Leibovich
parent 77fb561668
commit de9b707fe1

View File

@@ -63,6 +63,10 @@ if __name__ == "__main__":
parser.add_argument('-c', '--use_cpu', parser.add_argument('-c', '--use_cpu',
help="(flag) Use the cpu instead of the gpu", help="(flag) Use the cpu instead of the gpu",
action='store_true') action='store_true')
parser.add_argument('-f', '--framework',
help="(string) Neural network framework. Available values: tensorflow, mxnet",
default='tensorflow',
type=str)
args = parser.parse_args() args = parser.parse_args()
# dir_prefix = "benchmark_" # dir_prefix = "benchmark_"
@@ -83,6 +87,7 @@ if __name__ == "__main__":
levels = args.level.split(',') if args.level is not None else [None] levels = args.level.split(',') if args.level is not None else [None]
num_seeds = args.seeds num_seeds = args.seeds
num_workers = args.num_workers num_workers = args.num_workers
framework = args.framework
gpu = [int(gpu) for gpu in args.gpu.split(',')] gpu = [int(gpu) for gpu in args.gpu.split(',')]
level_as_sub_dir = args.level_as_sub_dir level_as_sub_dir = args.level_as_sub_dir
@@ -97,7 +102,8 @@ if __name__ == "__main__":
set_gpu(gpu_list[curr_gpu_idx]) set_gpu(gpu_list[curr_gpu_idx])
command = ['python3', 'rl_coach/coach.py', '-ns', '-p', '{}'.format(preset), command = ['python3', 'rl_coach/coach.py', '-ns', '-p', '{}'.format(preset),
'--seed', '{}'.format(seed), '-n', '{}'.format(num_workers)] '--seed', '{}'.format(seed), '-n', '{}'.format(num_workers),
'--framework', framework]
if args.use_cpu: if args.use_cpu:
command.append("-c") command.append("-c")
if args.evaluation_worker: if args.evaluation_worker:
@@ -108,11 +114,12 @@ if __name__ == "__main__":
separator = "/" separator = "/"
else: else:
separator = "_" separator = "_"
command.extend(['-e', '{dir_prefix}{preset}{separator}{level}_{num_workers}_workers'.format( command.extend(['-e', '{dir_prefix}{preset}_{seed}_{separator}{level}_{num_workers}_workers'.format(
dir_prefix=dir_prefix, preset=preset, level=level, separator=separator, num_workers=args.num_workers)]) dir_prefix=dir_prefix, preset=preset, seed=seed, level=level, separator=separator,
num_workers=args.num_workers)])
else: else:
command.extend(['-e', '{dir_prefix}{preset}_{num_workers}_workers'.format( command.extend(['-e', '{dir_prefix}{preset}_{seed}_{num_workers}_workers'.format(
dir_prefix=dir_prefix, preset=preset, num_workers=args.num_workers)]) dir_prefix=dir_prefix, preset=preset, seed=seed, num_workers=args.num_workers)])
print(command) print(command)
p = Popen(command) p = Popen(command)