1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

fixes to rainbow dqn + a cartpole based golden test (#253)

This commit is contained in:
Gal Leibovich
2019-03-21 12:57:56 +02:00
committed by GitHub
parent 83741fa92a
commit abec59f367
6 changed files with 127 additions and 23 deletions

View File

@@ -18,8 +18,8 @@ from typing import Type
import numpy as np
import tensorflow as tf
from tensorflow.python.ops.losses.losses_impl import Reduction
from rl_coach.architectures.tensorflow_components.layers import Dense
from rl_coach.base_parameters import AgentParameters, Parameters, NetworkComponentParameters
from rl_coach.architectures.tensorflow_components.layers import Dense, convert_layer_class
from rl_coach.base_parameters import AgentParameters
from rl_coach.spaces import SpacesDefinition
from rl_coach.utils import force_list
@@ -63,6 +63,8 @@ class Head(object):
self.dense_layer = dense_layer
if self.dense_layer is None:
self.dense_layer = Dense
else:
self.dense_layer = convert_layer_class(self.dense_layer)
def __call__(self, input_layer):
"""