1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

Changed run_multiple_seeds to support mxnet. And fix other bugs. (#122)

This commit is contained in:
Thom Lane
2018-11-24 22:33:09 -08:00
committed by Gal Leibovich
parent 77fb561668
commit de9b707fe1

View File

@@ -63,6 +63,10 @@ if __name__ == "__main__":
parser.add_argument('-c', '--use_cpu',
help="(flag) Use the cpu instead of the gpu",
action='store_true')
parser.add_argument('-f', '--framework',
help="(string) Neural network framework. Available values: tensorflow, mxnet",
default='tensorflow',
type=str)
args = parser.parse_args()
# dir_prefix = "benchmark_"
@@ -83,6 +87,7 @@ if __name__ == "__main__":
levels = args.level.split(',') if args.level is not None else [None]
num_seeds = args.seeds
num_workers = args.num_workers
framework = args.framework
gpu = [int(gpu) for gpu in args.gpu.split(',')]
level_as_sub_dir = args.level_as_sub_dir
@@ -97,7 +102,8 @@ if __name__ == "__main__":
set_gpu(gpu_list[curr_gpu_idx])
command = ['python3', 'rl_coach/coach.py', '-ns', '-p', '{}'.format(preset),
'--seed', '{}'.format(seed), '-n', '{}'.format(num_workers)]
'--seed', '{}'.format(seed), '-n', '{}'.format(num_workers),
'--framework', framework]
if args.use_cpu:
command.append("-c")
if args.evaluation_worker:
@@ -108,11 +114,12 @@ if __name__ == "__main__":
separator = "/"
else:
separator = "_"
command.extend(['-e', '{dir_prefix}{preset}{separator}{level}_{num_workers}_workers'.format(
dir_prefix=dir_prefix, preset=preset, level=level, separator=separator, num_workers=args.num_workers)])
command.extend(['-e', '{dir_prefix}{preset}_{seed}_{separator}{level}_{num_workers}_workers'.format(
dir_prefix=dir_prefix, preset=preset, seed=seed, level=level, separator=separator,
num_workers=args.num_workers)])
else:
command.extend(['-e', '{dir_prefix}{preset}_{num_workers}_workers'.format(
dir_prefix=dir_prefix, preset=preset, num_workers=args.num_workers)])
command.extend(['-e', '{dir_prefix}{preset}_{seed}_{num_workers}_workers'.format(
dir_prefix=dir_prefix, preset=preset, seed=seed, num_workers=args.num_workers)])
print(command)
p = Popen(command)