mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
Adding mxnet components to rl_coach/architectures (#60)
Adding mxnet components to rl_coach architectures. - Supports PPO and DQN - Tested with CartPole_PPO and CarPole_DQN - Normalizing filters don't work right now (see #49) and are disabled in CartPole_PPO preset - Checkpointing is disabled for MXNet
This commit is contained in:
@@ -22,7 +22,7 @@ from distutils.dir_util import copy_tree, remove_tree
|
||||
from typing import List, Tuple
|
||||
import contextlib
|
||||
|
||||
from rl_coach.base_parameters import iterable_to_items, TaskParameters, DistributedTaskParameters, \
|
||||
from rl_coach.base_parameters import iterable_to_items, TaskParameters, DistributedTaskParameters, Frameworks, \
|
||||
VisualizationParameters, \
|
||||
Parameters, PresetValidationParameters
|
||||
from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, TrainingSteps, EnvironmentEpisodes, \
|
||||
@@ -161,7 +161,8 @@ class GraphManager(object):
|
||||
"""
|
||||
raise NotImplementedError("")
|
||||
|
||||
def create_worker_or_parameters_server(self, task_parameters: DistributedTaskParameters):
|
||||
@staticmethod
|
||||
def _create_worker_or_parameters_server_tf(task_parameters: DistributedTaskParameters):
|
||||
import tensorflow as tf
|
||||
config = tf.ConfigProto()
|
||||
config.allow_soft_placement = True # allow placing ops on cpu if they are not fit for gpu
|
||||
@@ -170,7 +171,8 @@ class GraphManager(object):
|
||||
config.intra_op_parallelism_threads = 1
|
||||
config.inter_op_parallelism_threads = 1
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.distributed_tf_utils import create_and_start_parameters_server, \
|
||||
from rl_coach.architectures.tensorflow_components.distributed_tf_utils import \
|
||||
create_and_start_parameters_server, \
|
||||
create_cluster_spec, create_worker_server_and_device
|
||||
|
||||
# create cluster spec
|
||||
@@ -190,7 +192,16 @@ class GraphManager(object):
|
||||
raise ValueError("The job type should be either ps or worker and not {}"
|
||||
.format(task_parameters.job_type))
|
||||
|
||||
def create_session(self, task_parameters: DistributedTaskParameters):
|
||||
@staticmethod
|
||||
def create_worker_or_parameters_server(task_parameters: DistributedTaskParameters):
|
||||
if task_parameters.framework_type == Frameworks.tensorflow:
|
||||
GraphManager._create_worker_or_parameters_server_tf(task_parameters)
|
||||
elif task_parameters.framework_type == Frameworks.mxnet:
|
||||
raise NotImplementedError('Distributed training not implemented for MXNet')
|
||||
else:
|
||||
raise ValueError('Invalid framework {}'.format(task_parameters.framework_type))
|
||||
|
||||
def _create_session_tf(self, task_parameters: TaskParameters):
|
||||
import tensorflow as tf
|
||||
config = tf.ConfigProto()
|
||||
config.allow_soft_placement = True # allow placing ops on cpu if they are not fit for gpu
|
||||
@@ -235,6 +246,15 @@ class GraphManager(object):
|
||||
if hasattr(self.task_parameters, 'checkpoint_save_dir') and self.task_parameters.checkpoint_save_dir:
|
||||
self.save_graph()
|
||||
|
||||
def create_session(self, task_parameters: TaskParameters):
|
||||
if task_parameters.framework_type == Frameworks.tensorflow:
|
||||
self._create_session_tf(task_parameters)
|
||||
elif task_parameters.framework_type == Frameworks.mxnet:
|
||||
self.set_session(sess=None) # Initialize all modules
|
||||
# TODO add checkpoint loading
|
||||
else:
|
||||
raise ValueError('Invalid framework {}'.format(task_parameters.framework_type))
|
||||
|
||||
def save_graph(self) -> None:
|
||||
"""
|
||||
Save the TF graph to a protobuf description file in the experiment directory
|
||||
@@ -490,27 +510,35 @@ class GraphManager(object):
|
||||
self.train_and_act(self.steps_between_evaluation_periods)
|
||||
self.evaluate(self.evaluation_steps)
|
||||
|
||||
def _restore_checkpoint_tf(self, checkpoint_dir: str):
|
||||
import tensorflow as tf
|
||||
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
|
||||
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
|
||||
variables = {}
|
||||
for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
|
||||
# Load the variable
|
||||
var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)
|
||||
|
||||
# Set the new name
|
||||
new_name = var_name
|
||||
new_name = new_name.replace('global/', 'online/')
|
||||
variables[new_name] = var
|
||||
|
||||
for v in self.variables_to_restore:
|
||||
self.sess.run(v.assign(variables[v.name.split(':')[0]]))
|
||||
|
||||
def restore_checkpoint(self):
|
||||
self.verify_graph_was_created()
|
||||
|
||||
# TODO: find better way to load checkpoints that were saved with a global network into the online network
|
||||
if hasattr(self.task_parameters, 'checkpoint_restore_dir') and self.task_parameters.checkpoint_restore_dir:
|
||||
import tensorflow as tf
|
||||
checkpoint_dir = self.task_parameters.checkpoint_restore_dir
|
||||
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
|
||||
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
|
||||
variables = {}
|
||||
for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
|
||||
# Load the variable
|
||||
var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)
|
||||
|
||||
# Set the new name
|
||||
new_name = var_name
|
||||
new_name = new_name.replace('global/', 'online/')
|
||||
variables[new_name] = var
|
||||
|
||||
for v in self.variables_to_restore:
|
||||
self.sess.run(v.assign(variables[v.name.split(':')[0]]))
|
||||
if self.task_parameters.framework_type == Frameworks.tensorflow:
|
||||
self._restore_checkpoint_tf(self.task_parameters.checkpoint_restore_dir)
|
||||
elif self.task_parameters.framework_type == Frameworks.mxnet:
|
||||
# TODO implement checkpoint restore
|
||||
pass
|
||||
else:
|
||||
raise ValueError('Invalid framework {}'.format(self.task_parameters.framework_type))
|
||||
|
||||
def occasionally_save_checkpoint(self):
|
||||
# only the chief process saves checkpoints
|
||||
@@ -529,7 +557,10 @@ class GraphManager(object):
|
||||
self.checkpoint_id,
|
||||
self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps]))
|
||||
if not isinstance(self.task_parameters, DistributedTaskParameters):
|
||||
saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path)
|
||||
if self.checkpoint_saver is not None:
|
||||
saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path)
|
||||
else:
|
||||
saved_checkpoint_path = "<Not Saved>"
|
||||
else:
|
||||
saved_checkpoint_path = checkpoint_path
|
||||
|
||||
|
||||
Reference in New Issue
Block a user