1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

adding the missing export_onnx_graph parameter to task parameters (#73)

This commit is contained in:
Itai Caspi
2018-11-08 12:52:42 +02:00
committed by GitHub
parent 8f0415b4cc
commit 83e0b09a6a
3 changed files with 12 additions and 6 deletions

View File

@@ -19,7 +19,7 @@ import tensorflow as tf
from rl_coach.architectures.tensorflow_components.layers import Dense from rl_coach.architectures.tensorflow_components.layers import Dense
from rl_coach.architectures.tensorflow_components.heads.head import Head, normalized_columns_initializer from rl_coach.architectures.tensorflow_components.heads.head import Head, normalized_columns_initializer
from rl_coach.base_parameters import AgentParameters from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters
from rl_coach.core_types import ActionProbabilities from rl_coach.core_types import ActionProbabilities
from rl_coach.spaces import BoxActionSpace, DiscreteActionSpace from rl_coach.spaces import BoxActionSpace, DiscreteActionSpace
from rl_coach.spaces import SpacesDefinition from rl_coach.spaces import SpacesDefinition

View File

@@ -436,7 +436,7 @@ class AgentParameters(Parameters):
class TaskParameters(Parameters): class TaskParameters(Parameters):
def __init__(self, framework_type: Frameworks=Frameworks.tensorflow, evaluate_only: bool=False, use_cpu: bool=False, def __init__(self, framework_type: Frameworks=Frameworks.tensorflow, evaluate_only: bool=False, use_cpu: bool=False,
experiment_path='/tmp', seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None, experiment_path='/tmp', seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None,
checkpoint_save_dir=None): checkpoint_save_dir=None, export_onnx_graph: bool=False):
""" """
:param framework_type: deep learning framework type. currently only tensorflow is supported :param framework_type: deep learning framework type. currently only tensorflow is supported
:param evaluate_only: the task will be used only for evaluating the model :param evaluate_only: the task will be used only for evaluating the model
@@ -446,6 +446,7 @@ class TaskParameters(Parameters):
:param checkpoint_save_secs: the number of seconds between each checkpoint saving :param checkpoint_save_secs: the number of seconds between each checkpoint saving
:param checkpoint_restore_dir: the directory to restore the checkpoints from :param checkpoint_restore_dir: the directory to restore the checkpoints from
:param checkpoint_save_dir: the directory to store the checkpoints in :param checkpoint_save_dir: the directory to store the checkpoints in
:param export_onnx_graph: If set to True, this will export an onnx graph each time a checkpoint is saved
""" """
self.framework_type = framework_type self.framework_type = framework_type
self.task_index = 0 # TODO: not really needed self.task_index = 0 # TODO: not really needed
@@ -456,6 +457,7 @@ class TaskParameters(Parameters):
self.checkpoint_restore_dir = checkpoint_restore_dir self.checkpoint_restore_dir = checkpoint_restore_dir
self.checkpoint_save_dir = checkpoint_save_dir self.checkpoint_save_dir = checkpoint_save_dir
self.seed = seed self.seed = seed
self.export_onnx_graph = export_onnx_graph
class DistributedTaskParameters(TaskParameters): class DistributedTaskParameters(TaskParameters):
@@ -463,7 +465,7 @@ class DistributedTaskParameters(TaskParameters):
task_index: int, evaluate_only: bool=False, num_tasks: int=None, task_index: int, evaluate_only: bool=False, num_tasks: int=None,
num_training_tasks: int=None, use_cpu: bool=False, experiment_path=None, dnd=None, num_training_tasks: int=None, use_cpu: bool=False, experiment_path=None, dnd=None,
shared_memory_scratchpad=None, seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None, shared_memory_scratchpad=None, seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None,
checkpoint_save_dir=None): checkpoint_save_dir=None, export_onnx_graph: bool=False):
""" """
:param framework_type: deep learning framework type. currently only tensorflow is supported :param framework_type: deep learning framework type. currently only tensorflow is supported
:param evaluate_only: the task will be used only for evaluating the model :param evaluate_only: the task will be used only for evaluating the model
@@ -481,10 +483,12 @@ class DistributedTaskParameters(TaskParameters):
:param checkpoint_save_secs: the number of seconds between each checkpoint saving :param checkpoint_save_secs: the number of seconds between each checkpoint saving
:param checkpoint_restore_dir: the directory to restore the checkpoints from :param checkpoint_restore_dir: the directory to restore the checkpoints from
:param checkpoint_save_dir: the directory to store the checkpoints in :param checkpoint_save_dir: the directory to store the checkpoints in
:param export_onnx_graph: If set to True, this will export an onnx graph each time a checkpoint is saved
""" """
super().__init__(framework_type=framework_type, evaluate_only=evaluate_only, use_cpu=use_cpu, super().__init__(framework_type=framework_type, evaluate_only=evaluate_only, use_cpu=use_cpu,
experiment_path=experiment_path, seed=seed, checkpoint_save_secs=checkpoint_save_secs, experiment_path=experiment_path, seed=seed, checkpoint_save_secs=checkpoint_save_secs,
checkpoint_restore_dir=checkpoint_restore_dir, checkpoint_save_dir=checkpoint_save_dir) checkpoint_restore_dir=checkpoint_restore_dir, checkpoint_save_dir=checkpoint_save_dir,
export_onnx_graph=export_onnx_graph)
self.parameters_server_hosts = parameters_server_hosts self.parameters_server_hosts = parameters_server_hosts
self.worker_hosts = worker_hosts self.worker_hosts = worker_hosts
self.job_type = job_type self.job_type = job_type

View File

@@ -535,7 +535,8 @@ def main():
use_cpu=args.use_cpu, use_cpu=args.use_cpu,
checkpoint_save_secs=args.checkpoint_save_secs, checkpoint_save_secs=args.checkpoint_save_secs,
checkpoint_restore_dir=args.checkpoint_restore_dir, checkpoint_restore_dir=args.checkpoint_restore_dir,
checkpoint_save_dir=args.checkpoint_save_dir checkpoint_save_dir=args.checkpoint_save_dir,
export_onnx_graph=args.export_onnx_graph
) )
start_graph(graph_manager=graph_manager, task_parameters=task_parameters) start_graph(graph_manager=graph_manager, task_parameters=task_parameters)
@@ -574,7 +575,8 @@ def main():
seed=args.seed+task_index if args.seed is not None else None, # each worker gets a different seed seed=args.seed+task_index if args.seed is not None else None, # each worker gets a different seed
checkpoint_save_secs=args.checkpoint_save_secs, checkpoint_save_secs=args.checkpoint_save_secs,
checkpoint_restore_dir=args.checkpoint_restore_dir, checkpoint_restore_dir=args.checkpoint_restore_dir,
checkpoint_save_dir=args.checkpoint_save_dir checkpoint_save_dir=args.checkpoint_save_dir,
export_onnx_graph=args.export_onnx_graph
) )
# we assume that only the evaluation workers are rendering # we assume that only the evaluation workers are rendering