mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
Unify base class using new-style (object).
This commit is contained in:
@@ -33,7 +33,7 @@ from architectures.tensorflow_components.shared_variables import SharedRunningSt
|
||||
from six.moves import range
|
||||
|
||||
|
||||
class Agent:
|
||||
class Agent(object):
|
||||
def __init__(self, env, tuning_parameters, replicated_device=None, task_id=0):
|
||||
"""
|
||||
:param env: An environment instance
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
from configurations import Preset
|
||||
|
||||
|
||||
class Architecture:
|
||||
class Architecture(object):
|
||||
def __init__(self, tuning_parameters, name=""):
|
||||
"""
|
||||
:param tuning_parameters: A Preset class instance with all the running paramaters
|
||||
|
||||
@@ -19,7 +19,7 @@ import ngraph as ng
|
||||
from ngraph.util.names import name_scope
|
||||
|
||||
|
||||
class InputEmbedder:
|
||||
class InputEmbedder(object):
|
||||
def __init__(self, input_size, batch_size=None, activation_function=neon.Rectlin(), name="embedder"):
|
||||
self.name = name
|
||||
self.input_size = input_size
|
||||
|
||||
@@ -22,7 +22,7 @@ from utils import force_list
|
||||
from architectures.neon_components.losses import *
|
||||
|
||||
|
||||
class Head:
|
||||
class Head(object):
|
||||
def __init__(self, tuning_parameters, head_idx=0, loss_weight=1., is_local=True):
|
||||
self.head_idx = head_idx
|
||||
self.name = "head"
|
||||
|
||||
@@ -20,7 +20,7 @@ from ngraph.util.names import name_scope
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MiddlewareEmbedder:
|
||||
class MiddlewareEmbedder(object):
|
||||
def __init__(self, activation_function=neon.Rectlin(), name="middleware_embedder"):
|
||||
self.name = name
|
||||
self.input = None
|
||||
|
||||
@@ -29,7 +29,7 @@ except ImportError:
|
||||
failed_imports.append("Neon")
|
||||
|
||||
|
||||
class NetworkWrapper:
|
||||
class NetworkWrapper(object):
|
||||
def __init__(self, tuning_parameters, has_target, has_global, name, replicated_device=None, worker_device=None):
|
||||
"""
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class InputEmbedder:
|
||||
class InputEmbedder(object):
|
||||
def __init__(self, input_size, activation_function=tf.nn.relu, name="embedder"):
|
||||
self.name = name
|
||||
self.input_size = input_size
|
||||
|
||||
@@ -28,7 +28,7 @@ def normalized_columns_initializer(std=1.0):
|
||||
return _initializer
|
||||
|
||||
|
||||
class Head:
|
||||
class Head(object):
|
||||
def __init__(self, tuning_parameters, head_idx=0, loss_weight=1., is_local=True):
|
||||
self.head_idx = head_idx
|
||||
self.name = "head"
|
||||
|
||||
@@ -18,7 +18,7 @@ import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MiddlewareEmbedder:
|
||||
class MiddlewareEmbedder(object):
|
||||
def __init__(self, activation_function=tf.nn.relu, name="middleware_embedder"):
|
||||
self.name = name
|
||||
self.input = None
|
||||
|
||||
@@ -24,7 +24,7 @@ class Frameworks(Enum):
|
||||
Neon = 2
|
||||
|
||||
|
||||
class InputTypes:
|
||||
class InputTypes(object):
|
||||
Observation = 1
|
||||
Measurements = 2
|
||||
GoalVector = 3
|
||||
@@ -32,7 +32,7 @@ class InputTypes:
|
||||
TimedObservation = 5
|
||||
|
||||
|
||||
class OutputTypes:
|
||||
class OutputTypes(object):
|
||||
Q = 1
|
||||
DuelingQ = 2
|
||||
V = 3
|
||||
@@ -45,12 +45,12 @@ class OutputTypes:
|
||||
DistributionalQ = 10
|
||||
|
||||
|
||||
class MiddlewareTypes:
|
||||
class MiddlewareTypes(object):
|
||||
LSTM = 1
|
||||
FC = 2
|
||||
|
||||
|
||||
class AgentParameters:
|
||||
class AgentParameters(object):
|
||||
agent = ''
|
||||
|
||||
# Architecture parameters
|
||||
@@ -120,7 +120,7 @@ class AgentParameters:
|
||||
share_statistics_between_workers = True
|
||||
|
||||
|
||||
class EnvironmentParameters:
|
||||
class EnvironmentParameters(object):
|
||||
type = 'Doom'
|
||||
level = 'basic'
|
||||
observation_stack_size = 4
|
||||
@@ -133,7 +133,7 @@ class EnvironmentParameters:
|
||||
reward_clipping_max = None
|
||||
|
||||
|
||||
class ExplorationParameters:
|
||||
class ExplorationParameters(object):
|
||||
# Exploration policies
|
||||
policy = 'EGreedy'
|
||||
evaluation_policy = 'Greedy'
|
||||
@@ -167,7 +167,7 @@ class ExplorationParameters:
|
||||
dt = 0.01
|
||||
|
||||
|
||||
class GeneralParameters:
|
||||
class GeneralParameters(object):
|
||||
train = True
|
||||
framework = Frameworks.TensorFlow
|
||||
threads = 1
|
||||
@@ -212,7 +212,7 @@ class GeneralParameters:
|
||||
test_num_workers = 1
|
||||
|
||||
|
||||
class VisualizationParameters:
|
||||
class VisualizationParameters(object):
|
||||
# Visualization parameters
|
||||
record_video_every = 1000
|
||||
video_path = '/home/llt_lab/temp/breakout-videos'
|
||||
|
||||
@@ -19,7 +19,7 @@ from utils import *
|
||||
from configurations import Preset
|
||||
|
||||
|
||||
class EnvironmentWrapper:
|
||||
class EnvironmentWrapper(object):
|
||||
def __init__(self, tuning_parameters):
|
||||
"""
|
||||
:param tuning_parameters:
|
||||
|
||||
@@ -19,7 +19,7 @@ from utils import *
|
||||
from configurations import *
|
||||
|
||||
|
||||
class ExplorationPolicy:
|
||||
class ExplorationPolicy(object):
|
||||
def __init__(self, tuning_parameters):
|
||||
"""
|
||||
:param tuning_parameters: A Preset class instance with all the running paramaters
|
||||
|
||||
@@ -27,7 +27,7 @@ global failed_imports
|
||||
failed_imports = []
|
||||
|
||||
|
||||
class Colors:
|
||||
class Colors(object):
|
||||
PURPLE = '\033[95m'
|
||||
CYAN = '\033[96m'
|
||||
DARKCYAN = '\033[36m'
|
||||
@@ -51,7 +51,7 @@ class Colors:
|
||||
|
||||
|
||||
# prints to screen with a prefix identifying the origin of the print
|
||||
class ScreenLogger:
|
||||
class ScreenLogger(object):
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
@@ -85,7 +85,7 @@ class ScreenLogger:
|
||||
return input("{}{}{}".format(Colors.BG_CYAN, title, Colors.END))
|
||||
|
||||
|
||||
class BaseLogger:
|
||||
class BaseLogger(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ import numpy as np
|
||||
from annoy import AnnoyIndex
|
||||
|
||||
|
||||
class AnnoyDictionary:
|
||||
class AnnoyDictionary(object):
|
||||
def __init__(self, dict_size, key_width, new_value_shift_coefficient=0.1, batch_size=100, key_error_threshold=0.01):
|
||||
self.max_size = dict_size
|
||||
self.curr_size = 0
|
||||
|
||||
@@ -19,7 +19,7 @@ import copy
|
||||
from configurations import *
|
||||
|
||||
|
||||
class Memory:
|
||||
class Memory(object):
|
||||
def __init__(self, tuning_parameters):
|
||||
"""
|
||||
:param tuning_parameters: A Preset class instance with all the running paramaters
|
||||
@@ -43,7 +43,7 @@ class Memory:
|
||||
pass
|
||||
|
||||
|
||||
class Episode:
|
||||
class Episode(object):
|
||||
def __init__(self):
|
||||
self.transitions = []
|
||||
# a num_transitions x num_transitions table with the n step return in the n'th row
|
||||
@@ -122,7 +122,7 @@ class Episode:
|
||||
return batch
|
||||
|
||||
|
||||
class Transition:
|
||||
class Transition(object):
|
||||
def __init__(self, state, action, reward, next_state, game_over):
|
||||
self.state = copy.deepcopy(state)
|
||||
self.state['observation'] = np.array(self.state['observation'], copy=False)
|
||||
|
||||
4
utils.py
4
utils.py
@@ -24,7 +24,7 @@ from subprocess import call, Popen
|
||||
killed_processes = []
|
||||
|
||||
|
||||
class Enum:
|
||||
class Enum(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@@ -177,7 +177,7 @@ def threaded_cmd_line_run(run_cmd, id=-1):
|
||||
return result
|
||||
|
||||
|
||||
class Signal:
|
||||
class Signal(object):
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.sample_count = 0
|
||||
|
||||
Reference in New Issue
Block a user