mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Batch RL (#238)
This commit is contained in:
@@ -291,7 +291,8 @@ class NetworkParameters(Parameters):
|
||||
batch_size=32,
|
||||
replace_mse_with_huber_loss=False,
|
||||
create_target_network=False,
|
||||
tensorflow_support=True):
|
||||
tensorflow_support=True,
|
||||
softmax_temperature=1):
|
||||
"""
|
||||
:param force_cpu:
|
||||
Force the neural networks to run on the CPU even if a GPU is available
|
||||
@@ -374,6 +375,8 @@ class NetworkParameters(Parameters):
|
||||
online network at will.
|
||||
:param tensorflow_support:
|
||||
A flag which specifies if the network is supported by the TensorFlow framework.
|
||||
:param softmax_temperature:
|
||||
If a softmax is present in the network head output, use this temperature
|
||||
"""
|
||||
super().__init__()
|
||||
self.framework = Frameworks.tensorflow
|
||||
@@ -404,17 +407,20 @@ class NetworkParameters(Parameters):
|
||||
self.heads_parameters = heads_parameters
|
||||
self.use_separate_networks_per_head = use_separate_networks_per_head
|
||||
self.optimizer_type = optimizer_type
|
||||
self.optimizer_epsilon = optimizer_epsilon
|
||||
self.adam_optimizer_beta1 = adam_optimizer_beta1
|
||||
self.adam_optimizer_beta2 = adam_optimizer_beta2
|
||||
self.rms_prop_optimizer_decay = rms_prop_optimizer_decay
|
||||
self.batch_size = batch_size
|
||||
self.replace_mse_with_huber_loss = replace_mse_with_huber_loss
|
||||
self.create_target_network = create_target_network
|
||||
|
||||
# Framework support
|
||||
self.tensorflow_support = tensorflow_support
|
||||
|
||||
# Hyper-Parameter values
|
||||
self.optimizer_epsilon = optimizer_epsilon
|
||||
self.adam_optimizer_beta1 = adam_optimizer_beta1
|
||||
self.adam_optimizer_beta2 = adam_optimizer_beta2
|
||||
self.rms_prop_optimizer_decay = rms_prop_optimizer_decay
|
||||
self.batch_size = batch_size
|
||||
self.softmax_temperature = softmax_temperature
|
||||
|
||||
|
||||
class NetworkComponentParameters(Parameters):
|
||||
def __init__(self, dense_layer):
|
||||
@@ -544,6 +550,7 @@ class AgentParameters(Parameters):
|
||||
self.is_a_highest_level_agent = True
|
||||
self.is_a_lowest_level_agent = True
|
||||
self.task_parameters = None
|
||||
self.is_batch_rl_training = False
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
|
||||
Reference in New Issue
Block a user