mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Fixes for having NumpySharedRunningStats syncing on multi-node (#139)
1. Having the standard checkpoint prefix in order for the data store to grab it, and sync it to S3. 2. Removing the reference to Redis so that it won't try to pickle that in. 3. Enable restoring a checkpoint into a single-worker run, which was saved by a single-node-multiple-worker run.
This commit is contained in:
@@ -581,10 +581,13 @@ class GraphManager(object):
|
||||
def save_checkpoint(self):
|
||||
if self.task_parameters.checkpoint_save_dir is None:
|
||||
self.task_parameters.checkpoint_save_dir = os.path.join(self.task_parameters.experiment_path, 'checkpoint')
|
||||
checkpoint_path = os.path.join(self.task_parameters.checkpoint_save_dir,
|
||||
"{}_Step-{}.ckpt".format(
|
||||
|
||||
filename = "{}_Step-{}.ckpt".format(
|
||||
self.checkpoint_id,
|
||||
self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps]))
|
||||
self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps])
|
||||
|
||||
checkpoint_path = os.path.join(self.task_parameters.checkpoint_save_dir,
|
||||
filename)
|
||||
if not os.path.exists(os.path.dirname(checkpoint_path)):
|
||||
os.mkdir(os.path.dirname(checkpoint_path)) # Create directory structure
|
||||
if not isinstance(self.task_parameters, DistributedTaskParameters):
|
||||
@@ -593,7 +596,7 @@ class GraphManager(object):
|
||||
saved_checkpoint_path = checkpoint_path
|
||||
|
||||
# this is required in order for agents to save additional information like a DND for example
|
||||
[manager.save_checkpoint(self.checkpoint_id) for manager in self.level_managers]
|
||||
[manager.save_checkpoint(filename) for manager in self.level_managers]
|
||||
|
||||
# the ONNX graph will be stored only if checkpoints are stored and the -onnx flag is used
|
||||
if self.task_parameters.export_onnx_graph:
|
||||
|
||||
Reference in New Issue
Block a user