1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Adding target reward and target sucess (#58)

* Adding target reward

* Adding target successs

* Addressing comments

* Using custom_reward_threshold and target_success_rate

* Adding exit message

* Moving success rate to environment

* Making target_success_rate optional
This commit is contained in:
Ajay Deshpande
2018-11-12 15:03:43 -08:00
committed by Balaji Subramaniam
parent 0fe583186e
commit 875d6ef017
17 changed files with 162 additions and 74 deletions

View File

@@ -842,7 +842,5 @@ class Agent(AgentInterface):
for network in self.networks.values(): for network in self.networks.values():
network.sync() network.sync()
def get_success_rate(self) -> float:
return self.num_successes_across_evaluation_episodes / self.num_evaluation_episodes_completed

View File

@@ -1,4 +1,6 @@
from enum import Enum
class DataStoreParameters(object): class DataStoreParameters(object):
def __init__(self, store_type, orchestrator_type, orchestrator_params): def __init__(self, store_type, orchestrator_type, orchestrator_params):
@@ -6,6 +8,7 @@ class DataStoreParameters(object):
self.orchestrator_type = orchestrator_type self.orchestrator_type = orchestrator_type
self.orchestrator_params = orchestrator_params self.orchestrator_params = orchestrator_params
class DataStore(object): class DataStore(object):
def __init__(self, params: DataStoreParameters): def __init__(self, params: DataStoreParameters):
pass pass
@@ -24,3 +27,8 @@ class DataStore(object):
def load_from_store(self): def load_from_store(self):
pass pass
class SyncFiles(Enum):
FINISHED = ".finished"
LOCKFILE = ".lock"

View File

@@ -58,7 +58,7 @@ class NFSDataStore(DataStore):
pass pass
def deploy_k8s_nfs(self) -> bool: def deploy_k8s_nfs(self) -> bool:
name = "nfs-server" name = "nfs-server-{}".format(uuid.uuid4())
container = k8sclient.V1Container( container = k8sclient.V1Container(
name=name, name=name,
image="k8s.gcr.io/volume-nfs:0.8", image="k8s.gcr.io/volume-nfs:0.8",
@@ -83,7 +83,7 @@ class NFSDataStore(DataStore):
security_context=k8sclient.V1SecurityContext(privileged=True) security_context=k8sclient.V1SecurityContext(privileged=True)
) )
template = k8sclient.V1PodTemplateSpec( template = k8sclient.V1PodTemplateSpec(
metadata=k8sclient.V1ObjectMeta(labels={'app': 'nfs-server'}), metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
spec=k8sclient.V1PodSpec( spec=k8sclient.V1PodSpec(
containers=[container], containers=[container],
volumes=[k8sclient.V1Volume( volumes=[k8sclient.V1Volume(
@@ -96,14 +96,14 @@ class NFSDataStore(DataStore):
replicas=1, replicas=1,
template=template, template=template,
selector=k8sclient.V1LabelSelector( selector=k8sclient.V1LabelSelector(
match_labels={'app': 'nfs-server'} match_labels={'app': name}
) )
) )
deployment = k8sclient.V1Deployment( deployment = k8sclient.V1Deployment(
api_version='apps/v1', api_version='apps/v1',
kind='Deployment', kind='Deployment',
metadata=k8sclient.V1ObjectMeta(name=name, labels={'app': 'nfs-server'}), metadata=k8sclient.V1ObjectMeta(name=name, labels={'app': name}),
spec=deployment_spec spec=deployment_spec
) )
@@ -117,7 +117,7 @@ class NFSDataStore(DataStore):
k8s_core_v1_api_client = k8sclient.CoreV1Api() k8s_core_v1_api_client = k8sclient.CoreV1Api()
svc_name = "nfs-service" svc_name = "nfs-service-{}".format(uuid.uuid4())
service = k8sclient.V1Service( service = k8sclient.V1Service(
api_version='v1', api_version='v1',
kind='Service', kind='Service',
@@ -145,7 +145,7 @@ class NFSDataStore(DataStore):
return True return True
def create_k8s_nfs_resources(self) -> bool: def create_k8s_nfs_resources(self) -> bool:
pv_name = "nfs-ckpt-pv" pv_name = "nfs-ckpt-pv-{}".format(uuid.uuid4())
persistent_volume = k8sclient.V1PersistentVolume( persistent_volume = k8sclient.V1PersistentVolume(
api_version="v1", api_version="v1",
kind="PersistentVolume", kind="PersistentVolume",
@@ -171,7 +171,7 @@ class NFSDataStore(DataStore):
print("Got exception: %s\n while creating the NFS PV", e) print("Got exception: %s\n while creating the NFS PV", e)
return False return False
pvc_name = "nfs-ckpt-pvc" pvc_name = "nfs-ckpt-pvc-{}".format(uuid.uuid4())
persistent_volume_claim = k8sclient.V1PersistentVolumeClaim( persistent_volume_claim = k8sclient.V1PersistentVolumeClaim(
api_version="v1", api_version="v1",
kind="PersistentVolumeClaim", kind="PersistentVolumeClaim",

View File

@@ -5,6 +5,7 @@ from minio.error import ResponseError
from configparser import ConfigParser, Error from configparser import ConfigParser, Error
from google.protobuf import text_format from google.protobuf import text_format
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
from rl_coach.data_stores.data_store import SyncFiles
import os import os
import time import time
@@ -20,7 +21,6 @@ class S3DataStoreParameters(DataStoreParameters):
self.end_point = end_point self.end_point = end_point
self.bucket_name = bucket_name self.bucket_name = bucket_name
self.checkpoint_dir = checkpoint_dir self.checkpoint_dir = checkpoint_dir
self.lock_file = ".lock"
class S3DataStore(DataStore): class S3DataStore(DataStore):
@@ -52,9 +52,9 @@ class S3DataStore(DataStore):
def save_to_store(self): def save_to_store(self):
try: try:
self.mc.remove_object(self.params.bucket_name, self.params.lock_file) self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value)
self.mc.put_object(self.params.bucket_name, self.params.lock_file, io.BytesIO(b''), 0) self.mc.put_object(self.params.bucket_name, SyncFiles.LOCKFILE.value, io.BytesIO(b''), 0)
checkpoint_file = None checkpoint_file = None
for root, dirs, files in os.walk(self.params.checkpoint_dir): for root, dirs, files in os.walk(self.params.checkpoint_dir):
@@ -70,7 +70,7 @@ class S3DataStore(DataStore):
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir) rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name) self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
self.mc.remove_object(self.params.bucket_name, self.params.lock_file) self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value)
except ResponseError as e: except ResponseError as e:
print("Got exception: %s\n while saving to S3", e) print("Got exception: %s\n while saving to S3", e)
@@ -80,7 +80,7 @@ class S3DataStore(DataStore):
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint")) filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint"))
while True: while True:
objects = self.mc.list_objects_v2(self.params.bucket_name, self.params.lock_file) objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.LOCKFILE.value)
if next(objects, None) is None: if next(objects, None) is None:
try: try:
@@ -90,6 +90,18 @@ class S3DataStore(DataStore):
break break
time.sleep(10) time.sleep(10)
# Check if there's a finished file
objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.FINISHED.value)
if next(objects, None) is not None:
try:
self.mc.fget_object(
self.params.bucket_name, SyncFiles.FINISHED.value,
os.path.abspath(os.path.join(self.params.checkpoint_dir, SyncFiles.FINISHED.value))
)
except Exception as e:
pass
ckpt = CheckpointState() ckpt = CheckpointState()
if os.path.exists(filename): if os.path.exists(filename):
contents = open(filename, 'r').read() contents = open(filename, 'r').read()

View File

@@ -133,8 +133,8 @@ class CarlaEnvironment(Environment):
allow_braking: bool, quality: CarlaEnvironmentParameters.Quality, allow_braking: bool, quality: CarlaEnvironmentParameters.Quality,
cameras: List[CameraTypes], weather_id: List[int], experiment_path: str, cameras: List[CameraTypes], weather_id: List[int], experiment_path: str,
separate_actions_for_throttle_and_brake: bool, separate_actions_for_throttle_and_brake: bool,
num_speedup_steps: int, max_speed: float, **kwargs): num_speedup_steps: int, max_speed: float, target_success_rate: float = 1.0, **kwargs):
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters) super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters, target_success_rate)
# server configuration # server configuration
self.server_height = server_height self.server_height = server_height
@@ -261,6 +261,8 @@ class CarlaEnvironment(Environment):
image = self.get_rendered_image() image = self.get_rendered_image()
self.renderer.create_screen(image.shape[1], image.shape[0]) self.renderer.create_screen(image.shape[1], image.shape[0])
self.target_success_rate = target_success_rate
def _add_cameras(self, settings, cameras, camera_width, camera_height): def _add_cameras(self, settings, cameras, camera_width, camera_height):
# add a front facing camera # add a front facing camera
if CameraTypes.FRONT in cameras: if CameraTypes.FRONT in cameras:
@@ -461,3 +463,6 @@ class CarlaEnvironment(Environment):
image = [self.state[camera.name] for camera in self.scene.sensors] image = [self.state[camera.name] for camera in self.scene.sensors]
image = np.vstack(image) image = np.vstack(image)
return image return image
def get_target_success_rate(self) -> float:
return self.target_success_rate

View File

@@ -66,10 +66,10 @@ control_suite_envs = {':'.join(env): ':'.join(env) for env in suite.BENCHMARKING
# Environment # Environment
class ControlSuiteEnvironment(Environment): class ControlSuiteEnvironment(Environment):
def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters, def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters,
seed: Union[None, int]=None, human_control: bool=False, target_success_rate: float=1.0, seed: Union[None, int]=None, human_control: bool=False,
observation_type: ObservationType=ObservationType.Measurements, observation_type: ObservationType=ObservationType.Measurements,
custom_reward_threshold: Union[int, float]=None, **kwargs): custom_reward_threshold: Union[int, float]=None, **kwargs):
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters) super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters, target_success_rate)
self.observation_type = observation_type self.observation_type = observation_type
@@ -126,6 +126,8 @@ class ControlSuiteEnvironment(Environment):
if not self.native_rendering: if not self.native_rendering:
self.renderer.create_screen(image.shape[1]*scale, image.shape[0]*scale) self.renderer.create_screen(image.shape[1]*scale, image.shape[0]*scale)
self.target_success_rate = target_success_rate
def _update_state(self): def _update_state(self):
self.state = {} self.state = {}
@@ -160,3 +162,6 @@ class ControlSuiteEnvironment(Environment):
def get_rendered_image(self): def get_rendered_image(self):
return self.env.physics.render(camera_id=0) return self.env.physics.render(camera_id=0)
def get_target_success_rate(self) -> float:
return self.target_success_rate

