1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

provide a command line option which prints the tuning_parameters to stdout

This commit is contained in:
Zach Dwiel
2018-01-10 16:28:41 -05:00
parent 9b963c86d0
commit c7b11f1e9a
2 changed files with 31 additions and 11 deletions

View File

@@ -273,6 +273,9 @@ if __name__ == "__main__":
"\"visualization.render=False; num_training_iterations=500; optimizer='rmsprop'\"", "\"visualization.render=False; num_training_iterations=500; optimizer='rmsprop'\"",
default=None, default=None,
type=str) type=str)
parser.add_argument('--print_parameters',
help="(flag) Print tuning_parameters to stdout",
action='store_true')
args, run_dict = check_input_and_fill_run_dict(parser) args, run_dict = check_input_and_fill_run_dict(parser)
@@ -290,6 +293,9 @@ if __name__ == "__main__":
tuning_parameters = json_to_preset(json_run_dict_path) tuning_parameters = json_to_preset(json_run_dict_path)
tuning_parameters.sess = set_framework(args.framework) tuning_parameters.sess = set_framework(args.framework)
if args.print_parameters:
print('tuning_parameters', tuning_parameters)
# Single-thread runs # Single-thread runs
tuning_parameters.task_index = 0 tuning_parameters.task_index = 0
env_instance = create_environment(tuning_parameters) env_instance = create_environment(tuning_parameters)
@@ -352,5 +358,3 @@ if __name__ == "__main__":
# wait for all workers # wait for all workers
[w.wait() for w in workers] [w.wait() for w in workers]
evaluation_worker.kill() evaluation_worker.kill()

View File

@@ -16,7 +16,7 @@
from utils import Enum from utils import Enum
import json import json
from logger import screen, logger import types
class Frameworks(Enum): class Frameworks(Enum):
@@ -56,7 +56,24 @@ class MiddlewareTypes(object):
FC = 2 FC = 2
class AgentParameters(object): class Parameters(object):
def __str__(self):
parameters = {}
for k, v in self.__dict__.items():
if isinstance(v, type) and issubclass(v, Parameters):
# v.__dict__ doesn't return a dictionary but a mappingproxy
# which json doesn't serialize, so convert it into a normal
# dictionary
parameters[k] = dict(v.__dict__.items())
elif isinstance(v, types.MappingProxyType):
parameters[k] = dict(v.items())
else:
parameters[k] = v
return json.dumps(parameters, indent=4, default=repr)
class AgentParameters(Parameters):
agent = '' agent = ''
# Architecture parameters # Architecture parameters
@@ -129,7 +146,7 @@ class AgentParameters(object):
share_statistics_between_workers = True share_statistics_between_workers = True
class EnvironmentParameters(object): class EnvironmentParameters(Parameters):
type = 'Doom' type = 'Doom'
level = 'basic' level = 'basic'
observation_stack_size = 4 observation_stack_size = 4
@@ -143,7 +160,7 @@ class EnvironmentParameters(object):
human_control = False human_control = False
class ExplorationParameters(object): class ExplorationParameters(Parameters):
# Exploration policies # Exploration policies
policy = 'EGreedy' policy = 'EGreedy'
evaluation_policy = 'Greedy' evaluation_policy = 'Greedy'
@@ -177,7 +194,7 @@ class ExplorationParameters(object):
dt = 0.01 dt = 0.01
class GeneralParameters(object): class GeneralParameters(Parameters):
train = True train = True
framework = Frameworks.TensorFlow framework = Frameworks.TensorFlow
threads = 1 threads = 1
@@ -224,7 +241,7 @@ class GeneralParameters(object):
test_num_workers = 1 test_num_workers = 1
class VisualizationParameters(object): class VisualizationParameters(Parameters):
# Visualization parameters # Visualization parameters
record_video_every = 1000 record_video_every = 1000
video_path = '/home/llt_lab/temp/breakout-videos' video_path = '/home/llt_lab/temp/breakout-videos'
@@ -587,4 +604,3 @@ class Preset(GeneralParameters):
self.agent = agent self.agent = agent
self.env = env self.env = env
self.exploration = exploration self.exploration = exploration