1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

Adding steps and waiting for new checkpoint

This commit is contained in:
Ajay Deshpande
2018-10-08 13:41:51 -07:00
committed by zach dwiel
parent 0e121c5762
commit 0f46877d7e
3 changed files with 23 additions and 16 deletions

View File

@@ -52,13 +52,10 @@ class S3DataStore(DataStore):
def save_to_store(self):
try:
print("Writing lock file")
self.mc.remove_object(self.params.bucket_name, self.params.lock_file)
self.mc.put_object(self.params.bucket_name, self.params.lock_file, io.BytesIO(b''), 0)
print("saving to s3")
checkpoint_file = None
for root, dirs, files in os.walk(self.params.checkpoint_dir):
for filename in files:
@@ -73,7 +70,6 @@ class S3DataStore(DataStore):
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
print("Deleting lock file")
self.mc.remove_object(self.params.bucket_name, self.params.lock_file)
except ResponseError as e:
@@ -81,7 +77,6 @@ class S3DataStore(DataStore):
def load_from_store(self):
try:
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint"))
while True:
@@ -95,8 +90,6 @@ class S3DataStore(DataStore):
break
time.sleep(10)
print("loading from s3")
ckpt = CheckpointState()
if os.path.exists(filename):
contents = open(filename, 'r').read()

View File

@@ -16,7 +16,7 @@ from threading import Thread
from rl_coach.base_parameters import TaskParameters
from rl_coach.coach import expand_preset
from rl_coach.core_types import EnvironmentEpisodes, RunPhase
from rl_coach.core_types import EnvironmentSteps, RunPhase
from rl_coach.utils import short_dynamic_import
from rl_coach.memories.backend.memory_impl import construct_memory_params
from rl_coach.data_stores.data_store_impl import get_data_store, construct_data_store_params
@@ -82,7 +82,7 @@ def check_for_new_checkpoint(checkpoint_dir, last_checkpoint):
return last_checkpoint
def rollout_worker(graph_manager, checkpoint_dir):
def rollout_worker(graph_manager, checkpoint_dir, data_store):
"""
wait for first checkpoint then perform rollouts using the model
"""
@@ -94,16 +94,26 @@ def rollout_worker(graph_manager, checkpoint_dir):
graph_manager.create_graph(task_parameters)
graph_manager.phase = RunPhase.TRAIN
error_compensation = 100
last_checkpoint = 0
for i in range(10000000):
graph_manager.act(EnvironmentEpisodes(num_steps=1))
act_steps = graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps + error_compensation
new_checkpoint = check_for_new_checkpoint(checkpoint_dir, last_checkpoint)
print(act_steps, graph_manager.improve_steps.num_steps)
if new_checkpoint > last_checkpoint:
last_checkpoint = new_checkpoint
graph_manager.restore_checkpoint()
for i in range(int(graph_manager.improve_steps.num_steps/act_steps)):
graph_manager.act(EnvironmentSteps(num_steps=act_steps))
new_checkpoint = last_checkpoint + 1
while last_checkpoint < new_checkpoint:
if data_store:
data_store.load_from_store()
last_checkpoint = check_for_new_checkpoint(checkpoint_dir, last_checkpoint)
last_checkpoint = new_checkpoint
graph_manager.restore_checkpoint()
graph_manager.phase = RunPhase.UNDEFINED
@@ -137,6 +147,7 @@ def main():
graph_manager = short_dynamic_import(expand_preset(args.preset), ignore_module_case=True)
data_store = None
if args.memory_backend_params:
args.memory_backend_params = json.loads(args.memory_backend_params)
print(args.memory_backend_params)
@@ -156,6 +167,7 @@ def main():
rollout_worker(
graph_manager=graph_manager,
checkpoint_dir=args.checkpoint_dir,
data_store=data_store
)
if __name__ == '__main__':

View File

@@ -31,8 +31,10 @@ def training_worker(graph_manager, checkpoint_dir):
graph_manager.save_checkpoint()
# training loop
while True:
steps = 0
while(steps < graph_manager.improve_steps.num_steps):
if graph_manager.should_train():
steps += 1
graph_manager.phase = core_types.RunPhase.TRAIN
graph_manager.train(core_types.TrainingSteps(1))
graph_manager.phase = core_types.RunPhase.UNDEFINED