1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

multiple bug fixes in dealing with measurements + CartPole_DFP preset (#92)

This commit is contained in:
Itai Caspi
2018-04-23 10:44:46 +03:00
committed by GitHub
parent 5d5562bf62
commit 52eb159f69
5 changed files with 31 additions and 9 deletions

View File

@@ -365,7 +365,10 @@ class Agent(object):
'observation': observation
}
if self.tp.agent.use_measurements:
self.curr_state['measurements'] = self.env.measurements
if 'measurements' in self.curr_state.keys():
self.curr_state['measurements'] = self.env.state['measurements']
else:
self.curr_state['measurements'] = np.zeros(0)
if self.tp.agent.use_accumulated_reward_as_measurement:
self.curr_state['measurements'] = np.append(self.curr_state['measurements'], 0)
@@ -398,7 +401,7 @@ class Agent(object):
shaped_reward += action_info['action_intrinsic_reward']
# TODO: should total_reward_in_current_episode include shaped_reward?
self.total_reward_in_current_episode += result['reward']
next_state = result['state']
next_state = copy.copy(result['state'])
next_state['observation'] = self.preprocess_observation(next_state['observation'])
# plot action values online
@@ -411,8 +414,11 @@ class Agent(object):
observation = LazyStack(self.curr_stack, -1)
next_state['observation'] = observation
if self.tp.agent.use_measurements and 'measurements' in result.keys():
next_state['measurements'] = result['state']['measurements']
if self.tp.agent.use_measurements:
if 'measurements' in result['state'].keys():
next_state['measurements'] = result['state']['measurements']
else:
next_state['measurements'] = np.zeros(0)
if self.tp.agent.use_accumulated_reward_as_measurement:
next_state['measurements'] = np.append(next_state['measurements'], self.total_reward_in_current_episode)