diff --git a/rl_coach/data_stores/s3_data_store.py b/rl_coach/data_stores/s3_data_store.py index cc58367..7d40f8b 100644 --- a/rl_coach/data_stores/s3_data_store.py +++ b/rl_coach/data_stores/s3_data_store.py @@ -52,13 +52,10 @@ class S3DataStore(DataStore): def save_to_store(self): try: - print("Writing lock file") - self.mc.remove_object(self.params.bucket_name, self.params.lock_file) self.mc.put_object(self.params.bucket_name, self.params.lock_file, io.BytesIO(b''), 0) - print("saving to s3") checkpoint_file = None for root, dirs, files in os.walk(self.params.checkpoint_dir): for filename in files: @@ -73,7 +70,6 @@ class S3DataStore(DataStore): rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir) self.mc.fput_object(self.params.bucket_name, rel_name, abs_name) - print("Deleting lock file") self.mc.remove_object(self.params.bucket_name, self.params.lock_file) except ResponseError as e: @@ -81,7 +77,6 @@ class S3DataStore(DataStore): def load_from_store(self): try: - filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint")) while True: @@ -95,8 +90,6 @@ class S3DataStore(DataStore): break time.sleep(10) - print("loading from s3") - ckpt = CheckpointState() if os.path.exists(filename): contents = open(filename, 'r').read() diff --git a/rl_coach/rollout_worker.py b/rl_coach/rollout_worker.py index a04d23c..8efa7c2 100644 --- a/rl_coach/rollout_worker.py +++ b/rl_coach/rollout_worker.py @@ -16,7 +16,7 @@ from threading import Thread from rl_coach.base_parameters import TaskParameters from rl_coach.coach import expand_preset -from rl_coach.core_types import EnvironmentEpisodes, RunPhase +from rl_coach.core_types import EnvironmentSteps, RunPhase from rl_coach.utils import short_dynamic_import from rl_coach.memories.backend.memory_impl import construct_memory_params from rl_coach.data_stores.data_store_impl import get_data_store, construct_data_store_params @@ -82,7 +82,7 @@ def check_for_new_checkpoint(checkpoint_dir, last_checkpoint): return last_checkpoint -def rollout_worker(graph_manager, checkpoint_dir): +def rollout_worker(graph_manager, checkpoint_dir, data_store): """ wait for first checkpoint then perform rollouts using the model """ @@ -94,16 +94,26 @@ def rollout_worker(graph_manager, checkpoint_dir): graph_manager.create_graph(task_parameters) graph_manager.phase = RunPhase.TRAIN + error_compensation = 100 + last_checkpoint = 0 - for i in range(10000000): - graph_manager.act(EnvironmentEpisodes(num_steps=1)) + act_steps = graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps + error_compensation - new_checkpoint = check_for_new_checkpoint(checkpoint_dir, last_checkpoint) + print(act_steps, graph_manager.improve_steps.num_steps) - if new_checkpoint > last_checkpoint: - last_checkpoint = new_checkpoint - graph_manager.restore_checkpoint() + for i in range(int(graph_manager.improve_steps.num_steps/act_steps)): + + graph_manager.act(EnvironmentSteps(num_steps=act_steps)) + + new_checkpoint = last_checkpoint + 1 + while last_checkpoint < new_checkpoint: + if data_store: + data_store.load_from_store() + last_checkpoint = check_for_new_checkpoint(checkpoint_dir, last_checkpoint) + + last_checkpoint = new_checkpoint + graph_manager.restore_checkpoint() graph_manager.phase = RunPhase.UNDEFINED @@ -137,6 +147,7 @@ def main(): graph_manager = short_dynamic_import(expand_preset(args.preset), ignore_module_case=True) + data_store = None if args.memory_backend_params: args.memory_backend_params = json.loads(args.memory_backend_params) print(args.memory_backend_params) @@ -156,6 +167,7 @@ def main(): rollout_worker( graph_manager=graph_manager, checkpoint_dir=args.checkpoint_dir, + data_store=data_store ) if __name__ == '__main__': diff --git a/rl_coach/training_worker.py b/rl_coach/training_worker.py index b68bee7..1c79726 100644 --- a/rl_coach/training_worker.py +++ b/rl_coach/training_worker.py @@ -31,8 +31,10 @@ def training_worker(graph_manager, checkpoint_dir): graph_manager.save_checkpoint() # training loop - while True: + steps = 0 + while(steps < graph_manager.improve_steps.num_steps): if graph_manager.should_train(): + steps += 1 graph_manager.phase = core_types.RunPhase.TRAIN graph_manager.train(core_types.TrainingSteps(1)) graph_manager.phase = core_types.RunPhase.UNDEFINED