Compare commits

...

132 Commits

Author SHA1 Message Date
778d522b96 document distributed apis 2025-10-10 15:40:20 -07:00
50c338c2da [DeviceMesh] Move global state into class method (#164510)
This is PR trying to move bookkeeping state maps from MeshEnv to DeviceMesh class members. The reason is that in general global variables are thread local and cause potential issue.

We will also need to do DTensor CPU overhead benchmark for this change.

3-5% CPU overhead in DTensor has been observed:

before:
<img width="1147" height="535" alt="image" src="https://github.com/user-attachments/assets/9e4ac018-ec0a-46a4-8f2c-64b4dbec465c" />

After:
<img width="1114" height="576" alt="image" src="https://github.com/user-attachments/assets/eaf83660-652b-4c6b-8591-f6049ccdd14c" />

running the benchmark mentioned here: https://github.com/pytorch/pytorch/issues/159169

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164510
Approved by: https://github.com/lw, https://github.com/fegin
2025-10-10 21:37:17 +00:00
3faee20067 [opaque_obj_v2] PyObject custom op schema type (#165004)
This is a cleaner implementation of opaque objects (https://github.com/pytorch/pytorch/pull/162660). Instead now we just need to do:

Call `register_opaque_type` to register the type as being "opaque" and allowed by custom ops. You also need to pass a unique name that maps to the type.
```python
class OpaqueQueue:
    def __init__(self, queue: list[torch.Tensor], init_tensor_: torch.Tensor) -> None:
        super().__init__()
        self.queue = queue
        self.init_tensor_ = init_tensor_

    def push(self, tensor: torch.Tensor) -> None:
        self.queue.append(tensor)

    def pop(self) -> torch.Tensor:
        if len(self.queue) > 0:
            return self.queue.pop(0)
        return self.init_tensor_

    def size(self) -> int:
        return len(self.queue)

register_opaque_type(OpaqueQueue, "_TestOpaqueObject_OpaqueQueue")
```

When creating the custom op, the schema will then use the unique name:
```python
self.lib = torch.library.Library("_TestOpaqueObject", "FRAGMENT")

torch.library.define(
    "_TestOpaqueObject::queue_push",
    "(_TestOpaqueObject_OpaqueQueue a, Tensor b) -> ()",
    tags=torch.Tag.pt2_compliant_tag,
    lib=self.lib,
)

@torch.library.impl(
    "_TestOpaqueObject::queue_push", "CompositeExplicitAutograd", lib=self.lib
)
def push_impl(queue: OpaqueQueue, b: torch.Tensor) -> None:
    assert isinstance(queue, OpaqueQueue)
    queue.push(b)
```

Using the custom op:
```python
queue = OpaqueQueue([], torch.zeros(3))
torch.ops._TestOpaqueObject.queue_push(queue, torch.ones(3))
self.assertTrue(queue.size(), 1)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165004
Approved by: https://github.com/albanD
2025-10-10 21:31:56 +00:00
cafca357fb Fix h100 daily inductor running dispatch (#165185)
casued by merged pr: e7ed1a00eb

the if condition should also updated

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165185
Approved by: https://github.com/malfet, https://github.com/huydhn
2025-10-10 21:28:58 +00:00
1e35b3c4e0 Augment DebugMode to support attributes reporting (#165109)
DebugMode reports tensor type, it shapes and placements while active. This change augments reporting to tensor attributes from configured set. This feature is intended to be used to ease understanding debug string when dealing with larger outputs. For example, before running forward pass of a model we can annotate each of parameters and buffers with their fully qualified names, so that we can see which ops are being executed against specific tensors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165109
Approved by: https://github.com/ezyang, https://github.com/pianpwk
2025-10-10 21:27:05 +00:00
f363114852 [Bugfix][Inductor][Dynamo] Fix stride incorrectness issues for stride 0 tensor (#164897)
Fixes #164814 - we update to include cases where we know symbolic expression is statically one.  There are two errors here; first in graph capture, where a tensor with size 0 yet symbolic stride would attempt to keep the symbolic stride, resulting in a mismatch.  The second is in inductor code gen, where we only checked in squeeze if size == 1, missing the case where a symbolic stride equals 1.

Also fixes #164924 (@bobrenjc93  for fuzzer finding an issue affecting users : )

### Test plan:
```
python test/dynamo/test_aot_autograd.py AotAutogradFallbackTests
```

Results in:
```
..
----------------------------------------------------------------------
Ran 49 tests in 45.622s

OK (expected failures=1)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164897
Approved by: https://github.com/laithsakka
2025-10-10 21:26:57 +00:00
0ec0120b19 Move aws OIDC credentials steps into setup-rocm.yml (#164769)
The AWS ECR login step needs `id-token: write` permissions. We move the steps to get OIDC-based credentials from `_rocm-test.yml` to `setup-rocm.yml`. This lays the groundwork to enable access to AWS ECR in workflows in other repos such as torchtitan that use [linux_job_v2.yml](https://github.com/pytorch/test-infra/blob/main/.github/workflows/linux_job_v2.yml), which also uses [setup-rocm.yml](335f4f80a0/.github/workflows/linux_job_v2.yml (L168)).

Any caller workflows that eventually execute `setup-rocm` action will thus need to provide the `id-token: write` permission.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164769
Approved by: https://github.com/huydhn
2025-10-10 21:24:29 +00:00
8360f34c36 [ROCm] hotfix test scaled matmul cuda (#165104)
Refactoring of scaled mm APIs and related tests caused previously passing tests on ROCm to start failing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165104
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-10-10 21:06:58 +00:00
370b1c12d2 [CI] Put the no gpu tests on machines that don't have gpus (#165183)
I think this is just a copy paste error?

NS: Introduced by https://github.com/pytorch/pytorch/pull/161013

Not sure where it got copied from though, the other set of no gpu tests for the other cuda version already have cpu runners
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165183
Approved by: https://github.com/malfet
2025-10-10 20:59:09 +00:00
6fd1ca28e1 [lint] Run full lint on ciflow/trunk (#165169)
Add some naming stuff to differentiate between full + partial

If we find that partial always == full, then we can get rid of it

https://github.com/pytorch/pytorch/issues/165168
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165169
Approved by: https://github.com/Skylion007, https://github.com/malfet
2025-10-10 20:38:51 +00:00
0055f07997 Disable failing test_int8_woq_mm_cuda on slow grad check (#165147)
Fixes #ISSUE_NUMBER
Failing due to memory leak, ex
https://github.com/pytorch/pytorch/actions/runs/18401518298/job/52434584458

```
2025-10-10T11:07:42.9485277Z _ TestSelectAlgorithmCudaCUDA.test_int8_woq_mm_cuda_batch_size_32_mid_dim_8_in_features_144_out_features_65_cuda_bfloat16 _
2025-10-10T11:07:42.9485389Z Traceback (most recent call last):
2025-10-10T11:07:42.9485869Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 3278, in wrapper
2025-10-10T11:07:42.9485966Z     method(*args, **kwargs)
2025-10-10T11:07:42.9486365Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 3278, in wrapper
2025-10-10T11:07:42.9486454Z     method(*args, **kwargs)
2025-10-10T11:07:42.9486849Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 3277, in wrapper
2025-10-10T11:07:42.9486933Z     with policy():
2025-10-10T11:07:42.9487380Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 2654, in __exit__
2025-10-10T11:07:42.9487473Z     raise RuntimeError(msg)
2025-10-10T11:07:42.9488533Z RuntimeError: CUDA driver API confirmed a leak in __main__.TestSelectAlgorithmCudaCUDA.test_int8_woq_mm_cuda_batch_size_32_mid_dim_8_in_features_144_out_features_65_cuda_bfloat16! Caching allocator allocated memory was 19456 and is now reported as 29184 on device 0. CUDA driver allocated memory was 356712448 and is now 358809600.
2025-10-10T11:07:42.9488543Z
2025-10-10T11:07:42.9488722Z To execute this test, run the following from the base repo dir:
2025-10-10T11:07:42.9489520Z     PYTORCH_TEST_CUDA_MEM_LEAK_CHECK=1 PYTORCH_TEST_WITH_SLOW_GRADCHECK=1 python test/inductor/test_cuda_select_algorithm.py TestSelectAlgorithmCudaCUDA.test_int8_woq_mm_cuda_batch_size_32_mid_dim_8_in_features_144_out_features_65_cuda_bfloat16
2025-10-10T11:07:42.9489525Z
2025-10-10T11:07:42.9489748Z This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
```

Got added in #161680

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165147
Approved by: https://github.com/bbeckca
2025-10-10 20:26:31 +00:00
4f8a986b8f Make LOCK_TIMEOUT in codecache configurable (#165030)
- Introduce file_lock_timeout in config (defaults to current value of 600)
- Use the above config instead of hardcoded 600 config.

This is useful when running stress tests.

Differential Revision:
D84109142

Privacy Context Container: L1297311

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165030
Approved by: https://github.com/hl475
2025-10-10 20:22:11 +00:00
5c3fe9fb30 Revert "Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)"
This reverts commit a6fa4f9c283971c0fb6f60a89674a1f35370ac79.

Reverted https://github.com/pytorch/pytorch/pull/164939 on behalf of https://github.com/izaitsevfb due to introduces numeric issues internally, see [D84326613](https://www.internalfb.com/diff/D84326613) ([comment](https://github.com/pytorch/pytorch/pull/164939#issuecomment-3392203314))
2025-10-10 20:21:12 +00:00
306b344a18 [dynamo][DebugMode] mask python keys in dispatch_key_set guard checks (#164992)
I found that running any compiled function under DebugMode more than once will trigger recompilations, e.g. with the really simple modified test case in `test_compile`:
```
[0/1] [__recompiles] Recompiling function f in /data/users/pianpwk/ptclone/pytorch/test/distributed/tensor/debug/test_debug_mode.py:268
[0/1] [__recompiles]     triggered by the following guard failure(s):
[0/1] [__recompiles]     - 0/0:
[0/2] [__recompiles] Recompiling function f in /data/users/pianpwk/ptclone/pytorch/test/distributed/tensor/debug/test_debug_mode.py:268
[0/2] [__recompiles]     triggered by the following guard failure(s):
[0/2] [__recompiles]     - 0/1:
[0/2] [__recompiles]     - 0/0:
```

Digging deeper, the guard failures were due to TENSOR_MATCH guards failing on dispatch key set checks (seemingly on the Python dispatch key):
5a1fbf45ad/torch/csrc/dynamo/guards.cpp (L199-L203)

This seems to due to the `ignore_compile_internals=True` flag on custom dispatch modes being on, which causes these modes to "hide" themselves during compilation, making dynamo guard on the Python dispatch key being off.

The (maybe imperfect) solution is to mask out the Python keys for guard comparisons. This might be fine because custom dispatch modes won't appear here during compilation - `ignore_compile_internals=True` hides them, and `ignore_compile_internals=False` disables compile entirely?

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164992
Approved by: https://github.com/williamwen42
2025-10-10 20:00:28 +00:00
94e634942a Fix int32 overflow in embedding_dense_backward (#165095)
If `max_partial_segment` is large we can overflow `gid` and cause a bunch of IMA.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165095
Approved by: https://github.com/ngimel, https://github.com/eqy
2025-10-10 19:47:38 +00:00
a4925c0ce0 [testing] Print something for log classifier to better differentiate reruns vs real failures (#165163)
The normal pytest/unittest failure patterns also match flaky tests (specifically I think tests that fail -> succeed on rerun in a new subprocess)

So print something specifically for log classifier that it can match against
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165163
Approved by: https://github.com/izaitsevfb
2025-10-10 19:28:13 +00:00
d16627f4d0 Revert "[dynamo][executorch] Do not trace into exeuctorch LoweredBackendModule (#165126)"
This reverts commit 41936f4cf6ff93b70d81f6a23811d43a0647f1e1.

Reverted https://github.com/pytorch/pytorch/pull/165126 on behalf of https://github.com/anijain2305 due to https://github.com/pytorch/pytorch/pull/165172 is the right way ([comment](https://github.com/pytorch/pytorch/pull/165126#issuecomment-3391975498))
2025-10-10 19:21:41 +00:00
8f78999d77 [Inductor][ATen] Fix stride rounding on Blockwise128x128 to accommodate for small shapes (#164953)
Summary: Fix rounding issue on `Blockwise128x128` to accommodate for small shapes. The original implementation rounded all strides to 4, which caused failures for `test_fp8.py` tests as well as `test_scaled_matmul_cuda.py::test_scaled_mm_vs_emulated_block_wise` tests ([GitHub PR](https://github.com/pytorch/pytorch/pull/164259)).

Test Plan:
`test_fp8.py`
`test_scaled_matmul_cuda.py`

Differential Revision: D84103213

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164953
Approved by: https://github.com/slayton58, https://github.com/eqy
2025-10-10 19:12:58 +00:00
7cddda1234 Update asan in slow to linux.2xlarge.memory
Followup after f2ae7084eb
2025-10-10 12:02:29 -07:00
98b53961b9 [torchfuzz] add more context to xfail test file (#165149)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165149
Approved by: https://github.com/PaulZhang12
ghstack dependencies: #165116
2025-10-10 18:51:51 +00:00
a3eb275d3c Add torch compile check for ZeroBubble (#162511)
Fix https://github.com/pytorch/pytorch/issues/161904

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162511
Approved by: https://github.com/fegin
2025-10-10 18:49:45 +00:00
6f31406723 [Code Clean] Replace std::runtime_error with TORCH_CHECK (#163927)
Fixes part of  #148114

Including:

- aten/src/ATen/InferSize.h
- aten/src/ATen/functorch
- aten/src/ATen/cudnn/Types.cpp

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163927
Approved by: https://github.com/FFFrog, https://github.com/albanD

Co-authored-by: Jiawei Li <ljw1101.vip@gmail.com>
2025-10-10 18:23:27 +00:00
f2ae7084eb [BE] Use linux.2xlarge.memory for ASAN builds (#165164)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165164
Approved by: https://github.com/janeyx99
2025-10-10 18:13:42 +00:00
12d7cc5cd3 [BE] Set commit hooks to 3.10 2025-10-10 11:09:13 -07:00
a2e2e1d8c0 Add pytorch_version and mast_application_packages to pt2 compile scuba logging (#165018)
Summary: Two more fields requested for conda-on-mast jobs

Differential Revision: D84214442

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165018
Approved by: https://github.com/c00w
2025-10-10 17:57:40 +00:00
b67785d9eb Revert "C++ API handle optimizer defaults (#161825)"
This reverts commit f33201729416ed17467228e80b04d01d4d02b5f3.

Reverted https://github.com/pytorch/pytorch/pull/161825 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/161825#issuecomment-3391506427))
2025-10-10 17:56:11 +00:00
4cd06dc82c [PT2 Archive] Use tensor dtype while deduping/grouping weights (state_dict/constants) (#165090)
Summary: While saving state_dict tensors, deduping is done to reduce number of tensor data. For this storage point is used. But when the tensor is empty, storage pointer is 0. But dtype of the tensors could be different. Existing logic will consider all such tensor as same. This will fail the model later when different dtype is expected. This change will include dtype also while deduping. For non empty tensor, this should not affect as the storage point will be unique.

Test Plan: TBD

Differential Revision: D84243094

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165090
Approved by: https://github.com/yiming0416
2025-10-10 17:51:43 +00:00
41936f4cf6 [dynamo][executorch] Do not trace into exeuctorch LoweredBackendModule (#165126)
Required for https://github.com/pytorch/pytorch/pull/164691 .. comments
inline

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165126
Approved by: https://github.com/tugsbayasgalan
2025-10-10 17:41:33 +00:00
dec9a59992 [dynamo][logging] Add most recent bytecode to graph break with torch._dynamo.graph_break() and verbose (#164422)
https://github.com/pytorch/pytorch/issues/162858 The issue described the feature implemented.

This adds to the existing graph break log with the latest 20 (or viable user frame) bytecode instructions. The scenario is when the graph_break happens without errors. It happens during the case when user calling torch._dynamo.graph_break().

Meanwhile, in the testing, one can find that the generated frame based on step() is not deterministic as sometimes it reached the maximum amount, sometimes it generated the less than that. The bytecode generation is python version dependent. Thus, the testing plan excludes the bytecode output but generated the total bytecode line count.

This is a helpful process to understand bytecode transformation, symbolic convert, and convert frame. It is a helpful task to provide hands-on experience with dynamo workflow.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164422
Approved by: https://github.com/williamwen42, https://github.com/mlazos

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-10-10 17:33:06 +00:00
f975bd58af Revert "Warn if AccumulateGrad stream does not match producer node stream (#165065)"
This reverts commit a70ef954b919e990ebaba715b4072e76352867bf.

Reverted https://github.com/pytorch/pytorch/pull/165065 on behalf of https://github.com/izaitsevfb due to breaks lint ([comment](https://github.com/pytorch/pytorch/pull/165065#issuecomment-3391387386))
2025-10-10 17:29:29 +00:00
af42256db4 Fix missing brackets (#165138)
As stated in the title.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165138
Approved by: https://github.com/Aidyn-A, https://github.com/Skylion007
2025-10-10 17:23:31 +00:00
39161e73fc [Fix] missing lambda in torch._check (#165043)
Fixes more missing lambda in torch._check in the source code. Inspired by #164225.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165043
Approved by: https://github.com/FFFrog, https://github.com/Skylion007
2025-10-10 17:11:55 +00:00
3ed90f5a09 outline various stages from aot stage2 compile (#164808)
Splits the training and inference paths for aot stage2 compile.
1. Split `aot_stage2_autograd` into `_aot_stage2a_partition`, `_aot_stage2b_fw_compile` and `_aot_stage2b_bw_compile`, and rest.
2. Split `aot_stage2_inference` into `_aot_stage2b_inference_compile` and rest.
I'm leaving these as functions with underscore names since the I/O interfaces and the exact boundaries of these splits are somewhat in the air.

Differential Revision: D84028203

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164808
Approved by: https://github.com/SherlockNoMad
2025-10-10 17:04:36 +00:00
d41aa187ec Add more B200 smoke test (#165133)
A follow up to #159494. This PR adds additional `test_scaled_matmul_cuda` to smoke tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165133
Approved by: https://github.com/drisspg
2025-10-10 16:46:26 +00:00
8b2137e74a Don't use C++ CIA decomps if there's a Python one (#164970)
Some more context at https://github.com/pytorch/pytorch/pull/164939

The basic point here is that Python decomps are guaranteed to be functional, whereas C++ ones are not. If we have a Python decomp, we should prefer it over the C++ one. This currently doesn't matter too much as CIA decomps will get functionalized, but it matters after the quoted PR because we now run these decompositions very late (to make it easy for things like aot_eager to get the fused versions of operators in proxy tensor).

Signed-off-by: Edward Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164970
Approved by: https://github.com/bdhirsh
2025-10-10 16:46:09 +00:00
a70ef954b9 Warn if AccumulateGrad stream does not match producer node stream (#165065)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165065
Approved by: https://github.com/ngimel
ghstack dependencies: #162815
2025-10-10 16:46:01 +00:00
01a2812f48 [ROCm] Adjust grid size for non-unit stride backwards indexing (#165026)
Adjust grid size for non-unit stride backwards indexing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165026
Approved by: https://github.com/jeffdaily
2025-10-10 16:36:38 +00:00
3f27100d3e [torchfuzz] remove fixed xfail (#165116)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165116
Approved by: https://github.com/PaulZhang12
2025-10-10 16:31:27 +00:00
253fd765bd bf16 support for fake_quantize_learnable_per_channel_affine (#165098)
Adding bf16 support for `torch._fake_quantize_learnable_per_channel_affine()` op by relaxing the type check on scale

TODO: need to add bf16 support to `per_tensor_affine_` as `torch._fake_quantize_learnable_per_tensor_affine_backward` gets called in the backward pass

**Test**
Modified unit test in `test_workflow_ops.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165098
Approved by: https://github.com/jerryzh168, https://github.com/andrewor14
2025-10-10 16:24:52 +00:00
abb2f7179e Revert "Fix truediv numerics between eager and compile (#164144)"
This reverts commit 68913d8f2a953bdbada4033101b04f6e8d49dabe.

Reverted https://github.com/pytorch/pytorch/pull/164144 on behalf of https://github.com/malfet due to It breaks CI again, why was it landed for 3 times in a row without any changes? ([comment](https://github.com/pytorch/pytorch/pull/164144#issuecomment-3390973016))
2025-10-10 16:10:25 +00:00
b57ab9a3f2 Fix #165125: Type "str" is not assignable to return type "None" (#165128)
Fixes #165125

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165128
Approved by: https://github.com/malfet
2025-10-10 16:05:07 +00:00
fb64da0791 [2/N] Use "is" in python type comparison (#165142)
This is follow-up of #165037. It generally recommended to use `is/is not` to compare types. Therefore this series of changes apply this suggestion in the code base, and it aims to finally enabling related linter checks.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165142
Approved by: https://github.com/albanD
2025-10-10 15:36:44 +00:00
10a9fb641b Switch build jobs from linux.4xlarge to c7i (#165057)
Switch build jobs that use linux.4xlarge which uses c5 instance types to c7i variant. This should improve performance by ~15-20% while cutting costs by ~10-15%.

Relates to pytorch/test-infra#7175
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165057
Approved by: https://github.com/huydhn
2025-10-10 15:13:40 +00:00
9420944033 Revert "[AMP][Refactor] Simplify dtype support logic in autocast context manager (#163446)"
This reverts commit 960b0d5f0d0efb1f1962bddcf62e2a698e26edd2.

Reverted https://github.com/pytorch/pytorch/pull/163446 on behalf of https://github.com/izaitsevfb due to breaks autocast tests on linux and mac ([comment](https://github.com/pytorch/pytorch/pull/163446#issuecomment-3390688642))
2025-10-10 15:12:46 +00:00
55f01a48af [ROCm] Enable and fix several FSDP + Inductor distributed unit tests (#165011)
This PR enables a number of distributed unit tests and applies necessary fixes to ensure they pass on ROCm platforms. The changes have been successfully tested on both MI200 and MI300 hardware.

This work addresses the following issues:
**https://github.com/ROCm/frameworks-internal/issues/13586
https://github.com/ROCm/frameworks-internal/issues/13578**

**Enabled Tests**

The following tests have been enabled and are now passing:
1. test_compiled_autograd_ctx
2. test_simple_mlp_fullgraph_backend_aot_eager
3. test_simple_mlp_fullgraph_backend_aot_eager_decomp_partition
4. test_simple_mlp_fullgraph_backend_inductor
5. test_nested_fully_shard_backend_aot_eager
6. test_nested_fully_shard_backend_aot_eager_decomp_partition
7. test_nested_fully_shard_backend_inductor_fullgraph_True
8. test_nested_fully_shard_backend_inductor_fullgraph_True_graph_partition
9. test_transformer_backend_aot_eager
10. test_transformer_backend_aot_eager_decomp_partition
11. test_storage_resize_zero_gpu
12. test_storage_resize_nonzero_gpu
13. test_fake_distributed_inductor

**Tests skipped due to upstream issues:**
1. test_nested_fully_shard_backend_inductor_fullgraph_False
2. test_transformer_backend_inductor_fullgraph_True
3. test_transformer_backend_inductor_fullgraph_True_graph_partition
4. test_transformer_backend_inductor_fullgraph_False

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165011
Approved by: https://github.com/jeffdaily
2025-10-10 14:10:54 +00:00
68913d8f2a Fix truediv numerics between eager and compile (#164144)
Addresses numeric differences between eager and compile in https://github.com/pytorch/pytorch/issues/141753

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164144
Approved by: https://github.com/eellison, https://github.com/jansel, https://github.com/ngimel
2025-10-10 14:00:46 +00:00
b8be796a57 Revert "[2/N] More ruff SIM fixes (#165031)"
This reverts commit 38095fbd1323ee4a9541fbcbb9b28bd20f2cd956.

Reverted https://github.com/pytorch/pytorch/pull/165031 on behalf of https://github.com/albanD due to One of the changed line started to fail on trunk ([comment](https://github.com/pytorch/pytorch/pull/165031#issuecomment-3390190870))
2025-10-10 13:42:14 +00:00
238dd5517d [PP] Move profiler record_function in schedule (#164976)
Better engineering to move the `record_function` call to also encompass the custom callback, this line is the only change: https://github.com/pytorch/pytorch/pull/164976/files#diff-1d3d91f53db88fb886901fb178d69e47776e71b8103f85688fa9ca64cc55d068R2147, the rest is just formatting.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164976
Approved by: https://github.com/fegin
ghstack dependencies: #162016, #164962
2025-10-10 13:09:23 +00:00
d272ed4b3e Fix identity expansion (#165066)
In some cases, we wrap indexing with `Identity` to prevent expansion from int32 -> int64 range. There are some checks in codegen which intend to check for constants, which did not handle Identity. Update these checks and update Identity so that it recursively prints inputs.

Fix for https://github.com/pytorch/pytorch/issues/164700

Replaces https://github.com/pytorch/pytorch/pull/160190 cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @njriasan

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165066
Approved by: https://github.com/njriasan, https://github.com/shunting314, https://github.com/jansel
2025-10-10 13:07:15 +00:00
70925bdf82 [1/N] Use "is" in python type comparison (#165037)
It generally recommended to use `is/is not` to compare types. Therefore this series of changes apply this suggestion in the code base, and it aims to finally enabling related linter checks.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165037
Approved by: https://github.com/mlazos
2025-10-10 12:36:50 +00:00
960b0d5f0d [AMP][Refactor] Simplify dtype support logic in autocast context manager (#163446)
## Description:

This PR refactors the autocast context manager in `autocast_mode.py` to simplify and centralize the logic for checking supported dtypes for each device. The previous implementation repeated similar checks for multiple device types. Now, a single mapping `device_supported_dtypes` is used to associate device types with their supported dtypes, and the validation logic is unified.

In my view, this makes the code easier to maintain and extend for new devices.

Please share any suggestions and comments with me.

BTW, in the original `xla` branch, the `supported_dtype` are `[torch.float16, torch.bfloat16]`, 5d8a226e23/torch/amp/autocast_mode.py (L358-L363) but the warning message has only `torch.bfloat16`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163446
Approved by: https://github.com/FFFrog, https://github.com/albanD
2025-10-10 12:30:06 +00:00
e0abcee3b5 [Code Clean] Remove support of python3.9 (#163846)
As the title stated.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163846
Approved by: https://github.com/ezyang
2025-10-10 11:11:56 +00:00
77bf23d85c Add an option to put store large mmap weights on disk (#164526)
As title

In windows, we cannot modify the .dll to append weights at the end, the windows .dll loader will complain it's not a valid .dll file. So we store the weight blob as a separete file.

1. We add the following API which allows passing in a pointer to the weight blob and get the size of the weight blob.

```cpp
AOTI_API AOTIRuntimeError AOTInductorModelContainerGetConstantsBlobSize(
    AOTInductorModelContainerHandle container_handle,
    uint64_t* ret_size);

// Load weights from a single blob in weight_blob_ptr
AOTI_API AOTIRuntimeError AOTInductorModelUpdateConstantsFromBlob(
    AOTInductorModelContainerHandle container_handle,
    const uint8_t* weight_blob_ptr);
```

2. We also add a method in ModelContainerRunner to load the weight:

If the runner see that there is a `.blob` file in the package, if will mmap the .blob file and use the content to load the constants.

3. We also add the `USE_MMAP_EXTERNAL` macro. When this macro is defined, the model expects to load the weights from external mmap'd weights.

Test Plan:

```
buck run @mode/dev-nosan caffe2/test/inductor:test_aot_inductor -- -r test_large_mmaped_weights_on_disk
```

Also tested for windows-cross compilation with 6542566585/demo/main_voxtral.cpp

```
Loaded model.dll
audio_encoder loaded
C:\Users\shangdiy\source\repos\torchnative\demo\token_embedding\data\aotinductor\model\model.wrapper.so
Loaded model.dll
token_embedding loaded
C:\Users\shangdiy\source\repos\torchnative\demo\text_decoder\data\aotinductor\model\model.wrapper.so
Loaded model.dll
Loading weights from C:\Users\shangdiy\source\repos\torchnative\demo\text_decoder\data\aotinductor\model\model.wrapper_weights.blob
text_decoder loaded
Load latency (ms):
  audio_encoder: 1011.234
    archive extraction: 0.000
    .so loading: 1011.197
  token_embedding: 525.773
    archive extraction: 0.000
    .so loading: 525.704
  text_decoder: 3324.130
    archive extraction: 0.000
    .so loading: 3323.979
Run latency (ms):
  audio_encoder: 285.958
    audio_encoder output: dtype=bfloat16, shape=[1, 1125, 3072], numel=3456000
  token_embedding: 6.676
    token_embedding output: dtype=bfloat16, shape=[1, 1138, 3072], numel=3495936
  text_decoder: 576.519
    text_decoder output: dtype=bfloat16, shape=[1, 1138, 131072], numel=149159936
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164526
Approved by: https://github.com/desertfire
2025-10-10 07:53:57 +00:00
d2cb183344 Revert "[inductor] verify determinism with inductor benchmark script (#164904)"
This reverts commit a3c700656f9a666eb33074b60333a23eb7e99a15.

Reverted https://github.com/pytorch/pytorch/pull/164904 on behalf of https://github.com/huydhn due to Sorry for reverting your PR but there seems to be some failed vLLM failures coming out of this ([comment](https://github.com/pytorch/pytorch/pull/164904#issuecomment-3388443678))
2025-10-10 06:23:07 +00:00
38095fbd13 [2/N] More ruff SIM fixes (#165031)
This is follow-up of #164695 to apply ruff SIM rules to more files. Most changes are about simplifying dict.get because None is already the default value.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165031
Approved by: https://github.com/mlazos
2025-10-10 05:37:46 +00:00
ffc9559d9f [7/N] Apply ruff UP035 rule (#164653)
This PR is follow-up of #164438 to continue applying `UP035` rule. All changes are about proper `Callable` importation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164653
Approved by: https://github.com/aorenste
2025-10-10 05:16:17 +00:00
172d6ed8b8 Refactor _scaled_grouped_mm_cuda dispatch (#165060)
Summary:

* Clean & simplify different scaling recipe dispatch
* Split out recipes into separate dispatch functions

Test Plan:

```
pytest -svv -k grouped  test/test_scaled_matmul_cuda.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165060
Approved by: https://github.com/danielvegamyhre, https://github.com/ngimel
2025-10-10 04:44:25 +00:00
9a3c4b917e [CMake] Remove forcing of -O2 from torch_compile_options (#164894)
That was introduced by 75a65ffe0f
Hattip to @jathu for alerting me about the issue. As result, all our PyTorch builds were shipped with `-O2` for almost all of its modern history

Partially undo the damage introduced by https://github.com/pytorch/pytorch/pull/128406 that cause cross-ISA symbols leak, to be properly followed up in https://github.com/pytorch/pytorch/issues/165123

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164894
Approved by: https://github.com/ezyang
2025-10-10 04:43:53 +00:00
df514a6d5a Revert "[inductor][eazy] change how torch.use_deterministic_algorithms affect inductor (#164905)"
This reverts commit 344e6365a0068c2d2847fcec0c55dd53291d475e.

Reverted https://github.com/pytorch/pytorch/pull/164905 on behalf of https://github.com/huydhn due to Sorry for reverting your PR but there seems to be some failed vLLM failures coming out of this ([comment](https://github.com/pytorch/pytorch/pull/164905#issuecomment-3388258660))
2025-10-10 04:37:09 +00:00
48fe858fef Fix error, remove file from pyrefly checking (#165094)
Reported issue with formatting and parsing.

Removing suppressions and avoiding this file in future type checking until we can get a more complete fix in .

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165094
Approved by: https://github.com/albanD
2025-10-10 04:34:51 +00:00
7ab00c7c17 Revert "Hotfix test scaled matmul cuda (#165104)"
This reverts commit 9aa92f246fa5fe5cfda17970d41d167b19a0612a.

Reverted https://github.com/pytorch/pytorch/pull/165104 on behalf of https://github.com/malfet due to Looks like it broke cuda tests, isn't it, see 44b1ff54e9/1 ([comment](https://github.com/pytorch/pytorch/pull/165104#issuecomment-3388247886))
2025-10-10 04:32:18 +00:00
44b1ff54e9 [CD] Do not propagate download.pytorch.org IP into container (#165075)
Followup after https://github.com/pytorch/pytorch/pull/164969

Should fix binary build test failures
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165075
Approved by: https://github.com/seemethere, https://github.com/huydhn
ghstack dependencies: #164968, #164969
2025-10-10 04:27:29 +00:00
daea35df5c Revert "[CD] Do not propagate download.pytorch.org IP into container (#165075)"
This reverts commit 6d27a8e5093ee2a21d44dceeeffcb272e6e0f655.

Reverted https://github.com/pytorch/pytorch/pull/165075 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/165075#issuecomment-3388228013))
2025-10-10 04:20:51 +00:00
7f2a902ea2 more sizelike deprecation (#164889)
remove expext_size c++ bindings and usages

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164889
Approved by: https://github.com/mlazos
ghstack dependencies: #164884, #164885, #164886, #164887, #164888
2025-10-10 03:45:06 +00:00
9c057d9863 [BE] Refresh documentation for stable ABI / API (#163899)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163899
Approved by: https://github.com/janeyx99
2025-10-10 03:26:28 +00:00
938869e7d3 [DTensor] Improve sharding propagation error msg in DTensor dispatch (#164623)
Fixes #164543

This PR improves the `__str__` method of DTensor's `OpSchema` to provide better readable error message when dispatch fails as the error message prints `{op_info.schema}`

example 1 `aten.embedding`
```
aten.embedding.default(Spec(f32[2048, 256](S(0))), Spec(i64[16, 2048](S(0)R))) on DeviceMesh((dp=2, tp=2), 'cuda', stride=(2, 1)))
```

example 2 `aten.mm`
```
aten.mm.default(Spec(f32[1024, 512](S(1))), Spec(f32[512, 256](S(0)))) on DeviceMesh((tp=4), 'cuda', stride=(1,)))
```

example 3 `aten._scaled_dot_product_flash_attention`
```
aten._scaled_dot_product_flash_attention.default(Spec(f16[8, 16, 128, 64](RS(1))), Spec(f16[8, 16, 128, 64](RS(1))), Spec(f16[8, 16, 128, 64](RS(1)))) on DeviceMesh((dp=2, tp=4), 'cuda', stride=(4, 1)))
```

Added test
```
python test/distributed/tensor/test_dtensor_ops.py -k test_embedding_error_msg
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164623
Approved by: https://github.com/zpcore
2025-10-10 03:16:04 +00:00
ce6b589545 Enable B904 check of flake8 (#165047)
The description of `B904` is `Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling. `

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165047
Approved by: https://github.com/Lucaskabela
2025-10-10 03:08:01 +00:00
ae25dd51fc Simplifying computation of the final result for equals op on DTensor (#164999)
Instead of collecting local results using all_gather_object followed by local reduction, with this change we switch to using a single all_reduce with MIN reduction operation to compute the final equals result.

This change is needed to enable LocalTensor work (all_gather_object introduces challenges in for DTensor and LocalTensor integration).

topic: not user facing

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164999
Approved by: https://github.com/ezyang
2025-10-10 03:01:28 +00:00
a61d0de9f9 [hop] support local_map filtered gradients (#164437)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164437
Approved by: https://github.com/ezyang
ghstack dependencies: #164296, #164321, #164419, #164420, #164340, #163602, #164431, #164433
2025-10-10 02:34:27 +00:00
3ad88924ad [hop] support local_map None placements (#164433)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164433
Approved by: https://github.com/ezyang
ghstack dependencies: #164296, #164321, #164419, #164420, #164340, #163602, #164431
2025-10-10 02:34:27 +00:00
3241b9c15f [hop] support local_map None gradients (#164431)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164431
Approved by: https://github.com/bdhirsh
ghstack dependencies: #164296, #164321, #164419, #164420, #164340, #163602
2025-10-10 02:34:27 +00:00
25d4d5107e [dynamo] trace local_map with local shapes for AP (#163602)
Context is in https://www.internalfb.com/excalidraw/EX519691 and https://docs.google.com/document/d/1qnuXLZk_GYt_PksHTwkn7L2ELRDnYlIRPkHAlXTyuhw/edit?tab=t.0. And the description of the previous PR: https://github.com/pytorch/pytorch/pull/164340.

The previous PR adds the support on the HOP side for eager execution and AOTAutograd. Dynamo is still passing the HOP a subgraph with wrong shapes. This PR fixes that. This is similar to the HOP implementation, however we additionally need to manually keep the TensorVariable metadata in sync.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163602
Approved by: https://github.com/ydwu4
ghstack dependencies: #164296, #164321, #164419, #164420, #164340
2025-10-10 02:34:27 +00:00
e4fe811be8 [hop] trace local_map with local shapes in fake key (#164340)
Context is in https://www.internalfb.com/excalidraw/EX519691 and https://docs.google.com/document/d/1qnuXLZk_GYt_PksHTwkn7L2ELRDnYlIRPkHAlXTyuhw/edit?tab=t.0.

So for Autoparallel initial trace, we want to trace the graph with global shapes initially. But, for the local_map region, we are forced to trace with the expected local tensors. To the tracers, this looks weird, because it's a plain tensor input (representing DTensor's full tensor .to_local()) that we need to "redistribute".

After hacking a miserable version that had cross-key dependencies, @ydwu4 proposed this simpler approach to override the fake key. This means the shape conversion will be invisible to all dispatch keys above fake, this covers all current tracing mechanisms. This manifests as the joint graph for the HOP body being traced with local shapes:
```python
# HOP forward, note local shapes (10, 80)
class GraphModule(torch.nn.Module):
    def forward(self, primals_0: "f32[10, 80]"):
        # No stacktrace found for following nodes
        view: "f32[800]" = torch.ops.aten.view.default(primals_0, [-1]);  primals_0 = None
        add: "f32[800]" = torch.ops.aten.add.Tensor(view, 10);  view = None
        view_1: "f32[10, 80]" = torch.ops.aten.view.default(add, [10, 80]);  add = None
        return (view_1,)

# HOP backward, note local shapes (10, 80)
class GraphModule(torch.nn.Module):
    def forward(self, tangents_0: "f32[10, 80]"):
        # No stacktrace found for following nodes
        clone: "f32[10, 80]" = torch.ops.aten.clone.default(tangents_0);  tangents_0 = None
        return (clone,)
```

while the rest of the graph is still traced with global shapes:
```python
# Parent graph joint, note global shapes (80, 80)
class inner_f(torch.nn.Module):
    def forward(self, primals, tangents):
        primals_1: "f32[80, 80]"; tangents_1: "f32[80, 80]";

        primals_1, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
         # File: /home/xmfan/core/a/pytorch/test/higher_order_ops/test_local_map.py:597 in forward, code: return fn(x)
        call_local_map = torch._higher_order_ops.local_map.call_local_map(primals_1);  primals_1 = None
        getitem: "f32[80, 80]" = call_local_map[0];  call_local_map = None
        call_local_map_1 = torch._higher_order_ops.local_map.call_local_map(tangents_1);  tangents_1 = None
        getitem_1: "f32[80, 80]" = call_local_map_1[0];  call_local_map_1 = None
        return pytree.tree_unflatten([getitem, getitem_1], self._out_spec)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164340
Approved by: https://github.com/ydwu4
ghstack dependencies: #164296, #164321, #164419, #164420
2025-10-10 02:34:27 +00:00
82c71af59a [hop] local_map validate partitioned fw/bw wrt placements (#164420)
Reviewed GPT-5 Summary:

**Summary / Goal**
Add validation that partitioned forward/backward graphs respect placements.

**Details**
- Validates placement alignment in local_map.
- The HOP's autograd key gets called when we are tracing the joint, we need to validate:
  - the inputs to the HOP's fwd gm (typically this is the dynamo rewritten inputs)
  - the inputs to the HOP partitioned fwd/bwd gm
  - the outputs of the HOP partitioned fwd/bwd gm

**Motivation**
Catch mismatch errors earlier, improve debugging.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164420
Approved by: https://github.com/ezyang
ghstack dependencies: #164296, #164321, #164419
2025-10-10 02:34:27 +00:00
7bd704a346 [hop] local_map fix fw_gm/bw_gm naming (#164419)
Reviewed GPT5 summary:

**Summary / Goal**
Fix inconsistent variable naming for forward/backward graphs.

**Details**
- Those methods are actually for both fw and bw graphs now that we reuse the same op for fw/bw

**Motivation**
Improves clarity, avoids confusion.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164419
Approved by: https://github.com/bdhirsh
ghstack dependencies: #164296, #164321
2025-10-10 02:34:27 +00:00
ae139b73e0 [dynamo] Better error message for local_map subgraph mismatches number of inputs/outputs with placement info (#164321)
Reviewed GPT5 summary:

**Summary / Goal**
Improve error reporting when local_map subgraph input/output counts mismatch placement info.

**Details**
- Adds descriptive runtime error messages.

**Motivation**
Helps debug local_map misalignments.

```python
AssertionError: Expecting 2 inputs to local_map function based on placements, but found 1. If the count matches for eager, Dynamo may have flattened inputs to the function or found additional tensors used via closures. Please adjust the input placements to match what the traced graph sees:
class GraphModule(torch.nn.Module):
    def forward(self, l_args_0_: "f32[8, 8, 16]"):
         # File: /home/xmfan/core/a/pytorch/test/higher_order_ops/test_local_map.py:523 in mismatch_input, code: return x + scalar, scalar
        child: "f32[8, 8, 16]" = l_args_0_ + 10;  l_args_0_ = None
        return (child,)
        .
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164321
Approved by: https://github.com/ezyang, https://github.com/mlazos
ghstack dependencies: #164296
2025-10-10 02:34:27 +00:00
cbaa07e438 [dtensor] add util to compute expected local sizes/strides for even sharding (#164296)
Reviewed GPT5 summary:

**Summary / Goal**
Add a utility to compute expected local tensor sizes and strides under *even sharding* in dtensor.

**Details**
- New function in `torch/distributed/tensor/_utils.py`.
- Computes local sizes/strides given global shape, mesh, and placements.
- Enforces divisibility of global dimension by mesh size (strict even sharding).
- Complements `compute_global_tensor_info`.

**Motivation**
Ensures correctness for stride/layout computations in distributed tensors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164296
Approved by: https://github.com/ezyang
2025-10-10 02:34:27 +00:00
bc0e2a0d2b Fix a condition error in torch/_inductor/codegen/debug_utils.py (#165033)
This PR fixes the condition
```
if arg_signatures is None and self.kernel_type == "cpp" or "extern"
```
which is interpreted as
```
if (arg_signatures is None and self.kernel_type == "cpp") or ("extern"):
```
and it is always evaluated to `True`. According to the context the intention was
```
if arg_signatures is None and (self.kernel_type == "cpp" or self.kernel_type == "extern"):
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165033
Approved by: https://github.com/Skylion007
2025-10-10 02:20:00 +00:00
0747d95994 Add Loads from fixed inputs (#162031)
## TODO
Check on multi indices
```Python

    @cute.jit
    def score_mod(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers):
        in_ptr4 = buffers[0]
        tmp0 = tSrS_ssa
        tmp1 = b_idx
        tmp2 = h_idx
        tmp3 = cute.make_fragment(1, cutlass.Int32)
        tmp4 = tmp3.store(32*tmp1 + tmp2)
        tmp5 = cute.make_fragment(1, cutlass.BFloat16)
        tmp6 = tmp3[0]
        tmp7 = tmp5[0] = (in_ptr4[tmp6])
        tmp8 = (tmp5.load()).to(cutlass.Float32)
        tmp9 = (tmp0 + tmp8)
        tSrS_ssa = tmp9

        return tSrS_ssa

 ```

I dont think that
```
        tmp4 = tmp3.store(32*tmp1 + tmp2)
        tmp5 = cute.make_fragment(1, cutlass.BFloat16)
        tmp6 = tmp3[0]
        tmp7 = tmp5[0] = (in_ptr4[tmp6]

```

 is right since this tmp6 value will be larger than the actual index dim int his case its B -> see if its possible to 1d index

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162031
Approved by: https://github.com/v0i0
ghstack dependencies: #161118
2025-10-10 01:23:37 +00:00
0a2cde2f06 Add Flash Attention support to FlexAttention (#161118)
Relies on this PR in Flash Attention: https://github.com/Dao-AILab/flash-attention/pull/1840

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161118
Approved by: https://github.com/v0i0
2025-10-10 01:23:37 +00:00
c7b57d9349 Add gfx1100 to build target for ROCm docker builds (#165103)
Fixes issue of gfx1100 test jobs timing out

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165103
Approved by: https://github.com/jeffdaily
2025-10-10 01:18:56 +00:00
7614338b69 Revert "Add SVE128 ISA (#158932)"
This reverts commit 92284fb2ff44f09a9c7df0d8cf6cac9903e376a4.

Reverted https://github.com/pytorch/pytorch/pull/158932 on behalf of https://github.com/malfet due to Hmm, but from OSS point of view, this is a no-op ([comment](https://github.com/pytorch/pytorch/pull/158932#issuecomment-3387961238))
2025-10-10 01:17:02 +00:00
a6fa4f9c28 Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)
This fixes AOTAutograd rms_norm not being bitwise equivalent to
eager, because it avoids a decomposition.  You can force the
decomposition by having the decomposition in the dispatch table,
but if eager mode wouldn't have decomposed (because it went to the fused
one), we now default to preserving the fused call by default.

This largely reverts https://github.com/pytorch/pytorch/pull/103275/ for view ops. This means that in inference mode we could hit the wrong C++ kernel; if this occurs we should just SymInt'ify the C++ kernel.

Another neat side effect of this change is that Inductor's generated kernels for rms_norm now have rms_norm in their name.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164939
Approved by: https://github.com/bdhirsh
2025-10-10 00:15:00 +00:00
344e6365a0 [inductor][eazy] change how torch.use_deterministic_algorithms affect inductor (#164905)
Previously when torch.are_deterministic_algorithms_enabled() is True Inductor will
- skip autotuning pointwise kernels
- pick a fixed (and quite arbitrary) config for reduction

This PR change the behavior to
- for pointwise kernels, we still do autotuning
- for reduction kernels, we use the recent added heuristic to pick a config

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164905
Approved by: https://github.com/jansel, https://github.com/v0i0
ghstack dependencies: #164801, #164532, #164904
2025-10-10 00:00:58 +00:00
a3c700656f [inductor] verify determinism with inductor benchmark script (#164904)
Verify the deterministic mode with torch.compile benchmark scripts.

Here is what my testing script does (pasted in the end):
- run a model in default mode, save it's result
- run the model again in default mode, but distort the benchmarking results. Compare it with the saved result.
- Do the above again in deterministic mode.

I tried to test a few modes
- BertForMaskedLM and GoogleFnet: I can repro the numeric change by distorting the benchnmark result in the default mode. The non-determinism is gone in the deterministic mode
- DistillGPT2: I can not repro the numeric change by distorting the benchmarking result in the default mode. It does not surprise me much. Reduction order change does not always cause numeric change.

```
model=GoogleFnet

export TORCHINDUCTOR_WRITE_ARE_DETERMINISTIC_ALGORITHMS_ENABLED=0
export TORCHINDUCTOR_FORCE_DISABLE_CACHES=1  # disable autotune cache
export TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE=0
export TORCHINDUCTOR_FX_GRAPH_CACHE=0
export TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor_shunting/
export TORCHINDUCTOR_BENCHMARK_KERNEL=1
export TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1
export INDUCTOR_TEST_DISABLE_FRESH_CACHE=1

# Non deterministic mode
# --float32 rather than --amp to make it easier to repro non-deterministic
echo "Save results for non-deterministic mode"
python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --save-model-outputs-to=/tmp/saved-non-deterministic.pkl

echo "Compare results with distorted benchmarking in non-deterministic mode"
TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT=inverse python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --compare-model-outputs-with=/tmp/saved-non-deterministic.pkl

echo "Save results for deterministic mode"
TORCHINDUCTOR_DETERMINISTIC=1 python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --save-model-outputs-to=/tmp/saved-deterministic.pkl

echo "Compare results with distorted benchmarking in deterministic mode"
TORCHINDUCTOR_DETERMINISTIC=1 TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT=inverse python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --compare-model-outputs-with=/tmp/saved-deterministic.pkl
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164904
Approved by: https://github.com/jansel, https://github.com/v0i0
ghstack dependencies: #164801, #164532
2025-10-10 00:00:58 +00:00
600db525bd [easy][while_loop] use copy_input instead of clone in _clone_aliased_inputs (#164955)
Compared with clone, ExternKernel.copy_input additionally realize the buffer, which downstream assumes the input buffer are realized.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164955
Approved by: https://github.com/BoyuanFeng
2025-10-09 23:39:00 +00:00
f6de195616 [dynamo][trace_rules] Add ao.quantization (#165069)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165069
Approved by: https://github.com/tugsbayasgalan, https://github.com/mlazos
2025-10-09 23:08:42 +00:00
4a0df39f81 Symintify fused_scaled_matmul_reduce_scatter (#165086)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165086
Approved by: https://github.com/zou3519, https://github.com/Skylion007
2025-10-09 23:07:40 +00:00
34ac9b61cb Revert "[export] Turn on install_free_tensors flag (#164691)"
This reverts commit 0e9b3a772ab96e998ab85591d5b2a9c1d41bacb0.

Reverted https://github.com/pytorch/pytorch/pull/164691 on behalf of https://github.com/izaitsevfb due to breaks tests internally, author asked to revert, see [D84230990](https://www.internalfb.com/diff/D84230990) ([comment](https://github.com/pytorch/pytorch/pull/164691#issuecomment-3387718323))
2025-10-09 22:53:50 +00:00
9aa92f246f Hotfix test scaled matmul cuda (#165104)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165104
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-10-09 22:51:30 +00:00
a57a14868d Better handling of restore_state_dict (#164401)
After lean export, we might want to be able to restore the original fqn. This PR refactors one util function in export that sort of does this. Note that strict_export has some complicated logic of updating the graph signature as well which we don't want. I think we can gradually make this util more refined by handling constants, non persistent buffers etc and change how strict_export does it today.

Differential Revision: [D83687844](https://www.internalfb.com/diff/D83687844)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164401
Approved by: https://github.com/avikchaudhuri
2025-10-09 22:39:11 +00:00
47956196d9 Revert "Call internal log_compilation_event if it exists (#164855)"
This reverts commit 98a081a24c22072362dc536afd39a469e28939d4.

Reverted https://github.com/pytorch/pytorch/pull/164855 on behalf of https://github.com/albanD due to We should not land this kind of code in core ([comment](https://github.com/pytorch/pytorch/pull/164855#issuecomment-3387692988))
2025-10-09 22:38:45 +00:00
6d27a8e509 [CD] Do not propagate download.pytorch.org IP into container (#165075)
Followup after https://github.com/pytorch/pytorch/pull/164969

Should fix binary build test failures
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165075
Approved by: https://github.com/seemethere, https://github.com/huydhn
ghstack dependencies: #164968, #164969
2025-10-09 21:59:31 +00:00
cd62a73dcb [cuDNN][SDPA] Handle noncontig nested tensors in cuDNN SDPA (#164958)
Previously we hardcoded the assumption in cuDNN that the inputs would be dense which breaks when e.g., the user is chunking tensors yielding noncontig inputs

New test added to check this  when `TORCH_CUDNN_SDPA_NESTED_TENSOR_ENABLED=1` is set in `test/test_transformers.py`

One issue I noticed was that the old gating of nested tensor in `sdp_utils.cpp` seems to be a no-op? All of the inputs are reported as "dense" by the time that function is called in the nested tensor tests in `test/test_nestedtensor.py -k sdpa`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164958
Approved by: https://github.com/Skylion007, https://github.com/drisspg
2025-10-09 21:58:54 +00:00
4d7f9f3aed Revert "[ATen] Fix CUDA reduction warp shuffle order (#164790)"
This reverts commit 8e1f409b8ccf64b2cf3933ece13587ad57e9d8a9.

Reverted https://github.com/pytorch/pytorch/pull/164790 on behalf of https://github.com/jeffdaily due to broke cuda and rocm ci ([comment](https://github.com/pytorch/pytorch/pull/164790#issuecomment-3387558806))
2025-10-09 21:36:10 +00:00
2b9ff99535 [flex attention] change "==" to "is" in inspect parameter comparison (#165003)
Patch for https://github.com/pytorch/pytorch/issues/164760.

This doesn't actually fix the underlying torch function issue though.

Explanation: `is` is traced differently compared to `__eq__`, so we end up avoiding the issue where we attempt to evaluate `torch.eq(tensor, inspect._empty)` in the first place.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165003
Approved by: https://github.com/mlazos
2025-10-09 21:18:05 +00:00
98a081a24c Call internal log_compilation_event if it exists (#164855)
Summary: For internal conda on mast jobs, call the internal version of log_compilation_event if it exists.

Test Plan: Ran a simple test job that just calls the API: https://fburl.com/scuba/dynamo_compile/dqx8d10g
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164855
Approved by: https://github.com/c00w
2025-10-09 21:15:11 +00:00
6c0125dbc0 Mark functions const in CUDACachingAllocator (#165007)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165007
Approved by: https://github.com/eqy
2025-10-09 20:53:58 +00:00
0fd976b65c Enable mimalloc on non-Windows platforms and make default for AArch64 builds (#164741)
This change removes the Windows requirement for mimalloc builds, and makes mimalloc the default c10 system allocator for AArch64 builds. This significantly improves the performance of AArch64 builds of PyTorch as large allocations are better cached by mimalloc than glibc.

**Updated Results**

Torchbench FP32 eager Inference, 16 threads:
<img width="1510" height="733" alt="mimalloc-v2-fp32-diff" src="https://github.com/user-attachments/assets/7fe3ea0c-3b52-42e7-879b-612444479c90" />

Torchbench BF16 eager Inference, 16 threads:
<img width="1510" height="733" alt="mimalloc-v2-bf16-diff" src="https://github.com/user-attachments/assets/56469a72-9e06-4d57-ae2a-aeb139ca79a3" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164741
Approved by: https://github.com/fadara01, https://github.com/aditew01, https://github.com/malfet
2025-10-09 20:49:46 +00:00
9944cac6e6 Add suppressions to torch/_inductor (#165062)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Split this directory into two PRs to keep them from being too large.

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the project-excludes field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:
INFO 0 errors (6,884 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165062
Approved by: https://github.com/oulgen, https://github.com/mlazos
2025-10-09 20:34:20 +00:00
e7fd296930 [CI] Add full debug build to trunk (#164974)
But not test, just import torch, as regression test for https://github.com/pytorch/pytorch/issues/164297

Test plan: Re-apply #164974 on top of this change and observer the failure in the workflows: https://github.com/pytorch/pytorch/actions/runs/18383302153/job/52375282838
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164974
Approved by: https://github.com/seemethere, https://github.com/clee2000, https://github.com/atalman
ghstack dependencies: #164968, #164969
2025-10-09 20:12:16 +00:00
fac85fcfb5 [inductor] custom_graph_pass.get_hash_for_files: don't hash paths (#165020)
Summary: We have an internal user where caching broke because the paths that are unzipped are probably different per host. We can't think of a use case where a path change matters when the file content has not changed, so removing this part

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165020
Approved by: https://github.com/oulgen
2025-10-09 20:07:53 +00:00
228973df7f Fix channels-last dimension mapping in CUDA parallel_cat (#165023)
Fixes #164849
`dimension` was updated in-place, so for more than one batch of channels-last tensors the concat `dimension` for the second kernel launch was wrong

## Testing
- python -m compileall test/test_tensor_creation_ops.py

------
https://chatgpt.com/codex/tasks/task_e_68e708879b30832f89b10ae55faa68e8
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165023
Approved by: https://github.com/ezyang
2025-10-09 20:04:32 +00:00
ed2d514ad8 Revert "Fix truediv numerics between eager and compile (#164144)"
This reverts commit 724463d5a2fba369cd14e89215b84d1b01435df7.

Reverted https://github.com/pytorch/pytorch/pull/164144 on behalf of https://github.com/malfet due to Not sure if it's related, but looks it triggered fuzzer compiler test failure, see a2f29bcd63/1 ([comment](https://github.com/pytorch/pytorch/pull/164144#issuecomment-3387288464))
2025-10-09 19:53:38 +00:00
a2f29bcd63 [inductor] Remove Repeated Code in Subgraph (#164892)
Discovered some repeated code blocks in the subgraph.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164892
Approved by: https://github.com/PaulZhang12
2025-10-09 19:16:02 +00:00
5390324984 [CodeClean] Replace std::runtime_error with TORCH_CHECK (#164129)
As the title stated.

**Changes**:
- torch/csrc/Module.cpp
- torch/csrc/utils.cpp
- torch/csrc/stable
- torch/lib/libshm
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164129
Approved by: https://github.com/albanD
2025-10-09 19:01:07 +00:00
ae25ec569c reorder wrappers in aot_stage2_inference to match forward compile in aot_stage2_autograd (#165016)
In aot_stage2_autograd:
Before calling fw_compiler, we run pre_compile for the following wrappers:
* FakifiedOutWrapper
* FunctionalizedRngRuntimeWrapper

After, we run post_compile for the following wrappers:
 * EffectTokensWrapper
 * AOTDispatchSubclassWrapper
 * FunctionalizedRngRuntimeWrapper
 * FakifiedOutWrapper

In aot_stage2_inference:
Before calling inference compiler, we run pre_compile for the following wrappers (same as above):
 * FakifiedOutWrapper
 * FunctionalizedRngRuntimeWrapper

After, we run post_compile for the following wrappers  (different than above):
 * FunctionalizedRngRuntimeWrapper
 * FakifiedOutWrapper
 * EffectTokensWrapper
 * AOTDispatchSubclassWrapper

This PR makes both do the post_compiles in the same order.

Differential Revision: D84213657

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165016
Approved by: https://github.com/zhxchen17, https://github.com/bdhirsh
2025-10-09 18:36:04 +00:00
8e1f409b8c [ATen] Fix CUDA reduction warp shuffle order (#164790)
Typical warp shuffle reduction has the following pattern:
<img width="1138" height="501" alt="image" src="https://github.com/user-attachments/assets/3bd176dc-0ad2-4df6-90c7-06e467337166" />

which is exhibited in Triton generated by torch.compile:
<img width="663" height="403" alt="image" src="https://github.com/user-attachments/assets/7f9f36cd-b9eb-44c1-879e-b469668a2ea8" />

Switch the warp shuffle order to make bitwise equivalence between the 2 easier.
PTX difference between old and new, we see a few extra instructions: https://www.diffchecker.com/h6ly3INC/

Comparing the performance on different reduction operations, we see minimal differences. New represents the changes in this PR, old represents the past warp shuffle order:
```
Tensor Shape              Operation            New all dims (ms)       New dim=0 (ms)      New dim=1 (ms)     Old all dims (ms)    Old dim=0 (ms)      Old dim=1 (ms)
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)              mean                 0.015817             0.016259             0.013642             0.015990             0.016258             0.013631
(1024, 1024)              sum                  0.015917             0.015906             0.013359             0.015707             0.016266             0.013226
(1024, 1024)              min                  0.016021             0.024625             0.015631             0.015761             0.024485             0.015317
(1024, 1024)              max                  0.016349             0.024971             0.015972             0.015771             0.025001             0.015314
(1024, 1024)              argmin               0.018070             0.024448             0.015578             0.018135             0.025370             0.015322
(1024, 1024)              argmax               0.018427             0.024859             0.015932             0.018164             0.024452             0.015639
(1024, 1024)              var                  0.020078             0.026413             0.020295             0.020199             0.026381             0.020214
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)              mean                 0.023826             0.023726             0.022273             0.023236             0.023776             0.022248
(2048, 2048)              sum                  0.023840             0.023355             0.021974             0.023294             0.023354             0.021884
(2048, 2048)              min                  0.024519             0.041263             0.024620             0.023292             0.041491             0.024358
(2048, 2048)              max                  0.024509             0.041670             0.024277             0.023334             0.041231             0.024395
(2048, 2048)              argmin               0.026125             0.041282             0.024567             0.026772             0.041773             0.024296
(2048, 2048)              argmax               0.026117             0.041487             0.024572             0.026412             0.041477             0.024273
(2048, 2048)              var                  0.026603             0.048581             0.031308             0.027587             0.048603             0.030860
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)              mean                 0.053927             0.057070             0.054073             0.053028             0.057544             0.053935
(4096, 4096)              sum                  0.053604             0.057410             0.054451             0.053076             0.057033             0.054266
(4096, 4096)              min                  0.054293             0.109122             0.058363             0.053821             0.108689             0.058382
(4096, 4096)              max                  0.054258             0.108035             0.058703             0.053492             0.110552             0.058376
(4096, 4096)              argmin               0.056805             0.111167             0.058301             0.056836             0.112325             0.058292
(4096, 4096)              argmax               0.056488             0.110958             0.058636             0.056844             0.111000             0.057928
(4096, 4096)              var                  0.058936             0.141755             0.068693             0.059735             0.141284             0.068500
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)              mean                 0.145552             0.148082             0.138647             0.145364             0.147818             0.138207
(8192, 8192)              sum                  0.145985             0.147900             0.138714             0.145755             0.148031             0.138616
(8192, 8192)              min                  0.146566             0.205359             0.192739             0.145611             0.205237             0.182335
(8192, 8192)              max                  0.146526             0.204844             0.193050             0.146073             0.205457             0.182697
(8192, 8192)              argmin               0.150190             0.206605             0.192543             0.150654             0.206847             0.182007
(8192, 8192)              argmax               0.150481             0.206368             0.192535             0.150845             0.206430             0.182022
(8192, 8192)              var                  0.150884             0.184546             0.203900             0.151594             0.184172             0.197983
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1, 1024, 128)            mean                 0.014293             0.008119             0.014533             0.013861             0.008022             0.014449
(1, 1024, 128)            sum                  0.014039             0.007877             0.014111             0.014219             0.008227             0.014045
(1, 1024, 128)            min                  0.014159             0.011354             0.023493             0.014271             0.010862             0.023644
(1, 1024, 128)            max                  0.014154             0.011027             0.023368             0.014259             0.011234             0.023692
(1, 1024, 128)            argmin               0.016403             0.005677             0.023328             0.016273             0.005683             0.024073
(1, 1024, 128)            argmax               0.016734             0.005675             0.023437             0.016580             0.005318             0.023331
(1, 1024, 128)            var                  0.018338             0.009549             0.025538             0.018528             0.009391             0.024777
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(5, 1024, 128)            mean                 0.014873             0.010131             0.015546             0.015123             0.010131             0.015481
(5, 1024, 128)            sum                  0.015334             0.009673             0.015824             0.014736             0.009671             0.015438
(5, 1024, 128)            min                  0.015047             0.013252             0.024573             0.014803             0.013163             0.024551
(5, 1024, 128)            max                  0.015050             0.013339             0.024197             0.014810             0.013525             0.024230
(5, 1024, 128)            argmin               0.017341             0.012737             0.024306             0.017471             0.012379             0.024991
(5, 1024, 128)            argmax               0.017345             0.012411             0.024421             0.017422             0.012471             0.024237
(5, 1024, 128)            var                  0.019973             0.011453             0.026188             0.020050             0.011438             0.026282
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(10, 1024, 128)           mean                 0.016976             0.011575             0.016831             0.016722             0.011927             0.017173
(10, 1024, 128)           sum                  0.017039             0.011841             0.017159             0.016385             0.011860             0.016753
(10, 1024, 128)           min                  0.017036             0.015331             0.026770             0.016944             0.015205             0.027166
(10, 1024, 128)           max                  0.017369             0.015348             0.027077             0.016531             0.015716             0.026819
(10, 1024, 128)           argmin               0.019203             0.014447             0.026813             0.018994             0.014497             0.027313
(10, 1024, 128)           argmax               0.019563             0.014795             0.027140             0.019460             0.014912             0.026733
(10, 1024, 128)           var                  0.020529             0.014316             0.030405             0.020719             0.013960             0.029964
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(100, 1024, 128)          mean                 0.045046             0.039168             0.046082             0.044839             0.039217             0.045782
(100, 1024, 128)          sum                  0.045094             0.039150             0.045777             0.044496             0.039542             0.046083
(100, 1024, 128)          min                  0.045768             0.054466             0.076244             0.044915             0.053943             0.076599
(100, 1024, 128)          max                  0.045748             0.054459             0.076188             0.044931             0.053949             0.076856
(100, 1024, 128)          argmin               0.048275             0.054046             0.076647             0.048694             0.054105             0.077004
(100, 1024, 128)          argmax               0.048267             0.054395             0.077401             0.048691             0.054131             0.076751
(100, 1024, 128)          var                  0.049710             0.043254             0.083077             0.050971             0.043251             0.082378
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000, 100)         mean                 0.202312             0.196723             0.197765             0.201774             0.196641             0.197459
(1000, 1000, 100)         sum                  0.202651             0.196682             0.197736             0.202175             0.196313             0.197523
(1000, 1000, 100)         min                  0.203022             0.264762             0.269200             0.202729             0.264129             0.268694
(1000, 1000, 100)         max                  0.202864             0.264396             0.269388             0.202486             0.263896             0.268720
(1000, 1000, 100)         argmin               0.226727             0.263781             0.268651             0.226597             0.264676             0.268983
(1000, 1000, 100)         argmax               0.226412             0.264469             0.269090             0.226570             0.264595             0.269178
(1000, 1000, 100)         var                  0.243223             0.204079             0.216096             0.241942             0.204079             0.215925
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(10000, 100)              mean                 0.016193             0.020277             0.014316             0.016152             0.020324             0.013712
(10000, 100)              sum                  0.016289             0.020237             0.014034             0.016168             0.020265             0.013708
(10000, 100)              min                  0.016046             0.030872             0.019609             0.016208             0.030867             0.018627
(10000, 100)              max                  0.016369             0.030835             0.019257             0.016218             0.030861             0.018209
(10000, 100)              argmin               0.017957             0.031171             0.019517             0.018050             0.031556             0.018077
(10000, 100)              argmax               0.017961             0.031658             0.019521             0.018060             0.031564             0.018087
(10000, 100)              var                  0.020393             0.035652             0.019339             0.020144             0.035987             0.019171
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(100000, 10)              mean                 0.015718             0.016576             0.016555             0.015999             0.016246             0.014869
(100000, 10)              sum                  0.015833             0.016247             0.016572             0.016007             0.016627             0.014872
(100000, 10)              min                  0.015888             0.020510             0.023920             0.015671             0.020821             0.021417
(100000, 10)              max                  0.015889             0.020479             0.023918             0.016077             0.020386             0.021421
(100000, 10)              argmin               0.018233             0.020863             0.023647             0.017574             0.020864             0.021103
(100000, 10)              argmax               0.017896             0.020527             0.023296             0.017569             0.020447             0.021098
(100000, 10)              var                  0.020005             0.024198             0.024372             0.020075             0.024167             0.022415
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 1023)        mean                 1.874816             1.963506             1.903909             1.873279             1.963859             1.903230
(1023, 1023, 1023)        sum                  1.875030             1.965716             1.902458             1.873566             1.960730             1.901642
(1023, 1023, 1023)        min                  1.878563             2.473455             2.179092             1.875174             2.482086             2.183027
(1023, 1023, 1023)        max                  1.879128             2.474803             2.178895             1.874831             2.482253             2.183884
(1023, 1023, 1023)        argmin               1.921800             2.476629             2.174831             1.923987             2.472641             2.170453
(1023, 1023, 1023)        argmax               1.922605             2.476688             2.177927             1.923366             2.472808             2.172979
(1023, 1023, 1023)        var                  1.972606             3.088695             2.758797             1.978679             3.095658             2.762243
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 255)         mean                 0.489984             0.500954             0.492957             0.489891             0.500654             0.491971
(1023, 1023, 255)         sum                  0.490228             0.500764             0.492289             0.489624             0.501089             0.492824
(1023, 1023, 255)         min                  0.491457             0.563560             0.553334             0.490355             0.564709             0.554754
(1023, 1023, 255)         max                  0.491396             0.563628             0.553345             0.490017             0.565004             0.554947
(1023, 1023, 255)         argmin               0.503666             0.561512             0.551831             0.503845             0.560972             0.551017
(1023, 1023, 255)         argmax               0.503602             0.561185             0.551407             0.504328             0.561267             0.551448
(1023, 1023, 255)         var                  0.510844             0.709452             0.701630             0.512693             0.710365             0.701965
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 377)         mean                 0.707439             0.727646             0.712019             0.706769             0.727101             0.711632
(1023, 1023, 377)         sum                  0.707780             0.727453             0.711554             0.706807             0.726656             0.711729
(1023, 1023, 377)         min                  0.709423             0.819809             0.794379             0.707847             0.822086             0.796664
(1023, 1023, 377)         max                  0.709297             0.819780             0.794308             0.707566             0.821913             0.796690
(1023, 1023, 377)         argmin               0.725028             0.817088             0.791695             0.726039             0.816445             0.790828
(1023, 1023, 377)         argmax               0.725301             0.817011             0.791420             0.726040             0.816917             0.791143
(1023, 1023, 377)         var                  0.740859             1.034165             1.006712             0.743413             1.035506             1.007638
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164790
Approved by: https://github.com/ngimel, https://github.com/eqy
2025-10-09 18:08:30 +00:00
ee6a1ecb0a [ROCm] Enable MI355 CI on PRs, and run full set of UTs on PRs (#160215)
Useful to have PR testing for PRs such as https://github.com/pytorch/pytorch/pull/151360

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160215
Approved by: https://github.com/malfet, https://github.com/atalman

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-10-09 18:03:12 +00:00
3c0577bd15 Remove shared_ptr from MHAGraphCache (#164895)
This commit makes several cleanup changes to MHA.cpp, the main
one of which is removal of shared_ptr from MHAGraphCache as the
cache does not actually intend to share ownership. The changes are:

1. Remove shared_ptr from MHAGraphCache
2. Remove template arguments from MHAGraphCache
3. Remove unnecessary optional<shared_ptr<...>> vars
4. Change some functions with auto return type to the actual type

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164895
Approved by: https://github.com/eqy
2025-10-09 17:44:28 +00:00
688efd9741 Revert "Enable mimalloc on non-Windows platforms and make default for AArch64 builds (#164741)"
This reverts commit 87eccf10e8484c9e59ef81ae7bdee68d3db4f605.

Reverted https://github.com/pytorch/pytorch/pull/164741 on behalf of https://github.com/malfet due to But it breaks MacOS builds, see https://github.com/pytorch/pytorch/actions/runs/18382886648/job/52373781138 ([comment](https://github.com/pytorch/pytorch/pull/164741#issuecomment-3386859778))
2025-10-09 17:30:25 +00:00
91040f4934 Revert "[Code Clean] Remove support of python3.9 (#163846)"
This reverts commit bc1690c7e859dee8c47a7f0bbd3c43cc27c6fd2a.

Reverted https://github.com/pytorch/pytorch/pull/163846 on behalf of https://github.com/izaitsevfb due to breaks distributed tests ([comment](https://github.com/pytorch/pytorch/pull/163846#issuecomment-3386855437))
2025-10-09 17:27:08 +00:00
87eccf10e8 Enable mimalloc on non-Windows platforms and make default for AArch64 builds (#164741)
This change removes the Windows requirement for mimalloc builds, and makes mimalloc the default c10 system allocator for AArch64 builds. This significantly improves the performance of AArch64 builds of PyTorch as large allocations are better cached by mimalloc than glibc.

**Updated Results**

Torchbench FP32 eager Inference, 16 threads:
<img width="1510" height="733" alt="mimalloc-v2-fp32-diff" src="https://github.com/user-attachments/assets/7fe3ea0c-3b52-42e7-879b-612444479c90" />

Torchbench BF16 eager Inference, 16 threads:
<img width="1510" height="733" alt="mimalloc-v2-bf16-diff" src="https://github.com/user-attachments/assets/56469a72-9e06-4d57-ae2a-aeb139ca79a3" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164741
Approved by: https://github.com/fadara01, https://github.com/aditew01, https://github.com/malfet
2025-10-09 16:45:31 +00:00
5d459dd609 avoid bit cast for bfloat16_t (#159946)
using bit_cast<bfloat16_t> triggers a static_assert, so replace it with intrinsics.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159946
Approved by: https://github.com/aditew01, https://github.com/malfet
2025-10-09 16:42:49 +00:00
24d69c57cb Add view support for library custom Function (#164520)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164520
Approved by: https://github.com/soulitzer, https://github.com/ezyang
2025-10-09 16:17:48 +00:00
eaa02655ea [CI] Run cpp tests on windows in one run_tests call (#164861)
The windows cpp tests take ~1 hour according to logs.  Each has run_test called on them individually, so I tried batching them together so it's just one run_test call for all of them.  I believe it now takes 30min.  I turned off TD since I don't think cpp tests are included in TD stuff.

As always with batch, I'm not sure if the errorlevel/error surfacing stuff is correct

This code is written with a lot of help from chatgpu and copilot
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164861
Approved by: https://github.com/huydhn
2025-10-09 16:07:28 +00:00
aea57b3aa3 AOTI MPS Shim Implementation (#163865)
## MPS Shim API

*   Updated MPS shimification API with handles and function declarations:
    *   `AOTIMetalShaderLibraryHandle` and `AOTIMetalKernelFunctionHandle` types
    *   Library management: `aoti_torch_mps_create_shader_library`, `aoti_torch_mps_delete_shader_library`, `aoti_torch_mps_get_kernel_function`
    *   Kernel execution: `aoti_torch_mps_run_command_block`, `aoti_torch_mps_start_encoding`, `aoti_torch_mps_dispatch` variants, etc

## MPS Shader Codegen

*   Modified to generate source constants instead of direct `DynamicMetalShaderLibrary` instantiation:
    *   **Before**: `at::native::mps::DynamicMetalShaderLibrary mps_lib_0(R"MTL(...)MTL");`
    *   **After**: `const char* mps_lib_0_source = R"MTL(...)MTL";`
*   Updated kernel call generation  to use shimified functions:
    *   Generates calls to shimified API instead of direct libtorch calls

## Before vs After Comparison

### Section 1: Shader Library
**Before (Direct Library Object)**
```cpp
at::native::mps::DynamicMetalShaderLibrary mps_lib_0(R"MTL(
    ...
)MTL");
```
**After (Source String)**
```cpp
const char* mps_lib_0_source = (R"MTL(
    ...
)MTL");
```

### Section 2: Getter Functions & RAII Management

**Before (Direct Library Access)**
```cpp
const std::shared_ptr<at::native::mps::MetalKernelFunction> get_mps_lib_0() {
    static const auto func = mps_lib_0.getKernelFunction("generated_kernel");
    return func;
}

AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() {
    static const auto handle = AOTIMetalKernelFunctionHandle(get_mps_lib_0().get());
    return handle;
}
```

**After (Shim API + RAII Wrapper)**
```cpp
AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() {
    static auto kernel_handle = []() {
        AOTIMetalShaderLibraryHandle lib_handle = nullptr;
        AOTIMetalKernelFunctionHandle kern_handle = nullptr;

        aoti_torch_mps_create_shader_library(mps_lib_0_source, &lib_handle);
        aoti_torch_mps_get_kernel_function(lib_handle, "generated_kernel", &kern_handle);

        // RAII wrapper with custom deleter
        auto lib_deleter = [](AOTIMetalShaderLibraryHandle h) {{
            if (h) aoti_torch_mps_delete_shader_library(h);
        }};

        using LibDeleter = decltype(lib_deleter);
        using LibPtr = std::unique_ptr<AOTIMetalShaderLibraryOpaque, LibDeleter>;

        // Return pair of kernel handle and library smart pointer for cleanup
        return std::make_pair(kern_handle, LibPtr(lib_handle, lib_deleter));
    }();
    return kernel_handle.first;
}
```

### Section 3: Runtime Execution

**Before (Direct Library Methods)**
```cpp
void AOTInductorModel::run_impl(...) {

    ...

    get_mps_lib_0()->runCommandBlock([&] {
        get_mps_lib_0()->startEncoding();
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 0, buf0);
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 1, arg0_1);
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 2, arg1_1);
        get_mps_lib_0()->dispatch({static_cast<uint64_t>(10LL)});

    });

    ...

} // AOTInductorModel::run_impl
```

**After (Shim API with Lambda Pattern)**
```cpp
void AOTInductorModel::run_impl(...) {

    ...

    auto mps_lib_0_lambda_0 = [&](AOTIMetalKernelFunctionHandle handle) {
        aoti_torch_mps_start_encoding(handle);
        aoti_torch_mps_set_arg_tensor(handle, 0, buf0);
        aoti_torch_mps_set_arg_tensor(handle, 1, arg0_1);
        aoti_torch_mps_set_arg_tensor(handle, 2, arg1_1);
        aoti_torch_mps_dispatch_single(handle, static_cast<uint64_t>(10LL));
    };

    std::function<void(AOTIMetalKernelFunctionHandle)> mps_lib_0_func_wrapper_0 = mps_lib_0_lambda_0;
    aoti_torch_mps_run_command_block(get_mps_lib_0_handle(), aoti_torch_mps_shared_callback, &mps_lib_0_func_wrapper_0);

    ...

} // AOTInductorModel::run_impl
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163865
Approved by: https://github.com/angelayi, https://github.com/desertfire
2025-10-09 16:06:36 +00:00
3d1fa40ae1 Revert "[BC-Breaking] Remove long-deprecated casting functions from native_functions.yaml (#164641)"
This reverts commit 64108bdbed2f099d527060b4c9fdd5a11cad2afc.

Reverted https://github.com/pytorch/pytorch/pull/164641 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/164641#issuecomment-3386346474))
2025-10-09 15:42:51 +00:00
a7fa1a91e3 fix flex attention eager bwd: more rounding (#164317)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164317
Approved by: https://github.com/drisspg
ghstack dependencies: #163986
2025-10-09 15:40:49 +00:00
afeec56a5a Fix replacement reconstruct (#164937)
If we return Dtensor, the object is created via fx graph call so we never needed to reconstruct them. But if there is side effect, we do need to reconstruct it.

Differential Revision: [D84159000](https://our.internmc.facebook.com/intern/diff/D84159000)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164937
Approved by: https://github.com/StrongerXi
2025-10-09 15:31:23 +00:00
724463d5a2 Fix truediv numerics between eager and compile (#164144)
Addresses numeric differences between eager and compile in https://github.com/pytorch/pytorch/issues/141753

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164144
Approved by: https://github.com/eellison, https://github.com/jansel, https://github.com/ngimel
ghstack dependencies: #164997
2025-10-09 14:31:33 +00:00
f79e212733 Revert "[CUDA][cuBLAS] addmm -- some refactoring for easier navigation between the Lt and non-Lt paths (#163955)"
This reverts commit ab94a0d544503b5c27e889b45e45ef8cf75c8183.

Reverted https://github.com/pytorch/pytorch/pull/163955 on behalf of https://github.com/jeffdaily due to broke on cuda and rocm after landing though this PR had a clean signal initially ([comment](https://github.com/pytorch/pytorch/pull/163955#issuecomment-3386127145))
2025-10-09 14:24:56 +00:00
b28b24a9fc Switch build jobs that use linux.12xlarge to c7i (#164941)
This PR updates build jobs that currently use linux.12xlarge to the
c7i varient which should increase build times by 15% - 20% depending
on the job and reduce costs of these jobs by 10% - 15%.

Signed-off-by: Thanh Ha <thanh.ha@linuxfoundation.org>
2025-10-09 09:58:52 -04:00
17c7170ca6 Fix Avoid DDE in item numel check (#164934)
address https://github.com/pytorch/pytorch/issues/164725 and https://github.com/pytorch/pytorch/issues/164704

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164934
Approved by: https://github.com/ezyang, https://github.com/aorenste, https://github.com/Skylion007
2025-10-09 13:09:06 +00:00
6a7f5c0d21 Add scaled_mm python API, test (#164142)
Summary:

* Add `torch.nn.functional.scaled_mm` as an abstraction around the C++
  methods
* Wraps `torch._scaled_mm_v2` API by default, but user can force use of
  the older `torch._scaled_mm` interface.
* Scaled MM tests now run on the new API

Test Plan:

`pytest test/test_scaled_matmul_cuda.py`

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlaytonmeta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164142
Approved by: https://github.com/drisspg
ghstack dependencies: #164141
2025-10-09 12:43:18 +00:00
512b6b59f0 Add _scaled_mm_v2 API (#164141)
Summary:

* Add new scaled-MM API to future-proof / clean-up existing code.
* Scaling is explicitly described rather than infer
* Swizzling of scaled must now be defined (vs. inferred)
* Adds API support for multi-level scaling
* Refactor dispatch logic to make it easier to add new implementations

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlaytonmeta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164141
Approved by: https://github.com/drisspg
2025-10-09 12:43:18 +00:00
bc1690c7e8 [Code Clean] Remove support of python3.9 (#163846)
As the title stated.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163846
Approved by: https://github.com/ezyang
2025-10-09 11:54:10 +00:00
53f5af8c92 Update torch-xpu-ops commit pin (#164237)
Update the torch-xpu-ops commit to [intel/torch-xpu-ops@f30173](f301733b03), includes:

- Install xpu internal headers to PyTorch
- Fix error handling for BatchLinearAlgebra Ops
- Fix unnecessary double data type conversion
- Fix overflow when calculating workgroups count
- Fix segmentation fault and calculation error in AveragePool2dKernel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164237
Approved by: https://github.com/EikanWang
2025-10-09 10:38:59 +00:00
4412026949 Revert "AOTI MPS Shim Implementation (#163865)"
This reverts commit 874efa2d72d83b00894097130f18062ce331a265.

Reverted https://github.com/pytorch/pytorch/pull/163865 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/163865#issuecomment-3385196387))
2025-10-09 10:26:01 +00:00
06d86e58d0 Revert "Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)"
This reverts commit d40a9bfb8da0dc1ac1e6e56b33a25979112874de.

Reverted https://github.com/pytorch/pytorch/pull/164939 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/164939#issuecomment-3385056722))
2025-10-09 09:50:59 +00:00
874efa2d72 AOTI MPS Shim Implementation (#163865)
## MPS Shim API

*   Updated MPS shimification API with handles and function declarations:
    *   `AOTIMetalShaderLibraryHandle` and `AOTIMetalKernelFunctionHandle` types
    *   Library management: `aoti_torch_mps_create_shader_library`, `aoti_torch_mps_delete_shader_library`, `aoti_torch_mps_get_kernel_function`
    *   Kernel execution: `aoti_torch_mps_run_command_block`, `aoti_torch_mps_start_encoding`, `aoti_torch_mps_dispatch` variants, etc

## MPS Shader Codegen

*   Modified to generate source constants instead of direct `DynamicMetalShaderLibrary` instantiation:
    *   **Before**: `at::native::mps::DynamicMetalShaderLibrary mps_lib_0(R"MTL(...)MTL");`
    *   **After**: `const char* mps_lib_0_source = R"MTL(...)MTL";`
*   Updated kernel call generation  to use shimified functions:
    *   Generates calls to shimified API instead of direct libtorch calls

## Before vs After Comparison

### Section 1: Shader Library
**Before (Direct Library Object)**
```cpp
at::native::mps::DynamicMetalShaderLibrary mps_lib_0(R"MTL(
    ...
)MTL");
```
**After (Source String)**
```cpp
const char* mps_lib_0_source = (R"MTL(
    ...
)MTL");
```

### Section 2: Getter Functions & RAII Management

**Before (Direct Library Access)**
```cpp
const std::shared_ptr<at::native::mps::MetalKernelFunction> get_mps_lib_0() {
    static const auto func = mps_lib_0.getKernelFunction("generated_kernel");
    return func;
}

AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() {
    static const auto handle = AOTIMetalKernelFunctionHandle(get_mps_lib_0().get());
    return handle;
}
```

**After (Shim API + RAII Wrapper)**
```cpp
AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() {
    static auto kernel_handle = []() {
        AOTIMetalShaderLibraryHandle lib_handle = nullptr;
        AOTIMetalKernelFunctionHandle kern_handle = nullptr;

        aoti_torch_mps_create_shader_library(mps_lib_0_source, &lib_handle);
        aoti_torch_mps_get_kernel_function(lib_handle, "generated_kernel", &kern_handle);

        // RAII wrapper with custom deleter
        auto lib_deleter = [](AOTIMetalShaderLibraryHandle h) {{
            if (h) aoti_torch_mps_delete_shader_library(h);
        }};

        using LibDeleter = decltype(lib_deleter);
        using LibPtr = std::unique_ptr<AOTIMetalShaderLibraryOpaque, LibDeleter>;

        // Return pair of kernel handle and library smart pointer for cleanup
        return std::make_pair(kern_handle, LibPtr(lib_handle, lib_deleter));
    }();
    return kernel_handle.first;
}
```

### Section 3: Runtime Execution

**Before (Direct Library Methods)**
```cpp
void AOTInductorModel::run_impl(...) {

    ...

    get_mps_lib_0()->runCommandBlock([&] {
        get_mps_lib_0()->startEncoding();
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 0, buf0);
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 1, arg0_1);
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 2, arg1_1);
        get_mps_lib_0()->dispatch({static_cast<uint64_t>(10LL)});

    });

    ...

} // AOTInductorModel::run_impl
```

**After (Shim API with Lambda Pattern)**
```cpp
void AOTInductorModel::run_impl(...) {

    ...

    auto mps_lib_0_lambda_0 = [&](AOTIMetalKernelFunctionHandle handle) {
        aoti_torch_mps_start_encoding(handle);
        aoti_torch_mps_set_arg_tensor(handle, 0, buf0);
        aoti_torch_mps_set_arg_tensor(handle, 1, arg0_1);
        aoti_torch_mps_set_arg_tensor(handle, 2, arg1_1);
        aoti_torch_mps_dispatch_single(handle, static_cast<uint64_t>(10LL));
    };

    std::function<void(AOTIMetalKernelFunctionHandle)> mps_lib_0_func_wrapper_0 = mps_lib_0_lambda_0;
    aoti_torch_mps_run_command_block(get_mps_lib_0_handle(), aoti_torch_mps_shared_callback, &mps_lib_0_func_wrapper_0);

    ...

} // AOTInductorModel::run_impl
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163865
Approved by: https://github.com/angelayi, https://github.com/desertfire
2025-10-09 09:28:10 +00:00
e09fb44ef1 Revert "Fix truediv numerics between eager and compile (#164144)"
This reverts commit d386325ca9a142419f45b987391f4bb175dd7d0b.

Reverted https://github.com/pytorch/pytorch/pull/164144 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/164144#issuecomment-3384769092))
2025-10-09 08:40:52 +00:00
418 changed files with 7875 additions and 3974 deletions

View File

@ -181,7 +181,7 @@ case "$tag" in
KATEX=yes
UCX_COMMIT=${_UCX_COMMIT}
UCC_COMMIT=${_UCC_COMMIT}
PYTORCH_ROCM_ARCH="gfx90a;gfx942;gfx950"
PYTORCH_ROCM_ARCH="gfx90a;gfx942;gfx950;gfx1100"
if [[ $tag =~ "benchmarks" ]]; then
INDUCTOR_BENCHMARKS=yes
fi
@ -344,7 +344,7 @@ docker build \
--build-arg "NINJA_VERSION=${NINJA_VERSION:-}" \
--build-arg "KATEX=${KATEX:-}" \
--build-arg "ROCM_VERSION=${ROCM_VERSION:-}" \
--build-arg "PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH:-gfx90a;gfx942;gfx1100}" \
--build-arg "PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}" \
--build-arg "IMAGE_NAME=${IMAGE_NAME}" \
--build-arg "UCX_COMMIT=${UCX_COMMIT}" \
--build-arg "UCC_COMMIT=${UCC_COMMIT}" \

View File

@ -10,11 +10,6 @@ BAD_SSL = "https://self-signed.badssl.com"
print("Testing SSL certificate checking for Python:", sys.version)
if sys.version_info[:2] < (2, 7) or sys.version_info[:2] < (3, 4):
print("This version never checks SSL certs; skipping tests")
sys.exit(0)
EXC = OSError
print(f"Connecting to {GOOD_SSL} should work")

View File

@ -233,7 +233,9 @@ if [[ "${BUILD_ENVIRONMENT}" != *cuda* ]]; then
export BUILD_STATIC_RUNTIME_BENCHMARK=ON
fi
if [[ "$BUILD_ENVIRONMENT" == *-debug* ]]; then
if [[ "$BUILD_ENVIRONMENT" == *-full-debug* ]]; then
export CMAKE_BUILD_TYPE=Debug
elif [[ "$BUILD_ENVIRONMENT" == *-debug* ]]; then
export CMAKE_BUILD_TYPE=RelWithAssert
fi
@ -299,6 +301,11 @@ else
python -m build --wheel --no-isolation
fi
pip_install_whl "$(echo dist/*.whl)"
if [[ "$BUILD_ENVIRONMENT" == *full-debug* ]]; then
# Regression test for https://github.com/pytorch/pytorch/issues/164297
# Torch should be importable and that's about it
pushd /; python -c "import torch;print(torch.__config__.show(), torch.randn(5) + 1.7)"; popd
fi
if [[ "${BUILD_ADDITIONAL_PACKAGES:-}" == *vision* ]]; then
install_torchvision

View File

@ -337,13 +337,13 @@ test_python() {
test_python_smoke() {
# Smoke tests for H100/B200
time python test/run_test.py --include test_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
assert_git_not_dirty
}
test_python_smoke_b200() {
# Targeted smoke tests for B200 - staged approach to avoid too many failures
time python test/run_test.py --include test_matmul_cuda inductor/test_fp8 $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
assert_git_not_dirty
}

View File

@ -15,37 +15,35 @@ if errorlevel 1 exit /b 1
if not errorlevel 0 exit /b 1
cd %TMP_DIR_WIN%\build\torch\test
:: Enable delayed variable expansion to make the list
setlocal enabledelayedexpansion
set EXE_LIST=
for /r "." %%a in (*.exe) do (
call :libtorch_check "%%~na" "%%~fa"
if "%%~na" == "c10_intrusive_ptr_benchmark" (
@REM NB: This is not a gtest executable file, thus couldn't be handled by
@REM pytest-cpp and is excluded from test discovery by run_test
call "%%~fa"
if errorlevel 1 goto fail
if not errorlevel 0 goto fail
) else (
if "%%~na" == "verify_api_visibility" (
@REM Skip verify_api_visibility as it is a compile-level test
) else (
set EXE_LIST=!EXE_LIST! cpp/%%~na
)
)
)
goto :eof
:libtorch_check
cd %CWD%
set CPP_TESTS_DIR=%TMP_DIR_WIN%\build\torch\test
:: Skip verify_api_visibility as it a compile level test
if "%~1" == "verify_api_visibility" goto :eof
:: Run python test\run_test.py on the list
set NO_TD=True && python test\run_test.py --cpp --verbose -i !EXE_LIST!
if errorlevel 1 goto fail
if not errorlevel 0 goto fail
echo Running "%~2"
if "%~1" == "c10_intrusive_ptr_benchmark" (
:: NB: This is not a gtest executable file, thus couldn't be handled by pytest-cpp
call "%~2"
goto :eof
)
python test\run_test.py --cpp --verbose -i "cpp/%~1"
if errorlevel 1 (
echo %1 failed with exit code %errorlevel%
goto fail
)
if not errorlevel 0 (
echo %1 failed with exit code %errorlevel%
goto fail
)
goto :eof
:eof
exit /b 0

View File

@ -12,7 +12,7 @@ ignore =
# to line this up with executable bit
EXE001,
# these ignores are from flake8-bugbear; please fix!
B007,B008,B017,B019,B023,B028,B903,B904,B905,B906,B907,B908,B910
B007,B008,B017,B019,B023,B028,B903,B905,B906,B907,B908,B910
# these ignores are from flake8-comprehensions; please fix!
C407,
# these ignores are from flake8-logging-format; please fix!

View File

@ -274,8 +274,6 @@ runs:
-w /var/lib/jenkins/workspace \
"${DOCKER_IMAGE}"
)
# Propagate download.pytorch.org IP to container
grep download.pytorch.org /etc/hosts | docker exec -i "${container_name}" sudo bash -c "/bin/cat >> /etc/hosts"
echo "DOCKER_CONTAINER_ID=${container_name}" >> "${GITHUB_ENV}"
docker exec -t "${container_name}" sh -c "pip install $(echo dist/*.whl)[opt-einsum] && ${TEST_COMMAND}"

View File

@ -111,3 +111,16 @@ runs:
# This video group ID maps to subgid 1 inside the docker image due to the /etc/subgid entries.
# The group name corresponding to group ID 1 can change depending on the OS, so both are necessary.
echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd $DEVICE_FLAG --group-add video --group-add $render_gid --group-add daemon --group-add bin --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --network=host" >> "${GITHUB_ENV}"
- name: configure aws credentials
id: aws_creds
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
with:
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
aws-region: us-east-1
role-duration-seconds: 18000
- name: Login to Amazon ECR
id: login-ecr
continue-on-error: true
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1

View File

@ -33,10 +33,6 @@ runs:
)
echo "CONTAINER_NAME=${container_name}" >> "$GITHUB_ENV"
if [[ "${GPU_ARCH_TYPE}" != "rocm" && "${BUILD_ENVIRONMENT}" != "linux-aarch64-binary-manywheel" && "${BUILD_ENVIRONMENT}" != "linux-s390x-binary-manywheel" && "${GPU_ARCH_TYPE}" != "xpu" ]]; then
# Propagate download.pytorch.org IP to container. This is only needed on Linux non aarch64 runner
grep download.pytorch.org /etc/hosts | docker exec -i "${container_name}" bash -c "/bin/cat >> /etc/hosts"
fi
docker exec -t -w "${PYTORCH_ROOT}" "${container_name}" bash -c "bash .circleci/scripts/binary_populate_env.sh"
# Generate test script

View File

@ -30,6 +30,7 @@ ciflow_push_tags:
- ciflow/riscv64
- ciflow/rocm
- ciflow/rocm-mi300
- ciflow/rocm-mi355
- ciflow/s390
- ciflow/slow
- ciflow/torchbench

View File

@ -177,6 +177,9 @@ jobs:
runs-on: linux.rocm.gpu.mi250
timeout-minutes: !{{ common.timeout_minutes }}
!{{ upload.binary_env(config) }}
permissions:
id-token: write
contents: read
steps:
- name: Setup ROCm
uses: ./.github/actions/setup-rocm

View File

@ -389,8 +389,6 @@ jobs:
"${DOCKER_IMAGE}" \
${DOCKER_SHELL_CMD}
)
# Propagate download.pytorch.org IP to container
grep download.pytorch.org /etc/hosts | docker exec -i "${container_name}" sudo bash -c "/bin/cat >> /etc/hosts"
echo "DOCKER_CONTAINER_ID=${container_name}" >> "${GITHUB_ENV}"
if [[ ${BUILD_ENVIRONMENT} == *"s390x"* ]]; then

View File

@ -102,19 +102,6 @@ jobs:
exit 1
fi
- name: configure aws credentials
id: aws_creds
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
with:
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
aws-region: us-east-1
role-duration-seconds: 18000
- name: Login to Amazon ECR
id: login-ecr
continue-on-error: true
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
- name: Calculate docker image
id: calculate-docker-image
uses: pytorch/test-infra/.github/actions/calculate-docker-image@main

View File

@ -358,6 +358,9 @@ jobs:
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
permissions:
id-token: write
contents: read
steps:
- name: Setup ROCm
uses: ./.github/actions/setup-rocm
@ -473,6 +476,9 @@ jobs:
DOCKER_IMAGE_TAG_PREFIX: rocm7.0
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
permissions:
id-token: write
contents: read
steps:
- name: Setup ROCm
uses: ./.github/actions/setup-rocm

View File

@ -347,6 +347,9 @@ jobs:
DOCKER_IMAGE: manylinux2_28-builder
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
DESIRED_PYTHON: "3.10"
permissions:
id-token: write
contents: read
steps:
- name: Setup ROCm
uses: ./.github/actions/setup-rocm
@ -459,6 +462,9 @@ jobs:
DOCKER_IMAGE: manylinux2_28-builder
DOCKER_IMAGE_TAG_PREFIX: rocm7.0
DESIRED_PYTHON: "3.10"
permissions:
id-token: write
contents: read
steps:
- name: Setup ROCm
uses: ./.github/actions/setup-rocm
@ -941,6 +947,9 @@ jobs:
DOCKER_IMAGE: manylinux2_28-builder
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
DESIRED_PYTHON: "3.11"
permissions:
id-token: write
contents: read
steps:
- name: Setup ROCm
uses: ./.github/actions/setup-rocm
@ -1053,6 +1062,9 @@ jobs:
DOCKER_IMAGE: manylinux2_28-builder
DOCKER_IMAGE_TAG_PREFIX: rocm7.0
DESIRED_PYTHON: "3.11"
permissions:
id-token: write
contents: read
steps:
- name: Setup ROCm
uses: ./.github/actions/setup-rocm
@ -1535,6 +1547,9 @@ jobs:
DOCKER_IMAGE: manylinux2_28-builder
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
DESIRED_PYTHON: "3.12"
permissions:
id-token: write
contents: read
steps:
- name: Setup ROCm
uses: ./.github/actions/setup-rocm
@ -1647,6 +1662,9 @@ jobs:
DOCKER_IMAGE: manylinux2_28-builder
DOCKER_IMAGE_TAG_PREFIX: rocm7.0
DESIRED_PYTHON: "3.12"
permissions:
id-token: write
contents: read
steps:
- name: Setup ROCm
uses: ./.github/actions/setup-rocm
@ -2129,6 +2147,9 @@ jobs:
DOCKER_IMAGE: manylinux2_28-builder
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
DESIRED_PYTHON: "3.13"
permissions:
id-token: write
contents: read
steps:
- name: Setup ROCm
uses: ./.github/actions/setup-rocm
@ -2241,6 +2262,9 @@ jobs:
DOCKER_IMAGE: manylinux2_28-builder
DOCKER_IMAGE_TAG_PREFIX: rocm7.0
DESIRED_PYTHON: "3.13"
permissions:
id-token: write
contents: read
steps:
- name: Setup ROCm
uses: ./.github/actions/setup-rocm
@ -2723,6 +2747,9 @@ jobs:
DOCKER_IMAGE: manylinux2_28-builder
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
DESIRED_PYTHON: "3.13t"
permissions:
id-token: write
contents: read
steps:
- name: Setup ROCm
uses: ./.github/actions/setup-rocm
@ -2835,6 +2862,9 @@ jobs:
DOCKER_IMAGE: manylinux2_28-builder
DOCKER_IMAGE_TAG_PREFIX: rocm7.0
DESIRED_PYTHON: "3.13t"
permissions:
id-token: write
contents: read
steps:
- name: Setup ROCm
uses: ./.github/actions/setup-rocm
@ -3317,6 +3347,9 @@ jobs:
DOCKER_IMAGE: manylinux2_28-builder
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
DESIRED_PYTHON: "3.14"
permissions:
id-token: write
contents: read
steps:
- name: Setup ROCm
uses: ./.github/actions/setup-rocm
@ -3429,6 +3462,9 @@ jobs:
DOCKER_IMAGE: manylinux2_28-builder
DOCKER_IMAGE_TAG_PREFIX: rocm7.0
DESIRED_PYTHON: "3.14"
permissions:
id-token: write
contents: read
steps:
- name: Setup ROCm
uses: ./.github/actions/setup-rocm
@ -3911,6 +3947,9 @@ jobs:
DOCKER_IMAGE: manylinux2_28-builder
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
DESIRED_PYTHON: "3.14t"
permissions:
id-token: write
contents: read
steps:
- name: Setup ROCm
uses: ./.github/actions/setup-rocm
@ -4023,6 +4062,9 @@ jobs:
DOCKER_IMAGE: manylinux2_28-builder
DOCKER_IMAGE_TAG_PREFIX: rocm7.0
DESIRED_PYTHON: "3.14t"
permissions:
id-token: write
contents: read
steps:
- name: Setup ROCm
uses: ./.github/actions/setup-rocm

View File

@ -37,7 +37,7 @@ jobs:
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runner: "linux.12xlarge"
runner: "linux.c7i.12xlarge"
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90-dist
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
cuda-arch-list: '9.0'

View File

@ -130,7 +130,7 @@ jobs:
name: test-periodically
uses: ./.github/workflows/_linux-test.yml
needs: build
if: github.event.schedule == '15 0,12 * * 1-6'
if: github.event.schedule == '15 0 * * 1-6'
with:
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm90
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true

View File

@ -12,6 +12,7 @@ on:
- landchecks/*
tags:
- ciflow/pull/*
- ciflow/trunk/*
workflow_dispatch:
permissions: read-all
@ -32,10 +33,12 @@ jobs:
name: Get changed files
uses: ./.github/workflows/_get-changed-files.yml
with:
all_files: ${{ contains(github.event.pull_request.labels.*.name, 'lint-all-files') || contains(github.event.pull_request.labels.*.name, 'Reverted') }}
all_files: ${{ contains(github.event.pull_request.labels.*.name, 'lint-all-files') || contains(github.event.pull_request.labels.*.name, 'Reverted') || github.event_name == 'push' }}
lintrunner-clang:
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
# Needed to prevent deduping on HUD
name: lintrunner-clang-${{ needs.get-changed-files.outputs.changed-files == '*' && 'all' || 'partial' }}
needs: [get-label-type, get-changed-files]
# Only run if there are changed files relevant to clangtidy / clangformat
if: |
@ -75,6 +78,7 @@ jobs:
# fails to find types when it should
lintrunner-mypy:
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
name: lintrunner-mypy-${{ needs.get-changed-files.outputs.changed-files == '*' && 'all' || 'partial' }}
needs: [get-label-type, get-changed-files]
# Only run if there are changed files relevant to mypy
if: |
@ -99,6 +103,7 @@ jobs:
lintrunner-noclang:
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
name: lintrunner-noclang-${{ needs.get-changed-files.outputs.changed-files == '*' && 'all' || 'partial' }}
needs: [get-label-type, get-changed-files]
with:
timeout: 120

View File

@ -182,11 +182,11 @@ jobs:
docker-image-name: ci-image:pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11
test-matrix: |
{ include: [
{ config: "nogpu_AVX512", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.4xlarge.nvidia.gpu" },
{ config: "nogpu_AVX512", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.4xlarge.nvidia.gpu" },
{ config: "nogpu_AVX512", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.4xlarge.nvidia.gpu" },
{ config: "nogpu_NO_AVX2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.4xlarge.nvidia.gpu" },
{ config: "nogpu_NO_AVX2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.4xlarge.nvidia.gpu" },
{ config: "nogpu_AVX512", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
{ config: "nogpu_AVX512", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
{ config: "nogpu_AVX512", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
{ config: "nogpu_NO_AVX2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
{ config: "nogpu_NO_AVX2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
{ config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.4xlarge.nvidia.gpu" },
]}
secrets: inherit

View File

@ -127,6 +127,7 @@ jobs:
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner: linux.2xlarge.memory
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-py3.10-clang18-asan
docker-image-name: ci-image:pytorch-linux-jammy-py3-clang18-asan

View File

@ -1,6 +1,9 @@
name: rocm-mi355
on:
push:
tags:
- ciflow/rocm-mi355/*
workflow_dispatch:
schedule:
- cron: 30 11,1 * * * # about 4:30am PDT and 6:30pm PDT
@ -64,5 +67,7 @@ jobs:
build-environment: linux-noble-rocm-py3.12-mi355
docker-image: ${{ needs.linux-noble-rocm-py3_12-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-noble-rocm-py3_12-build.outputs.test-matrix }}
tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor"
tests-to-include: >-
${{ github.event_name == 'schedule' && 'test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor test_matmul_cuda test_scaled_matmul_cuda'
|| '' }}
secrets: inherit

View File

@ -140,6 +140,7 @@ jobs:
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner: linux.2xlarge.memory
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-py3.10-clang18-asan
docker-image-name: ci-image:pytorch-linux-jammy-py3-clang18-asan

View File

@ -56,7 +56,7 @@ jobs:
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
build-generates-artifacts: false
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runner: "linux.4xlarge"
runner: "linux.c7i.4xlarge"
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 1 },
@ -249,3 +249,14 @@ jobs:
docker-image: ${{ needs.linux-jammy-py3-clang12-executorch-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-py3-clang12-executorch-build.outputs.test-matrix }}
secrets: inherit
linux-jammy-py3_10-gcc11-full-debug-build-only:
name: linux-jammy-py3.10-gcc11-full-debug-build-only
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runner: linux.2xlarge.memory
build-environment: linux-jammy-py3.10-gcc11-full-debug-build-only
docker-image-name: ci-image:pytorch-linux-jammy-py3.10-gcc11
secrets: inherit

View File

@ -35,7 +35,7 @@ jobs:
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
build-environment: linux-jammy-xpu-n-1-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-1-py3
runner: linux.12xlarge
runner: linux.c7i.12xlarge
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 6, runner: "linux.idc.xpu" },
@ -56,7 +56,7 @@ jobs:
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
build-environment: linux-jammy-xpu-n-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3
runner: linux.12xlarge
runner: linux.c7i.12xlarge
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 8, runner: "linux.idc.xpu" },

View File

@ -388,9 +388,9 @@ cmake_dependent_option(USE_PRIORITIZED_TEXT_FOR_LD "Use prioritized text linker
option(USE_MIMALLOC "Use mimalloc" OFF)
# Enable third party mimalloc library to improve memory allocation performance
# on Windows.
# on Windows and AArch64.
option(USE_MIMALLOC_ON_MKL "Use mimalloc on MKL" OFF)
if(WIN32)
if(WIN32 OR (CPU_AARCH64 AND NOT APPLE))
set(USE_MIMALLOC ON)
# Not enable USE_MIMALLOC_ON_MKL due to it caused issue:

View File

@ -28,4 +28,19 @@ inline std::ostream& operator<<(std::ostream& stream, at::BlasBackend backend) {
return stream << BlasBackendToString(backend);
}
namespace blas {
enum class ScalingType : std::uint8_t {
TensorWise, // fp32 scales
RowWise, // fp32 scales
BlockWise1x16, // fp8_e4m3fn scales
BlockWise1x32, // fp8_e8m0fnu scales
BlockWise1x128, // fp32 scales
BlockWise128x128, // fp32 scales
};
enum class SwizzleType : std::uint8_t { NO_SWIZZLE = 0, SWIZZLE_32_4_4 = 1 };
} // namespace blas
} // namespace at

View File

@ -16,8 +16,8 @@ inline void check_size_nonnegative(ArrayRef<int64_t> size) {
inline void check_size_nonnegative(ArrayRef<c10::SymInt> size) {
for (const auto& x : size) {
TORCH_CHECK(
x.expect_size(__FILE__, __LINE__),
TORCH_SYM_CHECK(
x.sym_ge(0),
"Trying to create tensor with negative dimension ",
x,
": ",

View File

@ -4,6 +4,7 @@
#include <c10/core/ScalarType.h>
#include <c10/core/SymIntArrayRef.h>
#include <c10/util/DimVector.h>
#include <c10/util/Exception.h>
#include <optional>
#include <sstream>
#include <vector>
@ -26,9 +27,7 @@ inline void infer_size_impl(
std::optional<int64_t> infer_dim;
for (int64_t dim = 0, ndim = shape.size(); dim != ndim; dim++) {
if (TORCH_GUARD_OR_FALSE(sym_eq(shape[dim], -1))) {
if (infer_dim) {
throw std::runtime_error("only one dimension can be inferred");
}
TORCH_CHECK(!infer_dim, "only one dimension can be inferred");
infer_dim = dim;
} else {
// in case of unbacked shape[dim] we assume it's not -1 and add a runtime

View File

@ -103,9 +103,7 @@ std::string get_cpu_capability() {
#elif defined(HAVE_ZVECTOR_CPU_DEFINITION)
case native::CPUCapability::ZVECTOR:
return "Z VECTOR";
#elif defined(HAVE_SVE_CPU_DEFINITION) && defined(HAVE_ARM_BF16_CPU_DEFINITION)
case native::CPUCapability::SVE128:
return "SVE128";
#elif defined(HAVE_SVE256_CPU_DEFINITION) && defined(HAVE_ARM_BF16_CPU_DEFINITION)
case native::CPUCapability::SVE256:
return "SVE256";
#else

View File

@ -102,31 +102,8 @@ struct VecReduceAllSIMD<float, Op> {
#endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) &&
// !defined(C10_MOBILE)
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__)
#if defined(CPU_CAPABILITY_SVE256)
template <typename Op>
struct VecReduceAllSIMD<float, Op> {
static inline float apply(
const Op& vec_fun,
const Vectorized<float>& acc_vec) {
using Vec = Vectorized<float>;
Vec v = acc_vec;
// 128-bit shuffle
svuint32_t ind = svdupq_n_u32(4, 5, 6, 7);
Vec v1 = svtbl_f32(v, ind);
v = vec_fun(v, v1);
// 64-bit shuffle
ind = svdupq_n_u32(2, 3, 0, 1);
v1 = svtbl_f32(v, ind);
v = vec_fun(v, v1);
// 32-bit shuffle
ind = svdupq_n_u32(1, 0, 2, 3);
v1 = svtbl_f32(v, ind);
v = vec_fun(v, v1);
return svlasta(svpfalse(), v);
}
};
#else
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \
!defined(CPU_CAPABILITY_SVE)
template <typename Op>
struct VecReduceAllSIMD<float, Op> {
static inline float apply(
@ -163,8 +140,35 @@ struct VecReduceAllSIMD<float, std::plus<Vectorized<float>>> {
return vaddvq_f32(acc_vec);
}
};
#endif // defined(CPU_CAPABILITY_SVE256)
#endif // defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__)
// && !defined(CPU_CAPABILITY_SVE)
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \
defined(CPU_CAPABILITY_SVE256)
template <typename Op>
struct VecReduceAllSIMD<float, Op> {
static inline float apply(
const Op& vec_fun,
const Vectorized<float>& acc_vec) {
using Vec = Vectorized<float>;
Vec v = acc_vec;
// 128-bit shuffle
svuint32_t ind = svdupq_n_u32(4, 5, 6, 7);
Vec v1 = svtbl_f32(v, ind);
v = vec_fun(v, v1);
// 64-bit shuffle
ind = svdupq_n_u32(2, 3, 0, 1);
v1 = svtbl_f32(v, ind);
v = vec_fun(v, v1);
// 32-bit shuffle
ind = svdupq_n_u32(1, 0, 2, 3);
v1 = svtbl_f32(v, ind);
v = vec_fun(v, v1);
return svlasta(svpfalse(), v);
}
};
#endif // defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__)
// && defined(CPU_CAPABILITY_SVE256)
template <typename scalar_t, typename Op>
inline scalar_t vec_reduce_all(

View File

@ -1,21 +1,9 @@
#pragma once
#include <ATen/cpu/vec/intrinsics.h>
#include <c10/macros/Macros.h>
#include <cstdint>
#include <ATen/cpu/vec/vec_base.h>
#if defined(__aarch64__) && \
(defined(AT_BUILD_ARM_VEC256_WITH_SLEEF) || \
defined(AT_BUILD_ARM_VECSVE_WITH_SLEEF))
#define SLEEF_STATIC_LIBS
#include <sleef.h>
#define USE_SLEEF(sleef_code, non_sleef_code) sleef_code
#else
#define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code
#endif
#if defined(CPU_CAPABILITY_SVE)
// Define the data type of VLS(vector-length specific).

View File

@ -2,6 +2,7 @@
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/sve/sve_helper.h>
#include <ATen/cpu/vec/sve/vec_common_sve.h>
#include <ATen/cpu/vec/sve/vec_float.h>
#include <ATen/cpu/vec/vec_base.h>
#include <c10/util/bit_cast.h>
@ -307,8 +308,8 @@ Vectorized<c10::BFloat16> inline operator/(
}
inline Vectorized<BFloat16>::Vectorized() {
const short zero = 0;
values = svdup_n_bf16(c10::bit_cast<bfloat16_t>(zero));
auto vals_f = svdup_n_f32(0);
values = convert_float_bfloat16(vals_f, vals_f);
}
inline Vectorized<BFloat16>::Vectorized(int val) {

View File

@ -8,48 +8,13 @@
#include <ATen/cpu/vec/sve/sve_helper.h>
#include <ATen/cpu/vec/vec_base.h>
#ifdef CPU_CAPABILITY_SVE128
#include <ATen/cpu/vec/vec128/vec128_float_neon.h>
#include <ATen/cpu/vec/vec128/vec128_bfloat16_neon.h>
#include <ATen/cpu/vec/vec128/vec128_half_neon.h>
#include <ATen/cpu/vec/vec128/vec128_convert.h>
#include <ATen/cpu/vec/sve/vec_qint.h>
#elif defined(CPU_CAPABILITY_SVE)
#include <ATen/cpu/vec/sve/vec_float.h>
#if defined(CPU_CAPABILITY_SVE)
#include <ATen/cpu/vec/sve/vec_bfloat16.h>
#include <ATen/cpu/vec/sve/vec_double.h>
#include <ATen/cpu/vec/sve/vec_float.h>
#include <ATen/cpu/vec/sve/vec_int.h>
#include <ATen/cpu/vec/sve/vec_qint.h>
#include <ATen/cpu/vec/vec256/vec256_half.h>
#include <ATen/cpu/vec/vec256/vec256_convert.h>
#else // NEON
#include <ATen/cpu/vec/vec128/vec128_float_neon.h>
#include <ATen/cpu/vec/vec128/vec128_half_neon.h>
#include <ATen/cpu/vec/vec128/vec128_bfloat16_neon.h>
#include <ATen/cpu/vec/vec128/vec128_convert.h>
#include <ATen/cpu/vec/vec256/vec256_qint.h>
#endif // defined(CPU_CAPABILITY_SVE128)
#include <ATen/cpu/vec/functional.h>
#endif
namespace at::vec {
// Note [CPU_CAPABILITY namespace]
@ -83,6 +48,12 @@ DEFINE_SVE_CAST(int32_t, s32, float, f32)
DEFINE_SVE_CAST(int16_t, s16, float, f32)
DEFINE_SVE_CAST(float, f32, double, f64)
#ifdef __ARM_FEATURE_BF16
DEFINE_SVE_CAST(int64_t, s64, c10::BFloat16, bf16)
DEFINE_SVE_CAST(int32_t, s32, c10::BFloat16, bf16)
DEFINE_SVE_CAST(int16_t, s16, c10::BFloat16, bf16)
#endif // __ARM_FEATURE_BF16
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template <int64_t scale = 1>
@ -202,11 +173,9 @@ std::pair<
// group cols crossing lanes:
// return {a0, b0, a1, b1, a2, b2, a3, b3}
// {a4, b4, a5, b5, a6, b6, a7, b7}
svbfloat16_t aReg = a;
svbfloat16_t bReg = b;
Vectorized<c10::BFloat16> c = svzip1_bf16(aReg, bReg);
Vectorized<c10::BFloat16> d = svzip2_bf16(aReg, bReg);
return std::make_pair(c, d);
return std::make_pair(
Vectorized<c10::BFloat16>(svzip1_bf16(a, b)),
Vectorized<c10::BFloat16>(svzip2_bf16(a, b)));
}
#endif // __ARM_FEATURE_BF16
@ -255,27 +224,12 @@ std::pair<
// swap lanes:
// return {a0, a1, a2, a3, a4, a5, a6, a7}
// {b0, b1, b2, b3, b4, b5, b6, b7}
svbfloat16_t aReg = a;
svbfloat16_t bReg = b;
Vectorized<c10::BFloat16> c = svuzp1_bf16(aReg, bReg);
Vectorized<c10::BFloat16> d = svuzp2_bf16(aReg, bReg);
return std::make_pair(c, d);
return std::make_pair(
Vectorized<c10::BFloat16>(svuzp1_bf16((svbfloat16_t)a, (svbfloat16_t)b)),
Vectorized<c10::BFloat16>(svuzp2_bf16((svbfloat16_t)a, (svbfloat16_t)b)));
}
#endif // __ARM_FEATURE_BF16
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FLIP ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#define DEFINE_FLIP_FUNC(type, sve_func) \
inline Vectorized<type> flip(const Vectorized<type>& v) { \
return Vectorized<type>(sve_func(v)); \
}
// Use the macro to define the flip functions
DEFINE_FLIP_FUNC(float, svrev_f32)
DEFINE_FLIP_FUNC(double, svrev_f64)
DEFINE_FLIP_FUNC(int64_t, svrev_s64)
DEFINE_FLIP_FUNC(int32_t, svrev_s32)
DEFINE_FLIP_FUNC(int16_t, svrev_s16)
DEFINE_FLIP_FUNC(int8_t, svrev_s8)
#endif // defined(CPU_CAPABILITY_SVE)
} // namespace CPU_CAPABILITY

View File

@ -1,8 +1,6 @@
#pragma once
#if defined(__aarch64__)
#include <ATen/cpu/vec/vec_common_aarch64.h>
#elif defined(CPU_CAPABILITY_AVX512)
#if defined(CPU_CAPABILITY_AVX512)
#include <ATen/cpu/vec/vec512/vec512.h>
#else
#include <ATen/cpu/vec/vec128/vec128.h>
@ -13,34 +11,6 @@ namespace at::vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
inline std::ostream& operator<<(std::ostream& stream, const c10::qint32& val) {
stream << val.val_;
return stream;
}
inline std::ostream& operator<<(std::ostream& stream, const c10::qint8& val) {
stream << static_cast<int>(val.val_);
return stream;
}
inline std::ostream& operator<<(std::ostream& stream, const c10::quint8& val) {
stream << static_cast<unsigned int>(val.val_);
return stream;
}
template <typename T>
std::ostream& operator<<(std::ostream& stream, const Vectorized<T>& vec) {
T buf[Vectorized<T>::size()];
vec.store(buf);
stream << "vec[";
for (int i = 0; i != Vectorized<T>::size(); i++) {
if (i != 0) {
stream << ", ";
}
stream << buf[i];
}
stream << "]";
return stream;
}
inline Vectorized<bool> convert_to_bool(Vectorized<int8_t> x) {
__at_align__ bool buffer[x.size()];
x.ne(Vectorized<int8_t>(0)).store(buffer);

View File

@ -2,7 +2,6 @@
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with AVX]
#include <ATen/cpu/vec/sve/sve_helper.h>
#include <ATen/cpu/vec/vec128/vec128_float_neon.h>
#include <ATen/cpu/vec/vec128/vec128_reduced_precision_common_neon.h>
#include <ATen/cpu/vec/vec_base.h>
@ -263,13 +262,6 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
c10::bit_cast<at_bfloat16_t>(val6.x),
c10::bit_cast<at_bfloat16_t>(val7.x)}) {}
#ifdef CPU_CAPABILITY_SVE128
Vectorized(svbfloat16_t v) : Vectorized16(svget_neonq(v)) {}
operator svbfloat16_t() const {
return svset_neonq(svundef_bf16(), values);
}
#endif
static Vectorized<c10::BFloat16> blendv(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b,
@ -382,23 +374,6 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
Vectorized ge(const Vectorized& other) const;
Vectorized lt(const Vectorized& other) const;
Vectorized le(const Vectorized& other) const;
#ifdef CPU_CAPABILITY_SVE128
template <typename step_t>
static Vectorized<BFloat16> arange(
BFloat16 base = 0.f,
step_t step = static_cast<step_t>(1)) {
__at_align__ BFloat16 buffer[size()];
for (int64_t i = 0; i < size(); i++) {
buffer[i] = base + i * step;
}
return svget_neonq(
svld1_bf16(ptrue, reinterpret_cast<bfloat16_t*>(buffer)));
}
#endif // CPU_CAPABILITY_SVE128
}; // Vectorized<c10::BFloat16>
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_bfloat16_float(
@ -422,24 +397,6 @@ inline Vectorized<c10::BFloat16> convert_float_bfloat16(
return Vectorized<c10::BFloat16>(at_vcombine_bf16(x1, x2));
}
inline void load_fp32_from_bf16(const BFloat16* data, Vectorized<float>& out) {
__at_align__ float values[Vectorized<float>::size()];
for (const auto k : c10::irange(Vectorized<float>::size())) {
values[k] = data[k];
}
out = Vectorized<float>::loadu(values);
}
inline void load_fp32_from_bf16(
const BFloat16* data,
Vectorized<float>& out1,
Vectorized<float>& out2) {
Vectorized<BFloat16> bf16_vec = Vectorized<BFloat16>::loadu(data);
auto floats = convert_bfloat16_float(bf16_vec);
out1 = std::get<0>(floats);
out2 = std::get<1>(floats);
}
template <typename Op>
Vectorized<c10::BFloat16> binary_operator_via_float(
Op op,
@ -622,12 +579,6 @@ Vectorized<c10::BFloat16> inline fnmsub(
return -a * b - c;
}
#else //
CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16)
LOAD_FP32_NON_VECTORIZED_INIT(BFloat16, bf16)
#endif // !defined(C10_MOBILE) && defined(__aarch64__)
} // namespace CPU_CAPABILITY

View File

@ -4,7 +4,7 @@
namespace at::vec {
inline namespace CPU_CAPABILITY {
#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256)
#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256))
template <typename src_t>
struct VecConvert<
float,
@ -60,7 +60,6 @@ struct VecConvert<float, 1, BFloat16, 1> {
}
};
#endif // defined(__aarch64__) && (!defined(CPU_CAPABILITY_SVE) ||
// defined(CPU_CAPABILITY_SVE128))
#endif // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256)
} // namespace CPU_CAPABILITY
} // namespace at::vec

View File

@ -4,10 +4,13 @@
// See Note [Do not compile initializers with AVX]
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/sve/sve_helper.h>
#include <ATen/cpu/vec/vec_base.h>
#include <c10/util/irange.h>
#if defined(__aarch64__) && defined(AT_BUILD_ARM_VEC256_WITH_SLEEF)
#include <sleef.h>
#endif
// Sleef offers vectorized versions of some transcedentals
// such as sin, cos, tan etc..
// However for now opting for STL, since we are not building
@ -32,6 +35,12 @@ inline namespace CPU_CAPABILITY {
#error "Big endian is not supported."
#endif
#if defined(AT_BUILD_ARM_VEC256_WITH_SLEEF)
#define USE_SLEEF(sleef_code, non_sleef_code) sleef_code
#else
#define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code
#endif
template <int index, bool mask_val>
struct BlendRegs {
static float32x4_t impl(
@ -85,12 +94,6 @@ class Vectorized<float> {
operator float32x4_t() const {
return values;
}
#ifdef CPU_CAPABILITY_SVE128
Vectorized(svfloat32_t v) : values(svget_neonq(v)) {}
operator svfloat32_t() const {
return svset_neonq(svundef_f32(), values);
}
#endif
template <int64_t mask>
static Vectorized<float> blend(
const Vectorized<float>& a,

View File

@ -4,6 +4,7 @@
// See Note [Do not compile initializers with AVX]
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec128/vec128_convert.h>
#include <ATen/cpu/vec/vec128/vec128_float_neon.h>
#include <ATen/cpu/vec/vec128/vec128_reduced_precision_common_neon.h>
#include <ATen/cpu/vec/vec_base.h>
@ -24,6 +25,7 @@ inline namespace CPU_CAPABILITY {
// https://bugs.llvm.org/show_bug.cgi?id=45824
// Most likely we will do aarch32 support with inline asm.
#if !defined(C10_MOBILE) && defined(__aarch64__)
#ifdef __BIG_ENDIAN__
#error "Big endian is not supported."
#endif
@ -419,24 +421,6 @@ Vectorized<c10::Half> inline operator+(
#endif
}
inline void load_fp32_from_fp16(const c10::Half* data, Vectorized<float>& out) {
__at_align__ float values[Vectorized<float>::size()];
for (const auto k : c10::irange(Vectorized<float>::size())) {
values[k] = data[k];
}
out = Vectorized<float>::loadu(values);
}
inline void load_fp32_from_fp16(
const c10::Half* data,
Vectorized<float>& out1,
Vectorized<float>& out2) {
Vectorized<c10::Half> f16_vec = Vectorized<c10::Half>::loadu(data);
auto floats = convert_half_float(f16_vec);
out1 = std::get<0>(floats);
out2 = std::get<1>(floats);
}
template <>
Vectorized<c10::Half> inline operator-(
const Vectorized<c10::Half>& a,
@ -672,53 +656,6 @@ Vectorized<c10::Half> inline fnmsub(
return -a * b - c;
#endif
}
#else
#define CONVERT_NON_VECTORIZED_INIT(type, name) \
inline std::tuple<Vectorized<float>, Vectorized<float>> \
convert_##name##_float(const Vectorized<type>& a) { \
constexpr int64_t K = Vectorized<type>::size(); \
__at_align__ float arr[K]; \
__at_align__ type arr2[K]; \
a.store(arr2); \
convert(arr2, arr, K); \
return std::make_tuple( \
Vectorized<float>::loadu(arr), \
Vectorized<float>::loadu(arr + Vectorized<float>::size())); \
} \
inline Vectorized<type> convert_float_##name( \
const Vectorized<float>& a, const Vectorized<float>& b) { \
constexpr int64_t K = Vectorized<type>::size(); \
__at_align__ float arr[K]; \
__at_align__ type arr2[K]; \
a.store(arr); \
b.store(arr + Vectorized<float>::size()); \
convert(arr, arr2, K); \
return Vectorized<type>::loadu(arr2); \
}
#define LOAD_FP32_NON_VECTORIZED_INIT(type, name) \
inline void load_fp32_from_##name( \
const type* data, Vectorized<float>& out) { \
__at_align__ float values[Vectorized<float>::size()]; \
for (const auto k : c10::irange(Vectorized<float>::size())) { \
values[k] = data[k]; \
} \
out = Vectorized<float>::loadu(values); \
} \
\
inline void load_fp32_from_##name( \
const type* data, Vectorized<float>& out1, Vectorized<float>& out2) { \
load_fp32_from_##name(data, out1); \
data += Vectorized<float>::size(); \
load_fp32_from_##name(data, out2); \
}
CONVERT_NON_VECTORIZED_INIT(Half, half)
LOAD_FP32_NON_VECTORIZED_INIT(Half, fp16)
#endif // !defined(C10_MOBILE) && defined(__aarch64__)
} // namespace CPU_CAPABILITY

View File

@ -9,16 +9,21 @@
#if !( \
defined(__VSX__) || defined(CPU_CAPABILITY_VSX) || \
defined(CPU_CAPABILITY_ZVECTOR))
#include <ATen/cpu/vec/vec256/vec256_double.h>
#if defined(CPU_CAPABILITY_SVE256)
#include <ATen/cpu/vec/sve/vec_common_sve.h>
#else
// clang-format off
#include <ATen/cpu/vec/vec256/vec256_float.h>
#include <ATen/cpu/vec/vec256/vec256_double.h>
#include <ATen/cpu/vec/vec256/vec256_int.h>
#include <ATen/cpu/vec/vec256/vec256_qint.h>
#endif
#if !defined(CPU_CAPABILITY_SVE256) || !defined(__ARM_FEATURE_BF16)
#include <ATen/cpu/vec/vec256/vec256_bfloat16.h>
#endif
#include <ATen/cpu/vec/vec256/vec256_complex_double.h>
#include <ATen/cpu/vec/vec256/vec256_complex_float.h>
#include <ATen/cpu/vec/vec256/vec256_half.h>
#include <ATen/cpu/vec/vec256/vec256_complex_float.h>
#include <ATen/cpu/vec/vec256/vec256_complex_double.h>
// clang-format on
#elif defined(__VSX__) || defined(CPU_CAPABILITY_VSX)
#include <ATen/cpu/vec/vec256/vsx/vec256_common_vsx.h>
@ -51,6 +56,34 @@ namespace at::vec {
// accessed as `at::vec`.
inline namespace CPU_CAPABILITY {
inline std::ostream& operator<<(std::ostream& stream, const c10::qint32& val) {
stream << val.val_;
return stream;
}
inline std::ostream& operator<<(std::ostream& stream, const c10::qint8& val) {
stream << static_cast<int>(val.val_);
return stream;
}
inline std::ostream& operator<<(std::ostream& stream, const c10::quint8& val) {
stream << static_cast<unsigned int>(val.val_);
return stream;
}
template <typename T>
std::ostream& operator<<(std::ostream& stream, const Vectorized<T>& vec) {
T buf[Vectorized<T>::size()];
vec.store(buf);
stream << "vec[";
for (int i = 0; i != Vectorized<T>::size(); i++) {
if (i != 0) {
stream << ", ";
}
stream << buf[i];
}
stream << "]";
return stream;
}
#if defined(CPU_CAPABILITY_AVX2)
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX2) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -268,7 +268,9 @@ LOAD_FP32_VECTORIZED_INIT(BFloat16, bf16)
#else // defined(CPU_CAPABILITY_AVX2)
#if !(defined(__aarch64__))
#if !( \
defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \
!defined(CPU_CAPABILITY_SVE256))
CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16)
#endif

View File

@ -268,7 +268,9 @@ LOAD_FP32_VECTORIZED_INIT(Half, fp16)
#else // defined(CPU_CAPABILITY_AVX2)
#if !defined(__aarch64__) || defined(CPU_CAPABILITY_SVE256)
#if !( \
defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \
!defined(CPU_CAPABILITY_SVE256))
CONVERT_NON_VECTORIZED_INIT(Half, half)
#endif

View File

@ -5,13 +5,6 @@
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#ifdef __aarch64__
#if defined(CPU_CAPABILITY_SVE128) || !defined(CPU_CAPABILITY_SVE)
#include <ATen/cpu/vec/vec128/vec128_float_neon.h>
#endif
#endif
#include <ATen/native/quantized/AffineQuantizerBase.h>
#include <c10/util/irange.h>
@ -922,7 +915,7 @@ Vectorized<c10::quint8> inline maximum(
return a.maximum(b);
}
#else
#elif !defined(CPU_CAPABILITY_SVE256)
// NOTE: These are low-performance implementations that we fall back on
// if we are not building with AVX2. This may not be an issue, because
@ -1379,18 +1372,12 @@ Vectorized<c10::quint8> inline maximum(
return a.maximum(b);
}
#if defined(__aarch64__) && \
(defined(CPU_CAPABILITY_SVE128) || !defined(CPU_CAPABILITY_SVE))
#endif // if defined(CPU_CAPABILITY_AVX2)
#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256))
std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
at::vec::Vectorized<int8_t> src) {
#ifdef CPU_CAPABILITY_SVE
svint8_t x = src;
auto s8x8 = vget_low_s8(svget_neonq(x));
#else
auto s8x8 = vld1_s8(src.operator const int8_t*());
#endif
auto s16x8 = vmovl_s8(s8x8);
auto s32x4_hi = vmovl_s16(vget_high_s16(s16x8));
@ -1415,14 +1402,7 @@ std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
Vectorized<float> inline convert_int8_half_register_to_float(
at::vec::Vectorized<int8_t> src) {
#ifdef CPU_CAPABILITY_SVE
svint8_t x = src;
auto s8x8 = vget_low_s8(svget_neonq(x));
#else
auto s8x8 = vld1_s8(src.operator const int8_t*());
#endif
auto s16x8 = vmovl_s8(s8x8);
auto s32x4_lo = vmovl_s16(vget_low_s16(s16x8));
@ -1440,8 +1420,5 @@ Vectorized<float> inline convert_int8_half_register_to_float(
}
#endif
#endif // if defined(CPU_CAPABILITY_AVX2)
} // namespace CPU_CAPABILITY
} // namespace at::vec

View File

@ -31,6 +31,34 @@ namespace vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
inline std::ostream& operator<<(std::ostream& stream, const c10::qint32& val) {
stream << val.val_;
return stream;
}
inline std::ostream& operator<<(std::ostream& stream, const c10::qint8& val) {
stream << static_cast<int>(val.val_);
return stream;
}
inline std::ostream& operator<<(std::ostream& stream, const c10::quint8& val) {
stream << static_cast<unsigned int>(val.val_);
return stream;
}
template <typename T>
std::ostream& operator<<(std::ostream& stream, const Vectorized<T>& vec) {
T buf[Vectorized<T>::size()];
vec.store(buf);
stream << "vec[";
for (int i = 0; i != Vectorized<T>::size(); i++) {
if (i != 0) {
stream << ", ";
}
stream << buf[i];
}
stream << "]";
return stream;
}
#if defined(CPU_CAPABILITY_AVX512)
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX512)

View File

@ -67,7 +67,18 @@ Windows llvm will not have this definition.
#endif
#define VECTOR_WIDTH 64
#define int_vector __m512i
#elif defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_SVE256)
#elif defined(__aarch64__) && \
!defined(CPU_CAPABILITY_SVE) // CPU_CAPABILITY_AVX512
// SVE code expects 256-vectors; leave that set for SVE?
#if defined(__GNUC__)
#define __at_align__ __attribute__((aligned(16)))
#elif defined(_WIN32)
#define __at_align__ __declspec(align(16))
#else
#define __at_align__
#endif
#define VECTOR_WIDTH 16
#else // CPU_CAPABILITY_AVX512
#if defined(__GNUC__)
#define __at_align__ __attribute__((aligned(32)))
#elif defined(_WIN32)
@ -77,27 +88,7 @@ Windows llvm will not have this definition.
#endif
#define VECTOR_WIDTH 32
#define int_vector __m256i
#elif defined(__aarch64__)
// Define alignment and vector width for SVE128/Default (e.g., NEON)
#if defined(__GNUC__)
#define __at_align__ __attribute__((aligned(16)))
#elif defined(_WIN32)
#define __at_align__ __declspec(align(16))
#else
#define __at_align__
#endif
#define VECTOR_WIDTH 16
#else
// Fallback: define default alignment and vector width
#if defined(__GNUC__)
#define __at_align__ __attribute__((aligned(32)))
#elif defined(_WIN32)
#define __at_align__ __declspec(align(32))
#else
#define __at_align__
#endif
#define VECTOR_WIDTH 32
#endif
#endif // CPU_CAPABILITY_AVX512
namespace at::vec {
// See Note [CPU_CAPABILITY namespace]

View File

@ -1861,6 +1861,8 @@ template bool gemm_and_bias(
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation);
using at::blas::ScalingType;
int get_scale_mode(ScalingType scaling_type, ScalarType scale_dtype, bool use_fast_accum) {
switch (scaling_type) {
case ScalingType::BlockWise1x32:

View File

@ -14,6 +14,7 @@
*/
#include <ATen/cuda/CUDAContext.h>
#include <ATen/BlasBackend.h>
#include <ATen/OpMathType.h>
namespace at::cuda::blas {
@ -136,15 +137,6 @@ void int8_gemm(
int32_t* result_ptr,
int64_t result_ld);
enum class ScalingType : std::uint8_t {
TensorWise, // fp32 scales
RowWise, // fp32 scales
BlockWise1x16, // fp8_e4m3fn scales
BlockWise1x32, // fp8_e8m0fnu scales
BlockWise1x128, // fp32 scales
BlockWise128x128, // fp32 scales
};
void scaled_gemm(
char transa,
char transb,
@ -156,13 +148,13 @@ void scaled_gemm(
int64_t mat1_ld,
ScalarType mat1_dtype,
ScalarType mat1_scale_dtype,
ScalingType mat1_scaling_type,
at::blas::ScalingType mat1_scaling_type,
const void* mat2_ptr,
const void* mat2_scale_ptr,
int64_t mat2_ld,
ScalarType mat2_dtype,
ScalarType mat2_scale_dtype,
ScalingType mat2_scaling_type,
at::blas::ScalingType mat2_scaling_type,
const void* bias_ptr,
ScalarType bias_dtype,
void* result_ptr,

View File

@ -29,7 +29,7 @@
namespace at::cuda::tunable {
using at::cuda::blas::ScalingType;
using at::blas::ScalingType;
enum class BlasOp {
N = 0,

View File

@ -2,6 +2,8 @@
#include <ATen/ATen.h>
#include <c10/util/Exception.h>
namespace at::native {
cudnnDataType_t getCudnnDataTypeFromScalarType(const at::ScalarType dtype) {
@ -20,9 +22,10 @@ cudnnDataType_t getCudnnDataTypeFromScalarType(const at::ScalarType dtype) {
} else if (dtype == at::kByte) {
return CUDNN_DATA_UINT8;
}
std::string msg("getCudnnDataTypeFromScalarType() not supported for ");
msg += toString(dtype);
throw std::runtime_error(msg);
TORCH_CHECK(false,
"getCudnnDataTypeFromScalarType() not supported for ",
toString(dtype)
);
}
cudnnDataType_t getCudnnDataType(const at::Tensor& tensor) {

View File

@ -12,6 +12,7 @@
#include <ATen/native/IndexKernel.h>
#include <ATen/native/IndexingUtils.h>
#include <torch/library.h>
#include <c10/util/Exception.h>
// NOLINTBEGIN(bugprone-unchecked-optional-access)
@ -94,9 +95,10 @@ static std::vector<std::optional<Tensor>> batchIndices(
if (index.has_value() && index->sym_numel() != 0) {
const auto idx_bdim = indices_bdims[i];
indices_.emplace_back(maybePadToLogicalRank(moveBatchDimToFront(index.value(), idx_bdim), idx_bdim, maxLogicalRank));
if (index.value().dtype() == kBool && indices_bdims[i].has_value()) {
throw std::runtime_error("vmap: We do not support batching operators that can support dynamic shape. Attempting to batch over indexing with a boolean mask.");
}
TORCH_CHECK(
!(index.value().dtype() == kBool) || !indices_bdims[i].has_value(),
"vmap: We do not support batching operators that can support dynamic shape. Attempting to batch over indexing with a boolean mask."
);
} else {
indices_.push_back(index);
}

View File

@ -3,6 +3,7 @@
#include <ATen/functorch/Macros.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <c10/core/impl/LocalDispatchKeySet.h>
#include <c10/util/Exception.h>
#include <optional>
#include <bitset>
#include <utility>
@ -106,9 +107,10 @@ struct VmapInterpreterMeta {
template <typename T>
friend void to_json(T& json_j, const VmapInterpreterMeta& json_t) {
if (json_t.batchSize_.is_heap_allocated()) {
throw std::runtime_error("Serialization for heap-allocated SymInt is not implemented yet");
}
TORCH_CHECK(
!json_t.batchSize_.is_heap_allocated(),
"Serialization for heap-allocated SymInt is not implemented yet"
);
json_j["batchSize"] = json_t.batchSize_.as_int_unchecked();
json_j["randomness"] = static_cast<int64_t>(json_t.randomness_);
}
@ -302,7 +304,7 @@ struct Interpreter {
} else if (meta.contains("Functionalize")) {
json_t.meta_.emplace<FunctionalizeInterpreterMeta>(meta["Functionalize"].template get<FunctionalizeInterpreterMeta>());
} else {
throw std::runtime_error("unknown interpreter metadata type");
TORCH_CHECK(false, "unknown interpreter metadata type");
}
}

View File

@ -6,6 +6,7 @@
#include <ATen/functorch/BatchedTensorImpl.h>
#include <ATen/Dispatch.h>
#include <c10/util/irange.h>
#include <c10/util/Exception.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/native/LinearAlgebraUtils.h>
#include <ATen/native/xnnpack/Engine.h>
@ -108,9 +109,7 @@ Tensor binary_cross_entropy_with_logits_hack(
}
Tensor trace_backward_decomp(const Tensor& grad, IntArrayRef sizes) {
if (sizes.size() != 2) {
throw std::runtime_error("expected matrix input");
}
TORCH_CHECK(sizes.size() == 2, "expected matrix input");
auto grad_input = at::zeros(sizes[0] * sizes[1], grad.options());
auto indices = at::arange(0, grad_input.numel(), sizes[1] + 1, grad.options().dtype(at::kLong));
// Workaround using index_put instead of yet unsupported index_fill_

View File

@ -1157,103 +1157,103 @@ REGISTER_AVX512_DISPATCH(cholesky_stub, &cholesky_kernel)
REGISTER_AVX2_DISPATCH(cholesky_stub, &cholesky_kernel)
REGISTER_VSX_DISPATCH(cholesky_stub, &cholesky_kernel)
REGISTER_ZVECTOR_DISPATCH(cholesky_stub, &cholesky_kernel)
REGISTER_SVE_DISPATCH(cholesky_stub, &cholesky_kernel)
REGISTER_SVE256_DISPATCH(cholesky_stub, &cholesky_kernel)
REGISTER_ARCH_DISPATCH(cholesky_inverse_stub, DEFAULT, &cholesky_inverse_kernel_impl)
REGISTER_AVX512_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl)
REGISTER_AVX2_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl)
REGISTER_VSX_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl)
REGISTER_ZVECTOR_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl)
REGISTER_SVE_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl)
REGISTER_SVE256_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl)
REGISTER_ARCH_DISPATCH(linalg_eig_stub, DEFAULT, &linalg_eig_kernel)
REGISTER_AVX512_DISPATCH(linalg_eig_stub, &linalg_eig_kernel)
REGISTER_AVX2_DISPATCH(linalg_eig_stub, &linalg_eig_kernel)
REGISTER_VSX_DISPATCH(linalg_eig_stub, &linalg_eig_kernel)
REGISTER_ZVECTOR_DISPATCH(linalg_eig_stub, &linalg_eig_kernel)
REGISTER_SVE_DISPATCH(linalg_eig_stub, &linalg_eig_kernel)
REGISTER_SVE256_DISPATCH(linalg_eig_stub, &linalg_eig_kernel)
REGISTER_ARCH_DISPATCH(linalg_eigh_stub, DEFAULT, &linalg_eigh_kernel)
REGISTER_AVX512_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel)
REGISTER_AVX2_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel)
REGISTER_VSX_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel)
REGISTER_ZVECTOR_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel)
REGISTER_SVE_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel)
REGISTER_SVE256_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel)
REGISTER_ARCH_DISPATCH(geqrf_stub, DEFAULT, &geqrf_kernel)
REGISTER_AVX512_DISPATCH(geqrf_stub, &geqrf_kernel)
REGISTER_AVX2_DISPATCH(geqrf_stub, &geqrf_kernel)
REGISTER_VSX_DISPATCH(geqrf_stub, &geqrf_kernel)
REGISTER_ZVECTOR_DISPATCH(geqrf_stub, &geqrf_kernel)
REGISTER_SVE_DISPATCH(geqrf_stub, &geqrf_kernel)
REGISTER_SVE256_DISPATCH(geqrf_stub, &geqrf_kernel)
REGISTER_ARCH_DISPATCH(orgqr_stub, DEFAULT, &orgqr_kernel_impl)
REGISTER_AVX512_DISPATCH(orgqr_stub, &orgqr_kernel_impl)
REGISTER_AVX2_DISPATCH(orgqr_stub, &orgqr_kernel_impl)
REGISTER_VSX_DISPATCH(orgqr_stub, &orgqr_kernel_impl)
REGISTER_ZVECTOR_DISPATCH(orgqr_stub, &orgqr_kernel_impl)
REGISTER_SVE_DISPATCH(orgqr_stub, &orgqr_kernel_impl)
REGISTER_SVE256_DISPATCH(orgqr_stub, &orgqr_kernel_impl)
REGISTER_ARCH_DISPATCH(ormqr_stub, DEFAULT, &ormqr_kernel)
REGISTER_AVX512_DISPATCH(ormqr_stub, &ormqr_kernel)
REGISTER_AVX2_DISPATCH(ormqr_stub, &ormqr_kernel)
REGISTER_VSX_DISPATCH(ormqr_stub, &ormqr_kernel)
REGISTER_ZVECTOR_DISPATCH(ormqr_stub, &ormqr_kernel)
REGISTER_SVE_DISPATCH(ormqr_stub, &ormqr_kernel)
REGISTER_SVE256_DISPATCH(ormqr_stub, &ormqr_kernel)
REGISTER_ARCH_DISPATCH(lstsq_stub, DEFAULT, &lstsq_kernel)
REGISTER_AVX512_DISPATCH(lstsq_stub, &lstsq_kernel)
REGISTER_AVX2_DISPATCH(lstsq_stub, &lstsq_kernel)
REGISTER_VSX_DISPATCH(lstsq_stub, &lstsq_kernel)
REGISTER_ZVECTOR_DISPATCH(lstsq_stub, &lstsq_kernel)
REGISTER_SVE_DISPATCH(lstsq_stub, &lstsq_kernel)
REGISTER_SVE256_DISPATCH(lstsq_stub, &lstsq_kernel)
REGISTER_ARCH_DISPATCH(triangular_solve_stub, DEFAULT, &triangular_solve_kernel)
REGISTER_AVX512_DISPATCH(triangular_solve_stub, &triangular_solve_kernel)
REGISTER_AVX2_DISPATCH(triangular_solve_stub, &triangular_solve_kernel)
REGISTER_VSX_DISPATCH(triangular_solve_stub, &triangular_solve_kernel)
REGISTER_ZVECTOR_DISPATCH(triangular_solve_stub, &triangular_solve_kernel)
REGISTER_SVE_DISPATCH(triangular_solve_stub, &triangular_solve_kernel)
REGISTER_SVE256_DISPATCH(triangular_solve_stub, &triangular_solve_kernel)
REGISTER_ARCH_DISPATCH(lu_factor_stub, DEFAULT, &lu_factor_kernel)
REGISTER_AVX512_DISPATCH(lu_factor_stub, &lu_factor_kernel)
REGISTER_AVX2_DISPATCH(lu_factor_stub, &lu_factor_kernel)
REGISTER_VSX_DISPATCH(lu_factor_stub, &lu_factor_kernel)
REGISTER_ZVECTOR_DISPATCH(lu_factor_stub, &lu_factor_kernel)
REGISTER_SVE_DISPATCH(lu_factor_stub, &lu_factor_kernel)
REGISTER_SVE256_DISPATCH(lu_factor_stub, &lu_factor_kernel)
REGISTER_ARCH_DISPATCH(ldl_factor_stub, DEFAULT, &ldl_factor_kernel)
REGISTER_AVX512_DISPATCH(ldl_factor_stub, &ldl_factor_kernel)
REGISTER_AVX2_DISPATCH(ldl_factor_stub, &ldl_factor_kernel)
REGISTER_VSX_DISPATCH(ldl_factor_stub, &ldl_factor_kernel)
REGISTER_ZVECTOR_DISPATCH(ldl_factor_stub, &ldl_factor_kernel)
REGISTER_SVE_DISPATCH(ldl_factor_stub, &ldl_factor_kernel)
REGISTER_SVE256_DISPATCH(ldl_factor_stub, &ldl_factor_kernel)
REGISTER_ARCH_DISPATCH(ldl_solve_stub, DEFAULT, &ldl_solve_kernel)
REGISTER_AVX512_DISPATCH(ldl_solve_stub, &ldl_solve_kernel)
REGISTER_AVX2_DISPATCH(ldl_solve_stub, &ldl_solve_kernel)
REGISTER_VSX_DISPATCH(ldl_solve_stub, &ldl_solve_kernel)
REGISTER_ZVECTOR_DISPATCH(ldl_solve_stub, &ldl_solve_kernel)
REGISTER_SVE_DISPATCH(ldl_solve_stub, &ldl_solve_kernel)
REGISTER_SVE256_DISPATCH(ldl_solve_stub, &ldl_solve_kernel)
REGISTER_ARCH_DISPATCH(lu_solve_stub, DEFAULT, &lu_solve_kernel)
REGISTER_AVX512_DISPATCH(lu_solve_stub, &lu_solve_kernel)
REGISTER_AVX2_DISPATCH(lu_solve_stub, &lu_solve_kernel)
REGISTER_VSX_DISPATCH(lu_solve_stub, &lu_solve_kernel)
REGISTER_ZVECTOR_DISPATCH(lu_solve_stub, &lu_solve_kernel)
REGISTER_SVE_DISPATCH(lu_solve_stub, &lu_solve_kernel)
REGISTER_SVE256_DISPATCH(lu_solve_stub, &lu_solve_kernel)
REGISTER_ARCH_DISPATCH(svd_stub, DEFAULT, &svd_kernel)
REGISTER_AVX512_DISPATCH(svd_stub, &svd_kernel)
REGISTER_AVX2_DISPATCH(svd_stub, &svd_kernel)
REGISTER_VSX_DISPATCH(svd_stub, &svd_kernel)
REGISTER_ZVECTOR_DISPATCH(svd_stub, &svd_kernel)
REGISTER_SVE_DISPATCH(svd_stub, &svd_kernel)
REGISTER_SVE256_DISPATCH(svd_stub, &svd_kernel)
REGISTER_ARCH_DISPATCH(unpack_pivots_stub, DEFAULT, &unpack_pivots_cpu_kernel)
REGISTER_AVX512_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel)
REGISTER_AVX2_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel)
REGISTER_VSX_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel)
REGISTER_ZVECTOR_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel)
REGISTER_SVE_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel)
REGISTER_SVE256_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel)
} // namespace at::native

View File

@ -39,21 +39,19 @@ static CPUCapability compute_cpu_capability() {
}
#elif defined(HAVE_SVE_CPU_DEFINITION)
int sve_vl = cpuinfo_get_max_arm_sve_length(); //Returns maximum SVE VL supported by your HW.
if (envar == "sve") {
// Select SVE capability based on the maximum SVE VL supported by the HW.
#ifdef HAVE_SVE256_CPU_DEFINITION
if (envar == "sve256") {
if (sve_vl == 256) {
#ifdef HAVE_ARM_BF16_CPU_DEFINITION
if (cpuinfo_has_arm_bf16()) {
return CPUCapability::SVE256;
}
} else if (sve_vl == 128) {
if (cpuinfo_has_arm_bf16()) {
return CPUCapability::SVE128;
}
} else {
TORCH_WARN("SVE capability not available on hardware. Falling back to DEFAULT");
return CPUCapability::DEFAULT;
#endif
}
TORCH_WARN("SVE256 capability not available on hardware. Falling back to DEFAULT");
return CPUCapability::DEFAULT;
}
#endif
#else
#ifdef HAVE_AVX512_CPU_DEFINITION
if (envar == "avx512") {
@ -115,11 +113,6 @@ static CPUCapability compute_cpu_capability() {
#endif
}
#endif
#ifdef HAVE_SVE128_CPU_DEFINITION
if (sve_vl == 128) { // Check for SVE128
return CPUCapability::SVE128;
}
#endif
// Return the default CPU capability.
return CPUCapability::DEFAULT;
}
@ -154,9 +147,6 @@ DispatchResult DispatchStubImpl::try_get_call_ptr(
#ifdef HAVE_SVE256_CPU_DEFINITION
, void *SVE256
#endif
#ifdef HAVE_SVE128_CPU_DEFINITION
, void *SVE128
#endif
) {
constexpr auto supported_devices = c10::array_of<c10::DeviceType>(
c10::DeviceType::CPU,
@ -194,9 +184,6 @@ DispatchResult DispatchStubImpl::try_get_call_ptr(
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
, SVE256
#endif
#ifdef HAVE_SVE128_CPU_DEFINITION
, SVE128
#endif
);
if (!std::holds_alternative<ErrorType>(result)) {
@ -255,9 +242,6 @@ void* DispatchStubImpl::get_call_ptr(
#ifdef HAVE_SVE256_CPU_DEFINITION
, void *SVE256
#endif
#ifdef HAVE_SVE128_CPU_DEFINITION
, void *SVE128
#endif
) {
auto result = try_get_call_ptr(
@ -282,10 +266,6 @@ void* DispatchStubImpl::get_call_ptr(
#ifdef HAVE_SVE256_CPU_DEFINITION
,
SVE256
#endif
#ifdef HAVE_SVE128_CPU_DEFINITION
,
SVE128
#endif
);
if (std::holds_alternative<ErrorType>(result)) {
@ -320,9 +300,6 @@ DispatchResult DispatchStubImpl::try_choose_cpu_impl(
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
, void *SVE256
#endif
#ifdef HAVE_SVE128_CPU_DEFINITION
, void *SVE128
#endif
){
@ -365,16 +342,6 @@ DispatchResult DispatchStubImpl::try_choose_cpu_impl(
return DispatchResult(SVE256);
}
}
#endif
#ifdef HAVE_SVE128_CPU_DEFINITION
if (capability >= static_cast<int>(CPUCapability::SVE128)) {
if (C10_UNLIKELY(!SVE128)) {
// dispatch to DEFAULT, since the SVE kernel is missing
return DEFAULT != nullptr ? DispatchResult(DEFAULT) : ErrorType::MissingDeviceKernel;
} else {
return DispatchResult(SVE128);
}
}
#endif
return DEFAULT != nullptr ? DispatchResult(DEFAULT) : ErrorType::MissingDeviceKernel;
}
@ -396,9 +363,6 @@ void* DispatchStubImpl::choose_cpu_impl(
#ifdef HAVE_SVE256_CPU_DEFINITION
, void *SVE256
#endif
#ifdef HAVE_SVE128_CPU_DEFINITION
, void *SVE128
#endif
) {
auto capability = static_cast<int>(get_cpu_capability());
(void)capability;
@ -444,17 +408,6 @@ void* DispatchStubImpl::choose_cpu_impl(
return SVE256;
}
}
#endif
#ifdef HAVE_SVE128_CPU_DEFINITION
if (capability >= static_cast<int>(CPUCapability::SVE128)) {
if (C10_UNLIKELY(!SVE128)) {
// dispatch to DEFAULT, since the SVE kernel is missing
TORCH_INTERNAL_ASSERT(DEFAULT, "DispatchStub: missing default kernel");
return DEFAULT;
} else {
return SVE128;
}
}
#endif
TORCH_INTERNAL_ASSERT(DEFAULT, "DispatchStub: missing default kernel");
return DEFAULT;

View File

@ -64,9 +64,8 @@ enum class CPUCapability {
VSX = 1,
#elif defined(HAVE_ZVECTOR_CPU_DEFINITION)
ZVECTOR = 1,
#elif defined(HAVE_SVE_CPU_DEFINITION) && defined(HAVE_ARM_BF16_CPU_DEFINITION)
#elif defined(HAVE_SVE256_CPU_DEFINITION) && defined(HAVE_ARM_BF16_CPU_DEFINITION)
SVE256 = 1,
SVE128 = 2,
#else
AVX2 = 1,
AVX512 = 2,
@ -118,9 +117,6 @@ struct TORCH_API DispatchStubImpl {
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
, void *SVE256
#endif
#ifdef HAVE_SVE128_CPU_DEFINITION
, void *SVE128
#endif
);
@ -142,9 +138,6 @@ struct TORCH_API DispatchStubImpl {
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
, void *SVE256
#endif
#ifdef HAVE_SVE128_CPU_DEFINITION
, void *SVE128
#endif
);
@ -166,9 +159,6 @@ struct TORCH_API DispatchStubImpl {
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
, void *SVE256
#endif
#ifdef HAVE_SVE128_CPU_DEFINITION
, void *SVE128
#endif
);
@ -193,9 +183,6 @@ struct TORCH_API DispatchStubImpl {
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
, void *SVE256
#endif
#ifdef HAVE_SVE128_CPU_DEFINITION
, void *SVE128
#endif
);
@ -253,9 +240,6 @@ private:
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
, reinterpret_cast<void*>(SVE256)
#endif
#ifdef HAVE_SVE128_CPU_DEFINITION
, reinterpret_cast<void*>(SVE128)
#endif
)
);
@ -317,9 +301,6 @@ public:
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
, reinterpret_cast<void*>(SVE256)
#endif
#ifdef HAVE_SVE128_CPU_DEFINITION
, reinterpret_cast<void*>(SVE128)
#endif
);
if (std::holds_alternative<ErrorType>(result)){
@ -344,9 +325,6 @@ public:
#ifdef HAVE_SVE256_CPU_DEFINITION
static TORCH_API FnPtr SVE256;
#endif
#ifdef HAVE_SVE128_CPU_DEFINITION
static TORCH_API FnPtr SVE128;
#endif
private:
DispatchStubImpl impl;
};
@ -454,12 +432,6 @@ struct RegisterPRIVATEUSE1Dispatch {
#define REGISTER_SVE256_DISPATCH(name, fn)
#endif
#ifdef HAVE_SVE128_CPU_DEFINITION
#define REGISTER_SVE128_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, SVE128, fn)
#else
#define REGISTER_SVE128_DISPATCH(name, fn)
#endif
// Macro to register the same kernel for all CPU arch types. This is useful
// if a kernel does not benefit from being recompiled across different arch types.
#define REGISTER_ALL_CPU_DISPATCH(name, fn) \
@ -468,11 +440,6 @@ struct RegisterPRIVATEUSE1Dispatch {
REGISTER_AVX2_DISPATCH(name, fn) \
REGISTER_VSX_DISPATCH(name, fn) \
REGISTER_ZVECTOR_DISPATCH(name, fn) \
REGISTER_SVE256_DISPATCH(name, fn) \
REGISTER_SVE128_DISPATCH(name, fn)
#define REGISTER_SVE_DISPATCH(name, fn) \
REGISTER_SVE128_DISPATCH(name, fn) \
REGISTER_SVE256_DISPATCH(name, fn)
#define REGISTER_NO_CPU_DISPATCH(name) \
@ -515,7 +482,6 @@ struct RegisterPRIVATEUSE1Dispatch {
// REGISTER_DISPATCH now dispatches an AVX512 kernel to nullptr but registers other dispatches.
// ALSO_REGISTER_AVX512_DISPATCH should be used for ensuring AVX512 dispatch, among others.
// ALSO_REGISTER_SVE256_DISPATCH should be used for ensuring SVE256 dispatch, among others.
// ALSO_REGISTER_SVE128_DISPATCH should be used for ensuring SVE128 dispatch, among others.
#ifdef CPU_CAPABILITY_AVX512
#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, ((void*)(fn) ? nullptr : nullptr))
#else
@ -523,7 +489,6 @@ struct RegisterPRIVATEUSE1Dispatch {
#endif
#define ALSO_REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
#define ALSO_REGISTER_SVE256_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
#define ALSO_REGISTER_SVE128_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
#endif
} // namespace at::native

View File

@ -15,7 +15,11 @@ namespace at::native {
Scalar item(const Tensor& self) {
auto numel = self.sym_numel();
TORCH_CHECK(numel == 1, "a Tensor with ", numel, " elements cannot be converted to Scalar");
TORCH_SYM_CHECK(
numel.sym_eq(1),
"a Tensor with ",
numel,
" elements cannot be converted to Scalar");
if (self.is_sparse()) {
if (self._nnz() == 0) return Scalar(0);
if (self.is_coalesced()) return at::_local_scalar_dense(self._values());

View File

@ -466,7 +466,7 @@ REGISTER_AVX2_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cp
REGISTER_AVX512_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel)
REGISTER_VSX_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel)
REGISTER_ZVECTOR_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel)
REGISTER_SVE_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel)
REGISTER_SVE256_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel)
// offsets dispatches
REGISTER_ARCH_DISPATCH(
@ -477,7 +477,7 @@ REGISTER_AVX2_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cp
REGISTER_AVX512_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel)
REGISTER_VSX_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel)
REGISTER_ZVECTOR_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel)
REGISTER_SVE_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel)
REGISTER_SVE256_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel)
// Currently some computation is being duplicated across forward and backward.
// TODO: Cache indices in forward pass to reuse in backward
@ -548,7 +548,7 @@ REGISTER_VSX_DISPATCH(
REGISTER_ZVECTOR_DISPATCH(
_segment_reduce_lengths_backward_stub,
&_segment_reduce_cpu_lengths_backward_kernel)
REGISTER_SVE_DISPATCH(
REGISTER_SVE256_DISPATCH(
_segment_reduce_lengths_backward_stub,
&_segment_reduce_cpu_lengths_backward_kernel)
@ -568,7 +568,7 @@ REGISTER_VSX_DISPATCH(
REGISTER_ZVECTOR_DISPATCH(
_segment_reduce_offsets_backward_stub,
&_segment_reduce_cpu_offsets_backward_kernel)
REGISTER_SVE_DISPATCH(
REGISTER_SVE256_DISPATCH(
_segment_reduce_offsets_backward_stub,
&_segment_reduce_cpu_offsets_backward_kernel)

View File

@ -23,6 +23,14 @@
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_cast_Byte_native.h>
#include <ATen/ops/_cast_Char_native.h>
#include <ATen/ops/_cast_Double_native.h>
#include <ATen/ops/_cast_Float_native.h>
#include <ATen/ops/_cast_Half_native.h>
#include <ATen/ops/_cast_Int_native.h>
#include <ATen/ops/_cast_Long_native.h>
#include <ATen/ops/_cast_Short_native.h>
#include <ATen/ops/_dim_arange_native.h>
#include <ATen/ops/_efficientzerotensor_native.h>
#include <ATen/ops/_empty_affine_quantized.h>

View File

@ -406,7 +406,7 @@ scalar_t cubic_convolution2(scalar_t x, scalar_t A) {
}
template <typename scalar_t>
void get_cubic_upsample_coefficients(
static inline void get_cubic_upsample_coefficients(
scalar_t coeffs[4],
scalar_t t) {
scalar_t A = -0.75;

View File

@ -212,7 +212,7 @@ std::pair<vec::Vectorized<float>, vec::Vectorized<float>> fmadd(
const vec::Vectorized<c10::Half>& b,
const vec::Vectorized<float>& acc_low,
const vec::Vectorized<float>& acc_high) {
#if defined(__aarch64__) && ((defined(__ARM_FEATURE_FP16_FML) && !defined(__ARM_FEATURE_SVE)) || (defined(CPU_CAPABILITY_SVE128)))
#if defined(__ARM_FEATURE_FP16_FML) && !defined(CPU_CAPABILITY_SVE)
return std::make_pair(vfmlalq_low_f16(acc_low, a, b), vfmlalq_high_f16(acc_high, a, b));
#else
const auto [a_float_low, a_float_high] = convert_half_float(a);
@ -233,7 +233,7 @@ std::pair<vec::Vectorized<float>, vec::Vectorized<float>> fmadd(
// Return a + b_low * c_low + b_high * c_high
vec::Vectorized<float> fmadd(vec::Vectorized<float> a, vec::Vectorized<Half> b, vec::Vectorized<Half> c) {
#if defined(__aarch64__) && ((defined(__ARM_FEATURE_FP16_FML) && !defined(__ARM_FEATURE_SVE)) || (defined(CPU_CAPABILITY_SVE128)))
#if defined(__aarch64__) && defined(__ARM_FEATURE_FP16_FML) && !defined(__ARM_FEATURE_SVE)
// NOTE: this instruction is an optional instruction in ARM v8.2 and
// v8.3, but mandatory in v8.4 per
// https://developer.arm.com/documentation/ddi0596/2021-03/SIMD-FP-Instructions/FMLAL--FMLAL2--vector---Floating-point-fused-Multiply-Add-Long-to-accumulator--vector--?lang=en

File diff suppressed because it is too large Load Diff

View File

@ -88,9 +88,9 @@ __global__ void compute_grad_weight_bags(
const int64_t stride_warped) {
int64_t num_of_segments = *num_of_segments_ptr;
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
const int id = gid / stride_warped;
const int startFeature = gid % stride_warped;
const int64_t gid = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
const int64_t id = gid / stride_warped;
const int64_t startFeature = gid % stride_warped;
if (startFeature >= stride) {
return;
}
@ -134,9 +134,9 @@ __global__ void compute_grad_weight(
int64_t num_of_segments = *num_of_segments_ptr;
using accscalar_t = acc_type<scalar_t, true>;
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
const int id = gid / stride_warped;
const int startFeature = gid % stride_warped;
const int64_t gid = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
const int64_t id = gid / stride_warped;
const int64_t startFeature = gid % stride_warped;
if (startFeature >= stride) {
return;
}
@ -167,9 +167,9 @@ __global__ void sum_and_scatter(
int64_t num_of_segments = *num_of_segments_ptr;
int64_t num_of_partial_segments = *num_of_partial_segments_ptr;
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
const int id = gid / stride_warped;
const int startFeature = gid % stride_warped;
const int64_t gid = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
const int64_t id = gid / stride_warped;
const int64_t startFeature = gid % stride_warped;
if (startFeature >= stride) {
return;
}

View File

@ -710,6 +710,9 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
dim3 block(warp_size, indices_per_block);
#ifdef USE_ROCM
dim3 new_grid_many_indices(ceil_div(num_indices, (int64_t) (indices_per_block * warp_size)),
grid.y == 1 ? std::min<int>(at::cuda::getCurrentDeviceProperties()->maxGridSize[1], ceil_div(sliceSize, (int64_t) (warp_size))) : grid.y,
grid.z);
dim3 new_grid(ceil_div(num_indices, (int64_t) (indices_per_block * warp_size)), grid.y, grid.z);
size_t smem_dups_size = indices_per_block * warp_size * sizeof(int64_t);
#define KERNEL_GRID new_grid
@ -788,7 +791,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
expandedValue.scalar_type(),
"indexing_backward_many_indices",
AT_WRAP([&] {
indexing_backward_kernel_many_indices<scalar_t, UNROLL><<<new_grid, block, smem_dups_size, stream>>>(
indexing_backward_kernel_many_indices<scalar_t, UNROLL><<<new_grid_many_indices, block, smem_dups_size, stream>>>(
sorted_indices.const_data_ptr<int64_t>(),
orig_indices.const_data_ptr<int64_t>(),
expandedValue.const_data_ptr<scalar_t>(),

View File

@ -488,15 +488,16 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
}
}
int cat_dim = dimension;
if (memory_format != c10::MemoryFormat::Contiguous) {
switch (dimension) {
switch (cat_dim) {
case 0:
break;
case 1:
dimension = nDims - dimension;
cat_dim = nDims - cat_dim;
break;
default:
dimension--;
cat_dim--;
}
}
// Template Declarations for dim = 1, 2, 3, 4
@ -505,23 +506,23 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
constexpr auto elems_per_vec = alignment / sizeof(scalar_t); \
CatArrayBatchedCopy_vectorized<scalar_t, unsigned int, DIMS, batch_size, stride_size, alignment, elems_per_vec><<<\
catGrid, applyBlock, 0, stream.stream()>>>(\
(char*)data, catMetaData, kernelOutputParam, dimension, trailingSize);\
(char*)data, catMetaData, kernelOutputParam, cat_dim, trailingSize);\
} else if (isContig && isAligned && sizeof(scalar_t) > 2 && sizeof(scalar_t) <= 8) {\
CatArrayBatchedCopy_alignedK_contig<scalar_t, unsigned int, DIMS, batch_size, stride_size, ALIGNED_VEC_LOAD_BYTES_16><<<\
catGrid, applyBlock, 0, stream.stream()>>>(\
data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);\
data, catMetaData, outputParam, cat_dim, outputParam.tensorStride[cat_dim]);\
} else if (isContig && isAligned && sizeof(scalar_t) == 2) { \
CatArrayBatchedCopy_alignedK_contig<scalar_t, unsigned int, DIMS, batch_size, stride_size, ALIGNED_VEC_LOAD_BYTES_8><<<\
catGrid, applyBlock, 0, stream.stream()>>>(\
data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);\
data, catMetaData, outputParam, cat_dim, outputParam.tensorStride[cat_dim]);\
} else if (isContig) {\
CatArrayBatchedCopy_contig<scalar_t, unsigned int, DIMS, batch_size, stride_size><<<\
catGrid, applyBlock, 0, stream.stream()>>>(\
data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);\
data, catMetaData, outputParam, cat_dim, outputParam.tensorStride[cat_dim]);\
} else {\
CatArrayBatchedCopy<scalar_t, unsigned int, DIMS, batch_size, stride_size><<<\
catGrid, applyBlock, 0, stream.stream()>>>(\
data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);\
data, catMetaData, outputParam, cat_dim, outputParam.tensorStride[cat_dim]);\
}\
C10_CUDA_KERNEL_LAUNCH_CHECK();
switch (nDims) {

View File

@ -127,7 +127,7 @@ inline __host__ __device__ uint32_t getAlignmentRoundUp(const void* p) {
return diff == 0 ? 0 : uint32_t(Align) - diff;
}
#if defined (__gfx90a__) || defined(__gfx942__)
#if defined (__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)
#define CDNA2_OR_LATER 1
#else
#define CDNA2_OR_LATER 0
@ -143,7 +143,7 @@ template<typename T, uint32_t Rank>
using VecT = T __attribute__((ext_vector_type(Rank)));
static bool isCDNA2orLater(int index) {
return at::detail::getCUDAHooks().isGPUArch({"gfx90a", "gfx942"}, index);
return at::detail::getCUDAHooks().isGPUArch({"gfx90a", "gfx942", "gfx950"}, index);
}
#else

View File

@ -341,16 +341,22 @@ struct MHACacheKeyWrapper : ParamsWrapper<MHAParams> {
}
};
template <typename T, typename KeyType>
struct MHAGraphCache {
std::unordered_map<KeyType, T, ParamsWrapperHash<KeyType>> engine_cache;
using KeyType = MHACacheKeyWrapper;
using ValueType = std::unique_ptr<fe::graph::Graph>;
using MapType =
std::unordered_map<KeyType, ValueType, ParamsWrapperHash<KeyType>>;
using iterator = typename MapType::iterator;
using const_iterator = typename MapType::const_iterator;
MapType engine_cache;
int count = 0;
int hits = 0;
// no mutexes here as caches are now thread local for v8, can also return a
// pointer to the Execution Plan if we know it will not be invalidated by
// another thread
T* find(const KeyType& key) {
iterator find(const KeyType& key) {
static bool flag =
c10::utils::check_env("TORCH_CUDNN_SDPA_CACHE_DEBUG") == true;
if (flag && count) {
@ -363,15 +369,19 @@ struct MHAGraphCache {
}
count++;
auto it = engine_cache.find(key);
if (it == engine_cache.end()) {
return nullptr;
if (it != engine_cache.end()) {
hits++;
}
hits++;
return &(it->second);
return it;
}
void update(const KeyType& key, T& results) {
engine_cache.insert_or_assign(key, std::move(results));
const_iterator end() const {
return engine_cache.end();
}
template <typename... Args>
std::pair<iterator, bool> try_emplace(const KeyType& key, Args&&... args) {
return engine_cache.try_emplace(key, std::forward<Args>(args)...);
}
};
@ -380,16 +390,14 @@ struct MHAGraphCache {
// https://docs.nvidia.com/deeplearning/cudnn/backend/latest/release-notes.html
// We also leak the caches to workaround potential teardown race issues.
auto& getMHAGraphCache_() {
thread_local auto& instance =
*new MHAGraphCache<std::shared_ptr<fe::graph::Graph>, MHACacheKeyWrapper>;
return instance;
MHAGraphCache& getMHAGraphCache_() {
thread_local MHAGraphCache* instance{new MHAGraphCache()};
return *instance;
}
auto& getMHAGraphBackwardCache_() {
thread_local auto& instance =
*new MHAGraphCache<std::shared_ptr<fe::graph::Graph>, MHACacheKeyWrapper>;
return instance;
MHAGraphCache& getMHAGraphBackwardCache_() {
thread_local MHAGraphCache* instance{new MHAGraphCache()};
return *instance;
}
namespace {
@ -437,7 +445,7 @@ auto fixSizeOneDimStrideSDPA(
} // namespace
auto build_graph(
std::unique_ptr<fe::graph::Graph> build_graph(
int64_t b,
int64_t h,
int64_t s_q,
@ -461,7 +469,7 @@ auto build_graph(
if (q.scalar_type() == kBFloat16) {
dtype = fe::DataType_t::BFLOAT16;
}
auto mha_graph = std::make_shared<fe::graph::Graph>();
auto mha_graph = std::make_unique<fe::graph::Graph>();
// We're baking in float accumulation and scale types
// in theory the graph may support other types, but they
// have not been tested
@ -531,15 +539,13 @@ auto build_graph(
fe::graph::Tensor_attributes().set_uid(K).set_name("K"));
auto V_ = mha_graph->tensor(
fe::graph::Tensor_attributes().set_uid(V).set_name("V"));
std::optional<std::shared_ptr<fe::graph::Tensor_attributes>> bias;
if (attn_bias.has_value()) {
bias =
scaled_dot_product_flash_attention_options.set_bias(
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_uid(BIAS)
.set_name("bias")
.set_dim(attn_bias.value().sizes().vec())
.set_stride(attn_bias.value().strides().vec()));
scaled_dot_product_flash_attention_options.set_bias(bias.value());
.set_stride(attn_bias.value().strides().vec())));
}
auto [O_, Stats] =
@ -640,7 +646,7 @@ auto build_graph(
return mha_graph;
}
auto build_graph_nestedtensor(
std::unique_ptr<fe::graph::Graph> build_graph_nestedtensor(
int64_t b,
int64_t h_q,
int64_t h_k,
@ -668,7 +674,7 @@ auto build_graph_nestedtensor(
if (q.scalar_type() == kBFloat16) {
dtype = fe::DataType_t::BFLOAT16;
}
auto mha_graph = std::make_shared<fe::graph::Graph>();
auto mha_graph = std::make_unique<fe::graph::Graph>();
// We're baking in float accumulation and scale types
// in theory the graph may support other types, but they
// have not been tested
@ -766,18 +772,16 @@ auto build_graph_nestedtensor(
v_strides[strideidx0],
v_strides[strideidx1],
v_strides[strideidx2]}));
std::optional<std::shared_ptr<fe::graph::Tensor_attributes>> bias;
if (attn_bias.has_value()) {
TORCH_CHECK(
false,
"attn_bias not yet supportd with cuDNN Attention and NestedTensor");
bias =
scaled_dot_product_flash_attention_options.set_bias(
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_uid(BIAS)
.set_name("bias")
.set_dim(attn_bias.value().sizes().vec())
.set_stride(attn_bias.value().strides().vec()));
scaled_dot_product_flash_attention_options.set_bias(bias.value());
.set_stride(attn_bias.value().strides().vec())));
}
auto RAG_Q_OFF_ =
mha_graph->tensor(fe::graph::Tensor_attributes()
@ -847,7 +851,7 @@ auto build_graph_nestedtensor(
return mha_graph;
}
auto build_graph_backward(
std::unique_ptr<fe::graph::Graph> build_graph_backward(
int64_t b,
int64_t h,
int64_t s_q,
@ -874,7 +878,7 @@ auto build_graph_backward(
if (q.scalar_type() == kBFloat16) {
dtype = fe::DataType_t::BFLOAT16;
}
auto mha_graph = std::make_shared<fe::graph::Graph>();
auto mha_graph = std::make_unique<fe::graph::Graph>();
// We're baking in float accumulation and scale types
// in theory the graph may support other types, but they
// have not been tested
@ -919,15 +923,13 @@ auto build_graph_backward(
fe::graph::Tensor_attributes().set_uid(K).set_name("K"));
auto V_ = mha_graph->tensor(
fe::graph::Tensor_attributes().set_uid(V).set_name("V"));
std::optional<std::shared_ptr<fe::graph::Tensor_attributes>> bias;
if (attn_bias.has_value()) {
bias =
sdpa_backward_options.set_bias(
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_uid(BIAS)
.set_name("bias")
.set_dim(attn_bias.value().sizes().vec())
.set_stride(attn_bias.value().strides().vec()));
sdpa_backward_options.set_bias(bias.value());
.set_stride(attn_bias.value().strides().vec())));
}
if (dropout_probability != 0.0f) {
auto seed = mha_graph->tensor(fe::graph::Tensor_attributes()
@ -1061,7 +1063,7 @@ auto build_graph_backward(
return mha_graph;
}
auto build_graph_backward_nestedtensor(
std::unique_ptr<fe::graph::Graph> build_graph_backward_nestedtensor(
int64_t b,
int64_t h_q,
int64_t h_k,
@ -1092,7 +1094,7 @@ auto build_graph_backward_nestedtensor(
if (q.scalar_type() == kBFloat16) {
dtype = fe::DataType_t::BFLOAT16;
}
auto mha_graph = std::make_shared<fe::graph::Graph>();
auto mha_graph = std::make_unique<fe::graph::Graph>();
// We're baking in float accumulation and scale types
// in theory the graph may support other types, but they
// have not been tested
@ -1195,18 +1197,16 @@ auto build_graph_backward_nestedtensor(
o_strides[strideidx1],
o_strides[strideidx2]}));
std::optional<std::shared_ptr<fe::graph::Tensor_attributes>> bias;
if (attn_bias.has_value()) {
TORCH_CHECK(
false,
"attn_bias not yet supportd with cuDNN Attention and NestedTensor");
bias =
sdpa_backward_options.set_bias(
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_uid(BIAS)
.set_name("bias")
.set_dim(attn_bias.value().sizes().vec())
.set_stride(attn_bias.value().strides().vec()));
sdpa_backward_options.set_bias(bias.value());
.set_stride(attn_bias.value().strides().vec())));
}
auto RAG_Q_OFF_ =
mha_graph->tensor(fe::graph::Tensor_attributes()
@ -1378,7 +1378,7 @@ void run_cudnn_SDP_fprop(
// NB: The key initialization will round up sequence length, stride data etc.
// if use_ragged_in_dense is enabled (to allow multiple sequence lengths to
// reuse the same cached value/graph)
auto key = MHACacheKeyWrapper(
MHACacheKeyWrapper key(
b,
h,
s_q,
@ -1393,12 +1393,9 @@ void run_cudnn_SDP_fprop(
is_causal,
return_softmaxstats,
false);
auto graph_ptr = getMHAGraphCache_().find(key);
std::shared_ptr<fe::graph::Graph> mha_graph;
if (graph_ptr) {
mha_graph = *graph_ptr;
} else {
mha_graph = build_graph(
auto [cache_it, not_found] = getMHAGraphCache_().try_emplace(key, nullptr);
if (not_found) {
cache_it->second = build_graph(
b,
h,
s_q,
@ -1419,39 +1416,39 @@ void run_cudnn_SDP_fprop(
_dropoutoffset,
handle);
}
const fe::graph::Graph& mha_graph = *cache_it->second;
std::unordered_map<int64_t, void*> variant_pack = {
{Q, q.data_ptr()},
{K, k.data_ptr()},
{V, v.data_ptr()},
{Q, q.mutable_data_ptr()},
{K, k.mutable_data_ptr()},
{V, v.mutable_data_ptr()},
{SCALE, &scaling_factor},
{O, o.data_ptr()}};
{O, o.mutable_data_ptr()}};
if (return_softmaxstats) {
variant_pack[LSE] = softmaxstats.data_ptr();
variant_pack[LSE] = softmaxstats.mutable_data_ptr();
}
if (attn_bias.has_value()) {
variant_pack[BIAS] = attn_bias.value().data_ptr();
variant_pack[BIAS] = attn_bias.value().mutable_data_ptr();
}
if (dropout_probability != 0.0f) {
variant_pack[SEED] = _dropoutseed.data_ptr();
variant_pack[OFFSET] = _dropoutoffset.data_ptr();
variant_pack[SEED] = _dropoutseed.mutable_data_ptr();
variant_pack[OFFSET] = _dropoutoffset.mutable_data_ptr();
}
if (use_ragged_in_dense(q, k, v, o, attn_bias.has_value())) {
variant_pack[SEQ_LEN_Q] = seqlen_q.data_ptr();
variant_pack[SEQ_LEN_KV] = seqlen_kv.data_ptr();
variant_pack[RAG_Q_OFF] = rag_off_q.data_ptr();
variant_pack[RAG_K_OFF] = rag_off_k.data_ptr();
variant_pack[RAG_V_OFF] = rag_off_v.data_ptr();
variant_pack[RAG_O_OFF] = rag_off_o.data_ptr();
variant_pack[SEQ_LEN_Q] = seqlen_q.mutable_data_ptr();
variant_pack[SEQ_LEN_KV] = seqlen_kv.mutable_data_ptr();
variant_pack[RAG_Q_OFF] = rag_off_q.mutable_data_ptr();
variant_pack[RAG_K_OFF] = rag_off_k.mutable_data_ptr();
variant_pack[RAG_V_OFF] = rag_off_v.mutable_data_ptr();
variant_pack[RAG_O_OFF] = rag_off_o.mutable_data_ptr();
if (return_softmaxstats) {
variant_pack[RAG_LSE_OFF] = rag_off_lse.data_ptr();
variant_pack[RAG_LSE_OFF] = rag_off_lse.mutable_data_ptr();
}
}
auto workspace_size = mha_graph->get_workspace_size();
auto workspace_size = mha_graph.get_workspace_size();
auto workspace_ptr =
c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
TORCH_CHECK(
mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good());
getMHAGraphCache_().update(key, mha_graph);
mha_graph.execute(handle, variant_pack, workspace_ptr.get()).is_good());
}
void run_cudnn_SDP_fprop_nestedtensor(
@ -1491,7 +1488,7 @@ void run_cudnn_SDP_fprop_nestedtensor(
softmaxstats = at::empty({q.size(0), h_q, 1}, q.options().dtype(kFloat));
}
auto key = MHACacheKeyWrapper(
MHACacheKeyWrapper key(
b,
h_q,
s_q, // max-seqlen-q
@ -1506,13 +1503,12 @@ void run_cudnn_SDP_fprop_nestedtensor(
is_causal,
return_softmaxstats,
true);
auto graph_ptr = getMHAGraphCache_().find(key);
std::shared_ptr<fe::graph::Graph> mha_graph;
if (graph_ptr) {
mha_graph = *graph_ptr;
} else {
mha_graph = build_graph_nestedtensor(
MHAGraphCache& cache = getMHAGraphCache_();
auto cache_it = cache.find(key);
std::unique_ptr<fe::graph::Graph> mha_graph_storage;
if (cache_it == cache.end()) {
mha_graph_storage = build_graph_nestedtensor(
b,
h_q,
h_k,
@ -1537,40 +1533,44 @@ void run_cudnn_SDP_fprop_nestedtensor(
dropoutoffset,
handle);
}
const fe::graph::Graph& mha_graph =
mha_graph_storage ? *mha_graph_storage : *cache_it->second;
auto seqlen_q = at::diff(cum_seqlen_q, 1, 0);
auto seqlen_kv = at::diff(cum_seqlen_kv, 1, 0);
auto rag_q_off = cum_seqlen_q.mul(h_q * d_qk);
auto rag_k_off = cum_seqlen_kv.mul(h_k * d_v);
auto rag_v_off = cum_seqlen_kv.mul(h_v * d_v);
auto rag_q_off = cum_seqlen_q.mul(q.stride(-3));
auto rag_k_off = cum_seqlen_kv.mul(k.stride(-3));
auto rag_v_off = cum_seqlen_kv.mul(v.stride(-3));
auto rag_o_off = cum_seqlen_q.mul(o.stride(-3));
auto rag_stats_off = cum_seqlen_q.mul(h_q);
std::unordered_map<int64_t, void*> variant_pack = {
{Q, q.data_ptr()},
{K, k.data_ptr()},
{V, v.data_ptr()},
{Q, q.mutable_data_ptr()},
{K, k.mutable_data_ptr()},
{V, v.mutable_data_ptr()},
{SCALE, &scaling_factor},
{O, o.data_ptr()},
{RAG_Q_OFF, rag_q_off.data_ptr()},
{RAG_O_OFF, rag_q_off.data_ptr()},
{RAG_K_OFF, rag_k_off.data_ptr()},
{RAG_V_OFF, rag_v_off.data_ptr()},
{SEQ_LEN_Q, seqlen_q.data_ptr()},
{SEQ_LEN_KV, seqlen_kv.data_ptr()}};
{O, o.mutable_data_ptr()},
{RAG_Q_OFF, rag_q_off.mutable_data_ptr()},
{RAG_O_OFF, rag_o_off.mutable_data_ptr()},
{RAG_K_OFF, rag_k_off.mutable_data_ptr()},
{RAG_V_OFF, rag_v_off.mutable_data_ptr()},
{SEQ_LEN_Q, seqlen_q.mutable_data_ptr()},
{SEQ_LEN_KV, seqlen_kv.mutable_data_ptr()}};
if (return_softmaxstats) {
variant_pack[LSE] = softmaxstats.data_ptr();
variant_pack[RAG_LSE_OFF] = rag_stats_off.data_ptr();
variant_pack[LSE] = softmaxstats.mutable_data_ptr();
variant_pack[RAG_LSE_OFF] = rag_stats_off.mutable_data_ptr();
}
if (dropout_probability != 0.0f) {
variant_pack[SEED] = dropoutseed.data_ptr();
variant_pack[OFFSET] = dropoutoffset.data_ptr();
variant_pack[SEED] = dropoutseed.mutable_data_ptr();
variant_pack[OFFSET] = dropoutoffset.mutable_data_ptr();
}
if (attn_bias.has_value()) {
TORCH_CHECK("bias not supported with nestedtensor");
}
auto workspace_size = mha_graph->get_workspace_size();
auto workspace_size = mha_graph.get_workspace_size();
auto workspace_ptr =
c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
TORCH_CHECK(
mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good());
mha_graph.execute(handle, variant_pack, workspace_ptr.get()).is_good());
}
void run_cudnn_SDP_bprop(
@ -1652,7 +1652,7 @@ void run_cudnn_SDP_bprop(
}
cudnnHandle_t handle = getCudnnHandle();
auto key = MHACacheKeyWrapper(
MHACacheKeyWrapper key(
b,
h,
s_q,
@ -1667,12 +1667,10 @@ void run_cudnn_SDP_bprop(
is_causal,
true,
false);
auto graph_backward_ptr = getMHAGraphBackwardCache_().find(key);
std::shared_ptr<fe::graph::Graph> mha_graph;
if (graph_backward_ptr) {
mha_graph = *graph_backward_ptr;
} else {
mha_graph = build_graph_backward(
auto [cache_it, not_found] =
getMHAGraphBackwardCache_().try_emplace(key, nullptr);
if (not_found) {
cache_it->second = build_graph_backward(
b,
h,
s_q,
@ -1696,43 +1694,44 @@ void run_cudnn_SDP_bprop(
_dropoutoffset,
handle);
}
const fe::graph::Graph& mha_graph = *cache_it->second;
std::unordered_map<int64_t, void*> variant_pack = {
// inputs
{Q, q.data_ptr()},
{K, k.data_ptr()},
{V, v.data_ptr()},
{O, o.data_ptr()},
{DO, dO_.data_ptr()},
{LSE, softmaxstats.data_ptr()},
{Q, q.mutable_data_ptr()},
{K, k.mutable_data_ptr()},
{V, v.mutable_data_ptr()},
{O, o.mutable_data_ptr()},
{DO, dO_.mutable_data_ptr()},
{LSE, softmaxstats.mutable_data_ptr()},
// outputs
{DQ, dQ.data_ptr()},
{DK, dK.data_ptr()},
{DV, dV.data_ptr()},
{DQ, dQ.mutable_data_ptr()},
{DK, dK.mutable_data_ptr()},
{DV, dV.mutable_data_ptr()},
{SCALE, &scaling_factor}};
if (dropout_probability != 0.0f) {
variant_pack[SEED] = _dropoutseed.data_ptr();
variant_pack[OFFSET] = _dropoutoffset.data_ptr();
variant_pack[SEED] = _dropoutseed.mutable_data_ptr();
variant_pack[OFFSET] = _dropoutoffset.mutable_data_ptr();
}
if (attn_bias.has_value()) {
variant_pack[BIAS] = attn_bias.value().data_ptr();
variant_pack[BIAS] = attn_bias.value().mutable_data_ptr();
}
if (use_ragged_in_dense(q, k, v, o, attn_bias.has_value())) {
variant_pack[SEQ_LEN_Q] = seqlen_q.data_ptr();
variant_pack[SEQ_LEN_KV] = seqlen_kv.data_ptr();
variant_pack[RAG_Q_OFF] = rag_off_q.data_ptr();
variant_pack[RAG_K_OFF] = rag_off_k.data_ptr();
variant_pack[RAG_V_OFF] = rag_off_v.data_ptr();
variant_pack[RAG_O_OFF] = rag_off_o.data_ptr();
variant_pack[RAG_LSE_OFF] = rag_off_lse.data_ptr();
variant_pack[SEQ_LEN_Q] = seqlen_q.mutable_data_ptr();
variant_pack[SEQ_LEN_KV] = seqlen_kv.mutable_data_ptr();
variant_pack[RAG_Q_OFF] = rag_off_q.mutable_data_ptr();
variant_pack[RAG_K_OFF] = rag_off_k.mutable_data_ptr();
variant_pack[RAG_V_OFF] = rag_off_v.mutable_data_ptr();
variant_pack[RAG_O_OFF] = rag_off_o.mutable_data_ptr();
variant_pack[RAG_LSE_OFF] = rag_off_lse.mutable_data_ptr();
}
auto workspace_size = mha_graph->get_workspace_size();
auto workspace_size = mha_graph.get_workspace_size();
auto workspace_ptr =
c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
TORCH_CHECK(!workspace_size || workspace_ptr.get());
TORCH_CHECK(
mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good());
getMHAGraphBackwardCache_().update(key, mha_graph);
mha_graph.execute(handle, variant_pack, workspace_ptr.get()).is_good());
}
void run_cudnn_SDP_bprop_nestedtensor(
@ -1775,9 +1774,10 @@ void run_cudnn_SDP_bprop_nestedtensor(
auto seqlen_q = at::diff(cum_seqlen_q, 1, 0);
auto seqlen_kv = at::diff(cum_seqlen_kv, 1, 0);
auto rag_q_off = cum_seqlen_q.mul(h_q * d_qk);
auto rag_k_off = cum_seqlen_kv.mul(h_k * d_v);
auto rag_v_off = cum_seqlen_kv.mul(h_v * d_v);
auto rag_q_off = cum_seqlen_q.mul(q.stride(-3));
auto rag_k_off = cum_seqlen_kv.mul(k.stride(-3));
auto rag_v_off = cum_seqlen_kv.mul(v.stride(-3));
auto rag_o_off = cum_seqlen_q.mul(o.stride(-3));
auto rag_stats_off = cum_seqlen_q.mul(h_q);
auto dprops = at::cuda::getCurrentDeviceProperties();
@ -1791,7 +1791,7 @@ void run_cudnn_SDP_bprop_nestedtensor(
cudnnHandle_t handle = getCudnnHandle();
auto key = MHACacheKeyWrapper(
MHACacheKeyWrapper key(
b,
h_q,
s_q, // max-seqlen-q
@ -1806,13 +1806,12 @@ void run_cudnn_SDP_bprop_nestedtensor(
is_causal,
true,
true);
auto graph_ptr = getMHAGraphCache_().find(key);
std::shared_ptr<fe::graph::Graph> mha_graph;
if (graph_ptr) {
mha_graph = *graph_ptr;
} else {
mha_graph = build_graph_backward_nestedtensor(
MHAGraphCache& cache = getMHAGraphCache_();
auto cache_it = cache.find(key);
std::unique_ptr<fe::graph::Graph> mha_graph_storage;
if (cache_it == cache.end()) {
mha_graph_storage = build_graph_backward_nestedtensor(
b,
h_q,
h_k,
@ -1840,41 +1839,43 @@ void run_cudnn_SDP_bprop_nestedtensor(
dropoutoffset,
handle);
}
const fe::graph::Graph& mha_graph =
mha_graph_storage ? *mha_graph_storage : *cache_it->second;
std::unordered_map<int64_t, void*> variant_pack = {
// inputs
{Q, q.data_ptr()},
{K, k.data_ptr()},
{V, v.data_ptr()},
{O, o.data_ptr()},
{DO, dO_.data_ptr()},
{LSE, softmaxstats.data_ptr()},
{Q, q.mutable_data_ptr()},
{K, k.mutable_data_ptr()},
{V, v.mutable_data_ptr()},
{O, o.mutable_data_ptr()},
{DO, dO_.mutable_data_ptr()},
{LSE, softmaxstats.mutable_data_ptr()},
// outputs
{DQ, dQ.data_ptr()},
{DK, dK.data_ptr()},
{DV, dV.data_ptr()},
{DQ, dQ.mutable_data_ptr()},
{DK, dK.mutable_data_ptr()},
{DV, dV.mutable_data_ptr()},
{SCALE, &scaling_factor},
{RAG_Q_OFF, rag_q_off.data_ptr()},
{RAG_O_OFF, rag_q_off.data_ptr()},
{RAG_K_OFF, rag_k_off.data_ptr()},
{RAG_V_OFF, rag_v_off.data_ptr()},
{RAG_LSE_OFF, rag_stats_off.data_ptr()},
{SEQ_LEN_Q, seqlen_q.data_ptr()},
{SEQ_LEN_KV, seqlen_kv.data_ptr()}};
{RAG_Q_OFF, rag_q_off.mutable_data_ptr()},
{RAG_O_OFF, rag_o_off.mutable_data_ptr()},
{RAG_K_OFF, rag_k_off.mutable_data_ptr()},
{RAG_V_OFF, rag_v_off.mutable_data_ptr()},
{RAG_LSE_OFF, rag_stats_off.mutable_data_ptr()},
{SEQ_LEN_Q, seqlen_q.mutable_data_ptr()},
{SEQ_LEN_KV, seqlen_kv.mutable_data_ptr()}};
if (dropout_probability != 0.0f) {
variant_pack[SEED] = _dropoutseed.data_ptr();
variant_pack[OFFSET] = _dropoutoffset.data_ptr();
variant_pack[SEED] = _dropoutseed.mutable_data_ptr();
variant_pack[OFFSET] = _dropoutoffset.mutable_data_ptr();
}
TORCH_CHECK(
!attn_bias.has_value(),
"attn_bias not yet supportd with cuDNN Attention and NestedTensor");
auto workspace_size = mha_graph->get_workspace_size();
auto workspace_size = mha_graph.get_workspace_size();
auto workspace_ptr =
c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
TORCH_CHECK(!workspace_size || workspace_ptr.get());
TORCH_CHECK(
mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good());
mha_graph.execute(handle, variant_pack, workspace_ptr.get()).is_good());
}
} // namespace native

View File

@ -165,7 +165,7 @@ REGISTER_AVX2_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_co
REGISTER_AVX512_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
REGISTER_ZVECTOR_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
REGISTER_VSX_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
REGISTER_SVE_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
REGISTER_SVE256_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
// _out variants can be shared between PocketFFT and MKL
Tensor& _fft_r2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,

View File

@ -116,6 +116,8 @@ class MetalShaderLibrary {
std::vector<std::string> getFunctionNames();
std::shared_ptr<MetalKernelFunction> getKernelFunction(
const std::string& name);
// Returns a raw pointer to the kernel function for use in C APIs
MetalKernelFunction* getCachedKernelFunctionPtr(const std::string& name);
inline MTLComputePipelineState_t getPipelineStateForFunc(
const std::string& fname) {
return getLibraryPipelineState(getLibrary(), fname).first;
@ -164,6 +166,9 @@ class MetalShaderLibrary {
std::string,
std::pair<MTLComputePipelineState_t, MTLFunction_t>>
cplMap;
// Cache for kernel functions returned by getCachedKernelFunctionPtr
std::unordered_map<std::string, std::unique_ptr<MetalKernelFunction>>
kernelCache;
};
class DynamicMetalShaderLibrary : public MetalShaderLibrary {

View File

@ -917,6 +917,22 @@ std::shared_ptr<MetalKernelFunction> MetalShaderLibrary::getKernelFunction(const
return std::make_shared<MetalKernelFunction>(cpl, func);
}
MetalKernelFunction* MetalShaderLibrary::getCachedKernelFunctionPtr(const std::string& name) {
// Check if kernel is already cached
auto it = kernelCache.find(name);
if (it != kernelCache.end()) {
return it->second.get();
}
// Create new kernel function and cache it
auto [cpl, func] = getLibraryPipelineState(getLibrary(), name);
auto kernel = std::make_unique<MetalKernelFunction>(cpl, func);
MetalKernelFunction* raw_ptr = kernel.get();
kernelCache[name] = std::move(kernel);
return raw_ptr;
}
class BundledShaderLibary : public MetalShaderLibrary {
public:
BundledShaderLibary() : MetalShaderLibrary("") {}

View File

@ -5,6 +5,38 @@
# representing ScalarType's. They are now superseded by usage of
# `aten::to()`. The ops remain here for backward compatibility purposes.
# DEPRECATED. DO NOT USE
- func: _cast_Byte(Tensor self, bool non_blocking=False) -> Tensor
variants: function
# DEPRECATED. DO NOT USE
- func: _cast_Char(Tensor self, bool non_blocking=False) -> Tensor
variants: function
# DEPRECATED. DO NOT USE
- func: _cast_Double(Tensor self, bool non_blocking=False) -> Tensor
variants: function
# DEPRECATED. DO NOT USE
- func: _cast_Float(Tensor self, bool non_blocking=False) -> Tensor
variants: function
# DEPRECATED. DO NOT USE
- func: _cast_Int(Tensor self, bool non_blocking=False) -> Tensor
variants: function
# DEPRECATED. DO NOT USE
- func: _cast_Long(Tensor self, bool non_blocking=False) -> Tensor
variants: function
# DEPRECATED. DO NOT USE
- func: _cast_Short(Tensor self, bool non_blocking=False) -> Tensor
variants: function
# DEPRECATED. DO NOT USE
- func: _cast_Half(Tensor self, bool non_blocking=False) -> Tensor
variants: function
# Computes the gradient of current tensor w.r.t. graph leaves.
- func: _backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()
manual_cpp_binding: True
@ -7125,6 +7157,7 @@
CUDA: _scaled_mm_cuda
tags: needs_exact_strides
- func: _scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!)
variants: function
dispatch:
@ -7132,6 +7165,16 @@
CUDA: _scaled_mm_out_cuda
tags: needs_exact_strides
- func: _scaled_mm_v2(Tensor self, Tensor mat2, Tensor[] scale_a, int[] recipe_a, int[] swizzle_a, Tensor[] scale_b, int[] recipe_b, int[] swizzle_b, Tensor? bias, ScalarType? out_dtype, int[] contraction_dim=[], bool use_fast_accum=False) -> Tensor
variants: function
dispatch:
CUDA: _scaled_mm_cuda_v2
- func: _scaled_mm_v2.out(Tensor self, Tensor mat2, Tensor[] scale_a, int[] recipe_a, int[] swizzle_a, Tensor[] scale_b, int[] recipe_b, int[] swizzle_b, Tensor? bias, ScalarType? out_dtype, int[] contraction_dim=[], bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!)
variants: function
dispatch:
CUDA: _scaled_mm_cuda_v2_out
- func: _scaled_grouped_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? offs=None, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor
variants: function

View File

@ -48,8 +48,8 @@ std::tuple<Tensor, Tensor> fake_quantize_per_channel_affine_cachemask(
int64_t axis,
int64_t quant_min,
int64_t quant_max) {
TORCH_CHECK(scale.scalar_type() == ScalarType::Float,
"Scale must be Float, found ", scale.scalar_type());
TORCH_CHECK(scale.scalar_type() == ScalarType::Float || scale.scalar_type() == at::kBFloat16,
"Scale must be Float or BFloat16, found ", scale.scalar_type());
TORCH_CHECK(zero_point.scalar_type() == ScalarType::Int || zero_point.scalar_type() == ScalarType::Float || zero_point.scalar_type() == ScalarType::Half,
"Zero-point must be Int32, Float or Half, found ", zero_point.scalar_type());
TORCH_CHECK(scale.dim() == 1, "scale should be a 1-D tensor");

View File

@ -27,6 +27,6 @@ REGISTER_AVX512_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel)
REGISTER_AVX2_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel)
REGISTER_VSX_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel)
REGISTER_ZVECTOR_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel)
REGISTER_SVE_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel)
REGISTER_SVE256_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel)
} // namespace at::native

View File

@ -161,19 +161,19 @@ REGISTER_AVX512_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_
REGISTER_AVX2_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel)
REGISTER_VSX_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel)
REGISTER_ZVECTOR_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel)
REGISTER_SVE_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel)
REGISTER_SVE256_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel)
REGISTER_ARCH_DISPATCH(sparse_mask_intersection_out_stub, DEFAULT, &sparse_mask_intersection_out_cpu_kernel)
REGISTER_AVX512_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel)
REGISTER_AVX2_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel)
REGISTER_VSX_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel)
REGISTER_ZVECTOR_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel)
REGISTER_SVE_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel)
REGISTER_SVE256_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel)
REGISTER_ARCH_DISPATCH(sparse_mask_projection_out_stub, DEFAULT, &sparse_mask_projection_out_cpu_kernel)
REGISTER_AVX512_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel)
REGISTER_AVX2_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel)
REGISTER_VSX_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel)
REGISTER_ZVECTOR_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel)
REGISTER_SVE_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel)
REGISTER_SVE256_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel)
}

View File

@ -448,7 +448,7 @@ REGISTER_AVX2_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
REGISTER_AVX512_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
REGISTER_VSX_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
REGISTER_ZVECTOR_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
REGISTER_SVE_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
REGISTER_SVE256_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
REGISTER_HPU_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_meta)
int64_t _fused_sdp_choice_meta(

View File

@ -637,13 +637,7 @@ bool check_for_nested_inputs(sdp_params const& params, bool debug) {
TORCH_WARN("Experimental cuDNN SDPA nested tensor support is not enabled.");
}
return false;
} else if (has_for_nested_inputs(params) && (params.query.requires_grad() || params.key.requires_grad() || params.value.requires_grad())) {
if (debug) {
TORCH_WARN("Experimental cuDNN SDPA nested tensor support does not support backward.");
return false;
}
}
const auto dprop = at::cuda::getCurrentDeviceProperties();
// Check that the input is nested
if (!(dprop->major == 9 || dprop->major == 10) && has_for_nested_inputs(params)) {

View File

@ -37,7 +37,7 @@ class Benchmark(BenchmarkBase):
def f(a, b):
xs = b.tolist()
for x in xs:
torch._check_is_size(x)
torch._check(x >= 0)
torch._check(x <= self.N)
return a.split(xs)

View File

@ -52,7 +52,9 @@ constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset |
// where we would like to support composite implicit kernels but not
// explicit kernels therefore we manually add the key to the
// math_dispatch_keyset
DispatchKeySet{DispatchKey::NestedTensor};
DispatchKeySet{DispatchKey::NestedTensor} |
// Functionalize should always reuse CompositeImplicit decomps.
DispatchKeySet{DispatchKey::Functionalize};
constexpr DispatchKeySet nested_dispatch_keyset =
DispatchKeySet(

View File

@ -130,14 +130,6 @@ int64_t SymInt::guard_int(const char* file, int64_t line) const {
}
}
bool SymInt::expect_size(const char* file, int64_t line) const {
if (auto ma = maybe_as_int()) {
return *ma >= 0;
} else {
return toSymNodeImplUnowned()->expect_size(file, line);
}
}
SymInt operator-(const SymInt& s) {
if (auto ma = s.maybe_as_int()) {
const auto val = *ma;

View File

@ -153,14 +153,6 @@ class C10_API SymInt {
// number can be used to diagnose overspecialization.
int64_t guard_int(const char* file, int64_t line) const;
// Insert a guard that this SymInt must be size-like, returning true if
// the integer actually is >= 0. Unlike manually performing a >= 0 test,
// if the SymInt in question is an unbacked SymInt (or, potentially in the
// future, if it contains unbacked SymInts), we will also treat the
// unbacked SymInt as statically testing >= 2 (which will prevent us from
// choking on, e.g., contiguity checks.)
bool expect_size(const char* file, int64_t line) const;
// Distinguish actual symbolic values from constants stored on the heap
bool is_symbolic() const {
return is_heap_allocated() &&

View File

@ -210,11 +210,6 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
// with a better implementation!
return guard_bool(file, line);
}
virtual bool expect_size(const char* file, int64_t line) {
// No improvement for unbacked SymInts by default, replace this
// with a better implementation!
return ge(wrap_int(0))->guard_bool(file, line);
}
virtual int64_t int_() {
TORCH_CHECK(false, "NYI");
}

View File

@ -108,12 +108,15 @@ void* alloc_cpu(size_t nbytes) {
"DefaultCPUAllocator: not enough memory: you tried to allocate ",
nbytes,
" bytes.");
#elif defined(_MSC_VER)
#ifdef USE_MIMALLOC
#elif defined(USE_MIMALLOC)
data = mi_malloc_aligned(nbytes, gAlignment);
#else
CAFFE_ENFORCE(
data,
"DefaultCPUAllocator: not enough memory: you tried to allocate ",
nbytes,
" bytes.");
#elif defined(_MSC_VER)
data = _aligned_malloc(nbytes, gAlignment);
#endif
CAFFE_ENFORCE(
data,
"DefaultCPUAllocator: not enough memory: you tried to allocate ",
@ -160,12 +163,10 @@ void* alloc_cpu(size_t nbytes) {
}
void free_cpu(void* data) {
#ifdef _MSC_VER
#ifdef USE_MIMALLOC
mi_free(data);
#else
#elif defined(_MSC_VER)
_aligned_free(data);
#endif
#else
// NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
free(data);

View File

@ -638,11 +638,11 @@ struct ExpandableSegment {
return *stream_;
}
size_t getMappedSize() {
size_t getMappedSize() const {
return mapped_size_;
}
size_t getSegmentSize() {
size_t getSegmentSize() const {
return segment_size_;
}
@ -799,11 +799,11 @@ struct ExpandableSegment {
return nullptr;
}
size_t getMappedSize() {
size_t getMappedSize() const {
return 0;
}
size_t getSegmentSize() {
size_t getSegmentSize() const {
return 0;
}
void addPeer(c10::DeviceIndex device) {}
@ -824,14 +824,14 @@ struct BlockState {
// maintain invariant that event_count == 0 ;
// history will be left alone in checkpoint
BlockState(Block* block);
explicit BlockState(Block* block);
};
struct SegmentState {
std::vector<BlockState> blocks;
bool is_small = false;
SegmentState(Block* head);
explicit SegmentState(Block* head);
};
struct PrivatePoolState : AllocatorState {
@ -949,7 +949,7 @@ class EventPool {
// CUDA graphs helper
struct PrivatePool {
PrivatePool(MempoolId_t id, CUDAAllocator* allocator = nullptr)
explicit PrivatePool(MempoolId_t id, CUDAAllocator* allocator = nullptr)
: id(std::move(id)),
allocator_(allocator),
large_blocks(/*small=*/false, this),
@ -1078,7 +1078,7 @@ class RingBuffer {
}
}
void getEntries(std::vector<T>& result) {
void getEntries(std::vector<T>& result) const {
std::lock_guard<std::mutex> lk(alloc_trace_lock);
result.reserve(alloc_trace->size());
result.insert(
@ -1106,7 +1106,7 @@ class RingBuffer {
// Both alloc_trace and alloc_trace_next needs to be used
// under alloc_trace_lock.
std::mutex alloc_trace_lock;
mutable std::mutex alloc_trace_lock;
size_t alloc_trace_next = 0;
std::vector<T>*
alloc_trace; // pointer because we need to intentionally leak this on
@ -1299,7 +1299,7 @@ class DeviceCachingAllocator {
}
}
bool isHistoryEnabled() {
bool isHistoryEnabled() const {
return record_history;
}
@ -1315,7 +1315,7 @@ class DeviceCachingAllocator {
bool checkPoolLiveAllocations(
MempoolId_t mempool_id,
const std::unordered_set<void*>& expected_live_allocations) {
const std::unordered_set<void*>& expected_live_allocations) const {
std::unique_lock<std::recursive_mutex> lock(mutex);
PrivatePool* pool = nullptr;
@ -2081,7 +2081,7 @@ class DeviceCachingAllocator {
}
/** Returns a copy of the memory allocator stats **/
DeviceStats getStats() {
DeviceStats getStats() const {
std::lock_guard<std::recursive_mutex> lock(mutex);
return stats;
}
@ -2457,7 +2457,7 @@ class DeviceCachingAllocator {
}
std::vector<TraceEntry> trace(
const std::function<time_t(approx_time_t)>& tsc_to_us) {
const std::function<time_t(approx_time_t)>& tsc_to_us) const {
std::lock_guard<std::recursive_mutex> lock(mutex);
std::vector<TraceEntry> result;
alloc_buffer.getEntries(result);
@ -2593,7 +2593,7 @@ class DeviceCachingAllocator {
}
}
int getPoolUseCount(MempoolId_t mempool_id) {
int getPoolUseCount(MempoolId_t mempool_id) const {
std::lock_guard<std::recursive_mutex> lock(mutex);
auto pp = get_private_pool(mempool_id);
return pp->use_count;
@ -2689,7 +2689,7 @@ class DeviceCachingAllocator {
}
}
PrivatePool* get_private_pool(MempoolId_t mempool_id) {
PrivatePool* get_private_pool(MempoolId_t mempool_id) const {
auto it = graph_pools.find(mempool_id);
TORCH_INTERNAL_ASSERT(it != graph_pools.end());
return it->second.get();
@ -3686,7 +3686,7 @@ class DeviceCachingAllocator {
if (!compile_context.empty()) {
compile_string = compile_context.top();
}
auto te = TraceEntry(
TraceEntry te(
action,
device,
addr,

View File

@ -439,10 +439,6 @@ function(torch_compile_options libname)
$<$<COMPILE_LANGUAGE:CXX>: -fvisibility=hidden>)
endif()
# Use -O2 for release builds (-O3 doesn't improve perf, and -Os results in perf regression)
target_compile_options(${libname} PRIVATE
$<$<AND:$<COMPILE_LANGUAGE:CXX>,$<OR:$<CONFIG:Release>,$<CONFIG:RelWithDebInfo>>>:-O2>)
endfunction()
##############################################################################
@ -530,4 +526,4 @@ function(target_link_options_if_supported tgt flag)
else()
message(WARNING "Attempted to use unsupported link option : ${flag}.")
endif()
endfunction()
endfunction()

View File

@ -553,42 +553,6 @@ coverage_ignore_functions = [
# torch.distributed.checkpoint.utils
"find_state_dict_object",
"find_tensor_shard",
# torch.distributed.collective_utils
"all_gather",
"all_gather_object_enforce_type",
"broadcast",
# torch.distributed.distributed_c10d
"all_gather",
"all_gather_coalesced",
"all_gather_into_tensor",
"all_gather_object",
"all_reduce",
"all_reduce_coalesced",
"all_to_all",
"all_to_all_single",
"barrier",
"batch_isend_irecv",
"broadcast",
"broadcast_object_list",
"destroy_process_group",
"gather",
"gather_object",
"get_backend",
"get_backend_config",
"get_global_rank",
"get_group_rank",
"get_process_group_ranks",
"get_rank",
"get_world_size",
"init_process_group",
"irecv",
"is_backend_available",
"is_gloo_available",
"is_initialized",
"is_mpi_available",
"is_nccl_available",
"is_torchelastic_launched",
"is_ucc_available",
"isend",
"monitored_barrier",
"new_group",
@ -662,15 +626,8 @@ coverage_ignore_functions = [
"transformer_auto_wrap_policy",
"wrap",
# torch.distributed.nn.functional
"all_gather",
"all_reduce",
"all_to_all",
"all_to_all_single",
"broadcast",
"gather",
"reduce",
"reduce_scatter",
"scatter",
# torch.distributed.nn.jit.instantiator
"get_arg_return_types_from_interface",
"instantiate_non_scriptable_remote_module_template",

View File

@ -10,6 +10,7 @@ torch.cpu
current_device
current_stream
is_available
is_initialized
synchronize
stream
set_device

View File

@ -221,6 +221,16 @@ inconsistent 'UUID' assignment across ranks, and to prevent races during initial
```{eval-rst}
.. autofunction:: torch.distributed.distributed_c10d.is_xccl_available
.. autofunction:: torch.distributed.distributed_c10d.batch_isend_irecv
.. autofunction:: torch.distributed.distributed_c10d.destroy_process_group
.. autofunction:: torch.distributed.distributed_c10d.is_backend_available
.. autofunction:: torch.distributed.distributed_c10d.irecv
.. autofunction:: torch.distributed.distributed_c10d.is_gloo_available
.. autofunction:: torch.distributed.distributed_c10d.is_initialized
.. autofunction:: torch.distributed.distributed_c10d.is_mpi_available
.. autofunction:: torch.distributed.distributed_c10d.is_nccl_available
.. autofunction:: torch.distributed.distributed_c10d.is_torchelastic_launched
.. autofunction:: torch.distributed.distributed_c10d.is_ucc_available
```
```{eval-rst}

View File

@ -218,3 +218,13 @@ DataParallel functions (multi-GPU, distributed)
:nosignatures:
torch.nn.parallel.data_parallel
Low-Precision functions
-----------------------
.. autosummary::
:toctree: generated
:nosignatures:
ScalingType
SwizzleType
scaled_mm

View File

@ -1,6 +1,65 @@
# LibTorch Stable ABI
This note will eventually contain more details on how to use the APIs in torch/csrc/stable. For the moment, it contains a table of internal representations:
## Overview
The LibTorch Stable ABI (Application Binary Interface) provides an interface for extending PyTorch functionality without being tightly coupled to specific PyTorch versions. This enables the development of custom operators and extensions that remain compatible across PyTorch releases.
The stable ABI consists of three main components:
1. **Stable C headers** - Low-level C API implemented by libtorch (primarily `torch/csrc/inductor/aoti_torch/c/shim.h`)
2. **Header-only C++ library** - Standalone utilities implemented in only headers such that there is no dependence on libtorch (`torch/headeronly/*`)
3. **Stable C++ wrappers** - High-level C++ convenience wrappers (`torch/csrc/stable/*`)
We discuss each of these in detail
### `torch/headeronly`
This is a set of inlined C++ headers are completely decoupled from libtorch. The headers consist of certain utilities that might be familiar to custom extension writers. For example, the
`c10::ScalarType` enum lives here as `torch::headeronly::ScalarType`.
### `torch/csrc/stable`
This is a set of inlined C++ headers that provide wrappers around the C API that handle the rough edges
discussed below.
It consists of
- torch/csrc/stable/library.h: Provides a stable version of TORCH_LIBRARY and similar macros.
- torch/csrc/stable/tensor_struct.h: Provides torch::stable::Tensor, a stable version of at::Tensor.
- torch/csrc/stable/ops.h: Provides a stable interface for calling ATen ops from `native_functions.yaml`.
- torch/csrc/stable/accelerator.h: Provides a stable interface for device-generic objects and APIs
(e.g. `getCurrentStream`, `DeviceGuard`).
We are continuing to improve coverage in our `torch/csrc/stable` APIs. Please file an issue if you'd like to see support for particular APIs in your custom extension.
### Stable C headers
The stable C headers used by AOTInductor form the foundation of the stable ABI. However, this is **use at your own risk**. For example, users must handle the memory lifecycle of objects returned by certain APIs.
Further, the stack-based APIs discussed below which allow the user to call the PyTorch dispatcher don't provide strong guarantees on forward and backward compatibility.
Unless absolutely necessary, we recommend the high-level C++ API in `torch/csrc/stable`
which will handle all the rough edges of the C API for the user.
## How are objects passed across the ABI boundary when interacting with the dispatcher?
When interacting with the dispatcher via the stable APIs (``STABLE_TORCH_LIBRARY`` etc.) we use a boxed convention. Arguments and returns are represented as a stack of ``StableIValue`` which correlates with a `torch::jit::stack` of IValues. We discuss the following below
1. StableIValue Conversions
2. StableIValue stack Conventions
3. Stable APIs that interact with the dispatcher
### StableIValue Conversions
We provide utilities for users to convert objects to and from StableIValues with the synonymous
`to` and `from` APIs in `torch/csrc/stable/stableivalue_conversions.h`. We document the stable custom extension representation, libtorch representation and StableIValue
representations below. Our confidently supported types are the ones in the table that have completed
rows. You can rely on this subset for proper ABI stability, meaning that you can call `to<T_custom_ext>(arg/ret)` or `from(T)` on these types.
For a limited set of use cases, we also implicitly support any literal type that is representable within 64 bits as StableIValues, as the default reinterpret_cast will succeed. (For example: c10::Device.) These types are currently ABI-stable on best effort but might break in the future and thus should be used for short term testing only.
You can always work with StableIValue abstractions in your custom kernel for types such as c10::Device even if there is no standard defined representation of device in custom extensions by not introspecting into the StableIValue. For example, a custom operator can take as argument a StableIValue device and directly pass it through to an aten operator with `aoti_torch_call_dispatcher`.
1. type in custom extension: type used within the end user custom library.
2. StableIValue representation: a stable conversion of the type to liaison between the user model vs libtorch.so in an ABI-stable manner.
3. type in libtorch: type used within libtorch.so (or any code binary locked with libtorch).
@ -31,16 +90,10 @@ This note will eventually contain more details on how to use the APIs in torch/c
| ? | ? | c10::SymBool | SymBool |
| ? | ? | at::QScheme | QScheme |
Our confidently supported types are the ones in the table that have completed rows. You can rely on this subset for proper ABI stability.
For a limited set of use cases, we also implicitly support any literal type that is representable within 64 bits as StableIValues, as the default reinterpret_cast will succeed. (For example: c10::Device.) These types are currently ABI-stable on best effort but might break in the future and thus should be used for short term testing only.
### Stack Conventions
You can always work with StableIValue abstractions in your custom kernel for types such as c10::Device even if there is no standard defined representation of device in custom extensions by not introspecting into the StableIValue. For example, a custom operator can take as argument a StableIValue device and directly pass it through to an aten operator with `aoti_torch_call_dispatcher`.
## How to use stack-based APIs
`aoti_torch_call_dispatcher` is what we consider a stack-based API because it takes as input a stack of StableIValues, which correlates with a `torch::jit::stack` of IValues. Working with the dispatcher will likely bring you into proximity with stack-based APIs, so we are documenting some invariants:
There are two invariants for the stack:
1. The stack is populated left to right.
a. For example, a stack representing arguments `arg0`, `arg1`, and `arg2` will have `arg0` at index 0, `arg1` at index 1, and `arg2` at index 2.
@ -49,3 +102,32 @@ You can always work with StableIValue abstractions in your custom kernel for typ
2. The stack always has ownership of the objects it holds.
a. When calling a stack-based API, you must give owning references to the calling stack and steal references from the returned stack.
b. When registering your function to be called with a stack, you must steal references from your argument stack and push onto the stack new references.
### Stack-based APIs
The above is relevant in two places:
1. `STABLE_TORCH_LIBRARY`
Unlike `TORCH_LIBRARY`, the dispatcher expects kernels registered via `STABLE_TORCH_LIBRARY` to be boxed. This means they must have the signature `(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) -> void`.We plan to eventually abstract away the need for manual boxing, but, for the time being, please use `from` and `to`.
```cpp
Tensor my_amax_vec(Tensor t) {
std::vector<int64_t> v = {0,1};
return amax(t, v, false);
}
void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_amax_vec(to<Tensor>(stack[0]));
stack[0] = from(res);
}
```
2. `aoti_torch_call_dispatcher`
This API allows you to call the PyTorch dispatcher from C/C++ code. It has the following signature:
```cpp
aoti_torch_call_dispatcher(const char* opName, const char* overloadName, StableIValue* stack);
```
`aoti_torch_call_dispatcher` will call the op overload defined by a given `opName`, `overloadName`, and a stack of
StableIValues. This call will populate any return values of the op into the stack in their StableIValue form,
with `ret0` at index 0, `ret1` at index 1, and so on.

View File

@ -35,7 +35,6 @@ and supported quantized modules and functions.
quantization-support
.. torch.ao is missing documentation. Since part of it is mentioned here, adding them here for now.
.. They are here for tracking purposes until they are more permanently fixed.
.. py:module:: torch.ao

View File

@ -20,8 +20,10 @@ project-includes = [
project-excludes = [
# ==== below will be enabled directory by directory ====
# ==== to test Pyrefly on a specific directory, simply comment it out ====
"torch/_inductor/**",
# formatting issues
"torch/_inductor/runtime",
"torch/_inductor/codegen",
# formatting issues, will turn on after adjusting where suppressions can be
# in import statements
"torch/linalg/__init__.py",
"torch/package/importer.py",
"torch/package/_package_pickler.py",
@ -31,6 +33,9 @@ project-excludes = [
"torch/_export/utils.py",
"torch/fx/experimental/unification/multipledispatch/__init__.py",
"torch/nn/modules/__init__.py",
"torch/nn/modules/rnn.py", # only remove when parsing errors are fixed
"torch/_inductor/codecache.py",
"torch/distributed/elastic/metrics/__init__.py",
# ====
"benchmarks/instruction_counts/main.py",
"benchmarks/instruction_counts/definitions/setup.py",

View File

@ -89,7 +89,7 @@ if venv_dir.exists():
print("Removing existing hook venv...")
shutil.rmtree(venv_dir)
run(["uv", "venv", str(venv_dir), "--python", "3.9"])
run(["uv", "venv", str(venv_dir), "--python", "3.10"])
# Install lintrunner in the isolated environment
print("Installing lintrunner in isolated environment...")

View File

@ -225,7 +225,7 @@
#
# USE_MIMALLOC
# Static link mimalloc into C10, and use mimalloc in alloc_cpu & alloc_free.
# By default, It is only enabled on Windows.
# By default, It is only enabled on Windows and AArch64.
#
# BUILD_LIBTORCH_WHL
# Builds libtorch.so and its dependencies as a wheel

View File

@ -1111,6 +1111,14 @@
"_amp_update_scale_",
"_assert_async",
"_batch_norm_impl_index",
"_cast_Byte",
"_cast_Char",
"_cast_Double",
"_cast_Float",
"_cast_Half",
"_cast_Int",
"_cast_Long",
"_cast_Short",
"_choose_qparams_per_tensor",
"_coalesce",
"_compute_linear_combination",

View File

@ -1292,12 +1292,6 @@ torch::Tensor view_op(const torch::Tensor& self) {
return self.alias();
}
torch::Tensor view_op_with_extra_arg(
const torch::Tensor& self,
const torch::Tensor& other) {
return self.alias();
}
std::vector<torch::Tensor> ret_tensor_vector_view(
const torch::Tensor& self,
const torch::Tensor& other) {
@ -1534,35 +1528,9 @@ TEST(TestAutogradNotImplementedFallback, ViewOp) {
// Test inplace on view
auto t = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
// raise on rebase_history when it refreshes grad_fn
ASSERT_THROWS_WITH(
v1.add_(t), "which does not have a derivative implemented is forbidden");
// base should not be aware of the views, so this is still okay
// this works as we can properly replay the view given by the user
v1.add_(t);
b1.add_(t);
ASSERT_THROWS_WITH(
v1.grad_fn(),
"which does not have a derivative implemented is forbidden");
}
TEST(TestAutogradNotImplementedFallback, ViewOpWithExtraArg) {
REGISTER_TEST_OP(
"view_op_with_extra_arg",
"_test::view_op_with_extra_arg(Tensor(a) self, Tensor other) -> Tensor(a)",
view_op_with_extra_arg);
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
"_test::view_op_with_extra_arg", "");
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
return callOpUnboxed<
torch::Tensor,
const torch::Tensor&,
const torch::Tensor&>(opHandle, _1, _2);
};
assertBasicChecks(op);
auto a = torch::tensor({1.}, {torch::kFloat32});
auto b = torch::tensor({2.}, {torch::kFloat32});
auto out1 = op(a, b);
ASSERT_TRUE(out1.is_view());
ASSERT_EQ(out1._base().unsafeGetTensorImpl(), a.unsafeGetTensorImpl());
}
TEST(TestAutogradNotImplementedFallback, RetTensorVectorView) {

View File

@ -564,508 +564,3 @@ TEST(OptimTest, CheckLRChange_ReduceLROnPlateau_Adam) {
check_lr_change_for_reduce_on_plateau(
optimizer, reduce_lr_on_plateau_scheduler, expected_epoch_lrs);
}
// Tests for Issue 141884: Parameter group inheritance functionality
// Validates that partial options in parameter groups correctly inherit
// defaults from the optimizer while preserving explicitly set values
TEST(OptimTest, MergeWithDefaultOptions_Adam) {
// Create tensors for parameter groups
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
// Create param groups with partial options
std::vector<OptimizerParamGroup> param_groups;
// Group 1: Only weight_decay specified, should inherit lr, betas, eps,
// amsgrad
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor1},
std::make_unique<AdamOptions>(AdamOptions().weight_decay(0.11)));
// Group 2: Only eps specified, should inherit others
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor2},
std::make_unique<AdamOptions>(AdamOptions().eps(1e-6)));
// Create optimizer with specific defaults
AdamOptions defaults;
defaults.lr(0.002)
.betas(std::make_tuple(0.8, 0.88))
.eps(1e-12)
.weight_decay(0.05)
.amsgrad(true);
Adam optimizer(param_groups, defaults);
// Check Group 1: weight_decay preserved, others inherited
auto& group1_opts =
static_cast<AdamOptions&>(optimizer.param_groups()[0].options());
ASSERT_EQ(group1_opts.lr(), 0.002); // Inherited
ASSERT_EQ(group1_opts.betas(), std::make_tuple(0.8, 0.88)); // Inherited
ASSERT_EQ(group1_opts.eps(), 1e-12); // Inherited
ASSERT_EQ(group1_opts.weight_decay(), 0.11); // Preserved
ASSERT_TRUE(group1_opts.amsgrad()); // Inherited
// Check Group 2: eps preserved, others inherited
auto& group2_opts =
static_cast<AdamOptions&>(optimizer.param_groups()[1].options());
ASSERT_EQ(group2_opts.lr(), 0.002); // Inherited
ASSERT_EQ(group2_opts.betas(), std::make_tuple(0.8, 0.88)); // Inherited
ASSERT_EQ(group2_opts.eps(), 1e-6); // Preserved
ASSERT_EQ(group2_opts.weight_decay(), 0.05); // Inherited
ASSERT_TRUE(group2_opts.amsgrad()); // Inherited
}
TEST(OptimTest, MergeWithDefaultOptions_SGD) {
// Create tensors for parameter groups
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
// Create param groups with partial options
std::vector<OptimizerParamGroup> param_groups;
// Group 1: Only lr and weight_decay specified, should inherit momentum,
// dampening, nesterov
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor1},
std::make_unique<SGDOptions>(SGDOptions(0.01).weight_decay(0.22)));
// Group 2: Only lr specified, should inherit all others
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor2},
std::make_unique<SGDOptions>(SGDOptions(0.02)));
// Create optimizer with specific defaults
SGDOptions defaults(0.001); // lr should be overridden by param groups
defaults.momentum(0.9)
.dampening(0.0) // Must be 0 for Nesterov
.weight_decay(0.05)
.nesterov(true);
SGD optimizer(param_groups, defaults);
// Check Group 1: lr and weight_decay preserved, others inherited
auto& group1_opts =
static_cast<SGDOptions&>(optimizer.param_groups()[0].options());
ASSERT_EQ(group1_opts.lr(), 0.01); // Preserved
ASSERT_EQ(group1_opts.momentum(), 0.9); // Inherited
ASSERT_EQ(group1_opts.dampening(), 0.0); // Inherited
ASSERT_EQ(group1_opts.weight_decay(), 0.22); // Preserved
ASSERT_TRUE(group1_opts.nesterov()); // Inherited
// Check Group 2: lr preserved, others inherited
auto& group2_opts =
static_cast<SGDOptions&>(optimizer.param_groups()[1].options());
ASSERT_EQ(group2_opts.lr(), 0.02); // Preserved
ASSERT_EQ(group2_opts.momentum(), 0.9); // Inherited
ASSERT_EQ(group2_opts.dampening(), 0.0); // Inherited
ASSERT_EQ(group2_opts.weight_decay(), 0.05); // Inherited
ASSERT_TRUE(group2_opts.nesterov()); // Inherited
}
TEST(OptimTest, MergeWithDefaultOptions_AdamW) {
// Create tensors for parameter groups
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
// Create param groups with partial options
std::vector<OptimizerParamGroup> param_groups;
// Group 1: Only eps specified, should inherit others
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor1},
std::make_unique<AdamWOptions>(AdamWOptions().eps(1e-6)));
// Group 2: Only betas specified, should inherit others
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor2},
std::make_unique<AdamWOptions>(
AdamWOptions().betas(std::make_tuple(0.95, 0.999))));
// Create optimizer with specific defaults
AdamWOptions defaults;
defaults.lr(0.003)
.betas(std::make_tuple(0.9, 0.98))
.eps(1e-8)
.weight_decay(0.02)
.amsgrad(false);
AdamW optimizer(param_groups, defaults);
// Check Group 1: eps preserved, others inherited
auto& group1_opts =
static_cast<AdamWOptions&>(optimizer.param_groups()[0].options());
ASSERT_EQ(group1_opts.lr(), 0.003); // Inherited
ASSERT_EQ(group1_opts.betas(), std::make_tuple(0.9, 0.98)); // Inherited
ASSERT_EQ(group1_opts.eps(), 1e-6); // Preserved
ASSERT_EQ(group1_opts.weight_decay(), 0.02); // Inherited
ASSERT_FALSE(group1_opts.amsgrad()); // Inherited
// Check Group 2: betas preserved, others inherited
auto& group2_opts =
static_cast<AdamWOptions&>(optimizer.param_groups()[1].options());
ASSERT_EQ(group2_opts.lr(), 0.003); // Inherited
ASSERT_EQ(group2_opts.betas(), std::make_tuple(0.95, 0.999)); // Preserved
ASSERT_EQ(group2_opts.eps(), 1e-8); // Inherited
ASSERT_EQ(group2_opts.weight_decay(), 0.02); // Inherited
ASSERT_FALSE(group2_opts.amsgrad()); // Inherited
}
TEST(OptimTest, MergeWithDefaultOptions_Adagrad) {
// Create tensors for parameter groups
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
// Create param groups with partial options
std::vector<OptimizerParamGroup> param_groups;
// Group 1: Only lr_decay specified, should inherit others
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor1},
std::make_unique<AdagradOptions>(AdagradOptions().lr_decay(0.001)));
// Group 2: Only initial_accumulator_value specified, should inherit others
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor2},
std::make_unique<AdagradOptions>(
AdagradOptions().initial_accumulator_value(0.5)));
// Create optimizer with specific defaults
AdagradOptions defaults;
defaults.lr(0.04)
.lr_decay(0.002)
.weight_decay(0.03)
.initial_accumulator_value(0.1)
.eps(1e-11);
Adagrad optimizer(param_groups, defaults);
// Check Group 1: lr_decay preserved, others inherited
auto& group1_opts =
static_cast<AdagradOptions&>(optimizer.param_groups()[0].options());
ASSERT_EQ(group1_opts.lr(), 0.04); // Inherited
ASSERT_EQ(group1_opts.lr_decay(), 0.001); // Preserved
ASSERT_EQ(group1_opts.weight_decay(), 0.03); // Inherited
ASSERT_EQ(group1_opts.initial_accumulator_value(), 0.1); // Inherited
ASSERT_EQ(group1_opts.eps(), 1e-11); // Inherited
// Check Group 2: initial_accumulator_value preserved, others inherited
auto& group2_opts =
static_cast<AdagradOptions&>(optimizer.param_groups()[1].options());
ASSERT_EQ(group2_opts.lr(), 0.04); // Inherited
ASSERT_EQ(group2_opts.lr_decay(), 0.002); // Inherited
ASSERT_EQ(group2_opts.weight_decay(), 0.03); // Inherited
ASSERT_EQ(group2_opts.initial_accumulator_value(), 0.5); // Preserved
ASSERT_EQ(group2_opts.eps(), 1e-11); // Inherited
}
TEST(OptimTest, MergeWithDefaultOptions_RMSprop) {
// Create tensors for parameter groups
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
// Create param groups with partial options
std::vector<OptimizerParamGroup> param_groups;
// Group 1: Only alpha specified, should inherit others
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor1},
std::make_unique<RMSpropOptions>(RMSpropOptions().alpha(0.95)));
// Group 2: Only momentum and centered specified, should inherit others
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor2},
std::make_unique<RMSpropOptions>(
RMSpropOptions().momentum(0.8).centered(true)));
// Create optimizer with specific defaults
RMSpropOptions defaults;
defaults.lr(0.015)
.alpha(0.98)
.eps(1e-9)
.weight_decay(0.01)
.momentum(0.7)
.centered(false);
RMSprop optimizer(param_groups, defaults);
// Check Group 1: alpha preserved, others inherited
auto& group1_opts =
static_cast<RMSpropOptions&>(optimizer.param_groups()[0].options());
ASSERT_EQ(group1_opts.lr(), 0.015); // Inherited
ASSERT_EQ(group1_opts.alpha(), 0.95); // Preserved
ASSERT_EQ(group1_opts.eps(), 1e-9); // Inherited
ASSERT_EQ(group1_opts.weight_decay(), 0.01); // Inherited
ASSERT_EQ(group1_opts.momentum(), 0.7); // Inherited
ASSERT_FALSE(group1_opts.centered()); // Inherited
// Check Group 2: momentum and centered preserved, others inherited
auto& group2_opts =
static_cast<RMSpropOptions&>(optimizer.param_groups()[1].options());
ASSERT_EQ(group2_opts.lr(), 0.015); // Inherited
ASSERT_EQ(group2_opts.alpha(), 0.98); // Inherited
ASSERT_EQ(group2_opts.eps(), 1e-9); // Inherited
ASSERT_EQ(group2_opts.weight_decay(), 0.01); // Inherited
ASSERT_EQ(group2_opts.momentum(), 0.8); // Preserved
ASSERT_TRUE(group2_opts.centered()); // Preserved
}
TEST(OptimTest, MergeWithDefaultOptions_LBFGS) {
// Create tensors for single parameter group (LBFGS limitation)
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
// Create param group with partial options
std::vector<OptimizerParamGroup> param_groups;
// Single group: Only max_iter specified, should inherit others
param_groups.emplace_back(
std::vector<torch::Tensor>{
tensor1, tensor2}, // Combine tensors in single group
std::make_unique<LBFGSOptions>(LBFGSOptions().max_iter(15)));
// Create optimizer with specific defaults
LBFGSOptions defaults;
defaults.lr(0.8)
.max_iter(25)
.max_eval(31) // Use same value that appears to be auto-calculated
.tolerance_grad(1e-5)
.tolerance_change(1e-8)
.history_size(80)
.line_search_fn("strong_wolfe");
LBFGS optimizer(param_groups, defaults);
// Check Group: max_iter preserved, others inherited
auto& group_opts =
static_cast<LBFGSOptions&>(optimizer.param_groups()[0].options());
ASSERT_EQ(group_opts.lr(), 0.8); // Inherited
ASSERT_EQ(group_opts.max_iter(), 15); // Preserved
ASSERT_EQ(group_opts.max_eval(), 31); // Inherited
ASSERT_EQ(group_opts.tolerance_grad(), 1e-5); // Inherited
ASSERT_EQ(group_opts.tolerance_change(), 1e-8); // Inherited
ASSERT_EQ(group_opts.history_size(), 80); // Inherited
ASSERT_EQ(group_opts.line_search_fn(), "strong_wolfe"); // Inherited
}
TEST(OptimTest, MergeWithDefaultOptions_NoOptionsInheritance) {
// Test that param groups without options get full defaults
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
std::vector<OptimizerParamGroup> param_groups;
// Groups with no options - should inherit everything
param_groups.emplace_back(std::vector<torch::Tensor>{tensor1});
param_groups.emplace_back(std::vector<torch::Tensor>{tensor2});
// Create optimizer with specific defaults
AdamOptions defaults;
defaults.lr(0.005)
.betas(std::make_tuple(0.85, 0.95))
.eps(1e-7)
.weight_decay(0.08)
.amsgrad(true);
Adam optimizer(param_groups, defaults);
// Both groups should have exactly the default options
for (int i = 0; i < 2; i++) {
auto& group_opts =
static_cast<AdamOptions&>(optimizer.param_groups()[i].options());
ASSERT_EQ(group_opts.lr(), 0.005);
ASSERT_EQ(group_opts.betas(), std::make_tuple(0.85, 0.95));
ASSERT_EQ(group_opts.eps(), 1e-7);
ASSERT_EQ(group_opts.weight_decay(), 0.08);
ASSERT_TRUE(group_opts.amsgrad());
}
}
// Test that field tracking survives serialization/deserialization cycles
TEST(OptimTest, SerializationPreservesFieldTracking_Adam) {
// Create tensors for parameter groups
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
// Create param groups with partial options using fluent API (marks fields as
// explicit)
std::vector<OptimizerParamGroup> param_groups;
// Group 1: Only weight_decay and amsgrad explicitly set via fluent API
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor1},
std::make_unique<AdamOptions>(
AdamOptions().weight_decay(0.11).amsgrad(true)));
// Group 2: Only eps explicitly set via fluent API
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor2},
std::make_unique<AdamOptions>(AdamOptions().eps(1e-6)));
// Create optimizer with specific defaults
AdamOptions defaults;
defaults.lr(0.002)
.betas(std::make_tuple(0.8, 0.88))
.eps(1e-12)
.weight_decay(0.05)
.amsgrad(false);
Adam original_optimizer(param_groups, defaults);
// Capture original state for comparison
auto& orig_group1_opts =
static_cast<AdamOptions&>(original_optimizer.param_groups()[0].options());
auto& orig_group2_opts =
static_cast<AdamOptions&>(original_optimizer.param_groups()[1].options());
// Verify original state (sanity check)
ASSERT_NEAR(orig_group1_opts.weight_decay(), 0.11, 1e-6); // Explicitly set
ASSERT_TRUE(orig_group1_opts.amsgrad()); // Explicitly set
ASSERT_NEAR(orig_group1_opts.lr(), 0.002, 1e-6); // Inherited
ASSERT_NEAR(orig_group2_opts.eps(), 1e-6, 1e-9); // Explicitly set
ASSERT_NEAR(orig_group2_opts.lr(), 0.002, 1e-6); // Inherited
// Test serialization of the options objects (where field tracking lives)
std::stringstream ss1, ss2;
// Serialize the parameter group options
{
torch::serialize::OutputArchive archive;
orig_group1_opts.serialize(archive);
archive.save_to(ss1);
}
{
torch::serialize::OutputArchive archive;
orig_group2_opts.serialize(archive);
archive.save_to(ss2);
}
// Create new options objects and deserialize
AdamOptions loaded_group1_opts;
AdamOptions loaded_group2_opts;
{
torch::serialize::InputArchive archive;
archive.load_from(ss1);
loaded_group1_opts.serialize(archive);
}
{
torch::serialize::InputArchive archive;
archive.load_from(ss2);
loaded_group2_opts.serialize(archive);
}
// Verify that all parameter values are preserved after deserialization
// Group 1: weight_decay and amsgrad should be preserved as explicitly set,
// others inherited
ASSERT_NEAR(loaded_group1_opts.lr(), 0.002, 1e-6); // Inherited
ASSERT_EQ(
loaded_group1_opts.betas(), std::make_tuple(0.8, 0.88)); // Inherited
ASSERT_NEAR(loaded_group1_opts.eps(), 1e-12, 1e-15); // Inherited
ASSERT_NEAR(loaded_group1_opts.weight_decay(), 0.11, 1e-6); // Explicitly set
ASSERT_TRUE(loaded_group1_opts.amsgrad()); // Explicitly set
// Group 2: eps should be preserved as explicitly set, others inherited
ASSERT_NEAR(loaded_group2_opts.lr(), 0.002, 1e-6); // Inherited
ASSERT_EQ(
loaded_group2_opts.betas(), std::make_tuple(0.8, 0.88)); // Inherited
ASSERT_NEAR(loaded_group2_opts.eps(), 1e-6, 1e-9); // Explicitly set
ASSERT_NEAR(loaded_group2_opts.weight_decay(), 0.05, 1e-6); // Inherited
ASSERT_FALSE(loaded_group2_opts.amsgrad()); // Inherited
// CRITICAL: Test that field tracking is preserved after serialization
// Create a new optimizer using the deserialized options to test inheritance
auto tensor3 = torch::randn({2, 2}).requires_grad_(true);
auto tensor4 = torch::randn({3, 3}).requires_grad_(true);
std::vector<OptimizerParamGroup> test_param_groups;
test_param_groups.emplace_back(
std::vector<torch::Tensor>{tensor3},
std::make_unique<AdamOptions>(loaded_group1_opts));
test_param_groups.emplace_back(
std::vector<torch::Tensor>{tensor4},
std::make_unique<AdamOptions>(loaded_group2_opts));
Adam test_optimizer(test_param_groups, defaults);
// The field tracking should work correctly for inheritance
auto& final_group1_opts =
static_cast<AdamOptions&>(test_optimizer.param_groups()[0].options());
auto& final_group2_opts =
static_cast<AdamOptions&>(test_optimizer.param_groups()[1].options());
// Group 1: weight_decay and amsgrad should still be preserved as explicitly
// set
ASSERT_NEAR(
final_group1_opts.weight_decay(),
0.11,
1e-6); // Explicitly set (preserved)
ASSERT_TRUE(final_group1_opts.amsgrad()); // Explicitly set (preserved)
ASSERT_NEAR(final_group1_opts.lr(), 0.002, 1e-6); // Inherited from defaults
// Group 2: eps should still be preserved as explicitly set
ASSERT_NEAR(
final_group2_opts.eps(), 1e-6, 1e-9); // Explicitly set (preserved)
ASSERT_NEAR(final_group2_opts.lr(), 0.002, 1e-6); // Inherited from defaults
}
// Test serialization with SGD (different parameter types)
TEST(OptimTest, SerializationPreservesFieldTracking_SGD) {
// Create tensors
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
// Create param group with partial options using fluent API
std::vector<OptimizerParamGroup> param_groups;
param_groups.emplace_back(
std::vector<torch::Tensor>{tensor1},
std::make_unique<SGDOptions>(
SGDOptions(0.01).weight_decay(0.22).nesterov(true)));
// Create optimizer with defaults
SGDOptions defaults(0.001);
defaults.momentum(0.9).dampening(0.0).weight_decay(0.05).nesterov(false);
SGD original_optimizer(param_groups, defaults);
// Test serialization of the SGD options (where field tracking lives)
auto& original_opts =
static_cast<SGDOptions&>(original_optimizer.param_groups()[0].options());
std::stringstream ss;
{
torch::serialize::OutputArchive archive;
original_opts.serialize(archive);
archive.save_to(ss);
}
SGDOptions loaded_opts(0.0); // Dummy initial value
{
torch::serialize::InputArchive archive;
archive.load_from(ss);
loaded_opts.serialize(archive);
}
ASSERT_NEAR(loaded_opts.lr(), 0.01, 1e-6); // Explicitly set
ASSERT_NEAR(loaded_opts.momentum(), 0.9, 1e-6); // Inherited
ASSERT_NEAR(loaded_opts.dampening(), 0.0, 1e-6); // Inherited
ASSERT_NEAR(loaded_opts.weight_decay(), 0.22, 1e-6); // Explicitly set
ASSERT_TRUE(loaded_opts.nesterov()); // Explicitly set
// Test that field tracking still works after deserialization by creating new
// optimizer
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
std::vector<OptimizerParamGroup> test_param_groups;
test_param_groups.emplace_back(
std::vector<torch::Tensor>{tensor2},
std::make_unique<SGDOptions>(loaded_opts));
SGD test_optimizer(test_param_groups, defaults);
auto& final_opts =
static_cast<SGDOptions&>(test_optimizer.param_groups()[0].options());
ASSERT_NEAR(final_opts.lr(), 0.01, 1e-6); // Explicitly set (preserved)
ASSERT_NEAR(
final_opts.weight_decay(), 0.22, 1e-6); // Explicitly set (preserved)
ASSERT_TRUE(final_opts.nesterov()); // Explicitly set (preserved)
ASSERT_NEAR(final_opts.momentum(), 0.9, 1e-6); // Inherited from defaults
ASSERT_NEAR(final_opts.dampening(), 0.0, 1e-6); // Inherited from defaults
}

View File

@ -135,6 +135,84 @@ TEST_F(LazyOpsTest, TestIsSigned) {
});
}
TEST_F(LazyOpsTest, TestCastByte) {
torch::Tensor a =
torch::rand(
{2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
100.0;
torch::Tensor b = torch::_cast_Byte(a);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor lazy_a = CopyToDevice(a, device);
torch::Tensor lazy_b = torch::_cast_Byte(lazy_a);
AllEqual(b, lazy_b);
});
}
TEST_F(LazyOpsTest, TestCastChar) {
torch::Tensor a =
torch::rand(
{2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
100.0;
torch::Tensor b = torch::_cast_Char(a);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor lazy_a = CopyToDevice(a, device);
torch::Tensor lazy_b = torch::_cast_Char(lazy_a);
AllEqual(b, lazy_b);
});
}
TEST_F(LazyOpsTest, TestCastShort) {
torch::Tensor a =
torch::rand(
{2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
100.0;
torch::Tensor b = torch::_cast_Short(a);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor lazy_a = CopyToDevice(a, device);
torch::Tensor lazy_b = torch::_cast_Short(lazy_a);
AllEqual(b, lazy_b);
});
}
TEST_F(LazyOpsTest, TestCastInt) {
torch::Tensor a =
torch::rand(
{2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
100.0;
torch::Tensor b = torch::_cast_Int(a);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor lazy_a = CopyToDevice(a, device);
torch::Tensor lazy_b = torch::_cast_Int(lazy_a);
AllEqual(b, lazy_b);
});
}
TEST_F(LazyOpsTest, TestCastLong) {
torch::Tensor a =
torch::rand(
{2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
100.0;
torch::Tensor b = torch::_cast_Long(a);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor lazy_a = CopyToDevice(a, device);
torch::Tensor lazy_b = torch::_cast_Long(lazy_a);
AllEqual(b, lazy_b);
});
}
TEST_F(LazyOpsTest, TestCastFloat) {
torch::Tensor a =
torch::rand(
{2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
100.0;
torch::Tensor b = torch::_cast_Float(a);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor lazy_a = CopyToDevice(a, device);
torch::Tensor lazy_b = torch::_cast_Float(lazy_a);
AllEqual(b, lazy_b);
});
}
TEST_F(LazyOpsTest, TestRetainType) {
torch::Tensor lazy_a = torch::zeros(
{2, 2}, torch::TensorOptions(torch::kByte).device(torch::kLazy));

View File

@ -32,7 +32,7 @@ from torch.testing._internal.common_distributed import (
sm_is_or_higher_than,
)
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype, MLP
from torch.testing._internal.common_utils import run_tests, skipIfRocm
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
Transformer,
@ -133,7 +133,11 @@ class TestFullyShardCompile(FSDPTest):
device_type.type,
self.rank % torch.get_device_module(device_type).device_count(),
)
if device_type.type == "cuda" and not sm_is_or_higher_than(device, 8, 0):
if (
device_type.type == "cuda"
and not torch.version.hip
and not sm_is_or_higher_than(device, 8, 0)
):
self.skipTest("bf16 requires sm >= 8.0")
def test_dynamo_trace_use_training_state(self):
@ -478,7 +482,6 @@ val.shape: {[node.meta["val"].shape for node in aliased_graph_inputs]},
file_check = file_check.check("torch.ops._c10d_functional.wait_tensor.")
return file_check
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_compiled_autograd_ctx(self):
self.skipTestForOldSm()
@ -643,14 +646,12 @@ Unsupported Tensor.backward() call
return model_init_fn, input_creation_fn
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_simple_mlp_fullgraph_backend_aot_eager(self):
self._test_traceable_fsdp(
*self._create_simple_mlp_factory_fns(), "aot_eager", fwd_fullgraph=True
)
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_simple_mlp_fullgraph_backend_aot_eager_decomp_partition(self):
self._test_traceable_fsdp(
@ -659,7 +660,6 @@ Unsupported Tensor.backward() call
fwd_fullgraph=True,
)
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_simple_mlp_fullgraph_backend_inductor(self):
self.skipTestForOldSm()
@ -731,7 +731,6 @@ Unsupported Tensor.backward() call
return model_init_fn, input_creation_fn
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_nested_fully_shard_backend_aot_eager(self):
# TODO: fix fwd_fullgraph=False case
@ -744,7 +743,6 @@ Unsupported Tensor.backward() call
fwd_fullgraph=fwd_fullgraph,
)
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_nested_fully_shard_backend_aot_eager_decomp_partition(self):
# TODO: fix fwd_fullgraph=False case
@ -866,19 +864,16 @@ Unsupported Tensor.backward() call
pass
file_check.run(bwd_code)
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_nested_fully_shard_backend_inductor_fullgraph_True(self):
self._test_nested_fully_shard_backend_inductor_fullgraph_True()
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch("graph_partition", True)
def test_nested_fully_shard_backend_inductor_fullgraph_True_graph_partition(self):
self._test_nested_fully_shard_backend_inductor_fullgraph_True()
@unittest.skip("TODO: fix fwd_fullgraph=False case")
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_nested_fully_shard_backend_inductor_fullgraph_False(self):
self.skipTestForOldSm()
@ -956,7 +951,6 @@ Unsupported Tensor.backward() call
else:
return contextlib.nullcontext()
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_transformer_backend_aot_eager(self):
# TODO: fix fwd_fullgraph=False case
@ -975,7 +969,6 @@ Unsupported Tensor.backward() call
fwd_fullgraph=fwd_fullgraph,
)
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
# TODO: native_dropout has worse accuracy after decomp, need to figure out why
@torch._inductor.config.patch(fallback_random=True)
@ -1111,7 +1104,6 @@ Unsupported Tensor.backward() call
file_check.run(bwd_code)
@unittest.skip('"Traceable FSDP2" is not being maintained anymore.')
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
# TODO: native_dropout causes CUDA IMA error, need to figure out why
@torch._inductor.config.patch(fallback_random=True)
@ -1119,7 +1111,6 @@ Unsupported Tensor.backward() call
self._test_transformer_backend_inductor_fullgraph_True()
@unittest.skip('"Traceable FSDP2" is not being maintained anymore.')
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
# TODO: native_dropout causes CUDA IMA error, need to figure out why
@torch._inductor.config.patch(fallback_random=True)
@ -1128,7 +1119,6 @@ Unsupported Tensor.backward() call
self._test_transformer_backend_inductor_fullgraph_True()
@unittest.skip("TODO: fix fwd_fullgraph=False case")
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
# TODO: native_dropout causes CUDA IMA error, need to figure out why
@torch._inductor.config.patch(fallback_random=True)

View File

@ -3,7 +3,7 @@
import torch
import torch.nn as nn
from torch.distributed.checkpoint.state_dict import get_state_dict
from torch.distributed.device_mesh import _mesh_resources, init_device_mesh
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.tensor import DTensor
from torch.testing._internal.common_utils import run_tests
@ -73,8 +73,8 @@ class TestFSDPWithEP(DTensorTestBase, VerifyStateDictMixin):
self.device_type, (2, 4), mesh_dim_names=("dp", "tp")
)
# TODO: we are using an internal API atm. Change to a public API once it is ready.
mesh_fsdp_ep = _mesh_resources.create_sub_mesh(mesh_fsdp_tp, ("dp",), [(0,)])
del _mesh_resources.child_to_root_mapping[mesh_fsdp_ep]
mesh_fsdp_ep = mesh_fsdp_tp["dp"]
mesh_fsdp_ep._root_mesh = None
mesh_fsdp = init_device_mesh(self.device_type, (8,))
for i, l in enumerate(model.second.ep_layers):

View File

@ -8,6 +8,7 @@ import os
from model_registry import MultiMLP
import torch
from torch._dynamo import OptimizedModule
from torch.distributed.pipelining import (
Schedule1F1B,
ScheduleDualPipeV,
@ -258,7 +259,15 @@ class ScheduleTest(TestCase):
finally:
torch.distributed.destroy_process_group()
def test_zero_bubble_schedule_errors_with_compile(self):
@parametrize(
"ScheduleClass",
[
ScheduleInterleavedZeroBubble,
ScheduleZBVZeroBubble,
ScheduleDualPipeV,
],
)
def test_zero_bubble_schedule_errors_with_compile(self, ScheduleClass):
"""
Test that zero bubble schedules raise an error when used with torch.compile.
"""
@ -271,16 +280,18 @@ class ScheduleTest(TestCase):
model = MultiMLP(8, n_layers=n_stages)
# full_mod
compiled_model = torch.compile(model)
self.assertTrue(isinstance(compiled_model, OptimizedModule))
stage = PipelineStage(
compiled_model,
0,
n_stages,
device,
)
with self.assertRaises(RuntimeError):
ScheduleInterleavedZeroBubble([stage], 2)
torch.distributed.destroy_process_group()
try:
with self.assertRaises(RuntimeError):
ScheduleClass([stage], 2)
finally:
torch.distributed.destroy_process_group()
instantiate_parametrized_tests(ScheduleTest)

View File

@ -4,6 +4,7 @@ import contextlib
import torch
import torch.distributed as dist
from torch._dynamo.testing import CompileCounterWithBackend
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
from torch.testing._internal.common_utils import (
@ -214,6 +215,29 @@ class TestDTensorDebugMode(TestCase):
aten::_unsafe_view(ft: f32[64, 8], [8, 8, 8])""",
)
def test_tensor_attributes(self):
x = torch.randn(8, 8)
x.a1 = "x1"
x.a2 = "x2"
y = torch.randn(8, 8, 8)
y.a1 = "y"
with DebugMode(
record_torchfunction=True,
record_faketensor=True,
record_tensor_attributes=["a1", "a2"],
) as debug_mode:
torch.matmul(y, x)
self.assertExpectedInline(
debug_mode.debug_string(),
"""\
torch.matmul(t: f32[8, 8, 8]{a1=y}, t: f32[8, 8]{a1=x1, a2=x2})
aten::view(t: f32[8, 8, 8]{a1=y}, [64, 8])
aten::mm(t: f32[64, 8], t: f32[8, 8]{a1=x1, a2=x2})
aten::_unsafe_view(t: f32[64, 8], [8, 8, 8])""",
)
@parametrize("has_inner_mode", [True, False])
@parametrize("has_outer_mode", [True, False])
def test_nested_debug_mode(self, has_inner_mode, has_outer_mode):
@ -262,14 +286,21 @@ class TestDTensorDebugMode(TestCase):
self.assertIn("torch.ops.higher_order.cond", debug_mode.debug_string())
def test_compile(self):
@torch.compile
cnt = CompileCounterWithBackend("inductor")
@torch.compile(backend=cnt)
def f(x):
return x.sin().cos()
x = torch.randn(8)
with DebugMode() as debug_mode:
f(x)
self.assertEqual(len(debug_mode.debug_string()), 0)
self.assertEqual(len(debug_mode.debug_string()), 0)
f(x)
f(x)
self.assertEqual(
cnt.frame_count, 1
) # check DebugMode doesn't trigger additional recompilations
instantiate_parametrized_tests(TestDTensorDebugMode)

View File

@ -11,7 +11,8 @@ from torch._functorch.aot_autograd import aot_export_joint_with_descriptors
from torch._functorch.partitioners import min_cut_rematerialization_partition
from torch._guards import tracing, TracingContext
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import distribute_tensor, Replicate
from torch.distributed.tensor import distribute_tensor, Partial, Replicate, Shard
from torch.distributed.tensor._api import DTensor
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed.tensor.parallel import (
ColwiseParallel,
@ -39,6 +40,21 @@ class SimpleModel(torch.nn.Module):
return self.mlp_1(self.mlp_0(input))
class EinsumModel(torch.nn.Module):
"""Simple model that uses einsum with DTensor inputs and returns DTensor."""
def __init__(self):
super().__init__()
self.placement = None
def forward(self, x, y, z):
result = torch.einsum("bsh,hd->bsd", x, y)
self.placement = result.placements[0]
self.placement_2 = y.placements[0]
self.placement_3 = z.placements[0]
return result
class SimpleModelDynamicShapes(torch.nn.Module):
def __init__(self, device):
super().__init__()
@ -334,6 +350,32 @@ class DTensorExportTest(TestCase):
"""[[4, 10], [4], [10, 4], [10], [s22, 10], [s22, 10]]""",
)
def test_einsum_dtensor_export(self):
"""Test exporting a model with einsum that has DTensor inputs/outputs with side effects"""
world_size = 4
# Create device mesh
device_mesh = init_device_mesh(self.device_type, mesh_shape=(world_size,))
model = EinsumModel()
x = torch.randn(4, 8, 16)
x_dtensor = distribute_tensor(x, device_mesh, placements=[Shard(0)])
# y: [16, 16] replicated
y = torch.randn(16, 16)
z = torch.randn(16, 16)
y_dtensor = distribute_tensor(y, device_mesh, placements=[Replicate()])
z_dtensor = DTensor.from_local(z, device_mesh, placements=[Partial()])
# Run model to verify it works
output = model(x_dtensor, y_dtensor, z_dtensor)
with torch._dynamo.config.patch(install_free_tensors=True):
# TODO: switch to use the official graph_capture API once it is ready
gm = _dynamo_graph_capture_for_export(model)(
x_dtensor, y_dtensor, z_dtensor
)
output_gm = gm(x_dtensor, y_dtensor, z_dtensor)
self.assertEqual(output, output_gm)
instantiate_parametrized_tests(DTensorExportTest)

Some files were not shown because too many files have changed in this diff Show More