1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

Tf checkpointing using saver mechanism (#134)

This commit is contained in:
Sina Afrooze
2018-11-22 04:08:10 -08:00
committed by Gal Leibovich
parent dd18959e53
commit 16cdd9a9c1
6 changed files with 110 additions and 50 deletions

View File

@@ -21,6 +21,7 @@ import numpy as np
import tensorflow as tf
from rl_coach.architectures.architecture import Architecture
from rl_coach.architectures.tensorflow_components.savers import GlobalVariableSaver
from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters
from rl_coach.core_types import GradientClippingMethod
from rl_coach.saver import SaverCollection
@@ -645,8 +646,10 @@ class TensorFlowArchitecture(Architecture):
(e.g. could be name of level manager plus name of agent)
:return: checkpoint collection for the network
"""
# TODO implement returning checkpoints for tensorflow
return SaverCollection()
savers = SaverCollection()
if not self.distributed_training:
savers.add(GlobalVariableSaver(self.name))
return savers
def save_onnx_graph(input_nodes, output_nodes, checkpoint_save_dir: str) -> None: