mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
update nec and value optimization agents to work with recurrent middleware
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -204,10 +204,11 @@ class Agent(object):
|
||||
for action in self.env.actions_description:
|
||||
self.episode_running_info[action] = []
|
||||
plt.clf()
|
||||
|
||||
if self.tp.agent.middleware_type == MiddlewareTypes.LSTM:
|
||||
for network in self.networks:
|
||||
network.curr_rnn_c_in = network.middleware_embedder.c_init
|
||||
network.curr_rnn_h_in = network.middleware_embedder.h_init
|
||||
network.online_network.curr_rnn_c_in = network.online_network.middleware_embedder.c_init
|
||||
network.online_network.curr_rnn_h_in = network.online_network.middleware_embedder.h_init
|
||||
|
||||
def preprocess_observation(self, observation):
|
||||
"""
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import numpy as np
|
||||
|
||||
from agents.value_optimization_agent import *
|
||||
|
||||
|
||||
@@ -43,26 +45,34 @@ class NECAgent(ValueOptimizationAgent):
|
||||
return total_loss
|
||||
|
||||
def choose_action(self, curr_state, phase=RunPhase.TRAIN):
|
||||
# convert to batch so we can run it through the network
|
||||
observation = np.expand_dims(np.array(curr_state['observation']), 0)
|
||||
"""
|
||||
this method modifies the superclass's behavior in only 3 ways:
|
||||
|
||||
1) the embedding is saved and stored in self.current_episode_state_embeddings
|
||||
2) the dnd output head is only called if it has a minimum number of entries in it
|
||||
ideally, the dnd had would do this on its own, but in my attempt in encoding this
|
||||
behavior in tensorflow, I ran into problems. Would definitely be worth
|
||||
revisiting in the future
|
||||
3) during training, actions are saved and stored in self.current_episode_actions
|
||||
if behaviors 1 and 2 were handled elsewhere, this could easily be implemented
|
||||
as a wrapper around super instead of overriding this method entirelysearch
|
||||
"""
|
||||
|
||||
# get embedding
|
||||
embedding = self.main_network.sess.run(self.main_network.online_network.state_embedding,
|
||||
feed_dict={self.main_network.online_network.inputs[0]: observation})
|
||||
self.current_episode_state_embeddings.append(embedding[0])
|
||||
embedding = self.main_network.online_network.predict(
|
||||
self.tf_input_state(curr_state),
|
||||
outputs=self.main_network.online_network.state_embedding)
|
||||
self.current_episode_state_embeddings.append(embedding)
|
||||
|
||||
# get action values
|
||||
# TODO: support additional heads. Right now all other heads are ignored
|
||||
if self.main_network.online_network.output_heads[0].DND.has_enough_entries(self.tp.agent.number_of_knn):
|
||||
# if there are enough entries in the DND then we can query it to get the action values
|
||||
actions_q_values = []
|
||||
for action in range(self.action_space_size):
|
||||
feed_dict = {
|
||||
self.main_network.online_network.state_embedding: embedding,
|
||||
self.main_network.online_network.output_heads[0].input[0]: np.array([action])
|
||||
}
|
||||
q_value = self.main_network.sess.run(
|
||||
self.main_network.online_network.output_heads[0].output, feed_dict=feed_dict)
|
||||
actions_q_values.append(q_value[0])
|
||||
# actions_q_values = []
|
||||
feed_dict = {
|
||||
self.main_network.online_network.state_embedding: [embedding],
|
||||
}
|
||||
actions_q_values = self.main_network.sess.run(
|
||||
self.main_network.online_network.output_heads[0].output, feed_dict=feed_dict)
|
||||
else:
|
||||
# get only the embedding so we can insert it to the DND
|
||||
actions_q_values = [0] * self.action_space_size
|
||||
@@ -70,6 +80,8 @@ class NECAgent(ValueOptimizationAgent):
|
||||
# choose action according to the exploration policy and the current phase (evaluating or training the agent)
|
||||
if phase == RunPhase.TRAIN:
|
||||
action = self.exploration_policy.get_action(actions_q_values)
|
||||
# NOTE: this next line is not in the parent implementation
|
||||
# NOTE: it could be implemented as a wrapper around the parent since action is returned
|
||||
self.current_episode_actions.append(action)
|
||||
else:
|
||||
action = np.argmax(actions_q_values)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import numpy as np
|
||||
|
||||
from agents.agent import *
|
||||
|
||||
|
||||
@@ -30,15 +32,28 @@ class ValueOptimizationAgent(Agent):
|
||||
def get_q_values(self, prediction):
|
||||
return prediction
|
||||
|
||||
def choose_action(self, curr_state, phase=RunPhase.TRAIN):
|
||||
def tf_input_state(self, curr_state):
|
||||
"""
|
||||
convert curr_state into input tensors tensorflow is expecting.
|
||||
|
||||
TODO: move this function into Agent and use in as many agent implementations as possible
|
||||
currently, other agents will likely not work with environment measurements.
|
||||
This will become even more important as we support more complex and varied environment states.
|
||||
"""
|
||||
# convert to batch so we can run it through the network
|
||||
observation = np.expand_dims(np.array(curr_state['observation']), 0)
|
||||
if self.tp.agent.use_measurements:
|
||||
measurements = np.expand_dims(np.array(curr_state['measurements']), 0)
|
||||
prediction = self.main_network.online_network.predict([observation, measurements])
|
||||
tf_input_state = [observation, measurements]
|
||||
else:
|
||||
prediction = self.main_network.online_network.predict(observation)
|
||||
tf_input_state = observation
|
||||
return tf_input_state
|
||||
|
||||
def get_prediction(self, curr_state):
|
||||
return self.main_network.online_network.predict(self.tf_input_state(curr_state))
|
||||
|
||||
def choose_action(self, curr_state, phase=RunPhase.TRAIN):
|
||||
prediction = self.get_prediction(curr_state)
|
||||
actions_q_values = self.get_q_values(prediction)
|
||||
|
||||
# choose action according to the exploration policy and the current phase (evaluating or training the agent)
|
||||
|
||||
Reference in New Issue
Block a user