1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

Adding steps and waiting for new checkpoint

This commit is contained in:
Ajay Deshpande
2018-10-08 13:41:51 -07:00
committed by zach dwiel
parent 0e121c5762
commit 0f46877d7e
3 changed files with 23 additions and 16 deletions

View File

@@ -16,7 +16,7 @@ from threading import Thread
from rl_coach.base_parameters import TaskParameters
from rl_coach.coach import expand_preset
from rl_coach.core_types import EnvironmentEpisodes, RunPhase
from rl_coach.core_types import EnvironmentSteps, RunPhase
from rl_coach.utils import short_dynamic_import
from rl_coach.memories.backend.memory_impl import construct_memory_params
from rl_coach.data_stores.data_store_impl import get_data_store, construct_data_store_params
@@ -82,7 +82,7 @@ def check_for_new_checkpoint(checkpoint_dir, last_checkpoint):
return last_checkpoint
def rollout_worker(graph_manager, checkpoint_dir):
def rollout_worker(graph_manager, checkpoint_dir, data_store):
"""
wait for first checkpoint then perform rollouts using the model
"""
@@ -94,16 +94,26 @@ def rollout_worker(graph_manager, checkpoint_dir):
graph_manager.create_graph(task_parameters)
graph_manager.phase = RunPhase.TRAIN
error_compensation = 100
last_checkpoint = 0
for i in range(10000000):
graph_manager.act(EnvironmentEpisodes(num_steps=1))
act_steps = graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps + error_compensation
new_checkpoint = check_for_new_checkpoint(checkpoint_dir, last_checkpoint)
print(act_steps, graph_manager.improve_steps.num_steps)
if new_checkpoint > last_checkpoint:
last_checkpoint = new_checkpoint
graph_manager.restore_checkpoint()
for i in range(int(graph_manager.improve_steps.num_steps/act_steps)):
graph_manager.act(EnvironmentSteps(num_steps=act_steps))
new_checkpoint = last_checkpoint + 1
while last_checkpoint < new_checkpoint:
if data_store:
data_store.load_from_store()
last_checkpoint = check_for_new_checkpoint(checkpoint_dir, last_checkpoint)
last_checkpoint = new_checkpoint
graph_manager.restore_checkpoint()
graph_manager.phase = RunPhase.UNDEFINED
@@ -137,6 +147,7 @@ def main():
graph_manager = short_dynamic_import(expand_preset(args.preset), ignore_module_case=True)
data_store = None
if args.memory_backend_params:
args.memory_backend_params = json.loads(args.memory_backend_params)
print(args.memory_backend_params)
@@ -156,6 +167,7 @@ def main():
rollout_worker(
graph_manager=graph_manager,
checkpoint_dir=args.checkpoint_dir,
data_store=data_store
)
if __name__ == '__main__':