17271 Commits

Author SHA1 Message Date
33bfec27ff Revert "use sym_numel, to allow fake tensors to work (#163831)"
This reverts commit e71c75680f2d6ce5f61ad4b2125f4934087762eb.

Reverted https://github.com/pytorch/pytorch/pull/163831 on behalf of https://github.com/isuruf due to test failure on mps introduced ([comment](https://github.com/pytorch/pytorch/pull/163831#issuecomment-3400131730))
2025-10-14 05:10:56 +00:00
29c5368e0f MTIA _cdist_forward registration (#165333)
Summary: Added registration for _cdist_forward on MTIA

Differential Revision: D84357997

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165333
Approved by: https://github.com/albanD
2025-10-14 03:51:31 +00:00
e71c75680f use sym_numel, to allow fake tensors to work (#163831)
Fixes #[163759](https://github.com/pytorch/pytorch/issues/163759)

Replace `numel` with `sym_numel`. Tested with example in issue and it works now .

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163831
Approved by: https://github.com/bobrenjc93
2025-10-14 03:33:28 +00:00
37d57ac9cb Use sym_eq in _check_rms_norm_inputs_symint (#165112)
Summary:
### Problem
ArrayRef's `equals()`does elementwise quality using `==` operator. This can cause a DDE for unbacked symints since `==`  operator calls `guard_bool`.
```
// SymInt.h
bool operator==(const SymInt& o) const {
  return sym_eq(o).guard_bool(__FILE__, __LINE__);
}
```

### Solution
Adds `sym_equals()` to do elementwise equality for `SymIntArrayRef`. Use this instead of `equals()` for `SymIntArrayRef`.

Reviewed By: guangy10, pianpwk, muchulee8

Differential Revision: D84168401

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165112
Approved by: https://github.com/Skylion007
2025-10-14 00:06:24 +00:00
ecb53078fa Turn some const strings into constexpr in C++ code (#165203)
This PR turns more const strings into constexpr.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165203
Approved by: https://github.com/Skylion007
2025-10-13 20:25:20 +00:00
cad2d473bf Force inlining into torch_function_mode_enabled (#164617)
This function is relatively hot; inlining here reduces time reported by `python -m timeit --setup 'import torch; t = torch.tensor([1])' 't._cdata'` from about 125 nsec/loop to about 110 nsec/loop. (To be fair, variance is high, but I did confirm with perf that time in this path seems to have roughly halved during torchtitan training.)

Note that locally I am getting bit by a GCC bug that I documented in a comment. Would be interested to hear if this does anything for clang.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164617
Approved by: https://github.com/ezyang
2025-10-13 19:25:51 +00:00
dcce473352 [BE] Fix unused parameter warning (#165272)
Fixes
```
[23/1155] Compiling /Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/kernels/EmbeddingBag.metal to EmbeddingBag_31.air
/Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/kernels/EmbeddingBag.metal:252:62: warning: unused parameter 'bag_size' [-Wunused-parameter]
  inline opmath_t<T> operator()(opmath_t<T> val, opmath_t<T> bag_size) {
                                                             ^
1 warning generated.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165272
Approved by: https://github.com/Skylion007
2025-10-13 18:52:51 +00:00
83cbba8759 [MPS] Support large tensors in torch.cat (#164416)
Fixes #164415
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164416
Approved by: https://github.com/malfet
2025-10-13 16:56:56 +00:00
59ad8f1ac6 [XPU] Enhance XPUGeneratorImpl functionality to support XPUGraph (#163332)
As this [XPUGraph RFC](https://github.com/pytorch/pytorch/issues/162143) descripted. This PR enhances `XPUGeneratorImpl` to support XPUGraph.
In this PR, we add `XPUGerneratorState` and `PhiloxXpuState`. Which makes XPUGraph update philox state during graph capture and replay correctly

XPUGraph PR submission plan:

- [ ] 1, Enhance XPUGenerator functionality. Add XPUGeneratorState and philoxState
- [ ] 2, implemenet XPUGraph capture_begin/capture_end/instantiate functionality
- [ ] 3, implemenet XPUGraph replay/debug_dump/reset functionality
- [ ] 4, python APIs: is_current_stream_capturing/graph_pool_handle/graph
- [ ] 5, python APIs: make_graphed_callables

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163332
Approved by: https://github.com/gujinghui, https://github.com/EikanWang, https://github.com/albanD
2025-10-13 02:10:41 +00:00
de8d81275a Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)
This fixes AOTAutograd rms_norm not being bitwise equivalent to
eager, because it avoids a decomposition.  You can force the
decomposition by having the decomposition in the dispatch table,
but if eager mode wouldn't have decomposed (because it went to the fused
one), we now default to preserving the fused call by default.

This largely reverts https://github.com/pytorch/pytorch/pull/103275/ for view ops. This means that in inference mode we could hit the wrong C++ kernel; if this occurs we should just SymInt'ify the C++ kernel.

Another neat side effect of this change is that Inductor's generated kernels for rms_norm now have rms_norm in their name.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164939
Approved by: https://github.com/bdhirsh
2025-10-11 01:03:55 +00:00
ef50c9b557 Remove unnecessary "static" for definitions in anonymous namespace (#165035)
This PR removes unnecessary "static" for C++ functions and variables in anonymous namespace as detected by clang-tidy. This enhances code readability. The related rules are planed to be enabled in follow-up PRs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165035
Approved by: https://github.com/Skylion007
2025-10-11 00:04:23 +00:00
5c3fe9fb30 Revert "Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)"
This reverts commit a6fa4f9c283971c0fb6f60a89674a1f35370ac79.

Reverted https://github.com/pytorch/pytorch/pull/164939 on behalf of https://github.com/izaitsevfb due to introduces numeric issues internally, see [D84326613](https://www.internalfb.com/diff/D84326613) ([comment](https://github.com/pytorch/pytorch/pull/164939#issuecomment-3392203314))
2025-10-10 20:21:12 +00:00
94e634942a Fix int32 overflow in embedding_dense_backward (#165095)
If `max_partial_segment` is large we can overflow `gid` and cause a bunch of IMA.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165095
Approved by: https://github.com/ngimel, https://github.com/eqy
2025-10-10 19:47:38 +00:00
8f78999d77 [Inductor][ATen] Fix stride rounding on Blockwise128x128 to accommodate for small shapes (#164953)
Summary: Fix rounding issue on `Blockwise128x128` to accommodate for small shapes. The original implementation rounded all strides to 4, which caused failures for `test_fp8.py` tests as well as `test_scaled_matmul_cuda.py::test_scaled_mm_vs_emulated_block_wise` tests ([GitHub PR](https://github.com/pytorch/pytorch/pull/164259)).

Test Plan:
`test_fp8.py`
`test_scaled_matmul_cuda.py`

Differential Revision: D84103213

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164953
Approved by: https://github.com/slayton58, https://github.com/eqy
2025-10-10 19:12:58 +00:00
6f31406723 [Code Clean] Replace std::runtime_error with TORCH_CHECK (#163927)
Fixes part of  #148114

Including:

- aten/src/ATen/InferSize.h
- aten/src/ATen/functorch
- aten/src/ATen/cudnn/Types.cpp

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163927
Approved by: https://github.com/FFFrog, https://github.com/albanD

Co-authored-by: Jiawei Li <ljw1101.vip@gmail.com>
2025-10-10 18:23:27 +00:00
f975bd58af Revert "Warn if AccumulateGrad stream does not match producer node stream (#165065)"
This reverts commit a70ef954b919e990ebaba715b4072e76352867bf.

Reverted https://github.com/pytorch/pytorch/pull/165065 on behalf of https://github.com/izaitsevfb due to breaks lint ([comment](https://github.com/pytorch/pytorch/pull/165065#issuecomment-3391387386))
2025-10-10 17:29:29 +00:00
a70ef954b9 Warn if AccumulateGrad stream does not match producer node stream (#165065)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165065
Approved by: https://github.com/ngimel
ghstack dependencies: #162815
2025-10-10 16:46:01 +00:00
01a2812f48 [ROCm] Adjust grid size for non-unit stride backwards indexing (#165026)
Adjust grid size for non-unit stride backwards indexing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165026
Approved by: https://github.com/jeffdaily
2025-10-10 16:36:38 +00:00
253fd765bd bf16 support for fake_quantize_learnable_per_channel_affine (#165098)
Adding bf16 support for `torch._fake_quantize_learnable_per_channel_affine()` op by relaxing the type check on scale

TODO: need to add bf16 support to `per_tensor_affine_` as `torch._fake_quantize_learnable_per_tensor_affine_backward` gets called in the backward pass

**Test**
Modified unit test in `test_workflow_ops.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165098
Approved by: https://github.com/jerryzh168, https://github.com/andrewor14
2025-10-10 16:24:52 +00:00
172d6ed8b8 Refactor _scaled_grouped_mm_cuda dispatch (#165060)
Summary:

* Clean & simplify different scaling recipe dispatch
* Split out recipes into separate dispatch functions

Test Plan:

```
pytest -svv -k grouped  test/test_scaled_matmul_cuda.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165060
Approved by: https://github.com/danielvegamyhre, https://github.com/ngimel
2025-10-10 04:44:25 +00:00
9a3c4b917e [CMake] Remove forcing of -O2 from torch_compile_options (#164894)
That was introduced by 75a65ffe0f
Hattip to @jathu for alerting me about the issue. As result, all our PyTorch builds were shipped with `-O2` for almost all of its modern history

Partially undo the damage introduced by https://github.com/pytorch/pytorch/pull/128406 that cause cross-ISA symbols leak, to be properly followed up in https://github.com/pytorch/pytorch/issues/165123

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164894
Approved by: https://github.com/ezyang
2025-10-10 04:43:53 +00:00
7f2a902ea2 more sizelike deprecation (#164889)
remove expext_size c++ bindings and usages

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164889
Approved by: https://github.com/mlazos
ghstack dependencies: #164884, #164885, #164886, #164887, #164888
2025-10-10 03:45:06 +00:00
7614338b69 Revert "Add SVE128 ISA (#158932)"
This reverts commit 92284fb2ff44f09a9c7df0d8cf6cac9903e376a4.

Reverted https://github.com/pytorch/pytorch/pull/158932 on behalf of https://github.com/malfet due to Hmm, but from OSS point of view, this is a no-op ([comment](https://github.com/pytorch/pytorch/pull/158932#issuecomment-3387961238))
2025-10-10 01:17:02 +00:00
a6fa4f9c28 Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)
This fixes AOTAutograd rms_norm not being bitwise equivalent to
eager, because it avoids a decomposition.  You can force the
decomposition by having the decomposition in the dispatch table,
but if eager mode wouldn't have decomposed (because it went to the fused
one), we now default to preserving the fused call by default.

This largely reverts https://github.com/pytorch/pytorch/pull/103275/ for view ops. This means that in inference mode we could hit the wrong C++ kernel; if this occurs we should just SymInt'ify the C++ kernel.

Another neat side effect of this change is that Inductor's generated kernels for rms_norm now have rms_norm in their name.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164939
Approved by: https://github.com/bdhirsh
2025-10-10 00:15:00 +00:00
cd62a73dcb [cuDNN][SDPA] Handle noncontig nested tensors in cuDNN SDPA (#164958)
Previously we hardcoded the assumption in cuDNN that the inputs would be dense which breaks when e.g., the user is chunking tensors yielding noncontig inputs

New test added to check this  when `TORCH_CUDNN_SDPA_NESTED_TENSOR_ENABLED=1` is set in `test/test_transformers.py`

One issue I noticed was that the old gating of nested tensor in `sdp_utils.cpp` seems to be a no-op? All of the inputs are reported as "dense" by the time that function is called in the nested tensor tests in `test/test_nestedtensor.py -k sdpa`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164958
Approved by: https://github.com/Skylion007, https://github.com/drisspg
2025-10-09 21:58:54 +00:00
4d7f9f3aed Revert "[ATen] Fix CUDA reduction warp shuffle order (#164790)"
This reverts commit 8e1f409b8ccf64b2cf3933ece13587ad57e9d8a9.

Reverted https://github.com/pytorch/pytorch/pull/164790 on behalf of https://github.com/jeffdaily due to broke cuda and rocm ci ([comment](https://github.com/pytorch/pytorch/pull/164790#issuecomment-3387558806))
2025-10-09 21:36:10 +00:00
228973df7f Fix channels-last dimension mapping in CUDA parallel_cat (#165023)
Fixes #164849
`dimension` was updated in-place, so for more than one batch of channels-last tensors the concat `dimension` for the second kernel launch was wrong

## Testing
- python -m compileall test/test_tensor_creation_ops.py

------
https://chatgpt.com/codex/tasks/task_e_68e708879b30832f89b10ae55faa68e8
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165023
Approved by: https://github.com/ezyang
2025-10-09 20:04:32 +00:00
8e1f409b8c [ATen] Fix CUDA reduction warp shuffle order (#164790)
Typical warp shuffle reduction has the following pattern:
<img width="1138" height="501" alt="image" src="https://github.com/user-attachments/assets/3bd176dc-0ad2-4df6-90c7-06e467337166" />

which is exhibited in Triton generated by torch.compile:
<img width="663" height="403" alt="image" src="https://github.com/user-attachments/assets/7f9f36cd-b9eb-44c1-879e-b469668a2ea8" />

Switch the warp shuffle order to make bitwise equivalence between the 2 easier.
PTX difference between old and new, we see a few extra instructions: https://www.diffchecker.com/h6ly3INC/

Comparing the performance on different reduction operations, we see minimal differences. New represents the changes in this PR, old represents the past warp shuffle order:
```
Tensor Shape              Operation            New all dims (ms)       New dim=0 (ms)      New dim=1 (ms)     Old all dims (ms)    Old dim=0 (ms)      Old dim=1 (ms)
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)              mean                 0.015817             0.016259             0.013642             0.015990             0.016258             0.013631
(1024, 1024)              sum                  0.015917             0.015906             0.013359             0.015707             0.016266             0.013226
(1024, 1024)              min                  0.016021             0.024625             0.015631             0.015761             0.024485             0.015317
(1024, 1024)              max                  0.016349             0.024971             0.015972             0.015771             0.025001             0.015314
(1024, 1024)              argmin               0.018070             0.024448             0.015578             0.018135             0.025370             0.015322
(1024, 1024)              argmax               0.018427             0.024859             0.015932             0.018164             0.024452             0.015639
(1024, 1024)              var                  0.020078             0.026413             0.020295             0.020199             0.026381             0.020214
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)              mean                 0.023826             0.023726             0.022273             0.023236             0.023776             0.022248
(2048, 2048)              sum                  0.023840             0.023355             0.021974             0.023294             0.023354             0.021884
(2048, 2048)              min                  0.024519             0.041263             0.024620             0.023292             0.041491             0.024358
(2048, 2048)              max                  0.024509             0.041670             0.024277             0.023334             0.041231             0.024395
(2048, 2048)              argmin               0.026125             0.041282             0.024567             0.026772             0.041773             0.024296
(2048, 2048)              argmax               0.026117             0.041487             0.024572             0.026412             0.041477             0.024273
(2048, 2048)              var                  0.026603             0.048581             0.031308             0.027587             0.048603             0.030860
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)              mean                 0.053927             0.057070             0.054073             0.053028             0.057544             0.053935
(4096, 4096)              sum                  0.053604             0.057410             0.054451             0.053076             0.057033             0.054266
(4096, 4096)              min                  0.054293             0.109122             0.058363             0.053821             0.108689             0.058382
(4096, 4096)              max                  0.054258             0.108035             0.058703             0.053492             0.110552             0.058376
(4096, 4096)              argmin               0.056805             0.111167             0.058301             0.056836             0.112325             0.058292
(4096, 4096)              argmax               0.056488             0.110958             0.058636             0.056844             0.111000             0.057928
(4096, 4096)              var                  0.058936             0.141755             0.068693             0.059735             0.141284             0.068500
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)              mean                 0.145552             0.148082             0.138647             0.145364             0.147818             0.138207
(8192, 8192)              sum                  0.145985             0.147900             0.138714             0.145755             0.148031             0.138616
(8192, 8192)              min                  0.146566             0.205359             0.192739             0.145611             0.205237             0.182335
(8192, 8192)              max                  0.146526             0.204844             0.193050             0.146073             0.205457             0.182697
(8192, 8192)              argmin               0.150190             0.206605             0.192543             0.150654             0.206847             0.182007
(8192, 8192)              argmax               0.150481             0.206368             0.192535             0.150845             0.206430             0.182022
(8192, 8192)              var                  0.150884             0.184546             0.203900             0.151594             0.184172             0.197983
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1, 1024, 128)            mean                 0.014293             0.008119             0.014533             0.013861             0.008022             0.014449
(1, 1024, 128)            sum                  0.014039             0.007877             0.014111             0.014219             0.008227             0.014045
(1, 1024, 128)            min                  0.014159             0.011354             0.023493             0.014271             0.010862             0.023644
(1, 1024, 128)            max                  0.014154             0.011027             0.023368             0.014259             0.011234             0.023692
(1, 1024, 128)            argmin               0.016403             0.005677             0.023328             0.016273             0.005683             0.024073
(1, 1024, 128)            argmax               0.016734             0.005675             0.023437             0.016580             0.005318             0.023331
(1, 1024, 128)            var                  0.018338             0.009549             0.025538             0.018528             0.009391             0.024777
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(5, 1024, 128)            mean                 0.014873             0.010131             0.015546             0.015123             0.010131             0.015481
(5, 1024, 128)            sum                  0.015334             0.009673             0.015824             0.014736             0.009671             0.015438
(5, 1024, 128)            min                  0.015047             0.013252             0.024573             0.014803             0.013163             0.024551
(5, 1024, 128)            max                  0.015050             0.013339             0.024197             0.014810             0.013525             0.024230
(5, 1024, 128)            argmin               0.017341             0.012737             0.024306             0.017471             0.012379             0.024991
(5, 1024, 128)            argmax               0.017345             0.012411             0.024421             0.017422             0.012471             0.024237
(5, 1024, 128)            var                  0.019973             0.011453             0.026188             0.020050             0.011438             0.026282
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(10, 1024, 128)           mean                 0.016976             0.011575             0.016831             0.016722             0.011927             0.017173
(10, 1024, 128)           sum                  0.017039             0.011841             0.017159             0.016385             0.011860             0.016753
(10, 1024, 128)           min                  0.017036             0.015331             0.026770             0.016944             0.015205             0.027166
(10, 1024, 128)           max                  0.017369             0.015348             0.027077             0.016531             0.015716             0.026819
(10, 1024, 128)           argmin               0.019203             0.014447             0.026813             0.018994             0.014497             0.027313
(10, 1024, 128)           argmax               0.019563             0.014795             0.027140             0.019460             0.014912             0.026733
(10, 1024, 128)           var                  0.020529             0.014316             0.030405             0.020719             0.013960             0.029964
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(100, 1024, 128)          mean                 0.045046             0.039168             0.046082             0.044839             0.039217             0.045782
(100, 1024, 128)          sum                  0.045094             0.039150             0.045777             0.044496             0.039542             0.046083
(100, 1024, 128)          min                  0.045768             0.054466             0.076244             0.044915             0.053943             0.076599
(100, 1024, 128)          max                  0.045748             0.054459             0.076188             0.044931             0.053949             0.076856
(100, 1024, 128)          argmin               0.048275             0.054046             0.076647             0.048694             0.054105             0.077004
(100, 1024, 128)          argmax               0.048267             0.054395             0.077401             0.048691             0.054131             0.076751
(100, 1024, 128)          var                  0.049710             0.043254             0.083077             0.050971             0.043251             0.082378
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000, 100)         mean                 0.202312             0.196723             0.197765             0.201774             0.196641             0.197459
(1000, 1000, 100)         sum                  0.202651             0.196682             0.197736             0.202175             0.196313             0.197523
(1000, 1000, 100)         min                  0.203022             0.264762             0.269200             0.202729             0.264129             0.268694
(1000, 1000, 100)         max                  0.202864             0.264396             0.269388             0.202486             0.263896             0.268720
(1000, 1000, 100)         argmin               0.226727             0.263781             0.268651             0.226597             0.264676             0.268983
(1000, 1000, 100)         argmax               0.226412             0.264469             0.269090             0.226570             0.264595             0.269178
(1000, 1000, 100)         var                  0.243223             0.204079             0.216096             0.241942             0.204079             0.215925
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(10000, 100)              mean                 0.016193             0.020277             0.014316             0.016152             0.020324             0.013712
(10000, 100)              sum                  0.016289             0.020237             0.014034             0.016168             0.020265             0.013708
(10000, 100)              min                  0.016046             0.030872             0.019609             0.016208             0.030867             0.018627
(10000, 100)              max                  0.016369             0.030835             0.019257             0.016218             0.030861             0.018209
(10000, 100)              argmin               0.017957             0.031171             0.019517             0.018050             0.031556             0.018077
(10000, 100)              argmax               0.017961             0.031658             0.019521             0.018060             0.031564             0.018087
(10000, 100)              var                  0.020393             0.035652             0.019339             0.020144             0.035987             0.019171
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(100000, 10)              mean                 0.015718             0.016576             0.016555             0.015999             0.016246             0.014869
(100000, 10)              sum                  0.015833             0.016247             0.016572             0.016007             0.016627             0.014872
(100000, 10)              min                  0.015888             0.020510             0.023920             0.015671             0.020821             0.021417
(100000, 10)              max                  0.015889             0.020479             0.023918             0.016077             0.020386             0.021421
(100000, 10)              argmin               0.018233             0.020863             0.023647             0.017574             0.020864             0.021103
(100000, 10)              argmax               0.017896             0.020527             0.023296             0.017569             0.020447             0.021098
(100000, 10)              var                  0.020005             0.024198             0.024372             0.020075             0.024167             0.022415
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 1023)        mean                 1.874816             1.963506             1.903909             1.873279             1.963859             1.903230
(1023, 1023, 1023)        sum                  1.875030             1.965716             1.902458             1.873566             1.960730             1.901642
(1023, 1023, 1023)        min                  1.878563             2.473455             2.179092             1.875174             2.482086             2.183027
(1023, 1023, 1023)        max                  1.879128             2.474803             2.178895             1.874831             2.482253             2.183884
(1023, 1023, 1023)        argmin               1.921800             2.476629             2.174831             1.923987             2.472641             2.170453
(1023, 1023, 1023)        argmax               1.922605             2.476688             2.177927             1.923366             2.472808             2.172979
(1023, 1023, 1023)        var                  1.972606             3.088695             2.758797             1.978679             3.095658             2.762243
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 255)         mean                 0.489984             0.500954             0.492957             0.489891             0.500654             0.491971
(1023, 1023, 255)         sum                  0.490228             0.500764             0.492289             0.489624             0.501089             0.492824
(1023, 1023, 255)         min                  0.491457             0.563560             0.553334             0.490355             0.564709             0.554754
(1023, 1023, 255)         max                  0.491396             0.563628             0.553345             0.490017             0.565004             0.554947
(1023, 1023, 255)         argmin               0.503666             0.561512             0.551831             0.503845             0.560972             0.551017
(1023, 1023, 255)         argmax               0.503602             0.561185             0.551407             0.504328             0.561267             0.551448
(1023, 1023, 255)         var                  0.510844             0.709452             0.701630             0.512693             0.710365             0.701965
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 377)         mean                 0.707439             0.727646             0.712019             0.706769             0.727101             0.711632
(1023, 1023, 377)         sum                  0.707780             0.727453             0.711554             0.706807             0.726656             0.711729
(1023, 1023, 377)         min                  0.709423             0.819809             0.794379             0.707847             0.822086             0.796664
(1023, 1023, 377)         max                  0.709297             0.819780             0.794308             0.707566             0.821913             0.796690
(1023, 1023, 377)         argmin               0.725028             0.817088             0.791695             0.726039             0.816445             0.790828
(1023, 1023, 377)         argmax               0.725301             0.817011             0.791420             0.726040             0.816917             0.791143
(1023, 1023, 377)         var                  0.740859             1.034165             1.006712             0.743413             1.035506             1.007638
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164790
Approved by: https://github.com/ngimel, https://github.com/eqy
2025-10-09 18:08:30 +00:00
ee6a1ecb0a [ROCm] Enable MI355 CI on PRs, and run full set of UTs on PRs (#160215)
Useful to have PR testing for PRs such as https://github.com/pytorch/pytorch/pull/151360

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160215
Approved by: https://github.com/malfet, https://github.com/atalman

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-10-09 18:03:12 +00:00
3c0577bd15 Remove shared_ptr from MHAGraphCache (#164895)
This commit makes several cleanup changes to MHA.cpp, the main
one of which is removal of shared_ptr from MHAGraphCache as the
cache does not actually intend to share ownership. The changes are:

1. Remove shared_ptr from MHAGraphCache
2. Remove template arguments from MHAGraphCache
3. Remove unnecessary optional<shared_ptr<...>> vars
4. Change some functions with auto return type to the actual type

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164895
Approved by: https://github.com/eqy
2025-10-09 17:44:28 +00:00
5d459dd609 avoid bit cast for bfloat16_t (#159946)
using bit_cast<bfloat16_t> triggers a static_assert, so replace it with intrinsics.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159946
Approved by: https://github.com/aditew01, https://github.com/malfet
2025-10-09 16:42:49 +00:00
aea57b3aa3 AOTI MPS Shim Implementation (#163865)
## MPS Shim API

*   Updated MPS shimification API with handles and function declarations:
    *   `AOTIMetalShaderLibraryHandle` and `AOTIMetalKernelFunctionHandle` types
    *   Library management: `aoti_torch_mps_create_shader_library`, `aoti_torch_mps_delete_shader_library`, `aoti_torch_mps_get_kernel_function`
    *   Kernel execution: `aoti_torch_mps_run_command_block`, `aoti_torch_mps_start_encoding`, `aoti_torch_mps_dispatch` variants, etc

## MPS Shader Codegen

*   Modified to generate source constants instead of direct `DynamicMetalShaderLibrary` instantiation:
    *   **Before**: `at::native::mps::DynamicMetalShaderLibrary mps_lib_0(R"MTL(...)MTL");`
    *   **After**: `const char* mps_lib_0_source = R"MTL(...)MTL";`
*   Updated kernel call generation  to use shimified functions:
    *   Generates calls to shimified API instead of direct libtorch calls

## Before vs After Comparison

### Section 1: Shader Library
**Before (Direct Library Object)**
```cpp
at::native::mps::DynamicMetalShaderLibrary mps_lib_0(R"MTL(
    ...
)MTL");
```
**After (Source String)**
```cpp
const char* mps_lib_0_source = (R"MTL(
    ...
)MTL");
```

### Section 2: Getter Functions & RAII Management

**Before (Direct Library Access)**
```cpp
const std::shared_ptr<at::native::mps::MetalKernelFunction> get_mps_lib_0() {
    static const auto func = mps_lib_0.getKernelFunction("generated_kernel");
    return func;
}

AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() {
    static const auto handle = AOTIMetalKernelFunctionHandle(get_mps_lib_0().get());
    return handle;
}
```

**After (Shim API + RAII Wrapper)**
```cpp
AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() {
    static auto kernel_handle = []() {
        AOTIMetalShaderLibraryHandle lib_handle = nullptr;
        AOTIMetalKernelFunctionHandle kern_handle = nullptr;

        aoti_torch_mps_create_shader_library(mps_lib_0_source, &lib_handle);
        aoti_torch_mps_get_kernel_function(lib_handle, "generated_kernel", &kern_handle);

        // RAII wrapper with custom deleter
        auto lib_deleter = [](AOTIMetalShaderLibraryHandle h) {{
            if (h) aoti_torch_mps_delete_shader_library(h);
        }};

        using LibDeleter = decltype(lib_deleter);
        using LibPtr = std::unique_ptr<AOTIMetalShaderLibraryOpaque, LibDeleter>;

        // Return pair of kernel handle and library smart pointer for cleanup
        return std::make_pair(kern_handle, LibPtr(lib_handle, lib_deleter));
    }();
    return kernel_handle.first;
}
```

### Section 3: Runtime Execution

**Before (Direct Library Methods)**
```cpp
void AOTInductorModel::run_impl(...) {

    ...

    get_mps_lib_0()->runCommandBlock([&] {
        get_mps_lib_0()->startEncoding();
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 0, buf0);
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 1, arg0_1);
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 2, arg1_1);
        get_mps_lib_0()->dispatch({static_cast<uint64_t>(10LL)});

    });

    ...

} // AOTInductorModel::run_impl
```

**After (Shim API with Lambda Pattern)**
```cpp
void AOTInductorModel::run_impl(...) {

    ...

    auto mps_lib_0_lambda_0 = [&](AOTIMetalKernelFunctionHandle handle) {
        aoti_torch_mps_start_encoding(handle);
        aoti_torch_mps_set_arg_tensor(handle, 0, buf0);
        aoti_torch_mps_set_arg_tensor(handle, 1, arg0_1);
        aoti_torch_mps_set_arg_tensor(handle, 2, arg1_1);
        aoti_torch_mps_dispatch_single(handle, static_cast<uint64_t>(10LL));
    };

    std::function<void(AOTIMetalKernelFunctionHandle)> mps_lib_0_func_wrapper_0 = mps_lib_0_lambda_0;
    aoti_torch_mps_run_command_block(get_mps_lib_0_handle(), aoti_torch_mps_shared_callback, &mps_lib_0_func_wrapper_0);

    ...

} // AOTInductorModel::run_impl
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163865
Approved by: https://github.com/angelayi, https://github.com/desertfire
2025-10-09 16:06:36 +00:00
3d1fa40ae1 Revert "[BC-Breaking] Remove long-deprecated casting functions from native_functions.yaml (#164641)"
This reverts commit 64108bdbed2f099d527060b4c9fdd5a11cad2afc.

Reverted https://github.com/pytorch/pytorch/pull/164641 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/164641#issuecomment-3386346474))
2025-10-09 15:42:51 +00:00
f79e212733 Revert "[CUDA][cuBLAS] addmm -- some refactoring for easier navigation between the Lt and non-Lt paths (#163955)"
This reverts commit ab94a0d544503b5c27e889b45e45ef8cf75c8183.

Reverted https://github.com/pytorch/pytorch/pull/163955 on behalf of https://github.com/jeffdaily due to broke on cuda and rocm after landing though this PR had a clean signal initially ([comment](https://github.com/pytorch/pytorch/pull/163955#issuecomment-3386127145))
2025-10-09 14:24:56 +00:00
17c7170ca6 Fix Avoid DDE in item numel check (#164934)
address https://github.com/pytorch/pytorch/issues/164725 and https://github.com/pytorch/pytorch/issues/164704

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164934
Approved by: https://github.com/ezyang, https://github.com/aorenste, https://github.com/Skylion007
2025-10-09 13:09:06 +00:00
6a7f5c0d21 Add scaled_mm python API, test (#164142)
Summary:

* Add `torch.nn.functional.scaled_mm` as an abstraction around the C++
  methods
* Wraps `torch._scaled_mm_v2` API by default, but user can force use of
  the older `torch._scaled_mm` interface.
* Scaled MM tests now run on the new API

Test Plan:

`pytest test/test_scaled_matmul_cuda.py`

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlaytonmeta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164142
Approved by: https://github.com/drisspg
ghstack dependencies: #164141
2025-10-09 12:43:18 +00:00
512b6b59f0 Add _scaled_mm_v2 API (#164141)
Summary:

* Add new scaled-MM API to future-proof / clean-up existing code.
* Scaling is explicitly described rather than infer
* Swizzling of scaled must now be defined (vs. inferred)
* Adds API support for multi-level scaling
* Refactor dispatch logic to make it easier to add new implementations

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlaytonmeta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164141
Approved by: https://github.com/drisspg
2025-10-09 12:43:18 +00:00
4412026949 Revert "AOTI MPS Shim Implementation (#163865)"
This reverts commit 874efa2d72d83b00894097130f18062ce331a265.

Reverted https://github.com/pytorch/pytorch/pull/163865 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/163865#issuecomment-3385196387))
2025-10-09 10:26:01 +00:00
874efa2d72 AOTI MPS Shim Implementation (#163865)
## MPS Shim API

*   Updated MPS shimification API with handles and function declarations:
    *   `AOTIMetalShaderLibraryHandle` and `AOTIMetalKernelFunctionHandle` types
    *   Library management: `aoti_torch_mps_create_shader_library`, `aoti_torch_mps_delete_shader_library`, `aoti_torch_mps_get_kernel_function`
    *   Kernel execution: `aoti_torch_mps_run_command_block`, `aoti_torch_mps_start_encoding`, `aoti_torch_mps_dispatch` variants, etc

## MPS Shader Codegen

*   Modified to generate source constants instead of direct `DynamicMetalShaderLibrary` instantiation:
    *   **Before**: `at::native::mps::DynamicMetalShaderLibrary mps_lib_0(R"MTL(...)MTL");`
    *   **After**: `const char* mps_lib_0_source = R"MTL(...)MTL";`
*   Updated kernel call generation  to use shimified functions:
    *   Generates calls to shimified API instead of direct libtorch calls

## Before vs After Comparison

### Section 1: Shader Library
**Before (Direct Library Object)**
```cpp
at::native::mps::DynamicMetalShaderLibrary mps_lib_0(R"MTL(
    ...
)MTL");
```
**After (Source String)**
```cpp
const char* mps_lib_0_source = (R"MTL(
    ...
)MTL");
```

### Section 2: Getter Functions & RAII Management

**Before (Direct Library Access)**
```cpp
const std::shared_ptr<at::native::mps::MetalKernelFunction> get_mps_lib_0() {
    static const auto func = mps_lib_0.getKernelFunction("generated_kernel");
    return func;
}

AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() {
    static const auto handle = AOTIMetalKernelFunctionHandle(get_mps_lib_0().get());
    return handle;
}
```

**After (Shim API + RAII Wrapper)**
```cpp
AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() {
    static auto kernel_handle = []() {
        AOTIMetalShaderLibraryHandle lib_handle = nullptr;
        AOTIMetalKernelFunctionHandle kern_handle = nullptr;

        aoti_torch_mps_create_shader_library(mps_lib_0_source, &lib_handle);
        aoti_torch_mps_get_kernel_function(lib_handle, "generated_kernel", &kern_handle);

        // RAII wrapper with custom deleter
        auto lib_deleter = [](AOTIMetalShaderLibraryHandle h) {{
            if (h) aoti_torch_mps_delete_shader_library(h);
        }};

        using LibDeleter = decltype(lib_deleter);
        using LibPtr = std::unique_ptr<AOTIMetalShaderLibraryOpaque, LibDeleter>;

        // Return pair of kernel handle and library smart pointer for cleanup
        return std::make_pair(kern_handle, LibPtr(lib_handle, lib_deleter));
    }();
    return kernel_handle.first;
}
```

### Section 3: Runtime Execution

**Before (Direct Library Methods)**
```cpp
void AOTInductorModel::run_impl(...) {

    ...

    get_mps_lib_0()->runCommandBlock([&] {
        get_mps_lib_0()->startEncoding();
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 0, buf0);
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 1, arg0_1);
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 2, arg1_1);
        get_mps_lib_0()->dispatch({static_cast<uint64_t>(10LL)});

    });

    ...

} // AOTInductorModel::run_impl
```

**After (Shim API with Lambda Pattern)**
```cpp
void AOTInductorModel::run_impl(...) {

    ...

    auto mps_lib_0_lambda_0 = [&](AOTIMetalKernelFunctionHandle handle) {
        aoti_torch_mps_start_encoding(handle);
        aoti_torch_mps_set_arg_tensor(handle, 0, buf0);
        aoti_torch_mps_set_arg_tensor(handle, 1, arg0_1);
        aoti_torch_mps_set_arg_tensor(handle, 2, arg1_1);
        aoti_torch_mps_dispatch_single(handle, static_cast<uint64_t>(10LL));
    };

    std::function<void(AOTIMetalKernelFunctionHandle)> mps_lib_0_func_wrapper_0 = mps_lib_0_lambda_0;
    aoti_torch_mps_run_command_block(get_mps_lib_0_handle(), aoti_torch_mps_shared_callback, &mps_lib_0_func_wrapper_0);

    ...

} // AOTInductorModel::run_impl
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163865
Approved by: https://github.com/angelayi, https://github.com/desertfire
2025-10-09 09:28:10 +00:00
5209c8ce07 Revert "Fix Avoid DDE in item numel check (#164934)"
This reverts commit a9a9a3438a374f96a308b707a1718036aaec790d.

Reverted https://github.com/pytorch/pytorch/pull/164934 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/164934#issuecomment-3384390621))
2025-10-09 06:57:03 +00:00
f231be25c6 Mark unused parameters in C++ code (#164912)
This PR adds unused parameter name comments in C++ declarations to improve code readability.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164912
Approved by: https://github.com/Skylion007
2025-10-09 06:23:25 +00:00
a9a9a3438a Fix Avoid DDE in item numel check (#164934)
address https://github.com/pytorch/pytorch/issues/164725 and https://github.com/pytorch/pytorch/issues/164704

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164934
Approved by: https://github.com/ezyang, https://github.com/aorenste, https://github.com/Skylion007
2025-10-09 06:06:25 +00:00
ab94a0d544 [CUDA][cuBLAS] addmm -- some refactoring for easier navigation between the Lt and non-Lt paths (#163955)
As per title. Additionally, some Lt selection conditions are revisited, and some redundancy removed (especially in the ROCm vs non-ROCm paths).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163955
Approved by: https://github.com/ngimel, https://github.com/eqy
2025-10-09 04:07:45 +00:00
94b1ec8c7c [BE] Use torch check the way its intended (#164987)
Replace
`if (!foo) TORCH_CHECK(false, "bar");` with `TORCH_CHECK(foo, "bar");`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164987
Approved by: https://github.com/albanD, https://github.com/Skylion007
2025-10-08 22:28:08 +00:00
1d182dd81c [MPS] sparse norm (#164961)
Norms for sparse mps tensors

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164961
Approved by: https://github.com/malfet
2025-10-08 21:41:42 +00:00
71aefd5595 [reland] Allow setting grad_dtype on leaf tensors (#164751)
ghstack-source-id: e44b3941530be83a630ec93f1478eec741ffca2e
Pull-Request-resolved: https://github.com/pytorch/pytorch/pull/162815

Fixes #ISSUE_NUMBER

Relanding due to internal weirdness. Separate PR to codev w/o ghstack.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164751
Approved by: https://github.com/albanD
2025-10-08 20:23:13 +00:00
a4110fedcf Use insert_or_assign instead of erase+emplace (#164868)
insert_or_assign does effectively the same thing as
erase+emplace but more efficiently since the search
does not need to be repeated

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164868
Approved by: https://github.com/eqy
2025-10-08 19:13:49 +00:00
37c6087334 Add split-K control to cuBLAS reduced-precision settings (#164766)
## Summary
- add a CuBLASReductionOption enum so the CUDA context can track reduced-precision and split-K options
- extend the Python bindings, backend helpers, and docs to accept an optional allow_splitk argument for fp16/bf16 matmul controls
- update cuBLAS/cuBLASLt call sites plus dynamo guards and tests to respect the new combinations

## Testing
- python test/test_cuda.py TestCuda.test_cublas_allow_fp16_reduced_precision_reduction_get_set -v *(fails: ModuleNotFoundError: No module named 'psutil')*

------
https://chatgpt.com/codex/tasks/task_e_68e404623178832f8a3e1d34e1e175da

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164766
Approved by: https://github.com/malfet, https://github.com/albanD
2025-10-08 18:48:45 +00:00
0b01ff4de0 [ROCm] Improve non stride-one backwards indexing for small index sets (#164409)
This patch fixes a performance problem which occurs when a small set of indices is used and there are practically no duplicates.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164409
Approved by: https://github.com/jerrymannil, https://github.com/jeffdaily
2025-10-08 17:04:52 +00:00
01f3a43462 [MPS] Update OS version in error message (#164946)
Followup after https://github.com/pytorch/pytorch/pull/159912
Fixes https://github.com/pytorch/pytorch/issues/164943

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164946
Approved by: https://github.com/Camyll
2025-10-08 16:43:50 +00:00