Create export_for_inference API and expose core_aten as public facing API (#135912)

Differential Revision: [D62606908](https://our.internmc.facebook.com/intern/diff/D62606908)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135912
Approved by: https://github.com/avikchaudhuri
ghstack dependencies: #135080
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2024-09-14 22:28:17 -07:00
committed by PyTorch MergeBot
parent 382fad58b3
commit 1904b09e61
4 changed files with 158 additions and 3 deletions

View File

@ -677,6 +677,7 @@ API Reference
.. autofunction:: load
.. autofunction:: register_dataclass
.. autofunction:: torch.export.dynamic_shapes.Dim
.. autofunction:: torch.export.exported_program.core_aten_decompositions
.. autofunction:: dims
.. autoclass:: torch.export.dynamic_shapes.ShapesCollection

View File

@ -4762,6 +4762,56 @@ def forward(self, b_a_buffer, x):
self.assertTrue(torch.allclose(core_aten_ep.module()(*inp), m(*inp)))
self.assertEqual(id(state_dict), id(ep.state_dict))
@unittest.skipIf(IS_FBCODE, "We can't customize decomp in fbcode")
def test_export_for_inference_e2e(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.lin = torch.nn.Linear(10, 1)
def forward(self, x):
return self.lin(x)
inp = (torch.randn(5, 10),)
m = M()
decomp_table = torch.export.core_aten_decompositions()
def _custom_decomp_for_linear(x, weight, bias):
return x + bias.sum()
decomp_table[torch.ops.aten.linear.default] = _custom_decomp_for_linear
del decomp_table[torch.ops.aten.sum.default]
ep = torch.export.export_for_inference(
m, inp, decomp_table=decomp_table, dynamic_shapes={"x": {0: Dim("batch")}}
)
self.assertExpectedInline(
str(ep.graph_module.code).strip(),
"""\
def forward(self, p_lin_weight, p_lin_bias, x):
sum_1 = torch.ops.aten.sum.default(p_lin_bias); p_lin_bias = None
add = torch.ops.aten.add.Tensor(x, sum_1); x = sum_1 = None
return (add,)""",
)
ep_core = ep.run_decompositions()
self.assertExpectedInline(
str(ep_core.graph_module.code).strip(),
"""\
def forward(self, p_lin_weight, p_lin_bias, x):
sum_1 = torch.ops.aten.sum.dim_IntList(p_lin_bias, []); p_lin_bias = None
add = torch.ops.aten.add.Tensor(x, sum_1); x = sum_1 = None
return (add,)""",
)
with self.assertRaisesRegex(RuntimeError, "Expected input"):
ep.module()(torch.randn(4, 12))
with self.assertRaisesRegex(RuntimeError, "Expected input"):
ep_core.module()(torch.randn(4, 12))
@unittest.skipIf(IS_FBCODE, "We can't customize decomp in fbcode")
def test_export_decomp_torture_case_1(self):
class M(torch.nn.Module):

View File

@ -38,6 +38,7 @@ from torch.utils._pytree import (
if TYPE_CHECKING:
# Import the following modules during type checking to enable code intelligence features,
# Do not import unconditionally, as they import sympy and importing sympy is very slow
from torch._ops import OpOverload
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
@ -49,9 +50,11 @@ __all__ = [
"ExportedProgram",
"ModuleCallEntry",
"ModuleCallSignature",
"core_aten_decompositions",
"dims",
"export",
"export_for_training",
"export_for_inference",
"load",
"register_dataclass",
"save",
@ -62,7 +65,12 @@ __all__ = [
from .dynamic_shapes import Constraint, Dim, dims, ShapesCollection
from .exported_program import ExportedProgram, ModuleCallEntry, ModuleCallSignature
from .exported_program import (
core_aten_decompositions,
ExportedProgram,
ModuleCallEntry,
ModuleCallSignature,
)
from .graph_signature import ExportBackwardSignature, ExportGraphSignature
from .unflatten import FlatArgsAdapter, unflatten, UnflattenedModule
@ -165,6 +173,91 @@ def export_for_training(
)
def export_for_inference(
mod: torch.nn.Module,
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
*,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
strict: bool = True,
preserve_module_call_signature: Tuple[str, ...] = (),
decomp_table: Optional[Dict["OpOverload", Optional[Callable]]] = None,
) -> ExportedProgram:
"""
:func:`export_for_inference` takes any nn.Module along with example inputs, and produces a traced graph representing
only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion,
which can subsequently be executed with different inputs or serialized. The
traced graph (1) produces normalized operators in the ATen operator set
(as well as any user-specified custom operators) which is customizable via decomp_table,
(2) has eliminated all Python control flow and data structures (with certain exceptions),
and (3) records the set of shape constraints needed to show that this normalization and control-flow
elimination is sound for future inputs. This API is for convenience use as it combines :func:`export_for_training` and
:func:`run_decompositions`.
**Soundness Guarantee**
See :func:`export()` docstring for more details.
Args:
mod: We will trace the forward method of this module.
args: Example positional inputs.
kwargs: Optional example keyword inputs.
dynamic_shapes:
An optional argument where the type should either be:
1) a dict from argument names of ``f`` to their dynamic shape specifications,
2) a tuple that specifies dynamic shape specifications for each input in original order.
If you are specifying dynamism on keyword args, you will need to pass them in the order that
is defined in the original function signature.
The dynamic shape of a tensor argument can be specified as either
(1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
not required to include static dimension indices in this dict, but when they are,
they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
are denoted by None. Arguments that are dicts or tuples / lists of tensors are
recursively specified by using mappings or sequences of contained specifications.
strict: When enabled (default), the export function will trace the program through
TorchDynamo which will ensure the soundness of the resulting graph. Otherwise, the
exported program will not validate the implicit assumptions baked into the graph and
may cause behavior divergence between the original model and the exported one. This is
useful when users need to workaround bugs in the tracer, or simply want incrementally
enable safety in their models. Note that this does not affect the resulting IR spec
to be different and the model will be serialized in the same way regardless of what value
is passed here.
WARNING: This option is experimental and use this at your own risk.
decomp_table: See :func:`run_decompositions` for more details.
Returns:
An :class:`ExportedProgram` containing the traced callable.
**Acceptable input/output types**
Acceptable types of inputs (for ``args`` and ``kwargs``) and outputs include:
- Primitive types, i.e. ``torch.Tensor``, ``int``, ``float``, ``bool`` and ``str``.
- Dataclasses, but they must be registered by calling :func:`register_dataclass` first.
- (Nested) Data structures comprising of ``dict``, ``list``, ``tuple``, ``namedtuple`` and
``OrderedDict`` containing all above types.
"""
ep_for_training = export_for_training(
mod,
args,
kwargs,
dynamic_shapes=dynamic_shapes,
strict=strict,
preserve_module_call_signature=preserve_module_call_signature,
)
return ep_for_training.run_decompositions(decomp_table=decomp_table)
def export(
mod: torch.nn.Module,
args: Tuple[Any, ...],

View File

@ -78,6 +78,7 @@ __all__ = [
"ExportedProgram",
"ModuleCallEntry",
"ModuleCallSignature",
"core_aten_decompositions",
]
@ -289,6 +290,17 @@ def _split_decomp_table_to_cia_and_python_decomp(
return cia_ops_to_callable, decomp_table
def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
"""
This is the default decomposition table which contains decomposition of
all ATEN operators to core aten opset. Use this API together with
:func:`run_decompositions()`
"""
from torch._decomp import core_aten_decompositions
return core_aten_decompositions()
def _decompose_and_get_gm_with_new_signature_constants(
ep,
*,
@ -1015,8 +1027,7 @@ class ExportedProgram:
.. code-block:: python
ep = torch.export.export(model, ...)
from torch._decomp import core_aten_decompositions
decomp_table = core_aten_decompositions()
decomp_table = torch.export.core_aten_decompositions()
decomp_table[your_op] = your_custom_decomp
ep = ep.run_decompositions(decomp_table=decomp_table)
"""