1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

network_imporvements branch merge

This commit is contained in:
Shadi Endrawis
2018-10-02 13:41:46 +03:00
parent 72ea933384
commit 51726a5b80
110 changed files with 1639 additions and 1161 deletions

View File

@@ -199,6 +199,16 @@ class NetworkWrapper(object):
global_variables = [v for v in tf.global_variables() if self.global_network.name in v.name]
return global_variables
def set_is_training(self, state: bool):
"""
Set the phase of the network between training and testing
:param state: The current state (True = Training, False = Testing)
:return: None
"""
self.online_network.set_is_training(state)
if self.has_target:
self.target_network.set_is_training(state)
def set_session(self, sess):
self.sess = sess
self.online_network.set_session(sess)
@@ -207,3 +217,18 @@ class NetworkWrapper(object):
if self.target_network:
self.target_network.set_session(sess)
def __str__(self):
sub_networks = []
if self.global_network:
sub_networks.append("global network")
if self.online_network:
sub_networks.append("online network")
if self.target_network:
sub_networks.append("target network")
result = []
result.append("Network: {}, Copies: {} ({})".format(self.name, len(sub_networks), ' | '.join(sub_networks)))
result.append("-"*len(result[-1]))
result.append(str(self.online_network))
result.append("")
return '\n'.join(result)