mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
multiple bug fixes in dealing with measurements + CartPole_DFP preset (#92)
This commit is contained in:
@@ -365,7 +365,10 @@ class Agent(object):
|
|||||||
'observation': observation
|
'observation': observation
|
||||||
}
|
}
|
||||||
if self.tp.agent.use_measurements:
|
if self.tp.agent.use_measurements:
|
||||||
self.curr_state['measurements'] = self.env.measurements
|
if 'measurements' in self.curr_state.keys():
|
||||||
|
self.curr_state['measurements'] = self.env.state['measurements']
|
||||||
|
else:
|
||||||
|
self.curr_state['measurements'] = np.zeros(0)
|
||||||
if self.tp.agent.use_accumulated_reward_as_measurement:
|
if self.tp.agent.use_accumulated_reward_as_measurement:
|
||||||
self.curr_state['measurements'] = np.append(self.curr_state['measurements'], 0)
|
self.curr_state['measurements'] = np.append(self.curr_state['measurements'], 0)
|
||||||
|
|
||||||
@@ -398,7 +401,7 @@ class Agent(object):
|
|||||||
shaped_reward += action_info['action_intrinsic_reward']
|
shaped_reward += action_info['action_intrinsic_reward']
|
||||||
# TODO: should total_reward_in_current_episode include shaped_reward?
|
# TODO: should total_reward_in_current_episode include shaped_reward?
|
||||||
self.total_reward_in_current_episode += result['reward']
|
self.total_reward_in_current_episode += result['reward']
|
||||||
next_state = result['state']
|
next_state = copy.copy(result['state'])
|
||||||
next_state['observation'] = self.preprocess_observation(next_state['observation'])
|
next_state['observation'] = self.preprocess_observation(next_state['observation'])
|
||||||
|
|
||||||
# plot action values online
|
# plot action values online
|
||||||
@@ -411,8 +414,11 @@ class Agent(object):
|
|||||||
observation = LazyStack(self.curr_stack, -1)
|
observation = LazyStack(self.curr_stack, -1)
|
||||||
|
|
||||||
next_state['observation'] = observation
|
next_state['observation'] = observation
|
||||||
if self.tp.agent.use_measurements and 'measurements' in result.keys():
|
if self.tp.agent.use_measurements:
|
||||||
next_state['measurements'] = result['state']['measurements']
|
if 'measurements' in result['state'].keys():
|
||||||
|
next_state['measurements'] = result['state']['measurements']
|
||||||
|
else:
|
||||||
|
next_state['measurements'] = np.zeros(0)
|
||||||
if self.tp.agent.use_accumulated_reward_as_measurement:
|
if self.tp.agent.use_accumulated_reward_as_measurement:
|
||||||
next_state['measurements'] = np.append(next_state['measurements'], self.total_reward_in_current_episode)
|
next_state['measurements'] = np.append(next_state['measurements'], self.total_reward_in_current_episode)
|
||||||
|
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ class DFPAgent(Agent):
|
|||||||
|
|
||||||
# create the inputs for the network
|
# create the inputs for the network
|
||||||
input = current_states
|
input = current_states
|
||||||
input.append(np.repeat(np.expand_dims(self.current_goal, 0), self.tp.batch_size, 0))
|
input['goal'] = np.repeat(np.expand_dims(self.current_goal, 0), self.tp.batch_size, 0)
|
||||||
|
|
||||||
# get the current outputs of the network
|
# get the current outputs of the network
|
||||||
targets = self.main_network.online_network.predict(input)
|
targets = self.main_network.online_network.predict(input)
|
||||||
@@ -40,7 +40,7 @@ class DFPAgent(Agent):
|
|||||||
for i in range(self.tp.batch_size):
|
for i in range(self.tp.batch_size):
|
||||||
targets[i, actions[i]] = batch[i].info['future_measurements'].flatten()
|
targets[i, actions[i]] = batch[i].info['future_measurements'].flatten()
|
||||||
|
|
||||||
result = self.main_network.train_and_sync_networks(current_states, targets)
|
result = self.main_network.train_and_sync_networks(input, targets)
|
||||||
total_loss = result[0]
|
total_loss = result[0]
|
||||||
|
|
||||||
return total_loss
|
return total_loss
|
||||||
@@ -52,7 +52,10 @@ class DFPAgent(Agent):
|
|||||||
goal = np.expand_dims(self.current_goal, 0)
|
goal = np.expand_dims(self.current_goal, 0)
|
||||||
|
|
||||||
# predict the future measurements
|
# predict the future measurements
|
||||||
measurements_future_prediction = self.main_network.online_network.predict([observation, measurements, goal])[0]
|
measurements_future_prediction = self.main_network.online_network.predict({
|
||||||
|
"observation": observation,
|
||||||
|
"measurements": measurements,
|
||||||
|
"goal": goal})[0]
|
||||||
action_values = np.zeros((self.action_space_size,))
|
action_values = np.zeros((self.action_space_size,))
|
||||||
num_steps_used_for_objective = len(self.tp.agent.future_measurements_weights)
|
num_steps_used_for_objective = len(self.tp.agent.future_measurements_weights)
|
||||||
|
|
||||||
|
|||||||
@@ -135,7 +135,7 @@ class DoomEnvironmentWrapper(EnvironmentWrapper):
|
|||||||
# extract all data from the current state
|
# extract all data from the current state
|
||||||
state = self.game.get_state()
|
state = self.game.get_state()
|
||||||
if state is not None and state.screen_buffer is not None:
|
if state is not None and state.screen_buffer is not None:
|
||||||
self.observation = {
|
self.state = {
|
||||||
'observation': state.screen_buffer,
|
'observation': state.screen_buffer,
|
||||||
'measurements': state.game_variables,
|
'measurements': state.game_variables,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -113,7 +113,7 @@ class GymEnvironmentWrapper(EnvironmentWrapper):
|
|||||||
self.timestep_limit = self.env.spec.timestep_limit
|
self.timestep_limit = self.env.spec.timestep_limit
|
||||||
else:
|
else:
|
||||||
self.timestep_limit = None
|
self.timestep_limit = None
|
||||||
self.measurements_size = len(self.step(0)['info'].keys())
|
self.measurements_size = (len(self.step(0)['info'].keys()),)
|
||||||
self.random_initialization_steps = self.tp.env.random_initialization_steps
|
self.random_initialization_steps = self.tp.env.random_initialization_steps
|
||||||
|
|
||||||
def _wrap_state(self, state):
|
def _wrap_state(self, state):
|
||||||
|
|||||||
13
presets.py
13
presets.py
@@ -200,6 +200,19 @@ class CartPole_PAL(Preset):
|
|||||||
self.test_max_step_threshold = 100
|
self.test_max_step_threshold = 100
|
||||||
self.test_min_return_threshold = 150
|
self.test_min_return_threshold = 150
|
||||||
|
|
||||||
|
|
||||||
|
class CartPole_DFP(Preset):
|
||||||
|
def __init__(self):
|
||||||
|
Preset.__init__(self, DFP, GymVectorObservation, ExplorationParameters)
|
||||||
|
self.env.level = 'CartPole-v0'
|
||||||
|
self.agent.num_episodes_in_experience_replay = 200
|
||||||
|
self.learning_rate = 0.0001
|
||||||
|
self.num_heatup_steps = 1000
|
||||||
|
self.exploration.epsilon_decay_steps = 10000
|
||||||
|
self.agent.use_accumulated_reward_as_measurement = True
|
||||||
|
self.agent.goal_vector = [1.0]
|
||||||
|
|
||||||
|
|
||||||
class Doom_Basic_DFP(Preset):
|
class Doom_Basic_DFP(Preset):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
Preset.__init__(self, DFP, Doom, ExplorationParameters)
|
Preset.__init__(self, DFP, Doom, ExplorationParameters)
|
||||||
|
|||||||
Reference in New Issue
Block a user