mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Checkpoint and evaluation optimizations
This commit is contained in:
committed by
zach dwiel
parent
b285a02023
commit
fb1039fcb5
@@ -179,6 +179,7 @@ class Kubernetes(Deploy):
|
|||||||
|
|
||||||
worker_params.command += ['--memory-backend-params', json.dumps(self.params.memory_backend_parameters.__dict__)]
|
worker_params.command += ['--memory-backend-params', json.dumps(self.params.memory_backend_parameters.__dict__)]
|
||||||
worker_params.command += ['--data-store-params', json.dumps(self.params.data_store_params.__dict__)]
|
worker_params.command += ['--data-store-params', json.dumps(self.params.data_store_params.__dict__)]
|
||||||
|
worker_params.command += ['--num-workers', worker_params.num_replicas]
|
||||||
|
|
||||||
name = "{}-{}".format(worker_params.run_type, uuid.uuid4())
|
name = "{}-{}".format(worker_params.run_type, uuid.uuid4())
|
||||||
|
|
||||||
|
|||||||
@@ -8,9 +8,10 @@ from rl_coach.data_stores.nfs_data_store import NFSDataStoreParameters
|
|||||||
|
|
||||||
|
|
||||||
def main(preset: str, image: str='ajaysudh/testing:coach', num_workers: int=1, nfs_server: str=None, nfs_path: str=None,
|
def main(preset: str, image: str='ajaysudh/testing:coach', num_workers: int=1, nfs_server: str=None, nfs_path: str=None,
|
||||||
memory_backend: str=None, data_store: str=None, s3_end_point: str=None, s3_bucket_name: str=None):
|
memory_backend: str=None, data_store: str=None, s3_end_point: str=None, s3_bucket_name: str=None,
|
||||||
rollout_command = ['python3', 'rl_coach/rollout_worker.py', '-p', preset]
|
policy_type: str="OFF"):
|
||||||
training_command = ['python3', 'rl_coach/training_worker.py', '-p', preset]
|
rollout_command = ['python3', 'rl_coach/rollout_worker.py', '-p', preset, '--policy-type', policy_type]
|
||||||
|
training_command = ['python3', 'rl_coach/training_worker.py', '-p', preset, '--policy-type', policy_type]
|
||||||
|
|
||||||
memory_backend_params = None
|
memory_backend_params = None
|
||||||
if memory_backend == "redispubsub":
|
if memory_backend == "redispubsub":
|
||||||
@@ -95,6 +96,10 @@ if __name__ == '__main__':
|
|||||||
type=int,
|
type=int,
|
||||||
required=False,
|
required=False,
|
||||||
default=1)
|
default=1)
|
||||||
|
parser.add_argument('--policy-type',
|
||||||
|
help="(string) The type of policy: OFF/ON",
|
||||||
|
type=str,
|
||||||
|
default='OFF')
|
||||||
|
|
||||||
# parser.add_argument('--checkpoint_dir',
|
# parser.add_argument('--checkpoint_dir',
|
||||||
# help='(string) Path to a folder containing a checkpoint to write the model to.',
|
# help='(string) Path to a folder containing a checkpoint to write the model to.',
|
||||||
@@ -104,4 +109,4 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
main(preset=args.preset, image=args.image, nfs_server=args.nfs_server, nfs_path=args.nfs_path,
|
main(preset=args.preset, image=args.image, nfs_server=args.nfs_server, nfs_path=args.nfs_path,
|
||||||
memory_backend=args.memory_backend, data_store=args.data_store, s3_end_point=args.s3_end_point,
|
memory_backend=args.memory_backend, data_store=args.data_store, s3_end_point=args.s3_end_point,
|
||||||
s3_bucket_name=args.s3_bucket_name, num_workers=args.num_workers)
|
s3_bucket_name=args.s3_bucket_name, num_workers=args.num_workers, policy_type=args.policy_type)
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import argparse
|
|||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
import math
|
||||||
|
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
@@ -69,20 +70,16 @@ def data_store_ckpt_load(data_store):
|
|||||||
time.sleep(10)
|
time.sleep(10)
|
||||||
|
|
||||||
|
|
||||||
def check_for_new_checkpoint(checkpoint_dir, last_checkpoint):
|
def get_latest_checkpoint(checkpoint_dir):
|
||||||
if os.path.exists(os.path.join(checkpoint_dir, 'checkpoint')):
|
if os.path.exists(os.path.join(checkpoint_dir, 'checkpoint')):
|
||||||
ckpt = CheckpointState()
|
ckpt = CheckpointState()
|
||||||
contents = open(os.path.join(checkpoint_dir, 'checkpoint'), 'r').read()
|
contents = open(os.path.join(checkpoint_dir, 'checkpoint'), 'r').read()
|
||||||
text_format.Merge(contents, ckpt)
|
text_format.Merge(contents, ckpt)
|
||||||
rel_path = os.path.relpath(ckpt.model_checkpoint_path, checkpoint_dir)
|
rel_path = os.path.relpath(ckpt.model_checkpoint_path, checkpoint_dir)
|
||||||
current_checkpoint = int(rel_path.split('_Step')[0])
|
return int(rel_path.split('_Step')[0])
|
||||||
if current_checkpoint > last_checkpoint:
|
|
||||||
last_checkpoint = current_checkpoint
|
|
||||||
|
|
||||||
return last_checkpoint
|
|
||||||
|
|
||||||
|
|
||||||
def rollout_worker(graph_manager, checkpoint_dir, data_store):
|
def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers, policy_type):
|
||||||
"""
|
"""
|
||||||
wait for first checkpoint then perform rollouts using the model
|
wait for first checkpoint then perform rollouts using the model
|
||||||
"""
|
"""
|
||||||
@@ -98,22 +95,28 @@ def rollout_worker(graph_manager, checkpoint_dir, data_store):
|
|||||||
|
|
||||||
last_checkpoint = 0
|
last_checkpoint = 0
|
||||||
|
|
||||||
act_steps = graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps + error_compensation
|
act_steps = math.ceil((graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps + error_compensation)/num_workers)
|
||||||
|
|
||||||
print(act_steps, graph_manager.improve_steps.num_steps)
|
|
||||||
|
|
||||||
for i in range(int(graph_manager.improve_steps.num_steps/act_steps)):
|
for i in range(int(graph_manager.improve_steps.num_steps/act_steps)):
|
||||||
|
|
||||||
graph_manager.act(EnvironmentSteps(num_steps=act_steps))
|
graph_manager.act(EnvironmentSteps(num_steps=act_steps))
|
||||||
|
|
||||||
new_checkpoint = last_checkpoint + 1
|
new_checkpoint = get_latest_checkpoint(checkpoint_dir)
|
||||||
while last_checkpoint < new_checkpoint:
|
|
||||||
|
if policy_type == 'ON':
|
||||||
|
while new_checkpoint < last_checkpoint + 1:
|
||||||
if data_store:
|
if data_store:
|
||||||
data_store.load_from_store()
|
data_store.load_from_store()
|
||||||
last_checkpoint = check_for_new_checkpoint(checkpoint_dir, last_checkpoint)
|
new_checkpoint = get_latest_checkpoint(checkpoint_dir)
|
||||||
|
|
||||||
|
graph_manager.restore_checkpoint()
|
||||||
|
|
||||||
|
if policy_type == "OFF":
|
||||||
|
|
||||||
|
if new_checkpoint > last_checkpoint:
|
||||||
|
graph_manager.restore_checkpoint()
|
||||||
|
|
||||||
last_checkpoint = new_checkpoint
|
last_checkpoint = new_checkpoint
|
||||||
graph_manager.restore_checkpoint()
|
|
||||||
|
|
||||||
graph_manager.phase = RunPhase.UNDEFINED
|
graph_manager.phase = RunPhase.UNDEFINED
|
||||||
|
|
||||||
@@ -134,6 +137,14 @@ def main():
|
|||||||
parser.add_argument('--data-store-params',
|
parser.add_argument('--data-store-params',
|
||||||
help="(string) JSON string of the data store params",
|
help="(string) JSON string of the data store params",
|
||||||
type=str)
|
type=str)
|
||||||
|
parser.add_argument('--num-workers',
|
||||||
|
help="(int) The number of workers started in this pool",
|
||||||
|
type=int,
|
||||||
|
default=1)
|
||||||
|
parser.add_argument('--policy-type',
|
||||||
|
help="(string) The type of policy: OFF/ON",
|
||||||
|
type=str,
|
||||||
|
default='OFF')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -142,9 +153,7 @@ def main():
|
|||||||
data_store = None
|
data_store = None
|
||||||
if args.memory_backend_params:
|
if args.memory_backend_params:
|
||||||
args.memory_backend_params = json.loads(args.memory_backend_params)
|
args.memory_backend_params = json.loads(args.memory_backend_params)
|
||||||
print(args.memory_backend_params)
|
|
||||||
args.memory_backend_params['run_type'] = 'worker'
|
args.memory_backend_params['run_type'] = 'worker'
|
||||||
print(construct_memory_params(args.memory_backend_params))
|
|
||||||
graph_manager.agent_params.memory.register_var('memory_backend_params', construct_memory_params(args.memory_backend_params))
|
graph_manager.agent_params.memory.register_var('memory_backend_params', construct_memory_params(args.memory_backend_params))
|
||||||
|
|
||||||
if args.data_store_params:
|
if args.data_store_params:
|
||||||
@@ -159,7 +168,9 @@ def main():
|
|||||||
rollout_worker(
|
rollout_worker(
|
||||||
graph_manager=graph_manager,
|
graph_manager=graph_manager,
|
||||||
checkpoint_dir=args.checkpoint_dir,
|
checkpoint_dir=args.checkpoint_dir,
|
||||||
data_store=data_store
|
data_store=data_store,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
policy_type=args.policy_type
|
||||||
)
|
)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@@ -18,13 +18,14 @@ def data_store_ckpt_save(data_store):
|
|||||||
data_store.save_to_store()
|
data_store.save_to_store()
|
||||||
time.sleep(10)
|
time.sleep(10)
|
||||||
|
|
||||||
def training_worker(graph_manager, checkpoint_dir):
|
def training_worker(graph_manager, checkpoint_dir, policy_type):
|
||||||
"""
|
"""
|
||||||
restore a checkpoint then perform rollouts using the restored model
|
restore a checkpoint then perform rollouts using the restored model
|
||||||
"""
|
"""
|
||||||
# initialize graph
|
# initialize graph
|
||||||
task_parameters = TaskParameters()
|
task_parameters = TaskParameters()
|
||||||
task_parameters.__dict__['save_checkpoint_dir'] = checkpoint_dir
|
task_parameters.__dict__['save_checkpoint_dir'] = checkpoint_dir
|
||||||
|
task_parameters.__dict__['save_checkpoint_secs'] = 60
|
||||||
graph_manager.create_graph(task_parameters)
|
graph_manager.create_graph(task_parameters)
|
||||||
|
|
||||||
# save randomly initialized graph
|
# save randomly initialized graph
|
||||||
@@ -32,14 +33,26 @@ def training_worker(graph_manager, checkpoint_dir):
|
|||||||
|
|
||||||
# training loop
|
# training loop
|
||||||
steps = 0
|
steps = 0
|
||||||
|
|
||||||
|
# evaluation offset
|
||||||
|
eval_offset = 1
|
||||||
|
|
||||||
while(steps < graph_manager.improve_steps.num_steps):
|
while(steps < graph_manager.improve_steps.num_steps):
|
||||||
if graph_manager.should_train():
|
if graph_manager.should_train():
|
||||||
steps += 1
|
steps += 1
|
||||||
|
|
||||||
graph_manager.phase = core_types.RunPhase.TRAIN
|
graph_manager.phase = core_types.RunPhase.TRAIN
|
||||||
graph_manager.train(core_types.TrainingSteps(1))
|
graph_manager.train(core_types.TrainingSteps(1))
|
||||||
graph_manager.phase = core_types.RunPhase.UNDEFINED
|
graph_manager.phase = core_types.RunPhase.UNDEFINED
|
||||||
|
|
||||||
|
if steps * graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps > graph_manager.steps_between_evaluation_periods.num_steps * eval_offset:
|
||||||
graph_manager.evaluate(graph_manager.evaluation_steps)
|
graph_manager.evaluate(graph_manager.evaluation_steps)
|
||||||
|
eval_offset += 1
|
||||||
|
|
||||||
|
if policy_type == 'ON':
|
||||||
graph_manager.save_checkpoint()
|
graph_manager.save_checkpoint()
|
||||||
|
else:
|
||||||
|
graph_manager.occasionally_save_checkpoint()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -58,6 +71,10 @@ def main():
|
|||||||
parser.add_argument('--data-store-params',
|
parser.add_argument('--data-store-params',
|
||||||
help="(string) JSON string of the data store params",
|
help="(string) JSON string of the data store params",
|
||||||
type=str)
|
type=str)
|
||||||
|
parser.add_argument('--policy-type',
|
||||||
|
help="(string) The type of policy: OFF/ON",
|
||||||
|
type=str,
|
||||||
|
default='OFF')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
graph_manager = short_dynamic_import(expand_preset(args.preset), ignore_module_case=True)
|
graph_manager = short_dynamic_import(expand_preset(args.preset), ignore_module_case=True)
|
||||||
@@ -78,6 +95,7 @@ def main():
|
|||||||
training_worker(
|
training_worker(
|
||||||
graph_manager=graph_manager,
|
graph_manager=graph_manager,
|
||||||
checkpoint_dir=args.checkpoint_dir,
|
checkpoint_dir=args.checkpoint_dir,
|
||||||
|
policy_type=args.policy_type
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user