1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00
This commit is contained in:
Gal Leibovich
2019-03-19 18:07:09 +02:00
committed by GitHub
parent 4a8451ff02
commit e3c7e526c7
38 changed files with 1003 additions and 87 deletions

View File

@@ -291,7 +291,8 @@ class NetworkParameters(Parameters):
batch_size=32,
replace_mse_with_huber_loss=False,
create_target_network=False,
tensorflow_support=True):
tensorflow_support=True,
softmax_temperature=1):
"""
:param force_cpu:
Force the neural networks to run on the CPU even if a GPU is available
@@ -374,6 +375,8 @@ class NetworkParameters(Parameters):
online network at will.
:param tensorflow_support:
A flag which specifies if the network is supported by the TensorFlow framework.
:param softmax_temperature:
If a softmax is present in the network head output, use this temperature
"""
super().__init__()
self.framework = Frameworks.tensorflow
@@ -404,17 +407,20 @@ class NetworkParameters(Parameters):
self.heads_parameters = heads_parameters
self.use_separate_networks_per_head = use_separate_networks_per_head
self.optimizer_type = optimizer_type
self.optimizer_epsilon = optimizer_epsilon
self.adam_optimizer_beta1 = adam_optimizer_beta1
self.adam_optimizer_beta2 = adam_optimizer_beta2
self.rms_prop_optimizer_decay = rms_prop_optimizer_decay
self.batch_size = batch_size
self.replace_mse_with_huber_loss = replace_mse_with_huber_loss
self.create_target_network = create_target_network
# Framework support
self.tensorflow_support = tensorflow_support
# Hyper-Parameter values
self.optimizer_epsilon = optimizer_epsilon
self.adam_optimizer_beta1 = adam_optimizer_beta1
self.adam_optimizer_beta2 = adam_optimizer_beta2
self.rms_prop_optimizer_decay = rms_prop_optimizer_decay
self.batch_size = batch_size
self.softmax_temperature = softmax_temperature
class NetworkComponentParameters(Parameters):
def __init__(self, dense_layer):
@@ -544,6 +550,7 @@ class AgentParameters(Parameters):
self.is_a_highest_level_agent = True
self.is_a_lowest_level_agent = True
self.task_parameters = None
self.is_batch_rl_training = False
@property
def path(self):