mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
Adding target reward and target sucess (#58)
* Adding target reward * Adding target successs * Addressing comments * Using custom_reward_threshold and target_success_rate * Adding exit message * Moving success rate to environment * Making target_success_rate optional
This commit is contained in:
committed by
Balaji Subramaniam
parent
0fe583186e
commit
875d6ef017
@@ -15,6 +15,7 @@ from rl_coach.base_parameters import TaskParameters, DistributedCoachSynchroniza
|
||||
from rl_coach.core_types import EnvironmentSteps, RunPhase
|
||||
from google.protobuf import text_format
|
||||
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
|
||||
from rl_coach.data_stores.data_store import SyncFiles
|
||||
|
||||
|
||||
def has_checkpoint(checkpoint_dir):
|
||||
@@ -68,6 +69,10 @@ def get_latest_checkpoint(checkpoint_dir):
|
||||
return int(rel_path.split('_Step')[0])
|
||||
|
||||
|
||||
def should_stop(checkpoint_dir):
|
||||
return os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value))
|
||||
|
||||
|
||||
def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers):
|
||||
"""
|
||||
wait for first checkpoint then perform rollouts using the model
|
||||
@@ -87,12 +92,17 @@ def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers):
|
||||
|
||||
for i in range(int(graph_manager.improve_steps.num_steps/act_steps)):
|
||||
|
||||
if should_stop(checkpoint_dir):
|
||||
break
|
||||
|
||||
graph_manager.act(EnvironmentSteps(num_steps=act_steps))
|
||||
|
||||
new_checkpoint = get_latest_checkpoint(checkpoint_dir)
|
||||
|
||||
if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
|
||||
while new_checkpoint < last_checkpoint + 1:
|
||||
if should_stop(checkpoint_dir):
|
||||
break
|
||||
if data_store:
|
||||
data_store.load_from_store()
|
||||
new_checkpoint = get_latest_checkpoint(checkpoint_dir)
|
||||
|
||||
Reference in New Issue
Block a user