mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
fix clipped ppo
This commit is contained in:
@@ -338,6 +338,17 @@ class Agent(object):
|
||||
reward = max(reward, self.tp.env.reward_clipping_min)
|
||||
return reward
|
||||
|
||||
def tf_input_state(self, curr_state):
|
||||
"""
|
||||
convert curr_state into input tensors tensorflow is expecting.
|
||||
"""
|
||||
# add batch axis with length 1 onto each value
|
||||
# extract values from the state based on agent.input_types
|
||||
input_state = {}
|
||||
for input_name in self.tp.agent.input_types.keys():
|
||||
input_state[input_name] = np.expand_dims(np.array(curr_state[input_name]), 0)
|
||||
return input_state
|
||||
|
||||
def act(self, phase=RunPhase.TRAIN):
|
||||
"""
|
||||
Take one step in the environment according to the network prediction and store the transition in memory
|
||||
|
||||
Reference in New Issue
Block a user