1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

network_imporvements branch merge

This commit is contained in:
Shadi Endrawis
2018-10-02 13:41:46 +03:00
parent 72ea933384
commit 51726a5b80
110 changed files with 1639 additions and 1161 deletions

View File

@@ -21,7 +21,7 @@ from typing import Union
import numpy as np
from rl_coach.agents.agent import Agent
from rl_coach.architectures.tensorflow_components.architecture import Conv2d, Dense
from rl_coach.architectures.tensorflow_components.layers import Conv2d, Dense
from rl_coach.architectures.tensorflow_components.heads.measurements_prediction_head import \
MeasurementsPredictionHeadParameters
from rl_coach.architectures.tensorflow_components.middlewares.fc_middleware import FCMiddlewareParameters
@@ -48,28 +48,27 @@ class DFPNetworkParameters(NetworkParameters):
'goal': InputEmbedderParameters(activation_function='leaky_relu')}
self.input_embedders_parameters['observation'].scheme = [
Conv2d([32, 8, 4]),
Conv2d([64, 4, 2]),
Conv2d([64, 3, 1]),
Dense([512]),
Conv2d(32, 8, 4),
Conv2d(64, 4, 2),
Conv2d(64, 3, 1),
Dense(512),
]
self.input_embedders_parameters['measurements'].scheme = [
Dense([128]),
Dense([128]),
Dense([128]),
Dense(128),
Dense(128),
Dense(128),
]
self.input_embedders_parameters['goal'].scheme = [
Dense([128]),
Dense([128]),
Dense([128]),
Dense(128),
Dense(128),
Dense(128),
]
self.middleware_parameters = FCMiddlewareParameters(activation_function='leaky_relu',
scheme=MiddlewareScheme.Empty)
self.heads_parameters = [MeasurementsPredictionHeadParameters(activation_function='leaky_relu')]
self.loss_weights = [1.0]
self.async_training = False
self.batch_size = 64
self.adam_optimizer_beta1 = 0.95