1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

update nec and value optimization agents to work with recurrent middleware

This commit is contained in:
Zach Dwiel
2017-11-03 13:58:42 -07:00
parent 93a54c7e8e
commit 6c79a442f2
12 changed files with 138 additions and 72 deletions

View File

@@ -1,5 +1,5 @@
#
# Copyright (c) 2017 Intel Corporation
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@@ -204,10 +204,11 @@ class Agent(object):
for action in self.env.actions_description:
self.episode_running_info[action] = []
plt.clf()
if self.tp.agent.middleware_type == MiddlewareTypes.LSTM:
for network in self.networks:
network.curr_rnn_c_in = network.middleware_embedder.c_init
network.curr_rnn_h_in = network.middleware_embedder.h_init
network.online_network.curr_rnn_c_in = network.online_network.middleware_embedder.c_init
network.online_network.curr_rnn_h_in = network.online_network.middleware_embedder.h_init
def preprocess_observation(self, observation):
"""

View File

@@ -1,5 +1,5 @@
#
# Copyright (c) 2017 Intel Corporation
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,6 +14,8 @@
# limitations under the License.
#
import numpy as np
from agents.value_optimization_agent import *
@@ -43,26 +45,34 @@ class NECAgent(ValueOptimizationAgent):
return total_loss
def choose_action(self, curr_state, phase=RunPhase.TRAIN):
# convert to batch so we can run it through the network
observation = np.expand_dims(np.array(curr_state['observation']), 0)
"""
this method modifies the superclass's behavior in only 3 ways:
1) the embedding is saved and stored in self.current_episode_state_embeddings
2) the dnd output head is only called if it has a minimum number of entries in it
ideally, the dnd had would do this on its own, but in my attempt in encoding this
behavior in tensorflow, I ran into problems. Would definitely be worth
revisiting in the future
3) during training, actions are saved and stored in self.current_episode_actions
if behaviors 1 and 2 were handled elsewhere, this could easily be implemented
as a wrapper around super instead of overriding this method entirelysearch
"""
# get embedding
embedding = self.main_network.sess.run(self.main_network.online_network.state_embedding,
feed_dict={self.main_network.online_network.inputs[0]: observation})
self.current_episode_state_embeddings.append(embedding[0])
embedding = self.main_network.online_network.predict(
self.tf_input_state(curr_state),
outputs=self.main_network.online_network.state_embedding)
self.current_episode_state_embeddings.append(embedding)
# get action values
# TODO: support additional heads. Right now all other heads are ignored
if self.main_network.online_network.output_heads[0].DND.has_enough_entries(self.tp.agent.number_of_knn):
# if there are enough entries in the DND then we can query it to get the action values
actions_q_values = []
for action in range(self.action_space_size):
feed_dict = {
self.main_network.online_network.state_embedding: embedding,
self.main_network.online_network.output_heads[0].input[0]: np.array([action])
}
q_value = self.main_network.sess.run(
self.main_network.online_network.output_heads[0].output, feed_dict=feed_dict)
actions_q_values.append(q_value[0])
# actions_q_values = []
feed_dict = {
self.main_network.online_network.state_embedding: [embedding],
}
actions_q_values = self.main_network.sess.run(
self.main_network.online_network.output_heads[0].output, feed_dict=feed_dict)
else:
# get only the embedding so we can insert it to the DND
actions_q_values = [0] * self.action_space_size
@@ -70,6 +80,8 @@ class NECAgent(ValueOptimizationAgent):
# choose action according to the exploration policy and the current phase (evaluating or training the agent)
if phase == RunPhase.TRAIN:
action = self.exploration_policy.get_action(actions_q_values)
# NOTE: this next line is not in the parent implementation
# NOTE: it could be implemented as a wrapper around the parent since action is returned
self.current_episode_actions.append(action)
else:
action = np.argmax(actions_q_values)

View File

@@ -1,5 +1,5 @@
#
# Copyright (c) 2017 Intel Corporation
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,6 +14,8 @@
# limitations under the License.
#
import numpy as np
from agents.agent import *
@@ -30,15 +32,28 @@ class ValueOptimizationAgent(Agent):
def get_q_values(self, prediction):
return prediction
def choose_action(self, curr_state, phase=RunPhase.TRAIN):
def tf_input_state(self, curr_state):
"""
convert curr_state into input tensors tensorflow is expecting.
TODO: move this function into Agent and use in as many agent implementations as possible
currently, other agents will likely not work with environment measurements.
This will become even more important as we support more complex and varied environment states.
"""
# convert to batch so we can run it through the network
observation = np.expand_dims(np.array(curr_state['observation']), 0)
if self.tp.agent.use_measurements:
measurements = np.expand_dims(np.array(curr_state['measurements']), 0)
prediction = self.main_network.online_network.predict([observation, measurements])
tf_input_state = [observation, measurements]
else:
prediction = self.main_network.online_network.predict(observation)
tf_input_state = observation
return tf_input_state
def get_prediction(self, curr_state):
return self.main_network.online_network.predict(self.tf_input_state(curr_state))
def choose_action(self, curr_state, phase=RunPhase.TRAIN):
prediction = self.get_prediction(curr_state)
actions_q_values = self.get_q_values(prediction)
# choose action according to the exploration policy and the current phase (evaluating or training the agent)