mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
integration test changes to reach the train part (#254)
* integration test changes to override heatup to 1000 steps + run each preset for 30 sec (to make sure we reach the train part) * fixes to failing presets uncovered with this change + changes in the golden testing to properly test BatchRL * fix for rainbow dqn * fix to gym_environment (due to a change in Gym 0.12.1) + fix for rainbow DQN + some bug-fix in utils.squeeze_list * fix for NEC agent
This commit is contained in:
@@ -398,7 +398,10 @@ class Agent(AgentInterface):
|
||||
self.accumulated_shaped_rewards_across_evaluation_episodes = 0
|
||||
self.num_successes_across_evaluation_episodes = 0
|
||||
self.num_evaluation_episodes_completed = 0
|
||||
if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == "high":
|
||||
|
||||
# TODO verbosity was mistakenly removed from task_parameters on release 0.11.0, need to bring it back
|
||||
# if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == "high":
|
||||
if self.ap.is_a_highest_level_agent:
|
||||
screen.log_title("{}: Starting evaluation phase".format(self.name))
|
||||
|
||||
elif ending_evaluation:
|
||||
@@ -416,7 +419,10 @@ class Agent(AgentInterface):
|
||||
self.agent_logger.create_signal_value(
|
||||
"Success Rate",
|
||||
success_rate)
|
||||
if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == "high":
|
||||
|
||||
# TODO verbosity was mistakenly removed from task_parameters on release 0.11.0, need to bring it back
|
||||
# if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == "high":
|
||||
if self.ap.is_a_highest_level_agent:
|
||||
screen.log_title("{}: Finished evaluation phase. Success rate = {}, Avg Total Reward = {}"
|
||||
.format(self.name, np.round(success_rate, 2), np.round(evaluation_reward, 2)))
|
||||
|
||||
@@ -565,7 +571,9 @@ class Agent(AgentInterface):
|
||||
self.parent_level_manager.parent_graph_manager.time_metric == TimeTypes.EpisodeNumber:
|
||||
self.update_log()
|
||||
|
||||
if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == "high":
|
||||
# TODO verbosity was mistakenly removed from task_parameters on release 0.11.0, need to bring it back
|
||||
# if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == "high":
|
||||
if self.ap.is_a_highest_level_agent:
|
||||
self.log_to_screen()
|
||||
|
||||
def reset_internal_state(self) -> None:
|
||||
|
||||
@@ -85,7 +85,7 @@ class CategoricalDQNAgent(ValueOptimizationAgent):
|
||||
def get_all_q_values_for_states(self, states: StateType):
|
||||
if self.exploration_policy.requires_action_values():
|
||||
q_values = self.get_prediction(states,
|
||||
outputs=self.networks['main'].online_network.output_heads[0].q_values)
|
||||
outputs=[self.networks['main'].online_network.output_heads[0].q_values])
|
||||
else:
|
||||
q_values = None
|
||||
return q_values
|
||||
|
||||
@@ -162,7 +162,7 @@ class NECAgent(ValueOptimizationAgent):
|
||||
embedding = self.networks['main'].online_network.predict(
|
||||
self.prepare_batch_for_inference(self.curr_state, 'main'),
|
||||
outputs=self.networks['main'].online_network.state_embedding)
|
||||
self.current_episode_state_embeddings.append(embedding)
|
||||
self.current_episode_state_embeddings.append(embedding.squeeze())
|
||||
|
||||
return super().act()
|
||||
|
||||
|
||||
@@ -321,7 +321,7 @@ class GymEnvironment(Environment):
|
||||
self.state_space = StateSpace({})
|
||||
|
||||
# observations
|
||||
if not isinstance(self.env.observation_space, gym.spaces.dict_space.Dict):
|
||||
if not isinstance(self.env.observation_space, gym.spaces.dict.Dict):
|
||||
state_space = {'observation': self.env.observation_space}
|
||||
else:
|
||||
state_space = self.env.observation_space.spaces
|
||||
|
||||
@@ -29,7 +29,7 @@ schedule_params.heatup_steps = EnvironmentSteps(DATASET_SIZE)
|
||||
#########
|
||||
# TODO add a preset which uses a dataset to train a BatchRL graph. e.g. save a cartpole dataset in a csv format.
|
||||
agent_params = DDQNAgentParameters()
|
||||
agent_params.network_wrappers['main'].batch_size = 1024
|
||||
agent_params.network_wrappers['main'].batch_size = 128
|
||||
|
||||
# DQN params
|
||||
# agent_params.algorithm.num_steps_between_copying_online_weights_to_target = TrainingSteps(100)
|
||||
@@ -81,11 +81,11 @@ env_params = GymVectorEnvironment(level='CartPole-v0')
|
||||
preset_validation_params = PresetValidationParameters()
|
||||
preset_validation_params.test = True
|
||||
preset_validation_params.min_reward_threshold = 150
|
||||
preset_validation_params.max_episodes_to_achieve_reward = 250
|
||||
preset_validation_params.max_episodes_to_achieve_reward = 2000
|
||||
|
||||
graph_manager = BatchRLGraphManager(agent_params=agent_params, env_params=env_params,
|
||||
schedule_params=schedule_params,
|
||||
vis_params=VisualizationParameters(dump_signals_to_csv_every_x_episodes=1),
|
||||
preset_validation_params=preset_validation_params,
|
||||
reward_model_num_epochs=50,
|
||||
reward_model_num_epochs=30,
|
||||
train_to_eval_ratio=0.8)
|
||||
|
||||
@@ -57,14 +57,16 @@ def test_preset_runs(preset):
|
||||
|
||||
experiment_name = ".test-" + preset
|
||||
|
||||
params = ["python3", "rl_coach/coach.py", "-p", preset, "-ns", "-e", experiment_name]
|
||||
# overriding heatup steps to some small number of steps (1000), so to finish the heatup stage, and get to train
|
||||
params = ["python3", "rl_coach/coach.py", "-p", preset, "-ns", "-e", experiment_name, '-cp',
|
||||
'heatup_steps=EnvironmentSteps(1000)']
|
||||
if level != "":
|
||||
params += ["-lvl", level]
|
||||
|
||||
p = Popen(params)
|
||||
|
||||
# wait 10 seconds overhead of initialization etc.
|
||||
time.sleep(10)
|
||||
# wait 30 seconds overhead of initialization, and finishing heatup.
|
||||
time.sleep(30)
|
||||
return_value = p.poll()
|
||||
|
||||
if return_value is None:
|
||||
|
||||
@@ -33,7 +33,7 @@ import pytest
|
||||
from rl_coach.logger import screen
|
||||
|
||||
|
||||
def read_csv_paths(test_path, filename_pattern, read_csv_tries=100):
|
||||
def read_csv_paths(test_path, filename_pattern, read_csv_tries=200):
|
||||
csv_paths = []
|
||||
tries_counter = 0
|
||||
while not csv_paths:
|
||||
@@ -155,7 +155,7 @@ def test_preset_reward(preset_name, no_progress_bar=True, time_limit=60 * 60, ve
|
||||
if not no_progress_bar:
|
||||
print_progress(averaged_rewards, last_num_episodes, preset_validation_params, start_time, time_limit)
|
||||
|
||||
while csv is None or (csv['Episode #'].values[
|
||||
while csv is None or (csv[csv.columns[0]].values[
|
||||
-1] < preset_validation_params.max_episodes_to_achieve_reward and time.time() - start_time < time_limit):
|
||||
try:
|
||||
csv = pd.read_csv(csv_path)
|
||||
@@ -179,10 +179,10 @@ def test_preset_reward(preset_name, no_progress_bar=True, time_limit=60 * 60, ve
|
||||
if not no_progress_bar:
|
||||
print_progress(averaged_rewards, last_num_episodes, preset_validation_params, start_time, time_limit)
|
||||
|
||||
if csv['Episode #'].shape[0] - last_num_episodes <= 0:
|
||||
if csv[csv.columns[0]].shape[0] - last_num_episodes <= 0:
|
||||
continue
|
||||
|
||||
last_num_episodes = csv['Episode #'].values[-1]
|
||||
last_num_episodes = csv[csv.columns[0]].values[-1]
|
||||
|
||||
# check if reward is enough
|
||||
if np.any(averaged_rewards >= preset_validation_params.min_reward_threshold):
|
||||
@@ -213,6 +213,7 @@ def test_preset_reward(preset_name, no_progress_bar=True, time_limit=60 * 60, ve
|
||||
preset_validation_params.min_reward_threshold), crash=False)
|
||||
screen.error("averaged_rewards: {}".format(averaged_rewards), crash=False)
|
||||
screen.error("episode number: {}".format(csv['Episode #'].values[-1]), crash=False)
|
||||
screen.error("training iteration: {}".format(csv['Training Iter'].values[-1]), crash=False)
|
||||
else:
|
||||
screen.error("csv file never found", crash=False)
|
||||
if verbose:
|
||||
|
||||
@@ -220,7 +220,7 @@ def force_list(var):
|
||||
|
||||
|
||||
def squeeze_list(var):
|
||||
if len(var) == 1:
|
||||
if type(var) == list and len(var) == 1:
|
||||
return var[0]
|
||||
else:
|
||||
return var
|
||||
|
||||
Reference in New Issue
Block a user