1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

Uploading checkpoint if crd provided (#191)

* Uploading checkpoint if crd provided
* Changing the calculation of total steps because of a recent change in core_types

Fixes #195
This commit is contained in:
Ajay Deshpande
2019-04-26 12:27:33 -07:00
committed by Scott Leishman
parent b3db9ce77d
commit 33dc29ee99
8 changed files with 122 additions and 40 deletions

View File

@@ -37,6 +37,7 @@ from rl_coach.utils import set_cpu, start_shell_command_and_wait
from rl_coach.data_stores.data_store_impl import get_data_store as data_store_creator
from rl_coach.memories.backend.memory_impl import get_memory_backend
from rl_coach.data_stores.data_store import SyncFiles
from rl_coach.checkpoint import CheckpointStateReader
from rl_coach.core_types import TimeTypes
@@ -589,6 +590,10 @@ class GraphManager(object):
[manager.restore_checkpoint(checkpoint_restore_dir) for manager in self.level_managers]
# Set the last checkpoint ID
chkpt_state_reader = CheckpointStateReader(self.task_parameters.checkpoint_restore_path, checkpoint_state_optional=False)
self.checkpoint_id = chkpt_state_reader.get_latest().num + 1
def _get_checkpoint_state_tf(self, checkpoint_restore_dir):
import tensorflow as tf
return tf.train.get_checkpoint_state(checkpoint_restore_dir)
@@ -721,6 +726,13 @@ class GraphManager(object):
return data_store_creator(param)
def signal_ready(self):
if self.task_parameters.checkpoint_save_dir and os.path.exists(self.task_parameters.checkpoint_save_dir):
open(os.path.join(self.task_parameters.checkpoint_save_dir, SyncFiles.TRAINER_READY.value), 'w').close()
if hasattr(self, 'data_store_params'):
data_store = self.get_data_store(self.data_store_params)
data_store.save_to_store()
def close(self) -> None:
"""
Clean up to close environments.