1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

restoring from a checkpoint file (#247)

This commit is contained in:
Gal Leibovich
2019-03-17 16:28:09 +02:00
committed by GitHub
parent f03bd7ad93
commit d6158a5cfc
6 changed files with 87 additions and 39 deletions

View File

@@ -25,6 +25,7 @@ from typing import Dict, List, Union
from rl_coach.core_types import TrainingSteps, EnvironmentSteps, GradientClippingMethod, RunPhase, \
SelectedPhaseOnlyDumpFilter, MaxDumpFilter
from rl_coach.filters.filter import NoInputFilter
from rl_coach.logger import screen
class Frameworks(Enum):
@@ -552,8 +553,8 @@ class AgentParameters(Parameters):
class TaskParameters(Parameters):
def __init__(self, framework_type: Frameworks=Frameworks.tensorflow, evaluate_only: int=None, use_cpu: bool=False,
experiment_path='/tmp', seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None,
checkpoint_save_dir=None, export_onnx_graph: bool=False, apply_stop_condition: bool=False,
num_gpu: int=1):
checkpoint_restore_path=None, checkpoint_save_dir=None, export_onnx_graph: bool=False,
apply_stop_condition: bool=False, num_gpu: int=1):
"""
:param framework_type: deep learning framework type. currently only tensorflow is supported
:param evaluate_only: if not None, the task will be used only for evaluating the model for the given number of steps.
@@ -562,7 +563,10 @@ class TaskParameters(Parameters):
:param experiment_path: the path to the directory which will store all the experiment outputs
:param seed: a seed to use for the random numbers generator
:param checkpoint_save_secs: the number of seconds between each checkpoint saving
:param checkpoint_restore_dir: the directory to restore the checkpoints from
:param checkpoint_restore_dir:
[DEPECRATED - will be removed in one of the next releases - switch to checkpoint_restore_path]
the dir to restore the checkpoints from
:param checkpoint_restore_path: the path to restore the checkpoints from
:param checkpoint_save_dir: the directory to store the checkpoints in
:param export_onnx_graph: If set to True, this will export an onnx graph each time a checkpoint is saved
:param apply_stop_condition: If set to True, this will apply the stop condition defined by reaching a target success rate
@@ -574,7 +578,13 @@ class TaskParameters(Parameters):
self.use_cpu = use_cpu
self.experiment_path = experiment_path
self.checkpoint_save_secs = checkpoint_save_secs
self.checkpoint_restore_dir = checkpoint_restore_dir
if checkpoint_restore_dir:
screen.warning('TaskParameters.checkpoint_restore_dir is DEPECRATED and will be removed in one of the next '
'releases. Please switch to using TaskParameters.checkpoint_restore_path, with your '
'directory path. ')
self.checkpoint_restore_path = checkpoint_restore_dir
else:
self.checkpoint_restore_path = checkpoint_restore_path
self.checkpoint_save_dir = checkpoint_save_dir
self.seed = seed
self.export_onnx_graph = export_onnx_graph
@@ -586,7 +596,7 @@ class DistributedTaskParameters(TaskParameters):
def __init__(self, framework_type: Frameworks, parameters_server_hosts: str, worker_hosts: str, job_type: str,
task_index: int, evaluate_only: int=None, num_tasks: int=None,
num_training_tasks: int=None, use_cpu: bool=False, experiment_path=None, dnd=None,
shared_memory_scratchpad=None, seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None,
shared_memory_scratchpad=None, seed=None, checkpoint_save_secs=None, checkpoint_restore_path=None,
checkpoint_save_dir=None, export_onnx_graph: bool=False, apply_stop_condition: bool=False):
"""
:param framework_type: deep learning framework type. currently only tensorflow is supported
@@ -604,7 +614,7 @@ class DistributedTaskParameters(TaskParameters):
:param dnd: an external DND to use for NEC. This is a workaround needed for a shared DND not using the scratchpad.
:param seed: a seed to use for the random numbers generator
:param checkpoint_save_secs: the number of seconds between each checkpoint saving
:param checkpoint_restore_dir: the directory to restore the checkpoints from
:param checkpoint_restore_path: the path to restore the checkpoints from
:param checkpoint_save_dir: the directory to store the checkpoints in
:param export_onnx_graph: If set to True, this will export an onnx graph each time a checkpoint is saved
:param apply_stop_condition: If set to True, this will apply the stop condition defined by reaching a target success rate
@@ -612,7 +622,7 @@ class DistributedTaskParameters(TaskParameters):
"""
super().__init__(framework_type=framework_type, evaluate_only=evaluate_only, use_cpu=use_cpu,
experiment_path=experiment_path, seed=seed, checkpoint_save_secs=checkpoint_save_secs,
checkpoint_restore_dir=checkpoint_restore_dir, checkpoint_save_dir=checkpoint_save_dir,
checkpoint_restore_path=checkpoint_restore_path, checkpoint_save_dir=checkpoint_save_dir,
export_onnx_graph=export_onnx_graph, apply_stop_condition=apply_stop_condition)
self.parameters_server_hosts = parameters_server_hosts
self.worker_hosts = worker_hosts