diff --git a/docker/Dockerfile b/docker/Dockerfile index a19f759..9966943 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -29,4 +29,4 @@ RUN pip3 install -e . # RUN pip3 install rl_coach # CMD ["coach", "-p", "CartPole_PG", "-e", "cartpole"] -CMD python3 rl_coach/rollout_worker.py +CMD python3 rl_coach/rollout_worker.py --preset CartPole_PG diff --git a/docker/Makefile b/docker/Makefile index 5b812ab..5174c40 100644 --- a/docker/Makefile +++ b/docker/Makefile @@ -15,6 +15,7 @@ endif RUN_ARGUMENTS+=--rm RUN_ARGUMENTS+=--net host +RUN_ARGUMENTS+=-v /tmp/checkpoint:/checkpoint CONTEXT = $(realpath ..) @@ -24,6 +25,7 @@ endif build: ${DOCKER} build -f=Dockerfile -t=${IMAGE} ${BUILD_ARGUMENTS} ${CONTEXT} + mkdir -p /tmp/checkpoint shell: build ${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} /bin/bash @@ -34,5 +36,11 @@ test: build run: build ${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} +run_training_worker: build + ${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} python3 rl_coach/training_worker.py --preset CartPole_PG + +run_rollout_worker: build + ${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} python3 rl_coach/rollout_worker.py --preset CartPole_PG + push: docker push ${IMAGE} diff --git a/rl_coach/coach.py b/rl_coach/coach.py index 1e6bfba..9c14003 100644 --- a/rl_coach/coach.py +++ b/rl_coach/coach.py @@ -1,4 +1,3 @@ -# # Copyright (c) 2017 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index d13259b..3b799ba 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -341,6 +341,16 @@ class GraphManager(object): self.total_steps_counters[RunPhase.TRAIN][TrainingSteps] += 1 [manager.train() for manager in self.level_managers] + # # option 1 + # for _ in StepsLoop(self.total_steps_counters, RunPhase.TRAIN, steps): + # [manager.train() for manager in self.level_managers] + # + # # option 2 + # steps_loop = StepsLoop(self.total_steps_counters, RunPhase.TRAIN, steps) + # while steps_loop or other: + # [manager.train() for manager in self.level_managers] + + def reset_internal_state(self, force_environment_reset=False) -> None: """ Reset an episode for all the levels @@ -403,6 +413,7 @@ class GraphManager(object): if result.game_over: hold_until_a_full_episode = False self.handle_episode_ended() + # TODO: why not just reset right now? self.reset_required = True if keep_networks_in_sync: self.sync_graph() @@ -426,16 +437,16 @@ class GraphManager(object): # perform several steps of training interleaved with acting if steps.num_steps > 0: self.phase = RunPhase.TRAIN - count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps self.reset_internal_state(force_environment_reset=True) #TODO - the below while loop should end with full episodes, so to avoid situations where we have partial # episodes in memory + count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps while self.total_steps_counters[self.phase][steps.__class__] < count_end: # The actual steps being done on the environment are decided by the agents themselves. # This is just an high-level controller. self.act(EnvironmentSteps(1)) self.train(TrainingSteps(1)) - self.save_checkpoint() + self.occasionally_save_checkpoint() self.phase = RunPhase.UNDEFINED def sync_graph(self) -> None: @@ -491,35 +502,40 @@ class GraphManager(object): for v in self.variables_to_restore: self.sess.run(v.assign(variables[v.name.split(':')[0]])) - def save_checkpoint(self): + def occasionally_save_checkpoint(self): # only the chief process saves checkpoints if self.task_parameters.save_checkpoint_secs \ and time.time() - self.last_checkpoint_saving_time >= self.task_parameters.save_checkpoint_secs \ and (self.task_parameters.task_index == 0 # distributed or self.task_parameters.task_index is None # single-worker ): + self.save_checkpoint() - checkpoint_path = os.path.join(self.task_parameters.save_checkpoint_dir, - "{}_Step-{}.ckpt".format( - self.checkpoint_id, - self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps])) - if not isinstance(self.task_parameters, DistributedTaskParameters): - saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path) - else: - saved_checkpoint_path = checkpoint_path + def _log_save_checkpoint(self): + checkpoint_path = os.path.join(self.task_parameters.save_checkpoint_dir, + "{}_Step-{}.ckpt".format( + self.checkpoint_id, + self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps])) + if not isinstance(self.task_parameters, DistributedTaskParameters): + saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path) + else: + saved_checkpoint_path = checkpoint_path - # this is required in order for agents to save additional information like a DND for example - [manager.save_checkpoint(self.checkpoint_id) for manager in self.level_managers] + screen.log_dict( + OrderedDict([ + ("Saving in path", saved_checkpoint_path), + ]), + prefix="Checkpoint" + ) - screen.log_dict( - OrderedDict([ - ("Saving in path", saved_checkpoint_path), - ]), - prefix="Checkpoint" - ) + def save_checkpoint(self): + # this is required in order for agents to save additional information like a DND for example + [manager.save_checkpoint(self.checkpoint_id) for manager in self.level_managers] - self.checkpoint_id += 1 - self.last_checkpoint_saving_time = time.time() + self._log_save_checkpoint() + + self.checkpoint_id += 1 + self.last_checkpoint_saving_time = time.time() def improve(self): """ diff --git a/rl_coach/rollout_worker.py b/rl_coach/rollout_worker.py index bfb5be3..90035cd 100644 --- a/rl_coach/rollout_worker.py +++ b/rl_coach/rollout_worker.py @@ -1,3 +1,7 @@ +""" +this rollout worker restores a model from disk, evaluates a predefined number of +episodes, and contributes them to a distributed memory +""" import argparse from rl_coach.base_parameters import TaskParameters @@ -5,30 +9,40 @@ from rl_coach.coach import expand_preset from rl_coach.core_types import EnvironmentEpisodes, RunPhase from rl_coach.utils import short_dynamic_import +# Q: specify alternative distributed memory, or should this go in the preset? +# A: preset must define distributed memory to be used. we aren't going to take a non-distributed preset and automatically distribute it. -# TODO: acce[t preset option -# TODO: workers might need to define schedules in terms which can be synchronized: exploration(len(distributed_memory)) -> float -# TODO: periodically reload policy (from disk?) -# TODO: specify alternative distributed memory, or should this go in the preset? - -def rollout_worker(graph_manager): +def rollout_worker(graph_manager, checkpoint_dir): + """ + restore a checkpoint then perform rollouts using the restored model + """ task_parameters = TaskParameters() - task_parameters.checkpoint_restore_dir='/checkpoint' + task_parameters.__dict__['checkpoint_restore_dir'] = checkpoint_dir graph_manager.create_graph(task_parameters) graph_manager.phase = RunPhase.TRAIN graph_manager.act(EnvironmentEpisodes(num_steps=10)) + graph_manager.phase = RunPhase.UNDEFINED def main(): parser = argparse.ArgumentParser() parser.add_argument('-p', '--preset', help="(string) Name of a preset to run (class name from the 'presets' directory.)", - type=str) + type=str, + required=True) + parser.add_argument('--checkpoint_dir', + help='(string) Path to a folder containing a checkpoint to restore the model from.', + type=str, + default='/checkpoint') args = parser.parse_args() graph_manager = short_dynamic_import(expand_preset(args.preset), ignore_module_case=True) - rollout_worker(graph_manager) + + rollout_worker( + graph_manager=graph_manager, + checkpoint_dir=args.checkpoint_dir, + ) if __name__ == '__main__': main()