View File

@@ -124,8 +124,8 @@ class DoomEnvironment(Environment):
def __init__(self, level: LevelSelection, seed: int, frame_skip: int, human_control: bool, def __init__(self, level: LevelSelection, seed: int, frame_skip: int, human_control: bool,
custom_reward_threshold: Union[int, float], visualization_parameters: VisualizationParameters, custom_reward_threshold: Union[int, float], visualization_parameters: VisualizationParameters,
cameras: List[CameraTypes], **kwargs): cameras: List[CameraTypes], target_success_rate: float=1.0, **kwargs):
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters) super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters, target_success_rate)
self.cameras = cameras self.cameras = cameras
@@ -196,6 +196,8 @@ class DoomEnvironment(Environment):
image = self.get_rendered_image() image = self.get_rendered_image()
self.renderer.create_screen(image.shape[1], image.shape[0]) self.renderer.create_screen(image.shape[1], image.shape[0])
self.target_success_rate = target_success_rate
def _update_state(self): def _update_state(self):
# extract all data from the current state # extract all data from the current state
state = self.game.get_state() state = self.game.get_state()
@@ -227,3 +229,6 @@ class DoomEnvironment(Environment):
image = [self.state[camera.value[0]] for camera in self.cameras] image = [self.state[camera.value[0]] for camera in self.cameras]
image = np.vstack(image) image = np.vstack(image)
return image return image
def get_target_success_rate(self) -> float:
return self.target_success_rate

