mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
integration test changes to reach the train part (#254)
* integration test changes to override heatup to 1000 steps + run each preset for 30 sec (to make sure we reach the train part) * fixes to failing presets uncovered with this change + changes in the golden testing to properly test BatchRL * fix for rainbow dqn * fix to gym_environment (due to a change in Gym 0.12.1) + fix for rainbow DQN + some bug-fix in utils.squeeze_list * fix for NEC agent
This commit is contained in:
@@ -398,7 +398,10 @@ class Agent(AgentInterface):
|
||||
self.accumulated_shaped_rewards_across_evaluation_episodes = 0
|
||||
self.num_successes_across_evaluation_episodes = 0
|
||||
self.num_evaluation_episodes_completed = 0
|
||||
if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == "high":
|
||||
|
||||
# TODO verbosity was mistakenly removed from task_parameters on release 0.11.0, need to bring it back
|
||||
# if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == "high":
|
||||
if self.ap.is_a_highest_level_agent:
|
||||
screen.log_title("{}: Starting evaluation phase".format(self.name))
|
||||
|
||||
elif ending_evaluation:
|
||||
@@ -416,7 +419,10 @@ class Agent(AgentInterface):
|
||||
self.agent_logger.create_signal_value(
|
||||
"Success Rate",
|
||||
success_rate)
|
||||
if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == "high":
|
||||
|
||||
# TODO verbosity was mistakenly removed from task_parameters on release 0.11.0, need to bring it back
|
||||
# if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == "high":
|
||||
if self.ap.is_a_highest_level_agent:
|
||||
screen.log_title("{}: Finished evaluation phase. Success rate = {}, Avg Total Reward = {}"
|
||||
.format(self.name, np.round(success_rate, 2), np.round(evaluation_reward, 2)))
|
||||
|
||||
@@ -565,7 +571,9 @@ class Agent(AgentInterface):
|
||||
self.parent_level_manager.parent_graph_manager.time_metric == TimeTypes.EpisodeNumber:
|
||||
self.update_log()
|
||||
|
||||
if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == "high":
|
||||
# TODO verbosity was mistakenly removed from task_parameters on release 0.11.0, need to bring it back
|
||||
# if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == "high":
|
||||
if self.ap.is_a_highest_level_agent:
|
||||
self.log_to_screen()
|
||||
|
||||
def reset_internal_state(self) -> None:
|
||||
|
||||
@@ -85,7 +85,7 @@ class CategoricalDQNAgent(ValueOptimizationAgent):
|
||||
def get_all_q_values_for_states(self, states: StateType):
|
||||
if self.exploration_policy.requires_action_values():
|
||||
q_values = self.get_prediction(states,
|
||||
outputs=self.networks['main'].online_network.output_heads[0].q_values)
|
||||
outputs=[self.networks['main'].online_network.output_heads[0].q_values])
|
||||
else:
|
||||
q_values = None
|
||||
return q_values
|
||||
|
||||
@@ -162,7 +162,7 @@ class NECAgent(ValueOptimizationAgent):
|
||||
embedding = self.networks['main'].online_network.predict(
|
||||
self.prepare_batch_for_inference(self.curr_state, 'main'),
|
||||
outputs=self.networks['main'].online_network.state_embedding)
|
||||
self.current_episode_state_embeddings.append(embedding)
|
||||
self.current_episode_state_embeddings.append(embedding.squeeze())
|
||||
|
||||
return super().act()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user