mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 19:50:17 +01:00
Coach as a library (#348)
* CoachInterface + tutorial * Some improvements and typo fixes * merge tutorial 0 and 4 * typo fix + additional tutorial changes * tutorial changes * added reading signals and experiment path argument
This commit is contained in:
@@ -18,7 +18,6 @@ sys.path.append('.')
|
||||
|
||||
import copy
|
||||
from configparser import ConfigParser, Error
|
||||
from rl_coach.core_types import EnvironmentSteps
|
||||
import os
|
||||
from rl_coach import logger
|
||||
import traceback
|
||||
@@ -30,6 +29,8 @@ import sys
|
||||
import json
|
||||
from rl_coach.base_parameters import Frameworks, VisualizationParameters, TaskParameters, DistributedTaskParameters, \
|
||||
RunType, DistributedCoachSynchronizationType
|
||||
from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, TrainingSteps, EnvironmentEpisodes, \
|
||||
EnvironmentSteps, StepMethod, Transition
|
||||
from multiprocessing import Process
|
||||
from multiprocessing.managers import BaseManager
|
||||
import subprocess
|
||||
@@ -316,7 +317,7 @@ class CoachLauncher(object):
|
||||
|
||||
return preset
|
||||
|
||||
def get_config_args(self, parser: argparse.ArgumentParser) -> argparse.Namespace:
|
||||
def get_config_args(self, parser: argparse.ArgumentParser, arguments=None) -> argparse.Namespace:
|
||||
"""
|
||||
Returns a Namespace object with all the user-specified configuration options needed to launch.
|
||||
This implementation uses argparse to take arguments from the CLI, but this can be over-ridden by
|
||||
@@ -329,15 +330,19 @@ class CoachLauncher(object):
|
||||
|
||||
:param parser: a parser object which implicitly defines the format of the Namespace that
|
||||
is expected to be returned.
|
||||
:param arguments: command line arguments
|
||||
:return: the parsed arguments as a Namespace
|
||||
"""
|
||||
args = parser.parse_args()
|
||||
if arguments is None:
|
||||
args = parser.parse_args()
|
||||
else:
|
||||
args = parser.parse_args(arguments)
|
||||
|
||||
if args.nocolor:
|
||||
screen.set_use_colors(False)
|
||||
|
||||
# if no arg is given
|
||||
if len(sys.argv) == 1:
|
||||
if (len(sys.argv) == 1 and arguments is None) or (arguments is not None and len(arguments) <= 2):
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
@@ -417,7 +422,7 @@ class CoachLauncher(object):
|
||||
|
||||
# get experiment name and path
|
||||
args.experiment_name = logger.get_experiment_name(args.experiment_name)
|
||||
args.experiment_path = logger.get_experiment_path(args.experiment_name)
|
||||
args.experiment_path = logger.get_experiment_path(args.experiment_name, args.experiment_path)
|
||||
|
||||
if args.play and args.num_workers > 1:
|
||||
screen.warning("Playing the game as a human is only available with a single worker. "
|
||||
@@ -450,7 +455,11 @@ class CoachLauncher(object):
|
||||
action='store_true')
|
||||
parser.add_argument('-e', '--experiment_name',
|
||||
help="(string) Experiment name to be used to store the results.",
|
||||
default='',
|
||||
default=None,
|
||||
type=str)
|
||||
parser.add_argument('-ep', '--experiment_path',
|
||||
help="(string) Path to experiments folder.",
|
||||
default=None,
|
||||
type=str)
|
||||
parser.add_argument('-r', '--render',
|
||||
help="(flag) Render environment",
|
||||
@@ -526,7 +535,8 @@ class CoachLauncher(object):
|
||||
" the selected preset (or on top of the command-line assembled one). "
|
||||
"Whenever a parameter value is a string, it should be inputted as '\\\"string\\\"'. "
|
||||
"For ex.: "
|
||||
"\"visualization.render=False; num_training_iterations=500; optimizer='rmsprop'\"",
|
||||
"\"visualization_parameters.render=False; heatup_steps=EnvironmentSteps(1000);"
|
||||
"improve_steps=TrainingSteps(100000); optimizer='rmsprop'\"",
|
||||
default=None,
|
||||
type=str)
|
||||
parser.add_argument('--print_networks_summary',
|
||||
@@ -589,14 +599,31 @@ class CoachLauncher(object):
|
||||
return parser
|
||||
|
||||
def run_graph_manager(self, graph_manager: 'GraphManager', args: argparse.Namespace):
|
||||
task_parameters = self.create_task_parameters(graph_manager, args)
|
||||
|
||||
if args.distributed_coach and args.distributed_coach_run_type != RunType.ORCHESTRATOR:
|
||||
handle_distributed_coach_tasks(graph_manager, args, task_parameters)
|
||||
return
|
||||
|
||||
# Single-threaded runs
|
||||
if args.num_workers == 1:
|
||||
self.start_single_threaded(task_parameters, graph_manager, args)
|
||||
else:
|
||||
self.start_multi_threaded(graph_manager, args)
|
||||
|
||||
@staticmethod
|
||||
def create_task_parameters(graph_manager: 'GraphManager', args: argparse.Namespace):
|
||||
if args.distributed_coach and not graph_manager.agent_params.algorithm.distributed_coach_synchronization_type:
|
||||
screen.error("{} algorithm is not supported using distributed Coach.".format(graph_manager.agent_params.algorithm))
|
||||
screen.error(
|
||||
"{} algorithm is not supported using distributed Coach.".format(graph_manager.agent_params.algorithm))
|
||||
|
||||
if args.distributed_coach and args.checkpoint_save_secs and graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
|
||||
screen.warning("The --checkpoint_save_secs or -s argument will be ignored as SYNC distributed coach sync type is used. Checkpoint will be saved every training iteration.")
|
||||
screen.warning(
|
||||
"The --checkpoint_save_secs or -s argument will be ignored as SYNC distributed coach sync type is used. Checkpoint will be saved every training iteration.")
|
||||
|
||||
if args.distributed_coach and not args.checkpoint_save_secs and graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.ASYNC:
|
||||
screen.error("Distributed coach with ASYNC distributed coach sync type requires --checkpoint_save_secs or -s.")
|
||||
screen.error(
|
||||
"Distributed coach with ASYNC distributed coach sync type requires --checkpoint_save_secs or -s.")
|
||||
|
||||
# Intel optimized TF seems to run significantly faster when limiting to a single OMP thread.
|
||||
# This will not affect GPU runs.
|
||||
@@ -617,6 +644,13 @@ class CoachLauncher(object):
|
||||
checkpoint_restore_path = args.checkpoint_restore_dir if args.checkpoint_restore_dir \
|
||||
else args.checkpoint_restore_file
|
||||
|
||||
# open dashboard
|
||||
if args.open_dashboard:
|
||||
open_dashboard(args.experiment_path)
|
||||
|
||||
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
|
||||
exit(handle_distributed_coach_orchestrator(args))
|
||||
|
||||
task_parameters = TaskParameters(
|
||||
framework_type=args.framework,
|
||||
evaluate_only=args.evaluate,
|
||||
@@ -630,22 +664,7 @@ class CoachLauncher(object):
|
||||
apply_stop_condition=args.apply_stop_condition
|
||||
)
|
||||
|
||||
# open dashboard
|
||||
if args.open_dashboard:
|
||||
open_dashboard(args.experiment_path)
|
||||
|
||||
if args.distributed_coach and args.distributed_coach_run_type != RunType.ORCHESTRATOR:
|
||||
handle_distributed_coach_tasks(graph_manager, args, task_parameters)
|
||||
return
|
||||
|
||||
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
|
||||
exit(handle_distributed_coach_orchestrator(args))
|
||||
|
||||
# Single-threaded runs
|
||||
if args.num_workers == 1:
|
||||
self.start_single_threaded(task_parameters, graph_manager, args)
|
||||
else:
|
||||
self.start_multi_threaded(graph_manager, args)
|
||||
return task_parameters
|
||||
|
||||
@staticmethod
|
||||
def start_single_threaded(task_parameters, graph_manager: 'GraphManager', args: argparse.Namespace):
|
||||
@@ -708,6 +727,7 @@ class CoachLauncher(object):
|
||||
# wait a bit before spawning the non chief workers in order to make sure the session is already created
|
||||
workers = []
|
||||
workers.append(start_distributed_task("worker", 0))
|
||||
|
||||
time.sleep(2)
|
||||
for task_index in range(1, args.num_workers):
|
||||
workers.append(start_distributed_task("worker", task_index))
|
||||
@@ -722,6 +742,34 @@ class CoachLauncher(object):
|
||||
evaluation_worker.terminate()
|
||||
|
||||
|
||||
class CoachInterface(CoachLauncher):
|
||||
"""
|
||||
This class is used as an interface to use coach as library. It can take any of the command line arguments
|
||||
(with the respective names) as arguments to the class.
|
||||
"""
|
||||
def __init__(self, **kwargs):
|
||||
parser = self.get_argument_parser()
|
||||
|
||||
arguments = []
|
||||
for key in kwargs:
|
||||
arguments.append('--' + key)
|
||||
arguments.append(str(kwargs[key]))
|
||||
|
||||
if '--experiment_name' not in arguments:
|
||||
arguments.append('--experiment_name')
|
||||
arguments.append('')
|
||||
self.args = self.get_config_args(parser, arguments)
|
||||
|
||||
self.graph_manager = self.get_graph_manager_from_args(self.args)
|
||||
|
||||
if self.args.num_workers == 1:
|
||||
task_parameters = self.create_task_parameters(self.graph_manager, self.args)
|
||||
self.graph_manager.create_graph(task_parameters)
|
||||
|
||||
def run(self):
|
||||
self.run_graph_manager(self.graph_manager, self.args)
|
||||
|
||||
|
||||
def main():
|
||||
launcher = CoachLauncher()
|
||||
launcher.launch()
|
||||
|
||||
Reference in New Issue
Block a user