1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Unify base class using new-style (object).

This commit is contained in:
cxx
2017-10-26 10:16:11 +08:00
committed by Itai Caspi
parent 39cf78074c
commit f43c951c2d
16 changed files with 28 additions and 28 deletions

View File

@@ -33,7 +33,7 @@ from architectures.tensorflow_components.shared_variables import SharedRunningSt
from six.moves import range from six.moves import range
class Agent: class Agent(object):
def __init__(self, env, tuning_parameters, replicated_device=None, task_id=0): def __init__(self, env, tuning_parameters, replicated_device=None, task_id=0):
""" """
:param env: An environment instance :param env: An environment instance

View File

@@ -17,7 +17,7 @@
from configurations import Preset from configurations import Preset
class Architecture: class Architecture(object):
def __init__(self, tuning_parameters, name=""): def __init__(self, tuning_parameters, name=""):
""" """
:param tuning_parameters: A Preset class instance with all the running paramaters :param tuning_parameters: A Preset class instance with all the running paramaters

View File

@@ -19,7 +19,7 @@ import ngraph as ng
from ngraph.util.names import name_scope from ngraph.util.names import name_scope
class InputEmbedder: class InputEmbedder(object):
def __init__(self, input_size, batch_size=None, activation_function=neon.Rectlin(), name="embedder"): def __init__(self, input_size, batch_size=None, activation_function=neon.Rectlin(), name="embedder"):
self.name = name self.name = name
self.input_size = input_size self.input_size = input_size

View File

@@ -22,7 +22,7 @@ from utils import force_list
from architectures.neon_components.losses import * from architectures.neon_components.losses import *
class Head: class Head(object):
def __init__(self, tuning_parameters, head_idx=0, loss_weight=1., is_local=True): def __init__(self, tuning_parameters, head_idx=0, loss_weight=1., is_local=True):
self.head_idx = head_idx self.head_idx = head_idx
self.name = "head" self.name = "head"

View File

@@ -20,7 +20,7 @@ from ngraph.util.names import name_scope
import numpy as np import numpy as np
class MiddlewareEmbedder: class MiddlewareEmbedder(object):
def __init__(self, activation_function=neon.Rectlin(), name="middleware_embedder"): def __init__(self, activation_function=neon.Rectlin(), name="middleware_embedder"):
self.name = name self.name = name
self.input = None self.input = None

View File

@@ -29,7 +29,7 @@ except ImportError:
failed_imports.append("Neon") failed_imports.append("Neon")
class NetworkWrapper: class NetworkWrapper(object):
def __init__(self, tuning_parameters, has_target, has_global, name, replicated_device=None, worker_device=None): def __init__(self, tuning_parameters, has_target, has_global, name, replicated_device=None, worker_device=None):
""" """

View File

@@ -17,7 +17,7 @@
import tensorflow as tf import tensorflow as tf
class InputEmbedder: class InputEmbedder(object):
def __init__(self, input_size, activation_function=tf.nn.relu, name="embedder"): def __init__(self, input_size, activation_function=tf.nn.relu, name="embedder"):
self.name = name self.name = name
self.input_size = input_size self.input_size = input_size

View File

@@ -28,7 +28,7 @@ def normalized_columns_initializer(std=1.0):
return _initializer return _initializer
class Head: class Head(object):
def __init__(self, tuning_parameters, head_idx=0, loss_weight=1., is_local=True): def __init__(self, tuning_parameters, head_idx=0, loss_weight=1., is_local=True):
self.head_idx = head_idx self.head_idx = head_idx
self.name = "head" self.name = "head"

View File

@@ -18,7 +18,7 @@ import tensorflow as tf
import numpy as np import numpy as np
class MiddlewareEmbedder: class MiddlewareEmbedder(object):
def __init__(self, activation_function=tf.nn.relu, name="middleware_embedder"): def __init__(self, activation_function=tf.nn.relu, name="middleware_embedder"):
self.name = name self.name = name
self.input = None self.input = None

View File

