mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
multiple bug fixes in dealing with measurements + CartPole_DFP preset (#92)
This commit is contained in:
@@ -365,7 +365,10 @@ class Agent(object):
|
||||
'observation': observation
|
||||
}
|
||||
if self.tp.agent.use_measurements:
|
||||
self.curr_state['measurements'] = self.env.measurements
|
||||
if 'measurements' in self.curr_state.keys():
|
||||
self.curr_state['measurements'] = self.env.state['measurements']
|
||||
else:
|
||||
self.curr_state['measurements'] = np.zeros(0)
|
||||
if self.tp.agent.use_accumulated_reward_as_measurement:
|
||||
self.curr_state['measurements'] = np.append(self.curr_state['measurements'], 0)
|
||||
|
||||
@@ -398,7 +401,7 @@ class Agent(object):
|
||||
shaped_reward += action_info['action_intrinsic_reward']
|
||||
# TODO: should total_reward_in_current_episode include shaped_reward?
|
||||
self.total_reward_in_current_episode += result['reward']
|
||||
next_state = result['state']
|
||||
next_state = copy.copy(result['state'])
|
||||
next_state['observation'] = self.preprocess_observation(next_state['observation'])
|
||||
|
||||
# plot action values online
|
||||
@@ -411,8 +414,11 @@ class Agent(object):
|
||||
observation = LazyStack(self.curr_stack, -1)
|
||||
|
||||
next_state['observation'] = observation
|
||||
if self.tp.agent.use_measurements and 'measurements' in result.keys():
|
||||
next_state['measurements'] = result['state']['measurements']
|
||||
if self.tp.agent.use_measurements:
|
||||
if 'measurements' in result['state'].keys():
|
||||
next_state['measurements'] = result['state']['measurements']
|
||||
else:
|
||||
next_state['measurements'] = np.zeros(0)
|
||||
if self.tp.agent.use_accumulated_reward_as_measurement:
|
||||
next_state['measurements'] = np.append(next_state['measurements'], self.total_reward_in_current_episode)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user