Remove enable_fake_mode and exporter_legacy entirely. Even though this is bc breaking, `enable_fake_mode` is no longer compatible with the latest version of transformers, and so it is no longer useful.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161222
Approved by: https://github.com/titaiwangms
Previous to this PR, torch.onnx.export(..., dynamo=True, veriy=True, report=True) does not support symbolic arguments. Such examples are like follwing:
```python
class M(torch.nn.Module):
def forward(self, a, x):
return a + torch.tensor(1) + x
op = torch.onnx.export(M(), (1, torch.ones(2)),
dynamic_shapes=(torch.export.Dim.DYNAMIC, {0: torch.export.Dim.DYNAMIC}),
dynamo=True, report=True)
```
symbolic arguments are like constant arguments that they don't have tensor_meta wither. Besides, torch.export.export supports model inputs having constants, which is different from the legacy issue: https://github.com/pytorch/pytorch/issues/99534 where we tried to get the FX directly from dynamo export. Thus, `_remove_non_tensor` is deleted from args processing.
NOTE: If the ConstantArugment shows up in exported_program, it was kept to align the length of inputs to nn.Module, but it's irrelevant to the model graph, hwich is why in ONNX model the input is omitted.
The test `test_constant_argument_user_input_is_omitted_in_onnx_graph` needs #157719
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157734
Approved by: https://github.com/justinchuby
Previous to this PR, the exporter does not support dynamic dim with traced inputs containing 0/1. But after https://github.com/pytorch/pytorch/pull/148696, this is supported by torch.export.export. This PR adds the patch to torch.onnx.export.
However, there is still known pitfall existing because the difference between eager and export. Compiler needs to decide the exported shape ahead, and whether the "hidden broadcasting" being applied results in different export.
For example,
```python
import torch
class Model(torch.nn.Module):
def forward(self, x, y, z):
return torch.cat((x, y), axis=1) + z
model = Model()
x = torch.randn(2, 3)
y = torch.randn(2, 5)
z = torch.randn(1, 8)
model(x, y, z)
DYN = torch.export.Dim.DYNAMIC
ds = {0: DYN, 1: DYN}
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
ep = torch.export.export(model, (x, y, z), dynamic_shapes=(ds, ds, ds))
print(ep)
"""
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s7, s16]", y: "f32[s7, s43]", z: "f32[s7, s16 + s43]"):
#
sym_size_int: "Sym(s7)" = torch.ops.aten.sym_size.int(x, 0)
sym_size_int_1: "Sym(s16)" = torch.ops.aten.sym_size.int(x, 1)
sym_size_int_2: "Sym(s7)" = torch.ops.aten.sym_size.int(y, 0)
sym_size_int_3: "Sym(s43)" = torch.ops.aten.sym_size.int(y, 1)
sym_size_int_4: "Sym(s7)" = torch.ops.aten.sym_size.int(z, 0)
sym_size_int_5: "Sym(s16 + s43)" = torch.ops.aten.sym_size.int(z, 1)
# File: /home/titaiwang/pytorch/test_export.py:7 in forward, code: return torch.cat((x, y), axis=1) + z
cat: "f32[s7, s16 + s43]" = torch.ops.aten.cat.default([x, y], 1); x = y = None
#
eq: "Sym(True)" = sym_size_int_2 == sym_size_int; sym_size_int_2 = None
_assert_scalar_default = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s58, s35) on node 'eq'"); eq = _assert_scalar_default = None
add_1: "Sym(s16 + s43)" = sym_size_int_1 + sym_size_int_3; sym_size_int_1 = sym_size_int_3 = None
eq_1: "Sym(True)" = add_1 == sym_size_int_5; add_1 = sym_size_int_5 = None
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(eq_1, "Runtime assertion failed for expression Eq(s16 + s43, s23) on node 'eq_1'"); eq_1 = _assert_scalar_default_1 = None
eq_2: "Sym(True)" = sym_size_int == sym_size_int_4; sym_size_int = sym_size_int_4 = None
_assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(eq_2, "Runtime assertion failed for expression Eq(s35, s7) on node 'eq_2'"); eq_2 = _assert_scalar_default_2 = None
# File: /home/titaiwang/pytorch/test_export.py:7 in forward, code: return torch.cat((x, y), axis=1) + z
add: "f32[s7, s16 + s43]" = torch.ops.aten.add.Tensor(cat, z); cat = z = None
return (add,)
Graph signature:
# inputs
x: USER_INPUT
y: USER_INPUT
z: USER_INPUT
# outputs
add: USER_OUTPUT
Range constraints: {s7: VR[0, int_oo], s16: VR[0, int_oo], s43: VR[0, int_oo], s16 + s43: VR[0, int_oo]}
"""
ep.module()(x, y, z)
"""
Traceback (most recent call last):
File "/home/titaiwang/pytorch/test_export.py", line 20, in <module>
ep.module()(x, y, z)
File "/home/titaiwang/pytorch/torch/fx/graph_module.py", line 840, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/titaiwang/pytorch/torch/fx/graph_module.py", line 416, in __call__
raise e
File "/home/titaiwang/pytorch/torch/fx/graph_module.py", line 403, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/titaiwang/pytorch/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/titaiwang/pytorch/torch/nn/modules/module.py", line 1873, in _call_impl
return inner()
^^^^^^^
File "/home/titaiwang/pytorch/torch/nn/modules/module.py", line 1800, in inner
args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/titaiwang/pytorch/torch/_dynamo/eval_frame.py", line 895, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/titaiwang/pytorch/torch/export/_unlift.py", line 83, in _check_input_constraints_pre_hook
_check_input_constraints_for_graph(
File "/home/titaiwang/pytorch/torch/_export/utils.py", line 426, in _check_input_constraints_for_graph
_check_symint(
File "/home/titaiwang/pytorch/torch/_export/utils.py", line 338, in _check_symint
raise RuntimeError(
RuntimeError: Expected input at *args[2].shape[0] to be equal to 2, but got 1
"""
```
The explanation (from @pianpwk):
In the model we have `return torch.cat((x, y), axis=1) + z`.
Before this add is executed, the LHS has shape `[s7, s16 + s43]`, while the z has shape, say `[s8, s16 + s43]` (we don't know `s7 == s8` yet). When we execute this add, the compiler is making a decision: does broadcasting apply or not? The choices are:
1) Yes -> then we must specialize `s8` to 1
2) No -> then this element-wise op is only valid if the shapes match up, and we assume `s7 == s8`.
Unfortunately export can only follow one of these options, and in avoiding 0/1 specialization (because a dynamic dimension was requested), it assumed case 2).
For an operation like a + b, in eager semantics it's possible to have all options (either a == 1 OR b == 1 OR a == b), but with export we need to make a decision on what the output shape of this operation is, and keeping all branches alive requires expressing the output shape with a conditional (e.g. output shape = `a if b == 1 else b`), which is pretty hard for the compiler to reason about.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155717
Approved by: https://github.com/justinchuby
Previously the strategy used for obtaining the exported program is not asserted. This leads to silent errors if torch.export breaks something and a fallback strategy is used. This change adds a _capture_strategy field to ONNXProgram and enables unit tests to assert the strategy used to prevent fallbacks from happening.
Fixes#147674
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148348
Approved by: https://github.com/titaiwangms, https://github.com/shubhambhokare1
Fixes#139320
### Summary:
#### (1) Add `_rename_dynamic_shapes_with_model_inputs` for dynamic_shapes to play along with input_names
* Use model forward signature to rename dynamic_shapes when dynamic_shapes is not nested and dynamic_shapes is directly using the customized name. This solves the issue that torch.export.export expects dynamic_shapes only uses the model input names.
* If the dynamic_shapes is nested, we do nothing.
#### (2) Add `_from_dynamic_shapes_to_dynamic_axes` for fallback
* We flatten dynamic_shapes with leaf defined _pytree.tree_leaves()
~~* If a dynamic_shapes is not nested, and defined in dict. We can use the key as the input_names, since it should be renamed by `_rename_dynamic_shapes_with_model_inputs` already.~~
* If a dynamic_shapes is provided, input_names is required to assign the names, because dynamic_axes needs it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139532
Approved by: https://github.com/justinchuby
Features:
(1) Add support for tree structure.
(2) Add user warning before axes to shapes conversion
(3) Add suggestion of providing `dynamic_shapes` when conversion fails
Notes:
(1) `input_names` is crucial to the conversion, as we don't know the ONNX graph inputs.
(2) min and max are set as default, so LLM has higher chance to fail if users use `dynamic_axes` in terms of the min/max constraints dependency between `attention_mask` and `sequence_length`, etc. (Found in llama-3.2-1B_Instruct)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140488
Approved by: https://github.com/justinchuby
Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
The ONNX custom ops registration API.
## Design
1. Create a "custom_translation_table: dict[Callable, Sequence[Callable] | Callable" parameter for specifying extra functions
2. Use a callable as the key to support all possible call_function targets in the fx graph
3. Allow a callable or a Sequence of callables as values.
- When there is a single callable, it is the translation function for the op
- When there is a Sequence of callable, the exporter's dispatcher will dispatch to these callables in order based on input dtypes.
- The translation functions can be a plain python function that calls onnxscript ops (traced), or an onnxscript function.
- Complex input support: We create special type annotations for annotating real representations of complex inputs, which are needed to handle complex computation in the ONNX graph, as we don't have any ops in ONNX that handle complex inputs. The dispatcher will have knowledge of these newly created type annotations and dispatch correctly. The complex functions will be in the same overload pool as the real functions.
```py
torch.onnx.export(dynamo=True,
custom_translation_table = {
torch.ops.aten.add: [overload1, overload2],
torch.sym_not: sym_not_onnx,
})
```
Support for functions that handles complex inputs will be in separate PRs.
fixes https://github.com/pytorch/pytorch/issues/138391
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135403
Approved by: https://github.com/titaiwangms