1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

Adding mxnet components to rl_coach/architectures (#60)

Adding mxnet components to rl_coach architectures.

- Supports PPO and DQN
- Tested with CartPole_PPO and CarPole_DQN
- Normalizing filters don't work right now (see #49) and are disabled in CartPole_PPO preset
- Checkpointing is disabled for MXNet
This commit is contained in:
Sina Afrooze
2018-11-07 07:07:15 -08:00
committed by Itai Caspi
parent e7a91b4dc3
commit 5fadb9c18e
39 changed files with 3864 additions and 44 deletions

View File

@@ -22,7 +22,7 @@ from distutils.dir_util import copy_tree, remove_tree
from typing import List, Tuple
import contextlib
from rl_coach.base_parameters import iterable_to_items, TaskParameters, DistributedTaskParameters, \
from rl_coach.base_parameters import iterable_to_items, TaskParameters, DistributedTaskParameters, Frameworks, \
VisualizationParameters, \
Parameters, PresetValidationParameters
from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, TrainingSteps, EnvironmentEpisodes, \
@@ -161,7 +161,8 @@ class GraphManager(object):
"""
raise NotImplementedError("")
def create_worker_or_parameters_server(self, task_parameters: DistributedTaskParameters):
@staticmethod
def _create_worker_or_parameters_server_tf(task_parameters: DistributedTaskParameters):
import tensorflow as tf
config = tf.ConfigProto()
config.allow_soft_placement = True # allow placing ops on cpu if they are not fit for gpu
@@ -170,7 +171,8 @@ class GraphManager(object):
config.intra_op_parallelism_threads = 1
config.inter_op_parallelism_threads = 1
from rl_coach.architectures.tensorflow_components.distributed_tf_utils import create_and_start_parameters_server, \
from rl_coach.architectures.tensorflow_components.distributed_tf_utils import \
create_and_start_parameters_server, \
create_cluster_spec, create_worker_server_and_device
# create cluster spec
@@ -190,7 +192,16 @@ class GraphManager(object):
raise ValueError("The job type should be either ps or worker and not {}"
.format(task_parameters.job_type))
def create_session(self, task_parameters: DistributedTaskParameters):
@staticmethod
def create_worker_or_parameters_server(task_parameters: DistributedTaskParameters):
if task_parameters.framework_type == Frameworks.tensorflow:
GraphManager._create_worker_or_parameters_server_tf(task_parameters)
elif task_parameters.framework_type == Frameworks.mxnet:
raise NotImplementedError('Distributed training not implemented for MXNet')
else:
raise ValueError('Invalid framework {}'.format(task_parameters.framework_type))
def _create_session_tf(self, task_parameters: TaskParameters):
import tensorflow as tf
config = tf.ConfigProto()
config.allow_soft_placement = True # allow placing ops on cpu if they are not fit for gpu
@@ -235,6 +246,15 @@ class GraphManager(object):
if hasattr(self.task_parameters, 'checkpoint_save_dir') and self.task_parameters.checkpoint_save_dir:
self.save_graph()
def create_session(self, task_parameters: TaskParameters):
if task_parameters.framework_type == Frameworks.tensorflow:
self._create_session_tf(task_parameters)
elif task_parameters.framework_type == Frameworks.mxnet:
self.set_session(sess=None) # Initialize all modules
# TODO add checkpoint loading
else:
raise ValueError('Invalid framework {}'.format(task_parameters.framework_type))
def save_graph(self) -> None:
"""
Save the TF graph to a protobuf description file in the experiment directory
@@ -490,27 +510,35 @@ class GraphManager(object):
self.train_and_act(self.steps_between_evaluation_periods)
self.evaluate(self.evaluation_steps)
def _restore_checkpoint_tf(self, checkpoint_dir: str):
import tensorflow as tf
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
variables = {}
for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
# Load the variable
var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)
# Set the new name
new_name = var_name
new_name = new_name.replace('global/', 'online/')
variables[new_name] = var
for v in self.variables_to_restore:
self.sess.run(v.assign(variables[v.name.split(':')[0]]))
def restore_checkpoint(self):
self.verify_graph_was_created()
# TODO: find better way to load checkpoints that were saved with a global network into the online network
if hasattr(self.task_parameters, 'checkpoint_restore_dir') and self.task_parameters.checkpoint_restore_dir:
import tensorflow as tf
checkpoint_dir = self.task_parameters.checkpoint_restore_dir
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
variables = {}
for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
# Load the variable
var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)
# Set the new name
new_name = var_name
new_name = new_name.replace('global/', 'online/')
variables[new_name] = var
for v in self.variables_to_restore:
self.sess.run(v.assign(variables[v.name.split(':')[0]]))
if self.task_parameters.framework_type == Frameworks.tensorflow:
self._restore_checkpoint_tf(self.task_parameters.checkpoint_restore_dir)
elif self.task_parameters.framework_type == Frameworks.mxnet:
# TODO implement checkpoint restore
pass
else:
raise ValueError('Invalid framework {}'.format(self.task_parameters.framework_type))
def occasionally_save_checkpoint(self):
# only the chief process saves checkpoints
@@ -529,7 +557,10 @@ class GraphManager(object):
self.checkpoint_id,
self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps]))
if not isinstance(self.task_parameters, DistributedTaskParameters):
saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path)
if self.checkpoint_saver is not None:
saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path)
else:
saved_checkpoint_path = "<Not Saved>"
else:
saved_checkpoint_path = checkpoint_path