diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index 44a5df8..9e23610 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -541,28 +541,22 @@ class GraphManager(object): if self.evaluate(self.evaluation_steps): break - def _restore_checkpoint_tf(self, checkpoint_path: str): - import tensorflow as tf - variables = {} - reader = tf.contrib.framework.load_checkpoint(checkpoint_path) - for var_name, _ in reader.get_variable_to_shape_map().items(): - # Load the variable - var = reader.get_tensor(var_name) - - # Set the new name - new_name = var_name - new_name = new_name.replace('global/', 'online/') - variables[new_name] = var - - for v in self.variables_to_restore: - self.sess.run(v.assign(variables[v.name.split(':')[0]])) - def restore_checkpoint(self): self.verify_graph_was_created() # TODO: find better way to load checkpoints that were saved with a global network into the online network if self.task_parameters.checkpoint_restore_dir: - checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir) + if self.task_parameters.framework_type == Frameworks.tensorflow: + # TODO-fixme checkpointing + # MonitoredTrainingSession manages save/restore checkpoints autonomously. Doing so, + # it creates it own names for the saved checkpoints, which do not match the "{}_Step-{}.ckpt" filename + # pattern. The names used are maintained in a CheckpointState protobuf file named 'checkpoint'. Using + # Coach's '.coach_checkpoint' protobuf file, results in an error when trying to restore the model, as + # the checkpoint names defined do not match the actual checkpoint names. + checkpoint = self._get_checkpoint_state_tf() + else: + checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir) + if checkpoint is None: screen.warning("No checkpoint to restore in: {}".format(self.task_parameters.checkpoint_restore_dir)) else: @@ -571,6 +565,10 @@ class GraphManager(object): [manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers] + def _get_checkpoint_state_tf(self): + import tensorflow as tf + return tf.train.get_checkpoint_state(self.task_parameters.checkpoint_restore_dir) + def occasionally_save_checkpoint(self): # only the chief process saves checkpoints if self.task_parameters.checkpoint_save_secs \