mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
network_imporvements branch merge
This commit is contained in:
@@ -199,6 +199,16 @@ class NetworkWrapper(object):
|
||||
global_variables = [v for v in tf.global_variables() if self.global_network.name in v.name]
|
||||
return global_variables
|
||||
|
||||
def set_is_training(self, state: bool):
|
||||
"""
|
||||
Set the phase of the network between training and testing
|
||||
:param state: The current state (True = Training, False = Testing)
|
||||
:return: None
|
||||
"""
|
||||
self.online_network.set_is_training(state)
|
||||
if self.has_target:
|
||||
self.target_network.set_is_training(state)
|
||||
|
||||
def set_session(self, sess):
|
||||
self.sess = sess
|
||||
self.online_network.set_session(sess)
|
||||
@@ -207,3 +217,18 @@ class NetworkWrapper(object):
|
||||
if self.target_network:
|
||||
self.target_network.set_session(sess)
|
||||
|
||||
def __str__(self):
|
||||
sub_networks = []
|
||||
if self.global_network:
|
||||
sub_networks.append("global network")
|
||||
if self.online_network:
|
||||
sub_networks.append("online network")
|
||||
if self.target_network:
|
||||
sub_networks.append("target network")
|
||||
|
||||
result = []
|
||||
result.append("Network: {}, Copies: {} ({})".format(self.name, len(sub_networks), ' | '.join(sub_networks)))
|
||||
result.append("-"*len(result[-1]))
|
||||
result.append(str(self.online_network))
|
||||
result.append("")
|
||||
return '\n'.join(result)
|
||||
|
||||
Reference in New Issue
Block a user