1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Adding mxnet components to rl_coach/architectures (#60)

Adding mxnet components to rl_coach architectures.

- Supports PPO and DQN
- Tested with CartPole_PPO and CarPole_DQN
- Normalizing filters don't work right now (see #49) and are disabled in CartPole_PPO preset
- Checkpointing is disabled for MXNet
This commit is contained in:
Sina Afrooze
2018-11-07 07:07:15 -08:00
committed by Itai Caspi
parent e7a91b4dc3
commit 5fadb9c18e
39 changed files with 3864 additions and 44 deletions

View File

@@ -13,8 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
import os
import time
from typing import Any, List, Tuple, Dict
import numpy as np
import tensorflow as tf
@@ -544,6 +545,26 @@ class TensorFlowArchitecture(Architecture):
output = squeeze_list(output)
return output
@staticmethod
def parallel_predict(sess: Any,
network_input_tuples: List[Tuple['TensorFlowArchitecture', Dict[str, np.ndarray]]]) ->\
List[np.ndarray]:
"""
:param sess: active session to use for prediction
:param network_input_tuples: tuple of network and corresponding input
:return: list of outputs from all networks
"""
feed_dict = {}
fetches = []
for network, input in network_input_tuples:
feed_dict.update(network.create_feed_dict(input))
fetches += network.outputs
outputs = sess.run(fetches, feed_dict)
return outputs
def train_on_batch(self, inputs, targets, scaler=1., additional_fetches=None, importance_weights=None):
"""
Given a batch of examples and targets, runs a forward pass & backward pass and then applies the gradients