mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Moved coach to its top level module.
This commit is contained in:
331
scripts/coach
Normal file
331
scripts/coach
Normal file
@@ -0,0 +1,331 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import argparse
|
||||
import atexit
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
from coach import agents # noqa
|
||||
from coach import configurations as conf
|
||||
from coach import environments
|
||||
from coach import logger
|
||||
from coach import presets
|
||||
from coach import utils
|
||||
|
||||
|
||||
if len(set(logger.failed_imports)) > 0:
|
||||
logger.screen.warning("Warning: failed to import the following packages - {}".format(', '.join(set(logger.failed_imports))))
|
||||
|
||||
|
||||
def set_framework(framework_type):
|
||||
# choosing neural network framework
|
||||
framework = conf.Frameworks().get(framework_type)
|
||||
sess = None
|
||||
if framework == conf.Frameworks.TensorFlow:
|
||||
import tensorflow as tf
|
||||
config = tf.ConfigProto()
|
||||
config.allow_soft_placement = True
|
||||
config.gpu_options.allow_growth = True
|
||||
config.gpu_options.per_process_gpu_memory_fraction = 0.2
|
||||
sess = tf.Session(config=config)
|
||||
elif framework == conf.Frameworks.Neon:
|
||||
import ngraph as ng
|
||||
sess = ng.transformers.make_transformer()
|
||||
logger.screen.log_title("Using {} framework".format(conf.Frameworks().to_string(framework)))
|
||||
return sess
|
||||
|
||||
|
||||
def check_input_and_fill_run_dict(parser):
|
||||
args = parser.parse_args()
|
||||
|
||||
# if no arg is given
|
||||
if len(sys.argv) == 1:
|
||||
parser.print_help()
|
||||
exit(0)
|
||||
|
||||
# list available presets
|
||||
if args.list:
|
||||
presets_lists = utils.list_all_classes_in_module(presets)
|
||||
logger.screen.log_title("Available Presets:")
|
||||
for preset in presets_lists:
|
||||
print(preset)
|
||||
sys.exit(0)
|
||||
|
||||
# check inputs
|
||||
try:
|
||||
# num_workers = int(args.num_workers)
|
||||
num_workers = int(re.match("^\d+$", args.num_workers).group(0))
|
||||
except ValueError:
|
||||
logger.screen.error("Parameter num_workers should be an integer.")
|
||||
|
||||
preset_names = utils.list_all_classes_in_module(presets)
|
||||
if args.preset is not None and args.preset not in preset_names:
|
||||
logger.screen.error("A non-existing preset was selected. ")
|
||||
|
||||
if args.checkpoint_restore_dir is not None and not os.path.exists(args.checkpoint_restore_dir):
|
||||
logger.screen.error("The requested checkpoint folder to load from does not exist. ")
|
||||
|
||||
if args.save_model_sec is not None:
|
||||
try:
|
||||
args.save_model_sec = int(args.save_model_sec)
|
||||
except ValueError:
|
||||
logger.screen.error("Parameter save_model_sec should be an integer.")
|
||||
|
||||
if args.preset is None and (args.agent_type is None or args.environment_type is None
|
||||
or args.exploration_policy_type is None) and not args.play:
|
||||
logger.screen.error('When no preset is given for Coach to run, the user is expected to input the desired agent_type,'
|
||||
' environment_type and exploration_policy_type to assemble a preset. '
|
||||
'\nAt least one of these parameters was not given.')
|
||||
elif args.preset is None and args.play and args.environment_type is None:
|
||||
logger.screen.error('When no preset is given for Coach to run, and the user requests human control over the environment,'
|
||||
' the user is expected to input the desired environment_type and level.'
|
||||
'\nAt least one of these parameters was not given.')
|
||||
elif args.preset is None and args.play and args.environment_type:
|
||||
args.agent_type = 'Human'
|
||||
args.exploration_policy_type = 'ExplorationParameters'
|
||||
|
||||
# get experiment name and path
|
||||
experiment_name = logger.logger.get_experiment_name(args.experiment_name)
|
||||
experiment_path = logger.logger.get_experiment_path(experiment_name)
|
||||
|
||||
if args.play and num_workers > 1:
|
||||
logger.screen.warning("Playing the game as a human is only available with a single worker. "
|
||||
"The number of workers will be reduced to 1")
|
||||
num_workers = 1
|
||||
|
||||
# fill run_dict
|
||||
run_dict = dict()
|
||||
run_dict['agent_type'] = args.agent_type
|
||||
run_dict['environment_type'] = args.environment_type
|
||||
run_dict['exploration_policy_type'] = args.exploration_policy_type
|
||||
run_dict['level'] = args.level
|
||||
run_dict['preset'] = args.preset
|
||||
run_dict['custom_parameter'] = args.custom_parameter
|
||||
run_dict['experiment_path'] = experiment_path
|
||||
run_dict['framework'] = conf.Frameworks().get(args.framework)
|
||||
run_dict['play'] = args.play
|
||||
run_dict['evaluate'] = args.evaluate# or args.play
|
||||
|
||||
# multi-threading parameters
|
||||
run_dict['num_threads'] = num_workers
|
||||
|
||||
# checkpoints
|
||||
run_dict['save_model_sec'] = args.save_model_sec
|
||||
run_dict['save_model_dir'] = experiment_path if args.save_model_sec is not None else None
|
||||
run_dict['checkpoint_restore_dir'] = args.checkpoint_restore_dir
|
||||
|
||||
# visualization
|
||||
run_dict['visualization.dump_gifs'] = args.dump_gifs
|
||||
run_dict['visualization.render'] = args.render
|
||||
run_dict['visualization.tensorboard'] = args.tensorboard
|
||||
|
||||
return args, run_dict
|
||||
|
||||
|
||||
def run_dict_to_json(_run_dict, task_id=''):
|
||||
if task_id != '':
|
||||
json_path = os.path.join(_run_dict['experiment_path'], 'run_dict_worker{}.json'.format(task_id))
|
||||
else:
|
||||
json_path = os.path.join(_run_dict['experiment_path'], 'run_dict.json')
|
||||
|
||||
with open(json_path, 'w') as outfile:
|
||||
json.dump(_run_dict, outfile, indent=2)
|
||||
|
||||
return json_path
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-p', '--preset',
|
||||
help="(string) Name of a preset to run (as configured in presets.py)",
|
||||
default=None,
|
||||
type=str)
|
||||
parser.add_argument('-l', '--list',
|
||||
help="(flag) List all available presets",
|
||||
action='store_true')
|
||||
parser.add_argument('-e', '--experiment_name',
|
||||
help="(string) Experiment name to be used to store the results.",
|
||||
default='',
|
||||
type=str)
|
||||
parser.add_argument('-r', '--render',
|
||||
help="(flag) Render environment",
|
||||
action='store_true')
|
||||
parser.add_argument('-f', '--framework',
|
||||
help="(string) Neural network framework. Available values: tensorflow, neon",
|
||||
default='tensorflow',
|
||||
type=str)
|
||||
parser.add_argument('-n', '--num_workers',
|
||||
help="(int) Number of workers for multi-process based agents, e.g. A3C",
|
||||
default='1',
|
||||
type=str)
|
||||
parser.add_argument('--play',
|
||||
help="(flag) Play as a human by controlling the game with the keyboard. "
|
||||
"This option will save a replay buffer with the game play.",
|
||||
action='store_true')
|
||||
parser.add_argument('--evaluate',
|
||||
help="(flag) Run evaluation only. This is a convenient way to disable "
|
||||
"training in order to evaluate an existing checkpoint.",
|
||||
action='store_true')
|
||||
parser.add_argument('-v', '--verbose',
|
||||
help="(flag) Don't suppress TensorFlow debug prints.",
|
||||
action='store_true')
|
||||
parser.add_argument('-s', '--save_model_sec',
|
||||
help="(int) Time in seconds between saving checkpoints of the model.",
|
||||
default=None,
|
||||
type=int)
|
||||
parser.add_argument('-crd', '--checkpoint_restore_dir',
|
||||
help='(string) Path to a folder containing a checkpoint to restore the model from.',
|
||||
type=str)
|
||||
parser.add_argument('-dg', '--dump_gifs',
|
||||
help="(flag) Enable the gif saving functionality.",
|
||||
action='store_true')
|
||||
parser.add_argument('-at', '--agent_type',
|
||||
help="(string) Choose an agent type class to override on top of the selected preset. "
|
||||
"If no preset is defined, a preset can be set from the command-line by combining settings "
|
||||
"which are set by using --agent_type, --experiment_type, --environemnt_type",
|
||||
default=None,
|
||||
type=str)
|
||||
parser.add_argument('-et', '--environment_type',
|
||||
help="(string) Choose an environment type class to override on top of the selected preset."
|
||||
"If no preset is defined, a preset can be set from the command-line by combining settings "
|
||||
"which are set by using --agent_type, --experiment_type, --environemnt_type",
|
||||
default=None,
|
||||
type=str)
|
||||
parser.add_argument('-ept', '--exploration_policy_type',
|
||||
help="(string) Choose an exploration policy type class to override on top of the selected "
|
||||
"preset."
|
||||
"If no preset is defined, a preset can be set from the command-line by combining settings "
|
||||
"which are set by using --agent_type, --experiment_type, --environemnt_type"
|
||||
,
|
||||
default=None,
|
||||
type=str)
|
||||
parser.add_argument('-lvl', '--level',
|
||||
help="(string) Choose the level that will be played in the environment that was selected."
|
||||
"This value will override the level parameter in the environment class."
|
||||
,
|
||||
default=None,
|
||||
type=str)
|
||||
parser.add_argument('-cp', '--custom_parameter',
|
||||
help="(string) Semicolon separated parameters used to override specific parameters on top of"
|
||||
" the selected preset (or on top of the command-line assembled one). "
|
||||
"Whenever a parameter value is a string, it should be inputted as '\\\"string\\\"'. "
|
||||
"For ex.: "
|
||||
"\"visualization.render=False; num_training_iterations=500; optimizer='rmsprop'\"",
|
||||
default=None,
|
||||
type=str)
|
||||
parser.add_argument('--print_parameters',
|
||||
help="(flag) Print tuning_parameters to stdout",
|
||||
action='store_true')
|
||||
parser.add_argument('-tb', '--tensorboard',
|
||||
help="(flag) When using the TensorFlow backend, enable TensorBoard log dumps. ",
|
||||
action='store_true')
|
||||
parser.add_argument('-ns', '--no_summary',
|
||||
help="(flag) Prevent Coach from printing a summary and asking questions at the end of runs",
|
||||
action='store_true')
|
||||
|
||||
args, run_dict = check_input_and_fill_run_dict(parser)
|
||||
|
||||
# turn TF debug prints off
|
||||
if not args.verbose and args.framework.lower() == 'tensorflow':
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
|
||||
# dump documentation
|
||||
logger.logger.set_dump_dir(run_dict['experiment_path'], add_timestamp=True)
|
||||
if not args.no_summary:
|
||||
atexit.register(logger.logger.summarize_experiment)
|
||||
logger.screen.change_terminal_title(logger.logger.experiment_name)
|
||||
|
||||
# Single-threaded runs
|
||||
if run_dict['num_threads'] == 1:
|
||||
# set tuning parameters
|
||||
json_run_dict_path = run_dict_to_json(run_dict)
|
||||
tuning_parameters = presets.json_to_preset(json_run_dict_path)
|
||||
tuning_parameters.sess = set_framework(args.framework)
|
||||
|
||||
if args.print_parameters:
|
||||
print('tuning_parameters', tuning_parameters)
|
||||
|
||||
# Single-thread runs
|
||||
tuning_parameters.task_index = 0
|
||||
env_instance = environments.create_environment(tuning_parameters)
|
||||
agent = eval('agents.' + tuning_parameters.agent.type +
|
||||
'(env_instance, tuning_parameters)')
|
||||
|
||||
# Start the training or evaluation
|
||||
if tuning_parameters.evaluate:
|
||||
agent.evaluate(sys.maxsize, keep_networks_synced=True) # evaluate forever
|
||||
else:
|
||||
agent.improve()
|
||||
|
||||
# Multi-threaded runs
|
||||
else:
|
||||
assert args.framework.lower() == 'tensorflow', "Distributed training works only with TensorFlow"
|
||||
os.environ["OMP_NUM_THREADS"]="1"
|
||||
# set parameter server and workers addresses
|
||||
ps_hosts = "localhost:{}".format(utils.get_open_port())
|
||||
worker_hosts = ",".join(["localhost:{}".format(utils.get_open_port()) for i in range(run_dict['num_threads'] + 1)])
|
||||
|
||||
# Make sure to disable GPU so that all the workers will use the CPU
|
||||
utils.set_cpu()
|
||||
|
||||
# create a parameter server
|
||||
cmd = [
|
||||
"python3",
|
||||
"./parallel_actor.py",
|
||||
"--ps_hosts={}".format(ps_hosts),
|
||||
"--worker_hosts={}".format(worker_hosts),
|
||||
"--job_name=ps",
|
||||
]
|
||||
parameter_server = subprocess.Popen(cmd)
|
||||
|
||||
logger.screen.log_title("*** Distributed Training ***")
|
||||
time.sleep(1)
|
||||
|
||||
# create N training workers and 1 evaluating worker
|
||||
workers = []
|
||||
|
||||
for i in range(run_dict['num_threads'] + 1):
|
||||
# this is the evaluation worker
|
||||
run_dict['task_id'] = i
|
||||
if i == run_dict['num_threads']:
|
||||
run_dict['evaluate_only'] = True
|
||||
run_dict['visualization.render'] = args.render
|
||||
else:
|
||||
run_dict['evaluate_only'] = False
|
||||
run_dict['visualization.render'] = False # #In a parallel setting, only the evaluation agent renders
|
||||
|
||||
json_run_dict_path = run_dict_to_json(run_dict, i)
|
||||
workers_args = ["python3", "./parallel_actor.py",
|
||||
"--ps_hosts={}".format(ps_hosts),
|
||||
"--worker_hosts={}".format(worker_hosts),
|
||||
"--job_name=worker",
|
||||
"--load_json={}".format(json_run_dict_path)]
|
||||
|
||||
p = subprocess.Popen(workers_args)
|
||||
|
||||
if i != run_dict['num_threads']:
|
||||
workers.append(p)
|
||||
else:
|
||||
evaluation_worker = p
|
||||
|
||||
# wait for all workers
|
||||
[w.wait() for w in workers]
|
||||
evaluation_worker.kill()
|
||||
172
scripts/parallel_actor.py
Normal file
172
scripts/parallel_actor.py
Normal file
@@ -0,0 +1,172 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from coach import agents
|
||||
from coach import environments
|
||||
from coach import logger
|
||||
from coach import presets
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--ps_hosts',
|
||||
help="(string) Comma-separated list of hostname:port pairs",
|
||||
default='',
|
||||
type=str)
|
||||
parser.add_argument('--worker_hosts',
|
||||
help="(string) Comma-separated list of hostname:port pairs",
|
||||
default='',
|
||||
type=str)
|
||||
parser.add_argument('--job_name',
|
||||
help="(string) One of 'ps', 'worker'",
|
||||
default='',
|
||||
type=str)
|
||||
parser.add_argument('--load_json_path',
|
||||
help="(string) Path to a JSON file to load.",
|
||||
default='',
|
||||
type=str)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
ps_hosts = args.ps_hosts.split(",")
|
||||
worker_hosts = args.worker_hosts.split(",")
|
||||
|
||||
# Create a cluster from the parameter server and worker hosts.
|
||||
cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
|
||||
|
||||
if args.job_name == "ps":
|
||||
# Create and start a parameter server
|
||||
server = tf.train.Server(cluster,
|
||||
job_name="ps",
|
||||
task_index=0,
|
||||
config=tf.ConfigProto())#device_filters=["/job:ps"]))
|
||||
server.join()
|
||||
|
||||
elif args.job_name == "worker":
|
||||
# get tuning parameters
|
||||
tuning_parameters = presets.json_to_preset(args.load_json_path)
|
||||
|
||||
# dump documentation
|
||||
if not os.path.exists(tuning_parameters.experiment_path):
|
||||
os.makedirs(tuning_parameters.experiment_path)
|
||||
if tuning_parameters.evaluate_only:
|
||||
logger.logger.set_dump_dir(tuning_parameters.experiment_path, tuning_parameters.task_id, filename='evaluator')
|
||||
else:
|
||||
logger.logger.set_dump_dir(tuning_parameters.experiment_path, tuning_parameters.task_id)
|
||||
|
||||
# multi-threading parameters
|
||||
tuning_parameters.start_time = start_time
|
||||
|
||||
# User is allowed to override the number of synchronized threads if he wishes to do so.
|
||||
# Else, just sync over all of them.
|
||||
if not tuning_parameters.synchronize_over_num_threads:
|
||||
tuning_parameters.synchronize_over_num_threads = tuning_parameters.num_threads
|
||||
|
||||
tuning_parameters.distributed = True
|
||||
if tuning_parameters.evaluate_only:
|
||||
tuning_parameters.visualization.dump_signals_to_csv_every_x_episodes = 1
|
||||
|
||||
# Create and start a worker
|
||||
server = tf.train.Server(cluster,
|
||||
job_name="worker",
|
||||
task_index=tuning_parameters.task_id)
|
||||
|
||||
# Assigns ops to the local worker by default.
|
||||
device = tf.train.replica_device_setter(worker_device="/job:worker/task:%d/cpu:0" % tuning_parameters.task_id,
|
||||
cluster=cluster)
|
||||
|
||||
# create the agent and the environment
|
||||
env_instance = environments.create_environment(tuning_parameters)
|
||||
exec('agent = agents.' + tuning_parameters.agent.type + '(env_instance, tuning_parameters, replicated_device=device, '
|
||||
'thread_id=tuning_parameters.task_id)')
|
||||
|
||||
# building the scaffold
|
||||
# local vars
|
||||
local_variables = []
|
||||
for network in agent.networks:
|
||||
local_variables += network.get_local_variables()
|
||||
local_variables += tf.local_variables()
|
||||
|
||||
# global vars
|
||||
global_variables = []
|
||||
for network in agent.networks:
|
||||
global_variables += network.get_global_variables()
|
||||
|
||||
# out of scope variables - not sure why this variables are created out of scope
|
||||
variables_not_in_scope = [v for v in tf.global_variables() if v not in global_variables and v not in local_variables]
|
||||
|
||||
# init ops
|
||||
global_init_op = tf.variables_initializer(global_variables)
|
||||
local_init_op = tf.variables_initializer(local_variables + variables_not_in_scope)
|
||||
out_of_scope_init_op = tf.variables_initializer(variables_not_in_scope)
|
||||
init_all_op = tf.global_variables_initializer() # this includes global, local, and out of scope
|
||||
ready_op = tf.report_uninitialized_variables(global_variables + local_variables)
|
||||
ready_for_local_init_op = tf.report_uninitialized_variables([])
|
||||
|
||||
def init_fn(scaffold, session):
|
||||
session.run(init_all_op)
|
||||
|
||||
scaffold = tf.train.Scaffold(init_op=init_all_op,
|
||||
init_fn=init_fn,
|
||||
ready_op=ready_op,
|
||||
ready_for_local_init_op=ready_for_local_init_op,
|
||||
local_init_op=local_init_op)
|
||||
|
||||
# Due to awkward tensorflow behavior where the same variable is used to decide whether to restore a model
|
||||
# (and where from), or just save the model (and where to), we employ the below. In case where a restore folder
|
||||
# is given, it will also be used as the folder to save new checkpoints of the trained model to. Otherwise the
|
||||
# experiment's folder will be used as the folder to save the trained model to.
|
||||
if tuning_parameters.checkpoint_restore_dir:
|
||||
checkpoint_dir = tuning_parameters.checkpoint_restore_dir
|
||||
elif tuning_parameters.save_model_sec:
|
||||
checkpoint_dir = tuning_parameters.experiment_path
|
||||
else:
|
||||
checkpoint_dir = None
|
||||
|
||||
# Set the session
|
||||
sess = tf.train.MonitoredTrainingSession(
|
||||
server.target,
|
||||
is_chief=tuning_parameters.task_id == 0,
|
||||
scaffold=scaffold,
|
||||
hooks=[],
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
save_checkpoint_secs=tuning_parameters.save_model_sec)
|
||||
tuning_parameters.sess = sess
|
||||
for network in agent.networks:
|
||||
network.set_session(sess)
|
||||
|
||||
if tuning_parameters.visualization.tensorboard:
|
||||
# Write the merged summaries to the current experiment directory
|
||||
agent.main_network.online_network.train_writer = tf.summary.FileWriter(
|
||||
tuning_parameters.experiment_path + '/tensorboard_worker{}'.format(tuning_parameters.task_id),
|
||||
sess.graph)
|
||||
|
||||
# Start the training or evaluation
|
||||
if tuning_parameters.evaluate_only:
|
||||
agent.evaluate(sys.maxsize, keep_networks_synced=True) # evaluate forever
|
||||
else:
|
||||
agent.improve()
|
||||
else:
|
||||
logger.screen.error("Invalid mode requested for parallel_actor.")
|
||||
exit(1)
|
||||
|
||||
Reference in New Issue
Block a user