diff --git a/rl_coach/architectures/tensorflow_components/distributed_tf_utils.py b/rl_coach/architectures/tensorflow_components/distributed_tf_utils.py index db84c9a..ccf3e89 100644 --- a/rl_coach/architectures/tensorflow_components/distributed_tf_utils.py +++ b/rl_coach/architectures/tensorflow_components/distributed_tf_utils.py @@ -75,13 +75,13 @@ def create_worker_server_and_device(cluster_spec: tf.train.ClusterSpec, task_ind def create_monitored_session(target: tf.train.Server, task_index: int, - checkpoint_dir: str, save_checkpoint_secs: int, config: tf.ConfigProto=None) -> tf.Session: + checkpoint_dir: str, checkpoint_save_secs: int, config: tf.ConfigProto=None) -> tf.Session: """ Create a monitored session for the worker :param target: the target string for the tf.Session :param task_index: the task index of the worker :param checkpoint_dir: a directory path where the checkpoints will be stored - :param save_checkpoint_secs: number of seconds between checkpoints storing + :param checkpoint_save_secs: number of seconds between checkpoints storing :param config: the tensorflow configuration (optional) :return: the session to use for the run """ @@ -94,7 +94,7 @@ def create_monitored_session(target: tf.train.Server, task_index: int, is_chief=is_chief, hooks=[], checkpoint_dir=checkpoint_dir, - save_checkpoint_secs=save_checkpoint_secs, + checkpoint_save_secs=checkpoint_save_secs, config=config ) diff --git a/rl_coach/coach.py b/rl_coach/coach.py index 9ee8cfe..be7c81e 100644 --- a/rl_coach/coach.py +++ b/rl_coach/coach.py @@ -176,7 +176,7 @@ def parse_arguments(parser: argparse.ArgumentParser) -> argparse.Namespace: args.framework = Frameworks[args.framework.lower()] # checkpoints - args.checkpoint_save_dir = os.path.join(args.experiment_path, 'checkpoint') if args.save_checkpoint_secs is not None else None + args.checkpoint_save_dir = os.path.join(args.experiment_path, 'checkpoint') if args.checkpoint_save_secs is not None else None return args @@ -257,7 +257,7 @@ def main(): help="(flag) TensorFlow verbosity level", default=3, type=int) - parser.add_argument('-s', '--save_checkpoint_secs', + parser.add_argument('-s', '--checkpoint_save_secs', help="(int) Time in seconds between saving checkpoints of the model.", default=None, type=int) diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index 2b2c7e8..3260d09 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -208,7 +208,7 @@ class GraphManager(object): self.sess = create_monitored_session(target=task_parameters.worker_target, task_index=task_parameters.task_index, checkpoint_dir=checkpoint_dir, - save_checkpoint_secs=task_parameters.save_checkpoint_secs, + checkpoint_save_secs=task_parameters.checkpoint_save_secs, config=config) # set the session for all the modules self.set_session(self.sess) @@ -490,8 +490,8 @@ class GraphManager(object): def occasionally_save_checkpoint(self): # only the chief process saves checkpoints - if self.task_parameters.save_checkpoint_secs \ - and time.time() - self.last_checkpoint_saving_time >= self.task_parameters.save_checkpoint_secs \ + if self.task_parameters.checkpoint_save_secs \ + and time.time() - self.last_checkpoint_saving_time >= self.task_parameters.checkpoint_save_secs \ and (self.task_parameters.task_index == 0 # distributed or self.task_parameters.task_index is None # single-worker ): diff --git a/tutorials/1. Implementing an Algorithm.ipynb b/tutorials/1. Implementing an Algorithm.ipynb index 7e7866d..a309964 100644 --- a/tutorials/1. Implementing an Algorithm.ipynb +++ b/tutorials/1. Implementing an Algorithm.ipynb @@ -367,7 +367,7 @@ " evaluate_only=False,\n", " experiment_path=log_path)\n", "\n", - "task_parameters.__dict__['save_checkpoint_secs'] = None\n", + "task_parameters.__dict__['checkpoint_save_secs'] = None\n", "\n", "graph_manager.create_graph(task_parameters)\n", "\n", diff --git a/tutorials/2. Adding an Environment.ipynb b/tutorials/2. Adding an Environment.ipynb index 5bec481..71869d1 100644 --- a/tutorials/2. Adding an Environment.ipynb +++ b/tutorials/2. Adding an Environment.ipynb @@ -345,7 +345,7 @@ " evaluate_only=False,\n", " experiment_path=log_path)\n", "\n", - "task_parameters.__dict__['save_checkpoint_secs'] = None\n", + "task_parameters.__dict__['checkpoint_save_secs'] = None\n", "\n", "\n", "graph_manager.create_graph(task_parameters)\n", diff --git a/tutorials/3. Implementing a Hierarchical RL Graph.ipynb b/tutorials/3. Implementing a Hierarchical RL Graph.ipynb index 32e832b..eafa78f 100644 --- a/tutorials/3. Implementing a Hierarchical RL Graph.ipynb +++ b/tutorials/3. Implementing a Hierarchical RL Graph.ipynb @@ -372,7 +372,7 @@ " evaluate_only=False,\n", " experiment_path=log_path)\n", "\n", - "task_parameters.__dict__['save_checkpoint_secs'] = None\n", + "task_parameters.__dict__['checkpoint_save_secs'] = None\n", "task_parameters.__dict__['verbosity'] = 'low'\n", "\n", "graph_manager.create_graph(task_parameters)\n",