mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-01 04:54:55 +08:00
This PR create the `GraphContext` class and relays all graph methods to _C.Graph as well as implements the `g.op` method. The GraphContext object is passed into the symbolic functions in place of _C.Graph for compatibility with existing symbolic functions. This way (1) we can type annotate all `g` args because the method is defined and (2) we can use additional context information in symbolic functions. (3) no more monkey patching on `_C.Graph` Also - Fix return type of `_jit_pass_fixup_onnx_controlflow_node` - Create `torchscript.py` to house torch.Graph related functions - Change `GraphContext.op` to create nodes in the Block instead of the Graph - Create `add_op_with_blocks` to handle scenarios where we need to directly manipulate sub-blocks. Update loop and if symbolic functions to use this function. ## Discussion Should we put all the context inside `SymbolicContext` and make it an attribute in the `GraphContext` class? This way we only define two attributes `GraphContext.graph` and `GraphContext.context`. Currently all context attributes are directly defined in the class. ### Decision Keep GraphContext flatand note that it will change in the future. Pull Request resolved: https://github.com/pytorch/pytorch/pull/84728 Approved by: https://github.com/AllenTiTaiWang, https://github.com/BowenBao
325 lines
11 KiB
Python
325 lines
11 KiB
Python
"""Utilities for manipulating the torch.Graph object and the torchscript."""
|
|
|
|
# TODO(justinchuby): Move more of the symbolic helper functions here and expose
|
|
# them to the user.
|
|
|
|
import dataclasses
|
|
import re
|
|
from typing import Any, Dict, Iterable, Sequence, Tuple, Union
|
|
|
|
from typing_extensions import Protocol, runtime_checkable
|
|
|
|
import torch
|
|
from torch import _C
|
|
from torch._C import _onnx as _C_onnx
|
|
from torch.onnx._globals import GLOBALS
|
|
from torch.onnx._internal import _beartype
|
|
|
|
_ATTR_PATTERN = re.compile("^(.+)_(([ifstgz])|(ty))$")
|
|
_SKIP_NODE_ATTRIBUTES = {"inplace", "aten"}
|
|
|
|
|
|
@runtime_checkable
|
|
class _WithOp(Protocol):
|
|
"""A protocol for classes that implements the op method for create a node in a graph."""
|
|
|
|
def op(
|
|
self,
|
|
opname: str,
|
|
*raw_args: Union[torch.Tensor, _C.Value],
|
|
outputs: int = 1,
|
|
**kwargs,
|
|
):
|
|
...
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class GraphContext(_WithOp):
|
|
"""Extra context for symbolic functions with all methods from torch.Graph.
|
|
|
|
NOTE: This class is not meant for external consumption. Please do not depend on
|
|
it outside of torch.onnx as the interface may evolve.
|
|
|
|
Attributes:
|
|
graph: The _C.Graph being constructed.
|
|
block: The current _C.Block being constructed.
|
|
opset: The opset version.
|
|
original_node: Current node that is being converted from.
|
|
params_dict: Mapping from graph initializer name to IValue.
|
|
env: Mapping from Torch domain graph Value to ONNX domain graph Value.
|
|
"""
|
|
|
|
graph: _C.Graph
|
|
block: _C.Block
|
|
opset: int
|
|
original_node: _C.Node
|
|
params_dict: Dict[str, "_C.IValue"]
|
|
env: Dict[_C.Value, _C.Value]
|
|
|
|
# Relay methods from _C.Graph for compatibility with symbolic functions that expect
|
|
# a _C.Graph
|
|
def __getattr__(self, name: str) -> Any:
|
|
return getattr(self.graph, name)
|
|
|
|
@_beartype.beartype
|
|
def op(
|
|
self,
|
|
opname: str,
|
|
*raw_args: Union[torch.Tensor, _C.Value],
|
|
outputs: int = 1,
|
|
**kwargs,
|
|
):
|
|
"""Creates an ONNX operator "opname", taking "raw_args" as inputs and attributes "kwargs".
|
|
|
|
The set of operators and the inputs/attributes they take
|
|
is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md
|
|
|
|
Args:
|
|
opname: The ONNX operator name, e.g., `Abs` or `Add`, or an operator qualified
|
|
with a namespace, e.g., `aten::add`.
|
|
raw_args: The inputs to the operator; usually provided
|
|
as arguments to the `symbolic` definition.
|
|
outputs: The number of outputs this operator returns.
|
|
By default an operator is assumed to return a single output.
|
|
If `outputs` is greater than one, this functions returns a tuple
|
|
of output `Value`, representing each output of the ONNX operator
|
|
in order.
|
|
kwargs: The attributes of the ONNX operator, whose keys are named
|
|
according to the following convention: `alpha_f` indicates
|
|
the `alpha` attribute with type `f`. The valid type specifiers are
|
|
`f` (float), `i` (int), `s` (string) or `t` (Tensor). An attribute
|
|
specified with type float accepts either a single float, or a
|
|
list of floats (e.g., you would say `dims_i` for a `dims` attribute
|
|
that takes a list of integers).
|
|
|
|
Returns:
|
|
The value representing the single output of this operator (see the `outputs`
|
|
keyword argument for multi-return nodes).
|
|
"""
|
|
# FIXME(justinchuby): Add the return type back once we know how to handle mypy
|
|
return _add_op(self, opname, *raw_args, outputs=outputs, **kwargs)
|
|
|
|
@_beartype.beartype
|
|
def aten_op(self, operator: str, *args, overload_name: str = "", **kwargs):
|
|
"""Generates an ONNX ATen op node.
|
|
|
|
This function is for backward compatibility with the old symbolic functions.
|
|
"""
|
|
return self.op(
|
|
"aten::ATen",
|
|
*args,
|
|
operator_s=operator,
|
|
overload_name_s=overload_name,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
@_beartype.beartype
|
|
def add_op_with_blocks(
|
|
graph_context: GraphContext,
|
|
opname: str,
|
|
*inputs: _C.Value,
|
|
outputs: int = 1,
|
|
n_blocks: int = 1,
|
|
**attributes,
|
|
) -> Tuple[Any, Tuple[GraphContext, ...], _C.Node]:
|
|
"""Creates an ONNX operator "opname", taking inputs and attributes.
|
|
|
|
Args:
|
|
graph_context: The context for the current graph.
|
|
opname: The ONNX operator name, e.g., `Abs` or `Add`, or an operator qualified
|
|
with a namespace, e.g., `aten::add`.
|
|
inputs: The inputs to the operator.
|
|
outputs: The number of outputs this operator returns.
|
|
By default an operator is assumed to return a single output.
|
|
If `outputs` is greater than one, this functions returns a tuple
|
|
of output `Value`, representing each output of the ONNX operator
|
|
in order.
|
|
n_blocks: The number of sub-blocks to create in the node.
|
|
attributes: The attributes of the ONNX operator.
|
|
|
|
Returns:
|
|
A tuple of (output_values, new_contexts, node) where:
|
|
output_values: ONe or more output value of this operator
|
|
(see the `outputs` keyword argument for multi-return nodes).
|
|
new_contexts: A tuple of new graph contexts for each sub-block.
|
|
node: The node representing the operator.
|
|
"""
|
|
|
|
output_values = graph_context.op(opname, *inputs, outputs=outputs, **attributes)
|
|
if isinstance(output_values, Sequence):
|
|
node = output_values[0].node()
|
|
else:
|
|
node = output_values.node()
|
|
|
|
new_contexts = []
|
|
for _ in range(n_blocks):
|
|
new_block = node.addBlock()
|
|
# Create shallow copy of the graph context and update the block
|
|
new_context = dataclasses.replace(graph_context, block=new_block)
|
|
new_contexts.append(new_context)
|
|
|
|
return output_values, tuple(new_contexts), node
|
|
|
|
|
|
@_beartype.beartype
|
|
def _add_op(
|
|
graph_context: GraphContext,
|
|
opname: str,
|
|
*args: Union[torch.Tensor, _C.Value],
|
|
outputs: int = 1,
|
|
**kwargs,
|
|
):
|
|
"""Creates an ONNX operator "opname", taking "args" as inputs and attributes "kwargs".
|
|
|
|
The set of operators and the inputs/attributes they take
|
|
is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md
|
|
|
|
This function is monkey-patched onto Graph.
|
|
|
|
Args:
|
|
g: The Torch Graph or Block.
|
|
opname: The ONNX operator name, e.g., `Abs` or `Add`, or an operator qualified
|
|
with a namespace, e.g., `aten::add`.
|
|
args: The inputs to the operator; usually provided
|
|
as arguments to the `symbolic` definition.
|
|
outputs: The number of outputs this operator returns.
|
|
By default an operator is assumed to return a single output.
|
|
If `outputs` is greater than one, this functions returns a tuple
|
|
of output `Value`, representing each output of the ONNX operator
|
|
in order.
|
|
kwargs: The attributes of the ONNX operator, whose keys are named
|
|
according to the following convention: `alpha_f` indicates
|
|
the `alpha` attribute with type `f`. The valid type specifiers are
|
|
`f` (float), `i` (int), `s` (string) or `t` (Tensor). An attribute
|
|
specified with type float accepts either a single float, or a
|
|
list of floats (e.g., you would say `dims_i` for a `dims` attribute
|
|
that takes a list of integers).
|
|
|
|
Returns:
|
|
(Union[_C.Value, Tuple[_C.Value, ...]])
|
|
The value representing the single output of this operator (see the `outputs`
|
|
keyword argument for multi-return nodes).
|
|
"""
|
|
inputs = [_const_if_tensor(graph_context, arg) for arg in args]
|
|
# Filter out None attributes, this can be convenient client side because
|
|
# now they can pass through None attributes, and have them not show up
|
|
attributes = {k: v for k, v in kwargs.items() if v is not None}
|
|
|
|
if "::" not in opname:
|
|
opname = "onnx::" + opname
|
|
|
|
node = _create_node(
|
|
graph_context.block,
|
|
opname,
|
|
inputs,
|
|
attributes,
|
|
params_dict=graph_context.params_dict,
|
|
opset_version=graph_context.opset,
|
|
n_outputs=outputs,
|
|
shape_inference=GLOBALS.onnx_shape_inference,
|
|
)
|
|
|
|
if outputs == 1:
|
|
return node.output()
|
|
return tuple(node.outputs())
|
|
|
|
|
|
@_beartype.beartype
|
|
def _const_if_tensor(graph_context: GraphContext, arg):
|
|
if arg is None:
|
|
return arg
|
|
if isinstance(arg, _C.Value):
|
|
return arg
|
|
|
|
return _add_op(graph_context, "onnx::Constant", value_z=arg)
|
|
|
|
|
|
def _create_node(
|
|
graph_or_block: Union[_C.Graph, _C.Block],
|
|
domain_op: str,
|
|
inputs: Sequence,
|
|
attributes: dict,
|
|
params_dict: dict,
|
|
opset_version: int,
|
|
n_outputs: int,
|
|
shape_inference: bool = True,
|
|
) -> _C.Node:
|
|
"""Creates an node 'opname', taking "args" as inputs and attributes 'kwargs'."""
|
|
if isinstance(graph_or_block, _C.Graph):
|
|
graph = graph_or_block
|
|
node = graph.create(domain_op, inputs, n_outputs)
|
|
node = graph.insertNode(node)
|
|
elif isinstance(graph_or_block, _C.Block):
|
|
block = graph_or_block
|
|
node = block.addNode(domain_op, inputs)
|
|
|
|
# Block does not have create defined, so we need to add outputs manually
|
|
if n_outputs > 1:
|
|
for _ in range(1, n_outputs):
|
|
node.addOutput()
|
|
|
|
node_ouputs = tuple(node.outputs())
|
|
assert len(node_ouputs) == n_outputs
|
|
|
|
aten = domain_op.startswith("aten::")
|
|
|
|
# Add all attributes
|
|
for key, value in sorted(attributes.items()):
|
|
if key in _SKIP_NODE_ATTRIBUTES:
|
|
continue
|
|
_add_attribute(node, key, value, aten=aten)
|
|
if shape_inference:
|
|
_C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
|
|
return node
|
|
|
|
|
|
@_beartype.beartype
|
|
def _is_onnx_list(value):
|
|
return (
|
|
not isinstance(value, torch._six.string_classes)
|
|
and not isinstance(value, torch.Tensor)
|
|
and isinstance(value, Iterable)
|
|
)
|
|
|
|
|
|
@_beartype.beartype
|
|
def _scalar(x: torch.Tensor):
|
|
"""Convert a scalar tensor into a Python value."""
|
|
assert x.numel() == 1
|
|
return x[0]
|
|
|
|
|
|
@_beartype.beartype
|
|
def _is_caffe2_aten_fallback() -> bool:
|
|
return (
|
|
GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
|
|
and _C_onnx._CAFFE2_ATEN_FALLBACK
|
|
)
|
|
|
|
|
|
@_beartype.beartype
|
|
def _add_attribute(node: _C.Node, key: str, value: Any, aten: bool):
|
|
r"""Initializes the right attribute based on type of value."""
|
|
m = _ATTR_PATTERN.match(key)
|
|
if m is None:
|
|
raise ValueError(
|
|
f"Invalid attribute specifier '{key}' names "
|
|
"must be suffixed with type, e.g. 'dim_i' or 'dims_i'"
|
|
)
|
|
name, kind = m.group(1), m.group(2)
|
|
if _is_onnx_list(value):
|
|
kind += "s"
|
|
|
|
if aten and _is_caffe2_aten_fallback():
|
|
if isinstance(value, torch.Tensor):
|
|
# Caffe2 proto does not support tensor attribute.
|
|
if value.numel() > 1:
|
|
raise ValueError("Should not pass tensor attribute")
|
|
value = _scalar(value)
|
|
if isinstance(value, float):
|
|
kind = "f"
|
|
else:
|
|
kind = "i"
|
|
return getattr(node, f"{kind}_")(name, value)
|