mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
Merge pull request #42 from NervanaSystems/print_parameters
provide a command line option which prints the tuning_parameters to stdout
This commit is contained in:
10
coach.py
10
coach.py
@@ -1,5 +1,5 @@
|
|||||||
#
|
#
|
||||||
# Copyright (c) 2017 Intel Corporation
|
# Copyright (c) 2017 Intel Corporation
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -273,6 +273,9 @@ if __name__ == "__main__":
|
|||||||
"\"visualization.render=False; num_training_iterations=500; optimizer='rmsprop'\"",
|
"\"visualization.render=False; num_training_iterations=500; optimizer='rmsprop'\"",
|
||||||
default=None,
|
default=None,
|
||||||
type=str)
|
type=str)
|
||||||
|
parser.add_argument('--print_parameters',
|
||||||
|
help="(flag) Print tuning_parameters to stdout",
|
||||||
|
action='store_true')
|
||||||
|
|
||||||
args, run_dict = check_input_and_fill_run_dict(parser)
|
args, run_dict = check_input_and_fill_run_dict(parser)
|
||||||
|
|
||||||
@@ -290,6 +293,9 @@ if __name__ == "__main__":
|
|||||||
tuning_parameters = json_to_preset(json_run_dict_path)
|
tuning_parameters = json_to_preset(json_run_dict_path)
|
||||||
tuning_parameters.sess = set_framework(args.framework)
|
tuning_parameters.sess = set_framework(args.framework)
|
||||||
|
|
||||||
|
if args.print_parameters:
|
||||||
|
print('tuning_parameters', tuning_parameters)
|
||||||
|
|
||||||
# Single-thread runs
|
# Single-thread runs
|
||||||
tuning_parameters.task_index = 0
|
tuning_parameters.task_index = 0
|
||||||
env_instance = create_environment(tuning_parameters)
|
env_instance = create_environment(tuning_parameters)
|
||||||
@@ -352,5 +358,3 @@ if __name__ == "__main__":
|
|||||||
# wait for all workers
|
# wait for all workers
|
||||||
[w.wait() for w in workers]
|
[w.wait() for w in workers]
|
||||||
evaluation_worker.kill()
|
evaluation_worker.kill()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
#
|
#
|
||||||
# Copyright (c) 2017 Intel Corporation
|
# Copyright (c) 2017 Intel Corporation
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
from utils import Enum
|
from utils import Enum
|
||||||
import json
|
import json
|
||||||
from logger import screen, logger
|
import types
|
||||||
|
|
||||||
|
|
||||||
class Frameworks(Enum):
|
class Frameworks(Enum):
|
||||||
@@ -56,7 +56,24 @@ class MiddlewareTypes(object):
|
|||||||
FC = 2
|
FC = 2
|
||||||
|
|
||||||
|
|
||||||
class AgentParameters(object):
|
class Parameters(object):
|
||||||
|
def __str__(self):
|
||||||
|
parameters = {}
|
||||||
|
for k, v in self.__dict__.items():
|
||||||
|
if isinstance(v, type) and issubclass(v, Parameters):
|
||||||
|
# v.__dict__ doesn't return a dictionary but a mappingproxy
|
||||||
|
# which json doesn't serialize, so convert it into a normal
|
||||||
|
# dictionary
|
||||||
|
parameters[k] = dict(v.__dict__.items())
|
||||||
|
elif isinstance(v, types.MappingProxyType):
|
||||||
|
parameters[k] = dict(v.items())
|
||||||
|
else:
|
||||||
|
parameters[k] = v
|
||||||
|
|
||||||
|
return json.dumps(parameters, indent=4, default=repr)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentParameters(Parameters):
|
||||||
agent = ''
|
agent = ''
|
||||||
|
|
||||||
# Architecture parameters
|
# Architecture parameters
|
||||||
@@ -129,7 +146,7 @@ class AgentParameters(object):
|
|||||||
share_statistics_between_workers = True
|
share_statistics_between_workers = True
|
||||||
|
|
||||||
|
|
||||||
class EnvironmentParameters(object):
|
class EnvironmentParameters(Parameters):
|
||||||
type = 'Doom'
|
type = 'Doom'
|
||||||
level = 'basic'
|
level = 'basic'
|
||||||
observation_stack_size = 4
|
observation_stack_size = 4
|
||||||
@@ -143,7 +160,7 @@ class EnvironmentParameters(object):
|
|||||||
human_control = False
|
human_control = False
|
||||||
|
|
||||||
|
|
||||||
class ExplorationParameters(object):
|
class ExplorationParameters(Parameters):
|
||||||
# Exploration policies
|
# Exploration policies
|
||||||
policy = 'EGreedy'
|
policy = 'EGreedy'
|
||||||
evaluation_policy = 'Greedy'
|
evaluation_policy = 'Greedy'
|
||||||
@@ -177,7 +194,7 @@ class ExplorationParameters(object):
|
|||||||
dt = 0.01
|
dt = 0.01
|
||||||
|
|
||||||
|
|
||||||
class GeneralParameters(object):
|
class GeneralParameters(Parameters):
|
||||||
train = True
|
train = True
|
||||||
framework = Frameworks.TensorFlow
|
framework = Frameworks.TensorFlow
|
||||||
threads = 1
|
threads = 1
|
||||||
@@ -224,7 +241,7 @@ class GeneralParameters(object):
|
|||||||
test_num_workers = 1
|
test_num_workers = 1
|
||||||
|
|
||||||
|
|
||||||
class VisualizationParameters(object):
|
class VisualizationParameters(Parameters):
|
||||||
# Visualization parameters
|
# Visualization parameters
|
||||||
record_video_every = 1000
|
record_video_every = 1000
|
||||||
video_path = '/home/llt_lab/temp/breakout-videos'
|
video_path = '/home/llt_lab/temp/breakout-videos'
|
||||||
@@ -587,4 +604,3 @@ class Preset(GeneralParameters):
|
|||||||
self.agent = agent
|
self.agent = agent
|
||||||
self.env = env
|
self.env = env
|
||||||
self.exploration = exploration
|
self.exploration = exploration
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user