1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

Adding mxnet components to rl_coach/architectures (#60)

Adding mxnet components to rl_coach architectures.

- Supports PPO and DQN
- Tested with CartPole_PPO and CarPole_DQN
- Normalizing filters don't work right now (see #49) and are disabled in CartPole_PPO preset
- Checkpointing is disabled for MXNet
This commit is contained in:
Sina Afrooze
2018-11-07 07:07:15 -08:00
committed by Itai Caspi
parent e7a91b4dc3
commit 5fadb9c18e
39 changed files with 3864 additions and 44 deletions

View File

@@ -183,16 +183,7 @@ class NetworkWrapper(object):
target_network or global_network) and the second element is the inputs
:return: the outputs of all the networks in the same order as the inputs were given
"""
feed_dict = {}
fetches = []
for idx, (network, input) in enumerate(network_input_tuples):
feed_dict.update(network.create_feed_dict(input))
fetches += network.outputs
outputs = self.sess.run(fetches, feed_dict)
return outputs
return type(self.online_network).parallel_predict(self.sess, network_input_tuples)
def get_local_variables(self):
"""