1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Preset dependent number of csv read attempts in golden testing (#334)

This commit is contained in:
Gal Leibovich
2019-05-28 12:19:57 +03:00
committed by GitHub
parent ddffac8570
commit 251dc9ccc0
3 changed files with 8 additions and 2 deletions

View File

@@ -1,4 +1,5 @@
# #
#
# Copyright (c) 2017 Intel Corporation # Copyright (c) 2017 Intel Corporation
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@@ -222,7 +223,8 @@ class PresetValidationParameters(Parameters):
reward_test_level=None, reward_test_level=None,
test_using_a_trace_test=True, test_using_a_trace_test=True,
trace_test_levels=None, trace_test_levels=None,
trace_max_env_steps=5000): trace_max_env_steps=5000,
read_csv_tries=200):
""" """
:param test: :param test:
A flag which specifies if the preset should be tested as part of the validation process. A flag which specifies if the preset should be tested as part of the validation process.
@@ -245,6 +247,8 @@ class PresetValidationParameters(Parameters):
:param trace_max_env_steps: :param trace_max_env_steps:
An integer representing the maximum number of environment steps to run when running this preset as part An integer representing the maximum number of environment steps to run when running this preset as part
of the trace tests suite. of the trace tests suite.
:param read_csv_tries:
The number of retries to attempt for reading the experiment csv file, before declaring failure.
""" """
super().__init__() super().__init__()
@@ -261,6 +265,7 @@ class PresetValidationParameters(Parameters):
self.test_using_a_trace_test = test_using_a_trace_test self.test_using_a_trace_test = test_using_a_trace_test
self.trace_test_levels = trace_test_levels self.trace_test_levels = trace_test_levels
self.trace_max_env_steps = trace_max_env_steps self.trace_max_env_steps = trace_max_env_steps
self.read_csv_tries = read_csv_tries
class NetworkParameters(Parameters): class NetworkParameters(Parameters):

View File

@@ -135,6 +135,7 @@ preset_validation_params = PresetValidationParameters()
preset_validation_params.test = True preset_validation_params.test = True
preset_validation_params.min_reward_threshold = 150 preset_validation_params.min_reward_threshold = 150
preset_validation_params.max_episodes_to_achieve_reward = 50 preset_validation_params.max_episodes_to_achieve_reward = 50
preset_validation_params.read_csv_tries = 500
graph_manager = BatchRLGraphManager(agent_params=agent_params, graph_manager = BatchRLGraphManager(agent_params=agent_params,
experience_generating_agent_params=experience_generating_agent_params, experience_generating_agent_params=experience_generating_agent_params,

View File

@@ -140,7 +140,7 @@ def test_preset_reward(preset_name, no_progress_bar=True, time_limit=60 * 60, ve
test_passed = False test_passed = False
# get the csv with the results # get the csv with the results
csv_paths = read_csv_paths(test_path, filename_pattern) csv_paths = read_csv_paths(test_path, filename_pattern, read_csv_tries=preset_validation_params.read_csv_tries)
if csv_paths: if csv_paths:
csv_path = csv_paths[0] csv_path = csv_paths[0]