mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
Tf checkpointing using saver mechanism (#134)
This commit is contained in:
committed by
Gal Leibovich
parent
dd18959e53
commit
16cdd9a9c1
@@ -21,6 +21,7 @@ import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.architecture import Architecture
|
||||
from rl_coach.architectures.tensorflow_components.savers import GlobalVariableSaver
|
||||
from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters
|
||||
from rl_coach.core_types import GradientClippingMethod
|
||||
from rl_coach.saver import SaverCollection
|
||||
@@ -645,8 +646,10 @@ class TensorFlowArchitecture(Architecture):
|
||||
(e.g. could be name of level manager plus name of agent)
|
||||
:return: checkpoint collection for the network
|
||||
"""
|
||||
# TODO implement returning checkpoints for tensorflow
|
||||
return SaverCollection()
|
||||
savers = SaverCollection()
|
||||
if not self.distributed_training:
|
||||
savers.add(GlobalVariableSaver(self.name))
|
||||
return savers
|
||||
|
||||
|
||||
def save_onnx_graph(input_nodes, output_nodes, checkpoint_save_dir: str) -> None:
|
||||
|
||||
Reference in New Issue
Block a user