From 9804b033a2fbca1cd92e8373fb556c80c141a9a8 Mon Sep 17 00:00:00 2001 From: Zach Dwiel Date: Fri, 5 Oct 2018 11:44:49 -0400 Subject: [PATCH] rename save_checkpoint_dir -> checkpoint_save_dir --- rl_coach/agents/nec_agent.py | 2 +- rl_coach/coach.py | 2 +- rl_coach/graph_managers/graph_manager.py | 4 ++-- rl_coach/training_worker.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/rl_coach/agents/nec_agent.py b/rl_coach/agents/nec_agent.py index 9b8144d..891466d 100644 --- a/rl_coach/agents/nec_agent.py +++ b/rl_coach/agents/nec_agent.py @@ -171,5 +171,5 @@ class NECAgent(ValueOptimizationAgent): actions, returns) def save_checkpoint(self, checkpoint_id): - with open(os.path.join(self.ap.task_parameters.save_checkpoint_dir, str(checkpoint_id) + '.dnd'), 'wb') as f: + with open(os.path.join(self.ap.task_parameters.checkpoint_save_dir, str(checkpoint_id) + '.dnd'), 'wb') as f: pickle.dump(self.networks['main'].online_network.output_heads[0].DND, f, pickle.HIGHEST_PROTOCOL) diff --git a/rl_coach/coach.py b/rl_coach/coach.py index 9c14003..9ee8cfe 100644 --- a/rl_coach/coach.py +++ b/rl_coach/coach.py @@ -176,7 +176,7 @@ def parse_arguments(parser: argparse.ArgumentParser) -> argparse.Namespace: args.framework = Frameworks[args.framework.lower()] # checkpoints - args.save_checkpoint_dir = os.path.join(args.experiment_path, 'checkpoint') if args.save_checkpoint_secs is not None else None + args.checkpoint_save_dir = os.path.join(args.experiment_path, 'checkpoint') if args.save_checkpoint_secs is not None else None return args diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index b7ed951..2b2c7e8 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -203,7 +203,7 @@ class GraphManager(object): remove_tree(checkpoint_dir) copy_tree(task_parameters.checkpoint_restore_dir, checkpoint_dir) else: - checkpoint_dir = task_parameters.save_checkpoint_dir + checkpoint_dir = task_parameters.checkpoint_save_dir self.sess = create_monitored_session(target=task_parameters.worker_target, task_index=task_parameters.task_index, @@ -498,7 +498,7 @@ class GraphManager(object): self.save_checkpoint() def save_checkpoint(self): - checkpoint_path = os.path.join(self.task_parameters.save_checkpoint_dir, + checkpoint_path = os.path.join(self.task_parameters.checkpoint_save_dir, "{}_Step-{}.ckpt".format( self.checkpoint_id, self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps])) diff --git a/rl_coach/training_worker.py b/rl_coach/training_worker.py index 4031cd9..9eb0898 100644 --- a/rl_coach/training_worker.py +++ b/rl_coach/training_worker.py @@ -24,7 +24,7 @@ def training_worker(graph_manager, checkpoint_dir, policy_type): """ # initialize graph task_parameters = TaskParameters() - task_parameters.__dict__['save_checkpoint_dir'] = checkpoint_dir + task_parameters.__dict__['checkpoint_save_dir'] = checkpoint_dir task_parameters.__dict__['save_checkpoint_secs'] = 20 graph_manager.create_graph(task_parameters)