View File

@@ -103,6 +103,9 @@ class EnvironmentParameters(Parameters):
self.default_output_filter = None self.default_output_filter = None
self.experiment_path = None self.experiment_path = None
# Set target reward and target_success if present
self.target_success_rate = 1.0
@property @property
def path(self): def path(self):
return 'rl_coach.environments.environment:Environment' return 'rl_coach.environments.environment:Environment'
@@ -111,7 +114,7 @@ class EnvironmentParameters(Parameters):
class Environment(EnvironmentInterface): class Environment(EnvironmentInterface):
def __init__(self, level: LevelSelection, seed: int, frame_skip: int, human_control: bool, def __init__(self, level: LevelSelection, seed: int, frame_skip: int, human_control: bool,
custom_reward_threshold: Union[int, float], visualization_parameters: VisualizationParameters, custom_reward_threshold: Union[int, float], visualization_parameters: VisualizationParameters,
**kwargs): target_success_rate: float=1.0, **kwargs):
""" """
:param level: The environment level. Each environment can have multiple levels :param level: The environment level. Each environment can have multiple levels
:param seed: a seed for the random number generator of the environment :param seed: a seed for the random number generator of the environment
@@ -166,6 +169,9 @@ class Environment(EnvironmentInterface):
if not self.native_rendering: if not self.native_rendering:
self.renderer = Renderer() self.renderer = Renderer()
# Set target reward and target_success if present
self.target_success_rate = target_success_rate
@property @property
def action_space(self) -> Union[List[ActionSpace], ActionSpace]: def action_space(self) -> Union[List[ActionSpace], ActionSpace]:
""" """
@@ -469,3 +475,5 @@ class Environment(EnvironmentInterface):
""" """
return np.transpose(self.state['observation'], [1, 2, 0]) return np.transpose(self.state['observation'], [1, 2, 0])
def get_target_success_rate(self) -> float:
return self.target_success_rate

