mirror of
https://github.com/gryf/coach.git
synced 2026-02-17 06:35:47 +01:00
Adding target reward and target sucess (#58)
* Adding target reward * Adding target successs * Addressing comments * Using custom_reward_threshold and target_success_rate * Adding exit message * Moving success rate to environment * Making target_success_rate optional
This commit is contained in:
committed by
Balaji Subramaniam
parent
0fe583186e
commit
875d6ef017
@@ -1,4 +1,6 @@
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class DataStoreParameters(object):
|
||||
def __init__(self, store_type, orchestrator_type, orchestrator_params):
|
||||
@@ -6,6 +8,7 @@ class DataStoreParameters(object):
|
||||
self.orchestrator_type = orchestrator_type
|
||||
self.orchestrator_params = orchestrator_params
|
||||
|
||||
|
||||
class DataStore(object):
|
||||
def __init__(self, params: DataStoreParameters):
|
||||
pass
|
||||
@@ -24,3 +27,8 @@ class DataStore(object):
|
||||
|
||||
def load_from_store(self):
|
||||
pass
|
||||
|
||||
|
||||
class SyncFiles(Enum):
|
||||
FINISHED = ".finished"
|
||||
LOCKFILE = ".lock"
|
||||
|
||||
@@ -58,7 +58,7 @@ class NFSDataStore(DataStore):
|
||||
pass
|
||||
|
||||
def deploy_k8s_nfs(self) -> bool:
|
||||
name = "nfs-server"
|
||||
name = "nfs-server-{}".format(uuid.uuid4())
|
||||
container = k8sclient.V1Container(
|
||||
name=name,
|
||||
image="k8s.gcr.io/volume-nfs:0.8",
|
||||
@@ -83,7 +83,7 @@ class NFSDataStore(DataStore):
|
||||
security_context=k8sclient.V1SecurityContext(privileged=True)
|
||||
)
|
||||
template = k8sclient.V1PodTemplateSpec(
|
||||
metadata=k8sclient.V1ObjectMeta(labels={'app': 'nfs-server'}),
|
||||
metadata=k8sclient.V1ObjectMeta(labels={'app': name}),
|
||||
spec=k8sclient.V1PodSpec(
|
||||
containers=[container],
|
||||
volumes=[k8sclient.V1Volume(
|
||||
@@ -96,14 +96,14 @@ class NFSDataStore(DataStore):
|
||||
replicas=1,
|
||||
template=template,
|
||||
selector=k8sclient.V1LabelSelector(
|
||||
match_labels={'app': 'nfs-server'}
|
||||
match_labels={'app': name}
|
||||
)
|
||||
)
|
||||
|
||||
deployment = k8sclient.V1Deployment(
|
||||
api_version='apps/v1',
|
||||
kind='Deployment',
|
||||
metadata=k8sclient.V1ObjectMeta(name=name, labels={'app': 'nfs-server'}),
|
||||
metadata=k8sclient.V1ObjectMeta(name=name, labels={'app': name}),
|
||||
spec=deployment_spec
|
||||
)
|
||||
|
||||
@@ -117,7 +117,7 @@ class NFSDataStore(DataStore):
|
||||
|
||||
k8s_core_v1_api_client = k8sclient.CoreV1Api()
|
||||
|
||||
svc_name = "nfs-service"
|
||||
svc_name = "nfs-service-{}".format(uuid.uuid4())
|
||||
service = k8sclient.V1Service(
|
||||
api_version='v1',
|
||||
kind='Service',
|
||||
@@ -145,7 +145,7 @@ class NFSDataStore(DataStore):
|
||||
return True
|
||||
|
||||
def create_k8s_nfs_resources(self) -> bool:
|
||||
pv_name = "nfs-ckpt-pv"
|
||||
pv_name = "nfs-ckpt-pv-{}".format(uuid.uuid4())
|
||||
persistent_volume = k8sclient.V1PersistentVolume(
|
||||
api_version="v1",
|
||||
kind="PersistentVolume",
|
||||
@@ -171,7 +171,7 @@ class NFSDataStore(DataStore):
|
||||
print("Got exception: %s\n while creating the NFS PV", e)
|
||||
return False
|
||||
|
||||
pvc_name = "nfs-ckpt-pvc"
|
||||
pvc_name = "nfs-ckpt-pvc-{}".format(uuid.uuid4())
|
||||
persistent_volume_claim = k8sclient.V1PersistentVolumeClaim(
|
||||
api_version="v1",
|
||||
kind="PersistentVolumeClaim",
|
||||
|
||||
@@ -5,6 +5,7 @@ from minio.error import ResponseError
|
||||
from configparser import ConfigParser, Error
|
||||
from google.protobuf import text_format
|
||||
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
|
||||
from rl_coach.data_stores.data_store import SyncFiles
|
||||
|
||||
import os
|
||||
import time
|
||||
@@ -20,7 +21,6 @@ class S3DataStoreParameters(DataStoreParameters):
|
||||
self.end_point = end_point
|
||||
self.bucket_name = bucket_name
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
self.lock_file = ".lock"
|
||||
|
||||
|
||||
class S3DataStore(DataStore):
|
||||
@@ -52,9 +52,9 @@ class S3DataStore(DataStore):
|
||||
|
||||
def save_to_store(self):
|
||||
try:
|
||||
self.mc.remove_object(self.params.bucket_name, self.params.lock_file)
|
||||
self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value)
|
||||
|
||||
self.mc.put_object(self.params.bucket_name, self.params.lock_file, io.BytesIO(b''), 0)
|
||||
self.mc.put_object(self.params.bucket_name, SyncFiles.LOCKFILE.value, io.BytesIO(b''), 0)
|
||||
|
||||
checkpoint_file = None
|
||||
for root, dirs, files in os.walk(self.params.checkpoint_dir):
|
||||
@@ -70,7 +70,7 @@ class S3DataStore(DataStore):
|
||||
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
|
||||
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
|
||||
|
||||
self.mc.remove_object(self.params.bucket_name, self.params.lock_file)
|
||||
self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value)
|
||||
|
||||
except ResponseError as e:
|
||||
print("Got exception: %s\n while saving to S3", e)
|
||||
@@ -80,7 +80,7 @@ class S3DataStore(DataStore):
|
||||
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint"))
|
||||
|
||||
while True:
|
||||
objects = self.mc.list_objects_v2(self.params.bucket_name, self.params.lock_file)
|
||||
objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.LOCKFILE.value)
|
||||
|
||||
if next(objects, None) is None:
|
||||
try:
|
||||
@@ -90,6 +90,18 @@ class S3DataStore(DataStore):
|
||||
break
|
||||
time.sleep(10)
|
||||
|
||||
# Check if there's a finished file
|
||||
objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.FINISHED.value)
|
||||
|
||||
if next(objects, None) is not None:
|
||||
try:
|
||||
self.mc.fget_object(
|
||||
self.params.bucket_name, SyncFiles.FINISHED.value,
|
||||
os.path.abspath(os.path.join(self.params.checkpoint_dir, SyncFiles.FINISHED.value))
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
ckpt = CheckpointState()
|
||||
if os.path.exists(filename):
|
||||
contents = open(filename, 'r').read()
|
||||
|
||||
Reference in New Issue
Block a user