mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
Unify base class using new-style (object).
This commit is contained in:
@@ -17,7 +17,7 @@
|
||||
from configurations import Preset
|
||||
|
||||
|
||||
class Architecture:
|
||||
class Architecture(object):
|
||||
def __init__(self, tuning_parameters, name=""):
|
||||
"""
|
||||
:param tuning_parameters: A Preset class instance with all the running paramaters
|
||||
|
||||
@@ -19,7 +19,7 @@ import ngraph as ng
|
||||
from ngraph.util.names import name_scope
|
||||
|
||||
|
||||
class InputEmbedder:
|
||||
class InputEmbedder(object):
|
||||
def __init__(self, input_size, batch_size=None, activation_function=neon.Rectlin(), name="embedder"):
|
||||
self.name = name
|
||||
self.input_size = input_size
|
||||
|
||||
@@ -22,7 +22,7 @@ from utils import force_list
|
||||
from architectures.neon_components.losses import *
|
||||
|
||||
|
||||
class Head:
|
||||
class Head(object):
|
||||
def __init__(self, tuning_parameters, head_idx=0, loss_weight=1., is_local=True):
|
||||
self.head_idx = head_idx
|
||||
self.name = "head"
|
||||
|
||||
@@ -20,7 +20,7 @@ from ngraph.util.names import name_scope
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MiddlewareEmbedder:
|
||||
class MiddlewareEmbedder(object):
|
||||
def __init__(self, activation_function=neon.Rectlin(), name="middleware_embedder"):
|
||||
self.name = name
|
||||
self.input = None
|
||||
|
||||
@@ -29,7 +29,7 @@ except ImportError:
|
||||
failed_imports.append("Neon")
|
||||
|
||||
|
||||
class NetworkWrapper:
|
||||
class NetworkWrapper(object):
|
||||
def __init__(self, tuning_parameters, has_target, has_global, name, replicated_device=None, worker_device=None):
|
||||
"""
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class InputEmbedder:
|
||||
class InputEmbedder(object):
|
||||
def __init__(self, input_size, activation_function=tf.nn.relu, name="embedder"):
|
||||
self.name = name
|
||||
self.input_size = input_size
|
||||
|
||||
@@ -28,7 +28,7 @@ def normalized_columns_initializer(std=1.0):
|
||||
return _initializer
|
||||
|
||||
|
||||
class Head:
|
||||
class Head(object):
|
||||
def __init__(self, tuning_parameters, head_idx=0, loss_weight=1., is_local=True):
|
||||
self.head_idx = head_idx
|
||||
self.name = "head"
|
||||
|
||||
@@ -18,7 +18,7 @@ import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MiddlewareEmbedder:
|
||||
class MiddlewareEmbedder(object):
|
||||
def __init__(self, activation_function=tf.nn.relu, name="middleware_embedder"):
|
||||
self.name = name
|
||||
self.input = None
|
||||
|
||||
Reference in New Issue
Block a user