mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
fix clipped ppo
This commit is contained in:
@@ -36,23 +36,6 @@ class ValueOptimizationAgent(Agent):
|
||||
def get_q_values(self, prediction):
|
||||
return prediction
|
||||
|
||||
def tf_input_state(self, curr_state):
|
||||
"""
|
||||
convert curr_state into input tensors tensorflow is expecting.
|
||||
|
||||
TODO: move this function into Agent and use in as many agent implementations as possible
|
||||
currently, other agents will likely not work with environment measurements.
|
||||
This will become even more important as we support more complex and varied environment states.
|
||||
"""
|
||||
# convert to batch so we can run it through the network
|
||||
observation = np.expand_dims(np.array(curr_state['observation']), 0)
|
||||
if self.tp.agent.use_measurements:
|
||||
measurements = np.expand_dims(np.array(curr_state['measurements']), 0)
|
||||
tf_input_state = [observation, measurements]
|
||||
else:
|
||||
tf_input_state = observation
|
||||
return tf_input_state
|
||||
|
||||
def get_prediction(self, curr_state):
|
||||
return self.main_network.online_network.predict(self.tf_input_state(curr_state))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user