1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

update of api docstrings across coach and tutorials [WIP] (#91)

* updating the documentation website
* adding the built docs
* update of api docstrings across coach and tutorials 0-2
* added some missing api documentation
* New Sphinx based documentation
This commit is contained in:
Itai Caspi
2018-11-15 15:00:13 +02:00
committed by Gal Novik
parent 524f8436a2
commit 6d40ad1650
517 changed files with 71034 additions and 12834 deletions

View File

@@ -83,91 +83,91 @@ def start_graph(graph_manager: 'GraphManager', task_parameters: 'TaskParameters'
def handle_distributed_coach_tasks(graph_manager, args):
ckpt_inside_container = "/checkpoint"
ckpt_inside_container = "/checkpoint"
memory_backend_params = None
if args.memory_backend_params:
memory_backend_params = json.loads(args.memory_backend_params)
memory_backend_params['run_type'] = str(args.distributed_coach_run_type)
graph_manager.agent_params.memory.register_var('memory_backend_params', construct_memory_params(memory_backend_params))
memory_backend_params = None
if args.memory_backend_params:
memory_backend_params = json.loads(args.memory_backend_params)
memory_backend_params['run_type'] = str(args.distributed_coach_run_type)
graph_manager.agent_params.memory.register_var('memory_backend_params', construct_memory_params(memory_backend_params))
data_store_params = None
data_store_params = None
if args.data_store_params:
data_store_params = construct_data_store_params(json.loads(args.data_store_params))
data_store_params.checkpoint_dir = ckpt_inside_container
graph_manager.data_store_params = data_store_params
if args.distributed_coach_run_type == RunType.TRAINER:
training_worker(
graph_manager=graph_manager,
checkpoint_dir=ckpt_inside_container
)
if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER:
data_store = None
if args.data_store_params:
data_store_params = construct_data_store_params(json.loads(args.data_store_params))
data_store_params.checkpoint_dir = ckpt_inside_container
graph_manager.data_store_params = data_store_params
data_store = get_data_store(data_store_params)
wait_for_checkpoint(checkpoint_dir=ckpt_inside_container, data_store=data_store)
if args.distributed_coach_run_type == RunType.TRAINER:
training_worker(
graph_manager=graph_manager,
checkpoint_dir=ckpt_inside_container
)
if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER:
data_store = None
if args.data_store_params:
data_store = get_data_store(data_store_params)
wait_for_checkpoint(checkpoint_dir=ckpt_inside_container, data_store=data_store)
rollout_worker(
graph_manager=graph_manager,
checkpoint_dir=ckpt_inside_container,
data_store=data_store,
num_workers=args.num_workers
)
rollout_worker(
graph_manager=graph_manager,
checkpoint_dir=ckpt_inside_container,
data_store=data_store,
num_workers=args.num_workers
)
def handle_distributed_coach_orchestrator(graph_manager, args):
ckpt_inside_container = "/checkpoint"
rollout_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.ROLLOUT_WORKER)] + sys.argv[1:]
trainer_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.TRAINER)] + sys.argv[1:]
ckpt_inside_container = "/checkpoint"
rollout_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.ROLLOUT_WORKER)] + sys.argv[1:]
trainer_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.TRAINER)] + sys.argv[1:]
if '--experiment_name' not in rollout_command:
rollout_command = rollout_command + ['--experiment_name', args.experiment_name]
if '--experiment_name' not in rollout_command:
rollout_command = rollout_command + ['--experiment_name', args.experiment_name]
if '--experiment_name' not in trainer_command:
trainer_command = trainer_command + ['--experiment_name', args.experiment_name]
if '--experiment_name' not in trainer_command:
trainer_command = trainer_command + ['--experiment_name', args.experiment_name]
memory_backend_params = None
if args.memory_backend == "redispubsub":
memory_backend_params = RedisPubSubMemoryBackendParameters()
memory_backend_params = None
if args.memory_backend == "redispubsub":
memory_backend_params = RedisPubSubMemoryBackendParameters()
ds_params_instance = None
if args.data_store == "s3":
ds_params = DataStoreParameters("s3", "", "")
ds_params_instance = S3DataStoreParameters(ds_params=ds_params, end_point=args.s3_end_point, bucket_name=args.s3_bucket_name,
creds_file=args.s3_creds_file, checkpoint_dir=ckpt_inside_container)
ds_params_instance = None
if args.data_store == "s3":
ds_params = DataStoreParameters("s3", "", "")
ds_params_instance = S3DataStoreParameters(ds_params=ds_params, end_point=args.s3_end_point, bucket_name=args.s3_bucket_name,
creds_file=args.s3_creds_file, checkpoint_dir=ckpt_inside_container)
worker_run_type_params = RunTypeParameters(args.image, rollout_command, run_type=str(RunType.ROLLOUT_WORKER), num_replicas=args.num_workers)
trainer_run_type_params = RunTypeParameters(args.image, trainer_command, run_type=str(RunType.TRAINER))
worker_run_type_params = RunTypeParameters(args.image, rollout_command, run_type=str(RunType.ROLLOUT_WORKER), num_replicas=args.num_workers)
trainer_run_type_params = RunTypeParameters(args.image, trainer_command, run_type=str(RunType.TRAINER))
orchestration_params = KubernetesParameters([worker_run_type_params, trainer_run_type_params],
kubeconfig='~/.kube/config',
memory_backend_parameters=memory_backend_params,
data_store_params=ds_params_instance)
orchestrator = Kubernetes(orchestration_params)
if not orchestrator.setup():
print("Could not setup.")
return
orchestration_params = KubernetesParameters([worker_run_type_params, trainer_run_type_params],
kubeconfig='~/.kube/config',
memory_backend_parameters=memory_backend_params,
data_store_params=ds_params_instance)
orchestrator = Kubernetes(orchestration_params)
if not orchestrator.setup():
print("Could not setup.")
return
if orchestrator.deploy_trainer():
print("Successfully deployed trainer.")
else:
print("Could not deploy trainer.")
return
if orchestrator.deploy_trainer():
print("Successfully deployed trainer.")
else:
print("Could not deploy trainer.")
return
if orchestrator.deploy_worker():
print("Successfully deployed rollout worker(s).")
else:
print("Could not deploy rollout worker(s).")
return
if orchestrator.deploy_worker():
print("Successfully deployed rollout worker(s).")
else:
print("Could not deploy rollout worker(s).")
return
try:
orchestrator.trainer_logs()
except KeyboardInterrupt:
pass
try:
orchestrator.trainer_logs()
except KeyboardInterrupt:
pass
orchestrator.undeploy()
orchestrator.undeploy()
class CoachLauncher(object):
@@ -192,7 +192,6 @@ class CoachLauncher(object):
graph_manager = self.get_graph_manager_from_args(args)
self.run_graph_manager(graph_manager, args)
def get_graph_manager_from_args(self, args: argparse.Namespace) -> 'GraphManager':
"""
Return the graph manager according to the command line arguments given by the user.
@@ -251,7 +250,6 @@ class CoachLauncher(object):
return graph_manager
def display_all_presets_and_exit(self):
# list available presets
screen.log_title("Available Presets:")
@@ -259,7 +257,6 @@ class CoachLauncher(object):
print(preset)
sys.exit(0)
def expand_preset(self, preset):
"""
Replace a short preset name with the full python path, and verify that it can be imported.
@@ -287,7 +284,6 @@ class CoachLauncher(object):
return preset
def get_config_args(self, parser: argparse.ArgumentParser) -> argparse.Namespace:
"""
Returns a Namespace object with all the user-specified configuration options needed to launch.
@@ -317,7 +313,6 @@ class CoachLauncher(object):
if args.list:
self.display_all_presets_and_exit()
# Read args from config file for distributed Coach.
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
coach_config = ConfigParser({
@@ -401,7 +396,6 @@ class CoachLauncher(object):
return args
def get_argument_parser(self) -> argparse.ArgumentParser:
"""
This returns an ArgumentParser object which defines the set of options that customers are expected to supply in order
@@ -545,7 +539,6 @@ class CoachLauncher(object):
return parser
def run_graph_manager(self, graph_manager: 'GraphManager', args: argparse.Namespace):
if args.distributed_coach and not graph_manager.agent_params.algorithm.distributed_coach_synchronization_type:
screen.error("{} algorithm is not supported using distributed Coach.".format(graph_manager.agent_params.algorithm))
@@ -581,7 +574,6 @@ class CoachLauncher(object):
else:
self.start_multi_threaded(graph_manager, args)
def start_single_threaded(self, graph_manager: 'GraphManager', args: argparse.Namespace):
# Start the training or evaluation
task_parameters = TaskParameters(
@@ -598,7 +590,6 @@ class CoachLauncher(object):
start_graph(graph_manager=graph_manager, task_parameters=task_parameters)
def start_multi_threaded(self, graph_manager: 'GraphManager', args: argparse.Namespace):
total_tasks = args.num_workers
if args.evaluation_worker: