mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
TD3 (#338)
This commit is contained in:
@@ -28,23 +28,28 @@ class FCMiddleware(Middleware):
|
||||
def __init__(self, activation_function=tf.nn.relu,
|
||||
scheme: MiddlewareScheme = MiddlewareScheme.Medium,
|
||||
batchnorm: bool = False, dropout_rate: float = 0.0,
|
||||
name="middleware_fc_embedder", dense_layer=Dense, is_training=False):
|
||||
name="middleware_fc_embedder", dense_layer=Dense, is_training=False, num_streams: int = 1):
|
||||
super().__init__(activation_function=activation_function, batchnorm=batchnorm,
|
||||
dropout_rate=dropout_rate, scheme=scheme, name=name, dense_layer=dense_layer,
|
||||
is_training=is_training)
|
||||
self.return_type = Middleware_FC_Embedding
|
||||
self.layers = []
|
||||
|
||||
assert(isinstance(num_streams, int) and num_streams >= 1)
|
||||
self.num_streams = num_streams
|
||||
|
||||
def _build_module(self):
|
||||
self.layers.append(self.input)
|
||||
self.output = []
|
||||
|
||||
for idx, layer_params in enumerate(self.layers_params):
|
||||
self.layers.extend(force_list(
|
||||
layer_params(self.layers[-1], name='{}_{}'.format(layer_params.__class__.__name__, idx),
|
||||
is_training=self.is_training)
|
||||
))
|
||||
for stream_idx in range(self.num_streams):
|
||||
layers = [self.input]
|
||||
|
||||
self.output = self.layers[-1]
|
||||
for idx, layer_params in enumerate(self.layers_params):
|
||||
layers.extend(force_list(
|
||||
layer_params(layers[-1], name='{}_{}'.format(layer_params.__class__.__name__,
|
||||
idx + stream_idx * len(self.layers_params)),
|
||||
is_training=self.is_training)
|
||||
))
|
||||
self.output.append((layers[-1]))
|
||||
|
||||
@property
|
||||
def schemes(self):
|
||||
@@ -72,3 +77,15 @@ class FCMiddleware(Middleware):
|
||||
]
|
||||
}
|
||||
|
||||
def __str__(self):
|
||||
stream = [str(l) for l in self.layers_params]
|
||||
if self.layers_params:
|
||||
if self.num_streams > 1:
|
||||
stream = [''] + ['\t' + l for l in stream]
|
||||
result = stream * self.num_streams
|
||||
result[0::len(stream)] = ['Stream {}'.format(i) for i in range(self.num_streams)]
|
||||
else:
|
||||
result = stream
|
||||
return '\n'.join(result)
|
||||
else:
|
||||
return 'No layers'
|
||||
|
||||
Reference in New Issue
Block a user