View File

@@ -178,11 +178,11 @@ class MaxOverFramesAndFrameskipEnvWrapper(gym.Wrapper):
# Environment # Environment
class GymEnvironment(Environment): class GymEnvironment(Environment):
def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters, def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters,
additional_simulator_parameters: Dict[str, Any] = {}, seed: Union[None, int]=None, target_success_rate: float=1.0, additional_simulator_parameters: Dict[str, Any] = {}, seed: Union[None, int]=None,
human_control: bool=False, custom_reward_threshold: Union[int, float]=None, human_control: bool=False, custom_reward_threshold: Union[int, float]=None,
random_initialization_steps: int=1, max_over_num_frames: int=1, **kwargs): random_initialization_steps: int=1, max_over_num_frames: int=1, **kwargs):
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold,
visualization_parameters) visualization_parameters, target_success_rate)
self.random_initialization_steps = random_initialization_steps self.random_initialization_steps = random_initialization_steps
self.max_over_num_frames = max_over_num_frames self.max_over_num_frames = max_over_num_frames
@@ -337,6 +337,8 @@ class GymEnvironment(Environment):
self.reward_success_threshold = self.env.spec.reward_threshold self.reward_success_threshold = self.env.spec.reward_threshold
self.reward_space = RewardSpace(1, reward_success_threshold=self.reward_success_threshold) self.reward_space = RewardSpace(1, reward_success_threshold=self.reward_success_threshold)
self.target_success_rate = target_success_rate
def _wrap_state(self, state): def _wrap_state(self, state):
if not isinstance(self.env.observation_space, gym.spaces.Dict): if not isinstance(self.env.observation_space, gym.spaces.Dict):
return {'observation': state} return {'observation': state}
@@ -434,3 +436,6 @@ class GymEnvironment(Environment):
if self.is_mujoco_env: if self.is_mujoco_env:
self._set_mujoco_camera(0) self._set_mujoco_camera(0)
return image return image
def get_target_success_rate(self) -> float:
return self.target_success_rate

