mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Release 0.9
Main changes are detailed below: New features - * CARLA 0.7 simulator integration * Human control of the game play * Recording of human game play and storing / loading the replay buffer * Behavioral cloning agent and presets * Golden tests for several presets * Selecting between deep / shallow image embedders * Rendering through pygame (with some boost in performance) API changes - * Improved environment wrapper API * Added an evaluate flag to allow convenient evaluation of existing checkpoints * Improve frameskip definition in Gym Bug fixes - * Fixed loading of checkpoints for agents with more than one network * Fixed the N Step Q learning agent python3 compatibility
This commit is contained in:
86
coach.py
86
coach.py
@@ -37,8 +37,29 @@ time_started = datetime.datetime.now()
|
||||
cur_time = time_started.time()
|
||||
cur_date = time_started.date()
|
||||
|
||||
def get_experiment_path(general_experiments_path):
|
||||
if not os.path.exists(general_experiments_path):
|
||||
|
||||
def get_experiment_name(initial_experiment_name=''):
|
||||
match = None
|
||||
while match is None:
|
||||
if initial_experiment_name == '':
|
||||
experiment_name = screen.ask_input("Please enter an experiment name: ")
|
||||
else:
|
||||
experiment_name = initial_experiment_name
|
||||
|
||||
experiment_name = experiment_name.replace(" ", "_")
|
||||
match = re.match("^$|^[\w -/]{1,100}$", experiment_name)
|
||||
|
||||
if match is None:
|
||||
screen.error('Experiment name must be composed only of alphanumeric letters, '
|
||||
'underscores and dashes and should not be longer than 100 characters.')
|
||||
|
||||
return match.group(0)
|
||||
|
||||
|
||||
def get_experiment_path(experiment_name, create_path=True):
|
||||
general_experiments_path = os.path.join('./experiments/', experiment_name)
|
||||
|
||||
if not os.path.exists(general_experiments_path) and create_path:
|
||||
os.makedirs(general_experiments_path)
|
||||
experiment_path = os.path.join(general_experiments_path, '{}_{}_{}-{}_{}'
|
||||
.format(logger.two_digits(cur_date.day), logger.two_digits(cur_date.month),
|
||||
@@ -52,7 +73,8 @@ def get_experiment_path(general_experiments_path):
|
||||
cur_time.minute, i))
|
||||
i += 1
|
||||
else:
|
||||
os.makedirs(experiment_path)
|
||||
if create_path:
|
||||
os.makedirs(experiment_path)
|
||||
return experiment_path
|
||||
|
||||
|
||||
@@ -96,55 +118,54 @@ def check_input_and_fill_run_dict(parser):
|
||||
num_workers = int(re.match("^\d+$", args.num_workers).group(0))
|
||||
except ValueError:
|
||||
screen.error("Parameter num_workers should be an integer.")
|
||||
exit(1)
|
||||
|
||||
preset_names = list_all_classes_in_module(presets)
|
||||
if args.preset is not None and args.preset not in preset_names:
|
||||
screen.error("A non-existing preset was selected. ")
|
||||
exit(1)
|
||||
|
||||
if args.checkpoint_restore_dir is not None and not os.path.exists(args.checkpoint_restore_dir):
|
||||
screen.error("The requested checkpoint folder to load from does not exist. ")
|
||||
exit(1)
|
||||
|
||||
if args.save_model_sec is not None:
|
||||
try:
|
||||
args.save_model_sec = int(args.save_model_sec)
|
||||
except ValueError:
|
||||
screen.error("Parameter save_model_sec should be an integer.")
|
||||
exit(1)
|
||||
|
||||
if args.preset is None and (args.agent_type is None or args.environment_type is None
|
||||
or args.exploration_policy_type is None):
|
||||
or args.exploration_policy_type is None) and not args.play:
|
||||
screen.error('When no preset is given for Coach to run, the user is expected to input the desired agent_type,'
|
||||
' environment_type and exploration_policy_type to assemble a preset. '
|
||||
'\nAt least one of these parameters was not given.')
|
||||
exit(1)
|
||||
elif args.preset is None and args.play and args.environment_type is None:
|
||||
screen.error('When no preset is given for Coach to run, and the user requests human control over the environment,'
|
||||
' the user is expected to input the desired environment_type and level.'
|
||||
'\nAt least one of these parameters was not given.')
|
||||
elif args.preset is None and args.play and args.environment_type:
|
||||
args.agent_type = 'Human'
|
||||
args.exploration_policy_type = 'ExplorationParameters'
|
||||
|
||||
experiment_name = args.experiment_name
|
||||
# get experiment name and path
|
||||
experiment_name = get_experiment_name(args.experiment_name)
|
||||
experiment_path = get_experiment_path(experiment_name)
|
||||
|
||||
if args.experiment_name == '':
|
||||
experiment_name = screen.ask_input("Please enter an experiment name: ")
|
||||
|
||||
experiment_name = experiment_name.replace(" ", "_")
|
||||
match = re.match("^$|^\w{1,100}$", experiment_name)
|
||||
|
||||
if match is None:
|
||||
screen.error('Experiment name must be composed only of alphanumeric letters and underscores and should not be '
|
||||
'longer than 100 characters.')
|
||||
exit(1)
|
||||
experiment_path = os.path.join('./experiments/', match.group(0))
|
||||
experiment_path = get_experiment_path(experiment_path)
|
||||
if args.play and num_workers > 1:
|
||||
screen.warning("Playing the game as a human is only available with a single worker. "
|
||||
"The number of workers will be reduced to 1")
|
||||
num_workers = 1
|
||||
|
||||
# fill run_dict
|
||||
run_dict = dict()
|
||||
run_dict['agent_type'] = args.agent_type
|
||||
run_dict['environment_type'] = args.environment_type
|
||||
run_dict['exploration_policy_type'] = args.exploration_policy_type
|
||||
run_dict['level'] = args.level
|
||||
run_dict['preset'] = args.preset
|
||||
run_dict['custom_parameter'] = args.custom_parameter
|
||||
run_dict['experiment_path'] = experiment_path
|
||||
run_dict['framework'] = Frameworks().get(args.framework)
|
||||
run_dict['play'] = args.play
|
||||
run_dict['evaluate'] = args.evaluate# or args.play
|
||||
|
||||
# multi-threading parameters
|
||||
run_dict['num_threads'] = num_workers
|
||||
@@ -197,6 +218,14 @@ if __name__ == "__main__":
|
||||
help="(int) Number of workers for multi-process based agents, e.g. A3C",
|
||||
default='1',
|
||||
type=str)
|
||||
parser.add_argument('--play',
|
||||
help="(flag) Play as a human by controlling the game with the keyboard. "
|
||||
"This option will save a replay buffer with the game play.",
|
||||
action='store_true')
|
||||
parser.add_argument('--evaluate',
|
||||
help="(flag) Run evaluation only. This is a convenient way to disable "
|
||||
"training in order to evaluate an existing checkpoint.",
|
||||
action='store_true')
|
||||
parser.add_argument('-v', '--verbose',
|
||||
help="(flag) Don't suppress TensorFlow debug prints.",
|
||||
action='store_true')
|
||||
@@ -230,6 +259,12 @@ if __name__ == "__main__":
|
||||
,
|
||||
default=None,
|
||||
type=str)
|
||||
parser.add_argument('-lvl', '--level',
|
||||
help="(string) Choose the level that will be played in the environment that was selected."
|
||||
"This value will override the level parameter in the environment class."
|
||||
,
|
||||
default=None,
|
||||
type=str)
|
||||
parser.add_argument('-cp', '--custom_parameter',
|
||||
help="(string) Semicolon separated parameters used to override specific parameters on top of"
|
||||
" the selected preset (or on top of the command-line assembled one). "
|
||||
@@ -259,7 +294,12 @@ if __name__ == "__main__":
|
||||
tuning_parameters.task_index = 0
|
||||
env_instance = create_environment(tuning_parameters)
|
||||
agent = eval(tuning_parameters.agent.type + '(env_instance, tuning_parameters)')
|
||||
agent.improve()
|
||||
|
||||
# Start the training or evaluation
|
||||
if tuning_parameters.evaluate:
|
||||
agent.evaluate(sys.maxsize, keep_networks_synced=True) # evaluate forever
|
||||
else:
|
||||
agent.improve()
|
||||
|
||||
# Multi-threaded runs
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user