mirror of
https://github.com/gryf/coach.git
synced 2026-02-16 14:05:46 +01:00
Coach as a library (#348)
* CoachInterface + tutorial * Some improvements and typo fixes * merge tutorial 0 and 4 * typo fix + additional tutorial changes * tutorial changes * added reading signals and experiment path argument
This commit is contained in:
@@ -18,7 +18,6 @@ sys.path.append('.')
|
||||
|
||||
import copy
|
||||
from configparser import ConfigParser, Error
|
||||
from rl_coach.core_types import EnvironmentSteps
|
||||
import os
|
||||
from rl_coach import logger
|
||||
import traceback
|
||||
@@ -30,6 +29,8 @@ import sys
|
||||
import json
|
||||
from rl_coach.base_parameters import Frameworks, VisualizationParameters, TaskParameters, DistributedTaskParameters, \
|
||||
RunType, DistributedCoachSynchronizationType
|
||||
from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, TrainingSteps, EnvironmentEpisodes, \
|
||||
EnvironmentSteps, StepMethod, Transition
|
||||
from multiprocessing import Process
|
||||
from multiprocessing.managers import BaseManager
|
||||
import subprocess
|
||||
@@ -316,7 +317,7 @@ class CoachLauncher(object):
|
||||
|
||||
return preset
|
||||
|
||||
def get_config_args(self, parser: argparse.ArgumentParser) -> argparse.Namespace:
|
||||
def get_config_args(self, parser: argparse.ArgumentParser, arguments=None) -> argparse.Namespace:
|
||||
"""
|
||||
Returns a Namespace object with all the user-specified configuration options needed to launch.
|
||||
This implementation uses argparse to take arguments from the CLI, but this can be over-ridden by
|
||||
@@ -329,15 +330,19 @@ class CoachLauncher(object):
|
||||
|
||||
:param parser: a parser object which implicitly defines the format of the Namespace that
|
||||
is expected to be returned.
|
||||
:param arguments: command line arguments
|
||||
:return: the parsed arguments as a Namespace
|
||||
"""
|
||||
args = parser.parse_args()
|
||||
if arguments is None:
|
||||
args = parser.parse_args()
|
||||
else:
|
||||
args = parser.parse_args(arguments)
|
||||
|
||||
if args.nocolor:
|
||||
screen.set_use_colors(False)
|
||||
|
||||
# if no arg is given
|
||||
if len(sys.argv) == 1:
|
||||
if (len(sys.argv) == 1 and arguments is None) or (arguments is not None and len(arguments) <= 2):
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
@@ -417,7 +422,7 @@ class CoachLauncher(object):
|
||||
|
||||
# get experiment name and path
|
||||
args.experiment_name = logger.get_experiment_name(args.experiment_name)
|
||||
args.experiment_path = logger.get_experiment_path(args.experiment_name)
|
||||
args.experiment_path = logger.get_experiment_path(args.experiment_name, args.experiment_path)
|
||||
|
||||
if args.play and args.num_workers > 1:
|
||||
screen.warning("Playing the game as a human is only available with a single worker. "
|
||||
@@ -450,7 +455,11 @@ class CoachLauncher(object):
|
||||
action='store_true')
|
||||
parser.add_argument('-e', '--experiment_name',
|
||||
help="(string) Experiment name to be used to store the results.",
|
||||
default='',
|
||||
default=None,
|
||||
type=str)
|
||||
parser.add_argument('-ep', '--experiment_path',
|
||||
help="(string) Path to experiments folder.",
|
||||
default=None,
|
||||
type=str)
|
||||
parser.add_argument('-r', '--render',
|
||||
help="(flag) Render environment",
|
||||
@@ -526,7 +535,8 @@ class CoachLauncher(object):
|
||||
" the selected preset (or on top of the command-line assembled one). "
|
||||
"Whenever a parameter value is a string, it should be inputted as '\\\"string\\\"'. "
|
||||
"For ex.: "
|
||||
"\"visualization.render=False; num_training_iterations=500; optimizer='rmsprop'\"",
|
||||
"\"visualization_parameters.render=False; heatup_steps=EnvironmentSteps(1000);"
|
||||
"improve_steps=TrainingSteps(100000); optimizer='rmsprop'\"",
|
||||
default=None,
|
||||
type=str)
|
||||
parser.add_argument('--print_networks_summary',
|
||||
@@ -589,14 +599,31 @@ class CoachLauncher(object):
|
||||
return parser
|
||||
|
||||
def run_graph_manager(self, graph_manager: 'GraphManager', args: argparse.Namespace):
|
||||
task_parameters = self.create_task_parameters(graph_manager, args)
|
||||
|
||||
if args.distributed_coach and args.distributed_coach_run_type != RunType.ORCHESTRATOR:
|
||||
handle_distributed_coach_tasks(graph_manager, args, task_parameters)
|
||||
return
|
||||
|
||||
# Single-threaded runs
|
||||
if args.num_workers == 1:
|
||||
self.start_single_threaded(task_parameters, graph_manager, args)
|
||||
else:
|
||||
self.start_multi_threaded(graph_manager, args)
|
||||
|
||||
@staticmethod
|
||||
def create_task_parameters(graph_manager: 'GraphManager', args: argparse.Namespace):
|
||||
if args.distributed_coach and not graph_manager.agent_params.algorithm.distributed_coach_synchronization_type:
|
||||
screen.error("{} algorithm is not supported using distributed Coach.".format(graph_manager.agent_params.algorithm))
|
||||
screen.error(
|
||||
"{} algorithm is not supported using distributed Coach.".format(graph_manager.agent_params.algorithm))
|
||||
|
||||
if args.distributed_coach and args.checkpoint_save_secs and graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
|
||||
screen.warning("The --checkpoint_save_secs or -s argument will be ignored as SYNC distributed coach sync type is used. Checkpoint will be saved every training iteration.")
|
||||
screen.warning(
|
||||
"The --checkpoint_save_secs or -s argument will be ignored as SYNC distributed coach sync type is used. Checkpoint will be saved every training iteration.")
|
||||
|
||||
if args.distributed_coach and not args.checkpoint_save_secs and graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.ASYNC:
|
||||
screen.error("Distributed coach with ASYNC distributed coach sync type requires --checkpoint_save_secs or -s.")
|
||||
screen.error(
|
||||
"Distributed coach with ASYNC distributed coach sync type requires --checkpoint_save_secs or -s.")
|
||||
|
||||
# Intel optimized TF seems to run significantly faster when limiting to a single OMP thread.
|
||||
# This will not affect GPU runs.
|
||||
@@ -617,6 +644,13 @@ class CoachLauncher(object):
|
||||
checkpoint_restore_path = args.checkpoint_restore_dir if args.checkpoint_restore_dir \
|
||||
else args.checkpoint_restore_file
|
||||
|
||||
# open dashboard
|
||||
if args.open_dashboard:
|
||||
open_dashboard(args.experiment_path)
|
||||
|
||||
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
|
||||
exit(handle_distributed_coach_orchestrator(args))
|
||||
|
||||
task_parameters = TaskParameters(
|
||||
framework_type=args.framework,
|
||||
evaluate_only=args.evaluate,
|
||||
@@ -630,22 +664,7 @@ class CoachLauncher(object):
|
||||
apply_stop_condition=args.apply_stop_condition
|
||||
)
|
||||
|
||||
# open dashboard
|
||||
if args.open_dashboard:
|
||||
open_dashboard(args.experiment_path)
|
||||
|
||||
if args.distributed_coach and args.distributed_coach_run_type != RunType.ORCHESTRATOR:
|
||||
handle_distributed_coach_tasks(graph_manager, args, task_parameters)
|
||||
return
|
||||
|
||||
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
|
||||
exit(handle_distributed_coach_orchestrator(args))
|
||||
|
||||
# Single-threaded runs
|
||||
if args.num_workers == 1:
|
||||
self.start_single_threaded(task_parameters, graph_manager, args)
|
||||
else:
|
||||
self.start_multi_threaded(graph_manager, args)
|
||||
return task_parameters
|
||||
|
||||
@staticmethod
|
||||
def start_single_threaded(task_parameters, graph_manager: 'GraphManager', args: argparse.Namespace):
|
||||
@@ -708,6 +727,7 @@ class CoachLauncher(object):
|
||||
# wait a bit before spawning the non chief workers in order to make sure the session is already created
|
||||
workers = []
|
||||
workers.append(start_distributed_task("worker", 0))
|
||||
|
||||
time.sleep(2)
|
||||
for task_index in range(1, args.num_workers):
|
||||
workers.append(start_distributed_task("worker", task_index))
|
||||
@@ -722,6 +742,34 @@ class CoachLauncher(object):
|
||||
evaluation_worker.terminate()
|
||||
|
||||
|
||||
class CoachInterface(CoachLauncher):
|
||||
"""
|
||||
This class is used as an interface to use coach as library. It can take any of the command line arguments
|
||||
(with the respective names) as arguments to the class.
|
||||
"""
|
||||
def __init__(self, **kwargs):
|
||||
parser = self.get_argument_parser()
|
||||
|
||||
arguments = []
|
||||
for key in kwargs:
|
||||
arguments.append('--' + key)
|
||||
arguments.append(str(kwargs[key]))
|
||||
|
||||
if '--experiment_name' not in arguments:
|
||||
arguments.append('--experiment_name')
|
||||
arguments.append('')
|
||||
self.args = self.get_config_args(parser, arguments)
|
||||
|
||||
self.graph_manager = self.get_graph_manager_from_args(self.args)
|
||||
|
||||
if self.args.num_workers == 1:
|
||||
task_parameters = self.create_task_parameters(self.graph_manager, self.args)
|
||||
self.graph_manager.create_graph(task_parameters)
|
||||
|
||||
def run(self):
|
||||
self.run_graph_manager(self.graph_manager, self.args)
|
||||
|
||||
|
||||
def main():
|
||||
launcher = CoachLauncher()
|
||||
launcher.launch()
|
||||
|
||||
@@ -30,7 +30,7 @@ except ImportError:
|
||||
failed_imports.append("RoboSchool")
|
||||
|
||||
try:
|
||||
from rl_coach.gym_extensions.continuous import mujoco
|
||||
from gym_extensions.continuous import mujoco
|
||||
except:
|
||||
from rl_coach.logger import failed_imports
|
||||
failed_imports.append("GymExtensions")
|
||||
|
||||
@@ -35,6 +35,7 @@ class BasicRLGraphManager(GraphManager):
|
||||
preset_validation_params: PresetValidationParameters = PresetValidationParameters(),
|
||||
name='simple_rl_graph'):
|
||||
super().__init__(name, schedule_params, vis_params)
|
||||
|
||||
self.agent_params = agent_params
|
||||
self.env_params = env_params
|
||||
self.preset_validation_params = preset_validation_params
|
||||
@@ -71,3 +72,13 @@ class BasicRLGraphManager(GraphManager):
|
||||
level_manager = LevelManager(agents=agent, environment=env, name="main_level")
|
||||
|
||||
return [level_manager], [env]
|
||||
|
||||
def log_signal(self, signal_name, value):
|
||||
self.level_managers[0].agents['agent'].agent_logger.create_signal_value(signal_name, value)
|
||||
|
||||
def get_signal_value(self, signal_name):
|
||||
return self.level_managers[0].agents['agent'].agent_logger.get_signal_value(signal_name)
|
||||
|
||||
def get_agent(self):
|
||||
return self.level_managers[0].agents['agent']
|
||||
|
||||
|
||||
@@ -23,12 +23,10 @@ from typing import List, Tuple
|
||||
import contextlib
|
||||
|
||||
from rl_coach.base_parameters import iterable_to_items, TaskParameters, DistributedTaskParameters, Frameworks, \
|
||||
VisualizationParameters, \
|
||||
Parameters, PresetValidationParameters, RunType
|
||||
VisualizationParameters, Parameters, PresetValidationParameters, RunType
|
||||
from rl_coach.checkpoint import CheckpointStateUpdater, get_checkpoint_state, SingleCheckpoint, CheckpointState
|
||||
from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, TrainingSteps, EnvironmentEpisodes, \
|
||||
EnvironmentSteps, \
|
||||
StepMethod, Transition
|
||||
EnvironmentSteps, StepMethod, Transition
|
||||
from rl_coach.environments.environment import Environment
|
||||
from rl_coach.level_manager import LevelManager
|
||||
from rl_coach.logger import screen, Logger
|
||||
@@ -123,6 +121,10 @@ class GraphManager(object):
|
||||
self.time_metric = TimeTypes.EpisodeNumber
|
||||
|
||||
def create_graph(self, task_parameters: TaskParameters=TaskParameters()):
|
||||
# check if create graph has been already called
|
||||
if self.graph_creation_time is not None:
|
||||
return self
|
||||
|
||||
self.graph_creation_time = time.time()
|
||||
self.task_parameters = task_parameters
|
||||
|
||||
|
||||
@@ -206,7 +206,7 @@ class BaseLogger(object):
|
||||
return True
|
||||
return False
|
||||
|
||||
def signal_value_exists(self, time, signal_name):
|
||||
def signal_value_exists(self, signal_name, time):
|
||||
try:
|
||||
value = self.get_signal_value(time, signal_name)
|
||||
if value != value: # value is nan
|
||||
@@ -215,7 +215,9 @@ class BaseLogger(object):
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_signal_value(self, time, signal_name):
|
||||
def get_signal_value(self, signal_name, time=None):
|
||||
if not time:
|
||||
time = self.time
|
||||
return self.data.loc[time, signal_name]
|
||||
|
||||
def dump_output_csv(self, append=True):
|
||||
@@ -382,12 +384,12 @@ def summarize_experiment():
|
||||
screen.log_title("Results moved to: {}".format(new_path))
|
||||
|
||||
|
||||
def get_experiment_name(initial_experiment_name=''):
|
||||
def get_experiment_name(initial_experiment_name=None):
|
||||
global experiment_name
|
||||
|
||||
match = None
|
||||
while match is None:
|
||||
if initial_experiment_name == '':
|
||||
if initial_experiment_name is None:
|
||||
msg_if_timeout = "Timeout waiting for experiement name."
|
||||
experiment_name = screen.ask_input_with_timeout("Please enter an experiment name: ", 60, msg_if_timeout)
|
||||
else:
|
||||
@@ -407,10 +409,12 @@ def get_experiment_name(initial_experiment_name=''):
|
||||
return experiment_name
|
||||
|
||||
|
||||
def get_experiment_path(experiment_name, create_path=True):
|
||||
def get_experiment_path(experiment_name, initial_experiment_path=None, create_path=True):
|
||||
global experiment_path
|
||||
|
||||
general_experiments_path = os.path.join('./experiments/', experiment_name)
|
||||
if not initial_experiment_path:
|
||||
initial_experiment_path = './experiments/'
|
||||
general_experiments_path = os.path.join(initial_experiment_path, experiment_name)
|
||||
|
||||
cur_date = time_started.date()
|
||||
cur_time = time_started.time()
|
||||
|
||||
Reference in New Issue
Block a user