Refactor torchscript based exporter logic to move them to a single (private) location for better code management. Original public module and method apis are preserved.
- Updated module paths in `torch/csrc/autograd/python_function.cpp` accordingly
- Removed `check_onnx_broadcast` from `torch/autograd/_functions/utils.py` because it is private&unused
@albanD / @soulitzer could you review changes in `torch/csrc/autograd/python_function.cpp` and
`torch/autograd/_functions/utils.py`? Thanks!
## BC Breaking
- **Deprecated members in `torch.onnx.verification` are removed**
Differential Revision: [D81236421](https://our.internmc.facebook.com/intern/diff/D81236421)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161323
Approved by: https://github.com/titaiwangms, https://github.com/angelayi
Reference: https://docs.astral.sh/ruff/formatter/black/#assert-statements
> Unlike Black, Ruff prefers breaking the message over breaking the assertion, similar to how both Ruff and Black prefer breaking the assignment value over breaking the assignment target:
>
> ```python
> # Input
> assert (
> len(policy_types) >= priority + num_duplicates
> ), f"This tests needs at least {priority+num_duplicates} many types."
>
>
> # Black
> assert (
> len(policy_types) >= priority + num_duplicates
> ), f"This tests needs at least {priority+num_duplicates} many types."
>
> # Ruff
> assert len(policy_types) >= priority + num_duplicates, (
> f"This tests needs at least {priority + num_duplicates} many types."
> )
> ```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144546
Approved by: https://github.com/malfet
beartype has served us well in identifying type errors and ensuring we call internal functions with the correct arguments (thanks!). However, the value of having beartype is diminished because of the following:
1. When beartype improves support for better Dict[] type checking, it discovered typing mistakes in some functions that were previously uncaught. This caused the exporter to fail with newer versions beartype when it used to succeed. Since we cannot fix PyTorch and release a new version just because of this, it creates confusion for users that have beartype in their environment from using torch.onnx
2. beartype adds an additional call line in the traceback, which makes the already thick dynamo stack even larger, affecting readability when users diagnose errors with the traceback.
3. Since the typing annotations need to be evaluated, we cannot use new syntaxes like `|` because we need to maintain compatibility with Python 3.8. We don't want to wait for PyTorch take py310 as the lowest supported Python before using the new typing syntaxes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130484
Approved by: https://github.com/titaiwangms
beartype has served us well in identifying type errors and ensuring we call internal functions with the correct arguments (thanks!). However, the value of having beartype is diminished because of the following:
1. When beartype improves support for better Dict[] type checking, it discovered typing mistakes in some functions that were previously uncaught. This caused the exporter to fail with newer versions beartype when it used to succeed. Since we cannot fix PyTorch and release a new version just because of this, it creates confusion for users that have beartype in their environment from using torch.onnx
2. beartype adds an additional call line in the traceback, which makes the already thick dynamo stack even larger, affecting readability when users diagnose errors with the traceback.
3. Since the typing annotations need to be evaluated, we cannot use new syntaxes like `|` because we need to maintain compatibility with Python 3.8. We don't want to wait for PyTorch take py310 as the lowest supported Python before using the new typing syntaxes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130484
Approved by: https://github.com/titaiwangms
beartype has served us well in identifying type errors and ensuring we call internal functions with the correct arguments (thanks!). However, the value of having beartype is diminished because of the following:
1. When beartype improves support for better Dict[] type checking, it discovered typing mistakes in some functions that were previously uncaught. This caused the exporter to fail with newer versions beartype when it used to succeed. Since we cannot fix PyTorch and release a new version just because of this, it creates confusion for users that have beartype in their environment from using torch.onnx
2. beartype adds an additional call line in the traceback, which makes the already thick dynamo stack even larger, affecting readability when users diagnose errors with the traceback.
3. Since the typing annotations need to be evaluated, we cannot use new syntaxes like `|` because we need to maintain compatibility with Python 3.8. We don't want to wait for PyTorch take py310 as the lowest supported Python before using the new typing syntaxes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130484
Approved by: https://github.com/titaiwangms
Using ONNX opset 14, the aten scaled_dot_product_attention oeprator can be implemented with bfloat16 support because Add-14 does support bfloat16
This PR simply add bfloat16 to the list of supported types
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117878
Approved by: https://github.com/BowenBao
### **Description**:
The problem is that the graph was cast to `fp32` at a certain point but never reverted to `fp16`, causing the rest of the graph to run on `fp32`. This change aims to fix that issue and improve performance.
### **Changes Made**:
- Modified the ONNX exporter code to ensure that the graph is correctly cast back to `fp16` after a necessary cast to `fp32`.
### **Why This Change is Necessary**:
This change is necessary to ensure that the exported ONNX graph remains in `fp16` where appropriate, leading to significant gains in performance and memory savings. Without this fix, the graph would run entirely in `fp32`, causing suboptimal performance.
### **Testing**:
- Performed extensive testing with various models and scenarios to validate the correctness of the changes.
### **Benchmarking Results**:
Experiments Ran on:
8 GPUS - Tesla V100 - 32GB
**Before Fix: ort + 4 hidden layers + without fix**
- **Train Runtime**: 78.7088 seconds
- **Train Samples per Second**: 10.164
- **Train Steps per Second**: 1.271
- **Train Loss**: 5.624655108451844
- **Epoch**: 0.3
**After Fix: ort + 4 hidden layers + with fix**
- **Train Runtime**: 72.5636 seconds
- **Train Samples per Second**: 11.025
- **Train Steps per Second**: 1.378
- **Train Loss**: 5.6252727746963505
- **Epoch**: 0.3
We can see 7.79% perf gain after this fix.
- I only ran it on 4 hidden layers due to GPU constraints, the perf gain is going to be much higher on the full model.
- You could see the gain on other models that uses _attention_scale as well.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112554
Approved by: https://github.com/BowenBao, https://github.com/thiagocrepaldi
When users define customized `attention mask` using `dtype=torch.float16`, e.g.
```
from torch.nn import functional as F
float_min = torch.finfo(torch.float16).min
attention_mask_fp16 = (attention_mask * 1.0).masked_fill(attention_mask, float_min).to(torch.float16)
attn_output = F.scaled_dot_product_attention(
query_layer_, key_layer_, value_layer_, attention_mask_fp16, 0.0, is_causal=False
)
```
the onnx graph cannot be exported.
When q, k ,v have the fp16 type, we can support this `attn_mask` to be `fp16` type, by adding
```
elif (
_type_utils.JitScalarType.from_value(attn_mask)
== _type_utils.JitScalarType.FLOAT
in (_type_utils.JitScalarType.FLOAT, _type_utils.JitScalarType.HALF)
```
This can export `.onnx` graph.
Fixes#109336
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110306
Approved by: https://github.com/titaiwangms
This is the 4th PR in the series of #83787. It enables the use of `@onnx_symbolic` across `torch.onnx`.
- **Backward breaking**: Removed some symbolic functions from `__all__` because of the use of `@onnx_symbolic` for registering the same function on multiple aten names.
- Decorate all symbolic functions with `@onnx_symbolic`
- Move Quantized and Prim ops out from classes to functions defined in the modules. Eliminate the need for `isfunction` checking, speeding up the registration process by 60%.
- Remove the outdated unit test `test_symbolic_opset9.py`
- Symbolic function registration moved from the first call to `_run_symbolic_function` to init time.
- Registration is fast:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84448
Approved by: https://github.com/AllenTiTaiWang, https://github.com/BowenBao
Enable runtime type checking for all torch.onnx public apis, symbolic functions and most helpers (minus two that does not have a checkable type: `_.JitType` does not exist) by adding the beartype decorator. Fix type annotations to makes unit tests green.
Profile:
export `torchvision.models.alexnet(pretrained=True)`
```
with runtime type checking: 21.314 / 10 passes
without runtime type checking: 20.797 / 10 passes
+ 2.48%
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84091
Approved by: https://github.com/BowenBao, https://github.com/thiagocrepaldi
Enable runtime type checking for all torch.onnx public apis, symbolic functions and most helpers (minus two that does not have a checkable type: `_.JitType` does not exist) by adding the beartype decorator. Fix type annotations to makes unit tests green.
Profile:
export `torchvision.models.alexnet(pretrained=True)`
```
with runtime type checking: 21.314 / 10 passes
without runtime type checking: 20.797 / 10 passes
+ 2.48%
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84091
Approved by: https://github.com/BowenBao
Replace runtime errors in torch.onnx with `errors.SymbolicValueError` for more context around jit values.
- Extend `_unimplemented`, `_onnx_unsupported`, `_onnx_opset_unsupported`, `_onnx_opset_unsupported_detailed` errors to include JIT value information
- Replace plain RuntimeError with `errors.SymbolicValueError`
- Clean up: Use `_is_bool` to replace string comparison on jit types
- Clean up: Remove the todo `Remove type ignore after #81112`
#77316
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83332
Approved by: https://github.com/AllenTiTaiWang, https://github.com/thiagocrepaldi, https://github.com/BowenBao
Cleaning up onnx module imports to prepare for updating `__init__`.
- Simplify importing the `_C` and `_C._onnx` name spaces
- Remove alias of the symbolic_helper module in imports
- Remove any module level function imports. Import modules instead
- Alias `symbilic_opsetx` as `opsetx`
- Fix some docstrings
Requires:
- https://github.com/pytorch/pytorch/pull/77448
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77423
Approved by: https://github.com/BowenBao
Reduce circular dependencies
- Lift constants and flags from `symbolic_helper` to `_constants` and `_globals`
- Standardized constant naming to make it consistant
- Make `utils` strictly dependent on `symbolic_helper`, removing inline imports from symbolic_helper
- Move side effects from `utils` to `_patch_torch`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77142
Approved by: https://github.com/garymm, https://github.com/BowenBao
Extending the support for quantization with per channel quantization.
An extra attribute `axis` can be found for per channel quantized tensors,
most commonly in quantized weight of Convolution or Linear module.
The PR adds support to correctly parse the `axis` attribute, and map to
ONNX representation in `QuantizeLinear` and `DequantizeLinear`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76002
Approved by: https://github.com/garymm