1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

rename save_checkpoint_secs -> checkpoint_save_secs

This commit is contained in:
Zach Dwiel
2018-10-05 11:47:35 -04:00
committed by zach dwiel
parent 9804b033a2
commit 700a175902
6 changed files with 11 additions and 11 deletions

View File

@@ -75,13 +75,13 @@ def create_worker_server_and_device(cluster_spec: tf.train.ClusterSpec, task_ind
def create_monitored_session(target: tf.train.Server, task_index: int, def create_monitored_session(target: tf.train.Server, task_index: int,
checkpoint_dir: str, save_checkpoint_secs: int, config: tf.ConfigProto=None) -> tf.Session: checkpoint_dir: str, checkpoint_save_secs: int, config: tf.ConfigProto=None) -> tf.Session:
""" """
Create a monitored session for the worker Create a monitored session for the worker
:param target: the target string for the tf.Session :param target: the target string for the tf.Session
:param task_index: the task index of the worker :param task_index: the task index of the worker
:param checkpoint_dir: a directory path where the checkpoints will be stored :param checkpoint_dir: a directory path where the checkpoints will be stored
:param save_checkpoint_secs: number of seconds between checkpoints storing :param checkpoint_save_secs: number of seconds between checkpoints storing
:param config: the tensorflow configuration (optional) :param config: the tensorflow configuration (optional)
:return: the session to use for the run :return: the session to use for the run
""" """
@@ -94,7 +94,7 @@ def create_monitored_session(target: tf.train.Server, task_index: int,
is_chief=is_chief, is_chief=is_chief,
hooks=[], hooks=[],
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
save_checkpoint_secs=save_checkpoint_secs, checkpoint_save_secs=checkpoint_save_secs,
config=config config=config
) )

View File

@@ -176,7 +176,7 @@ def parse_arguments(parser: argparse.ArgumentParser) -> argparse.Namespace:
args.framework = Frameworks[args.framework.lower()] args.framework = Frameworks[args.framework.lower()]
# checkpoints # checkpoints
args.checkpoint_save_dir = os.path.join(args.experiment_path, 'checkpoint') if args.save_checkpoint_secs is not None else None args.checkpoint_save_dir = os.path.join(args.experiment_path, 'checkpoint') if args.checkpoint_save_secs is not None else None
return args return args
@@ -257,7 +257,7 @@ def main():
help="(flag) TensorFlow verbosity level", help="(flag) TensorFlow verbosity level",
default=3, default=3,
type=int) type=int)
parser.add_argument('-s', '--save_checkpoint_secs', parser.add_argument('-s', '--checkpoint_save_secs',
help="(int) Time in seconds between saving checkpoints of the model.", help="(int) Time in seconds between saving checkpoints of the model.",
default=None, default=None,
type=int) type=int)

View File

@@ -208,7 +208,7 @@ class GraphManager(object):
self.sess = create_monitored_session(target=task_parameters.worker_target, self.sess = create_monitored_session(target=task_parameters.worker_target,
task_index=task_parameters.task_index, task_index=task_parameters.task_index,
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
save_checkpoint_secs=task_parameters.save_checkpoint_secs, checkpoint_save_secs=task_parameters.checkpoint_save_secs,
config=config) config=config)
# set the session for all the modules # set the session for all the modules
self.set_session(self.sess) self.set_session(self.sess)
@@ -490,8 +490,8 @@ class GraphManager(object):
def occasionally_save_checkpoint(self): def occasionally_save_checkpoint(self):
# only the chief process saves checkpoints # only the chief process saves checkpoints
if self.task_parameters.save_checkpoint_secs \ if self.task_parameters.checkpoint_save_secs \
and time.time() - self.last_checkpoint_saving_time >= self.task_parameters.save_checkpoint_secs \ and time.time() - self.last_checkpoint_saving_time >= self.task_parameters.checkpoint_save_secs \
and (self.task_parameters.task_index == 0 # distributed and (self.task_parameters.task_index == 0 # distributed
or self.task_parameters.task_index is None # single-worker or self.task_parameters.task_index is None # single-worker
): ):

View File

@@ -367,7 +367,7 @@
" evaluate_only=False,\n", " evaluate_only=False,\n",
" experiment_path=log_path)\n", " experiment_path=log_path)\n",
"\n", "\n",
"task_parameters.__dict__['save_checkpoint_secs'] = None\n", "task_parameters.__dict__['checkpoint_save_secs'] = None\n",
"\n", "\n",
"graph_manager.create_graph(task_parameters)\n", "graph_manager.create_graph(task_parameters)\n",
"\n", "\n",

View File

@@ -345,7 +345,7 @@
" evaluate_only=False,\n", " evaluate_only=False,\n",
" experiment_path=log_path)\n", " experiment_path=log_path)\n",
"\n", "\n",
"task_parameters.__dict__['save_checkpoint_secs'] = None\n", "task_parameters.__dict__['checkpoint_save_secs'] = None\n",
"\n", "\n",
"\n", "\n",
"graph_manager.create_graph(task_parameters)\n", "graph_manager.create_graph(task_parameters)\n",

View File

@@ -372,7 +372,7 @@
" evaluate_only=False,\n", " evaluate_only=False,\n",
" experiment_path=log_path)\n", " experiment_path=log_path)\n",
"\n", "\n",
"task_parameters.__dict__['save_checkpoint_secs'] = None\n", "task_parameters.__dict__['checkpoint_save_secs'] = None\n",
"task_parameters.__dict__['verbosity'] = 'low'\n", "task_parameters.__dict__['verbosity'] = 'low'\n",
"\n", "\n",
"graph_manager.create_graph(task_parameters)\n", "graph_manager.create_graph(task_parameters)\n",