mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
update of api docstrings across coach and tutorials [WIP] (#91)
* updating the documentation website * adding the built docs * update of api docstrings across coach and tutorials 0-2 * added some missing api documentation * New Sphinx based documentation
This commit is contained in:
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import List, Union
|
||||
from typing import List, Union, Tuple
|
||||
import copy
|
||||
|
||||
import numpy as np
|
||||
@@ -74,7 +74,12 @@ class InputEmbedder(object):
|
||||
activation_function=self.activation_function,
|
||||
dropout_rate=self.dropout_rate))
|
||||
|
||||
def __call__(self, prev_input_placeholder=None):
|
||||
def __call__(self, prev_input_placeholder: tf.placeholder=None) -> Tuple[tf.Tensor, tf.Tensor]:
|
||||
"""
|
||||
Wrapper for building the module graph including scoping and loss creation
|
||||
:param prev_input_placeholder: the input to the graph
|
||||
:return: the input placeholder and the output of the last layer
|
||||
"""
|
||||
with tf.variable_scope(self.get_name()):
|
||||
if prev_input_placeholder is None:
|
||||
self.input = tf.placeholder("float", shape=[None] + self.input_size, name=self.get_name())
|
||||
@@ -84,7 +89,13 @@ class InputEmbedder(object):
|
||||
|
||||
return self.input, self.output
|
||||
|
||||
def _build_module(self):
|
||||
def _build_module(self) -> None:
|
||||
"""
|
||||
Builds the graph of the module
|
||||
This method is called early on from __call__. It is expected to store the graph
|
||||
in self.output.
|
||||
:return: None
|
||||
"""
|
||||
# NOTE: for image inputs, we expect the data format to be of type uint8, so to be memory efficient. we chose not
|
||||
# to implement the rescaling as an input filters.observation.observation_filter, as this would have caused the
|
||||
# input to the network to be float, which is 4x more expensive in memory.
|
||||
@@ -127,7 +138,11 @@ class InputEmbedder(object):
|
||||
raise NotImplementedError("Inheriting embedder must define schemes matching its allowed default "
|
||||
"configurations.")
|
||||
|
||||
def get_name(self):
|
||||
def get_name(self) -> str:
|
||||
"""
|
||||
Get a formatted name for the module
|
||||
:return: the formatted name
|
||||
"""
|
||||
return self.name
|
||||
|
||||
def __str__(self):
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import copy
|
||||
from typing import Union
|
||||
from typing import Union, Tuple
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
@@ -64,17 +64,33 @@ class Middleware(object):
|
||||
activation_function=self.activation_function,
|
||||
dropout_rate=self.dropout_rate))
|
||||
|
||||
def __call__(self, input_layer):
|
||||
def __call__(self, input_layer: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
|
||||
"""
|
||||
Wrapper for building the module graph including scoping and loss creation
|
||||
:param input_layer: the input to the graph
|
||||
:return: the input placeholder and the output of the last layer
|
||||
"""
|
||||
with tf.variable_scope(self.get_name()):
|
||||
self.input = input_layer
|
||||
self._build_module()
|
||||
|
||||
return self.input, self.output
|
||||
|
||||
def _build_module(self):
|
||||
def _build_module(self) -> None:
|
||||
"""
|
||||
Builds the graph of the module
|
||||
This method is called early on from __call__. It is expected to store the graph
|
||||
in self.output.
|
||||
:param input_layer: the input to the graph
|
||||
:return: None
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_name(self):
|
||||
def get_name(self) -> str:
|
||||
"""
|
||||
Get a formatted name for the module
|
||||
:return: the formatted name
|
||||
"""
|
||||
return self.name
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user