mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
working checkpoints
This commit is contained in:
@@ -29,4 +29,4 @@ RUN pip3 install -e .
|
|||||||
# RUN pip3 install rl_coach
|
# RUN pip3 install rl_coach
|
||||||
|
|
||||||
# CMD ["coach", "-p", "CartPole_PG", "-e", "cartpole"]
|
# CMD ["coach", "-p", "CartPole_PG", "-e", "cartpole"]
|
||||||
CMD python3 rl_coach/rollout_worker.py
|
CMD python3 rl_coach/rollout_worker.py --preset CartPole_PG
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ endif
|
|||||||
|
|
||||||
RUN_ARGUMENTS+=--rm
|
RUN_ARGUMENTS+=--rm
|
||||||
RUN_ARGUMENTS+=--net host
|
RUN_ARGUMENTS+=--net host
|
||||||
|
RUN_ARGUMENTS+=-v /tmp/checkpoint:/checkpoint
|
||||||
|
|
||||||
CONTEXT = $(realpath ..)
|
CONTEXT = $(realpath ..)
|
||||||
|
|
||||||
@@ -24,6 +25,7 @@ endif
|
|||||||
|
|
||||||
build:
|
build:
|
||||||
${DOCKER} build -f=Dockerfile -t=${IMAGE} ${BUILD_ARGUMENTS} ${CONTEXT}
|
${DOCKER} build -f=Dockerfile -t=${IMAGE} ${BUILD_ARGUMENTS} ${CONTEXT}
|
||||||
|
mkdir -p /tmp/checkpoint
|
||||||
|
|
||||||
shell: build
|
shell: build
|
||||||
${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} /bin/bash
|
${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} /bin/bash
|
||||||
@@ -34,5 +36,11 @@ test: build
|
|||||||
run: build
|
run: build
|
||||||
${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE}
|
${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE}
|
||||||
|
|
||||||
|
run_training_worker: build
|
||||||
|
${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} python3 rl_coach/training_worker.py --preset CartPole_PG
|
||||||
|
|
||||||
|
run_rollout_worker: build
|
||||||
|
${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} python3 rl_coach/rollout_worker.py --preset CartPole_PG
|
||||||
|
|
||||||
push:
|
push:
|
||||||
docker push ${IMAGE}
|
docker push ${IMAGE}
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
#
|
|
||||||
# Copyright (c) 2017 Intel Corporation
|
# Copyright (c) 2017 Intel Corporation
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
|||||||
@@ -341,6 +341,16 @@ class GraphManager(object):
|
|||||||
self.total_steps_counters[RunPhase.TRAIN][TrainingSteps] += 1
|
self.total_steps_counters[RunPhase.TRAIN][TrainingSteps] += 1
|
||||||
[manager.train() for manager in self.level_managers]
|
[manager.train() for manager in self.level_managers]
|
||||||
|
|
||||||
|
# # option 1
|
||||||
|
# for _ in StepsLoop(self.total_steps_counters, RunPhase.TRAIN, steps):
|
||||||
|
# [manager.train() for manager in self.level_managers]
|
||||||
|
#
|
||||||
|
# # option 2
|
||||||
|
# steps_loop = StepsLoop(self.total_steps_counters, RunPhase.TRAIN, steps)
|
||||||
|
# while steps_loop or other:
|
||||||
|
# [manager.train() for manager in self.level_managers]
|
||||||
|
|
||||||
|
|
||||||
def reset_internal_state(self, force_environment_reset=False) -> None:
|
def reset_internal_state(self, force_environment_reset=False) -> None:
|
||||||
"""
|
"""
|
||||||
Reset an episode for all the levels
|
Reset an episode for all the levels
|
||||||
@@ -403,6 +413,7 @@ class GraphManager(object):
|
|||||||
if result.game_over:
|
if result.game_over:
|
||||||
hold_until_a_full_episode = False
|
hold_until_a_full_episode = False
|
||||||
self.handle_episode_ended()
|
self.handle_episode_ended()
|
||||||
|
# TODO: why not just reset right now?
|
||||||
self.reset_required = True
|
self.reset_required = True
|
||||||
if keep_networks_in_sync:
|
if keep_networks_in_sync:
|
||||||
self.sync_graph()
|
self.sync_graph()
|
||||||
@@ -426,16 +437,16 @@ class GraphManager(object):
|
|||||||
# perform several steps of training interleaved with acting
|
# perform several steps of training interleaved with acting
|
||||||
if steps.num_steps > 0:
|
if steps.num_steps > 0:
|
||||||
self.phase = RunPhase.TRAIN
|
self.phase = RunPhase.TRAIN
|
||||||
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
|
|
||||||
self.reset_internal_state(force_environment_reset=True)
|
self.reset_internal_state(force_environment_reset=True)
|
||||||
#TODO - the below while loop should end with full episodes, so to avoid situations where we have partial
|
#TODO - the below while loop should end with full episodes, so to avoid situations where we have partial
|
||||||
# episodes in memory
|
# episodes in memory
|
||||||
|
count_end = self.total_steps_counters[self.phase][steps.__class__] + steps.num_steps
|
||||||
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
|
while self.total_steps_counters[self.phase][steps.__class__] < count_end:
|
||||||
# The actual steps being done on the environment are decided by the agents themselves.
|
# The actual steps being done on the environment are decided by the agents themselves.
|
||||||
# This is just an high-level controller.
|
# This is just an high-level controller.
|
||||||
self.act(EnvironmentSteps(1))
|
self.act(EnvironmentSteps(1))
|
||||||
self.train(TrainingSteps(1))
|
self.train(TrainingSteps(1))
|
||||||
self.save_checkpoint()
|
self.occasionally_save_checkpoint()
|
||||||
self.phase = RunPhase.UNDEFINED
|
self.phase = RunPhase.UNDEFINED
|
||||||
|
|
||||||
def sync_graph(self) -> None:
|
def sync_graph(self) -> None:
|
||||||
@@ -491,14 +502,16 @@ class GraphManager(object):
|
|||||||
for v in self.variables_to_restore:
|
for v in self.variables_to_restore:
|
||||||
self.sess.run(v.assign(variables[v.name.split(':')[0]]))
|
self.sess.run(v.assign(variables[v.name.split(':')[0]]))
|
||||||
|
|
||||||
def save_checkpoint(self):
|
def occasionally_save_checkpoint(self):
|
||||||
# only the chief process saves checkpoints
|
# only the chief process saves checkpoints
|
||||||
if self.task_parameters.save_checkpoint_secs \
|
if self.task_parameters.save_checkpoint_secs \
|
||||||
and time.time() - self.last_checkpoint_saving_time >= self.task_parameters.save_checkpoint_secs \
|
and time.time() - self.last_checkpoint_saving_time >= self.task_parameters.save_checkpoint_secs \
|
||||||
and (self.task_parameters.task_index == 0 # distributed
|
and (self.task_parameters.task_index == 0 # distributed
|
||||||
or self.task_parameters.task_index is None # single-worker
|
or self.task_parameters.task_index is None # single-worker
|
||||||
):
|
):
|
||||||
|
self.save_checkpoint()
|
||||||
|
|
||||||
|
def _log_save_checkpoint(self):
|
||||||
checkpoint_path = os.path.join(self.task_parameters.save_checkpoint_dir,
|
checkpoint_path = os.path.join(self.task_parameters.save_checkpoint_dir,
|
||||||
"{}_Step-{}.ckpt".format(
|
"{}_Step-{}.ckpt".format(
|
||||||
self.checkpoint_id,
|
self.checkpoint_id,
|
||||||
@@ -508,9 +521,6 @@ class GraphManager(object):
|
|||||||
else:
|
else:
|
||||||
saved_checkpoint_path = checkpoint_path
|
saved_checkpoint_path = checkpoint_path
|
||||||
|
|
||||||
# this is required in order for agents to save additional information like a DND for example
|
|
||||||
[manager.save_checkpoint(self.checkpoint_id) for manager in self.level_managers]
|
|
||||||
|
|
||||||
screen.log_dict(
|
screen.log_dict(
|
||||||
OrderedDict([
|
OrderedDict([
|
||||||
("Saving in path", saved_checkpoint_path),
|
("Saving in path", saved_checkpoint_path),
|
||||||
@@ -518,6 +528,12 @@ class GraphManager(object):
|
|||||||
prefix="Checkpoint"
|
prefix="Checkpoint"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def save_checkpoint(self):
|
||||||
|
# this is required in order for agents to save additional information like a DND for example
|
||||||
|
[manager.save_checkpoint(self.checkpoint_id) for manager in self.level_managers]
|
||||||
|
|
||||||
|
self._log_save_checkpoint()
|
||||||
|
|
||||||
self.checkpoint_id += 1
|
self.checkpoint_id += 1
|
||||||
self.last_checkpoint_saving_time = time.time()
|
self.last_checkpoint_saving_time = time.time()
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,7 @@
|
|||||||
|
"""
|
||||||
|
this rollout worker restores a model from disk, evaluates a predefined number of
|
||||||
|
episodes, and contributes them to a distributed memory
|
||||||
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from rl_coach.base_parameters import TaskParameters
|
from rl_coach.base_parameters import TaskParameters
|
||||||
@@ -5,30 +9,40 @@ from rl_coach.coach import expand_preset
|
|||||||
from rl_coach.core_types import EnvironmentEpisodes, RunPhase
|
from rl_coach.core_types import EnvironmentEpisodes, RunPhase
|
||||||
from rl_coach.utils import short_dynamic_import
|
from rl_coach.utils import short_dynamic_import
|
||||||
|
|
||||||
|
# Q: specify alternative distributed memory, or should this go in the preset?
|
||||||
|
# A: preset must define distributed memory to be used. we aren't going to take a non-distributed preset and automatically distribute it.
|
||||||
|
|
||||||
|
|
||||||
# TODO: acce[t preset option
|
def rollout_worker(graph_manager, checkpoint_dir):
|
||||||
# TODO: workers might need to define schedules in terms which can be synchronized: exploration(len(distributed_memory)) -> float
|
"""
|
||||||
# TODO: periodically reload policy (from disk?)
|
restore a checkpoint then perform rollouts using the restored model
|
||||||
# TODO: specify alternative distributed memory, or should this go in the preset?
|
"""
|
||||||
|
|
||||||
def rollout_worker(graph_manager):
|
|
||||||
task_parameters = TaskParameters()
|
task_parameters = TaskParameters()
|
||||||
task_parameters.checkpoint_restore_dir='/checkpoint'
|
task_parameters.__dict__['checkpoint_restore_dir'] = checkpoint_dir
|
||||||
graph_manager.create_graph(task_parameters)
|
graph_manager.create_graph(task_parameters)
|
||||||
graph_manager.phase = RunPhase.TRAIN
|
graph_manager.phase = RunPhase.TRAIN
|
||||||
graph_manager.act(EnvironmentEpisodes(num_steps=10))
|
graph_manager.act(EnvironmentEpisodes(num_steps=10))
|
||||||
|
graph_manager.phase = RunPhase.UNDEFINED
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-p', '--preset',
|
parser.add_argument('-p', '--preset',
|
||||||
help="(string) Name of a preset to run (class name from the 'presets' directory.)",
|
help="(string) Name of a preset to run (class name from the 'presets' directory.)",
|
||||||
type=str)
|
type=str,
|
||||||
|
required=True)
|
||||||
|
parser.add_argument('--checkpoint_dir',
|
||||||
|
help='(string) Path to a folder containing a checkpoint to restore the model from.',
|
||||||
|
type=str,
|
||||||
|
default='/checkpoint')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
graph_manager = short_dynamic_import(expand_preset(args.preset), ignore_module_case=True)
|
graph_manager = short_dynamic_import(expand_preset(args.preset), ignore_module_case=True)
|
||||||
rollout_worker(graph_manager)
|
|
||||||
|
rollout_worker(
|
||||||
|
graph_manager=graph_manager,
|
||||||
|
checkpoint_dir=args.checkpoint_dir,
|
||||||
|
)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user