View File

@@ -107,14 +107,14 @@ class StarCraft2EnvironmentParameters(EnvironmentParameters):
# Environment # Environment
class StarCraft2Environment(Environment): class StarCraft2Environment(Environment):
def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters, def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters,
seed: Union[None, int]=None, human_control: bool=False, target_success_rate: float=1.0, seed: Union[None, int]=None, human_control: bool=False,
custom_reward_threshold: Union[int, float]=None, custom_reward_threshold: Union[int, float]=None,
screen_size: int=84, minimap_size: int=64, screen_size: int=84, minimap_size: int=64,
feature_minimap_maps_to_use: List=range(7), feature_screen_maps_to_use: List=range(17), feature_minimap_maps_to_use: List=range(7), feature_screen_maps_to_use: List=range(17),
observation_type: StarcraftObservationType=StarcraftObservationType.Features, observation_type: StarcraftObservationType=StarcraftObservationType.Features,
disable_fog: bool=False, auto_select_all_army: bool=True, disable_fog: bool=False, auto_select_all_army: bool=True,
use_full_action_space: bool=False, **kwargs): use_full_action_space: bool=False, **kwargs):
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters) super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters, target_success_rate)
self.screen_size = screen_size self.screen_size = screen_size
self.minimap_size = minimap_size self.minimap_size = minimap_size
@@ -192,6 +192,8 @@ class StarCraft2Environment(Environment):
self.action_space = BoxActionSpace(2, 0, self.screen_size - 1, ["X-Axis, Y-Axis"], self.action_space = BoxActionSpace(2, 0, self.screen_size - 1, ["X-Axis, Y-Axis"],
default_action=np.array([self.screen_size/2, self.screen_size/2])) default_action=np.array([self.screen_size/2, self.screen_size/2]))
self.target_success_rate = target_success_rate
def _update_state(self): def _update_state(self):
timestep = 0 timestep = 0
self.screen = self.last_result[timestep].observation.feature_screen self.screen = self.last_result[timestep].observation.feature_screen
@@ -244,3 +246,6 @@ class StarCraft2Environment(Environment):
self.env._run_config.replay_dir = experiment_path self.env._run_config.replay_dir = experiment_path
self.env.save_replay('replays') self.env.save_replay('replays')
super().dump_video_of_last_episode() super().dump_video_of_last_episode()
def get_target_success_rate(self):
return self.target_success_rate

View File

