mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
rename save_checkpoint_secs -> checkpoint_save_secs
This commit is contained in:
@@ -75,13 +75,13 @@ def create_worker_server_and_device(cluster_spec: tf.train.ClusterSpec, task_ind
|
||||
|
||||
|
||||
def create_monitored_session(target: tf.train.Server, task_index: int,
|
||||
checkpoint_dir: str, save_checkpoint_secs: int, config: tf.ConfigProto=None) -> tf.Session:
|
||||
checkpoint_dir: str, checkpoint_save_secs: int, config: tf.ConfigProto=None) -> tf.Session:
|
||||
"""
|
||||
Create a monitored session for the worker
|
||||
:param target: the target string for the tf.Session
|
||||
:param task_index: the task index of the worker
|
||||
:param checkpoint_dir: a directory path where the checkpoints will be stored
|
||||
:param save_checkpoint_secs: number of seconds between checkpoints storing
|
||||
:param checkpoint_save_secs: number of seconds between checkpoints storing
|
||||
:param config: the tensorflow configuration (optional)
|
||||
:return: the session to use for the run
|
||||
"""
|
||||
@@ -94,7 +94,7 @@ def create_monitored_session(target: tf.train.Server, task_index: int,
|
||||
is_chief=is_chief,
|
||||
hooks=[],
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
save_checkpoint_secs=save_checkpoint_secs,
|
||||
checkpoint_save_secs=checkpoint_save_secs,
|
||||
config=config
|
||||
)
|
||||
|
||||
|
||||
@@ -176,7 +176,7 @@ def parse_arguments(parser: argparse.ArgumentParser) -> argparse.Namespace:
|
||||
args.framework = Frameworks[args.framework.lower()]
|
||||
|
||||
# checkpoints
|
||||
args.checkpoint_save_dir = os.path.join(args.experiment_path, 'checkpoint') if args.save_checkpoint_secs is not None else None
|
||||
args.checkpoint_save_dir = os.path.join(args.experiment_path, 'checkpoint') if args.checkpoint_save_secs is not None else None
|
||||
|
||||
return args
|
||||
|
||||
@@ -257,7 +257,7 @@ def main():
|
||||
help="(flag) TensorFlow verbosity level",
|
||||
default=3,
|
||||
type=int)
|
||||
parser.add_argument('-s', '--save_checkpoint_secs',
|
||||
parser.add_argument('-s', '--checkpoint_save_secs',
|
||||
help="(int) Time in seconds between saving checkpoints of the model.",
|
||||
default=None,
|
||||
type=int)
|
||||
|
||||
@@ -208,7 +208,7 @@ class GraphManager(object):
|
||||
self.sess = create_monitored_session(target=task_parameters.worker_target,
|
||||
task_index=task_parameters.task_index,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
save_checkpoint_secs=task_parameters.save_checkpoint_secs,
|
||||
checkpoint_save_secs=task_parameters.checkpoint_save_secs,
|
||||
config=config)
|
||||
# set the session for all the modules
|
||||
self.set_session(self.sess)
|
||||
@@ -490,8 +490,8 @@ class GraphManager(object):
|
||||
|
||||
def occasionally_save_checkpoint(self):
|
||||
# only the chief process saves checkpoints
|
||||
if self.task_parameters.save_checkpoint_secs \
|
||||
and time.time() - self.last_checkpoint_saving_time >= self.task_parameters.save_checkpoint_secs \
|
||||
if self.task_parameters.checkpoint_save_secs \
|
||||
and time.time() - self.last_checkpoint_saving_time >= self.task_parameters.checkpoint_save_secs \
|
||||
and (self.task_parameters.task_index == 0 # distributed
|
||||
or self.task_parameters.task_index is None # single-worker
|
||||
):
|
||||
|
||||
@@ -367,7 +367,7 @@
|
||||
" evaluate_only=False,\n",
|
||||
" experiment_path=log_path)\n",
|
||||
"\n",
|
||||
"task_parameters.__dict__['save_checkpoint_secs'] = None\n",
|
||||
"task_parameters.__dict__['checkpoint_save_secs'] = None\n",
|
||||
"\n",
|
||||
"graph_manager.create_graph(task_parameters)\n",
|
||||
"\n",
|
||||
|
||||
@@ -345,7 +345,7 @@
|
||||
" evaluate_only=False,\n",
|
||||
" experiment_path=log_path)\n",
|
||||
"\n",
|
||||
"task_parameters.__dict__['save_checkpoint_secs'] = None\n",
|
||||
"task_parameters.__dict__['checkpoint_save_secs'] = None\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"graph_manager.create_graph(task_parameters)\n",
|
||||
|
||||
@@ -372,7 +372,7 @@
|
||||
" evaluate_only=False,\n",
|
||||
" experiment_path=log_path)\n",
|
||||
"\n",
|
||||
"task_parameters.__dict__['save_checkpoint_secs'] = None\n",
|
||||
"task_parameters.__dict__['checkpoint_save_secs'] = None\n",
|
||||
"task_parameters.__dict__['verbosity'] = 'low'\n",
|
||||
"\n",
|
||||
"graph_manager.create_graph(task_parameters)\n",
|
||||
|
||||
Reference in New Issue
Block a user