1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

Adding framework for multinode tests (#149)

* Currently runs CartPole_ClippedPPO and Mujoco_ClippedPPO with inverted_pendulum level.
This commit is contained in:
Ajay Deshpande
2019-02-26 13:53:12 -08:00
committed by Balaji Subramaniam
parent b461a1b8ab
commit 2c1a9dbf20
8 changed files with 210 additions and 24 deletions

View File

@@ -504,12 +504,7 @@ class GraphManager(object):
self.act(EnvironmentEpisodes(1))
self.sync()
if self.should_stop():
if self.task_parameters.checkpoint_save_dir and os.path.exists(self.task_parameters.checkpoint_save_dir):
open(os.path.join(self.task_parameters.checkpoint_save_dir, SyncFiles.FINISHED.value), 'w').close()
if hasattr(self, 'data_store_params'):
data_store = self.get_data_store(self.data_store_params)
data_store.save_to_store()
self.flush_finished()
screen.success("Reached required success rate. Exiting.")
return True
return False
@@ -713,3 +708,20 @@ class GraphManager(object):
"""
for env in self.environments:
env.close()
def get_current_episodes_count(self):
"""
Returns the current EnvironmentEpisodes counter
"""
return self.current_step_counter[EnvironmentEpisodes]
def flush_finished(self):
"""
To indicate the training has finished, writes a `.finished` file to the checkpoint directory and calls
the data store to updload that file.
"""
if self.task_parameters.checkpoint_save_dir and os.path.exists(self.task_parameters.checkpoint_save_dir):
open(os.path.join(self.task_parameters.checkpoint_save_dir, SyncFiles.FINISHED.value), 'w').close()
if hasattr(self, 'data_store_params'):
data_store = self.get_data_store(self.data_store_params)
data_store.save_to_store()