1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

update nec and value optimization agents to work with recurrent middleware

This commit is contained in:
Zach Dwiel
2017-11-03 13:58:42 -07:00
parent 93a54c7e8e
commit 6c79a442f2
12 changed files with 138 additions and 72 deletions

View File

@@ -1,5 +1,5 @@
#
# Copyright (c) 2017 Intel Corporation
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,6 +14,8 @@
# limitations under the License.
#
import numpy as np
from agents.value_optimization_agent import *
@@ -43,26 +45,34 @@ class NECAgent(ValueOptimizationAgent):
return total_loss
def choose_action(self, curr_state, phase=RunPhase.TRAIN):
# convert to batch so we can run it through the network
observation = np.expand_dims(np.array(curr_state['observation']), 0)
"""
this method modifies the superclass's behavior in only 3 ways:
1) the embedding is saved and stored in self.current_episode_state_embeddings
2) the dnd output head is only called if it has a minimum number of entries in it
ideally, the dnd had would do this on its own, but in my attempt in encoding this
behavior in tensorflow, I ran into problems. Would definitely be worth
revisiting in the future
3) during training, actions are saved and stored in self.current_episode_actions
if behaviors 1 and 2 were handled elsewhere, this could easily be implemented
as a wrapper around super instead of overriding this method entirelysearch
"""
# get embedding
embedding = self.main_network.sess.run(self.main_network.online_network.state_embedding,
feed_dict={self.main_network.online_network.inputs[0]: observation})
self.current_episode_state_embeddings.append(embedding[0])
embedding = self.main_network.online_network.predict(
self.tf_input_state(curr_state),
outputs=self.main_network.online_network.state_embedding)
self.current_episode_state_embeddings.append(embedding)
# get action values
# TODO: support additional heads. Right now all other heads are ignored
if self.main_network.online_network.output_heads[0].DND.has_enough_entries(self.tp.agent.number_of_knn):
# if there are enough entries in the DND then we can query it to get the action values
actions_q_values = []
for action in range(self.action_space_size):
feed_dict = {
self.main_network.online_network.state_embedding: embedding,
self.main_network.online_network.output_heads[0].input[0]: np.array([action])
}
q_value = self.main_network.sess.run(
self.main_network.online_network.output_heads[0].output, feed_dict=feed_dict)
actions_q_values.append(q_value[0])
# actions_q_values = []
feed_dict = {
self.main_network.online_network.state_embedding: [embedding],
}
actions_q_values = self.main_network.sess.run(
self.main_network.online_network.output_heads[0].output, feed_dict=feed_dict)
else:
# get only the embedding so we can insert it to the DND
actions_q_values = [0] * self.action_space_size
@@ -70,6 +80,8 @@ class NECAgent(ValueOptimizationAgent):
# choose action according to the exploration policy and the current phase (evaluating or training the agent)
if phase == RunPhase.TRAIN:
action = self.exploration_policy.get_action(actions_q_values)
# NOTE: this next line is not in the parent implementation
# NOTE: it could be implemented as a wrapper around the parent since action is returned
self.current_episode_actions.append(action)
else:
action = np.argmax(actions_q_values)