1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

fix clipped ppo

This commit is contained in:
Zach Dwiel
2018-02-16 13:22:10 -05:00
parent 85afb86893
commit 39a28aba95
7 changed files with 51 additions and 39 deletions

View File

@@ -338,6 +338,17 @@ class Agent(object):
reward = max(reward, self.tp.env.reward_clipping_min) reward = max(reward, self.tp.env.reward_clipping_min)
return reward return reward
def tf_input_state(self, curr_state):
"""
convert curr_state into input tensors tensorflow is expecting.
"""
# add batch axis with length 1 onto each value
# extract values from the state based on agent.input_types
input_state = {}
for input_name in self.tp.agent.input_types.keys():
input_state[input_name] = np.expand_dims(np.array(curr_state[input_name]), 0)
return input_state
def act(self, phase=RunPhase.TRAIN): def act(self, phase=RunPhase.TRAIN):
""" """
Take one step in the environment according to the network prediction and store the transition in memory Take one step in the environment according to the network prediction and store the transition in memory

View File

@@ -1,5 +1,5 @@
# #
# Copyright (c) 2017 Intel Corporation # Copyright (c) 2017 Intel Corporation
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -39,7 +39,7 @@ class ClippedPPOAgent(ActorCriticAgent):
def fill_advantages(self, batch): def fill_advantages(self, batch):
current_states, next_states, actions, rewards, game_overs, total_return = self.extract_batch(batch) current_states, next_states, actions, rewards, game_overs, total_return = self.extract_batch(batch)
current_state_values = self.main_network.online_network.predict([current_states])[0] current_state_values = self.main_network.online_network.predict(current_states)[0]
current_state_values = current_state_values.squeeze() current_state_values = current_state_values.squeeze()
self.state_values.add_sample(current_state_values) self.state_values.add_sample(current_state_values)
@@ -97,7 +97,7 @@ class ClippedPPOAgent(ActorCriticAgent):
actions = np.expand_dims(actions, -1) actions = np.expand_dims(actions, -1)
# get old policy probabilities and distribution # get old policy probabilities and distribution
result = self.main_network.target_network.predict([current_states]) result = self.main_network.target_network.predict(current_states)
old_policy_distribution = result[1:] old_policy_distribution = result[1:]
# calculate gradients and apply on both the local policy network and on the global policy network # calculate gradients and apply on both the local policy network and on the global policy network
@@ -106,10 +106,18 @@ class ClippedPPOAgent(ActorCriticAgent):
total_return = np.expand_dims(total_return, -1) total_return = np.expand_dims(total_return, -1)
value_targets = gae_based_value_targets if self.tp.agent.estimate_value_using_gae else total_return value_targets = gae_based_value_targets if self.tp.agent.estimate_value_using_gae else total_return
inputs = copy.copy(current_states)
# TODO: why is this output 0 and not output 1?
inputs['output_0_0'] = actions
# TODO: does old_policy_distribution really need to be represented as a list?
# A: yes it does, in the event of discrete controls, it has just a mean
# otherwise, it has both a mean and standard deviation
for input_index, input in enumerate(old_policy_distribution):
inputs['output_0_{}'.format(input_index + 1)] = input
# print('old_policy_distribution.shape', len(old_policy_distribution))
total_loss, policy_losses, unclipped_grads, fetch_result =\ total_loss, policy_losses, unclipped_grads, fetch_result =\
self.main_network.online_network.accumulate_gradients( self.main_network.online_network.accumulate_gradients(
[current_states] + [actions] + old_policy_distribution, inputs, [total_return, advantages], additional_fetches=fetches)
[total_return, advantages], additional_fetches=fetches)
self.value_targets.add_sample(value_targets) self.value_targets.add_sample(value_targets)
if self.tp.distributed: if self.tp.distributed:
@@ -177,14 +185,10 @@ class ClippedPPOAgent(ActorCriticAgent):
self.update_log() # should be done in order to update the data that has been accumulated * while not playing * self.update_log() # should be done in order to update the data that has been accumulated * while not playing *
return np.append(losses[0], losses[1]) return np.append(losses[0], losses[1])
def choose_action(self, curr_state, phase=RunPhase.TRAIN): def choose_action(self, current_state, phase=RunPhase.TRAIN):
# convert to batch so we can run it through the network
observation = curr_state['observation']
observation = np.expand_dims(np.array(observation), 0)
if self.env.discrete_controls: if self.env.discrete_controls:
# DISCRETE # DISCRETE
_, action_values = self.main_network.online_network.predict(observation) _, action_values = self.main_network.online_network.predict(self.tf_input_state(current_state))
action_values = action_values.squeeze() action_values = action_values.squeeze()
if phase == RunPhase.TRAIN: if phase == RunPhase.TRAIN:
@@ -195,7 +199,7 @@ class ClippedPPOAgent(ActorCriticAgent):
# self.entropy.add_sample(-np.sum(action_values * np.log(action_values))) # self.entropy.add_sample(-np.sum(action_values * np.log(action_values)))
else: else:
# CONTINUOUS # CONTINUOUS
_, action_values_mean, action_values_std = self.main_network.online_network.predict(observation) _, action_values_mean, action_values_std = self.main_network.online_network.predict(self.tf_input_state(current_state))
action_values_mean = action_values_mean.squeeze() action_values_mean = action_values_mean.squeeze()
action_values_std = action_values_std.squeeze() action_values_std = action_values_std.squeeze()
if phase == RunPhase.TRAIN: if phase == RunPhase.TRAIN:

