mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Adding checkpointing framework (#74)
* Adding checkpointing framework as well as mxnet checkpointing implementation. - MXNet checkpoint for each network is saved in a separate file. * Adding checkpoint restore for mxnet to graph-manager * Add unit-test for get_checkpoint_state() * Added match.group() to fix unit-test failing on CI * Added ONNX export support for MXNet
This commit is contained in:
committed by
shadiendrawis
parent
4da56b1ff2
commit
67eb9e4c28
@@ -31,7 +31,8 @@ from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, T
|
||||
from rl_coach.environments.environment import Environment
|
||||
from rl_coach.level_manager import LevelManager
|
||||
from rl_coach.logger import screen, Logger
|
||||
from rl_coach.utils import set_cpu, start_shell_command_and_wait
|
||||
from rl_coach.saver import SaverCollection
|
||||
from rl_coach.utils import get_checkpoint_state, set_cpu, start_shell_command_and_wait
|
||||
from rl_coach.data_stores.data_store_impl import get_data_store as data_store_creator
|
||||
from rl_coach.memories.backend.memory_impl import get_memory_backend
|
||||
from rl_coach.data_stores.data_store import SyncFiles
|
||||
@@ -87,7 +88,7 @@ class GraphManager(object):
|
||||
schedule_params: ScheduleParameters,
|
||||
vis_params: VisualizationParameters = VisualizationParameters()):
|
||||
self.sess = None
|
||||
self.level_managers = []
|
||||
self.level_managers = [] # type: List[LevelManager]
|
||||
self.top_level_manager = None
|
||||
self.environments = []
|
||||
self.heatup_steps = schedule_params.heatup_steps
|
||||
@@ -248,12 +249,22 @@ class GraphManager(object):
|
||||
if hasattr(self.task_parameters, 'checkpoint_save_dir') and self.task_parameters.checkpoint_save_dir:
|
||||
self.save_graph()
|
||||
|
||||
def _create_session_mx(self):
|
||||
"""
|
||||
Call set_session to initialize parameters and construct checkpoint_saver
|
||||
"""
|
||||
self.set_session(sess=None) # Initialize all modules
|
||||
self.checkpoint_saver = SaverCollection()
|
||||
for level in self.level_managers:
|
||||
self.checkpoint_saver.update(level.collect_savers())
|
||||
# restore from checkpoint if given
|
||||
self.restore_checkpoint()
|
||||
|
||||
def create_session(self, task_parameters: TaskParameters):
|
||||
if task_parameters.framework_type == Frameworks.tensorflow:
|
||||
self._create_session_tf(task_parameters)
|
||||
elif task_parameters.framework_type == Frameworks.mxnet:
|
||||
self.set_session(sess=None) # Initialize all modules
|
||||
# TODO add checkpoint loading
|
||||
self._create_session_mx()
|
||||
else:
|
||||
raise ValueError('Invalid framework {}'.format(task_parameters.framework_type))
|
||||
|
||||
@@ -270,14 +281,13 @@ class GraphManager(object):
|
||||
name='graphdef.pb',
|
||||
as_text=False)
|
||||
|
||||
def save_onnx_graph(self) -> None:
|
||||
def _save_onnx_graph_tf(self) -> None:
|
||||
"""
|
||||
Save the graph as an ONNX graph.
|
||||
Save the tensorflow graph as an ONNX graph.
|
||||
This requires the graph and the weights checkpoint to be stored in the experiment directory.
|
||||
It then freezes the graph (merging the graph and weights checkpoint), and converts it to ONNX.
|
||||
:return: None
|
||||
"""
|
||||
|
||||
# collect input and output nodes
|
||||
input_nodes = []
|
||||
output_nodes = []
|
||||
@@ -290,11 +300,20 @@ class GraphManager(object):
|
||||
for output in network.online_network.outputs:
|
||||
output_nodes.append(output.name)
|
||||
|
||||
# TODO: make this framework agnostic
|
||||
from rl_coach.architectures.tensorflow_components.architecture import save_onnx_graph
|
||||
|
||||
save_onnx_graph(input_nodes, output_nodes, self.task_parameters.checkpoint_save_dir)
|
||||
|
||||
def save_onnx_graph(self) -> None:
|
||||
"""
|
||||
Save the graph as an ONNX graph.
|
||||
This requires the graph and the weights checkpoint to be stored in the experiment directory.
|
||||
It then freezes the graph (merging the graph and weights checkpoint), and converts it to ONNX.
|
||||
:return: None
|
||||
"""
|
||||
if self.task_parameters.framework_type == Frameworks.tensorflow:
|
||||
self._save_onnx_graph_tf()
|
||||
|
||||
def setup_logger(self) -> None:
|
||||
# dump documentation
|
||||
logger_prefix = "{graph_name}".format(graph_name=self.name)
|
||||
@@ -526,14 +545,13 @@ class GraphManager(object):
|
||||
if self.evaluate(self.evaluation_steps):
|
||||
break
|
||||
|
||||
def _restore_checkpoint_tf(self, checkpoint_dir: str):
|
||||
def _restore_checkpoint_tf(self, checkpoint_path: str):
|
||||
import tensorflow as tf
|
||||
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
|
||||
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
|
||||
variables = {}
|
||||
for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
|
||||
reader = tf.contrib.framework.load_checkpoint(checkpoint_path)
|
||||
for var_name, _ in reader.get_variable_to_shape_map().items():
|
||||
# Load the variable
|
||||
var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)
|
||||
var = reader.get_tensor(var_name)
|
||||
|
||||
# Set the new name
|
||||
new_name = var_name
|
||||
@@ -548,11 +566,14 @@ class GraphManager(object):
|
||||
|
||||
# TODO: find better way to load checkpoints that were saved with a global network into the online network
|
||||
if hasattr(self.task_parameters, 'checkpoint_restore_dir') and self.task_parameters.checkpoint_restore_dir:
|
||||
|
||||
checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir)
|
||||
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
|
||||
|
||||
if self.task_parameters.framework_type == Frameworks.tensorflow:
|
||||
self._restore_checkpoint_tf(self.task_parameters.checkpoint_restore_dir)
|
||||
self._restore_checkpoint_tf(checkpoint.model_checkpoint_path)
|
||||
elif self.task_parameters.framework_type == Frameworks.mxnet:
|
||||
# TODO implement checkpoint restore
|
||||
pass
|
||||
self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path)
|
||||
else:
|
||||
raise ValueError('Invalid framework {}'.format(self.task_parameters.framework_type))
|
||||
|
||||
@@ -572,6 +593,8 @@ class GraphManager(object):
|
||||
"{}_Step-{}.ckpt".format(
|
||||
self.checkpoint_id,
|
||||
self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps]))
|
||||
if not os.path.exists(os.path.dirname(checkpoint_path)):
|
||||
os.mkdir(os.path.dirname(checkpoint_path)) # Create directory structure
|
||||
if not isinstance(self.task_parameters, DistributedTaskParameters):
|
||||
if self.checkpoint_saver is not None:
|
||||
saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path)
|
||||
|
||||
Reference in New Issue
Block a user