mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
network_imporvements branch merge
This commit is contained in:
@@ -21,7 +21,7 @@ from typing import Union
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.agents.agent import Agent
|
||||
from rl_coach.architectures.tensorflow_components.architecture import Conv2d, Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import Conv2d, Dense
|
||||
from rl_coach.architectures.tensorflow_components.heads.measurements_prediction_head import \
|
||||
MeasurementsPredictionHeadParameters
|
||||
from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters
|
||||
@@ -48,28 +48,27 @@ class DFPNetworkParameters(NetworkParameters):
|
||||
'goal': InputEmbedderParameters(activation_function='leaky_relu')}
|
||||
|
||||
self.input_embedders_parameters['observation'].scheme = [
|
||||
Conv2d([32, 8, 4]),
|
||||
Conv2d([64, 4, 2]),
|
||||
Conv2d([64, 3, 1]),
|
||||
Dense([512]),
|
||||
Conv2d(32, 8, 4),
|
||||
Conv2d(64, 4, 2),
|
||||
Conv2d(64, 3, 1),
|
||||
Dense(512),
|
||||
]
|
||||
|
||||
self.input_embedders_parameters['measurements'].scheme = [
|
||||
Dense([128]),
|
||||
Dense([128]),
|
||||
Dense([128]),
|
||||
Dense(128),
|
||||
Dense(128),
|
||||
Dense(128),
|
||||
]
|
||||
|
||||
self.input_embedders_parameters['goal'].scheme = [
|
||||
Dense([128]),
|
||||
Dense([128]),
|
||||
Dense([128]),
|
||||
Dense(128),
|
||||
Dense(128),
|
||||
Dense(128),
|
||||
]
|
||||
|
||||
self.middleware_parameters = FCMiddlewareParameters(activation_function='leaky_relu',
|
||||
scheme=MiddlewareScheme.Empty)
|
||||
self.heads_parameters = [MeasurementsPredictionHeadParameters(activation_function='leaky_relu')]
|
||||
self.loss_weights = [1.0]
|
||||
self.async_training = False
|
||||
self.batch_size = 64
|
||||
self.adam_optimizer_beta1 = 0.95
|
||||
|
||||
Reference in New Issue
Block a user