mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
load and save function for non-episodic replay buffers + carla improvements + network bug fixes
This commit is contained in:
@@ -61,7 +61,7 @@ class Conv2d(object):
|
||||
"""
|
||||
self.params = params
|
||||
|
||||
def __call__(self, input_layer, name: str):
|
||||
def __call__(self, input_layer, name: str=None):
|
||||
"""
|
||||
returns a tensorflow conv2d layer
|
||||
:param input_layer: previous layer
|
||||
@@ -79,7 +79,7 @@ class Dense(object):
|
||||
"""
|
||||
self.params = force_list(params)
|
||||
|
||||
def __call__(self, input_layer, name: str, kernel_initializer=None, activation=None):
|
||||
def __call__(self, input_layer, name: str=None, kernel_initializer=None, activation=None):
|
||||
"""
|
||||
returns a tensorflow dense layer
|
||||
:param input_layer: previous layer
|
||||
|
||||
@@ -253,9 +253,11 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
else:
|
||||
# if we use a single network with multiple embedders, then the head type is the current head idx
|
||||
head_type_idx = head_idx
|
||||
|
||||
# create output head and add it to the output heads list
|
||||
self.output_heads.append(
|
||||
self.get_output_head(self.network_parameters.heads_parameters[head_type_idx],
|
||||
head_copy_idx,
|
||||
head_idx*self.network_parameters.num_output_head_copies + head_copy_idx,
|
||||
self.network_parameters.loss_weights[head_type_idx])
|
||||
)
|
||||
|
||||
|
||||
@@ -59,8 +59,8 @@ class Head(object):
|
||||
self.loss = []
|
||||
self.loss_type = []
|
||||
self.regularizations = []
|
||||
# self.loss_weight = force_list(loss_weight)
|
||||
self.loss_weight = tf.Variable(force_list(loss_weight), trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||
self.loss_weight = tf.Variable([float(w) for w in force_list(loss_weight)],
|
||||
trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||
self.loss_weight_placeholder = tf.placeholder("float")
|
||||
self.set_loss_weight = tf.assign(self.loss_weight, self.loss_weight_placeholder)
|
||||
self.target = []
|
||||
|
||||
@@ -49,7 +49,7 @@ class PolicyHead(Head):
|
||||
# a scalar weight that penalizes low entropy values to encourage exploration
|
||||
if hasattr(agent_parameters.algorithm, 'beta_entropy'):
|
||||
# we set the beta value as a tf variable so it can be updated later if needed
|
||||
self.beta = tf.Variable(agent_parameters.algorithm.beta_entropy,
|
||||
self.beta = tf.Variable(float(agent_parameters.algorithm.beta_entropy),
|
||||
trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||
self.beta_placeholder = tf.placeholder('float')
|
||||
self.set_beta = tf.assign(self.beta, self.beta_placeholder)
|
||||
|
||||
Reference in New Issue
Block a user