1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

Preset dependent number of csv read attempts in golden testing (#334)

This commit is contained in:
Gal Leibovich
2019-05-28 12:19:57 +03:00
committed by GitHub
parent ddffac8570
commit 251dc9ccc0
3 changed files with 8 additions and 2 deletions

View File

@@ -1,4 +1,5 @@
#
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -222,7 +223,8 @@ class PresetValidationParameters(Parameters):
reward_test_level=None,
test_using_a_trace_test=True,
trace_test_levels=None,
trace_max_env_steps=5000):
trace_max_env_steps=5000,
read_csv_tries=200):
"""
:param test:
A flag which specifies if the preset should be tested as part of the validation process.
@@ -245,6 +247,8 @@ class PresetValidationParameters(Parameters):
:param trace_max_env_steps:
An integer representing the maximum number of environment steps to run when running this preset as part
of the trace tests suite.
:param read_csv_tries:
The number of retries to attempt for reading the experiment csv file, before declaring failure.
"""
super().__init__()
@@ -261,6 +265,7 @@ class PresetValidationParameters(Parameters):
self.test_using_a_trace_test = test_using_a_trace_test
self.trace_test_levels = trace_test_levels
self.trace_max_env_steps = trace_max_env_steps
self.read_csv_tries = read_csv_tries
class NetworkParameters(Parameters):

View File

@@ -135,6 +135,7 @@ preset_validation_params = PresetValidationParameters()
preset_validation_params.test = True
preset_validation_params.min_reward_threshold = 150
preset_validation_params.max_episodes_to_achieve_reward = 50
preset_validation_params.read_csv_tries = 500
graph_manager = BatchRLGraphManager(agent_params=agent_params,
experience_generating_agent_params=experience_generating_agent_params,

View File

@@ -140,7 +140,7 @@ def test_preset_reward(preset_name, no_progress_bar=True, time_limit=60 * 60, ve
test_passed = False
# get the csv with the results
csv_paths = read_csv_paths(test_path, filename_pattern)
csv_paths = read_csv_paths(test_path, filename_pattern, read_csv_tries=preset_validation_params.read_csv_tries)
if csv_paths:
csv_path = csv_paths[0]