mirror of
https://github.com/gryf/coach.git
synced 2026-02-28 13:25:46 +01:00
update of api docstrings across coach and tutorials [WIP] (#91)
* updating the documentation website * adding the built docs * update of api docstrings across coach and tutorials 0-2 * added some missing api documentation * New Sphinx based documentation
This commit is contained in:
@@ -57,7 +57,7 @@ class Architecture(object):
|
||||
:param initial_feed_dict: a dictionary of extra inputs for forward pass.
|
||||
:return: predictions of action or value of shape (batch_size, action_space_size) for action predictions)
|
||||
"""
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def parallel_predict(sess: Any,
|
||||
@@ -68,7 +68,7 @@ class Architecture(object):
|
||||
:param network_input_tuples: tuple of network and corresponding input
|
||||
:return: list or tuple of outputs from all networks
|
||||
"""
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
def train_on_batch(self,
|
||||
inputs: Dict[str, np.ndarray],
|
||||
@@ -102,7 +102,7 @@ class Architecture(object):
|
||||
norm_unclippsed_grads (float): global norm of all gradients before any gradient clipping is applied
|
||||
fetched_tensors: all values for additional_fetches
|
||||
"""
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
def get_weights(self) -> List[np.ndarray]:
|
||||
"""
|
||||
@@ -110,7 +110,7 @@ class Architecture(object):
|
||||
|
||||
:return: list weights as ndarray
|
||||
"""
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
def set_weights(self, weights: List[np.ndarray], rate: float=1.0) -> None:
|
||||
"""
|
||||
@@ -121,7 +121,7 @@ class Architecture(object):
|
||||
i.e. new_weight = rate * given_weight + (1 - rate) * old_weight
|
||||
:return: None
|
||||
"""
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
def reset_accumulated_gradients(self) -> None:
|
||||
"""
|
||||
@@ -130,7 +130,7 @@ class Architecture(object):
|
||||
Once gradients are reset, they must be accessible by `accumulated_gradients` property of this class,
|
||||
which must return a list of numpy ndarrays. Child class must ensure that `accumulated_gradients` is set.
|
||||
"""
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
def accumulate_gradients(self,
|
||||
inputs: Dict[str, np.ndarray],
|
||||
@@ -166,7 +166,7 @@ class Architecture(object):
|
||||
norm_unclippsed_grads (float): global norm of all gradients before any gradient clipping is applied
|
||||
fetched_tensors: all values for additional_fetches
|
||||
"""
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
def apply_and_reset_gradients(self, gradients: List[np.ndarray], scaler: float=1.) -> None:
|
||||
"""
|
||||
@@ -177,7 +177,7 @@ class Architecture(object):
|
||||
of an identical network (either self or another identical network)
|
||||
:param scaler: A scaling factor that allows rescaling the gradients before applying them
|
||||
"""
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
def apply_gradients(self, gradients: List[np.ndarray], scaler: float=1.) -> None:
|
||||
"""
|
||||
@@ -188,7 +188,7 @@ class Architecture(object):
|
||||
of an identical network (either self or another identical network)
|
||||
:param scaler: A scaling factor that allows rescaling the gradients before applying them
|
||||
"""
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
def get_variable_value(self, variable: Any) -> np.ndarray:
|
||||
"""
|
||||
@@ -199,7 +199,7 @@ class Architecture(object):
|
||||
:param variable: variable of interest
|
||||
:return: value of the specified variable
|
||||
"""
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
def set_variable_value(self, assign_op: Any, value: np.ndarray, placeholder: Any):
|
||||
"""
|
||||
@@ -212,4 +212,4 @@ class Architecture(object):
|
||||
:param value: value of the specified variable used for update
|
||||
:param placeholder: a placeholder for binding the value to assign_op.
|
||||
"""
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -34,7 +34,11 @@ except ImportError:
|
||||
|
||||
class NetworkWrapper(object):
|
||||
"""
|
||||
Contains multiple networks and managers syncing and gradient updates
|
||||
The network wrapper contains multiple copies of the same network, each one with a different set of weights which is
|
||||
updating in a different time scale. The network wrapper will always contain an online network.
|
||||
It will contain an additional slow updating target network if it was requested by the user,
|
||||
and it will contain a global network shared between different workers, if Coach is run in a single-node
|
||||
multi-process distributed mode. The network wrapper contains functionality for managing these networks and syncing
|
||||
between them.
|
||||
"""
|
||||
def __init__(self, agent_parameters: AgentParameters, has_target: bool, has_global: bool, name: str,
|
||||
@@ -98,6 +102,7 @@ class NetworkWrapper(object):
|
||||
def sync(self):
|
||||
"""
|
||||
Initializes the weights of the networks to match each other
|
||||
|
||||
:return:
|
||||
"""
|
||||
self.update_online_network()
|
||||
@@ -106,6 +111,7 @@ class NetworkWrapper(object):
|
||||
def update_target_network(self, rate=1.0):
|
||||
"""
|
||||
Copy weights: online network >>> target network
|
||||
|
||||
:param rate: the rate of copying the weights - 1 for copying exactly
|
||||
"""
|
||||
if self.target_network:
|
||||
@@ -114,6 +120,7 @@ class NetworkWrapper(object):
|
||||
def update_online_network(self, rate=1.0):
|
||||
"""
|
||||
Copy weights: global network >>> online network
|
||||
|
||||
:param rate: the rate of copying the weights - 1 for copying exactly
|
||||
"""
|
||||
if self.global_network:
|
||||
@@ -122,6 +129,7 @@ class NetworkWrapper(object):
|
||||
def apply_gradients_to_global_network(self, gradients=None):
|
||||
"""
|
||||
Apply gradients from the online network on the global network
|
||||
|
||||
:param gradients: optional gradients that will be used instead of teh accumulated gradients
|
||||
:return:
|
||||
"""
|
||||
@@ -135,6 +143,7 @@ class NetworkWrapper(object):
|
||||
def apply_gradients_to_online_network(self, gradients=None):
|
||||
"""
|
||||
Apply gradients from the online network on itself
|
||||
|
||||
:return:
|
||||
"""
|
||||
if gradients is None:
|
||||
@@ -144,6 +153,7 @@ class NetworkWrapper(object):
|
||||
def train_and_sync_networks(self, inputs, targets, additional_fetches=[], importance_weights=None):
|
||||
"""
|
||||
A generic training function that enables multi-threading training using a global network if necessary.
|
||||
|
||||
:param inputs: The inputs for the network.
|
||||
:param targets: The targets corresponding to the given inputs
|
||||
:param additional_fetches: Any additional tensor the user wants to fetch
|
||||
@@ -160,6 +170,7 @@ class NetworkWrapper(object):
|
||||
"""
|
||||
Applies the gradients accumulated in the online network to the global network or to itself and syncs the
|
||||
networks if necessary
|
||||
|
||||
:param reset_gradients: If set to True, the accumulated gradients wont be reset to 0 after applying them to
|
||||
the network. this is useful when the accumulated gradients are overwritten instead
|
||||
if accumulated by the accumulate_gradients function. this allows reducing time
|
||||
@@ -179,6 +190,7 @@ class NetworkWrapper(object):
|
||||
def parallel_prediction(self, network_input_tuples: List[Tuple]):
|
||||
"""
|
||||
Run several network prediction in parallel. Currently this only supports running each of the network once.
|
||||
|
||||
:param network_input_tuples: a list of tuples where the first element is the network (online_network,
|
||||
target_network or global_network) and the second element is the inputs
|
||||
:return: the outputs of all the networks in the same order as the inputs were given
|
||||
@@ -188,6 +200,7 @@ class NetworkWrapper(object):
|
||||
def get_local_variables(self):
|
||||
"""
|
||||
Get all the variables that are local to the thread
|
||||
|
||||
:return: a list of all the variables that are local to the thread
|
||||
"""
|
||||
local_variables = [v for v in tf.local_variables() if self.online_network.name in v.name]
|
||||
@@ -198,6 +211,7 @@ class NetworkWrapper(object):
|
||||
def get_global_variables(self):
|
||||
"""
|
||||
Get all the variables that are shared between threads
|
||||
|
||||
:return: a list of all the variables that are shared between threads
|
||||
"""
|
||||
global_variables = [v for v in tf.global_variables() if self.global_network.name in v.name]
|
||||
@@ -206,6 +220,7 @@ class NetworkWrapper(object):
|
||||
def set_is_training(self, state: bool):
|
||||
"""
|
||||
Set the phase of the network between training and testing
|
||||
|
||||
:param state: The current state (True = Training, False = Testing)
|
||||
:return: None
|
||||
"""
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import List, Union
|
||||
from typing import List, Union, Tuple
|
||||
import copy
|
||||
|
||||
import numpy as np
|
||||
@@ -74,7 +74,12 @@ class InputEmbedder(object):
|
||||
activation_function=self.activation_function,
|
||||
dropout_rate=self.dropout_rate))
|
||||
|
||||
def __call__(self, prev_input_placeholder=None):
|
||||
def __call__(self, prev_input_placeholder: tf.placeholder=None) -> Tuple[tf.Tensor, tf.Tensor]:
|
||||
"""
|
||||
Wrapper for building the module graph including scoping and loss creation
|
||||
:param prev_input_placeholder: the input to the graph
|
||||
:return: the input placeholder and the output of the last layer
|
||||
"""
|
||||
with tf.variable_scope(self.get_name()):
|
||||
if prev_input_placeholder is None:
|
||||
self.input = tf.placeholder("float", shape=[None] + self.input_size, name=self.get_name())
|
||||
@@ -84,7 +89,13 @@ class InputEmbedder(object):
|
||||
|
||||
return self.input, self.output
|
||||
|
||||
def _build_module(self):
|
||||
def _build_module(self) -> None:
|
||||
"""
|
||||
Builds the graph of the module
|
||||
This method is called early on from __call__. It is expected to store the graph
|
||||
in self.output.
|
||||
:return: None
|
||||
"""
|
||||
# NOTE: for image inputs, we expect the data format to be of type uint8, so to be memory efficient. we chose not
|
||||
# to implement the rescaling as an input filters.observation.observation_filter, as this would have caused the
|
||||
# input to the network to be float, which is 4x more expensive in memory.
|
||||
@@ -127,7 +138,11 @@ class InputEmbedder(object):
|
||||
raise NotImplementedError("Inheriting embedder must define schemes matching its allowed default "
|
||||
"configurations.")
|
||||
|
||||
def get_name(self):
|
||||
def get_name(self) -> str:
|
||||
"""
|
||||
Get a formatted name for the module
|
||||
:return: the formatted name
|
||||
"""
|
||||
return self.name
|
||||
|
||||
def __str__(self):
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import copy
|
||||
from typing import Union
|
||||
from typing import Union, Tuple
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
@@ -64,17 +64,33 @@ class Middleware(object):
|
||||
activation_function=self.activation_function,
|
||||
dropout_rate=self.dropout_rate))
|
||||
|
||||
def __call__(self, input_layer):
|
||||
def __call__(self, input_layer: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
|
||||
"""
|
||||
Wrapper for building the module graph including scoping and loss creation
|
||||
:param input_layer: the input to the graph
|
||||
:return: the input placeholder and the output of the last layer
|
||||
"""
|
||||
with tf.variable_scope(self.get_name()):
|
||||
self.input = input_layer
|
||||
self._build_module()
|
||||
|
||||
return self.input, self.output
|
||||
|
||||
def _build_module(self):
|
||||
def _build_module(self) -> None:
|
||||
"""
|
||||
Builds the graph of the module
|
||||
This method is called early on from __call__. It is expected to store the graph
|
||||
in self.output.
|
||||
:param input_layer: the input to the graph
|
||||
:return: None
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_name(self):
|
||||
def get_name(self) -> str:
|
||||
"""
|
||||
Get a formatted name for the module
|
||||
:return: the formatted name
|
||||
"""
|
||||
return self.name
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user