mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
update nec and value optimization agents to work with recurrent middleware
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import numpy as np
|
||||
|
||||
from agents.agent import *
|
||||
|
||||
|
||||
@@ -30,15 +32,28 @@ class ValueOptimizationAgent(Agent):
|
||||
def get_q_values(self, prediction):
|
||||
return prediction
|
||||
|
||||
def choose_action(self, curr_state, phase=RunPhase.TRAIN):
|
||||
def tf_input_state(self, curr_state):
|
||||
"""
|
||||
convert curr_state into input tensors tensorflow is expecting.
|
||||
|
||||
TODO: move this function into Agent and use in as many agent implementations as possible
|
||||
currently, other agents will likely not work with environment measurements.
|
||||
This will become even more important as we support more complex and varied environment states.
|
||||
"""
|
||||
# convert to batch so we can run it through the network
|
||||
observation = np.expand_dims(np.array(curr_state['observation']), 0)
|
||||
if self.tp.agent.use_measurements:
|
||||
measurements = np.expand_dims(np.array(curr_state['measurements']), 0)
|
||||
prediction = self.main_network.online_network.predict([observation, measurements])
|
||||
tf_input_state = [observation, measurements]
|
||||
else:
|
||||
prediction = self.main_network.online_network.predict(observation)
|
||||
tf_input_state = observation
|
||||
return tf_input_state
|
||||
|
||||
def get_prediction(self, curr_state):
|
||||
return self.main_network.online_network.predict(self.tf_input_state(curr_state))
|
||||
|
||||
def choose_action(self, curr_state, phase=RunPhase.TRAIN):
|
||||
prediction = self.get_prediction(curr_state)
|
||||
actions_q_values = self.get_q_values(prediction)
|
||||
|
||||
# choose action according to the exploration policy and the current phase (evaluating or training the agent)
|
||||
|
||||
Reference in New Issue
Block a user