mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
Checkpoint and evaluation optimizations
This commit is contained in:
committed by
zach dwiel
parent
b285a02023
commit
fb1039fcb5
@@ -18,13 +18,14 @@ def data_store_ckpt_save(data_store):
|
||||
data_store.save_to_store()
|
||||
time.sleep(10)
|
||||
|
||||
def training_worker(graph_manager, checkpoint_dir):
|
||||
def training_worker(graph_manager, checkpoint_dir, policy_type):
|
||||
"""
|
||||
restore a checkpoint then perform rollouts using the restored model
|
||||
"""
|
||||
# initialize graph
|
||||
task_parameters = TaskParameters()
|
||||
task_parameters.__dict__['save_checkpoint_dir'] = checkpoint_dir
|
||||
task_parameters.__dict__['save_checkpoint_secs'] = 60
|
||||
graph_manager.create_graph(task_parameters)
|
||||
|
||||
# save randomly initialized graph
|
||||
@@ -32,14 +33,26 @@ def training_worker(graph_manager, checkpoint_dir):
|
||||
|
||||
# training loop
|
||||
steps = 0
|
||||
|
||||
# evaluation offset
|
||||
eval_offset = 1
|
||||
|
||||
while(steps < graph_manager.improve_steps.num_steps):
|
||||
if graph_manager.should_train():
|
||||
steps += 1
|
||||
|
||||
graph_manager.phase = core_types.RunPhase.TRAIN
|
||||
graph_manager.train(core_types.TrainingSteps(1))
|
||||
graph_manager.phase = core_types.RunPhase.UNDEFINED
|
||||
graph_manager.evaluate(graph_manager.evaluation_steps)
|
||||
graph_manager.save_checkpoint()
|
||||
|
||||
if steps * graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps > graph_manager.steps_between_evaluation_periods.num_steps * eval_offset:
|
||||
graph_manager.evaluate(graph_manager.evaluation_steps)
|
||||
eval_offset += 1
|
||||
|
||||
if policy_type == 'ON':
|
||||
graph_manager.save_checkpoint()
|
||||
else:
|
||||
graph_manager.occasionally_save_checkpoint()
|
||||
|
||||
|
||||
def main():
|
||||
@@ -58,6 +71,10 @@ def main():
|
||||
parser.add_argument('--data-store-params',
|
||||
help="(string) JSON string of the data store params",
|
||||
type=str)
|
||||
parser.add_argument('--policy-type',
|
||||
help="(string) The type of policy: OFF/ON",
|
||||
type=str,
|
||||
default='OFF')
|
||||
args = parser.parse_args()
|
||||
|
||||
graph_manager = short_dynamic_import(expand_preset(args.preset), ignore_module_case=True)
|
||||
@@ -78,6 +95,7 @@ def main():
|
||||
training_worker(
|
||||
graph_manager=graph_manager,
|
||||
checkpoint_dir=args.checkpoint_dir,
|
||||
policy_type=args.policy_type
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user