mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
fix n_step_q_agent
This commit is contained in:
11
utils.py
11
utils.py
@@ -351,3 +351,14 @@ def stack_observation(curr_stack, observation, stack_size):
|
||||
curr_stack = np.delete(curr_stack, 0, -1)
|
||||
|
||||
return curr_stack
|
||||
|
||||
|
||||
def last_sample(state):
|
||||
"""
|
||||
given a batch of states, return the last sample of the batch with length 1
|
||||
batch axis.
|
||||
"""
|
||||
return {
|
||||
k: np.expand_dims(v[-1], 0)
|
||||
for k, v in state.items()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user