mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
Uploading checkpoint if crd provided (#191)
* Uploading checkpoint if crd provided * Changing the calculation of total steps because of a recent change in core_types Fixes #195
This commit is contained in:
committed by
Scott Leishman
parent
b3db9ce77d
commit
33dc29ee99
@@ -37,6 +37,7 @@ from rl_coach.utils import set_cpu, start_shell_command_and_wait
|
||||
from rl_coach.data_stores.data_store_impl import get_data_store as data_store_creator
|
||||
from rl_coach.memories.backend.memory_impl import get_memory_backend
|
||||
from rl_coach.data_stores.data_store import SyncFiles
|
||||
from rl_coach.checkpoint import CheckpointStateReader
|
||||
|
||||
from rl_coach.core_types import TimeTypes
|
||||
|
||||
@@ -589,6 +590,10 @@ class GraphManager(object):
|
||||
|
||||
[manager.restore_checkpoint(checkpoint_restore_dir) for manager in self.level_managers]
|
||||
|
||||
# Set the last checkpoint ID
|
||||
chkpt_state_reader = CheckpointStateReader(self.task_parameters.checkpoint_restore_path, checkpoint_state_optional=False)
|
||||
self.checkpoint_id = chkpt_state_reader.get_latest().num + 1
|
||||
|
||||
def _get_checkpoint_state_tf(self, checkpoint_restore_dir):
|
||||
import tensorflow as tf
|
||||
return tf.train.get_checkpoint_state(checkpoint_restore_dir)
|
||||
@@ -721,6 +726,13 @@ class GraphManager(object):
|
||||
|
||||
return data_store_creator(param)
|
||||
|
||||
def signal_ready(self):
|
||||
if self.task_parameters.checkpoint_save_dir and os.path.exists(self.task_parameters.checkpoint_save_dir):
|
||||
open(os.path.join(self.task_parameters.checkpoint_save_dir, SyncFiles.TRAINER_READY.value), 'w').close()
|
||||
if hasattr(self, 'data_store_params'):
|
||||
data_store = self.get_data_store(self.data_store_params)
|
||||
data_store.save_to_store()
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Clean up to close environments.
|
||||
|
||||
Reference in New Issue
Block a user