mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
rename save_checkpoint_secs -> checkpoint_save_secs
This commit is contained in:
@@ -75,13 +75,13 @@ def create_worker_server_and_device(cluster_spec: tf.train.ClusterSpec, task_ind
|
|||||||
|
|
||||||
|
|
||||||
def create_monitored_session(target: tf.train.Server, task_index: int,
|
def create_monitored_session(target: tf.train.Server, task_index: int,
|
||||||
checkpoint_dir: str, save_checkpoint_secs: int, config: tf.ConfigProto=None) -> tf.Session:
|
checkpoint_dir: str, checkpoint_save_secs: int, config: tf.ConfigProto=None) -> tf.Session:
|
||||||
"""
|
"""
|
||||||
Create a monitored session for the worker
|
Create a monitored session for the worker
|
||||||
:param target: the target string for the tf.Session
|
:param target: the target string for the tf.Session
|
||||||
:param task_index: the task index of the worker
|
:param task_index: the task index of the worker
|
||||||
:param checkpoint_dir: a directory path where the checkpoints will be stored
|
:param checkpoint_dir: a directory path where the checkpoints will be stored
|
||||||
:param save_checkpoint_secs: number of seconds between checkpoints storing
|
:param checkpoint_save_secs: number of seconds between checkpoints storing
|
||||||
:param config: the tensorflow configuration (optional)
|
:param config: the tensorflow configuration (optional)
|
||||||
:return: the session to use for the run
|
:return: the session to use for the run
|
||||||
"""
|
"""
|
||||||
@@ -94,7 +94,7 @@ def create_monitored_session(target: tf.train.Server, task_index: int,
|
|||||||
is_chief=is_chief,
|
is_chief=is_chief,
|
||||||
hooks=[],
|
hooks=[],
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
save_checkpoint_secs=save_checkpoint_secs,
|
checkpoint_save_secs=checkpoint_save_secs,
|
||||||
config=config
|
config=config
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -176,7 +176,7 @@ def parse_arguments(parser: argparse.ArgumentParser) -> argparse.Namespace:
|
|||||||
args.framework = Frameworks[args.framework.lower()]
|
args.framework = Frameworks[args.framework.lower()]
|
||||||
|
|
||||||
# checkpoints
|
# checkpoints
|
||||||
args.checkpoint_save_dir = os.path.join(args.experiment_path, 'checkpoint') if args.save_checkpoint_secs is not None else None
|
args.checkpoint_save_dir = os.path.join(args.experiment_path, 'checkpoint') if args.checkpoint_save_secs is not None else None
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
@@ -257,7 +257,7 @@ def main():
|
|||||||
help="(flag) TensorFlow verbosity level",
|
help="(flag) TensorFlow verbosity level",
|
||||||
default=3,
|
default=3,
|
||||||
type=int)
|
type=int)
|
||||||
parser.add_argument('-s', '--save_checkpoint_secs',
|
parser.add_argument('-s', '--checkpoint_save_secs',
|
||||||
help="(int) Time in seconds between saving checkpoints of the model.",
|
help="(int) Time in seconds between saving checkpoints of the model.",
|
||||||
default=None,
|
default=None,
|
||||||
type=int)
|
type=int)
|
||||||
|
|||||||
@@ -208,7 +208,7 @@ class GraphManager(object):
|
|||||||
self.sess = create_monitored_session(target=task_parameters.worker_target,
|
self.sess = create_monitored_session(target=task_parameters.worker_target,
|
||||||
task_index=task_parameters.task_index,
|
task_index=task_parameters.task_index,
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
save_checkpoint_secs=task_parameters.save_checkpoint_secs,
|
checkpoint_save_secs=task_parameters.checkpoint_save_secs,
|
||||||
config=config)
|
config=config)
|
||||||
# set the session for all the modules
|
# set the session for all the modules
|
||||||
self.set_session(self.sess)
|
self.set_session(self.sess)
|
||||||
@@ -490,8 +490,8 @@ class GraphManager(object):
|
|||||||
|
|
||||||
def occasionally_save_checkpoint(self):
|
def occasionally_save_checkpoint(self):
|
||||||
# only the chief process saves checkpoints
|
# only the chief process saves checkpoints
|
||||||
if self.task_parameters.save_checkpoint_secs \
|
if self.task_parameters.checkpoint_save_secs \
|
||||||
and time.time() - self.last_checkpoint_saving_time >= self.task_parameters.save_checkpoint_secs \
|
and time.time() - self.last_checkpoint_saving_time >= self.task_parameters.checkpoint_save_secs \
|
||||||
and (self.task_parameters.task_index == 0 # distributed
|
and (self.task_parameters.task_index == 0 # distributed
|
||||||
or self.task_parameters.task_index is None # single-worker
|
or self.task_parameters.task_index is None # single-worker
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -367,7 +367,7 @@
|
|||||||
" evaluate_only=False,\n",
|
" evaluate_only=False,\n",
|
||||||
" experiment_path=log_path)\n",
|
" experiment_path=log_path)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"task_parameters.__dict__['save_checkpoint_secs'] = None\n",
|
"task_parameters.__dict__['checkpoint_save_secs'] = None\n",
|
||||||
"\n",
|
"\n",
|
||||||
"graph_manager.create_graph(task_parameters)\n",
|
"graph_manager.create_graph(task_parameters)\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|||||||
@@ -345,7 +345,7 @@
|
|||||||
" evaluate_only=False,\n",
|
" evaluate_only=False,\n",
|
||||||
" experiment_path=log_path)\n",
|
" experiment_path=log_path)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"task_parameters.__dict__['save_checkpoint_secs'] = None\n",
|
"task_parameters.__dict__['checkpoint_save_secs'] = None\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"graph_manager.create_graph(task_parameters)\n",
|
"graph_manager.create_graph(task_parameters)\n",
|
||||||
|
|||||||
@@ -372,7 +372,7 @@
|
|||||||
" evaluate_only=False,\n",
|
" evaluate_only=False,\n",
|
||||||
" experiment_path=log_path)\n",
|
" experiment_path=log_path)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"task_parameters.__dict__['save_checkpoint_secs'] = None\n",
|
"task_parameters.__dict__['checkpoint_save_secs'] = None\n",
|
||||||
"task_parameters.__dict__['verbosity'] = 'low'\n",
|
"task_parameters.__dict__['verbosity'] = 'low'\n",
|
||||||
"\n",
|
"\n",
|
||||||
"graph_manager.create_graph(task_parameters)\n",
|
"graph_manager.create_graph(task_parameters)\n",
|
||||||
|
|||||||
Reference in New Issue
Block a user