1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

Adding target reward and target sucess (#58)

* Adding target reward

* Adding target successs

* Addressing comments

* Using custom_reward_threshold and target_success_rate

* Adding exit message

* Moving success rate to environment

* Making target_success_rate optional
This commit is contained in:
Ajay Deshpande
2018-11-12 15:03:43 -08:00
committed by Balaji Subramaniam
parent 0fe583186e
commit 875d6ef017
17 changed files with 162 additions and 74 deletions

View File

@@ -15,6 +15,7 @@ from rl_coach.base_parameters import TaskParameters, DistributedCoachSynchroniza
from rl_coach.core_types import EnvironmentSteps, RunPhase
from google.protobuf import text_format
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
from rl_coach.data_stores.data_store import SyncFiles
def has_checkpoint(checkpoint_dir):
@@ -68,6 +69,10 @@ def get_latest_checkpoint(checkpoint_dir):
return int(rel_path.split('_Step')[0])
def should_stop(checkpoint_dir):
return os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value))
def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers):
"""
wait for first checkpoint then perform rollouts using the model
@@ -87,12 +92,17 @@ def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers):
for i in range(int(graph_manager.improve_steps.num_steps/act_steps)):
if should_stop(checkpoint_dir):
break
graph_manager.act(EnvironmentSteps(num_steps=act_steps))
new_checkpoint = get_latest_checkpoint(checkpoint_dir)
if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
while new_checkpoint < last_checkpoint + 1:
if should_stop(checkpoint_dir):
break
if data_store:
data_store.load_from_store()
new_checkpoint = get_latest_checkpoint(checkpoint_dir)