mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
bug fix in dueling network + revert to TF 1.6 for CPU due to requirements compatibility issues
This commit is contained in:
@@ -39,15 +39,15 @@ class DuelingQHead(QHead):
|
|||||||
def _build_module(self, input_layer):
|
def _build_module(self, input_layer):
|
||||||
# state value tower - V
|
# state value tower - V
|
||||||
with tf.variable_scope("state_value"):
|
with tf.variable_scope("state_value"):
|
||||||
state_value = self.dense_layer(512)(input_layer, activation=self.activation_function, name='fc1')
|
self.state_value = self.dense_layer(512)(input_layer, activation=self.activation_function, name='fc1')
|
||||||
state_value = self.dense_layer(1)(state_value, name='fc2')
|
self.state_value = self.dense_layer(1)(self.state_value, name='fc2')
|
||||||
# state_value = tf.expand_dims(state_value, axis=-1)
|
|
||||||
|
|
||||||
# action advantage tower - A
|
# action advantage tower - A
|
||||||
with tf.variable_scope("action_advantage"):
|
with tf.variable_scope("action_advantage"):
|
||||||
action_advantage = self.dense_layer(512)(input_layer, activation=self.activation_function, name='fc1')
|
self.action_advantage = self.dense_layer(512)(input_layer, activation=self.activation_function, name='fc1')
|
||||||
action_advantage = self.dense_layer(self.num_actions)(action_advantage, name='fc2')
|
self.action_advantage = self.dense_layer(self.num_actions)(self.action_advantage, name='fc2')
|
||||||
action_advantage = action_advantage - tf.reduce_mean(action_advantage, axis=1)
|
self.action_mean = tf.reduce_mean(self.action_advantage, axis=1, keep_dims=True)
|
||||||
|
self.action_advantage = self.action_advantage - self.action_mean
|
||||||
|
|
||||||
# merge to state-action value function Q
|
# merge to state-action value function Q
|
||||||
self.output = tf.add(state_value, action_advantage, name='output')
|
self.output = tf.add(self.state_value, self.action_advantage, name='output')
|
||||||
|
|||||||
4
setup.py
4
setup.py
@@ -61,9 +61,9 @@ if not using_GPU:
|
|||||||
# For linux wth no GPU, we install the Intel optimized version of TensorFlow
|
# For linux wth no GPU, we install the Intel optimized version of TensorFlow
|
||||||
if sys.platform == "linux" or sys.platform == "linux2":
|
if sys.platform == "linux" or sys.platform == "linux2":
|
||||||
subprocess.check_call(['pip install '
|
subprocess.check_call(['pip install '
|
||||||
'https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp35-cp35m-linux_x86_64.whl'],
|
'https://anaconda.org/intel/tensorflow/1.6.0/download/tensorflow-1.6.0-cp35-cp35m-linux_x86_64.whl'],
|
||||||
shell=True)
|
shell=True)
|
||||||
install_requires.append('tensorflow==1.10.0')
|
install_requires.append('tensorflow==1.6.0')
|
||||||
else:
|
else:
|
||||||
install_requires.append('tensorflow-gpu==1.10.0')
|
install_requires.append('tensorflow-gpu==1.10.0')
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user