mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Adding framework for multinode tests (#149)
* Currently runs CartPole_ClippedPPO and Mujoco_ClippedPPO with inverted_pendulum level.
This commit is contained in:
committed by
Balaji Subramaniam
parent
b461a1b8ab
commit
2c1a9dbf20
@@ -504,12 +504,7 @@ class GraphManager(object):
|
||||
self.act(EnvironmentEpisodes(1))
|
||||
self.sync()
|
||||
if self.should_stop():
|
||||
if self.task_parameters.checkpoint_save_dir and os.path.exists(self.task_parameters.checkpoint_save_dir):
|
||||
open(os.path.join(self.task_parameters.checkpoint_save_dir, SyncFiles.FINISHED.value), 'w').close()
|
||||
if hasattr(self, 'data_store_params'):
|
||||
data_store = self.get_data_store(self.data_store_params)
|
||||
data_store.save_to_store()
|
||||
|
||||
self.flush_finished()
|
||||
screen.success("Reached required success rate. Exiting.")
|
||||
return True
|
||||
return False
|
||||
@@ -713,3 +708,20 @@ class GraphManager(object):
|
||||
"""
|
||||
for env in self.environments:
|
||||
env.close()
|
||||
|
||||
def get_current_episodes_count(self):
|
||||
"""
|
||||
Returns the current EnvironmentEpisodes counter
|
||||
"""
|
||||
return self.current_step_counter[EnvironmentEpisodes]
|
||||
|
||||
def flush_finished(self):
|
||||
"""
|
||||
To indicate the training has finished, writes a `.finished` file to the checkpoint directory and calls
|
||||
the data store to updload that file.
|
||||
"""
|
||||
if self.task_parameters.checkpoint_save_dir and os.path.exists(self.task_parameters.checkpoint_save_dir):
|
||||
open(os.path.join(self.task_parameters.checkpoint_save_dir, SyncFiles.FINISHED.value), 'w').close()
|
||||
if hasattr(self, 'data_store_params'):
|
||||
data_store = self.get_data_store(self.data_store_params)
|
||||
data_store.save_to_store()
|
||||
|
||||
Reference in New Issue
Block a user