mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Adding target reward and target sucess (#58)
* Adding target reward * Adding target successs * Addressing comments * Using custom_reward_threshold and target_success_rate * Adding exit message * Moving success rate to environment * Making target_success_rate optional
This commit is contained in:
committed by
Balaji Subramaniam
parent
0fe583186e
commit
875d6ef017
@@ -842,7 +842,5 @@ class Agent(AgentInterface):
|
||||
for network in self.networks.values():
|
||||
network.sync()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def get_success_rate(self) -> float:
|
||||
return self.num_successes_across_evaluation_episodes / self.num_evaluation_episodes_completed
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class DataStoreParameters(object):
|
||||
def __init__(self, store_type, orchestrator_type, orchestrator_params):
|
||||
@@ -6,6 +8,7 @@ class DataStoreParameters(object):
|
||||
self.orchestrator_type = orchestrator_type
|
||||
self.orchestrator_params = orchestrator_params
|
||||
|
||||
|
||||
class DataStore(object):
|
||||
def __init__(self, params: DataStoreParameters):
|
||||
pass
|
||||
@@ -24,3 +27,8 @@ class DataStore(object):
|
||||
|
||||
def load_from_store(self):
|
||||
pass
|
||||
|
||||
|
||||
class SyncFiles(Enum):
|
||||
FINISHED = ".finished"
|
||||
LOCKFILE = ".lock"
|
||||
|
||||
@@ -58,7 +58,7 @@ class NFSDataStore(DataStore):
|
||||
pass
|
||||
|
||||
def deploy_k8s_nfs(self) -> bool:
|
||||
name = "nfs-server"
|
||||
name = "nfs-server-{}".format(uuid.uuid4())
|
||||
container = k8sclient.V1Container(
|
||||
name=name,
|
||||
image="k8s.gcr.io/volume-nfs:0.8",
|
||||
@@ -83,7 +83,7 @@ class NFSDataStore(DataStore):
|
||||
security_context=k8sclient.V1SecurityContext(privileged=True)
|
||||
)
|
||||
template = k8sclient.V1PodTemplateSpec(
|
||||
metadata=k8sclient.V1ObjectMeta(labels={'app': 'nfs-server'}),
|
||||
metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
|
||||
spec=k8sclient.V1PodSpec(
|
||||
containers=[container],
|
||||
volumes=[k8sclient.V1Volume(
|
||||
@@ -96,14 +96,14 @@ class NFSDataStore(DataStore):
|
||||
replicas=1,
|
||||
template=template,
|
||||
selector=k8sclient.V1LabelSelector(
|
||||
match_labels={'app': 'nfs-server'}
|
||||
match_labels={'app': name}
|
||||
)
|
||||
)
|
||||
|
||||
deployment = k8sclient.V1Deployment(
|
||||
api_version='apps/v1',
|
||||
kind='Deployment',
|
||||
metadata=k8sclient.V1ObjectMeta(name=name, labels={'app': 'nfs-server'}),
|
||||
metadata=k8sclient.V1ObjectMeta(name=name, labels={'app': name}),
|
||||
spec=deployment_spec
|
||||
)
|
||||
|
||||
@@ -117,7 +117,7 @@ class NFSDataStore(DataStore):
|
||||
|
||||
k8s_core_v1_api_client = k8sclient.CoreV1Api()
|
||||
|
||||
svc_name = "nfs-service"
|
||||
svc_name = "nfs-service-{}".format(uuid.uuid4())
|
||||
service = k8sclient.V1Service(
|
||||
api_version='v1',
|
||||
kind='Service',
|
||||
@@ -145,7 +145,7 @@ class NFSDataStore(DataStore):
|
||||
return True
|
||||
|
||||
def create_k8s_nfs_resources(self) -> bool:
|
||||
pv_name = "nfs-ckpt-pv"
|
||||
pv_name = "nfs-ckpt-pv-{}".format(uuid.uuid4())
|
||||
persistent_volume = k8sclient.V1PersistentVolume(
|
||||
api_version="v1",
|
||||
kind="PersistentVolume",
|
||||
@@ -171,7 +171,7 @@ class NFSDataStore(DataStore):
|
||||
print("Got exception: %s\n while creating the NFS PV", e)
|
||||
return False
|
||||
|
||||
pvc_name = "nfs-ckpt-pvc"
|
||||
pvc_name = "nfs-ckpt-pvc-{}".format(uuid.uuid4())
|
||||
persistent_volume_claim = k8sclient.V1PersistentVolumeClaim(
|
||||
api_version="v1",
|
||||
kind="PersistentVolumeClaim",
|
||||
|
||||
@@ -5,6 +5,7 @@ from minio.error import ResponseError
|
||||
from configparser import ConfigParser, Error
|
||||
from google.protobuf import text_format
|
||||
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
|
||||
from rl_coach.data_stores.data_store import SyncFiles
|
||||
|
||||
import os
|
||||
import time
|
||||
@@ -20,7 +21,6 @@ class S3DataStoreParameters(DataStoreParameters):
|
||||
self.end_point = end_point
|
||||
self.bucket_name = bucket_name
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
self.lock_file = ".lock"
|
||||
|
||||
|
||||
class S3DataStore(DataStore):
|
||||
@@ -52,9 +52,9 @@ class S3DataStore(DataStore):
|
||||
|
||||
def save_to_store(self):
|
||||
try:
|
||||
self.mc.remove_object(self.params.bucket_name, self.params.lock_file)
|
||||
self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value)
|
||||
|
||||
self.mc.put_object(self.params.bucket_name, self.params.lock_file, io.BytesIO(b''), 0)
|
||||
self.mc.put_object(self.params.bucket_name, SyncFiles.LOCKFILE.value, io.BytesIO(b''), 0)
|
||||
|
||||
checkpoint_file = None
|
||||
for root, dirs, files in os.walk(self.params.checkpoint_dir):
|
||||
@@ -70,7 +70,7 @@ class S3DataStore(DataStore):
|
||||
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
|
||||
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
|
||||
|
||||
self.mc.remove_object(self.params.bucket_name, self.params.lock_file)
|
||||
self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value)
|
||||
|
||||
except ResponseError as e:
|
||||
print("Got exception: %s\n while saving to S3", e)
|
||||
@@ -80,7 +80,7 @@ class S3DataStore(DataStore):
|
||||
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint"))
|
||||
|
||||
while True:
|
||||
objects = self.mc.list_objects_v2(self.params.bucket_name, self.params.lock_file)
|
||||
objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.LOCKFILE.value)
|
||||
|
||||
if next(objects, None) is None:
|
||||
try:
|
||||
@@ -90,6 +90,18 @@ class S3DataStore(DataStore):
|
||||
break
|
||||
time.sleep(10)
|
||||
|
||||
# Check if there's a finished file
|
||||
objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.FINISHED.value)
|
||||
|
||||
if next(objects, None) is not None:
|
||||
try:
|
||||
self.mc.fget_object(
|
||||
self.params.bucket_name, SyncFiles.FINISHED.value,
|
||||
os.path.abspath(os.path.join(self.params.checkpoint_dir, SyncFiles.FINISHED.value))
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
ckpt = CheckpointState()
|
||||
if os.path.exists(filename):
|
||||
contents = open(filename, 'r').read()
|
||||
|
||||
@@ -133,8 +133,8 @@ class CarlaEnvironment(Environment):
|
||||
allow_braking: bool, quality: CarlaEnvironmentParameters.Quality,
|
||||
cameras: List[CameraTypes], weather_id: List[int], experiment_path: str,
|
||||
separate_actions_for_throttle_and_brake: bool,
|
||||
num_speedup_steps: int, max_speed: float, **kwargs):
|
||||
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters)
|
||||
num_speedup_steps: int, max_speed: float, target_success_rate: float = 1.0, **kwargs):
|
||||
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters, target_success_rate)
|
||||
|
||||
# server configuration
|
||||
self.server_height = server_height
|
||||
@@ -261,6 +261,8 @@ class CarlaEnvironment(Environment):
|
||||
image = self.get_rendered_image()
|
||||
self.renderer.create_screen(image.shape[1], image.shape[0])
|
||||
|
||||
self.target_success_rate = target_success_rate
|
||||
|
||||
def _add_cameras(self, settings, cameras, camera_width, camera_height):
|
||||
# add a front facing camera
|
||||
if CameraTypes.FRONT in cameras:
|
||||
@@ -461,3 +463,6 @@ class CarlaEnvironment(Environment):
|
||||
image = [self.state[camera.name] for camera in self.scene.sensors]
|
||||
image = np.vstack(image)
|
||||
return image
|
||||
|
||||
def get_target_success_rate(self) -> float:
|
||||
return self.target_success_rate
|
||||
|
||||
@@ -66,10 +66,10 @@ control_suite_envs = {':'.join(env): ':'.join(env) for env in suite.BENCHMARKING
|
||||
# Environment
|
||||
class ControlSuiteEnvironment(Environment):
|
||||
def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters,
|
||||
seed: Union[None, int]=None, human_control: bool=False,
|
||||
target_success_rate: float=1.0, seed: Union[None, int]=None, human_control: bool=False,
|
||||
observation_type: ObservationType=ObservationType.Measurements,
|
||||
custom_reward_threshold: Union[int, float]=None, **kwargs):
|
||||
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters)
|
||||
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters, target_success_rate)
|
||||
|
||||
self.observation_type = observation_type
|
||||
|
||||
@@ -126,6 +126,8 @@ class ControlSuiteEnvironment(Environment):
|
||||
if not self.native_rendering:
|
||||
self.renderer.create_screen(image.shape[1]*scale, image.shape[0]*scale)
|
||||
|
||||
self.target_success_rate = target_success_rate
|
||||
|
||||
def _update_state(self):
|
||||
self.state = {}
|
||||
|
||||
@@ -160,3 +162,6 @@ class ControlSuiteEnvironment(Environment):
|
||||
|
||||
def get_rendered_image(self):
|
||||
return self.env.physics.render(camera_id=0)
|
||||
|
||||
def get_target_success_rate(self) -> float:
|
||||
return self.target_success_rate
|
||||
@@ -124,8 +124,8 @@ class DoomEnvironment(Environment):
|
||||
|
||||
def __init__(self, level: LevelSelection, seed: int, frame_skip: int, human_control: bool,
|
||||
custom_reward_threshold: Union[int, float], visualization_parameters: VisualizationParameters,
|
||||
cameras: List[CameraTypes], **kwargs):
|
||||
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters)
|
||||
cameras: List[CameraTypes], target_success_rate: float=1.0, **kwargs):
|
||||
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters, target_success_rate)
|
||||
|
||||
self.cameras = cameras
|
||||
|
||||
@@ -196,6 +196,8 @@ class DoomEnvironment(Environment):
|
||||
image = self.get_rendered_image()
|
||||
self.renderer.create_screen(image.shape[1], image.shape[0])
|
||||
|
||||
self.target_success_rate = target_success_rate
|
||||
|
||||
def _update_state(self):
|
||||
# extract all data from the current state
|
||||
state = self.game.get_state()
|
||||
@@ -227,3 +229,6 @@ class DoomEnvironment(Environment):
|
||||
image = [self.state[camera.value[0]] for camera in self.cameras]
|
||||
image = np.vstack(image)
|
||||
return image
|
||||
|
||||
def get_target_success_rate(self) -> float:
|
||||
return self.target_success_rate
|
||||
|
||||
@@ -103,6 +103,9 @@ class EnvironmentParameters(Parameters):
|
||||
self.default_output_filter = None
|
||||
self.experiment_path = None
|
||||
|
||||
# Set target reward and target_success if present
|
||||
self.target_success_rate = 1.0
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return 'rl_coach.environments.environment:Environment'
|
||||
@@ -111,7 +114,7 @@ class EnvironmentParameters(Parameters):
|
||||
class Environment(EnvironmentInterface):
|
||||
def __init__(self, level: LevelSelection, seed: int, frame_skip: int, human_control: bool,
|
||||
custom_reward_threshold: Union[int, float], visualization_parameters: VisualizationParameters,
|
||||
**kwargs):
|
||||
target_success_rate: float=1.0, **kwargs):
|
||||
"""
|
||||
:param level: The environment level. Each environment can have multiple levels
|
||||
:param seed: a seed for the random number generator of the environment
|
||||
@@ -166,6 +169,9 @@ class Environment(EnvironmentInterface):
|
||||
if not self.native_rendering:
|
||||
self.renderer = Renderer()
|
||||
|
||||
# Set target reward and target_success if present
|
||||
self.target_success_rate = target_success_rate
|
||||
|
||||
@property
|
||||
def action_space(self) -> Union[List[ActionSpace], ActionSpace]:
|
||||
"""
|
||||
@@ -469,3 +475,5 @@ class Environment(EnvironmentInterface):
|
||||
"""
|
||||
return np.transpose(self.state['observation'], [1, 2, 0])
|
||||
|
||||
def get_target_success_rate(self) -> float:
|
||||
return self.target_success_rate
|
||||
|
||||
@@ -178,11 +178,11 @@ class MaxOverFramesAndFrameskipEnvWrapper(gym.Wrapper):
|
||||
# Environment
|
||||
class GymEnvironment(Environment):
|
||||
def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters,
|
||||
additional_simulator_parameters: Dict[str, Any] = {}, seed: Union[None, int]=None,
|
||||
target_success_rate: float=1.0, additional_simulator_parameters: Dict[str, Any] = {}, seed: Union[None, int]=None,
|
||||
human_control: bool=False, custom_reward_threshold: Union[int, float]=None,
|
||||
random_initialization_steps: int=1, max_over_num_frames: int=1, **kwargs):
|
||||
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold,
|
||||
visualization_parameters)
|
||||
visualization_parameters, target_success_rate)
|
||||
|
||||
self.random_initialization_steps = random_initialization_steps
|
||||
self.max_over_num_frames = max_over_num_frames
|
||||
@@ -337,6 +337,8 @@ class GymEnvironment(Environment):
|
||||
self.reward_success_threshold = self.env.spec.reward_threshold
|
||||
self.reward_space = RewardSpace(1, reward_success_threshold=self.reward_success_threshold)
|
||||
|
||||
self.target_success_rate = target_success_rate
|
||||
|
||||
def _wrap_state(self, state):
|
||||
if not isinstance(self.env.observation_space, gym.spaces.Dict):
|
||||
return {'observation': state}
|
||||
@@ -434,3 +436,6 @@ class GymEnvironment(Environment):
|
||||
if self.is_mujoco_env:
|
||||
self._set_mujoco_camera(0)
|
||||
return image
|
||||
|
||||
def get_target_success_rate(self) -> float:
|
||||
return self.target_success_rate
|
||||
|
||||
@@ -107,14 +107,14 @@ class StarCraft2EnvironmentParameters(EnvironmentParameters):
|
||||
# Environment
|
||||
class StarCraft2Environment(Environment):
|
||||
def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters,
|
||||
seed: Union[None, int]=None, human_control: bool=False,
|
||||
target_success_rate: float=1.0, seed: Union[None, int]=None, human_control: bool=False,
|
||||
custom_reward_threshold: Union[int, float]=None,
|
||||
screen_size: int=84, minimap_size: int=64,
|
||||
feature_minimap_maps_to_use: List=range(7), feature_screen_maps_to_use: List=range(17),
|
||||
observation_type: StarcraftObservationType=StarcraftObservationType.Features,
|
||||
disable_fog: bool=False, auto_select_all_army: bool=True,
|
||||
use_full_action_space: bool=False, **kwargs):
|
||||
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters)
|
||||
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters, target_success_rate)
|
||||
|
||||
self.screen_size = screen_size
|
||||
self.minimap_size = minimap_size
|
||||
@@ -192,6 +192,8 @@ class StarCraft2Environment(Environment):
|
||||
self.action_space = BoxActionSpace(2, 0, self.screen_size - 1, ["X-Axis, Y-Axis"],
|
||||
default_action=np.array([self.screen_size/2, self.screen_size/2]))
|
||||
|
||||
self.target_success_rate = target_success_rate
|
||||
|
||||
def _update_state(self):
|
||||
timestep = 0
|
||||
self.screen = self.last_result[timestep].observation.feature_screen
|
||||
@@ -244,3 +246,6 @@ class StarCraft2Environment(Environment):
|
||||
self.env._run_config.replay_dir = experiment_path
|
||||
self.env.save_replay('replays')
|
||||
super().dump_video_of_last_episode()
|
||||
|
||||
def get_target_success_rate(self):
|
||||
return self.target_success_rate
|
||||
|
||||
@@ -34,6 +34,7 @@ from rl_coach.logger import screen, Logger
|
||||
from rl_coach.utils import set_cpu, start_shell_command_and_wait
|
||||
from rl_coach.data_stores.data_store_impl import get_data_store
|
||||
from rl_coach.orchestrators.kubernetes_orchestrator import RunType
|
||||
from rl_coach.data_stores.data_store import SyncFiles
|
||||
|
||||
|
||||
class ScheduleParameters(Parameters):
|
||||
@@ -458,12 +459,12 @@ class GraphManager(object):
|
||||
"""
|
||||
[manager.sync() for manager in self.level_managers]
|
||||
|
||||
def evaluate(self, steps: PlayingStepsType, keep_networks_in_sync: bool=False) -> None:
|
||||
def evaluate(self, steps: PlayingStepsType, keep_networks_in_sync: bool=False) -> bool:
|
||||
"""
|
||||
Perform evaluation for several steps
|
||||
:param steps: the number of steps as a tuple of steps time and steps count
|
||||
:param keep_networks_in_sync: sync the network parameters with the global network before each episode
|
||||
:return: None
|
||||
:return: bool, True if the target reward and target success has been reached
|
||||
"""
|
||||
self.verify_graph_was_created()
|
||||
|
||||
@@ -478,6 +479,16 @@ class GraphManager(object):
|
||||
while self.current_step_counter < count_end:
|
||||
self.act(EnvironmentEpisodes(1))
|
||||
self.sync()
|
||||
if self.should_stop():
|
||||
if self.task_parameters.checkpoint_save_dir:
|
||||
open(os.path.join(self.task_parameters.checkpoint_save_dir, SyncFiles.FINISHED.value), 'w').close()
|
||||
if hasattr(self, 'data_store_params'):
|
||||
data_store = get_data_store(self.data_store_params)
|
||||
data_store.save_to_store()
|
||||
|
||||
screen.success("Reached required success rate. Exiting.")
|
||||
return True
|
||||
return False
|
||||
|
||||
def improve(self):
|
||||
"""
|
||||
@@ -508,7 +519,8 @@ class GraphManager(object):
|
||||
count_end = self.total_steps_counters[RunPhase.TRAIN] + self.improve_steps
|
||||
while self.total_steps_counters[RunPhase.TRAIN] < count_end:
|
||||
self.train_and_act(self.steps_between_evaluation_periods)
|
||||
self.evaluate(self.evaluation_steps)
|
||||
if self.evaluate(self.evaluation_steps):
|
||||
break
|
||||
|
||||
def _restore_checkpoint_tf(self, checkpoint_dir: str):
|
||||
import tensorflow as tf
|
||||
@@ -609,3 +621,6 @@ class GraphManager(object):
|
||||
|
||||
def should_train(self) -> bool:
|
||||
return any([manager.should_train() for manager in self.level_managers])
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
return all([manager.should_stop() for manager in self.level_managers])
|
||||
|
||||
@@ -263,3 +263,6 @@ class LevelManager(EnvironmentInterface):
|
||||
|
||||
def should_train(self) -> bool:
|
||||
return any([agent._should_train_helper() for agent in self.agents.values()])
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
return all([agent.get_success_rate() >= self.environment.get_target_success_rate() for agent in self.agents.values()])
|
||||
|
||||
@@ -79,6 +79,7 @@ class Kubernetes(Deploy):
|
||||
self.memory_backend = get_memory_backend(self.params.memory_backend_parameters)
|
||||
|
||||
self.params.data_store_params.orchestrator_params = {'namespace': self.params.namespace}
|
||||
self.params.data_store_params.namespace = self.params.namespace
|
||||
self.data_store = get_data_store(self.params.data_store_params)
|
||||
|
||||
if self.params.data_store_params.store_type == "s3":
|
||||
@@ -137,7 +138,8 @@ class Kubernetes(Deploy):
|
||||
volumes=[k8sclient.V1Volume(
|
||||
name="nfs-pvc",
|
||||
persistent_volume_claim=self.nfs_pvc
|
||||
)]
|
||||
)],
|
||||
restart_policy='OnFailure'
|
||||
),
|
||||
)
|
||||
else:
|
||||
@@ -155,32 +157,30 @@ class Kubernetes(Deploy):
|
||||
template = k8sclient.V1PodTemplateSpec(
|
||||
metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
|
||||
spec=k8sclient.V1PodSpec(
|
||||
containers=[container]
|
||||
containers=[container],
|
||||
restart_policy='OnFailure'
|
||||
),
|
||||
)
|
||||
|
||||
deployment_spec = k8sclient.V1DeploymentSpec(
|
||||
replicas=trainer_params.num_replicas,
|
||||
template=template,
|
||||
selector=k8sclient.V1LabelSelector(
|
||||
match_labels={'app': name}
|
||||
)
|
||||
job_spec = k8sclient.V1JobSpec(
|
||||
completions=1,
|
||||
template=template
|
||||
)
|
||||
|
||||
deployment = k8sclient.V1Deployment(
|
||||
api_version='apps/v1',
|
||||
kind='Deployment',
|
||||
job = k8sclient.V1Job(
|
||||
api_version="batch/v1",
|
||||
kind="Job",
|
||||
metadata=k8sclient.V1ObjectMeta(name=name),
|
||||
spec=deployment_spec
|
||||
spec=job_spec
|
||||
)
|
||||
|
||||
api_client = k8sclient.AppsV1Api()
|
||||
api_client = k8sclient.BatchV1Api()
|
||||
try:
|
||||
api_client.create_namespaced_deployment(self.params.namespace, deployment)
|
||||
trainer_params.orchestration_params['deployment_name'] = name
|
||||
api_client.create_namespaced_job(self.params.namespace, job)
|
||||
trainer_params.orchestration_params['job_name'] = name
|
||||
return True
|
||||
except k8sclient.rest.ApiException as e:
|
||||
print("Got exception: %s\n while creating deployment", e)
|
||||
print("Got exception: %s\n while creating job", e)
|
||||
return False
|
||||
|
||||
def deploy_worker(self):
|
||||
@@ -217,6 +217,7 @@ class Kubernetes(Deploy):
|
||||
name="nfs-pvc",
|
||||
persistent_volume_claim=self.nfs_pvc
|
||||
)],
|
||||
restart_policy='OnFailure'
|
||||
),
|
||||
)
|
||||
else:
|
||||
@@ -234,31 +235,31 @@ class Kubernetes(Deploy):
|
||||
template = k8sclient.V1PodTemplateSpec(
|
||||
metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
|
||||
spec=k8sclient.V1PodSpec(
|
||||
containers=[container]
|
||||
containers=[container],
|
||||
restart_policy='OnFailure'
|
||||
)
|
||||
)
|
||||
|
||||
deployment_spec = k8sclient.V1DeploymentSpec(
|
||||
replicas=worker_params.num_replicas,
|
||||
template=template,
|
||||
selector=k8sclient.V1LabelSelector(
|
||||
match_labels={'app': name}
|
||||
job_spec = k8sclient.V1JobSpec(
|
||||
completions=worker_params.num_replicas,
|
||||
parallelism=worker_params.num_replicas,
|
||||
template=template
|
||||
)
|
||||
)
|
||||
deployment = k8sclient.V1Deployment(
|
||||
api_version='apps/v1',
|
||||
kind="Deployment",
|
||||
|
||||
job = k8sclient.V1Job(
|
||||
api_version="batch/v1",
|
||||
kind="Job",
|
||||
metadata=k8sclient.V1ObjectMeta(name=name),
|
||||
spec=deployment_spec
|
||||
spec=job_spec
|
||||
)
|
||||
|
||||
api_client = k8sclient.AppsV1Api()
|
||||
api_client = k8sclient.BatchV1Api()
|
||||
try:
|
||||
api_client.create_namespaced_deployment(self.params.namespace, deployment)
|
||||
worker_params.orchestration_params['deployment_name'] = name
|
||||
api_client.create_namespaced_job(self.params.namespace, job)
|
||||
worker_params.orchestration_params['job_name'] = name
|
||||
return True
|
||||
except k8sclient.rest.ApiException as e:
|
||||
print("Got exception: %s\n while creating deployment", e)
|
||||
print("Got exception: %s\n while creating Job", e)
|
||||
return False
|
||||
|
||||
def worker_logs(self):
|
||||
@@ -273,7 +274,7 @@ class Kubernetes(Deploy):
|
||||
pod = None
|
||||
try:
|
||||
pods = api_client.list_namespaced_pod(self.params.namespace, label_selector='app={}'.format(
|
||||
trainer_params.orchestration_params['deployment_name']
|
||||
trainer_params.orchestration_params['job_name']
|
||||
))
|
||||
|
||||
pod = pods.items[0]
|
||||
@@ -324,17 +325,20 @@ class Kubernetes(Deploy):
|
||||
|
||||
def undeploy(self):
|
||||
trainer_params = self.params.run_type_params.get(str(RunType.TRAINER), None)
|
||||
api_client = k8sclient.AppsV1Api()
|
||||
delete_options = k8sclient.V1DeleteOptions()
|
||||
api_client = k8sclient.BatchV1Api()
|
||||
delete_options = k8sclient.V1DeleteOptions(
|
||||
propagation_policy="Foreground"
|
||||
)
|
||||
|
||||
if trainer_params:
|
||||
try:
|
||||
api_client.delete_namespaced_deployment(trainer_params.orchestration_params['deployment_name'], self.params.namespace, delete_options)
|
||||
api_client.delete_namespaced_job(trainer_params.orchestration_params['job_name'], self.params.namespace, delete_options)
|
||||
except k8sclient.rest.ApiException as e:
|
||||
print("Got exception: %s\n while deleting trainer", e)
|
||||
worker_params = self.params.run_type_params.get(str(RunType.ROLLOUT_WORKER), None)
|
||||
if worker_params:
|
||||
try:
|
||||
api_client.delete_namespaced_deployment(worker_params.orchestration_params['deployment_name'], self.params.namespace, delete_options)
|
||||
api_client.delete_namespaced_job(worker_params.orchestration_params['job_name'], self.params.namespace, delete_options)
|
||||
except k8sclient.rest.ApiException as e:
|
||||
print("Got exception: %s\n while deleting workers", e)
|
||||
self.memory_backend.undeploy()
|
||||
|
||||
@@ -56,6 +56,10 @@ agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000)
|
||||
# Environment #
|
||||
###############
|
||||
env_params = GymVectorEnvironment(level='CartPole-v0')
|
||||
env_params.custom_reward_threshold = 200
|
||||
# Set the target success
|
||||
env_params.target_success_rate = 1.0
|
||||
|
||||
|
||||
########
|
||||
# Test #
|
||||
|
||||
@@ -15,6 +15,7 @@ from rl_coach.base_parameters import TaskParameters, DistributedCoachSynchroniza
|
||||
from rl_coach.core_types import EnvironmentSteps, RunPhase
|
||||
from google.protobuf import text_format
|
||||
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
|
||||
from rl_coach.data_stores.data_store import SyncFiles
|
||||
|
||||
|
||||
def has_checkpoint(checkpoint_dir):
|
||||
@@ -68,6 +69,10 @@ def get_latest_checkpoint(checkpoint_dir):
|
||||
return int(rel_path.split('_Step')[0])
|
||||
|
||||
|
||||
def should_stop(checkpoint_dir):
|
||||
return os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value))
|
||||
|
||||
|
||||
def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers):
|
||||
"""
|
||||
wait for first checkpoint then perform rollouts using the model
|
||||
@@ -87,12 +92,17 @@ def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers):
|
||||
|
||||
for i in range(int(graph_manager.improve_steps.num_steps/act_steps)):
|
||||
|
||||
if should_stop(checkpoint_dir):
|
||||
break
|
||||
|
||||
graph_manager.act(EnvironmentSteps(num_steps=act_steps))
|
||||
|
||||
new_checkpoint = get_latest_checkpoint(checkpoint_dir)
|
||||
|
||||
if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
|
||||
while new_checkpoint < last_checkpoint + 1:
|
||||
if should_stop(checkpoint_dir):
|
||||
break
|
||||
if data_store:
|
||||
data_store.load_from_store()
|
||||
new_checkpoint = get_latest_checkpoint(checkpoint_dir)
|
||||
|
||||
@@ -40,8 +40,9 @@ def training_worker(graph_manager, checkpoint_dir):
|
||||
graph_manager.phase = core_types.RunPhase.UNDEFINED
|
||||
|
||||
if steps * graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps > graph_manager.steps_between_evaluation_periods.num_steps * eval_offset:
|
||||
graph_manager.evaluate(graph_manager.evaluation_steps)
|
||||
eval_offset += 1
|
||||
if graph_manager.evaluate(graph_manager.evaluation_steps):
|
||||
break
|
||||
|
||||
if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
|
||||
graph_manager.save_checkpoint()
|
||||
|
||||
Reference in New Issue
Block a user