1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 19:50:17 +01:00

Coach as a library (#348)

* CoachInterface + tutorial

* Some improvements and typo fixes

* merge tutorial 0 and 4

* typo fix + additional tutorial changes

* tutorial changes

* added reading signals and experiment path argument
This commit is contained in:
shadiendrawis
2019-06-19 18:05:03 +03:00
committed by GitHub
parent 1c90bc22a1
commit 8e812ef82f
6 changed files with 181 additions and 87 deletions

View File

@@ -18,7 +18,6 @@ sys.path.append('.')
import copy
from configparser import ConfigParser, Error
from rl_coach.core_types import EnvironmentSteps
import os
from rl_coach import logger
import traceback
@@ -30,6 +29,8 @@ import sys
import json
from rl_coach.base_parameters import Frameworks, VisualizationParameters, TaskParameters, DistributedTaskParameters, \
RunType, DistributedCoachSynchronizationType
from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, TrainingSteps, EnvironmentEpisodes, \
EnvironmentSteps, StepMethod, Transition
from multiprocessing import Process
from multiprocessing.managers import BaseManager
import subprocess
@@ -316,7 +317,7 @@ class CoachLauncher(object):
return preset
def get_config_args(self, parser: argparse.ArgumentParser) -> argparse.Namespace:
def get_config_args(self, parser: argparse.ArgumentParser, arguments=None) -> argparse.Namespace:
"""
Returns a Namespace object with all the user-specified configuration options needed to launch.
This implementation uses argparse to take arguments from the CLI, but this can be over-ridden by
@@ -329,15 +330,19 @@ class CoachLauncher(object):
:param parser: a parser object which implicitly defines the format of the Namespace that
is expected to be returned.
:param arguments: command line arguments
:return: the parsed arguments as a Namespace
"""
args = parser.parse_args()
if arguments is None:
args = parser.parse_args()
else:
args = parser.parse_args(arguments)
if args.nocolor:
screen.set_use_colors(False)
# if no arg is given
if len(sys.argv) == 1:
if (len(sys.argv) == 1 and arguments is None) or (arguments is not None and len(arguments) <= 2):
parser.print_help()
sys.exit(1)
@@ -417,7 +422,7 @@ class CoachLauncher(object):
# get experiment name and path
args.experiment_name = logger.get_experiment_name(args.experiment_name)
args.experiment_path = logger.get_experiment_path(args.experiment_name)
args.experiment_path = logger.get_experiment_path(args.experiment_name, args.experiment_path)
if args.play and args.num_workers > 1:
screen.warning("Playing the game as a human is only available with a single worker. "
@@ -450,7 +455,11 @@ class CoachLauncher(object):
action='store_true')
parser.add_argument('-e', '--experiment_name',
help="(string) Experiment name to be used to store the results.",
default='',
default=None,
type=str)
parser.add_argument('-ep', '--experiment_path',
help="(string) Path to experiments folder.",
default=None,
type=str)
parser.add_argument('-r', '--render',
help="(flag) Render environment",
@@ -526,7 +535,8 @@ class CoachLauncher(object):
" the selected preset (or on top of the command-line assembled one). "
"Whenever a parameter value is a string, it should be inputted as '\\\"string\\\"'. "
"For ex.: "
"\"visualization.render=False; num_training_iterations=500; optimizer='rmsprop'\"",
"\"visualization_parameters.render=False; heatup_steps=EnvironmentSteps(1000);"
"improve_steps=TrainingSteps(100000); optimizer='rmsprop'\"",
default=None,
type=str)
parser.add_argument('--print_networks_summary',
@@ -589,14 +599,31 @@ class CoachLauncher(object):
return parser
def run_graph_manager(self, graph_manager: 'GraphManager', args: argparse.Namespace):
task_parameters = self.create_task_parameters(graph_manager, args)
if args.distributed_coach and args.distributed_coach_run_type != RunType.ORCHESTRATOR:
handle_distributed_coach_tasks(graph_manager, args, task_parameters)
return
# Single-threaded runs
if args.num_workers == 1:
self.start_single_threaded(task_parameters, graph_manager, args)
else:
self.start_multi_threaded(graph_manager, args)
@staticmethod
def create_task_parameters(graph_manager: 'GraphManager', args: argparse.Namespace):
if args.distributed_coach and not graph_manager.agent_params.algorithm.distributed_coach_synchronization_type:
screen.error("{} algorithm is not supported using distributed Coach.".format(graph_manager.agent_params.algorithm))
screen.error(
"{} algorithm is not supported using distributed Coach.".format(graph_manager.agent_params.algorithm))
if args.distributed_coach and args.checkpoint_save_secs and graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
screen.warning("The --checkpoint_save_secs or -s argument will be ignored as SYNC distributed coach sync type is used. Checkpoint will be saved every training iteration.")
screen.warning(
"The --checkpoint_save_secs or -s argument will be ignored as SYNC distributed coach sync type is used. Checkpoint will be saved every training iteration.")
if args.distributed_coach and not args.checkpoint_save_secs and graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.ASYNC:
screen.error("Distributed coach with ASYNC distributed coach sync type requires --checkpoint_save_secs or -s.")
screen.error(
"Distributed coach with ASYNC distributed coach sync type requires --checkpoint_save_secs or -s.")
# Intel optimized TF seems to run significantly faster when limiting to a single OMP thread.
# This will not affect GPU runs.
@@ -617,6 +644,13 @@ class CoachLauncher(object):
checkpoint_restore_path = args.checkpoint_restore_dir if args.checkpoint_restore_dir \
else args.checkpoint_restore_file
# open dashboard
if args.open_dashboard:
open_dashboard(args.experiment_path)
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
exit(handle_distributed_coach_orchestrator(args))
task_parameters = TaskParameters(
framework_type=args.framework,
evaluate_only=args.evaluate,
@@ -630,22 +664,7 @@ class CoachLauncher(object):
apply_stop_condition=args.apply_stop_condition
)
# open dashboard
if args.open_dashboard:
open_dashboard(args.experiment_path)
if args.distributed_coach and args.distributed_coach_run_type != RunType.ORCHESTRATOR:
handle_distributed_coach_tasks(graph_manager, args, task_parameters)
return
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
exit(handle_distributed_coach_orchestrator(args))
# Single-threaded runs
if args.num_workers == 1:
self.start_single_threaded(task_parameters, graph_manager, args)
else:
self.start_multi_threaded(graph_manager, args)
return task_parameters
@staticmethod
def start_single_threaded(task_parameters, graph_manager: 'GraphManager', args: argparse.Namespace):
@@ -708,6 +727,7 @@ class CoachLauncher(object):
# wait a bit before spawning the non chief workers in order to make sure the session is already created
workers = []
workers.append(start_distributed_task("worker", 0))
time.sleep(2)
for task_index in range(1, args.num_workers):
workers.append(start_distributed_task("worker", task_index))
@@ -722,6 +742,34 @@ class CoachLauncher(object):
evaluation_worker.terminate()
class CoachInterface(CoachLauncher):
"""
This class is used as an interface to use coach as library. It can take any of the command line arguments
(with the respective names) as arguments to the class.
"""
def __init__(self, **kwargs):
parser = self.get_argument_parser()
arguments = []
for key in kwargs:
arguments.append('--' + key)
arguments.append(str(kwargs[key]))
if '--experiment_name' not in arguments:
arguments.append('--experiment_name')
arguments.append('')
self.args = self.get_config_args(parser, arguments)
self.graph_manager = self.get_graph_manager_from_args(self.args)
if self.args.num_workers == 1:
task_parameters = self.create_task_parameters(self.graph_manager, self.args)
self.graph_manager.create_graph(task_parameters)
def run(self):
self.run_graph_manager(self.graph_manager, self.args)
def main():
launcher = CoachLauncher()
launcher.launch()