1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

network_imporvements branch merge

This commit is contained in:
Shadi Endrawis
2018-10-02 13:41:46 +03:00
parent 72ea933384
commit 51726a5b80
110 changed files with 1639 additions and 1161 deletions

View File

@@ -18,7 +18,7 @@ from typing import List
import tensorflow as tf
from rl_coach.architectures.tensorflow_components.architecture import Conv2d, Dense
from rl_coach.architectures.tensorflow_components.layers import Conv2d, Dense
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedder
from rl_coach.base_parameters import EmbedderScheme
from rl_coach.core_types import InputImageEmbedding
@@ -34,9 +34,9 @@ class ImageEmbedder(InputEmbedder):
def __init__(self, input_size: List[int], activation_function=tf.nn.relu,
scheme: EmbedderScheme=EmbedderScheme.Medium, batchnorm: bool=False, dropout: bool=False,
name: str= "embedder", input_rescaling: float=255.0, input_offset: float=0.0, input_clipping=None,
dense_layer=Dense):
dense_layer=Dense, is_training=False):
super().__init__(input_size, activation_function, scheme, batchnorm, dropout, name, input_rescaling,
input_offset, input_clipping, dense_layer=dense_layer)
input_offset, input_clipping, dense_layer=dense_layer, is_training=is_training)
self.return_type = InputImageEmbedding
if len(input_size) != 3 and scheme != EmbedderScheme.Empty:
raise ValueError("Image embedders expect the input size to have 3 dimensions. The given size is: {}"
@@ -50,28 +50,28 @@ class ImageEmbedder(InputEmbedder):
EmbedderScheme.Shallow:
[
Conv2d([32, 3, 1])
Conv2d(32, 3, 1)
],
# atari dqn
EmbedderScheme.Medium:
[
Conv2d([32, 8, 4]),
Conv2d([64, 4, 2]),
Conv2d([64, 3, 1])
Conv2d(32, 8, 4),
Conv2d(64, 4, 2),
Conv2d(64, 3, 1)
],
# carla
EmbedderScheme.Deep: \
[
Conv2d([32, 5, 2]),
Conv2d([32, 3, 1]),
Conv2d([64, 3, 2]),
Conv2d([64, 3, 1]),
Conv2d([128, 3, 2]),
Conv2d([128, 3, 1]),
Conv2d([256, 3, 2]),
Conv2d([256, 3, 1])
Conv2d(32, 5, 2),
Conv2d(32, 3, 1),
Conv2d(64, 3, 2),
Conv2d(64, 3, 1),
Conv2d(128, 3, 2),
Conv2d(128, 3, 1),
Conv2d(256, 3, 2),
Conv2d(256, 3, 1)
]
}