1
0
mirror of https://github.com/gryf/coach.git synced 2026-02-16 14:05:46 +01:00

Coach as a library (#348)

* CoachInterface + tutorial

* Some improvements and typo fixes

* merge tutorial 0 and 4

* typo fix + additional tutorial changes

* tutorial changes

* added reading signals and experiment path argument
This commit is contained in:
shadiendrawis
2019-06-19 18:05:03 +03:00
committed by GitHub
parent 1c90bc22a1
commit 8e812ef82f
6 changed files with 181 additions and 87 deletions

View File

@@ -18,7 +18,6 @@ sys.path.append('.')
import copy
from configparser import ConfigParser, Error
from rl_coach.core_types import EnvironmentSteps
import os
from rl_coach import logger
import traceback
@@ -30,6 +29,8 @@ import sys
import json
from rl_coach.base_parameters import Frameworks, VisualizationParameters, TaskParameters, DistributedTaskParameters, \
RunType, DistributedCoachSynchronizationType
from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, TrainingSteps, EnvironmentEpisodes, \
EnvironmentSteps, StepMethod, Transition
from multiprocessing import Process
from multiprocessing.managers import BaseManager
import subprocess
@@ -316,7 +317,7 @@ class CoachLauncher(object):
return preset
def get_config_args(self, parser: argparse.ArgumentParser) -> argparse.Namespace:
def get_config_args(self, parser: argparse.ArgumentParser, arguments=None) -> argparse.Namespace:
"""
Returns a Namespace object with all the user-specified configuration options needed to launch.
This implementation uses argparse to take arguments from the CLI, but this can be over-ridden by
@@ -329,15 +330,19 @@ class CoachLauncher(object):
:param parser: a parser object which implicitly defines the format of the Namespace that
is expected to be returned.
:param arguments: command line arguments
:return: the parsed arguments as a Namespace
"""
args = parser.parse_args()
if arguments is None:
args = parser.parse_args()
else:
args = parser.parse_args(arguments)
if args.nocolor:
screen.set_use_colors(False)
# if no arg is given
if len(sys.argv) == 1:
if (len(sys.argv) == 1 and arguments is None) or (arguments is not None and len(arguments) <= 2):
parser.print_help()
sys.exit(1)
@@ -417,7 +422,7 @@ class CoachLauncher(object):
# get experiment name and path
args.experiment_name = logger.get_experiment_name(args.experiment_name)
args.experiment_path = logger.get_experiment_path(args.experiment_name)
args.experiment_path = logger.get_experiment_path(args.experiment_name, args.experiment_path)
if args.play and args.num_workers > 1:
screen.warning("Playing the game as a human is only available with a single worker. "
@@ -450,7 +455,11 @@ class CoachLauncher(object):
action='store_true')
parser.add_argument('-e', '--experiment_name',
help="(string) Experiment name to be used to store the results.",
default='',
default=None,
type=str)
parser.add_argument('-ep', '--experiment_path',
help="(string) Path to experiments folder.",
default=None,
type=str)
parser.add_argument('-r', '--render',
help="(flag) Render environment",
@@ -526,7 +535,8 @@ class CoachLauncher(object):
" the selected preset (or on top of the command-line assembled one). "
"Whenever a parameter value is a string, it should be inputted as '\\\"string\\\"'. "
"For ex.: "
"\"visualization.render=False; num_training_iterations=500; optimizer='rmsprop'\"",
"\"visualization_parameters.render=False; heatup_steps=EnvironmentSteps(1000);"
"improve_steps=TrainingSteps(100000); optimizer='rmsprop'\"",
default=None,
type=str)
parser.add_argument('--print_networks_summary',
@@ -589,14 +599,31 @@ class CoachLauncher(object):
return parser
def run_graph_manager(self, graph_manager: 'GraphManager', args: argparse.Namespace):
task_parameters = self.create_task_parameters(graph_manager, args)
if args.distributed_coach and args.distributed_coach_run_type != RunType.ORCHESTRATOR:
handle_distributed_coach_tasks(graph_manager, args, task_parameters)
return
# Single-threaded runs
if args.num_workers == 1:
self.start_single_threaded(task_parameters, graph_manager, args)
else:
self.start_multi_threaded(graph_manager, args)
@staticmethod
def create_task_parameters(graph_manager: 'GraphManager', args: argparse.Namespace):
if args.distributed_coach and not graph_manager.agent_params.algorithm.distributed_coach_synchronization_type:
screen.error("{} algorithm is not supported using distributed Coach.".format(graph_manager.agent_params.algorithm))
screen.error(
"{} algorithm is not supported using distributed Coach.".format(graph_manager.agent_params.algorithm))
if args.distributed_coach and args.checkpoint_save_secs and graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
screen.warning("The --checkpoint_save_secs or -s argument will be ignored as SYNC distributed coach sync type is used. Checkpoint will be saved every training iteration.")
screen.warning(
"The --checkpoint_save_secs or -s argument will be ignored as SYNC distributed coach sync type is used. Checkpoint will be saved every training iteration.")
if args.distributed_coach and not args.checkpoint_save_secs and graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.ASYNC:
screen.error("Distributed coach with ASYNC distributed coach sync type requires --checkpoint_save_secs or -s.")
screen.error(
"Distributed coach with ASYNC distributed coach sync type requires --checkpoint_save_secs or -s.")
# Intel optimized TF seems to run significantly faster when limiting to a single OMP thread.
# This will not affect GPU runs.
@@ -617,6 +644,13 @@ class CoachLauncher(object):
checkpoint_restore_path = args.checkpoint_restore_dir if args.checkpoint_restore_dir \
else args.checkpoint_restore_file
# open dashboard
if args.open_dashboard:
open_dashboard(args.experiment_path)
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
exit(handle_distributed_coach_orchestrator(args))
task_parameters = TaskParameters(
framework_type=args.framework,
evaluate_only=args.evaluate,
@@ -630,22 +664,7 @@ class CoachLauncher(object):
apply_stop_condition=args.apply_stop_condition
)
# open dashboard
if args.open_dashboard:
open_dashboard(args.experiment_path)
if args.distributed_coach and args.distributed_coach_run_type != RunType.ORCHESTRATOR:
handle_distributed_coach_tasks(graph_manager, args, task_parameters)
return
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
exit(handle_distributed_coach_orchestrator(args))
# Single-threaded runs
if args.num_workers == 1:
self.start_single_threaded(task_parameters, graph_manager, args)
else:
self.start_multi_threaded(graph_manager, args)
return task_parameters
@staticmethod
def start_single_threaded(task_parameters, graph_manager: 'GraphManager', args: argparse.Namespace):
@@ -708,6 +727,7 @@ class CoachLauncher(object):
# wait a bit before spawning the non chief workers in order to make sure the session is already created
workers = []
workers.append(start_distributed_task("worker", 0))
time.sleep(2)
for task_index in range(1, args.num_workers):
workers.append(start_distributed_task("worker", task_index))
@@ -722,6 +742,34 @@ class CoachLauncher(object):
evaluation_worker.terminate()
class CoachInterface(CoachLauncher):
"""
This class is used as an interface to use coach as library. It can take any of the command line arguments
(with the respective names) as arguments to the class.
"""
def __init__(self, **kwargs):
parser = self.get_argument_parser()
arguments = []
for key in kwargs:
arguments.append('--' + key)
arguments.append(str(kwargs[key]))
if '--experiment_name' not in arguments:
arguments.append('--experiment_name')
arguments.append('')
self.args = self.get_config_args(parser, arguments)
self.graph_manager = self.get_graph_manager_from_args(self.args)
if self.args.num_workers == 1:
task_parameters = self.create_task_parameters(self.graph_manager, self.args)
self.graph_manager.create_graph(task_parameters)
def run(self):
self.run_graph_manager(self.graph_manager, self.args)
def main():
launcher = CoachLauncher()
launcher.launch()

View File

@@ -30,7 +30,7 @@ except ImportError:
failed_imports.append("RoboSchool")
try:
from rl_coach.gym_extensions.continuous import mujoco
from gym_extensions.continuous import mujoco
except:
from rl_coach.logger import failed_imports
failed_imports.append("GymExtensions")

View File

@@ -35,6 +35,7 @@ class BasicRLGraphManager(GraphManager):
preset_validation_params: PresetValidationParameters = PresetValidationParameters(),
name='simple_rl_graph'):
super().__init__(name, schedule_params, vis_params)
self.agent_params = agent_params
self.env_params = env_params
self.preset_validation_params = preset_validation_params
@@ -71,3 +72,13 @@ class BasicRLGraphManager(GraphManager):
level_manager = LevelManager(agents=agent, environment=env, name="main_level")
return [level_manager], [env]
def log_signal(self, signal_name, value):
self.level_managers[0].agents['agent'].agent_logger.create_signal_value(signal_name, value)
def get_signal_value(self, signal_name):
return self.level_managers[0].agents['agent'].agent_logger.get_signal_value(signal_name)
def get_agent(self):
return self.level_managers[0].agents['agent']

View File

@@ -23,12 +23,10 @@ from typing import List, Tuple
import contextlib
from rl_coach.base_parameters import iterable_to_items, TaskParameters, DistributedTaskParameters, Frameworks, \
VisualizationParameters, \
Parameters, PresetValidationParameters, RunType
VisualizationParameters, Parameters, PresetValidationParameters, RunType
from rl_coach.checkpoint import CheckpointStateUpdater, get_checkpoint_state, SingleCheckpoint, CheckpointState
from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, TrainingSteps, EnvironmentEpisodes, \
EnvironmentSteps, \
StepMethod, Transition
EnvironmentSteps, StepMethod, Transition
from rl_coach.environments.environment import Environment
from rl_coach.level_manager import LevelManager
from rl_coach.logger import screen, Logger
@@ -123,6 +121,10 @@ class GraphManager(object):
self.time_metric = TimeTypes.EpisodeNumber
def create_graph(self, task_parameters: TaskParameters=TaskParameters()):
# check if create graph has been already called
if self.graph_creation_time is not None:
return self
self.graph_creation_time = time.time()
self.task_parameters = task_parameters

View File

@@ -206,7 +206,7 @@ class BaseLogger(object):
return True
return False
def signal_value_exists(self, time, signal_name):
def signal_value_exists(self, signal_name, time):
try:
value = self.get_signal_value(time, signal_name)
if value != value: # value is nan
@@ -215,7 +215,9 @@ class BaseLogger(object):
return False
return True
def get_signal_value(self, time, signal_name):
def get_signal_value(self, signal_name, time=None):
if not time:
time = self.time
return self.data.loc[time, signal_name]
def dump_output_csv(self, append=True):
@@ -382,12 +384,12 @@ def summarize_experiment():
screen.log_title("Results moved to: {}".format(new_path))
def get_experiment_name(initial_experiment_name=''):
def get_experiment_name(initial_experiment_name=None):
global experiment_name
match = None
while match is None:
if initial_experiment_name == '':
if initial_experiment_name is None:
msg_if_timeout = "Timeout waiting for experiement name."
experiment_name = screen.ask_input_with_timeout("Please enter an experiment name: ", 60, msg_if_timeout)
else:
@@ -407,10 +409,12 @@ def get_experiment_name(initial_experiment_name=''):
return experiment_name
def get_experiment_path(experiment_name, create_path=True):
def get_experiment_path(experiment_name, initial_experiment_path=None, create_path=True):
global experiment_path
general_experiments_path = os.path.join('./experiments/', experiment_name)
if not initial_experiment_path:
initial_experiment_path = './experiments/'
general_experiments_path = os.path.join(initial_experiment_path, experiment_name)
cur_date = time_started.date()
cur_time = time_started.time()