mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ONNX] Create decorator to handle symbolic context (#84776)
- Create decorator to handle old style custom symbolics that require context - Deprecate `torch.onnx.SymbolicContext` in favor of `GraphContext`. Added deprecation message - Remove README reference of SymbolicContext Pull Request resolved: https://github.com/pytorch/pytorch/pull/84776 Approved by: https://github.com/AllenTiTaiWang, https://github.com/BowenBao
This commit is contained in:
committed by
PyTorch MergeBot
parent
723193ec16
commit
c42a408baa
@ -54,16 +54,14 @@ more robust to different graphs.
|
||||
|
||||
### Extra context for symbolic functions
|
||||
|
||||
> TODO(justinchuby): Update this after #84776 is merged.
|
||||
The first argument of a symbolic function is always a `GraphContext` object.
|
||||
|
||||
`GraphContext` contains all methods defined in a `torch.Graph` object and context
|
||||
for the symbolic function.
|
||||
|
||||
In general, symbolic functions only require inputs and attributes to
|
||||
the original node. In rare circumstances, extra context may be required.
|
||||
For example, symbolic function for `prim::Loop` needs access to the sub-block of
|
||||
the original node.
|
||||
A symbolic function that has a first arg (before the Graph object) with the
|
||||
type annotation of torch.onnx.SymbolicContext will be called with that additional context.
|
||||
During export, it is populated from `utils._run_symbolic_function`
|
||||
to contain the context for each node being converted.
|
||||
the original node. An example of a symbolic function needing context is
|
||||
`prim::Loop`. It needs access to the sub-block of the original node.
|
||||
|
||||
### Export inplace
|
||||
|
||||
|
@ -28,6 +28,8 @@ from . import ( # usort:skip. Keep the order instead of sorting lexicographical
|
||||
symbolic_opset17,
|
||||
utils,
|
||||
)
|
||||
|
||||
# TODO(After 1.13 release): Remove the deprecated SymbolicContext
|
||||
from ._exporter_states import ExportTypes, SymbolicContext
|
||||
from ._type_utils import JitScalarType
|
||||
from .errors import CheckerError # Backwards compatibility
|
||||
@ -67,8 +69,6 @@ __all__ = [
|
||||
"TrainingMode",
|
||||
"TensorProtoDataType",
|
||||
"JitScalarType",
|
||||
# Classes
|
||||
"SymbolicContext",
|
||||
# Public functions
|
||||
"export",
|
||||
"export_to_pretty_string",
|
||||
@ -87,7 +87,6 @@ __all__ = [
|
||||
|
||||
# Set namespace for exposed private names
|
||||
ExportTypes.__module__ = "torch.onnx"
|
||||
SymbolicContext.__module__ = "torch.onnx"
|
||||
JitScalarType.__module__ = "torch.onnx"
|
||||
|
||||
producer_name = "pytorch"
|
||||
@ -95,7 +94,7 @@ producer_version = _C_onnx.PRODUCER_VERSION
|
||||
|
||||
|
||||
@_deprecation.deprecated(
|
||||
since="1.12.0", removed_in="TBD", instructions="use `torch.onnx.export` instead"
|
||||
since="1.12.0", removed_in="1.14", instructions="use `torch.onnx.export` instead"
|
||||
)
|
||||
def _export(*args, **kwargs):
|
||||
return utils._export(*args, **kwargs)
|
||||
|
@ -4,6 +4,7 @@ import enum
|
||||
from typing import Dict
|
||||
|
||||
from torch import _C
|
||||
from torch.onnx import _deprecation
|
||||
|
||||
|
||||
class ExportTypes:
|
||||
@ -25,6 +26,12 @@ class SymbolicContext:
|
||||
onnx_block (_C.Block): Current ONNX block that converted nodes are being appended to.
|
||||
"""
|
||||
|
||||
@_deprecation.deprecated(
|
||||
"1.13",
|
||||
"1.14",
|
||||
# TODO(justinchuby): Fix the instruction when GraphContext is public.
|
||||
"remove the 'ctx' argument and annotate 'g: GraphContext' instead",
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
params_dict: Dict[str, _C.IValue],
|
||||
|
@ -24,9 +24,6 @@ from torch.onnx import ( # noqa: F401
|
||||
errors,
|
||||
symbolic_helper,
|
||||
)
|
||||
from torch.onnx._exporter_states import (
|
||||
SymbolicContext, # Special case class import for readability
|
||||
)
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal import _beartype, jit_utils, registration
|
||||
from torch.types import Number
|
||||
@ -6303,33 +6300,31 @@ def prim_tolist(g, input, dim_val, elem_ty_val):
|
||||
# -----------------------------------------------------------------------------
|
||||
@_onnx_symbolic("prim::device")
|
||||
@_beartype.beartype
|
||||
def prim_device(
|
||||
ctx: SymbolicContext, g: jit_utils.GraphContext, *inputs, **kwargs
|
||||
) -> None:
|
||||
output_type = ctx.cur_node.output().type()
|
||||
def prim_device(g: jit_utils.GraphContext, *inputs, **kwargs) -> None:
|
||||
output_type = g.original_node.output().type()
|
||||
if isinstance(output_type, _C.DeviceObjType):
|
||||
return None
|
||||
|
||||
return symbolic_helper._unimplemented(
|
||||
"prim::device",
|
||||
f"output type should be 'DeviceObjType', not '{output_type.kind()}'",
|
||||
ctx.cur_node.output(),
|
||||
g.original_node.output(),
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic("prim::Loop")
|
||||
@_beartype.beartype
|
||||
def prim_loop(ctx: SymbolicContext, g, *inputs, **attrs) -> List[_C.Value]:
|
||||
n = ctx.cur_node
|
||||
env = ctx.env
|
||||
params_dict = ctx.params_dict
|
||||
def prim_loop(g: jit_utils.GraphContext, *inputs, **attrs) -> List[_C.Value]:
|
||||
node = g.original_node
|
||||
env = g.env
|
||||
params_dict = g.params_dict
|
||||
|
||||
operator_export_type = GLOBALS.operator_export_type
|
||||
opset_version = GLOBALS.export_onnx_opset_version
|
||||
|
||||
old_blocks = tuple(n.blocks())
|
||||
old_blocks = tuple(node.blocks())
|
||||
new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks(
|
||||
g, "Loop", *inputs, outputs=n.outputsSize(), n_blocks=len(old_blocks)
|
||||
g, "Loop", *inputs, outputs=node.outputsSize(), n_blocks=len(old_blocks)
|
||||
)
|
||||
|
||||
for old_block, new_block_context in zip(old_blocks, new_block_contexts):
|
||||
@ -6358,7 +6353,7 @@ def prim_loop(ctx: SymbolicContext, g, *inputs, **attrs) -> List[_C.Value]:
|
||||
env,
|
||||
False,
|
||||
)
|
||||
new_op_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node(
|
||||
fixed_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node(
|
||||
new_node, opset_version
|
||||
)
|
||||
# Run shape type inference for Loop after subblock is converted.
|
||||
@ -6366,16 +6361,16 @@ def prim_loop(ctx: SymbolicContext, g, *inputs, **attrs) -> List[_C.Value]:
|
||||
torch._C._jit_pass_onnx_node_shape_type_inference(
|
||||
new_node, params_dict, opset_version
|
||||
)
|
||||
return new_op_outputs
|
||||
return fixed_outputs
|
||||
|
||||
|
||||
@_onnx_symbolic("prim::If")
|
||||
@_beartype.beartype
|
||||
def prim_if(ctx: SymbolicContext, g, *inputs, **attrs):
|
||||
n = ctx.cur_node
|
||||
block = ctx.onnx_block
|
||||
env = ctx.env
|
||||
params_dict = ctx.params_dict
|
||||
def prim_if(g: jit_utils.GraphContext, *inputs, **attrs) -> List[_C.Value]:
|
||||
n = g.original_node
|
||||
block = g.block
|
||||
env = g.env
|
||||
params_dict = g.params_dict
|
||||
|
||||
operator_export_type = GLOBALS.operator_export_type
|
||||
opset_version = GLOBALS.export_onnx_opset_version
|
||||
@ -6447,7 +6442,7 @@ def prim_if(ctx: SymbolicContext, g, *inputs, **attrs):
|
||||
env,
|
||||
False,
|
||||
)
|
||||
new_op_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node(
|
||||
fixed_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node(
|
||||
new_node, opset_version
|
||||
)
|
||||
# Run shape type inference for If after subblock is converted.
|
||||
@ -6455,44 +6450,44 @@ def prim_if(ctx: SymbolicContext, g, *inputs, **attrs):
|
||||
torch._C._jit_pass_onnx_node_shape_type_inference(
|
||||
new_node, params_dict, opset_version
|
||||
)
|
||||
return new_op_outputs
|
||||
return fixed_outputs
|
||||
|
||||
|
||||
@_onnx_symbolic("prim::Constant")
|
||||
@_beartype.beartype
|
||||
def prim_constant(ctx: SymbolicContext, g, *inputs, **attrs):
|
||||
n = ctx.cur_node
|
||||
def prim_constant(g: jit_utils.GraphContext, *inputs, **attrs):
|
||||
node = g.original_node
|
||||
|
||||
if n.mustBeNone():
|
||||
if node.mustBeNone():
|
||||
return None
|
||||
# This must go before checking for string values, because some device constants
|
||||
# have string values, but we want to keep them as unconverted Device types so
|
||||
# that eq() can work on them.
|
||||
if isinstance(n.output().type(), _C.DeviceObjType):
|
||||
if isinstance(node.output().type(), _C.DeviceObjType):
|
||||
return None
|
||||
if n.kindOf("value") == "t":
|
||||
return g.op("Constant", value_t=symbolic_helper._node_get(n, "value"))
|
||||
if n.kindOf("value") == "s":
|
||||
return g.op("Constant", value_s=symbolic_helper._node_get(n, "value"))
|
||||
if n.output().type().isSubtypeOf(
|
||||
if node.kindOf("value") == "t":
|
||||
return g.op("Constant", value_t=symbolic_helper._node_get(node, "value"))
|
||||
if node.kindOf("value") == "s":
|
||||
return g.op("Constant", value_s=symbolic_helper._node_get(node, "value"))
|
||||
if node.output().type().isSubtypeOf(
|
||||
_C.ListType.ofInts()
|
||||
) or n.output().type().isSubtypeOf(_C.ListType.ofFloats()):
|
||||
) or node.output().type().isSubtypeOf(_C.ListType.ofFloats()):
|
||||
return g.op(
|
||||
"Constant", value_t=torch.tensor(symbolic_helper._node_get(n, "value"))
|
||||
"Constant", value_t=torch.tensor(symbolic_helper._node_get(node, "value"))
|
||||
)
|
||||
|
||||
raise errors.SymbolicValueError(
|
||||
f"Unsupported prim::Constant kind: `{n.kindOf('value')}`. "
|
||||
f"Unsupported prim::Constant kind: '{node.kindOf('value')}'. "
|
||||
f"Please send a bug report at {_constants.PYTORCH_GITHUB_ISSUES_URL}.",
|
||||
n.output(),
|
||||
node.output(),
|
||||
)
|
||||
|
||||
|
||||
@_onnx_symbolic("onnx::Placeholder")
|
||||
@_beartype.beartype
|
||||
def onnx_placeholder(ctx: SymbolicContext, g, *inputs, **attrs):
|
||||
n = ctx.cur_node
|
||||
block = ctx.onnx_block
|
||||
env = ctx.env
|
||||
def onnx_placeholder(g: jit_utils.GraphContext, *inputs, **attrs):
|
||||
node = g.original_node
|
||||
block = g.block
|
||||
env = g.env
|
||||
|
||||
return torch._C._jit_onnx_convert_pattern_from_subblock(block, n, env)
|
||||
return torch._C._jit_onnx_convert_pattern_from_subblock(block, node, env)
|
||||
|
@ -1698,7 +1698,7 @@ def _should_aten_fallback(
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _need_symbolic_context(symbolic_fn) -> bool:
|
||||
def _need_symbolic_context(symbolic_fn: Callable) -> bool:
|
||||
"""Checks if the first argument to symbolic_fn is annotated as type `torch.onnx.SymbolicContext`."""
|
||||
params = tuple(inspect.signature(symbolic_fn).parameters.values())
|
||||
# When the annotation is postpone-evaluated, the annotation is a string
|
||||
@ -1713,6 +1713,32 @@ def _need_symbolic_context(symbolic_fn) -> bool:
|
||||
return issubclass(param_type, _exporter_states.SymbolicContext)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _symbolic_context_handler(symbolic_fn: Callable) -> Callable:
|
||||
"""Decorator that provides the symbolic context to the symbolic function if needed."""
|
||||
if _need_symbolic_context(symbolic_fn):
|
||||
|
||||
# TODO(justinchuby): Update the module name of GraphContext when it is public
|
||||
warnings.warn(
|
||||
"The first argument to symbolic functions is deprecated in 1.13 and will be "
|
||||
"removed in the future. Please annotate treat the first argument (g) as GraphContext "
|
||||
"and use context information from the object instead.",
|
||||
category=FutureWarning,
|
||||
)
|
||||
|
||||
def wrapper(graph_context: jit_utils.GraphContext, *args, **kwargs):
|
||||
symbolic_context = _exporter_states.SymbolicContext(
|
||||
params_dict=graph_context.params_dict,
|
||||
env=graph_context.env,
|
||||
cur_node=graph_context.original_node,
|
||||
onnx_block=graph_context.block,
|
||||
)
|
||||
return symbolic_fn(symbolic_context, graph_context, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return symbolic_fn
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
def _get_aten_op_overload_name(n: _C.Node) -> str:
|
||||
|
||||
@ -1782,16 +1808,6 @@ def _run_symbolic_function(
|
||||
attrs = {
|
||||
k: symbolic_helper._node_get(node, k) for k in node.attributeNames()
|
||||
}
|
||||
if _need_symbolic_context(symbolic_fn):
|
||||
# TODO(justinchuby): Refactor how we check for the need of the symbolic context
|
||||
ctx = _exporter_states.SymbolicContext(
|
||||
_params_dict, env, node, block
|
||||
)
|
||||
return symbolic_fn(ctx, graph_context, *inputs, **attrs)
|
||||
# PythonOp symbolic need access to the node to resolve the name conflict,
|
||||
# this is inconsistent with regular op symbolic.
|
||||
if op_name == "PythonOp":
|
||||
inputs = (node, *inputs)
|
||||
return symbolic_fn(graph_context, *inputs, **attrs)
|
||||
|
||||
attrs = {
|
||||
@ -1895,7 +1911,14 @@ def register_custom_op_symbolic(
|
||||
versions = range(
|
||||
max(_constants.ONNX_MIN_OPSET, opset_version), _constants.ONNX_MAX_OPSET + 1
|
||||
)
|
||||
registration.custom_onnx_symbolic(symbolic_name, versions)(symbolic_fn)
|
||||
|
||||
registration.custom_onnx_symbolic(
|
||||
symbolic_name,
|
||||
versions,
|
||||
decorate=[
|
||||
_symbolic_context_handler,
|
||||
],
|
||||
)(symbolic_fn)
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
|
Reference in New Issue
Block a user