mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
Preset dependent number of csv read attempts in golden testing (#334)
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
#
|
#
|
||||||
|
#
|
||||||
# Copyright (c) 2017 Intel Corporation
|
# Copyright (c) 2017 Intel Corporation
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@@ -222,7 +223,8 @@ class PresetValidationParameters(Parameters):
|
|||||||
reward_test_level=None,
|
reward_test_level=None,
|
||||||
test_using_a_trace_test=True,
|
test_using_a_trace_test=True,
|
||||||
trace_test_levels=None,
|
trace_test_levels=None,
|
||||||
trace_max_env_steps=5000):
|
trace_max_env_steps=5000,
|
||||||
|
read_csv_tries=200):
|
||||||
"""
|
"""
|
||||||
:param test:
|
:param test:
|
||||||
A flag which specifies if the preset should be tested as part of the validation process.
|
A flag which specifies if the preset should be tested as part of the validation process.
|
||||||
@@ -245,6 +247,8 @@ class PresetValidationParameters(Parameters):
|
|||||||
:param trace_max_env_steps:
|
:param trace_max_env_steps:
|
||||||
An integer representing the maximum number of environment steps to run when running this preset as part
|
An integer representing the maximum number of environment steps to run when running this preset as part
|
||||||
of the trace tests suite.
|
of the trace tests suite.
|
||||||
|
:param read_csv_tries:
|
||||||
|
The number of retries to attempt for reading the experiment csv file, before declaring failure.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -261,6 +265,7 @@ class PresetValidationParameters(Parameters):
|
|||||||
self.test_using_a_trace_test = test_using_a_trace_test
|
self.test_using_a_trace_test = test_using_a_trace_test
|
||||||
self.trace_test_levels = trace_test_levels
|
self.trace_test_levels = trace_test_levels
|
||||||
self.trace_max_env_steps = trace_max_env_steps
|
self.trace_max_env_steps = trace_max_env_steps
|
||||||
|
self.read_csv_tries = read_csv_tries
|
||||||
|
|
||||||
|
|
||||||
class NetworkParameters(Parameters):
|
class NetworkParameters(Parameters):
|
||||||
|
|||||||
@@ -135,6 +135,7 @@ preset_validation_params = PresetValidationParameters()
|
|||||||
preset_validation_params.test = True
|
preset_validation_params.test = True
|
||||||
preset_validation_params.min_reward_threshold = 150
|
preset_validation_params.min_reward_threshold = 150
|
||||||
preset_validation_params.max_episodes_to_achieve_reward = 50
|
preset_validation_params.max_episodes_to_achieve_reward = 50
|
||||||
|
preset_validation_params.read_csv_tries = 500
|
||||||
|
|
||||||
graph_manager = BatchRLGraphManager(agent_params=agent_params,
|
graph_manager = BatchRLGraphManager(agent_params=agent_params,
|
||||||
experience_generating_agent_params=experience_generating_agent_params,
|
experience_generating_agent_params=experience_generating_agent_params,
|
||||||
|
|||||||
@@ -140,7 +140,7 @@ def test_preset_reward(preset_name, no_progress_bar=True, time_limit=60 * 60, ve
|
|||||||
test_passed = False
|
test_passed = False
|
||||||
|
|
||||||
# get the csv with the results
|
# get the csv with the results
|
||||||
csv_paths = read_csv_paths(test_path, filename_pattern)
|
csv_paths = read_csv_paths(test_path, filename_pattern, read_csv_tries=preset_validation_params.read_csv_tries)
|
||||||
|
|
||||||
if csv_paths:
|
if csv_paths:
|
||||||
csv_path = csv_paths[0]
|
csv_path = csv_paths[0]
|
||||||
|
|||||||
Reference in New Issue
Block a user