mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Create export_for_inference API and expose core_aten as public facing API (#135912)
Differential Revision: [D62606908](https://our.internmc.facebook.com/intern/diff/D62606908) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135912 Approved by: https://github.com/avikchaudhuri ghstack dependencies: #135080
This commit is contained in:
committed by
PyTorch MergeBot
parent
382fad58b3
commit
1904b09e61
@ -677,6 +677,7 @@ API Reference
|
||||
.. autofunction:: load
|
||||
.. autofunction:: register_dataclass
|
||||
.. autofunction:: torch.export.dynamic_shapes.Dim
|
||||
.. autofunction:: torch.export.exported_program.core_aten_decompositions
|
||||
.. autofunction:: dims
|
||||
.. autoclass:: torch.export.dynamic_shapes.ShapesCollection
|
||||
|
||||
|
@ -4762,6 +4762,56 @@ def forward(self, b_a_buffer, x):
|
||||
self.assertTrue(torch.allclose(core_aten_ep.module()(*inp), m(*inp)))
|
||||
self.assertEqual(id(state_dict), id(ep.state_dict))
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "We can't customize decomp in fbcode")
|
||||
def test_export_for_inference_e2e(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.lin = torch.nn.Linear(10, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.lin(x)
|
||||
|
||||
inp = (torch.randn(5, 10),)
|
||||
m = M()
|
||||
|
||||
decomp_table = torch.export.core_aten_decompositions()
|
||||
|
||||
def _custom_decomp_for_linear(x, weight, bias):
|
||||
return x + bias.sum()
|
||||
|
||||
decomp_table[torch.ops.aten.linear.default] = _custom_decomp_for_linear
|
||||
del decomp_table[torch.ops.aten.sum.default]
|
||||
ep = torch.export.export_for_inference(
|
||||
m, inp, decomp_table=decomp_table, dynamic_shapes={"x": {0: Dim("batch")}}
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
str(ep.graph_module.code).strip(),
|
||||
"""\
|
||||
def forward(self, p_lin_weight, p_lin_bias, x):
|
||||
sum_1 = torch.ops.aten.sum.default(p_lin_bias); p_lin_bias = None
|
||||
add = torch.ops.aten.add.Tensor(x, sum_1); x = sum_1 = None
|
||||
return (add,)""",
|
||||
)
|
||||
|
||||
ep_core = ep.run_decompositions()
|
||||
|
||||
self.assertExpectedInline(
|
||||
str(ep_core.graph_module.code).strip(),
|
||||
"""\
|
||||
def forward(self, p_lin_weight, p_lin_bias, x):
|
||||
sum_1 = torch.ops.aten.sum.dim_IntList(p_lin_bias, []); p_lin_bias = None
|
||||
add = torch.ops.aten.add.Tensor(x, sum_1); x = sum_1 = None
|
||||
return (add,)""",
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected input"):
|
||||
ep.module()(torch.randn(4, 12))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected input"):
|
||||
ep_core.module()(torch.randn(4, 12))
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "We can't customize decomp in fbcode")
|
||||
def test_export_decomp_torture_case_1(self):
|
||||
class M(torch.nn.Module):
|
||||
|
@ -38,6 +38,7 @@ from torch.utils._pytree import (
|
||||
if TYPE_CHECKING:
|
||||
# Import the following modules during type checking to enable code intelligence features,
|
||||
# Do not import unconditionally, as they import sympy and importing sympy is very slow
|
||||
from torch._ops import OpOverload
|
||||
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
|
||||
|
||||
|
||||
@ -49,9 +50,11 @@ __all__ = [
|
||||
"ExportedProgram",
|
||||
"ModuleCallEntry",
|
||||
"ModuleCallSignature",
|
||||
"core_aten_decompositions",
|
||||
"dims",
|
||||
"export",
|
||||
"export_for_training",
|
||||
"export_for_inference",
|
||||
"load",
|
||||
"register_dataclass",
|
||||
"save",
|
||||
@ -62,7 +65,12 @@ __all__ = [
|
||||
|
||||
|
||||
from .dynamic_shapes import Constraint, Dim, dims, ShapesCollection
|
||||
from .exported_program import ExportedProgram, ModuleCallEntry, ModuleCallSignature
|
||||
from .exported_program import (
|
||||
core_aten_decompositions,
|
||||
ExportedProgram,
|
||||
ModuleCallEntry,
|
||||
ModuleCallSignature,
|
||||
)
|
||||
from .graph_signature import ExportBackwardSignature, ExportGraphSignature
|
||||
from .unflatten import FlatArgsAdapter, unflatten, UnflattenedModule
|
||||
|
||||
@ -165,6 +173,91 @@ def export_for_training(
|
||||
)
|
||||
|
||||
|
||||
def export_for_inference(
|
||||
mod: torch.nn.Module,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
|
||||
strict: bool = True,
|
||||
preserve_module_call_signature: Tuple[str, ...] = (),
|
||||
decomp_table: Optional[Dict["OpOverload", Optional[Callable]]] = None,
|
||||
) -> ExportedProgram:
|
||||
"""
|
||||
:func:`export_for_inference` takes any nn.Module along with example inputs, and produces a traced graph representing
|
||||
only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion,
|
||||
which can subsequently be executed with different inputs or serialized. The
|
||||
traced graph (1) produces normalized operators in the ATen operator set
|
||||
(as well as any user-specified custom operators) which is customizable via decomp_table,
|
||||
(2) has eliminated all Python control flow and data structures (with certain exceptions),
|
||||
and (3) records the set of shape constraints needed to show that this normalization and control-flow
|
||||
elimination is sound for future inputs. This API is for convenience use as it combines :func:`export_for_training` and
|
||||
:func:`run_decompositions`.
|
||||
|
||||
**Soundness Guarantee**
|
||||
|
||||
See :func:`export()` docstring for more details.
|
||||
|
||||
Args:
|
||||
mod: We will trace the forward method of this module.
|
||||
|
||||
args: Example positional inputs.
|
||||
|
||||
kwargs: Optional example keyword inputs.
|
||||
|
||||
dynamic_shapes:
|
||||
An optional argument where the type should either be:
|
||||
1) a dict from argument names of ``f`` to their dynamic shape specifications,
|
||||
2) a tuple that specifies dynamic shape specifications for each input in original order.
|
||||
If you are specifying dynamism on keyword args, you will need to pass them in the order that
|
||||
is defined in the original function signature.
|
||||
|
||||
The dynamic shape of a tensor argument can be specified as either
|
||||
(1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
|
||||
not required to include static dimension indices in this dict, but when they are,
|
||||
they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
|
||||
where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
|
||||
are denoted by None. Arguments that are dicts or tuples / lists of tensors are
|
||||
recursively specified by using mappings or sequences of contained specifications.
|
||||
|
||||
strict: When enabled (default), the export function will trace the program through
|
||||
TorchDynamo which will ensure the soundness of the resulting graph. Otherwise, the
|
||||
exported program will not validate the implicit assumptions baked into the graph and
|
||||
may cause behavior divergence between the original model and the exported one. This is
|
||||
useful when users need to workaround bugs in the tracer, or simply want incrementally
|
||||
enable safety in their models. Note that this does not affect the resulting IR spec
|
||||
to be different and the model will be serialized in the same way regardless of what value
|
||||
is passed here.
|
||||
WARNING: This option is experimental and use this at your own risk.
|
||||
|
||||
decomp_table: See :func:`run_decompositions` for more details.
|
||||
|
||||
Returns:
|
||||
An :class:`ExportedProgram` containing the traced callable.
|
||||
|
||||
**Acceptable input/output types**
|
||||
|
||||
Acceptable types of inputs (for ``args`` and ``kwargs``) and outputs include:
|
||||
|
||||
- Primitive types, i.e. ``torch.Tensor``, ``int``, ``float``, ``bool`` and ``str``.
|
||||
- Dataclasses, but they must be registered by calling :func:`register_dataclass` first.
|
||||
- (Nested) Data structures comprising of ``dict``, ``list``, ``tuple``, ``namedtuple`` and
|
||||
``OrderedDict`` containing all above types.
|
||||
|
||||
"""
|
||||
|
||||
ep_for_training = export_for_training(
|
||||
mod,
|
||||
args,
|
||||
kwargs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
strict=strict,
|
||||
preserve_module_call_signature=preserve_module_call_signature,
|
||||
)
|
||||
|
||||
return ep_for_training.run_decompositions(decomp_table=decomp_table)
|
||||
|
||||
|
||||
def export(
|
||||
mod: torch.nn.Module,
|
||||
args: Tuple[Any, ...],
|
||||
|
@ -78,6 +78,7 @@ __all__ = [
|
||||
"ExportedProgram",
|
||||
"ModuleCallEntry",
|
||||
"ModuleCallSignature",
|
||||
"core_aten_decompositions",
|
||||
]
|
||||
|
||||
|
||||
@ -289,6 +290,17 @@ def _split_decomp_table_to_cia_and_python_decomp(
|
||||
return cia_ops_to_callable, decomp_table
|
||||
|
||||
|
||||
def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
|
||||
"""
|
||||
This is the default decomposition table which contains decomposition of
|
||||
all ATEN operators to core aten opset. Use this API together with
|
||||
:func:`run_decompositions()`
|
||||
"""
|
||||
from torch._decomp import core_aten_decompositions
|
||||
|
||||
return core_aten_decompositions()
|
||||
|
||||
|
||||
def _decompose_and_get_gm_with_new_signature_constants(
|
||||
ep,
|
||||
*,
|
||||
@ -1015,8 +1027,7 @@ class ExportedProgram:
|
||||
.. code-block:: python
|
||||
|
||||
ep = torch.export.export(model, ...)
|
||||
from torch._decomp import core_aten_decompositions
|
||||
decomp_table = core_aten_decompositions()
|
||||
decomp_table = torch.export.core_aten_decompositions()
|
||||
decomp_table[your_op] = your_custom_decomp
|
||||
ep = ep.run_decompositions(decomp_table=decomp_table)
|
||||
"""
|
||||
|
Reference in New Issue
Block a user