mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Adding mxnet components to rl_coach/architectures (#60)
Adding mxnet components to rl_coach architectures. - Supports PPO and DQN - Tested with CartPole_PPO and CarPole_DQN - Normalizing filters don't work right now (see #49) and are disabled in CartPole_PPO preset - Checkpointing is disabled for MXNet
This commit is contained in:
@@ -13,8 +13,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import time
|
||||
import os
|
||||
import time
|
||||
from typing import Any, List, Tuple, Dict
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
@@ -544,6 +545,26 @@ class TensorFlowArchitecture(Architecture):
|
||||
output = squeeze_list(output)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def parallel_predict(sess: Any,
|
||||
network_input_tuples: List[Tuple['TensorFlowArchitecture', Dict[str, np.ndarray]]]) ->\
|
||||
List[np.ndarray]:
|
||||
"""
|
||||
:param sess: active session to use for prediction
|
||||
:param network_input_tuples: tuple of network and corresponding input
|
||||
:return: list of outputs from all networks
|
||||
"""
|
||||
feed_dict = {}
|
||||
fetches = []
|
||||
|
||||
for network, input in network_input_tuples:
|
||||
feed_dict.update(network.create_feed_dict(input))
|
||||
fetches += network.outputs
|
||||
|
||||
outputs = sess.run(fetches, feed_dict)
|
||||
|
||||
return outputs
|
||||
|
||||
def train_on_batch(self, inputs, targets, scaler=1., additional_fetches=None, importance_weights=None):
|
||||
"""
|
||||
Given a batch of examples and targets, runs a forward pass & backward pass and then applies the gradients
|
||||
|
||||
Reference in New Issue
Block a user