mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Changed run_multiple_seeds to support mxnet. And fix other bugs. (#122)
This commit is contained in:
@@ -63,6 +63,10 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument('-c', '--use_cpu',
|
parser.add_argument('-c', '--use_cpu',
|
||||||
help="(flag) Use the cpu instead of the gpu",
|
help="(flag) Use the cpu instead of the gpu",
|
||||||
action='store_true')
|
action='store_true')
|
||||||
|
parser.add_argument('-f', '--framework',
|
||||||
|
help="(string) Neural network framework. Available values: tensorflow, mxnet",
|
||||||
|
default='tensorflow',
|
||||||
|
type=str)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# dir_prefix = "benchmark_"
|
# dir_prefix = "benchmark_"
|
||||||
@@ -83,6 +87,7 @@ if __name__ == "__main__":
|
|||||||
levels = args.level.split(',') if args.level is not None else [None]
|
levels = args.level.split(',') if args.level is not None else [None]
|
||||||
num_seeds = args.seeds
|
num_seeds = args.seeds
|
||||||
num_workers = args.num_workers
|
num_workers = args.num_workers
|
||||||
|
framework = args.framework
|
||||||
gpu = [int(gpu) for gpu in args.gpu.split(',')]
|
gpu = [int(gpu) for gpu in args.gpu.split(',')]
|
||||||
level_as_sub_dir = args.level_as_sub_dir
|
level_as_sub_dir = args.level_as_sub_dir
|
||||||
|
|
||||||
@@ -97,7 +102,8 @@ if __name__ == "__main__":
|
|||||||
set_gpu(gpu_list[curr_gpu_idx])
|
set_gpu(gpu_list[curr_gpu_idx])
|
||||||
|
|
||||||
command = ['python3', 'rl_coach/coach.py', '-ns', '-p', '{}'.format(preset),
|
command = ['python3', 'rl_coach/coach.py', '-ns', '-p', '{}'.format(preset),
|
||||||
'--seed', '{}'.format(seed), '-n', '{}'.format(num_workers)]
|
'--seed', '{}'.format(seed), '-n', '{}'.format(num_workers),
|
||||||
|
'--framework', framework]
|
||||||
if args.use_cpu:
|
if args.use_cpu:
|
||||||
command.append("-c")
|
command.append("-c")
|
||||||
if args.evaluation_worker:
|
if args.evaluation_worker:
|
||||||
@@ -108,11 +114,12 @@ if __name__ == "__main__":
|
|||||||
separator = "/"
|
separator = "/"
|
||||||
else:
|
else:
|
||||||
separator = "_"
|
separator = "_"
|
||||||
command.extend(['-e', '{dir_prefix}{preset}{separator}{level}_{num_workers}_workers'.format(
|
command.extend(['-e', '{dir_prefix}{preset}_{seed}_{separator}{level}_{num_workers}_workers'.format(
|
||||||
dir_prefix=dir_prefix, preset=preset, level=level, separator=separator, num_workers=args.num_workers)])
|
dir_prefix=dir_prefix, preset=preset, seed=seed, level=level, separator=separator,
|
||||||
|
num_workers=args.num_workers)])
|
||||||
else:
|
else:
|
||||||
command.extend(['-e', '{dir_prefix}{preset}_{num_workers}_workers'.format(
|
command.extend(['-e', '{dir_prefix}{preset}_{seed}_{num_workers}_workers'.format(
|
||||||
dir_prefix=dir_prefix, preset=preset, num_workers=args.num_workers)])
|
dir_prefix=dir_prefix, preset=preset, seed=seed, num_workers=args.num_workers)])
|
||||||
print(command)
|
print(command)
|
||||||
|
|
||||||
p = Popen(command)
|
p = Popen(command)
|
||||||
|
|||||||
Reference in New Issue
Block a user