1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00
This commit is contained in:
Gal Leibovich
2019-06-16 11:11:21 +03:00
committed by GitHub
parent 8df3c46756
commit 7eb884c5b2
107 changed files with 2200 additions and 495 deletions

View File

@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Type
import numpy as np
import tensorflow as tf
@@ -22,7 +21,7 @@ from rl_coach.architectures.tensorflow_components.layers import Dense, convert_l
from rl_coach.base_parameters import AgentParameters
from rl_coach.spaces import SpacesDefinition
from rl_coach.utils import force_list
from rl_coach.architectures.tensorflow_components.utils import squeeze_tensor
# Used to initialize weights for policy and value output layers
def normalized_columns_initializer(std=1.0):
@@ -72,8 +71,9 @@ class Head(object):
:param input_layer: the input to the graph
:return: the output of the last layer and the target placeholder
"""
with tf.variable_scope(self.get_name(), initializer=tf.contrib.layers.xavier_initializer()):
self._build_module(input_layer)
self._build_module(squeeze_tensor(input_layer))
self.output = force_list(self.output)
self.target = force_list(self.target)