1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

fix clipped ppo

This commit is contained in:
Zach Dwiel
2018-02-16 13:22:10 -05:00
parent 85afb86893
commit 39a28aba95
7 changed files with 51 additions and 39 deletions

View File

@@ -338,6 +338,17 @@ class Agent(object):
reward = max(reward, self.tp.env.reward_clipping_min)
return reward
def tf_input_state(self, curr_state):
"""
convert curr_state into input tensors tensorflow is expecting.
"""
# add batch axis with length 1 onto each value
# extract values from the state based on agent.input_types
input_state = {}
for input_name in self.tp.agent.input_types.keys():
input_state[input_name] = np.expand_dims(np.array(curr_state[input_name]), 0)
return input_state
def act(self, phase=RunPhase.TRAIN):
"""
Take one step in the environment according to the network prediction and store the transition in memory