mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
Adding steps and waiting for new checkpoint
This commit is contained in:
committed by
zach dwiel
parent
0e121c5762
commit
0f46877d7e
@@ -52,13 +52,10 @@ class S3DataStore(DataStore):
|
|||||||
|
|
||||||
def save_to_store(self):
|
def save_to_store(self):
|
||||||
try:
|
try:
|
||||||
print("Writing lock file")
|
|
||||||
|
|
||||||
self.mc.remove_object(self.params.bucket_name, self.params.lock_file)
|
self.mc.remove_object(self.params.bucket_name, self.params.lock_file)
|
||||||
|
|
||||||
self.mc.put_object(self.params.bucket_name, self.params.lock_file, io.BytesIO(b''), 0)
|
self.mc.put_object(self.params.bucket_name, self.params.lock_file, io.BytesIO(b''), 0)
|
||||||
|
|
||||||
print("saving to s3")
|
|
||||||
checkpoint_file = None
|
checkpoint_file = None
|
||||||
for root, dirs, files in os.walk(self.params.checkpoint_dir):
|
for root, dirs, files in os.walk(self.params.checkpoint_dir):
|
||||||
for filename in files:
|
for filename in files:
|
||||||
@@ -73,7 +70,6 @@ class S3DataStore(DataStore):
|
|||||||
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
|
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
|
||||||
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
|
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
|
||||||
|
|
||||||
print("Deleting lock file")
|
|
||||||
self.mc.remove_object(self.params.bucket_name, self.params.lock_file)
|
self.mc.remove_object(self.params.bucket_name, self.params.lock_file)
|
||||||
|
|
||||||
except ResponseError as e:
|
except ResponseError as e:
|
||||||
@@ -81,7 +77,6 @@ class S3DataStore(DataStore):
|
|||||||
|
|
||||||
def load_from_store(self):
|
def load_from_store(self):
|
||||||
try:
|
try:
|
||||||
|
|
||||||
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint"))
|
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "checkpoint"))
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@@ -95,8 +90,6 @@ class S3DataStore(DataStore):
|
|||||||
break
|
break
|
||||||
time.sleep(10)
|
time.sleep(10)
|
||||||
|
|
||||||
print("loading from s3")
|
|
||||||
|
|
||||||
ckpt = CheckpointState()
|
ckpt = CheckpointState()
|
||||||
if os.path.exists(filename):
|
if os.path.exists(filename):
|
||||||
contents = open(filename, 'r').read()
|
contents = open(filename, 'r').read()
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from threading import Thread
|
|||||||
|
|
||||||
from rl_coach.base_parameters import TaskParameters
|
from rl_coach.base_parameters import TaskParameters
|
||||||
from rl_coach.coach import expand_preset
|
from rl_coach.coach import expand_preset
|
||||||
from rl_coach.core_types import EnvironmentEpisodes, RunPhase
|
from rl_coach.core_types import EnvironmentSteps, RunPhase
|
||||||
from rl_coach.utils import short_dynamic_import
|
from rl_coach.utils import short_dynamic_import
|
||||||
from rl_coach.memories.backend.memory_impl import construct_memory_params
|
from rl_coach.memories.backend.memory_impl import construct_memory_params
|
||||||
from rl_coach.data_stores.data_store_impl import get_data_store, construct_data_store_params
|
from rl_coach.data_stores.data_store_impl import get_data_store, construct_data_store_params
|
||||||
@@ -82,7 +82,7 @@ def check_for_new_checkpoint(checkpoint_dir, last_checkpoint):
|
|||||||
return last_checkpoint
|
return last_checkpoint
|
||||||
|
|
||||||
|
|
||||||
def rollout_worker(graph_manager, checkpoint_dir):
|
def rollout_worker(graph_manager, checkpoint_dir, data_store):
|
||||||
"""
|
"""
|
||||||
wait for first checkpoint then perform rollouts using the model
|
wait for first checkpoint then perform rollouts using the model
|
||||||
"""
|
"""
|
||||||
@@ -94,16 +94,26 @@ def rollout_worker(graph_manager, checkpoint_dir):
|
|||||||
graph_manager.create_graph(task_parameters)
|
graph_manager.create_graph(task_parameters)
|
||||||
graph_manager.phase = RunPhase.TRAIN
|
graph_manager.phase = RunPhase.TRAIN
|
||||||
|
|
||||||
|
error_compensation = 100
|
||||||
|
|
||||||
last_checkpoint = 0
|
last_checkpoint = 0
|
||||||
|
|
||||||
for i in range(10000000):
|
act_steps = graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps + error_compensation
|
||||||
graph_manager.act(EnvironmentEpisodes(num_steps=1))
|
|
||||||
|
|
||||||
new_checkpoint = check_for_new_checkpoint(checkpoint_dir, last_checkpoint)
|
print(act_steps, graph_manager.improve_steps.num_steps)
|
||||||
|
|
||||||
if new_checkpoint > last_checkpoint:
|
for i in range(int(graph_manager.improve_steps.num_steps/act_steps)):
|
||||||
last_checkpoint = new_checkpoint
|
|
||||||
graph_manager.restore_checkpoint()
|
graph_manager.act(EnvironmentSteps(num_steps=act_steps))
|
||||||
|
|
||||||
|
new_checkpoint = last_checkpoint + 1
|
||||||
|
while last_checkpoint < new_checkpoint:
|
||||||
|
if data_store:
|
||||||
|
data_store.load_from_store()
|
||||||
|
last_checkpoint = check_for_new_checkpoint(checkpoint_dir, last_checkpoint)
|
||||||
|
|
||||||
|
last_checkpoint = new_checkpoint
|
||||||
|
graph_manager.restore_checkpoint()
|
||||||
|
|
||||||
graph_manager.phase = RunPhase.UNDEFINED
|
graph_manager.phase = RunPhase.UNDEFINED
|
||||||
|
|
||||||
@@ -137,6 +147,7 @@ def main():
|
|||||||
|
|
||||||
graph_manager = short_dynamic_import(expand_preset(args.preset), ignore_module_case=True)
|
graph_manager = short_dynamic_import(expand_preset(args.preset), ignore_module_case=True)
|
||||||
|
|
||||||
|
data_store = None
|
||||||
if args.memory_backend_params:
|
if args.memory_backend_params:
|
||||||
args.memory_backend_params = json.loads(args.memory_backend_params)
|
args.memory_backend_params = json.loads(args.memory_backend_params)
|
||||||
print(args.memory_backend_params)
|
print(args.memory_backend_params)
|
||||||
@@ -156,6 +167,7 @@ def main():
|
|||||||
rollout_worker(
|
rollout_worker(
|
||||||
graph_manager=graph_manager,
|
graph_manager=graph_manager,
|
||||||
checkpoint_dir=args.checkpoint_dir,
|
checkpoint_dir=args.checkpoint_dir,
|
||||||
|
data_store=data_store
|
||||||
)
|
)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@@ -31,8 +31,10 @@ def training_worker(graph_manager, checkpoint_dir):
|
|||||||
graph_manager.save_checkpoint()
|
graph_manager.save_checkpoint()
|
||||||
|
|
||||||
# training loop
|
# training loop
|
||||||
while True:
|
steps = 0
|
||||||
|
while(steps < graph_manager.improve_steps.num_steps):
|
||||||
if graph_manager.should_train():
|
if graph_manager.should_train():
|
||||||
|
steps += 1
|
||||||
graph_manager.phase = core_types.RunPhase.TRAIN
|
graph_manager.phase = core_types.RunPhase.TRAIN
|
||||||
graph_manager.train(core_types.TrainingSteps(1))
|
graph_manager.train(core_types.TrainingSteps(1))
|
||||||
graph_manager.phase = core_types.RunPhase.UNDEFINED
|
graph_manager.phase = core_types.RunPhase.UNDEFINED
|
||||||
|
|||||||
Reference in New Issue
Block a user