1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 19:50:17 +01:00

Merge branch 'master' into imports

This commit is contained in:
Roman Dobosz
2018-04-24 07:43:04 +02:00
124 changed files with 10828 additions and 17 deletions

View File

@@ -360,7 +360,10 @@ class Agent(object):
'observation': observation
}
if self.tp.agent.use_measurements:
self.curr_state['measurements'] = self.env.measurements
if 'measurements' in self.curr_state.keys():
self.curr_state['measurements'] = self.env.state['measurements']
else:
self.curr_state['measurements'] = np.zeros(0)
if self.tp.agent.use_accumulated_reward_as_measurement:
self.curr_state['measurements'] = np.append(self.curr_state['measurements'], 0)
@@ -393,7 +396,7 @@ class Agent(object):
shaped_reward += action_info['action_intrinsic_reward']
# TODO: should total_reward_in_current_episode include shaped_reward?
self.total_reward_in_current_episode += result['reward']
next_state = result['state']
next_state = copy.copy(result['state'])
next_state['observation'] = self.preprocess_observation(next_state['observation'])
# plot action values online
@@ -406,8 +409,11 @@ class Agent(object):
observation = utils.LazyStack(self.curr_stack, -1)
next_state['observation'] = observation
if self.tp.agent.use_measurements and 'measurements' in result.keys():
next_state['measurements'] = result['state']['measurements']
if self.tp.agent.use_measurements:
if 'measurements' in result['state'].keys():
next_state['measurements'] = result['state']['measurements']
else:
next_state['measurements'] = np.zeros(0)
if self.tp.agent.use_accumulated_reward_as_measurement:
next_state['measurements'] = np.append(next_state['measurements'], self.total_reward_in_current_episode)

View File

@@ -34,7 +34,7 @@ class DFPAgent(agent.Agent):
# create the inputs for the network
input = current_states
input.append(np.repeat(np.expand_dims(self.current_goal, 0), self.tp.batch_size, 0))
input['goal'] = np.repeat(np.expand_dims(self.current_goal, 0), self.tp.batch_size, 0)
# get the current outputs of the network
targets = self.main_network.online_network.predict(input)
@@ -43,7 +43,7 @@ class DFPAgent(agent.Agent):
for i in range(self.tp.batch_size):
targets[i, actions[i]] = batch[i].info['future_measurements'].flatten()
result = self.main_network.train_and_sync_networks(current_states, targets)
result = self.main_network.train_and_sync_networks(input, targets)
total_loss = result[0]
return total_loss
@@ -55,7 +55,10 @@ class DFPAgent(agent.Agent):
goal = np.expand_dims(self.current_goal, 0)
# predict the future measurements
measurements_future_prediction = self.main_network.online_network.predict([observation, measurements, goal])[0]
measurements_future_prediction = self.main_network.online_network.predict({
"observation": observation,
"measurements": measurements,
"goal": goal})[0]
action_values = np.zeros((self.action_space_size,))
num_steps_used_for_objective = len(self.tp.agent.future_measurements_weights)