1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

Celaning up coach code + removing play/Human agent

This commit is contained in:
Roman Dobosz
2018-05-10 09:06:10 +02:00
parent 50d38b4b98
commit 5d47368972
11 changed files with 120 additions and 281 deletions

View File

@@ -23,7 +23,6 @@ from coach.agents.ddpg_agent import DDPGAgent
from coach.agents.ddqn_agent import DDQNAgent from coach.agents.ddqn_agent import DDQNAgent
from coach.agents.dfp_agent import DFPAgent from coach.agents.dfp_agent import DFPAgent
from coach.agents.dqn_agent import DQNAgent from coach.agents.dqn_agent import DQNAgent
from coach.agents.human_agent import HumanAgent
from coach.agents.imitation_agent import ImitationAgent from coach.agents.imitation_agent import ImitationAgent
from coach.agents.mmc_agent import MixedMonteCarloAgent from coach.agents.mmc_agent import MixedMonteCarloAgent
from coach.agents.n_step_q_agent import NStepQAgent from coach.agents.n_step_q_agent import NStepQAgent
@@ -46,7 +45,6 @@ __all__ = [ActorCriticAgent,
DDQNAgent, DDQNAgent,
DFPAgent, DFPAgent,
DQNAgent, DQNAgent,
HumanAgent,
ImitationAgent, ImitationAgent,
MixedMonteCarloAgent, MixedMonteCarloAgent,
NAFAgent, NAFAgent,

View File

@@ -1,73 +0,0 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import collections
import os
import pygame
from pandas.io import pickle
from coach.agents import agent
from coach import logger
from coach import utils
class HumanAgent(agent.Agent):
def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0):
agent.Agent.__init__(self, env, tuning_parameters, replicated_device, thread_id)
self.clock = pygame.time.Clock()
self.max_fps = int(self.tp.visualization.max_fps_for_human_control)
logger.screen.log_title("Human Control Mode")
available_keys = self.env.get_available_keys()
if available_keys:
logger.screen.log("Use keyboard keys to move. Press escape to quit. Available keys:")
logger.screen.log("")
for action, key in self.env.get_available_keys():
logger.screen.log("\t- {}: {}".format(action, key))
logger.screen.separator()
def train(self):
return 0
def choose_action(self, curr_state, phase=utils.RunPhase.TRAIN):
action = self.env.get_action_from_user()
# keep constant fps
self.clock.tick(self.max_fps)
if not self.env.renderer.is_open:
self.save_replay_buffer_and_exit()
return action, {"action_value": 0}
def save_replay_buffer_and_exit(self):
replay_buffer_path = os.path.join(logger.logger.experiments_path, 'replay_buffer.p')
self.memory.tp = None
pickle.to_pickle(self.memory, replay_buffer_path)
logger.screen.log_title("Replay buffer was stored in {}".format(replay_buffer_path))
exit()
def log_to_screen(self, phase):
# log to logger.screen
logger.screen.log_dict(
collections.OrderedDict([
("Episode", self.current_episode),
("total reward", self.total_reward_in_current_episode),
("steps", self.total_steps_counter)
]),
prefix="Recording"
)

View File

@@ -16,18 +16,10 @@
import os import os
import collections import collections
from coach import configurations as conf
from coach import logger
try:
import tensorflow as tf import tensorflow as tf
from coach.architectures.tensorflow_components import general_network as tf_net
except ImportError:
logger.failed_imports.append("TensorFlow")
try: from coach.architectures.tensorflow_components import general_network as tf_net
from coach.architectures.neon_components import general_network as neon_net from coach import logger
except ImportError:
logger.failed_imports.append("Neon")
class NetworkWrapper(object): class NetworkWrapper(object):
@@ -50,13 +42,7 @@ class NetworkWrapper(object):
self.has_global = has_global self.has_global = has_global
self.name = name self.name = name
self.sess = tuning_parameters.sess self.sess = tuning_parameters.sess
if self.tp.framework == conf.Frameworks.TensorFlow:
general_network = tf_net.GeneralTensorFlowNetwork general_network = tf_net.GeneralTensorFlowNetwork
elif self.tp.framework == conf.Frameworks.Neon:
general_network = neon_net.GeneralNeonNetwork
else:
raise Exception("{} Framework is not supported".format(conf.Frameworks().to_string(self.tp.framework)))
# Global network - the main network shared between threads # Global network - the main network shared between threads
self.global_network = None self.global_network = None
@@ -78,7 +64,7 @@ class NetworkWrapper(object):
self.target_network = general_network(tuning_parameters, '{}/target'.format(name), self.target_network = general_network(tuning_parameters, '{}/target'.format(name),
network_is_local=True) network_is_local=True)
if not self.tp.distributed and self.tp.framework == conf.Frameworks.TensorFlow: if not self.tp.distributed:
variables_to_restore = tf.global_variables() variables_to_restore = tf.global_variables()
variables_to_restore = [v for v in variables_to_restore if '/online' in v.name] variables_to_restore = [v for v in variables_to_restore if '/online' in v.name]
self.model_saver = tf.train.Saver(variables_to_restore) self.model_saver = tf.train.Saver(variables_to_restore)

