1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

Parallel agents fixes (#95)

* Parallel agents related bug fixes: checkpoint restore, tensorboard integration.
Adding narrow networks support.
Reference code for unlimited number of checkpoints
This commit is contained in:
Itai Caspi
2018-05-24 14:24:19 +03:00
committed by GitHub
parent 6c0b59b4de
commit d302168c8c
10 changed files with 75 additions and 41 deletions

View File

@@ -16,13 +16,15 @@
import tensorflow as tf
import numpy as np
from configurations import EmbedderWidth
class MiddlewareEmbedder(object):
def __init__(self, activation_function=tf.nn.relu, name="middleware_embedder"):
def __init__(self, activation_function=tf.nn.relu, embedder_width=EmbedderWidth.Wide, name="middleware_embedder"):
self.name = name
self.input = None
self.output = None
self.embedder_width = embedder_width
self.activation_function = activation_function
def __call__(self, input_layer):
@@ -70,4 +72,6 @@ class LSTM_Embedder(MiddlewareEmbedder):
class FC_Embedder(MiddlewareEmbedder):
def _build_module(self):
self.output = tf.layers.dense(self.input, 512, activation=self.activation_function, name='fc1')
width = 512 if self.embedder_width == EmbedderWidth.Wide else 64
self.output = tf.layers.dense(self.input, width, activation=self.activation_function, name='fc1')