1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00
This commit is contained in:
Zach Dwiel
2018-02-16 15:47:16 -05:00
parent 943e41ba58
commit 98f57a0d87
3 changed files with 44 additions and 27 deletions

View File

@@ -13,13 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
import six
import numpy as np
import tensorflow as tf
from architectures.architecture import Architecture
import tensorflow as tf
from utils import force_list, squeeze_list
from configurations import Preset, MiddlewareTypes
import numpy as np
import time
def variable_summaries(var):
"""Attach a lot of summaries to a Tensor (for TensorBoard visualization)."""
@@ -269,17 +271,29 @@ class TensorFlowArchitecture(Architecture):
def _feed_dict(self, inputs):
feed_dict = {}
for input_name, input_value in inputs.items():
if input_name not in self.inputs:
if isinstance(input_name, six.string_types):
if input_name not in self.inputs:
raise ValueError((
'input name {input_name} was provided to create a feed '
'dictionary, but there is no placeholder with that name. '
'placeholder names available include: {placeholder_names}'
).format(
input_name=input_name,
placeholder_names=', '.join(self.inputs.keys())
))
feed_dict[self.inputs[input_name]] = input_value
elif isinstance(input_name, tf.Tensor) and input_name.op.type == 'Placeholder':
feed_dict[input_name] = input_value
else:
raise ValueError((
'input name {input_name} was provided to create a feed '
'dictionary, but there is no placeholder with that name. '
'placeholder names available include: {placeholder_names}'
'input dictionary expects strings or placeholders as keys, '
'but found key {key} of type {type}'
).format(
input_name=input_name,
placeholder_names=', '.join(self.inputs.keys())
key=input_name,
type=type(input_name),
))
feed_dict[self.inputs[input_name]] = input_value
return feed_dict
def predict(self, inputs, outputs=None):