Compare commits

..

72 Commits

Author SHA1 Message Date
26f67ef050 Add an option to put store large mmap weights on disk (#164526)
Summary:
As title

In windows, we cannot modify the .dll to append weights at the end, the windows .dll loader will complain it's not a valid .dll file. So we store the weight blob as a separete file.

1. We add the following API which allows passing in a pointer to the weight blob and get the size of the weight blob.

```cpp
AOTI_API AOTIRuntimeError AOTInductorModelContainerGetConstantsBlobSize(
    AOTInductorModelContainerHandle container_handle,
    uint64_t* ret_size);

// Load weights from a single blob in weight_blob_ptr
AOTI_API AOTIRuntimeError AOTInductorModelUpdateConstantsFromBlob(
    AOTInductorModelContainerHandle container_handle,
    const uint8_t* weight_blob_ptr);
```

2. We also add a method in ModelContainerRunner to load the weight:

If the runner see that there is a `.blob` file in the package, if will mmap the .blob file and use the content to load the constants.

3. We also add the `USE_MMAP_EXTERNAL` macro. When this macro is defined, the model expects to load the weights from external mmap'd weights.


Test Plan:
```
buck run mode/dev-nosan caffe2/test/inductor:test_aot_inductor -- -r test_large_mmaped_weights_on_disk
```

Also tested for windows-cross compilation with 6542566585/demo/main_voxtral.cpp

```
Loaded model.dll
audio_encoder loaded
C:\Users\shangdiy\source\repos\torchnative\demo\token_embedding\data\aotinductor\model\model.wrapper.so
Loaded model.dll
token_embedding loaded
C:\Users\shangdiy\source\repos\torchnative\demo\text_decoder\data\aotinductor\model\model.wrapper.so
Loaded model.dll
Loading weights from C:\Users\shangdiy\source\repos\torchnative\demo\text_decoder\data\aotinductor\model\model.wrapper_weights.blob
text_decoder loaded
Load latency (ms):
  audio_encoder: 1011.234
    archive extraction: 0.000
    .so loading: 1011.197
  token_embedding: 525.773
    archive extraction: 0.000
    .so loading: 525.704
  text_decoder: 3324.130
    archive extraction: 0.000
    .so loading: 3323.979
Run latency (ms):
  audio_encoder: 285.958
    audio_encoder output: dtype=bfloat16, shape=[1, 1125, 3072], numel=3456000
  token_embedding: 6.676
    token_embedding output: dtype=bfloat16, shape=[1, 1138, 3072], numel=3495936
  text_decoder: 576.519
    text_decoder output: dtype=bfloat16, shape=[1, 1138, 131072], numel=149159936
```


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

Differential Revision: D84093310

Pulled By: yushangdi
2025-10-09 15:21:58 -07:00
6d27a8e509 [CD] Do not propagate download.pytorch.org IP into container (#165075)
Followup after https://github.com/pytorch/pytorch/pull/164969

Should fix binary build test failures
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165075
Approved by: https://github.com/seemethere, https://github.com/huydhn
ghstack dependencies: #164968, #164969
2025-10-09 21:59:31 +00:00
cd62a73dcb [cuDNN][SDPA] Handle noncontig nested tensors in cuDNN SDPA (#164958)
Previously we hardcoded the assumption in cuDNN that the inputs would be dense which breaks when e.g., the user is chunking tensors yielding noncontig inputs

New test added to check this  when `TORCH_CUDNN_SDPA_NESTED_TENSOR_ENABLED=1` is set in `test/test_transformers.py`

One issue I noticed was that the old gating of nested tensor in `sdp_utils.cpp` seems to be a no-op? All of the inputs are reported as "dense" by the time that function is called in the nested tensor tests in `test/test_nestedtensor.py -k sdpa`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164958
Approved by: https://github.com/Skylion007, https://github.com/drisspg
2025-10-09 21:58:54 +00:00
4d7f9f3aed Revert "[ATen] Fix CUDA reduction warp shuffle order (#164790)"
This reverts commit 8e1f409b8ccf64b2cf3933ece13587ad57e9d8a9.

Reverted https://github.com/pytorch/pytorch/pull/164790 on behalf of https://github.com/jeffdaily due to broke cuda and rocm ci ([comment](https://github.com/pytorch/pytorch/pull/164790#issuecomment-3387558806))
2025-10-09 21:36:10 +00:00
2b9ff99535 [flex attention] change "==" to "is" in inspect parameter comparison (#165003)
Patch for https://github.com/pytorch/pytorch/issues/164760.

This doesn't actually fix the underlying torch function issue though.

Explanation: `is` is traced differently compared to `__eq__`, so we end up avoiding the issue where we attempt to evaluate `torch.eq(tensor, inspect._empty)` in the first place.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165003
Approved by: https://github.com/mlazos
2025-10-09 21:18:05 +00:00
98a081a24c Call internal log_compilation_event if it exists (#164855)
Summary: For internal conda on mast jobs, call the internal version of log_compilation_event if it exists.

Test Plan: Ran a simple test job that just calls the API: https://fburl.com/scuba/dynamo_compile/dqx8d10g
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164855
Approved by: https://github.com/c00w
2025-10-09 21:15:11 +00:00
6c0125dbc0 Mark functions const in CUDACachingAllocator (#165007)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165007
Approved by: https://github.com/eqy
2025-10-09 20:53:58 +00:00
0fd976b65c Enable mimalloc on non-Windows platforms and make default for AArch64 builds (#164741)
This change removes the Windows requirement for mimalloc builds, and makes mimalloc the default c10 system allocator for AArch64 builds. This significantly improves the performance of AArch64 builds of PyTorch as large allocations are better cached by mimalloc than glibc.

**Updated Results**

Torchbench FP32 eager Inference, 16 threads:
<img width="1510" height="733" alt="mimalloc-v2-fp32-diff" src="https://github.com/user-attachments/assets/7fe3ea0c-3b52-42e7-879b-612444479c90" />

Torchbench BF16 eager Inference, 16 threads:
<img width="1510" height="733" alt="mimalloc-v2-bf16-diff" src="https://github.com/user-attachments/assets/56469a72-9e06-4d57-ae2a-aeb139ca79a3" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164741
Approved by: https://github.com/fadara01, https://github.com/aditew01, https://github.com/malfet
2025-10-09 20:49:46 +00:00
9944cac6e6 Add suppressions to torch/_inductor (#165062)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Split this directory into two PRs to keep them from being too large.

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the project-excludes field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:
INFO 0 errors (6,884 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165062
Approved by: https://github.com/oulgen, https://github.com/mlazos
2025-10-09 20:34:20 +00:00
e7fd296930 [CI] Add full debug build to trunk (#164974)
But not test, just import torch, as regression test for https://github.com/pytorch/pytorch/issues/164297

Test plan: Re-apply #164974 on top of this change and observer the failure in the workflows: https://github.com/pytorch/pytorch/actions/runs/18383302153/job/52375282838
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164974
Approved by: https://github.com/seemethere, https://github.com/clee2000, https://github.com/atalman
ghstack dependencies: #164968, #164969
2025-10-09 20:12:16 +00:00
fac85fcfb5 [inductor] custom_graph_pass.get_hash_for_files: don't hash paths (#165020)
Summary: We have an internal user where caching broke because the paths that are unzipped are probably different per host. We can't think of a use case where a path change matters when the file content has not changed, so removing this part

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165020
Approved by: https://github.com/oulgen
2025-10-09 20:07:53 +00:00
228973df7f Fix channels-last dimension mapping in CUDA parallel_cat (#165023)
Fixes #164849
`dimension` was updated in-place, so for more than one batch of channels-last tensors the concat `dimension` for the second kernel launch was wrong

## Testing
- python -m compileall test/test_tensor_creation_ops.py

------
https://chatgpt.com/codex/tasks/task_e_68e708879b30832f89b10ae55faa68e8
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165023
Approved by: https://github.com/ezyang
2025-10-09 20:04:32 +00:00
ed2d514ad8 Revert "Fix truediv numerics between eager and compile (#164144)"
This reverts commit 724463d5a2fba369cd14e89215b84d1b01435df7.

Reverted https://github.com/pytorch/pytorch/pull/164144 on behalf of https://github.com/malfet due to Not sure if it's related, but looks it triggered fuzzer compiler test failure, see a2f29bcd63/1 ([comment](https://github.com/pytorch/pytorch/pull/164144#issuecomment-3387288464))
2025-10-09 19:53:38 +00:00
a2f29bcd63 [inductor] Remove Repeated Code in Subgraph (#164892)
Discovered some repeated code blocks in the subgraph.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164892
Approved by: https://github.com/PaulZhang12
2025-10-09 19:16:02 +00:00
5390324984 [CodeClean] Replace std::runtime_error with TORCH_CHECK (#164129)
As the title stated.

**Changes**:
- torch/csrc/Module.cpp
- torch/csrc/utils.cpp
- torch/csrc/stable
- torch/lib/libshm
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164129
Approved by: https://github.com/albanD
2025-10-09 19:01:07 +00:00
ae25ec569c reorder wrappers in aot_stage2_inference to match forward compile in aot_stage2_autograd (#165016)
In aot_stage2_autograd:
Before calling fw_compiler, we run pre_compile for the following wrappers:
* FakifiedOutWrapper
* FunctionalizedRngRuntimeWrapper

After, we run post_compile for the following wrappers:
 * EffectTokensWrapper
 * AOTDispatchSubclassWrapper
 * FunctionalizedRngRuntimeWrapper
 * FakifiedOutWrapper

In aot_stage2_inference:
Before calling inference compiler, we run pre_compile for the following wrappers (same as above):
 * FakifiedOutWrapper
 * FunctionalizedRngRuntimeWrapper

After, we run post_compile for the following wrappers  (different than above):
 * FunctionalizedRngRuntimeWrapper
 * FakifiedOutWrapper
 * EffectTokensWrapper
 * AOTDispatchSubclassWrapper

This PR makes both do the post_compiles in the same order.

Differential Revision: D84213657

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165016
Approved by: https://github.com/zhxchen17, https://github.com/bdhirsh
2025-10-09 18:36:04 +00:00
8e1f409b8c [ATen] Fix CUDA reduction warp shuffle order (#164790)
Typical warp shuffle reduction has the following pattern:
<img width="1138" height="501" alt="image" src="https://github.com/user-attachments/assets/3bd176dc-0ad2-4df6-90c7-06e467337166" />

which is exhibited in Triton generated by torch.compile:
<img width="663" height="403" alt="image" src="https://github.com/user-attachments/assets/7f9f36cd-b9eb-44c1-879e-b469668a2ea8" />

Switch the warp shuffle order to make bitwise equivalence between the 2 easier.
PTX difference between old and new, we see a few extra instructions: https://www.diffchecker.com/h6ly3INC/

Comparing the performance on different reduction operations, we see minimal differences. New represents the changes in this PR, old represents the past warp shuffle order:
```
Tensor Shape              Operation            New all dims (ms)       New dim=0 (ms)      New dim=1 (ms)     Old all dims (ms)    Old dim=0 (ms)      Old dim=1 (ms)
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)              mean                 0.015817             0.016259             0.013642             0.015990             0.016258             0.013631
(1024, 1024)              sum                  0.015917             0.015906             0.013359             0.015707             0.016266             0.013226
(1024, 1024)              min                  0.016021             0.024625             0.015631             0.015761             0.024485             0.015317
(1024, 1024)              max                  0.016349             0.024971             0.015972             0.015771             0.025001             0.015314
(1024, 1024)              argmin               0.018070             0.024448             0.015578             0.018135             0.025370             0.015322
(1024, 1024)              argmax               0.018427             0.024859             0.015932             0.018164             0.024452             0.015639
(1024, 1024)              var                  0.020078             0.026413             0.020295             0.020199             0.026381             0.020214
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)              mean                 0.023826             0.023726             0.022273             0.023236             0.023776             0.022248
(2048, 2048)              sum                  0.023840             0.023355             0.021974             0.023294             0.023354             0.021884
(2048, 2048)              min                  0.024519             0.041263             0.024620             0.023292             0.041491             0.024358
(2048, 2048)              max                  0.024509             0.041670             0.024277             0.023334             0.041231             0.024395
(2048, 2048)              argmin               0.026125             0.041282             0.024567             0.026772             0.041773             0.024296
(2048, 2048)              argmax               0.026117             0.041487             0.024572             0.026412             0.041477             0.024273
(2048, 2048)              var                  0.026603             0.048581             0.031308             0.027587             0.048603             0.030860
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)              mean                 0.053927             0.057070             0.054073             0.053028             0.057544             0.053935
(4096, 4096)              sum                  0.053604             0.057410             0.054451             0.053076             0.057033             0.054266
(4096, 4096)              min                  0.054293             0.109122             0.058363             0.053821             0.108689             0.058382
(4096, 4096)              max                  0.054258             0.108035             0.058703             0.053492             0.110552             0.058376
(4096, 4096)              argmin               0.056805             0.111167             0.058301             0.056836             0.112325             0.058292
(4096, 4096)              argmax               0.056488             0.110958             0.058636             0.056844             0.111000             0.057928
(4096, 4096)              var                  0.058936             0.141755             0.068693             0.059735             0.141284             0.068500
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)              mean                 0.145552             0.148082             0.138647             0.145364             0.147818             0.138207
(8192, 8192)              sum                  0.145985             0.147900             0.138714             0.145755             0.148031             0.138616
(8192, 8192)              min                  0.146566             0.205359             0.192739             0.145611             0.205237             0.182335
(8192, 8192)              max                  0.146526             0.204844             0.193050             0.146073             0.205457             0.182697
(8192, 8192)              argmin               0.150190             0.206605             0.192543             0.150654             0.206847             0.182007
(8192, 8192)              argmax               0.150481             0.206368             0.192535             0.150845             0.206430             0.182022
(8192, 8192)              var                  0.150884             0.184546             0.203900             0.151594             0.184172             0.197983
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1, 1024, 128)            mean                 0.014293             0.008119             0.014533             0.013861             0.008022             0.014449
(1, 1024, 128)            sum                  0.014039             0.007877             0.014111             0.014219             0.008227             0.014045
(1, 1024, 128)            min                  0.014159             0.011354             0.023493             0.014271             0.010862             0.023644
(1, 1024, 128)            max                  0.014154             0.011027             0.023368             0.014259             0.011234             0.023692
(1, 1024, 128)            argmin               0.016403             0.005677             0.023328             0.016273             0.005683             0.024073
(1, 1024, 128)            argmax               0.016734             0.005675             0.023437             0.016580             0.005318             0.023331
(1, 1024, 128)            var                  0.018338             0.009549             0.025538             0.018528             0.009391             0.024777
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(5, 1024, 128)            mean                 0.014873             0.010131             0.015546             0.015123             0.010131             0.015481
(5, 1024, 128)            sum                  0.015334             0.009673             0.015824             0.014736             0.009671             0.015438
(5, 1024, 128)            min                  0.015047             0.013252             0.024573             0.014803             0.013163             0.024551
(5, 1024, 128)            max                  0.015050             0.013339             0.024197             0.014810             0.013525             0.024230
(5, 1024, 128)            argmin               0.017341             0.012737             0.024306             0.017471             0.012379             0.024991
(5, 1024, 128)            argmax               0.017345             0.012411             0.024421             0.017422             0.012471             0.024237
(5, 1024, 128)            var                  0.019973             0.011453             0.026188             0.020050             0.011438             0.026282
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(10, 1024, 128)           mean                 0.016976             0.011575             0.016831             0.016722             0.011927             0.017173
(10, 1024, 128)           sum                  0.017039             0.011841             0.017159             0.016385             0.011860             0.016753
(10, 1024, 128)           min                  0.017036             0.015331             0.026770             0.016944             0.015205             0.027166
(10, 1024, 128)           max                  0.017369             0.015348             0.027077             0.016531             0.015716             0.026819
(10, 1024, 128)           argmin               0.019203             0.014447             0.026813             0.018994             0.014497             0.027313
(10, 1024, 128)           argmax               0.019563             0.014795             0.027140             0.019460             0.014912             0.026733
(10, 1024, 128)           var                  0.020529             0.014316             0.030405             0.020719             0.013960             0.029964
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(100, 1024, 128)          mean                 0.045046             0.039168             0.046082             0.044839             0.039217             0.045782
(100, 1024, 128)          sum                  0.045094             0.039150             0.045777             0.044496             0.039542             0.046083
(100, 1024, 128)          min                  0.045768             0.054466             0.076244             0.044915             0.053943             0.076599
(100, 1024, 128)          max                  0.045748             0.054459             0.076188             0.044931             0.053949             0.076856
(100, 1024, 128)          argmin               0.048275             0.054046             0.076647             0.048694             0.054105             0.077004
(100, 1024, 128)          argmax               0.048267             0.054395             0.077401             0.048691             0.054131             0.076751
(100, 1024, 128)          var                  0.049710             0.043254             0.083077             0.050971             0.043251             0.082378
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000, 100)         mean                 0.202312             0.196723             0.197765             0.201774             0.196641             0.197459
(1000, 1000, 100)         sum                  0.202651             0.196682             0.197736             0.202175             0.196313             0.197523
(1000, 1000, 100)         min                  0.203022             0.264762             0.269200             0.202729             0.264129             0.268694
(1000, 1000, 100)         max                  0.202864             0.264396             0.269388             0.202486             0.263896             0.268720
(1000, 1000, 100)         argmin               0.226727             0.263781             0.268651             0.226597             0.264676             0.268983
(1000, 1000, 100)         argmax               0.226412             0.264469             0.269090             0.226570             0.264595             0.269178
(1000, 1000, 100)         var                  0.243223             0.204079             0.216096             0.241942             0.204079             0.215925
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(10000, 100)              mean                 0.016193             0.020277             0.014316             0.016152             0.020324             0.013712
(10000, 100)              sum                  0.016289             0.020237             0.014034             0.016168             0.020265             0.013708
(10000, 100)              min                  0.016046             0.030872             0.019609             0.016208             0.030867             0.018627
(10000, 100)              max                  0.016369             0.030835             0.019257             0.016218             0.030861             0.018209
(10000, 100)              argmin               0.017957             0.031171             0.019517             0.018050             0.031556             0.018077
(10000, 100)              argmax               0.017961             0.031658             0.019521             0.018060             0.031564             0.018087
(10000, 100)              var                  0.020393             0.035652             0.019339             0.020144             0.035987             0.019171
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(100000, 10)              mean                 0.015718             0.016576             0.016555             0.015999             0.016246             0.014869
(100000, 10)              sum                  0.015833             0.016247             0.016572             0.016007             0.016627             0.014872
(100000, 10)              min                  0.015888             0.020510             0.023920             0.015671             0.020821             0.021417
(100000, 10)              max                  0.015889             0.020479             0.023918             0.016077             0.020386             0.021421
(100000, 10)              argmin               0.018233             0.020863             0.023647             0.017574             0.020864             0.021103
(100000, 10)              argmax               0.017896             0.020527             0.023296             0.017569             0.020447             0.021098
(100000, 10)              var                  0.020005             0.024198             0.024372             0.020075             0.024167             0.022415
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 1023)        mean                 1.874816             1.963506             1.903909             1.873279             1.963859             1.903230
(1023, 1023, 1023)        sum                  1.875030             1.965716             1.902458             1.873566             1.960730             1.901642
(1023, 1023, 1023)        min                  1.878563             2.473455             2.179092             1.875174             2.482086             2.183027
(1023, 1023, 1023)        max                  1.879128             2.474803             2.178895             1.874831             2.482253             2.183884
(1023, 1023, 1023)        argmin               1.921800             2.476629             2.174831             1.923987             2.472641             2.170453
(1023, 1023, 1023)        argmax               1.922605             2.476688             2.177927             1.923366             2.472808             2.172979
(1023, 1023, 1023)        var                  1.972606             3.088695             2.758797             1.978679             3.095658             2.762243
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 255)         mean                 0.489984             0.500954             0.492957             0.489891             0.500654             0.491971
(1023, 1023, 255)         sum                  0.490228             0.500764             0.492289             0.489624             0.501089             0.492824
(1023, 1023, 255)         min                  0.491457             0.563560             0.553334             0.490355             0.564709             0.554754
(1023, 1023, 255)         max                  0.491396             0.563628             0.553345             0.490017             0.565004             0.554947
(1023, 1023, 255)         argmin               0.503666             0.561512             0.551831             0.503845             0.560972             0.551017
(1023, 1023, 255)         argmax               0.503602             0.561185             0.551407             0.504328             0.561267             0.551448
(1023, 1023, 255)         var                  0.510844             0.709452             0.701630             0.512693             0.710365             0.701965
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 377)         mean                 0.707439             0.727646             0.712019             0.706769             0.727101             0.711632
(1023, 1023, 377)         sum                  0.707780             0.727453             0.711554             0.706807             0.726656             0.711729
(1023, 1023, 377)         min                  0.709423             0.819809             0.794379             0.707847             0.822086             0.796664
(1023, 1023, 377)         max                  0.709297             0.819780             0.794308             0.707566             0.821913             0.796690
(1023, 1023, 377)         argmin               0.725028             0.817088             0.791695             0.726039             0.816445             0.790828
(1023, 1023, 377)         argmax               0.725301             0.817011             0.791420             0.726040             0.816917             0.791143
(1023, 1023, 377)         var                  0.740859             1.034165             1.006712             0.743413             1.035506             1.007638
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164790
Approved by: https://github.com/ngimel, https://github.com/eqy
2025-10-09 18:08:30 +00:00
ee6a1ecb0a [ROCm] Enable MI355 CI on PRs, and run full set of UTs on PRs (#160215)
Useful to have PR testing for PRs such as https://github.com/pytorch/pytorch/pull/151360

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160215
Approved by: https://github.com/malfet, https://github.com/atalman

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-10-09 18:03:12 +00:00
3c0577bd15 Remove shared_ptr from MHAGraphCache (#164895)
This commit makes several cleanup changes to MHA.cpp, the main
one of which is removal of shared_ptr from MHAGraphCache as the
cache does not actually intend to share ownership. The changes are:

1. Remove shared_ptr from MHAGraphCache
2. Remove template arguments from MHAGraphCache
3. Remove unnecessary optional<shared_ptr<...>> vars
4. Change some functions with auto return type to the actual type

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164895
Approved by: https://github.com/eqy
2025-10-09 17:44:28 +00:00
688efd9741 Revert "Enable mimalloc on non-Windows platforms and make default for AArch64 builds (#164741)"
This reverts commit 87eccf10e8484c9e59ef81ae7bdee68d3db4f605.

Reverted https://github.com/pytorch/pytorch/pull/164741 on behalf of https://github.com/malfet due to But it breaks MacOS builds, see https://github.com/pytorch/pytorch/actions/runs/18382886648/job/52373781138 ([comment](https://github.com/pytorch/pytorch/pull/164741#issuecomment-3386859778))
2025-10-09 17:30:25 +00:00
91040f4934 Revert "[Code Clean] Remove support of python3.9 (#163846)"
This reverts commit bc1690c7e859dee8c47a7f0bbd3c43cc27c6fd2a.

Reverted https://github.com/pytorch/pytorch/pull/163846 on behalf of https://github.com/izaitsevfb due to breaks distributed tests ([comment](https://github.com/pytorch/pytorch/pull/163846#issuecomment-3386855437))
2025-10-09 17:27:08 +00:00
87eccf10e8 Enable mimalloc on non-Windows platforms and make default for AArch64 builds (#164741)
This change removes the Windows requirement for mimalloc builds, and makes mimalloc the default c10 system allocator for AArch64 builds. This significantly improves the performance of AArch64 builds of PyTorch as large allocations are better cached by mimalloc than glibc.

**Updated Results**

Torchbench FP32 eager Inference, 16 threads:
<img width="1510" height="733" alt="mimalloc-v2-fp32-diff" src="https://github.com/user-attachments/assets/7fe3ea0c-3b52-42e7-879b-612444479c90" />

Torchbench BF16 eager Inference, 16 threads:
<img width="1510" height="733" alt="mimalloc-v2-bf16-diff" src="https://github.com/user-attachments/assets/56469a72-9e06-4d57-ae2a-aeb139ca79a3" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164741
Approved by: https://github.com/fadara01, https://github.com/aditew01, https://github.com/malfet
2025-10-09 16:45:31 +00:00
5d459dd609 avoid bit cast for bfloat16_t (#159946)
using bit_cast<bfloat16_t> triggers a static_assert, so replace it with intrinsics.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159946
Approved by: https://github.com/aditew01, https://github.com/malfet
2025-10-09 16:42:49 +00:00
24d69c57cb Add view support for library custom Function (#164520)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164520
Approved by: https://github.com/soulitzer, https://github.com/ezyang
2025-10-09 16:17:48 +00:00
eaa02655ea [CI] Run cpp tests on windows in one run_tests call (#164861)
The windows cpp tests take ~1 hour according to logs.  Each has run_test called on them individually, so I tried batching them together so it's just one run_test call for all of them.  I believe it now takes 30min.  I turned off TD since I don't think cpp tests are included in TD stuff.

As always with batch, I'm not sure if the errorlevel/error surfacing stuff is correct

This code is written with a lot of help from chatgpu and copilot
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164861
Approved by: https://github.com/huydhn
2025-10-09 16:07:28 +00:00
aea57b3aa3 AOTI MPS Shim Implementation (#163865)
## MPS Shim API

*   Updated MPS shimification API with handles and function declarations:
    *   `AOTIMetalShaderLibraryHandle` and `AOTIMetalKernelFunctionHandle` types
    *   Library management: `aoti_torch_mps_create_shader_library`, `aoti_torch_mps_delete_shader_library`, `aoti_torch_mps_get_kernel_function`
    *   Kernel execution: `aoti_torch_mps_run_command_block`, `aoti_torch_mps_start_encoding`, `aoti_torch_mps_dispatch` variants, etc

## MPS Shader Codegen

*   Modified to generate source constants instead of direct `DynamicMetalShaderLibrary` instantiation:
    *   **Before**: `at::native::mps::DynamicMetalShaderLibrary mps_lib_0(R"MTL(...)MTL");`
    *   **After**: `const char* mps_lib_0_source = R"MTL(...)MTL";`
*   Updated kernel call generation  to use shimified functions:
    *   Generates calls to shimified API instead of direct libtorch calls

## Before vs After Comparison

### Section 1: Shader Library
**Before (Direct Library Object)**
```cpp
at::native::mps::DynamicMetalShaderLibrary mps_lib_0(R"MTL(
    ...
)MTL");
```
**After (Source String)**
```cpp
const char* mps_lib_0_source = (R"MTL(
    ...
)MTL");
```

### Section 2: Getter Functions & RAII Management

**Before (Direct Library Access)**
```cpp
const std::shared_ptr<at::native::mps::MetalKernelFunction> get_mps_lib_0() {
    static const auto func = mps_lib_0.getKernelFunction("generated_kernel");
    return func;
}

AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() {
    static const auto handle = AOTIMetalKernelFunctionHandle(get_mps_lib_0().get());
    return handle;
}
```

**After (Shim API + RAII Wrapper)**
```cpp
AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() {
    static auto kernel_handle = []() {
        AOTIMetalShaderLibraryHandle lib_handle = nullptr;
        AOTIMetalKernelFunctionHandle kern_handle = nullptr;

        aoti_torch_mps_create_shader_library(mps_lib_0_source, &lib_handle);
        aoti_torch_mps_get_kernel_function(lib_handle, "generated_kernel", &kern_handle);

        // RAII wrapper with custom deleter
        auto lib_deleter = [](AOTIMetalShaderLibraryHandle h) {{
            if (h) aoti_torch_mps_delete_shader_library(h);
        }};

        using LibDeleter = decltype(lib_deleter);
        using LibPtr = std::unique_ptr<AOTIMetalShaderLibraryOpaque, LibDeleter>;

        // Return pair of kernel handle and library smart pointer for cleanup
        return std::make_pair(kern_handle, LibPtr(lib_handle, lib_deleter));
    }();
    return kernel_handle.first;
}
```

### Section 3: Runtime Execution

**Before (Direct Library Methods)**
```cpp
void AOTInductorModel::run_impl(...) {

    ...

    get_mps_lib_0()->runCommandBlock([&] {
        get_mps_lib_0()->startEncoding();
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 0, buf0);
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 1, arg0_1);
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 2, arg1_1);
        get_mps_lib_0()->dispatch({static_cast<uint64_t>(10LL)});

    });

    ...

} // AOTInductorModel::run_impl
```

**After (Shim API with Lambda Pattern)**
```cpp
void AOTInductorModel::run_impl(...) {

    ...

    auto mps_lib_0_lambda_0 = [&](AOTIMetalKernelFunctionHandle handle) {
        aoti_torch_mps_start_encoding(handle);
        aoti_torch_mps_set_arg_tensor(handle, 0, buf0);
        aoti_torch_mps_set_arg_tensor(handle, 1, arg0_1);
        aoti_torch_mps_set_arg_tensor(handle, 2, arg1_1);
        aoti_torch_mps_dispatch_single(handle, static_cast<uint64_t>(10LL));
    };

    std::function<void(AOTIMetalKernelFunctionHandle)> mps_lib_0_func_wrapper_0 = mps_lib_0_lambda_0;
    aoti_torch_mps_run_command_block(get_mps_lib_0_handle(), aoti_torch_mps_shared_callback, &mps_lib_0_func_wrapper_0);

    ...

} // AOTInductorModel::run_impl
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163865
Approved by: https://github.com/angelayi, https://github.com/desertfire
2025-10-09 16:06:36 +00:00
3d1fa40ae1 Revert "[BC-Breaking] Remove long-deprecated casting functions from native_functions.yaml (#164641)"
This reverts commit 64108bdbed2f099d527060b4c9fdd5a11cad2afc.

Reverted https://github.com/pytorch/pytorch/pull/164641 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/164641#issuecomment-3386346474))
2025-10-09 15:42:51 +00:00
a7fa1a91e3 fix flex attention eager bwd: more rounding (#164317)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164317
Approved by: https://github.com/drisspg
ghstack dependencies: #163986
2025-10-09 15:40:49 +00:00
afeec56a5a Fix replacement reconstruct (#164937)
If we return Dtensor, the object is created via fx graph call so we never needed to reconstruct them. But if there is side effect, we do need to reconstruct it.

Differential Revision: [D84159000](https://our.internmc.facebook.com/intern/diff/D84159000)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164937
Approved by: https://github.com/StrongerXi
2025-10-09 15:31:23 +00:00
724463d5a2 Fix truediv numerics between eager and compile (#164144)
Addresses numeric differences between eager and compile in https://github.com/pytorch/pytorch/issues/141753

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164144
Approved by: https://github.com/eellison, https://github.com/jansel, https://github.com/ngimel
ghstack dependencies: #164997
2025-10-09 14:31:33 +00:00
f79e212733 Revert "[CUDA][cuBLAS] addmm -- some refactoring for easier navigation between the Lt and non-Lt paths (#163955)"
This reverts commit ab94a0d544503b5c27e889b45e45ef8cf75c8183.

Reverted https://github.com/pytorch/pytorch/pull/163955 on behalf of https://github.com/jeffdaily due to broke on cuda and rocm after landing though this PR had a clean signal initially ([comment](https://github.com/pytorch/pytorch/pull/163955#issuecomment-3386127145))
2025-10-09 14:24:56 +00:00
b28b24a9fc Switch build jobs that use linux.12xlarge to c7i (#164941)
This PR updates build jobs that currently use linux.12xlarge to the
c7i varient which should increase build times by 15% - 20% depending
on the job and reduce costs of these jobs by 10% - 15%.

Signed-off-by: Thanh Ha <thanh.ha@linuxfoundation.org>
2025-10-09 09:58:52 -04:00
17c7170ca6 Fix Avoid DDE in item numel check (#164934)
address https://github.com/pytorch/pytorch/issues/164725 and https://github.com/pytorch/pytorch/issues/164704

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164934
Approved by: https://github.com/ezyang, https://github.com/aorenste, https://github.com/Skylion007
2025-10-09 13:09:06 +00:00
6a7f5c0d21 Add scaled_mm python API, test (#164142)
Summary:

* Add `torch.nn.functional.scaled_mm` as an abstraction around the C++
  methods
* Wraps `torch._scaled_mm_v2` API by default, but user can force use of
  the older `torch._scaled_mm` interface.
* Scaled MM tests now run on the new API

Test Plan:

`pytest test/test_scaled_matmul_cuda.py`

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlaytonmeta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164142
Approved by: https://github.com/drisspg
ghstack dependencies: #164141
2025-10-09 12:43:18 +00:00
512b6b59f0 Add _scaled_mm_v2 API (#164141)
Summary:

* Add new scaled-MM API to future-proof / clean-up existing code.
* Scaling is explicitly described rather than infer
* Swizzling of scaled must now be defined (vs. inferred)
* Adds API support for multi-level scaling
* Refactor dispatch logic to make it easier to add new implementations

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlaytonmeta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164141
Approved by: https://github.com/drisspg
2025-10-09 12:43:18 +00:00
bc1690c7e8 [Code Clean] Remove support of python3.9 (#163846)
As the title stated.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163846
Approved by: https://github.com/ezyang
2025-10-09 11:54:10 +00:00
53f5af8c92 Update torch-xpu-ops commit pin (#164237)
Update the torch-xpu-ops commit to [intel/torch-xpu-ops@f30173](f301733b03), includes:

- Install xpu internal headers to PyTorch
- Fix error handling for BatchLinearAlgebra Ops
- Fix unnecessary double data type conversion
- Fix overflow when calculating workgroups count
- Fix segmentation fault and calculation error in AveragePool2dKernel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164237
Approved by: https://github.com/EikanWang
2025-10-09 10:38:59 +00:00
4412026949 Revert "AOTI MPS Shim Implementation (#163865)"
This reverts commit 874efa2d72d83b00894097130f18062ce331a265.

Reverted https://github.com/pytorch/pytorch/pull/163865 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/163865#issuecomment-3385196387))
2025-10-09 10:26:01 +00:00
06d86e58d0 Revert "Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)"
This reverts commit d40a9bfb8da0dc1ac1e6e56b33a25979112874de.

Reverted https://github.com/pytorch/pytorch/pull/164939 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/164939#issuecomment-3385056722))
2025-10-09 09:50:59 +00:00
874efa2d72 AOTI MPS Shim Implementation (#163865)
## MPS Shim API

*   Updated MPS shimification API with handles and function declarations:
    *   `AOTIMetalShaderLibraryHandle` and `AOTIMetalKernelFunctionHandle` types
    *   Library management: `aoti_torch_mps_create_shader_library`, `aoti_torch_mps_delete_shader_library`, `aoti_torch_mps_get_kernel_function`
    *   Kernel execution: `aoti_torch_mps_run_command_block`, `aoti_torch_mps_start_encoding`, `aoti_torch_mps_dispatch` variants, etc

## MPS Shader Codegen

*   Modified to generate source constants instead of direct `DynamicMetalShaderLibrary` instantiation:
    *   **Before**: `at::native::mps::DynamicMetalShaderLibrary mps_lib_0(R"MTL(...)MTL");`
    *   **After**: `const char* mps_lib_0_source = R"MTL(...)MTL";`
*   Updated kernel call generation  to use shimified functions:
    *   Generates calls to shimified API instead of direct libtorch calls

## Before vs After Comparison

### Section 1: Shader Library
**Before (Direct Library Object)**
```cpp
at::native::mps::DynamicMetalShaderLibrary mps_lib_0(R"MTL(
    ...
)MTL");
```
**After (Source String)**
```cpp
const char* mps_lib_0_source = (R"MTL(
    ...
)MTL");
```

### Section 2: Getter Functions & RAII Management

**Before (Direct Library Access)**
```cpp
const std::shared_ptr<at::native::mps::MetalKernelFunction> get_mps_lib_0() {
    static const auto func = mps_lib_0.getKernelFunction("generated_kernel");
    return func;
}

AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() {
    static const auto handle = AOTIMetalKernelFunctionHandle(get_mps_lib_0().get());
    return handle;
}
```

**After (Shim API + RAII Wrapper)**
```cpp
AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() {
    static auto kernel_handle = []() {
        AOTIMetalShaderLibraryHandle lib_handle = nullptr;
        AOTIMetalKernelFunctionHandle kern_handle = nullptr;

        aoti_torch_mps_create_shader_library(mps_lib_0_source, &lib_handle);
        aoti_torch_mps_get_kernel_function(lib_handle, "generated_kernel", &kern_handle);

        // RAII wrapper with custom deleter
        auto lib_deleter = [](AOTIMetalShaderLibraryHandle h) {{
            if (h) aoti_torch_mps_delete_shader_library(h);
        }};

        using LibDeleter = decltype(lib_deleter);
        using LibPtr = std::unique_ptr<AOTIMetalShaderLibraryOpaque, LibDeleter>;

        // Return pair of kernel handle and library smart pointer for cleanup
        return std::make_pair(kern_handle, LibPtr(lib_handle, lib_deleter));
    }();
    return kernel_handle.first;
}
```

### Section 3: Runtime Execution

**Before (Direct Library Methods)**
```cpp
void AOTInductorModel::run_impl(...) {

    ...

    get_mps_lib_0()->runCommandBlock([&] {
        get_mps_lib_0()->startEncoding();
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 0, buf0);
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 1, arg0_1);
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 2, arg1_1);
        get_mps_lib_0()->dispatch({static_cast<uint64_t>(10LL)});

    });

    ...

} // AOTInductorModel::run_impl
```

**After (Shim API with Lambda Pattern)**
```cpp
void AOTInductorModel::run_impl(...) {

    ...

    auto mps_lib_0_lambda_0 = [&](AOTIMetalKernelFunctionHandle handle) {
        aoti_torch_mps_start_encoding(handle);
        aoti_torch_mps_set_arg_tensor(handle, 0, buf0);
        aoti_torch_mps_set_arg_tensor(handle, 1, arg0_1);
        aoti_torch_mps_set_arg_tensor(handle, 2, arg1_1);
        aoti_torch_mps_dispatch_single(handle, static_cast<uint64_t>(10LL));
    };

    std::function<void(AOTIMetalKernelFunctionHandle)> mps_lib_0_func_wrapper_0 = mps_lib_0_lambda_0;
    aoti_torch_mps_run_command_block(get_mps_lib_0_handle(), aoti_torch_mps_shared_callback, &mps_lib_0_func_wrapper_0);

    ...

} // AOTInductorModel::run_impl
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163865
Approved by: https://github.com/angelayi, https://github.com/desertfire
2025-10-09 09:28:10 +00:00
e09fb44ef1 Revert "Fix truediv numerics between eager and compile (#164144)"
This reverts commit d386325ca9a142419f45b987391f4bb175dd7d0b.

Reverted https://github.com/pytorch/pytorch/pull/164144 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/164144#issuecomment-3384769092))
2025-10-09 08:40:52 +00:00
5b8174bc28 Revert "[vllm hash update] update the pinned vllm hash (#164628)"
This reverts commit 7b691546d2949790ffc8f6bd3c674faa6a46ff7c.

Reverted https://github.com/pytorch/pytorch/pull/164628 on behalf of https://github.com/huydhn due to There are some broken vLLM tests ([comment](https://github.com/pytorch/pytorch/pull/164628#issuecomment-3384560957))
2025-10-09 07:43:02 +00:00
5209c8ce07 Revert "Fix Avoid DDE in item numel check (#164934)"
This reverts commit a9a9a3438a374f96a308b707a1718036aaec790d.

Reverted https://github.com/pytorch/pytorch/pull/164934 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/164934#issuecomment-3384390621))
2025-10-09 06:57:03 +00:00
f231be25c6 Mark unused parameters in C++ code (#164912)
This PR adds unused parameter name comments in C++ declarations to improve code readability.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164912
Approved by: https://github.com/Skylion007
2025-10-09 06:23:25 +00:00
a753ffa9af Revert "Use runner with more memory for ASAN builds (#165000)"
This reverts commit f5fd18f7e24378bd9eb91404f697f1c81a8187d5.

Reverted https://github.com/pytorch/pytorch/pull/165000 on behalf of https://github.com/izaitsevfb due to not sure how, but this broke lint ([comment](https://github.com/pytorch/pytorch/pull/165000#issuecomment-3384286412))
2025-10-09 06:22:28 +00:00
a9a9a3438a Fix Avoid DDE in item numel check (#164934)
address https://github.com/pytorch/pytorch/issues/164725 and https://github.com/pytorch/pytorch/issues/164704

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164934
Approved by: https://github.com/ezyang, https://github.com/aorenste, https://github.com/Skylion007
2025-10-09 06:06:25 +00:00
263db92563 Add knobs in FR dump by watchdog (stacktrace and only active collectives) and trigger FR even on any exceptions (#164591)
This PR includes a couple of changes to extend FlightRecorder dump by PyTorch watchdog

- New knobs to control FR dump as suggested in the public documentation even for watchdog
(TORCH_INCLUDE_STACK_TRACE, TORCH_INCLUDE_ONLY_ACTIVE)
- Trigger the flight recorder dump on exceptions which could be triggered by any CUDA / host side error
  (TORCH_NCCL_EXTRA_DUMP_ON_EXEC)
-> Can be used as a snapshot of the workload progress for post-mortem analysis

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164591
Approved by: https://github.com/fduwjj
2025-10-09 05:33:35 +00:00
ed6156e3ea non-fb impls + unit tests (#164722)
Test Plan:
```
buck test fbcode//mode/opt caffe2/test/inductor:caching
```

Differential Revision: D83714692

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164722
Approved by: https://github.com/NikhilAPatel, https://github.com/adamomainz
2025-10-09 05:10:57 +00:00
d40a9bfb8d Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)
This fixes AOTAutograd rms_norm not being bitwise equivalent to
eager, because it avoids a decomposition.  You can force the
decomposition by having the decomposition in the dispatch table,
but if eager mode wouldn't have decomposed (because it went to the fused
one), we now default to preserving the fused call by default.

This largely reverts https://github.com/pytorch/pytorch/pull/103275/ for view ops. This means that in inference mode we could hit the wrong C++ kernel; if this occurs we should just SymInt'ify the C++ kernel.

Another neat side effect of this change is that Inductor's generated kernels for rms_norm now have rms_norm in their name.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164939
Approved by: https://github.com/bdhirsh
ghstack dependencies: #164573
2025-10-09 04:49:44 +00:00
e532f62e0d Introduce joint_custom_pass callback (#164981)
```
        def joint_custom_pass(joint_gm: torch.fx.GraphModule, joint_inputs):
           # apply your pass for joint graph here

            return joint_gm

        class M(torch.nn.Module):
            def forward(self, x):
                return x.sin()

        x = torch.randn(10, requires_grad=False)
        compiled_fn = torch.compile(M(), backend="aot_eager")

        with torch._functorch.config.patch("joint_custom_pass", joint_custom_pass):
            out = compiled_fn(x)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164981
Approved by: https://github.com/ezyang, https://github.com/anijain2305
2025-10-09 04:40:54 +00:00
1f73b96668 [PGO] log missing sources in allowlist (#164881)
Summary:
- logs missing dynamic sources
- emits MLHub insight only on size mismatch recompiles

Test Plan: test_pgo

Differential Revision: D84098898

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164881
Approved by: https://github.com/bobrenjc93
2025-10-09 04:39:09 +00:00
7b691546d2 [vllm hash update] update the pinned vllm hash (#164628)
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml).
Update the pinned vllm hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164628
Approved by: https://github.com/pytorchbot
2025-10-09 04:35:36 +00:00
f05e23e1bc Add less warps config to inner reductions (#162447)
Add less warps to ensure proper vectorization + memory coalescing for inner reductions, prefer more work per thread

<img width="1717" height="731" alt="Screenshot 2025-09-17 at 10 03 25 AM" src="https://github.com/user-attachments/assets/7b1f4a30-62f2-4bee-bb9c-122501bde63e" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162447
Approved by: https://github.com/v0i0, https://github.com/eellison, https://github.com/shunting314
2025-10-09 04:22:16 +00:00
d386325ca9 Fix truediv numerics between eager and compile (#164144)
Addresses numeric differences between eager and compile in https://github.com/pytorch/pytorch/issues/141753

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164144
Approved by: https://github.com/eellison, https://github.com/jansel, https://github.com/ngimel
ghstack dependencies: #164997
2025-10-09 04:22:03 +00:00
7457d139c5 Add pyrefly suppressions to torch/distributed (7/n) (#165002)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

One more PR after this one.

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the project-excludes field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:
INFO 0 errors (6,884 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165002
Approved by: https://github.com/oulgen
2025-10-09 04:08:25 +00:00
ab94a0d544 [CUDA][cuBLAS] addmm -- some refactoring for easier navigation between the Lt and non-Lt paths (#163955)
As per title. Additionally, some Lt selection conditions are revisited, and some redundancy removed (especially in the ROCm vs non-ROCm paths).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163955
Approved by: https://github.com/ngimel, https://github.com/eqy
2025-10-09 04:07:45 +00:00
0e9b3a772a [export] Turn on install_free_tensors flag (#164691)
The final step in removing the discrepancy between
torch.compile(fullgraph=True) and torch.export(strict=True).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164691
Approved by: https://github.com/avikchaudhuri
ghstack dependencies: #164721
2025-10-09 03:25:15 +00:00
af7ca55ced [export][dynamo] Fallback to slowpath for MultiHeadAttention for strict export (#164721)
In https://github.com/pytorch/pytorch/pull/106824, export decided to slow-path for MultiHeadAttention module (look into the PR description as to why). But that PR eventually caused a divergence between Dynamo and export.

Today, strict-export does not inline into builtin modules (like MultiHeadAttention), and therefore make_fx sees the original nn.Module and takes the slow path. But compile inlines into the nn module, and at this time the condition `_is_make_fx_tracing` is False. As a result, Dynamo takes a fast path, resulting in a different op being called.

This divergence is undesirable. There are 2 ways to fix it

1) Make export take the fast path - As explained in the https://github.com/pytorch/pytorch/pull/106824 , this might be difficult. So, we go to (2)
2) Make compile as well take the slow path - This is easy to implement. The con here is that Pytorch eager and compile will use different operators, which can cause numerics issues etc.

Since (2) is easy to do, we will follow this path. We are tracking the issue in  https://github.com/pytorch/pytorch/issues/164062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164721
Approved by: https://github.com/avikchaudhuri, https://github.com/tugsbayasgalan
2025-10-09 03:25:15 +00:00
a029675f6f More ruff SIM fixes (#164695)
This PR applies ruff `SIM` rules to more files. Most changes are about simplifying `dict.get` because `None` is already the default value.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164695
Approved by: https://github.com/ezyang
2025-10-09 03:24:50 +00:00
54ae61c573 Change test_emulate_precision_casts_mean_ratio_chain from gelu to relu (#164997)
gelu can be instable on local builds due to libdevice differences, as we lower to libdevice.erf. That combined with the semantics in the test can lead to catastrophic cancellation. We switch this test from gelu to relu to fix this instability.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164997
Approved by: https://github.com/eellison, https://github.com/jansel
2025-10-09 03:14:05 +00:00
2fe37b5fde [RecSys][Combo Kernel] skip combo kernel generation if parition group is empty (#164918)
Summary: Noticed sometimes the combo kernel partition will contain empty group. Skip kernel generation in this case to unblock head model launching. The change in this diff is safe, but it's better to root cause why empty group is being created.

Test Plan:
Lowering passed after applying the diff

Differential Revision: D84134471

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164918
Approved by: https://github.com/mlazos
2025-10-09 02:55:23 +00:00
96d91da792 [dynamo] allow placement subclass to be traceble (#164985)
This pr is to unblock SimpleFSDP+`gradient_divide_factor` [here](https://github.com/pytorch/torchtitan/pull/1793). We will need to create a subclass for DTensor `Partial` placement. When tracing `SimpleFSDPPartial`, I hit the assertion error that `SimpleFSDPPartial` is not in `ok_types`. I'm updating the code to check placement dtype via `isinstance` instead of `type(val)`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164985
Approved by: https://github.com/ezyang, https://github.com/eellison
2025-10-09 01:44:21 +00:00
f5fd18f7e2 Use runner with more memory for ASAN builds (#165000)
An attempt to [address OOM here](aed5ed1076/1).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165000
Approved by: https://github.com/seemethere, https://github.com/malfet, https://github.com/huydhn
2025-10-09 01:09:28 +00:00
8ca986ee60 [fr] Enable reset the FR recording for fault tolerance (#164988)
We also want to have a python side API for users to reset FR recording for FR entries. We don't need to reset the PGNCCL's member counter since we are creating new PGNCCL anyway. FR is a global ring buffer, so we need to reset it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164988
Approved by: https://github.com/tushar00jain
ghstack dependencies: #164752
2025-10-09 01:03:01 +00:00
81dbeb06f4 CUDA aarch64 12.6 and 12.8 builds fix triton constraints (#165013)
Since we have introduced CUDA aarch64 builds for all cuda versions we need to remove this constraint.
This was missed by https://github.com/pytorch/pytorch/pull/162364

Proper constraint on triton should be:
```
Requires-Dist: triton==3.5.0; platform_system == "Linux"
```

not:
```
Requires-Dist: triton==3.5.0; platform_system == "Linux" and platform_machine == "x86_64"
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165013
Approved by: https://github.com/Camyll, https://github.com/nWEIdia, https://github.com/tinglvv
2025-10-09 00:49:28 +00:00
7a1ead755f [DeviceMesh] Add a warning for slicing flattened dim from root mesh and types for _get_slice_mesh_layout (#164993)
As title, we want to add a deprecate warning for slicing flattened dim from root mesh. Also cosmetic changes for adding types for `_get_slice_mesh_layout`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164993
Approved by: https://github.com/fegin
ghstack dependencies: #164750, #164954
2025-10-09 00:47:08 +00:00
90b4e130d6 [Benchmark] cleanup torchbench models (#164816)
Prune models from TorchInductor dashboard to reduce ci cost. This PR prunes torchbench models according to the [doc](https://docs.google.com/document/d/1nLPNNAU-_M9Clx9FMrJ1ycdPxe-xRA54olPnsFzdpoU/edit?tab=t.0), which removes timm and huggingface models from torchbench.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164816
Approved by: https://github.com/anijain2305, https://github.com/seemethere, https://github.com/huydhn, https://github.com/malfet
2025-10-09 00:31:25 +00:00
4308b8a28f [dynamo] Support torch.fx.traceback.annotate (#164678)
Builds on top of https://github.com/pytorch/pytorch/pull/163673 and https://github.com/pytorch/pytorch/pull/164174. This will be used in the followup PRs to apply regional inductor compilation.

The existing implementation let Dynamo trace into the `torch.fx.traceback.annotate`, but thats not what we want. We want Dynamo to essentially run the torch.fx.traceback.annotate function in eager, so that every Fx node created in Dynamo Fx graph has the custom meta node.

What does not work?
* We still have to set the context manager `torch.fx.traceback.preserve_node_meta()` in the user code because CI was unhappy. This can be fixed but with some perseverance.
* This does not work with graph breaks yet. But we can solve that problem, if needed, in a separate PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164678
Approved by: https://github.com/SherlockNoMad, https://github.com/jansel, https://github.com/xmfan
2025-10-08 22:41:00 +00:00
94b1ec8c7c [BE] Use torch check the way its intended (#164987)
Replace
`if (!foo) TORCH_CHECK(false, "bar");` with `TORCH_CHECK(foo, "bar");`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164987
Approved by: https://github.com/albanD, https://github.com/Skylion007
2025-10-08 22:28:08 +00:00
054268c9eb Consider collective inputs to be deallocated only when wait is completed (#164945)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164945
Approved by: https://github.com/IvanKobzarev
ghstack dependencies: #164738, #164783, #164944
2025-10-08 22:19:25 +00:00
af40828bbb Limit coll bucketing within node idxs (#164944)
Respect max_coll_distance from overlap scheduler in bucketing, also, add an optimization in path searching.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164944
Approved by: https://github.com/IvanKobzarev
ghstack dependencies: #164738, #164783
2025-10-08 22:18:53 +00:00
5a1fbf45ad [ez] remove unnecessary wrapper (#164720)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164720
Approved by: https://github.com/ydwu4
2025-10-08 22:12:29 +00:00
493 changed files with 5789 additions and 4326 deletions

View File

@ -233,7 +233,9 @@ if [[ "${BUILD_ENVIRONMENT}" != *cuda* ]]; then
export BUILD_STATIC_RUNTIME_BENCHMARK=ON
fi
if [[ "$BUILD_ENVIRONMENT" == *-debug* ]]; then
if [[ "$BUILD_ENVIRONMENT" == *-full-debug* ]]; then
export CMAKE_BUILD_TYPE=Debug
elif [[ "$BUILD_ENVIRONMENT" == *-debug* ]]; then
export CMAKE_BUILD_TYPE=RelWithAssert
fi
@ -299,6 +301,11 @@ else
python -m build --wheel --no-isolation
fi
pip_install_whl "$(echo dist/*.whl)"
if [[ "$BUILD_ENVIRONMENT" == *full-debug* ]]; then
# Regression test for https://github.com/pytorch/pytorch/issues/164297
# Torch should be importable and that's about it
pushd /; python -c "import torch;print(torch.__config__.show(), torch.randn(5) + 1.7)"; popd
fi
if [[ "${BUILD_ADDITIONAL_PACKAGES:-}" == *vision* ]]; then
install_torchvision

View File

@ -256,7 +256,7 @@ test_torchbench_smoketest() {
local device=mps
local dtypes=(undefined float16 bfloat16 notset)
local dtype=${dtypes[$1]}
local models=(hf_T5 llama BERT_pytorch dcgan hf_GPT2 yolov3 resnet152 sam sam_fast pytorch_unet stable_diffusion_text_encoder speech_transformer Super_SloMo doctr_det_predictor doctr_reco_predictor timm_resnet timm_vovnet vgg16)
local models=(llama BERT_pytorch dcgan yolov3 resnet152 sam sam_fast pytorch_unet stable_diffusion_text_encoder speech_transformer Super_SloMo doctr_det_predictor doctr_reco_predictor vgg16)
for backend in eager inductor; do
@ -319,7 +319,7 @@ test_aoti_torchbench_smoketest() {
local device=mps
local dtypes=(undefined float16 bfloat16 notset)
local dtype=${dtypes[$1]}
local models=(hf_T5 llama BERT_pytorch dcgan hf_GPT2 yolov3 resnet152 sam sam_fast pytorch_unet stable_diffusion_text_encoder speech_transformer Super_SloMo doctr_det_predictor doctr_reco_predictor timm_resnet timm_vovnet vgg16)
local models=(llama BERT_pytorch dcgan yolov3 resnet152 sam sam_fast pytorch_unet stable_diffusion_text_encoder speech_transformer Super_SloMo doctr_det_predictor doctr_reco_predictor vgg16)
echo "Launching torchbench inference performance run for AOT Inductor and dtype ${dtype}"
local dtype_arg="--${dtype}"

View File

@ -838,7 +838,7 @@ test_dynamo_benchmark() {
elif [[ "${suite}" == "timm_models" ]]; then
export TORCHBENCH_ONLY_MODELS="inception_v3"
elif [[ "${suite}" == "torchbench" ]]; then
export TORCHBENCH_ONLY_MODELS="hf_Bert"
export TORCHBENCH_ONLY_MODELS="BERT_pytorch"
fi
fi
test_single_dynamo_benchmark "dashboard" "$suite" "$shard_id" "$@"
@ -869,13 +869,13 @@ test_inductor_torchbench_smoketest_perf() {
mkdir -p "$TEST_REPORTS_DIR"
python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --float16 --training \
--batch-size-file "$(realpath benchmarks/dynamo/torchbench_models_list.txt)" --only hf_Bert \
--batch-size-file "$(realpath benchmarks/dynamo/torchbench_models_list.txt)" --only BERT_pytorch \
--output "$TEST_REPORTS_DIR/inductor_training_smoketest.csv"
# The threshold value needs to be actively maintained to make this check useful
python benchmarks/dynamo/check_perf_csv.py -f "$TEST_REPORTS_DIR/inductor_training_smoketest.csv" -t 1.4
# Check memory compression ratio for a few models
for test in hf_Albert timm_vision_transformer; do
for test in BERT_pytorch yolov3; do
python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --amp --training \
--disable-cudagraphs --batch-size-file "$(realpath benchmarks/dynamo/torchbench_models_list.txt)" \
--only $test --output "$TEST_REPORTS_DIR/inductor_training_smoketest_$test.csv"

View File

@ -15,37 +15,35 @@ if errorlevel 1 exit /b 1
if not errorlevel 0 exit /b 1
cd %TMP_DIR_WIN%\build\torch\test
:: Enable delayed variable expansion to make the list
setlocal enabledelayedexpansion
set EXE_LIST=
for /r "." %%a in (*.exe) do (
call :libtorch_check "%%~na" "%%~fa"
if "%%~na" == "c10_intrusive_ptr_benchmark" (
@REM NB: This is not a gtest executable file, thus couldn't be handled by
@REM pytest-cpp and is excluded from test discovery by run_test
call "%%~fa"
if errorlevel 1 goto fail
if not errorlevel 0 goto fail
) else (
if "%%~na" == "verify_api_visibility" (
@REM Skip verify_api_visibility as it is a compile-level test
) else (
set EXE_LIST=!EXE_LIST! cpp/%%~na
)
)
)
goto :eof
:libtorch_check
cd %CWD%
set CPP_TESTS_DIR=%TMP_DIR_WIN%\build\torch\test
:: Skip verify_api_visibility as it a compile level test
if "%~1" == "verify_api_visibility" goto :eof
:: Run python test\run_test.py on the list
set NO_TD=True && python test\run_test.py --cpp --verbose -i !EXE_LIST!
if errorlevel 1 goto fail
if not errorlevel 0 goto fail
echo Running "%~2"
if "%~1" == "c10_intrusive_ptr_benchmark" (
:: NB: This is not a gtest executable file, thus couldn't be handled by pytest-cpp
call "%~2"
goto :eof
)
python test\run_test.py --cpp --verbose -i "cpp/%~1"
if errorlevel 1 (
echo %1 failed with exit code %errorlevel%
goto fail
)
if not errorlevel 0 (
echo %1 failed with exit code %errorlevel%
goto fail
)
goto :eof
:eof
exit /b 0

View File

@ -71,14 +71,7 @@ export PYTORCH_BUILD_NUMBER=1
# Set triton version as part of PYTORCH_EXTRA_INSTALL_REQUIREMENTS
TRITON_VERSION=$(cat $PYTORCH_ROOT/.ci/docker/triton_version.txt)
# Here PYTORCH_EXTRA_INSTALL_REQUIREMENTS is already set for the all the wheel builds hence append TRITON_CONSTRAINT
TRITON_CONSTRAINT="platform_system == 'Linux' and platform_machine == 'x86_64'"
# CUDA 12.9/13.0 builds have triton for Linux and Linux aarch64 binaries.
if [[ "$DESIRED_CUDA" == "cu129" ]] || [[ "$DESIRED_CUDA" == "cu130" ]]; then
TRITON_CONSTRAINT="platform_system == 'Linux'"
fi
TRITON_CONSTRAINT="platform_system == 'Linux'"
if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "${PYTORCH_EXTRA_INSTALL_REQUIREMENTS:-}" && ! "$PYTORCH_BUILD_VERSION" =~ .*xpu.* ]]; then
TRITON_REQUIREMENT="triton==${TRITON_VERSION}; ${TRITON_CONSTRAINT}"

View File

@ -274,8 +274,6 @@ runs:
-w /var/lib/jenkins/workspace \
"${DOCKER_IMAGE}"
)
# Propagate download.pytorch.org IP to container
grep download.pytorch.org /etc/hosts | docker exec -i "${container_name}" sudo bash -c "/bin/cat >> /etc/hosts"
echo "DOCKER_CONTAINER_ID=${container_name}" >> "${GITHUB_ENV}"
docker exec -t "${container_name}" sh -c "pip install $(echo dist/*.whl)[opt-einsum] && ${TEST_COMMAND}"

View File

@ -33,10 +33,6 @@ runs:
)
echo "CONTAINER_NAME=${container_name}" >> "$GITHUB_ENV"
if [[ "${GPU_ARCH_TYPE}" != "rocm" && "${BUILD_ENVIRONMENT}" != "linux-aarch64-binary-manywheel" && "${BUILD_ENVIRONMENT}" != "linux-s390x-binary-manywheel" && "${GPU_ARCH_TYPE}" != "xpu" ]]; then
# Propagate download.pytorch.org IP to container. This is only needed on Linux non aarch64 runner
grep download.pytorch.org /etc/hosts | docker exec -i "${container_name}" bash -c "/bin/cat >> /etc/hosts"
fi
docker exec -t -w "${PYTORCH_ROOT}" "${container_name}" bash -c "bash .circleci/scripts/binary_populate_env.sh"
# Generate test script

View File

@ -30,6 +30,7 @@ ciflow_push_tags:
- ciflow/riscv64
- ciflow/rocm
- ciflow/rocm-mi300
- ciflow/rocm-mi355
- ciflow/s390
- ciflow/slow
- ciflow/torchbench

View File

@ -389,8 +389,6 @@ jobs:
"${DOCKER_IMAGE}" \
${DOCKER_SHELL_CMD}
)
# Propagate download.pytorch.org IP to container
grep download.pytorch.org /etc/hosts | docker exec -i "${container_name}" sudo bash -c "/bin/cat >> /etc/hosts"
echo "DOCKER_CONTAINER_ID=${container_name}" >> "${GITHUB_ENV}"
if [[ ${BUILD_ENVIRONMENT} == *"s390x"* ]]; then

View File

@ -37,7 +37,7 @@ jobs:
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runner: "linux.12xlarge"
runner: "linux.c7i.12xlarge"
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90-dist
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
cuda-arch-list: '9.0'

View File

@ -128,7 +128,6 @@ jobs:
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runner: linux.2xlarge.memory
build-environment: linux-jammy-py3.10-clang18-asan
docker-image-name: ci-image:pytorch-linux-jammy-py3-clang18-asan
test-matrix: |

View File

@ -1,6 +1,9 @@
name: rocm-mi355
on:
push:
tags:
- ciflow/rocm-mi355/*
workflow_dispatch:
schedule:
- cron: 30 11,1 * * * # about 4:30am PDT and 6:30pm PDT
@ -64,5 +67,7 @@ jobs:
build-environment: linux-noble-rocm-py3.12-mi355
docker-image: ${{ needs.linux-noble-rocm-py3_12-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-noble-rocm-py3_12-build.outputs.test-matrix }}
tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor"
tests-to-include: >-
${{ github.event_name == 'schedule' && 'test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor test_matmul_cuda test_scaled_matmul_cuda'
|| '' }}
secrets: inherit

View File

@ -249,3 +249,14 @@ jobs:
docker-image: ${{ needs.linux-jammy-py3-clang12-executorch-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-py3-clang12-executorch-build.outputs.test-matrix }}
secrets: inherit
linux-jammy-py3_10-gcc11-full-debug-build-only:
name: linux-jammy-py3.10-gcc11-full-debug-build-only
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runner: linux.2xlarge.memory
build-environment: linux-jammy-py3.10-gcc11-full-debug-build-only
docker-image-name: ci-image:pytorch-linux-jammy-py3.10-gcc11
secrets: inherit

View File

@ -35,7 +35,7 @@ jobs:
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
build-environment: linux-jammy-xpu-n-1-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-1-py3
runner: linux.12xlarge
runner: linux.c7i.12xlarge
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 6, runner: "linux.idc.xpu" },
@ -56,7 +56,7 @@ jobs:
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
build-environment: linux-jammy-xpu-n-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3
runner: linux.12xlarge
runner: linux.c7i.12xlarge
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 8, runner: "linux.idc.xpu" },

View File

@ -388,9 +388,9 @@ cmake_dependent_option(USE_PRIORITIZED_TEXT_FOR_LD "Use prioritized text linker
option(USE_MIMALLOC "Use mimalloc" OFF)
# Enable third party mimalloc library to improve memory allocation performance
# on Windows.
# on Windows and AArch64.
option(USE_MIMALLOC_ON_MKL "Use mimalloc on MKL" OFF)
if(WIN32)
if(WIN32 OR (CPU_AARCH64 AND NOT APPLE))
set(USE_MIMALLOC ON)
# Not enable USE_MIMALLOC_ON_MKL due to it caused issue:

View File

@ -28,4 +28,19 @@ inline std::ostream& operator<<(std::ostream& stream, at::BlasBackend backend) {
return stream << BlasBackendToString(backend);
}
namespace blas {
enum class ScalingType : std::uint8_t {
TensorWise, // fp32 scales
RowWise, // fp32 scales
BlockWise1x16, // fp8_e4m3fn scales
BlockWise1x32, // fp8_e8m0fnu scales
BlockWise1x128, // fp32 scales
BlockWise128x128, // fp32 scales
};
enum class SwizzleType : std::uint8_t { NO_SWIZZLE = 0, SWIZZLE_32_4_4 = 1 };
} // namespace blas
} // namespace at

View File

@ -144,8 +144,7 @@ inline std::string _all_equal_numel_error(at::ArrayRef<Tensor> tensors) {
inline bool _apply_preamble(ArrayRef<Tensor> tensors) {
checkDeviceType("CPU_tensor_apply", tensors, kCPU);
checkLayout("CPU_tensor_apply", tensors, kStrided);
if (!_all_equal_numel(tensors))
TORCH_CHECK(false, _all_equal_numel_error(tensors));
TORCH_CHECK(_all_equal_numel(tensors), _all_equal_numel_error(tensors));
// An empty tensor has no elements
for (auto& t : tensors)
if (t.numel() == 0)

View File

@ -226,15 +226,15 @@ class TORCH_API Context {
bool userEnabledMkldnn() const;
void setUserEnabledMkldnn(bool e);
bool benchmarkCuDNN() const;
void setBenchmarkCuDNN(bool);
void setBenchmarkCuDNN(bool /*b*/);
int benchmarkLimitCuDNN() const;
void setBenchmarkLimitCuDNN(int);
void setBenchmarkLimitCuDNN(int /*b*/);
bool immediateMiopen() const;
void setImmediateMiopen(bool);
void setImmediateMiopen(bool /*b*/);
bool deterministicCuDNN() const;
void setDeterministicCuDNN(bool);
void setDeterministicCuDNN(bool /*b*/);
bool deterministicMkldnn() const;
void setDeterministicMkldnn(bool);
void setDeterministicMkldnn(bool /*b*/);
bool userEnabledNNPACK() const;
void setUserEnabledNNPACK(bool e);
@ -252,32 +252,32 @@ class TORCH_API Context {
void setSDPPriorityOrder(const std::vector<int64_t>& order);
std::array<at::SDPBackend, at::num_sdp_backends> sDPPriorityOrder();
void setSDPUseFlash(bool);
void setSDPUseFlash(bool /*e*/);
bool userEnabledFlashSDP() const;
void setSDPUseMemEfficient(bool);
void setSDPUseMemEfficient(bool /*e*/);
bool userEnabledMemEfficientSDP() const;
void setSDPUseMath(bool);
void setSDPUseMath(bool /*e*/);
bool userEnabledMathSDP() const;
void setSDPUseCuDNN(bool);
void setSDPUseCuDNN(bool /*e*/);
bool userEnabledCuDNNSDP() const;
void setAllowFP16BF16ReductionMathSDP(bool);
void setAllowFP16BF16ReductionMathSDP(bool /*e*/);
bool allowFP16BF16ReductionMathSDP() const;
void setSDPUseOverrideable(bool);
void setSDPUseOverrideable(bool /*e*/);
bool userEnabledOverrideableSDP() const;
at::LinalgBackend linalgPreferredBackend() const;
void setLinalgPreferredBackend(at::LinalgBackend);
void setLinalgPreferredBackend(at::LinalgBackend /*b*/);
at::BlasBackend blasPreferredBackend();
void setBlasPreferredBackend(at::BlasBackend);
void setBlasPreferredBackend(at::BlasBackend /*b*/);
at::ROCmFABackend getROCmFAPreferredBackend();
void setROCmFAPreferredBackend(at::ROCmFABackend);
void setROCmFAPreferredBackend(at::ROCmFABackend /*b*/);
// Note [Enabling Deterministic Operations]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -310,9 +310,9 @@ class TORCH_API Context {
bool deterministicAlgorithms() const;
bool deterministicAlgorithmsWarnOnly() const;
void setDeterministicAlgorithms(bool, bool);
void setDeterministicAlgorithms(bool /*b*/, bool /*warn_only*/);
bool deterministicFillUninitializedMemory() const;
void setDeterministicFillUninitializedMemory(bool);
void setDeterministicFillUninitializedMemory(bool /*b*/);
// Note [Writing Nondeterministic Operations]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -356,11 +356,11 @@ class TORCH_API Context {
Float32Op op,
Float32Precision p);
bool allowTF32CuDNN(std::optional<Float32Op> op = std::nullopt) const;
void setAllowTF32CuDNN(bool);
void setAllowTF32CuDNN(bool /*b*/);
bool allowTF32OneDNN() const;
void setAllowTF32OneDNN(bool);
void setAllowTF32OneDNN(bool /*b*/);
bool allowTF32CuBLAS() const;
void setAllowTF32CuBLAS(bool);
void setAllowTF32CuBLAS(bool /*b*/);
Float32MatmulPrecision float32MatmulPrecision() const;
Float32Precision float32Precision(Float32Backend backend, Float32Op op) const;
CuBLASReductionOption allowFP16ReductionCuBLAS() const;
@ -372,7 +372,7 @@ class TORCH_API Context {
bool allow_reduced_precision,
bool allow_splitk = true);
bool allowFP16AccumulationCuBLAS() const;
void setAllowFP16AccumulationCuBLAS(bool);
void setAllowFP16AccumulationCuBLAS(bool /*b*/);
// Matmuls can use a so-called "persistent" kernel which launches one CUDA
// block for each SM on the GPU, and each block then iterates over multiple
@ -384,7 +384,7 @@ class TORCH_API Context {
// to make matmuls target only a subset of the SMs, so they can fully schedule
// even next to a comms kernel, and only be a few percent slower.
std::optional<int32_t> _SMCarveout_EXPERIMENTAL() const;
void _setSMCarveout_EXPERIMENTAL(std::optional<int32_t>);
void _setSMCarveout_EXPERIMENTAL(std::optional<int32_t> /*c*/);
at::QEngine qEngine() const;
void setQEngine(at::QEngine e);
@ -405,7 +405,7 @@ class TORCH_API Context {
void setDefaultMobileCPUAllocator();
void unsetDefaultMobileCPUAllocator();
bool allowFP16ReductionCPU() const;
void setAllowFP16ReductionCPU(bool);
void setAllowFP16ReductionCPU(bool /*b*/);
// Preserved for BC
void lazyInitCUDA() {

View File

@ -62,7 +62,7 @@ constexpr const char* unknown_eventname = "eventname not specified";
#endif
} // namespace (anonymous)
MapAllocator::MapAllocator(WithFd, std::string_view filename, int fd, int flags, size_t size)
MapAllocator::MapAllocator(WithFd /*unused*/, std::string_view filename, int fd, int flags, size_t size)
: filename_(filename.empty() ? unknown_filename : filename)
, size_(0) // to be filled later
#ifdef _WIN32
@ -494,7 +494,7 @@ RefcountedMapAllocator::RefcountedMapAllocator(const char *filename, int flags,
initializeAlloc();
}
RefcountedMapAllocator::RefcountedMapAllocator(WithFd, const char *filename, int fd, int flags, size_t size)
RefcountedMapAllocator::RefcountedMapAllocator(WithFd /*unused*/, const char *filename, int fd, int flags, size_t size)
: RefcountedMapAllocatorArgCheck(flags)
, MapAllocator(WITH_FD, filename, flags, fd, size + map_alloc_alignment) {
@ -614,7 +614,7 @@ at::DataPtr MapAllocator::makeDataPtr(std::string_view filename, int flags, size
return {context->data(), context, &deleteMapAllocator, at::DeviceType::CPU};
}
at::DataPtr MapAllocator::makeDataPtr(WithFd, const char *filename, int fd, int flags, size_t size, size_t* actual_size_out) {
at::DataPtr MapAllocator::makeDataPtr(WithFd /*unused*/, const char *filename, int fd, int flags, size_t size, size_t* actual_size_out) {
auto* context = new MapAllocator(WITH_FD, filename, fd, flags, size);
if (actual_size_out) *actual_size_out = context->size();
return {context->data(), context, &deleteMapAllocator, at::DeviceType::CPU};
@ -626,7 +626,7 @@ at::DataPtr RefcountedMapAllocator::makeDataPtr(const char *filename, int flags,
return {context->data(), context, &deleteRefcountedMapAllocator, at::DeviceType::CPU};
}
at::DataPtr RefcountedMapAllocator::makeDataPtr(WithFd, const char *filename, int fd, int flags, size_t size, size_t* actual_size_out) {
at::DataPtr RefcountedMapAllocator::makeDataPtr(WithFd /*unused*/, const char *filename, int fd, int flags, size_t size, size_t* actual_size_out) {
auto* context = new RefcountedMapAllocator(WITH_FD, filename, fd, flags, size);
if (actual_size_out) *actual_size_out = context->size() - map_alloc_alignment;
return {context->data(), context, &deleteRefcountedMapAllocator, at::DeviceType::CPU};

View File

@ -25,7 +25,7 @@ class TORCH_API MapAllocator {
public:
MapAllocator(std::string_view filename, int flags, size_t size);
MapAllocator(
WithFd,
WithFd /*unused*/,
std::string_view filename,
int fd,
int flags,
@ -59,14 +59,14 @@ class TORCH_API MapAllocator {
return flags_;
}
static MapAllocator* fromDataPtr(const at::DataPtr&);
static MapAllocator* fromDataPtr(const at::DataPtr& /*dptr*/);
static at::DataPtr makeDataPtr(
std::string_view filename,
int flags,
size_t size,
size_t* actual_size_out);
static at::DataPtr makeDataPtr(
WithFd,
WithFd /*unused*/,
const char* filename,
int fd,
int flags,
@ -105,13 +105,13 @@ class TORCH_API RefcountedMapAllocator : private RefcountedMapAllocatorArgCheck,
public:
RefcountedMapAllocator(const char* filename, int flags, size_t size);
RefcountedMapAllocator(
WithFd,
WithFd /*unused*/,
const char* filename,
int fd,
int flags,
size_t size);
static RefcountedMapAllocator* fromDataPtr(const at::DataPtr&);
static RefcountedMapAllocator* fromDataPtr(const at::DataPtr& /*dptr*/);
RefcountedMapAllocator(const RefcountedMapAllocator&) = delete;
RefcountedMapAllocator(RefcountedMapAllocator&&) = delete;
RefcountedMapAllocator& operator=(const RefcountedMapAllocator&) = delete;
@ -122,7 +122,7 @@ class TORCH_API RefcountedMapAllocator : private RefcountedMapAllocatorArgCheck,
size_t size,
size_t* actual_size_out);
static at::DataPtr makeDataPtr(
WithFd,
WithFd /*unused*/,
const char* filename,
int fd,
int flags,

View File

@ -273,7 +273,7 @@ c10::SymInt NestedTensorImpl::sym_numel_custom() const {
return NestedTensorImpl::numel_custom();
}
c10::SymBool NestedTensorImpl::sym_is_contiguous_custom(MemoryFormat) const {
c10::SymBool NestedTensorImpl::sym_is_contiguous_custom(MemoryFormat /*memory_format*/) const {
return nested_tensor_impl_is_contiguous(this);
}
IntArrayRef NestedTensorImpl::sizes_custom() const {

View File

@ -115,7 +115,8 @@ struct TORCH_API NestedTensorImpl : public c10::TensorImpl {
// with real implementations
int64_t numel_custom() const override;
c10::SymInt sym_numel_custom() const override;
c10::SymBool sym_is_contiguous_custom(MemoryFormat) const override;
c10::SymBool sym_is_contiguous_custom(
MemoryFormat /*memory_format*/) const override;
int64_t size_custom(int64_t d) const override {
return this->size(d);
}

View File

@ -14,7 +14,7 @@ inline int64_t divup(int64_t x, int64_t y) {
TORCH_API void init_num_threads();
// Sets the number of threads to be used in parallel region
TORCH_API void set_num_threads(int);
TORCH_API void set_num_threads(int /*nthreads*/);
// Returns the maximum number of threads that may be used in a parallel region
TORCH_API int get_num_threads();
@ -37,7 +37,7 @@ inline void lazy_init_num_threads() {
}
}
TORCH_API void set_thread_num(int);
TORCH_API void set_thread_num(int /*id*/);
class TORCH_API ThreadIdGuard {
public:
@ -130,7 +130,7 @@ inline scalar_t parallel_reduce(
TORCH_API std::string get_parallel_info();
// Sets number of threads used for inter-op parallelism
TORCH_API void set_num_interop_threads(int);
TORCH_API void set_num_interop_threads(int /*nthreads*/);
// Returns the number of threads used for inter-op parallelism
TORCH_API size_t get_num_interop_threads();

View File

@ -252,7 +252,7 @@ void SparseCsrTensorImpl::set_stride(int64_t dim, int64_t new_stride) {
void SparseCsrTensorImpl::set_storage_offset(int64_t storage_offset) {
TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have set_storage_offset.");
}
c10::SymBool SparseCsrTensorImpl::sym_is_contiguous_custom(MemoryFormat) const {
c10::SymBool SparseCsrTensorImpl::sym_is_contiguous_custom(MemoryFormat /*memory_format*/) const {
TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have is_contiguous");
}
} // namespace at

View File

@ -32,10 +32,10 @@ struct TORCH_API SparseCsrTensorImpl : public TensorImpl {
public:
explicit SparseCsrTensorImpl(
at::DispatchKeySet,
at::DispatchKeySet /*key_set*/,
at::Device device,
Layout layout,
const caffe2::TypeMeta);
const caffe2::TypeMeta /*data_type*/);
void resize_(int64_t nnz, IntArrayRef size);
void resize_and_clear_(
@ -86,7 +86,8 @@ struct TORCH_API SparseCsrTensorImpl : public TensorImpl {
protected:
IntArrayRef strides_custom() const override;
SymIntArrayRef sym_strides_custom() const override;
SymBool sym_is_contiguous_custom(MemoryFormat) const override;
SymBool sym_is_contiguous_custom(
MemoryFormat /*memory_format*/) const override;
public:
void set_size(int64_t dim, int64_t new_size) override;

View File

@ -46,7 +46,9 @@ struct TORCH_API SparseTensorImpl : public TensorImpl {
public:
// Public for now...
explicit SparseTensorImpl(at::DispatchKeySet, const caffe2::TypeMeta);
explicit SparseTensorImpl(
at::DispatchKeySet /*key_set*/,
const caffe2::TypeMeta /*data_type*/);
void release_resources() override;
@ -384,8 +386,8 @@ struct TORCH_API SparseTensorImpl : public TensorImpl {
private:
explicit SparseTensorImpl(
at::DispatchKeySet,
const caffe2::TypeMeta,
at::DispatchKeySet /*key_set*/,
const caffe2::TypeMeta /*data_type*/,
at::Tensor indices,
at::Tensor values);

View File

@ -112,10 +112,10 @@ TORCH_API std::ostream& operator<<(std::ostream& stream, const Slice& slice);
// `torch.tensor([1, 2])`) | `torch::tensor({1, 2})`
struct TORCH_API TensorIndex final {
// Case 1: `at::indexing::None`
TensorIndex(std::nullopt_t) : type_(TensorIndexType::None) {}
TensorIndex(std::nullopt_t /*unused*/) : type_(TensorIndexType::None) {}
// Case 2: "..." / `at::indexing::Ellipsis`
TensorIndex(at::indexing::EllipsisIndexType)
TensorIndex(at::indexing::EllipsisIndexType /*unused*/)
: type_(TensorIndexType::Ellipsis) {}
TensorIndex(const char* str) : TensorIndex(at::indexing::Ellipsis) {
TORCH_CHECK_VALUE(

View File

@ -250,7 +250,7 @@ struct TORCH_API TensorIteratorBase : public impl::MetaBase {
using PtrVector = SmallVector<char*, 4>;
using StrideVector = SmallVector<int64_t, 6>;
void build(TensorIteratorConfig&);
void build(TensorIteratorConfig& /*config*/);
// The inner-loop function operates on the fastest moving dimension. It
// implements element-wise operations in terms of 1-d strided tensors.
@ -618,20 +618,20 @@ struct TORCH_API TensorIteratorBase : public impl::MetaBase {
#undef TORCH_DISALLOW_TEMPORARIES
protected:
// Mutable reference as it moves tensors out of TensorIteratorConfig
void populate_operands(TensorIteratorConfig&);
void populate_operands(TensorIteratorConfig& /*config*/);
void mark_outputs();
void mark_resize_outputs(const TensorIteratorConfig&);
void compute_mem_overlaps(const TensorIteratorConfig&);
void compute_shape(const TensorIteratorConfig&);
void compute_strides(const TensorIteratorConfig&);
void mark_resize_outputs(const TensorIteratorConfig& /*config*/);
void compute_mem_overlaps(const TensorIteratorConfig& /*config*/);
void compute_shape(const TensorIteratorConfig& /*config*/);
void compute_strides(const TensorIteratorConfig& /*config*/);
void reorder_dimensions();
void permute_dimensions(IntArrayRef perm);
void compute_types(const TensorIteratorConfig&);
void compute_types(const TensorIteratorConfig& /*config*/);
ScalarType compute_common_dtype();
void allocate_or_resize_outputs();
bool fast_set_up(const TensorIteratorConfig&);
FastSetupType compute_fast_setup_type(const TensorIteratorConfig&);
void compute_names(const TensorIteratorConfig&);
bool fast_set_up(const TensorIteratorConfig& /*config*/);
FastSetupType compute_fast_setup_type(const TensorIteratorConfig& /*config*/);
void compute_names(const TensorIteratorConfig& /*config*/);
void propagate_names_to_outputs();
void coalesce_dimensions();

View File

@ -20,7 +20,7 @@
namespace at {
TORCH_API int _crash_if_asan(int);
TORCH_API int _crash_if_asan(int /*arg*/);
// Converts a TensorList (i.e. ArrayRef<Tensor> to vector of TensorImpl*)
// NB: This is ONLY used by legacy TH bindings, and ONLY used by cat.

View File

@ -148,7 +148,7 @@ Tensor cached_cast(at::ScalarType to_type, const Tensor& arg, DeviceType device_
Banned functions
*******************************/
static Tensor binary_cross_entropy_banned(const Tensor &, const Tensor &, const std::optional<Tensor>&, int64_t) {
static Tensor binary_cross_entropy_banned(const Tensor & /*unused*/, const Tensor & /*unused*/, const std::optional<Tensor>& /*unused*/, int64_t /*unused*/) {
TORCH_CHECK(false, "torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.\n"
"Many models use a sigmoid layer right before the binary cross entropy layer.\n"
"In this case, combine the two layers using torch.nn.functional.binary_cross_entropy_with_logits\n"

View File

@ -27,11 +27,11 @@ struct TORCH_API NamedTensorMeta final : public c10::NamedTensorMetaInterface {
HasNonWildcard
};
explicit NamedTensorMeta(HAS_NON_WILDCARD, DimnameList names)
explicit NamedTensorMeta(HAS_NON_WILDCARD /*unused*/, DimnameList names)
: names_(names.vec()) {
check_invariants();
}
explicit NamedTensorMeta(HAS_NON_WILDCARD, std::vector<Dimname>&& names)
explicit NamedTensorMeta(HAS_NON_WILDCARD /*unused*/, std::vector<Dimname>&& names)
: names_(std::move(names)) {
check_invariants();
}
@ -52,13 +52,13 @@ struct TORCH_API NamedTensorMeta final : public c10::NamedTensorMetaInterface {
std::any_of(names_.begin(), names_.end(), [](const Dimname& n) { return !n.isWildcard(); }));
}
void set_names(HAS_NON_WILDCARD, DimnameList new_names) {
void set_names(HAS_NON_WILDCARD /*unused*/, DimnameList new_names) {
TORCH_INTERNAL_ASSERT(new_names.size() == names_.size());
std::copy(new_names.begin(), new_names.end(), names_.begin());
check_invariants();
}
void set_names(HAS_NON_WILDCARD, std::vector<Dimname>&& new_names) {
void set_names(HAS_NON_WILDCARD /*unused*/, std::vector<Dimname>&& new_names) {
TORCH_INTERNAL_ASSERT(new_names.size() == names_.size());
names_ = std::move(new_names);
check_invariants();

View File

@ -13,7 +13,7 @@ class TORCH_API PythonOpRegistrationTrampoline final {
public:
// Returns true if you successfully registered yourself (that means
// you are in the hot seat for doing the operator registrations!)
static bool registerInterpreter(c10::impl::PyInterpreter*);
static bool registerInterpreter(c10::impl::PyInterpreter* /*interp*/);
// Returns nullptr if no interpreter has been registered yet.
static c10::impl::PyInterpreter* getInterpreter();

View File

@ -100,7 +100,7 @@ class TORCH_API TensorBase {
// Create a Tensor with a +0 reference count. Special care must be
// taken to avoid decrementing this reference count at destruction
// time. Intended to support MaybeOwnedTraits<Tensor>.
explicit TensorBase(unsafe_borrow_t, const TensorBase& rhs)
explicit TensorBase(unsafe_borrow_t /*unused*/, const TensorBase& rhs)
: impl_(c10::intrusive_ptr<at::TensorImpl, UndefinedTensorImpl>(rhs.impl_.get(), c10::raw::DontIncreaseRefcount{})) {}
friend MaybeOwnedTraits<TensorBase>;
@ -954,7 +954,7 @@ protected:
c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> impl_;
private:
TensorBase __dispatch_contiguous(c10::MemoryFormat) const;
TensorBase __dispatch_contiguous(c10::MemoryFormat /*memory_format*/) const;
};
inline DeviceIndex get_device(const TensorBase& self) {

View File

@ -18,10 +18,10 @@ class KernelFunction;
// implementation notes; notably, this does NOT actually go through the
// boxing/unboxing codepath.
TORCH_API void fallthrough_kernel(
OperatorKernel*,
const OperatorHandle&,
DispatchKeySet,
Stack*);
OperatorKernel* /*unused*/,
const OperatorHandle& /*unused*/,
DispatchKeySet /*unused*/,
Stack* /*unused*/);
// Note [Ambiguity in AutogradOther kernel]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -62,10 +62,10 @@ TORCH_API void fallthrough_kernel(
// than arbitrarily pick one or the other, we just register a kernel that raises
// an error and let the user decide how to proceed.
TORCH_API void ambiguous_autogradother_kernel(
OperatorKernel*,
const OperatorHandle&,
DispatchKeySet,
Stack*);
OperatorKernel* /*unused*/,
const OperatorHandle& /*op*/,
DispatchKeySet /*unused*/,
Stack* /*unused*/);
// Note [named_not_supported_kernel]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -75,10 +75,10 @@ TORCH_API void ambiguous_autogradother_kernel(
// give a good error message in cases when boxing is not supported). When
// boxing is universally supported this can be removed.
[[noreturn]] TORCH_API void named_not_supported_kernel(
OperatorKernel*,
const OperatorHandle&,
DispatchKeySet,
Stack*);
OperatorKernel* /*unused*/,
const OperatorHandle& /*op*/,
DispatchKeySet /*unused*/,
Stack* /*unused*/);
/**
* BoxedKernel is similar to a std::function storing a boxed kernel.
@ -185,16 +185,16 @@ class TORCH_API BoxedKernel final {
template <BoxedKernelFunction* func>
static void make_boxed_function(
OperatorKernel*,
OperatorKernel* /*unused*/,
const OperatorHandle& opHandle,
DispatchKeySet,
DispatchKeySet /*unused*/,
Stack* stack);
template <BoxedKernelFunction_withDispatchKeys* func>
static void make_boxed_function(
OperatorKernel*,
OperatorKernel* /*unused*/,
const OperatorHandle& opHandle,
DispatchKeySet,
DispatchKeySet /*ks*/,
Stack* stack);
explicit BoxedKernel(

View File

@ -11,9 +11,9 @@ inline BoxedKernel::BoxedKernel(
template <BoxedKernel::BoxedKernelFunction* func>
inline void BoxedKernel::make_boxed_function(
OperatorKernel*,
OperatorKernel* /*unused*/,
const OperatorHandle& opHandle,
DispatchKeySet,
DispatchKeySet /*unused*/,
Stack* stack) {
// Note that we're dropping the DispatchKeySet argument.
// See Note [Plumbing Keys Through The Dispatcher 2] for details.
@ -22,7 +22,7 @@ inline void BoxedKernel::make_boxed_function(
template <BoxedKernel::BoxedKernelFunction_withDispatchKeys* func>
inline void BoxedKernel::make_boxed_function(
OperatorKernel*,
OperatorKernel* /*unused*/,
const OperatorHandle& opHandle,
DispatchKeySet ks,
Stack* stack) {

View File

@ -10,7 +10,7 @@ namespace c10 {
// be handled specially. Its semantics is that it redispatches to the
// *next* dispatch key that would have been processed, skipping the current
// one.
void fallthrough_kernel(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*) {
void fallthrough_kernel(OperatorKernel* /*unused*/, const OperatorHandle& /*unused*/, DispatchKeySet /*unused*/, Stack* /*unused*/) {
TORCH_INTERNAL_ASSERT(0,
"fallthrough_kernel was executed but it should have been short-circuited by the dispatcher. "
"This could occur if you registered a fallthrough kernel as a override for a specific operator "
@ -19,7 +19,7 @@ void fallthrough_kernel(OperatorKernel*, const OperatorHandle&, DispatchKeySet,
"let us know in the bug tracker.");
}
void ambiguous_autogradother_kernel(OperatorKernel*, const OperatorHandle& op, DispatchKeySet, Stack*) {
void ambiguous_autogradother_kernel(OperatorKernel* /*unused*/, const OperatorHandle& op, DispatchKeySet /*unused*/, Stack* /*unused*/) {
TORCH_INTERNAL_ASSERT(0,
op.operator_name(), " has kernels registered to both CompositeImplicitAutograd and a backend mapped to AutogradOther. "
"This makes the backend kernel unreachable; the dispatcher will always prefer the CompositeImplicitAutograd lowering "
@ -32,7 +32,7 @@ void ambiguous_autogradother_kernel(OperatorKernel*, const OperatorHandle& op, D
"\nCanonical state\n~~~~~~~~~~~\n", op.dumpState(), "\n\n");
}
void named_not_supported_kernel(OperatorKernel*, const OperatorHandle& op, DispatchKeySet, Stack*) {
void named_not_supported_kernel(OperatorKernel* /*unused*/, const OperatorHandle& op, DispatchKeySet /*unused*/, Stack* /*unused*/) {
// DO NOT LOOK AT STACK, YOU HAVE SHORT CIRCUITED BOXING
// See Note [named_not_supported_kernel]
TORCH_CHECK(0,

View File

@ -229,7 +229,7 @@ class TORCH_API KernelFunction final {
* &unboxed_func>();
*/
template <class FuncPtr, bool AllowLegacyTypes = false>
static KernelFunction makeFromUnboxedFunction(FuncPtr);
static KernelFunction makeFromUnboxedFunction(FuncPtr /*func_ptr*/);
/**
* Create a KernelFunction from an unboxed function.
@ -271,7 +271,7 @@ class TORCH_API KernelFunction final {
std::string dumpState() const;
// For testing internal invariants only
bool _equalsBoxedAndUnboxed(const KernelFunction&) const;
bool _equalsBoxedAndUnboxed(const KernelFunction& /*other*/) const;
// Register a token to be invalidated when this KernelFunction is destroyed
void registerToken(std::weak_ptr<KernelToken> token) const;

View File

@ -131,7 +131,7 @@ C10_ALWAYS_INLINE_UNLESS_MOBILE void boxToStack(
new (dest++) IValue(options.pinned_memory());
}
inline void boxArgsToStack(IValue*&) {}
inline void boxArgsToStack(IValue*& /*unused*/) {}
template <typename T, typename... Args>
C10_ALWAYS_INLINE_UNLESS_MOBILE void boxArgsToStack(
@ -185,7 +185,7 @@ struct PopResult<std::tuple<Types...>> final {
template <size_t... indices>
static Result pop_to_tuple_impl(
Stack& stack,
std::index_sequence<indices...>) {
std::index_sequence<indices...> /*unused*/) {
return std::make_tuple((std::move(stack[indices]).template to<Types>())...);
}
};

View File

@ -561,7 +561,7 @@ struct wrap_kernel_functor_unboxed_<
// doesn't use &&
static ReturnType call(
OperatorKernel* functor,
DispatchKeySet,
DispatchKeySet /*unused*/,
ParameterTypes... args) {
KernelFunctor* functor_ = static_cast<KernelFunctor*>(functor);
// Note [Plumbing Keys Through The Dispatcher 2]
@ -629,8 +629,8 @@ call_functor_with_args_from_stack_(
OperatorKernel* functor,
DispatchKeySet dispatchKeySet,
Stack* stack,
std::index_sequence<ivalue_arg_indices...>,
guts::typelist::typelist<ArgTypes...>*) {
std::index_sequence<ivalue_arg_indices...> /*unused*/,
guts::typelist::typelist<ArgTypes...>* /*unused*/) {
(void)(stack); // when sizeof...(ivalue_arg_indices) == 0, this argument would
// be unused and we have to silence the compiler warning.
@ -708,7 +708,7 @@ struct push_outputs<std::tuple<OutputTypes...>, AllowDeprecatedTypes> final {
static void call_(
std::tuple<OutputTypes...>&& output,
Stack* stack,
std::index_sequence<indices...>) {
std::index_sequence<indices...> /*unused*/) {
torch::jit::push(
*stack,
return_to_ivalue<OutputTypes, AllowDeprecatedTypes>::call(
@ -718,7 +718,7 @@ struct push_outputs<std::tuple<OutputTypes...>, AllowDeprecatedTypes> final {
static void copy_(
const std::tuple<OutputTypes...>& output,
Stack* stack,
std::index_sequence<indices...>) {
std::index_sequence<indices...> /*unused*/) {
torch::jit::push(
*stack,
return_to_ivalue<OutputTypes, AllowDeprecatedTypes>::copy(
@ -741,7 +741,7 @@ struct make_boxed_from_unboxed_functor final {
static void call(
OperatorKernel* functor,
const OperatorHandle&,
const OperatorHandle& /*unused*/,
DispatchKeySet dispatchKeySet,
Stack* stack) {
using ReturnType =

View File

@ -63,13 +63,13 @@ struct BuiltinOpFunction : public Function {
bool call(
Stack& stack,
std::optional<size_t>,
c10::function_ref<void(const Code&)>) override {
std::optional<size_t> /*unused*/,
c10::function_ref<void(const Code&)> /*unused*/) override {
run(stack);
return false;
}
bool call(Stack& stack, c10::function_ref<void(const mobile::Code&)>)
bool call(Stack& stack, c10::function_ref<void(const mobile::Code&)> /*unused*/)
override {
run(stack);
return false;

View File

@ -80,7 +80,8 @@ struct MultiDispatchKeySet : at::IterArgs<MultiDispatchKeySet> {
ts = ts | x.key_set();
}
}
[[noreturn]] void operator()(at::ArrayRef<std::optional<at::Tensor>>) {
[[noreturn]] void operator()(
at::ArrayRef<std::optional<at::Tensor>> /*unused*/) {
// Just checking that the handling of Tensor?[] didn't change.
TORCH_INTERNAL_ASSERT(false);
}
@ -95,7 +96,7 @@ struct MultiDispatchKeySet : at::IterArgs<MultiDispatchKeySet> {
}
}
template <typename T>
void operator()(const T&) {
void operator()(const T& /*unused*/) {
// do nothing
}
};

View File

@ -633,7 +633,7 @@ class TypedOperatorHandle<Return(Args...)> final : public OperatorHandle {
namespace detail {
template <class... Args>
inline void unused_arg_(const Args&...) {}
inline void unused_arg_(const Args&... /*unused*/) {}
// CaptureKernelCall is intended to capture return values from Dispatcher
// unboxed kernel calls. A record function may request to get outputs from the

View File

@ -105,7 +105,7 @@ class TORCH_API OperatorEntry final {
// versa that is an error. (Refcounting for the registrations is
// handled in the OperatorHandle in Dispatcher)
void registerSchema(
FunctionSchema&&,
FunctionSchema&& /*schema*/,
std::string&& debug,
std::vector<at::Tag> tags = {});
void deregisterSchema();

View File

@ -177,7 +177,7 @@ bool DynamicType::equals(const Type& rhs) const {
return equals(*create(rhs));
}
bool DynamicType::isSubtypeOfExt(const Type& rhs, std::ostream*) const {
bool DynamicType::isSubtypeOfExt(const Type& rhs, std::ostream* /*why_not*/) const {
auto other = create(rhs);
if (tag_ == other->tag_) {
if (equals(*other)) {
@ -371,7 +371,7 @@ DynamicTypePtr ivalue::TupleTypeFactory<c10::DynamicType>::create(
}
DynamicTypePtr ivalue::TupleTypeFactory<c10::DynamicType>::fallback(
const Type&) {
const Type& /*unused*/) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false);
return nullptr;
}

View File

@ -138,8 +138,8 @@ class DynamicType : public SharedType {
struct Arguments {
Arguments() = default;
Arguments(c10::ArrayRef<TypePtr>);
Arguments(const std::vector<std::string_view>&, c10::ArrayRef<TypePtr>);
Arguments(c10::ArrayRef<TypePtr> /*args*/);
Arguments(const std::vector<std::string_view>& /*names*/, c10::ArrayRef<TypePtr> /*args*/);
std::vector<LabeledDynamicType> elems;
};
@ -156,15 +156,15 @@ class DynamicType : public SharedType {
static const TypeKind Kind = TypeKind::DynamicType;
static TORCH_API DynamicTypePtr create(Type& ty);
explicit DynamicType(Tag, Arguments);
explicit DynamicType(Tag, std::string_view, Arguments);
explicit DynamicType(Tag /*tag*/, Arguments /*arguments*/);
explicit DynamicType(Tag /*tag*/, std::string_view /*name*/, Arguments /*arguments*/);
DynamicType(DynamicType&& other) = delete;
DynamicType(const DynamicType&) = delete;
DynamicType& operator=(const DynamicType&) = delete;
DynamicType& operator=(DynamicType&&) = delete;
TypePtr containedType(size_t) const override;
TypePtr containedType(size_t /*i*/) const override;
size_t containedTypeSize() const override;
Tag tag() const {
return tag_;

View File

@ -96,15 +96,15 @@ struct TORCH_API Function {
// Overload for server interpreter, a bailout size is needed for graph
// executor.
virtual bool call(
Stack&,
std::optional<size_t>,
c10::function_ref<void(const Code&)>) {
Stack& /*unused*/,
std::optional<size_t> /*unused*/,
c10::function_ref<void(const Code&)> /*unused*/) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false);
return false;
}
// Overload for mobile interpreter.
virtual bool call(Stack&, c10::function_ref<void(const mobile::Code&)>) {
virtual bool call(Stack& /*unused*/, c10::function_ref<void(const mobile::Code&)> /*unused*/) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false);
return false;
}

View File

@ -847,7 +847,7 @@ struct TORCH_API IValue final {
IValue(std::optional<T> v);
template <class T, enable_if_list_is_ivalue_constructible<T> = nullptr>
IValue(c10::OptionalArrayRef<T> v);
IValue(std::nullopt_t);
IValue(std::nullopt_t /*unused*/);
// ClassType
IValue(c10::intrusive_ptr<ivalue::Object> v);

View File

@ -660,7 +660,7 @@ struct TORCH_API TupleTypeFactory<TupleType> {
template <>
struct TORCH_API TupleTypeFactory<c10::DynamicType> {
static DynamicTypePtr create(const std::vector<TypePtr>& elemTypes);
static DynamicTypePtr fallback(const Type&);
static DynamicTypePtr fallback(const Type& /*unused*/);
};
struct TORCH_API Tuple : c10::intrusive_ptr_target {
@ -1682,7 +1682,7 @@ struct ivalue::EnumHolder : c10::intrusive_ptr_target {
namespace detail {
struct _guarded_unsigned_long_unique_dummy final {
_guarded_unsigned_long_unique_dummy(int64_t){}
_guarded_unsigned_long_unique_dummy(int64_t /*unused*/){}
};
using _guarded_unsigned_long = std::conditional_t<
std::is_same_v<unsigned long, uint32_t> ||
@ -1776,7 +1776,7 @@ template <class Elem>
// native_functions.yaml still return std::vector.
// C10_DEPRECATED_MESSAGE("IValues based on std::vector<T> are potentially slow
// and deprecated. Please use torch::List<T> instead.")
std::vector<Elem> generic_to(IValue ivalue, _fake_type<std::vector<Elem>>) {
std::vector<Elem> generic_to(IValue ivalue, _fake_type<std::vector<Elem>> /*unused*/) {
// We need to do a deep copy of the vector because there might be other
// references to this same IValue that also use the list. We can't just
// move the elements out.
@ -1826,18 +1826,18 @@ c10::intrusive_ptr<T> IValue::toCustomClass() const& {
}
template <typename T>
T generic_to(IValue ivalue, _fake_type<T>) {
T generic_to(IValue ivalue, _fake_type<T> /*unused*/) {
using ElemType = typename std::remove_pointer<T>::type::element_type;
return std::move(ivalue).template toCustomClass<ElemType>();
}
template <typename T>
tagged_capsule<T> generic_to(IValue ivalue, _fake_type<tagged_capsule<T>>) {
tagged_capsule<T> generic_to(IValue ivalue, _fake_type<tagged_capsule<T>> /*unused*/) {
return tagged_capsule<T>{std::move(ivalue)};
}
template <typename Elem>
c10::List<Elem> generic_to(IValue ivalue, _fake_type<c10::List<Elem>>) {
c10::List<Elem> generic_to(IValue ivalue, _fake_type<c10::List<Elem>> /*unused*/) {
return impl::toTypedList<Elem>(std::move(ivalue).toList());
}
@ -1867,7 +1867,7 @@ std::vector<T> createVectorFromList(const c10::List<T>& impl) {
}
template <typename T>
OptionalArray<T> generic_to(IValue ivalue, _fake_type<OptionalArray<T>>) {
OptionalArray<T> generic_to(IValue ivalue, _fake_type<OptionalArray<T>> /*unused*/) {
if (ivalue.isNone()) {
return {};
}
@ -1880,8 +1880,8 @@ namespace detail {
template <typename Elem, size_t... I>
std::array<Elem, sizeof...(I)> generic_to_array(
IValue ivalue,
_fake_type<std::array<Elem, sizeof...(I)>>,
std::index_sequence<I...>) {
_fake_type<std::array<Elem, sizeof...(I)>> /*unused*/,
std::index_sequence<I...> /*unused*/) {
// We need to do a deep copy of the array because there might be other
// references to this same IValue that also use the list. We can't just
// move the elements out.
@ -1906,7 +1906,7 @@ std::array<Elem, N> generic_to(
template <typename Key, typename Value>
c10::Dict<Key, Value> generic_to(
IValue ivalue,
_fake_type<c10::Dict<Key, Value>>) {
_fake_type<c10::Dict<Key, Value>> /*unused*/) {
return impl::toTypedDict<Key, Value>(std::move(ivalue).toGenericDict());
}
@ -1915,7 +1915,7 @@ C10_DEPRECATED_MESSAGE(
"IValues based on std::unordered_map are slow and deprecated. Please use c10::Dict<K, V> instead.")
std::unordered_map<K, V> generic_to(
IValue ivalue,
_fake_type<std::unordered_map<K, V>>) {
_fake_type<std::unordered_map<K, V>> /*unused*/) {
std::unordered_map<K, V> specialized_dict;
for (const auto& item : std::move(ivalue).toGenericDict()) {
@ -1926,7 +1926,7 @@ std::unordered_map<K, V> generic_to(
}
template <typename T>
std::optional<T> generic_to(IValue ivalue, _fake_type<std::optional<T>>) {
std::optional<T> generic_to(IValue ivalue, _fake_type<std::optional<T>> /*unused*/) {
if (ivalue.isNone()) {
return std::nullopt;
}
@ -1937,7 +1937,7 @@ namespace detail {
template <typename Tuple, std::size_t... INDEX>
Tuple generic_to_tuple_impl(
const ivalue::TupleElements& t,
std::index_sequence<INDEX...>) {
std::index_sequence<INDEX...> /*unused*/) {
return std::make_tuple(
t[INDEX].to<typename std::tuple_element<INDEX, Tuple>::type>()...);
}
@ -1951,7 +1951,7 @@ template <
std::is_lvalue_reference<Args>...,
std::negation<std::is_constructible<IValue, Args>>...>,
std::nullptr_t> = nullptr>
std::tuple<Args...> generic_to(const IValue& ivalue, _fake_type<std::tuple<Args...>>) {
std::tuple<Args...> generic_to(const IValue& ivalue, _fake_type<std::tuple<Args...>> /*unused*/) {
const auto& vals = ivalue.toTupleRef().elements();
TORCH_CHECK(vals.size() == sizeof...(Args));
return detail::generic_to_tuple_impl<std::tuple<Args...>>(vals, Indices{});
@ -2311,7 +2311,7 @@ inline IValue::IValue(std::optional<T> v) : IValue() {
}
}
inline IValue::IValue(std::nullopt_t) : IValue() {}
inline IValue::IValue(std::nullopt_t /*unused*/) : IValue() {}
inline IValue::IValue(c10::intrusive_ptr<ivalue::Object> v)
: tag(Tag::Object) {
@ -2482,15 +2482,15 @@ namespace ivalue {
namespace detail {
template <typename T>
IValue from_(T&& x, std::true_type) {
IValue from_(T&& x, std::true_type /*unused*/) {
return IValue(std::forward<T>(x));
}
template <typename T>
IValue from_(c10::intrusive_ptr<T> x, std::false_type) {
IValue from_(c10::intrusive_ptr<T> x, std::false_type /*unused*/) {
return IValue(std::move(x));
}
template <typename T>
IValue from_(T&& /*x*/, std::false_type) {
IValue from_(T&& /*x*/, std::false_type /*unused*/) {
static_assert(
guts::false_t<T>::value,
"You are calling from with a type that it doesn't support, and isn't a potential custom class (ie: is an intrusive_ptr)");
@ -2546,19 +2546,19 @@ struct MaybeOwnedTraits<IValue> {
return &borrow;
}
static bool debugBorrowIsValid(const borrow_type&) {
static bool debugBorrowIsValid(const borrow_type& /*unused*/) {
return true;
}
};
template <>
struct IValue::TagType<c10::Type> {
static TORCH_API c10::TypePtr get(const IValue&);
static TORCH_API c10::TypePtr get(const IValue& /*v*/);
};
template <>
struct IValue::TagType<c10::DynamicType> {
static TORCH_API c10::TypePtr get(const IValue&);
static TORCH_API c10::TypePtr get(const IValue& /*v*/);
};
template <typename T>

View File

@ -44,7 +44,7 @@ constexpr int checkStaticTypes() {
}
template <typename... Ts, size_t... Is>
constexpr std::array<ArgumentDef, sizeof...(Ts)> createArgumentVectorFromTypes(std::index_sequence<Is...>) {
constexpr std::array<ArgumentDef, sizeof...(Ts)> createArgumentVectorFromTypes(std::index_sequence<Is...> /*unused*/) {
return (
// Check types for common errors
checkStaticTypes<Ts...>(),

View File

@ -83,7 +83,7 @@ inline bool operator!=(const OperatorName& lhs, const OperatorName& rhs) {
}
TORCH_API std::string toString(const OperatorName& opName);
TORCH_API std::ostream& operator<<(std::ostream&, const OperatorName&);
TORCH_API std::ostream& operator<<(std::ostream& /*os*/, const OperatorName& /*opName*/);
} // namespace c10

View File

@ -16,7 +16,7 @@ class SingletonTypePtr {
/* implicit */ SingletonTypePtr(T* p) : repr_(p) {}
// We need this to satisfy Pybind11, but it shouldn't be hit.
explicit SingletonTypePtr(std::shared_ptr<T>) { TORCH_CHECK(false); }
explicit SingletonTypePtr(std::shared_ptr<T> /*unused*/) { TORCH_CHECK(false); }
using element_type = typename std::shared_ptr<T>::element_type;

View File

@ -307,8 +307,8 @@ Vectorized<c10::BFloat16> inline operator/(
}
inline Vectorized<BFloat16>::Vectorized() {
const short zero = 0;
values = svdup_n_bf16(c10::bit_cast<bfloat16_t>(zero));
auto vals_f = svdup_n_f32(0);
values = convert_float_bfloat16(vals_f, vals_f);
}
inline Vectorized<BFloat16>::Vectorized(int val) {

View File

@ -342,19 +342,19 @@ class Vectorized<c10::complex<double>> {
return _mm256_cmp_pd(values, other.values, _CMP_NEQ_UQ);
}
Vectorized<c10::complex<double>> operator<(
const Vectorized<c10::complex<double>>&) const {
const Vectorized<c10::complex<double>>& /*unused*/) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Vectorized<c10::complex<double>> operator<=(
const Vectorized<c10::complex<double>>&) const {
const Vectorized<c10::complex<double>>& /*unused*/) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Vectorized<c10::complex<double>> operator>(
const Vectorized<c10::complex<double>>&) const {
const Vectorized<c10::complex<double>>& /*unused*/) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Vectorized<c10::complex<double>> operator>=(
const Vectorized<c10::complex<double>>&) const {
const Vectorized<c10::complex<double>>& /*unused*/) const {
TORCH_CHECK(false, "not supported for complex numbers");
}

View File

@ -1861,6 +1861,8 @@ template bool gemm_and_bias(
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation);
using at::blas::ScalingType;
int get_scale_mode(ScalingType scaling_type, ScalarType scale_dtype, bool use_fast_accum) {
switch (scaling_type) {
case ScalingType::BlockWise1x32:

View File

@ -14,6 +14,7 @@
*/
#include <ATen/cuda/CUDAContext.h>
#include <ATen/BlasBackend.h>
#include <ATen/OpMathType.h>
namespace at::cuda::blas {
@ -136,15 +137,6 @@ void int8_gemm(
int32_t* result_ptr,
int64_t result_ld);
enum class ScalingType : std::uint8_t {
TensorWise, // fp32 scales
RowWise, // fp32 scales
BlockWise1x16, // fp8_e4m3fn scales
BlockWise1x32, // fp8_e8m0fnu scales
BlockWise1x128, // fp32 scales
BlockWise128x128, // fp32 scales
};
void scaled_gemm(
char transa,
char transb,
@ -156,13 +148,13 @@ void scaled_gemm(
int64_t mat1_ld,
ScalarType mat1_dtype,
ScalarType mat1_scale_dtype,
ScalingType mat1_scaling_type,
at::blas::ScalingType mat1_scaling_type,
const void* mat2_ptr,
const void* mat2_scale_ptr,
int64_t mat2_ld,
ScalarType mat2_dtype,
ScalarType mat2_scale_dtype,
ScalingType mat2_scaling_type,
at::blas::ScalingType mat2_scaling_type,
const void* bias_ptr,
ScalarType bias_dtype,
void* result_ptr,

View File

@ -17,7 +17,7 @@ TORCH_CUDA_CPP_API void set_magma_init_fn(void (*magma_init_fn)());
// The real implementation of CUDAHooksInterface
struct CUDAHooks : public at::CUDAHooksInterface {
CUDAHooks(at::CUDAHooksArgs) {}
CUDAHooks(at::CUDAHooksArgs /*unused*/) {}
void init() const override;
Device getDeviceFromPtr(void* data) const override;
bool isPinnedPtr(const void* data) const override;

View File

@ -29,7 +29,7 @@
namespace at::cuda::tunable {
using at::cuda::blas::ScalingType;
using at::blas::ScalingType;
enum class BlasOp {
N = 0,

View File

@ -29,7 +29,7 @@ template <typename ParamsT>
class Callable {
public:
virtual ~Callable() = default;
virtual TuningStatus Call(const ParamsT*) {
virtual TuningStatus Call(const ParamsT* /*unused*/) {
return FAIL;
}
virtual TuningStatus IsSupported(const ParamsT* params) {

View File

@ -25,7 +25,7 @@ struct TORCH_API HPUHooksInterface : AcceleratorHooksInterface {
false, "Cannot get device of pointer on HPU without HPU backend");
}
bool isPinnedPtr(const void*) const override {
bool isPinnedPtr(const void* /*data*/) const override {
return false;
}

View File

@ -410,7 +410,7 @@ struct ExistingBdimBatchRuleHelper<F, Func, c10::guts::typelist::typelist<A, T..
template <typename F, F Method, typename... ExtraArgs>
Tensor& unary_inplace_batch_rule(Tensor& self, std::optional<int64_t>, ExtraArgs... extra_args) {
Tensor& unary_inplace_batch_rule(Tensor& self, std::optional<int64_t> /*unused*/, ExtraArgs... extra_args) {
INVOKE(self, Method)(std::forward<ExtraArgs>(extra_args)...);
return self;
}

View File

@ -18,7 +18,7 @@ extern std::atomic<const MetalInterface*> g_metal_impl_registry;
class MetalImplRegistrar {
public:
explicit MetalImplRegistrar(MetalInterface*);
explicit MetalImplRegistrar(MetalInterface* /*impl*/);
};
at::Tensor& metal_copy_(at::Tensor& self, const at::Tensor& src);

View File

@ -2060,7 +2060,7 @@ std::tuple<Tensor, Tensor> linalg_lu_factor(const Tensor& A, bool pivot) {
}
// TODO Deprecate this function in favour of linalg_lu_factor_ex
std::tuple<Tensor, Tensor, Tensor> _lu_with_info(const Tensor& self, bool compute_pivots, bool) {
std::tuple<Tensor, Tensor, Tensor> _lu_with_info(const Tensor& self, bool compute_pivots, bool /*unused*/) {
TORCH_WARN_ONCE(
"torch.lu is deprecated in favor of torch.linalg.lu_factor / torch.linalg.lu_factor_ex and will be ",
"removed in a future PyTorch release.\n",

View File

@ -15,7 +15,11 @@ namespace at::native {
Scalar item(const Tensor& self) {
auto numel = self.sym_numel();
TORCH_CHECK(numel == 1, "a Tensor with ", numel, " elements cannot be converted to Scalar");
TORCH_SYM_CHECK(
numel.sym_eq(1),
"a Tensor with ",
numel,
" elements cannot be converted to Scalar");
if (self.is_sparse()) {
if (self._nnz() == 0) return Scalar(0);
if (self.is_coalesced()) return at::_local_scalar_dense(self._values());

View File

@ -346,17 +346,17 @@ template<typename acc_t>
struct AbsSwitch {};
template<typename scalar_t, typename acc_t>
inline C10_DEVICE acc_t abs_if_complex(scalar_t data, AbsSwitch<acc_t>) {
inline C10_DEVICE acc_t abs_if_complex(scalar_t data, AbsSwitch<acc_t> /*unused*/) {
return static_cast<acc_t>(data);
}
template<typename scalar_t, typename acc_t>
inline C10_DEVICE acc_t abs_if_complex(std::complex<scalar_t> data, AbsSwitch<acc_t>) {
inline C10_DEVICE acc_t abs_if_complex(std::complex<scalar_t> data, AbsSwitch<acc_t> /*unused*/) {
return static_cast<acc_t>(std::abs(data));
}
template<typename scalar_t, typename acc_t>
inline C10_DEVICE acc_t abs_if_complex(c10::complex<scalar_t> data, AbsSwitch<acc_t>) {
inline C10_DEVICE acc_t abs_if_complex(c10::complex<scalar_t> data, AbsSwitch<acc_t> /*unused*/) {
return static_cast<acc_t>(std::abs(at::opmath_type<c10::complex<scalar_t>>(data)));
}

View File

@ -846,7 +846,7 @@ TORCH_IMPL_FUNC(clamp_Tensor_out)
(const Tensor& self,
const OptionalTensorRef min,
const OptionalTensorRef max,
const Tensor&) {
const Tensor& /*unused*/) {
if (min && max) {
clamp_stub(device_type(), *this);
} else if (min) {

View File

@ -23,6 +23,14 @@
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_cast_Byte_native.h>
#include <ATen/ops/_cast_Char_native.h>
#include <ATen/ops/_cast_Double_native.h>
#include <ATen/ops/_cast_Float_native.h>
#include <ATen/ops/_cast_Half_native.h>
#include <ATen/ops/_cast_Int_native.h>
#include <ATen/ops/_cast_Long_native.h>
#include <ATen/ops/_cast_Short_native.h>
#include <ATen/ops/_dim_arange_native.h>
#include <ATen/ops/_efficientzerotensor_native.h>
#include <ATen/ops/_empty_affine_quantized.h>

View File

@ -452,11 +452,11 @@ void convolution_depthwise3x3_winograd_impl(
#else
void convolution_depthwise3x3_winograd_impl(
const Arguments&,
const float* const,
const float* const,
const float* const,
float* const) {
const Arguments& /*unused*/,
const float* const /*unused*/,
const float* const /*unused*/,
const float* const /*unused*/,
float* const /*unused*/) {
}
#endif /* __ARM_NEON__ */

View File

@ -46,7 +46,7 @@ using namespace vec;
template <typename traits, std::size_t... INDEX>
typename traits::ArgsTuple
dereference_impl(char* C10_RESTRICT data[], const int64_t* strides, int64_t i,
std::index_sequence<INDEX...>) {
std::index_sequence<INDEX...> /*unused*/) {
return std::make_tuple(
c10::load<typename traits::template arg<INDEX>::type>(
data[INDEX] + i * strides[INDEX])...);
@ -65,7 +65,7 @@ dereference_vec_impl(char* C10_RESTRICT data[],
const typename traits::result_type& opt_scalar,
size_t S,
int64_t i,
std::index_sequence<INDEX...>) {
std::index_sequence<INDEX...> /*unused*/) {
using Vec = typename traits::result_type;
using scalar_t = typename Vec::value_type;
return std::make_tuple(
@ -231,7 +231,7 @@ vectorized_loop(char** C10_RESTRICT data_, int64_t n, int64_t S, func_t&& op, ve
template <typename traits, typename cb_t>
inline void unroll_contiguous_scalar_checks(
const int64_t* /*strides*/,
std::index_sequence<>,
std::index_sequence<> /*unused*/,
cb_t&& cb) {
cb(0);
}
@ -239,7 +239,7 @@ inline void unroll_contiguous_scalar_checks(
template <typename traits, typename cb_t, size_t INDEX0, size_t ...INDEX>
inline void unroll_contiguous_scalar_checks(
const int64_t* strides,
std::index_sequence<INDEX0, INDEX...>,
std::index_sequence<INDEX0, INDEX...> /*unused*/,
cb_t&& cb) {
if (is_contiguous_scalar<traits, INDEX0 + 1>(strides)) {
cb(INDEX0 + 1);

View File

@ -4,6 +4,7 @@
#include <c10/util/SmallVector.h>
#include <c10/core/Scalar.h>
#include <c10/core/ScalarType.h>
#include <c10/util/Exception.h>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/core/NamedTensor.h>
@ -105,7 +106,8 @@ c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, b
}
}
using at::cuda::blas::ScalingType;
using at::blas::ScalingType;
using at::blas::SwizzleType;
/**
* @brief Prepares matrices for CUBLAS operation
@ -1112,7 +1114,7 @@ namespace{
* - Returns Error.
*/
using at::cuda::blas::ScalingType;
using at::blas::ScalingType;
bool is_tensorwise_scaling(const at::Tensor& t, const at::Tensor& scale) {
return isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.numel() == 1;
@ -1679,9 +1681,890 @@ _scaled_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
bool use_fast_accum) {
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_));
return _scaled_mm_out_cuda(mat_a, mat_b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
}
/**
* Track concrete implementations available
*/
enum class ScaledGemmImplementation {
NONE = 0,
TENSORWISE_TENSORWISE = 1,
ROWWISE_ROWWISE = 2,
BLOCK_128x128_1x128 = 3,
BLOCK_1x128_128x128 = 4,
BLOCK_1x128_1x128 = 5,
MXFP8_MXFP8 = 6,
NVFP4_NVFP4 = 7,
NVFP4_NVFP4_SINGLE_SCALE = 8,
};
/**
* Convert passed int (enum) from python back into a
* strictly-typed enum
*/
template <class EnumType, class ArrayType>
std::vector<EnumType> convert_int_to_enum(ArrayType& v) {
std::vector<EnumType> converted;
converted.reserve(v.size());
for (auto vi : v) {
converted.push_back(static_cast<EnumType>(vi));
}
return converted;
}
/**
* Both inputs must be fp8,
* Each needs a single scale, {Tensorwise (float)}
*/
bool check_tensorwise_recipe(c10::ScalarType type_a,
std::vector<ScalingType>& recipe_a,
ArrayRef<Tensor>& scales_a,
c10::ScalarType type_b,
std::vector<ScalingType>& recipe_b,
ArrayRef<Tensor>& scales_b) {
// both types must be fp8
if (!isFloat8Type(type_a) || !isFloat8Type(type_b)) {
return false;
}
// 1 scale each, {Tensorwise, float}
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
return false;
}
// Need {Blockwise_1x32, e8m0} for A & B
if (recipe_a[0] != ScalingType::TensorWise) return false;
if (scales_a[0].scalar_type() != ScalarType::Float) return false;
if (recipe_b[0] != ScalingType::TensorWise) return false;
if (scales_b[0].scalar_type() != ScalarType::Float) return false;
return true;
}
/**
* Both inputs must be fp8,
* Each needs scales, {Rowwise (float)}
*/
bool check_rowwise_recipe(c10::ScalarType type_a,
std::vector<ScalingType>& recipe_a,
ArrayRef<Tensor>& scales_a,
c10::ScalarType type_b,
std::vector<ScalingType>& recipe_b,
ArrayRef<Tensor>& scales_b) {
// both types must be fp8
if (!isFloat8Type(type_a) || !isFloat8Type(type_b)) {
return false;
}
// 1 scale each, {Tensorwise, float}
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
return false;
}
// Need {RowWise, dp32} for A & B
if (recipe_a[0] != ScalingType::RowWise) return false;
if (scales_a[0].scalar_type() != ScalarType::Float) return false;
if (recipe_b[0] != ScalingType::RowWise) return false;
if (scales_b[0].scalar_type() != ScalarType::Float) return false;
return true;
}
/**
* Two-level scaling, canonical NVFP4
* Both inputs must be fp4
* A, B need 2 scales, {Blockwise_1x16 (e4m3), Tensorwise (fp32)}
*/
bool check_nvfp4_recipe(c10::ScalarType type_a,
std::vector<ScalingType>& recipe_a,
ArrayRef<Tensor>& scales_a,
c10::ScalarType type_b,
std::vector<ScalingType>& recipe_b,
ArrayRef<Tensor>& scales_b) {
// both types must be fp4
if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) {
return false;
}
// 2 scales, 2 recipes for each input
if (scales_a.size() != 2 || recipe_a.size() != 2 || scales_b.size() != 2 || recipe_b.size() != 2) {
return false;
}
// Need {Blockwise_1x16, e4m3 for scale[0], Tensorwise, fp32 for scale[1]}
if (recipe_a[0] != ScalingType::BlockWise1x16 || recipe_a[1] != ScalingType::TensorWise) return false;
if (scales_a[0].scalar_type() != ScalarType::Float8_e4m3fn || scales_a[1].scalar_type() != ScalarType::Float) return false;
if (recipe_b[0] != ScalingType::BlockWise1x16 || recipe_b[1] != ScalingType::TensorWise) return false;
if (scales_b[0].scalar_type() != ScalarType::Float8_e4m3fn || scales_b[1].scalar_type() != ScalarType::Float) return false;
return true;
}
/**
* Single-level scaling, what PyT currently understands
* Both inputs must be fp4
* A, B need 1 scale, {Blockwise_1x16 (e4m3)}
*/
bool check_nvfp4_recipe_single_scale
(c10::ScalarType type_a,
std::vector<ScalingType>& recipe_a,
ArrayRef<Tensor>& scales_a,
c10::ScalarType type_b,
std::vector<ScalingType>& recipe_b,
ArrayRef<Tensor>& scales_b) {
// both types must be fp4
if (type_a != ScalarType::Float4_e2m1fn_x2 || type_b != ScalarType::Float4_e2m1fn_x2) {
return false;
}
// 2 scales, 2 recipes for each input
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
return false;
}
// Need {Blockwise_1x16, e4m3 for scale[0], Tensorwise, fp32 for scale[1]}
if (recipe_a[0] != ScalingType::BlockWise1x16) return false;
if (scales_a[0].scalar_type() != ScalarType::Float8_e4m3fn) return false;
if (recipe_b[0] != ScalingType::BlockWise1x16) return false;
if (scales_b[0].scalar_type() != ScalarType::Float8_e4m3fn) return false;
return true;
}
/**
* Both inputs must be fp8
* A, B must only have 1 scale each, A: {Blockwise_1x128 (float), B: {Blockwise_128x128 (float)
*/
bool check_deepseek_recipe(ScalingType expected_recipe_a,
ScalingType expected_recipe_b,
c10::ScalarType type_a,
std::vector<ScalingType>& recipe_a,
ArrayRef<Tensor>& scales_a,
c10::ScalarType type_b,
std::vector<ScalingType>& recipe_b,
ArrayRef<Tensor>& scales_b) {
// both types must be fp8
if (type_a != ScalarType::Float8_e4m3fn || type_b != ScalarType::Float8_e4m3fn) {
return false;
}
// 1 scales, 1 recipes for each input
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
return false;
}
// Need {Blockwise_1x128, float} for A, {Blockwise_128x128, float} for B
if (recipe_a[0] != expected_recipe_a) return false;
if (scales_a[0].scalar_type() != ScalarType::Float) return false;
if (recipe_b[0] != expected_recipe_b) return false;
if (scales_b[0].scalar_type() != ScalarType::Float) return false;
return true;
}
/**
* Both inputs must be fp8
* A, B must have 1 scale each, {Blockwise_1x32, e8m0}
*/
bool check_mxfp8_recipe(c10::ScalarType type_a,
std::vector<ScalingType>& recipe_a,
ArrayRef<Tensor>& scales_a,
c10::ScalarType type_b,
std::vector<ScalingType>& recipe_b,
ArrayRef<Tensor>& scales_b) {
// both types must be fp8
if (type_a != ScalarType::Float8_e4m3fn || type_b != ScalarType::Float8_e4m3fn) {
return false;
}
// 1 scales, 1 recipes for each input
if (scales_a.size() != 1 || recipe_a.size() != 1 || scales_b.size() != 1 || recipe_b.size() != 1) {
return false;
}
// Need {Blockwise_1x32, e8m0} for A & B
if (recipe_a[0] != ScalingType::BlockWise1x32) return false;
if (scales_a[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
if (recipe_b[0] != ScalingType::BlockWise1x32) return false;
if (scales_b[0].scalar_type() != ScalarType::Float8_e8m0fnu) return false;
return true;
}
using acceptance_fn = std::function<bool(c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&, c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&)>;
using namespace std::placeholders;
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 8> scale_kernel_dispatch = {{
{ "tensorwise_tensorwise", check_tensorwise_recipe, ScaledGemmImplementation::TENSORWISE_TENSORWISE },
{ "rowwise_rowwise", check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
{ "block_1x128_128x128", std::bind(check_deepseek_recipe, ScalingType::BlockWise1x128, ScalingType::BlockWise128x128, _1, _2, _3, _4, _5, _6),
ScaledGemmImplementation::BLOCK_1x128_128x128},
{ "block_128x128_1x128", std::bind(check_deepseek_recipe, ScalingType::BlockWise128x128, ScalingType::BlockWise1x128, _1, _2, _3, _4, _5, _6),
ScaledGemmImplementation::BLOCK_128x128_1x128},
{ "block_1x128_1x128", std::bind(check_deepseek_recipe, ScalingType::BlockWise1x128, ScalingType::BlockWise1x128, _1, _2, _3, _4, _5, _6),
ScaledGemmImplementation::BLOCK_1x128_1x128},
{ "nvfp4_nvfp4", check_nvfp4_recipe, ScaledGemmImplementation::NVFP4_NVFP4},
{ "nvfp4_nvfp4_single_scale", check_nvfp4_recipe_single_scale, ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE },
{ "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
Tensor&
_cutlass_scaled_gemm(
const Tensor& mat1, const Tensor& mat2,
const Tensor& scale_a, const Tensor& scale_b,
const ScalingType scaling_choice_a, const ScalingType scaling_choice_b,
const std::optional<Tensor>& bias,
const bool use_fast_accum,
Tensor& out) {
cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, std::nullopt, scaling_choice_a, scaling_choice_b);
const auto out_dtype_ = args.result->scalar_type();
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
#ifdef USE_ROCM
auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
#define TUNABLE_DISPATCH(BLASOP_A, BLASOP_B) \
if (mat1.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fnuz, at::Float8_e4m3fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fnuz, at::Float8_e5m2fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
} \
else if (mat1.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2fnuz, at::Float8_e4m3fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2fnuz, at::Float8_e5m2fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
} \
else if (mat1.scalar_type() == ScalarType::Float8_e4m3fn) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fn, at::Float8_e4m3fn, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fn, at::Float8_e5m2, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
} \
else if (mat1.scalar_type() == ScalarType::Float8_e5m2) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2, at::Float8_e4m3fn, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2, at::Float8_e5m2, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
}
AT_DISPATCH_V2(out_dtype_, "_tunable_scaled_gemm", AT_WRAP([&] {
bool transa_ = ((args.transa != 'n') && (args.transa != 'N'));
bool transb_ = ((args.transb != 'n') && (args.transb != 'N'));
at::cuda::tunable::ScaledGemmParams<scalar_t> params;
params.transa = args.transa;
params.transb = args.transb;
params.m = args.m;
params.n = args.n;
params.k = args.k;
params.a = args.mata->data_ptr();
params.a_scale_ptr = args.scale_mata_ptr;
params.a_scale_dtype = args.scale_mata_dtype.value();
params.lda = args.lda;
params.a_dtype = args.mata->scalar_type();
params.a_scale_dtype = args.scale_mata_dtype.value();
params.a_scaling_type = args.scaling_mata_type.value();
params.b = args.matb->data_ptr();
params.b_scale_ptr = args.scale_matb_ptr;
params.b_scale_dtype = args.scale_matb_dtype.value();
params.ldb = args.ldb;
params.b_dtype = args.matb->scalar_type();
params.b_scale_dtype = args.scale_matb_dtype.value();
params.b_scaling_type = args.scaling_matb_type.value();
params.bias_ptr = bias ? bias->data_ptr(): nullptr;
params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_;
params.c = args.result->data_ptr();
params.c_scale_ptr = args.scale_result_ptr;
params.ldc = args.result_ld;
params.c_dtype = out_dtype_;
params.use_fast_accum = use_fast_accum;
if (transa_ && transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T)
}
else if (transa_ && !transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::N)
}
else if (!transa_ && transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::T)
}
else if (!transa_ && !transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::N)
}
else {
TORCH_CHECK(false, "unreachable");
}
}),
kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_FLOATING_TYPES));
#undef TUNABLE_DISPATCH
}
else
#endif
{
at::cuda::blas::scaled_gemm(
args.transa,
args.transb,
args.m,
args.n,
args.k,
args.mata->data_ptr(),
args.scale_mata_ptr,
args.lda,
args.mata->scalar_type(),
args.scale_mata_dtype.value(),
args.scaling_mata_type.value(),
args.matb->data_ptr(),
args.scale_matb_ptr,
args.ldb,
args.matb->scalar_type(),
args.scale_matb_dtype.value(),
args.scaling_matb_type.value(),
bias ? bias->data_ptr(): nullptr,
bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
args.result->data_ptr(),
args.scale_result_ptr,
args.result_ld,
out_dtype_,
use_fast_accum);
}
return out;
}
Tensor&
_scaled_tensorwise_tensorwise(
const Tensor& mat_a, const Tensor& mat_b,
const Tensor& scale_a, const Tensor& scale_b,
const std::optional<Tensor>& bias,
const c10::ScalarType out_dtype,
bool use_fast_accum,
Tensor& out) {
// Restrictions:
// A, B are FP8, scales are fp32
//
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
mat_a.scalar_type(), mat_b.scalar_type());
TORCH_CHECK_VALUE(scale_a.numel() == 1 && scale_a.scalar_type() == kFloat, "scale_a must have 1 Float element")
TORCH_CHECK_VALUE(scale_b.numel() == 1 && scale_b.scalar_type() == kFloat, "scale_b must have 1 Float element")
auto scaling_choice_a = ScalingType::TensorWise;
auto scaling_choice_b = ScalingType::TensorWise;
_cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
return out;
}
Tensor&
_scaled_rowwise_rowwise(
const Tensor& mat_a, const Tensor& mat_b,
const Tensor& scale_a, const Tensor& scale_b,
const std::optional<Tensor>& bias,
const c10::ScalarType out_dtype,
bool use_fast_accum,
Tensor& out) {
// Restrictions:
// A, B are FP8, scales are fp32, shape M/N for A/B
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
mat_a.scalar_type(), mat_b.scalar_type());
TORCH_CHECK_VALUE(scale_a.size(0) == mat_a.size(0) && scale_a.size(1) == 1, "scale_a must have shape [", mat_a.size(0), ", 1], got [", scale_a.sizes(), "]");
TORCH_CHECK_VALUE(scale_a.numel() == mat_a.size(0) && scale_a.scalar_type() == kFloat, "scale_a must have ", mat_a.size(0), " Float elements, got ", scale_a.numel())
TORCH_CHECK_VALUE(scale_b.numel() == mat_b.size(1) && scale_b.scalar_type() == kFloat, "scale_b must have ", mat_b.size(1), " Float elements, got ", scale_b.numel())
TORCH_CHECK_VALUE(scale_a.stride(1) == 1, "expected scale_a.stride(1) to be 1, but got ", scale_a.stride(1));
TORCH_CHECK_VALUE(scale_b.stride(1) == 1, "expected scale_b.stride(1) to be 1, but got ", scale_b.stride(1));
auto scaling_choice_a = ScalingType::RowWise;
auto scaling_choice_b = ScalingType::RowWise;
//
// NVIDIA's cuBLAS only started supporting row-wise scaling in version 12.9,
// and only for compute capability 9.0+. In other cases we use CUTLASS.
#ifndef USE_ROCM
// We are doing row-wise scaling
auto dprops = at::cuda::getCurrentDeviceProperties();
if (((dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900)
// cuBLAS only supports tiled 1D factor layout for 1D block scaling, no 2D block scales
|| (dprops->major == 10 && (scale_a.sizes().size() || scale_b.sizes().size())))) {
TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling.");
at::cuda::detail::f8f8bf16_rowwise(
mat_a,
mat_b,
scale_a,
scale_b,
bias,
use_fast_accum,
out);
return out;
}
#else
// For ROCm, match behavior of f8f8bf16_rowwise type checking, for unit test purposes.
//Tensor b = mat_b;
if (_scaled_mm_is_fnuz()) {
TORCH_CHECK_VALUE(mat_b.dtype() == at::kFloat8_e4m3fnuz, "expected mat_b.dtype() to be at::kFloat8_e4m3fnuz, but got ", mat_b.dtype());
}
else {
TORCH_CHECK_VALUE(mat_b.dtype() == at::kFloat8_e4m3fn, "expected mat_b.dtype() to be at::kFloat8_e4m3fn, but got ", mat_b.dtype());
}
// Until more than bf16 is supported.
TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16,
"hipblaslt rowwise _scaled_mm only supports BFloat16 output but got ", out.scalar_type());
#endif
_cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
return out;
}
Tensor&
_scaled_block1x128_block1x128(
const Tensor& mat_a, const Tensor& mat_b,
const Tensor& scale_a, const Tensor& scale_b,
const std::optional<Tensor>& bias,
const c10::ScalarType out_dtype,
const bool use_fast_accum,
Tensor& out) {
// Restrictions:
// A, B are FP8, scales are fp32, shape K//128
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
mat_a.scalar_type(), mat_b.scalar_type());
TORCH_CHECK_VALUE(scale_a.sizes()[0] == mat_a.sizes()[0] && scale_a.sizes()[1] == mat_a.sizes()[1] / 128 && scale_a.scalar_type() == kFloat,
"scale_a must have shape ", mat_a.sizes()[0], " x ", mat_a.sizes()[1] / 128, " Float elements, got ", scale_a.sizes())
TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div<int64_t>(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat,
"scale_b must have shape ", ceil_div<int64_t>(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes())
TORCH_CHECK(scale_a.stride(0) == 1, "expected scale_a.stride(0) to be 1, but got ", scale_a.stride(0));
TORCH_CHECK(scale_b.stride(1) == 1, "expected scale_b.stride(1) to be 1, but got ", scale_b.stride(1));
TORCH_CHECK(scale_b.stride(0) == scale_b.size(1),
"expected scale_b.stride(0) to be ", scale_b.size(1), ", but got ", scale_b.size(1));
auto scaling_choice_a = ScalingType::BlockWise1x128;
auto scaling_choice_b = ScalingType::BlockWise1x128;
_cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
return out;
}
Tensor&
_scaled_block128x128_block1x128(
const Tensor& mat_a, const Tensor& mat_b,
const Tensor& scale_a, const Tensor& scale_b,
const std::optional<Tensor>& bias,
const c10::ScalarType out_dtype,
const bool use_fast_accum,
Tensor& out) {
// Restrictions:
// A, B are FP8, scales are fp32, shape K//128
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
mat_a.scalar_type(), mat_b.scalar_type());
TORCH_CHECK_VALUE(scale_a.sizes()[0] == ceil_div<int64_t>(mat_a.sizes()[0], 128) && scale_a.sizes()[1] == ceil_div<int64_t>(mat_a.sizes()[1], 128) && scale_a.scalar_type() == kFloat,
"scale_a must have shape ", ceil_div<int64_t>(mat_a.sizes()[0], 128), " x ", ceil_div<int64_t>(mat_a.sizes()[1], 128), " Float elements, got ", scale_a.sizes())
TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div<int64_t>(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat,
"scale_b must have shape ", ceil_div<int64_t>(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes())
TORCH_CHECK_VALUE(scale_a.stride(1) == 1, "expected scale_a.stride(1) to be 1, but got ", scale_a.stride(1));
TORCH_CHECK_VALUE(scale_b.stride(1) == 1, "expected scale_b.stride(1) to be 1, but got ", scale_b.stride(1));
TORCH_CHECK_VALUE(scale_b.stride(0) == scale_b.size(1),
"expected scale_b.stride(0) to be ", scale_b.size(1), ", but got ", scale_b.stride(0));
auto scaling_choice_a = ScalingType::BlockWise128x128;
auto scaling_choice_b = ScalingType::BlockWise1x128;
_cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
return out;
}
Tensor&
_scaled_block1x128_block128x128(
const Tensor& mat_a, const Tensor& mat_b,
const Tensor& scale_a, const Tensor& scale_b,
const std::optional<Tensor>& bias,
const c10::ScalarType out_dtype,
const bool use_fast_accum,
Tensor& out) {
// Restrictions:
// A, B are FP8, scales are fp32, A: shape K//128, B: K//128, N//128
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
mat_a.scalar_type(), mat_b.scalar_type());
TORCH_CHECK_VALUE(scale_a.sizes()[0] == mat_a.sizes()[0] && scale_a.sizes()[1] == mat_a.sizes()[1] / 128 && scale_a.scalar_type() == kFloat,
"scale_a must have shape ", mat_a.sizes()[0], " x ", mat_a.sizes()[1] / 128, " Float elements, got ", scale_a.sizes())
TORCH_CHECK_VALUE(scale_b.sizes()[0] == mat_b.sizes()[0] / 128 && scale_b.sizes()[1] == mat_b.sizes()[1] / 128 && scale_b.scalar_type() == kFloat,
"scale_b must have shape ", mat_b.sizes()[0] / 128, " x ", mat_b.sizes()[1] / 128, " Float elements, got ", scale_b.sizes())
TORCH_CHECK_VALUE(scale_a.stride(0) == 1, "expected scale_a.stride(0) to be 1, but got ", scale_a.stride(0));
TORCH_CHECK_VALUE(scale_b.stride(0) == 1, "expected scale_b.stride(0) to be 1, but got ", scale_b.stride(0));
TORCH_CHECK_VALUE(scale_b.stride(1) == scale_b.size(0),
"expected scale_b.stride(1) to be ", scale_b.size(0), ", but got ", scale_b.stride(1));
auto scaling_choice_a = ScalingType::BlockWise1x128;
auto scaling_choice_b = ScalingType::BlockWise128x128;
_cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
return out;
}
Tensor&
_scaled_mxfp8_mxfp8(
const Tensor& mat_a, const Tensor& mat_b,
const Tensor& scale_a, const SwizzleType swizzle_a,
const Tensor& scale_b, const SwizzleType swizzle_b,
const std::optional<Tensor>& bias,
const c10::ScalarType out_dtype,
Tensor& out) {
// Restrictions:
// A, B are FP8, scales are e8m0, A: shape K//32, B: K, N//32
// Scales must be swizzled
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
mat_a.scalar_type(), mat_b.scalar_type());
auto scale_a_elems = round_up<int64_t>(mat_a.size(0), 128) * round_up<int64_t>(ceil_div<int64_t>(mat_a.size(1), 32), 4);
auto scale_b_elems = round_up<int64_t>(mat_b.size(1), 128) * round_up<int64_t>(ceil_div<int64_t>(mat_b.size(0), 32), 4);
TORCH_CHECK_VALUE(scale_a_elems == scale_a.numel(),
"For Blockwise scaling scale_a should have ", scale_a_elems, " elements, got: ", scale_a.numel());
TORCH_CHECK_VALUE(scale_b_elems == scale_b.numel(),
"For Blockwise scaling scale_b should have ", scale_b_elems, " elements, got: ", scale_b.numel());
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4, "scale_a must be swizzled to SWIZZLE_32_4_4 format");
TORCH_CHECK_VALUE(swizzle_b == SwizzleType::SWIZZLE_32_4_4, "scale_b must be swizzled to SWIZZLE_32_4_4 format");
TORCH_CHECK_VALUE(scale_a.is_contiguous() && scale_b.is_contiguous(),
"For Blockwise scaling both scales should be contiguous");
TORCH_CHECK_VALUE(out.scalar_type() == out_dtype, "expected out.scalar_type() to be ", out_dtype, ", but got ", out_dtype);
auto scaling_choice_a = ScalingType::BlockWise1x32;
auto scaling_choice_b = ScalingType::BlockWise1x32;
#ifdef USE_ROCM
#if ROCM_VERSION >= 70000
TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950"}),
"Block-wise scaling for Float8_e8m0fnu is only supported on gfx950");
TORCH_CHECK_VALUE(mat_a.size(0) % 32 == 0 && mat_a.size(1) % 32 == 0 &&
mat_b.size(0) % 32 == 0 && mat_b.size(1) % 32 == 0,
"Matrix dimensions must be multiples of 32 for block-wise scaling");
TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16 ||
out.scalar_type() == ScalarType::Half,
"Block-wise scaling only supports BFloat16 or Half output types");
#else
TORCH_CHECK_NOT_IMPLEMENTED(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
#endif
#endif
return _cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
}
Tensor&
_scaled_nvfp4_nvfp4(
const Tensor& mat_a, const Tensor& mat_b,
const Tensor& scale_a, const SwizzleType swizzle_a,
const Tensor& scale_b, const SwizzleType swizzle_b,
const std::optional<Tensor>& bias,
const c10::ScalarType out_dtype,
const bool single_scale,
Tensor& out) {
#ifdef USE_ROCM
TORCH_CHECK_NOT_IMPLEMENTED(false, "NVFP4 scaling not supported on ROCM");
#endif
TORCH_CHECK_VALUE(single_scale, "Only single-scaled NVFP4 currently supported");
// Restrictions:
// A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32
// Scales must be swizzled
TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2 && mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_a and mat_b must be fp4 types, got: ",
mat_a.scalar_type(), mat_b.scalar_type());
// Note: fp4x2 format, need to double the K dimension for checking purposes.
auto scale_a_elems = round_up<int64_t>(mat_a.size(0), 128) * round_up<int64_t>(ceil_div<int64_t>(mat_a.size(1) * 2, 16), 4);
auto scale_b_elems = round_up<int64_t>(mat_b.size(1), 128) * round_up<int64_t>(ceil_div<int64_t>(mat_b.size(0) * 2, 16), 4);
TORCH_CHECK_VALUE(scale_a_elems == scale_a.numel(),
"For Blockwise scaling scale_a should have ", scale_a_elems, " elements, got: ", scale_a.numel());
TORCH_CHECK_VALUE(scale_b_elems == scale_b.numel(),
"For Blockwise scaling scale_b should have ", scale_b_elems, " elements, got: ", scale_b.numel());
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4, "scale_a must be swizzled to SWIZZLE_32_4_4 format");
TORCH_CHECK_VALUE(swizzle_b == SwizzleType::SWIZZLE_32_4_4, "scale_b must be swizzled to SWIZZLE_32_4_4 format");
TORCH_CHECK_VALUE(scale_a.is_contiguous() && scale_b.is_contiguous(),
"For Blockwise scaling both scales should be contiguous");
auto scaling_choice_a = ScalingType::BlockWise1x16;
auto scaling_choice_b = ScalingType::BlockWise1x16;
return _cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
}
// V2: Computes matrix multiply + bias while applying scaling to input and output matrices
// Scales are only applicable when matrices are of Float8 type and assumed to be equal to 1.0 by default.
// If output matrix type is 16 or 32-bit type, scale_result is not applied.
// Known limitations:
// - Only works if mat1 is row-major and mat2 is column-major
// - Only works if matrices sizes are divisible by 32
// - If 1-dimensional tensors are used then scale_a should be size = mat1.size(0)
// and scale_b should have size = to mat2.size(1)
// Arguments:
// - `mat1`: the first operand of the matrix multiply, can be type `torch.float8_e4m3fn` or `torch.float8_e5m2`
// - `mat2`: the second operand of the matrix multiply, can be type `torch.float8_e4m3fn` or `torch.float8_e5m2`
// - `scale_a`: a tensor with the inverse scale of `mat1`, whose shape/strides/dtype depend on the scaling scheme
// - `scale_recipe_a`: An integer corresponding to an enum describing the scaling scheme used for `scale_a`
// - `swizzle_a`: An integer corresponding to a `SwizzleType` enum describing the swizzling scheme for `scale_a`
// - `scale_b`: a tensor with the inverse scale of `mat2`, whose shape/strides/dtype depend on the scaling scheme
// - `scale_recipe_b`: An integer corresponding to an enum describing the scaling scheme used for `scale_b`
// - `swizzle_b`: An integer corresponding to a `SwizzleType` enum describing the swizzling scheme for `scale_b`
// - `bias`: the bias, can be type `torch.float16` or `torch.bfloat16`
// - `out_dtype`: the output dtype, can either be a float8 or a higher precision floating point type
// - `use_fast_accum`: if true, enables fast float8 accumulation. Backends may ignore this option if not applicable.
// - `out`: a reference to the output tensor
Tensor&
_scaled_mm_cuda_v2_out(
const Tensor& mat_a, const Tensor& mat_b,
ArrayRef<Tensor> scale_a,
IntArrayRef scale_recipe_a,
IntArrayRef swizzle_a,
ArrayRef<Tensor> scale_b,
IntArrayRef scale_recipe_b,
IntArrayRef swizzle_b,
const std::optional<Tensor>& bias,
const std::optional<c10::ScalarType> out_dtype,
IntArrayRef contraction_dim,
bool use_fast_accum,
Tensor& out) {
// Check sizes
bool allowed_device = _scaled_mm_allowed_device();
TORCH_CHECK_NOT_IMPLEMENTED(allowed_device,
"torch._scaled_mm is only supported on CUDA devices with compute capability >= 9.0 or 8.9, or ROCm MI300+");
TORCH_CHECK_VALUE(mat_a.dim() == 2, "mat_a must be a matrix");
TORCH_CHECK_VALUE(mat_b.dim() == 2, "mat_b must be a matrix");
// If any of M, K, N is 0 - return early (the tensorwise/rowwise float8 gemm kernels
// do not support this case).
if (mat_a.size(0) == 0 || mat_a.size(1) == 0 || mat_b.size(1) == 0) {
// `out` was created with `at::empty`. In the case where we are multiplying
// MxK by KxN and K is the zero dim, we need to initialize here to properly
// return a tensor of zeros.
at::native::resize_output(out, {mat_a.size(0), mat_b.size(1)});
if (mat_a.size(1) == 0) {
out.zero_();
}
return out;
}
// Check if the input matrix sizes can be multiplied
// - if optional contraction dims are provided, use those
// -- mostly for < 1B formats (i.e. nvfp4x2) where cheap .t() is not available.
if (contraction_dim.size() > 0) {
TORCH_CHECK_VALUE(contraction_dim.size() == 2, "contraction_dim must have exactly 2 elements");
auto mat_a_dim = contraction_dim[0];
auto mat_b_dim = contraction_dim[1];
TORCH_CHECK_VALUE(
mat_a.size(mat_a_dim) == mat_b.size(mat_b_dim), "mat_a and mat_b shapes cannot be multiplied (",
mat_a.size(0), "x", mat_a.size(1), " and ", mat_b.size(0), "x", mat_b.size(1), ") ",
"with contraction dims mat_a: ", mat_a_dim, ", mat_b: ", mat_b_dim);
} else {
TORCH_CHECK_VALUE(
mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied (",
mat_a.size(0), "x", mat_a.size(1), " and ", mat_b.size(0), "x", mat_b.size(1), ")");
}
TORCH_CHECK_VALUE(!bias || bias->numel() == mat_b.sizes()[1], "Bias must be size ", mat_b.sizes()[1],
" but got ", bias->numel());
TORCH_CHECK_VALUE(
mat_a.sizes()[1] % 16 == 0,
"Expected trailing dimension of mat1 to be divisible by 16 ",
"but got mat1 shape: (",
mat_a.sizes()[0],
"x",
mat_a.sizes()[1],
").");
TORCH_CHECK_VALUE(mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x",
mat_b.sizes()[1], ") must be divisible by 16");
// TODO(slayton): Existing checks, not sure if they should really be here.
TORCH_CHECK_VALUE(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type");
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) || mat_a.scalar_type() == ScalarType::Float4_e2m1fn_x2,
"Expected mat_a to be Float8 or Float4_x2 matrix got ", mat_a.scalar_type());
TORCH_CHECK_VALUE(isFloat8Type(mat_b.scalar_type()) || mat_b.scalar_type() == ScalarType::Float4_e2m1fn_x2,
"Expected mat_b to be Float8 or Float4_x2 matrix got ", mat_b.scalar_type());
#ifndef USE_ROCM
// Type restrictions imposed by CuBLASLt as of CUDA-12.1
TORCH_CHECK_VALUE(mat_a.scalar_type() != ScalarType::Float8_e5m2 || mat_b.scalar_type() != ScalarType::Float8_e5m2,
"Multiplication of two Float8_e5m2 matrices is not supported");
#endif
if (use_fast_accum) {
TORCH_CHECK_VALUE(mat_a.scalar_type() != ScalarType::Float4_e2m1fn_x2 && mat_b.scalar_type() != ScalarType::Float4_e2m1fn_x2, "`use_fast_accum` is not supported when `mat_a` or `mat_b` tensors have the `Float4_e2m1fn_x2` dtype.");
}
#ifdef USE_ROCM
if (mat_a.scalar_type() == ScalarType::Float4_e2m1fn_x2 || mat_b.scalar_type() == ScalarType::Float4_e2m1fn_x2) {
TORCH_CHECK_NOT_IMPLEMENTED(ROCM_VERSION >= 70000,
"Float4_e2m1fn_x2 is only supported for ROCm 7.0 and above");
}
if (mat_a.scalar_type() == ScalarType::Float8_e5m2 || mat_b.scalar_type() == ScalarType::Float8_e5m2) {
TORCH_CHECK_NOT_IMPLEMENTED(ROCM_VERSION >= 60500,
"Float8_e5m2 is only supported for ROCm 6.5 and above");
}
if (mat_a.scalar_type() == ScalarType::Float8_e4m3fn || mat_b.scalar_type() == ScalarType::Float8_e4m3fn) {
TORCH_CHECK_NOT_IMPLEMENTED(ROCM_VERSION >= 60500,
"Float8_e4m3fn is only supported for ROCm 6.5 and above");
}
#endif
if (bias) {
TORCH_CHECK_VALUE(out.scalar_type() != kFloat,
"Bias is not supported when out_dtype is set to Float32");
TORCH_CHECK_VALUE(bias->scalar_type() == ScalarType::BFloat16 ||
bias->scalar_type() == ScalarType::Half,
"Bias must be BFloat16 or Half, but got ", bias->scalar_type());
TORCH_CHECK_VALUE((out.scalar_type() != kFloat &&
out.scalar_type() != ScalarType::BFloat16) ||
bias->scalar_type() == ScalarType::BFloat16,
"Bias must be BFloat16 to compute ", out.scalar_type(),
" output, but got ", bias->scalar_type());
TORCH_CHECK_VALUE(out.scalar_type() != ScalarType::Half ||
bias->scalar_type() == ScalarType::Half,
"Bias must be Float16 to compute ", out.scalar_type(),
" output, but got ", bias->scalar_type());
}
{
auto bias_ = bias.value_or(Tensor());
// NOLINTNEXTLINE(*c-array*)
TensorArg targs[]{{out, "out", 0}, {mat_a, "mat_a", 1}, {mat_b, "mat_b", 2},
{bias_, "bias", 3}, {scale_a[0], "scale_a", 4}, {scale_b[0], "scale_b", 5}};
checkAllSameGPU(__func__, targs);
}
auto out_dtype_ = out_dtype.value_or(at::ScalarType::BFloat16);
// Conversion of implicitly-defined enums to explicit
auto scale_recipe_a_enum = convert_int_to_enum<ScalingType>(scale_recipe_a);
auto swizzle_a_enum = convert_int_to_enum<SwizzleType>(swizzle_a);
auto scale_recipe_b_enum = convert_int_to_enum<ScalingType>(scale_recipe_b);
auto swizzle_b_enum = convert_int_to_enum<SwizzleType>(swizzle_b);
// at this point we can start working out what we want to be doing
// Try to do as few steps as possible.
// NOTE: support is deliberately sparse, can explicitly enumerate all combinations allowed.
// Do this via a list of defined (name, acceptance, concrete_impl) tuples.
bool found_impl = false;
ScaledGemmImplementation gemm_impl = ScaledGemmImplementation::NONE;
for (const auto& fn_entry : scale_kernel_dispatch) {
const auto [name, accept_fn, scaled_gemm_impl] = fn_entry;
bool ok = accept_fn(mat_a.scalar_type(),
scale_recipe_a_enum,
scale_a,
mat_b.scalar_type(),
scale_recipe_b_enum,
scale_b);
if (ok) {
gemm_impl = scaled_gemm_impl;
found_impl = true;
break;
}
}
TORCH_CHECK_VALUE(
found_impl,
"Invalid scaling configuration.\n"
"- For TensorWise scaling, a and b should be float8, scales should be float and singletons.\n"
"- For RowWise scaling, a and b should be float8, scales should be float, scale_a should be (", mat_a.size(0), ", 1) and scale_b should be (1, ", mat_b.size(1), "), and both should be contiguous.\n"
"- For BlockWise 1x128 scaling, a and b should be float8, scales should be float, scale_a should be (", mat_a.size(0), ", ", ceil_div<int64_t>(mat_a.size(1), 128), ") and scale_b should be (", ceil_div<int64_t>(mat_b.size(0), 128), ", ", mat_b.size(1), "), and both should be outer-dim-major.\n"
"- For BlockWise 128x128 scaling, a and b should be float8, scales should be float, scale_a should be (", ceil_div<int64_t>(mat_a.size(0), 128), ", ", ceil_div<int64_t>(mat_a.size(1), 128), ") and scale_b should be (", ceil_div<int64_t>(mat_b.size(0), 128), ", ", ceil_div<int64_t>(mat_b.size(1), 128), "), and both should be near-inner-dim-major (with 16-byte aligned strides).\n"
"- For Blockwise 1x32 scaling, a and b should be float8, scales should be float8_e8m0fnu, scale_a should have ", round_up<int64_t>(mat_a.size(0), 128) * round_up<int64_t>(ceil_div<int64_t>(mat_a.size(1), 32), 4), " elements and scale_b should have ", round_up<int64_t>(mat_b.size(1), 128) * round_up<int64_t>(ceil_div<int64_t>(mat_b.size(0), 32), 4), " elements, and both should be contiguous.\n"
"- For Blockwise 1x16 scaling, a and b should be float4 (packed 2x), scales should be float8_e4m3fn, scale_a should have ", round_up<int64_t>(mat_a.size(0), 128) * round_up<int64_t>(ceil_div<int64_t>(mat_a.size(1) * 2, 16), 4), " elements and scale_b should have ", round_up<int64_t>(mat_b.size(1), 128) * round_up<int64_t>(ceil_div<int64_t>(mat_b.size(0) * 2, 16), 4), " elements, and both should be contiguous.\n"
"Got mat_a.dtype()=", mat_a.scalar_type(), ", scale_a[0].dtype()=", scale_a[0].scalar_type(), ", scale_a[0].size()=", scale_a[0].sizes(), ", scale_a[0].stride()=", scale_a[0].strides(), ", ",
"mat_b.dtype()=", mat_b.scalar_type(), ", scale_b[0].dtype()=", scale_b[0].scalar_type(), ", scale_b[0].size()=", scale_b[0].sizes(), " and scale_b[0].stride()=", scale_b[0].strides()
);
at::native::resize_output(out, {mat_a.size(0), mat_b.size(1)});
auto bias_ = bias.value_or(Tensor());
// dispatch to appropriate lower-level calls for error checking & execution
if (gemm_impl == ScaledGemmImplementation::TENSORWISE_TENSORWISE) {
return _scaled_tensorwise_tensorwise(mat_a, mat_b, scale_a[0], scale_b[0], bias, out_dtype_, use_fast_accum, out);
} else if (gemm_impl == ScaledGemmImplementation::ROWWISE_ROWWISE) {
return _scaled_rowwise_rowwise(mat_a, mat_b, scale_a[0], scale_b[0], bias, out_dtype_, use_fast_accum, out);
} else if (gemm_impl == ScaledGemmImplementation::BLOCK_128x128_1x128) {
return _scaled_block128x128_block1x128(mat_a, mat_b, scale_a[0], scale_b[0], bias, out_dtype_, use_fast_accum, out);
} else if (gemm_impl == ScaledGemmImplementation::BLOCK_1x128_128x128) {
return _scaled_block1x128_block128x128(mat_a, mat_b, scale_a[0], scale_b[0], bias, out_dtype_, use_fast_accum, out);
} else if (gemm_impl == ScaledGemmImplementation::BLOCK_1x128_1x128) {
return _scaled_block1x128_block1x128(mat_a, mat_b, scale_a[0], scale_b[0], bias, out_dtype_, use_fast_accum, out);
} else if (gemm_impl == ScaledGemmImplementation::MXFP8_MXFP8) {
return _scaled_mxfp8_mxfp8(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out);
} else if (gemm_impl == ScaledGemmImplementation::NVFP4_NVFP4) {
TORCH_CHECK_NOT_IMPLEMENTED(false, "Only single-scale NVFP4 currently supported");
} else if (gemm_impl == ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE) {
return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, true /* single_scale */, out);
} else {
TORCH_CHECK_VALUE(false, "Invalid state - found an implementation, but not really");
}
}
Tensor
_scaled_mm_cuda_v2(
const Tensor& mat_a, const Tensor& mat_b,
ArrayRef<Tensor> scale_a,
IntArrayRef scale_recipe_a,
IntArrayRef swizzle_a,
ArrayRef<Tensor> scale_b,
IntArrayRef scale_recipe_b,
IntArrayRef swizzle_b,
const std::optional<Tensor>& bias,
const std::optional<c10::ScalarType> out_dtype,
IntArrayRef contraction_dim,
bool use_fast_accum) {
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_));
return _scaled_mm_cuda_v2_out(
mat_a, mat_b,
scale_a, scale_recipe_a, swizzle_a,
scale_b, scale_recipe_b, swizzle_b,
bias,
out_dtype,
contraction_dim,
use_fast_accum,
out);
}
Tensor
_scaled_grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,

View File

@ -488,15 +488,16 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
}
}
int cat_dim = dimension;
if (memory_format != c10::MemoryFormat::Contiguous) {
switch (dimension) {
switch (cat_dim) {
case 0:
break;
case 1:
dimension = nDims - dimension;
cat_dim = nDims - cat_dim;
break;
default:
dimension--;
cat_dim--;
}
}
// Template Declarations for dim = 1, 2, 3, 4
@ -505,23 +506,23 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
constexpr auto elems_per_vec = alignment / sizeof(scalar_t); \
CatArrayBatchedCopy_vectorized<scalar_t, unsigned int, DIMS, batch_size, stride_size, alignment, elems_per_vec><<<\
catGrid, applyBlock, 0, stream.stream()>>>(\
(char*)data, catMetaData, kernelOutputParam, dimension, trailingSize);\
(char*)data, catMetaData, kernelOutputParam, cat_dim, trailingSize);\
} else if (isContig && isAligned && sizeof(scalar_t) > 2 && sizeof(scalar_t) <= 8) {\
CatArrayBatchedCopy_alignedK_contig<scalar_t, unsigned int, DIMS, batch_size, stride_size, ALIGNED_VEC_LOAD_BYTES_16><<<\
catGrid, applyBlock, 0, stream.stream()>>>(\
data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);\
data, catMetaData, outputParam, cat_dim, outputParam.tensorStride[cat_dim]);\
} else if (isContig && isAligned && sizeof(scalar_t) == 2) { \
CatArrayBatchedCopy_alignedK_contig<scalar_t, unsigned int, DIMS, batch_size, stride_size, ALIGNED_VEC_LOAD_BYTES_8><<<\
catGrid, applyBlock, 0, stream.stream()>>>(\
data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);\
data, catMetaData, outputParam, cat_dim, outputParam.tensorStride[cat_dim]);\
} else if (isContig) {\
CatArrayBatchedCopy_contig<scalar_t, unsigned int, DIMS, batch_size, stride_size><<<\
catGrid, applyBlock, 0, stream.stream()>>>(\
data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);\
data, catMetaData, outputParam, cat_dim, outputParam.tensorStride[cat_dim]);\
} else {\
CatArrayBatchedCopy<scalar_t, unsigned int, DIMS, batch_size, stride_size><<<\
catGrid, applyBlock, 0, stream.stream()>>>(\
data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);\
data, catMetaData, outputParam, cat_dim, outputParam.tensorStride[cat_dim]);\
}\
C10_CUDA_KERNEL_LAUNCH_CHECK();
switch (nDims) {

View File

@ -127,7 +127,7 @@ inline __host__ __device__ uint32_t getAlignmentRoundUp(const void* p) {
return diff == 0 ? 0 : uint32_t(Align) - diff;
}
#if defined (__gfx90a__) || defined(__gfx942__)
#if defined (__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)
#define CDNA2_OR_LATER 1
#else
#define CDNA2_OR_LATER 0
@ -143,7 +143,7 @@ template<typename T, uint32_t Rank>
using VecT = T __attribute__((ext_vector_type(Rank)));
static bool isCDNA2orLater(int index) {
return at::detail::getCUDAHooks().isGPUArch({"gfx90a", "gfx942"}, index);
return at::detail::getCUDAHooks().isGPUArch({"gfx90a", "gfx942", "gfx950"}, index);
}
#else

View File

@ -86,7 +86,7 @@ namespace cuda { namespace detail {
struct LinalgDispatch {
Tensor (*cholesky_solve_helper)(const Tensor& self, const Tensor& A, bool upper);
};
C10_EXPORT void registerLinalgDispatch(const LinalgDispatch&);
C10_EXPORT void registerLinalgDispatch(const LinalgDispatch& /*disp_*/);
}} // namespace cuda::detail
#endif

View File

@ -341,16 +341,22 @@ struct MHACacheKeyWrapper : ParamsWrapper<MHAParams> {
}
};
template <typename T, typename KeyType>
struct MHAGraphCache {
std::unordered_map<KeyType, T, ParamsWrapperHash<KeyType>> engine_cache;
using KeyType = MHACacheKeyWrapper;
using ValueType = std::unique_ptr<fe::graph::Graph>;
using MapType =
std::unordered_map<KeyType, ValueType, ParamsWrapperHash<KeyType>>;
using iterator = typename MapType::iterator;
using const_iterator = typename MapType::const_iterator;
MapType engine_cache;
int count = 0;
int hits = 0;
// no mutexes here as caches are now thread local for v8, can also return a
// pointer to the Execution Plan if we know it will not be invalidated by
// another thread
T* find(const KeyType& key) {
iterator find(const KeyType& key) {
static bool flag =
c10::utils::check_env("TORCH_CUDNN_SDPA_CACHE_DEBUG") == true;
if (flag && count) {
@ -363,15 +369,19 @@ struct MHAGraphCache {
}
count++;
auto it = engine_cache.find(key);
if (it == engine_cache.end()) {
return nullptr;
if (it != engine_cache.end()) {
hits++;
}
hits++;
return &(it->second);
return it;
}
void update(const KeyType& key, T& results) {
engine_cache.insert_or_assign(key, std::move(results));
const_iterator end() const {
return engine_cache.end();
}
template <typename... Args>
std::pair<iterator, bool> try_emplace(const KeyType& key, Args&&... args) {
return engine_cache.try_emplace(key, std::forward<Args>(args)...);
}
};
@ -380,16 +390,14 @@ struct MHAGraphCache {
// https://docs.nvidia.com/deeplearning/cudnn/backend/latest/release-notes.html
// We also leak the caches to workaround potential teardown race issues.
auto& getMHAGraphCache_() {
thread_local auto& instance =
*new MHAGraphCache<std::shared_ptr<fe::graph::Graph>, MHACacheKeyWrapper>;
return instance;
MHAGraphCache& getMHAGraphCache_() {
thread_local MHAGraphCache* instance{new MHAGraphCache()};
return *instance;
}
auto& getMHAGraphBackwardCache_() {
thread_local auto& instance =
*new MHAGraphCache<std::shared_ptr<fe::graph::Graph>, MHACacheKeyWrapper>;
return instance;
MHAGraphCache& getMHAGraphBackwardCache_() {
thread_local MHAGraphCache* instance{new MHAGraphCache()};
return *instance;
}
namespace {
@ -437,7 +445,7 @@ auto fixSizeOneDimStrideSDPA(
} // namespace
auto build_graph(
std::unique_ptr<fe::graph::Graph> build_graph(
int64_t b,
int64_t h,
int64_t s_q,
@ -461,7 +469,7 @@ auto build_graph(
if (q.scalar_type() == kBFloat16) {
dtype = fe::DataType_t::BFLOAT16;
}
auto mha_graph = std::make_shared<fe::graph::Graph>();
auto mha_graph = std::make_unique<fe::graph::Graph>();
// We're baking in float accumulation and scale types
// in theory the graph may support other types, but they
// have not been tested
@ -531,15 +539,13 @@ auto build_graph(
fe::graph::Tensor_attributes().set_uid(K).set_name("K"));
auto V_ = mha_graph->tensor(
fe::graph::Tensor_attributes().set_uid(V).set_name("V"));
std::optional<std::shared_ptr<fe::graph::Tensor_attributes>> bias;
if (attn_bias.has_value()) {
bias =
scaled_dot_product_flash_attention_options.set_bias(
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_uid(BIAS)
.set_name("bias")
.set_dim(attn_bias.value().sizes().vec())
.set_stride(attn_bias.value().strides().vec()));
scaled_dot_product_flash_attention_options.set_bias(bias.value());
.set_stride(attn_bias.value().strides().vec())));
}
auto [O_, Stats] =
@ -640,7 +646,7 @@ auto build_graph(
return mha_graph;
}
auto build_graph_nestedtensor(
std::unique_ptr<fe::graph::Graph> build_graph_nestedtensor(
int64_t b,
int64_t h_q,
int64_t h_k,
@ -668,7 +674,7 @@ auto build_graph_nestedtensor(
if (q.scalar_type() == kBFloat16) {
dtype = fe::DataType_t::BFLOAT16;
}
auto mha_graph = std::make_shared<fe::graph::Graph>();
auto mha_graph = std::make_unique<fe::graph::Graph>();
// We're baking in float accumulation and scale types
// in theory the graph may support other types, but they
// have not been tested
@ -766,18 +772,16 @@ auto build_graph_nestedtensor(
v_strides[strideidx0],
v_strides[strideidx1],
v_strides[strideidx2]}));
std::optional<std::shared_ptr<fe::graph::Tensor_attributes>> bias;
if (attn_bias.has_value()) {
TORCH_CHECK(
false,
"attn_bias not yet supportd with cuDNN Attention and NestedTensor");
bias =
scaled_dot_product_flash_attention_options.set_bias(
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_uid(BIAS)
.set_name("bias")
.set_dim(attn_bias.value().sizes().vec())
.set_stride(attn_bias.value().strides().vec()));
scaled_dot_product_flash_attention_options.set_bias(bias.value());
.set_stride(attn_bias.value().strides().vec())));
}
auto RAG_Q_OFF_ =
mha_graph->tensor(fe::graph::Tensor_attributes()
@ -847,7 +851,7 @@ auto build_graph_nestedtensor(
return mha_graph;
}
auto build_graph_backward(
std::unique_ptr<fe::graph::Graph> build_graph_backward(
int64_t b,
int64_t h,
int64_t s_q,
@ -874,7 +878,7 @@ auto build_graph_backward(
if (q.scalar_type() == kBFloat16) {
dtype = fe::DataType_t::BFLOAT16;
}
auto mha_graph = std::make_shared<fe::graph::Graph>();
auto mha_graph = std::make_unique<fe::graph::Graph>();
// We're baking in float accumulation and scale types
// in theory the graph may support other types, but they
// have not been tested
@ -919,15 +923,13 @@ auto build_graph_backward(
fe::graph::Tensor_attributes().set_uid(K).set_name("K"));
auto V_ = mha_graph->tensor(
fe::graph::Tensor_attributes().set_uid(V).set_name("V"));
std::optional<std::shared_ptr<fe::graph::Tensor_attributes>> bias;
if (attn_bias.has_value()) {
bias =
sdpa_backward_options.set_bias(
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_uid(BIAS)
.set_name("bias")
.set_dim(attn_bias.value().sizes().vec())
.set_stride(attn_bias.value().strides().vec()));
sdpa_backward_options.set_bias(bias.value());
.set_stride(attn_bias.value().strides().vec())));
}
if (dropout_probability != 0.0f) {
auto seed = mha_graph->tensor(fe::graph::Tensor_attributes()
@ -1061,7 +1063,7 @@ auto build_graph_backward(
return mha_graph;
}
auto build_graph_backward_nestedtensor(
std::unique_ptr<fe::graph::Graph> build_graph_backward_nestedtensor(
int64_t b,
int64_t h_q,
int64_t h_k,
@ -1092,7 +1094,7 @@ auto build_graph_backward_nestedtensor(
if (q.scalar_type() == kBFloat16) {
dtype = fe::DataType_t::BFLOAT16;
}
auto mha_graph = std::make_shared<fe::graph::Graph>();
auto mha_graph = std::make_unique<fe::graph::Graph>();
// We're baking in float accumulation and scale types
// in theory the graph may support other types, but they
// have not been tested
@ -1195,18 +1197,16 @@ auto build_graph_backward_nestedtensor(
o_strides[strideidx1],
o_strides[strideidx2]}));
std::optional<std::shared_ptr<fe::graph::Tensor_attributes>> bias;
if (attn_bias.has_value()) {
TORCH_CHECK(
false,
"attn_bias not yet supportd with cuDNN Attention and NestedTensor");
bias =
sdpa_backward_options.set_bias(
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_uid(BIAS)
.set_name("bias")
.set_dim(attn_bias.value().sizes().vec())
.set_stride(attn_bias.value().strides().vec()));
sdpa_backward_options.set_bias(bias.value());
.set_stride(attn_bias.value().strides().vec())));
}
auto RAG_Q_OFF_ =
mha_graph->tensor(fe::graph::Tensor_attributes()
@ -1378,7 +1378,7 @@ void run_cudnn_SDP_fprop(
// NB: The key initialization will round up sequence length, stride data etc.
// if use_ragged_in_dense is enabled (to allow multiple sequence lengths to
// reuse the same cached value/graph)
auto key = MHACacheKeyWrapper(
MHACacheKeyWrapper key(
b,
h,
s_q,
@ -1393,12 +1393,9 @@ void run_cudnn_SDP_fprop(
is_causal,
return_softmaxstats,
false);
auto graph_ptr = getMHAGraphCache_().find(key);
std::shared_ptr<fe::graph::Graph> mha_graph;
if (graph_ptr) {
mha_graph = *graph_ptr;
} else {
mha_graph = build_graph(
auto [cache_it, not_found] = getMHAGraphCache_().try_emplace(key, nullptr);
if (not_found) {
cache_it->second = build_graph(
b,
h,
s_q,
@ -1419,39 +1416,39 @@ void run_cudnn_SDP_fprop(
_dropoutoffset,
handle);
}
const fe::graph::Graph& mha_graph = *cache_it->second;
std::unordered_map<int64_t, void*> variant_pack = {
{Q, q.data_ptr()},
{K, k.data_ptr()},
{V, v.data_ptr()},
{Q, q.mutable_data_ptr()},
{K, k.mutable_data_ptr()},
{V, v.mutable_data_ptr()},
{SCALE, &scaling_factor},
{O, o.data_ptr()}};
{O, o.mutable_data_ptr()}};
if (return_softmaxstats) {
variant_pack[LSE] = softmaxstats.data_ptr();
variant_pack[LSE] = softmaxstats.mutable_data_ptr();
}
if (attn_bias.has_value()) {
variant_pack[BIAS] = attn_bias.value().data_ptr();
variant_pack[BIAS] = attn_bias.value().mutable_data_ptr();
}
if (dropout_probability != 0.0f) {
variant_pack[SEED] = _dropoutseed.data_ptr();
variant_pack[OFFSET] = _dropoutoffset.data_ptr();
variant_pack[SEED] = _dropoutseed.mutable_data_ptr();
variant_pack[OFFSET] = _dropoutoffset.mutable_data_ptr();
}
if (use_ragged_in_dense(q, k, v, o, attn_bias.has_value())) {
variant_pack[SEQ_LEN_Q] = seqlen_q.data_ptr();
variant_pack[SEQ_LEN_KV] = seqlen_kv.data_ptr();
variant_pack[RAG_Q_OFF] = rag_off_q.data_ptr();
variant_pack[RAG_K_OFF] = rag_off_k.data_ptr();
variant_pack[RAG_V_OFF] = rag_off_v.data_ptr();
variant_pack[RAG_O_OFF] = rag_off_o.data_ptr();
variant_pack[SEQ_LEN_Q] = seqlen_q.mutable_data_ptr();
variant_pack[SEQ_LEN_KV] = seqlen_kv.mutable_data_ptr();
variant_pack[RAG_Q_OFF] = rag_off_q.mutable_data_ptr();
variant_pack[RAG_K_OFF] = rag_off_k.mutable_data_ptr();
variant_pack[RAG_V_OFF] = rag_off_v.mutable_data_ptr();
variant_pack[RAG_O_OFF] = rag_off_o.mutable_data_ptr();
if (return_softmaxstats) {
variant_pack[RAG_LSE_OFF] = rag_off_lse.data_ptr();
variant_pack[RAG_LSE_OFF] = rag_off_lse.mutable_data_ptr();
}
}
auto workspace_size = mha_graph->get_workspace_size();
auto workspace_size = mha_graph.get_workspace_size();
auto workspace_ptr =
c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
TORCH_CHECK(
mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good());
getMHAGraphCache_().update(key, mha_graph);
mha_graph.execute(handle, variant_pack, workspace_ptr.get()).is_good());
}
void run_cudnn_SDP_fprop_nestedtensor(
@ -1491,7 +1488,7 @@ void run_cudnn_SDP_fprop_nestedtensor(
softmaxstats = at::empty({q.size(0), h_q, 1}, q.options().dtype(kFloat));
}
auto key = MHACacheKeyWrapper(
MHACacheKeyWrapper key(
b,
h_q,
s_q, // max-seqlen-q
@ -1506,13 +1503,12 @@ void run_cudnn_SDP_fprop_nestedtensor(
is_causal,
return_softmaxstats,
true);
auto graph_ptr = getMHAGraphCache_().find(key);
std::shared_ptr<fe::graph::Graph> mha_graph;
if (graph_ptr) {
mha_graph = *graph_ptr;
} else {
mha_graph = build_graph_nestedtensor(
MHAGraphCache& cache = getMHAGraphCache_();
auto cache_it = cache.find(key);
std::unique_ptr<fe::graph::Graph> mha_graph_storage;
if (cache_it == cache.end()) {
mha_graph_storage = build_graph_nestedtensor(
b,
h_q,
h_k,
@ -1537,40 +1533,44 @@ void run_cudnn_SDP_fprop_nestedtensor(
dropoutoffset,
handle);
}
const fe::graph::Graph& mha_graph =
mha_graph_storage ? *mha_graph_storage : *cache_it->second;
auto seqlen_q = at::diff(cum_seqlen_q, 1, 0);
auto seqlen_kv = at::diff(cum_seqlen_kv, 1, 0);
auto rag_q_off = cum_seqlen_q.mul(h_q * d_qk);
auto rag_k_off = cum_seqlen_kv.mul(h_k * d_v);
auto rag_v_off = cum_seqlen_kv.mul(h_v * d_v);
auto rag_q_off = cum_seqlen_q.mul(q.stride(-3));
auto rag_k_off = cum_seqlen_kv.mul(k.stride(-3));
auto rag_v_off = cum_seqlen_kv.mul(v.stride(-3));
auto rag_o_off = cum_seqlen_q.mul(o.stride(-3));
auto rag_stats_off = cum_seqlen_q.mul(h_q);
std::unordered_map<int64_t, void*> variant_pack = {
{Q, q.data_ptr()},
{K, k.data_ptr()},
{V, v.data_ptr()},
{Q, q.mutable_data_ptr()},
{K, k.mutable_data_ptr()},
{V, v.mutable_data_ptr()},
{SCALE, &scaling_factor},
{O, o.data_ptr()},
{RAG_Q_OFF, rag_q_off.data_ptr()},
{RAG_O_OFF, rag_q_off.data_ptr()},
{RAG_K_OFF, rag_k_off.data_ptr()},
{RAG_V_OFF, rag_v_off.data_ptr()},
{SEQ_LEN_Q, seqlen_q.data_ptr()},
{SEQ_LEN_KV, seqlen_kv.data_ptr()}};
{O, o.mutable_data_ptr()},
{RAG_Q_OFF, rag_q_off.mutable_data_ptr()},
{RAG_O_OFF, rag_o_off.mutable_data_ptr()},
{RAG_K_OFF, rag_k_off.mutable_data_ptr()},
{RAG_V_OFF, rag_v_off.mutable_data_ptr()},
{SEQ_LEN_Q, seqlen_q.mutable_data_ptr()},
{SEQ_LEN_KV, seqlen_kv.mutable_data_ptr()}};
if (return_softmaxstats) {
variant_pack[LSE] = softmaxstats.data_ptr();
variant_pack[RAG_LSE_OFF] = rag_stats_off.data_ptr();
variant_pack[LSE] = softmaxstats.mutable_data_ptr();
variant_pack[RAG_LSE_OFF] = rag_stats_off.mutable_data_ptr();
}
if (dropout_probability != 0.0f) {
variant_pack[SEED] = dropoutseed.data_ptr();
variant_pack[OFFSET] = dropoutoffset.data_ptr();
variant_pack[SEED] = dropoutseed.mutable_data_ptr();
variant_pack[OFFSET] = dropoutoffset.mutable_data_ptr();
}
if (attn_bias.has_value()) {
TORCH_CHECK("bias not supported with nestedtensor");
}
auto workspace_size = mha_graph->get_workspace_size();
auto workspace_size = mha_graph.get_workspace_size();
auto workspace_ptr =
c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
TORCH_CHECK(
mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good());
mha_graph.execute(handle, variant_pack, workspace_ptr.get()).is_good());
}
void run_cudnn_SDP_bprop(
@ -1652,7 +1652,7 @@ void run_cudnn_SDP_bprop(
}
cudnnHandle_t handle = getCudnnHandle();
auto key = MHACacheKeyWrapper(
MHACacheKeyWrapper key(
b,
h,
s_q,
@ -1667,12 +1667,10 @@ void run_cudnn_SDP_bprop(
is_causal,
true,
false);
auto graph_backward_ptr = getMHAGraphBackwardCache_().find(key);
std::shared_ptr<fe::graph::Graph> mha_graph;
if (graph_backward_ptr) {
mha_graph = *graph_backward_ptr;
} else {
mha_graph = build_graph_backward(
auto [cache_it, not_found] =
getMHAGraphBackwardCache_().try_emplace(key, nullptr);
if (not_found) {
cache_it->second = build_graph_backward(
b,
h,
s_q,
@ -1696,43 +1694,44 @@ void run_cudnn_SDP_bprop(
_dropoutoffset,
handle);
}
const fe::graph::Graph& mha_graph = *cache_it->second;
std::unordered_map<int64_t, void*> variant_pack = {
// inputs
{Q, q.data_ptr()},
{K, k.data_ptr()},
{V, v.data_ptr()},
{O, o.data_ptr()},
{DO, dO_.data_ptr()},
{LSE, softmaxstats.data_ptr()},
{Q, q.mutable_data_ptr()},
{K, k.mutable_data_ptr()},
{V, v.mutable_data_ptr()},
{O, o.mutable_data_ptr()},
{DO, dO_.mutable_data_ptr()},
{LSE, softmaxstats.mutable_data_ptr()},
// outputs
{DQ, dQ.data_ptr()},
{DK, dK.data_ptr()},
{DV, dV.data_ptr()},
{DQ, dQ.mutable_data_ptr()},
{DK, dK.mutable_data_ptr()},
{DV, dV.mutable_data_ptr()},
{SCALE, &scaling_factor}};
if (dropout_probability != 0.0f) {
variant_pack[SEED] = _dropoutseed.data_ptr();
variant_pack[OFFSET] = _dropoutoffset.data_ptr();
variant_pack[SEED] = _dropoutseed.mutable_data_ptr();
variant_pack[OFFSET] = _dropoutoffset.mutable_data_ptr();
}
if (attn_bias.has_value()) {
variant_pack[BIAS] = attn_bias.value().data_ptr();
variant_pack[BIAS] = attn_bias.value().mutable_data_ptr();
}
if (use_ragged_in_dense(q, k, v, o, attn_bias.has_value())) {
variant_pack[SEQ_LEN_Q] = seqlen_q.data_ptr();
variant_pack[SEQ_LEN_KV] = seqlen_kv.data_ptr();
variant_pack[RAG_Q_OFF] = rag_off_q.data_ptr();
variant_pack[RAG_K_OFF] = rag_off_k.data_ptr();
variant_pack[RAG_V_OFF] = rag_off_v.data_ptr();
variant_pack[RAG_O_OFF] = rag_off_o.data_ptr();
variant_pack[RAG_LSE_OFF] = rag_off_lse.data_ptr();
variant_pack[SEQ_LEN_Q] = seqlen_q.mutable_data_ptr();
variant_pack[SEQ_LEN_KV] = seqlen_kv.mutable_data_ptr();
variant_pack[RAG_Q_OFF] = rag_off_q.mutable_data_ptr();
variant_pack[RAG_K_OFF] = rag_off_k.mutable_data_ptr();
variant_pack[RAG_V_OFF] = rag_off_v.mutable_data_ptr();
variant_pack[RAG_O_OFF] = rag_off_o.mutable_data_ptr();
variant_pack[RAG_LSE_OFF] = rag_off_lse.mutable_data_ptr();
}
auto workspace_size = mha_graph->get_workspace_size();
auto workspace_size = mha_graph.get_workspace_size();
auto workspace_ptr =
c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
TORCH_CHECK(!workspace_size || workspace_ptr.get());
TORCH_CHECK(
mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good());
getMHAGraphBackwardCache_().update(key, mha_graph);
mha_graph.execute(handle, variant_pack, workspace_ptr.get()).is_good());
}
void run_cudnn_SDP_bprop_nestedtensor(
@ -1775,9 +1774,10 @@ void run_cudnn_SDP_bprop_nestedtensor(
auto seqlen_q = at::diff(cum_seqlen_q, 1, 0);
auto seqlen_kv = at::diff(cum_seqlen_kv, 1, 0);
auto rag_q_off = cum_seqlen_q.mul(h_q * d_qk);
auto rag_k_off = cum_seqlen_kv.mul(h_k * d_v);
auto rag_v_off = cum_seqlen_kv.mul(h_v * d_v);
auto rag_q_off = cum_seqlen_q.mul(q.stride(-3));
auto rag_k_off = cum_seqlen_kv.mul(k.stride(-3));
auto rag_v_off = cum_seqlen_kv.mul(v.stride(-3));
auto rag_o_off = cum_seqlen_q.mul(o.stride(-3));
auto rag_stats_off = cum_seqlen_q.mul(h_q);
auto dprops = at::cuda::getCurrentDeviceProperties();
@ -1791,7 +1791,7 @@ void run_cudnn_SDP_bprop_nestedtensor(
cudnnHandle_t handle = getCudnnHandle();
auto key = MHACacheKeyWrapper(
MHACacheKeyWrapper key(
b,
h_q,
s_q, // max-seqlen-q
@ -1806,13 +1806,12 @@ void run_cudnn_SDP_bprop_nestedtensor(
is_causal,
true,
true);
auto graph_ptr = getMHAGraphCache_().find(key);
std::shared_ptr<fe::graph::Graph> mha_graph;
if (graph_ptr) {
mha_graph = *graph_ptr;
} else {
mha_graph = build_graph_backward_nestedtensor(
MHAGraphCache& cache = getMHAGraphCache_();
auto cache_it = cache.find(key);
std::unique_ptr<fe::graph::Graph> mha_graph_storage;
if (cache_it == cache.end()) {
mha_graph_storage = build_graph_backward_nestedtensor(
b,
h_q,
h_k,
@ -1840,41 +1839,43 @@ void run_cudnn_SDP_bprop_nestedtensor(
dropoutoffset,
handle);
}
const fe::graph::Graph& mha_graph =
mha_graph_storage ? *mha_graph_storage : *cache_it->second;
std::unordered_map<int64_t, void*> variant_pack = {
// inputs
{Q, q.data_ptr()},
{K, k.data_ptr()},
{V, v.data_ptr()},
{O, o.data_ptr()},
{DO, dO_.data_ptr()},
{LSE, softmaxstats.data_ptr()},
{Q, q.mutable_data_ptr()},
{K, k.mutable_data_ptr()},
{V, v.mutable_data_ptr()},
{O, o.mutable_data_ptr()},
{DO, dO_.mutable_data_ptr()},
{LSE, softmaxstats.mutable_data_ptr()},
// outputs
{DQ, dQ.data_ptr()},
{DK, dK.data_ptr()},
{DV, dV.data_ptr()},
{DQ, dQ.mutable_data_ptr()},
{DK, dK.mutable_data_ptr()},
{DV, dV.mutable_data_ptr()},
{SCALE, &scaling_factor},
{RAG_Q_OFF, rag_q_off.data_ptr()},
{RAG_O_OFF, rag_q_off.data_ptr()},
{RAG_K_OFF, rag_k_off.data_ptr()},
{RAG_V_OFF, rag_v_off.data_ptr()},
{RAG_LSE_OFF, rag_stats_off.data_ptr()},
{SEQ_LEN_Q, seqlen_q.data_ptr()},
{SEQ_LEN_KV, seqlen_kv.data_ptr()}};
{RAG_Q_OFF, rag_q_off.mutable_data_ptr()},
{RAG_O_OFF, rag_o_off.mutable_data_ptr()},
{RAG_K_OFF, rag_k_off.mutable_data_ptr()},
{RAG_V_OFF, rag_v_off.mutable_data_ptr()},
{RAG_LSE_OFF, rag_stats_off.mutable_data_ptr()},
{SEQ_LEN_Q, seqlen_q.mutable_data_ptr()},
{SEQ_LEN_KV, seqlen_kv.mutable_data_ptr()}};
if (dropout_probability != 0.0f) {
variant_pack[SEED] = _dropoutseed.data_ptr();
variant_pack[OFFSET] = _dropoutoffset.data_ptr();
variant_pack[SEED] = _dropoutseed.mutable_data_ptr();
variant_pack[OFFSET] = _dropoutoffset.mutable_data_ptr();
}
TORCH_CHECK(
!attn_bias.has_value(),
"attn_bias not yet supportd with cuDNN Attention and NestedTensor");
auto workspace_size = mha_graph->get_workspace_size();
auto workspace_size = mha_graph.get_workspace_size();
auto workspace_ptr =
c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
TORCH_CHECK(!workspace_size || workspace_ptr.get());
TORCH_CHECK(
mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good());
mha_graph.execute(handle, variant_pack, workspace_ptr.get()).is_good());
}
} // namespace native

View File

@ -116,6 +116,8 @@ class MetalShaderLibrary {
std::vector<std::string> getFunctionNames();
std::shared_ptr<MetalKernelFunction> getKernelFunction(
const std::string& name);
// Returns a raw pointer to the kernel function for use in C APIs
MetalKernelFunction* getCachedKernelFunctionPtr(const std::string& name);
inline MTLComputePipelineState_t getPipelineStateForFunc(
const std::string& fname) {
return getLibraryPipelineState(getLibrary(), fname).first;
@ -164,6 +166,9 @@ class MetalShaderLibrary {
std::string,
std::pair<MTLComputePipelineState_t, MTLFunction_t>>
cplMap;
// Cache for kernel functions returned by getCachedKernelFunctionPtr
std::unordered_map<std::string, std::unique_ptr<MetalKernelFunction>>
kernelCache;
};
class DynamicMetalShaderLibrary : public MetalShaderLibrary {

View File

@ -917,6 +917,22 @@ std::shared_ptr<MetalKernelFunction> MetalShaderLibrary::getKernelFunction(const
return std::make_shared<MetalKernelFunction>(cpl, func);
}
MetalKernelFunction* MetalShaderLibrary::getCachedKernelFunctionPtr(const std::string& name) {
// Check if kernel is already cached
auto it = kernelCache.find(name);
if (it != kernelCache.end()) {
return it->second.get();
}
// Create new kernel function and cache it
auto [cpl, func] = getLibraryPipelineState(getLibrary(), name);
auto kernel = std::make_unique<MetalKernelFunction>(cpl, func);
MetalKernelFunction* raw_ptr = kernel.get();
kernelCache[name] = std::move(kernel);
return raw_ptr;
}
class BundledShaderLibary : public MetalShaderLibrary {
public:
BundledShaderLibary() : MetalShaderLibrary("") {}

View File

@ -5,6 +5,38 @@
# representing ScalarType's. They are now superseded by usage of
# `aten::to()`. The ops remain here for backward compatibility purposes.
# DEPRECATED. DO NOT USE
- func: _cast_Byte(Tensor self, bool non_blocking=False) -> Tensor
variants: function
# DEPRECATED. DO NOT USE
- func: _cast_Char(Tensor self, bool non_blocking=False) -> Tensor
variants: function
# DEPRECATED. DO NOT USE
- func: _cast_Double(Tensor self, bool non_blocking=False) -> Tensor
variants: function
# DEPRECATED. DO NOT USE
- func: _cast_Float(Tensor self, bool non_blocking=False) -> Tensor
variants: function
# DEPRECATED. DO NOT USE
- func: _cast_Int(Tensor self, bool non_blocking=False) -> Tensor
variants: function
# DEPRECATED. DO NOT USE
- func: _cast_Long(Tensor self, bool non_blocking=False) -> Tensor
variants: function
# DEPRECATED. DO NOT USE
- func: _cast_Short(Tensor self, bool non_blocking=False) -> Tensor
variants: function
# DEPRECATED. DO NOT USE
- func: _cast_Half(Tensor self, bool non_blocking=False) -> Tensor
variants: function
# Computes the gradient of current tensor w.r.t. graph leaves.
- func: _backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()
manual_cpp_binding: True
@ -7125,6 +7157,7 @@
CUDA: _scaled_mm_cuda
tags: needs_exact_strides
- func: _scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!)
variants: function
dispatch:
@ -7132,6 +7165,16 @@
CUDA: _scaled_mm_out_cuda
tags: needs_exact_strides
- func: _scaled_mm_v2(Tensor self, Tensor mat2, Tensor[] scale_a, int[] recipe_a, int[] swizzle_a, Tensor[] scale_b, int[] recipe_b, int[] swizzle_b, Tensor? bias, ScalarType? out_dtype, int[] contraction_dim=[], bool use_fast_accum=False) -> Tensor
variants: function
dispatch:
CUDA: _scaled_mm_cuda_v2
- func: _scaled_mm_v2.out(Tensor self, Tensor mat2, Tensor[] scale_a, int[] recipe_a, int[] swizzle_a, Tensor[] scale_b, int[] recipe_b, int[] swizzle_b, Tensor? bias, ScalarType? out_dtype, int[] contraction_dim=[], bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!)
variants: function
dispatch:
CUDA: _scaled_mm_cuda_v2_out
- func: _scaled_grouped_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? offs=None, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor
variants: function

View File

@ -31,7 +31,7 @@ TORCH_API float dequantize_vec(
float* dst,
size_t count = 8);
template <typename SRC_T, typename DST_T>
TORCH_API DST_T requantize_val(double, int64_t, double, int64_t, SRC_T src);
TORCH_API DST_T requantize_val(double /*src_scale*/, int64_t /*src_zero_point*/, double /*dst_scale*/, int64_t /*dst_zero_point*/, SRC_T src);
// Given a multiplier and a zero_point, requantize int32_t computed values back
// to quantized values. See comment above

View File

@ -104,27 +104,27 @@ Tensor empty_strided_unknown_quantized(
// Provide better error message if dtype is wrong
Tensor empty_affine_quantized_other_backends_stub(
IntArrayRef,
std::optional<ScalarType>,
std::optional<Layout>,
std::optional<Device>,
std::optional<bool>,
double,
int64_t,
std::optional<c10::MemoryFormat>) {
IntArrayRef /*unused*/,
std::optional<ScalarType> /*unused*/,
std::optional<Layout> /*unused*/,
std::optional<Device> /*unused*/,
std::optional<bool> /*unused*/,
double /*unused*/,
int64_t /*unused*/,
std::optional<c10::MemoryFormat> /*unused*/) {
TORCH_CHECK(false, "Creation of quantized tensor requires quantized dtype like torch.quint8");
}
Tensor empty_per_channel_affine_quantized_other_backends_stub(
IntArrayRef,
const Tensor&,
const Tensor&,
int64_t,
std::optional<ScalarType>,
std::optional<Layout>,
std::optional<Device>,
std::optional<bool>,
std::optional<c10::MemoryFormat>) {
IntArrayRef /*unused*/,
const Tensor& /*unused*/,
const Tensor& /*unused*/,
int64_t /*unused*/,
std::optional<ScalarType> /*unused*/,
std::optional<Layout> /*unused*/,
std::optional<Device> /*unused*/,
std::optional<bool> /*unused*/,
std::optional<c10::MemoryFormat> /*unused*/) {
TORCH_CHECK(false, "Creation of quantized tensor requires quantized dtype like torch.quint8");
}

View File

@ -637,13 +637,7 @@ bool check_for_nested_inputs(sdp_params const& params, bool debug) {
TORCH_WARN("Experimental cuDNN SDPA nested tensor support is not enabled.");
}
return false;
} else if (has_for_nested_inputs(params) && (params.query.requires_grad() || params.key.requires_grad() || params.value.requires_grad())) {
if (debug) {
TORCH_WARN("Experimental cuDNN SDPA nested tensor support does not support backward.");
return false;
}
}
const auto dprop = at::cuda::getCurrentDeviceProperties();
// Check that the input is nested
if (!(dprop->major == 9 || dprop->major == 10) && has_for_nested_inputs(params)) {

View File

@ -29,63 +29,63 @@ bool available() {
}
bool use_convolution2d(
const Tensor&,
const Tensor&,
const at::OptionalIntArrayRef,
const IntArrayRef,
const IntArrayRef,
const IntArrayRef,
const int64_t,
bool) {
const Tensor& /*unused*/,
const Tensor& /*unused*/,
const at::OptionalIntArrayRef /*unused*/,
const IntArrayRef /*unused*/,
const IntArrayRef /*unused*/,
const IntArrayRef /*unused*/,
const int64_t /*unused*/,
bool /*unused*/) {
return false;
}
Tensor convolution2d(
const Tensor&,
const Tensor&,
const Tensor&,
const IntArrayRef,
const IntArrayRef,
const IntArrayRef,
const int64_t) {
const Tensor& /*unused*/,
const Tensor& /*unused*/,
const Tensor& /*unused*/,
const IntArrayRef /*unused*/,
const IntArrayRef /*unused*/,
const IntArrayRef /*unused*/,
const int64_t /*unused*/) {
TORCH_CHECK(false, internal::kError);
}
bool use_linear(
const Tensor&,
const Tensor&,
const Tensor&) {
const Tensor& /*unused*/,
const Tensor& /*unused*/,
const Tensor& /*unused*/) {
return false;
}
Tensor linear(
const Tensor&,
const Tensor&,
const Tensor&) {
const Tensor& /*unused*/,
const Tensor& /*unused*/,
const Tensor& /*unused*/) {
TORCH_CHECK(false, internal::kError);
}
bool use_max_pool2d(
const Tensor&,
const IntArrayRef,
const IntArrayRef,
IntArrayRef,
const IntArrayRef,
const bool,
const float,
const float) {
const Tensor& /*unused*/,
const IntArrayRef /*unused*/,
const IntArrayRef /*unused*/,
IntArrayRef /*unused*/,
const IntArrayRef /*unused*/,
const bool /*unused*/,
const float /*unused*/,
const float /*unused*/) {
return false;
}
Tensor max_pool2d(
const Tensor&,
const IntArrayRef,
const IntArrayRef,
IntArrayRef,
const IntArrayRef,
const bool,
const float,
const float) {
const Tensor& /*unused*/,
const IntArrayRef /*unused*/,
const IntArrayRef /*unused*/,
IntArrayRef /*unused*/,
const IntArrayRef /*unused*/,
const bool /*unused*/,
const float /*unused*/,
const float /*unused*/) {
TORCH_CHECK(false, internal::kError);
}

View File

@ -5,7 +5,7 @@ namespace at {
namespace detail {
inline void noopDelete(void*) {}
inline void noopDelete(void* /*unused*/) {}
} // namespace detail

View File

@ -18,7 +18,7 @@ extern std::atomic<const VulkanImplInterface*> g_vulkan_impl_registry;
class VulkanImplRegistrar {
public:
explicit VulkanImplRegistrar(VulkanImplInterface*);
explicit VulkanImplRegistrar(VulkanImplInterface* /*impl*/);
};
at::Tensor& vulkan_copy_(at::Tensor& self, const at::Tensor& src);

View File

@ -25,15 +25,6 @@ drq
fambench_dlrm
fambench_xlmr
fastNLP_Bert
hf_Albert
hf_Bart
hf_Bert
hf_BigBird
hf_DistilBert
hf_GPT2
hf_Longformer
hf_Reformer
hf_T5
maml
maml_omniglot
mnasnet1_0
@ -60,13 +51,6 @@ soft_actor_critic
speech_transformer
squeezenet1_1
tacotron2
timm_efficientdet
timm_efficientnet
timm_nfnet
timm_regnet
timm_resnest
timm_vision_transformer
timm_vovnet
tts_angular
vgg16
vision_maskrcnn

View File

@ -23,7 +23,6 @@ TORCHBENCH_MODELS: list[str] = [
"resnet50",
"moco",
"llama",
"hf_T5",
]
HUGGINGFACE_MODELS: list[str] = [
"AllenaiLongformerBase",

View File

@ -11,7 +11,6 @@ import pandas as pd
flaky_models = {
"yolov3",
"detectron2_maskrcnn_r_101_c4",
"timm_efficientnet", # see https://github.com/pytorch/pytorch/issues/148699
"XGLMForCausalLM", # discovered in https://github.com/pytorch/pytorch/pull/128148
"moondream", # discovered in https://github.com/pytorch/pytorch/pull/159291
# discovered in https://github.com/pytorch/pytorch/issues/161419. Its not flaky but really hard to repro, so skipping it
@ -40,13 +39,9 @@ def check_accuracy(actual_csv, expected_csv, expected_filename):
"detectron2_fcos_r_50_fpn",
"doctr_det_predictor",
"doctr_reco_predictor",
"hf_BigBird",
"hf_Longformer",
"hf_Reformer",
"hf_Roberta_base",
"hf_T5",
"hf_T5_base",
"hf_T5_generate",
"dpn107",
"fbnetv3_b",
"levit_128",
"llava",
"microbench_unbacked_tolist_sum",
"mnasnet1_0",
@ -63,12 +58,7 @@ def check_accuracy(actual_csv, expected_csv, expected_filename):
"squeezenet1_1",
"stable_diffusion_text_encoder",
"stable_diffusion_unet",
"timm_efficientdet",
"timm_efficientnet",
"timm_nfnet",
"timm_regnet",
"timm_resnest",
"timm_vovnet",
"swsl_resnext101_32x16d",
"torchrec_dlrm",
"vgg16",
# LLM

View File

@ -36,12 +36,7 @@ def check_graph_breaks(actual_csv, expected_csv, expected_filename):
"detectron2_fcos_r_50_fpn",
"doctr_det_predictor",
"doctr_reco_predictor",
"hf_BigBird",
"hf_Longformer",
"hf_Reformer",
"hf_Roberta_base",
"hf_T5",
"hf_T5_base",
"levit_128",
"llava",
"microbench_unbacked_tolist_sum",
"resnet50",
@ -51,7 +46,6 @@ def check_graph_breaks(actual_csv, expected_csv, expected_filename):
"stable_diffusion_text_encoder",
"stable_diffusion_unet",
"timm_efficientdet",
"timm_nfnet",
"torchrec_dlrm",
"vgg16",
# LLM

View File

@ -130,70 +130,6 @@ functorch_maml_omniglot,pass,0
hf_Albert,pass,0
hf_Bart,pass,0
hf_Bert,pass,0
hf_Bert_large,pass,0
hf_BigBird,pass,0
hf_DistilBert,pass,0
hf_GPT2,pass,0
hf_GPT2_large,pass_due_to_skip,0
hf_Reformer,pass,5
hf_Roberta_base,pass,0
hf_T5,pass,0
hf_T5_base,eager_fail_to_run,0
hf_T5_generate,pass,7
hf_T5_large,pass_due_to_skip,0
hf_Whisper,pass,0
hf_distil_whisper,pass,0
lennard_jones,pass,0
@ -342,30 +278,6 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,0
timm_vision_transformer,pass,0
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,0
torch_multimodal_clip,pass,0

1 name accuracy graph_breaks
130
131
132
133
134
135
278
279
280
281
282
283

View File

@ -78,62 +78,6 @@ functorch_maml_omniglot,pass,7
hf_Albert,pass,6
hf_Bart,pass,6
hf_Bert,pass,6
hf_Bert_large,pass,6
hf_BigBird,pass,6
hf_DistilBert,pass,6
hf_GPT2,pass,8
hf_GPT2_large,pass_due_to_skip,0
hf_Reformer,pass,20
hf_Roberta_base,pass,6
hf_T5_base,eager_2nd_run_OOM,0
hf_T5_large,pass_due_to_skip,0
hf_Whisper,pass,6
hf_distil_whisper,model_fail_to_load,0
lennard_jones,pass,7
@ -250,30 +194,6 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientnet,fail_accuracy,7
timm_regnet,pass,7
timm_resnest,pass,6
timm_vision_transformer,pass,6
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,6
torch_multimodal_clip,pass,7

1 name accuracy graph_breaks
78
79
80
81
82
83
194
195
196
197
198
199

View File

@ -118,62 +118,6 @@ functorch_maml_omniglot,pass,0
hf_Albert,pass,0
hf_Bart,pass,0
hf_Bert,pass,0
hf_Bert_large,pass,0
hf_BigBird,fail_accuracy,0
hf_DistilBert,pass,0
hf_GPT2,pass,0
hf_GPT2_large,pass_due_to_skip,0
hf_Roberta_base,pass,0
hf_T5,pass,0
hf_T5_base,eager_fail_to_run,0
hf_T5_large,pass_due_to_skip,0
hf_Whisper,pass,0
hf_distil_whisper,pass,0
lennard_jones,pass,0
@ -314,30 +258,6 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,0
timm_vision_transformer,pass,0
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,0
torch_multimodal_clip,pass,0

1 name accuracy graph_breaks
118
119
120
121
122
123
258
259
260
261
262
263

View File

@ -114,58 +114,6 @@ functorch_maml_omniglot,pass,0
hf_Albert,pass,0
hf_Bart,pass,0
hf_Bert,pass,0
hf_Bert_large,pass,0
hf_BigBird,pass,0
hf_DistilBert,pass,0
hf_GPT2,pass,0
hf_GPT2_large,pass_due_to_skip,0
hf_Roberta_base,pass,0
hf_T5,pass,0
hf_T5_base,pass,0
hf_T5_large,pass_due_to_skip,0
hf_distil_whisper,pass,0
lennard_jones,pass,0
@ -278,38 +226,6 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientdet,model_fail_to_load,0
timm_efficientnet,pass,0
timm_nfnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,0
timm_vision_transformer,pass,0
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,0
torch_multimodal_clip,pass,0

1 name accuracy graph_breaks
114
115
116
117
118
119
226
227
228
229
230
231

View File

@ -114,58 +114,6 @@ functorch_maml_omniglot,pass,0
hf_Albert,pass,0
hf_Bart,pass,0
hf_Bert,pass,0
hf_Bert_large,pass,0
hf_BigBird,pass,0
hf_DistilBert,pass,0
hf_GPT2,pass,0
hf_GPT2_large,pass_due_to_skip,0
hf_Roberta_base,pass,0
hf_T5,pass,0
hf_T5_base,pass,0
hf_T5_large,pass_due_to_skip,0
hf_distil_whisper,pass,0
lennard_jones,pass,0
@ -278,38 +226,6 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientdet,model_fail_to_load,0
timm_efficientnet,pass,0
timm_nfnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,0
timm_vision_transformer,pass,0
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,0
torch_multimodal_clip,pass,0

1 name accuracy graph_breaks
114
115
116
117
118
119
226
227
228
229
230
231

View File

@ -122,66 +122,6 @@ functorch_maml_omniglot,pass,0
hf_Albert,pass,0
hf_Bart,pass,0
hf_Bert,pass,0
hf_Bert_large,pass,0
hf_BigBird,pass,27
hf_DistilBert,pass,0
hf_GPT2,pass,0
hf_GPT2_large,pass_due_to_skip,0
hf_Longformer,pass,4
hf_Reformer,pass,5
hf_Roberta_base,pass,0
hf_T5,pass,0
hf_T5_base,pass,0
hf_T5_large,pass_due_to_skip,0
hf_distil_whisper,pass,0
lennard_jones,pass,0
@ -302,38 +242,6 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientdet,model_fail_to_load,0
timm_efficientnet,pass,0
timm_nfnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,0
timm_vision_transformer,pass,0
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,0
torch_multimodal_clip,pass,0

1 name accuracy graph_breaks
122
123
124
125
126
127
242
243
244
245
246
247

View File

@ -122,66 +122,6 @@ functorch_maml_omniglot,pass,0
hf_Albert,pass,0
hf_Bart,pass,0
hf_Bert,pass,0
hf_Bert_large,pass,0
hf_BigBird,pass,27
hf_DistilBert,pass,0
hf_GPT2,pass,0
hf_GPT2_large,pass_due_to_skip,0
hf_Longformer,pass,4
hf_Reformer,pass,5
hf_Roberta_base,pass,0
hf_T5,pass,0
hf_T5_base,pass,0
hf_T5_large,pass_due_to_skip,0
hf_distil_whisper,pass,0
lennard_jones,pass,0
@ -302,38 +242,6 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientdet,model_fail_to_load,0
timm_efficientnet,pass,0
timm_nfnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,0
timm_vision_transformer,pass,0
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,0
torch_multimodal_clip,pass,0

1 name accuracy graph_breaks
122
123
124
125
126
127
242
243
244
245
246
247

View File

@ -122,66 +122,6 @@ functorch_maml_omniglot,pass,0
hf_Albert,pass,0
hf_Bart,pass,0
hf_Bert,pass,0
hf_Bert_large,pass,0
hf_BigBird,pass,27
hf_DistilBert,pass,0
hf_GPT2,pass,0
hf_GPT2_large,pass_due_to_skip,0
hf_Longformer,pass,4
hf_Reformer,pass,5
hf_Roberta_base,pass,0
hf_T5,pass,0
hf_T5_base,pass,0
hf_T5_large,pass_due_to_skip,0
hf_distil_whisper,pass,0
lennard_jones,pass,0
@ -302,38 +242,6 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientdet,model_fail_to_load,0
timm_efficientnet,pass,0
timm_nfnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,0
timm_vision_transformer,pass,0
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,0
torch_multimodal_clip,pass,0

1 name accuracy graph_breaks
122
123
124
125
126
127
242
243
244
245
246
247

View File

@ -130,70 +130,6 @@ functorch_maml_omniglot,pass,0
hf_Albert,pass,0
hf_Bart,pass,0
hf_Bert,pass,0
hf_Bert_large,pass,0
hf_BigBird,pass,0
hf_DistilBert,pass,0
hf_GPT2,pass,0
hf_GPT2_large,pass_due_to_skip,0
hf_Reformer,pass,5
hf_Roberta_base,pass,0
hf_T5,pass,0
hf_T5_base,eager_fail_to_run,0
hf_T5_generate,pass,7
hf_T5_large,pass_due_to_skip,0
hf_Whisper,pass,0
hf_distil_whisper,pass,0
lennard_jones,pass,0
@ -342,30 +278,6 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,0
timm_vision_transformer,pass,0
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,0
torch_multimodal_clip,pass,0

1 name accuracy graph_breaks
130
131
132
133
134
135
278
279
280
281
282
283

View File

@ -78,62 +78,6 @@ functorch_maml_omniglot,pass,7
hf_Albert,pass,6
hf_Bart,pass,6
hf_Bert,pass,6
hf_Bert_large,pass,6
hf_BigBird,pass,6
hf_DistilBert,pass,6
hf_GPT2,pass,8
hf_GPT2_large,pass_due_to_skip,0
hf_Reformer,pass,20
hf_Roberta_base,pass,6
hf_T5_base,eager_2nd_run_OOM,0
hf_T5_large,pass_due_to_skip,0
hf_Whisper,pass,6
hf_distil_whisper,model_fail_to_load,0
lennard_jones,pass,7
@ -246,30 +190,6 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientnet,pass,7
timm_regnet,pass,7
timm_resnest,pass,6
timm_vision_transformer,pass,6
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,6
torch_multimodal_clip,pass,7

1 name accuracy graph_breaks
78
79
80
81
82
83
190
191
192
193
194
195

View File

@ -98,58 +98,6 @@ functorch_maml_omniglot,pass,0
hf_Albert,pass,0
hf_Bart,pass,0
hf_Bert,pass,0
hf_Bert_large,pass,0
hf_BigBird,pass,0
hf_DistilBert,pass,0
hf_GPT2,pass,0
hf_GPT2_large,pass_due_to_skip,0
hf_Roberta_base,pass,0
hf_T5,pass,0
hf_T5_base,pass,0
hf_T5_large,pass_due_to_skip,0
hf_distil_whisper,pass,0
lennard_jones,pass,0
@ -262,38 +210,6 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientdet,model_fail_to_load,0
timm_efficientnet,pass,0
timm_nfnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,0
timm_vision_transformer,pass,0
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,0
torch_multimodal_clip,pass,0

1 name accuracy graph_breaks
98
99
100
101
102
103
210
211
212
213
214
215

View File

@ -98,58 +98,6 @@ functorch_maml_omniglot,pass,0
hf_Albert,pass,0
hf_Bart,pass,0
hf_Bert,pass,0
hf_Bert_large,pass,0
hf_BigBird,pass,0
hf_DistilBert,pass,0
hf_GPT2,pass,0
hf_GPT2_large,pass_due_to_skip,0
hf_Roberta_base,pass,0
hf_T5,pass,0
hf_T5_base,pass,0
hf_T5_large,pass_due_to_skip,0
hf_distil_whisper,pass,0
lennard_jones,pass,0
@ -262,38 +210,6 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientdet,model_fail_to_load,0
timm_efficientnet,pass,0
timm_nfnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,0
timm_vision_transformer,pass,0
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,0
torch_multimodal_clip,pass,0

1 name accuracy graph_breaks
98
99
100
101
102
103
210
211
212
213
214
215

View File

@ -106,66 +106,6 @@ functorch_maml_omniglot,pass,0
hf_Albert,pass,0
hf_Bart,pass,0
hf_Bert,pass,0
hf_Bert_large,pass,0
hf_BigBird,pass,27
hf_DistilBert,pass,0
hf_GPT2,pass,0
hf_GPT2_large,pass_due_to_skip,0
hf_Longformer,pass,4
hf_Reformer,pass,5
hf_Roberta_base,pass,0
hf_T5,pass,0
hf_T5_base,pass,0
hf_T5_large,pass_due_to_skip,0
hf_distil_whisper,pass,0
lennard_jones,pass,0
@ -286,38 +226,6 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientdet,model_fail_to_load,0
timm_efficientnet,pass,0
timm_nfnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,0
timm_vision_transformer,pass,0
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,0
torch_multimodal_clip,pass,0

1 name accuracy graph_breaks
106
107
108
109
110
111
226
227
228
229
230
231

View File

@ -122,66 +122,6 @@ functorch_maml_omniglot,pass,0
hf_Albert,pass,0
hf_Bart,pass,0
hf_Bert,pass,0
hf_Bert_large,pass,0
hf_BigBird,pass,25
hf_DistilBert,pass,0
hf_GPT2,pass,0
hf_GPT2_large,pass_due_to_skip,0
hf_Longformer,pass,4
hf_Reformer,pass,8
hf_Roberta_base,pass,0
hf_T5,pass,0
hf_T5_base,pass,0
hf_T5_large,pass_due_to_skip,0
hf_distil_whisper,pass,0
lennard_jones,pass,0
@ -302,38 +242,6 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientdet,model_fail_to_load,0
timm_efficientnet,pass,0
timm_nfnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,0
timm_vision_transformer,pass,0
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,0
torch_multimodal_clip,pass,3

1 name accuracy graph_breaks
122
123
124
125
126
127
242
243
244
245
246
247

Some files were not shown because too many files have changed in this diff Show More