Compare commits

...

175 Commits

Author SHA1 Message Date
e752a29afd revert expectedFailure on TestFSDPWithEP::test_e2e 2025-10-16 10:55:03 +00:00
36b622bb72 skip test_cupy_as_tensor 2025-10-16 10:55:03 +00:00
83a04f38a4 xfail test_fsdp_ep.py::TestFSDPWithEP::test_e2e 2025-10-16 10:55:03 +00:00
6579829bee fix lint 2025-10-16 10:55:03 +00:00
2b856676f3 skip test on SM100OrLater 2025-10-16 10:55:03 +00:00
5746261c97 Resolve more merge conflicts 2025-10-16 10:55:03 +00:00
b3c94fd0fc Resolve merge conflicts 2025-10-16 10:55:03 +00:00
6fd366b2c7 Further increase timeout.
test_gather_object* require exact world size (4) to succeed
Likely, all NVSHMEM Triton tests would fail. Skipping all for now.
2025-10-16 10:55:03 +00:00
fe25f6ab59 rebase 2025-10-16 10:55:03 +00:00
ca89e5732f Change runner to linux.12xlarge.memory
and fix lint
2025-10-16 10:55:03 +00:00
f12cb265d4 Fix Lint 2025-10-16 10:55:03 +00:00
7dc6bf5377 Skip distributed unit tests with their issue numbers. 2025-10-16 10:55:03 +00:00
e5ba464808 Undo changes to test_3_level_hierarchical_model_averager 2025-10-16 10:55:03 +00:00
7d95185044 Check if using linux.dgx.b200.8 would avoid the fsdp issues. 2025-10-16 10:55:03 +00:00
77fb3c1cac Revert "Use linux.dgx.b200.8 since linux.dgx.b200.4 is not available."
This reverts commit 1ba7b2e20b804f337aa483e0228228273ade7396.
2025-10-16 10:55:03 +00:00
11a3d1d87b Use linux.dgx.b200.8 since linux.dgx.b200.4 is not available. 2025-10-16 10:55:03 +00:00
8c6d9feb26 Add b200 distributed job
Add aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
to test job on B200 runner

Fix runner usage - this distributed job should use linux.dgx.b200.8

@require_world_size(4) does not translate to world_size==4.
For example, 8GPU B200 runner would make the unit test run with
world_size = 8.
Tested with:
TEMP_DIR=/tmp BACKEND=nccl WORLD_SIZE=4 pytest -v test/distributed/test_distributed_spawn.py -k test_new_subgroups_world_size_not_divisible_by_group_size

Add require_exact_world_size for distributed unit tests that implicitly
require world_size of 4 to pass.

Fix test_3_level_hierarchical_model_averager

Mimic H100 distributed, run distributed less often because it takes
quite long to finish all the tests (easily 4hours+ for each of the 3
shards).
2025-10-16 10:55:03 +00:00
003dd13073 [dynamo, guards] Better error messages when generated guard fails on the same frame (#165242)
Not sure what exactly we want to have in the message, but that's easy to adjust. I tried to find a reliable test to reproduce this message (happens only when a guard fails right after it's created), but I ended up mocking a `guard_manager.check` function to return `False` to trigger this behavior. I think that's fine, because any other case that we pick (like datetime.now()), we want to patch one day anyway, so every time we make the next patch, will need to chase for another repro test

@williamwen42