@@ -34,6 +34,7 @@ from rl_coach.logger import screen, Logger
from rl_coach.utils import set_cpu, start_shell_command_and_wait from rl_coach.utils import set_cpu, start_shell_command_and_wait
from rl_coach.data_stores.data_store_impl import get_data_store from rl_coach.data_stores.data_store_impl import get_data_store
from rl_coach.orchestrators.kubernetes_orchestrator import RunType from rl_coach.orchestrators.kubernetes_orchestrator import RunType
from rl_coach.data_stores.data_store import SyncFiles
class ScheduleParameters(Parameters): class ScheduleParameters(Parameters):
@@ -458,12 +459,12 @@ class GraphManager(object):
""" """
[manager.sync() for manager in self.level_managers] [manager.sync() for manager in self.level_managers]
def evaluate(self, steps: PlayingStepsType, keep_networks_in_sync: bool=False) -> None: def evaluate(self, steps: PlayingStepsType, keep_networks_in_sync: bool=False) -> bool:
""" """
Perform evaluation for several steps Perform evaluation for several steps
:param steps: the number of steps as a tuple of steps time and steps count :param steps: the number of steps as a tuple of steps time and steps count
:param keep_networks_in_sync: sync the network parameters with the global network before each episode :param keep_networks_in_sync: sync the network parameters with the global network before each episode
:return: None :return: bool, True if the target reward and target success has been reached
""" """
self.verify_graph_was_created() self.verify_graph_was_created()
@@ -478,6 +479,16 @@ class GraphManager(object):
while self.current_step_counter < count_end: while self.current_step_counter < count_end:
self.act(EnvironmentEpisodes(1)) self.act(EnvironmentEpisodes(1))
self.sync() self.sync()
if self.should_stop():
if self.task_parameters.checkpoint_save_dir:
open(os.path.join(self.task_parameters.checkpoint_save_dir, SyncFiles.FINISHED.value), 'w').close()
if hasattr(self, 'data_store_params'):
data_store = get_data_store(self.data_store_params)
data_store.save_to_store()
screen.success("Reached required success rate. Exiting.")
return True
return False
def improve(self): def improve(self):
""" """
@@ -508,7 +519,8 @@ class GraphManager(object):
count_end = self.total_steps_counters[RunPhase.TRAIN] + self.improve_steps count_end = self.total_steps_counters[RunPhase.TRAIN] + self.improve_steps
while self.total_steps_counters[RunPhase.TRAIN] < count_end: while self.total_steps_counters[RunPhase.TRAIN] < count_end:
self.train_and_act(self.steps_between_evaluation_periods) self.train_and_act(self.steps_between_evaluation_periods)
self.evaluate(self.evaluation_steps) if self.evaluate(self.evaluation_steps):
break
def _restore_checkpoint_tf(self, checkpoint_dir: str): def _restore_checkpoint_tf(self, checkpoint_dir: str):
import tensorflow as tf import tensorflow as tf
@@ -609,3 +621,6 @@ class GraphManager(object):
def should_train(self) -> bool: def should_train(self) -> bool:
return any([manager.should_train() for manager in self.level_managers]) return any([manager.should_train() for manager in self.level_managers])
def should_stop(self) -> bool:
return all([manager.should_stop() for manager in self.level_managers])

View File

@@ -263,3 +263,6 @@ class LevelManager(EnvironmentInterface):
def should_train(self) -> bool: def should_train(self) -> bool:
return any([agent._should_train_helper() for agent in self.agents.values()]) return any([agent._should_train_helper() for agent in self.agents.values()])
def should_stop(self) -> bool:
return all([agent.get_success_rate() >= self.environment.get_target_success_rate() for agent in self.agents.values()])

View File

