mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Adding target reward and target sucess (#58)
* Adding target reward * Adding target successs * Addressing comments * Using custom_reward_threshold and target_success_rate * Adding exit message * Moving success rate to environment * Making target_success_rate optional
This commit is contained in:
committed by
Balaji Subramaniam
parent
0fe583186e
commit
875d6ef017
@@ -34,6 +34,7 @@ from rl_coach.logger import screen, Logger
|
||||
from rl_coach.utils import set_cpu, start_shell_command_and_wait
|
||||
from rl_coach.data_stores.data_store_impl import get_data_store
|
||||
from rl_coach.orchestrators.kubernetes_orchestrator import RunType
|
||||
from rl_coach.data_stores.data_store import SyncFiles
|
||||
|
||||
|
||||
class ScheduleParameters(Parameters):
|
||||
@@ -458,12 +459,12 @@ class GraphManager(object):
|
||||
"""
|
||||
[manager.sync() for manager in self.level_managers]
|
||||
|
||||
def evaluate(self, steps: PlayingStepsType, keep_networks_in_sync: bool=False) -> None:
|
||||
def evaluate(self, steps: PlayingStepsType, keep_networks_in_sync: bool=False) -> bool:
|
||||
"""
|
||||
Perform evaluation for several steps
|
||||
:param steps: the number of steps as a tuple of steps time and steps count
|
||||
:param keep_networks_in_sync: sync the network parameters with the global network before each episode
|
||||
:return: None
|
||||
:return: bool, True if the target reward and target success has been reached
|
||||
"""
|
||||
self.verify_graph_was_created()
|
||||
|
||||
@@ -478,6 +479,16 @@ class GraphManager(object):
|
||||
while self.current_step_counter < count_end:
|
||||
self.act(EnvironmentEpisodes(1))
|
||||
self.sync()
|
||||
if self.should_stop():
|
||||
if self.task_parameters.checkpoint_save_dir:
|
||||
open(os.path.join(self.task_parameters.checkpoint_save_dir, SyncFiles.FINISHED.value), 'w').close()
|
||||
if hasattr(self, 'data_store_params'):
|
||||
data_store = get_data_store(self.data_store_params)
|
||||
data_store.save_to_store()
|
||||
|
||||
screen.success("Reached required success rate. Exiting.")
|
||||
return True
|
||||
return False
|
||||
|
||||
def improve(self):
|
||||
"""
|
||||
@@ -508,7 +519,8 @@ class GraphManager(object):
|
||||
count_end = self.total_steps_counters[RunPhase.TRAIN] + self.improve_steps
|
||||
while self.total_steps_counters[RunPhase.TRAIN] < count_end:
|
||||
self.train_and_act(self.steps_between_evaluation_periods)
|
||||
self.evaluate(self.evaluation_steps)
|
||||
if self.evaluate(self.evaluation_steps):
|
||||
break
|
||||
|
||||
def _restore_checkpoint_tf(self, checkpoint_dir: str):
|
||||
import tensorflow as tf
|
||||
@@ -609,3 +621,6 @@ class GraphManager(object):
|
||||
|
||||
def should_train(self) -> bool:
|
||||
return any([manager.should_train() for manager in self.level_managers])
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
return all([manager.should_stop() for manager in self.level_managers])
|
||||
|
||||
Reference in New Issue
Block a user