From c7b11f1e9a85e2bb5b3a391ca697d3d9a7cf7e45 Mon Sep 17 00:00:00 2001 From: Zach Dwiel Date: Wed, 10 Jan 2018 16:28:41 -0500 Subject: [PATCH] provide a command line option which prints the tuning_parameters to stdout --- coach.py | 10 +++++++--- configurations.py | 32 ++++++++++++++++++++++++-------- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/coach.py b/coach.py index 45b7382..171b34f 100644 --- a/coach.py +++ b/coach.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2017 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -273,6 +273,9 @@ if __name__ == "__main__": "\"visualization.render=False; num_training_iterations=500; optimizer='rmsprop'\"", default=None, type=str) + parser.add_argument('--print_parameters', + help="(flag) Print tuning_parameters to stdout", + action='store_true') args, run_dict = check_input_and_fill_run_dict(parser) @@ -290,6 +293,9 @@ if __name__ == "__main__": tuning_parameters = json_to_preset(json_run_dict_path) tuning_parameters.sess = set_framework(args.framework) + if args.print_parameters: + print('tuning_parameters', tuning_parameters) + # Single-thread runs tuning_parameters.task_index = 0 env_instance = create_environment(tuning_parameters) @@ -352,5 +358,3 @@ if __name__ == "__main__": # wait for all workers [w.wait() for w in workers] evaluation_worker.kill() - - diff --git a/configurations.py b/configurations.py index 8480c19..17035ed 100644 --- a/configurations.py +++ b/configurations.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2017 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ from utils import Enum import json -from logger import screen, logger +import types class Frameworks(Enum): @@ -56,7 +56,24 @@ class MiddlewareTypes(object): FC = 2 -class AgentParameters(object): +class Parameters(object): + def __str__(self): + parameters = {} + for k, v in self.__dict__.items(): + if isinstance(v, type) and issubclass(v, Parameters): + # v.__dict__ doesn't return a dictionary but a mappingproxy + # which json doesn't serialize, so convert it into a normal + # dictionary + parameters[k] = dict(v.__dict__.items()) + elif isinstance(v, types.MappingProxyType): + parameters[k] = dict(v.items()) + else: + parameters[k] = v + + return json.dumps(parameters, indent=4, default=repr) + + +class AgentParameters(Parameters): agent = '' # Architecture parameters @@ -129,7 +146,7 @@ class AgentParameters(object): share_statistics_between_workers = True -class EnvironmentParameters(object): +class EnvironmentParameters(Parameters): type = 'Doom' level = 'basic' observation_stack_size = 4 @@ -143,7 +160,7 @@ class EnvironmentParameters(object): human_control = False -class ExplorationParameters(object): +class ExplorationParameters(Parameters): # Exploration policies policy = 'EGreedy' evaluation_policy = 'Greedy' @@ -177,7 +194,7 @@ class ExplorationParameters(object): dt = 0.01 -class GeneralParameters(object): +class GeneralParameters(Parameters): train = True framework = Frameworks.TensorFlow threads = 1 @@ -224,7 +241,7 @@ class GeneralParameters(object): test_num_workers = 1 -class VisualizationParameters(object): +class VisualizationParameters(Parameters): # Visualization parameters record_video_every = 1000 video_path = '/home/llt_lab/temp/breakout-videos' @@ -587,4 +604,3 @@ class Preset(GeneralParameters): self.agent = agent self.env = env self.exploration = exploration -