mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
Adding steps and waiting for new checkpoint
This commit is contained in:
committed by
zach dwiel
parent
0e121c5762
commit
0f46877d7e
@@ -52,13 +52,10 @@ class S3DataStore(DataStore):
|
||||
|
||||
def save_to_store(self):
|
||||
try:
|
||||
print("Writing lock file")
|
||||
|
||||
self.mc.remove_object(self.params.bucket_name, self.params.lock_file)
|
||||
|
||||
self.mc.put_object(self.params.bucket_name, self.params.lock_file, io.BytesIO(b''), 0)
|
||||
|
||||
print("saving to s3")
|
||||
checkpoint_file = None
|
||||
for root, dirs, files in os.walk(self.params.checkpoint_dir):
|
||||
for filename in files:
|
||||
@@ -73,7 +70,6 @@ class S3DataStore(DataStore):
|
||||
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
|
||||
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
|
||||
|
||||
print("Deleting lock file")
|
||||
self.mc.remove_object(self.params.bucket_name, self.params.lock_file)
|
||||
|
||||
except ResponseError as e:
|
||||
@@ -81,7 +77,6 @@ class S3DataStore(DataStore):
|
||||
|
||||
def load_from_store(self):
|
||||
try:
|
||||
|
||||
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint"))
|
||||
|
||||
while True:
|
||||
@@ -95,8 +90,6 @@ class S3DataStore(DataStore):
|
||||
break
|
||||
time.sleep(10)
|
||||
|
||||
print("loading from s3")
|
||||
|
||||
ckpt = CheckpointState()
|
||||
if os.path.exists(filename):
|
||||
contents = open(filename, 'r').read()
|
||||
|
||||
@@ -16,7 +16,7 @@ from threading import Thread
|
||||
|
||||
from rl_coach.base_parameters import TaskParameters
|
||||
from rl_coach.coach import expand_preset
|
||||
from rl_coach.core_types import EnvironmentEpisodes, RunPhase
|
||||
from rl_coach.core_types import EnvironmentSteps, RunPhase
|
||||
from rl_coach.utils import short_dynamic_import
|
||||
from rl_coach.memories.backend.memory_impl import construct_memory_params
|
||||
from rl_coach.data_stores.data_store_impl import get_data_store, construct_data_store_params
|
||||
@@ -82,7 +82,7 @@ def check_for_new_checkpoint(checkpoint_dir, last_checkpoint):
|
||||
return last_checkpoint
|
||||
|
||||
|
||||
def rollout_worker(graph_manager, checkpoint_dir):
|
||||
def rollout_worker(graph_manager, checkpoint_dir, data_store):
|
||||
"""
|
||||
wait for first checkpoint then perform rollouts using the model
|
||||
"""
|
||||
@@ -94,14 +94,24 @@ def rollout_worker(graph_manager, checkpoint_dir):
|
||||
graph_manager.create_graph(task_parameters)
|
||||
graph_manager.phase = RunPhase.TRAIN
|
||||
|
||||
error_compensation = 100
|
||||
|
||||
last_checkpoint = 0
|
||||
|
||||
for i in range(10000000):
|
||||
graph_manager.act(EnvironmentEpisodes(num_steps=1))
|
||||
act_steps = graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps + error_compensation
|
||||
|
||||
new_checkpoint = check_for_new_checkpoint(checkpoint_dir, last_checkpoint)
|
||||
print(act_steps, graph_manager.improve_steps.num_steps)
|
||||
|
||||
for i in range(int(graph_manager.improve_steps.num_steps/act_steps)):
|
||||
|
||||
graph_manager.act(EnvironmentSteps(num_steps=act_steps))
|
||||
|
||||
new_checkpoint = last_checkpoint + 1
|
||||
while last_checkpoint < new_checkpoint:
|
||||
if data_store:
|
||||
data_store.load_from_store()
|
||||
last_checkpoint = check_for_new_checkpoint(checkpoint_dir, last_checkpoint)
|
||||
|
||||
if new_checkpoint > last_checkpoint:
|
||||
last_checkpoint = new_checkpoint
|
||||
graph_manager.restore_checkpoint()
|
||||
|
||||
@@ -137,6 +147,7 @@ def main():
|
||||
|
||||
graph_manager = short_dynamic_import(expand_preset(args.preset), ignore_module_case=True)
|
||||
|
||||
data_store = None
|
||||
if args.memory_backend_params:
|
||||
args.memory_backend_params = json.loads(args.memory_backend_params)
|
||||
print(args.memory_backend_params)
|
||||
@@ -156,6 +167,7 @@ def main():
|
||||
rollout_worker(
|
||||
graph_manager=graph_manager,
|
||||
checkpoint_dir=args.checkpoint_dir,
|
||||
data_store=data_store
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -31,8 +31,10 @@ def training_worker(graph_manager, checkpoint_dir):
|
||||
graph_manager.save_checkpoint()
|
||||
|
||||
# training loop
|
||||
while True:
|
||||
steps = 0
|
||||
while(steps < graph_manager.improve_steps.num_steps):
|
||||
if graph_manager.should_train():
|
||||
steps += 1
|
||||
graph_manager.phase = core_types.RunPhase.TRAIN
|
||||
graph_manager.train(core_types.TrainingSteps(1))
|
||||
graph_manager.phase = core_types.RunPhase.UNDEFINED
|
||||
|
||||
Reference in New Issue
Block a user