1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

update of api docstrings across coach and tutorials [WIP] (#91)

* updating the documentation website
* adding the built docs
* update of api docstrings across coach and tutorials 0-2
* added some missing api documentation
* New Sphinx based documentation
This commit is contained in:
Itai Caspi
2018-11-15 15:00:13 +02:00
committed by Gal Novik
parent 524f8436a2
commit 6d40ad1650
517 changed files with 71034 additions and 12834 deletions

View File

@@ -34,7 +34,11 @@ except ImportError:
class NetworkWrapper(object):
"""
Contains multiple networks and managers syncing and gradient updates
The network wrapper contains multiple copies of the same network, each one with a different set of weights which is
updating in a different time scale. The network wrapper will always contain an online network.
It will contain an additional slow updating target network if it was requested by the user,
and it will contain a global network shared between different workers, if Coach is run in a single-node
multi-process distributed mode. The network wrapper contains functionality for managing these networks and syncing
between them.
"""
def __init__(self, agent_parameters: AgentParameters, has_target: bool, has_global: bool, name: str,
@@ -98,6 +102,7 @@ class NetworkWrapper(object):
def sync(self):
"""
Initializes the weights of the networks to match each other
:return:
"""
self.update_online_network()
@@ -106,6 +111,7 @@ class NetworkWrapper(object):
def update_target_network(self, rate=1.0):
"""
Copy weights: online network >>> target network
:param rate: the rate of copying the weights - 1 for copying exactly
"""
if self.target_network:
@@ -114,6 +120,7 @@ class NetworkWrapper(object):
def update_online_network(self, rate=1.0):
"""
Copy weights: global network >>> online network
:param rate: the rate of copying the weights - 1 for copying exactly
"""
if self.global_network:
@@ -122,6 +129,7 @@ class NetworkWrapper(object):
def apply_gradients_to_global_network(self, gradients=None):
"""
Apply gradients from the online network on the global network
:param gradients: optional gradients that will be used instead of teh accumulated gradients
:return:
"""
@@ -135,6 +143,7 @@ class NetworkWrapper(object):
def apply_gradients_to_online_network(self, gradients=None):
"""
Apply gradients from the online network on itself
:return:
"""
if gradients is None:
@@ -144,6 +153,7 @@ class NetworkWrapper(object):
def train_and_sync_networks(self, inputs, targets, additional_fetches=[], importance_weights=None):
"""
A generic training function that enables multi-threading training using a global network if necessary.
:param inputs: The inputs for the network.
:param targets: The targets corresponding to the given inputs
:param additional_fetches: Any additional tensor the user wants to fetch
@@ -160,6 +170,7 @@ class NetworkWrapper(object):
"""
Applies the gradients accumulated in the online network to the global network or to itself and syncs the
networks if necessary
:param reset_gradients: If set to True, the accumulated gradients wont be reset to 0 after applying them to
the network. this is useful when the accumulated gradients are overwritten instead
if accumulated by the accumulate_gradients function. this allows reducing time
@@ -179,6 +190,7 @@ class NetworkWrapper(object):
def parallel_prediction(self, network_input_tuples: List[Tuple]):
"""
Run several network prediction in parallel. Currently this only supports running each of the network once.
:param network_input_tuples: a list of tuples where the first element is the network (online_network,
target_network or global_network) and the second element is the inputs
:return: the outputs of all the networks in the same order as the inputs were given
@@ -188,6 +200,7 @@ class NetworkWrapper(object):
def get_local_variables(self):
"""
Get all the variables that are local to the thread
:return: a list of all the variables that are local to the thread
"""
local_variables = [v for v in tf.local_variables() if self.online_network.name in v.name]
@@ -198,6 +211,7 @@ class NetworkWrapper(object):
def get_global_variables(self):
"""
Get all the variables that are shared between threads
:return: a list of all the variables that are shared between threads
"""
global_variables = [v for v in tf.global_variables() if self.global_network.name in v.name]
@@ -206,6 +220,7 @@ class NetworkWrapper(object):
def set_is_training(self, state: bool):
"""
Set the phase of the network between training and testing
:param state: The current state (True = Training, False = Testing)
:return: None
"""