View File

@@ -22,6 +22,8 @@ import subprocess
import sys import sys
import time import time
import tensorflow as tf
from coach import agents # noqa from coach import agents # noqa
from coach import configurations as conf from coach import configurations as conf
from coach import environments from coach import environments
@@ -30,28 +32,6 @@ from coach import presets
from coach import utils from coach import utils
if len(set(logger.failed_imports)) > 0:
logger.screen.warning("Warning: failed to import the following packages - {}".format(', '.join(set(logger.failed_imports))))
def set_framework(framework_type):
# choosing neural network framework
framework = conf.Frameworks().get(framework_type)
sess = None
if framework == conf.Frameworks.TensorFlow:
import tensorflow as tf
config = tf.ConfigProto()
config.allow_soft_placement = True
config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = 0.2
sess = tf.Session(config=config)
elif framework == conf.Frameworks.Neon:
import ngraph as ng
sess = ng.transformers.make_transformer()
logger.screen.log_title("Using {} framework".format(conf.Frameworks().to_string(framework)))
return sess
def check_input_and_fill_run_dict(parser): def check_input_and_fill_run_dict(parser):
args = parser.parse_args() args = parser.parse_args()
@@ -68,48 +48,28 @@ def check_input_and_fill_run_dict(parser):
print(preset) print(preset)
sys.exit(0) sys.exit(0)
# check inputs
try:
# num_workers = int(args.num_workers)
num_workers = int(re.match("^\d+$", args.num_workers).group(0))
except ValueError:
logger.screen.error("Parameter num_workers should be an integer.")
preset_names = utils.list_all_classes_in_module(presets) preset_names = utils.list_all_classes_in_module(presets)
if args.preset is not None and args.preset not in preset_names: if args.preset is not None and args.preset not in preset_names:
logger.screen.error("A non-existing preset was selected. ") logger.screen.error("A non-existing preset was selected. ")
if args.checkpoint_restore_dir is not None and not os.path.exists(args.checkpoint_restore_dir): if (args.checkpoint_restore_dir is not None and not
logger.screen.error("The requested checkpoint folder to load from does not exist. ") os.path.exists(args.checkpoint_restore_dir)):
logger.screen.error("The requested checkpoint folder to load from "
"does not exist. ")
if args.save_model_sec is not None: if (not args.preset and not
try: all([args.agent_type, args.environment_type,
args.save_model_sec = int(args.save_model_sec) args.exploration_policy_type])):
except ValueError: logger.screen.error('When no preset is given for Coach to run, the '
logger.screen.error("Parameter save_model_sec should be an integer.") 'user is expected to input the desired agent_type,'
' environment_type and exploration_policy_type to'
if args.preset is None and (args.agent_type is None or args.environment_type is None ' assemble a preset.\nAt least one of these '
or args.exploration_policy_type is None) and not args.play: 'parameters was not given.')
logger.screen.error('When no preset is given for Coach to run, the user is expected to input the desired agent_type,'
' environment_type and exploration_policy_type to assemble a preset. '
'\nAt least one of these parameters was not given.')
elif args.preset is None and args.play and args.environment_type is None:
logger.screen.error('When no preset is given for Coach to run, and the user requests human control over the environment,'
' the user is expected to input the desired environment_type and level.'
'\nAt least one of these parameters was not given.')
elif args.preset is None and args.play and args.environment_type:
args.agent_type = 'Human'
args.exploration_policy_type = 'ExplorationParameters'
# get experiment name and path # get experiment name and path
experiment_name = logger.logger.get_experiment_name(args.experiment_name) experiment_name = logger.logger.get_experiment_name(args.experiment_name)
experiment_path = logger.logger.get_experiment_path(experiment_name) experiment_path = logger.logger.get_experiment_path(experiment_name)
if args.play and num_workers > 1:
logger.screen.warning("Playing the game as a human is only available with a single worker. "
"The number of workers will be reduced to 1")
num_workers = 1
# fill run_dict # fill run_dict
run_dict = dict() run_dict = dict()
run_dict['agent_type'] = args.agent_type run_dict['agent_type'] = args.agent_type
@@ -119,16 +79,16 @@ def check_input_and_fill_run_dict(parser):
run_dict['preset'] = args.preset run_dict['preset'] = args.preset
run_dict['custom_parameter'] = args.custom_parameter run_dict['custom_parameter'] = args.custom_parameter
run_dict['experiment_path'] = experiment_path run_dict['experiment_path'] = experiment_path
run_dict['framework'] = conf.Frameworks().get(args.framework) run_dict['evaluate'] = args.evaluate
run_dict['play'] = args.play
run_dict['evaluate'] = args.evaluate# or args.play
# multi-threading parameters # multi-threading parameters
run_dict['num_threads'] = num_workers run_dict['num_threads'] = args.num_workers
# checkpoints # checkpoints
run_dict['save_model_sec'] = args.save_model_sec run_dict['save_model_sec'] = args.save_model_sec
run_dict['save_model_dir'] = experiment_path if args.save_model_sec is not None else None run_dict['save_model_dir'] = None
if args.save_model_sec:
run_dict['save_model_dir'] = experiment_path
run_dict['checkpoint_restore_dir'] = args.checkpoint_restore_dir run_dict['checkpoint_restore_dir'] = args.checkpoint_restore_dir
# visualization # visualization
@@ -141,7 +101,8 @@ def check_input_and_fill_run_dict(parser):
def run_dict_to_json(_run_dict, task_id=''): def run_dict_to_json(_run_dict, task_id=''):
if task_id != '': if task_id != '':
json_path = os.path.join(_run_dict['experiment_path'], 'run_dict_worker{}.json'.format(task_id)) json_path = os.path.join(_run_dict['experiment_path'],
'run_dict_worker{}.json'.format(task_id))
else: else:
json_path = os.path.join(_run_dict['experiment_path'], 'run_dict.json') json_path = os.path.join(_run_dict['experiment_path'], 'run_dict.json')
@@ -153,97 +114,82 @@ def run_dict_to_json(_run_dict, task_id=''):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-p', '--preset', parser.add_argument('-p', '--preset', default=None,
help="(string) Name of a preset to run (as configured in presets.py)", help='(string) Name of a preset to run (as configured '
default=None, 'in presets.py)')
type=str) parser.add_argument('-l', '--list', action='store_true',
parser.add_argument('-l', '--list', help='(flag) List all available presets')
help="(flag) List all available presets", parser.add_argument('-e', '--experiment_name', default='',
action='store_true') help='(string) Experiment name to be used to store '
parser.add_argument('-e', '--experiment_name', 'the results.')
help="(string) Experiment name to be used to store the results.", parser.add_argument('-r', '--render', action='store_true',
default='', help='(flag) Render environment')
type=str) parser.add_argument('-n', '--num_workers', default=1, type=int,
parser.add_argument('-r', '--render', help='(int) Number of workers for multi-process based '
help="(flag) Render environment", 'agents, e.g. A3C')
action='store_true') parser.add_argument('--evaluate', action='store_true',
parser.add_argument('-f', '--framework', help='(flag) Run evaluation only. This is a '
help="(string) Neural network framework. Available values: tensorflow, neon", 'convenient way to disable training in order to '
default='tensorflow', 'evaluate an existing checkpoint.')
type=str) parser.add_argument('-v', '--verbose', action='store_true',
parser.add_argument('-n', '--num_workers', help='(flag) Don\'t suppress TensorFlow debug prints.')
help="(int) Number of workers for multi-process based agents, e.g. A3C", parser.add_argument('-s', '--save_model_sec', default=None, type=int,
default='1', help='(int) Time in seconds between saving checkpoints'
type=str) ' of the model.')
parser.add_argument('--play',
help="(flag) Play as a human by controlling the game with the keyboard. "
"This option will save a replay buffer with the game play.",
action='store_true')
parser.add_argument('--evaluate',
help="(flag) Run evaluation only. This is a convenient way to disable "
"training in order to evaluate an existing checkpoint.",
action='store_true')
parser.add_argument('-v', '--verbose',
help="(flag) Don't suppress TensorFlow debug prints.",
action='store_true')
parser.add_argument('-s', '--save_model_sec',
help="(int) Time in seconds between saving checkpoints of the model.",
default=None,
type=int)
parser.add_argument('-crd', '--checkpoint_restore_dir', parser.add_argument('-crd', '--checkpoint_restore_dir',
help='(string) Path to a folder containing a checkpoint to restore the model from.', help='(string) Path to a folder containing a '
type=str) 'checkpoint to restore the model from.')
parser.add_argument('-dg', '--dump_gifs', parser.add_argument('-dg', '--dump_gifs', action='store_true',
help="(flag) Enable the gif saving functionality.", help='(flag) Enable the gif saving functionality.')
action='store_true') parser.add_argument('-at', '--agent_type', default=None,
parser.add_argument('-at', '--agent_type', help='(string) Choose an agent type class to override'
help="(string) Choose an agent type class to override on top of the selected preset. " ' on top of the selected preset. If no preset is '
"If no preset is defined, a preset can be set from the command-line by combining settings " 'defined, a preset can be set from the command-line '
"which are set by using --agent_type, --experiment_type, --environemnt_type", 'by combining settings which are set by using '
default=None, '--agent_type, --experiment_type, --environemnt_type')
type=str) parser.add_argument('-et', '--environment_type', default=None,
parser.add_argument('-et', '--environment_type', help='(string) Choose an environment type class to '
help="(string) Choose an environment type class to override on top of the selected preset." 'override on top of the selected preset. If no preset'
"If no preset is defined, a preset can be set from the command-line by combining settings " ' is defined, a preset can be set from the '
"which are set by using --agent_type, --experiment_type, --environemnt_type", 'command-line by combining settings which are set by '
default=None, 'using --agent_type, --experiment_type, '
type=str) '--environemnt_type')
parser.add_argument('-ept', '--exploration_policy_type', parser.add_argument('-ept', '--exploration_policy_type', default=None,
help="(string) Choose an exploration policy type class to override on top of the selected " help='(string) Choose an exploration policy type '
"preset." 'class to override on top of the selected preset. If '
"If no preset is defined, a preset can be set from the command-line by combining settings " 'no preset is defined, a preset can be set from the '
"which are set by using --agent_type, --experiment_type, --environemnt_type" 'command-line by combining settings which are set by '
, 'using --agent_type, --experiment_type, '
default=None, '--environemnt_type')
type=str) parser.add_argument('-lvl', '--level', default=None,
parser.add_argument('-lvl', '--level', help='(string) Choose the level that will be played '
help="(string) Choose the level that will be played in the environment that was selected." 'in the environment that was selected. This value '
"This value will override the level parameter in the environment class." 'will override the level parameter in the environment '
, 'class.')
default=None, parser.add_argument('-cp', '--custom_parameter', default=None,
type=str) help='(string) Semicolon separated parameters used to '
parser.add_argument('-cp', '--custom_parameter', 'override specific parameters on top of the selected '
help="(string) Semicolon separated parameters used to override specific parameters on top of" 'preset (or on top of the command-line assembled '
" the selected preset (or on top of the command-line assembled one). " 'one). Whenever a parameter value is a string, it '
"Whenever a parameter value is a string, it should be inputted as '\\\"string\\\"'. " 'should be inputted as "string". For ex.: '
"For ex.: " '"visualization.render=False; '
"\"visualization.render=False; num_training_iterations=500; optimizer='rmsprop'\"", 'num_training_iterations=500; optimizer=\'rmsprop\'"')
default=None, parser.add_argument('-pf', '--parameters_file', default=None,
type=str) help='YAML file with customized parameters, just like '
parser.add_argument('--print_parameters', '\'--custom-parameter\' bit in a file for convenience')
help="(flag) Print tuning_parameters to stdout", parser.add_argument('--print_parameters', action='store_true',
action='store_true') help='(flag) Print tuning_parameters to stdout')
parser.add_argument('-tb', '--tensorboard', parser.add_argument('-tb', '--tensorboard', action='store_true',
help="(flag) When using the TensorFlow backend, enable TensorBoard log dumps. ", help='(flag) When using the TensorFlow backend, '
action='store_true') 'enable TensorBoard log dumps. ')
parser.add_argument('-ns', '--no_summary', parser.add_argument('-ns', '--no_summary', action='store_true',
help="(flag) Prevent Coach from printing a summary and asking questions at the end of runs", help='(flag) Prevent Coach from printing a summary '
action='store_true') 'and asking questions at the end of runs')
args, run_dict = check_input_and_fill_run_dict(parser) args, run_dict = check_input_and_fill_run_dict(parser)
# turn TF debug prints off # turn TF debug prints off
if not args.verbose and args.framework.lower() == 'tensorflow': if not args.verbose:
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
# dump documentation # dump documentation
@@ -257,43 +203,46 @@ def main():
# set tuning parameters # set tuning parameters
json_run_dict_path = run_dict_to_json(run_dict) json_run_dict_path = run_dict_to_json(run_dict)
tuning_parameters = presets.json_to_preset(json_run_dict_path) tuning_parameters = presets.json_to_preset(json_run_dict_path)
tuning_parameters.sess = set_framework(args.framework) config = tf.ConfigProto()
config.allow_soft_placement = True
config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = 0.2
tuning_parameters.sess = tf.Session(config=config)
if args.print_parameters: if args.print_parameters:
print('tuning_parameters', tuning_parameters) print('tuning_parameters', tuning_parameters)
# Single-thread runs # Single-thread runs
tuning_parameters.task_index = 0 tuning_parameters.task_index = 0
env_instance = environments.create_environment(tuning_parameters) env_instance = environments.create_environment(tuning_parameters) # noqa
agent = eval('agents.' + tuning_parameters.agent.type + agent = eval('agents.' + tuning_parameters.agent.type +
'(env_instance, tuning_parameters)') '(env_instance, tuning_parameters)')
# Start the training or evaluation # Start the training or evaluation
if tuning_parameters.evaluate: if tuning_parameters.evaluate:
agent.evaluate(sys.maxsize, keep_networks_synced=True) # evaluate forever # evaluate forever
agent.evaluate(sys.maxsize, keep_networks_synced=True)
else: else:
agent.improve() agent.improve()
# Multi-threaded runs # Multi-threaded runs
else: else:
assert args.framework.lower() == 'tensorflow', "Distributed training works only with TensorFlow" os.environ['OMP_NUM_THREADS'] = '1'
os.environ["OMP_NUM_THREADS"]="1"
# set parameter server and workers addresses # set parameter server and workers addresses
ps_hosts = "localhost:{}".format(utils.get_open_port()) ps_hosts = 'localhost:{}'.format(utils.get_open_port())
worker_hosts = ",".join(["localhost:{}".format(utils.get_open_port()) for i in range(run_dict['num_threads'] + 1)]) worker_hosts = ','.join(['localhost:{}'.format(utils.get_open_port())
for i in range(run_dict['num_threads'] + 1)])
# Make sure to disable GPU so that all the workers will use the CPU # Make sure to disable GPU so that all the workers will use the CPU
utils.set_cpu() utils.set_cpu()
# create a parameter server # create a parameter server
cmd = [ cmd = ["python3",
"python3",
"./parallel_actor.py", "./parallel_actor.py",
"--ps_hosts={}".format(ps_hosts), "--ps_hosts={}".format(ps_hosts),
"--worker_hosts={}".format(worker_hosts), "--worker_hosts={}".format(worker_hosts),
"--job_name=ps", "--job_name=ps"]
] subprocess.Popen(cmd)
parameter_server = subprocess.Popen(cmd)
logger.screen.log_title("*** Distributed Training ***") logger.screen.log_title("*** Distributed Training ***")
time.sleep(1) time.sleep(1)
@@ -309,7 +258,8 @@ def main():
run_dict['visualization.render'] = args.render run_dict['visualization.render'] = args.render
else: else:
run_dict['evaluate_only'] = False run_dict['evaluate_only'] = False
run_dict['visualization.render'] = False # #In a parallel setting, only the evaluation agent renders # In a parallel setting, only the evaluation agent renders
run_dict['visualization.render'] = False
json_run_dict_path = run_dict_to_json(run_dict, i) json_run_dict_path = run_dict_to_json(run_dict, i)
workers_args = ["python3", "./parallel_actor.py", workers_args = ["python3", "./parallel_actor.py",

View File

@@ -160,7 +160,6 @@ class EnvironmentParameters(Parameters):
reward_scaling = 1.0 reward_scaling = 1.0
reward_clipping_min = None reward_clipping_min = None
reward_clipping_max = None reward_clipping_max = None
human_control = False
class ExplorationParameters(Parameters): class ExplorationParameters(Parameters):
@@ -257,7 +256,6 @@ class VisualizationParameters(Parameters):
dump_signals_to_csv_every_x_episodes = 5 dump_signals_to_csv_every_x_episodes = 5
render = False render = False
dump_gifs = True dump_gifs = True
max_fps_for_human_control = 10
tensorboard = False tensorboard = False
@@ -325,11 +323,6 @@ class Carla(EnvironmentParameters):
allow_braking = False allow_braking = False
class Human(AgentParameters):
type = 'HumanAgent'
num_episodes_in_experience_replay = 10000000
class NStepQ(AgentParameters): class NStepQ(AgentParameters):
type = 'NStepQAgent' type = 'NStepQAgent'
input_types = {'observation': InputTypes.Observation} input_types = {'observation': InputTypes.Observation}

View File

@@ -91,11 +91,7 @@ class DoomEnvironmentWrapper(ew.EnvironmentWrapper):
self.game.set_window_visible(False) self.game.set_window_visible(False)
self.game.add_game_args("+vid_forcesurface 1") self.game.add_game_args("+vid_forcesurface 1")
self.wait_for_explicit_human_action = True if self.is_rendered:
if self.human_control:
self.game.set_screen_resolution(vizdoom.ScreenResolution.RES_640X480)
self.renderer.create_screen(640, 480)
elif self.is_rendered:
self.game.set_screen_resolution(vizdoom.ScreenResolution.RES_320X240) self.game.set_screen_resolution(vizdoom.ScreenResolution.RES_320X240)
self.renderer.create_screen(320, 240) self.renderer.create_screen(320, 240)
else: else:

View File

@@ -57,9 +57,6 @@ class EnvironmentWrapper(object):
self.is_rendered = self.tp.visualization.render self.is_rendered = self.tp.visualization.render
self.seed = self.tp.seed self.seed = self.tp.seed
self.frame_skip = self.tp.env.frame_skip self.frame_skip = self.tp.env.frame_skip
self.human_control = self.tp.env.human_control
self.wait_for_explicit_human_action = False
self.is_rendered = self.is_rendered or self.human_control
self.game_is_open = True self.game_is_open = True
@property @property

View File

@@ -45,8 +45,6 @@ class GymEnvironmentWrapper(ew.EnvironmentWrapper):
if self.is_rendered: if self.is_rendered:
image = self.get_rendered_image() image = self.get_rendered_image()
scale = 1 scale = 1
if self.human_control:
scale = 2
self.renderer.create_screen(image.shape[1]*scale, image.shape[0]*scale) self.renderer.create_screen(image.shape[1]*scale, image.shape[0]*scale)
if isinstance(self.env.observation_space, gym.spaces.Dict): if isinstance(self.env.observation_space, gym.spaces.Dict):

View File

@@ -221,12 +221,12 @@ class Logger(BaseLogger):
def get_signal_value(self, time, signal_name): def get_signal_value(self, time, signal_name):
return self.data.loc[time, signal_name] return self.data.loc[time, signal_name]
def dump_output_csv(self, append=True): def dump_output_csv(self):
self.data.index.name = "Episode #" self.data.index.name = "Episode #"
if len(self.data.index) == 1: if len(self.data.index) == 1:
self.start_time = time.time() self.start_time = time.time()
if os.path.exists(self.csv_path) and append: if os.path.exists(self.csv_path):
self.data[self.last_line_idx_written_to_csv:].to_csv(self.csv_path, mode='a', header=False) self.data[self.last_line_idx_written_to_csv:].to_csv(self.csv_path, mode='a', header=False)
else: else:
self.data.to_csv(self.csv_path) self.data.to_csv(self.csv_path)

View File

@@ -44,12 +44,6 @@ def json_to_preset(json_path):
if run_dict['exploration_policy_type'] is not None: if run_dict['exploration_policy_type'] is not None:
tuning_parameters.exploration = eval('ep.' + run_dict['exploration_policy_type'])() tuning_parameters.exploration = eval('ep.' + run_dict['exploration_policy_type'])()
# human control
if run_dict['play']:
tuning_parameters.agent.type = 'HumanAgent'
tuning_parameters.env.human_control = True
tuning_parameters.num_heatup_steps = 0
if run_dict['level']: if run_dict['level']:
tuning_parameters.env.level = run_dict['level'] tuning_parameters.env.level = run_dict['level']

0
scripts/coach Normal file → Executable file
View File