[ONNX] Create decorator to handle symbolic context (#84776)

- Create decorator to handle old style custom symbolics that require context
- Deprecate `torch.onnx.SymbolicContext` in favor of `GraphContext`. Added deprecation message
- Remove README reference of SymbolicContext

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84776
Approved by: https://github.com/AllenTiTaiWang, https://github.com/BowenBao
This commit is contained in:
Justin Chu
2022-09-28 19:52:42 +00:00
committed by PyTorch MergeBot
parent 723193ec16
commit c42a408baa
5 changed files with 87 additions and 65 deletions

View File

@ -54,16 +54,14 @@ more robust to different graphs.
### Extra context for symbolic functions
> TODO(justinchuby): Update this after #84776 is merged.
The first argument of a symbolic function is always a `GraphContext` object.
`GraphContext` contains all methods defined in a `torch.Graph` object and context
for the symbolic function.
In general, symbolic functions only require inputs and attributes to
the original node. In rare circumstances, extra context may be required.
For example, symbolic function for `prim::Loop` needs access to the sub-block of
the original node.
A symbolic function that has a first arg (before the Graph object) with the
type annotation of torch.onnx.SymbolicContext will be called with that additional context.
During export, it is populated from `utils._run_symbolic_function`
to contain the context for each node being converted.
the original node. An example of a symbolic function needing context is
`prim::Loop`. It needs access to the sub-block of the original node.
### Export inplace

View File

@ -28,6 +28,8 @@ from . import ( # usort:skip. Keep the order instead of sorting lexicographical
symbolic_opset17,
utils,
)
# TODO(After 1.13 release): Remove the deprecated SymbolicContext
from ._exporter_states import ExportTypes, SymbolicContext
from ._type_utils import JitScalarType
from .errors import CheckerError # Backwards compatibility
@ -67,8 +69,6 @@ __all__ = [
"TrainingMode",
"TensorProtoDataType",
"JitScalarType",
# Classes
"SymbolicContext",
# Public functions
"export",
"export_to_pretty_string",
@ -87,7 +87,6 @@ __all__ = [
# Set namespace for exposed private names
ExportTypes.__module__ = "torch.onnx"
SymbolicContext.__module__ = "torch.onnx"
JitScalarType.__module__ = "torch.onnx"
producer_name = "pytorch"
@ -95,7 +94,7 @@ producer_version = _C_onnx.PRODUCER_VERSION
@_deprecation.deprecated(
since="1.12.0", removed_in="TBD", instructions="use `torch.onnx.export` instead"
since="1.12.0", removed_in="1.14", instructions="use `torch.onnx.export` instead"
)
def _export(*args, **kwargs):
return utils._export(*args, **kwargs)

View File

@ -4,6 +4,7 @@ import enum
from typing import Dict
from torch import _C
from torch.onnx import _deprecation
class ExportTypes:
@ -25,6 +26,12 @@ class SymbolicContext:
onnx_block (_C.Block): Current ONNX block that converted nodes are being appended to.
"""
@_deprecation.deprecated(
"1.13",
"1.14",
# TODO(justinchuby): Fix the instruction when GraphContext is public.
"remove the 'ctx' argument and annotate 'g: GraphContext' instead",
)
def __init__(
self,
params_dict: Dict[str, _C.IValue],

View File

@ -24,9 +24,6 @@ from torch.onnx import ( # noqa: F401
errors,
symbolic_helper,
)
from torch.onnx._exporter_states import (
SymbolicContext, # Special case class import for readability
)
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import _beartype, jit_utils, registration
from torch.types import Number
@ -6303,33 +6300,31 @@ def prim_tolist(g, input, dim_val, elem_ty_val):
# -----------------------------------------------------------------------------
@_onnx_symbolic("prim::device")
@_beartype.beartype
def prim_device(
ctx: SymbolicContext, g: jit_utils.GraphContext, *inputs, **kwargs
) -> None:
output_type = ctx.cur_node.output().type()
def prim_device(g: jit_utils.GraphContext, *inputs, **kwargs) -> None:
output_type = g.original_node.output().type()
if isinstance(output_type, _C.DeviceObjType):
return None
return symbolic_helper._unimplemented(
"prim::device",
f"output type should be 'DeviceObjType', not '{output_type.kind()}'",
ctx.cur_node.output(),
g.original_node.output(),
)
@_onnx_symbolic("prim::Loop")
@_beartype.beartype
def prim_loop(ctx: SymbolicContext, g, *inputs, **attrs) -> List[_C.Value]:
n = ctx.cur_node
env = ctx.env
params_dict = ctx.params_dict
def prim_loop(g: jit_utils.GraphContext, *inputs, **attrs) -> List[_C.Value]:
node = g.original_node
env = g.env
params_dict = g.params_dict
operator_export_type = GLOBALS.operator_export_type
opset_version = GLOBALS.export_onnx_opset_version
old_blocks = tuple(n.blocks())
old_blocks = tuple(node.blocks())
new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks(
g, "Loop", *inputs, outputs=n.outputsSize(), n_blocks=len(old_blocks)
g, "Loop", *inputs, outputs=node.outputsSize(), n_blocks=len(old_blocks)
)
for old_block, new_block_context in zip(old_blocks, new_block_contexts):
@ -6358,7 +6353,7 @@ def prim_loop(ctx: SymbolicContext, g, *inputs, **attrs) -> List[_C.Value]:
env,
False,
)
new_op_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node(
fixed_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node(
new_node, opset_version
)
# Run shape type inference for Loop after subblock is converted.
@ -6366,16 +6361,16 @@ def prim_loop(ctx: SymbolicContext, g, *inputs, **attrs) -> List[_C.Value]:
torch._C._jit_pass_onnx_node_shape_type_inference(
new_node, params_dict, opset_version
)
return new_op_outputs
return fixed_outputs
@_onnx_symbolic("prim::If")
@_beartype.beartype
def prim_if(ctx: SymbolicContext, g, *inputs, **attrs):
n = ctx.cur_node
block = ctx.onnx_block
env = ctx.env
params_dict = ctx.params_dict
def prim_if(g: jit_utils.GraphContext, *inputs, **attrs) -> List[_C.Value]:
n = g.original_node
block = g.block
env = g.env
params_dict = g.params_dict
operator_export_type = GLOBALS.operator_export_type
opset_version = GLOBALS.export_onnx_opset_version
@ -6447,7 +6442,7 @@ def prim_if(ctx: SymbolicContext, g, *inputs, **attrs):
env,
False,
)
new_op_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node(
fixed_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node(
new_node, opset_version
)
# Run shape type inference for If after subblock is converted.
@ -6455,44 +6450,44 @@ def prim_if(ctx: SymbolicContext, g, *inputs, **attrs):
torch._C._jit_pass_onnx_node_shape_type_inference(
new_node, params_dict, opset_version
)
return new_op_outputs
return fixed_outputs
@_onnx_symbolic("prim::Constant")
@_beartype.beartype
def prim_constant(ctx: SymbolicContext, g, *inputs, **attrs):
n = ctx.cur_node
def prim_constant(g: jit_utils.GraphContext, *inputs, **attrs):
node = g.original_node
if n.mustBeNone():
if node.mustBeNone():
return None
# This must go before checking for string values, because some device constants
# have string values, but we want to keep them as unconverted Device types so
# that eq() can work on them.
if isinstance(n.output().type(), _C.DeviceObjType):
if isinstance(node.output().type(), _C.DeviceObjType):
return None
if n.kindOf("value") == "t":
return g.op("Constant", value_t=symbolic_helper._node_get(n, "value"))
if n.kindOf("value") == "s":
return g.op("Constant", value_s=symbolic_helper._node_get(n, "value"))
if n.output().type().isSubtypeOf(
if node.kindOf("value") == "t":
return g.op("Constant", value_t=symbolic_helper._node_get(node, "value"))
if node.kindOf("value") == "s":
return g.op("Constant", value_s=symbolic_helper._node_get(node, "value"))
if node.output().type().isSubtypeOf(
_C.ListType.ofInts()
) or n.output().type().isSubtypeOf(_C.ListType.ofFloats()):
) or node.output().type().isSubtypeOf(_C.ListType.ofFloats()):
return g.op(
"Constant", value_t=torch.tensor(symbolic_helper._node_get(n, "value"))
"Constant", value_t=torch.tensor(symbolic_helper._node_get(node, "value"))
)
raise errors.SymbolicValueError(
f"Unsupported prim::Constant kind: `{n.kindOf('value')}`. "
f"Unsupported prim::Constant kind: '{node.kindOf('value')}'. "
f"Please send a bug report at {_constants.PYTORCH_GITHUB_ISSUES_URL}.",
n.output(),
node.output(),
)
@_onnx_symbolic("onnx::Placeholder")
@_beartype.beartype
def onnx_placeholder(ctx: SymbolicContext, g, *inputs, **attrs):
n = ctx.cur_node
block = ctx.onnx_block
env = ctx.env
def onnx_placeholder(g: jit_utils.GraphContext, *inputs, **attrs):
node = g.original_node
block = g.block
env = g.env
return torch._C._jit_onnx_convert_pattern_from_subblock(block, n, env)
return torch._C._jit_onnx_convert_pattern_from_subblock(block, node, env)

View File

@ -1698,7 +1698,7 @@ def _should_aten_fallback(
@_beartype.beartype
def _need_symbolic_context(symbolic_fn) -> bool:
def _need_symbolic_context(symbolic_fn: Callable) -> bool:
"""Checks if the first argument to symbolic_fn is annotated as type `torch.onnx.SymbolicContext`."""
params = tuple(inspect.signature(symbolic_fn).parameters.values())
# When the annotation is postpone-evaluated, the annotation is a string
@ -1713,6 +1713,32 @@ def _need_symbolic_context(symbolic_fn) -> bool:
return issubclass(param_type, _exporter_states.SymbolicContext)
@_beartype.beartype
def _symbolic_context_handler(symbolic_fn: Callable) -> Callable:
"""Decorator that provides the symbolic context to the symbolic function if needed."""
if _need_symbolic_context(symbolic_fn):
# TODO(justinchuby): Update the module name of GraphContext when it is public
warnings.warn(
"The first argument to symbolic functions is deprecated in 1.13 and will be "
"removed in the future. Please annotate treat the first argument (g) as GraphContext "
"and use context information from the object instead.",
category=FutureWarning,
)
def wrapper(graph_context: jit_utils.GraphContext, *args, **kwargs):
symbolic_context = _exporter_states.SymbolicContext(
params_dict=graph_context.params_dict,
env=graph_context.env,
cur_node=graph_context.original_node,
onnx_block=graph_context.block,
)
return symbolic_fn(symbolic_context, graph_context, *args, **kwargs)
return wrapper
return symbolic_fn
@_beartype.beartype
def _get_aten_op_overload_name(n: _C.Node) -> str:
@ -1782,16 +1808,6 @@ def _run_symbolic_function(
attrs = {
k: symbolic_helper._node_get(node, k) for k in node.attributeNames()
}
if _need_symbolic_context(symbolic_fn):
# TODO(justinchuby): Refactor how we check for the need of the symbolic context
ctx = _exporter_states.SymbolicContext(
_params_dict, env, node, block
)
return symbolic_fn(ctx, graph_context, *inputs, **attrs)
# PythonOp symbolic need access to the node to resolve the name conflict,
# this is inconsistent with regular op symbolic.
if op_name == "PythonOp":
inputs = (node, *inputs)
return symbolic_fn(graph_context, *inputs, **attrs)
attrs = {
@ -1895,7 +1911,14 @@ def register_custom_op_symbolic(
versions = range(
max(_constants.ONNX_MIN_OPSET, opset_version), _constants.ONNX_MAX_OPSET + 1
)
registration.custom_onnx_symbolic(symbolic_name, versions)(symbolic_fn)
registration.custom_onnx_symbolic(
symbolic_name,
versions,
decorate=[
_symbolic_context_handler,
],
)(symbolic_fn)
@_beartype.beartype