diff --git a/rl_coach/base_parameters.py b/rl_coach/base_parameters.py index 815a78a..a84eb5b 100644 --- a/rl_coach/base_parameters.py +++ b/rl_coach/base_parameters.py @@ -1,4 +1,5 @@ # +# # Copyright (c) 2017 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -222,7 +223,8 @@ class PresetValidationParameters(Parameters): reward_test_level=None, test_using_a_trace_test=True, trace_test_levels=None, - trace_max_env_steps=5000): + trace_max_env_steps=5000, + read_csv_tries=200): """ :param test: A flag which specifies if the preset should be tested as part of the validation process. @@ -245,6 +247,8 @@ class PresetValidationParameters(Parameters): :param trace_max_env_steps: An integer representing the maximum number of environment steps to run when running this preset as part of the trace tests suite. + :param read_csv_tries: + The number of retries to attempt for reading the experiment csv file, before declaring failure. """ super().__init__() @@ -261,6 +265,7 @@ class PresetValidationParameters(Parameters): self.test_using_a_trace_test = test_using_a_trace_test self.trace_test_levels = trace_test_levels self.trace_max_env_steps = trace_max_env_steps + self.read_csv_tries = read_csv_tries class NetworkParameters(Parameters): diff --git a/rl_coach/presets/CartPole_DQN_BatchRL_BCQ.py b/rl_coach/presets/CartPole_DQN_BatchRL_BCQ.py index 91a777c..d830ba2 100644 --- a/rl_coach/presets/CartPole_DQN_BatchRL_BCQ.py +++ b/rl_coach/presets/CartPole_DQN_BatchRL_BCQ.py @@ -135,6 +135,7 @@ preset_validation_params = PresetValidationParameters() preset_validation_params.test = True preset_validation_params.min_reward_threshold = 150 preset_validation_params.max_episodes_to_achieve_reward = 50 +preset_validation_params.read_csv_tries = 500 graph_manager = BatchRLGraphManager(agent_params=agent_params, experience_generating_agent_params=experience_generating_agent_params, diff --git a/rl_coach/tests/test_golden.py b/rl_coach/tests/test_golden.py index cd06b11..d7b9705 100644 --- a/rl_coach/tests/test_golden.py +++ b/rl_coach/tests/test_golden.py @@ -140,7 +140,7 @@ def test_preset_reward(preset_name, no_progress_bar=True, time_limit=60 * 60, ve test_passed = False # get the csv with the results - csv_paths = read_csv_paths(test_path, filename_pattern) + csv_paths = read_csv_paths(test_path, filename_pattern, read_csv_tries=preset_validation_params.read_csv_tries) if csv_paths: csv_path = csv_paths[0]