mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Implement frame-work agnostic rollout and training workers (#137)
* Added checkpoint state file to coach checkpointing. * Removed TF specific code from rollout_worker, training_worker, and s3_data_store
This commit is contained in:
committed by
Balaji Subramaniam
parent
4a6c404070
commit
5332013bd1
@@ -25,6 +25,7 @@ import contextlib
|
||||
from rl_coach.base_parameters import iterable_to_items, TaskParameters, DistributedTaskParameters, Frameworks, \
|
||||
VisualizationParameters, \
|
||||
Parameters, PresetValidationParameters, RunType
|
||||
from rl_coach.checkpoint import CheckpointStateUpdater, get_checkpoint_state, SingleCheckpoint
|
||||
from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, TrainingSteps, EnvironmentEpisodes, \
|
||||
EnvironmentSteps, \
|
||||
StepMethod, Transition
|
||||
@@ -32,7 +33,7 @@ from rl_coach.environments.environment import Environment
|
||||
from rl_coach.level_manager import LevelManager
|
||||
from rl_coach.logger import screen, Logger
|
||||
from rl_coach.saver import SaverCollection
|
||||
from rl_coach.utils import get_checkpoint_state, set_cpu, start_shell_command_and_wait
|
||||
from rl_coach.utils import set_cpu, start_shell_command_and_wait
|
||||
from rl_coach.data_stores.data_store_impl import get_data_store as data_store_creator
|
||||
from rl_coach.memories.backend.memory_impl import get_memory_backend
|
||||
from rl_coach.data_stores.data_store import SyncFiles
|
||||
@@ -115,6 +116,7 @@ class GraphManager(object):
|
||||
self.checkpoint_id = 0
|
||||
|
||||
self.checkpoint_saver = None
|
||||
self.checkpoint_state_updater = None
|
||||
self.graph_logger = Logger()
|
||||
self.data_store = None
|
||||
|
||||
@@ -495,7 +497,7 @@ class GraphManager(object):
|
||||
self.act(EnvironmentEpisodes(1))
|
||||
self.sync()
|
||||
if self.should_stop():
|
||||
if self.task_parameters.checkpoint_save_dir:
|
||||
if self.task_parameters.checkpoint_save_dir and os.path.exists(self.task_parameters.checkpoint_save_dir):
|
||||
open(os.path.join(self.task_parameters.checkpoint_save_dir, SyncFiles.FINISHED.value), 'w').close()
|
||||
if hasattr(self, 'data_store_params'):
|
||||
data_store = self.get_data_store(self.data_store_params)
|
||||
@@ -579,29 +581,35 @@ class GraphManager(object):
|
||||
self.save_checkpoint()
|
||||
|
||||
def save_checkpoint(self):
|
||||
# create current session's checkpoint directory
|
||||
if self.task_parameters.checkpoint_save_dir is None:
|
||||
self.task_parameters.checkpoint_save_dir = os.path.join(self.task_parameters.experiment_path, 'checkpoint')
|
||||
|
||||
filename = "{}_Step-{}.ckpt".format(
|
||||
self.checkpoint_id,
|
||||
self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps])
|
||||
if not os.path.exists(self.task_parameters.checkpoint_save_dir):
|
||||
os.mkdir(self.task_parameters.checkpoint_save_dir) # Create directory structure
|
||||
|
||||
if self.checkpoint_state_updater is None:
|
||||
self.checkpoint_state_updater = CheckpointStateUpdater(self.task_parameters.checkpoint_save_dir)
|
||||
|
||||
checkpoint_name = "{}_Step-{}.ckpt".format(
|
||||
self.checkpoint_id, self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps])
|
||||
checkpoint_path = os.path.join(self.task_parameters.checkpoint_save_dir, checkpoint_name)
|
||||
|
||||
checkpoint_path = os.path.join(self.task_parameters.checkpoint_save_dir,
|
||||
filename)
|
||||
if not os.path.exists(os.path.dirname(checkpoint_path)):
|
||||
os.mkdir(os.path.dirname(checkpoint_path)) # Create directory structure
|
||||
if not isinstance(self.task_parameters, DistributedTaskParameters):
|
||||
saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path)
|
||||
else:
|
||||
saved_checkpoint_path = checkpoint_path
|
||||
|
||||
# this is required in order for agents to save additional information like a DND for example
|
||||
[manager.save_checkpoint(filename) for manager in self.level_managers]
|
||||
[manager.save_checkpoint(checkpoint_name) for manager in self.level_managers]
|
||||
|
||||
# the ONNX graph will be stored only if checkpoints are stored and the -onnx flag is used
|
||||
if self.task_parameters.export_onnx_graph:
|
||||
self.save_onnx_graph()
|
||||
|
||||
# write the new checkpoint name to a file to signal this checkpoint has been fully saved
|
||||
self.checkpoint_state_updater.update(SingleCheckpoint(self.checkpoint_id, checkpoint_name))
|
||||
|
||||
screen.log_dict(
|
||||
OrderedDict([
|
||||
("Saving in path", saved_checkpoint_path),
|
||||
|
||||
Reference in New Issue
Block a user