@@ -79,6 +79,7 @@ class Kubernetes(Deploy):
self.memory_backend = get_memory_backend(self.params.memory_backend_parameters) self.memory_backend = get_memory_backend(self.params.memory_backend_parameters)
self.params.data_store_params.orchestrator_params = {'namespace': self.params.namespace} self.params.data_store_params.orchestrator_params = {'namespace': self.params.namespace}
self.params.data_store_params.namespace = self.params.namespace
self.data_store = get_data_store(self.params.data_store_params) self.data_store = get_data_store(self.params.data_store_params)
if self.params.data_store_params.store_type == "s3": if self.params.data_store_params.store_type == "s3":
@@ -137,7 +138,8 @@ class Kubernetes(Deploy):
volumes=[k8sclient.V1Volume( volumes=[k8sclient.V1Volume(
name="nfs-pvc", name="nfs-pvc",
persistent_volume_claim=self.nfs_pvc persistent_volume_claim=self.nfs_pvc
)] )],
restart_policy='OnFailure'
), ),
) )
else: else:
@@ -155,32 +157,30 @@ class Kubernetes(Deploy):
template = k8sclient.V1PodTemplateSpec( template = k8sclient.V1PodTemplateSpec(
metadata=k8sclient.V1ObjectMeta(labels={'app': name}), metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
spec=k8sclient.V1PodSpec( spec=k8sclient.V1PodSpec(
containers=[container] containers=[container],
restart_policy='OnFailure'
), ),
) )
deployment_spec = k8sclient.V1DeploymentSpec( job_spec = k8sclient.V1JobSpec(
replicas=trainer_params.num_replicas, completions=1,
template=template, template=template
selector=k8sclient.V1LabelSelector(
match_labels={'app': name}
)
) )
deployment = k8sclient.V1Deployment( job = k8sclient.V1Job(
api_version='apps/v1', api_version="batch/v1",
kind='Deployment', kind="Job",
metadata=k8sclient.V1ObjectMeta(name=name), metadata=k8sclient.V1ObjectMeta(name=name),
spec=deployment_spec spec=job_spec
) )
api_client = k8sclient.AppsV1Api() api_client = k8sclient.BatchV1Api()
try: try:
api_client.create_namespaced_deployment(self.params.namespace, deployment) api_client.create_namespaced_job(self.params.namespace, job)
trainer_params.orchestration_params['deployment_name'] = name trainer_params.orchestration_params['job_name'] = name
return True return True
except k8sclient.rest.ApiException as e: except k8sclient.rest.ApiException as e:
print("Got exception: %s\n while creating deployment", e) print("Got exception: %s\n while creating job", e)
return False return False
def deploy_worker(self): def deploy_worker(self):
@@ -217,6 +217,7 @@ class Kubernetes(Deploy):
name="nfs-pvc", name="nfs-pvc",
persistent_volume_claim=self.nfs_pvc persistent_volume_claim=self.nfs_pvc
)], )],
restart_policy='OnFailure'
), ),
) )
else: else:
@@ -234,31 +235,31 @@ class Kubernetes(Deploy):
template = k8sclient.V1PodTemplateSpec( template = k8sclient.V1PodTemplateSpec(
metadata=k8sclient.V1ObjectMeta(labels={'app': name}), metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
spec=k8sclient.V1PodSpec( spec=k8sclient.V1PodSpec(
containers=[container] containers=[container],
restart_policy='OnFailure'
) )
) )
deployment_spec = k8sclient.V1DeploymentSpec( job_spec = k8sclient.V1JobSpec(
replicas=worker_params.num_replicas, completions=worker_params.num_replicas,
template=template, parallelism=worker_params.num_replicas,
selector=k8sclient.V1LabelSelector( template=template
match_labels={'app': name}
) )
)
deployment = k8sclient.V1Deployment( job = k8sclient.V1Job(
api_version='apps/v1', api_version="batch/v1",
kind="Deployment", kind="Job",
metadata=k8sclient.V1ObjectMeta(name=name), metadata=k8sclient.V1ObjectMeta(name=name),
spec=deployment_spec spec=job_spec
) )
api_client = k8sclient.AppsV1Api() api_client = k8sclient.BatchV1Api()
try: try:
api_client.create_namespaced_deployment(self.params.namespace, deployment) api_client.create_namespaced_job(self.params.namespace, job)
worker_params.orchestration_params['deployment_name'] = name worker_params.orchestration_params['job_name'] = name
return True return True
except k8sclient.rest.ApiException as e: except k8sclient.rest.ApiException as e:
print("Got exception: %s\n while creating deployment", e) print("Got exception: %s\n while creating Job", e)
return False return False
def worker_logs(self): def worker_logs(self):
@@ -273,7 +274,7 @@ class Kubernetes(Deploy):
pod = None pod = None
try: try:
pods = api_client.list_namespaced_pod(self.params.namespace, label_selector='app={}'.format( pods = api_client.list_namespaced_pod(self.params.namespace, label_selector='app={}'.format(
trainer_params.orchestration_params['deployment_name'] trainer_params.orchestration_params['job_name']
)) ))
pod = pods.items[0] pod = pods.items[0]
@@ -324,17 +325,20 @@ class Kubernetes(Deploy):
def undeploy(self): def undeploy(self):
trainer_params = self.params.run_type_params.get(str(RunType.TRAINER), None) trainer_params = self.params.run_type_params.get(str(RunType.TRAINER), None)
api_client = k8sclient.AppsV1Api() api_client = k8sclient.BatchV1Api()
delete_options = k8sclient.V1DeleteOptions() delete_options = k8sclient.V1DeleteOptions(
propagation_policy="Foreground"
)
if trainer_params: if trainer_params:
try: try:
api_client.delete_namespaced_deployment(trainer_params.orchestration_params['deployment_name'], self.params.namespace, delete_options) api_client.delete_namespaced_job(trainer_params.orchestration_params['job_name'], self.params.namespace, delete_options)
except k8sclient.rest.ApiException as e: except k8sclient.rest.ApiException as e:
print("Got exception: %s\n while deleting trainer", e) print("Got exception: %s\n while deleting trainer", e)
worker_params = self.params.run_type_params.get(str(RunType.ROLLOUT_WORKER), None) worker_params = self.params.run_type_params.get(str(RunType.ROLLOUT_WORKER), None)
if worker_params: if worker_params:
try: try:
api_client.delete_namespaced_deployment(worker_params.orchestration_params['deployment_name'], self.params.namespace, delete_options) api_client.delete_namespaced_job(worker_params.orchestration_params['job_name'], self.params.namespace, delete_options)
except k8sclient.rest.ApiException as e: except k8sclient.rest.ApiException as e:
print("Got exception: %s\n while deleting workers", e) print("Got exception: %s\n while deleting workers", e)
self.memory_backend.undeploy() self.memory_backend.undeploy()

