1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

reenable redis; better error message

This commit is contained in:
Zach Dwiel
2018-09-17 19:50:03 +00:00
committed by zach dwiel
parent 009cf670f3
commit 3328b25549
3 changed files with 11 additions and 7 deletions

View File

@@ -30,4 +30,4 @@ RUN pip3 install -e .
# RUN pip3 install rl_coach
# CMD ["coach", "-p", "CartPole_PG", "-e", "cartpole"]
CMD python3 rl_coach/rollout_worker.py --preset CartPole_PG
# CMD python3 rl_coach/rollout_worker.py --preset CartPole_DQN_distributed

View File

@@ -26,6 +26,7 @@ endif
build:
${DOCKER} build -f=Dockerfile -t=${IMAGE} ${BUILD_ARGUMENTS} ${CONTEXT}
mkdir -p /tmp/checkpoint
rm -rf /tmp/checkpoint/*
shell: build
${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} /bin/bash
@@ -37,10 +38,10 @@ run: build
${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE}
run_training_worker: build
${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} python3 rl_coach/training_worker.py --preset CartPole_PG
${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} python3 rl_coach/training_worker.py --preset CartPole_DQN_distributed
run_rollout_worker: build
${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} python3 rl_coach/rollout_worker.py --preset CartPole_PG
${DOCKER} run ${RUN_ARGUMENTS} -it ${IMAGE} python3 rl_coach/rollout_worker.py --preset CartPole_DQN_distributed
push:
docker push ${IMAGE}

View File

@@ -40,7 +40,11 @@ def wait_for_checkpoint(checkpoint_dir, timeout=10):
if has_checkpoint(checkpoint_dir):
return
raise ValueError('checkpoint never found in {checkpoint_dir}'.format(
raise ValueError((
'Waited {timeout} seconds, but checkpoint never found in'
' {checkpoint_dir}'
).format(
timeout=timeout,
checkpoint_dir=checkpoint_dir,
))
@@ -82,9 +86,8 @@ def main():
graph_manager = short_dynamic_import(expand_preset(args.preset), ignore_module_case=True)
# TODO: get this working, this expects that memory already has a redis ip and port
# graph_manager.agent_params.memory.redis_ip = args.redis_ip
# graph_manager.agent_params.memory.redis_port = args.redis_port
graph_manager.agent_params.memory.redis_ip = args.redis_ip
graph_manager.agent_params.memory.redis_port = args.redis_port
rollout_worker(
graph_manager=graph_manager,