mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
update of api docstrings across coach and tutorials [WIP] (#91)
* updating the documentation website * adding the built docs * update of api docstrings across coach and tutorials 0-2 * added some missing api documentation * New Sphinx based documentation
This commit is contained in:
@@ -34,7 +34,11 @@ except ImportError:
|
||||
|
||||
class NetworkWrapper(object):
|
||||
"""
|
||||
Contains multiple networks and managers syncing and gradient updates
|
||||
The network wrapper contains multiple copies of the same network, each one with a different set of weights which is
|
||||
updating in a different time scale. The network wrapper will always contain an online network.
|
||||
It will contain an additional slow updating target network if it was requested by the user,
|
||||
and it will contain a global network shared between different workers, if Coach is run in a single-node
|
||||
multi-process distributed mode. The network wrapper contains functionality for managing these networks and syncing
|
||||
between them.
|
||||
"""
|
||||
def __init__(self, agent_parameters: AgentParameters, has_target: bool, has_global: bool, name: str,
|
||||
@@ -98,6 +102,7 @@ class NetworkWrapper(object):
|
||||
def sync(self):
|
||||
"""
|
||||
Initializes the weights of the networks to match each other
|
||||
|
||||
:return:
|
||||
"""
|
||||
self.update_online_network()
|
||||
@@ -106,6 +111,7 @@ class NetworkWrapper(object):
|
||||
def update_target_network(self, rate=1.0):
|
||||
"""
|
||||
Copy weights: online network >>> target network
|
||||
|
||||
:param rate: the rate of copying the weights - 1 for copying exactly
|
||||
"""
|
||||
if self.target_network:
|
||||
@@ -114,6 +120,7 @@ class NetworkWrapper(object):
|
||||
def update_online_network(self, rate=1.0):
|
||||
"""
|
||||
Copy weights: global network >>> online network
|
||||
|
||||
:param rate: the rate of copying the weights - 1 for copying exactly
|
||||
"""
|
||||
if self.global_network:
|
||||
@@ -122,6 +129,7 @@ class NetworkWrapper(object):
|
||||
def apply_gradients_to_global_network(self, gradients=None):
|
||||
"""
|
||||
Apply gradients from the online network on the global network
|
||||
|
||||
:param gradients: optional gradients that will be used instead of teh accumulated gradients
|
||||
:return:
|
||||
"""
|
||||
@@ -135,6 +143,7 @@ class NetworkWrapper(object):
|
||||
def apply_gradients_to_online_network(self, gradients=None):
|
||||
"""
|
||||
Apply gradients from the online network on itself
|
||||
|
||||
:return:
|
||||
"""
|
||||
if gradients is None:
|
||||
@@ -144,6 +153,7 @@ class NetworkWrapper(object):
|
||||
def train_and_sync_networks(self, inputs, targets, additional_fetches=[], importance_weights=None):
|
||||
"""
|
||||
A generic training function that enables multi-threading training using a global network if necessary.
|
||||
|
||||
:param inputs: The inputs for the network.
|
||||
:param targets: The targets corresponding to the given inputs
|
||||
:param additional_fetches: Any additional tensor the user wants to fetch
|
||||
@@ -160,6 +170,7 @@ class NetworkWrapper(object):
|
||||
"""
|
||||
Applies the gradients accumulated in the online network to the global network or to itself and syncs the
|
||||
networks if necessary
|
||||
|
||||
:param reset_gradients: If set to True, the accumulated gradients wont be reset to 0 after applying them to
|
||||
the network. this is useful when the accumulated gradients are overwritten instead
|
||||
if accumulated by the accumulate_gradients function. this allows reducing time
|
||||
@@ -179,6 +190,7 @@ class NetworkWrapper(object):
|
||||
def parallel_prediction(self, network_input_tuples: List[Tuple]):
|
||||
"""
|
||||
Run several network prediction in parallel. Currently this only supports running each of the network once.
|
||||
|
||||
:param network_input_tuples: a list of tuples where the first element is the network (online_network,
|
||||
target_network or global_network) and the second element is the inputs
|
||||
:return: the outputs of all the networks in the same order as the inputs were given
|
||||
@@ -188,6 +200,7 @@ class NetworkWrapper(object):
|
||||
def get_local_variables(self):
|
||||
"""
|
||||
Get all the variables that are local to the thread
|
||||
|
||||
:return: a list of all the variables that are local to the thread
|
||||
"""
|
||||
local_variables = [v for v in tf.local_variables() if self.online_network.name in v.name]
|
||||
@@ -198,6 +211,7 @@ class NetworkWrapper(object):
|
||||
def get_global_variables(self):
|
||||
"""
|
||||
Get all the variables that are shared between threads
|
||||
|
||||
:return: a list of all the variables that are shared between threads
|
||||
"""
|
||||
global_variables = [v for v in tf.global_variables() if self.global_network.name in v.name]
|
||||
@@ -206,6 +220,7 @@ class NetworkWrapper(object):
|
||||
def set_is_training(self, state: bool):
|
||||
"""
|
||||
Set the phase of the network between training and testing
|
||||
|
||||
:param state: The current state (True = Training, False = Testing)
|
||||
:return: None
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user