mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Adding target reward and target sucess (#58)
* Adding target reward * Adding target successs * Addressing comments * Using custom_reward_threshold and target_success_rate * Adding exit message * Moving success rate to environment * Making target_success_rate optional
This commit is contained in:
committed by
Balaji Subramaniam
parent
0fe583186e
commit
875d6ef017
@@ -842,7 +842,5 @@ class Agent(AgentInterface):
|
|||||||
for network in self.networks.values():
|
for network in self.networks.values():
|
||||||
network.sync()
|
network.sync()
|
||||||
|
|
||||||
|
def get_success_rate(self) -> float:
|
||||||
|
return self.num_successes_across_evaluation_episodes / self.num_evaluation_episodes_completed
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
class DataStoreParameters(object):
|
class DataStoreParameters(object):
|
||||||
def __init__(self, store_type, orchestrator_type, orchestrator_params):
|
def __init__(self, store_type, orchestrator_type, orchestrator_params):
|
||||||
@@ -6,6 +8,7 @@ class DataStoreParameters(object):
|
|||||||
self.orchestrator_type = orchestrator_type
|
self.orchestrator_type = orchestrator_type
|
||||||
self.orchestrator_params = orchestrator_params
|
self.orchestrator_params = orchestrator_params
|
||||||
|
|
||||||
|
|
||||||
class DataStore(object):
|
class DataStore(object):
|
||||||
def __init__(self, params: DataStoreParameters):
|
def __init__(self, params: DataStoreParameters):
|
||||||
pass
|
pass
|
||||||
@@ -24,3 +27,8 @@ class DataStore(object):
|
|||||||
|
|
||||||
def load_from_store(self):
|
def load_from_store(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SyncFiles(Enum):
|
||||||
|
FINISHED = ".finished"
|
||||||
|
LOCKFILE = ".lock"
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ class NFSDataStore(DataStore):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def deploy_k8s_nfs(self) -> bool:
|
def deploy_k8s_nfs(self) -> bool:
|
||||||
name = "nfs-server"
|
name = "nfs-server-{}".format(uuid.uuid4())
|
||||||
container = k8sclient.V1Container(
|
container = k8sclient.V1Container(
|
||||||
name=name,
|
name=name,
|
||||||
image="k8s.gcr.io/volume-nfs:0.8",
|
image="k8s.gcr.io/volume-nfs:0.8",
|
||||||
@@ -83,7 +83,7 @@ class NFSDataStore(DataStore):
|
|||||||
security_context=k8sclient.V1SecurityContext(privileged=True)
|
security_context=k8sclient.V1SecurityContext(privileged=True)
|
||||||
)
|
)
|
||||||
template = k8sclient.V1PodTemplateSpec(
|
template = k8sclient.V1PodTemplateSpec(
|
||||||
metadata=k8sclient.V1ObjectMeta(labels={'app': 'nfs-server'}),
|
metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
|
||||||
spec=k8sclient.V1PodSpec(
|
spec=k8sclient.V1PodSpec(
|
||||||
containers=[container],
|
containers=[container],
|
||||||
volumes=[k8sclient.V1Volume(
|
volumes=[k8sclient.V1Volume(
|
||||||
@@ -96,14 +96,14 @@ class NFSDataStore(DataStore):
|
|||||||
replicas=1,
|
replicas=1,
|
||||||
template=template,
|
template=template,
|
||||||
selector=k8sclient.V1LabelSelector(
|
selector=k8sclient.V1LabelSelector(
|
||||||
match_labels={'app': 'nfs-server'}
|
match_labels={'app': name}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
deployment = k8sclient.V1Deployment(
|
deployment = k8sclient.V1Deployment(
|
||||||
api_version='apps/v1',
|
api_version='apps/v1',
|
||||||
kind='Deployment',
|
kind='Deployment',
|
||||||
metadata=k8sclient.V1ObjectMeta(name=name, labels={'app': 'nfs-server'}),
|
metadata=k8sclient.V1ObjectMeta(name=name, labels={'app': name}),
|
||||||
spec=deployment_spec
|
spec=deployment_spec
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -117,7 +117,7 @@ class NFSDataStore(DataStore):
|
|||||||
|
|
||||||
k8s_core_v1_api_client = k8sclient.CoreV1Api()
|
k8s_core_v1_api_client = k8sclient.CoreV1Api()
|
||||||
|
|
||||||
svc_name = "nfs-service"
|
svc_name = "nfs-service-{}".format(uuid.uuid4())
|
||||||
service = k8sclient.V1Service(
|
service = k8sclient.V1Service(
|
||||||
api_version='v1',
|
api_version='v1',
|
||||||
kind='Service',
|
kind='Service',
|
||||||
@@ -145,7 +145,7 @@ class NFSDataStore(DataStore):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def create_k8s_nfs_resources(self) -> bool:
|
def create_k8s_nfs_resources(self) -> bool:
|
||||||
pv_name = "nfs-ckpt-pv"
|
pv_name = "nfs-ckpt-pv-{}".format(uuid.uuid4())
|
||||||
persistent_volume = k8sclient.V1PersistentVolume(
|
persistent_volume = k8sclient.V1PersistentVolume(
|
||||||
api_version="v1",
|
api_version="v1",
|
||||||
kind="PersistentVolume",
|
kind="PersistentVolume",
|
||||||
@@ -171,7 +171,7 @@ class NFSDataStore(DataStore):
|
|||||||
print("Got exception: %s\n while creating the NFS PV", e)
|
print("Got exception: %s\n while creating the NFS PV", e)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
pvc_name = "nfs-ckpt-pvc"
|
pvc_name = "nfs-ckpt-pvc-{}".format(uuid.uuid4())
|
||||||
persistent_volume_claim = k8sclient.V1PersistentVolumeClaim(
|
persistent_volume_claim = k8sclient.V1PersistentVolumeClaim(
|
||||||
api_version="v1",
|
api_version="v1",
|
||||||
kind="PersistentVolumeClaim",
|
kind="PersistentVolumeClaim",
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from minio.error import ResponseError
|
|||||||
from configparser import ConfigParser, Error
|
from configparser import ConfigParser, Error
|
||||||
from google.protobuf import text_format
|
from google.protobuf import text_format
|
||||||
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
|
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
|
||||||
|
from rl_coach.data_stores.data_store import SyncFiles
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
@@ -20,7 +21,6 @@ class S3DataStoreParameters(DataStoreParameters):
|
|||||||
self.end_point = end_point
|
self.end_point = end_point
|
||||||
self.bucket_name = bucket_name
|
self.bucket_name = bucket_name
|
||||||
self.checkpoint_dir = checkpoint_dir
|
self.checkpoint_dir = checkpoint_dir
|
||||||
self.lock_file = ".lock"
|
|
||||||
|
|
||||||
|
|
||||||
class S3DataStore(DataStore):
|
class S3DataStore(DataStore):
|
||||||
@@ -52,9 +52,9 @@ class S3DataStore(DataStore):
|
|||||||
|
|
||||||
def save_to_store(self):
|
def save_to_store(self):
|
||||||
try:
|
try:
|
||||||
self.mc.remove_object(self.params.bucket_name, self.params.lock_file)
|
self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value)
|
||||||
|
|
||||||
self.mc.put_object(self.params.bucket_name, self.params.lock_file, io.BytesIO(b''), 0)
|
self.mc.put_object(self.params.bucket_name, SyncFiles.LOCKFILE.value, io.BytesIO(b''), 0)
|
||||||
|
|
||||||
checkpoint_file = None
|
checkpoint_file = None
|
||||||
for root, dirs, files in os.walk(self.params.checkpoint_dir):
|
for root, dirs, files in os.walk(self.params.checkpoint_dir):
|
||||||
@@ -70,7 +70,7 @@ class S3DataStore(DataStore):
|
|||||||
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
|
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
|
||||||
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
|
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
|
||||||
|
|
||||||
self.mc.remove_object(self.params.bucket_name, self.params.lock_file)
|
self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value)
|
||||||
|
|
||||||
except ResponseError as e:
|
except ResponseError as e:
|
||||||
print("Got exception: %s\n while saving to S3", e)
|
print("Got exception: %s\n while saving to S3", e)
|
||||||
@@ -80,7 +80,7 @@ class S3DataStore(DataStore):
|
|||||||
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint"))
|
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint"))
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
objects = self.mc.list_objects_v2(self.params.bucket_name, self.params.lock_file)
|
objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.LOCKFILE.value)
|
||||||
|
|
||||||
if next(objects, None) is None:
|
if next(objects, None) is None:
|
||||||
try:
|
try:
|
||||||
@@ -90,6 +90,18 @@ class S3DataStore(DataStore):
|
|||||||
break
|
break
|
||||||
time.sleep(10)
|
time.sleep(10)
|
||||||
|
|
||||||
|
# Check if there's a finished file
|
||||||
|
objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.FINISHED.value)
|
||||||
|
|
||||||
|
if next(objects, None) is not None:
|
||||||
|
try:
|
||||||
|
self.mc.fget_object(
|
||||||
|
self.params.bucket_name, SyncFiles.FINISHED.value,
|
||||||
|
os.path.abspath(os.path.join(self.params.checkpoint_dir, SyncFiles.FINISHED.value))
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
ckpt = CheckpointState()
|
ckpt = CheckpointState()
|
||||||
if os.path.exists(filename):
|
if os.path.exists(filename):
|
||||||
contents = open(filename, 'r').read()
|
contents = open(filename, 'r').read()
|
||||||
|
|||||||
@@ -133,8 +133,8 @@ class CarlaEnvironment(Environment):
|
|||||||
allow_braking: bool, quality: CarlaEnvironmentParameters.Quality,
|
allow_braking: bool, quality: CarlaEnvironmentParameters.Quality,
|
||||||
cameras: List[CameraTypes], weather_id: List[int], experiment_path: str,
|
cameras: List[CameraTypes], weather_id: List[int], experiment_path: str,
|
||||||
separate_actions_for_throttle_and_brake: bool,
|
separate_actions_for_throttle_and_brake: bool,
|
||||||
num_speedup_steps: int, max_speed: float, **kwargs):
|
num_speedup_steps: int, max_speed: float, target_success_rate: float = 1.0, **kwargs):
|
||||||
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters)
|
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters, target_success_rate)
|
||||||
|
|
||||||
# server configuration
|
# server configuration
|
||||||
self.server_height = server_height
|
self.server_height = server_height
|
||||||
@@ -261,6 +261,8 @@ class CarlaEnvironment(Environment):
|
|||||||
image = self.get_rendered_image()
|
image = self.get_rendered_image()
|
||||||
self.renderer.create_screen(image.shape[1], image.shape[0])
|
self.renderer.create_screen(image.shape[1], image.shape[0])
|
||||||
|
|
||||||
|
self.target_success_rate = target_success_rate
|
||||||
|
|
||||||
def _add_cameras(self, settings, cameras, camera_width, camera_height):
|
def _add_cameras(self, settings, cameras, camera_width, camera_height):
|
||||||
# add a front facing camera
|
# add a front facing camera
|
||||||
if CameraTypes.FRONT in cameras:
|
if CameraTypes.FRONT in cameras:
|
||||||
@@ -461,3 +463,6 @@ class CarlaEnvironment(Environment):
|
|||||||
image = [self.state[camera.name] for camera in self.scene.sensors]
|
image = [self.state[camera.name] for camera in self.scene.sensors]
|
||||||
image = np.vstack(image)
|
image = np.vstack(image)
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
def get_target_success_rate(self) -> float:
|
||||||
|
return self.target_success_rate
|
||||||
|
|||||||
@@ -66,10 +66,10 @@ control_suite_envs = {':'.join(env): ':'.join(env) for env in suite.BENCHMARKING
|
|||||||
# Environment
|
# Environment
|
||||||
class ControlSuiteEnvironment(Environment):
|
class ControlSuiteEnvironment(Environment):
|
||||||
def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters,
|
def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters,
|
||||||
seed: Union[None, int]=None, human_control: bool=False,
|
target_success_rate: float=1.0, seed: Union[None, int]=None, human_control: bool=False,
|
||||||
observation_type: ObservationType=ObservationType.Measurements,
|
observation_type: ObservationType=ObservationType.Measurements,
|
||||||
custom_reward_threshold: Union[int, float]=None, **kwargs):
|
custom_reward_threshold: Union[int, float]=None, **kwargs):
|
||||||
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters)
|
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters, target_success_rate)
|
||||||
|
|
||||||
self.observation_type = observation_type
|
self.observation_type = observation_type
|
||||||
|
|
||||||
@@ -126,6 +126,8 @@ class ControlSuiteEnvironment(Environment):
|
|||||||
if not self.native_rendering:
|
if not self.native_rendering:
|
||||||
self.renderer.create_screen(image.shape[1]*scale, image.shape[0]*scale)
|
self.renderer.create_screen(image.shape[1]*scale, image.shape[0]*scale)
|
||||||
|
|
||||||
|
self.target_success_rate = target_success_rate
|
||||||
|
|
||||||
def _update_state(self):
|
def _update_state(self):
|
||||||
self.state = {}
|
self.state = {}
|
||||||
|
|
||||||
@@ -160,3 +162,6 @@ class ControlSuiteEnvironment(Environment):
|
|||||||
|
|
||||||
def get_rendered_image(self):
|
def get_rendered_image(self):
|
||||||
return self.env.physics.render(camera_id=0)
|
return self.env.physics.render(camera_id=0)
|
||||||
|
|
||||||
|
def get_target_success_rate(self) -> float:
|
||||||
|
return self.target_success_rate
|
||||||
@@ -124,8 +124,8 @@ class DoomEnvironment(Environment):
|
|||||||
|
|
||||||
def __init__(self, level: LevelSelection, seed: int, frame_skip: int, human_control: bool,
|
def __init__(self, level: LevelSelection, seed: int, frame_skip: int, human_control: bool,
|
||||||
custom_reward_threshold: Union[int, float], visualization_parameters: VisualizationParameters,
|
custom_reward_threshold: Union[int, float], visualization_parameters: VisualizationParameters,
|
||||||
cameras: List[CameraTypes], **kwargs):
|
cameras: List[CameraTypes], target_success_rate: float=1.0, **kwargs):
|
||||||
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters)
|
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters, target_success_rate)
|
||||||
|
|
||||||
self.cameras = cameras
|
self.cameras = cameras
|
||||||
|
|
||||||
@@ -196,6 +196,8 @@ class DoomEnvironment(Environment):
|
|||||||
image = self.get_rendered_image()
|
image = self.get_rendered_image()
|
||||||
self.renderer.create_screen(image.shape[1], image.shape[0])
|
self.renderer.create_screen(image.shape[1], image.shape[0])
|
||||||
|
|
||||||
|
self.target_success_rate = target_success_rate
|
||||||
|
|
||||||
def _update_state(self):
|
def _update_state(self):
|
||||||
# extract all data from the current state
|
# extract all data from the current state
|
||||||
state = self.game.get_state()
|
state = self.game.get_state()
|
||||||
@@ -227,3 +229,6 @@ class DoomEnvironment(Environment):
|
|||||||
image = [self.state[camera.value[0]] for camera in self.cameras]
|
image = [self.state[camera.value[0]] for camera in self.cameras]
|
||||||
image = np.vstack(image)
|
image = np.vstack(image)
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
def get_target_success_rate(self) -> float:
|
||||||
|
return self.target_success_rate
|
||||||
|
|||||||
@@ -103,6 +103,9 @@ class EnvironmentParameters(Parameters):
|
|||||||
self.default_output_filter = None
|
self.default_output_filter = None
|
||||||
self.experiment_path = None
|
self.experiment_path = None
|
||||||
|
|
||||||
|
# Set target reward and target_success if present
|
||||||
|
self.target_success_rate = 1.0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def path(self):
|
def path(self):
|
||||||
return 'rl_coach.environments.environment:Environment'
|
return 'rl_coach.environments.environment:Environment'
|
||||||
@@ -111,7 +114,7 @@ class EnvironmentParameters(Parameters):
|
|||||||
class Environment(EnvironmentInterface):
|
class Environment(EnvironmentInterface):
|
||||||
def __init__(self, level: LevelSelection, seed: int, frame_skip: int, human_control: bool,
|
def __init__(self, level: LevelSelection, seed: int, frame_skip: int, human_control: bool,
|
||||||
custom_reward_threshold: Union[int, float], visualization_parameters: VisualizationParameters,
|
custom_reward_threshold: Union[int, float], visualization_parameters: VisualizationParameters,
|
||||||
**kwargs):
|
target_success_rate: float=1.0, **kwargs):
|
||||||
"""
|
"""
|
||||||
:param level: The environment level. Each environment can have multiple levels
|
:param level: The environment level. Each environment can have multiple levels
|
||||||
:param seed: a seed for the random number generator of the environment
|
:param seed: a seed for the random number generator of the environment
|
||||||
@@ -166,6 +169,9 @@ class Environment(EnvironmentInterface):
|
|||||||
if not self.native_rendering:
|
if not self.native_rendering:
|
||||||
self.renderer = Renderer()
|
self.renderer = Renderer()
|
||||||
|
|
||||||
|
# Set target reward and target_success if present
|
||||||
|
self.target_success_rate = target_success_rate
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def action_space(self) -> Union[List[ActionSpace], ActionSpace]:
|
def action_space(self) -> Union[List[ActionSpace], ActionSpace]:
|
||||||
"""
|
"""
|
||||||
@@ -469,3 +475,5 @@ class Environment(EnvironmentInterface):
|
|||||||
"""
|
"""
|
||||||
return np.transpose(self.state['observation'], [1, 2, 0])
|
return np.transpose(self.state['observation'], [1, 2, 0])
|
||||||
|
|
||||||
|
def get_target_success_rate(self) -> float:
|
||||||
|
return self.target_success_rate
|
||||||
|
|||||||
@@ -178,11 +178,11 @@ class MaxOverFramesAndFrameskipEnvWrapper(gym.Wrapper):
|
|||||||
# Environment
|
# Environment
|
||||||
class GymEnvironment(Environment):
|
class GymEnvironment(Environment):
|
||||||
def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters,
|
def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters,
|
||||||
additional_simulator_parameters: Dict[str, Any] = {}, seed: Union[None, int]=None,
|
target_success_rate: float=1.0, additional_simulator_parameters: Dict[str, Any] = {}, seed: Union[None, int]=None,
|
||||||
human_control: bool=False, custom_reward_threshold: Union[int, float]=None,
|
human_control: bool=False, custom_reward_threshold: Union[int, float]=None,
|
||||||
random_initialization_steps: int=1, max_over_num_frames: int=1, **kwargs):
|
random_initialization_steps: int=1, max_over_num_frames: int=1, **kwargs):
|
||||||
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold,
|
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold,
|
||||||
visualization_parameters)
|
visualization_parameters, target_success_rate)
|
||||||
|
|
||||||
self.random_initialization_steps = random_initialization_steps
|
self.random_initialization_steps = random_initialization_steps
|
||||||
self.max_over_num_frames = max_over_num_frames
|
self.max_over_num_frames = max_over_num_frames
|
||||||
@@ -337,6 +337,8 @@ class GymEnvironment(Environment):
|
|||||||
self.reward_success_threshold = self.env.spec.reward_threshold
|
self.reward_success_threshold = self.env.spec.reward_threshold
|
||||||
self.reward_space = RewardSpace(1, reward_success_threshold=self.reward_success_threshold)
|
self.reward_space = RewardSpace(1, reward_success_threshold=self.reward_success_threshold)
|
||||||
|
|
||||||
|
self.target_success_rate = target_success_rate
|
||||||
|
|
||||||
def _wrap_state(self, state):
|
def _wrap_state(self, state):
|
||||||
if not isinstance(self.env.observation_space, gym.spaces.Dict):
|
if not isinstance(self.env.observation_space, gym.spaces.Dict):
|
||||||
return {'observation': state}
|
return {'observation': state}
|
||||||
@@ -434,3 +436,6 @@ class GymEnvironment(Environment):
|
|||||||
if self.is_mujoco_env:
|
if self.is_mujoco_env:
|
||||||
self._set_mujoco_camera(0)
|
self._set_mujoco_camera(0)
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
def get_target_success_rate(self) -> float:
|
||||||
|
return self.target_success_rate
|
||||||
|
|||||||
@@ -107,14 +107,14 @@ class StarCraft2EnvironmentParameters(EnvironmentParameters):
|
|||||||
# Environment
|
# Environment
|
||||||
class StarCraft2Environment(Environment):
|
class StarCraft2Environment(Environment):
|
||||||
def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters,
|
def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters,
|
||||||
seed: Union[None, int]=None, human_control: bool=False,
|
target_success_rate: float=1.0, seed: Union[None, int]=None, human_control: bool=False,
|
||||||
custom_reward_threshold: Union[int, float]=None,
|
custom_reward_threshold: Union[int, float]=None,
|
||||||
screen_size: int=84, minimap_size: int=64,
|
screen_size: int=84, minimap_size: int=64,
|
||||||
feature_minimap_maps_to_use: List=range(7), feature_screen_maps_to_use: List=range(17),
|
feature_minimap_maps_to_use: List=range(7), feature_screen_maps_to_use: List=range(17),
|
||||||
observation_type: StarcraftObservationType=StarcraftObservationType.Features,
|
observation_type: StarcraftObservationType=StarcraftObservationType.Features,
|
||||||
disable_fog: bool=False, auto_select_all_army: bool=True,
|
disable_fog: bool=False, auto_select_all_army: bool=True,
|
||||||
use_full_action_space: bool=False, **kwargs):
|
use_full_action_space: bool=False, **kwargs):
|
||||||
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters)
|
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters, target_success_rate)
|
||||||
|
|
||||||
self.screen_size = screen_size
|
self.screen_size = screen_size
|
||||||
self.minimap_size = minimap_size
|
self.minimap_size = minimap_size
|
||||||
@@ -192,6 +192,8 @@ class StarCraft2Environment(Environment):
|
|||||||
self.action_space = BoxActionSpace(2, 0, self.screen_size - 1, ["X-Axis, Y-Axis"],
|
self.action_space = BoxActionSpace(2, 0, self.screen_size - 1, ["X-Axis, Y-Axis"],
|
||||||
default_action=np.array([self.screen_size/2, self.screen_size/2]))
|
default_action=np.array([self.screen_size/2, self.screen_size/2]))
|
||||||
|
|
||||||
|
self.target_success_rate = target_success_rate
|
||||||
|
|
||||||
def _update_state(self):
|
def _update_state(self):
|
||||||
timestep = 0
|
timestep = 0
|
||||||
self.screen = self.last_result[timestep].observation.feature_screen
|
self.screen = self.last_result[timestep].observation.feature_screen
|
||||||
@@ -244,3 +246,6 @@ class StarCraft2Environment(Environment):
|
|||||||
self.env._run_config.replay_dir = experiment_path
|
self.env._run_config.replay_dir = experiment_path
|
||||||
self.env.save_replay('replays')
|
self.env.save_replay('replays')
|
||||||
super().dump_video_of_last_episode()
|
super().dump_video_of_last_episode()
|
||||||
|
|
||||||
|
def get_target_success_rate(self):
|
||||||
|
return self.target_success_rate
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from rl_coach.logger import screen, Logger
|
|||||||
from rl_coach.utils import set_cpu, start_shell_command_and_wait
|
from rl_coach.utils import set_cpu, start_shell_command_and_wait
|
||||||
from rl_coach.data_stores.data_store_impl import get_data_store
|
from rl_coach.data_stores.data_store_impl import get_data_store
|
||||||
from rl_coach.orchestrators.kubernetes_orchestrator import RunType
|
from rl_coach.orchestrators.kubernetes_orchestrator import RunType
|
||||||
|
from rl_coach.data_stores.data_store import SyncFiles
|
||||||
|
|
||||||
|
|
||||||
class ScheduleParameters(Parameters):
|
class ScheduleParameters(Parameters):
|
||||||
@@ -458,12 +459,12 @@ class GraphManager(object):
|
|||||||
"""
|
"""
|
||||||
[manager.sync() for manager in self.level_managers]
|
[manager.sync() for manager in self.level_managers]
|
||||||
|
|
||||||
def evaluate(self, steps: PlayingStepsType, keep_networks_in_sync: bool=False) -> None:
|
def evaluate(self, steps: PlayingStepsType, keep_networks_in_sync: bool=False) -> bool:
|
||||||
"""
|
"""
|
||||||
Perform evaluation for several steps
|
Perform evaluation for several steps
|
||||||
:param steps: the number of steps as a tuple of steps time and steps count
|
:param steps: the number of steps as a tuple of steps time and steps count
|
||||||
:param keep_networks_in_sync: sync the network parameters with the global network before each episode
|
:param keep_networks_in_sync: sync the network parameters with the global network before each episode
|
||||||
:return: None
|
:return: bool, True if the target reward and target success has been reached
|
||||||
"""
|
"""
|
||||||
self.verify_graph_was_created()
|
self.verify_graph_was_created()
|
||||||
|
|
||||||
@@ -478,6 +479,16 @@ class GraphManager(object):
|
|||||||
while self.current_step_counter < count_end:
|
while self.current_step_counter < count_end:
|
||||||
self.act(EnvironmentEpisodes(1))
|
self.act(EnvironmentEpisodes(1))
|
||||||
self.sync()
|
self.sync()
|
||||||
|
if self.should_stop():
|
||||||
|
if self.task_parameters.checkpoint_save_dir:
|
||||||
|
open(os.path.join(self.task_parameters.checkpoint_save_dir, SyncFiles.FINISHED.value), 'w').close()
|
||||||
|
if hasattr(self, 'data_store_params'):
|
||||||
|
data_store = get_data_store(self.data_store_params)
|
||||||
|
data_store.save_to_store()
|
||||||
|
|
||||||
|
screen.success("Reached required success rate. Exiting.")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def improve(self):
|
def improve(self):
|
||||||
"""
|
"""
|
||||||
@@ -508,7 +519,8 @@ class GraphManager(object):
|
|||||||
count_end = self.total_steps_counters[RunPhase.TRAIN] + self.improve_steps
|
count_end = self.total_steps_counters[RunPhase.TRAIN] + self.improve_steps
|
||||||
while self.total_steps_counters[RunPhase.TRAIN] < count_end:
|
while self.total_steps_counters[RunPhase.TRAIN] < count_end:
|
||||||
self.train_and_act(self.steps_between_evaluation_periods)
|
self.train_and_act(self.steps_between_evaluation_periods)
|
||||||
self.evaluate(self.evaluation_steps)
|
if self.evaluate(self.evaluation_steps):
|
||||||
|
break
|
||||||
|
|
||||||
def _restore_checkpoint_tf(self, checkpoint_dir: str):
|
def _restore_checkpoint_tf(self, checkpoint_dir: str):
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
@@ -609,3 +621,6 @@ class GraphManager(object):
|
|||||||
|
|
||||||
def should_train(self) -> bool:
|
def should_train(self) -> bool:
|
||||||
return any([manager.should_train() for manager in self.level_managers])
|
return any([manager.should_train() for manager in self.level_managers])
|
||||||
|
|
||||||
|
def should_stop(self) -> bool:
|
||||||
|
return all([manager.should_stop() for manager in self.level_managers])
|
||||||
|
|||||||
@@ -263,3 +263,6 @@ class LevelManager(EnvironmentInterface):
|
|||||||
|
|
||||||
def should_train(self) -> bool:
|
def should_train(self) -> bool:
|
||||||
return any([agent._should_train_helper() for agent in self.agents.values()])
|
return any([agent._should_train_helper() for agent in self.agents.values()])
|
||||||
|
|
||||||
|
def should_stop(self) -> bool:
|
||||||
|
return all([agent.get_success_rate() >= self.environment.get_target_success_rate() for agent in self.agents.values()])
|
||||||
|
|||||||
@@ -79,6 +79,7 @@ class Kubernetes(Deploy):
|
|||||||
self.memory_backend = get_memory_backend(self.params.memory_backend_parameters)
|
self.memory_backend = get_memory_backend(self.params.memory_backend_parameters)
|
||||||
|
|
||||||
self.params.data_store_params.orchestrator_params = {'namespace': self.params.namespace}
|
self.params.data_store_params.orchestrator_params = {'namespace': self.params.namespace}
|
||||||
|
self.params.data_store_params.namespace = self.params.namespace
|
||||||
self.data_store = get_data_store(self.params.data_store_params)
|
self.data_store = get_data_store(self.params.data_store_params)
|
||||||
|
|
||||||
if self.params.data_store_params.store_type == "s3":
|
if self.params.data_store_params.store_type == "s3":
|
||||||
@@ -137,7 +138,8 @@ class Kubernetes(Deploy):
|
|||||||
volumes=[k8sclient.V1Volume(
|
volumes=[k8sclient.V1Volume(
|
||||||
name="nfs-pvc",
|
name="nfs-pvc",
|
||||||
persistent_volume_claim=self.nfs_pvc
|
persistent_volume_claim=self.nfs_pvc
|
||||||
)]
|
)],
|
||||||
|
restart_policy='OnFailure'
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -155,32 +157,30 @@ class Kubernetes(Deploy):
|
|||||||
template = k8sclient.V1PodTemplateSpec(
|
template = k8sclient.V1PodTemplateSpec(
|
||||||
metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
|
metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
|
||||||
spec=k8sclient.V1PodSpec(
|
spec=k8sclient.V1PodSpec(
|
||||||
containers=[container]
|
containers=[container],
|
||||||
|
restart_policy='OnFailure'
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
deployment_spec = k8sclient.V1DeploymentSpec(
|
job_spec = k8sclient.V1JobSpec(
|
||||||
replicas=trainer_params.num_replicas,
|
completions=1,
|
||||||
template=template,
|
template=template
|
||||||
selector=k8sclient.V1LabelSelector(
|
|
||||||
match_labels={'app': name}
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
deployment = k8sclient.V1Deployment(
|
job = k8sclient.V1Job(
|
||||||
api_version='apps/v1',
|
api_version="batch/v1",
|
||||||
kind='Deployment',
|
kind="Job",
|
||||||
metadata=k8sclient.V1ObjectMeta(name=name),
|
metadata=k8sclient.V1ObjectMeta(name=name),
|
||||||
spec=deployment_spec
|
spec=job_spec
|
||||||
)
|
)
|
||||||
|
|
||||||
api_client = k8sclient.AppsV1Api()
|
api_client = k8sclient.BatchV1Api()
|
||||||
try:
|
try:
|
||||||
api_client.create_namespaced_deployment(self.params.namespace, deployment)
|
api_client.create_namespaced_job(self.params.namespace, job)
|
||||||
trainer_params.orchestration_params['deployment_name'] = name
|
trainer_params.orchestration_params['job_name'] = name
|
||||||
return True
|
return True
|
||||||
except k8sclient.rest.ApiException as e:
|
except k8sclient.rest.ApiException as e:
|
||||||
print("Got exception: %s\n while creating deployment", e)
|
print("Got exception: %s\n while creating job", e)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def deploy_worker(self):
|
def deploy_worker(self):
|
||||||
@@ -217,6 +217,7 @@ class Kubernetes(Deploy):
|
|||||||
name="nfs-pvc",
|
name="nfs-pvc",
|
||||||
persistent_volume_claim=self.nfs_pvc
|
persistent_volume_claim=self.nfs_pvc
|
||||||
)],
|
)],
|
||||||
|
restart_policy='OnFailure'
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -234,31 +235,31 @@ class Kubernetes(Deploy):
|
|||||||
template = k8sclient.V1PodTemplateSpec(
|
template = k8sclient.V1PodTemplateSpec(
|
||||||
metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
|
metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
|
||||||
spec=k8sclient.V1PodSpec(
|
spec=k8sclient.V1PodSpec(
|
||||||
containers=[container]
|
containers=[container],
|
||||||
|
restart_policy='OnFailure'
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
deployment_spec = k8sclient.V1DeploymentSpec(
|
job_spec = k8sclient.V1JobSpec(
|
||||||
replicas=worker_params.num_replicas,
|
completions=worker_params.num_replicas,
|
||||||
template=template,
|
parallelism=worker_params.num_replicas,
|
||||||
selector=k8sclient.V1LabelSelector(
|
template=template
|
||||||
match_labels={'app': name}
|
|
||||||
)
|
)
|
||||||
)
|
|
||||||
deployment = k8sclient.V1Deployment(
|
job = k8sclient.V1Job(
|
||||||
api_version='apps/v1',
|
api_version="batch/v1",
|
||||||
kind="Deployment",
|
kind="Job",
|
||||||
metadata=k8sclient.V1ObjectMeta(name=name),
|
metadata=k8sclient.V1ObjectMeta(name=name),
|
||||||
spec=deployment_spec
|
spec=job_spec
|
||||||
)
|
)
|
||||||
|
|
||||||
api_client = k8sclient.AppsV1Api()
|
api_client = k8sclient.BatchV1Api()
|
||||||
try:
|
try:
|
||||||
api_client.create_namespaced_deployment(self.params.namespace, deployment)
|
api_client.create_namespaced_job(self.params.namespace, job)
|
||||||
worker_params.orchestration_params['deployment_name'] = name
|
worker_params.orchestration_params['job_name'] = name
|
||||||
return True
|
return True
|
||||||
except k8sclient.rest.ApiException as e:
|
except k8sclient.rest.ApiException as e:
|
||||||
print("Got exception: %s\n while creating deployment", e)
|
print("Got exception: %s\n while creating Job", e)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def worker_logs(self):
|
def worker_logs(self):
|
||||||
@@ -273,7 +274,7 @@ class Kubernetes(Deploy):
|
|||||||
pod = None
|
pod = None
|
||||||
try:
|
try:
|
||||||
pods = api_client.list_namespaced_pod(self.params.namespace, label_selector='app={}'.format(
|
pods = api_client.list_namespaced_pod(self.params.namespace, label_selector='app={}'.format(
|
||||||
trainer_params.orchestration_params['deployment_name']
|
trainer_params.orchestration_params['job_name']
|
||||||
))
|
))
|
||||||
|
|
||||||
pod = pods.items[0]
|
pod = pods.items[0]
|
||||||
@@ -324,17 +325,20 @@ class Kubernetes(Deploy):
|
|||||||
|
|
||||||
def undeploy(self):
|
def undeploy(self):
|
||||||
trainer_params = self.params.run_type_params.get(str(RunType.TRAINER), None)
|
trainer_params = self.params.run_type_params.get(str(RunType.TRAINER), None)
|
||||||
api_client = k8sclient.AppsV1Api()
|
api_client = k8sclient.BatchV1Api()
|
||||||
delete_options = k8sclient.V1DeleteOptions()
|
delete_options = k8sclient.V1DeleteOptions(
|
||||||
|
propagation_policy="Foreground"
|
||||||
|
)
|
||||||
|
|
||||||
if trainer_params:
|
if trainer_params:
|
||||||
try:
|
try:
|
||||||
api_client.delete_namespaced_deployment(trainer_params.orchestration_params['deployment_name'], self.params.namespace, delete_options)
|
api_client.delete_namespaced_job(trainer_params.orchestration_params['job_name'], self.params.namespace, delete_options)
|
||||||
except k8sclient.rest.ApiException as e:
|
except k8sclient.rest.ApiException as e:
|
||||||
print("Got exception: %s\n while deleting trainer", e)
|
print("Got exception: %s\n while deleting trainer", e)
|
||||||
worker_params = self.params.run_type_params.get(str(RunType.ROLLOUT_WORKER), None)
|
worker_params = self.params.run_type_params.get(str(RunType.ROLLOUT_WORKER), None)
|
||||||
if worker_params:
|
if worker_params:
|
||||||
try:
|
try:
|
||||||
api_client.delete_namespaced_deployment(worker_params.orchestration_params['deployment_name'], self.params.namespace, delete_options)
|
api_client.delete_namespaced_job(worker_params.orchestration_params['job_name'], self.params.namespace, delete_options)
|
||||||
except k8sclient.rest.ApiException as e:
|
except k8sclient.rest.ApiException as e:
|
||||||
print("Got exception: %s\n while deleting workers", e)
|
print("Got exception: %s\n while deleting workers", e)
|
||||||
self.memory_backend.undeploy()
|
self.memory_backend.undeploy()
|
||||||
|
|||||||
@@ -56,6 +56,10 @@ agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000)
|
|||||||
# Environment #
|
# Environment #
|
||||||
###############
|
###############
|
||||||
env_params = GymVectorEnvironment(level='CartPole-v0')
|
env_params = GymVectorEnvironment(level='CartPole-v0')
|
||||||
|
env_params.custom_reward_threshold = 200
|
||||||
|
# Set the target success
|
||||||
|
env_params.target_success_rate = 1.0
|
||||||
|
|
||||||
|
|
||||||
########
|
########
|
||||||
# Test #
|
# Test #
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from rl_coach.base_parameters import TaskParameters, DistributedCoachSynchroniza
|
|||||||
from rl_coach.core_types import EnvironmentSteps, RunPhase
|
from rl_coach.core_types import EnvironmentSteps, RunPhase
|
||||||
from google.protobuf import text_format
|
from google.protobuf import text_format
|
||||||
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
|
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
|
||||||
|
from rl_coach.data_stores.data_store import SyncFiles
|
||||||
|
|
||||||
|
|
||||||
def has_checkpoint(checkpoint_dir):
|
def has_checkpoint(checkpoint_dir):
|
||||||
@@ -68,6 +69,10 @@ def get_latest_checkpoint(checkpoint_dir):
|
|||||||
return int(rel_path.split('_Step')[0])
|
return int(rel_path.split('_Step')[0])
|
||||||
|
|
||||||
|
|
||||||
|
def should_stop(checkpoint_dir):
|
||||||
|
return os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value))
|
||||||
|
|
||||||
|
|
||||||
def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers):
|
def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers):
|
||||||
"""
|
"""
|
||||||
wait for first checkpoint then perform rollouts using the model
|
wait for first checkpoint then perform rollouts using the model
|
||||||
@@ -87,12 +92,17 @@ def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers):
|
|||||||
|
|
||||||
for i in range(int(graph_manager.improve_steps.num_steps/act_steps)):
|
for i in range(int(graph_manager.improve_steps.num_steps/act_steps)):
|
||||||
|
|
||||||
|
if should_stop(checkpoint_dir):
|
||||||
|
break
|
||||||
|
|
||||||
graph_manager.act(EnvironmentSteps(num_steps=act_steps))
|
graph_manager.act(EnvironmentSteps(num_steps=act_steps))
|
||||||
|
|
||||||
new_checkpoint = get_latest_checkpoint(checkpoint_dir)
|
new_checkpoint = get_latest_checkpoint(checkpoint_dir)
|
||||||
|
|
||||||
if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
|
if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
|
||||||
while new_checkpoint < last_checkpoint + 1:
|
while new_checkpoint < last_checkpoint + 1:
|
||||||
|
if should_stop(checkpoint_dir):
|
||||||
|
break
|
||||||
if data_store:
|
if data_store:
|
||||||
data_store.load_from_store()
|
data_store.load_from_store()
|
||||||
new_checkpoint = get_latest_checkpoint(checkpoint_dir)
|
new_checkpoint = get_latest_checkpoint(checkpoint_dir)
|
||||||
|
|||||||
@@ -40,8 +40,9 @@ def training_worker(graph_manager, checkpoint_dir):
|
|||||||
graph_manager.phase = core_types.RunPhase.UNDEFINED
|
graph_manager.phase = core_types.RunPhase.UNDEFINED
|
||||||
|
|
||||||
if steps * graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps > graph_manager.steps_between_evaluation_periods.num_steps * eval_offset:
|
if steps * graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps > graph_manager.steps_between_evaluation_periods.num_steps * eval_offset:
|
||||||
graph_manager.evaluate(graph_manager.evaluation_steps)
|
|
||||||
eval_offset += 1
|
eval_offset += 1
|
||||||
|
if graph_manager.evaluate(graph_manager.evaluation_steps):
|
||||||
|
break
|
||||||
|
|
||||||
if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
|
if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
|
||||||
graph_manager.save_checkpoint()
|
graph_manager.save_checkpoint()
|
||||||
|
|||||||
Reference in New Issue
Block a user