diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index 3b799ba..db261b8 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -511,7 +511,7 @@ class GraphManager(object): ): self.save_checkpoint() - def _log_save_checkpoint(self): + def save_checkpoint(self): checkpoint_path = os.path.join(self.task_parameters.save_checkpoint_dir, "{}_Step-{}.ckpt".format( self.checkpoint_id, @@ -521,6 +521,9 @@ class GraphManager(object): else: saved_checkpoint_path = checkpoint_path + # this is required in order for agents to save additional information like a DND for example + [manager.save_checkpoint(self.checkpoint_id) for manager in self.level_managers] + screen.log_dict( OrderedDict([ ("Saving in path", saved_checkpoint_path), @@ -528,12 +531,6 @@ class GraphManager(object): prefix="Checkpoint" ) - def save_checkpoint(self): - # this is required in order for agents to save additional information like a DND for example - [manager.save_checkpoint(self.checkpoint_id) for manager in self.level_managers] - - self._log_save_checkpoint() - self.checkpoint_id += 1 self.last_checkpoint_saving_time = time.time() diff --git a/rl_coach/training_worker.py b/rl_coach/training_worker.py index d2445e6..1b9ddea 100644 --- a/rl_coach/training_worker.py +++ b/rl_coach/training_worker.py @@ -29,11 +29,6 @@ def training_worker(graph_manager, checkpoint_dir): # save randomly initialized graph graph_manager.save_checkpoint() - # TODO: Q: training steps passed into graph_manager.train ignored? - # TODO: specify training steps between checkpoints (in preset?) - # TODO: replace outer training loop with something general - # TODO: low priority: move evaluate out of this process - heatup(graph_manager) # training loop @@ -57,10 +52,21 @@ def main(): help='(string) Path to a folder containing a checkpoint to write the model to.', type=str, default='/checkpoint') + parser.add_argument('-r', '--redis_ip', + help="(string) IP or host for the redis server", + default='localhost', + type=str) + parser.add_argument('-rp', '--redis_port', + help="(int) Port of the redis server", + default=6379, + type=int) args = parser.parse_args() graph_manager = short_dynamic_import(expand_preset(args.preset), ignore_module_case=True) + graph_manager.agent_params.memory.redis_ip = args.redis_ip + graph_manager.agent_params.memory.redis_port = args.redis_port + training_worker( graph_manager=graph_manager, checkpoint_dir=args.checkpoint_dir,