diff --git a/rl_coach/agents/agent.py b/rl_coach/agents/agent.py index e3b116c..9b4fef6 100644 --- a/rl_coach/agents/agent.py +++ b/rl_coach/agents/agent.py @@ -842,7 +842,5 @@ class Agent(AgentInterface): for network in self.networks.values(): network.sync() - - - - + def get_success_rate(self) -> float: + return self.num_successes_across_evaluation_episodes / self.num_evaluation_episodes_completed diff --git a/rl_coach/base_parameters.py b/rl_coach/base_parameters.py index 2aa09c9..78b5460 100644 --- a/rl_coach/base_parameters.py +++ b/rl_coach/base_parameters.py @@ -168,7 +168,7 @@ class AlgorithmParameters(Parameters): # n-step returns self.n_step = -1 # calculate the total return (no bootstrap, by default) - + # Distributed Coach params self.distributed_coach_synchronization_type = None diff --git a/rl_coach/data_stores/data_store.py b/rl_coach/data_stores/data_store.py index 03718e1..b9e19ea 100644 --- a/rl_coach/data_stores/data_store.py +++ b/rl_coach/data_stores/data_store.py @@ -1,4 +1,6 @@ +from enum import Enum + class DataStoreParameters(object): def __init__(self, store_type, orchestrator_type, orchestrator_params): @@ -6,6 +8,7 @@ class DataStoreParameters(object): self.orchestrator_type = orchestrator_type self.orchestrator_params = orchestrator_params + class DataStore(object): def __init__(self, params: DataStoreParameters): pass @@ -24,3 +27,8 @@ class DataStore(object): def load_from_store(self): pass + + +class SyncFiles(Enum): + FINISHED = ".finished" + LOCKFILE = ".lock" diff --git a/rl_coach/data_stores/nfs_data_store.py b/rl_coach/data_stores/nfs_data_store.py index 375d917..ba2e057 100644 --- a/rl_coach/data_stores/nfs_data_store.py +++ b/rl_coach/data_stores/nfs_data_store.py @@ -58,7 +58,7 @@ class NFSDataStore(DataStore): pass def deploy_k8s_nfs(self) -> bool: - name = "nfs-server" + name = "nfs-server-{}".format(uuid.uuid4()) container = k8sclient.V1Container( name=name, image="k8s.gcr.io/volume-nfs:0.8", @@ -83,7 +83,7 @@ class NFSDataStore(DataStore): security_context=k8sclient.V1SecurityContext(privileged=True) ) template = k8sclient.V1PodTemplateSpec( - metadata=k8sclient.V1ObjectMeta(labels={'app': 'nfs-server'}), + metadata=k8sclient.V1ObjectMeta(labels={'app': name}), spec=k8sclient.V1PodSpec( containers=[container], volumes=[k8sclient.V1Volume( @@ -96,14 +96,14 @@ class NFSDataStore(DataStore): replicas=1, template=template, selector=k8sclient.V1LabelSelector( - match_labels={'app': 'nfs-server'} + match_labels={'app': name} ) ) deployment = k8sclient.V1Deployment( api_version='apps/v1', kind='Deployment', - metadata=k8sclient.V1ObjectMeta(name=name, labels={'app': 'nfs-server'}), + metadata=k8sclient.V1ObjectMeta(name=name, labels={'app': name}), spec=deployment_spec ) @@ -117,7 +117,7 @@ class NFSDataStore(DataStore): k8s_core_v1_api_client = k8sclient.CoreV1Api() - svc_name = "nfs-service" + svc_name = "nfs-service-{}".format(uuid.uuid4()) service = k8sclient.V1Service( api_version='v1', kind='Service', @@ -145,7 +145,7 @@ class NFSDataStore(DataStore): return True def create_k8s_nfs_resources(self) -> bool: - pv_name = "nfs-ckpt-pv" + pv_name = "nfs-ckpt-pv-{}".format(uuid.uuid4()) persistent_volume = k8sclient.V1PersistentVolume( api_version="v1", kind="PersistentVolume", @@ -171,7 +171,7 @@ class NFSDataStore(DataStore): print("Got exception: %s\n while creating the NFS PV", e) return False - pvc_name = "nfs-ckpt-pvc" + pvc_name = "nfs-ckpt-pvc-{}".format(uuid.uuid4()) persistent_volume_claim = k8sclient.V1PersistentVolumeClaim( api_version="v1", kind="PersistentVolumeClaim", diff --git a/rl_coach/data_stores/s3_data_store.py b/rl_coach/data_stores/s3_data_store.py index 7d40f8b..11f3fe1 100644 --- a/rl_coach/data_stores/s3_data_store.py +++ b/rl_coach/data_stores/s3_data_store.py @@ -5,6 +5,7 @@ from minio.error import ResponseError from configparser import ConfigParser, Error from google.protobuf import text_format from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState +from rl_coach.data_stores.data_store import SyncFiles import os import time @@ -20,7 +21,6 @@ class S3DataStoreParameters(DataStoreParameters): self.end_point = end_point self.bucket_name = bucket_name self.checkpoint_dir = checkpoint_dir - self.lock_file = ".lock" class S3DataStore(DataStore): @@ -52,9 +52,9 @@ class S3DataStore(DataStore): def save_to_store(self): try: - self.mc.remove_object(self.params.bucket_name, self.params.lock_file) + self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value) - self.mc.put_object(self.params.bucket_name, self.params.lock_file, io.BytesIO(b''), 0) + self.mc.put_object(self.params.bucket_name, SyncFiles.LOCKFILE.value, io.BytesIO(b''), 0) checkpoint_file = None for root, dirs, files in os.walk(self.params.checkpoint_dir): @@ -70,7 +70,7 @@ class S3DataStore(DataStore): rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir) self.mc.fput_object(self.params.bucket_name, rel_name, abs_name) - self.mc.remove_object(self.params.bucket_name, self.params.lock_file) + self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value) except ResponseError as e: print("Got exception: %s\n while saving to S3", e) @@ -80,7 +80,7 @@ class S3DataStore(DataStore): filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint")) while True: - objects = self.mc.list_objects_v2(self.params.bucket_name, self.params.lock_file) + objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.LOCKFILE.value) if next(objects, None) is None: try: @@ -90,6 +90,18 @@ class S3DataStore(DataStore): break time.sleep(10) + # Check if there's a finished file + objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.FINISHED.value) + + if next(objects, None) is not None: + try: + self.mc.fget_object( + self.params.bucket_name, SyncFiles.FINISHED.value, + os.path.abspath(os.path.join(self.params.checkpoint_dir, SyncFiles.FINISHED.value)) + ) + except Exception as e: + pass + ckpt = CheckpointState() if os.path.exists(filename): contents = open(filename, 'r').read() diff --git a/rl_coach/environments/carla_environment.py b/rl_coach/environments/carla_environment.py index 21963a8..397998b 100644 --- a/rl_coach/environments/carla_environment.py +++ b/rl_coach/environments/carla_environment.py @@ -133,8 +133,8 @@ class CarlaEnvironment(Environment): allow_braking: bool, quality: CarlaEnvironmentParameters.Quality, cameras: List[CameraTypes], weather_id: List[int], experiment_path: str, separate_actions_for_throttle_and_brake: bool, - num_speedup_steps: int, max_speed: float, **kwargs): - super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters) + num_speedup_steps: int, max_speed: float, target_success_rate: float = 1.0, **kwargs): + super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters, target_success_rate) # server configuration self.server_height = server_height @@ -261,6 +261,8 @@ class CarlaEnvironment(Environment): image = self.get_rendered_image() self.renderer.create_screen(image.shape[1], image.shape[0]) + self.target_success_rate = target_success_rate + def _add_cameras(self, settings, cameras, camera_width, camera_height): # add a front facing camera if CameraTypes.FRONT in cameras: @@ -461,3 +463,6 @@ class CarlaEnvironment(Environment): image = [self.state[camera.name] for camera in self.scene.sensors] image = np.vstack(image) return image + + def get_target_success_rate(self) -> float: + return self.target_success_rate diff --git a/rl_coach/environments/control_suite_environment.py b/rl_coach/environments/control_suite_environment.py index 27f3db7..a5667e9 100644 --- a/rl_coach/environments/control_suite_environment.py +++ b/rl_coach/environments/control_suite_environment.py @@ -66,10 +66,10 @@ control_suite_envs = {':'.join(env): ':'.join(env) for env in suite.BENCHMARKING # Environment class ControlSuiteEnvironment(Environment): def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters, - seed: Union[None, int]=None, human_control: bool=False, + target_success_rate: float=1.0, seed: Union[None, int]=None, human_control: bool=False, observation_type: ObservationType=ObservationType.Measurements, custom_reward_threshold: Union[int, float]=None, **kwargs): - super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters) + super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters, target_success_rate) self.observation_type = observation_type @@ -126,6 +126,8 @@ class ControlSuiteEnvironment(Environment): if not self.native_rendering: self.renderer.create_screen(image.shape[1]*scale, image.shape[0]*scale) + self.target_success_rate = target_success_rate + def _update_state(self): self.state = {} @@ -160,3 +162,6 @@ class ControlSuiteEnvironment(Environment): def get_rendered_image(self): return self.env.physics.render(camera_id=0) + + def get_target_success_rate(self) -> float: + return self.target_success_rate \ No newline at end of file diff --git a/rl_coach/environments/doom_environment.py b/rl_coach/environments/doom_environment.py index 437968b..d4269ba 100644 --- a/rl_coach/environments/doom_environment.py +++ b/rl_coach/environments/doom_environment.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2017 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -124,8 +124,8 @@ class DoomEnvironment(Environment): def __init__(self, level: LevelSelection, seed: int, frame_skip: int, human_control: bool, custom_reward_threshold: Union[int, float], visualization_parameters: VisualizationParameters, - cameras: List[CameraTypes], **kwargs): - super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters) + cameras: List[CameraTypes], target_success_rate: float=1.0, **kwargs): + super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters, target_success_rate) self.cameras = cameras @@ -196,6 +196,8 @@ class DoomEnvironment(Environment): image = self.get_rendered_image() self.renderer.create_screen(image.shape[1], image.shape[0]) + self.target_success_rate = target_success_rate + def _update_state(self): # extract all data from the current state state = self.game.get_state() @@ -227,3 +229,6 @@ class DoomEnvironment(Environment): image = [self.state[camera.value[0]] for camera in self.cameras] image = np.vstack(image) return image + + def get_target_success_rate(self) -> float: + return self.target_success_rate diff --git a/rl_coach/environments/environment.py b/rl_coach/environments/environment.py index 295c168..841549a 100644 --- a/rl_coach/environments/environment.py +++ b/rl_coach/environments/environment.py @@ -103,6 +103,9 @@ class EnvironmentParameters(Parameters): self.default_output_filter = None self.experiment_path = None + # Set target reward and target_success if present + self.target_success_rate = 1.0 + @property def path(self): return 'rl_coach.environments.environment:Environment' @@ -111,7 +114,7 @@ class EnvironmentParameters(Parameters): class Environment(EnvironmentInterface): def __init__(self, level: LevelSelection, seed: int, frame_skip: int, human_control: bool, custom_reward_threshold: Union[int, float], visualization_parameters: VisualizationParameters, - **kwargs): + target_success_rate: float=1.0, **kwargs): """ :param level: The environment level. Each environment can have multiple levels :param seed: a seed for the random number generator of the environment @@ -166,6 +169,9 @@ class Environment(EnvironmentInterface): if not self.native_rendering: self.renderer = Renderer() + # Set target reward and target_success if present + self.target_success_rate = target_success_rate + @property def action_space(self) -> Union[List[ActionSpace], ActionSpace]: """ @@ -469,3 +475,5 @@ class Environment(EnvironmentInterface): """ return np.transpose(self.state['observation'], [1, 2, 0]) + def get_target_success_rate(self) -> float: + return self.target_success_rate diff --git a/rl_coach/environments/gym_environment.py b/rl_coach/environments/gym_environment.py index 569c537..ff3df5c 100644 --- a/rl_coach/environments/gym_environment.py +++ b/rl_coach/environments/gym_environment.py @@ -178,11 +178,11 @@ class MaxOverFramesAndFrameskipEnvWrapper(gym.Wrapper): # Environment class GymEnvironment(Environment): def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters, - additional_simulator_parameters: Dict[str, Any] = {}, seed: Union[None, int]=None, + target_success_rate: float=1.0, additional_simulator_parameters: Dict[str, Any] = {}, seed: Union[None, int]=None, human_control: bool=False, custom_reward_threshold: Union[int, float]=None, random_initialization_steps: int=1, max_over_num_frames: int=1, **kwargs): super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, - visualization_parameters) + visualization_parameters, target_success_rate) self.random_initialization_steps = random_initialization_steps self.max_over_num_frames = max_over_num_frames @@ -221,7 +221,7 @@ class GymEnvironment(Environment): try: self.env = env_class(**self.additional_simulator_parameters) except: - screen.error("Failed to instantiate Gym environment class %s with arguments %s" % + screen.error("Failed to instantiate Gym environment class %s with arguments %s" % (env_class, self.additional_simulator_parameters), crash=False) raise else: @@ -337,6 +337,8 @@ class GymEnvironment(Environment): self.reward_success_threshold = self.env.spec.reward_threshold self.reward_space = RewardSpace(1, reward_success_threshold=self.reward_success_threshold) + self.target_success_rate = target_success_rate + def _wrap_state(self, state): if not isinstance(self.env.observation_space, gym.spaces.Dict): return {'observation': state} @@ -434,3 +436,6 @@ class GymEnvironment(Environment): if self.is_mujoco_env: self._set_mujoco_camera(0) return image + + def get_target_success_rate(self) -> float: + return self.target_success_rate diff --git a/rl_coach/environments/starcraft2_environment.py b/rl_coach/environments/starcraft2_environment.py index 69a5f98..87747d7 100644 --- a/rl_coach/environments/starcraft2_environment.py +++ b/rl_coach/environments/starcraft2_environment.py @@ -107,14 +107,14 @@ class StarCraft2EnvironmentParameters(EnvironmentParameters): # Environment class StarCraft2Environment(Environment): def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters, - seed: Union[None, int]=None, human_control: bool=False, + target_success_rate: float=1.0, seed: Union[None, int]=None, human_control: bool=False, custom_reward_threshold: Union[int, float]=None, screen_size: int=84, minimap_size: int=64, feature_minimap_maps_to_use: List=range(7), feature_screen_maps_to_use: List=range(17), observation_type: StarcraftObservationType=StarcraftObservationType.Features, disable_fog: bool=False, auto_select_all_army: bool=True, use_full_action_space: bool=False, **kwargs): - super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters) + super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters, target_success_rate) self.screen_size = screen_size self.minimap_size = minimap_size @@ -163,11 +163,11 @@ class StarCraft2Environment(Environment): """ feature_screen: [height_map, visibility_map, creep, power, player_id, player_relative, unit_type, selected, - unit_hit_points, unit_hit_points_ratio, unit_energy, unit_energy_ratio, unit_shields, + unit_hit_points, unit_hit_points_ratio, unit_energy, unit_energy_ratio, unit_shields, unit_shields_ratio, unit_density, unit_density_aa, effects] feature_minimap: [height_map, visibility_map, creep, camera, player_id, player_relative, selecte d] - player: [player_id, minerals, vespene, food_cap, food_army, food_workers, idle_worker_dount, + player: [player_id, minerals, vespene, food_cap, food_army, food_workers, idle_worker_dount, army_count, warp_gate_count, larva_count] """ self.screen_shape = np.array(self.env.observation_spec()[0]['feature_screen']) @@ -192,6 +192,8 @@ class StarCraft2Environment(Environment): self.action_space = BoxActionSpace(2, 0, self.screen_size - 1, ["X-Axis, Y-Axis"], default_action=np.array([self.screen_size/2, self.screen_size/2])) + self.target_success_rate = target_success_rate + def _update_state(self): timestep = 0 self.screen = self.last_result[timestep].observation.feature_screen @@ -244,3 +246,6 @@ class StarCraft2Environment(Environment): self.env._run_config.replay_dir = experiment_path self.env.save_replay('replays') super().dump_video_of_last_episode() + + def get_target_success_rate(self): + return self.target_success_rate diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index b20a12c..0259941 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -34,6 +34,7 @@ from rl_coach.logger import screen, Logger from rl_coach.utils import set_cpu, start_shell_command_and_wait from rl_coach.data_stores.data_store_impl import get_data_store from rl_coach.orchestrators.kubernetes_orchestrator import RunType +from rl_coach.data_stores.data_store import SyncFiles class ScheduleParameters(Parameters): @@ -458,12 +459,12 @@ class GraphManager(object): """ [manager.sync() for manager in self.level_managers] - def evaluate(self, steps: PlayingStepsType, keep_networks_in_sync: bool=False) -> None: + def evaluate(self, steps: PlayingStepsType, keep_networks_in_sync: bool=False) -> bool: """ Perform evaluation for several steps :param steps: the number of steps as a tuple of steps time and steps count :param keep_networks_in_sync: sync the network parameters with the global network before each episode - :return: None + :return: bool, True if the target reward and target success has been reached """ self.verify_graph_was_created() @@ -478,6 +479,16 @@ class GraphManager(object): while self.current_step_counter < count_end: self.act(EnvironmentEpisodes(1)) self.sync() + if self.should_stop(): + if self.task_parameters.checkpoint_save_dir: + open(os.path.join(self.task_parameters.checkpoint_save_dir, SyncFiles.FINISHED.value), 'w').close() + if hasattr(self, 'data_store_params'): + data_store = get_data_store(self.data_store_params) + data_store.save_to_store() + + screen.success("Reached required success rate. Exiting.") + return True + return False def improve(self): """ @@ -508,7 +519,8 @@ class GraphManager(object): count_end = self.total_steps_counters[RunPhase.TRAIN] + self.improve_steps while self.total_steps_counters[RunPhase.TRAIN] < count_end: self.train_and_act(self.steps_between_evaluation_periods) - self.evaluate(self.evaluation_steps) + if self.evaluate(self.evaluation_steps): + break def _restore_checkpoint_tf(self, checkpoint_dir: str): import tensorflow as tf @@ -609,3 +621,6 @@ class GraphManager(object): def should_train(self) -> bool: return any([manager.should_train() for manager in self.level_managers]) + + def should_stop(self) -> bool: + return all([manager.should_stop() for manager in self.level_managers]) diff --git a/rl_coach/level_manager.py b/rl_coach/level_manager.py index 7697bf5..962fa20 100644 --- a/rl_coach/level_manager.py +++ b/rl_coach/level_manager.py @@ -263,3 +263,6 @@ class LevelManager(EnvironmentInterface): def should_train(self) -> bool: return any([agent._should_train_helper() for agent in self.agents.values()]) + + def should_stop(self) -> bool: + return all([agent.get_success_rate() >= self.environment.get_target_success_rate() for agent in self.agents.values()]) diff --git a/rl_coach/orchestrators/kubernetes_orchestrator.py b/rl_coach/orchestrators/kubernetes_orchestrator.py index d2afb4d..318c9f8 100644 --- a/rl_coach/orchestrators/kubernetes_orchestrator.py +++ b/rl_coach/orchestrators/kubernetes_orchestrator.py @@ -79,6 +79,7 @@ class Kubernetes(Deploy): self.memory_backend = get_memory_backend(self.params.memory_backend_parameters) self.params.data_store_params.orchestrator_params = {'namespace': self.params.namespace} + self.params.data_store_params.namespace = self.params.namespace self.data_store = get_data_store(self.params.data_store_params) if self.params.data_store_params.store_type == "s3": @@ -137,7 +138,8 @@ class Kubernetes(Deploy): volumes=[k8sclient.V1Volume( name="nfs-pvc", persistent_volume_claim=self.nfs_pvc - )] + )], + restart_policy='OnFailure' ), ) else: @@ -155,32 +157,30 @@ class Kubernetes(Deploy): template = k8sclient.V1PodTemplateSpec( metadata=k8sclient.V1ObjectMeta(labels={'app': name}), spec=k8sclient.V1PodSpec( - containers=[container] + containers=[container], + restart_policy='OnFailure' ), ) - deployment_spec = k8sclient.V1DeploymentSpec( - replicas=trainer_params.num_replicas, - template=template, - selector=k8sclient.V1LabelSelector( - match_labels={'app': name} - ) + job_spec = k8sclient.V1JobSpec( + completions=1, + template=template ) - deployment = k8sclient.V1Deployment( - api_version='apps/v1', - kind='Deployment', + job = k8sclient.V1Job( + api_version="batch/v1", + kind="Job", metadata=k8sclient.V1ObjectMeta(name=name), - spec=deployment_spec + spec=job_spec ) - api_client = k8sclient.AppsV1Api() + api_client = k8sclient.BatchV1Api() try: - api_client.create_namespaced_deployment(self.params.namespace, deployment) - trainer_params.orchestration_params['deployment_name'] = name + api_client.create_namespaced_job(self.params.namespace, job) + trainer_params.orchestration_params['job_name'] = name return True except k8sclient.rest.ApiException as e: - print("Got exception: %s\n while creating deployment", e) + print("Got exception: %s\n while creating job", e) return False def deploy_worker(self): @@ -217,6 +217,7 @@ class Kubernetes(Deploy): name="nfs-pvc", persistent_volume_claim=self.nfs_pvc )], + restart_policy='OnFailure' ), ) else: @@ -234,31 +235,31 @@ class Kubernetes(Deploy): template = k8sclient.V1PodTemplateSpec( metadata=k8sclient.V1ObjectMeta(labels={'app': name}), spec=k8sclient.V1PodSpec( - containers=[container] + containers=[container], + restart_policy='OnFailure' ) ) - deployment_spec = k8sclient.V1DeploymentSpec( - replicas=worker_params.num_replicas, - template=template, - selector=k8sclient.V1LabelSelector( - match_labels={'app': name} - ) - ) - deployment = k8sclient.V1Deployment( - api_version='apps/v1', - kind="Deployment", - metadata=k8sclient.V1ObjectMeta(name=name), - spec=deployment_spec + job_spec = k8sclient.V1JobSpec( + completions=worker_params.num_replicas, + parallelism=worker_params.num_replicas, + template=template ) - api_client = k8sclient.AppsV1Api() + job = k8sclient.V1Job( + api_version="batch/v1", + kind="Job", + metadata=k8sclient.V1ObjectMeta(name=name), + spec=job_spec + ) + + api_client = k8sclient.BatchV1Api() try: - api_client.create_namespaced_deployment(self.params.namespace, deployment) - worker_params.orchestration_params['deployment_name'] = name + api_client.create_namespaced_job(self.params.namespace, job) + worker_params.orchestration_params['job_name'] = name return True except k8sclient.rest.ApiException as e: - print("Got exception: %s\n while creating deployment", e) + print("Got exception: %s\n while creating Job", e) return False def worker_logs(self): @@ -273,7 +274,7 @@ class Kubernetes(Deploy): pod = None try: pods = api_client.list_namespaced_pod(self.params.namespace, label_selector='app={}'.format( - trainer_params.orchestration_params['deployment_name'] + trainer_params.orchestration_params['job_name'] )) pod = pods.items[0] @@ -324,17 +325,20 @@ class Kubernetes(Deploy): def undeploy(self): trainer_params = self.params.run_type_params.get(str(RunType.TRAINER), None) - api_client = k8sclient.AppsV1Api() - delete_options = k8sclient.V1DeleteOptions() + api_client = k8sclient.BatchV1Api() + delete_options = k8sclient.V1DeleteOptions( + propagation_policy="Foreground" + ) + if trainer_params: try: - api_client.delete_namespaced_deployment(trainer_params.orchestration_params['deployment_name'], self.params.namespace, delete_options) + api_client.delete_namespaced_job(trainer_params.orchestration_params['job_name'], self.params.namespace, delete_options) except k8sclient.rest.ApiException as e: print("Got exception: %s\n while deleting trainer", e) worker_params = self.params.run_type_params.get(str(RunType.ROLLOUT_WORKER), None) if worker_params: try: - api_client.delete_namespaced_deployment(worker_params.orchestration_params['deployment_name'], self.params.namespace, delete_options) + api_client.delete_namespaced_job(worker_params.orchestration_params['job_name'], self.params.namespace, delete_options) except k8sclient.rest.ApiException as e: print("Got exception: %s\n while deleting workers", e) self.memory_backend.undeploy() diff --git a/rl_coach/presets/CartPole_ClippedPPO.py b/rl_coach/presets/CartPole_ClippedPPO.py index 7c4d3c1..f400478 100644 --- a/rl_coach/presets/CartPole_ClippedPPO.py +++ b/rl_coach/presets/CartPole_ClippedPPO.py @@ -56,6 +56,10 @@ agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000) # Environment # ############### env_params = GymVectorEnvironment(level='CartPole-v0') +env_params.custom_reward_threshold = 200 +# Set the target success +env_params.target_success_rate = 1.0 + ######## # Test # diff --git a/rl_coach/rollout_worker.py b/rl_coach/rollout_worker.py index d039a9e..184ccdc 100644 --- a/rl_coach/rollout_worker.py +++ b/rl_coach/rollout_worker.py @@ -15,6 +15,7 @@ from rl_coach.base_parameters import TaskParameters, DistributedCoachSynchroniza from rl_coach.core_types import EnvironmentSteps, RunPhase from google.protobuf import text_format from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState +from rl_coach.data_stores.data_store import SyncFiles def has_checkpoint(checkpoint_dir): @@ -68,6 +69,10 @@ def get_latest_checkpoint(checkpoint_dir): return int(rel_path.split('_Step')[0]) +def should_stop(checkpoint_dir): + return os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value)) + + def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers): """ wait for first checkpoint then perform rollouts using the model @@ -87,12 +92,17 @@ def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers): for i in range(int(graph_manager.improve_steps.num_steps/act_steps)): + if should_stop(checkpoint_dir): + break + graph_manager.act(EnvironmentSteps(num_steps=act_steps)) new_checkpoint = get_latest_checkpoint(checkpoint_dir) if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC: while new_checkpoint < last_checkpoint + 1: + if should_stop(checkpoint_dir): + break if data_store: data_store.load_from_store() new_checkpoint = get_latest_checkpoint(checkpoint_dir) diff --git a/rl_coach/training_worker.py b/rl_coach/training_worker.py index 3d2f9e5..fbb5640 100644 --- a/rl_coach/training_worker.py +++ b/rl_coach/training_worker.py @@ -40,8 +40,9 @@ def training_worker(graph_manager, checkpoint_dir): graph_manager.phase = core_types.RunPhase.UNDEFINED if steps * graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps > graph_manager.steps_between_evaluation_periods.num_steps * eval_offset: - graph_manager.evaluate(graph_manager.evaluation_steps) eval_offset += 1 + if graph_manager.evaluate(graph_manager.evaluation_steps): + break if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC: graph_manager.save_checkpoint()