mirror of
https://github.com/gryf/coach.git
synced 2026-02-27 12:45:52 +01:00
Batch RL Tutorial (#372)
This commit is contained in:
@@ -26,7 +26,7 @@ from rl_coach.spaces import SpacesDefinition
|
||||
class VHead(Head):
|
||||
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
|
||||
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu',
|
||||
dense_layer=Dense, initializer='normalized_columns'):
|
||||
dense_layer=Dense, initializer='normalized_columns', output_bias_initializer=None):
|
||||
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
|
||||
dense_layer=dense_layer)
|
||||
self.name = 'v_values_head'
|
||||
@@ -38,14 +38,17 @@ class VHead(Head):
|
||||
self.loss_type = tf.losses.mean_squared_error
|
||||
|
||||
self.initializer = initializer
|
||||
self.output_bias_initializer = output_bias_initializer
|
||||
|
||||
def _build_module(self, input_layer):
|
||||
# Standard V Network
|
||||
if self.initializer == 'normalized_columns':
|
||||
self.output = self.dense_layer(1)(input_layer, name='output',
|
||||
kernel_initializer=normalized_columns_initializer(1.0))
|
||||
kernel_initializer=normalized_columns_initializer(1.0),
|
||||
bias_initializer=self.output_bias_initializer)
|
||||
elif self.initializer == 'xavier' or self.initializer is None:
|
||||
self.output = self.dense_layer(1)(input_layer, name='output')
|
||||
self.output = self.dense_layer(1)(input_layer, name='output',
|
||||
bias_initializer=self.output_bias_initializer)
|
||||
|
||||
def __str__(self):
|
||||
result = [
|
||||
|
||||
Reference in New Issue
Block a user