mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
fix n_step_q_agent
This commit is contained in:
@@ -20,17 +20,6 @@ from utils import *
|
|||||||
import scipy.signal
|
import scipy.signal
|
||||||
|
|
||||||
|
|
||||||
def last_sample(state):
|
|
||||||
"""
|
|
||||||
given a batch of states, return the last sample of the batch with length 1
|
|
||||||
batch axis.
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
k: np.expand_dims(v[-1], 0)
|
|
||||||
for k, v in state.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Actor Critic - https://arxiv.org/abs/1602.01783
|
# Actor Critic - https://arxiv.org/abs/1602.01783
|
||||||
class ActorCriticAgent(PolicyOptimizationAgent):
|
class ActorCriticAgent(PolicyOptimizationAgent):
|
||||||
def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0, create_target_network = False):
|
def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0, create_target_network = False):
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
#
|
#
|
||||||
# Copyright (c) 2017 Intel Corporation
|
# Copyright (c) 2017 Intel Corporation
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -13,13 +13,14 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import numpy as np
|
||||||
from agents.value_optimization_agent import *
|
|
||||||
from agents.policy_optimization_agent import *
|
|
||||||
from logger import *
|
|
||||||
from utils import *
|
|
||||||
import scipy.signal
|
import scipy.signal
|
||||||
|
|
||||||
|
from agents.value_optimization_agent import ValueOptimizationAgent
|
||||||
|
from agents.policy_optimization_agent import PolicyOptimizationAgent
|
||||||
|
from logger import logger
|
||||||
|
from utils import Signal, last_sample
|
||||||
|
|
||||||
|
|
||||||
# N Step Q Learning Agent - https://arxiv.org/abs/1602.01783
|
# N Step Q Learning Agent - https://arxiv.org/abs/1602.01783
|
||||||
class NStepQAgent(ValueOptimizationAgent, PolicyOptimizationAgent):
|
class NStepQAgent(ValueOptimizationAgent, PolicyOptimizationAgent):
|
||||||
@@ -56,7 +57,7 @@ class NStepQAgent(ValueOptimizationAgent, PolicyOptimizationAgent):
|
|||||||
if game_overs[-1]:
|
if game_overs[-1]:
|
||||||
R = 0
|
R = 0
|
||||||
else:
|
else:
|
||||||
R = np.max(self.main_network.target_network.predict(np.expand_dims(next_states[-1], 0)))
|
R = np.max(self.main_network.target_network.predict(last_sample(next_states)))
|
||||||
|
|
||||||
for i in reversed(range(num_transitions)):
|
for i in reversed(range(num_transitions)):
|
||||||
R = rewards[i] + self.tp.agent.discount * R
|
R = rewards[i] + self.tp.agent.discount * R
|
||||||
@@ -66,7 +67,7 @@ class NStepQAgent(ValueOptimizationAgent, PolicyOptimizationAgent):
|
|||||||
assert True, 'The available values for targets_horizon are: 1-Step, N-Step'
|
assert True, 'The available values for targets_horizon are: 1-Step, N-Step'
|
||||||
|
|
||||||
# train
|
# train
|
||||||
result = self.main_network.online_network.accumulate_gradients([current_states], [state_value_head_targets])
|
result = self.main_network.online_network.accumulate_gradients(current_states, [state_value_head_targets])
|
||||||
|
|
||||||
# logging
|
# logging
|
||||||
total_loss, losses, unclipped_grads = result[:3]
|
total_loss, losses, unclipped_grads = result[:3]
|
||||||
|
|||||||
11
utils.py
11
utils.py
@@ -351,3 +351,14 @@ def stack_observation(curr_stack, observation, stack_size):
|
|||||||
curr_stack = np.delete(curr_stack, 0, -1)
|
curr_stack = np.delete(curr_stack, 0, -1)
|
||||||
|
|
||||||
return curr_stack
|
return curr_stack
|
||||||
|
|
||||||
|
|
||||||
|
def last_sample(state):
|
||||||
|
"""
|
||||||
|
given a batch of states, return the last sample of the batch with length 1
|
||||||
|
batch axis.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
k: np.expand_dims(v[-1], 0)
|
||||||
|
for k, v in state.items()
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user