mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
imporved API for getting / setting variables within the graph
This commit is contained in:
@@ -16,7 +16,6 @@
|
||||
|
||||
from agents.actor_critic_agent import *
|
||||
from random import shuffle
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
# Proximal Policy Optimization - https://arxiv.org/pdf/1707.06347.pdf
|
||||
@@ -35,13 +34,6 @@ class PPOAgent(ActorCriticAgent):
|
||||
self.replicated_device, self.worker_device)
|
||||
self.networks.append(self.policy_network)
|
||||
|
||||
# operations for changing the kl coefficient
|
||||
self.kl_coefficient = tf.placeholder('float', name='kl_coefficient')
|
||||
self.increase_kl_coefficient = tf.assign(self.policy_network.online_network.output_heads[0].kl_coefficient,
|
||||
self.kl_coefficient * 1.5)
|
||||
self.decrease_kl_coefficient = tf.assign(self.policy_network.online_network.output_heads[0].kl_coefficient,
|
||||
self.kl_coefficient / 1.5)
|
||||
|
||||
# signals definition
|
||||
self.value_loss = Signal('Value Loss')
|
||||
self.signals.append(self.value_loss)
|
||||
@@ -180,7 +172,7 @@ class PPOAgent(ActorCriticAgent):
|
||||
loss[key] = np.mean(loss[key], 0)
|
||||
|
||||
if self.tp.learning_rate_decay_rate != 0:
|
||||
curr_learning_rate = self.tp.sess.run(self.tp.learning_rate)
|
||||
curr_learning_rate = self.main_network.online_network.get_variable_value(self.tp.learning_rate)
|
||||
self.curr_learning_rate.add_sample(curr_learning_rate)
|
||||
else:
|
||||
curr_learning_rate = self.tp.learning_rate
|
||||
@@ -209,15 +201,24 @@ class PPOAgent(ActorCriticAgent):
|
||||
|
||||
# update kl coefficient
|
||||
kl_target = self.tp.agent.target_kl_divergence
|
||||
kl_coefficient = self.tp.sess.run(self.policy_network.online_network.output_heads[0].kl_coefficient)
|
||||
kl_coefficient = self.policy_network.online_network.get_variable_value(
|
||||
self.policy_network.online_network.output_heads[0].kl_coefficient)
|
||||
new_kl_coefficient = kl_coefficient
|
||||
if self.total_kl_divergence_during_training_process > 1.3 * kl_target:
|
||||
# kl too high => increase regularization
|
||||
self.tp.sess.run(self.increase_kl_coefficient, feed_dict={self.kl_coefficient: kl_coefficient})
|
||||
new_kl_coefficient *= 1.5
|
||||
elif self.total_kl_divergence_during_training_process < 0.7 * kl_target:
|
||||
# kl too low => decrease regularization
|
||||
self.tp.sess.run(self.decrease_kl_coefficient, feed_dict={self.kl_coefficient: kl_coefficient})
|
||||
screen.log_title("KL penalty coefficient change = {} -> {}".format(
|
||||
kl_coefficient, self.tp.sess.run(self.policy_network.online_network.output_heads[0].kl_coefficient)))
|
||||
new_kl_coefficient /= 1.5
|
||||
|
||||
# update the kl coefficient variable
|
||||
if kl_coefficient != new_kl_coefficient:
|
||||
self.policy_network.online_network.set_variable_value(
|
||||
self.policy_network.online_network.output_heads[0].assign_kl_coefficient,
|
||||
new_kl_coefficient,
|
||||
self.policy_network.online_network.output_heads[0].kl_coefficient_ph)
|
||||
|
||||
screen.log_title("KL penalty coefficient change = {} -> {}".format(kl_coefficient, new_kl_coefficient))
|
||||
|
||||
def post_training_commands(self):
|
||||
if self.tp.agent.use_kl_regularization:
|
||||
|
||||
Reference in New Issue
Block a user