mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
@@ -21,13 +21,13 @@ from rl_coach.base_parameters import EmbedderScheme, NetworkComponentParameters
|
||||
|
||||
class InputEmbedderParameters(NetworkComponentParameters):
|
||||
def __init__(self, activation_function: str='relu', scheme: Union[List, EmbedderScheme]=EmbedderScheme.Medium,
|
||||
batchnorm: bool=False, dropout=False, name: str='embedder', input_rescaling=None, input_offset=None,
|
||||
input_clipping=None, dense_layer=None, is_training=False):
|
||||
batchnorm: bool=False, dropout_rate: float=0.0, name: str='embedder', input_rescaling=None,
|
||||
input_offset=None, input_clipping=None, dense_layer=None, is_training=False):
|
||||
super().__init__(dense_layer=dense_layer)
|
||||
self.activation_function = activation_function
|
||||
self.scheme = scheme
|
||||
self.batchnorm = batchnorm
|
||||
self.dropout = dropout
|
||||
self.dropout_rate = dropout_rate
|
||||
|
||||
if input_rescaling is None:
|
||||
input_rescaling = {'image': 255.0, 'vector': 1.0}
|
||||
|
||||
Reference in New Issue
Block a user