mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
tests: update traces (#302)
* Traces folder removed from repo and moved to S3 * Traces jobs and update will use directly the S3 files
This commit is contained in:
@@ -22,15 +22,15 @@ import subprocess
|
||||
import multiprocessing
|
||||
import sys
|
||||
import signal
|
||||
import pandas as pd
|
||||
import time
|
||||
from configparser import ConfigParser
|
||||
from importlib import import_module
|
||||
from os import path
|
||||
sys.path.append('.')
|
||||
import pandas as pd
|
||||
import time
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
from rl_coach.logger import screen
|
||||
|
||||
|
||||
processes = []
|
||||
|
||||
|
||||
@@ -45,6 +45,7 @@ def sigint_handler(signum, frame):
|
||||
os.remove(f)
|
||||
exit()
|
||||
|
||||
|
||||
signal.signal(signal.SIGINT, sigint_handler)
|
||||
|
||||
|
||||
@@ -76,12 +77,15 @@ def run_trace_based_test(preset_name, num_env_steps, level=None):
|
||||
screen.log_title("Running test {}{}".format(preset_name, ' - ' + level if level else ''))
|
||||
log_file_name = 'trace_test_log_{preset_name}.txt'.format(preset_name=test_name[13:])
|
||||
|
||||
config_file = './tmp.cred'
|
||||
|
||||
cmd = (
|
||||
'python3 rl_coach/coach.py '
|
||||
'-p {preset_name} '
|
||||
'-e {test_name} '
|
||||
'--seed 42 '
|
||||
'-c '
|
||||
'-dcp {template}'
|
||||
'--no_summary '
|
||||
'-cp {custom_param} '
|
||||
'{level} '
|
||||
@@ -89,6 +93,7 @@ def run_trace_based_test(preset_name, num_env_steps, level=None):
|
||||
).format(
|
||||
preset_name=preset_name,
|
||||
test_name=test_name,
|
||||
template=config_file,
|
||||
log_file_name=log_file_name,
|
||||
level='-lvl ' + level if level else '',
|
||||
custom_param='\"improve_steps=EnvironmentSteps({n});'
|
||||
@@ -166,6 +171,31 @@ def wait_and_check(args, processes, force=False):
|
||||
return test_passed
|
||||
|
||||
|
||||
def generate_config(image, memory_backend, s3_end_point, s3_bucket_name, s3_creds_file, config_file):
|
||||
"""
|
||||
Generate the s3 config file to be used and also the dist-coach-config.template to be used for the test
|
||||
It reads the `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` env vars and fails if they are not provided.
|
||||
"""
|
||||
# Write s3 creds
|
||||
aws_config = ConfigParser({
|
||||
'aws_access_key_id': os.environ.get('AWS_ACCESS_KEY_ID'),
|
||||
'aws_secret_access_key': os.environ.get('AWS_SECRET_ACCESS_KEY')
|
||||
}, default_section='default')
|
||||
with open(s3_creds_file, 'w') as f:
|
||||
aws_config.write(f)
|
||||
|
||||
coach_config = ConfigParser({
|
||||
'image': image,
|
||||
'memory_backend': memory_backend,
|
||||
'data_store': 's3',
|
||||
's3_end_point': s3_end_point,
|
||||
's3_bucket_name': s3_bucket_name,
|
||||
's3_creds_file': s3_creds_file
|
||||
}, default_section="coach")
|
||||
with open(config_file, 'w') as f:
|
||||
coach_config.write(f)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-p', '--preset', '--presets',
|
||||
@@ -188,12 +218,42 @@ def main():
|
||||
parser.add_argument('-prl', '--parallel',
|
||||
help="(flag) run tests in parallel",
|
||||
action='store_true')
|
||||
parser.add_argument('-ut', '--update_traces',
|
||||
help="(flag) update traces on repository",
|
||||
action='store_true')
|
||||
parser.add_argument('-mt', '--max_threads',
|
||||
help="(int) maximum number of threads to run in parallel",
|
||||
default=multiprocessing.cpu_count()-2,
|
||||
type=int)
|
||||
parser.add_argument(
|
||||
'-i', '--image', help="(string) Name of the testing image", type=str, default=None
|
||||
)
|
||||
parser.add_argument(
|
||||
'-mb', '--memory_backend', help="(string) Name of the memory backend", type=str, default="redispubsub"
|
||||
)
|
||||
parser.add_argument(
|
||||
'-e', '--endpoint', help="(string) Name of the s3 endpoint", type=str, default='s3.amazonaws.com'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-cr', '--creds_file', help="(string) Path of the s3 creds file", type=str, default='.aws_creds'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-b', '--bucket', help="(string) Name of the bucket for s3", type=str, default=None
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.update_traces:
|
||||
if not args.bucket:
|
||||
print("bucket_name required for s3")
|
||||
exit(1)
|
||||
if not os.environ.get('AWS_ACCESS_KEY_ID') or not os.environ.get('AWS_SECRET_ACCESS_KEY'):
|
||||
print("AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY env vars need to be set")
|
||||
exit(1)
|
||||
|
||||
config_file = './tmp.cred'
|
||||
generate_config(args.image, args.memory_backend, args.endpoint, args.bucket, args.creds_file, config_file)
|
||||
|
||||
if not args.parallel:
|
||||
args.max_threads = 1
|
||||
|
||||
@@ -251,7 +311,7 @@ def main():
|
||||
if fail_count == 0:
|
||||
screen.success(" Summary: " + str(test_count) + "/" + str(test_count) + " tests passed successfully")
|
||||
else:
|
||||
screen.error(" Summary: " + str(test_count - fail_count) + "/" + str(test_count) + " tests passed successfully")
|
||||
screen.error(" Summary: " + str(test_count - fail_count) + "/" + str(test_count) + " tests passed successfully", crash=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user