mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
Release 0.9
Main changes are detailed below: New features - * CARLA 0.7 simulator integration * Human control of the game play * Recording of human game play and storing / loading the replay buffer * Behavioral cloning agent and presets * Golden tests for several presets * Selecting between deep / shallow image embedders * Rendering through pygame (with some boost in performance) API changes - * Improved environment wrapper API * Added an evaluate flag to allow convenient evaluation of existing checkpoints * Improve frameskip definition in Gym Bug fixes - * Fixed loading of checkpoints for agents with more than one network * Fixed the N Step Q learning agent python3 compatibility
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
|
||||
from agents.actor_critic_agent import *
|
||||
from agents.agent import *
|
||||
from agents.bc_agent import *
|
||||
from agents.bootstrapped_dqn_agent import *
|
||||
from agents.clipped_ppo_agent import *
|
||||
from agents.ddpg_agent import *
|
||||
@@ -23,6 +24,8 @@ from agents.ddqn_agent import *
|
||||
from agents.dfp_agent import *
|
||||
from agents.dqn_agent import *
|
||||
from agents.categorical_dqn_agent import *
|
||||
from agents.human_agent import *
|
||||
from agents.imitation_agent import *
|
||||
from agents.mmc_agent import *
|
||||
from agents.n_step_q_agent import *
|
||||
from agents.naf_agent import *
|
||||
|
||||
104
agents/agent.py
104
agents/agent.py
@@ -50,6 +50,7 @@ class Agent(object):
|
||||
self.task_id = task_id
|
||||
self.sess = tuning_parameters.sess
|
||||
self.env = tuning_parameters.env_instance = env
|
||||
self.imitation = False
|
||||
|
||||
# i/o dimensions
|
||||
if not tuning_parameters.env.desired_observation_width or not tuning_parameters.env.desired_observation_height:
|
||||
@@ -61,7 +62,12 @@ class Agent(object):
|
||||
self.measurements_size = tuning_parameters.env.measurements_size = (self.measurements_size[0] + 1,)
|
||||
|
||||
# modules
|
||||
self.memory = eval(tuning_parameters.memory + '(tuning_parameters)')
|
||||
if tuning_parameters.agent.load_memory_from_file_path:
|
||||
screen.log_title("Loading replay buffer from pickle. Pickle path: {}"
|
||||
.format(tuning_parameters.agent.load_memory_from_file_path))
|
||||
self.memory = read_pickle(tuning_parameters.agent.load_memory_from_file_path)
|
||||
else:
|
||||
self.memory = eval(tuning_parameters.memory + '(tuning_parameters)')
|
||||
# self.architecture = eval(tuning_parameters.architecture)
|
||||
|
||||
self.has_global = replicated_device is not None
|
||||
@@ -121,11 +127,12 @@ class Agent(object):
|
||||
|
||||
def log_to_screen(self, phase):
|
||||
# log to screen
|
||||
if self.current_episode > 0:
|
||||
if phase == RunPhase.TEST:
|
||||
exploration = self.evaluation_exploration_policy.get_control_param()
|
||||
else:
|
||||
if self.current_episode >= 0:
|
||||
if phase == RunPhase.TRAIN:
|
||||
exploration = self.exploration_policy.get_control_param()
|
||||
else:
|
||||
exploration = self.evaluation_exploration_policy.get_control_param()
|
||||
|
||||
screen.log_dict(
|
||||
OrderedDict([
|
||||
("Worker", self.task_id),
|
||||
@@ -135,7 +142,7 @@ class Agent(object):
|
||||
("steps", self.total_steps_counter),
|
||||
("training iteration", self.training_iteration)
|
||||
]),
|
||||
prefix="Heatup" if self.in_heatup else "Training" if phase == RunPhase.TRAIN else "Testing"
|
||||
prefix=phase
|
||||
)
|
||||
|
||||
def update_log(self, phase=RunPhase.TRAIN):
|
||||
@@ -146,7 +153,7 @@ class Agent(object):
|
||||
# log all the signals to file
|
||||
logger.set_current_time(self.current_episode)
|
||||
logger.create_signal_value('Training Iter', self.training_iteration)
|
||||
logger.create_signal_value('In Heatup', int(self.in_heatup))
|
||||
logger.create_signal_value('In Heatup', int(phase == RunPhase.HEATUP))
|
||||
logger.create_signal_value('ER #Transitions', self.memory.num_transitions())
|
||||
logger.create_signal_value('ER #Episodes', self.memory.length())
|
||||
logger.create_signal_value('Episode Length', self.current_episode_steps_counter)
|
||||
@@ -197,24 +204,6 @@ class Agent(object):
|
||||
network.curr_rnn_c_in = network.middleware_embedder.c_init
|
||||
network.curr_rnn_h_in = network.middleware_embedder.h_init
|
||||
|
||||
def stack_observation(self, curr_stack, observation):
|
||||
"""
|
||||
Adds a new observation to an existing stack of observations from previous time-steps.
|
||||
:param curr_stack: The current observations stack.
|
||||
:param observation: The new observation
|
||||
:return: The updated observation stack
|
||||
"""
|
||||
|
||||
if curr_stack == []:
|
||||
# starting an episode
|
||||
curr_stack = np.vstack(np.expand_dims([observation] * self.tp.env.observation_stack_size, 0))
|
||||
curr_stack = self.switch_axes_order(curr_stack, from_type='channels_first', to_type='channels_last')
|
||||
else:
|
||||
curr_stack = np.append(curr_stack, np.expand_dims(np.squeeze(observation), axis=-1), axis=-1)
|
||||
curr_stack = np.delete(curr_stack, 0, -1)
|
||||
|
||||
return curr_stack
|
||||
|
||||
def preprocess_observation(self, observation):
|
||||
"""
|
||||
Preprocesses the given observation.
|
||||
@@ -335,26 +324,6 @@ class Agent(object):
|
||||
reward = max(reward, self.tp.env.reward_clipping_min)
|
||||
return reward
|
||||
|
||||
def switch_axes_order(self, observation, from_type='channels_first', to_type='channels_last'):
|
||||
"""
|
||||
transpose an observation axes from channels_first to channels_last or vice versa
|
||||
:param observation: a numpy array
|
||||
:param from_type: can be 'channels_first' or 'channels_last'
|
||||
:param to_type: can be 'channels_first' or 'channels_last'
|
||||
:return: a new observation with the requested axes order
|
||||
"""
|
||||
if from_type == to_type or len(observation.shape) == 1:
|
||||
return observation
|
||||
assert 2 <= len(observation.shape) <= 3, 'num axes of an observation must be 2 for a vector or 3 for an image'
|
||||
assert type(observation) == np.ndarray, 'observation must be a numpy array'
|
||||
if len(observation.shape) == 3:
|
||||
if from_type == 'channels_first' and to_type == 'channels_last':
|
||||
return np.transpose(observation, (1, 2, 0))
|
||||
elif from_type == 'channels_last' and to_type == 'channels_first':
|
||||
return np.transpose(observation, (2, 0, 1))
|
||||
else:
|
||||
return np.transpose(observation, (1, 0))
|
||||
|
||||
def act(self, phase=RunPhase.TRAIN):
|
||||
"""
|
||||
Take one step in the environment according to the network prediction and store the transition in memory
|
||||
@@ -370,7 +339,7 @@ class Agent(object):
|
||||
is_first_transition_in_episode = (self.curr_state == [])
|
||||
if is_first_transition_in_episode:
|
||||
observation = self.preprocess_observation(self.env.observation)
|
||||
observation = self.stack_observation([], observation)
|
||||
observation = stack_observation([], observation, self.tp.env.observation_stack_size)
|
||||
|
||||
self.curr_state = {'observation': observation}
|
||||
if self.tp.agent.use_measurements:
|
||||
@@ -378,7 +347,7 @@ class Agent(object):
|
||||
if self.tp.agent.use_accumulated_reward_as_measurement:
|
||||
self.curr_state['measurements'] = np.append(self.curr_state['measurements'], 0)
|
||||
|
||||
if self.in_heatup: # we do not have a stacked curr_state yet
|
||||
if phase == RunPhase.HEATUP and not self.tp.heatup_using_network_decisions:
|
||||
action = self.env.get_random_action()
|
||||
else:
|
||||
action, action_info = self.choose_action(self.curr_state, phase=phase)
|
||||
@@ -394,11 +363,11 @@ class Agent(object):
|
||||
observation = self.preprocess_observation(result['observation'])
|
||||
|
||||
# plot action values online
|
||||
if self.tp.visualization.plot_action_values_online and not self.in_heatup:
|
||||
if self.tp.visualization.plot_action_values_online and phase != RunPhase.HEATUP:
|
||||
self.plot_action_values_online()
|
||||
|
||||
# initialize the next state
|
||||
observation = self.stack_observation(self.curr_state['observation'], observation)
|
||||
observation = stack_observation(self.curr_state['observation'], observation, self.tp.env.observation_stack_size)
|
||||
|
||||
next_state = {'observation': observation}
|
||||
if self.tp.agent.use_measurements and 'measurements' in result.keys():
|
||||
@@ -407,7 +376,7 @@ class Agent(object):
|
||||
next_state['measurements'] = np.append(next_state['measurements'], self.total_reward_in_current_episode)
|
||||
|
||||
# store the transition only if we are training
|
||||
if phase == RunPhase.TRAIN:
|
||||
if phase == RunPhase.TRAIN or phase == RunPhase.HEATUP:
|
||||
transition = Transition(self.curr_state, result['action'], shaped_reward, next_state, result['done'])
|
||||
for key in action_info.keys():
|
||||
transition.info[key] = action_info[key]
|
||||
@@ -427,7 +396,7 @@ class Agent(object):
|
||||
self.update_log(phase=phase)
|
||||
self.log_to_screen(phase=phase)
|
||||
|
||||
if phase == RunPhase.TRAIN:
|
||||
if phase == RunPhase.TRAIN or phase == RunPhase.HEATUP:
|
||||
self.reset_game()
|
||||
|
||||
self.current_episode += 1
|
||||
@@ -462,11 +431,12 @@ class Agent(object):
|
||||
for network in self.networks:
|
||||
network.sync()
|
||||
|
||||
if self.tp.visualization.dump_gifs and self.total_reward_in_current_episode > max_reward_achieved:
|
||||
if self.total_reward_in_current_episode > max_reward_achieved:
|
||||
max_reward_achieved = self.total_reward_in_current_episode
|
||||
frame_skipping = int(5/self.tp.env.frame_skip)
|
||||
logger.create_gif(self.last_episode_images[::frame_skipping],
|
||||
name='score-{}'.format(max_reward_achieved), fps=10)
|
||||
if self.tp.visualization.dump_gifs:
|
||||
logger.create_gif(self.last_episode_images[::frame_skipping],
|
||||
name='score-{}'.format(max_reward_achieved), fps=10)
|
||||
|
||||
average_evaluation_reward += self.total_reward_in_current_episode
|
||||
self.reset_game()
|
||||
@@ -496,7 +466,7 @@ class Agent(object):
|
||||
screen.log_title("Starting heatup {}".format(self.task_id))
|
||||
num_steps_required_for_one_training_batch = self.tp.batch_size * self.tp.env.observation_stack_size
|
||||
for step in range(max(self.tp.num_heatup_steps, num_steps_required_for_one_training_batch)):
|
||||
self.act()
|
||||
self.act(phase=RunPhase.HEATUP)
|
||||
|
||||
# training phase
|
||||
self.in_heatup = False
|
||||
@@ -509,7 +479,12 @@ class Agent(object):
|
||||
# evaluate
|
||||
evaluate_agent = (self.last_episode_evaluation_ran is not self.current_episode) and \
|
||||
(self.current_episode % self.tp.evaluate_every_x_episodes == 0)
|
||||
evaluate_agent = evaluate_agent or \
|
||||
(self.imitation and self.training_iteration > 0 and
|
||||
self.training_iteration % self.tp.evaluate_every_x_training_iterations == 0)
|
||||
|
||||
if evaluate_agent:
|
||||
self.env.reset()
|
||||
self.last_episode_evaluation_ran = self.current_episode
|
||||
self.evaluate(self.tp.evaluation_episodes)
|
||||
|
||||
@@ -522,14 +497,15 @@ class Agent(object):
|
||||
self.save_model(model_snapshots_periods_passed)
|
||||
|
||||
# play and record in replay buffer
|
||||
if self.tp.agent.step_until_collecting_full_episodes:
|
||||
step = 0
|
||||
while step < self.tp.agent.num_consecutive_playing_steps or self.memory.get_episode(-1).length() != 0:
|
||||
self.act()
|
||||
step += 1
|
||||
else:
|
||||
for step in range(self.tp.agent.num_consecutive_playing_steps):
|
||||
self.act()
|
||||
if self.tp.agent.collect_new_data:
|
||||
if self.tp.agent.step_until_collecting_full_episodes:
|
||||
step = 0
|
||||
while step < self.tp.agent.num_consecutive_playing_steps or self.memory.get_episode(-1).length() != 0:
|
||||
self.act()
|
||||
step += 1
|
||||
else:
|
||||
for step in range(self.tp.agent.num_consecutive_playing_steps):
|
||||
self.act()
|
||||
|
||||
# train
|
||||
if self.tp.train:
|
||||
@@ -537,6 +513,8 @@ class Agent(object):
|
||||
loss = self.train()
|
||||
self.loss.add_sample(loss)
|
||||
self.training_iteration += 1
|
||||
if self.imitation:
|
||||
self.log_to_screen(RunPhase.TRAIN)
|
||||
self.post_training_commands()
|
||||
|
||||
def save_model(self, model_id):
|
||||
|
||||
40
agents/bc_agent.py
Normal file
40
agents/bc_agent.py
Normal file
@@ -0,0 +1,40 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from agents.imitation_agent import *
|
||||
|
||||
|
||||
# Behavioral Cloning Agent
|
||||
class BCAgent(ImitationAgent):
|
||||
def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0):
|
||||
ImitationAgent.__init__(self, env, tuning_parameters, replicated_device, thread_id)
|
||||
|
||||
def learn_from_batch(self, batch):
|
||||
current_states, _, actions, _, _, _ = self.extract_batch(batch)
|
||||
|
||||
# create the inputs for the network
|
||||
input = current_states
|
||||
|
||||
# the targets for the network are the actions since this is supervised learning
|
||||
if self.env.discrete_controls:
|
||||
targets = np.eye(self.env.action_space_size)[[actions]]
|
||||
else:
|
||||
targets = actions
|
||||
|
||||
result = self.main_network.train_and_sync_networks(input, targets)
|
||||
total_loss = result[0]
|
||||
|
||||
return total_loss
|
||||
60
agents/distributional_dqn_agent.py
Normal file
60
agents/distributional_dqn_agent.py
Normal file
@@ -0,0 +1,60 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from agents.value_optimization_agent import *
|
||||
|
||||
|
||||
# Distributional Deep Q Network - https://arxiv.org/pdf/1707.06887.pdf
|
||||
class DistributionalDQNAgent(ValueOptimizationAgent):
|
||||
def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0):
|
||||
ValueOptimizationAgent.__init__(self, env, tuning_parameters, replicated_device, thread_id)
|
||||
self.z_values = np.linspace(self.tp.agent.v_min, self.tp.agent.v_max, self.tp.agent.atoms)
|
||||
|
||||
# prediction's format is (batch,actions,atoms)
|
||||
def get_q_values(self, prediction):
|
||||
return np.dot(prediction, self.z_values)
|
||||
|
||||
def learn_from_batch(self, batch):
|
||||
current_states, next_states, actions, rewards, game_overs, _ = self.extract_batch(batch)
|
||||
|
||||
# for the action we actually took, the error is calculated by the atoms distribution
|
||||
# for all other actions, the error is 0
|
||||
distributed_q_st_plus_1 = self.main_network.target_network.predict(next_states)
|
||||
# initialize with the current prediction so that we will
|
||||
TD_targets = self.main_network.online_network.predict(current_states)
|
||||
|
||||
# only update the action that we have actually done in this transition
|
||||
target_actions = np.argmax(self.get_q_values(distributed_q_st_plus_1), axis=1)
|
||||
m = np.zeros((self.tp.batch_size, self.z_values.size))
|
||||
|
||||
batches = np.arange(self.tp.batch_size)
|
||||
for j in range(self.z_values.size):
|
||||
tzj = np.fmax(np.fmin(rewards + (1.0 - game_overs) * self.tp.agent.discount * self.z_values[j],
|
||||
self.z_values[self.z_values.size - 1]),
|
||||
self.z_values[0])
|
||||
bj = (tzj - self.z_values[0])/(self.z_values[1] - self.z_values[0])
|
||||
u = (np.ceil(bj)).astype(int)
|
||||
l = (np.floor(bj)).astype(int)
|
||||
m[batches, l] = m[batches, l] + (distributed_q_st_plus_1[batches, target_actions, j] * (u - bj))
|
||||
m[batches, u] = m[batches, u] + (distributed_q_st_plus_1[batches, target_actions, j] * (bj - l))
|
||||
# total_loss = cross entropy between actual result above and predicted result for the given action
|
||||
TD_targets[batches, actions] = m
|
||||
|
||||
result = self.main_network.train_and_sync_networks(current_states, TD_targets)
|
||||
total_loss = result[0]
|
||||
|
||||
return total_loss
|
||||
|
||||
67
agents/human_agent.py
Normal file
67
agents/human_agent.py
Normal file
@@ -0,0 +1,67 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from agents.agent import *
|
||||
import pygame
|
||||
|
||||
|
||||
class HumanAgent(Agent):
|
||||
def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0):
|
||||
Agent.__init__(self, env, tuning_parameters, replicated_device, thread_id)
|
||||
|
||||
self.clock = pygame.time.Clock()
|
||||
self.max_fps = int(self.tp.visualization.max_fps_for_human_control)
|
||||
|
||||
screen.log_title("Human Control Mode")
|
||||
available_keys = self.env.get_available_keys()
|
||||
if available_keys:
|
||||
screen.log("Use keyboard keys to move. Press escape to quit. Available keys:")
|
||||
screen.log("")
|
||||
for action, key in self.env.get_available_keys():
|
||||
screen.log("\t- {}: {}".format(action, key))
|
||||
screen.separator()
|
||||
|
||||
def train(self):
|
||||
return 0
|
||||
|
||||
def choose_action(self, curr_state, phase=RunPhase.TRAIN):
|
||||
action = self.env.get_action_from_user()
|
||||
|
||||
# keep constant fps
|
||||
self.clock.tick(self.max_fps)
|
||||
|
||||
if not self.env.renderer.is_open:
|
||||
self.save_replay_buffer_and_exit()
|
||||
|
||||
return action, {"action_value": 0}
|
||||
|
||||
def save_replay_buffer_and_exit(self):
|
||||
replay_buffer_path = os.path.join(logger.experiments_path, 'replay_buffer.p')
|
||||
self.memory.tp = None
|
||||
to_pickle(self.memory, replay_buffer_path)
|
||||
screen.log_title("Replay buffer was stored in {}".format(replay_buffer_path))
|
||||
exit()
|
||||
|
||||
def log_to_screen(self, phase):
|
||||
# log to screen
|
||||
screen.log_dict(
|
||||
OrderedDict([
|
||||
("Episode", self.current_episode),
|
||||
("total reward", self.total_reward_in_current_episode),
|
||||
("steps", self.total_steps_counter)
|
||||
]),
|
||||
prefix="Recording"
|
||||
)
|
||||
70
agents/imitation_agent.py
Normal file
70
agents/imitation_agent.py
Normal file
@@ -0,0 +1,70 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from agents.agent import *
|
||||
|
||||
|
||||
# Imitation Agent
|
||||
class ImitationAgent(Agent):
|
||||
def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0):
|
||||
Agent.__init__(self, env, tuning_parameters, replicated_device, thread_id)
|
||||
self.main_network = NetworkWrapper(tuning_parameters, False, self.has_global, 'main',
|
||||
self.replicated_device, self.worker_device)
|
||||
self.networks.append(self.main_network)
|
||||
self.imitation = True
|
||||
|
||||
def extract_action_values(self, prediction):
|
||||
return prediction.squeeze()
|
||||
|
||||
def choose_action(self, curr_state, phase=RunPhase.TRAIN):
|
||||
# convert to batch so we can run it through the network
|
||||
observation = np.expand_dims(np.array(curr_state['observation']), 0)
|
||||
if self.tp.agent.use_measurements:
|
||||
measurements = np.expand_dims(np.array(curr_state['measurements']), 0)
|
||||
prediction = self.main_network.online_network.predict([observation, measurements])
|
||||
else:
|
||||
prediction = self.main_network.online_network.predict(observation)
|
||||
|
||||
# get action values and extract the best action from it
|
||||
action_values = self.extract_action_values(prediction)
|
||||
if self.env.discrete_controls:
|
||||
# DISCRETE
|
||||
# action = np.argmax(action_values)
|
||||
action = self.evaluation_exploration_policy.get_action(action_values)
|
||||
action_value = {"action_probability": action_values[action]}
|
||||
else:
|
||||
# CONTINUOUS
|
||||
action = action_values
|
||||
action_value = {}
|
||||
|
||||
return action, action_value
|
||||
|
||||
def log_to_screen(self, phase):
|
||||
# log to screen
|
||||
if phase == RunPhase.TRAIN:
|
||||
# for the training phase - we log during the episode to visualize the progress in training
|
||||
screen.log_dict(
|
||||
OrderedDict([
|
||||
("Worker", self.task_id),
|
||||
("Episode", self.current_episode),
|
||||
("Loss", self.loss.values[-1]),
|
||||
("Training iteration", self.training_iteration)
|
||||
]),
|
||||
prefix="Training"
|
||||
)
|
||||
else:
|
||||
# for the evaluation phase - logging as in regular RL
|
||||
Agent.log_to_screen(self, phase)
|
||||
@@ -45,7 +45,7 @@ class NStepQAgent(ValueOptimizationAgent, PolicyOptimizationAgent):
|
||||
# 1-Step Q learning
|
||||
q_st_plus_1 = self.main_network.target_network.predict(next_states)
|
||||
|
||||
for i in reversed(xrange(num_transitions)):
|
||||
for i in reversed(range(num_transitions)):
|
||||
state_value_head_targets[i][actions[i]] = \
|
||||
rewards[i] + (1.0 - game_overs[i]) * self.tp.agent.discount * np.max(q_st_plus_1[i], 0)
|
||||
|
||||
@@ -56,7 +56,7 @@ class NStepQAgent(ValueOptimizationAgent, PolicyOptimizationAgent):
|
||||
else:
|
||||
R = np.max(self.main_network.target_network.predict(np.expand_dims(next_states[-1], 0)))
|
||||
|
||||
for i in reversed(xrange(num_transitions)):
|
||||
for i in reversed(range(num_transitions)):
|
||||
R = rewards[i] + self.tp.agent.discount * R
|
||||
state_value_head_targets[i][actions[i]] = R
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ class PolicyOptimizationAgent(Agent):
|
||||
("steps", self.total_steps_counter),
|
||||
("training iteration", self.training_iteration)
|
||||
]),
|
||||
prefix="Heatup" if self.in_heatup else "Training" if phase == RunPhase.TRAIN else "Testing"
|
||||
prefix=phase
|
||||
)
|
||||
|
||||
def update_episode_statistics(self, episode):
|
||||
|
||||
Reference in New Issue
Block a user