1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Adding target reward and target sucess (#58)

* Adding target reward

* Adding target successs

* Addressing comments

* Using custom_reward_threshold and target_success_rate

* Adding exit message

* Moving success rate to environment

* Making target_success_rate optional
This commit is contained in:
Ajay Deshpande
2018-11-12 15:03:43 -08:00
committed by Balaji Subramaniam
parent 0fe583186e
commit 875d6ef017
17 changed files with 162 additions and 74 deletions

View File

@@ -34,6 +34,7 @@ from rl_coach.logger import screen, Logger
from rl_coach.utils import set_cpu, start_shell_command_and_wait
from rl_coach.data_stores.data_store_impl import get_data_store
from rl_coach.orchestrators.kubernetes_orchestrator import RunType
from rl_coach.data_stores.data_store import SyncFiles
class ScheduleParameters(Parameters):
@@ -458,12 +459,12 @@ class GraphManager(object):
"""
[manager.sync() for manager in self.level_managers]
def evaluate(self, steps: PlayingStepsType, keep_networks_in_sync: bool=False) -> None:
def evaluate(self, steps: PlayingStepsType, keep_networks_in_sync: bool=False) -> bool:
"""
Perform evaluation for several steps
:param steps: the number of steps as a tuple of steps time and steps count
:param keep_networks_in_sync: sync the network parameters with the global network before each episode
:return: None
:return: bool, True if the target reward and target success has been reached
"""
self.verify_graph_was_created()
@@ -478,6 +479,16 @@ class GraphManager(object):
while self.current_step_counter < count_end:
self.act(EnvironmentEpisodes(1))
self.sync()
if self.should_stop():
if self.task_parameters.checkpoint_save_dir:
open(os.path.join(self.task_parameters.checkpoint_save_dir, SyncFiles.FINISHED.value), 'w').close()
if hasattr(self, 'data_store_params'):
data_store = get_data_store(self.data_store_params)
data_store.save_to_store()
screen.success("Reached required success rate. Exiting.")
return True
return False
def improve(self):
"""
@@ -508,7 +519,8 @@ class GraphManager(object):
count_end = self.total_steps_counters[RunPhase.TRAIN] + self.improve_steps
while self.total_steps_counters[RunPhase.TRAIN] < count_end:
self.train_and_act(self.steps_between_evaluation_periods)
self.evaluate(self.evaluation_steps)
if self.evaluate(self.evaluation_steps):
break
def _restore_checkpoint_tf(self, checkpoint_dir: str):
import tensorflow as tf
@@ -609,3 +621,6 @@ class GraphManager(object):
def should_train(self) -> bool:
return any([manager.should_train() for manager in self.level_managers])
def should_stop(self) -> bool:
return all([manager.should_stop() for manager in self.level_managers])