diff --git a/agents/actor_critic_agent.py b/agents/actor_critic_agent.py index 8b198bf..e2dc916 100644 --- a/agents/actor_critic_agent.py +++ b/agents/actor_critic_agent.py @@ -20,17 +20,6 @@ from utils import * import scipy.signal -def last_sample(state): - """ - given a batch of states, return the last sample of the batch with length 1 - batch axis. - """ - return { - k: np.expand_dims(v[-1], 0) - for k, v in state.items() - } - - # Actor Critic - https://arxiv.org/abs/1602.01783 class ActorCriticAgent(PolicyOptimizationAgent): def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0, create_target_network = False): diff --git a/agents/n_step_q_agent.py b/agents/n_step_q_agent.py index 9dc839a..5a74fb5 100644 --- a/agents/n_step_q_agent.py +++ b/agents/n_step_q_agent.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2017 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,13 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # - -from agents.value_optimization_agent import * -from agents.policy_optimization_agent import * -from logger import * -from utils import * +import numpy as np import scipy.signal +from agents.value_optimization_agent import ValueOptimizationAgent +from agents.policy_optimization_agent import PolicyOptimizationAgent +from logger import logger +from utils import Signal, last_sample + # N Step Q Learning Agent - https://arxiv.org/abs/1602.01783 class NStepQAgent(ValueOptimizationAgent, PolicyOptimizationAgent): @@ -56,7 +57,7 @@ class NStepQAgent(ValueOptimizationAgent, PolicyOptimizationAgent): if game_overs[-1]: R = 0 else: - R = np.max(self.main_network.target_network.predict(np.expand_dims(next_states[-1], 0))) + R = np.max(self.main_network.target_network.predict(last_sample(next_states))) for i in reversed(range(num_transitions)): R = rewards[i] + self.tp.agent.discount * R @@ -66,7 +67,7 @@ class NStepQAgent(ValueOptimizationAgent, PolicyOptimizationAgent): assert True, 'The available values for targets_horizon are: 1-Step, N-Step' # train - result = self.main_network.online_network.accumulate_gradients([current_states], [state_value_head_targets]) + result = self.main_network.online_network.accumulate_gradients(current_states, [state_value_head_targets]) # logging total_loss, losses, unclipped_grads = result[:3] diff --git a/utils.py b/utils.py index 7449819..4eecbce 100644 --- a/utils.py +++ b/utils.py @@ -351,3 +351,14 @@ def stack_observation(curr_stack, observation, stack_size): curr_stack = np.delete(curr_stack, 0, -1) return curr_stack + + +def last_sample(state): + """ + given a batch of states, return the last sample of the batch with length 1 + batch axis. + """ + return { + k: np.expand_dims(v[-1], 0) + for k, v in state.items() + }