View File

@@ -1,5 +1,5 @@
# #
# Copyright (c) 2017 Intel Corporation # Copyright (c) 2017 Intel Corporation
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -53,7 +53,7 @@ class PPOAgent(ActorCriticAgent):
# * Found not to have any impact * # * Found not to have any impact *
# current_states_with_timestep = self.concat_state_and_timestep(batch) # current_states_with_timestep = self.concat_state_and_timestep(batch)
current_state_values = self.critic_network.online_network.predict([current_states]).squeeze() current_state_values = self.critic_network.online_network.predict(current_state).squeeze()
# calculate advantages # calculate advantages
advantages = [] advantages = []
@@ -105,11 +105,11 @@ class PPOAgent(ActorCriticAgent):
current_states_batch = current_states[i * batch_size:(i + 1) * batch_size] current_states_batch = current_states[i * batch_size:(i + 1) * batch_size]
total_return_batch = total_return[i * batch_size:(i + 1) * batch_size] total_return_batch = total_return[i * batch_size:(i + 1) * batch_size]
old_policy_values = force_list(self.critic_network.target_network.predict( old_policy_values = force_list(self.critic_network.target_network.predict(
[current_states_batch]).squeeze()) current_states_batch).squeeze())
if self.critic_network.online_network.optimizer_type != 'LBFGS': if self.critic_network.online_network.optimizer_type != 'LBFGS':
targets = total_return_batch targets = total_return_batch
else: else:
current_values = self.critic_network.online_network.predict([current_states_batch]) current_values = self.critic_network.online_network.predict(current_states_batch)
targets = current_values * (1 - mix_fraction) + total_return_batch * mix_fraction targets = current_values * (1 - mix_fraction) + total_return_batch * mix_fraction
value_loss = self.critic_network.online_network.\ value_loss = self.critic_network.online_network.\

View File

@@ -36,23 +36,6 @@ class ValueOptimizationAgent(Agent):
def get_q_values(self, prediction): def get_q_values(self, prediction):
return prediction return prediction
def tf_input_state(self, curr_state):
"""
convert curr_state into input tensors tensorflow is expecting.
TODO: move this function into Agent and use in as many agent implementations as possible
currently, other agents will likely not work with environment measurements.
This will become even more important as we support more complex and varied environment states.
"""
# convert to batch so we can run it through the network
observation = np.expand_dims(np.array(curr_state['observation']), 0)
if self.tp.agent.use_measurements:
measurements = np.expand_dims(np.array(curr_state['measurements']), 0)
tf_input_state = [observation, measurements]
else:
tf_input_state = observation
return tf_input_state
def get_prediction(self, curr_state): def get_prediction(self, curr_state):
return self.main_network.online_network.predict(self.tf_input_state(curr_state)) return self.main_network.online_network.predict(self.tf_input_state(curr_state))

View File

@@ -267,10 +267,20 @@ class TensorFlowArchitecture(Architecture):
time.sleep(0.00001) time.sleep(0.00001)
def _feed_dict(self, inputs): def _feed_dict(self, inputs):
return { feed_dict = {}
self.inputs[input_name]: input_value for input_name, input_value in inputs.items():
for input_name, input_value in inputs.items() if input_name not in self.inputs:
} raise ValueError((
'input name {input_name} was provided to create a feed '
'dictionary, but there is no placeholder with that name. '
'placeholder names available include: {placeholder_names}'
).format(
input_name=input_name,
placeholder_names=', '.join(self.inputs.keys())
))
feed_dict[self.inputs[input_name]] = input_value
return feed_dict
def predict(self, inputs, outputs=None): def predict(self, inputs, outputs=None):
""" """

View File

@@ -327,13 +327,15 @@ if __name__ == "__main__":
set_cpu() set_cpu()
# create a parameter server # create a parameter server
parameter_server = Popen([ cmd = [
"python3", "python3",
"./parallel_actor.py", "./parallel_actor.py",
"--ps_hosts={}".format(ps_hosts), "--ps_hosts={}".format(ps_hosts),
"--worker_hosts={}".format(worker_hosts), "--worker_hosts={}".format(worker_hosts),
"--job_name=ps", "--job_name=ps",
], stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=1) ]
print(' '.join(cmd))
parameter_server = Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=1)
screen.log_title("*** Distributed Training ***") screen.log_title("*** Distributed Training ***")
time.sleep(1) time.sleep(1)
@@ -358,6 +360,7 @@ if __name__ == "__main__":
"--job_name=worker", "--job_name=worker",
"--load_json={}".format(json_run_dict_path)] "--load_json={}".format(json_run_dict_path)]
print(' '.join(workers_args))
p = Popen(workers_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=1) p = Popen(workers_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=1)
if i != run_dict['num_threads']: if i != run_dict['num_threads']:

View File

@@ -132,6 +132,7 @@ def parse_int(value):
def set_gpu(gpu_id): def set_gpu(gpu_id):
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id) os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
os.environ['NVIDIA_VISIBLE_DEVICES'] = str(gpu_id)
def set_cpu(): def set_cpu():