Fixes #164990

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165242
Approved by: https://github.com/williamwen42
2025-10-16 01:05:31 +00:00
c2bd41ac9f Build vLLM nightly wheels for CUDA 13.0 (#163239)
Now that https://github.com/vllm-project/vllm/pull/24599 has been merged
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163239
Approved by: https://github.com/malfet, https://github.com/atalman
2025-10-16 01:03:26 +00:00
ca8bd5dbed Move toString(ScalarType) and ScalarType ostream operator to headeronly (#164405)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164405
Approved by: https://github.com/Skylion007, https://github.com/janeyx99
ghstack dependencies: #164350, #164354
2025-10-16 00:55:43 +00:00
26f3803433 Remove workaround to old CUDA bug (#164354)
As in the title.

A check for https://github.com/pytorch/pytorch/issues/164348 to see if the workaround can be removed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164354
Approved by: https://github.com/janeyx99, https://github.com/ngimel, https://github.com/malfet, https://github.com/jeffdaily
ghstack dependencies: #164350
2025-10-16 00:55:43 +00:00
48064acf37 Move AT_FORALL_... macros and ScalarTypeToCPPTypeT to headeronly (#164350)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164350
Approved by: https://github.com/janeyx99
2025-10-16 00:55:42 +00:00
e5a9c247bc [Fix XPU CI] [Inductor UT] Fix test cases broken by community. (#165406)
Fixes #163159, Fixes #164098, Fixes #164097, Fixes #164099, Fixes #165025

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165406
Approved by: https://github.com/EikanWang, https://github.com/jansel
2025-10-16 00:53:32 +00:00
36371b8ec7 [ATen] Fix CUDA reduction warp shuffle order (#164790)
Typical warp shuffle reduction has the following pattern:
<img width="1138" height="501" alt="image" src="https://github.com/user-attachments/assets/3bd176dc-0ad2-4df6-90c7-06e467337166" />

which is exhibited in Triton generated by torch.compile:
<img width="663" height="403" alt="image" src="https://github.com/user-attachments/assets/7f9f36cd-b9eb-44c1-879e-b469668a2ea8" />

Switch the warp shuffle order to make bitwise equivalence between the 2 easier.
PTX difference between old and new, we see a few extra instructions: https://www.diffchecker.com/h6ly3INC/

Comparing the performance on different reduction operations, we see minimal differences. New represents the changes in this PR, old represents the past warp shuffle order:
```
Tensor Shape              Operation            New all dims (ms)       New dim=0 (ms)      New dim=1 (ms)     Old all dims (ms)    Old dim=0 (ms)      Old dim=1 (ms)
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)              mean                 0.015817             0.016259             0.013642             0.015990             0.016258             0.013631
(1024, 1024)              sum                  0.015917             0.015906             0.013359             0.015707             0.016266             0.013226
(1024, 1024)              min                  0.016021             0.024625             0.015631             0.015761             0.024485             0.015317
(1024, 1024)              max                  0.016349             0.024971             0.015972             0.015771             0.025001             0.015314
(1024, 1024)              argmin               0.018070             0.024448             0.015578             0.018135             0.025370             0.015322
(1024, 1024)              argmax               0.018427             0.024859             0.015932             0.018164             0.024452             0.015639
(1024, 1024)              var                  0.020078             0.026413             0.020295             0.020199             0.026381             0.020214
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)              mean                 0.023826             0.023726             0.022273             0.023236             0.023776             0.022248
(2048, 2048)              sum                  0.023840             0.023355             0.021974             0.023294             0.023354             0.021884
(2048, 2048)              min                  0.024519             0.041263             0.024620             0.023292             0.041491             0.024358
(2048, 2048)              max                  0.024509             0.041670             0.024277             0.023334             0.041231             0.024395
(2048, 2048)              argmin               0.026125             0.041282             0.024567             0.026772             0.041773             0.024296
(2048, 2048)              argmax               0.026117             0.041487             0.024572             0.026412             0.041477             0.024273
(2048, 2048)              var                  0.026603             0.048581             0.031308             0.027587             0.048603             0.030860
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)              mean                 0.053927             0.057070             0.054073             0.053028             0.057544             0.053935
(4096, 4096)              sum                  0.053604             0.057410             0.054451             0.053076             0.057033             0.054266
(4096, 4096)              min                  0.054293             0.109122             0.058363             0.053821             0.108689             0.058382
(4096, 4096)              max                  0.054258             0.108035             0.058703             0.053492             0.110552             0.058376
(4096, 4096)              argmin               0.056805             0.111167             0.058301             0.056836             0.112325             0.058292
(4096, 4096)              argmax               0.056488             0.110958             0.058636             0.056844             0.111000             0.057928
(4096, 4096)              var                  0.058936             0.141755             0.068693             0.059735             0.141284             0.068500
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)              mean                 0.145552             0.148082             0.138647             0.145364             0.147818             0.138207
(8192, 8192)              sum                  0.145985             0.147900             0.138714             0.145755             0.148031             0.138616
(8192, 8192)              min                  0.146566             0.205359             0.192739             0.145611             0.205237             0.182335
(8192, 8192)              max                  0.146526             0.204844             0.193050             0.146073             0.205457             0.182697
(8192, 8192)              argmin               0.150190             0.206605             0.192543             0.150654             0.206847             0.182007
(8192, 8192)              argmax               0.150481             0.206368             0.192535             0.150845             0.206430             0.182022
(8192, 8192)              var                  0.150884             0.184546             0.203900             0.151594             0.184172             0.197983
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1, 1024, 128)            mean                 0.014293             0.008119             0.014533             0.013861             0.008022             0.014449
(1, 1024, 128)            sum                  0.014039             0.007877             0.014111             0.014219             0.008227             0.014045
(1, 1024, 128)            min                  0.014159             0.011354             0.023493             0.014271             0.010862             0.023644
(1, 1024, 128)            max                  0.014154             0.011027             0.023368             0.014259             0.011234             0.023692
(1, 1024, 128)            argmin               0.016403             0.005677             0.023328             0.016273             0.005683             0.024073
(1, 1024, 128)            argmax               0.016734             0.005675             0.023437             0.016580             0.005318             0.023331
(1, 1024, 128)            var                  0.018338             0.009549             0.025538             0.018528             0.009391             0.024777
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(5, 1024, 128)            mean                 0.014873             0.010131             0.015546             0.015123             0.010131             0.015481
(5, 1024, 128)            sum                  0.015334             0.009673             0.015824             0.014736             0.009671             0.015438
(5, 1024, 128)            min                  0.015047             0.013252             0.024573             0.014803             0.013163             0.024551
(5, 1024, 128)            max                  0.015050             0.013339             0.024197             0.014810             0.013525             0.024230
(5, 1024, 128)            argmin               0.017341             0.012737             0.024306             0.017471             0.012379             0.024991
(5, 1024, 128)            argmax               0.017345             0.012411             0.024421             0.017422             0.012471             0.024237
(5, 1024, 128)            var                  0.019973             0.011453             0.026188             0.020050             0.011438             0.026282
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(10, 1024, 128)           mean                 0.016976             0.011575             0.016831             0.016722             0.011927             0.017173
(10, 1024, 128)           sum                  0.017039             0.011841             0.017159             0.016385             0.011860             0.016753
(10, 1024, 128)           min                  0.017036             0.015331             0.026770             0.016944             0.015205             0.027166
(10, 1024, 128)           max                  0.017369             0.015348             0.027077             0.016531             0.015716             0.026819
(10, 1024, 128)           argmin               0.019203             0.014447             0.026813             0.018994             0.014497             0.027313
(10, 1024, 128)           argmax               0.019563             0.014795             0.027140             0.019460             0.014912             0.026733
(10, 1024, 128)           var                  0.020529             0.014316             0.030405             0.020719             0.013960             0.029964
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(100, 1024, 128)          mean                 0.045046             0.039168             0.046082             0.044839             0.039217             0.045782
(100, 1024, 128)          sum                  0.045094             0.039150             0.045777             0.044496             0.039542             0.046083
(100, 1024, 128)          min                  0.045768             0.054466             0.076244             0.044915             0.053943             0.076599
(100, 1024, 128)          max                  0.045748             0.054459             0.076188             0.044931             0.053949             0.076856
(100, 1024, 128)          argmin               0.048275             0.054046             0.076647             0.048694             0.054105             0.077004
(100, 1024, 128)          argmax               0.048267             0.054395             0.077401             0.048691             0.054131             0.076751
(100, 1024, 128)          var                  0.049710             0.043254             0.083077             0.050971             0.043251             0.082378
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000, 100)         mean                 0.202312             0.196723             0.197765             0.201774             0.196641             0.197459
(1000, 1000, 100)         sum                  0.202651             0.196682             0.197736             0.202175             0.196313             0.197523
(1000, 1000, 100)         min                  0.203022             0.264762             0.269200             0.202729             0.264129             0.268694
(1000, 1000, 100)         max                  0.202864             0.264396             0.269388             0.202486             0.263896             0.268720
(1000, 1000, 100)         argmin               0.226727             0.263781             0.268651             0.226597             0.264676             0.268983
(1000, 1000, 100)         argmax               0.226412             0.264469             0.269090             0.226570             0.264595             0.269178
(1000, 1000, 100)         var                  0.243223             0.204079             0.216096             0.241942             0.204079             0.215925
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(10000, 100)              mean                 0.016193             0.020277             0.014316             0.016152             0.020324             0.013712
(10000, 100)              sum                  0.016289             0.020237             0.014034             0.016168             0.020265             0.013708
(10000, 100)              min                  0.016046             0.030872             0.019609             0.016208             0.030867             0.018627
(10000, 100)              max                  0.016369             0.030835             0.019257             0.016218             0.030861             0.018209
(10000, 100)              argmin               0.017957             0.031171             0.019517             0.018050             0.031556             0.018077
(10000, 100)              argmax               0.017961             0.031658             0.019521             0.018060             0.031564             0.018087
(10000, 100)              var                  0.020393             0.035652             0.019339             0.020144             0.035987             0.019171
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(100000, 10)              mean                 0.015718             0.016576             0.016555             0.015999             0.016246             0.014869
(100000, 10)              sum                  0.015833             0.016247             0.016572             0.016007             0.016627             0.014872
(100000, 10)              min                  0.015888             0.020510             0.023920             0.015671             0.020821             0.021417
(100000, 10)              max                  0.015889             0.020479             0.023918             0.016077             0.020386             0.021421
(100000, 10)              argmin               0.018233             0.020863             0.023647             0.017574             0.020864             0.021103
(100000, 10)              argmax               0.017896             0.020527             0.023296             0.017569             0.020447             0.021098
(100000, 10)              var                  0.020005             0.024198             0.024372             0.020075             0.024167             0.022415
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 1023)        mean                 1.874816             1.963506             1.903909             1.873279             1.963859             1.903230
(1023, 1023, 1023)        sum                  1.875030             1.965716             1.902458             1.873566             1.960730             1.901642
(1023, 1023, 1023)        min                  1.878563             2.473455             2.179092             1.875174             2.482086             2.183027
(1023, 1023, 1023)        max                  1.879128             2.474803             2.178895             1.874831             2.482253             2.183884
(1023, 1023, 1023)        argmin               1.921800             2.476629             2.174831             1.923987             2.472641             2.170453
(1023, 1023, 1023)        argmax               1.922605             2.476688             2.177927             1.923366             2.472808             2.172979
(1023, 1023, 1023)        var                  1.972606             3.088695             2.758797             1.978679             3.095658             2.762243
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 255)         mean                 0.489984             0.500954             0.492957             0.489891             0.500654             0.491971
(1023, 1023, 255)         sum                  0.490228             0.500764             0.492289             0.489624             0.501089             0.492824
(1023, 1023, 255)         min                  0.491457             0.563560             0.553334             0.490355             0.564709             0.554754
(1023, 1023, 255)         max                  0.491396             0.563628             0.553345             0.490017             0.565004             0.554947
(1023, 1023, 255)         argmin               0.503666             0.561512             0.551831             0.503845             0.560972             0.551017
(1023, 1023, 255)         argmax               0.503602             0.561185             0.551407             0.504328             0.561267             0.551448
(1023, 1023, 255)         var                  0.510844             0.709452             0.701630             0.512693             0.710365             0.701965
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 377)         mean                 0.707439             0.727646             0.712019             0.706769             0.727101             0.711632
(1023, 1023, 377)         sum                  0.707780             0.727453             0.711554             0.706807             0.726656             0.711729
(1023, 1023, 377)         min                  0.709423             0.819809             0.794379             0.707847             0.822086             0.796664
(1023, 1023, 377)         max                  0.709297             0.819780             0.794308             0.707566             0.821913             0.796690
(1023, 1023, 377)         argmin               0.725028             0.817088             0.791695             0.726039             0.816445             0.790828
(1023, 1023, 377)         argmax               0.725301             0.817011             0.791420             0.726040             0.816917             0.791143
(1023, 1023, 377)         var                  0.740859             1.034165             1.006712             0.743413             1.035506             1.007638
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164790
Approved by: https://github.com/ngimel, https://github.com/eqy
ghstack dependencies: #165494
2025-10-15 23:54:51 +00:00
7e6721fb0a [BE] Remove confusing opbenchmark-on-demand-build (#165583)
As it doesn't have a test shard, so what's the point or running the build? Was added in https://github.com/pytorch/pytorch/pull/143733 and looks like test shard never existed for it

Moreover, allow one to specify benchmark size as argument, so one
technically can do a workflow dispatch with different opbenchmark sizes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165583
Approved by: https://github.com/huydhn
2025-10-15 23:48:28 +00:00
901bbcba12 Gate division bitwise numerics under a flag (#165566)
https://github.com/pytorch/pytorch/pull/164144 ensures that division for compile is bitwise equivalent with eager. However, in https://github.com/pytorch/pytorch/issues/164301, the kernel performance is regressed.

On B200:
With standard triton `/`:
6511 GB/s

With triton `div_rn`:
4692 GB/s

Further investigation is required for the generated PTX to see why there is such a large slowdown. For now, enable bitwise equivalent results under `TORCHINDUCTOR_EMULATE_DIVISION_ROUNDING` similar to emulate_precision_cast

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165566
Approved by: https://github.com/ngimel, https://github.com/eellison
2025-10-15 23:41:01 +00:00
febb603230 [Inductor][CuTeDSL] Move load_template up two directories (#165347) (#165576)
Summary:

Moves the function used to load CuTeDSL Jinja templates up one level out of the flex attention folder. This way it can be used for more generate Inductor templates in the future.

Test Plan: `INDUCTOR_TEST_DISABLE_FRESH_CACHE=1 TORCHINDUCTOR_CACHE_DIR=~/cutetest buck2 run mode/opt //caffe2/test/inductor:cutedsl_grouped_mm -c fbcode.nvcc_arch=b200a -c fbcode.enable_gpu_sections=true -c fbcode.platform010_cuda_version=12.8`

Reviewed By: drisspg

Differential Revision: D84527470

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165576
Approved by: https://github.com/jananisriram
2025-10-15 23:37:55 +00:00
568d2f3ae7 [Dynamo][Logging] Add sources/types to LazyVariableTracker logging (#165402)
Fixes #162860

This task add the variable source attrition to LazyVariableTracker when output trace bytecode

Test plan -- test/dynamo/test_error_messages.py ErrorMessagesTest.test_variable_tracker_source_attribution

The output is as specified in the prior mentioned Github issue.

<img width="961" height="59" alt="Screenshot 2025-10-13 at 10 19 44 PM" src="https://github.com/user-attachments/assets/fb27da3f-d00b-437b-bf2e-52e892572cd7" />

This is specifically for the log setup with ``TORCH_LOGS=trace_bytecode``

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165402
Approved by: https://github.com/Lucaskabela, https://github.com/williamwen42

Co-authored-by: William Wen <williamwen@meta.com>
2025-10-15 23:23:09 +00:00
b54e466fd0 Megacache integration (#163533)
This diff adds megacache integration for DynamoCache.

Because DynamoCache requires lazy serialization, i.e. it can only be serialized once all relevant backends have been compiled and we're ready for a save, we actually do the DynamoCache saving only on a call to `torch.compiler.save_cache_artifacts`.

Differential Revision: [D82735763](https://our.internmc.facebook.com/intern/diff/D82735763/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163533
Approved by: https://github.com/oulgen, https://github.com/zhxchen17
2025-10-15 22:49:15 +00:00
53f9ae0e50 [ROCm] new implementation of upsample_bilinear2d_backward (#164572)
Changed the implementation from an output-based approach to an input-based one to remove `atomicAdd` operations, and it appears to deliver at least a 20× speedup.

The changes are from Yu-Yun <YuYun.Chang@amd.com>.

# Summary: Refactor of the implementation of the `upsample_bilinear2d_backward` opertion on MI300X/MI325X
- The original "scatter-add" approach
  - Each thread, representing an output pixel, scattered gradient contributions to four input pixels, using costly atomic operations on MI300X/MI325X GPUs.
- The new "gather-sum" approach
  - Each thread is responsible for a single input pixel and gathers all relevant gradient contributions from a small, calculated region of the output tensor (done by the `compute_output_range` device function).
# Breakdown of the code changes
- Inversion of the parallelization strategy of the kernel function `upsample_bilinear2d_backward_out_frame`
  - Originally, the main kernel loop was parallelized over the number of elements in the output gradient tensor (`const size_t o_numel = nc * width2 * height2;`).
    - Each thread processed one output pixel.
  - The new loop is parallelized over the number of elements in the input gradient tensor (`const size_t i_numel = nc * height1 * width1;`).
    - Each thread is responsible for calculating the final gradient for a single input pixel.
  - The kernel launch changes accordingly in the function `upsample_bilinear2d_backward_out_cuda_template`.
- Added a device function for calculating the range of output pixels that could have possibly used that the input pixel (`input_pos`) during the forward pass interpolation
  - This is essentially the mathematical inverse of the forward pass.
  - This function tries to prune a thread's search space so that it only needs to inspect a small, local window of the output tensor.
- Gradient calculation approach switching from "scatter-add" to "gather-sum"
  - Scatter-add
    - For each output pixel, the thread calculated 4 gradient contributions and use `fastAtomicAdd` 4 times to add these values to 4 different (and potentially highly contended) memory locations in the input gradient tensor.
  - Gather-sum
    - A thread responsible for one input pixel calls `compute_output_range` to determine the small rectangular region of output pixels that influence the input's final gradient value.
    - The thread iterates through this region, and for each output pixel in the regionre, it re-calculates the interpolation weights to determine the exact contribution to its specific input pixel.
    - All these contributions are accumulated into a private, per-thread register variable (`accscalar_t grad_sum = 0;`).
      - W/o any gloabl memory access, this accumulation is extremely fast.
    - When the loops are done, the thread performs a single, direct write (non-atomic) of the final summed gradient to its designated location in global memory (`idata[index] = static_cast<scalar_t>(grad_sum);`).
# Why performance gets boosted
- Analysis of the root cause of performance drop
  - Ref. (internal only) - https://amd.atlassian.net/wiki/spaces/~glencao2/pages/1140493327/PyTorch__upsample_bilinear2d_backward
- First and foremost, elimination of the contention of atomic operations
  - Many parallel threads called `atomicAdd` frequently attempting to update the exact same memory location in the input gradient tensor at the same time.
    - The GPU's memory controler has to serialize these operations, effectively nullifying the benefit of parallel capability at those contention points.
  - MI300X/MI325X chiplet-based CDNA 3 architeture amplified the issue.
    - When contending threads reside on different XCDs, resolving the atomic operation requires high-latency coherence traffic across the Infinity Fabric interconnect.
  - The implementation change eliminates hardware-level serialization and cross-chiplet coherence traffic caused by many `atomicAdd`.
- Improved memory access pattern and locality
  - Write coalescing
    - The regular sum writes `idata[index] = static_cast<scalar_t>(grad_sum);` can be perfectly coalesced by GPUs.
  - Read locality
    - Even though there are many (potentially repeated) reads from the output tensor (`static_cast<accscalar_t>(odata[output_idx])`), these are highly cache-friendly, meaning the data for one thread is likely to be in the L1 or L2 cache already due to an access from a neighboring thread.
- Trade-off: computation for memory synchronization
  - The recalculation of interpolation weights fits well on high-computational-throughput modern GPUs like MI300X/MI325X.
  - Removal of atomic operations avoids expensive memory synchronization.

---

Optimizations of `grid_sampler_2d_backward` will be addressed in a separate PR.
Doc for reference: (internal only) https://amd.atlassian.net/wiki/spaces/~glencao2/pages/1162750701/PyTorch__grid_sampler_2d_backward

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164572
Approved by: https://github.com/jeffdaily
2025-10-15 22:35:43 +00:00
b42fe389b9 ROCm unit tests enablement (#165366)
Enables:
test_cuda.py::TestCuda::test_streaming_backwards_multiple_streams
test_cuda.py::TestCuda::test_graph_make_graphed_callables_with_amp_cache_disabled_allow_unused_input
test_cuda.py::TestCuda::test_graph_make_graphed_callables_without_amp_allow_unused_input
test_matmul_cuda.py::TestMatmulCudaCUDA::test_cublas_baddbmm_large_input_1_10000_10000_10000_cuda_bfloat16
test_matmul_cuda.py::TestMatmulCudaCUDA::test_cublas_baddbmm_large_input_1_10000_10000_10000_cuda_float16
test_matmul_cuda.py::TestMatmulCudaCUDA::test_cublas_baddbmm_large_input_1_10000_10000_10000_cuda_float32
test_matmul_cuda.py::TestMatmulCudaCUDA::test_cublas_baddbmm_large_input_1_10000_1000_10000_cuda_bfloat16
test_matmul_cuda.py::TestMatmulCudaCUDA::test_cublas_baddbmm_large_input_1_10000_1000_10000_cuda_float16
test_matmul_cuda.py::TestMatmulCudaCUDA::test_cublas_baddbmm_large_input_1_10000_1000_10000_cuda_float32
test_matmul_cuda.py::TestMatmulCudaCUDA::test_cublas_baddbmm_large_input_2_1000_1000_1000_cuda_bfloat16
test_matmul_cuda.py::TestMatmulCudaCUDA::test_cublas_baddbmm_large_input_2_1000_1000_1000_cuda_float16
test_matmul_cuda.py::TestMatmulCudaCUDA::test_cublas_baddbmm_large_input_2_1000_1000_1000_cuda_float32
test_matmul_cuda.py::TestMatmulCudaCUDA::test_cublas_baddbmm_large_input_2_100_100_100_cuda_bfloat16
test_matmul_cuda.py::TestMatmulCudaCUDA::test_cublas_baddbmm_large_input_2_100_100_100_cuda_float16
test_matmul_cuda.py::TestMatmulCudaCUDA::test_cublas_baddbmm_large_input_2_100_100_100_cuda_float32

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165366
Approved by: https://github.com/jeffdaily
2025-10-15 22:35:03 +00:00
66ea76ec44 [ROCm][tunableop] Improvements to tunableop Numerical Check (#163079)
Modified the flag PYTORCH_TUNABLEOP_NUMERICAL_CHECK, so that it accepts the numerical tolerances in the format atol_rtol as compared to the previous 0 and 1. Retains previous functionality with default values as well.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163079
Approved by: https://github.com/naromero77amd, https://github.com/jeffdaily
2025-10-15 22:26:47 +00:00
e787d532b6 tmp fix for compile internal logger issue (#165568)
Summary: Catch runtime exception when garse and scrub uninteresting configs from inductor config

Test Plan: tested locally

Differential Revision: D84727788

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165568
Approved by: https://github.com/luccafong, https://github.com/oulgen
2025-10-15 22:03:16 +00:00
b3f6d49b69 Overlap scheduler improvements (#165318)
Bucketing a number of smallish improvements:

- Account for bucketing in overlap calculation: if an in-flight collective exists with the same bucket key, reduce new collectives estimated time by its latency time
-  Update compute domination so we are ordering based on compute idx, as opposed to compute depth, so we never reorder compute. this makes it a bit easier to reason about memory, and pre-fetching, although we can exploring reordering in the future.
- When we wait on a collective, force all collectives on the same process group as it that were enqueued prior to the collective to wait as well.

Better Memory Handling:
- Pre-fetch limiting - when scheduling collectives for overlap, only pre-fetch up to a certain distance, then schedule off-path collectives (which are typically memory reducing).
- When we are above peak memory, schedule waits.

TODO:
- for each compute node, we know its original memory in the graph. we could limit pre-fetching that goes across peak memory
- By scheduling off-path collectives for overlap, we reduce memory, but if there weren't enough compute for overlap, we need to proactively schedule them. not an issue yet on examples.
- config some hard coded constants, clean up enablement (can do in subsequent pr)

On small llama 2d backward :
578 of 618 potentially hideable collectives hidden
original mem 14.4GB, rescheduled mem, 15.9GB

on forward:
254/256 potentially hideable collectives hidden
original mem 5.8 gb, reshceduled mem 5.8GB

WIP: adding tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165318
Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev
ghstack dependencies: #164738, #164783, #164944, #164945, #165059
2025-10-15 21:58:47 +00:00
bc1f2108d7 [PP] Update backward_counter and fsdp util to schedule class (#165513)
Fixed one issue with FSDP last reshard not being called.

Rest is mostly refactoring, changing some variables to be class variables so they can be used in https://github.com/pytorch/torchtitan/pull/1721

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165513
Approved by: https://github.com/fegin
2025-10-15 21:58:16 +00:00
f071f17911 [Graph Partition] fix partition x memory plan issue (#165514)
For `test_graph_partition_with_memory_plan_reuse`, before this PR, when using graph partition, it would error ([P1992728479](https://www.internalfb.com/phabricator/paste/view/P1992728479)):

```
def partition_0(args):
    ...
    del buf0
    return (buf3, buf4, buf5, buf2, primals_4, )

...

  File "/tmp/torchinductor_boyuan/ww/cwwc7ukfqscg2vy6ankby2fizdb377tvgyx3fwdgddrxe3g47jg6.py", line 132, in partition_0
    return (buf3, buf4, buf5, buf2, primals_4, )
                              ^^^^
NameError: name 'buf2' is not defined. Did you mean: 'buf0'?
```

When not using graph partition, it would work and give the following code ([P1992997521](https://www.internalfb.com/phabricator/paste/view/P1992997521)):

```
def call(self, args):
    ...
    buf2 = buf0; del buf0  # reuse
    ...
```

Note that the issue is buf0 is not reused for buf2 when using graph partition.

Why? Because the codegen runs `run_wrapper_ir_passes` and `memory_plan_reuse`, which pops tailing `MemoryPlanningLine` unless it is in graph output by checking `V.graph.get_output_names()`. However, for graph partition, we should check the output of the current partition instead of the graph before partition.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165514
Approved by: https://github.com/ProExpertProg, https://github.com/eellison
2025-10-15 21:52:16 +00:00
fa1539594b consolidate fw and inference compile paths (#165457)
By design, fw compile and inference compile stages should share a bunch of code; just consolidating the duplication here.

Differential Revision: D84628978

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165457
Approved by: https://github.com/zhxchen17, https://github.com/tugsbayasgalan
2025-10-15 21:33:50 +00:00
dfc8a1c5dd Fix _StridedShard incorrect split (#165533)
https://github.com/pytorch/pytorch/pull/164820 introduced a bug that `_StridedShard` will call parent class `Shard`'s `split_tensor` method, thus results in incorrect data locality. (I think @ezyang spotted this issue, but we have no test to capture this)

Meanwhile, I notice another bug that when we normalize a `_StridedShard`'s placement, it will also trigger parent class `Shard`'s `split_tensor` method because it will create a Shard class [here](0c14f55de6/torch/distributed/tensor/_api.py (L783)). I think we never test `distribute_tensor` for `_StridedShard` before. So I added a test here to compare against ordered shard.

Using classmethod because the _split_tensor logic is different between `Shard` and `_StridedShard`. Basically I want to shard on local tensors without initializing the Shard object:
```
local_tensor = _StridedShard._make_shard_tensor(dim, tensor, mesh, mesh_dim, split_factor=split_factor)
local_tensor = Shard._make_shard_tensor(dim, tensor, mesh, mesh_dim)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165533
Approved by: https://github.com/XilunWu
2025-10-15 20:52:41 +00:00
7f9b745494 [ROCm][tunableop] Modified Online Tuning Mode to add Instant Logging (#163965)
- Added instant logging in online tuning mode, so that each tuned GEMM is instantly written
- Allows us to have saved tuning configs, in cases of crashes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163965
Approved by: https://github.com/naromero77amd, https://github.com/jeffdaily
2025-10-15 20:02:31 +00:00
83f9baf413 [Bugfix][Precompile][vLLM] Support for pickling einops for aot_autograd serialization in vLLM (#165359)
Fixes issue with compiling `Qwen2_5_vl` in https://github.com/vllm-project/vllm/pull/23207 (issue happens with `aot_autograd_cache`)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165359
Approved by: https://github.com/jamesjwu
2025-10-15 20:00:24 +00:00
ffc7552e01 See if we can handle uploading all test data (#165484)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165484
Approved by: https://github.com/izaitsevfb
2025-10-15 19:57:41 +00:00
78f5a1ec60 varlen api (#164502)
**Summary**

Today, the only way to have variable sequence length support in PyTorch attention is through nested tensors [here](https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html#nestedtensor-and-dense-tensor-support). We also want to add an explicit lower-level API that provides variable sequence length support without padding/masking in SDPA.

This PR builds out `varlen_attn`, the public API that users can call for the forward method, and `_varlen_attn`, the private API that calls into the Flash Attention/cuDNN backend.

**Benchmarking**

To benchmark, we compare runtime and TFLOPs against the current SDPA approach with padding.

Settings:

- 1 H100 machine
- `batch_size=8`, `max_seq_len=2048`, `embed_dim=1024`, `num_heads=16`
- dtype `torch.bfloat16`
- `is_causal=False`
- for variable length, we set sequences to be random multiples of 64 up to `max_seq_len`
- 100 runs

|        | Variable Length API | SDPA     |
|--------|--------------------|----------|
| Runtime | 0.21750560760498047 ms       | 0.43171775817871094 ms  |
| TFLOPs | 231.812         | 320.840  |

The sparsity is 0.453 which we can see matches the speedup we get from Varlen (approx 50%). TFLOPs remains around the same, with SDPA slightly larger due to potential higher overhead and total flops scaling with sequence length.

**Testing**

Run `python test/test_varlen_attention.py` for unit tests where we verify basic functionality and confirm numerical match between varlen outputs vs SDPA.

**Next steps**

Next steps from this PR (higher in the stack) include registering the private API `_varlen_attn` as a custom op, implementing backward support, and enabling cuDNN with correct numerics.

(This stack builds on top of #162326)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164502
Approved by: https://github.com/v0i0, https://github.com/drisspg
2025-10-15 19:45:55 +00:00
2b71b62045 Add Memory Estimation Tracker (#165059)
Add Memory Tracker utility, which will track live memory given alternate ordering of nodes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165059
Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev
ghstack dependencies: #164738, #164783, #164944, #164945
2025-10-15 19:44:29 +00:00
8c4b528403 Revert "[Inductor][CuTeDSL] Move load_template up two directories (#165347)"
This reverts commit 815d6415996d5b32b569fd2a8206f1e57c75bfe3.

Reverted https://github.com/pytorch/pytorch/pull/165347 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/165347#issuecomment-3407958496))
2025-10-15 19:30:46 +00:00
066f818eea Refactor and unify v1/v2 _scaled_mm codes (#165436)
Summary:

* Refactor out some core routines (scaled_gemm, auto-tuned scaled_gemm)
* Unify v1/v2 dispatch calls where possible
* Simplify call pattern w.r.t. CUDA/ROCM for easier readability.

Test Plan:

```
pytest -svv test/test_scaled_matmul_cuda.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165436
Approved by: https://github.com/drisspg
2025-10-15 19:07:05 +00:00
14af1dc3da [DeviceMesh] Fix layout calculation when flattening non-contiguous dims (#165542)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165542
Approved by: https://github.com/ezyang, https://github.com/fduwjj
2025-10-15 18:55:45 +00:00
2395d7d7da Relax equality check (#165460)
When an object is inherited from multiple types, the previous check would fail. So we should relax it to respect eager semantic

Differential Revision: [D84635322](https://our.internmc.facebook.com/intern/diff/D84635322)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165460
Approved by: https://github.com/avikchaudhuri
2025-10-15 18:32:01 +00:00
0aa7ebaf03 Fix periodic debug tests failing due to FakeProcessGroup things (#165479)
These happen when building with CMAKE_BUILD_TYPE=RelWithAssert

This should fix two types of failures that started with https://github.com/pytorch/pytorch/pull/163665

Disclaimer that I used a lot of AI since I don't how pybind works or what refcounts and pointers are, so idk if this is a good solution, or even a solution at all (fwiw the tests pass now)

The first one type is

Truncated:
```
    default_pg, _ = _new_process_group_helper(
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 2096, in _new_process_group_helper
    backend_class = creator_fn(dist_backend_opts, backend_options)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/distributed/fake_pg.py", line 25, in _create_fake_pg
    return FakeProcessGroup._create_internal(
RuntimeError: new_refcount != 1 INTERNAL ASSERT FAILED at "/var/lib/jenkins/workspace/c10/util/intrusive_ptr.h":319, please report a bug to PyTorch. intrusive_ptr: Cannot increase refcount after it reached zero.
Exception raised from retain_ at /var/lib/jenkins/workspace/c10/util/intrusive_ptr.h:319 (most recent call first):
C++ CapturedTraceback:
#4 std::_Function_handler<std::shared_ptr<c10::LazyValue<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const> (), c10::SetStackTraceFetcher(std::function<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > ()>)::{lambda()#1}>::_M_invoke(std::_Any_data const&) from Logging.cpp:0
#5 c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) from ??:0
#6 c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) from ??:0
#7 c10::detail::torchInternalAssertFail(char const*, char const*, unsigned int, char const*, char const*) from ??:0
#8 void pybind11::class_<c10d::FakeProcessGroup, (anonymous namespace)::IntrusivePtrNoGilDestructor<c10d::FakeProcessGroup> >::init_instance<(anonymous namespace)::IntrusivePtrNoGilDestructor<c10d::FakeProcessGroup>, 0>(pybind11::detail::instance*, void const*) from init.cpp:0
#9 pybind11::detail::type_caster_generic::cast(void const*, pybind11::return_value_policy, pybind11::handle, pybind11::detail::type_info const*, void* (*)(void const*), void* (*)(void const*), void const*) from :0
#10 pybind11::cpp_function::initialize<torch::distributed::c10d::(anonymous namespace)::c10d_init(_object*, _object*)::{lambda(int, int, c10::intrusive_ptr<c10d::FakeProcessGroup::Options, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup::Options> >)#127}, c10::intrusive_ptr<c10d::FakeProcessGroup, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup> >, int, int, c10::intrusive_ptr<c10d::FakeProcessGroup::Options, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup::Options> >, pybind11::name, pybind11::scope, pybind11::sibling, pybind11::arg, pybind11::arg, pybind11::arg_v>(torch::distributed::c10d::(anonymous namespace)::c10d_init(_object*, _object*)::{lambda(int, int, c10::intrusive_ptr<c10d::FakeProcessGroup::Options, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup::Options> >)#127}&&, c10::intrusive_ptr<c10d::FakeProcessGroup, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup> > (*)(int, int, c10::intrusive_ptr<c10d::FakeProcessGroup::Options, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup::Options> >), pybind11::name const&, pybind11::scope const&, pybind11::sibling const&, pybind11::arg const&, pybind11::arg const&, pybind11::arg_v const&)::{lambda(pybind11::detail::function_call&)#3}::_FUN(pybind11::detail::function_call&) from init.cpp:0
```
and I fix it here by getting rid of `DontIncreaseRefcount` and using make_intrusive to do the ref count handling instead.  However, I also had to move the constructor to be public, which I think is not good, based on the reasoning of the original PR

The other one type is
```
Traceback (most recent call last):
  File "/var/lib/jenkins/workspace/test/test_testing.py", line 2415, in test_no_warning_on_import
    self.assertEqual(out, "")
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 4233, in assertEqual
    raise error_metas.pop()[0].to_error(  # type: ignore[index]
AssertionError: String comparison failed: "/opt/conda/envs/py_3.10/lib/python3.10/s[352 chars]):\n" != ''
- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/distributed/__init__.py:29: FutureWarning: pybind11-bound class 'torch._C._distributed_c10d.FakeProcessGroup' is using an old-style placement-new '__init__' which has been deprecated. See the upgrade guide in pybind11's docs. This message is only visible when compiled in debug mode.
-   if is_available() and not torch._C._c10d_init():

To execute this test, run the following from the base repo dir:
    python test/test_testing.py TestImports.test_no_warning_on_import
```
which I fix by getting rid of the `__init__` which I think is ok since it'll just error if you try to make one?

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165479
Approved by: https://github.com/ezyang
2025-10-15 18:16:08 +00:00
7a97832585 [ROCm] Add more timm models, forward fix #165381 (#165569)
PR #165381 added timm models to cuda and cpu expected accuracy files. ROCm expected accuracy files were not updated.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165569
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-10-15 18:11:21 +00:00
84d141e910 Revert "[inductor] Expand use of generic benchmark function (#164938)"
This reverts commit 5c583e2573f29243742e00b9fa36b266c5c78bb3.

Reverted https://github.com/pytorch/pytorch/pull/164938 on behalf of https://github.com/clee2000 due to I think this broke test/inductor/test_cuda_repro.py::CudaReproTests::test_epilogue_fusion_with_view? [GH job link](https://github.com/pytorch/pytorch/actions/runs/18529735968/job/52813191763) [HUD commit link](f58f301313) on both rocm and the slow grad check for linux. It did run successfully on cuda workflow on trunk, I wonder if this a gpu capability thing? no clue though ([comment](https://github.com/pytorch/pytorch/pull/164938#issuecomment-3407600224))
2025-10-15 17:48:38 +00:00
7c6c5d04fe Add scaled_grouped_mm_v2 and python API (#165154)
Summary:

* Add `torch._scaled_grouped_mm_v2` with more functionality and
  extensibility for future formats
* Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint
* Test both original and v2 functionality

Test Plan:

```
pytest -svv -k grouped test/test_scaled_matmul_cuda.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165154
Approved by: https://github.com/drisspg, https://github.com/danielvegamyhre
2025-10-15 17:47:23 +00:00
b509fb9b5d Revert "add and fix OpInfo tests for the default partitioner (#165372)"
This reverts commit bcfea48ab7fd489218289693b98c1a6a6582d079.

Reverted https://github.com/pytorch/pytorch/pull/165372 on behalf of https://github.com/malfet due to Looks like it broke slow jobs, see 331b7cc054/1 ([comment](https://github.com/pytorch/pytorch/pull/165372#issuecomment-3407567748))
2025-10-15 17:38:52 +00:00
331b7cc054 Fix double dispatch to Python for detach (#163671)
This fixes #71725.

Differential Revision: [D83857880](https://our.internmc.facebook.com/intern/diff/D83857880)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163671
Approved by: https://github.com/ezyang, https://github.com/albanD
2025-10-15 17:24:50 +00:00
815d641599 [Inductor][CuTeDSL] Move load_template up two directories (#165347)
Summary: Moves the function used to load CuTeDSL Jinja templates up one level out of the flex attention folder. This way it can be used for more generate Inductor templates in the future.

Test Plan: `INDUCTOR_TEST_DISABLE_FRESH_CACHE=1 TORCHINDUCTOR_CACHE_DIR=~/cutetest buck2 run mode/opt //caffe2/test/inductor:flex_flash -c fbcode.nvcc_arch=b200a -c fbcode.enable_gpu_sections=true -c fbcode.platform010_cuda_version=12.8`

Differential Revision: D84527470

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165347
Approved by: https://github.com/drisspg
2025-10-15 16:34:58 +00:00
ffe3cb226a In pipeline parallelism: Use same dtype for receive and send tensor when initializing p2p communication. (#165539)
When initializing the p2p communication for pipeline parallelism, currently different default dtypes are used for the send and receive tensor here:
5c583e2573/torch/distributed/pipelining/stage.py (L935-L936)

This caused hard to trace issues when training on multiple nodes. Multiple stages on one node seem to work for some reason which probably caused the unit tests not to catch this.

Fixes #165143

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165539
Approved by: https://github.com/H-Huang
2025-10-15 15:05:55 +00:00
7ae123d72c [DeviceMesh] Make _flatten_mapping an object attribute instead of a class attribute (#165521)
The `_flatten_mapping` field was defined as a class attribute with a mutable default value {}:
```
_flatten_mapping: dict[str, "DeviceMesh"] = {}
```
This caused all DeviceMesh instances to share the same dictionary object. When multiple test instances tried to create flattened meshes with the same name (like "dp"), they would conflict because they were all using the same shared dictionary, resulting in the error: "Flatten mesh with mesh_dim_name dp has been created before, Please specify another valid mesh_dim_name."

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165521
Approved by: https://github.com/fegin, https://github.com/lw
2025-10-15 14:47:09 +00:00
7719cb75bf [ATen][CMake] Fix duplicated CUTLASS path (#165424)
Fixes #165110

The `PUBLIC` scope causes CUTLASS of the FBGEMM being included in for all PyTorch targets, including special matmuls (RowwiseScaledMM, ScaledGroupMM and GroupMM). Due to version mismatch between FBGEMM/CUTLASS and PyTorch/CUTLASS it is unacceptable to use FBGEMM/CUTLASS in PyTorch targets. This PR limits the scope of FBGEMM/CUTLASS to `fbgemm_genai` target only.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165424
Approved by: https://github.com/cthi, https://github.com/eqy, https://github.com/danielvegamyhre
2025-10-15 14:14:17 +00:00
712f54d453 [ATen] Remove explicit casting of complex nansum during accumulation (#165494)
https://github.com/pytorch/pytorch/pull/164790 modifies aten to perform a different reduction order intra warp. However, this change exposed a large difference in a sum for complex32. Namely the case:

```
import torch

a = torch.tensor([[ 4.82031250+7.34765625j,
           -3.37109375-1.9501953125j],

         [ 3.7832031250-2.43359375j,
           -6.07812500+5.32812500j]], dtype=torch.complex32, device='cuda:0')

sum_out = torch.sum(a)
nansum_out = torch.nansum(a)
torch.testing.assert_close(
    sum_out,
    nansum_out,
    rtol=0,
    atol=0,
)
```

Here, the result of `sum` and `nansum` differed significantly by 1e-2. Further investigation showed that the explicit casting of b back to `arg_t` from `scalar_t` was the root cause. `arg_t` is the dtype of the accumulator, ComplexFloat, and `scalar_t` of the input dtype, ComplexHalf. When we cast in the reduction to the accumulator order, that means the input is still of ComplexHalf, which loses precision as it can store intermediate values.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165494
Approved by: https://github.com/ngimel
2025-10-15 13:49:25 +00:00
f58f301313 Fixes bug with tolist calls to GradTrackingTensors (#165184)
Fixes #161943

## The Fix
I implemented a recursive unwrapping helper function in the `tensor_to_list.cpp` file that looks for wrapped tensors and unwraps them. The recursive implementation was needed for multi-level gradTrackingTensors.

Let me know if there is any more suggestions on fixing this issue!

@guilhermeleobas @KimbingNg

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165184
Approved by: https://github.com/zou3519
2025-10-15 12:54:28 +00:00
5c583e2573 [inductor] Expand use of generic benchmark function (#164938)
Use the more generic `Benchmarker.benchmark` function to allow benchmarking other devices that support the required functionality, for example prologue and epilogue fusion can be benchmarked for triton CPU.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164938
Approved by: https://github.com/nmacchioni, https://github.com/eellison
2025-10-15 09:18:24 +00:00
0c14f55de6 [ez] fix typo (#165282)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165282
Approved by: https://github.com/ezyang, https://github.com/mlazos
2025-10-15 06:19:24 +00:00
8e510e1095 [MPS] fix empty dot op crash (#165237)
reproducer
```
import torch

# does not crash
a = torch.rand((0), device="cpu")
b = torch.rand((0), device="cpu")
a.dot(b)

# crashes due to internal assert
a = torch.rand((0), device="mps")
b = torch.rand((0), device="mps")
a.dot(b)

```

Discovered when implementing an op for SparseMPS backend
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165237
Approved by: https://github.com/malfet
2025-10-15 04:49:29 +00:00
59d30d1b75 [vision hash update] update the pinned vision hash (#165496)
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml).
Update the pinned vision hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165496
Approved by: https://github.com/pytorchbot
2025-10-15 04:35:50 +00:00
3915898c22 [audio hash update] update the pinned audio hash (#165495)
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml).
Update the pinned audio hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165495
Approved by: https://github.com/pytorchbot
2025-10-15 04:32:49 +00:00
3044e1a460 Revert "varlen api (#164502)"
This reverts commit 3681312ce03e425e280a110df2153db107616a15.

Reverted https://github.com/pytorch/pytorch/pull/164502 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but the doctests failure is legit ([comment](https://github.com/pytorch/pytorch/pull/164502#issuecomment-3404419420))
2025-10-15 03:56:42 +00:00
b11593c31b [8/N] Apply ruff UP035 rule (#165214)
This is follow-up of #164653 to continue applying `UP035` fixes. The purpose is to finally enable this rule.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165214
Approved by: https://github.com/ezyang
2025-10-15 03:18:57 +00:00
36871622f1 [2/N] Mark unused parameters in C++ code (#165121)
This is follow-up of #164912 to mark unused C++ parameters to improve code readability.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165121
Approved by: https://github.com/Skylion007
2025-10-15 03:04:39 +00:00
b4fd47179e feat(dynamo): IS#160752 make F.one_hot work with jacfwd + torch.compile(dynamic=True) (#160837)
Fixes #160752

# Background:
`torch.func.jacfwd` is implemented as vmap over forward-mode JVP. With torch.compile(dynamic=True), FakeTensor + SymInt shape reasoning is used while tracing through the transform. The old vmap rule for one_hot decomposed into “zeros_symint + scatter,” which interacted poorly with the transform stack and dynamic shapes, leading to failures mid-trace. Using a functional equality construction makes one_hot composable with vmap/JVP and friendly to dynamic shape tracing.

# Changes:
- functorch vmap batching rule for `aten::one_hot` now uses a purely functional formulation:
- Replace “zeros + scatter” with eq(self.unsqueeze(-1), arange(num_classes)).to(kLong) under FuncTorchBatched.
- one_hot native path remains unchanged for regular eager; vmap transform no longer relies on scatter, which was fragile under dynamic shape tracing.

The minimal repro from the issue is now fixed:
```python
import torch
import torch.nn.functional as F

MAX, BATCH = 3, 37

def func(x, idxs):
    return x.square() * F.one_hot(idxs, MAX)

def jacfunc(x, idxs):
    return torch.func.jacfwd(func, argnums=0)(x, idxs)

idxs = torch.randint(MAX, (BATCH,), dtype=torch.int64)
x = torch.rand((BATCH, MAX), dtype=torch.float64)

# eager
out_eager = jacfunc(x, idxs)

# compiled dynamic
jacfunc_c = torch.compile(jacfunc, dynamic=True)
out_comp = jacfunc_c(x, idxs)

torch.testing.assert_close(out_eager, out_comp)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160837
Approved by: https://github.com/guilhermeleobas, https://github.com/zou3519
2025-10-15 02:48:44 +00:00
4f400ab520 Fix: nDims is mutated inside the loop in Shape.cu (#165446)
Summary:
The `nDims` variable is mutated inside the loop but never restored to its original value.
This affects subsequent iterations of the outer loop.
Each batch iteration may get incorrect `nDims` after the first batch.

Test Plan: CI

Reviewed By: ngimel

Differential Revision: D84612194

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165446
Approved by: https://github.com/ngimel
2025-10-15 02:32:15 +00:00
839f6facdb [precompile] Fix frame construction for wrapped model. (#165454)
Summary: If a function is wrapped with functools, we should not look at the wrapped function signature but rather the wrapper, since we need to construct the frame for the top level function here.

Test Plan: test_decorated_function_with_functools_wrap_aot

Differential Revision: D84626752

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165454
Approved by: https://github.com/yiming0416
2025-10-15 02:01:46 +00:00
ca65023b90 [PP] Fix edge case with FSDP when stages_per_rank > 3 (#165467)
There is an edge case with FSDP + PP when we add UNSHARD + RESHARD, we at max have 3 stages unsharded, 3f83e8915e/torch/distributed/pipelining/schedules.py (L1029-L1031)

This change is need to be able to unshard and reshard a stage multiple times.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165467
Approved by: https://github.com/wwwjn
2025-10-15 01:53:04 +00:00
132ae8e6dd Don't link with libnvToolsExt when building for 12.9 (#165465)
This is to bring back this logic from https://github.com/pytorch/pytorch/pull/161916/files#diff-bf46b4a09ca67e50622bf84fefc0d11b584ffcc24ee6cc5019cf0fc7565d81a8L170.  Building libtorch on 12.9 is failing otherwise https://github.com/pytorch/pytorch/actions/runs/18458531395/job/52610761895:

```
cp: cannot stat '/usr/local/cuda/lib64/libnvToolsExt.so.1': No such file or directory
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165465
Approved by: https://github.com/atalman, https://github.com/malfet
2025-10-15 01:45:37 +00:00
a20afb6100 Allow at::native::offset_t to be offset using operator+= (#164570)
This will be required by CCCL 3.1.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164570
Approved by: https://github.com/Skylion007, https://github.com/eqy
2025-10-15 01:40:54 +00:00
47524dcc48 [benchmark] Add more timm models (#165381)
Added following models to timm_models

- [convnextv2_nano.fcmae_ft_in22k_in1k](https://huggingface.co/timm/convnextv2_nano.fcmae_ft_in22k_in1k)
- [vit_base_patch14_dinov2.lvd142m](https://huggingface.co/timm/vit_base_patch14_dinov2.lvd142m)
- [ViT-B-16-SigLIP-i18n-256](https://huggingface.co/timm/ViT-B-16-SigLIP-i18n-256)
- [deit_tiny_patch16_224.fb_in1k](https://huggingface.co/timm/deit_tiny_patch16_224.fb_in1k)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165381
Approved by: https://github.com/BoyuanFeng
2025-10-15 01:19:10 +00:00
9ffba8a2f9 fixing stress test failure (#164353)
Summary: This diff fixes a stress test failure by adding a new binary echo4.py and modifying the existing echo1.py binary. The changes are made in both fbcode and xplat directories. The api_test.py file is updated to use the new echo4.py binary, and the BUCK file is updated to include the new binary.

Test Plan:
```
buck test -j 18 'fbcode//mode/opt' fbcode//caffe2/test/distributed/elastic/multiprocessing:api_test -- --exact 'caffe2/test/distributed/elastic/multiprocessing:api_test - test_binary_redirect_and_tee (api_test.StartProcessesListAsBinaryTest)' --run-disabled --stress-runs 20 --record-results
```

```
buck test -j 18 'fbcode//mode/opt' fbcode//caffe2/test/distributed/elastic/multiprocessing:api_test -- --exact 'caffe2/test/distributed/elastic/multiprocessing:api_test - test_binary (api_test.StartProcessesListAsBinaryTest)' --run-disabled --stress-runs 20 --record-results
```

https://www.internalfb.com/intern/testinfra/testrun/17732923648474906

https://www.internalfb.com/intern/testinfra/testrun/15481123834815653

Differential Revision: D83623694

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164353
Approved by: https://github.com/d4l3k
2025-10-15 01:18:50 +00:00
3681312ce0 varlen api (#164502)
**Summary**

Today, the only way to have variable sequence length support in PyTorch attention is through nested tensors [here](https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html#nestedtensor-and-dense-tensor-support). We also want to add an explicit lower-level API that provides variable sequence length support without padding/masking in SDPA.

This PR builds out `varlen_attn`, the public API that users can call for the forward method, and `_varlen_attn`, the private API that calls into the Flash Attention/cuDNN backend.

**Benchmarking**

To benchmark, we compare runtime and TFLOPs against the current SDPA approach with padding.

Settings:

- 1 H100 machine
- `batch_size=8`, `max_seq_len=2048`, `embed_dim=1024`, `num_heads=16`
- dtype `torch.bfloat16`
- `is_causal=False`
- for variable length, we set sequences to be random multiples of 64 up to `max_seq_len`
- 100 runs

|        | Variable Length API | SDPA     |
|--------|--------------------|----------|
| Runtime | 0.21750560760498047 ms       | 0.43171775817871094 ms  |
| TFLOPs | 231.812         | 320.840  |

The sparsity is 0.453 which we can see matches the speedup we get from Varlen (approx 50%). TFLOPs remains around the same, with SDPA slightly larger due to potential higher overhead and total flops scaling with sequence length.

**Testing**

Run `python test/test_varlen_attention.py` for unit tests where we verify basic functionality and confirm numerical match between varlen outputs vs SDPA.

**Next steps**

Next steps from this PR (higher in the stack) include registering the private API `_varlen_attn` as a custom op, implementing backward support, and enabling cuDNN with correct numerics.

(This stack builds on top of #162326)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164502
Approved by: https://github.com/v0i0, https://github.com/drisspg
2025-10-15 00:45:06 +00:00
7778a58e7c Revert "[export] Handle kwargs better in aot_export_joint_with_descriptors (#165334)"
This reverts commit bbb902c8dd911e1587253f496c1e2fb178d4b6a1.

Reverted https://github.com/pytorch/pytorch/pull/165334 on behalf of https://github.com/jeffdaily due to trunk CI passed here but failures on HUD after merge?  test/functorch/test_aot_joint_with_descriptors.py::TestAOTJointWithDescriptors::test_module_with_kwargs [GH job link](https://github.com/pytorch/pytorch/actions/runs/18511729262/job/52755708742) [HUD commit link](bbb902c8dd) ([comment](https://github.com/pytorch/pytorch/pull/165334#issuecomment-3404071893))
2025-10-15 00:21:49 +00:00
e7091a47da [AOTI] skip Windows XPU crashed UTs. (#165393)
Skip some UTs, which crashed on Windows XPU.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165393
Approved by: https://github.com/jansel
2025-10-14 23:45:14 +00:00
bcfea48ab7 add and fix OpInfo tests for the default partitioner (#165372)
I noticed the default partitioner was breaking in some dynamic shape tests, so prior to turning off functionalization I want to tweak it to pass all of our OpInfo tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165372
Approved by: https://github.com/ezyang
ghstack dependencies: #165327
2025-10-14 23:34:34 +00:00
d2e1dbc8f2 make aotdispatcher opinfo tests keep input mutations in graph (#165327)
This stack is going to turn off functionalization and turn on the default partitioner, so I'm going to separate out a few changes before turning off functionalization in our OpInfo tests:

(1) run our tests with input mutations allowed inside the graph

(2) run our tests with the default partitioner

(3) run with functionalization off

(4) (later) make the tests properly test for bitwise equivalence

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165327
Approved by: https://github.com/ezyang
2025-10-14 23:34:33 +00:00
89298ada83 [device_mesh] Implement _unflatten on top of CuTe layout bookkeeping (#161224)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161224
Approved by: https://github.com/lw, https://github.com/fegin
ghstack dependencies: #164510
2025-10-14 23:17:11 +00:00
c467e59cb0 dynamo configs to torch.compiler (#163517)
Moving some dynamo configs to torch.compiler

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163517
Approved by: https://github.com/williamwen42, https://github.com/anijain2305

Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
2025-10-14 22:44:53 +00:00
bbb902c8dd [export] Handle kwargs better in aot_export_joint_with_descriptors (#165334)
fx.Interpreter doesn't handle kwargs... not sure how this code worked previously

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165334
Approved by: https://github.com/tugsbayasgalan, https://github.com/ezyang
2025-10-14 22:22:58 +00:00
e6f766c7d7 [Dynamo] Fixes for exceptions (#153966)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153966
Approved by: https://github.com/Lucaskabela
2025-10-14 22:03:58 +00:00
13b621d87c [DTensor] add __repr__ for CommDebugMode(get_total_count()=) (#165006)
I just want to print CommDebugMode and know if there is communication. implementing `__repr__` for `print(comm_mode)`

```
comm_mode = CommDebugMode()
with comm_mode:
    out = torch.mm(inps, weight)
print(comm_mode)
# CommDebugMode(get_total_counts()=0)
```

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165006
Approved by: https://github.com/anshul-si
ghstack dependencies: #165024
2025-10-14 21:31:23 +00:00
01738a3fea Continue local tensor mode enablement for DTensor tests (#165451)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165451
Approved by: https://github.com/ezyang, https://github.com/albanD
2025-10-14 21:20:54 +00:00
a2f34bdd7c Revert "Patch the flex_attention._get_mod_type to not use inspect.signature when computing num_positional_args (an alternative fix for flex attention graph break on create_block_mask) (#164923)"
This reverts commit 3401665110dbfbfa4625646e4a18ebf8c99fa92f.

Reverted https://github.com/pytorch/pytorch/pull/164923 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/164923#issuecomment-3403654378))
2025-10-14 21:20:49 +00:00
a63ab0b8cd [Inductor] Fix out-of-bounds indices in repeat_interleave decomposition (#165368)
When `repeat_interleave` is decomposed into:
```bash
  cumsum = repeat.cumsum(0)
  pos = torch.arange(output_size, device=repeat.device)
  indices = torch.searchsorted(cumsum, pos, right=True)
```
`searchsorted` op with `right=True` returns the insertion point after matching elements. When query values `pos` are `>= cumsum[-1]`, searchsorted returns `len(cumsum)`, which is out of bounds for indexing (valid range: `[0, len(cumsum)-1]`). These invalid indices trigger CUDA device-side assert errors in downstream indexing operations.

This fix adds clamping to ensure all indices stay within the valid range [0, repeat.size(0)-1].

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165368
Approved by: https://github.com/mlazos
2025-10-14 21:16:36 +00:00
102b7885ff Add option to run AOT Precompile in benchmark (#164906)
Use the existing benchmark infra to get some signals for AOT precompile pass rate on OSS models. Here we also measure and log the loading time.

```
python ./benchmarks/dynamo/huggingface.py --accuracy --inference --aot-precompile

python ./benchmarks/dynamo/timm_models.py --accuracy --inference --aot-precompile

python ./benchmarks/dynamo/torchbench.py --accuracy --inference --aot-precompile
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164906
Approved by: https://github.com/zhxchen17
2025-10-14 20:59:55 +00:00
382d04a51e [Inductor][ATen][FP8] Add note for supported blockwise scaling strategy pairs (#165450)
Summary: Add note mentioning which scaling type pairs are supported in Inductor ATen, since this was a source of confusion and also informs which scaling strategies we choose to support for other backends, like Triton.

Test Plan: n/a

Reviewed By: lw

Differential Revision: D84522373

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165450
Approved by: https://github.com/NikhilAPatel
2025-10-14 20:43:58 +00:00
1ec0755a7e [ISSUES] Update ci:sev template to include a note about ci: disable-autorevert label (#165459)
We noticed that disabling autorevert in any and all ci:sevs is too impactful, as ci: sevs are sometimes created just to communicate an action or a impactful change. But sometimes durring a SEV we might not want to disable autorevert anyways, a example is a ci: sev impacting jobs we don't use as basis for autorevert.

So, a note is added reminding the ci:sev author to optionally add this tag to disable auto-revert

Note: using this opportunity to fix the ci: disable-autorevert issues. As it is best for the title to be simple and the displayed message in the GitHub interface to be decorated with emoji :)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165459
Approved by: https://github.com/malfet
2025-10-14 20:32:46 +00:00
058782c6ab [torch.export] Rmoving unused constants - add support for corner case (#165205)
Summary: In some cases unused constant had only one level of child node, no second level of child node. Those constants should be removed too. The added test case has the scenario where this scenario will happen.

Test Plan:
```
buck test mode/opt caffe2/test:test_export -- 'test_unused_constant'
```

https://www.internalfb.com/intern/testinfra/testrun/15481123837456594

Differential Revision: D84398413

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165205
Approved by: https://github.com/angelayi
2025-10-14 20:26:28 +00:00
2b4ef6b4d6 [opaque_obj_v2] PyObject custom op schema type (#165004)
This is a cleaner implementation of opaque objects (https://github.com/pytorch/pytorch/pull/162660). Instead now we just need to do:

Call `register_opaque_type` to register the type as being "opaque" and allowed by custom ops. You also need to pass a unique name that maps to the type.
```python
class OpaqueQueue:
    def __init__(self, queue: list[torch.Tensor], init_tensor_: torch.Tensor) -> None:
        super().__init__()
        self.queue = queue
        self.init_tensor_ = init_tensor_

    def push(self, tensor: torch.Tensor) -> None:
        self.queue.append(tensor)

    def pop(self) -> torch.Tensor:
        if len(self.queue) > 0:
            return self.queue.pop(0)
        return self.init_tensor_

    def size(self) -> int:
        return len(self.queue)

register_opaque_type(OpaqueQueue, "_TestOpaqueObject_OpaqueQueue")
```

When creating the custom op, the schema will then use the unique name:
```python
self.lib = torch.library.Library("_TestOpaqueObject", "FRAGMENT")

torch.library.define(
    "_TestOpaqueObject::queue_push",
    "(_TestOpaqueObject_OpaqueQueue a, Tensor b) -> ()",
    tags=torch.Tag.pt2_compliant_tag,
    lib=self.lib,
)

@torch.library.impl(
    "_TestOpaqueObject::queue_push", "CompositeExplicitAutograd", lib=self.lib
)
def push_impl(queue: OpaqueQueue, b: torch.Tensor) -> None:
    assert isinstance(queue, OpaqueQueue)
    queue.push(b)
```

Using the custom op:
```python
queue = OpaqueQueue([], torch.zeros(3))
torch.ops._TestOpaqueObject.queue_push(queue, torch.ones(3))
self.assertTrue(queue.size(), 1)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165004
Approved by: https://github.com/albanD
2025-10-14 20:21:04 +00:00
3f83e8915e [inductor] fix issue for example value with unbacked strides (#163660)
## Issue

During autotune, we're not applying size hints atomically for the example inputs used for benchmarking.

If there is unbacked symint showing up in inputs' strides, this might lead to CUDA IMA,

and this could be reproduced by the added unittest, with stride being `[128 * u0, 128, 1]` and unbacked fallback being 8192, after calling `benchmark_example_value`, we get back a tensor with stride as `[8192, 128, 1]` as opposed to `[128 * 8192, 128, 1]`

## Fix

Using the atomic API when trying to apply size hints to input tensor' strides.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163660
Approved by: https://github.com/ColinPeppler
2025-10-14 20:07:51 +00:00
d7e3f493d9 [ROCm][CI] add mi355 to inductor perf test nightly (#165326)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165326
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-10-14 20:03:21 +00:00
08f09d9543 Ensure rms_norm decomp generates add.Scalar for pattern match BC (#165437)
Summary: Apparently if I just do `tensor + eps` this turns into add.Tensor, which is bad because the constant Tensor ends up getting hoisted into an input, which is a bozo thing to do. Just make sure it's exactly compatible.

Test Plan:
```
buck run 'fbcode//mode/opt' fbcode//bolt/nn/executorch/backends/tests:qnn_test_ar1g1 bolt.nn.executorch.backends.tests.qnn_test_ar1g1.QnnTestAR1G1.test_RMSNorm
```

Reviewed By: tugsbayasgalan

Differential Revision: D84613184

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165437
Approved by: https://github.com/tugsbayasgalan
2025-10-14 19:56:37 +00:00
74acf92648 Forward fix inductor failure (#165363) (#165443)
Summary:

Title

Test Plan: CI

Differential Revision: D84615478

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165443
Approved by: https://github.com/angelayi
2025-10-14 19:31:58 +00:00
cbf212e9c7 [CI] Fix doctest job if build without distributed (#165449)
Guard test with `TORCH_DOCTEST_DISTRIBUTED` and set it to true in
run_test.py to be able to pass doctest for PyTorch build without
distribtued support. This is a regression introduced by https://github.com/pytorch/pytorch/pull/164806

Fixes https://github.com/pytorch/pytorch/issues/165343

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165449
Approved by: https://github.com/seemethere
2025-10-14 19:19:03 +00:00
d18e068fd6 [dict] Implement __eq__ for dict_items (#155154)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155154
Approved by: https://github.com/anijain2305
2025-10-14 18:56:51 +00:00
3401665110 Patch the flex_attention._get_mod_type to not use inspect.signature when computing num_positional_args (an alternative fix for flex attention graph break on create_block_mask) (#164923)
The initial fix for inspect.signature uses not a right approach (https://github.com/pytorch/pytorch/pull/164349#pullrequestreview-3306614010). As @williamwen42 suggests (https://github.com/pytorch/pytorch/pull/164349#issuecomment-3379222885) we can just for now get rid of `inspect.signature` call in flex_attention to resolve this high priority issue (https://github.com/pytorch/pytorch/issues/164247#issuecomment-3378673179). In this PR I did exactly this - limited the scope of fix to just computing `num_positional_args` in `flex_attention._get_mod_type` based on properties returned by `NestedUserFunctionVariable.const_getattr` (some were missing so I added them)

Fixes #164247

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164923
Approved by: https://github.com/williamwen42
2025-10-14 18:29:15 +00:00
8c60f4ae08 [Distributed] update table in docs (#165009)
Fixes #162248

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165009
Approved by: https://github.com/ezyang
2025-10-14 18:17:22 +00:00
c4565c3b94 [distributed] Replace 164 assert statements in fsdp directory (#165235)
Replace assert statements with explicit if/raise patterns across 20 files:
- _optim_utils.py (38 asserts)
- _flat_param.py (25 asserts)
- _fully_shard/_fsdp_param.py (23 asserts)
- sharded_grad_scaler.py (12 asserts)
- fully_sharded_data_parallel.py (11 asserts)
- wrap.py (10 asserts)
- _state_dict_utils.py (9 asserts)
- _fully_shard/_fsdp_param_group.py (8 asserts)
- _runtime_utils.py (6 asserts)
- _init_utils.py (6 asserts)
- 10 additional files (16 asserts)

This prevents assertions from being disabled with Python -O flag.

Fixes partially #164878

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165235
Approved by: https://github.com/albanD
2025-10-14 18:04:57 +00:00
6918f17114 [FSDP2] provide public API to share cuda streams across roots (#165024)
for pipeline parallel, we can have multiple FSDP roots (chunks)
```
model = nn.Sequential([chunk0, chunk1])
fully_shard(model.chunk0)
fully_shard(model.chunk1)
```

we can call `share_comm_ctx` to share all-gather, reduce-scatter, all-reduce cuda streams. this avoids inter-stream memory fragmentation
```
from torch.distributed.fsdp import share_comm_ctx
share_comm_ctx([model.chunk0, model.chunk1])
```

unit test: `pytest -s test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_share_comm_context`

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165024
Approved by: https://github.com/mori360
2025-10-14 17:50:46 +00:00
9b6be53326 [distributed] Replace 94 assert statements in tensor ops files (#165229)
Replace assert statements with explicit if/raise patterns in:
- _math_ops.py (43 asserts)
- _matrix_ops.py (27 asserts)
- _view_ops.py (24 asserts)

This prevents assertions from being disabled with Python -O flag.

Fixes partially #164878.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165229
Approved by: https://github.com/albanD
2025-10-14 17:28:06 +00:00
7fee6bbf34 [Fix] Completely remove stride normalization on DLPack Tensor (#164161)
A followup on PR #163282
Fixes #163274
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164161
Approved by: https://github.com/ngimel, https://github.com/eqy
2025-10-14 17:17:11 +00:00
6adaa328f4 [autobucketing] aten autobucketing fix to enable aot_eager pass (#165063)
When the autobucketing pass  is registered as aot_eager backend `fw_compiler` and `bw_compiler`, this pr ensures the tensors are all-gathers on "cpu/cuda" device instead of "meta" device.

When we do `dist.all_gather_object`, it will create new bytestorage outside no_dispatch [here](a2e2e1d8c0/torch/distributed/distributed_c10d.py (L3303)), which is on meta device. Thus, I updated the code to use `unset_fake_temporarily`, which would gather RealTensor from other ranks.

 It is needed to unblock the aot_eager+autobucketing pass in this [PR](https://github.com/pytorch/torchtitan/pull/1813).

Otherwise, I hit the error as follows:

```bash
  traceback : Traceback (most recent call last):
    File "/home/ruisizhang123/pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 358, in wrapper
      return f(*args, **kwargs)
    File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 607, in train
      self.train_step(data_iterator)
      ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^
    File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 507, in train_step
      loss = self.forward_backward_step(input_dict, labels)
    File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 483, in forward_backward_step
      pred = model_parts[0](inputs, **extra_inputs, **extra_args)
    File "/home/ruisizhang123/pytorch/torch/_dynamo/eval_frame.py", line 418, in __call__
      return super().__call__(*args, **kwargs)
             ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
    File "/home/ruisizhang123/pytorch/torch/nn/modules/module.py", line 1784, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
             ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
    File "/home/ruisizhang123/pytorch/torch/nn/modules/module.py", line 1795, in _call_impl
      return forward_call(*args, **kwargs)
    File "/home/ruisizhang123/pytorch/torch/_dynamo/eval_frame.py", line 901, in compile_wrapper
      raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/ruisizhang123/pytorch/torch/_dynamo/output_graph.py", line 2359, in _call_user_compiler
      raise BackendCompilerFailed(
          self.compiler_fn, e, inspect.currentframe()
      ).with_traceback(e.__traceback__) from None
    File "/home/ruisizhang123/pytorch/torch/_dynamo/output_graph.py", line 2334, in _call_user_compiler
      compiled_fn = compiler_fn(gm, example_inputs)
    File "/home/ruisizhang123/pytorch/torch/_dynamo/repro/after_dynamo.py", line 156, in __call__
      compiled_gm = compiler_fn(gm, example_inputs)
    File "/home/ruisizhang123/pytorch/torch/__init__.py", line 2441, in __call__
      return self.compiler_fn(model_, inputs_, **self.kwargs)
             ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/ruisizhang123/pytorch/torch/_dynamo/backends/common.py", line 117, in __call__
      cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
    File "/home/ruisizhang123/pytorch/torch/_functorch/aot_autograd.py", line 1100, in aot_module_simplified
      compiled_fn, _ = aot_stage2_compile(
                       ~~~~~~~~~~~~~~~~~~^
          aot_state,
          ^^^^^^^^^^
      ...<4 lines>...
          inference_compiler,
          ^^^^^^^^^^^^^^^^^^^
      )
      ^
    File "/home/ruisizhang123/pytorch/torch/_functorch/_aot_autograd/graph_compile.py", line 257, in aot_stage2_compile
      return aot_stage2_autograd(aot_state, aot_graph_capture)
    File "/home/ruisizhang123/pytorch/torch/_functorch/_aot_autograd/graph_compile.py", line 1696, in aot_stage2_autograd
      compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
    File "/home/ruisizhang123/torchtitan/torchtitan/experiments/simple_fsdp/backend.py", line 35, in aten_autobucketing_reordering_pass
      schedule_overlap_bucketing(gm)
      ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^
    File "/home/ruisizhang123/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 755, in schedule_overlap_bucketing
      ).run()
        ~~~^^
    File "/home/ruisizhang123/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 358, in run
      self._align_compute_nodes_runtime_estimations_across_all_distributed_ranks()
      ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^
    File "/home/ruisizhang123/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 337, in _align_compute_nodes_runtime_estimations_across_all_distributed_ranks
      dist.all_gather_object(
      ~~~~~~~~~~~~~~~~~~~~~~^
          gathered_runtime_estimations, runtime_estimations, pg
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      )
      ^
    File "/home/ruisizhang123/pytorch/torch/distributed/c10d_logger.py", line 82, in wrapper
      return func(*args, **kwargs)
    File "/home/ruisizhang123/pytorch/torch/distributed/distributed_c10d.py", line 3170, in all_gather_object
      input_tensor, local_size = _object_to_tensor(obj, current_device, group)
                                 ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/ruisizhang123/pytorch/torch/distributed/distributed_c10d.py", line 3079, in _object_to_tensor
      byte_tensor = torch.ByteTensor(byte_storage).to(device)
                    ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^
  torch._dynamo.exc.BackendCompilerFailed: backend='compiler_fn' raised:
  RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "meta".  This is no longer allowed; the devices must match.

  Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165063
Approved by: https://github.com/eellison
2025-10-14 17:09:54 +00:00
4a7eed527f Make truediv numerics change external only for now (#165328)
Summary: For D84399286, failing ads ne deterministic tests now. These tests are especially brittle with subtle bitwise numerics changes. Will reenable for fbcode once e2e validation tests are performed

Test Plan: N/A

Differential Revision: D84514361

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165328
Approved by: https://github.com/izaitsevfb
2025-10-14 17:08:17 +00:00
d2494cbb2b Revert "[distributed] Replace assert statements with AssertionError exceptions (#165216)"
This reverts commit 74db92b21868b7e9e77cc966e5d57a8246723cbd.

Reverted https://github.com/pytorch/pytorch/pull/165216 on behalf of https://github.com/clee2000 due to I think this broke distributed/test_pg_wrapper.py::ProcessGroupNCCLWrapperTest::test_debug_level_detail_no_gloo [GH job link](https://github.com/pytorch/pytorch/actions/runs/18492765290/job/52693842750) [HUD commit link](74db92b218), note to self: bad TD ([comment](https://github.com/pytorch/pytorch/pull/165216#issuecomment-3402838765))
2025-10-14 17:05:16 +00:00
5eddbb5e47 [annotate] Annotation should be mapped across submod (#165202)
The match for backward nodes might be in a different submod, so we should check all submod for potential matches.

In flex attention, this could happen if `mask_mod` has operations (such as index) that increase the seq_nr of the forward graph nodes. Then the backward flex_attention nodes cannot find a match in its own subgraph.

```
python test/functorch/test_aot_joint_with_descriptors.py -k preserve_annotate
```

Also tested on torchtitan joint_graph_runner branch. The flex_attention backward nodes are annotated now.

```
NGPU=8   CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"   LOG_RANK=0   TRAIN_FILE="torchtitan.train"   TORCHFT_LIGHTHOUSE="http://localhost:29510"   PYTORCH_ALLOC_CONF="expandable_segments:True"   torchrun     --nproc_per_node=8     --rdzv_backend c10d     --rdzv_endpoint="localhost:0"     --local-ranks-filter 0     --role rank     --tee 3     -m torchtitan.train     --job.config_file ./torchtitan/models/llama3/train_configs/debug_model.toml     --model.name joint_graph_runner.llama3     --compile.enable     --parallelism.data_parallel_shard_degree=2     --parallelism.tensor_parallel_degree=4     --model.flavor=debugmodel_flex_attn
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165202
Approved by: https://github.com/SherlockNoMad
2025-10-14 16:19:38 +00:00
c9b2a09530 [export] Turn on install_free_tensors flag (#164691)
The final step in removing the discrepancy between
torch.compile(fullgraph=True) and torch.export(strict=True).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164691
Approved by: https://github.com/avikchaudhuri
2025-10-14 15:33:50 +00:00
bf5aeb3148 [torch/utils][Code Clean] Clean asserts in hipify/, jit/, model_dump and tensorboard of torch/utils (#165311)
Including:
- `torch/utils/hipify/`
- `torch/utils/jit/`
- `torch/utils/model_dump/`
- `torch/utils/tensorboard/`

Fixes part of #164878

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165311
Approved by: https://github.com/albanD
2025-10-14 15:26:23 +00:00
45b8c0f75c [distributed] Replace 54 assert statements in tensor/_ops/_tensor_ops.py (#165226)
Replace assert statements with explicit if/raise patterns to prevent assertions from being disabled with Python -O flag.

Fixes partially #164878

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165226
Approved by: https://github.com/albanD
2025-10-14 15:10:03 +00:00
c733072874 Fix IValue from SymBool on big-endian system (#163647)
Skip test_compiled_autograd_attribution on s390x

It fails both on s390x and x86_64 at least under some circumstances. Disable it for now until on s390x until it works reliably.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163647
Approved by: https://github.com/malfet
2025-10-14 15:07:48 +00:00
fbe0d20a17 [2/N] More ruff SIM fixes (#165031)
This is follow-up of #164695 to apply ruff SIM rules to more files. Most changes are about simplifying dict.get because None is already the default value.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165031
Approved by: https://github.com/mlazos
2025-10-14 14:22:54 +00:00
1fa11f42b1 [Bugfix][vLLM] Explicitly do not support instead of crashing for named tuples in infer schema (#165191)
Fixes https://github.com/vllm-project/vllm/issues/25270 by being explicit in erroring; previously we had a cryptic `__origin__ undefined` error, but now should give proper error message that we don't support NamedTuples in schema

Test with
```
python test/test_custom_ops.py TestCustomOp.test_unsupported_param_types
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165191
Approved by: https://github.com/zou3519
2025-10-14 14:18:42 +00:00
6f713e25bb [CodeClean] Replace std::runtime_error with TORCH_CHECK (#164130)
As the title stated.

**Changes**:
- torch/csrc/inductor(Part 1)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164130
Approved by: https://github.com/albanD, https://github.com/Skylion007
2025-10-14 14:09:53 +00:00
09a4187b8e Update windows cuda build to use 12.8 (#165345)
As title

Motivation: The rest of the pytorch and inductor build is using 12.8 and we're deprecating cuda 12.6 builds soon per https://github.com/pytorch/pytorch/issues/165111

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165345
Approved by: https://github.com/atalman, https://github.com/malfet
2025-10-14 13:58:20 +00:00
306c55ba27 [atomically_apply_size_hint] Make unbacked replacements reconciles to a single expr (#164324)
## Problem
Okay there's limitations with today's `atomically_apply_size_hint` though it works for most observed failures we've seen so far. However, it's easy to come up with an edge case.

Suppose you encounter this setup.
```
a: [s0 + u0]
b: [s1 + u1]
c: [u2 + u3]
d: [u100]
```

Today, we use a few heuristics to specify the LHS and RHS for replacements.

10d2734d9b/torch/_inductor/sizevars.py (L730-L759)

It's possible to end up with these replacement rules. Notice how there's no replacement for `s1 + u1` and `u2 + u3` :( That's because today picking the LHS and RHS matters a lot, and `s1 + u1` & `u2 + u3` happened to end up on the RHS.
```
s0 + u0 => s1 + u1
s0 + u0 => u2 + u3         # overrides previous replacement; each expr only gets one replacement
s0 + u0 => u100            # overrides previous replacement; ditto
```

I believe what we really want is this: everybody gets a replacement! And they all should (eventually) settle at the same canonical expr (i.e. `u100`) when running the replacement several times.
```
s1 + u1 ==> s0 + u0
u2 + u3 ==> s0 + u0
s0 + u0 ==> u100
```

We can just short-cut this by using the canonical expr as the replacement.
```
s1 + u1 ==> u100
u2 + u3 ==> u100
s0 + u0 ==> u100
```

## Implementation

I offer one way to deal with this:
1. assure every expression has one canonical replacement (i.e. `u100`)
2. if two expressions are equal (inferred from `deferred_runtime_asserts`), then they must have the same canonical replacement

 We can implement the above with union find.
* Whenever you see `Eq(lhs, rhs)` then do `union(lhs, rhs)`.
* Whenever you want to find the canonical replacement for a given expr then do `find(expr)`.
* When picking the canonical replacement we can use a few heuristics like (1) prefer a fully backed expr, (2) replacing with sub-expressions, and whatever we'd like.

Differential Revision: [D84549260](https://our.internmc.facebook.com/intern/diff/D84549260)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164324
Approved by: https://github.com/laithsakka
2025-10-14 13:57:33 +00:00
56d6229ff9 [MPS] fix comment for normcdf (#165233)
Just a small comment fix for normcdf
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165233
Approved by: https://github.com/malfet
2025-10-14 13:56:31 +00:00
74db92b218 [distributed] Replace assert statements with AssertionError exceptions (#165216)
Replaces 71 assert statements across 11 files in `torch.distributed` with explicit if-checks raising AssertionError to prevent assertions from being disabled with Python -O flag.

Fixes #164878

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165216
Approved by: https://github.com/albanD
2025-10-14 09:58:59 +00:00
c48843e4c6 [CP][BE] Docstrings, comments polish and remove unused variables (#165039)
No logic change, just polish the docstrings, comments and remove unused variables

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165039
Approved by: https://github.com/XilunWu
ghstack dependencies: #162542, #164500, #163185
2025-10-14 09:35:32 +00:00
9e89b1c4c7 Update torch-xpu-ops commit pin (#165321)
Update the torch-xpu-ops commit to [intel/torch-xpu-ops@ce9db1](ce9db15136), includes:

- Fix test_barrier hang by using static global rank in ProcessGroupXCCL
- Update install_xpu_headers only when content should change to speedup recompilation
- Add global rank information to communication logging
- Remove duplicate normalization from FFT methods
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165321
Approved by: https://github.com/EikanWang
2025-10-14 09:07:24 +00:00
c5972ebdfb Revert "Update windows cuda build to use 12.8 (#165345)"
This reverts commit ca96c675001fa87b9d9c648972415ab8b1591f11.

Reverted https://github.com/pytorch/pytorch/pull/165345 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/165345#issuecomment-3400344079))
2025-10-14 06:46:33 +00:00
18b3658df9 [inductor][ez] properly print Pointwise (#165369)
Previously when we print a ComputedBuffer for reduction, we get something like:
```
ComputedBuffer(name='buf0', layout=FixedLayout('cuda:0', torch.float32, size=[1, 768], stride=[768, 1]), data=Reduction(
  'cuda',
  torch.float32,
  def inner_fn(index, rindex):
      _, i1 = index
      r0_0 = rindex
      tmp0 = ops.load(tangents_1, i1 + 768 * r0_0)
      tmp1 = ops.to_dtype(tmp0, torch.float32, src_dtype=torch.bfloat16)
      tmp2 = ops.load(primals_1, i1 + 768 * r0_0)
      tmp3 = ops.to_dtype(tmp2, torch.float32, src_dtype=torch.bfloat16)
      tmp4 = ops.load(rsqrt, r0_0)
      tmp5 = tmp3 * tmp4
      tmp6 = tmp1 * tmp5
      return tmp6
  ,
```
But if we print a ComputedBuffer for a pointwise, we get something like
```
ComputedBuffer(name='buf2', layout=FixedLayout('cuda:0', torch.bfloat16, size=[32768, 768], stride=[768, 1]), data=Pointwise(device=device(type='cuda', index=0), dtype=torch.bfloat16, inner_fn=<function make_pointwise.<locals>.inner.<locals>.inner_fn at 0x7f12922c5bc0>, ranges=[32768, 768]))

```

Note that the inner function str is not printed.

With the change, we get the inner_fn string printed in this case:
```

ComputedBuffer(name='buf2', layout=FixedLayout('cuda:0', torch.bfloat16, size=[32768, 768], stride=[768, 1]), data=Pointwise(       14:42:46 [25/1988]
  'cuda',
  torch.bfloat16,
  def inner_fn(index):
      i0, i1 = index
      tmp0 = ops.load(tangents_1, i1 + 768 * i0)
      tmp1 = ops.to_dtype(tmp0, torch.float32, src_dtype=torch.bfloat16)
      tmp2 = ops.load(primals_2, i1)
      tmp3 = tmp1 * tmp2
      tmp4 = ops.load(rsqrt, i0)
      tmp5 = tmp3 * tmp4
      tmp6 = ops.load(buf1, i0)
      tmp7 = ops.constant(-0.5, torch.float32)
      tmp8 = tmp6 * tmp7
      tmp9 = ops.load(rsqrt, i0)
      tmp10 = tmp9 * tmp9
      tmp11 = tmp10 * tmp9
      tmp12 = tmp8 * tmp11
      tmp13 = ops.constant(0.0013020833333333333, torch.float32)
      tmp14 = tmp12 * tmp13
      tmp15 = ops.load(primals_1, i1 + 768 * i0)
      tmp16 = ops.to_dtype(tmp15, torch.float32, src_dtype=torch.bfloat16)
      tmp17 = tmp14 * tmp16
      tmp18 = tmp5 + tmp17
      tmp19 = ops.load(buf1, i0)
      tmp20 = ops.constant(-0.5, torch.float32)
      tmp21 = tmp19 * tmp20
      tmp22 = ops.load(rsqrt, i0)
      tmp23 = tmp22 * tmp22
      tmp24 = tmp23 * tmp22
      tmp25 = tmp21 * tmp24
      tmp26 = ops.constant(0.0013020833333333333, torch.float32)
      tmp27 = tmp25 * tmp26
      tmp28 = ops.load(primals_1, i1 + 768 * i0)
      tmp29 = ops.to_dtype(tmp28, torch.float32, src_dtype=torch.bfloat16)
      tmp30 = tmp27 * tmp29
      tmp31 = tmp18 + tmp30
      tmp32 = ops.to_dtype(tmp31, torch.bfloat16, src_dtype=torch.float32)
      return tmp32
  ,
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165369
Approved by: https://github.com/eellison
2025-10-14 06:08:12 +00:00
5fbf93b774 Introduce automatic wrapper to run DTensor tests under local tensor mode (#165383)
The wrapper enable to share test body implementation while eliminating need test class by hand. As an example, this change converts the whole DTensorTest to use local tensor mode.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165383
Approved by: https://github.com/ezyang
2025-10-14 06:08:03 +00:00
a856a17799 bf16 support for per_channel bwd (#165325)
Follow up to #165098 - adding bf16 support for the backward pass. To avoid BC breaking changes/losing precision, we upcast the parameters to fp32 after the op gets called, and downcast the gradients to bf16 before returning.

For testing, we upcast to fp32 before calling the reference function. We increase the tolerance to 1e-2 for bf16 inputs because of a difference in casting calculations between python's `x.to(torch.bfloat16)` and cpp's `x.to(at::kBFloat16)` (after comparing intermediate tensors, we found that the numerics diverge after the final casting). We don't explicitly cast in the CPP op but rather let autograd/optimizer handle it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165325
Approved by: https://github.com/andrewor14
2025-10-14 05:47:32 +00:00
bc6e08954d [user-cuda-streams] Add fork/join custom ops (#162900)
Creates the fork/join stream ops. These ops are passthrough ops which mutate all of their args (without actually performing any computation on them) so that during functionalization, implicit dependencies are added on all of their args. This allows us to prevent reordering during our pre/post grad graph passes.

Make custom ops inplace

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162900
Approved by: https://github.com/anijain2305
ghstack dependencies: #163027, #162899, #163028
2025-10-14 05:43:19 +00:00
45a96b2081 [user-streams] Handle aliasing properly (#163028)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163028
Approved by: https://github.com/williamwen42, https://github.com/anijain2305
ghstack dependencies: #163027, #162899
2025-10-14 05:43:19 +00:00
04e36611bb [user-cuda-streams] Pass streams/events to the graph via lookup table (#162899)
Stores streams in a global object look table that maps a dynamo selected index to objects. This index is generated during tracing, and at runtime, a helper function is called from the bytecode to populate this map.

This differs from the previous implementation that simply mapped IDs to the associated objects. This required specialization on the IDs of the specific objects, while this new approach does not.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162899
Approved by: https://github.com/anijain2305
ghstack dependencies: #163027
2025-10-14 05:43:19 +00:00
f15c25d5c3 [user-streams] Move stream code to streams module (#163027)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163027
Approved by: https://github.com/StrongerXi, https://github.com/anijain2305
2025-10-14 05:43:19 +00:00
e93981c243 [PyTorch][aarch64] Cast to signed char to fix aarch64 build (#165021)
Summary:
Initial fix: D39198776
Reverted by clang-tidy bot: D83948172

Test Plan:
Can now build on aarch64
{P1983767795}

Reviewed By: bigning

Differential Revision: D84203406

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165021
Approved by: https://github.com/cyyever, https://github.com/Skylion007
2025-10-14 05:37:34 +00:00
496adf9f9c Replace insert with std::rotate_copy for RingBuffer (#165348)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165348
Approved by: https://github.com/eqy, https://github.com/Skylion007
2025-10-14 05:11:28 +00:00
33bfec27ff Revert "use sym_numel, to allow fake tensors to work (#163831)"
This reverts commit e71c75680f2d6ce5f61ad4b2125f4934087762eb.

Reverted https://github.com/pytorch/pytorch/pull/163831 on behalf of https://github.com/isuruf due to test failure on mps introduced ([comment](https://github.com/pytorch/pytorch/pull/163831#issuecomment-3400131730))
2025-10-14 05:10:56 +00:00
f44935cc14 [torch/utils][Code Clean] Clean asserts in torch/utils/_sympy (#165279)
Including: `torch/utils/_sympy/`

Fixes part of #164878

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165279
Approved by: https://github.com/albanD
2025-10-14 04:52:23 +00:00
39116409a1 [torch/utils][Code Clean] Clean asserts in benchmark/ and data/ in torch/utils/ (#165299)
Including:
- `torch/utils/benchmarks/`
- `torch/utils/data/`

Fixes part of #164878

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165299
Approved by: https://github.com/albanD
2025-10-14 04:50:39 +00:00
515d1326c1 Add CLAUDE_CONTEXT directory to gitignore (#165358)
Claude often adds a bunch of MD files or other stuff that is specific to a local session, add a folder for claude to put this stuff that doesn't get checked into the repo
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165358
Approved by: https://github.com/oulgen
2025-10-14 04:47:21 +00:00
ac529df244 Native matmul (#157743)
### Implementation of #151705

This PR introduces the initial implementation of native `tl.dot` support in Inductor, with the goal of generating Triton matmul kernels directly—without relying on predefined templates.

To avoid complexity and ease the review process, I plan to split this work into two phases as outlined in #151705:

1. **Basic support** (this PR)
2. **Lazy broadcasting** for optimal performance (future PR)

### Summary of This PR

This PR implements the basic functionality. It does **not** include lazy broadcasting, so the generated kernels may involve explicit `tl.reshape` and `tl.trans` operations before calling `tl.dot`, which introduces some overhead.

### Notable Changes

1. Adds a new config flag: `config.triton.enable_native_matmul`
2. Introduces a new `ops.dot` IR node in Inductor and lowers `aten.mm` and `aten.bmm` to it when native matmul is enabled
3. Enforces tililng suitable for matmul when the native matmul flag is enabled
4. Implements code generation for `ops.dot`
5. Adds Triton autotuning heuristics: for now, I’ve copied the configuration from the existing matmul templates. However, this may not be optimal—it currently takes a long time to tune, and I think there must be a better way to tackle this.

@eellison @jansel @PaulZhang12 @shunting314

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157743
Approved by: https://github.com/jansel
2025-10-14 04:22:30 +00:00
fa3916f466 Revert "[export] Turn on install_free_tensors flag (#164691)"
This reverts commit 220a34118f40fab4f3f517556d6e1434139a1590.

Reverted https://github.com/pytorch/pytorch/pull/164691 on behalf of https://github.com/seemethere due to Breaks some internal things, both me and author agreed that revert was the best course of action ([comment](https://github.com/pytorch/pytorch/pull/164691#issuecomment-3400013759))
2025-10-14 03:58:12 +00:00
267348fe7f Revert "Fix double dispatch to Python for detach (#163671)"
This reverts commit a3e3efe474bef63940ded803e78bb2a382681f1e.

Reverted https://github.com/pytorch/pytorch/pull/163671 on behalf of https://github.com/seemethere due to We should've reverted this when we decided to revert https://github.com/pytorch/pytorch/pull/164691 since they were actually stacked ([comment](https://github.com/pytorch/pytorch/pull/163671#issuecomment-3400009953))
2025-10-14 03:55:36 +00:00
1803d40c99 Reapply "[export] Turn on install_free_tensors flag (#164691)" (#165353)
This reverts commit 9166f6120f63e2d5d76e6ccdbfccb8d6e41cbb43.

Reverted https://github.com/pytorch/pytorch/pull/165353 on behalf of https://github.com/seemethere due to This is causing merge conflicts since a dependent PR wasn't reverted ([comment](https://github.com/pytorch/pytorch/pull/165353#issuecomment-3400006587))
2025-10-14 03:52:50 +00:00
29c5368e0f MTIA _cdist_forward registration (#165333)
Summary: Added registration for _cdist_forward on MTIA

Differential Revision: D84357997

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165333
Approved by: https://github.com/albanD
2025-10-14 03:51:31 +00:00
e71c75680f use sym_numel, to allow fake tensors to work (#163831)
Fixes #[163759](https://github.com/pytorch/pytorch/issues/163759)

Replace `numel` with `sym_numel`. Tested with example in issue and it works now .

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163831
Approved by: https://github.com/bobrenjc93
2025-10-14 03:33:28 +00:00
ca96c67500 Update windows cuda build to use 12.8 (#165345)
As title

Motivation: The rest of the pytorch and inductor build is using 12.8 and we're deprecating cuda 12.6 builds soon per https://github.com/pytorch/pytorch/issues/165111

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165345
Approved by: https://github.com/atalman
2025-10-14 02:33:44 +00:00
770e6b910c [DTensor] Extend conv ops to 3D (#165241)
Current implementation hardcodes 4D input and output tensor shapes

Change that by computing `output_conv_shape` for any number of input dims
Replace `[.., .., .., slice]` with `[..., slice]`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165241
Approved by: https://github.com/ezyang
2025-10-14 02:30:46 +00:00
37d57ac9cb Use sym_eq in _check_rms_norm_inputs_symint (#165112)
Summary:
### Problem
ArrayRef's `equals()`does elementwise quality using `==` operator. This can cause a DDE for unbacked symints since `==`  operator calls `guard_bool`.
```
// SymInt.h
bool operator==(const SymInt& o) const {
  return sym_eq(o).guard_bool(__FILE__, __LINE__);
}
```

### Solution
Adds `sym_equals()` to do elementwise equality for `SymIntArrayRef`. Use this instead of `equals()` for `SymIntArrayRef`.

Reviewed By: guangy10, pianpwk, muchulee8

Differential Revision: D84168401

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165112
Approved by: https://github.com/Skylion007
2025-10-14 00:06:24 +00:00
9166f6120f Revert "[export] Turn on install_free_tensors flag (#164691)" (#165353)
This reverts commit 220a34118f40fab4f3f517556d6e1434139a1590.

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165353
Approved by: https://github.com/seemethere
2025-10-13 23:40:11 +00:00
fb0291d14b [pt2][caching] fix runtime error in context on cpu-only machine when compile for gpu (#165220)
re https://github.com/pytorch/pytorch/pull/165186

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165220
Approved by: https://github.com/clee2000
2025-10-13 22:47:41 +00:00
f3683453ae [compile] Regional inductor compilation with fx.annotate (#164776)
This PR introduces a way to compile a region of FX graph using `fx.traceback.annotate`.

### UX

1) In the user code, mark the region that you want to be compiled with inductor using `with fx_traceback.annotate({"compile_with_inductor": 0})`. As of now, we just rely on the string `compile_with_inductor` and ignore the integer. As the needs arise, we can update the logic.

Example

```
        def fn(x, y):
            sin = torch.sin(x)

            with fx_traceback.annotate({"compile_with_inductor": 0}):
                mul = sin * y
                add = mul + 1

            return torch.sin(add)
```

2) You have to instruct the compiler to use the annotations with `compile_fx_annotated_nodes_with_inductor` transformation. This is somewhat controversial, and a user might expect that just setting annotation is enough. But for now to control the blast radius, we need to explicitly do this. One such example is

```

# Set the fw and bw compiler of aot_autograd to `compile_fx_annotated_nodes_with_inductor`
def aot_eager_regional_inductor():
    return aot_autograd(
        fw_compiler=compile_fx_annotated_nodes_with_inductor,
        bw_compiler=compile_fx_annotated_nodes_with_inductor,
    )

```

3) Fixable in short-term - You have to wrap the user code in `torch.fx.traceback.preserve_node_meta` to ensure that annotations are propagated to the compiler. This is fixable, just need to make CI happy.

### Implementation

1) Relies on `CapabilityBasedPartitioner` to "scoop" out regions based on annotations, and then create subgraphs in the main graph.
2) Call `torch._inductor.standalone_compile` on these subgraphs, and jam the returned callable into the FX graph at the place of call_module

Resulting graph looks something like this - search for `torch__inductor_standalone_compile_inner`

Forward graph
```
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[10]", primals_2: "f32[10]"):
         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
        sin: "f32[10]" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        inner = torch__inductor_standalone_compile_inner(sin, primals_2)

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:68 in fn, code: add = mul + 1
        getitem: "f32[10]" = inner[0];  inner = None

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:70 in fn, code: return torch.sin(add)
        sin_1: "f32[10]" = torch.ops.aten.sin.default(getitem)
        return (sin_1, primals_1, primals_2, sin, getitem)
```

Backward graph
```
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[10]", primals_2: "f32[10]", sin: "f32[10]", add: "f32[10]", tangents_1: "f32[10]"):
         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
        cos_1: "f32[10]" = torch.ops.aten.cos.default(primals_1);  primals_1 = None

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:70 in fn, code: return torch.sin(add)
        cos: "f32[10]" = torch.ops.aten.cos.default(add);  add = None
        mul_1: "f32[10]" = torch.ops.aten.mul.Tensor(tangents_1, cos);  tangents_1 = cos = None

        # No stacktrace found for following nodes
        inner = torch__inductor_standalone_compile_inner(mul_1, sin, primals_2);  mul_1 = sin = primals_2 = None

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:67 in fn, code: mul = sin * y
        getitem: "f32[10]" = inner[0]
        getitem_1: "f32[10]" = inner[1];  inner = None

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
        mul_4: "f32[10]" = torch.ops.aten.mul.Tensor(getitem_1, cos_1);  getitem_1 = cos_1 = None
        return (mul_4, getitem)
```

### Some issue raised in the HOP meeting
1) CSE will not differentiate different meta custom nodes and do wrong thing.
2) SAC - The recomputed forward will be smaller than the forward. Will we compile a smaller region than?
3) What happens if you have a op in the middle which does not disturb the topology, is it still 1 subgraph?
4) What happens with the nesting of `fx_traceback.annotate`? Are there any ordering requirements?
5) What are we going to use the annotations for?
   a) compile flex
   b) streams
   c) nn.Module info to organize MoE components for pipelining
   d) PP stages
   e) Rename graph nodes for more debugging
   f) No nested regional compile

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164776
Approved by: https://github.com/SherlockNoMad
ghstack dependencies: #165188
2025-10-13 22:22:20 +00:00
1191e51c44 [dynamo][annotate] Remove the need of external ctx mgr of preserve_node_meta (#165188)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165188
Approved by: https://github.com/yushangdi
2025-10-13 22:22:20 +00:00
3edd94485f [5/N][DTensor device order] Implement graph based redistribution algorithm (#164902)
(Extract out the algorithm from https://github.com/pytorch/pytorch/pull/160266.)

Build a graph to search for the path from source placement to destination placement (with device order). Currently solution introduces too many all-gathers and missing the opportunity for all-to-all when redistribute, especially when we consider the device order.

### How to build the graph:
When operator of Shard, think of collective op as operation on a stack of device axis:
- I, J are tensor dimensions;
- X, Y, Z, Y are ordered mesh dimensions.
<img width="357" height="253" alt="image" src="https://github.com/user-attachments/assets/23bb3cc3-0506-4071-9053-3c525cf0e526" />

Detailed collective op transition is implemented in `DTensorRedistributePlanner.get_next_state`.

### How to find the min cost path:
Assign weight to different type of collective ops and use Dijkstra to find the min cost path from the graph we build.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164902
Approved by: https://github.com/ezyang
2025-10-13 22:03:57 +00:00
a701c937bf [dynamo][executorch] Return already added nn.Module during registration (#165338)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165338
Approved by: https://github.com/tugsbayasgalan
2025-10-13 21:24:07 +00:00
ecb53078fa Turn some const strings into constexpr in C++ code (#165203)
This PR turns more const strings into constexpr.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165203
Approved by: https://github.com/Skylion007
2025-10-13 20:25:20 +00:00
fa95882093 [BE] document distributed apis (#165194)
This PR documents some `torch.distributed.distributed_c10d` APIs. Below are some screenshots of the rendered docs.

<img width="909" height="527" alt="Screenshot 2025-10-10 at 10 18 40 PM" src="https://github.com/user-attachments/assets/555ae886-bead-47f3-8c67-9bc91c14bd11" />
<img width="885" height="548" alt="Screenshot 2025-10-10 at 10 18 47 PM" src="https://github.com/user-attachments/assets/1d6f7af1-db28-40f9-927e-5c47668a1a88" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165194
Approved by: https://github.com/janeyx99
2025-10-13 20:13:59 +00:00
a71ca4dcb9 Revert "[opaque_obj_v2] PyObject custom op schema type (#165004)"
This reverts commit 3faee200674c0c2bca3f395a063264cfd8a9a5b7.

Reverted https://github.com/pytorch/pytorch/pull/165004 on behalf of https://github.com/seemethere due to This fails internal tests, see D84399300 ([comment](https://github.com/pytorch/pytorch/pull/165004#issuecomment-3398906856))
2025-10-13 20:08:38 +00:00
c44d638b15 [Easy][Test][Dynamo] Avoid direct string comparison in MiscTestsDevice::get_device_module (#165314)
Fixes a small issue on string comparison, as the test fails with:
```
AssertionError: String comparison failed: 'cuda' != 'cuda:0'
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165314
Approved by: https://github.com/soulitzer
2025-10-13 19:58:59 +00:00
7c015334a3 Remove FIXME comment about reset_max_memory_reserved (#165249)
The function doesn't actually exist https://github.com/pytorch/pytorch/blob/main/torch/cuda/__init__.py#L1816

Fixes https://github.com/pytorch/pytorch/issues/27785

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165249
Approved by: https://github.com/svekars
2025-10-13 19:44:40 +00:00
cad2d473bf Force inlining into torch_function_mode_enabled (#164617)
This function is relatively hot; inlining here reduces time reported by `python -m timeit --setup 'import torch; t = torch.tensor([1])' 't._cdata'` from about 125 nsec/loop to about 110 nsec/loop. (To be fair, variance is high, but I did confirm with perf that time in this path seems to have roughly halved during torchtitan training.)

Note that locally I am getting bit by a GCC bug that I documented in a comment. Would be interested to hear if this does anything for clang.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164617
Approved by: https://github.com/ezyang
2025-10-13 19:25:51 +00:00
cb328c0b20 [ONNX] TorchTensor supports tofile() (#165195)
Fixes #165120

ref: 43ebf47bb5/src/onnx_ir/tensor_adapters.py (L171-L200)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165195
Approved by: https://github.com/justinchuby
2025-10-13 19:12:06 +00:00
64699b8042 [trymerge] Do not check for rules when reverting (#165342)
Why do we need to check for merge rules when reverting?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165342
Approved by: https://github.com/malfet
2025-10-13 19:07:00 +00:00
dcce473352 [BE] Fix unused parameter warning (#165272)
Fixes
```
[23/1155] Compiling /Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/kernels/EmbeddingBag.metal to EmbeddingBag_31.air
/Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/kernels/EmbeddingBag.metal:252:62: warning: unused parameter 'bag_size' [-Wunused-parameter]
  inline opmath_t<T> operator()(opmath_t<T> val, opmath_t<T> bag_size) {
                                                             ^
1 warning generated.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165272
Approved by: https://github.com/Skylion007
2025-10-13 18:52:51 +00:00
c41e52118d Fix loop pipelining for 2d/2d case of Triton grouped MM (#165265)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165265
Approved by: https://github.com/ngimel
2025-10-13 18:45:39 +00:00
955cd7060b Revert "Update round size with 1 division behavior (#162203)"
This reverts commit 12d2ef557f6e127100267c31a31572d8ab5cc788.

Reverted https://github.com/pytorch/pytorch/pull/162203 on behalf of https://github.com/izaitsevfb due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/162203#issuecomment-3398622898))
2025-10-13 18:32:37 +00:00
0ce945790e [NJT] Fix schema validation error in jagged functions (#165307)
Fixes #161812
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165307
Approved by: https://github.com/soulitzer
2025-10-13 17:59:18 +00:00
70ec464c16 [BE] document some quantization public apis (#165160)
This PR documents some apis in `torch.ao.quantization.utils`

<img width="885" height="296" alt="Screenshot 2025-10-10 at 4 38 10 PM" src="https://github.com/user-attachments/assets/4323a6f5-ac3a-4f2e-ba00-35f3b208bef4" />
<img width="876" height="319" alt="Screenshot 2025-10-10 at 4 38 14 PM" src="https://github.com/user-attachments/assets/164822c3-9740-46f9-953d-bb20c77bcf69" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165160
Approved by: https://github.com/janeyx99
2025-10-13 17:24:42 +00:00
2c600bb665 [torchfuzz] fix some errors when walkthroughing README.md (#165225)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165225
Approved by: https://github.com/soulitzer
2025-10-13 17:17:50 +00:00
e93343cfab [CP] Introduce flex_cp_forward custom op for FlexAttention CP (#163185)
The custom op will fetch the required K and V. Currently, the forward pass is just an all-gather, and the backward pass is a reduce-scatter.  While the logic is the same as all_gather_tensor_autograd, the custom op avoids the Autograd warning that wait_tensor() is registered to autograd.

For the next step, we should explore how to interpolate the required communication based on the information from BlockMask.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163185
Approved by: https://github.com/XilunWu
ghstack dependencies: #162542, #164500
2025-10-13 17:16:32 +00:00
c86a7c5f5e Disable failing test_int8_woq_mm_concat_cuda on slow grad check (#165331)
Same as https://github.com/pytorch/pytorch/pull/165147, I missed some

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165331
Approved by: https://github.com/bbeckca
2025-10-13 17:08:00 +00:00
4e420415e8 Avoids calling builtin iter if object is a generator (#162521)
The `iter(gen)` call will return the given `gen` object. So, we just avoid this call and shaves off a few ms of tracing time

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162521
Approved by: https://github.com/mlazos
2025-10-13 17:07:54 +00:00
83cbba8759 [MPS] Support large tensors in torch.cat (#164416)
Fixes #164415
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164416
Approved by: https://github.com/malfet
2025-10-13 16:56:56 +00:00
684df93975 [CI] Default keep-going true for tags of form ciflow/something/commitsha (#165180)
Tags of the form `ciflow/something/commitsha` are usually created by running the workflow from HUD

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165180
Approved by: https://github.com/huydhn
2025-10-13 16:12:37 +00:00
a3e3efe474 Fix double dispatch to Python for detach (#163671)
This fixes #71725.

Differential Revision: [D83857880](https://our.internmc.facebook.com/intern/diff/D83857880)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163671
Approved by: https://github.com/ezyang, https://github.com/albanD
2025-10-13 16:10:17 +00:00
6bda3bb286 [PP] Fix split_args_kwargs_into_chunks issues (#165306)
1. https://github.com/pytorch/pytorch/pull/164111/ adds the support of splitting BlockMask. But BlockMask actually has B=1 case that the BlockMask will be broadcast. This PR adds the support of B=1 case.

2. The original split_args_kwargs_into_chunks doesn't initialize the default specs correctly. Since we now use tree_flatten and tree_unflatten to do split, we should also use tree_map to initialize the default spec. This will actually support the case when the values are not torch.Tensor, which were only supported if users explicitly provide the shard spec.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165306
Approved by: https://github.com/H-Huang
2025-10-13 15:52:39 +00:00
8580112682 Revert "[dynamo][DebugMode] mask python keys in dispatch_key_set guard checks (#164992)"
This reverts commit 306b344a1847749f0baf085dcd92560f4e99cd1b.

Reverted https://github.com/pytorch/pytorch/pull/164992 on behalf of https://github.com/jeffdaily due to broke ROCm CI test/inductor/test_inductor_scheduler.py::TestSchedulerCUDA::test_flop_counter_op_options0_cuda_float32 [GH job link](https://github.com/pytorch/pytorch/actions/runs/18417066364/job/52485636942) [HUD commit link](306b344a18) ([comment](https://github.com/pytorch/pytorch/pull/164992#issuecomment-3397927142))
2025-10-13 15:14:34 +00:00
4874cce52f [xla hash update] update the pinned xla hash (#165302)
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml).
Update the pinned xla hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165302
Approved by: https://github.com/pytorchbot
2025-10-13 12:36:29 +00:00
c509a78645 Update slow tests (#165301)
This PR is auto-generated weekly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/weekly.yml).
Update the list of slow tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165301
Approved by: https://github.com/pytorchbot
2025-10-13 11:47:32 +00:00
524 changed files with 12988 additions and 4412 deletions

View File

@ -187,19 +187,22 @@ if [[ $CUDA_VERSION == 12* || $CUDA_VERSION == 13* ]]; then
export USE_CUFILE=0
else
DEPS_LIST+=(
"/usr/local/cuda/lib64/libnvToolsExt.so.1"
"/usr/local/cuda/lib64/libcublas.so.12"
"/usr/local/cuda/lib64/libcublasLt.so.12"
"/usr/local/cuda/lib64/libcudart.so.12"
"/usr/local/cuda/lib64/libnvrtc.so.12"
"/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.12")
DEPS_SONAME+=(
"libnvToolsExt.so.1"
"libcublas.so.12"
"libcublasLt.so.12"
"libcudart.so.12"
"libnvrtc.so.12"
"libcupti.so.12")
if [[ $CUDA_VERSION != 12.9* ]]; then
DEPS_LIST+=("/usr/local/cuda/lib64/libnvToolsExt.so.1")
DEPS_SONAME+=("libnvToolsExt.so.1")
fi
fi
else
echo "Using nvidia libs from pypi."

View File

@ -8,6 +8,7 @@ assignees: ''
---
> NOTE: Remember to label this issue with "`ci: sev`"
> If you want autorevert to be disabled, keep the ci: disable-autorevert label
<!-- Add the `merge blocking` label to this PR to prevent PRs from being merged while this issue is open -->

View File

@ -1,7 +1,7 @@
---
name: DISABLE AUTOREVERT
name: "D❌\U0001F519 ISABLE AUTOREVERT"
about: Disables autorevert when open
title: "❌​\U0001F519 [DISABLE AUTOREVERT]"
title: "[DISABLE AUTOREVERT]"
labels: 'ci: disable-autorevert'
assignees: ''

View File

@ -65,7 +65,7 @@ runs:
cd .ci/lumen_cli
python3 -m pip install -e .
)
MAX_JOBS="$(nproc --ignore=6)"
MAX_JOBS="$(nproc --ignore=10)"
export MAX_JOBS
# Split the comma-separated list and build each target

View File

@ -1 +1 @@
8ad2aa5d354d1bf432339113860185d5a5d1abbd
1b013f5b5a87a1882eb143c26d79d091150d6a37

View File

@ -1 +1 @@
f5c6c2ec6490455e86f67b2a25c10390d60a27f7
faffd5cf673615583da6517275e361cb3dbc77e6

View File

@ -1 +1 @@
2a9138a26ee257fef05310ad3fecf7c55fe80d73
0fa6e3129e61143224663e1ec67980d12b7ec4eb

View File

@ -3,6 +3,7 @@ ciflow_tracking_issue: 64124
ciflow_push_tags:
- ciflow/b200
- ciflow/b200-symm-mem
- ciflow/b200-distributed
- ciflow/binaries
- ciflow/binaries_libtorch
- ciflow/binaries_wheel
@ -15,7 +16,8 @@ ciflow_push_tags:
- ciflow/inductor-micro-benchmark
- ciflow/inductor-micro-benchmark-cpu-x86
- ciflow/inductor-perf-compare
- ciflow/inductor-perf-test-nightly-rocm
- ciflow/inductor-perf-test-nightly-rocm-mi300
- ciflow/inductor-perf-test-nightly-rocm-mi355
- ciflow/inductor-perf-test-nightly-x86-zen
- ciflow/inductor-periodic
- ciflow/inductor-rocm

View File

@ -512,6 +512,8 @@ def perform_misc_tasks(
"keep-going",
branch == MAIN_BRANCH
or bool(tag and re.match(r"^trunk/[a-f0-9]{40}$", tag))
# Pattern for tags created via manual run on HUD
or bool(tag and re.match(r"^ciflow/[^/]+/[a-f0-9]{40}$", tag))
or check_for_setting(labels, pr_body, "keep-going"),
)
set_output(

View File

@ -2042,10 +2042,6 @@ def validate_revert(
f"[{', '.join(allowed_reverters)}], but instead is {author_association}."
)
# Raises exception if matching rule is not found, but ignores all status checks
find_matching_merge_rule(
pr, repo, skip_mandatory_checks=True, skip_internal_checks=True
)
commit_sha = get_pr_commit_sha(repo, pr)
return (author_login, commit_sha)

62
.github/workflows/b200-distributed.yml vendored Normal file
View File

@ -0,0 +1,62 @@
name: CI for distributed tests on B200
on:
pull_request:
paths:
- .github/workflows/b200-distributed.yml
workflow_dispatch:
push:
tags:
- ciflow/b200-distributed/*
schedule:
- cron: 46 8 * * * # about 1:46am PDT
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
cancel-in-progress: true
permissions:
id-token: write
contents: read
jobs:
get-label-type:
if: github.repository_owner == 'pytorch'
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200:
name: linux-jammy-cuda12.8-py3.10-gcc11-build-distributed-b200
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runner: linux.12xlarge.memory
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
cuda-arch-list: '10.0'
test-matrix: |
{ include: [
{ config: "distributed", shard: 1, num_shards: 2, runner: "linux.dgx.b200.8" },
{ config: "distributed", shard: 2, num_shards: 2, runner: "linux.dgx.b200.8" },
]}
secrets: inherit
linux-jammy-cuda12_8-py3_10-gcc11-test-distributed-b200:
name: linux-jammy-cuda12.8-py3.10-gcc11-test-b200
uses: ./.github/workflows/_linux-test.yml
needs:
- linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200
with:
timeout-minutes: 1200
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200.outputs.test-matrix }}
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
secrets: inherit

View File

@ -27,9 +27,8 @@ jobs:
fail-fast: false
matrix:
python-version: [ '3.12' ]
# TODO (huydhn): Add cu130 after https://github.com/vllm-project/vllm/issues/24464 is resolved
platform: [ 'manylinux_2_28_x86_64', 'manylinux_2_28_aarch64' ]
device: [ 'cu128', 'cu129' ]
device: [ 'cu128', 'cu129', 'cu130' ]
include:
- platform: manylinux_2_28_x86_64
device: cu128
@ -39,6 +38,10 @@ jobs:
device: cu129
manylinux-image: 'pytorch/manylinux2_28-builder:cuda12.9'
runner: linux.12xlarge.memory
- platform: manylinux_2_28_x86_64
device: cu130
manylinux-image: 'pytorch/manylinux2_28-builder:cuda13.0'
runner: linux.12xlarge.memory
- platform: manylinux_2_28_aarch64
device: cu128
manylinux-image: 'pytorch/manylinuxaarch64-builder:cuda12.8'
@ -47,6 +50,11 @@ jobs:
device: cu129
manylinux-image: 'pytorch/manylinuxaarch64-builder:cuda12.9'
runner: linux.arm64.r7g.12xlarge.memory
exclude:
# TODO (huydhn): Add cu130 aarch64 once PyTorch is on 2.9+ and
# xformers is update to support 13.0
- platform: manylinux_2_28_aarch64
device: cu130
name: "Build ${{ matrix.device }} vLLM wheel on ${{ matrix.platform }}"
runs-on: ${{ matrix.runner }}
timeout-minutes: 480
@ -169,7 +177,12 @@ jobs:
fail-fast: false
matrix:
platform: [ 'manylinux_2_28_x86_64', 'manylinux_2_28_aarch64' ]
device: [ 'cu128', 'cu129' ]
device: [ 'cu128', 'cu129', 'cu130' ]
exclude:
# TODO (huydhn): Add cu130 aarch64 once PyTorch is on 2.9+ and
# xformers is update to support 13.0
- platform: manylinux_2_28_aarch64
device: cu130
env:
PLATFORM: ${{ matrix.platform }}
BUILD_DEVICE: ${{ matrix.device }}

View File

@ -0,0 +1,132 @@
name: inductor-perf-nightly-rocm-mi300
on:
push:
tags:
- ciflow/inductor-perf-test-nightly-rocm-mi300/*
schedule:
- cron: 15 0 * * *
# NB: GitHub has an upper limit of 10 inputs here, so before we can sort it
# out, let try to run torchao cudagraphs_low_precision as part of cudagraphs
workflow_dispatch:
inputs:
training:
description: Run training (on by default)?
required: false
type: boolean
default: true
inference:
description: Run inference (on by default)?
required: false
type: boolean
default: true
default:
description: Run inductor_default?
required: false
type: boolean
default: false
dynamic:
description: Run inductor_dynamic_shapes?
required: false
type: boolean
default: false
cppwrapper:
description: Run inductor_cpp_wrapper?
required: false
type: boolean
default: false
cudagraphs:
description: Run inductor_cudagraphs?
required: false
type: boolean
default: true
freezing_cudagraphs:
description: Run inductor_cudagraphs with freezing for inference?
required: false
type: boolean
default: false
aotinductor:
description: Run aot_inductor for inference?
required: false
type: boolean
default: false
maxautotune:
description: Run inductor_max_autotune?
required: false
type: boolean
default: false
benchmark_configs:
description: The list of configs used the benchmark
required: false
type: string
default: inductor_huggingface_perf_rocm_mi300,inductor_timm_perf_rocm_mi300,inductor_torchbench_perf_rocm_mi300
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
cancel-in-progress: true
permissions: read-all
jobs:
get-label-type:
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
opt_out_experiments: lf
linux-jammy-rocm-py3_10-inductor-benchmark-build:
if: github.repository_owner == 'pytorch'
name: rocm-py3_10-inductor-benchmark-build
uses: ./.github/workflows/_linux-build.yml
with:
build-environment: linux-jammy-rocm-py3_10
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks
test-matrix: |
{ include: [
{ config: "inductor_huggingface_perf_rocm_mi300", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_huggingface_perf_rocm_mi300", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_huggingface_perf_rocm_mi300", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_huggingface_perf_rocm_mi300", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_huggingface_perf_rocm_mi300", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_timm_perf_rocm_mi300", shard: 1, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_timm_perf_rocm_mi300", shard: 2, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_timm_perf_rocm_mi300", shard: 3, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_timm_perf_rocm_mi300", shard: 4, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_timm_perf_rocm_mi300", shard: 5, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_timm_perf_rocm_mi300", shard: 6, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_timm_perf_rocm_mi300", shard: 7, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_torchbench_perf_rocm_mi300", shard: 1, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_torchbench_perf_rocm_mi300", shard: 2, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_torchbench_perf_rocm_mi300", shard: 3, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_torchbench_perf_rocm_mi300", shard: 4, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_torchbench_perf_rocm_mi300", shard: 5, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_torchbench_perf_rocm_mi300", shard: 6, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_torchbench_perf_rocm_mi300", shard: 7, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_torchbench_perf_rocm_mi300", shard: 8, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_torchbench_perf_rocm_mi300", shard: 9, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
]}
secrets: inherit
linux-jammy-rocm-py3_10-inductor-benchmark-test:
permissions:
id-token: write
contents: read
name: rocm-py3_10-inductor-benchmark-test
uses: ./.github/workflows/_rocm-test.yml
needs: linux-jammy-rocm-py3_10-inductor-benchmark-build
with:
build-environment: linux-jammy-rocm-py3_10
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true
docker-image: ${{ needs.linux-jammy-rocm-py3_10-inductor-benchmark-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-inductor-benchmark-build.outputs.test-matrix }}
timeout-minutes: 720
# Disable monitor in perf tests for more investigation
disable-monitor: true
monitor-log-interval: 10
monitor-data-collect-interval: 2
secrets: inherit

View File

@ -1,11 +1,11 @@
name: inductor-perf-nightly-rocm
name: inductor-perf-nightly-rocm-mi355
on:
push:
tags:
- ciflow/inductor-perf-test-nightly-rocm/*
- ciflow/inductor-perf-test-nightly-rocm-mi355/*
schedule:
- cron: 0 7 * * 0,3
- cron: 15 0 * * *
# NB: GitHub has an upper limit of 10 inputs here, so before we can sort it
# out, let try to run torchao cudagraphs_low_precision as part of cudagraphs
workflow_dispatch:
@ -59,7 +59,7 @@ on:
description: The list of configs used the benchmark
required: false
type: string
default: inductor_huggingface_perf_rocm,inductor_timm_perf_rocm,inductor_torchbench_perf_rocm
default: inductor_huggingface_perf_rocm_mi355,inductor_timm_perf_rocm_mi355,inductor_torchbench_perf_rocm_mi355
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
@ -88,23 +88,27 @@ jobs:
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks
test-matrix: |
{ include: [
{ config: "inductor_huggingface_perf_rocm", shard: 1, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_huggingface_perf_rocm", shard: 2, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_huggingface_perf_rocm", shard: 3, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_huggingface_perf_rocm", shard: 4, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_timm_perf_rocm", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_timm_perf_rocm", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_timm_perf_rocm", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_timm_perf_rocm", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_timm_perf_rocm", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_torchbench_perf_rocm", shard: 1, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_torchbench_perf_rocm", shard: 2, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_torchbench_perf_rocm", shard: 3, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_torchbench_perf_rocm", shard: 4, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_torchbench_perf_rocm", shard: 5, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_torchbench_perf_rocm", shard: 6, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_torchbench_perf_rocm", shard: 7, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_torchbench_perf_rocm", shard: 8, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_timm_perf_rocm_mi355", shard: 1, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_timm_perf_rocm_mi355", shard: 2, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_timm_perf_rocm_mi355", shard: 3, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_timm_perf_rocm_mi355", shard: 4, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_timm_perf_rocm_mi355", shard: 5, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_timm_perf_rocm_mi355", shard: 6, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_timm_perf_rocm_mi355", shard: 7, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 1, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 2, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 3, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 4, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 5, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 6, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 7, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 8, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 9, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
]}
secrets: inherit

View File

@ -7,9 +7,11 @@ on:
workflow_dispatch:
inputs:
test_mode:
required: false
type: string
default: 'short'
type: choice
options:
- 'short'
- 'long'
- 'all'
description: tag filter for operator benchmarks, options from long, short, all
schedule:
# Run at 07:00 UTC every Sunday
@ -37,20 +39,7 @@ jobs:
docker-image-name: ci-image:pytorch-linux-jammy-py3-gcc11-inductor-benchmarks
test-matrix: |
{ include: [
{ config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
]}
secrets: inherit
opbenchmark-on-demand-build:
if: ${{ github.event_name == 'workflow_dispatch' && github.repository_owner == 'pytorch' }}
name: opbenchmark-on-demand-build
uses: ./.github/workflows/_linux-build.yml
with:
build-environment: linux-jammy-py3.10-gcc11-build
docker-image-name: ci-image:pytorch-linux-jammy-py3-gcc11-inductor-benchmarks
test-matrix: |
{ include: [
{ config: "cpu_operator_benchmark_${{ inputs.test_mode }}", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
{ config: "cpu_operator_benchmark_${{ inputs.test_mode || 'short' }}", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
]}
secrets: inherit

View File

@ -180,13 +180,13 @@ jobs:
disable-monitor: false
secrets: inherit
win-vs2022-cuda12_6-py3-build:
name: win-vs2022-cuda12.6-py3
win-vs2022-cuda12_8-py3-build:
name: win-vs2022-cuda12.8-py3
uses: ./.github/workflows/_win-build.yml
needs: get-label-type
with:
build-environment: win-vs2022-cuda12.6-py3
cuda-version: "12.6"
build-environment: win-vs2022-cuda12.8-py3
cuda-version: "12.8"
runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
secrets: inherit

1
.gitignore vendored
View File

@ -395,3 +395,4 @@ android/pytorch_android_torchvision/.cxx
CLAUDE.local.md
/test_*.py
/debug_*.py
CLAUDE_CONTEXT/

View File

@ -256,6 +256,7 @@ endif()
IF(USE_FBGEMM_GENAI)
set(FBGEMM_THIRD_PARTY ${PROJECT_SOURCE_DIR}/third_party/fbgemm/external/)
set(FBGEMM_GENAI_SRCS ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize)
if(USE_CUDA)
# To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
# If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
@ -292,58 +293,64 @@ IF(USE_FBGEMM_GENAI)
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/"
)
target_include_directories(fbgemm_genai PUBLIC
target_include_directories(fbgemm_genai PRIVATE
${FBGEMM_THIRD_PARTY}/cutlass/include
${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
${fbgemm_genai_mx8mx8bf16_grouped}
${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h
)
else()
if(USE_ROCM)
# Only include the kernels we want to build to avoid increasing binary size.
file(GLOB_RECURSE fbgemm_genai_native_rocm_hip
"${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped*.hip"
"${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip")
set_source_files_properties(${fbgemm_genai_native_rocm_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
# Add additional HIPCC compiler flags for performance
set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
-mllvm
-amdgpu-coerce-illegal-types=1
-mllvm
-enable-post-misched=0
-mllvm
-greedy-reverse-local-assignment=1
-fhip-new-launch-api)
# Add FBGEMM_GENAI include directories for torch_ops.h
list(APPEND ATen_CUDA_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include)
list(APPEND ATen_CUDA_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include)
elseif(USE_ROCM)
# Only include the kernels we want to build to avoid increasing binary size.
file(GLOB_RECURSE fbgemm_genai_native_rocm_hip
"${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped*.hip"
"${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip")
set_source_files_properties(${fbgemm_genai_native_rocm_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
# Only compile for gfx942 for now.
# This is rather hacky, I could not figure out a clean solution :(
set(HIP_CLANG_FLAGS_ORIGINAL ${HIP_CLANG_FLAGS})
string(REGEX REPLACE "--offload-arch=[^ ]*" "" FILTERED_HIP_CLANG_FLAGS "${HIP_CLANG_FLAGS}")
if("gfx942" IN_LIST PYTORCH_ROCM_ARCH)
list(APPEND FILTERED_HIP_CLANG_FLAGS --offload-arch=gfx942;)
endif()
set(HIP_CLANG_FLAGS ${FILTERED_HIP_CLANG_FLAGS})
# Add additional HIPCC compiler flags for performance
set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
-mllvm
-amdgpu-coerce-illegal-types=1
-mllvm
-enable-post-misched=0
-mllvm
-greedy-reverse-local-assignment=1
-fhip-new-launch-api)
hip_add_library(
fbgemm_genai STATIC
${fbgemm_genai_native_rocm_hip}
HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${FBGEMM_GENAI_EXTRA_HIPCC_FLAGS})
set(HIP_CLANG_FLAGS ${HIP_CLANG_FLAGS_ORIGINAL})
set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_compile_definitions(fbgemm_genai PRIVATE FBGEMM_GENAI_NO_EXTENDED_SHAPES)
target_include_directories(fbgemm_genai PUBLIC
# FBGEMM version of Composable Kernel is used due to some customizations
${FBGEMM_THIRD_PARTY}/composable_kernel/include
${FBGEMM_THIRD_PARTY}/composable_kernel/library/include
${FBGEMM_THIRD_PARTY}/cutlass/include
${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h
)
# Only compile for gfx942 for now.
# This is rather hacky, I could not figure out a clean solution :(
set(HIP_CLANG_FLAGS_ORIGINAL ${HIP_CLANG_FLAGS})
string(REGEX REPLACE "--offload-arch=[^ ]*" "" FILTERED_HIP_CLANG_FLAGS "${HIP_CLANG_FLAGS}")
if("gfx942" IN_LIST PYTORCH_ROCM_ARCH)
list(APPEND FILTERED_HIP_CLANG_FLAGS --offload-arch=gfx942;)
endif()
set(HIP_CLANG_FLAGS ${FILTERED_HIP_CLANG_FLAGS})
hip_add_library(
fbgemm_genai STATIC
${fbgemm_genai_native_rocm_hip}
HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${FBGEMM_GENAI_EXTRA_HIPCC_FLAGS})
set(HIP_CLANG_FLAGS ${HIP_CLANG_FLAGS_ORIGINAL})
set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_compile_definitions(fbgemm_genai PRIVATE FBGEMM_GENAI_NO_EXTENDED_SHAPES)
target_include_directories(fbgemm_genai PRIVATE
# FBGEMM version of Composable Kernel is used due to some customizations
${FBGEMM_THIRD_PARTY}/composable_kernel/include
${FBGEMM_THIRD_PARTY}/composable_kernel/library/include
${FBGEMM_THIRD_PARTY}/cutlass/include
${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h
)
# Add FBGEMM_GENAI include directories for torch_ops.h
list(APPEND ATen_HIP_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include)
list(APPEND ATen_HIP_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include)
endif()
endif()
@ -692,12 +699,6 @@ if(USE_CUDA AND NOT USE_ROCM)
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include)
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/tools/util/include)
# Add FBGEMM_GENAI include directories for torch_ops.h
if(USE_FBGEMM_GENAI)
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include)
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include)
endif()
if($ENV{ATEN_STATIC_CUDA})
if(CUDA_VERSION VERSION_LESS_EQUAL 12.9)
list(APPEND ATen_CUDA_DEPENDENCY_LIBS

View File

@ -389,37 +389,16 @@ void fillVersion<DLManagedTensorVersioned>(
// constructed out of ATen tensor
template <class T>
T* toDLPackImpl(const Tensor& src) {
auto view = src;
// Detect whether there is need to normalize the strides
// Background: gh-83069
//
// However, normalizing strides can come at a high-cost
// to slow down toDLPack conversion 3x, so we
// only normalize if needed.
//
// The following code detects whether the src follows
// a continuous pattern. If the src follows such pattern (common-case)
// then we do not need to normalize the strides.
bool need_normalize_strides = src.dim() == 1 && src.size(0) == 1 && src.stride(0) != 1;
// less common case, try normalizing the strides
if (need_normalize_strides) {
// create a new tensor with possibly normalized strides
// gh-83069
auto shape = src.sizes();
view = src.as_strided(shape, {1}, src.storage_offset());
}
ATenDLMTensor<T>* atDLMTensor(new ATenDLMTensor<T>);
atDLMTensor->handle = view;
atDLMTensor->handle = src;
atDLMTensor->tensor.manager_ctx = atDLMTensor;
atDLMTensor->tensor.deleter = &deleter<T>;
atDLMTensor->tensor.dl_tensor.data = view.data_ptr();
atDLMTensor->tensor.dl_tensor.data = src.data_ptr();
atDLMTensor->tensor.dl_tensor.device = torchDeviceToDLDevice(src.device());
atDLMTensor->tensor.dl_tensor.ndim = static_cast<int32_t>(src.dim());
atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src);
atDLMTensor->tensor.dl_tensor.shape = const_cast<int64_t*>(view.sizes().data());
atDLMTensor->tensor.dl_tensor.strides = const_cast<int64_t*>(view.strides().data());
atDLMTensor->tensor.dl_tensor.shape = const_cast<int64_t*>(src.sizes().data());
atDLMTensor->tensor.dl_tensor.strides = const_cast<int64_t*>(src.strides().data());
atDLMTensor->tensor.dl_tensor.byte_offset = 0;
fillVersion(&atDLMTensor->tensor);

View File

@ -52,16 +52,16 @@ struct DLPackTraits {};
template <>
struct DLPackTraits<DLManagedTensor> {
inline static const char* capsule = "dltensor";
inline static const char* used = "used_dltensor";
inline static constexpr const char* capsule = "dltensor";
inline static constexpr const char* used = "used_dltensor";
inline static auto toDLPack = at::toDLPack;
inline static auto fromDLPack = at::fromDLPack;
};
template <>
struct DLPackTraits<DLManagedTensorVersioned> {
inline static const char* capsule = "dltensor_versioned";
inline static const char* used = "used_dltensor_versioned";
inline static constexpr const char* capsule = "dltensor_versioned";
inline static constexpr const char* used = "used_dltensor_versioned";
inline static auto toDLPack = at::toDLPackVersioned;
inline static auto fromDLPack = at::fromDLPackVersioned;
};

View File

@ -42,8 +42,14 @@ const PythonTorchFunctionTLS& PythonTorchFunctionTLS::get_state() {
}
bool torch_function_mode_enabled() {
return PythonTorchFunctionTLS::get_disabled_state() != TorchFunctionDisabledState::ALL_DISABLED &&
PythonTorchFunctionTLS::stack_len() > 0;
// Manually flatten because gcc is refusing to inline here. Note
// that we are still calling __tls_get_addr twice here with GCC,
// presumably because of
// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81501 (which says
// the fix ships in GCC 16), but forcing inlining still improves
// performance.
const auto& ptfs = pythonTorchFunctionState;
return ptfs.disabled_state_ != TorchFunctionDisabledState::ALL_DISABLED && !ptfs.stack_.empty();
}
// This is needed to disambiguate the ternary torch function disabled states

View File

@ -27,6 +27,7 @@ struct TORCH_API PythonTorchFunctionTLS {
TorchFunctionDisabledState disabled_state_ =
TorchFunctionDisabledState::ENABLED;
std::vector<std::shared_ptr<c10::SafePyObject>> stack_;
friend TORCH_API bool torch_function_mode_enabled();
};
TORCH_API bool torch_function_mode_enabled();

View File

@ -624,7 +624,14 @@ struct TORCH_API IValue final {
IValue(const c10::SymBool& i) {
if (auto mi = i.maybe_as_bool()) {
tag = Tag::Bool;
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
payload.u.as_int = *mi;
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
/* due to byteorder if value assigned as_int, as_bool actually is not set correctly */
payload.u.as_bool = *mi;
#else
#error Unexpected or undefined __BYTE_ORDER__
#endif
} else {
tag = Tag::SymBool;
payload.u.as_intrusive_ptr = i.toSymNodeImpl().release();

View File

@ -13,6 +13,7 @@
#include <c10/core/ScalarType.h>
#include <ATen/cuda/tunable/TunableOp.h>
#include <ATen/cuda/tunable/Tunable.h>
#include <ATen/cuda/CUDABlas.h>
#include <ATen/cuda/Exceptions.h>
#include <c10/util/StringUtil.h>
@ -150,6 +151,7 @@ inline std::string ScalarTypeToBLASType(c10::ScalarType scalar_type) {
BLASType = "unknown";
}
return BLASType;
}
// Similar to Compute Type in GemmRocblas.h
@ -244,33 +246,25 @@ inline std::string to_string_epilogue(const at::cuda::blas::GEMMAndBiasActivatio
namespace detail {
static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size) {
static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size, const NumericalCheckConfig& config) {
if (!config.enabled) {
return true; // skip when disabled
}
auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA);
// comparison done as 1D tensor
at::Tensor ref = at::from_blob(c, {size}, options);
at::Tensor oth = at::from_blob(other_c, {size}, options);
at::Tensor ref_float = ref.to(at::kFloat);
at::Tensor oth_float = oth.to(at::kFloat);
std::vector<double> atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
std::vector<double> rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
double last_succeed_atol = 1;
double last_succeed_rtol = 1;
for (auto& atol : atols) {
for (auto& rtol : rtols) {
if (at::allclose(ref_float, oth_float, rtol, atol)) {
last_succeed_atol = atol;
last_succeed_rtol = rtol;
}
}
}
if (last_succeed_atol == 1) {
return false;
}
else {
TUNABLE_LOG3("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol);
}
return true;
const bool ok = at::allclose(ref_float, oth_float, config.rtol, config.atol);
if (ok) {
TUNABLE_LOG3("├──verify numerics: PASSED with atol=", config.atol, ", rtol=", config.rtol);
} else {
TUNABLE_LOG3("├──verify numerics: FAILED with atol=", config.atol, ", rtol=", config.rtol);
}
return ok;
}
}
@ -355,8 +349,10 @@ struct GemmParams : OpParams {
}
TuningStatus NumericalCheck(GemmParams<T> *other) {
auto* ctx = getTuningContext();
auto cfg = ctx->GetNumericalCheckConfig();
auto c_dtype = c10::CppTypeToScalarType<T>::value;
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
}
char transa{};
@ -449,8 +445,10 @@ struct GemmAndBiasParams : OpParams {
}
TuningStatus NumericalCheck(GemmAndBiasParams<T> *other) {
auto* ctx = getTuningContext();
auto cfg = ctx->GetNumericalCheckConfig();
auto c_dtype = c10::CppTypeToScalarType<T>::value;
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
}
char transa{};
@ -546,8 +544,10 @@ struct GemmStridedBatchedParams : OpParams {
}
TuningStatus NumericalCheck(GemmStridedBatchedParams<T> *other) {
auto* ctx = getTuningContext();
auto cfg = ctx->GetNumericalCheckConfig();
auto c_dtype = c10::CppTypeToScalarType<C_Dtype>::value;
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
}
char transa{};
@ -663,7 +663,9 @@ struct ScaledGemmParams : OpParams {
}
TuningStatus NumericalCheck(ScaledGemmParams<T> *other) {
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
auto* ctx = getTuningContext();
auto cfg = ctx->GetNumericalCheckConfig();
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
}
char transa{};

View File

@ -145,7 +145,7 @@ programmatically since the settings become fixed. Use the C++ or Python APIs ins
| PYTORCH_TUNABLEOP_VERBOSE | Default is 0. Set to 1 to enable basic logging. 2 for basic tuning status. 3 for full trace. |
| PYTORCH_TUNABLEOP_VERBOSE_FILENAME | Default is "err" for stderr. Set to "out" for stdout or a filename for capturing verbose logging. |
| PYTORCH_TUNABLEOP_FILENAME | Default is 'tunableop_results.csv'. |
| PYTORCH_TUNABLEOP_NUMERICAL_CHECK | Default is 0. Set to 1 to enable. |
| PYTORCH_TUNABLEOP_NUMERICAL_CHECK | Default is off. Set 'atol_rtol' to enable, for example "1e-5_1e-5". |
| PYTORCH_TUNABLEOP_ROCBLAS_ENABLED | Default is 1. Set to 0 to disable rocblas being considered during tuning. |
| PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED | Default is 1. Set to 0 to disable hipblaslt being considered during tuning. |
| PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS | Default is 30. Unit is milliseconds. |
@ -173,10 +173,9 @@ All python APIs exist in the `torch.cuda.tunable` module.
| get_max_tuning_iterations() -> int | |
| set_filename(filename: str, insert_device_ordinal: bool = False) -> None | |
| get_filename() -> str | |
| set_numerical_check_tolerances(enable: bool, atol: float, rtol: float) -> None | Enable or disable numerical checking; atol and rtol default to 1e-5.
| get_results() -> Tuple[str, str, str, float] | |
| get_validators() -> Tuple[str, str] | |
| write_file_on_exit(val: bool) -> None | Default is True. |
| write_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). |
| read_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). |
| tune_gemm_in_file(filename: str) -> None | read an untuned file and tune GEMMs in it. |
| mgpu_tune_gemm_in_file(filename_pattern: str, num_gpus: int) -> None: -> None | read one or more untuned files and tune all unique GEMMs on one or more GPUs. |

View File

@ -107,14 +107,30 @@ void TuningResultsManager::AddImpl(const std::string& op_signature,
}
void TuningResultsManager::Add(const std::string& op_signature, const std::string& params_signature, ResultEntry best) {
std::scoped_lock l{lock_};
bool is_new = false;
ResultEntry inserted = ResultEntry::Null();
auto it = results_.find(op_signature);
if (it == results_.end()) {
it = results_.insert({op_signature, {}}).first;
// ---- mutate maps under results lock ----
{
std::scoped_lock l{lock_};
auto& km = results_[op_signature]; // creates if missing
is_new = (km.find(params_signature) == km.end());
AddImpl(op_signature, params_signature, std::move(best), km);
if (is_new) {
inserted = km.at(params_signature); // snapshot for I/O after unlocking
}
}
if (!is_new) return; // only write once per unique (op, params)
TuningContext* ctx = getTuningContext();
if (ctx->IsTuningEnabled() && !ctx->IsRecordUntunedEnabled()) {
InitRealtimeAppend(ctx->GetFilename(), ctx->GetTuningResultsValidator().GetAllValidators());
if (is_new && realtime_out_ && realtime_out_->good()) {
AppendResultLine(op_signature, params_signature, inserted);
}
}
AddImpl(op_signature, params_signature, std::move(best), it->second);
}
void TuningResultsManager::RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature,
@ -150,6 +166,77 @@ void TuningResultsManager::RecordUntuned( std::ofstream& untuned_file, const std
}
}
void TuningResultsManager::InitRealtimeAppend(const std::string& filename, const std::unordered_map<std::string, std::string>& validators) {
std::scoped_lock fl{realtime_file_mutex_};
if (realtime_out_ && realtime_out_->good() && realtime_filename_ == filename) {
return;
}
if (realtime_out_ && realtime_filename_ != filename) {
realtime_out_->flush();
realtime_out_->close();
realtime_out_.reset();
validators_written_ = false;
}
bool file_exists = false;
bool file_empty = true;
{
std::ifstream check_file(filename);
if (check_file.good()) {
file_exists = true;
file_empty = (check_file.peek() == std::ifstream::traits_type::eof());
}
}
realtime_out_ = std::make_unique<std::ofstream>(filename, std::ios::out | std::ios::app);
if (!realtime_out_->good()) {
TORCH_WARN("TunableOp realtime append: failed to open '", filename,"'");
realtime_out_.reset();
return;
}
if(!file_exists || file_empty) {
for(const auto& [key, val] : validators) {
(*realtime_out_) << "Validator," << key << "," << val << std::endl;
realtime_out_->flush();
}
validators_written_ = true;
TUNABLE_LOG2("Wrote validators to realtime output file");
}
realtime_filename_ = filename;
}
void TuningResultsManager::AppendResultLine(const std::string& op_sig, const std::string& param_sig, const ResultEntry& result) {
std::scoped_lock fl{realtime_file_mutex_};
if(!realtime_out_ || !realtime_out_->good()) {
return;
}
(*realtime_out_) << op_sig << "," << param_sig << "," << result << std::endl;
realtime_out_->flush(); //ensure immediate write to disk
TUNABLE_LOG3("Realtime append: ", op_sig, "(", param_sig, ") -> ", result);
}
void TuningResultsManager::CloseRealtimeAppend() {
std::scoped_lock fl{realtime_file_mutex_};
if(realtime_out_) {
realtime_out_->flush();
realtime_out_->close();
realtime_out_.reset();
TUNABLE_LOG2("Closed realtime output file");
}
}
void TuningResultsManager::Delete(const std::string& op_signature, const std::string& params_signature) {
std::scoped_lock l{lock_};
@ -396,7 +483,6 @@ TuningContext::TuningContext() :
tuning_enable_{true},
record_untuned_enable_{false},
manager_initialized_{false},
write_file_on_exit_{true},
numerics_check_enable_{false},
max_tuning_duration_ms_{30},
max_tuning_iterations_{100},
@ -417,20 +503,8 @@ TuningContext::~TuningContext() {
// but doesn't do any computation itself.
return;
}
auto filename = GetFilename();
if (IsTunableOpEnabled() && IsTuningEnabled() && !filename.empty() && write_file_on_exit_) {
if (results_count_from_input_file_ < GetTuningResultsManager().GetSize()) {
if (results_count_from_input_file_ > 0) {
TUNABLE_LOG1("additional tuning results available, rewriting file ", filename);
}
else {
TUNABLE_LOG1("writing file ", filename);
}
if (!WriteFile(filename)) {
TUNABLE_LOG1("failed to write file ", filename);
}
}
}
TUNABLE_LOG1("Closing File");
GetTuningResultsManager().CloseRealtimeAppend(); // Since, we do instant logging by default now.
if (untuned_file_.good()) {
untuned_file_.close();
@ -511,20 +585,54 @@ std::ofstream& TuningContext::GetUntunedFile(){
return untuned_file_;
}
void TuningContext::WriteFileOnExit(bool value) {
write_file_on_exit_ = value;
}
void TuningContext::EnableNumericsCheck(bool value) {
numerics_check_enable_ = value;
}
bool TuningContext::IsNumericsCheckEnabled() const {
const auto env = c10::utils::get_env("PYTORCH_TUNABLEOP_NUMERICAL_CHECK");
if (env == "1") {
return true;
NumericalCheckConfig TuningContext::GetNumericalCheckConfig() const {
const auto env_opt = c10::utils::get_env("PYTORCH_TUNABLEOP_NUMERICAL_CHECK");
if (!env_opt.has_value()) {
return numerics_cfg_;
}
return numerics_check_enable_;
const std::string& env = env_opt.value();
if (env == "0") {
return NumericalCheckConfig(false, 1e-5, 1e-5);
}
const size_t underscore = env.find('_');
TORCH_CHECK(
underscore != std::string::npos,
"Invalid PYTORCH_TUNABLEOP_NUMERICAL_CHECK format. "
"Expected 'atol_rtol', got: ",
env);
double atol = 0.0;
double rtol = 0.0;
try {
atol = std::stod(env.substr(0, underscore));
rtol = std::stod(env.substr(underscore + 1));
} catch (const std::exception& e) {
TORCH_CHECK(false, "Failed to parse PYTORCH_TUNABLEOP_NUMERICAL_CHECK: ", e.what());
}
TORCH_CHECK( atol > 0.0 && rtol > 0.0, "Tolerance values must be positive. atol=", atol, ", rtol=", rtol);
return NumericalCheckConfig(true, atol, rtol);
}
void TuningContext::SetNumericalCheckConfig(bool enabled, double atol, double rtol) {
TORCH_CHECK(atol > 0.0 && rtol > 0.0, "Numerical check tolerances must be positive");
numerics_cfg_ = {enabled, atol, rtol};
}
bool TuningContext::IsNumericsCheckEnabled() const {
const auto cfg = GetNumericalCheckConfig();
return cfg.enabled || numerics_check_enable_;
}
void TuningContext::SetMaxTuningDurationMs(int max_duration_ms) {
@ -634,11 +742,6 @@ TuningResultsManager& TuningContext::GetTuningResultsManager() {
auto filename = GetFilename();
if (!filename.empty() && !IsRecordUntunedEnabled()) {
ReadFile(filename);
// attempt immediately to open file for writing to catch errors early
std::ofstream file(filename, std::ios::out | std::ios::app);
if (!file.good()) {
TORCH_WARN("failed to open file '", filename, "' for writing; your tuning results will not be saved");
}
}
});
return manager_;
@ -744,27 +847,6 @@ bool TuningContext::ReadFile(const std::string& filename_) {
return true;
}
bool TuningContext::WriteFile(const std::string& filename_) {
std::string filename = filename_.empty() ? GetFilename() : filename_;
std::ofstream file(filename, std::ios::out | std::ios::trunc);
if (!file.good()) {
TUNABLE_LOG1("error opening tuning results file for writing ", filename);
return false;
}
auto validators = GetTuningResultsValidator().GetAllValidators();
for (const auto& [key, val] : validators) {
file << "Validator," << key << "," << val << std::endl;
}
auto results = GetTuningResultsManager().Dump();
for (const auto& [op_sig, kernelmap] : results) {
for (const auto& [param_sig, result] : kernelmap) {
file << op_sig << "," << param_sig << "," << result << std::endl;
}
}
file.close();
return true;
}
namespace {
struct MaybeDelete {

View File

@ -103,10 +103,24 @@ class TORCH_CUDA_CPP_API TuningResultsManager {
void RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature,
const std::string& params_signature, const std::string& blas_signature);
void InitRealtimeAppend(
const std::string& filename,
const std::unordered_map<std::string, std::string>& validators);
void AppendResultLine(const std::string& op_sig,
const std::string& param_sig,
const ResultEntry& result);
void CloseRealtimeAppend(); // For clean shutdown
private:
std::mutex lock_;
std::mutex realtime_file_mutex_;
std::unique_ptr<std::ofstream> realtime_out_;
std::string realtime_filename_;
ResultsMap results_;
UntunedMap untuned_results_;
bool validators_written_ = false;
};
@ -134,6 +148,16 @@ class TORCH_CUDA_CPP_API TuningResultsValidator {
GetValidateFuncs validators_;
};
struct NumericalCheckConfig {
bool enabled{false};
double atol{1e-5};
double rtol{1e-5};
NumericalCheckConfig() = default;
NumericalCheckConfig(bool e, double a, double r) : enabled(e), atol(a), rtol(r) {}
};
class TORCH_CUDA_CPP_API TuningContext {
public:
TuningContext();
@ -155,6 +179,8 @@ class TORCH_CUDA_CPP_API TuningContext {
void EnableNumericsCheck(bool value);
bool IsNumericsCheckEnabled() const;
void SetNumericalCheckConfig(bool enabled, double atol, double rtol);
NumericalCheckConfig GetNumericalCheckConfig() const;
void SetMaxTuningDurationMs(int max_duration_ms);
int GetMaxTuningDurationMs() const;
@ -185,10 +211,7 @@ class TORCH_CUDA_CPP_API TuningContext {
void SetFilename(const std::string& filename, bool insert_device_ordinal=false);
std::string GetFilename() const;
void WriteFileOnExit(bool value);
bool ReadFile(const std::string& filename={});
bool WriteFile(const std::string& filename={});
template<class... Types>
void Log(int level, Types... args) {
@ -207,7 +230,6 @@ class TORCH_CUDA_CPP_API TuningContext {
bool tuning_enable_;
bool record_untuned_enable_;
bool manager_initialized_;
bool write_file_on_exit_;
bool numerics_check_enable_;
int max_tuning_duration_ms_;
int max_tuning_iterations_;
@ -222,6 +244,8 @@ class TORCH_CUDA_CPP_API TuningContext {
std::ofstream untuned_file_;
size_t results_count_from_input_file_;
bool is_shutting_down_;
NumericalCheckConfig numerics_cfg_{};
};
TORCH_CUDA_CPP_API TuningContext* getTuningContext();

View File

@ -267,27 +267,10 @@ class TunableOp {
for (size_t i = 0; i < op_names_.size(); i++) {
auto* candidate = ops_[op_names_[i]].get(); // borrow pointer
if (do_numerics_check) {
ParamsT* numerical_params = params->DeepCopy(false);
auto status = candidate->Call(numerical_params);
if (status != OK) {
numerical_params->Delete();
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
continue;
}
status = reference_params->NumericalCheck(numerical_params);
numerical_params->Delete();
if (status != OK) {
TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
continue;
}
}
else {
auto status = candidate->Call(reusable_params[0]);
if (status != OK) {
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
continue;
}
auto status = candidate->Call(reusable_params[0]);
if (status != OK) {
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
continue;
}
// collect a small profile
@ -310,6 +293,22 @@ class TunableOp {
continue;
}
if (do_numerics_check) {
ParamsT* numerical_params = params->DeepCopy(false);
auto status = candidate->Call(numerical_params);
if (status != OK) {
numerical_params->Delete();
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
continue;
}
status = reference_params->NumericalCheck(numerical_params);
numerical_params->Delete();
if (status != OK) {
TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
continue;
}
}
// for warmup does user set max duration, max iters, or both?
// warmup is skipped by default, i.e. warmup_iter = 0
// warmup will be set to the non-zero value of max_warmup_duration

View File

@ -213,40 +213,22 @@ static cudnn_grid_sample_backward_batch_rule(
return grid_sample_backward_helper_out(std::move(bw_out), 0, 0, bdim_size);
}
// TODO: replace with targetable functionalization
// uses functional formulation for one_hot under vmap to be compatible with
// fakeTensor/dynamic shapes and compiled functorch transforms.
// mirrors the meta path in aten/src/ATen/native/Onehot.cpp,
// but requires explicit positive num_classes under vmap to avoid
// data-dependent output shapes.
static Tensor one_hot_decomposition_hack(const Tensor &self, int64_t num_classes) {
TORCH_CHECK(self.dtype() == kLong, "one_hot is only applicable to index tensor.");
auto shape = self.sym_sizes().vec();
// empty tensor could be converted to one hot representation,
// but shape inference is not possible.
if (self.sym_numel() == 0) {
if (num_classes <= 0) {
TORCH_CHECK(false, "Can not infer total number of classes from empty tensor.");
} else {
shape.emplace_back(num_classes);
return at::empty_symint(shape, self.options());
}
}
// disallow implicit inference under vmap; this would be data-dependent
// and is intentionally guarded by Dynamo in torch/_dynamo/variables/torch.py.
TORCH_CHECK(num_classes > 0, "When vmap-ing torch.nn.functional.one_hot, please "
"provide an explicit positive num_classes argument.");
// Disabling all of the following checks. This is OK because scatter has checks too.
// Maybe one_hot should be a primitive wrt autograd so we don't have to deal with this.
// // non-empty tensor
// if (self.device().type() != at::kCUDA) {
// //for cuda, rely on device assert thrown by scatter
// TORCH_CHECK(self.min().item().toLong() >= 0, "Class values must be non-negative.");
// }
// if (self.device().type() != at::kCUDA) {
// //rely on device asserts from scatter to avoid sync here
// TORCH_CHECK(num_classes > self.max().item().toLong(), "Class values must be smaller than num_classes.");
// }
shape.emplace_back(num_classes);
Tensor ret = at::zeros_symint(shape, self.options());
return ret.scatter(-1, self.unsqueeze(-1), 1);
const auto options = self.options();
at::Tensor index = at::arange(num_classes, options);
return at::eq(self.unsqueeze(-1), index).to(at::kLong);
}
template <typename A, A a, typename C>

View File

@ -34,16 +34,16 @@ Tensor one_hot(const Tensor &self, int64_t num_classes) {
}
}
auto shape = self.sizes().vec();
auto shape = self.sym_sizes().vec();
// empty tensor could be converted to one hot representation,
// but shape inference is not possible.
if (self.numel() == 0) {
if (self.sym_numel() == 0) {
if (num_classes <= 0) {
TORCH_CHECK(false, "Can not infer total number of classes from empty tensor.");
} else {
shape.push_back(num_classes);
return at::empty(shape, self.options());
shape.emplace_back(num_classes);
return at::empty_symint(shape, self.options());
}
}
@ -66,8 +66,8 @@ Tensor one_hot(const Tensor &self, int64_t num_classes) {
}
}
shape.push_back(num_classes);
Tensor ret = at::zeros(shape, self.options());
shape.emplace_back(num_classes);
Tensor ret = at::zeros_symint(shape, self.options());
ret.scatter_(-1, self.unsqueeze(-1), 1);
return ret;
}

View File

@ -120,7 +120,7 @@ static void pow_tensor_scalar_kernel(
} else if (dtype == ScalarType::Half) {
[&]() {
using scalar_t =
decltype(c10::impl::ScalarTypeToCPPType<ScalarType::Half>::t);
c10::impl::ScalarTypeToCPPTypeT<ScalarType::Half>;
const auto exp = exp_scalar.to<scalar_t>();
using Vec = Vectorized<scalar_t>;
cpu_kernel_vec(iter,

View File

@ -1230,8 +1230,205 @@ std::pair<ScalingType, ScalingType> get_joint_scaling(
);
}
Tensor&
_tunable_scaled_gemm_rocm(
cublasCommonArgs& args,
const Tensor& mat1, const Tensor& mat2,
const Tensor& scale_a, const Tensor& scale_b,
const ScalingType scaling_choice_a, const ScalingType scaling_choice_b,
const std::optional<Tensor>& bias,
const bool use_fast_accum,
const at::ScalarType out_dtype,
Tensor& out) {
#ifdef USE_ROCM
#define TUNABLE_DISPATCH(BLASOP_A, BLASOP_B) \
if (mat1.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fnuz, at::Float8_e4m3fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fnuz, at::Float8_e5m2fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
} \
else if (mat1.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2fnuz, at::Float8_e4m3fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2fnuz, at::Float8_e5m2fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
} \
else if (mat1.scalar_type() == ScalarType::Float8_e4m3fn) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fn, at::Float8_e4m3fn, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fn, at::Float8_e5m2, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
} \
else if (mat1.scalar_type() == ScalarType::Float8_e5m2) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2, at::Float8_e4m3fn, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2, at::Float8_e5m2, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
}
AT_DISPATCH_V2(out_dtype, "_tunable_scaled_gemm", AT_WRAP([&] {
bool transa_ = ((args.transa != 'n') && (args.transa != 'N'));
bool transb_ = ((args.transb != 'n') && (args.transb != 'N'));
at::cuda::tunable::ScaledGemmParams<scalar_t> params;
params.transa = args.transa;
params.transb = args.transb;
params.m = args.m;
params.n = args.n;
params.k = args.k;
params.a = args.mata->data_ptr();
params.a_scale_ptr = args.scale_mata_ptr;
params.a_scale_dtype = args.scale_mata_dtype.value();
params.lda = args.lda;
params.a_dtype = args.mata->scalar_type();
params.a_scale_dtype = args.scale_mata_dtype.value();
params.a_scaling_type = args.scaling_mata_type.value();
params.b = args.matb->data_ptr();
params.b_scale_ptr = args.scale_matb_ptr;
params.b_scale_dtype = args.scale_matb_dtype.value();
params.ldb = args.ldb;
params.b_dtype = args.matb->scalar_type();
params.b_scale_dtype = args.scale_matb_dtype.value();
params.b_scaling_type = args.scaling_matb_type.value();
params.bias_ptr = bias ? bias->data_ptr(): nullptr;
params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype) ? at::ScalarType::Half : out_dtype;
params.c = args.result->data_ptr();
params.c_scale_ptr = args.scale_result_ptr;
params.ldc = args.result_ld;
params.c_dtype = out_dtype;
params.use_fast_accum = use_fast_accum;
if (transa_ && transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T)
}
else if (transa_ && !transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::N)
}
else if (!transa_ && transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::T)
}
else if (!transa_ && !transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::N)
}
else {
TORCH_CHECK(false, "unreachable");
}
}),
kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_FLOATING_TYPES));
#undef TUNABLE_DISPATCH
return out;
#else
TORCH_CHECK_NOT_IMPLEMENTED(false, "_scaled_gemm_rocm only callable on ROCM devices");
#endif
}
Tensor&
_scaled_gemm(
const Tensor& mat1, const Tensor& mat2,
const Tensor& scale_a, const Tensor& scale_b,
const ScalingType scaling_choice_a, const ScalingType scaling_choice_b,
const std::optional<Tensor>& bias,
const bool use_fast_accum,
Tensor& out) {
cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, std::nullopt, scaling_choice_a, scaling_choice_b);
const auto out_dtype_ = args.result->scalar_type();
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
// ROCM enables the TunableOp path only
// but can fallback to at::cuda::blas::scaled_gemm
#ifdef USE_ROCM
auto tuning_ctx = at::cuda::tunable::getTuningContext();
bool tunable_op_enabled = tuning_ctx->IsTunableOpEnabled();
#else
bool tunable_op_enabled = false;
#endif
if (tunable_op_enabled) {
// Only available on ROCM
return _tunable_scaled_gemm_rocm(
args,
mat1, mat2,
scale_a, scale_b,
scaling_choice_a, scaling_choice_b,
bias,
use_fast_accum,
out_dtype_,
out);
}
else
{
at::cuda::blas::scaled_gemm(
args.transa,
args.transb,
args.m,
args.n,
args.k,
args.mata->data_ptr(),
args.scale_mata_ptr,
args.lda,
args.mata->scalar_type(),
args.scale_mata_dtype.value(),
args.scaling_mata_type.value(),
args.matb->data_ptr(),
args.scale_matb_ptr,
args.ldb,
args.matb->scalar_type(),
args.scale_matb_dtype.value(),
args.scaling_matb_type.value(),
bias ? bias->data_ptr(): nullptr,
bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
args.result->data_ptr(),
args.scale_result_ptr,
args.result_ld,
out_dtype_,
use_fast_accum);
return out;
}
}
} // namespace
// NOTE(slayton58): This is defined as part of the _v2 code (way) below - declare the signature here
// to help cleanup v1 call structure.
Tensor&
_scaled_rowwise_rowwise(
const Tensor&, const Tensor&,
const Tensor&, const Tensor&,
const std::optional<Tensor>&,
const c10::ScalarType,
bool,
Tensor&);
// Computes matrix multiply + bias while applying scaling to input and output matrices
// Scales are only applicable when matrices are of Float8 type and assumed to be equal to 1.0 by default.
// If output matrix type is 16 or 32-bit type, scale_result is not applied.
@ -1273,6 +1470,10 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
// by decreasing priority. We prefer "simpler" schemes as they are supported
// more broadly (more GPU archs, more CUDA versions) and because they are more
// efficient. This tends to matter only for small matmuls (e.g., 1x1x128).
// List of supported BlockWise pairs for FP8:
// https://docs.nvidia.com/cuda/cublas/#element-1d-and-128x128-2d-block-scaling-for-fp8-data-types
auto [scaling_choice_a, scaling_choice_b] = get_joint_scaling(
{
std::make_pair(ScalingType::TensorWise, ScalingType::TensorWise),
@ -1305,7 +1506,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
TORCH_CHECK(isFloat8Type(mat2.scalar_type()) || mat2.scalar_type() == ScalarType::Float4_e2m1fn_x2, "Expected mat2 to be Float8 or Float4_x2 matrix got ", mat2.scalar_type());
#ifndef USE_ROCM
// Type restrictions imposed by CuBLASLt as of CUDA-12.1
TORCH_CHECK(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2,
TORCH_CHECK_VALUE(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2,
"Multiplication of two Float8_e5m2 matrices is not supported");
#endif
if (use_fast_accum) {
@ -1371,41 +1572,44 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
// NVIDIA's cuBLAS only started supporting row-wise scaling in version 12.9,
// and only for compute capability 9.0+. In other cases we use CUTLASS.
#ifndef USE_ROCM
// We are doing row-wise scaling
auto dprops = at::cuda::getCurrentDeviceProperties();
if (scaling_choice_a == ScalingType::RowWise && scaling_choice_b == ScalingType::RowWise
&& ((dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900)
// cuBLAS only supports tiled 1D factor layout for 1D block scaling, no 2D block scales
|| (dprops->major >= 10 && (!scale_a.sizes().empty() || !scale_b.sizes().empty())))) {
TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling.");
at::cuda::detail::f8f8bf16_rowwise(
mat1,
mat2,
scale_a,
scale_b,
bias,
use_fast_accum,
out);
return out;
}
#else
if (scaling_choice_a == ScalingType::RowWise && scaling_choice_b == ScalingType::RowWise) {
#ifndef USE_ROCM
auto dprops = at::cuda::getCurrentDeviceProperties();
if ((dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900)
// cuBLAS only supports tiled 1D factor layout for 1D block scaling, no 2D block scales
|| (dprops->major >= 10 && (!scale_a.sizes().empty() || !scale_b.sizes().empty()))) {
TORCH_CHECK_VALUE(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling.");
return _scaled_rowwise_rowwise(
mat1,
mat2,
scale_a,
scale_b,
bias,
out.scalar_type(),
use_fast_accum,
out);
}
#else
// For ROCm, match behavior of f8f8bf16_rowwise type checking, for unit test purposes.
Tensor b = mat2;
if (_scaled_mm_is_fnuz()) {
TORCH_CHECK(b.dtype() == at::kFloat8_e4m3fnuz);
TORCH_CHECK_VALUE(b.dtype() == at::kFloat8_e4m3fnuz,
"Expected b.dtype() == at::kFloat8_e4m3fnuz, got: ", b.dtype());
}
else {
TORCH_CHECK(b.dtype() == at::kFloat8_e4m3fn);
TORCH_CHECK_VALUE(b.dtype() == at::kFloat8_e4m3fn,
"Expected b.dtype() == at::kFloat8_e4m3fn, got: ", b.dtype());
}
// Until more than bf16 is supported.
TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16,
TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16,
"hipblaslt rowwise _scaled_mm only supports BFloat16 output but got ", out.scalar_type());
#endif
}
else if (scaling_choice_a == ScalingType::BlockWise1x32 && scaling_choice_b == ScalingType::BlockWise1x32) {
#ifdef USE_ROCM
#if ROCM_VERSION >= 70000
TORCH_CHECK(at::detail::getCUDAHooks().isGPUArch({"gfx950"}),
TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950"}),
"Block-wise scaling for Float8_e8m0fnu is only supported on gfx950");
int packed_factor = 1;
@ -1414,163 +1618,20 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
// effectively packing two elements into one byte.
packed_factor = 2;
}
TORCH_CHECK(mat1.size(0) % 16 == 0 && (mat1.size(1) * packed_factor) % 128 == 0 &&
TORCH_CHECK_VALUE(mat1.size(0) % 16 == 0 && (mat1.size(1) * packed_factor) % 128 == 0 &&
mat2.size(1) % 16 == 0,
"M, N must be multiples of 16 and K must be multiple of 128 for block-wise scaling");
TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16 ||
TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16 ||
out.scalar_type() == ScalarType::Half,
"Block-wise scaling only supports BFloat16 or Half output types");
#else
TORCH_CHECK(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
TORCH_CHECK_NOT_IMPLEMENTED(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
#endif
#endif
}
#endif
cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, scale_result, scaling_choice_a, scaling_choice_b);
const auto out_dtype_ = args.result->scalar_type();
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
#ifdef USE_ROCM
auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
#define TUNABLE_DISPATCH(BLASOP_A, BLASOP_B) \
if (mat1.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fnuz, at::Float8_e4m3fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fnuz, at::Float8_e5m2fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
} \
else if (mat1.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2fnuz, at::Float8_e4m3fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2fnuz, at::Float8_e5m2fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
} \
else if (mat1.scalar_type() == ScalarType::Float8_e4m3fn) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fn, at::Float8_e4m3fn, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fn, at::Float8_e5m2, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
} \
else if (mat1.scalar_type() == ScalarType::Float8_e5m2) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2, at::Float8_e4m3fn, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2, at::Float8_e5m2, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
}
AT_DISPATCH_V2(out_dtype_, "_tunable_scaled_gemm", AT_WRAP([&] {
bool transa_ = ((args.transa != 'n') && (args.transa != 'N'));
bool transb_ = ((args.transb != 'n') && (args.transb != 'N'));
at::cuda::tunable::ScaledGemmParams<scalar_t> params;
params.transa = args.transa;
params.transb = args.transb;
params.m = args.m;
params.n = args.n;
params.k = args.k;
params.a = args.mata->data_ptr();
params.a_scale_ptr = args.scale_mata_ptr;
params.a_scale_dtype = args.scale_mata_dtype.value();
params.lda = args.lda;
params.a_dtype = args.mata->scalar_type();
params.a_scale_dtype = args.scale_mata_dtype.value();
params.a_scaling_type = args.scaling_mata_type.value();
params.b = args.matb->data_ptr();
params.b_scale_ptr = args.scale_matb_ptr;
params.b_scale_dtype = args.scale_matb_dtype.value();
params.ldb = args.ldb;
params.b_dtype = args.matb->scalar_type();
params.b_scale_dtype = args.scale_matb_dtype.value();
params.b_scaling_type = args.scaling_matb_type.value();
params.bias_ptr = bias ? bias->data_ptr(): nullptr;
params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_;
params.c = args.result->data_ptr();
params.c_scale_ptr = args.scale_result_ptr;
params.ldc = args.result_ld;
params.c_dtype = out_dtype_;
params.use_fast_accum = use_fast_accum;
if (transa_ && transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T)
}
else if (transa_ && !transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::N)
}
else if (!transa_ && transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::T)
}
else if (!transa_ && !transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::N)
}
else {
TORCH_CHECK(false, "unreachable");
}
}),
kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_FLOATING_TYPES));
#undef TUNABLE_DISPATCH
}
else
#endif
{
at::cuda::blas::scaled_gemm(
args.transa,
args.transb,
args.m,
args.n,
args.k,
args.mata->data_ptr(),
args.scale_mata_ptr,
args.lda,
args.mata->scalar_type(),
args.scale_mata_dtype.value(),
args.scaling_mata_type.value(),
args.matb->data_ptr(),
args.scale_matb_ptr,
args.ldb,
args.matb->scalar_type(),
args.scale_matb_dtype.value(),
args.scaling_matb_type.value(),
bias ? bias->data_ptr(): nullptr,
bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
args.result->data_ptr(),
args.scale_result_ptr,
args.result_ld,
out_dtype_,
use_fast_accum);
}
return out;
return _scaled_gemm(mat1, mat2, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
}
namespace {
@ -1910,159 +1971,6 @@ std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 8>
{ "nvfp4_nvfp4_single_scale", check_nvfp4_recipe_single_scale, ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE },
{ "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
Tensor&
_cutlass_scaled_gemm(
const Tensor& mat1, const Tensor& mat2,
const Tensor& scale_a, const Tensor& scale_b,
const ScalingType scaling_choice_a, const ScalingType scaling_choice_b,
const std::optional<Tensor>& bias,
const bool use_fast_accum,
Tensor& out) {
cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, std::nullopt, scaling_choice_a, scaling_choice_b);
const auto out_dtype_ = args.result->scalar_type();
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
#ifdef USE_ROCM
auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
#define TUNABLE_DISPATCH(BLASOP_A, BLASOP_B) \
if (mat1.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fnuz, at::Float8_e4m3fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fnuz, at::Float8_e5m2fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
} \
else if (mat1.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2fnuz, at::Float8_e4m3fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2fnuz, at::Float8_e5m2fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
} \
else if (mat1.scalar_type() == ScalarType::Float8_e4m3fn) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fn, at::Float8_e4m3fn, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fn, at::Float8_e5m2, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
} \
else if (mat1.scalar_type() == ScalarType::Float8_e5m2) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2, at::Float8_e4m3fn, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2, at::Float8_e5m2, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
}
AT_DISPATCH_V2(out_dtype_, "_tunable_scaled_gemm", AT_WRAP([&] {
bool transa_ = ((args.transa != 'n') && (args.transa != 'N'));
bool transb_ = ((args.transb != 'n') && (args.transb != 'N'));
at::cuda::tunable::ScaledGemmParams<scalar_t> params;
params.transa = args.transa;
params.transb = args.transb;
params.m = args.m;
params.n = args.n;
params.k = args.k;
params.a = args.mata->data_ptr();
params.a_scale_ptr = args.scale_mata_ptr;
params.a_scale_dtype = args.scale_mata_dtype.value();
params.lda = args.lda;
params.a_dtype = args.mata->scalar_type();
params.a_scale_dtype = args.scale_mata_dtype.value();
params.a_scaling_type = args.scaling_mata_type.value();
params.b = args.matb->data_ptr();
params.b_scale_ptr = args.scale_matb_ptr;
params.b_scale_dtype = args.scale_matb_dtype.value();
params.ldb = args.ldb;
params.b_dtype = args.matb->scalar_type();
params.b_scale_dtype = args.scale_matb_dtype.value();
params.b_scaling_type = args.scaling_matb_type.value();
params.bias_ptr = bias ? bias->data_ptr(): nullptr;
params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_;
params.c = args.result->data_ptr();
params.c_scale_ptr = args.scale_result_ptr;
params.ldc = args.result_ld;
params.c_dtype = out_dtype_;
params.use_fast_accum = use_fast_accum;
if (transa_ && transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T)
}
else if (transa_ && !transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::N)
}
else if (!transa_ && transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::T)
}
else if (!transa_ && !transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::N)
}
else {
TORCH_CHECK(false, "unreachable");
}
}),
kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_FLOATING_TYPES));
#undef TUNABLE_DISPATCH
}
else
#endif
{
at::cuda::blas::scaled_gemm(
args.transa,
args.transb,
args.m,
args.n,
args.k,
args.mata->data_ptr(),
args.scale_mata_ptr,
args.lda,
args.mata->scalar_type(),
args.scale_mata_dtype.value(),
args.scaling_mata_type.value(),
args.matb->data_ptr(),
args.scale_matb_ptr,
args.ldb,
args.matb->scalar_type(),
args.scale_matb_dtype.value(),
args.scaling_matb_type.value(),
bias ? bias->data_ptr(): nullptr,
bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
args.result->data_ptr(),
args.scale_result_ptr,
args.result_ld,
out_dtype_,
use_fast_accum);
}
return out;
}
Tensor&
_scaled_tensorwise_tensorwise(
const Tensor& mat_a, const Tensor& mat_b,
@ -2082,7 +1990,7 @@ _scaled_tensorwise_tensorwise(
auto scaling_choice_a = ScalingType::TensorWise;
auto scaling_choice_b = ScalingType::TensorWise;
_cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
return out;
}
@ -2118,7 +2026,7 @@ _scaled_rowwise_rowwise(
if (((dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900)
// cuBLAS only supports tiled 1D factor layout for 1D block scaling, no 2D block scales
|| (dprops->major == 10 && (scale_a.sizes().size() || scale_b.sizes().size())))) {
TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling.");
TORCH_CHECK_VALUE(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling.");
at::cuda::detail::f8f8bf16_rowwise(
mat_a,
mat_b,
@ -2144,11 +2052,38 @@ _scaled_rowwise_rowwise(
"hipblaslt rowwise _scaled_mm only supports BFloat16 output but got ", out.scalar_type());
#endif
_cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
return out;
}
// Check the shapes & sizes of scales for deepseek-style (1x128, 128x128) scaling.
// Wraps check_size_stride for easier integration, correctly handles cases where a dimension of the scale == 1,
// and strides become somewhat meaningless
void _check_deepseek_scale_stride(const Tensor& scale, const Tensor& t, const ScalingType scale_type) {
if (scale_type == ScalingType::BlockWise1x128) {
TORCH_CHECK_VALUE(check_size_stride(scale, 0, t.size(0), 1),
"at dim=0 scale should have ", t.size(0), "elements and stride(0) ", 1, "if ", t.size(0), " > 1 - Got: ",
"shape=", scale.sizes(), ", stride=", scale.strides());
auto expected_size = ceil_div<int64_t>(t.size(1), 128);
TORCH_CHECK_VALUE(check_size_stride(scale, 1, expected_size, t.size(0)),
"at dim=1 scale should have ", expected_size, "elements and stride ", t.size(0), "if ", expected_size, " > 1 - Got: ",
"shape=", scale.sizes(), ", stride=", scale.strides());
} else if (scale_type == ScalingType::BlockWise128x128) {
TORCH_CHECK_VALUE(check_size_stride(
scale,
0,
ceil_div<int64_t>(t.size(0), 128),
ceil_div<int64_t>(t.size(1), 128)),
"at dim=0 scale should have ", ceil_div<int64_t>(t.size(0), 128), "elements and stride(0) ", ceil_div<int64_t>(t.size(1), 128), "if ", ceil_div<int64_t>(t.size(0), 128), " > 1 - Got: ",
"shape=", scale.sizes(), ", stride=", scale.strides());
TORCH_CHECK(check_size_stride(
scale, 1, ceil_div<int64_t>(t.size(1), 128), 1),
"at dim=1 scale should have ", ceil_div<int64_t>(t.size(1), 128), "elements and stride(1) ", 1, "if ", ceil_div<int64_t>(t.size(1), 128), " > 1 - Got: ",
"shape=", scale.sizes(), ", stride=", scale.strides());
}
}
Tensor&
_scaled_block1x128_block1x128(
const Tensor& mat_a, const Tensor& mat_b,
@ -2166,15 +2101,14 @@ _scaled_block1x128_block1x128(
TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div<int64_t>(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat,
"scale_b must have shape ", ceil_div<int64_t>(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes())
TORCH_CHECK(scale_a.stride(0) == 1, "expected scale_a.stride(0) to be 1, but got ", scale_a.stride(0));
TORCH_CHECK(scale_b.stride(1) == 1, "expected scale_b.stride(1) to be 1, but got ", scale_b.stride(1));
TORCH_CHECK(scale_b.stride(0) == scale_b.size(1),
"expected scale_b.stride(0) to be ", scale_b.size(1), ", but got ", scale_b.size(1));
auto scaling_choice_a = ScalingType::BlockWise1x128;
auto scaling_choice_b = ScalingType::BlockWise1x128;
_cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
// Check scale strides (including stride=1 small cases)
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
return out;
}
@ -2189,6 +2123,8 @@ _scaled_block128x128_block1x128(
Tensor& out) {
// Restrictions:
// A, B are FP8, scales are fp32, shape K//128
std::cout << "mat_b: " << mat_b.dim() << ", " << mat_b.sizes() << ", " << mat_b.strides() << std::endl;
std::cout << "scale_b: " << scale_b.dim() << ", " << scale_b.sizes() << ", " << scale_b.strides() << std::endl;
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
mat_a.scalar_type(), mat_b.scalar_type());
TORCH_CHECK_VALUE(scale_a.sizes()[0] == ceil_div<int64_t>(mat_a.sizes()[0], 128) && scale_a.sizes()[1] == ceil_div<int64_t>(mat_a.sizes()[1], 128) && scale_a.scalar_type() == kFloat,
@ -2196,15 +2132,14 @@ _scaled_block128x128_block1x128(
TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div<int64_t>(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat,
"scale_b must have shape ", ceil_div<int64_t>(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes())
TORCH_CHECK_VALUE(scale_a.stride(1) == 1, "expected scale_a.stride(1) to be 1, but got ", scale_a.stride(1));
TORCH_CHECK_VALUE(scale_b.stride(1) == 1, "expected scale_b.stride(1) to be 1, but got ", scale_b.stride(1));
TORCH_CHECK_VALUE(scale_b.stride(0) == scale_b.size(1),
"expected scale_b.stride(0) to be ", scale_b.size(1), ", but got ", scale_b.stride(0));
auto scaling_choice_a = ScalingType::BlockWise128x128;
auto scaling_choice_b = ScalingType::BlockWise1x128;
_cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
// Check scale strides (including stride=1 small cases)
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
return out;
}
@ -2226,15 +2161,14 @@ _scaled_block1x128_block128x128(
TORCH_CHECK_VALUE(scale_b.sizes()[0] == mat_b.sizes()[0] / 128 && scale_b.sizes()[1] == mat_b.sizes()[1] / 128 && scale_b.scalar_type() == kFloat,
"scale_b must have shape ", mat_b.sizes()[0] / 128, " x ", mat_b.sizes()[1] / 128, " Float elements, got ", scale_b.sizes())
TORCH_CHECK_VALUE(scale_a.stride(0) == 1, "expected scale_a.stride(0) to be 1, but got ", scale_a.stride(0));
TORCH_CHECK_VALUE(scale_b.stride(0) == 1, "expected scale_b.stride(0) to be 1, but got ", scale_b.stride(0));
TORCH_CHECK_VALUE(scale_b.stride(1) == scale_b.size(0),
"expected scale_b.stride(1) to be ", scale_b.size(0), ", but got ", scale_b.stride(1));
auto scaling_choice_a = ScalingType::BlockWise1x128;
auto scaling_choice_b = ScalingType::BlockWise128x128;
_cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
// Check scale strides (including stride=1 small cases)
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
return out;
}
@ -2288,7 +2222,7 @@ _scaled_mxfp8_mxfp8(
#endif
#endif
return _cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
}
Tensor&
@ -2325,7 +2259,7 @@ _scaled_nvfp4_nvfp4(
auto scaling_choice_a = ScalingType::BlockWise1x16;
auto scaling_choice_b = ScalingType::BlockWise1x16;
return _cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
}
@ -2574,7 +2508,9 @@ _mx8_mx8_bf16_grouped_mm_fbgemm(
const Tensor& mat_a,
const Tensor& mat_b,
const Tensor& scale_a,
const SwizzleType& swizzle_a,
const Tensor& scale_b,
const SwizzleType& swizzle_b,
const std::optional<at::Tensor>& offs,
Tensor& out) {
const bool a_is_2d = mat_a.dim() == 2;
@ -2585,6 +2521,16 @@ _mx8_mx8_bf16_grouped_mm_fbgemm(
TORCH_CHECK_VALUE(is_2d_2d || is_2d_3d, "MXFP8 grouped GEMM currently only supports 2d-2d and 2d-3d cases");
TORCH_CHECK_VALUE(offs.has_value(), "MXFP8 2d-2d and 2d-3d grouped GEMMs requires offsets");
TORCH_CHECK_VALUE(out.scalar_type() == at::kBFloat16, "Only bf16 out_dtype is supported for MXFP8 grouped gemm");
// MXFP8 expects float8_e8m0fnu scales.
TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu,
"For MXFP8 grouped gemm, both scales must be float8_e8m0fnu tensors.");
#ifdef USE_ROCM
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE && swizzle_b == SwizzleType::NO_SWIZZLE,
"For ROCM MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_NONE");
#else
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4 && swizzle_b == SwizzleType::SWIZZLE_32_4_4,
"For CUDA MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_32_4_4");
#endif
#if defined(USE_FBGEMM_GENAI) and !defined(USE_ROCM)
fbgemm_gpu::mx8mx8bf16_grouped_mm(
@ -2669,6 +2615,9 @@ _f8_f8_bf16_rowwise_grouped_mm(
const std::optional<Tensor>& bias,
bool use_fast_accum,
Tensor& out) {
// FP8 per-tensor and per-row scaling expect fp32 scales.
TORCH_CHECK_VALUE(scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
"For grouped FP8 rowwise, both scales must be float32 tensors");
#ifndef USE_ROCM
return _f8_f8_bf16_rowwise_grouped_mm_cuda(
mat_a,
@ -2768,11 +2717,15 @@ _scaled_grouped_mm_cuda(
#endif
if (is_mx8mx8bf16) {
// Note: Passing implied SwizzleType here, correctness of scale previously checked
// in `check_scale` call
return _mx8_mx8_bf16_grouped_mm_fbgemm(
mat_a,
mat_b,
scale_a,
SwizzleType::SWIZZLE_32_4_4,
scale_b,
SwizzleType::SWIZZLE_32_4_4,
offs.value(),
out);
}
@ -2789,6 +2742,140 @@ _scaled_grouped_mm_cuda(
out);
}
namespace {
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 2> scale_grouped_kernel_dispatch = {{
{ "rowwise_rowwise", check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
{ "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
} // anonymous namespace
Tensor
_scaled_grouped_mm_cuda_v2(
const Tensor& mat_a, const Tensor& mat_b,
ArrayRef<Tensor> scale_a,
IntArrayRef scale_recipe_a,
IntArrayRef swizzle_a,
ArrayRef<Tensor> scale_b,
IntArrayRef scale_recipe_b,
IntArrayRef swizzle_b,
const std::optional<Tensor>& offs,
const std::optional<Tensor>& bias,
const std::optional<c10::ScalarType> out_dtype,
IntArrayRef contraction_dim,
bool use_fast_accum) {
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true);
TORCH_CHECK_VALUE(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+");
TORCH_CHECK_VALUE(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed");
TORCH_CHECK_VALUE(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed");
TORCH_CHECK_VALUE(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d");
TORCH_CHECK_VALUE(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
const bool a_is_2d = mat_a.dim() == 2;
const bool b_is_2d = mat_b.dim() == 2;
// NOTE(slayton): For sub-1B formats want contraction_dim argument?
if (!a_is_2d || !b_is_2d) {
if (contraction_dim.size() > 0) {
const int dim_a = contraction_dim[0], dim_b = mat_b.size(contraction_dim[1]);
TORCH_CHECK_VALUE(mat_a.size(dim_a) == mat_b.size(dim_b),
"Contraction dimensions (", dim_a, ",", dim_b, ") of mat_a and mat_b must match, got: ", mat_a.size(dim_a), " and ",
mat_b.size(dim_b));
// Note: only (-1, -2) is currently supported
TORCH_CHECK_VALUE(dim_a == -1 && dim_b == -2, "Curently contraction dims must be (-1, -2) only");
} else {
TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
}
}
TORCH_CHECK_VALUE(
mat_a.size(-1) % 16 == 0,
"Expected trailing dimension of mat_a to be divisible by 16 ",
"but got mat1 shape: (",
mat_a.sizes(),
").");
TORCH_CHECK_VALUE(mat_b.size(-2) % 16 == 0 && mat_b.size(-1) % 16 == 0,
"Expected mat_b shape to be divisible by 16 ",
"but got mat_b shape: (",
mat_b.sizes(),
").");
TORCH_CHECK_VALUE(!bias.has_value(), "Bias not supported yet");
TORCH_CHECK_VALUE(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");
// NOTE: mxfp8 x mxfp8 requires (and asserts later) that offsets is present.
// for rowwise, no offsets implies 3d-3d and is handled by lower-level
// routines
if (offs.has_value()) {
TORCH_CHECK_VALUE(offs->dim() == 1, "offs has to be 1D");
TORCH_CHECK_VALUE(offs->dtype() == at::kInt, "Offsets have to be int32");
}
const auto out_dtype_ = out_dtype.value_or(kBFloat16);
TORCH_CHECK_VALUE(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
// Conversion of implicitly-defined enums to explicit
auto scale_recipe_a_enum = convert_int_to_enum<ScalingType>(scale_recipe_a);
auto swizzle_a_enum = convert_int_to_enum<SwizzleType>(swizzle_a);
auto scale_recipe_b_enum = convert_int_to_enum<ScalingType>(scale_recipe_b);
auto swizzle_b_enum = convert_int_to_enum<SwizzleType>(swizzle_b);
// at this point we can start working out what we want to be doing
// Try to do as few steps as possible.
// NOTE: support is deliberately sparse, can explicitly enumerate all combinations allowed.
// Do this via a list of defined (name, acceptance, concrete_impl) tuples.
ScaledGemmImplementation gemm_impl = ScaledGemmImplementation::NONE;
for (const auto& fn_entry : scale_grouped_kernel_dispatch) {
const auto [name, accept_fn, scaled_gemm_impl] = fn_entry;
bool ok = accept_fn(mat_a.scalar_type(),
scale_recipe_a_enum,
scale_a,
mat_b.scalar_type(),
scale_recipe_b_enum,
scale_b);
if (ok) {
gemm_impl = scaled_gemm_impl;
break;
}
}
TORCH_CHECK_VALUE(gemm_impl != ScaledGemmImplementation::NONE,
"No gemm implementation was found");
switch (gemm_impl) {
case ScaledGemmImplementation::ROWWISE_ROWWISE: {
const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1;
_check_scales_fp8_rowwise(mat_a, scale_a[0], 0 /* dim */ , 0 /* arg_idx */, scale_multiplier);
_check_scales_fp8_rowwise(mat_b, scale_b[0], 1 /* dim */ , 1 /* arg_idx */, scale_multiplier);
return _f8_f8_bf16_rowwise_grouped_mm(
mat_a,
mat_b,
scale_a[0],
scale_b[0],
offs,
bias,
use_fast_accum,
out);
}
case ScaledGemmImplementation::MXFP8_MXFP8: {
_check_scales_mxfp8(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
_check_scales_mxfp8(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
return _mx8_mx8_bf16_grouped_mm_fbgemm(
mat_a,
mat_b,
scale_a[0],
swizzle_a_enum[0],
scale_b[0],
swizzle_b_enum[0],
offs.value(),
out);
}
default:
TORCH_CHECK_NOT_IMPLEMENTED(false,
"_scaled_grouped_mm_cuda_v2 is in an inconsistent state - should never reach here");
}
}
Tensor _grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
const std::optional<at::Tensor>& offs,
const std::optional<at::Tensor>& bias,

View File

@ -856,9 +856,13 @@ struct type_specialized_kernel_launcher {
out_calc_t output_offset_calculator,
loader_t loader,
storer_t storer) {
if (ret_t == rt_binary_specializations[arg_index][0] &&
arg0_t == rt_binary_specializations[arg_index][1] &&
arg1_t == rt_binary_specializations[arg_index][2])
constexpr ScalarType sret_t = rt_binary_specializations[arg_index][0];
constexpr ScalarType sarg0_t = rt_binary_specializations[arg_index][1];
constexpr ScalarType sarg1_t = rt_binary_specializations[arg_index][2];
if (ret_t == sret_t && arg0_t == sarg0_t && arg1_t == sarg1_t) {
using cret_t = c10::impl::ScalarTypeToCPPTypeT<sret_t>;
using carg0_t = c10::impl::ScalarTypeToCPPTypeT<sarg0_t>;
using carg1_t = c10::impl::ScalarTypeToCPPTypeT<sarg1_t>;
launch_vectorized_templated_kernel<
func_t,
array_t,
@ -866,12 +870,9 @@ struct type_specialized_kernel_launcher {
out_calc_t,
loader_t,
storer_t,
decltype(c10::impl::ScalarTypeToCPPType<
rt_binary_specializations[arg_index][0]>::t),
decltype(c10::impl::ScalarTypeToCPPType<
rt_binary_specializations[arg_index][1]>::t),
decltype(c10::impl::ScalarTypeToCPPType<
rt_binary_specializations[arg_index][2]>::t)>(
cret_t,
carg0_t,
carg1_t>(
numel,
f,
data,
@ -879,6 +880,7 @@ struct type_specialized_kernel_launcher {
output_offset_calculator,
loader,
storer);
}
}
};

View File

@ -655,8 +655,14 @@ struct ReduceOp {
}
__syncthreads();
// Intra-warp reduction, fix CUDA to have offset decreasing for better numerics
// matching Triton, etc.
// todo for AMD
#ifdef USE_ROCM
for (int offset = 1; offset < dim_x; offset <<= 1) {
#else
for (int offset = dim_x >> 1; offset > 0; offset >>= 1) {
#endif
#pragma unroll
for (int i = 0; i < output_vec_size; i++) {
arg_t other = ops.warp_shfl_down(value[i], offset);

View File

@ -77,8 +77,8 @@ struct nansum_functor_complex {
#if AT_USE_JITERATOR()
void operator()(TensorIterator& iter) {
std::string func = jiterator_stringify(
arg_t combine(arg_t a, scalar_t b) {
return a + (std::isnan(b) ? arg_t{0.} : arg_t{b});
arg_t combine(arg_t a, arg_t b) {
return a + (std::isnan(b) ? arg_t{0.} : b);
}
);
jitted_gpu_reduce_kernel<nansum_name, scalar_t, scalar_t>(

View File

@ -464,6 +464,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
}
#endif
int32_t trailingSize;
int nDimsLocal = nDims;
TensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> kernelOutputParam;
if (isInOutAligned) {
// in this case we can and should flatten the tensors after the cat dim
@ -477,7 +478,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
// and divide all strides except last by elems_per_vec (last stride is 1 always)
// for input, we will fix up the sizes and strides in the kernel directly
kernelOutputParam = outputParam;
nDims = dimension + 1;
nDimsLocal = dimension + 1;
constexpr auto elems_per_vec = alignment / sizeof(scalar_t);
auto out_size = dimension == 0 ? out.numel() : kernelOutputParam.tensorStride[dimension-1];
kernelOutputParam.tensorSize[dimension] = out_size / elems_per_vec;
@ -494,7 +495,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
case 0:
break;
case 1:
cat_dim = nDims - cat_dim;
cat_dim = nDimsLocal - cat_dim;
break;
default:
cat_dim--;
@ -525,7 +526,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
data, catMetaData, outputParam, cat_dim, outputParam.tensorStride[cat_dim]);\
}\
C10_CUDA_KERNEL_LAUNCH_CHECK();
switch (nDims) {
switch (nDimsLocal) {
case 1:
HANDLE_CASE(1);
break;

View File

@ -21,9 +21,15 @@ namespace {
struct offset_t {
int stride;
int begin;
__device__ int operator[](int i) {
__device__ int operator[](int i) const {
return stride * (begin + i);
}
#if CCCL_VERSION >= 3001000
__device__ offset_t& operator+=(int i) {
begin += i;
return *this;
}
#endif
};
// Segmented sort by full sort algorithm:.
// Say we are sorting a (2, 3) tensor. We have in flattened form:

View File

@ -127,6 +127,29 @@ __global__ void upsample_bilinear2d_nhwc_out_frame(
}
}
#ifdef USE_ROCM
// Helper function to compute output pixel range that can contribute to input pixel
template <typename accscalar_t>
__device__ __forceinline__ void compute_output_range(
int input_pos,
accscalar_t scale,
int output_size,
bool align_corners,
int& min_output,
int& max_output) {
accscalar_t lo, hi;
if (align_corners) {
lo = static_cast<accscalar_t>(input_pos - 1) / scale;
hi = static_cast<accscalar_t>(input_pos + 1) / scale;
} else {
lo = (input_pos - static_cast<accscalar_t>(0.5)) / scale - static_cast<accscalar_t>(0.5);
hi = (input_pos + static_cast<accscalar_t>(1.5)) / scale - static_cast<accscalar_t>(0.5);
}
min_output = max(0, static_cast<int>(ceil(lo)));
max_output = min(output_size - 1, static_cast<int>(floor(hi)));
}
#endif
// Backward (adjoint) operation 1 <- 2 (accumulates)
template <typename scalar_t, typename accscalar_t>
C10_LAUNCH_BOUNDS_1(1024)
@ -141,8 +164,74 @@ __global__ void upsample_bilinear2d_backward_out_frame(
const bool align_corners,
scalar_t* __restrict__ idata,
const scalar_t* __restrict__ odata) {
const size_t o_numel = nc * width2 * height2;
// In C++, integer multiplication, like in standard arithmetic, is generally commutative.
const size_t i_numel = nc * width1 * height1;
#ifdef USE_ROCM
for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < i_numel;
index += blockDim.x * gridDim.x) {
// Decode input pixel coordinates
size_t index_temp = index;
const int w1 = index_temp % width1;
index_temp /= width1;
const int h1 = index_temp % height1;
const size_t nc_idx = index_temp / height1;
accscalar_t grad_sum = 0;
// Find range of output pixels that could interpolate from this input pixel
int h2_min, h2_max, w2_min, w2_max;
compute_output_range<accscalar_t>(h1, rheight, height2, align_corners, h2_min, h2_max);
compute_output_range<accscalar_t>(w1, rwidth, width2, align_corners, w2_min, w2_max);
// Iterate over potential output pixels
for (int h2 = h2_min; h2 <= h2_max; h2++) {
for (int w2 = w2_min; w2 <= w2_max; w2++) {
// Compute source coordinates for this output pixel
const accscalar_t h1r = area_pixel_compute_source_index<accscalar_t>(
rheight, h2, align_corners, /*cubic=*/false);
const int h1_base = (int)h1r;
const int h1p = (h1_base < height1 - 1) ? 1 : 0;
const accscalar_t h1lambda = h1r - h1_base;
const accscalar_t h0lambda = static_cast<accscalar_t>(1) - h1lambda;
const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t>(
rwidth, w2, align_corners, /*cubic=*/false);
const int w1_base = (int)w1r;
const int w1p = (w1_base < width1 - 1) ? 1 : 0;
const accscalar_t w1lambda = w1r - w1_base;
const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda;
// Check if our input pixel participates in this interpolation and accumulate all weights
// At boundaries, h1p=0 or w1p=0 causes some sampling positions to collapse
// to the same pixel, so we need to accumulate weights from all matching positions
accscalar_t weight = 0;
// Check all four interpolation positions and accumulate weights
if (h1 == h1_base && w1 == w1_base) {
weight += h0lambda * w0lambda; // top-left
}
if (h1 == h1_base && w1 == w1_base + w1p) {
weight += h0lambda * w1lambda; // top-right (may be same as top-left if w1p=0)
}
if (h1 == h1_base + h1p && w1 == w1_base) {
weight += h1lambda * w0lambda; // bottom-left (may be same as top-left if h1p=0)
}
if (h1 == h1_base + h1p && w1 == w1_base + w1p) {
weight += h1lambda * w1lambda; // bottom-right (may collapse to other positions)
}
if (weight > 0) {
const size_t output_idx = nc_idx * height2 * width2 + h2 * width2 + w2;
grad_sum += weight * static_cast<accscalar_t>(odata[output_idx]);
}
}
}
// Write accumulated gradient (no atomics needed)
idata[index] = static_cast<scalar_t>(grad_sum);
}
#else
const size_t o_numel = nc * width2 * height2;
for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel;
index += blockDim.x * gridDim.x) {
size_t index_temp = index;
@ -191,6 +280,7 @@ __global__ void upsample_bilinear2d_backward_out_frame(
static_cast<scalar_t>(h1lambda * w1lambda * d2val),
true);
}
#endif
}
template <typename scalar_t, typename accscalar_t>
@ -387,7 +477,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
// threads are not covering the whole input tensor.
grad_input.zero_();
const size_t num_kernels = nbatch * channels * output_height * output_width;
const int num_threads = std::min(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@ -397,6 +486,12 @@ static void upsample_bilinear2d_backward_out_cuda_template(
return;
}
#ifdef USE_ROCM
constexpr bool use_input = true;
#else
constexpr bool use_input = false;
#endif
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16,
grad_output_.scalar_type(), "upsample_bilinear2d_backward_out_frame", [&] {
@ -414,6 +509,8 @@ static void upsample_bilinear2d_backward_out_cuda_template(
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
input_width, output_width, align_corners, scales_w);
const size_t num_kernels = nbatch * channels * output_height * output_width;
upsample_bilinear2d_backward_nhwc_out_frame<scalar_t, accscalar_t>
<<<ceil_div(num_kernels, static_cast<size_t>(num_threads)), num_threads, 0, stream>>>(
input_height,
@ -444,6 +541,8 @@ static void upsample_bilinear2d_backward_out_cuda_template(
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
input_width, output_width, align_corners, scales_w);
const size_t num_kernels = nbatch * channels * (use_input ? input_height * input_width : output_height * output_width);
upsample_bilinear2d_backward_out_frame<scalar_t, accscalar_t>
<<<ceil_div(num_kernels, static_cast<size_t>(num_threads)),
num_threads,

View File

@ -662,7 +662,7 @@ void svd_cusolver(const Tensor& A,
const auto n = A.size(-1);
const auto k = std::min(m, n);
static const char* check_svd_doc = "Check doc at https://pytorch.org/docs/stable/generated/torch.linalg.svd.html";
static constexpr const char* check_svd_doc = "Check doc at https://pytorch.org/docs/stable/generated/torch.linalg.svd.html";
// The default heuristic is to use gesvdj driver
#ifdef USE_ROCM

View File

@ -466,7 +466,11 @@ struct ReduceJitOp {
__syncthreads();
#ifdef USE_ROCM
for (int offset = 1; offset < dim_x; offset <<= 1) {
#else
for (int offset = dim_x >> 1; offset > 0; offset >>= 1) {
#endif
#pragma unroll
for (int i = 0; i < output_vec_size; i++) {
arg_t other = reducer::warp_shfl_down(value[i], offset);

View File

@ -3,6 +3,9 @@
#include <ATen/core/Tensor.h>
#include <ATen/native/DispatchStub.h>
#include <c10/util/accumulate.h>
#include <c10/core/SymBool.h>
#include <c10/util/StringUtil.h>
namespace at::native {
@ -19,28 +22,30 @@ C10_ALWAYS_INLINE void _check_rms_norm_inputs_symint(
"Expected normalized_shape to be at least 1-dimensional, i.e., ",
"containing at least one element, but got normalized_shape = ",
normalized_shape);
TORCH_CHECK(
!weight.defined() || weight.sym_sizes().equals(normalized_shape),
"Expected weight to be of same shape as normalized_shape, but got ",
"weight of shape ",
weight.sym_sizes(),
" and normalized_shape = ",
normalized_shape);
if (weight.defined()) {
TORCH_SYM_CHECK(
sym_equals(weight.sym_sizes(), normalized_shape),
"Expected weight to be of same shape as normalized_shape, but got ",
"weight of shape ",
weight.sym_sizes(),
" and normalized_shape = ",
normalized_shape);
}
const auto input_ndim = input.dim();
const auto input_shape = input.sym_sizes();
if (input_ndim < normalized_ndim ||
!input_shape.slice(input_ndim - normalized_ndim)
.equals(normalized_shape)) {
std::stringstream ss;
ss << "Given normalized_shape=" << normalized_shape
<< ", expected input with shape [*";
for (auto size : normalized_shape) {
ss << ", " << size;
}
ss << "], but got input of size" << input_shape;
TORCH_CHECK(false, ss.str());
}
TORCH_CHECK_VALUE(
input_ndim >= normalized_ndim,
"Input tensor must have at least ", normalized_ndim, " dimensions, but got ", input_ndim);
auto expect_input_shape_msg = c10::str(
"Given normalized_shape=", normalized_shape,
", expected input with shape [*", c10::Join(", ", normalized_shape),
"], but got input of size", input_shape);
TORCH_SYM_CHECK(
sym_equals(input_shape.slice(input_ndim - normalized_ndim), normalized_shape),
expect_input_shape_msg);
}
C10_ALWAYS_INLINE std::pair<int64_t, int64_t> _check_layer_norm_inputs(

View File

@ -99,6 +99,9 @@ Tensor getTensorView(const Tensor& t, MPSShape* shape);
MPSShape* getMPSShape(const TensorBase& t, c10::MemoryFormat memory_format = MemoryFormat::Contiguous);
MPSShape* getMPSShape(IntArrayRef sizes, c10::MemoryFormat memory_format = MemoryFormat::Contiguous);
// Determines whether a tensor is too large to use MPSGraph
bool isTooLargeForMPSGraph(const Tensor& tensor, bool useMPSStridedAPI = true);
static inline id<MTLBuffer> getMTLBufferStorage(const TensorBase& tensor) {
return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
}

View File

@ -439,6 +439,22 @@ static void check_mps_shape(MPSShape* shape) {
}
}
bool isTooLargeForMPSGraph(const Tensor& tensor, bool useMPSStridedAPI) {
static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
if ((!tensor.is_contiguous() || tensor.storage_offset()) && useMPSStridedAPI && is_macOS_15_0_or_newer) {
auto storage_numel = tensor.storage().nbytes() / tensor.element_size() - tensor.storage_offset();
if (storage_numel > std::numeric_limits<int32_t>::max()) {
return true;
}
}
for (auto size : tensor.sizes()) {
if (size > std::numeric_limits<int32_t>::max()) {
return true;
}
}
return false;
}
MPSNDArray* getMPSNDArray(const TensorBase& t, MPSShape* sizes, MPSShape* strides) {
id<MTLBuffer> srcBuf = getMTLBufferStorage(t);

View File

@ -249,7 +249,7 @@ kernel void embedding_bag(
template <EmbeddingBagMode M, typename T>
struct MaybeDivBagSize {
inline opmath_t<T> operator()(opmath_t<T> val, opmath_t<T> bag_size) {
inline opmath_t<T> operator()(opmath_t<T> val, opmath_t<T> /*bag_size*/) {
return val;
}
};

View File

@ -0,0 +1,18 @@
#pragma once
#include <c10/metal/common.h>
template <unsigned N = c10::metal::max_ndim, typename idx_type_t = int64_t>
struct CatLargeSharedParams {
int32_t ndim;
int32_t cat_dim;
::c10::metal::array<idx_type_t, N> output_strides;
::c10::metal::array<idx_type_t, N> output_sizes;
};
template <unsigned N = c10::metal::max_ndim, typename idx_type_t = int64_t>
struct CatLargeInputParams {
idx_type_t cat_dim_offset;
idx_type_t input_element_offset;
::c10::metal::array<idx_type_t, N> input_strides;
::c10::metal::array<idx_type_t, N> input_sizes;
};

View File

@ -0,0 +1,82 @@
#include <ATen/native/mps/kernels/Shape.h>
#include <c10/metal/utils.h>
#include <metal_array>
#include <metal_stdlib>
using namespace metal;
using namespace c10::metal;
template <typename T_in, typename T_out>
kernel void cat_large(
constant T_in* input [[buffer(0)]],
device T_out* output [[buffer(1)]],
constant CatLargeSharedParams<>& shared_params [[buffer(2)]],
constant CatLargeInputParams<>& input_params [[buffer(3)]],
uint tid [[thread_position_in_grid]]) {
auto ndim = shared_params.ndim;
auto cat_dim = shared_params.cat_dim;
constant auto& output_strides = shared_params.output_strides;
constant auto& output_sizes = shared_params.output_sizes;
auto cat_dim_offset = input_params.cat_dim_offset;
auto input_element_offset = input_params.input_element_offset;
constant auto& input_strides = input_params.input_strides;
constant auto& input_sizes = input_params.input_sizes;
auto input_element_idx = static_cast<int64_t>(tid) + input_element_offset;
int64_t input_offset = 0;
int64_t output_offset = 0;
for (auto dim = ndim - 1; dim >= 0; dim--) {
auto dim_size = input_sizes[dim];
auto input_dim_idx = input_element_idx % dim_size;
auto output_dim_idx =
input_dim_idx + ((dim == cat_dim) ? cat_dim_offset : 0);
input_offset += input_strides[dim] * input_dim_idx;
output_offset += output_strides[dim] * output_dim_idx;
input_element_idx = input_element_idx / dim_size;
}
output[output_offset] = static_cast<T_out>(input[input_offset]);
}
#define REGISTER_CAT_LARGE_OP(T_in, T_out) \
template [[host_name("cat_large_" #T_in "_" #T_out)]] \
kernel void cat_large<T_in, T_out>( \
constant T_in * input [[buffer(0)]], \
device T_out * output [[buffer(1)]], \
constant CatLargeSharedParams<> & shared_params [[buffer(2)]], \
constant CatLargeInputParams<> & input_params [[buffer(3)]], \
uint tid [[thread_position_in_grid]]);
#define REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(T_out) \
REGISTER_CAT_LARGE_OP(float, T_out); \
REGISTER_CAT_LARGE_OP(half, T_out); \
REGISTER_CAT_LARGE_OP(bfloat, T_out); \
REGISTER_CAT_LARGE_OP(int, T_out); \
REGISTER_CAT_LARGE_OP(uint, T_out); \
REGISTER_CAT_LARGE_OP(long, T_out); \
REGISTER_CAT_LARGE_OP(ulong, T_out); \
REGISTER_CAT_LARGE_OP(short, T_out); \
REGISTER_CAT_LARGE_OP(ushort, T_out); \
REGISTER_CAT_LARGE_OP(char, T_out); \
REGISTER_CAT_LARGE_OP(uchar, T_out); \
REGISTER_CAT_LARGE_OP(bool, T_out);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(float);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(half);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(bfloat);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(int);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(uint);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(long);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(ulong);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(short);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(ushort);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(char);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(uchar);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(bool);
REGISTER_CAT_LARGE_OP(float2, float2);
REGISTER_CAT_LARGE_OP(half2, half2);

View File

@ -512,7 +512,7 @@ TORCH_IMPL_FUNC(threshold_backward_out_mps)
}
static MPSGraphTensor* normcdf(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
// (1.0f + erf(x*SQRT1_2)) * 0.5f * x;
// (1.0f + erf(x*SQRT1_2)) * 0.5f;
auto dataType = [inputTensor dataType];
const float SQRT1_2 = 0.707106781186547524400844362104849039f;
MPSGraphTensor* sqrt1_2 = [mpsGraph constantWithScalar:SQRT1_2 shape:@[ @1 ] dataType:dataType];

View File

@ -54,6 +54,10 @@ Tensor dot_mps(const Tensor& self, const Tensor& other) {
using namespace mps;
using CachedGraph = MPSBinaryCachedGraph;
if (self.numel() == 0 & other.numel() == 0) {
return zeros({}, self.options());
}
dot_check(self, other);
auto output = at::empty({}, self.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt);

View File

@ -2,9 +2,13 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/MemoryOverlap.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/TensorShape.h>
#include <ATen/native/TypeProperties.h>
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/mps/kernels/Shape.h>
#include <fmt/format.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@ -16,6 +20,13 @@
#endif
namespace at::native {
#ifndef PYTORCH_JIT_COMPILE_SHADERS
static auto& lib = mps::MetalShaderLibrary::getBundledLibrary();
#else
#include <ATen/native/mps/Shape_metallib.h>
#endif
namespace mps {
// Produces a shape with the `dim` dimension set to 0.
@ -57,6 +68,70 @@ static void check_shape_except_dim(const Tensor& first, const Tensor& second, in
")");
}
}
// This implementation of cat is used only if one of the inputs or the output is
// too large to use MPSGraph.
// NOTE: `output` is expected to already have the correct size.
static void cat_out_large_tensor_mps(const ITensorListRef& inputs, int64_t dimension, const Tensor& output) {
CatLargeSharedParams shared_params;
shared_params.ndim = output.dim();
shared_params.cat_dim = dimension;
for (const auto dim : c10::irange(output.dim())) {
shared_params.output_strides[dim] = output.stride(dim);
shared_params.output_sizes[dim] = output.size(dim);
}
int64_t cat_dim_offset = 0;
size_t input_idx = 0;
MPSStream* stream = getCurrentMPSStream();
// Launch a separate kernels for each input. This will produce some overhead,
// but that should be relatively minimal since at least one of the inputs is
// very large. In order to launch only one kernel to process all inputs, we
// would have to copy all the input tensor data into a packed buffer, which
// would not be ideal.
for (const Tensor& input : inputs) {
if (input.numel() == 0) {
continue;
}
// Metal can only launch up to MAX_INT threads at one time. If the input has
// more than that number of elements, launch multiple kernels with different
// offsets into the data.
const int64_t max_num_threads = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
for (int64_t numel_remaining = input.numel(); numel_remaining > 0; numel_remaining -= max_num_threads) {
auto num_threads = std::min(max_num_threads, numel_remaining);
CatLargeInputParams input_params;
input_params.cat_dim_offset = cat_dim_offset;
input_params.input_element_offset = input.numel() - numel_remaining;
for (const auto dim : c10::irange(input.dim())) {
input_params.input_strides[dim] = input.stride(dim);
input_params.input_sizes[dim] = input.size(dim);
}
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder();
auto pipeline_state = lib.getPipelineStateForFunc(
fmt::format("cat_large_{}_{}", scalarToMetalTypeString(input), scalarToMetalTypeString(output)));
getMPSProfiler().beginProfileKernel(pipeline_state, "cat", {input});
[computeEncoder setComputePipelineState:pipeline_state];
mtl_setArgs(computeEncoder, input, output, shared_params, input_params);
mtl_dispatch1DJob(computeEncoder, pipeline_state, num_threads);
getMPSProfiler().endProfileKernel(pipeline_state);
}
});
}
cat_dim_offset += input.size(dimension);
input_idx++;
}
}
} // namespace mps
// topk
@ -231,7 +306,11 @@ TORCH_IMPL_FUNC(cat_out_mps)
// Compute size of the result in the cat dimension
int64_t cat_dim_size = 0;
idx = 0;
bool has_large_tensor = false;
for (const Tensor& tensor : materialized_inputs) {
if (isTooLargeForMPSGraph(tensor)) {
has_large_tensor |= true;
}
if (!should_skip(tensor)) {
// TODO: Factor out `check_shape_except_dim`
check_shape_except_dim(notSkippedTensor, tensor, dimension, idx);
@ -249,6 +328,12 @@ TORCH_IMPL_FUNC(cat_out_mps)
return;
}
has_large_tensor |= isTooLargeForMPSGraph(out);
if (has_large_tensor) {
return mps::cat_out_large_tensor_mps(materialized_inputs, dimension, out);
}
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
std::vector<MPSGraphTensor*> inputTensors_;

View File

@ -4545,6 +4545,7 @@
- func: _cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor
dispatch:
CPU, CUDA: _cdist_forward
MTIA: _cdist_forward_mtia
MPS: _cdist_forward_mps
autogen: _cdist_forward.out
tags: core
@ -7182,6 +7183,12 @@
CUDA: _scaled_grouped_mm_cuda
tags: needs_exact_strides
- func: _scaled_grouped_mm_v2(Tensor self, Tensor mat2, Tensor[] scale_a, int[] recipe_a, int[] swizzle_a, Tensor[] scale_b, int[] recipe_b, int[] swizzle_b, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None, int[] contraction_dim=[], bool use_fast_accum=False) -> Tensor
variants: function
dispatch:
CUDA: _scaled_grouped_mm_cuda_v2
tags: needs_exact_strides
- func: _grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor
variants: function
dispatch:

View File

@ -178,24 +178,30 @@ std::tuple<Tensor, Tensor, Tensor> _fake_quantize_learnable_per_channel_affine_b
0 & \text{ else }
\end{cases}
*/
auto zero_point_rounded = _get_rounded_zero_point(zero_point, quant_min, quant_max);
bool is_bfloat16 = (X.scalar_type() == at::kBFloat16);
at::Tensor X_ = is_bfloat16 ? X.to(ScalarType::Float) : X;
at::Tensor dY_ = is_bfloat16 ? dY.to(ScalarType::Float) : dY;
at::Tensor scale_ = is_bfloat16 ? scale.to(ScalarType::Float) : scale;
at::Tensor zero_point_ = is_bfloat16 ? zero_point.to(ScalarType::Float) : zero_point;
TORCH_CHECK(dY.scalar_type() == ScalarType::Float);
TORCH_CHECK(X.scalar_type() == ScalarType::Float);
TORCH_CHECK(scale.scalar_type() == ScalarType::Float);
TORCH_CHECK(zero_point.scalar_type() == ScalarType::Float);
auto zero_point_rounded = _get_rounded_zero_point(zero_point_, quant_min, quant_max);
TORCH_CHECK(X.sizes() == dY.sizes(), "`X` and `dY` are not the same size");
TORCH_CHECK(dY_.scalar_type() == ScalarType::Float);
TORCH_CHECK(X_.scalar_type() == ScalarType::Float);
TORCH_CHECK(scale_.scalar_type() == ScalarType::Float);
TORCH_CHECK(zero_point_.scalar_type() == ScalarType::Float);
TORCH_CHECK(X_.sizes() == dY_.sizes(), "`X` and `dY` are not the same size");
TORCH_CHECK(
quant_min <= 0 && quant_max >= 0,
"Expecting `quant_min` <= 0 and `quant_max` >= 0");
TORCH_CHECK(scale.dim() == 1, "scale should be a 1-D tensor");
TORCH_CHECK(zero_point.dim() == 1, "zero point should be a 1-D tensor");
TORCH_CHECK(scale_.dim() == 1, "scale should be a 1-D tensor");
TORCH_CHECK(zero_point_.dim() == 1, "zero point should be a 1-D tensor");
TORCH_CHECK(
scale.numel() == zero_point.numel(),
scale_.numel() == zero_point_.numel(),
"scale and zero-point need to have the same dimensions");
TORCH_CHECK(
scale.numel() == X.size(axis),
scale_.numel() == X_.size(axis),
"dimensions of scale and zero-point are not consistent with input tensor")
TORCH_CHECK(
@ -204,42 +210,42 @@ std::tuple<Tensor, Tensor, Tensor> _fake_quantize_learnable_per_channel_affine_b
"`zero_point` must be between `quant_min` and `quant_max`.");
TORCH_CHECK(
axis >= 0 && axis < X.dim(),
axis >= 0 && axis < X_.dim(),
"`axis` must be between 0 and number of dimensions of input");
if (X.numel() <= 0) {
if (X_.numel() <= 0) {
return std::make_tuple(X, scale, zero_point);
}
auto dX = at::empty_like(X, X.options(), MemoryFormat::Preserve);
auto dScale_vec = at::empty_like(X, X.options(), MemoryFormat::Preserve);
auto dZeroPoint_vec = at::empty_like(X, X.options(), MemoryFormat::Preserve);
auto numDimensions = X.ndimension();
auto dX = at::empty_like(X_, X_.options(), MemoryFormat::Preserve);
auto dScale_vec = at::empty_like(X_, X_.options(), MemoryFormat::Preserve);
auto dZeroPoint_vec = at::empty_like(X_, X_.options(), MemoryFormat::Preserve);
auto numDimensions = X_.ndimension();
// Create an axis mask for vectorizing and reshaping the scale and zero point tensors
// into the same shapes as X along the channel axis.
c10::DimVector axis_mask(numDimensions);
for (const auto i : c10::irange(numDimensions)) {
axis_mask[i] = (i == axis) ? X.size(axis) : 1;
axis_mask[i] = (i == axis) ? X_.size(axis) : 1;
}
auto X_shape = X.sizes();
auto scale_vectorized = scale.reshape(at::IntArrayRef(axis_mask.data(), numDimensions)).expand(X_shape);
auto X_shape = X_.sizes();
auto scale_vectorized = scale_.reshape(at::IntArrayRef(axis_mask.data(), numDimensions)).expand(X_shape);
auto zero_point_vectorized = zero_point_rounded.reshape(at::IntArrayRef(axis_mask.data(), numDimensions)).expand(X_shape);
auto iter = TensorIteratorConfig()
.add_output(dX)
.add_output(dScale_vec)
.add_output(dZeroPoint_vec)
.add_input(X)
.add_input(dY)
.add_input(X_)
.add_input(dY_)
.add_input(scale_vectorized)
.add_input(zero_point_vectorized)
.build();
fake_quant_grad_learnable_channel_stub(
X.device().type(), iter, quant_min, quant_max, grad_factor);
X_.device().type(), iter, quant_min, quant_max, grad_factor);
auto numElements = X.ndimension() - 1;
auto numElements = X_.ndimension() - 1;
// Create a collection of axes that include all but the channel axis for
// reduction when summing over the dScale and dZeroPoint tensors.

View File

@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilevit_s mobilenetv2_100 pass 0
11 nfnet_l0 mobilenetv3_large_100 pass 0
12 repvgg_a2 mobilevit_s pass 0
13 nfnet_l0 pass 0
14 repvgg_a2 pass 0
15 swin_base_patch4_window7_224 pass 0
16 tf_efficientnet_b0 pass 0
17 swin_base_patch4_window7_224 visformer_small pass 0
18 tf_efficientnet_b0 vit_base_patch14_dinov2.lvd142m pass 0
19 visformer_small vit_base_patch16_siglip_256 pass 0
20
21
22
23
24
25
26
27
63
64
65
66
67
68
69
70
71
72
73

View File

@ -10,10 +10,18 @@ beit_base_patch16_224,pass,7
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
deit_base_distilled_patch16_224,pass,7
deit_tiny_patch16_224.fb_in1k,pass,7
dm_nfnet_f0,pass,6
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,6
visformer_small,pass,7
vit_base_patch14_dinov2.lvd142m,pass,7
vit_base_patch16_siglip_256,pass,7

1 name accuracy graph_breaks
10 mobilevit_s mobilenetv2_100 pass 6 7
11 nfnet_l0 mobilenetv3_large_100 pass 7
12 repvgg_a2 mobilevit_s pass 7 6
13 nfnet_l0 pass 7
14 repvgg_a2 pass 7
15 swin_base_patch4_window7_224 pass 7
16 tf_efficientnet_b0 pass 6
17 swin_base_patch4_window7_224 visformer_small pass 7
18 tf_efficientnet_b0 vit_base_patch14_dinov2.lvd142m pass 6 7
19 visformer_small vit_base_patch16_siglip_256 pass 7
20
21
22
23
24
25
26
27
63
64
65
66
67
68
69
70
71
72
73

View File

@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilevit_s mobilenetv2_100 pass 0
11 nfnet_l0 mobilenetv3_large_100 pass 0
12 repvgg_a2 mobilevit_s pass 0
13 nfnet_l0 pass 0
14 repvgg_a2 pass 0
15 swin_base_patch4_window7_224 pass 0
16 tf_efficientnet_b0 pass 0
17 swin_base_patch4_window7_224 visformer_small pass 0
18 tf_efficientnet_b0 vit_base_patch14_dinov2.lvd142m pass 0
19 visformer_small vit_base_patch16_siglip_256 pass 0
20
21
22
23
24
25
26
27
63
64
65
66
67
68
69
70
71
72
73

View File

@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilevit_s mobilenetv2_100 pass 0
11 nfnet_l0 mobilenetv3_large_100 pass 0
12 repvgg_a2 mobilevit_s pass 0
13 nfnet_l0 pass 0
14 repvgg_a2 pass 0
15 swin_base_patch4_window7_224 pass 0
16 tf_efficientnet_b0 pass 0
17 swin_base_patch4_window7_224 visformer_small pass 0
18 tf_efficientnet_b0 vit_base_patch14_dinov2.lvd142m pass 0
19 visformer_small vit_base_patch16_siglip_256 pass 0
20
21
22
23
24
25
26
27
63
64
65
66
67
68
69
70
71
72
73

View File

@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilevit_s mobilenetv2_100 pass 0
11 nfnet_l0 mobilenetv3_large_100 pass 0
12 repvgg_a2 mobilevit_s pass 0
13 nfnet_l0 pass 0
14 repvgg_a2 pass 0
15 swin_base_patch4_window7_224 pass 0
16 tf_efficientnet_b0 pass 0
17 swin_base_patch4_window7_224 visformer_small pass 0
18 tf_efficientnet_b0 vit_base_patch14_dinov2.lvd142m pass 0
19 visformer_small vit_base_patch16_siglip_256 pass 0
20
21
22
23
24
25
26
27
63
64
65
66
67
68
69
70
71
72
73

View File

@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilevit_s mobilenetv2_100 pass 0
11 nfnet_l0 mobilenetv3_large_100 pass 0
12 repvgg_a2 mobilevit_s pass 0
13 nfnet_l0 pass 0
14 repvgg_a2 pass 0
15 swin_base_patch4_window7_224 pass 0
16 tf_efficientnet_b0 pass 0
17 swin_base_patch4_window7_224 visformer_small pass 0
18 tf_efficientnet_b0 vit_base_patch14_dinov2.lvd142m pass 0
19 visformer_small vit_base_patch16_siglip_256 pass 0
20
21
22
23
24
25
26
27
63
64
65
66
67
68
69
70
71
72
73

View File

@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilevit_s mobilenetv2_100 pass 0
11 nfnet_l0 mobilenetv3_large_100 pass 0
12 repvgg_a2 mobilevit_s pass 0
13 nfnet_l0 pass 0
14 repvgg_a2 pass 0
15 swin_base_patch4_window7_224 pass 0
16 tf_efficientnet_b0 pass 0
17 swin_base_patch4_window7_224 visformer_small pass 0
18 tf_efficientnet_b0 vit_base_patch14_dinov2.lvd142m pass 0
19 visformer_small vit_base_patch16_siglip_256 pass 0
20
21
22
23
24
25
26
27
63
64
65
66
67
68
69
70
71
72
73

View File

@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilevit_s mobilenetv2_100 pass 0
11 nfnet_l0 mobilenetv3_large_100 pass 0
12 repvgg_a2 mobilevit_s pass 0
13 nfnet_l0 pass 0
14 repvgg_a2 pass 0
15 swin_base_patch4_window7_224 pass 0
16 tf_efficientnet_b0 pass 0
17 swin_base_patch4_window7_224 visformer_small pass 0
18 tf_efficientnet_b0 vit_base_patch14_dinov2.lvd142m pass 0
19 visformer_small vit_base_patch16_siglip_256 pass 0
20
21
22
23
24
25
26
27
63
64
65
66
67
68
69
70
71
72
73

View File

@ -10,10 +10,18 @@ beit_base_patch16_224,pass,7
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
deit_base_distilled_patch16_224,pass,7
deit_tiny_patch16_224.fb_in1k,pass,7
dm_nfnet_f0,pass,6
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,6
visformer_small,pass,7
vit_base_patch14_dinov2.lvd142m,pass,7
vit_base_patch16_siglip_256,pass,7

1 name accuracy graph_breaks
10 mobilevit_s mobilenetv2_100 pass 6 7
11 nfnet_l0 mobilenetv3_large_100 pass 7
12 repvgg_a2 mobilevit_s pass 7 6
13 nfnet_l0 pass 7
14 repvgg_a2 pass 7
15 swin_base_patch4_window7_224 pass 7
16 tf_efficientnet_b0 pass 6
17 swin_base_patch4_window7_224 visformer_small pass 7
18 tf_efficientnet_b0 vit_base_patch14_dinov2.lvd142m pass 6 7
19 visformer_small vit_base_patch16_siglip_256 pass 7
20
21
22
23
24
25
26
27
63
64
65
66
67
68
69
70
71
72
73

View File

@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilevit_s mobilenetv2_100 pass 0
11 nfnet_l0 mobilenetv3_large_100 pass 0
12 repvgg_a2 mobilevit_s pass 0
13 nfnet_l0 pass 0
14 repvgg_a2 pass 0
15 swin_base_patch4_window7_224 pass 0
16 tf_efficientnet_b0 pass 0
17 swin_base_patch4_window7_224 visformer_small pass 0
18 tf_efficientnet_b0 vit_base_patch14_dinov2.lvd142m pass 0
19 visformer_small vit_base_patch16_siglip_256 pass 0
20
21
22
23
24
25
26
27
63
64
65
66
67
68
69
70
71
72
73

View File

@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilevit_s mobilenetv2_100 pass 0
11 nfnet_l0 mobilenetv3_large_100 pass 0
12 repvgg_a2 mobilevit_s pass 0
13 nfnet_l0 pass 0
14 repvgg_a2 pass 0
15 swin_base_patch4_window7_224 pass 0
16 tf_efficientnet_b0 pass 0
17 swin_base_patch4_window7_224 visformer_small pass 0
18 tf_efficientnet_b0 vit_base_patch14_dinov2.lvd142m pass 0
19 visformer_small vit_base_patch16_siglip_256 pass 0
20
21
22
23
24
25
26
27
63
64
65
66
67
68
69
70
71
72
73

View File

@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilevit_s mobilenetv2_100 pass 0
11 nfnet_l0 mobilenetv3_large_100 pass 0
12 repvgg_a2 mobilevit_s pass 0
13 nfnet_l0 pass 0
14 repvgg_a2 pass 0
15 swin_base_patch4_window7_224 pass 0
16 tf_efficientnet_b0 pass 0
17 swin_base_patch4_window7_224 visformer_small pass 0
18 tf_efficientnet_b0 vit_base_patch14_dinov2.lvd142m pass 0
19 visformer_small vit_base_patch16_siglip_256 pass 0
20
21
22
23
24
25
26
27
63
64
65
66
67
68
69
70
71
72
73

View File

@ -10,10 +10,18 @@ beit_base_patch16_224,pass,7
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
deit_base_distilled_patch16_224,pass,7
deit_tiny_patch16_224.fb_in1k,pass,7
dm_nfnet_f0,pass,6
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,6
visformer_small,pass,7
vit_base_patch14_dinov2.lvd142m,pass,7
vit_base_patch16_siglip_256,pass,7

1 name accuracy graph_breaks
10 mobilevit_s mobilenetv2_100 pass 6 7
11 nfnet_l0 mobilenetv3_large_100 pass 7
12 repvgg_a2 mobilevit_s pass 7 6
13 nfnet_l0 pass 7
14 repvgg_a2 pass 7
15 swin_base_patch4_window7_224 pass 7
16 tf_efficientnet_b0 pass 6
17 swin_base_patch4_window7_224 visformer_small pass 7
18 tf_efficientnet_b0 vit_base_patch14_dinov2.lvd142m pass 6 7
19 visformer_small vit_base_patch16_siglip_256 pass 7
20
21
22
23
24
25
26
27
63
64
65
66
67
68
69
70
71
72
73

View File

@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilevit_s mobilenetv2_100 pass 0
11 nfnet_l0 mobilenetv3_large_100 pass 0
12 repvgg_a2 mobilevit_s pass 0
13 nfnet_l0 pass 0
14 repvgg_a2 pass 0
15 swin_base_patch4_window7_224 pass 0
16 tf_efficientnet_b0 pass 0
17 swin_base_patch4_window7_224 visformer_small pass 0
18 tf_efficientnet_b0 vit_base_patch14_dinov2.lvd142m pass 0
19 visformer_small vit_base_patch16_siglip_256 pass 0
20
21
22
23
24
25
26
27
63
64
65
66
67
68
69
70
71
72
73

View File

@ -10,10 +10,18 @@ beit_base_patch16_224,pass,7
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
deit_base_distilled_patch16_224,pass,7
deit_tiny_patch16_224.fb_in1k,pass,7
dm_nfnet_f0,pass,6
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,6
visformer_small,pass,7
vit_base_patch14_dinov2.lvd142m,pass,7
vit_base_patch16_siglip_256,pass,7

1 name accuracy graph_breaks
10 mobilevit_s mobilenetv2_100 pass 6 7
11 nfnet_l0 mobilenetv3_large_100 pass 7
12 repvgg_a2 mobilevit_s pass 7 6
13 nfnet_l0 pass 7
14 repvgg_a2 pass 7
15 swin_base_patch4_window7_224 pass 7
16 tf_efficientnet_b0 pass 6
17 swin_base_patch4_window7_224 visformer_small pass 7
18 tf_efficientnet_b0 vit_base_patch14_dinov2.lvd142m pass 6 7
19 visformer_small vit_base_patch16_siglip_256 pass 7
20
21
22
23
24
25
26
27
63
64
65
66
67
68
69
70
71
72
73

View File

@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilevit_s mobilenetv2_100 pass 0
11 nfnet_l0 mobilenetv3_large_100 pass 0
12 repvgg_a2 mobilevit_s pass 0
13 nfnet_l0 pass 0
14 repvgg_a2 pass 0
15 swin_base_patch4_window7_224 pass 0
16 tf_efficientnet_b0 pass 0
17 swin_base_patch4_window7_224 visformer_small pass 0
18 tf_efficientnet_b0 vit_base_patch14_dinov2.lvd142m pass 0
19 visformer_small vit_base_patch16_siglip_256 pass 0
20
21
22
23
24
25
26
27
63
64
65
66
67
68
69
70
71
72
73

View File

@ -10,10 +10,18 @@ beit_base_patch16_224,pass,7
convnextv2_nano.fcmae_ft_in22k_in1k,fail_accuracy,7
deit_base_distilled_patch16_224,pass,7
deit_tiny_patch16_224.fb_in1k,pass,7
dm_nfnet_f0,pass,6
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,6
visformer_small,pass,7
vit_base_patch14_dinov2.lvd142m,fail_accuracy,7
vit_base_patch16_siglip_256,pass,7

1 name accuracy graph_breaks
10 mobilevit_s mobilenetv2_100 pass 6 7
11 nfnet_l0 mobilenetv3_large_100 pass 7
12 repvgg_a2 mobilevit_s pass 7 6
13 nfnet_l0 pass 7
14 repvgg_a2 pass 7
15 swin_base_patch4_window7_224 pass 7
16 tf_efficientnet_b0 pass 6
17 swin_base_patch4_window7_224 visformer_small pass 7
18 tf_efficientnet_b0 vit_base_patch14_dinov2.lvd142m pass fail_accuracy 6 7
19 visformer_small vit_base_patch16_siglip_256 pass 7
20
21
22
23
24
25
26
27
63
64
65
66
67
68
69
70
71
72
73

View File

@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilevit_s mobilenetv2_100 pass 0
11 nfnet_l0 mobilenetv3_large_100 pass 0
12 repvgg_a2 mobilevit_s pass 0
13 nfnet_l0 pass 0
14 repvgg_a2 pass 0
15 swin_base_patch4_window7_224 pass 0
16 tf_efficientnet_b0 pass 0
17 swin_base_patch4_window7_224 visformer_small pass 0
18 tf_efficientnet_b0 vit_base_patch14_dinov2.lvd142m pass 0
19 visformer_small vit_base_patch16_siglip_256 pass 0
20
21
22
23
24
25
26
27
63
64
65
66
67
68
69
70
71
72
73

View File

@ -10,10 +10,18 @@ beit_base_patch16_224,pass,7
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
deit_base_distilled_patch16_224,pass,7
deit_tiny_patch16_224.fb_in1k,pass,7
dm_nfnet_f0,pass,6
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,6
visformer_small,pass,7
vit_base_patch14_dinov2.lvd142m,pass,7
vit_base_patch16_siglip_256,pass,7

1 name accuracy graph_breaks
10 mobilevit_s mobilenetv2_100 pass 6 7
11 nfnet_l0 mobilenetv3_large_100 pass 7
12 repvgg_a2 mobilevit_s pass 7 6
13 nfnet_l0 pass 7
14 repvgg_a2 pass 7
15 swin_base_patch4_window7_224 pass 7
16 tf_efficientnet_b0 pass 6
17 swin_base_patch4_window7_224 visformer_small pass 7
18 tf_efficientnet_b0 vit_base_patch14_dinov2.lvd142m pass 6 7
19 visformer_small vit_base_patch16_siglip_256 pass 7
20
21
22
23
24
25
26
27
63
64
65
66
67
68
69
70
71
72
73

View File

@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilevit_s mobilenetv2_100 pass 0
11 nfnet_l0 mobilenetv3_large_100 pass 0
12 repvgg_a2 mobilevit_s pass 0
13 nfnet_l0 pass 0
14 repvgg_a2 pass 0
15 swin_base_patch4_window7_224 pass 0
16 tf_efficientnet_b0 pass 0
17 swin_base_patch4_window7_224 visformer_small pass 0
18 tf_efficientnet_b0 vit_base_patch14_dinov2.lvd142m pass 0
19 visformer_small vit_base_patch16_siglip_256 pass 0
20
21
22
23
24
25
26
27
63
64
65
66
67
68
69
70
71
72
73

View File

@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilevit_s mobilenetv2_100 pass 0
11 nfnet_l0 mobilenetv3_large_100 pass 0
12 repvgg_a2 mobilevit_s pass 0
13 nfnet_l0 pass 0
14 repvgg_a2 pass 0
15 swin_base_patch4_window7_224 pass 0
16 tf_efficientnet_b0 pass 0
17 swin_base_patch4_window7_224 visformer_small pass 0
18 tf_efficientnet_b0 vit_base_patch14_dinov2.lvd142m pass 0
19 visformer_small vit_base_patch16_siglip_256 pass 0
20
21
22
23
24
25
26
27
63
64
65
66
67
68
69
70
71
72
73

View File

@ -10,10 +10,18 @@ beit_base_patch16_224,pass,7
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
deit_base_distilled_patch16_224,pass,7
deit_tiny_patch16_224.fb_in1k,pass,7
dm_nfnet_f0,pass,6
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,6
visformer_small,pass,7
vit_base_patch14_dinov2.lvd142m,pass,7
vit_base_patch16_siglip_256,pass,7

1 name accuracy graph_breaks
10 mobilevit_s mobilenetv2_100 pass 6 7
11 nfnet_l0 mobilenetv3_large_100 pass 7
12 repvgg_a2 mobilevit_s pass 7 6
13 nfnet_l0 pass 7
14 repvgg_a2 pass 7
15 swin_base_patch4_window7_224 pass 7
16 tf_efficientnet_b0 pass 6
17 swin_base_patch4_window7_224 visformer_small pass 7
18 tf_efficientnet_b0 vit_base_patch14_dinov2.lvd142m pass 6 7
19 visformer_small vit_base_patch16_siglip_256 pass 7
20
21
22
23
24
25
26
27
63
64
65
66
67
68
69
70
71
72
73

View File

@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilevit_s mobilenetv2_100 pass 0
11 nfnet_l0 mobilenetv3_large_100 pass 0
12 repvgg_a2 mobilevit_s fail_accuracy pass 0
13 nfnet_l0 pass 0
14 repvgg_a2 fail_accuracy 0
15 swin_base_patch4_window7_224 pass 0
16 tf_efficientnet_b0 pass 0
17 swin_base_patch4_window7_224 visformer_small pass 0
18 tf_efficientnet_b0 vit_base_patch14_dinov2.lvd142m pass 0
19 visformer_small vit_base_patch16_siglip_256 pass 0
20
21
22
23
24
25
26
27
63
64
65
66
67
68
69
70
71
72
73

View File

@ -10,10 +10,18 @@ beit_base_patch16_224,pass,7
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
deit_base_distilled_patch16_224,pass,7
deit_tiny_patch16_224.fb_in1k,pass,7
dm_nfnet_f0,pass,6
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,6
visformer_small,fail_accuracy,7
vit_base_patch14_dinov2.lvd142m,pass,7
vit_base_patch16_siglip_256,pass,7

1 name accuracy graph_breaks
10 mobilevit_s mobilenetv2_100 pass fail_accuracy 6 7
11 nfnet_l0 mobilenetv3_large_100 pass 7
12 repvgg_a2 mobilevit_s fail_accuracy pass 7 6
13 nfnet_l0 pass 7
14 repvgg_a2 fail_accuracy 7
15 swin_base_patch4_window7_224 pass 7
16 tf_efficientnet_b0 pass 6
17 swin_base_patch4_window7_224 visformer_small pass fail_accuracy 7
18 tf_efficientnet_b0 vit_base_patch14_dinov2.lvd142m pass 6 7
19 visformer_small vit_base_patch16_siglip_256 fail_accuracy pass 7
20
21
22
23
24
25
26
27
63
64
65
66
67
68
69
70
71
72
73

View File

@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilevit_s mobilenetv2_100 pass 0
11 nfnet_l0 mobilenetv3_large_100 pass 0
12 repvgg_a2 mobilevit_s pass 0
13 nfnet_l0 pass 0
14 repvgg_a2 pass 0
15 swin_base_patch4_window7_224 pass 0
16 tf_efficientnet_b0 pass 0
17 swin_base_patch4_window7_224 visformer_small pass 0
18 tf_efficientnet_b0 vit_base_patch14_dinov2.lvd142m pass 0
19 visformer_small vit_base_patch16_siglip_256 pass 0
20
21
22
23
24
25
26
27
63
64
65
66
67
68
69
70
71
72
73

View File

@ -10,10 +10,18 @@ beit_base_patch16_224,pass,7
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
deit_base_distilled_patch16_224,pass,7
deit_tiny_patch16_224.fb_in1k,pass,7
dm_nfnet_f0,pass,6
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,6
visformer_small,pass,7
vit_base_patch14_dinov2.lvd142m,pass,7
vit_base_patch16_siglip_256,pass,7

1 name accuracy graph_breaks
10 mobilevit_s mobilenetv2_100 pass 6 7
11 nfnet_l0 mobilenetv3_large_100 pass 7
12 repvgg_a2 mobilevit_s pass 7 6
13 nfnet_l0 pass 7
14 repvgg_a2 pass 7
15 swin_base_patch4_window7_224 pass 7
16 tf_efficientnet_b0 pass 6
17 swin_base_patch4_window7_224 visformer_small pass 7
18 tf_efficientnet_b0 vit_base_patch14_dinov2.lvd142m pass 6 7
19 visformer_small vit_base_patch16_siglip_256 pass 7
20
21
22
23
24
25
26
27
63
64
65
66
67
68
69
70
71
72
73

View File

@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
deit_base_distilled_patch16_224,pass,0
deit_tiny_patch16_224.fb_in1k,pass,0
dm_nfnet_f0,pass,0
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
visformer_small,pass,0
vit_base_patch14_dinov2.lvd142m,pass,0
vit_base_patch16_siglip_256,pass,0

1 name accuracy graph_breaks
10 mobilevit_s mobilenetv2_100 pass 0
11 nfnet_l0 mobilenetv3_large_100 pass 0
12 repvgg_a2 mobilevit_s pass 0
13 nfnet_l0 pass 0
14 repvgg_a2 pass 0
15 swin_base_patch4_window7_224 pass 0
16 tf_efficientnet_b0 pass 0
17 swin_base_patch4_window7_224 visformer_small pass 0
18 tf_efficientnet_b0 vit_base_patch14_dinov2.lvd142m pass 0
19 visformer_small vit_base_patch16_siglip_256 pass 0
20
21
22
23
24
25
26
27
63
64
65
66
67
68
69
70
71
72
73

View File

@ -10,10 +10,18 @@ beit_base_patch16_224,pass,7
convnextv2_nano.fcmae_ft_in22k_in1k,fail_accuracy,7
deit_base_distilled_patch16_224,pass,7
deit_tiny_patch16_224.fb_in1k,pass,7
dm_nfnet_f0,pass,6
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,6
visformer_small,pass,7
vit_base_patch14_dinov2.lvd142m,fail_accuracy,7
vit_base_patch16_siglip_256,pass,7

1 name accuracy graph_breaks
10 mobilevit_s mobilenetv2_100 pass 6 7
11 nfnet_l0 mobilenetv3_large_100 pass 7
12 repvgg_a2 mobilevit_s pass 7 6
13 nfnet_l0 pass 7
14 repvgg_a2 pass 7
15 swin_base_patch4_window7_224 pass 7
16 tf_efficientnet_b0 pass 6
17 swin_base_patch4_window7_224 visformer_small pass 7
18 tf_efficientnet_b0 vit_base_patch14_dinov2.lvd142m pass fail_accuracy 6 7
19 visformer_small vit_base_patch16_siglip_256 pass 7
20
21
22
23
24
25
26
27
63
64
65
66
67
68
69
70
71
72
73

View File

@ -1060,6 +1060,8 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs):
frozen_model_iter_fn = export_nativert(model, example_inputs)
elif args.torchscript_jit_trace:
frozen_model_iter_fn = torchscript_jit_trace(model, example_inputs)
elif args.aot_precompile:
frozen_model_iter_fn = aot_precompile(model, example_inputs)
else:
if kwargs["hf_llm"]:
# If it's an llm, we want to optimize model.forward, and use
@ -1495,6 +1497,37 @@ def export(model, example_inputs):
return opt_export
def aot_precompile(model, example_inputs):
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
save_path = f.name
with fresh_cache(), torch._dynamo.config.patch("enable_aot_compile", True):
compiled_fn = torch.compile(
model,
fullgraph=True,
options={"guard_filter_fn": lambda guards: [False for _ in guards]},
).forward.aot_compile((example_args, example_kwargs))
compiled_fn.save_compiled_function(save_path)
torch._dynamo.reset()
with open(save_path, "rb") as f:
load_start_time = time.perf_counter()
loaded_fn = torch.compiler.load_compiled_function(f)
load_end_time = time.perf_counter()
print(
f"AOT Precompile loading time: {load_end_time - load_start_time} seconds"
)
def opt_aot_precompile(_, example_inputs, collect_outputs=False):
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
return loaded_fn(model, *example_args, **example_kwargs)
return opt_aot_precompile
def export_nativert(model, example_inputs):
optimized = NativeRTCache.load(model, example_inputs)
@ -2274,6 +2307,7 @@ class BenchmarkRunner:
or self.args.export_aot_inductor
or self.args.export_nativert
or self.args.torchscript_jit_trace
or self.args.aot_precompile
):
# apply export on module directly
# no need for n iterations
@ -2729,6 +2763,7 @@ class BenchmarkRunner:
self.args.export_aot_inductor
or self.args.export_nativert
or self.args.torchscript_jit_trace
or self.args.aot_precompile
):
optimized_model_iter_fn = optimize_ctx
else:
@ -3505,6 +3540,11 @@ def parse_args(args=None):
action="store_true",
help="Measure pass rate with Export+AOTInductor",
)
group.add_argument(
"--aot-precompile",
action="store_true",
help="Measure pass rate with AOT Precompile",
)
group.add_argument(
"--export-nativert",
action="store_true",
@ -3935,6 +3975,10 @@ def run(runner, args, original_dir=None):
optimize_ctx = export
experiment = speedup_experiment
output_filename = "export.csv"
elif args.aot_precompile:
optimize_ctx = aot_precompile
experiment = speedup_experiment
output_filename = "aot_precompile.csv"
elif args.export_nativert:
optimize_ctx = export_nativert
experiment = speedup_experiment

View File

@ -271,8 +271,6 @@ class TimmRunner(BenchmarkRunner):
memory_format=torch.channels_last if channels_last else None,
)
self.num_classes = model.num_classes
data_config = resolve_data_config(
vars(self._args) if timmversion >= "0.8.0" else self._args,
model=model,
@ -302,7 +300,6 @@ class TimmRunner(BenchmarkRunner):
example_inputs = [
example_inputs,
]
self.target = self._gen_target(batch_size, device)
self.loss = torch.nn.CrossEntropyLoss().to(device)
@ -370,11 +367,6 @@ class TimmRunner(BenchmarkRunner):
tolerance = 1e-2
return tolerance, cosine
def _gen_target(self, batch_size, device):
return torch.empty((batch_size,) + (), device=device, dtype=torch.long).random_(
self.num_classes
)
def compute_loss(self, pred):
# High loss values make gradient checking harder, as small changes in
# accumulation order upsets accuracy checks.

View File

@ -1,6 +1,8 @@
adv_inception_v3 128
beit_base_patch16_224 128
convnextv2_nano.fcmae_ft_in22k_in1k 128
deit_base_distilled_patch16_224 128
deit_tiny_patch16_224.fb_in1k 128
dm_nfnet_f0 128
ghostnet_100 512
inception_v3 128
@ -12,3 +14,5 @@ repvgg_a2 128
swin_base_patch4_window7_224 128
tf_efficientnet_b0 128
visformer_small 128
vit_base_patch14_dinov2.lvd142m 128
vit_base_patch16_siglip_256 128

View File

@ -1,6 +1,8 @@
adv_inception_v3,128
beit_base_patch16_224,64
convnextv2_nano.fcmae_ft_in22k_in1k,128
deit_base_distilled_patch16_224,64
deit_tiny_patch16_224.fb_in1k,128
dm_nfnet_f0,128
ghostnet_100,128
inception_v3,128
@ -12,3 +14,5 @@ repvgg_a2,128
swin_base_patch4_window7_224,64
tf_efficientnet_b0,128
visformer_small,128
vit_base_patch14_dinov2.lvd142m,128
ViT-B-16-SigLIP-i18n-256,128

View File

@ -28,101 +28,8 @@
namespace c10 {
// [dtype Macros note] For the macros below:
//
// For users: If you want to macro some code for all non-QInt scalar types
// (i.e. types with complete information, you probably want one of the
// AT_FORALL_SCALAR_TYPES / AT_FORALL_SCALAR_TYPES_AND macros below, which are
// designed to behave similarly to the Dispatch macros with the same name.
//
// For adding a new dtype: In the beginning, we had an idea that there was a
// list of all scalar types, and you could use AT_FORALL_SCALAR_TYPES to
// iterate over them. But over the years we added weird types which couldn't
// be handled uniformly everywhere and so in the end we ended up with some
// mish-mosh of some helper macros, but mostly use sites making a call about
// what dtypes they can or can't support. So if you want to add a new dtype,
// the preferred resolution is to find a dtype similar to what you want,
// grep for it and edit all the sites you find this way. If you need to add
// a completely new kind of dtype, you're going to have to laboriously audit
// all of the sites everywhere to figure out how it should work. Consulting
// some old PRs where we added new dtypes (check history of this file) can
// help give you an idea where to start.
// If you want to support ComplexHalf for real, add ComplexHalf
// into this macro (and change the name). But beware: convert()
// doesn't work for all the conversions you need...
//
// TODO: To add unsigned int types here, we must define accumulate type.
// But uint8 currently accumulates into int64, so we would have to make
// an inconsistent choice for the larger types. Difficult.
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ(_) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(at::Half, Half) \
_(float, Float) \
_(double, Double) \
_(c10::complex<float>, ComplexFloat) \
_(c10::complex<double>, ComplexDouble) \
_(bool, Bool) \
_(at::BFloat16, BFloat16) \
_(at::Float8_e5m2, Float8_e5m2) \
_(at::Float8_e4m3fn, Float8_e4m3fn)
// This macro controls many of our C++ APIs, including constructors
// for Scalar as well as the data() and item() accessors on Tensor
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(at::Half, Half) \
_(float, Float) \
_(double, Double) \
_(c10::complex<c10::Half>, ComplexHalf) \
_(c10::complex<float>, ComplexFloat) \
_(c10::complex<double>, ComplexDouble) \
_(bool, Bool) \
_(at::BFloat16, BFloat16) \
_(at::Float8_e5m2, Float8_e5m2) \
_(at::Float8_e4m3fn, Float8_e4m3fn) \
_(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
namespace impl {
// These are used to map ScalarTypes to C++ types.
template <c10::ScalarType N>
struct ScalarTypeToCPPType;
#define SPECIALIZE_ScalarTypeToCPPType(cpp_type, scalar_type) \
template <> \
struct ScalarTypeToCPPType<c10::ScalarType::scalar_type> { \
using type = cpp_type; \
\
/* This is a workaround for the CUDA bug which prevents */ \
/* ::detail::ScalarTypeToCType<T>::type being used directly due to */ \
/* ambiguous reference which can't to be resolved. For some reason it */ \
/* can't pick between at::detail and at::cuda::detail. */ \
/* For repro example, please see: */ \
/* https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba */ \
/* TODO: remove once the bug is fixed. */ \
static type t; \
};
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType)
#undef SPECIALIZE_ScalarTypeToCPPType
template <c10::ScalarType N>
using ScalarTypeToCPPTypeT = typename ScalarTypeToCPPType<N>::type;
} // namespace impl
// See [dtype Macros note] in torch/headeronly/core/ScalarType.h
// regarding macros.
template <typename T>
struct CppTypeToScalarType;
@ -138,130 +45,6 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
#undef SPECIALIZE_CppTypeToScalarType
// NB: despite its generic sounding name, the macros that don't take _AND
// are mostly only used by tensorexpr
#define AT_FORALL_INT_TYPES(_) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long)
#define AT_FORALL_SCALAR_TYPES(_) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(float, Float) \
_(double, Double)
// These macros are often controlling how many template instantiations we
// create for kernels. It is typically inappropriate to add new dtypes here,
// instead, new types should be added to use sites on a case-by-case basis.
// We generally are not accepting new dtypes due to binary size concerns.
#define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE>::t), \
SCALARTYPE)
#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE1>::t), \
SCALARTYPE1) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE2>::t), \
SCALARTYPE2)
#define AT_FORALL_SCALAR_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, _) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE1>::t), \
SCALARTYPE1) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE2>::t), \
SCALARTYPE2) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE3>::t), \
SCALARTYPE3)
#define AT_FORALL_SCALAR_TYPES_AND7( \
SCALARTYPE1, \
SCALARTYPE2, \
SCALARTYPE3, \
SCALARTYPE4, \
SCALARTYPE5, \
SCALARTYPE6, \
SCALARTYPE7, \
_) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE1>::t), \
SCALARTYPE1) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE2>::t), \
SCALARTYPE2) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE3>::t), \
SCALARTYPE3) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE4>::t), \
SCALARTYPE4) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE5>::t), \
SCALARTYPE5) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE6>::t), \
SCALARTYPE6) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE7>::t), \
SCALARTYPE7)
#define AT_FORALL_QINT_TYPES(_) \
_(c10::qint8, QInt8) \
_(c10::quint8, QUInt8) \
_(c10::qint32, QInt32) \
_(c10::quint4x2, QUInt4x2) \
_(c10::quint2x4, QUInt2x4)
#define AT_FORALL_FLOAT8_TYPES(_) \
_(at::Float8_e5m2, Float8_e5m2) \
_(at::Float8_e4m3fn, Float8_e4m3fn) \
_(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
#define AT_FORALL_COMPLEX_TYPES(_) \
_(c10::complex<float>, ComplexFloat) \
_(c10::complex<double>, ComplexDouble)
#define DEFINE_CONSTANT(_, name) \
constexpr ScalarType k##name = ScalarType::name;
@ -269,19 +52,6 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT)
#undef DEFINE_CONSTANT
inline const char* toString(ScalarType t) {
#define DEFINE_CASE(_, name) \
case ScalarType::name: \
return #name;
switch (t) {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
default:
return "UNKNOWN_SCALAR";
}
#undef DEFINE_CASE
}
inline size_t elementSize(ScalarType t) {
#define CASE_ELEMENTSIZE_CASE(ctype, name) \
case ScalarType::name: \
@ -525,12 +295,6 @@ inline bool canCast(const ScalarType from, const ScalarType to) {
C10_API ScalarType promoteTypes(ScalarType a, ScalarType b);
inline std::ostream& operator<<(
std::ostream& stream,
at::ScalarType scalar_type) {
return stream << toString(scalar_type);
}
// Returns a pair of strings representing the names for each dtype.
// The returned pair is (name, legacy_name_if_applicable)
C10_API std::pair<std::string, std::string> getDtypeNames(

View File

@ -86,4 +86,23 @@ inline SymIntArrayRef fromIntArrayRefSlow(IntArrayRef array_ref) {
reinterpret_cast<const SymInt*>(array_ref.data()), array_ref.size());
}
inline c10::SymBool sym_equals(SymIntArrayRef LHS, SymIntArrayRef RHS) {
if (LHS.size() != RHS.size()) {
return c10::SymBool(false);
}
c10::SymBool result = sym_eq(LHS.size(), RHS.size());
for (size_t i = 0; i < RHS.size(); ++i) {
c10::SymBool equals = sym_eq(LHS[i], RHS[i]);
std::optional<bool> equals_bool = equals.maybe_as_bool();
if (equals_bool.has_value() && !*equals_bool) {
// Early return if element comparison is known to be false
return equals;
}
result = result.sym_and(equals);
}
return result;
}
} // namespace c10

View File

@ -1080,19 +1080,12 @@ class RingBuffer {
void getEntries(std::vector<T>& result) const {
std::lock_guard<std::mutex> lk(alloc_trace_lock);
result.reserve(alloc_trace->size());
result.insert(
result.end(),
alloc_trace->begin() +
static_cast<typename std::vector<T>::difference_type>(
alloc_trace_next),
alloc_trace->end());
result.insert(
result.end(),
result.reserve(result.size() + alloc_trace->size());
std::rotate_copy(
alloc_trace->begin(),
alloc_trace->begin() +
static_cast<typename std::vector<T>::difference_type>(
alloc_trace_next));
std::next(alloc_trace->begin(), alloc_trace_next),
alloc_trace->end(),
std::back_inserter(result));
}
void clear() {
@ -2502,8 +2495,6 @@ class DeviceCachingAllocator {
auto divisions = CUDAAllocatorConfig::roundup_power2_divisions(size);
if (divisions > 1 && size > (kMinBlockSize * divisions)) {
return roundup_power2_next_division(size, divisions);
} else if (divisions == 1) {
return llvm::PowerOf2Ceil(size);
} else {
return kMinBlockSize * ((size + kMinBlockSize - 1) / kMinBlockSize);
}
@ -4468,10 +4459,7 @@ struct BackendStaticInitializer {
if (kv[0] == "backend") {
#ifdef USE_ROCM
// convenience for ROCm users to allow either CUDA or HIP env var
if (kv[1] ==
"cud"
"aMallocAsync" ||
kv[1] == "hipMallocAsync")
if (kv[1] == "cudaMallocAsync" || kv[1] == "hipMallocAsync")
#else
if (kv[1] == "cudaMallocAsync")
#endif
@ -4493,9 +4481,7 @@ struct BackendStaticInitializer {
// HIPAllocatorMasqueradingAsCUDA because it needs to happen during static
// initialization, and doing so there may introduce static initialization
// order (SIOF) issues.
#define HIP_MASQUERADING_AS_CUDA \
"cud" \
"a"
#define HIP_MASQUERADING_AS_CUDA "cuda"
at::SetAllocator(c10::Device(HIP_MASQUERADING_AS_CUDA).type(), r, 0);
allocator.store(r);
#undef HIP_MASQUERADING_AS_CUDA

View File

@ -65,7 +65,7 @@ struct default_constructible
namespace impl {
template <typename T>
constexpr bool supports_default_construction(const ::strong::default_constructible::modifier<T>*)
constexpr bool supports_default_construction(const ::strong::default_constructible::modifier<T>* /*unused*/)
{
return true;
}
@ -76,7 +76,7 @@ class type : public modifier<M, type<T, Tag, M...>>...
{
public:
template <typename TT = T, typename = std::enable_if_t<std::is_trivially_constructible<TT>{}>>
explicit type(uninitialized_t)
explicit type(uninitialized_t /*unused*/)
noexcept
{
}
@ -138,7 +138,7 @@ private:
namespace impl {
template <typename T, typename Tag, typename ... Ms>
constexpr bool is_strong_type_func(const strong::type<T, Tag, Ms...>*) { return true;}
constexpr bool is_strong_type_func(const strong::type<T, Tag, Ms...>* /*unused*/) { return true;}
constexpr bool is_strong_type_func(...) { return false;}
template <typename T, typename Tag, typename ... Ms>
constexpr T underlying_type(strong::type<T, Tag, Ms...>*);

View File

@ -217,9 +217,7 @@ coverage_ignore_functions = [
"is_available",
# torch.distributed.checkpoint.state_dict
"gc_context",
"state_dict",
# torch.distributed.elastic.events
"construct_and_record_rdzv_event",
"record_rdzv_event",
# torch.distributed.elastic.metrics
"initialize_metrics",
@ -430,7 +428,6 @@ coverage_ignore_functions = [
"get_default_qconfig_dict",
"qconfig_equals",
# torch.ao.quantization.quantization_mappings
"get_default_compare_output_module_list",
"get_default_dynamic_quant_module_mappings",
"get_default_dynamic_sparse_quant_module_mappings",
"get_default_float_to_quantized_operator_mappings",
@ -473,29 +470,13 @@ coverage_ignore_functions = [
"get_weight_qspec",
"propagate_annotation",
"register_annotator",
# torch.ao.quantization.utils
"activation_dtype",
"activation_is_dynamically_quantized",
"activation_is_int32_quantized",
"activation_is_int8_quantized",
"activation_is_statically_quantized",
"calculate_qmin_qmax",
"check_min_max_valid",
"check_node",
"determine_qparams",
"get_combined_dict",
"get_fqn_to_example_inputs",
"get_qconfig_dtypes",
"get_qparam_dict",
"get_quant_type",
"get_swapped_custom_module_class",
"getattr_from_fqn",
"has_no_children_ignoring_parametrizations",
"is_per_channel",
"is_per_tensor",
"op_is_int8_dynamically_quantized",
"to_underlying_dtype",
"validate_qmin_qmax",
"weight_dtype",
"weight_is_quantized",
"weight_is_statically_quantized",
@ -553,42 +534,6 @@ coverage_ignore_functions = [
# torch.distributed.checkpoint.utils
"find_state_dict_object",
"find_tensor_shard",
# torch.distributed.collective_utils
"all_gather",
"all_gather_object_enforce_type",
"broadcast",
# torch.distributed.distributed_c10d
"all_gather",
"all_gather_coalesced",
"all_gather_into_tensor",
"all_gather_object",
"all_reduce",
"all_reduce_coalesced",
"all_to_all",
"all_to_all_single",
"barrier",
"batch_isend_irecv",
"broadcast",
"broadcast_object_list",
"destroy_process_group",
"gather",
"gather_object",
"get_backend",
"get_backend_config",
"get_global_rank",
"get_group_rank",
"get_process_group_ranks",
"get_rank",
"get_world_size",
"init_process_group",
"irecv",
"is_backend_available",
"is_gloo_available",
"is_initialized",
"is_mpi_available",
"is_nccl_available",
"is_torchelastic_launched",
"is_ucc_available",
"isend",
"monitored_barrier",
"new_group",
@ -662,15 +607,8 @@ coverage_ignore_functions = [
"transformer_auto_wrap_policy",
"wrap",
# torch.distributed.nn.functional
"all_gather",
"all_reduce",
"all_to_all",
"all_to_all_single",
"broadcast",
"gather",
"reduce",
"reduce_scatter",
"scatter",
# torch.distributed.nn.jit.instantiator
"get_arg_return_types_from_interface",
"instantiate_non_scriptable_remote_module_template",
@ -1081,6 +1019,8 @@ coverage_ignore_functions = [
"loop_pass",
"these_before_those_pass_constraint",
"this_before_that_pass_constraint",
# torch.fx.passes.regional_inductor
"regional_inductor",
# torch.fx.passes.reinplace
"reinplace",
# torch.fx.passes.split_module

View File

@ -10,6 +10,7 @@ torch.cpu
current_device
current_stream
is_available
is_initialized
synchronize
stream
set_device

View File

@ -176,10 +176,6 @@
.. autoclass:: torch.cuda.use_mem_pool
```
% FIXME The following doesn't seem to exist. Is it supposed to?
% https://github.com/pytorch/pytorch/issues/27785
% .. autofunction:: reset_max_memory_reserved
## NVIDIA Tools Extension (NVTX)
```{eval-rst}
@ -299,4 +295,4 @@ See the docs for {class}`~torch.cuda.gds.GdsFile` for an example of how to use t
:hidden:
cuda.aliases.md
```
```

View File

@ -68,14 +68,6 @@
.. autofunction:: get_validators
```
```{eval-rst}
.. autofunction:: write_file_on_exit
```
```{eval-rst}
.. autofunction:: write_file
```
```{eval-rst}
.. autofunction:: read_file
```
@ -95,3 +87,7 @@
```{eval-rst}
.. autofunction:: get_rotating_buffer_size
```
```{eval-rst}
.. autofunction:: set_numerical_check_tolerances
```

View File

@ -123,3 +123,7 @@ The frontend API is `fully_shard` that can be called on a `module`:
.. autoclass:: CPUOffloadPolicy
:members:
```
```{eval-rst}
.. autofunction:: share_comm_ctx
```

View File

@ -51,7 +51,7 @@ MPI supports CUDA only if the implementation used to build PyTorch supports it.
+----------------+-----+-----+-----+-----+-----+-----+-----+-----+
| reduce_scatter | ✓ | ✓ | ✘ | ✘ | ✘ | ✓ | ✘ | ✓ |
+----------------+-----+-----+-----+-----+-----+-----+-----+-----+
| all_to_all | | | ✓ | ? | ✘ | ✓ | ✘ | ✓ |
| all_to_all | | | ✓ | ? | ✘ | ✓ | ✘ | ✓ |
+----------------+-----+-----+-----+-----+-----+-----+-----+-----+
| barrier | ✓ | ✘ | ✓ | ? | ✘ | ✓ | ✘ | ✓ |
+----------------+-----+-----+-----+-----+-----+-----+-----+-----+
@ -221,6 +221,16 @@ inconsistent 'UUID' assignment across ranks, and to prevent races during initial
```{eval-rst}
.. autofunction:: torch.distributed.distributed_c10d.is_xccl_available
.. autofunction:: torch.distributed.distributed_c10d.batch_isend_irecv
.. autofunction:: torch.distributed.distributed_c10d.destroy_process_group
.. autofunction:: torch.distributed.distributed_c10d.is_backend_available
.. autofunction:: torch.distributed.distributed_c10d.irecv
.. autofunction:: torch.distributed.distributed_c10d.is_gloo_available
.. autofunction:: torch.distributed.distributed_c10d.is_initialized
.. autofunction:: torch.distributed.distributed_c10d.is_mpi_available
.. autofunction:: torch.distributed.distributed_c10d.is_nccl_available
.. autofunction:: torch.distributed.distributed_c10d.is_torchelastic_launched
.. autofunction:: torch.distributed.distributed_c10d.is_ucc_available
```
```{eval-rst}

View File

@ -1169,6 +1169,7 @@ The set of leaf modules can be customized by overriding
.. py:module:: torch.fx.passes.operator_support
.. py:module:: torch.fx.passes.param_fetch
.. py:module:: torch.fx.passes.pass_manager
.. py:module:: torch.fx.passes.regional_inductor
.. py:module:: torch.fx.passes.reinplace
.. py:module:: torch.fx.passes.runtime_assert
.. py:module:: torch.fx.passes.shape_prop

View File

@ -23,6 +23,7 @@ Submodules
flex_attention
bias
experimental
varlen
.. toctree::
:hidden:
@ -30,3 +31,4 @@ Submodules
nn.attention.flex_attention
nn.attention.bias
nn.attention.experimental
nn.attention.varlen

View File

@ -0,0 +1,17 @@
```{eval-rst}
.. role:: hidden
:class: hidden-section
```
# torch.nn.attention.varlen
```{eval-rst}
.. automodule:: torch.nn.attention.varlen
.. currentmodule:: torch.nn.attention.varlen
```
```{eval-rst}
.. autofunction:: varlen_attn
```
```{eval-rst}
.. autoclass:: AuxRequest
```

View File

@ -228,3 +228,4 @@ Low-Precision functions
ScalingType
SwizzleType
scaled_mm
scaled_grouped_mm

View File

@ -52,6 +52,26 @@ This module contains Eager mode quantization APIs.
default_eval_fn
```
## torch.ao.quantization.utils
```{eval-rst}
.. automodule:: torch.ao.quantization.utils
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
activation_is_dynamically_quantized
activation_is_int32_quantized
activation_is_int8_quantized
activation_is_statically_quantized
determine_qparams
check_min_max_valid
calculate_qmin_qmax
validate_qmin_qmax
```
## torch.ao.quantization.quantize_fx
This module contains FX graph mode quantization APIs (prototype).
@ -150,7 +170,7 @@ This module contains a few CustomConfig classes that's used in both eager mode a
## torch.ao.quantization.pt2e.export_utils
```{eval-rst}
.. currentmodule:: torch.ao.quantization.pt2e.export_utils
.. automodule:: torch.ao.quantization.pt2e.export_utils
```
```{eval-rst}

View File

@ -134,7 +134,6 @@ and supported quantized modules and functions.
.. py:module:: torch.ao.quantization.fx.utils
.. py:module:: torch.ao.quantization.observer
.. py:module:: torch.ao.quantization.pt2e.duplicate_dq_pass
.. py:module:: torch.ao.quantization.pt2e.export_utils
.. py:module:: torch.ao.quantization.pt2e.graph_utils
.. py:module:: torch.ao.quantization.pt2e.port_metadata_pass
.. py:module:: torch.ao.quantization.pt2e.prepare
@ -158,7 +157,6 @@ and supported quantized modules and functions.
.. py:module:: torch.ao.quantization.quantizer.xnnpack_quantizer
.. py:module:: torch.ao.quantization.quantizer.xnnpack_quantizer_utils
.. py:module:: torch.ao.quantization.stubs
.. py:module:: torch.ao.quantization.utils
.. py:module:: torch.nn.intrinsic.modules.fused
.. py:module:: torch.nn.intrinsic.qat.modules.conv_fused
.. py:module:: torch.nn.intrinsic.qat.modules.linear_fused

View File

@ -1,14 +1,12 @@
```{eval-rst}
.. currentmodule:: torch.compiler.config
```
# torch.compiler.config
```{eval-rst}
.. automodule:: torch.compiler.config
```
```{eval-rst}
.. autodata:: torch.compiler.config.job_id
:members:
:undoc-members:
:show-inheritance:
```

Some files were not shown because too many files have changed in this diff Show More