View File

@@ -56,6 +56,10 @@ agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000)
# Environment # # Environment #
############### ###############
env_params = GymVectorEnvironment(level='CartPole-v0') env_params = GymVectorEnvironment(level='CartPole-v0')
env_params.custom_reward_threshold = 200
# Set the target success
env_params.target_success_rate = 1.0
######## ########
# Test # # Test #

View File

@@ -15,6 +15,7 @@ from rl_coach.base_parameters import TaskParameters, DistributedCoachSynchroniza
from rl_coach.core_types import EnvironmentSteps, RunPhase from rl_coach.core_types import EnvironmentSteps, RunPhase
from google.protobuf import text_format from google.protobuf import text_format
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
from rl_coach.data_stores.data_store import SyncFiles
def has_checkpoint(checkpoint_dir): def has_checkpoint(checkpoint_dir):
@@ -68,6 +69,10 @@ def get_latest_checkpoint(checkpoint_dir):
return int(rel_path.split('_Step')[0]) return int(rel_path.split('_Step')[0])
def should_stop(checkpoint_dir):
return os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value))
def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers): def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers):
""" """
wait for first checkpoint then perform rollouts using the model wait for first checkpoint then perform rollouts using the model
@@ -87,12 +92,17 @@ def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers):
for i in range(int(graph_manager.improve_steps.num_steps/act_steps)): for i in range(int(graph_manager.improve_steps.num_steps/act_steps)):
if should_stop(checkpoint_dir):
break
graph_manager.act(EnvironmentSteps(num_steps=act_steps)) graph_manager.act(EnvironmentSteps(num_steps=act_steps))
new_checkpoint = get_latest_checkpoint(checkpoint_dir) new_checkpoint = get_latest_checkpoint(checkpoint_dir)
if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC: if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
while new_checkpoint < last_checkpoint + 1: while new_checkpoint < last_checkpoint + 1:
if should_stop(checkpoint_dir):
break
if data_store: if data_store:
data_store.load_from_store() data_store.load_from_store()
new_checkpoint = get_latest_checkpoint(checkpoint_dir) new_checkpoint = get_latest_checkpoint(checkpoint_dir)

View File

@@ -40,8 +40,9 @@ def training_worker(graph_manager, checkpoint_dir):
graph_manager.phase = core_types.RunPhase.UNDEFINED graph_manager.phase = core_types.RunPhase.UNDEFINED
if steps * graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps > graph_manager.steps_between_evaluation_periods.num_steps * eval_offset: if steps * graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps > graph_manager.steps_between_evaluation_periods.num_steps * eval_offset:
graph_manager.evaluate(graph_manager.evaluation_steps)
eval_offset += 1 eval_offset += 1
if graph_manager.evaluate(graph_manager.evaluation_steps):
break
if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC: if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
graph_manager.save_checkpoint() graph_manager.save_checkpoint()