1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

integration test changes to reach the train part (#254)

* integration test changes to override heatup to 1000 steps +  run each preset for 30 sec (to make sure we reach the train part)

* fixes to failing presets uncovered with this change + changes in the golden testing to properly test BatchRL

* fix for rainbow dqn

* fix to gym_environment (due to a change in Gym 0.12.1) + fix for rainbow DQN + some bug-fix in utils.squeeze_list

* fix for NEC agent
This commit is contained in:
Gal Leibovich
2019-03-27 21:14:19 +02:00
committed by GitHub
parent 6e08c55ad5
commit 310d31c227
8 changed files with 28 additions and 17 deletions

View File

@@ -398,7 +398,10 @@ class Agent(AgentInterface):
self.accumulated_shaped_rewards_across_evaluation_episodes = 0 self.accumulated_shaped_rewards_across_evaluation_episodes = 0
self.num_successes_across_evaluation_episodes = 0 self.num_successes_across_evaluation_episodes = 0
self.num_evaluation_episodes_completed = 0 self.num_evaluation_episodes_completed = 0
if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == "high":
# TODO verbosity was mistakenly removed from task_parameters on release 0.11.0, need to bring it back
# if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == "high":
if self.ap.is_a_highest_level_agent:
screen.log_title("{}: Starting evaluation phase".format(self.name)) screen.log_title("{}: Starting evaluation phase".format(self.name))
elif ending_evaluation: elif ending_evaluation:
@@ -416,7 +419,10 @@ class Agent(AgentInterface):
self.agent_logger.create_signal_value( self.agent_logger.create_signal_value(
"Success Rate", "Success Rate",
success_rate) success_rate)
if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == "high":
# TODO verbosity was mistakenly removed from task_parameters on release 0.11.0, need to bring it back
# if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == "high":
if self.ap.is_a_highest_level_agent:
screen.log_title("{}: Finished evaluation phase. Success rate = {}, Avg Total Reward = {}" screen.log_title("{}: Finished evaluation phase. Success rate = {}, Avg Total Reward = {}"
.format(self.name, np.round(success_rate, 2), np.round(evaluation_reward, 2))) .format(self.name, np.round(success_rate, 2), np.round(evaluation_reward, 2)))
@@ -565,7 +571,9 @@ class Agent(AgentInterface):
self.parent_level_manager.parent_graph_manager.time_metric == TimeTypes.EpisodeNumber: self.parent_level_manager.parent_graph_manager.time_metric == TimeTypes.EpisodeNumber:
self.update_log() self.update_log()
if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == "high": # TODO verbosity was mistakenly removed from task_parameters on release 0.11.0, need to bring it back
# if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == "high":
if self.ap.is_a_highest_level_agent:
self.log_to_screen() self.log_to_screen()
def reset_internal_state(self) -> None: def reset_internal_state(self) -> None:

View File

@@ -85,7 +85,7 @@ class CategoricalDQNAgent(ValueOptimizationAgent):
def get_all_q_values_for_states(self, states: StateType): def get_all_q_values_for_states(self, states: StateType):
if self.exploration_policy.requires_action_values(): if self.exploration_policy.requires_action_values():
q_values = self.get_prediction(states, q_values = self.get_prediction(states,
outputs=self.networks['main'].online_network.output_heads[0].q_values) outputs=[self.networks['main'].online_network.output_heads[0].q_values])
else: else:
q_values = None q_values = None
return q_values return q_values

View File

@@ -162,7 +162,7 @@ class NECAgent(ValueOptimizationAgent):
embedding = self.networks['main'].online_network.predict( embedding = self.networks['main'].online_network.predict(
self.prepare_batch_for_inference(self.curr_state, 'main'), self.prepare_batch_for_inference(self.curr_state, 'main'),
outputs=self.networks['main'].online_network.state_embedding) outputs=self.networks['main'].online_network.state_embedding)
self.current_episode_state_embeddings.append(embedding) self.current_episode_state_embeddings.append(embedding.squeeze())
return super().act() return super().act()

View File

@@ -321,7 +321,7 @@ class GymEnvironment(Environment):
self.state_space = StateSpace({}) self.state_space = StateSpace({})
# observations # observations
if not isinstance(self.env.observation_space, gym.spaces.dict_space.Dict): if not isinstance(self.env.observation_space, gym.spaces.dict.Dict):
state_space = {'observation': self.env.observation_space} state_space = {'observation': self.env.observation_space}
else: else:
state_space = self.env.observation_space.spaces state_space = self.env.observation_space.spaces

View File

@@ -29,7 +29,7 @@ schedule_params.heatup_steps = EnvironmentSteps(DATASET_SIZE)
######### #########
# TODO add a preset which uses a dataset to train a BatchRL graph. e.g. save a cartpole dataset in a csv format. # TODO add a preset which uses a dataset to train a BatchRL graph. e.g. save a cartpole dataset in a csv format.
agent_params = DDQNAgentParameters() agent_params = DDQNAgentParameters()
agent_params.network_wrappers['main'].batch_size = 1024 agent_params.network_wrappers['main'].batch_size = 128
# DQN params # DQN params
# agent_params.algorithm.num_steps_between_copying_online_weights_to_target = TrainingSteps(100) # agent_params.algorithm.num_steps_between_copying_online_weights_to_target = TrainingSteps(100)
@@ -81,11 +81,11 @@ env_params = GymVectorEnvironment(level='CartPole-v0')
preset_validation_params = PresetValidationParameters() preset_validation_params = PresetValidationParameters()
preset_validation_params.test = True preset_validation_params.test = True
preset_validation_params.min_reward_threshold = 150 preset_validation_params.min_reward_threshold = 150
preset_validation_params.max_episodes_to_achieve_reward = 250 preset_validation_params.max_episodes_to_achieve_reward = 2000
graph_manager = BatchRLGraphManager(agent_params=agent_params, env_params=env_params, graph_manager = BatchRLGraphManager(agent_params=agent_params, env_params=env_params,
schedule_params=schedule_params, schedule_params=schedule_params,
vis_params=VisualizationParameters(dump_signals_to_csv_every_x_episodes=1), vis_params=VisualizationParameters(dump_signals_to_csv_every_x_episodes=1),
preset_validation_params=preset_validation_params, preset_validation_params=preset_validation_params,
reward_model_num_epochs=50, reward_model_num_epochs=30,
train_to_eval_ratio=0.8) train_to_eval_ratio=0.8)

View File

@@ -57,14 +57,16 @@ def test_preset_runs(preset):
experiment_name = ".test-" + preset experiment_name = ".test-" + preset
params = ["python3", "rl_coach/coach.py", "-p", preset, "-ns", "-e", experiment_name] # overriding heatup steps to some small number of steps (1000), so to finish the heatup stage, and get to train
params = ["python3", "rl_coach/coach.py", "-p", preset, "-ns", "-e", experiment_name, '-cp',
'heatup_steps=EnvironmentSteps(1000)']
if level != "": if level != "":
params += ["-lvl", level] params += ["-lvl", level]
p = Popen(params) p = Popen(params)
# wait 10 seconds overhead of initialization etc. # wait 30 seconds overhead of initialization, and finishing heatup.
time.sleep(10) time.sleep(30)
return_value = p.poll() return_value = p.poll()
if return_value is None: if return_value is None:

View File

@@ -33,7 +33,7 @@ import pytest
from rl_coach.logger import screen from rl_coach.logger import screen
def read_csv_paths(test_path, filename_pattern, read_csv_tries=100): def read_csv_paths(test_path, filename_pattern, read_csv_tries=200):
csv_paths = [] csv_paths = []
tries_counter = 0 tries_counter = 0
while not csv_paths: while not csv_paths:
@@ -155,7 +155,7 @@ def test_preset_reward(preset_name, no_progress_bar=True, time_limit=60 * 60, ve
if not no_progress_bar: if not no_progress_bar:
print_progress(averaged_rewards, last_num_episodes, preset_validation_params, start_time, time_limit) print_progress(averaged_rewards, last_num_episodes, preset_validation_params, start_time, time_limit)
while csv is None or (csv['Episode #'].values[ while csv is None or (csv[csv.columns[0]].values[
-1] < preset_validation_params.max_episodes_to_achieve_reward and time.time() - start_time < time_limit): -1] < preset_validation_params.max_episodes_to_achieve_reward and time.time() - start_time < time_limit):
try: try:
csv = pd.read_csv(csv_path) csv = pd.read_csv(csv_path)
@@ -179,10 +179,10 @@ def test_preset_reward(preset_name, no_progress_bar=True, time_limit=60 * 60, ve
if not no_progress_bar: if not no_progress_bar:
print_progress(averaged_rewards, last_num_episodes, preset_validation_params, start_time, time_limit) print_progress(averaged_rewards, last_num_episodes, preset_validation_params, start_time, time_limit)
if csv['Episode #'].shape[0] - last_num_episodes <= 0: if csv[csv.columns[0]].shape[0] - last_num_episodes <= 0:
continue continue
last_num_episodes = csv['Episode #'].values[-1] last_num_episodes = csv[csv.columns[0]].values[-1]
# check if reward is enough # check if reward is enough
if np.any(averaged_rewards >= preset_validation_params.min_reward_threshold): if np.any(averaged_rewards >= preset_validation_params.min_reward_threshold):
@@ -213,6 +213,7 @@ def test_preset_reward(preset_name, no_progress_bar=True, time_limit=60 * 60, ve
preset_validation_params.min_reward_threshold), crash=False) preset_validation_params.min_reward_threshold), crash=False)
screen.error("averaged_rewards: {}".format(averaged_rewards), crash=False) screen.error("averaged_rewards: {}".format(averaged_rewards), crash=False)
screen.error("episode number: {}".format(csv['Episode #'].values[-1]), crash=False) screen.error("episode number: {}".format(csv['Episode #'].values[-1]), crash=False)
screen.error("training iteration: {}".format(csv['Training Iter'].values[-1]), crash=False)
else: else:
screen.error("csv file never found", crash=False) screen.error("csv file never found", crash=False)
if verbose: if verbose:

View File

@@ -220,7 +220,7 @@ def force_list(var):
def squeeze_list(var): def squeeze_list(var):
if len(var) == 1: if type(var) == list and len(var) == 1:
return var[0] return var[0]
else: else:
return var return var