1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

Adding checkpointing framework (#74)

* Adding checkpointing framework as well as mxnet checkpointing implementation.

- MXNet checkpoint for each network is saved in a separate file.

* Adding checkpoint restore for mxnet to graph-manager

* Add unit-test for get_checkpoint_state()

* Added match.group() to fix unit-test failing on CI

* Added ONNX export support for MXNet
This commit is contained in:
Sina Afrooze
2018-11-19 09:45:49 -08:00
committed by shadiendrawis
parent 4da56b1ff2
commit 67eb9e4c28
19 changed files with 598 additions and 29 deletions

View File

@@ -31,7 +31,8 @@ from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, T
from rl_coach.environments.environment import Environment
from rl_coach.level_manager import LevelManager
from rl_coach.logger import screen, Logger
from rl_coach.utils import set_cpu, start_shell_command_and_wait
from rl_coach.saver import SaverCollection
from rl_coach.utils import get_checkpoint_state, set_cpu, start_shell_command_and_wait
from rl_coach.data_stores.data_store_impl import get_data_store as data_store_creator
from rl_coach.memories.backend.memory_impl import get_memory_backend
from rl_coach.data_stores.data_store import SyncFiles
@@ -87,7 +88,7 @@ class GraphManager(object):
schedule_params: ScheduleParameters,
vis_params: VisualizationParameters = VisualizationParameters()):
self.sess = None
self.level_managers = []
self.level_managers = [] # type: List[LevelManager]
self.top_level_manager = None
self.environments = []
self.heatup_steps = schedule_params.heatup_steps
@@ -248,12 +249,22 @@ class GraphManager(object):
if hasattr(self.task_parameters, 'checkpoint_save_dir') and self.task_parameters.checkpoint_save_dir:
self.save_graph()
def _create_session_mx(self):
"""
Call set_session to initialize parameters and construct checkpoint_saver
"""
self.set_session(sess=None) # Initialize all modules
self.checkpoint_saver = SaverCollection()
for level in self.level_managers:
self.checkpoint_saver.update(level.collect_savers())
# restore from checkpoint if given
self.restore_checkpoint()
def create_session(self, task_parameters: TaskParameters):
if task_parameters.framework_type == Frameworks.tensorflow:
self._create_session_tf(task_parameters)
elif task_parameters.framework_type == Frameworks.mxnet:
self.set_session(sess=None) # Initialize all modules
# TODO add checkpoint loading
self._create_session_mx()
else:
raise ValueError('Invalid framework {}'.format(task_parameters.framework_type))
@@ -270,14 +281,13 @@ class GraphManager(object):
name='graphdef.pb',
as_text=False)
def save_onnx_graph(self) -> None:
def _save_onnx_graph_tf(self) -> None:
"""
Save the graph as an ONNX graph.
Save the tensorflow graph as an ONNX graph.
This requires the graph and the weights checkpoint to be stored in the experiment directory.
It then freezes the graph (merging the graph and weights checkpoint), and converts it to ONNX.
:return: None
"""
# collect input and output nodes
input_nodes = []
output_nodes = []
@@ -290,11 +300,20 @@ class GraphManager(object):
for output in network.online_network.outputs:
output_nodes.append(output.name)
# TODO: make this framework agnostic
from rl_coach.architectures.tensorflow_components.architecture import save_onnx_graph
save_onnx_graph(input_nodes, output_nodes, self.task_parameters.checkpoint_save_dir)
def save_onnx_graph(self) -> None:
"""
Save the graph as an ONNX graph.
This requires the graph and the weights checkpoint to be stored in the experiment directory.
It then freezes the graph (merging the graph and weights checkpoint), and converts it to ONNX.
:return: None
"""
if self.task_parameters.framework_type == Frameworks.tensorflow:
self._save_onnx_graph_tf()
def setup_logger(self) -> None:
# dump documentation
logger_prefix = "{graph_name}".format(graph_name=self.name)
@@ -526,14 +545,13 @@ class GraphManager(object):
if self.evaluate(self.evaluation_steps):
break
def _restore_checkpoint_tf(self, checkpoint_dir: str):
def _restore_checkpoint_tf(self, checkpoint_path: str):
import tensorflow as tf
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
variables = {}
for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
reader = tf.contrib.framework.load_checkpoint(checkpoint_path)
for var_name, _ in reader.get_variable_to_shape_map().items():
# Load the variable
var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)
var = reader.get_tensor(var_name)
# Set the new name
new_name = var_name
@@ -548,11 +566,14 @@ class GraphManager(object):
# TODO: find better way to load checkpoints that were saved with a global network into the online network
if hasattr(self.task_parameters, 'checkpoint_restore_dir') and self.task_parameters.checkpoint_restore_dir:
checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir)
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
if self.task_parameters.framework_type == Frameworks.tensorflow:
self._restore_checkpoint_tf(self.task_parameters.checkpoint_restore_dir)
self._restore_checkpoint_tf(checkpoint.model_checkpoint_path)
elif self.task_parameters.framework_type == Frameworks.mxnet:
# TODO implement checkpoint restore
pass
self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path)
else:
raise ValueError('Invalid framework {}'.format(self.task_parameters.framework_type))
@@ -572,6 +593,8 @@ class GraphManager(object):
"{}_Step-{}.ckpt".format(
self.checkpoint_id,
self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps]))
if not os.path.exists(os.path.dirname(checkpoint_path)):
os.mkdir(os.path.dirname(checkpoint_path)) # Create directory structure
if not isinstance(self.task_parameters, DistributedTaskParameters):
if self.checkpoint_saver is not None:
saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path)