1
0
mirror of https://github.com/gryf/coach.git synced 2026-02-28 13:25:46 +01:00

update of api docstrings across coach and tutorials [WIP] (#91)

* updating the documentation website
* adding the built docs
* update of api docstrings across coach and tutorials 0-2
* added some missing api documentation
* New Sphinx based documentation
This commit is contained in:
Itai Caspi
2018-11-15 15:00:13 +02:00
committed by Gal Novik
parent 524f8436a2
commit 6d40ad1650
517 changed files with 71034 additions and 12834 deletions

View File

@@ -57,7 +57,7 @@ class Architecture(object):
:param initial_feed_dict: a dictionary of extra inputs for forward pass.
:return: predictions of action or value of shape (batch_size, action_space_size) for action predictions)
"""
pass
raise NotImplementedError
@staticmethod
def parallel_predict(sess: Any,
@@ -68,7 +68,7 @@ class Architecture(object):
:param network_input_tuples: tuple of network and corresponding input
:return: list or tuple of outputs from all networks
"""
pass
raise NotImplementedError
def train_on_batch(self,
inputs: Dict[str, np.ndarray],
@@ -102,7 +102,7 @@ class Architecture(object):
norm_unclippsed_grads (float): global norm of all gradients before any gradient clipping is applied
fetched_tensors: all values for additional_fetches
"""
pass
raise NotImplementedError
def get_weights(self) -> List[np.ndarray]:
"""
@@ -110,7 +110,7 @@ class Architecture(object):
:return: list weights as ndarray
"""
pass
raise NotImplementedError
def set_weights(self, weights: List[np.ndarray], rate: float=1.0) -> None:
"""
@@ -121,7 +121,7 @@ class Architecture(object):
i.e. new_weight = rate * given_weight + (1 - rate) * old_weight
:return: None
"""
pass
raise NotImplementedError
def reset_accumulated_gradients(self) -> None:
"""
@@ -130,7 +130,7 @@ class Architecture(object):
Once gradients are reset, they must be accessible by `accumulated_gradients` property of this class,
which must return a list of numpy ndarrays. Child class must ensure that `accumulated_gradients` is set.
"""
pass
raise NotImplementedError
def accumulate_gradients(self,
inputs: Dict[str, np.ndarray],
@@ -166,7 +166,7 @@ class Architecture(object):
norm_unclippsed_grads (float): global norm of all gradients before any gradient clipping is applied
fetched_tensors: all values for additional_fetches
"""
pass
raise NotImplementedError
def apply_and_reset_gradients(self, gradients: List[np.ndarray], scaler: float=1.) -> None:
"""
@@ -177,7 +177,7 @@ class Architecture(object):
of an identical network (either self or another identical network)
:param scaler: A scaling factor that allows rescaling the gradients before applying them
"""
pass
raise NotImplementedError
def apply_gradients(self, gradients: List[np.ndarray], scaler: float=1.) -> None:
"""
@@ -188,7 +188,7 @@ class Architecture(object):
of an identical network (either self or another identical network)
:param scaler: A scaling factor that allows rescaling the gradients before applying them
"""
pass
raise NotImplementedError
def get_variable_value(self, variable: Any) -> np.ndarray:
"""
@@ -199,7 +199,7 @@ class Architecture(object):
:param variable: variable of interest
:return: value of the specified variable
"""
pass
raise NotImplementedError
def set_variable_value(self, assign_op: Any, value: np.ndarray, placeholder: Any):
"""
@@ -212,4 +212,4 @@ class Architecture(object):
:param value: value of the specified variable used for update
:param placeholder: a placeholder for binding the value to assign_op.
"""
pass
raise NotImplementedError

View File

@@ -34,7 +34,11 @@ except ImportError:
class NetworkWrapper(object):
"""
Contains multiple networks and managers syncing and gradient updates
The network wrapper contains multiple copies of the same network, each one with a different set of weights which is
updating in a different time scale. The network wrapper will always contain an online network.
It will contain an additional slow updating target network if it was requested by the user,
and it will contain a global network shared between different workers, if Coach is run in a single-node
multi-process distributed mode. The network wrapper contains functionality for managing these networks and syncing
between them.
"""
def __init__(self, agent_parameters: AgentParameters, has_target: bool, has_global: bool, name: str,
@@ -98,6 +102,7 @@ class NetworkWrapper(object):
def sync(self):
"""
Initializes the weights of the networks to match each other
:return:
"""
self.update_online_network()
@@ -106,6 +111,7 @@ class NetworkWrapper(object):
def update_target_network(self, rate=1.0):
"""
Copy weights: online network >>> target network
:param rate: the rate of copying the weights - 1 for copying exactly
"""
if self.target_network:
@@ -114,6 +120,7 @@ class NetworkWrapper(object):
def update_online_network(self, rate=1.0):
"""
Copy weights: global network >>> online network
:param rate: the rate of copying the weights - 1 for copying exactly
"""
if self.global_network:
@@ -122,6 +129,7 @@ class NetworkWrapper(object):
def apply_gradients_to_global_network(self, gradients=None):
"""
Apply gradients from the online network on the global network
:param gradients: optional gradients that will be used instead of teh accumulated gradients
:return:
"""
@@ -135,6 +143,7 @@ class NetworkWrapper(object):
def apply_gradients_to_online_network(self, gradients=None):
"""
Apply gradients from the online network on itself
:return:
"""
if gradients is None:
@@ -144,6 +153,7 @@ class NetworkWrapper(object):
def train_and_sync_networks(self, inputs, targets, additional_fetches=[], importance_weights=None):
"""
A generic training function that enables multi-threading training using a global network if necessary.
:param inputs: The inputs for the network.
:param targets: The targets corresponding to the given inputs
:param additional_fetches: Any additional tensor the user wants to fetch
@@ -160,6 +170,7 @@ class NetworkWrapper(object):
"""
Applies the gradients accumulated in the online network to the global network or to itself and syncs the
networks if necessary
:param reset_gradients: If set to True, the accumulated gradients wont be reset to 0 after applying them to
the network. this is useful when the accumulated gradients are overwritten instead
if accumulated by the accumulate_gradients function. this allows reducing time
@@ -179,6 +190,7 @@ class NetworkWrapper(object):
def parallel_prediction(self, network_input_tuples: List[Tuple]):
"""
Run several network prediction in parallel. Currently this only supports running each of the network once.
:param network_input_tuples: a list of tuples where the first element is the network (online_network,
target_network or global_network) and the second element is the inputs
:return: the outputs of all the networks in the same order as the inputs were given
@@ -188,6 +200,7 @@ class NetworkWrapper(object):
def get_local_variables(self):
"""
Get all the variables that are local to the thread
:return: a list of all the variables that are local to the thread
"""
local_variables = [v for v in tf.local_variables() if self.online_network.name in v.name]
@@ -198,6 +211,7 @@ class NetworkWrapper(object):
def get_global_variables(self):
"""
Get all the variables that are shared between threads
:return: a list of all the variables that are shared between threads
"""
global_variables = [v for v in tf.global_variables() if self.global_network.name in v.name]
@@ -206,6 +220,7 @@ class NetworkWrapper(object):
def set_is_training(self, state: bool):
"""
Set the phase of the network between training and testing
:param state: The current state (True = Training, False = Testing)
:return: None
"""

View File

@@ -14,7 +14,7 @@
# limitations under the License.
#
from typing import List, Union
from typing import List, Union, Tuple
import copy
import numpy as np
@@ -74,7 +74,12 @@ class InputEmbedder(object):
activation_function=self.activation_function,
dropout_rate=self.dropout_rate))
def __call__(self, prev_input_placeholder=None):
def __call__(self, prev_input_placeholder: tf.placeholder=None) -> Tuple[tf.Tensor, tf.Tensor]:
"""
Wrapper for building the module graph including scoping and loss creation
:param prev_input_placeholder: the input to the graph
:return: the input placeholder and the output of the last layer
"""
with tf.variable_scope(self.get_name()):
if prev_input_placeholder is None:
self.input = tf.placeholder("float", shape=[None] + self.input_size, name=self.get_name())
@@ -84,7 +89,13 @@ class InputEmbedder(object):
return self.input, self.output
def _build_module(self):
def _build_module(self) -> None:
"""
Builds the graph of the module
This method is called early on from __call__. It is expected to store the graph
in self.output.
:return: None
"""
# NOTE: for image inputs, we expect the data format to be of type uint8, so to be memory efficient. we chose not
# to implement the rescaling as an input filters.observation.observation_filter, as this would have caused the
# input to the network to be float, which is 4x more expensive in memory.
@@ -127,7 +138,11 @@ class InputEmbedder(object):
raise NotImplementedError("Inheriting embedder must define schemes matching its allowed default "
"configurations.")
def get_name(self):
def get_name(self) -> str:
"""
Get a formatted name for the module
:return: the formatted name
"""
return self.name
def __str__(self):

View File

@@ -14,7 +14,7 @@
# limitations under the License.
#
import copy
from typing import Union
from typing import Union, Tuple
import tensorflow as tf
@@ -64,17 +64,33 @@ class Middleware(object):
activation_function=self.activation_function,
dropout_rate=self.dropout_rate))
def __call__(self, input_layer):
def __call__(self, input_layer: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
"""
Wrapper for building the module graph including scoping and loss creation
:param input_layer: the input to the graph
:return: the input placeholder and the output of the last layer
"""
with tf.variable_scope(self.get_name()):
self.input = input_layer
self._build_module()
return self.input, self.output
def _build_module(self):
def _build_module(self) -> None:
"""
Builds the graph of the module
This method is called early on from __call__. It is expected to store the graph
in self.output.
:param input_layer: the input to the graph
:return: None
"""
pass
def get_name(self):
def get_name(self) -> str:
"""
Get a formatted name for the module
:return: the formatted name
"""
return self.name
@property