1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

Update how save checkpoint secs arg is handled in distributed Coach. (#151)

This commit is contained in:
Balaji Subramaniam
2018-11-25 00:05:24 -08:00
committed by GitHub
parent de9b707fe1
commit 8df425b6e1

View File

@@ -29,7 +29,7 @@ import time
import sys
import json
from rl_coach.base_parameters import Frameworks, VisualizationParameters, TaskParameters, DistributedTaskParameters, \
RunType
RunType, DistributedCoachSynchronizationType
from multiprocessing import Process
from multiprocessing.managers import BaseManager
import subprocess
@@ -336,9 +336,6 @@ class CoachLauncher(object):
if args.list:
self.display_all_presets_and_exit()
if args.distributed_coach and not args.checkpoint_save_secs:
screen.error("Distributed coach requires --checkpoint_save_secs or -s")
# Read args from config file for distributed Coach.
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
coach_config = ConfigParser({
@@ -578,6 +575,12 @@ class CoachLauncher(object):
if args.distributed_coach and not graph_manager.agent_params.algorithm.distributed_coach_synchronization_type:
screen.error("{} algorithm is not supported using distributed Coach.".format(graph_manager.agent_params.algorithm))
if args.distributed_coach and args.checkpoint_save_secs and graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
screen.warning("The --checkpoint_save_secs or -s argument will be ignored as SYNC distributed coach sync type is used. Checkpoint will be saved every training iteration.")
if args.distributed_coach and not args.checkpoint_save_secs and graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.ASYNC:
screen.error("Distributed coach with ASYNC distributed coach sync type requires --checkpoint_save_secs or -s.")
# Intel optimized TF seems to run significantly faster when limiting to a single OMP thread.
# This will not affect GPU runs.
os.environ["OMP_NUM_THREADS"] = "1"