mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
Parallel agents fixes (#95)
* Parallel agents related bug fixes: checkpoint restore, tensorboard integration. Adding narrow networks support. Reference code for unlimited number of checkpoints
This commit is contained in:
@@ -16,13 +16,15 @@
|
||||
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
from configurations import EmbedderWidth
|
||||
|
||||
|
||||
class MiddlewareEmbedder(object):
|
||||
def __init__(self, activation_function=tf.nn.relu, name="middleware_embedder"):
|
||||
def __init__(self, activation_function=tf.nn.relu, embedder_width=EmbedderWidth.Wide, name="middleware_embedder"):
|
||||
self.name = name
|
||||
self.input = None
|
||||
self.output = None
|
||||
self.embedder_width = embedder_width
|
||||
self.activation_function = activation_function
|
||||
|
||||
def __call__(self, input_layer):
|
||||
@@ -70,4 +72,6 @@ class LSTM_Embedder(MiddlewareEmbedder):
|
||||
|
||||
class FC_Embedder(MiddlewareEmbedder):
|
||||
def _build_module(self):
|
||||
self.output = tf.layers.dense(self.input, 512, activation=self.activation_function, name='fc1')
|
||||
width = 512 if self.embedder_width == EmbedderWidth.Wide else 64
|
||||
self.output = tf.layers.dense(self.input, width, activation=self.activation_function, name='fc1')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user