mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
coach v0.8.0
This commit is contained in:
176
parallel_actor.py
Normal file
176
parallel_actor.py
Normal file
@@ -0,0 +1,176 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import argparse
|
||||
import tensorflow as tf
|
||||
from architectures import *
|
||||
from environments import *
|
||||
from agents import *
|
||||
from utils import *
|
||||
import time
|
||||
import copy
|
||||
from logger import *
|
||||
from configurations import *
|
||||
from presets import *
|
||||
import shutil
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
# Disables write_meta_graph argument, which freezes entire process and is mostly useless.
|
||||
class FastSaver(tf.train.Saver):
|
||||
def save(self, sess, save_path, global_step=None, latest_filename=None,
|
||||
meta_graph_suffix="meta", write_meta_graph=True):
|
||||
super(FastSaver, self).save(sess, save_path, global_step, latest_filename,
|
||||
meta_graph_suffix, False)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--ps_hosts',
|
||||
help="(string) Comma-separated list of hostname:port pairs",
|
||||
default='',
|
||||
type=str)
|
||||
parser.add_argument('--worker_hosts',
|
||||
help="(string) Comma-separated list of hostname:port pairs",
|
||||
default='',
|
||||
type=str)
|
||||
parser.add_argument('--job_name',
|
||||
help="(string) One of 'ps', 'worker'",
|
||||
default='',
|
||||
type=str)
|
||||
parser.add_argument('--load_json_path',
|
||||
help="(string) Path to a JSON file to load.",
|
||||
default='',
|
||||
type=str)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
ps_hosts = args.ps_hosts.split(",")
|
||||
worker_hosts = args.worker_hosts.split(",")
|
||||
|
||||
# Create a cluster from the parameter server and worker hosts.
|
||||
cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
|
||||
|
||||
if args.job_name == "ps":
|
||||
# Create and start a parameter server
|
||||
server = tf.train.Server(cluster,
|
||||
job_name="ps",
|
||||
task_index=0,
|
||||
config=tf.ConfigProto())#device_filters=["/job:ps"]))
|
||||
server.join()
|
||||
|
||||
elif args.job_name == "worker":
|
||||
# get tuning parameters
|
||||
tuning_parameters = json_to_preset(args.load_json_path)
|
||||
|
||||
# dump documentation
|
||||
if not os.path.exists(tuning_parameters.experiment_path):
|
||||
os.makedirs(tuning_parameters.experiment_path)
|
||||
if tuning_parameters.evaluate_only:
|
||||
logger.set_dump_dir(tuning_parameters.experiment_path, tuning_parameters.task_id, filename='evaluator')
|
||||
else:
|
||||
logger.set_dump_dir(tuning_parameters.experiment_path, tuning_parameters.task_id)
|
||||
|
||||
# multi-threading parameters
|
||||
tuning_parameters.start_time = start_time
|
||||
|
||||
# User is allowed to override the number of synchronized threads if he wishes to do so.
|
||||
# Else, just sync over all of them.
|
||||
if not tuning_parameters.synchronize_over_num_threads:
|
||||
tuning_parameters.synchronize_over_num_threads = tuning_parameters.num_threads
|
||||
|
||||
tuning_parameters.distributed = True
|
||||
if tuning_parameters.evaluate_only:
|
||||
tuning_parameters.visualization.dump_signals_to_csv_every_x_episodes = 1
|
||||
|
||||
# Create and start a worker
|
||||
server = tf.train.Server(cluster,
|
||||
job_name="worker",
|
||||
task_index=tuning_parameters.task_id)
|
||||
|
||||
# Assigns ops to the local worker by default.
|
||||
device = tf.train.replica_device_setter(worker_device="/job:worker/task:%d/cpu:0" % tuning_parameters.task_id,
|
||||
cluster=cluster)
|
||||
|
||||
# create the agent and the environment
|
||||
env_instance = create_environment(tuning_parameters)
|
||||
exec('agent = ' + tuning_parameters.agent.type + '(env_instance, tuning_parameters, replicated_device=device, '
|
||||
'thread_id=tuning_parameters.task_id)')
|
||||
|
||||
# building the scaffold
|
||||
# local vars
|
||||
local_variables = []
|
||||
for network in agent.networks:
|
||||
local_variables += network.get_local_variables()
|
||||
local_variables += tf.local_variables()
|
||||
|
||||
# global vars
|
||||
global_variables = []
|
||||
for network in agent.networks:
|
||||
global_variables += network.get_global_variables()
|
||||
|
||||
# out of scope variables - not sure why this variables are created out of scope
|
||||
variables_not_in_scope = [v for v in tf.global_variables() if v not in global_variables and v not in local_variables]
|
||||
|
||||
# init ops
|
||||
global_init_op = tf.variables_initializer(global_variables)
|
||||
local_init_op = tf.variables_initializer(local_variables + variables_not_in_scope)
|
||||
out_of_scope_init_op = tf.variables_initializer(variables_not_in_scope)
|
||||
init_all_op = tf.global_variables_initializer() # this includes global, local, and out of scope
|
||||
ready_op = tf.report_uninitialized_variables(global_variables + local_variables)
|
||||
ready_for_local_init_op = tf.report_uninitialized_variables([])
|
||||
|
||||
def init_fn(scaffold, session):
|
||||
session.run(init_all_op)
|
||||
|
||||
scaffold = tf.train.Scaffold(init_op=init_all_op,
|
||||
init_fn=init_fn,
|
||||
ready_op=ready_op,
|
||||
ready_for_local_init_op=ready_for_local_init_op,
|
||||
local_init_op=local_init_op)
|
||||
|
||||
# Due to awkward tensorflow behavior where the same variable is used to decide whether to restore a model
|
||||
# (and where from), or just save the model (and where to), we employ the below. In case where a restore folder
|
||||
# is given, it will also be used as the folder to save new checkpoints of the trained model to. Otherwise the
|
||||
# experiment's folder will be used as the folder to save the trained model to.
|
||||
if tuning_parameters.checkpoint_restore_dir:
|
||||
checkpoint_dir = tuning_parameters.checkpoint_restore_dir
|
||||
elif tuning_parameters.save_model_sec:
|
||||
checkpoint_dir = tuning_parameters.experiment_path
|
||||
else:
|
||||
checkpoint_dir = None
|
||||
|
||||
# Set the session
|
||||
sess = tf.train.MonitoredTrainingSession(
|
||||
server.target,
|
||||
is_chief=tuning_parameters.task_id == 0,
|
||||
scaffold=scaffold,
|
||||
hooks=[],
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
save_checkpoint_secs=tuning_parameters.save_model_sec)
|
||||
tuning_parameters.sess = sess
|
||||
for network in agent.networks:
|
||||
network.set_session(sess)
|
||||
|
||||
# Start the training or evaluation
|
||||
if tuning_parameters.evaluate_only:
|
||||
agent.evaluate(sys.maxsize, keep_networks_synced=True) # evaluate forever
|
||||
else:
|
||||
agent.improve()
|
||||
else:
|
||||
screen.error("Invalid mode requested for parallel_actor.")
|
||||
exit(1)
|
||||
|
||||
Reference in New Issue
Block a user