mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Checkpoint and evaluation optimizations
This commit is contained in:
committed by
zach dwiel
parent
b285a02023
commit
fb1039fcb5
@@ -11,6 +11,7 @@ import argparse
|
||||
import time
|
||||
import os
|
||||
import json
|
||||
import math
|
||||
|
||||
from threading import Thread
|
||||
|
||||
@@ -69,20 +70,16 @@ def data_store_ckpt_load(data_store):
|
||||
time.sleep(10)
|
||||
|
||||
|
||||
def check_for_new_checkpoint(checkpoint_dir, last_checkpoint):
|
||||
def get_latest_checkpoint(checkpoint_dir):
|
||||
if os.path.exists(os.path.join(checkpoint_dir, 'checkpoint')):
|
||||
ckpt = CheckpointState()
|
||||
contents = open(os.path.join(checkpoint_dir, 'checkpoint'), 'r').read()
|
||||
text_format.Merge(contents, ckpt)
|
||||
rel_path = os.path.relpath(ckpt.model_checkpoint_path, checkpoint_dir)
|
||||
current_checkpoint = int(rel_path.split('_Step')[0])
|
||||
if current_checkpoint > last_checkpoint:
|
||||
last_checkpoint = current_checkpoint
|
||||
|
||||
return last_checkpoint
|
||||
return int(rel_path.split('_Step')[0])
|
||||
|
||||
|
||||
def rollout_worker(graph_manager, checkpoint_dir, data_store):
|
||||
def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers, policy_type):
|
||||
"""
|
||||
wait for first checkpoint then perform rollouts using the model
|
||||
"""
|
||||
@@ -98,22 +95,28 @@ def rollout_worker(graph_manager, checkpoint_dir, data_store):
|
||||
|
||||
last_checkpoint = 0
|
||||
|
||||
act_steps = graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps + error_compensation
|
||||
|
||||
print(act_steps, graph_manager.improve_steps.num_steps)
|
||||
act_steps = math.ceil((graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps + error_compensation)/num_workers)
|
||||
|
||||
for i in range(int(graph_manager.improve_steps.num_steps/act_steps)):
|
||||
|
||||
graph_manager.act(EnvironmentSteps(num_steps=act_steps))
|
||||
|
||||
new_checkpoint = last_checkpoint + 1
|
||||
while last_checkpoint < new_checkpoint:
|
||||
if data_store:
|
||||
data_store.load_from_store()
|
||||
last_checkpoint = check_for_new_checkpoint(checkpoint_dir, last_checkpoint)
|
||||
new_checkpoint = get_latest_checkpoint(checkpoint_dir)
|
||||
|
||||
if policy_type == 'ON':
|
||||
while new_checkpoint < last_checkpoint + 1:
|
||||
if data_store:
|
||||
data_store.load_from_store()
|
||||
new_checkpoint = get_latest_checkpoint(checkpoint_dir)
|
||||
|
||||
graph_manager.restore_checkpoint()
|
||||
|
||||
if policy_type == "OFF":
|
||||
|
||||
if new_checkpoint > last_checkpoint:
|
||||
graph_manager.restore_checkpoint()
|
||||
|
||||
last_checkpoint = new_checkpoint
|
||||
graph_manager.restore_checkpoint()
|
||||
|
||||
graph_manager.phase = RunPhase.UNDEFINED
|
||||
|
||||
@@ -134,6 +137,14 @@ def main():
|
||||
parser.add_argument('--data-store-params',
|
||||
help="(string) JSON string of the data store params",
|
||||
type=str)
|
||||
parser.add_argument('--num-workers',
|
||||
help="(int) The number of workers started in this pool",
|
||||
type=int,
|
||||
default=1)
|
||||
parser.add_argument('--policy-type',
|
||||
help="(string) The type of policy: OFF/ON",
|
||||
type=str,
|
||||
default='OFF')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -142,9 +153,7 @@ def main():
|
||||
data_store = None
|
||||
if args.memory_backend_params:
|
||||
args.memory_backend_params = json.loads(args.memory_backend_params)
|
||||
print(args.memory_backend_params)
|
||||
args.memory_backend_params['run_type'] = 'worker'
|
||||
print(construct_memory_params(args.memory_backend_params))
|
||||
graph_manager.agent_params.memory.register_var('memory_backend_params', construct_memory_params(args.memory_backend_params))
|
||||
|
||||
if args.data_store_params:
|
||||
@@ -159,7 +168,9 @@ def main():
|
||||
rollout_worker(
|
||||
graph_manager=graph_manager,
|
||||
checkpoint_dir=args.checkpoint_dir,
|
||||
data_store=data_store
|
||||
data_store=data_store,
|
||||
num_workers=args.num_workers,
|
||||
policy_type=args.policy_type
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user