1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

update nec and value optimization agents to work with recurrent middleware

This commit is contained in:
Zach Dwiel
2017-11-03 13:58:42 -07:00
parent 93a54c7e8e
commit 6c79a442f2
12 changed files with 138 additions and 72 deletions

View File

@@ -1,5 +1,5 @@
#
# Copyright (c) 2017 Intel Corporation
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,6 +14,8 @@
# limitations under the License.
#
import numpy as np
from agents.agent import *
@@ -30,15 +32,28 @@ class ValueOptimizationAgent(Agent):
def get_q_values(self, prediction):
return prediction
def choose_action(self, curr_state, phase=RunPhase.TRAIN):
def tf_input_state(self, curr_state):
"""
convert curr_state into input tensors tensorflow is expecting.
TODO: move this function into Agent and use in as many agent implementations as possible
currently, other agents will likely not work with environment measurements.
This will become even more important as we support more complex and varied environment states.
"""
# convert to batch so we can run it through the network
observation = np.expand_dims(np.array(curr_state['observation']), 0)
if self.tp.agent.use_measurements:
measurements = np.expand_dims(np.array(curr_state['measurements']), 0)
prediction = self.main_network.online_network.predict([observation, measurements])
tf_input_state = [observation, measurements]
else:
prediction = self.main_network.online_network.predict(observation)
tf_input_state = observation
return tf_input_state
def get_prediction(self, curr_state):
return self.main_network.online_network.predict(self.tf_input_state(curr_state))
def choose_action(self, curr_state, phase=RunPhase.TRAIN):
prediction = self.get_prediction(curr_state)
actions_q_values = self.get_q_values(prediction)
# choose action according to the exploration policy and the current phase (evaluating or training the agent)