1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

fix n_step_q_agent

This commit is contained in:
Zach Dwiel
2018-02-16 20:25:33 -05:00
parent 5cf10e5f52
commit e1ad86417f
3 changed files with 20 additions and 19 deletions

View File

@@ -20,17 +20,6 @@ from utils import *
import scipy.signal import scipy.signal
def last_sample(state):
"""
given a batch of states, return the last sample of the batch with length 1
batch axis.
"""
return {
k: np.expand_dims(v[-1], 0)
for k, v in state.items()
}
# Actor Critic - https://arxiv.org/abs/1602.01783 # Actor Critic - https://arxiv.org/abs/1602.01783
class ActorCriticAgent(PolicyOptimizationAgent): class ActorCriticAgent(PolicyOptimizationAgent):
def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0, create_target_network = False): def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0, create_target_network = False):

View File

@@ -13,13 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import numpy as np
from agents.value_optimization_agent import *
from agents.policy_optimization_agent import *
from logger import *
from utils import *
import scipy.signal import scipy.signal
from agents.value_optimization_agent import ValueOptimizationAgent
from agents.policy_optimization_agent import PolicyOptimizationAgent
from logger import logger
from utils import Signal, last_sample
# N Step Q Learning Agent - https://arxiv.org/abs/1602.01783 # N Step Q Learning Agent - https://arxiv.org/abs/1602.01783
class NStepQAgent(ValueOptimizationAgent, PolicyOptimizationAgent): class NStepQAgent(ValueOptimizationAgent, PolicyOptimizationAgent):
@@ -56,7 +57,7 @@ class NStepQAgent(ValueOptimizationAgent, PolicyOptimizationAgent):
if game_overs[-1]: if game_overs[-1]:
R = 0 R = 0
else: else:
R = np.max(self.main_network.target_network.predict(np.expand_dims(next_states[-1], 0))) R = np.max(self.main_network.target_network.predict(last_sample(next_states)))
for i in reversed(range(num_transitions)): for i in reversed(range(num_transitions)):
R = rewards[i] + self.tp.agent.discount * R R = rewards[i] + self.tp.agent.discount * R
@@ -66,7 +67,7 @@ class NStepQAgent(ValueOptimizationAgent, PolicyOptimizationAgent):
assert True, 'The available values for targets_horizon are: 1-Step, N-Step' assert True, 'The available values for targets_horizon are: 1-Step, N-Step'
# train # train
result = self.main_network.online_network.accumulate_gradients([current_states], [state_value_head_targets]) result = self.main_network.online_network.accumulate_gradients(current_states, [state_value_head_targets])
# logging # logging
total_loss, losses, unclipped_grads = result[:3] total_loss, losses, unclipped_grads = result[:3]

View File

@@ -351,3 +351,14 @@ def stack_observation(curr_stack, observation, stack_size):
curr_stack = np.delete(curr_stack, 0, -1) curr_stack = np.delete(curr_stack, 0, -1)
return curr_stack return curr_stack
def last_sample(state):
"""
given a batch of states, return the last sample of the batch with length 1
batch axis.
"""
return {
k: np.expand_dims(v[-1], 0)
for k, v in state.items()
}