1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Adding target reward and target sucess (#58)

* Adding target reward

* Adding target successs

* Addressing comments

* Using custom_reward_threshold and target_success_rate

* Adding exit message

* Moving success rate to environment

* Making target_success_rate optional
This commit is contained in:
Ajay Deshpande
2018-11-12 15:03:43 -08:00
committed by Balaji Subramaniam
parent 0fe583186e
commit 875d6ef017
17 changed files with 162 additions and 74 deletions

View File

@@ -5,6 +5,7 @@ from minio.error import ResponseError
from configparser import ConfigParser, Error
from google.protobuf import text_format
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
from rl_coach.data_stores.data_store import SyncFiles
import os
import time
@@ -20,7 +21,6 @@ class S3DataStoreParameters(DataStoreParameters):
self.end_point = end_point
self.bucket_name = bucket_name
self.checkpoint_dir = checkpoint_dir
self.lock_file = ".lock"
class S3DataStore(DataStore):
@@ -52,9 +52,9 @@ class S3DataStore(DataStore):
def save_to_store(self):
try:
self.mc.remove_object(self.params.bucket_name, self.params.lock_file)
self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value)
self.mc.put_object(self.params.bucket_name, self.params.lock_file, io.BytesIO(b''), 0)
self.mc.put_object(self.params.bucket_name, SyncFiles.LOCKFILE.value, io.BytesIO(b''), 0)
checkpoint_file = None
for root, dirs, files in os.walk(self.params.checkpoint_dir):
@@ -70,7 +70,7 @@ class S3DataStore(DataStore):
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
self.mc.remove_object(self.params.bucket_name, self.params.lock_file)
self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value)
except ResponseError as e:
print("Got exception: %s\n while saving to S3", e)
@@ -80,7 +80,7 @@ class S3DataStore(DataStore):
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint"))
while True:
objects = self.mc.list_objects_v2(self.params.bucket_name, self.params.lock_file)
objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.LOCKFILE.value)
if next(objects, None) is None:
try:
@@ -90,6 +90,18 @@ class S3DataStore(DataStore):
break
time.sleep(10)
# Check if there's a finished file
objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.FINISHED.value)
if next(objects, None) is not None:
try:
self.mc.fget_object(
self.params.bucket_name, SyncFiles.FINISHED.value,
os.path.abspath(os.path.join(self.params.checkpoint_dir, SyncFiles.FINISHED.value))
)
except Exception as e:
pass
ckpt = CheckpointState()
if os.path.exists(filename):
contents = open(filename, 'r').read()