mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Adding framework for multinode tests (#149)
* Currently runs CartPole_ClippedPPO and Mujoco_ClippedPPO with inverted_pendulum level.
This commit is contained in:
committed by
Balaji Subramaniam
parent
b461a1b8ab
commit
2c1a9dbf20
@@ -291,6 +291,41 @@ jobs:
|
|||||||
kubectl delete ns trace-test-mujoco-${CIRCLE_BUILD_NUM} || true
|
kubectl delete ns trace-test-mujoco-${CIRCLE_BUILD_NUM} || true
|
||||||
when: always
|
when: always
|
||||||
|
|
||||||
|
multinode_test:
|
||||||
|
<<: *executor_prep
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- *remote_docker
|
||||||
|
- *restore_cache
|
||||||
|
- *aws_prep
|
||||||
|
- *docker_prep
|
||||||
|
- run:
|
||||||
|
name: run multinode test
|
||||||
|
command: |
|
||||||
|
aws s3 mb s3://coach-aws-test-${CIRCLE_BUILD_NUM}
|
||||||
|
kubectl create ns multinode-test-${CIRCLE_BUILD_NUM}
|
||||||
|
docker run -e CIRCLE_BUILD_NUM=$CIRCLE_BUILD_NUM -e TAG=$(git describe --tags --always --dirty) -e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY -e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID -e AWS_DEFAULT_REGION=$AWS_DEFAULT_REGION 316971102342.dkr.ecr.us-west-2.amazonaws.com/coach:$(git describe --tags --always --dirty) \
|
||||||
|
/bin/bash -c 'pip install awscli; curl -o /usr/local/bin/aws-iam-authenticator https://amazon-eks.s3-us-west-2.amazonaws.com/1.10.3/2018-07-26/bin/linux/amd64/aws-iam-authenticator \
|
||||||
|
&& chmod a+x /usr/local/bin/aws-iam-authenticator \
|
||||||
|
&& aws eks update-kubeconfig --name coach-aws-cicd \
|
||||||
|
&& curl -o /usr/local/bin/kubectl https://storage.googleapis.com/kubernetes-release/release/$(curl -s https://storage.googleapis.com/kubernetes-release/release/stable.txt)/bin/linux/amd64/kubectl \
|
||||||
|
&& chmod a+x /usr/local/bin/kubectl \
|
||||||
|
&& kubectl config set-context $(kubectl config current-context) --namespace=multinode-test-${CIRCLE_BUILD_NUM} \
|
||||||
|
&& python3 rl_coach/tests/test_dist_coach.py -i 316971102342.dkr.ecr.us-west-2.amazonaws.com/coach:${TAG} -b coach-aws-test-${CIRCLE_BUILD_NUM}'
|
||||||
|
docker ps -a -q | head -n 1 | xargs -I% docker cp %:/root/src/experiments . || true
|
||||||
|
no_output_timeout: 30m
|
||||||
|
- store_artifacts:
|
||||||
|
path: ~/repo/experiments
|
||||||
|
- run:
|
||||||
|
name: cleanup
|
||||||
|
command: |
|
||||||
|
kubectl delete --all pods --namespace=golden-test-mujoco-${CIRCLE_BUILD_NUM} || true
|
||||||
|
kubectl delete ns golden-test-mujoco-${CIRCLE_BUILD_NUM} || true
|
||||||
|
aws s3 rm --recursive s3://coach-aws-test-${CIRCLE_BUILD_NUM} || true
|
||||||
|
aws s3 rb s3://coach-aws-test-${CIRCLE_BUILD_NUM} || true
|
||||||
|
kubectl delete ns multinode-test-${CIRCLE_BUILD_NUM} || true
|
||||||
|
when: always
|
||||||
|
|
||||||
container_deploy:
|
container_deploy:
|
||||||
<<: *executor_prep
|
<<: *executor_prep
|
||||||
steps:
|
steps:
|
||||||
@@ -329,6 +364,13 @@ workflows:
|
|||||||
- integration_tests:
|
- integration_tests:
|
||||||
requires:
|
requires:
|
||||||
- build_base
|
- build_base
|
||||||
|
- multinode_approval:
|
||||||
|
type: approval
|
||||||
|
requires:
|
||||||
|
- build_base
|
||||||
|
- multinode_test:
|
||||||
|
requires:
|
||||||
|
- multinode_approval
|
||||||
- e2e_approval:
|
- e2e_approval:
|
||||||
type: approval
|
type: approval
|
||||||
requires:
|
requires:
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ scikit-image>=0.13.0
|
|||||||
gym>=0.10.5
|
gym>=0.10.5
|
||||||
bokeh>=0.13.0
|
bokeh>=0.13.0
|
||||||
futures>=3.1.1
|
futures>=3.1.1
|
||||||
kubernetes>=8.0.0b1
|
kubernetes>=8.0.0b1,<=8.0.1
|
||||||
redis>=2.10.6
|
redis>=2.10.6
|
||||||
minio>=4.0.5
|
minio>=4.0.5
|
||||||
pytest>=3.8.2
|
pytest>=3.8.2
|
||||||
|
|||||||
@@ -103,7 +103,8 @@ def handle_distributed_coach_tasks(graph_manager, args, task_parameters):
|
|||||||
task_parameters.checkpoint_save_dir = ckpt_inside_container
|
task_parameters.checkpoint_save_dir = ckpt_inside_container
|
||||||
training_worker(
|
training_worker(
|
||||||
graph_manager=graph_manager,
|
graph_manager=graph_manager,
|
||||||
task_parameters=task_parameters
|
task_parameters=task_parameters,
|
||||||
|
is_multi_node_test=args.is_multi_node_test
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER:
|
if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER:
|
||||||
@@ -166,30 +167,32 @@ def handle_distributed_coach_orchestrator(args):
|
|||||||
orchestrator = Kubernetes(orchestration_params)
|
orchestrator = Kubernetes(orchestration_params)
|
||||||
if not orchestrator.setup():
|
if not orchestrator.setup():
|
||||||
print("Could not setup.")
|
print("Could not setup.")
|
||||||
return
|
return 1
|
||||||
|
|
||||||
if orchestrator.deploy_trainer():
|
if orchestrator.deploy_trainer():
|
||||||
print("Successfully deployed trainer.")
|
print("Successfully deployed trainer.")
|
||||||
else:
|
else:
|
||||||
print("Could not deploy trainer.")
|
print("Could not deploy trainer.")
|
||||||
return
|
return 1
|
||||||
|
|
||||||
if orchestrator.deploy_worker():
|
if orchestrator.deploy_worker():
|
||||||
print("Successfully deployed rollout worker(s).")
|
print("Successfully deployed rollout worker(s).")
|
||||||
else:
|
else:
|
||||||
print("Could not deploy rollout worker(s).")
|
print("Could not deploy rollout worker(s).")
|
||||||
return
|
return 1
|
||||||
|
|
||||||
if args.dump_worker_logs:
|
if args.dump_worker_logs:
|
||||||
screen.log_title("Dumping rollout worker logs in: {}".format(args.experiment_path))
|
screen.log_title("Dumping rollout worker logs in: {}".format(args.experiment_path))
|
||||||
orchestrator.worker_logs(path=args.experiment_path)
|
orchestrator.worker_logs(path=args.experiment_path)
|
||||||
|
|
||||||
|
exit_code = 1
|
||||||
try:
|
try:
|
||||||
orchestrator.trainer_logs()
|
exit_code = orchestrator.trainer_logs()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
orchestrator.undeploy()
|
orchestrator.undeploy()
|
||||||
|
return exit_code
|
||||||
|
|
||||||
|
|
||||||
class CoachLauncher(object):
|
class CoachLauncher(object):
|
||||||
@@ -331,7 +334,7 @@ class CoachLauncher(object):
|
|||||||
# if no arg is given
|
# if no arg is given
|
||||||
if len(sys.argv) == 1:
|
if len(sys.argv) == 1:
|
||||||
parser.print_help()
|
parser.print_help()
|
||||||
sys.exit(0)
|
sys.exit(1)
|
||||||
|
|
||||||
# list available presets
|
# list available presets
|
||||||
if args.list:
|
if args.list:
|
||||||
@@ -569,6 +572,9 @@ class CoachLauncher(object):
|
|||||||
parser.add_argument('--dump_worker_logs',
|
parser.add_argument('--dump_worker_logs',
|
||||||
help="(flag) Only used in distributed coach. If set, the worker logs are saved in the experiment dir",
|
help="(flag) Only used in distributed coach. If set, the worker logs are saved in the experiment dir",
|
||||||
action='store_true')
|
action='store_true')
|
||||||
|
parser.add_argument('--is_multi_node_test',
|
||||||
|
help=argparse.SUPPRESS,
|
||||||
|
action='store_true')
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@@ -617,8 +623,7 @@ class CoachLauncher(object):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
|
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
|
||||||
handle_distributed_coach_orchestrator(args)
|
exit(handle_distributed_coach_orchestrator(args))
|
||||||
return
|
|
||||||
|
|
||||||
# Single-threaded runs
|
# Single-threaded runs
|
||||||
if args.num_workers == 1:
|
if args.num_workers == 1:
|
||||||
|
|||||||
@@ -504,12 +504,7 @@ class GraphManager(object):
|
|||||||
self.act(EnvironmentEpisodes(1))
|
self.act(EnvironmentEpisodes(1))
|
||||||
self.sync()
|
self.sync()
|
||||||
if self.should_stop():
|
if self.should_stop():
|
||||||
if self.task_parameters.checkpoint_save_dir and os.path.exists(self.task_parameters.checkpoint_save_dir):
|
self.flush_finished()
|
||||||
open(os.path.join(self.task_parameters.checkpoint_save_dir, SyncFiles.FINISHED.value), 'w').close()
|
|
||||||
if hasattr(self, 'data_store_params'):
|
|
||||||
data_store = self.get_data_store(self.data_store_params)
|
|
||||||
data_store.save_to_store()
|
|
||||||
|
|
||||||
screen.success("Reached required success rate. Exiting.")
|
screen.success("Reached required success rate. Exiting.")
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
@@ -713,3 +708,20 @@ class GraphManager(object):
|
|||||||
"""
|
"""
|
||||||
for env in self.environments:
|
for env in self.environments:
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
|
def get_current_episodes_count(self):
|
||||||
|
"""
|
||||||
|
Returns the current EnvironmentEpisodes counter
|
||||||
|
"""
|
||||||
|
return self.current_step_counter[EnvironmentEpisodes]
|
||||||
|
|
||||||
|
def flush_finished(self):
|
||||||
|
"""
|
||||||
|
To indicate the training has finished, writes a `.finished` file to the checkpoint directory and calls
|
||||||
|
the data store to updload that file.
|
||||||
|
"""
|
||||||
|
if self.task_parameters.checkpoint_save_dir and os.path.exists(self.task_parameters.checkpoint_save_dir):
|
||||||
|
open(os.path.join(self.task_parameters.checkpoint_save_dir, SyncFiles.FINISHED.value), 'w').close()
|
||||||
|
if hasattr(self, 'data_store_params'):
|
||||||
|
data_store = self.get_data_store(self.data_store_params)
|
||||||
|
data_store.save_to_store()
|
||||||
|
|||||||
@@ -166,7 +166,7 @@ class Kubernetes(Deploy):
|
|||||||
name="nfs-pvc",
|
name="nfs-pvc",
|
||||||
persistent_volume_claim=self.nfs_pvc
|
persistent_volume_claim=self.nfs_pvc
|
||||||
)],
|
)],
|
||||||
restart_policy='OnFailure'
|
restart_policy='Never'
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -185,7 +185,7 @@ class Kubernetes(Deploy):
|
|||||||
metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
|
metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
|
||||||
spec=k8sclient.V1PodSpec(
|
spec=k8sclient.V1PodSpec(
|
||||||
containers=[container],
|
containers=[container],
|
||||||
restart_policy='OnFailure'
|
restart_policy='Never'
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -247,7 +247,7 @@ class Kubernetes(Deploy):
|
|||||||
name="nfs-pvc",
|
name="nfs-pvc",
|
||||||
persistent_volume_claim=self.nfs_pvc
|
persistent_volume_claim=self.nfs_pvc
|
||||||
)],
|
)],
|
||||||
restart_policy='OnFailure'
|
restart_policy='Never'
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -266,7 +266,7 @@ class Kubernetes(Deploy):
|
|||||||
metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
|
metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
|
||||||
spec=k8sclient.V1PodSpec(
|
spec=k8sclient.V1PodSpec(
|
||||||
containers=[container],
|
containers=[container],
|
||||||
restart_policy='OnFailure'
|
restart_policy='Never'
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -316,7 +316,7 @@ class Kubernetes(Deploy):
|
|||||||
return
|
return
|
||||||
|
|
||||||
for pod in pods.items:
|
for pod in pods.items:
|
||||||
Process(target=self._tail_log_file, args=(pod.metadata.name, api_client, self.params.namespace, path)).start()
|
Process(target=self._tail_log_file, args=(pod.metadata.name, api_client, self.params.namespace, path), daemon=True).start()
|
||||||
|
|
||||||
def _tail_log_file(self, pod_name, api_client, namespace, path):
|
def _tail_log_file(self, pod_name, api_client, namespace, path):
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
@@ -348,7 +348,7 @@ class Kubernetes(Deploy):
|
|||||||
if not pod:
|
if not pod:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.tail_log(pod.metadata.name, api_client)
|
return self.tail_log(pod.metadata.name, api_client)
|
||||||
|
|
||||||
def tail_log(self, pod_name, corev1_api):
|
def tail_log(self, pod_name, corev1_api):
|
||||||
while True:
|
while True:
|
||||||
@@ -382,9 +382,9 @@ class Kubernetes(Deploy):
|
|||||||
container_status.state.waiting.reason == 'CrashLoopBackOff' or \
|
container_status.state.waiting.reason == 'CrashLoopBackOff' or \
|
||||||
container_status.state.waiting.reason == 'ImagePullBackOff' or \
|
container_status.state.waiting.reason == 'ImagePullBackOff' or \
|
||||||
container_status.state.waiting.reason == 'ErrImagePull':
|
container_status.state.waiting.reason == 'ErrImagePull':
|
||||||
return
|
return 1
|
||||||
if container_status.state.terminated is not None:
|
if container_status.state.terminated is not None:
|
||||||
return
|
return container_status.state.terminated.exit_code
|
||||||
|
|
||||||
def undeploy(self):
|
def undeploy(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -56,6 +56,8 @@ agent_params.pre_network_filter.add_observation_filter('observation', 'normalize
|
|||||||
# Environment #
|
# Environment #
|
||||||
###############
|
###############
|
||||||
env_params = GymVectorEnvironment(level=SingleLevelSelection(mujoco_v2))
|
env_params = GymVectorEnvironment(level=SingleLevelSelection(mujoco_v2))
|
||||||
|
# Set the target success
|
||||||
|
env_params.target_success_rate = 1.0
|
||||||
|
|
||||||
########
|
########
|
||||||
# Test #
|
# Test #
|
||||||
|
|||||||
116
rl_coach/tests/test_dist_coach.py
Normal file
116
rl_coach/tests/test_dist_coach.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
|
||||||
|
from configparser import ConfigParser
|
||||||
|
import pytest
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from rl_coach.coach import CoachLauncher
|
||||||
|
import sys
|
||||||
|
from minio import Minio
|
||||||
|
|
||||||
|
|
||||||
|
def generate_config(image, memory_backend, s3_end_point, s3_bucket_name, s3_creds_file, config_file):
|
||||||
|
"""
|
||||||
|
Generate the s3 config file to be used and also the dist-coach-config.template to be used for the test
|
||||||
|
It reads the `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` env vars and fails if they are not provided.
|
||||||
|
"""
|
||||||
|
# Write s3 creds
|
||||||
|
aws_config = ConfigParser({
|
||||||
|
'aws_access_key_id': os.environ.get('AWS_ACCESS_KEY_ID'),
|
||||||
|
'aws_secret_access_key': os.environ.get('AWS_SECRET_ACCESS_KEY')
|
||||||
|
}, default_section='default')
|
||||||
|
with open(s3_creds_file, 'w') as f:
|
||||||
|
aws_config.write(f)
|
||||||
|
|
||||||
|
coach_config = ConfigParser({
|
||||||
|
'image': image,
|
||||||
|
'memory_backend': memory_backend,
|
||||||
|
'data_store': 's3',
|
||||||
|
's3_end_point': s3_end_point,
|
||||||
|
's3_bucket_name': s3_bucket_name,
|
||||||
|
's3_creds_file': s3_creds_file
|
||||||
|
}, default_section="coach")
|
||||||
|
with open(config_file, 'w') as f:
|
||||||
|
coach_config.write(f)
|
||||||
|
|
||||||
|
|
||||||
|
def test_command(command):
|
||||||
|
"""
|
||||||
|
Launches the actual training.
|
||||||
|
"""
|
||||||
|
sys.argv = command
|
||||||
|
launcher = CoachLauncher()
|
||||||
|
with pytest.raises(SystemExit) as e:
|
||||||
|
launcher.launch()
|
||||||
|
assert e.value.code == 0
|
||||||
|
|
||||||
|
|
||||||
|
def clear_bucket(s3_end_point, s3_bucket_name):
|
||||||
|
"""
|
||||||
|
Clear the bucket before starting the test.
|
||||||
|
"""
|
||||||
|
access_key = os.environ.get('AWS_ACCESS_KEY_ID')
|
||||||
|
secret_access_key = os.environ.get('AWS_SECRET_ACCESS_KEY')
|
||||||
|
minio_client = Minio(s3_end_point, access_key=access_key, secret_key=secret_access_key)
|
||||||
|
try:
|
||||||
|
for obj in minio_client.list_objects_v2(s3_bucket_name, recursive=True):
|
||||||
|
minio_client.remove_object(s3_bucket_name, obj.object_name)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_dc(command, image, memory_backend, s3_end_point, s3_bucket_name, s3_creds_file, config_file):
|
||||||
|
"""
|
||||||
|
Entry point into the test
|
||||||
|
"""
|
||||||
|
clear_bucket(s3_end_point, s3_bucket_name)
|
||||||
|
command = command.format(template=config_file).split(' ')
|
||||||
|
test_command(command)
|
||||||
|
|
||||||
|
|
||||||
|
def get_tests():
|
||||||
|
"""
|
||||||
|
All the presets to test. New presets should be added here.
|
||||||
|
"""
|
||||||
|
tests = [
|
||||||
|
'rl_coach/coach.py -p CartPole_ClippedPPO -dc -e sample -dcp {template} --dump_worker_logs -asc --is_multi_node_test --seed 1',
|
||||||
|
'rl_coach/coach.py -p Mujoco_ClippedPPO -lvl inverted_pendulum -dc -e sample -dcp {template} --dump_worker_logs -asc --is_multi_node_test --seed 1'
|
||||||
|
]
|
||||||
|
return tests
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-i', '--image', help="(string) Name of the testing image", type=str, default=None
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'-mb', '--memory_backend', help="(string) Name of the memory backend", type=str, default="redispubsub"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'-e', '--endpoint', help="(string) Name of the s3 endpoint", type=str, default='s3.amazonaws.com'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'-cr', '--creds_file', help="(string) Path of the s3 creds file", type=str, default='.aws_creds'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'-b', '--bucket', help="(string) Name of the bucket for s3", type=str, default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not args.bucket:
|
||||||
|
print("bucket_name required for s3")
|
||||||
|
exit(1)
|
||||||
|
if not os.environ.get('AWS_ACCESS_KEY_ID') or not os.environ.get('AWS_SECRET_ACCESS_KEY'):
|
||||||
|
print("AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY env vars need to be set")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
config_file = './tmp.cred'
|
||||||
|
generate_config(args.image, args.memory_backend, args.endpoint, args.bucket, args.creds_file, config_file)
|
||||||
|
for command in get_tests():
|
||||||
|
test_dc(command, args.image, args.memory_backend, args.endpoint, args.bucket, args.creds_file, config_file)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -21,6 +21,7 @@ import time
|
|||||||
|
|
||||||
from rl_coach.base_parameters import TaskParameters, DistributedCoachSynchronizationType
|
from rl_coach.base_parameters import TaskParameters, DistributedCoachSynchronizationType
|
||||||
from rl_coach import core_types
|
from rl_coach import core_types
|
||||||
|
from rl_coach.logger import screen
|
||||||
|
|
||||||
|
|
||||||
def data_store_ckpt_save(data_store):
|
def data_store_ckpt_save(data_store):
|
||||||
@@ -29,9 +30,12 @@ def data_store_ckpt_save(data_store):
|
|||||||
time.sleep(10)
|
time.sleep(10)
|
||||||
|
|
||||||
|
|
||||||
def training_worker(graph_manager, task_parameters):
|
def training_worker(graph_manager, task_parameters, is_multi_node_test):
|
||||||
"""
|
"""
|
||||||
restore a checkpoint then perform rollouts using the restored model
|
restore a checkpoint then perform rollouts using the restored model
|
||||||
|
:param graph_manager: An instance of the graph manager
|
||||||
|
:param task_parameters: An instance of task parameters
|
||||||
|
:param is_multi_node_test: If this is a multi node test insted of a normal run.
|
||||||
"""
|
"""
|
||||||
# initialize graph
|
# initialize graph
|
||||||
graph_manager.create_graph(task_parameters)
|
graph_manager.create_graph(task_parameters)
|
||||||
@@ -50,6 +54,11 @@ def training_worker(graph_manager, task_parameters):
|
|||||||
while steps < graph_manager.improve_steps.num_steps:
|
while steps < graph_manager.improve_steps.num_steps:
|
||||||
|
|
||||||
graph_manager.phase = core_types.RunPhase.TRAIN
|
graph_manager.phase = core_types.RunPhase.TRAIN
|
||||||
|
if is_multi_node_test and graph_manager.get_current_episodes_count() > graph_manager.preset_validation_params.max_episodes_to_achieve_reward:
|
||||||
|
# Test failed as it has not reached the required success rate
|
||||||
|
graph_manager.flush_finished()
|
||||||
|
screen.error("Could not reach required success by {} episodes.".format(graph_manager.preset_validation_params.max_episodes_to_achieve_reward), crash=True)
|
||||||
|
|
||||||
graph_manager.fetch_from_worker(graph_manager.agent_params.algorithm.num_consecutive_playing_steps)
|
graph_manager.fetch_from_worker(graph_manager.agent_params.algorithm.num_consecutive_playing_steps)
|
||||||
graph_manager.phase = core_types.RunPhase.UNDEFINED
|
graph_manager.phase = core_types.RunPhase.UNDEFINED
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user