1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

update of api docstrings across coach and tutorials [WIP] (#91)

* updating the documentation website
* adding the built docs
* update of api docstrings across coach and tutorials 0-2
* added some missing api documentation
* New Sphinx based documentation
This commit is contained in:
Itai Caspi
2018-11-15 15:00:13 +02:00
committed by Gal Novik
parent 524f8436a2
commit 6d40ad1650
517 changed files with 71034 additions and 12834 deletions

View File

@@ -14,7 +14,7 @@
# limitations under the License.
#
from typing import List, Union
from typing import List, Union, Tuple
import copy
import numpy as np
@@ -74,7 +74,12 @@ class InputEmbedder(object):
activation_function=self.activation_function,
dropout_rate=self.dropout_rate))
def __call__(self, prev_input_placeholder=None):
def __call__(self, prev_input_placeholder: tf.placeholder=None) -> Tuple[tf.Tensor, tf.Tensor]:
"""
Wrapper for building the module graph including scoping and loss creation
:param prev_input_placeholder: the input to the graph
:return: the input placeholder and the output of the last layer
"""
with tf.variable_scope(self.get_name()):
if prev_input_placeholder is None:
self.input = tf.placeholder("float", shape=[None] + self.input_size, name=self.get_name())
@@ -84,7 +89,13 @@ class InputEmbedder(object):
return self.input, self.output
def _build_module(self):
def _build_module(self) -> None:
"""
Builds the graph of the module
This method is called early on from __call__. It is expected to store the graph
in self.output.
:return: None
"""
# NOTE: for image inputs, we expect the data format to be of type uint8, so to be memory efficient. we chose not
# to implement the rescaling as an input filters.observation.observation_filter, as this would have caused the
# input to the network to be float, which is 4x more expensive in memory.
@@ -127,7 +138,11 @@ class InputEmbedder(object):
raise NotImplementedError("Inheriting embedder must define schemes matching its allowed default "
"configurations.")
def get_name(self):
def get_name(self) -> str:
"""
Get a formatted name for the module
:return: the formatted name
"""
return self.name
def __str__(self):

View File

@@ -14,7 +14,7 @@
# limitations under the License.
#
import copy
from typing import Union
from typing import Union, Tuple
import tensorflow as tf
@@ -64,17 +64,33 @@ class Middleware(object):
activation_function=self.activation_function,
dropout_rate=self.dropout_rate))
def __call__(self, input_layer):
def __call__(self, input_layer: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
"""
Wrapper for building the module graph including scoping and loss creation
:param input_layer: the input to the graph
:return: the input placeholder and the output of the last layer
"""
with tf.variable_scope(self.get_name()):
self.input = input_layer
self._build_module()
return self.input, self.output
def _build_module(self):
def _build_module(self) -> None:
"""
Builds the graph of the module
This method is called early on from __call__. It is expected to store the graph
in self.output.
:param input_layer: the input to the graph
:return: None
"""
pass
def get_name(self):
def get_name(self) -> str:
"""
Get a formatted name for the module
:return: the formatted name
"""
return self.name
@property