mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Parallel agents fixes (#95)
* Parallel agents related bug fixes: checkpoint restore, tensorboard integration. Adding narrow networks support. Reference code for unlimited number of checkpoints
This commit is contained in:
@@ -128,11 +128,14 @@ if __name__ == "__main__":
|
||||
def init_fn(scaffold, session):
|
||||
session.run(init_all_op)
|
||||
|
||||
|
||||
#saver = tf.train.Saver(max_to_keep=None) # uncomment to unlimit number of stored checkpoints
|
||||
scaffold = tf.train.Scaffold(init_op=init_all_op,
|
||||
init_fn=init_fn,
|
||||
ready_op=ready_op,
|
||||
ready_for_local_init_op=ready_for_local_init_op,
|
||||
local_init_op=local_init_op)
|
||||
#saver=saver) # uncomment to unlimit number of stored checkpoints
|
||||
|
||||
# Due to awkward tensorflow behavior where the same variable is used to decide whether to restore a model
|
||||
# (and where from), or just save the model (and where to), we employ the below. In case where a restore folder
|
||||
@@ -156,6 +159,10 @@ if __name__ == "__main__":
|
||||
tuning_parameters.sess = sess
|
||||
for network in agent.networks:
|
||||
network.set_session(sess)
|
||||
# if hasattr(network.global_network, 'lock_init'):
|
||||
# sess.run(network.global_network.lock_init)
|
||||
# if hasattr(network.global_network, 'release_init'):
|
||||
# sess.run(network.global_network.release_init)
|
||||
|
||||
if tuning_parameters.visualization.tensorboard:
|
||||
# Write the merged summaries to the current experiment directory
|
||||
|
||||
Reference in New Issue
Block a user