mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
rename save_checkpoint_dir -> checkpoint_save_dir
This commit is contained in:
@@ -171,5 +171,5 @@ class NECAgent(ValueOptimizationAgent):
|
|||||||
actions, returns)
|
actions, returns)
|
||||||
|
|
||||||
def save_checkpoint(self, checkpoint_id):
|
def save_checkpoint(self, checkpoint_id):
|
||||||
with open(os.path.join(self.ap.task_parameters.save_checkpoint_dir, str(checkpoint_id) + '.dnd'), 'wb') as f:
|
with open(os.path.join(self.ap.task_parameters.checkpoint_save_dir, str(checkpoint_id) + '.dnd'), 'wb') as f:
|
||||||
pickle.dump(self.networks['main'].online_network.output_heads[0].DND, f, pickle.HIGHEST_PROTOCOL)
|
pickle.dump(self.networks['main'].online_network.output_heads[0].DND, f, pickle.HIGHEST_PROTOCOL)
|
||||||
|
|||||||
@@ -176,7 +176,7 @@ def parse_arguments(parser: argparse.ArgumentParser) -> argparse.Namespace:
|
|||||||
args.framework = Frameworks[args.framework.lower()]
|
args.framework = Frameworks[args.framework.lower()]
|
||||||
|
|
||||||
# checkpoints
|
# checkpoints
|
||||||
args.save_checkpoint_dir = os.path.join(args.experiment_path, 'checkpoint') if args.save_checkpoint_secs is not None else None
|
args.checkpoint_save_dir = os.path.join(args.experiment_path, 'checkpoint') if args.save_checkpoint_secs is not None else None
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|||||||
@@ -203,7 +203,7 @@ class GraphManager(object):
|
|||||||
remove_tree(checkpoint_dir)
|
remove_tree(checkpoint_dir)
|
||||||
copy_tree(task_parameters.checkpoint_restore_dir, checkpoint_dir)
|
copy_tree(task_parameters.checkpoint_restore_dir, checkpoint_dir)
|
||||||
else:
|
else:
|
||||||
checkpoint_dir = task_parameters.save_checkpoint_dir
|
checkpoint_dir = task_parameters.checkpoint_save_dir
|
||||||
|
|
||||||
self.sess = create_monitored_session(target=task_parameters.worker_target,
|
self.sess = create_monitored_session(target=task_parameters.worker_target,
|
||||||
task_index=task_parameters.task_index,
|
task_index=task_parameters.task_index,
|
||||||
@@ -498,7 +498,7 @@ class GraphManager(object):
|
|||||||
self.save_checkpoint()
|
self.save_checkpoint()
|
||||||
|
|
||||||
def save_checkpoint(self):
|
def save_checkpoint(self):
|
||||||
checkpoint_path = os.path.join(self.task_parameters.save_checkpoint_dir,
|
checkpoint_path = os.path.join(self.task_parameters.checkpoint_save_dir,
|
||||||
"{}_Step-{}.ckpt".format(
|
"{}_Step-{}.ckpt".format(
|
||||||
self.checkpoint_id,
|
self.checkpoint_id,
|
||||||
self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps]))
|
self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps]))
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ def training_worker(graph_manager, checkpoint_dir, policy_type):
|
|||||||
"""
|
"""
|
||||||
# initialize graph
|
# initialize graph
|
||||||
task_parameters = TaskParameters()
|
task_parameters = TaskParameters()
|
||||||
task_parameters.__dict__['save_checkpoint_dir'] = checkpoint_dir
|
task_parameters.__dict__['checkpoint_save_dir'] = checkpoint_dir
|
||||||
task_parameters.__dict__['save_checkpoint_secs'] = 20
|
task_parameters.__dict__['save_checkpoint_secs'] = 20
|
||||||
graph_manager.create_graph(task_parameters)
|
graph_manager.create_graph(task_parameters)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user