1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

fix clipped ppo

This commit is contained in:
Zach Dwiel
2018-02-16 13:22:10 -05:00
parent 85afb86893
commit 39a28aba95
7 changed files with 51 additions and 39 deletions

View File

@@ -36,23 +36,6 @@ class ValueOptimizationAgent(Agent):
def get_q_values(self, prediction):
return prediction
def tf_input_state(self, curr_state):
"""
convert curr_state into input tensors tensorflow is expecting.
TODO: move this function into Agent and use in as many agent implementations as possible
currently, other agents will likely not work with environment measurements.
This will become even more important as we support more complex and varied environment states.
"""
# convert to batch so we can run it through the network
observation = np.expand_dims(np.array(curr_state['observation']), 0)
if self.tp.agent.use_measurements:
measurements = np.expand_dims(np.array(curr_state['measurements']), 0)
tf_input_state = [observation, measurements]
else:
tf_input_state = observation
return tf_input_state
def get_prediction(self, curr_state):
return self.main_network.online_network.predict(self.tf_input_state(curr_state))