mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
Adding mxnet components to rl_coach/architectures (#60)
Adding mxnet components to rl_coach architectures. - Supports PPO and DQN - Tested with CartPole_PPO and CarPole_DQN - Normalizing filters don't work right now (see #49) and are disabled in CartPole_PPO preset - Checkpointing is disabled for MXNet
This commit is contained in:
@@ -183,16 +183,7 @@ class NetworkWrapper(object):
|
||||
target_network or global_network) and the second element is the inputs
|
||||
:return: the outputs of all the networks in the same order as the inputs were given
|
||||
"""
|
||||
feed_dict = {}
|
||||
fetches = []
|
||||
|
||||
for idx, (network, input) in enumerate(network_input_tuples):
|
||||
feed_dict.update(network.create_feed_dict(input))
|
||||
fetches += network.outputs
|
||||
|
||||
outputs = self.sess.run(fetches, feed_dict)
|
||||
|
||||
return outputs
|
||||
return type(self.online_network).parallel_predict(self.sess, network_input_tuples)
|
||||
|
||||
def get_local_variables(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user