@@ -24,7 +24,7 @@ class Frameworks(Enum):
Neon = 2 Neon = 2
class InputTypes: class InputTypes(object):
Observation = 1 Observation = 1
Measurements = 2 Measurements = 2
GoalVector = 3 GoalVector = 3
@@ -32,7 +32,7 @@ class InputTypes:
TimedObservation = 5 TimedObservation = 5
class OutputTypes: class OutputTypes(object):
Q = 1 Q = 1
DuelingQ = 2 DuelingQ = 2
V = 3 V = 3
@@ -45,12 +45,12 @@ class OutputTypes:
DistributionalQ = 10 DistributionalQ = 10
class MiddlewareTypes: class MiddlewareTypes(object):
LSTM = 1 LSTM = 1
FC = 2 FC = 2
class AgentParameters: class AgentParameters(object):
agent = '' agent = ''
# Architecture parameters # Architecture parameters
@@ -120,7 +120,7 @@ class AgentParameters:
share_statistics_between_workers = True share_statistics_between_workers = True
class EnvironmentParameters: class EnvironmentParameters(object):
type = 'Doom' type = 'Doom'
level = 'basic' level = 'basic'
observation_stack_size = 4 observation_stack_size = 4
@@ -133,7 +133,7 @@ class EnvironmentParameters:
reward_clipping_max = None reward_clipping_max = None
class ExplorationParameters: class ExplorationParameters(object):
# Exploration policies # Exploration policies
policy = 'EGreedy' policy = 'EGreedy'
evaluation_policy = 'Greedy' evaluation_policy = 'Greedy'
@@ -167,7 +167,7 @@ class ExplorationParameters:
dt = 0.01 dt = 0.01
class GeneralParameters: class GeneralParameters(object):
train = True train = True
framework = Frameworks.TensorFlow framework = Frameworks.TensorFlow
threads = 1 threads = 1
@@ -212,7 +212,7 @@ class GeneralParameters:
test_num_workers = 1 test_num_workers = 1
class VisualizationParameters: class VisualizationParameters(object):
# Visualization parameters # Visualization parameters
record_video_every = 1000 record_video_every = 1000
video_path = '/home/llt_lab/temp/breakout-videos' video_path = '/home/llt_lab/temp/breakout-videos'

View File

@@ -19,7 +19,7 @@ from utils import *
from configurations import Preset from configurations import Preset
class EnvironmentWrapper: class EnvironmentWrapper(object):
def __init__(self, tuning_parameters): def __init__(self, tuning_parameters):
""" """
:param tuning_parameters: :param tuning_parameters:

View File

@@ -19,7 +19,7 @@ from utils import *
from configurations import * from configurations import *
class ExplorationPolicy: class ExplorationPolicy(object):
def __init__(self, tuning_parameters): def __init__(self, tuning_parameters):
""" """
:param tuning_parameters: A Preset class instance with all the running paramaters :param tuning_parameters: A Preset class instance with all the running paramaters

View File

@@ -27,7 +27,7 @@ global failed_imports
failed_imports = [] failed_imports = []
class Colors: class Colors(object):
PURPLE = '\033[95m' PURPLE = '\033[95m'
CYAN = '\033[96m' CYAN = '\033[96m'
DARKCYAN = '\033[36m' DARKCYAN = '\033[36m'
@@ -51,7 +51,7 @@ class Colors:
# prints to screen with a prefix identifying the origin of the print # prints to screen with a prefix identifying the origin of the print
class ScreenLogger: class ScreenLogger(object):
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
@@ -85,7 +85,7 @@ class ScreenLogger:
return input("{}{}{}".format(Colors.BG_CYAN, title, Colors.END)) return input("{}{}{}".format(Colors.BG_CYAN, title, Colors.END))
class BaseLogger: class BaseLogger(object):
def __init__(self): def __init__(self):
pass pass

View File

@@ -18,7 +18,7 @@ import numpy as np
from annoy import AnnoyIndex from annoy import AnnoyIndex
class AnnoyDictionary: class AnnoyDictionary(object):
def __init__(self, dict_size, key_width, new_value_shift_coefficient=0.1, batch_size=100, key_error_threshold=0.01): def __init__(self, dict_size, key_width, new_value_shift_coefficient=0.1, batch_size=100, key_error_threshold=0.01):
self.max_size = dict_size self.max_size = dict_size
self.curr_size = 0 self.curr_size = 0

View File

@@ -19,7 +19,7 @@ import copy
from configurations import * from configurations import *
class Memory: class Memory(object):
def __init__(self, tuning_parameters): def __init__(self, tuning_parameters):
""" """
:param tuning_parameters: A Preset class instance with all the running paramaters :param tuning_parameters: A Preset class instance with all the running paramaters
@@ -43,7 +43,7 @@ class Memory:
pass pass
class Episode: class Episode(object):
def __init__(self): def __init__(self):
self.transitions = [] self.transitions = []
# a num_transitions x num_transitions table with the n step return in the n'th row # a num_transitions x num_transitions table with the n step return in the n'th row
@@ -122,7 +122,7 @@ class Episode:
return batch return batch
class Transition: class Transition(object):
def __init__(self, state, action, reward, next_state, game_over): def __init__(self, state, action, reward, next_state, game_over):
self.state = copy.deepcopy(state) self.state = copy.deepcopy(state)
self.state['observation'] = np.array(self.state['observation'], copy=False) self.state['observation'] = np.array(self.state['observation'], copy=False)

View File

@@ -24,7 +24,7 @@ from subprocess import call, Popen
killed_processes = [] killed_processes = []
class Enum: class Enum(object):
def __init__(self): def __init__(self):
pass pass
@@ -177,7 +177,7 @@ def threaded_cmd_line_run(run_cmd, id=-1):
return result return result
class Signal: class Signal(object):
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
self.sample_count = 0 self.sample_count = 0