1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

Implement frame-work agnostic rollout and training workers (#137)

* Added checkpoint state file to coach checkpointing.

* Removed TF specific code from rollout_worker, training_worker, and s3_data_store
This commit is contained in:
Sina Afrooze
2018-11-23 18:05:44 -08:00
committed by Balaji Subramaniam
parent 4a6c404070
commit 5332013bd1
7 changed files with 350 additions and 117 deletions

View File

@@ -25,6 +25,7 @@ import contextlib
from rl_coach.base_parameters import iterable_to_items, TaskParameters, DistributedTaskParameters, Frameworks, \
VisualizationParameters, \
Parameters, PresetValidationParameters, RunType
from rl_coach.checkpoint import CheckpointStateUpdater, get_checkpoint_state, SingleCheckpoint
from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, TrainingSteps, EnvironmentEpisodes, \
EnvironmentSteps, \
StepMethod, Transition
@@ -32,7 +33,7 @@ from rl_coach.environments.environment import Environment
from rl_coach.level_manager import LevelManager
from rl_coach.logger import screen, Logger
from rl_coach.saver import SaverCollection
from rl_coach.utils import get_checkpoint_state, set_cpu, start_shell_command_and_wait
from rl_coach.utils import set_cpu, start_shell_command_and_wait
from rl_coach.data_stores.data_store_impl import get_data_store as data_store_creator
from rl_coach.memories.backend.memory_impl import get_memory_backend
from rl_coach.data_stores.data_store import SyncFiles
@@ -115,6 +116,7 @@ class GraphManager(object):
self.checkpoint_id = 0
self.checkpoint_saver = None
self.checkpoint_state_updater = None
self.graph_logger = Logger()
self.data_store = None
@@ -495,7 +497,7 @@ class GraphManager(object):
self.act(EnvironmentEpisodes(1))
self.sync()
if self.should_stop():
if self.task_parameters.checkpoint_save_dir:
if self.task_parameters.checkpoint_save_dir and os.path.exists(self.task_parameters.checkpoint_save_dir):
open(os.path.join(self.task_parameters.checkpoint_save_dir, SyncFiles.FINISHED.value), 'w').close()
if hasattr(self, 'data_store_params'):
data_store = self.get_data_store(self.data_store_params)
@@ -579,29 +581,35 @@ class GraphManager(object):
self.save_checkpoint()
def save_checkpoint(self):
# create current session's checkpoint directory
if self.task_parameters.checkpoint_save_dir is None:
self.task_parameters.checkpoint_save_dir = os.path.join(self.task_parameters.experiment_path, 'checkpoint')
filename = "{}_Step-{}.ckpt".format(
self.checkpoint_id,
self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps])
if not os.path.exists(self.task_parameters.checkpoint_save_dir):
os.mkdir(self.task_parameters.checkpoint_save_dir) # Create directory structure
if self.checkpoint_state_updater is None:
self.checkpoint_state_updater = CheckpointStateUpdater(self.task_parameters.checkpoint_save_dir)
checkpoint_name = "{}_Step-{}.ckpt".format(
self.checkpoint_id, self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps])
checkpoint_path = os.path.join(self.task_parameters.checkpoint_save_dir, checkpoint_name)
checkpoint_path = os.path.join(self.task_parameters.checkpoint_save_dir,
filename)
if not os.path.exists(os.path.dirname(checkpoint_path)):
os.mkdir(os.path.dirname(checkpoint_path)) # Create directory structure
if not isinstance(self.task_parameters, DistributedTaskParameters):
saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path)
else:
saved_checkpoint_path = checkpoint_path
# this is required in order for agents to save additional information like a DND for example
[manager.save_checkpoint(filename) for manager in self.level_managers]
[manager.save_checkpoint(checkpoint_name) for manager in self.level_managers]
# the ONNX graph will be stored only if checkpoints are stored and the -onnx flag is used
if self.task_parameters.export_onnx_graph:
self.save_onnx_graph()
# write the new checkpoint name to a file to signal this checkpoint has been fully saved
self.checkpoint_state_updater.update(SingleCheckpoint(self.checkpoint_id, checkpoint_name))
screen.log_dict(
OrderedDict([
("Saving in path", saved_checkpoint_path),