1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

update nec and value optimization agents to work with recurrent middleware

This commit is contained in:
Zach Dwiel
2017-11-03 13:58:42 -07:00
parent 93a54c7e8e
commit 6c79a442f2
12 changed files with 138 additions and 72 deletions

View File

@@ -204,10 +204,11 @@ class Agent(object):
for action in self.env.actions_description:
self.episode_running_info[action] = []
plt.clf()
if self.tp.agent.middleware_type == MiddlewareTypes.LSTM:
for network in self.networks:
network.curr_rnn_c_in = network.middleware_embedder.c_init
network.curr_rnn_h_in = network.middleware_embedder.h_init
network.online_network.curr_rnn_c_in = network.online_network.middleware_embedder.c_init
network.online_network.curr_rnn_h_in = network.online_network.middleware_embedder.h_init
def preprocess_observation(self, observation):
"""