mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
provide a command line option which prints the tuning_parameters to stdout
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
from utils import Enum
|
||||
import json
|
||||
from logger import screen, logger
|
||||
import types
|
||||
|
||||
|
||||
class Frameworks(Enum):
|
||||
@@ -56,7 +56,24 @@ class MiddlewareTypes(object):
|
||||
FC = 2
|
||||
|
||||
|
||||
class AgentParameters(object):
|
||||
class Parameters(object):
|
||||
def __str__(self):
|
||||
parameters = {}
|
||||
for k, v in self.__dict__.items():
|
||||
if isinstance(v, type) and issubclass(v, Parameters):
|
||||
# v.__dict__ doesn't return a dictionary but a mappingproxy
|
||||
# which json doesn't serialize, so convert it into a normal
|
||||
# dictionary
|
||||
parameters[k] = dict(v.__dict__.items())
|
||||
elif isinstance(v, types.MappingProxyType):
|
||||
parameters[k] = dict(v.items())
|
||||
else:
|
||||
parameters[k] = v
|
||||
|
||||
return json.dumps(parameters, indent=4, default=repr)
|
||||
|
||||
|
||||
class AgentParameters(Parameters):
|
||||
agent = ''
|
||||
|
||||
# Architecture parameters
|
||||
@@ -129,7 +146,7 @@ class AgentParameters(object):
|
||||
share_statistics_between_workers = True
|
||||
|
||||
|
||||
class EnvironmentParameters(object):
|
||||
class EnvironmentParameters(Parameters):
|
||||
type = 'Doom'
|
||||
level = 'basic'
|
||||
observation_stack_size = 4
|
||||
@@ -143,7 +160,7 @@ class EnvironmentParameters(object):
|
||||
human_control = False
|
||||
|
||||
|
||||
class ExplorationParameters(object):
|
||||
class ExplorationParameters(Parameters):
|
||||
# Exploration policies
|
||||
policy = 'EGreedy'
|
||||
evaluation_policy = 'Greedy'
|
||||
@@ -177,7 +194,7 @@ class ExplorationParameters(object):
|
||||
dt = 0.01
|
||||
|
||||
|
||||
class GeneralParameters(object):
|
||||
class GeneralParameters(Parameters):
|
||||
train = True
|
||||
framework = Frameworks.TensorFlow
|
||||
threads = 1
|
||||
@@ -224,7 +241,7 @@ class GeneralParameters(object):
|
||||
test_num_workers = 1
|
||||
|
||||
|
||||
class VisualizationParameters(object):
|
||||
class VisualizationParameters(Parameters):
|
||||
# Visualization parameters
|
||||
record_video_every = 1000
|
||||
video_path = '/home/llt_lab/temp/breakout-videos'
|
||||
@@ -587,4 +604,3 @@ class Preset(GeneralParameters):
|
||||
self.agent = agent
|
||||
self.env = env
|
||||
self.exploration = exploration
|
||||
|
||||
|
||||
Reference in New Issue
Block a user