Compare commits

...

186 Commits

Author SHA1 Message Date
a4aa741d88 Update on "Fix all gather bucketing fusion in of dtype casts"
The all gather bucketing was part of the way to fusing in dtype casts into the bucket. We do this by allocating the group bucket buffer, then viewing each slice of it as the destination dtype. We then foreach_copy_ into the allocated buffer, with each collective copying in to its destination dtype.

This logic was causing an issue in a later part of the stack, but not fully firing, so might as well fix it. 

Note: custom ops dont yet support list[dtype], so i worked around by list[int], but will fix in a follow up.



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-11-17 15:09:58 -08:00
2f92952af4 Update base for Update on "Fix all gather bucketing fusion in of dtype casts"
The all gather bucketing was part of the way to fusing in dtype casts into the bucket. We do this by allocating the group bucket buffer, then viewing each slice of it as the destination dtype. We then foreach_copy_ into the allocated buffer, with each collective copying in to its destination dtype.

This logic was causing an issue in a later part of the stack, but not fully firing, so might as well fix it. 

Note: custom ops dont yet support list[dtype], so i worked around by list[int], but will fix in a follow up.



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-11-17 15:09:58 -08:00
8d44a39e0a Update on "Fix all gather bucketing fusion in of dtype casts"
The all gather bucketing was part of the way to fusing in dtype casts into the bucket. We do this by allocating the group bucket buffer, then viewing each slice of it as the destination dtype. We then foreach_copy_ into the allocated buffer, with each collective copying in to its destination dtype.

This logic was causing an issue in a later part of the stack, but not fully firing, so might as well fix it. 

Note: custom ops dont yet support list[dtype], so i worked around by list[int], but will fix in a follow up.



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-11-17 11:23:14 -08:00
c96f4288fc Update base for Update on "Fix all gather bucketing fusion in of dtype casts"
The all gather bucketing was part of the way to fusing in dtype casts into the bucket. We do this by allocating the group bucket buffer, then viewing each slice of it as the destination dtype. We then foreach_copy_ into the allocated buffer, with each collective copying in to its destination dtype.

This logic was causing an issue in a later part of the stack, but not fully firing, so might as well fix it. 

Note: custom ops dont yet support list[dtype], so i worked around by list[int], but will fix in a follow up.



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-11-17 11:23:14 -08:00
1d7b1616c0 Update on "Fix all gather bucketing fusion in of dtype casts"
The all gather bucketing was part of the way to fusing in dtype casts into the bucket. We do this by allocating the group bucket buffer, then viewing each slice of it as the destination dtype. We then foreach_copy_ into the allocated buffer, with each collective copying in to its destination dtype.

This logic was causing an issue in a later part of the stack, but not fully firing, so might as well fix it. 

Note: custom ops dont yet support list[dtype], so i worked around by list[int], but will fix in a follow up.



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-11-17 09:53:40 -08:00
14cfb9fb10 Update base for Update on "Fix all gather bucketing fusion in of dtype casts"
The all gather bucketing was part of the way to fusing in dtype casts into the bucket. We do this by allocating the group bucket buffer, then viewing each slice of it as the destination dtype. We then foreach_copy_ into the allocated buffer, with each collective copying in to its destination dtype.

This logic was causing an issue in a later part of the stack, but not fully firing, so might as well fix it. 

Note: custom ops dont yet support list[dtype], so i worked around by list[int], but will fix in a follow up.



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-11-17 09:53:40 -08:00
567dcdba75 Fix longstanding race condition around getAllOperatorsFor (#167860)
getAllOperatorsFor returns a const reference to internal state that is protected by a lock. Presuming that the lock is necessary in the first place (about which I offer no opinion because it's unclear to what extent the GIL should help here), this is a straightforward way to cause callers to create race conditions.

This should fix those race conditions by copying the state instead. I modified calling code to stop binding a const reference to the result for clarity.

Differential Revision: [D87088731](https://our.internmc.facebook.com/intern/diff/D87088731/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D87088731/)!

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167860
Approved by: https://github.com/zou3519
2025-11-17 17:37:02 +00:00
77acc66df9 [ROCm][CI] Upgrade ROCm CI to 7.1 (#166743)
Upgrade all the ROCm docker images to ROCm 7.1 release version.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166743
Approved by: https://github.com/atalman, https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
Co-authored-by: Prachi Gupta <prachi.gupta@amd.com>
2025-11-17 17:17:25 +00:00
95d1df7d4e Disable CUDA MXFP4 on non-B200 GPUs (#167857)
Summary:

MXFP4 unit tests pass on B200, fail on RTX 5090 - disable non-B200
cases.

Also add a fail w/a not implemented error for non-B200 to avoid
unhelpful failure messages.

Test Plan:

```
pytest -sv -k "mxfp4" test/test_scaled_matmul_cuda.py
```

Reviewers:

@nWEIdia

Subscribers:

Tasks:

Fixes https://github.com/pytorch/pytorch/issues/167850

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167857
Approved by: https://github.com/nWEIdia, https://github.com/malfet
2025-11-17 17:14:53 +00:00
094e529c64 [MPS] Fix repeat_interleave with slices (#167961)
Alas, one can not use `repeat_interleave_common` for MPS tensors, as `data_offset` is not a valid pointer to `id<MTLTensor>`
On the other hand, one does not need to use `AT_DISPATCH_INDEX_TYPES` as dispatching is happening on the shader side

Fixes https://github.com/pytorch/pytorch/issues/167924
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167961
Approved by: https://github.com/manuelcandales
2025-11-17 17:10:59 +00:00
a4c7bf7e8d Revert "Use c10::filesystem (#167821)"
This reverts commit deabb3e36de207aa497b035a8bdf6ec1b37d17fe.

Reverted https://github.com/pytorch/pytorch/pull/167821 on behalf of https://github.com/jeanschmidt due to Breaks internal tests, see D87148810. @Skylion007 may you help the author to get this PR merged? ([comment](https://github.com/pytorch/pytorch/pull/167821#issuecomment-3542877623))
2025-11-17 16:48:57 +00:00
22ccd44d73 Revert "Improve char printing (#167899)"
This reverts commit 2245d7d3b90162ae2958929a22c140537cfc4b42.

Reverted https://github.com/pytorch/pytorch/pull/167899 on behalf of https://github.com/jeanschmidt due to need to revert in order to revert https://github.com/pytorch/pytorch/pull/167899 ([comment](https://github.com/pytorch/pytorch/pull/167899#issuecomment-3542869096))
2025-11-17 16:46:44 +00:00
39ebab1dd9 Revert "Remove python workaround for ContextDecorator (#167049)"
This reverts commit e20ca3bc2e6ef9935c782fe548348f81fabc5bd7.

Reverted https://github.com/pytorch/pytorch/pull/167049 on behalf of https://github.com/jeanschmidt due to breaks internal tests see D87120562, @Skylion007 please thelp the author get this PR merged ([comment](https://github.com/pytorch/pytorch/pull/167049#issuecomment-3542847796))
2025-11-17 16:41:26 +00:00
4c152a71ad Revert "add device generalization support for distributed tests (#165067)"
This reverts commit 96a4c4b3d1c533b36cfa7259524b91a0eaf4254f.

Reverted https://github.com/pytorch/pytorch/pull/165067 on behalf of https://github.com/jeanschmidt due to breaks internal tests see D87036515, @albanD please help the author get this PR merged ([comment](https://github.com/pytorch/pytorch/pull/165067#issuecomment-3542820651))
2025-11-17 16:37:07 +00:00
35f12801b7 Update on "Fix all gather bucketing fusion in of dtype casts"
The all gather bucketing was part of the way to fusing in dtype casts into the bucket. We do this by allocating the group bucket buffer, then viewing each slice of it as the destination dtype. We then foreach_copy_ into the allocated buffer, with each collective copying in to its destination dtype.

This logic was causing an issue in a later part of the stack, but not fully firing, so might as well fix it. 

Note: custom ops dont yet support list[dtype], so i worked around by list[int], but will fix in a follow up.



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-11-17 08:32:43 -08:00
a44daec0a9 Update base for Update on "Fix all gather bucketing fusion in of dtype casts"
The all gather bucketing was part of the way to fusing in dtype casts into the bucket. We do this by allocating the group bucket buffer, then viewing each slice of it as the destination dtype. We then foreach_copy_ into the allocated buffer, with each collective copying in to its destination dtype.

This logic was causing an issue in a later part of the stack, but not fully firing, so might as well fix it. 

Note: custom ops dont yet support list[dtype], so i worked around by list[int], but will fix in a follow up.



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-11-17 08:32:43 -08:00
1b43d6cd4e [ROCm] enable fastSpecializedAtomicAdd for gfx950 (#167661)
Use standard HIP headers for unsafeAtomicAdd. Removes copy/paste of unsafeAtomicAdd as "preview" implementation for gfx942.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167661
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-11-17 16:18:49 +00:00
2b69673bbf [CD] Add libopenblas to dep list for AArch64+CPU whl (#167841)
#166044 removes openblas from whl dependency list for AArch64+CPU build so this PR adds it back. Only affects CPU build since AArch64+CUDA uses NVPL.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167841
Approved by: https://github.com/tinglvv, https://github.com/malfet
2025-11-17 16:11:39 +00:00
2f74916e36 Do not hardfail on use nccl estimations for non-nccl (#167827)
Previously we hard failed if pg was "gloo".
Fallback on hardcoded formulas.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167827
Approved by: https://github.com/eellison
2025-11-17 16:06:26 +00:00
2b5eabc74b Rework PyObject preservation (v2) (#167564)
Make the PyObject preservation scheme thread-safe with free threaded (nogil) Python. The general idea is:

* Python Tensor and Storage objects always hold a strong reference to their underlying c10 object
* c10 objects hold a strong reference to their Python objects if there's at least one other reference to the c10 object

This is implemented in `intrusive_ptr`:

* The top most bit (`kHasPyObject`) from the weakref count is now used to indicate if the `intrusive_ptr_target` has an associated PyObject. So `kHasPyObject` is one bit, the weakref count is now 31 bits and the strong refcount remains 32 bits.
* When the reference count increases from one to two and `kHasPyObject` is set, we incref the associated Python object to ensure that it's kept alive.
* When the reference count decreases from two to one (i.e., there are no C++ reference to the `intrusive_ptr_target` other than from the Python object), we decre the associated Python object to break the cycle.

Other benefits:

* We can delete a lot of the copypasta from Python internal `subtype_dealloc`
* This fixes the weakref and GC bugs we had in the previous scheme. Python weakrefs on Tensors and Storages should just work as expected now.

Risks:

* Extra branch for reference count operations on `intrusive_ptr<TensorImpl>`, `intrusive_ptr<StorageImpl>`, and the generic `intrusive_ptr<intrusive_ptr_target>` even when we're not using Python.
* It's a big change

(Second attempt at https://github.com/pytorch/pytorch/pull/166342)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167564
Approved by: https://github.com/albanD, https://github.com/Skylion007
2025-11-17 14:52:02 +00:00
9ff95f6835 [inductor] Expose config for fx bucket all_reduces (#167634)
Exposing `_inductor.config.bucket_all_reduces_fx` similar to all_gathers, reduce_scatters with only option "all".

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167634
Approved by: https://github.com/eellison
2025-11-17 13:10:36 +00:00
6fdb974f4a Update torch-xpu-ops commit pin (#167698)
Update the torch-xpu-ops commit to [intel/torch-xpu-ops@1e69f4](1e69f40b3c), includes:

- Add PTL in the default AOT target list for both Win and Lin
- Use PyTorch p2p API in Copy kernel
- Add event cache and event timing to XCCL
- Add Float8_e8m0fnu support for copy
- Add CMAKE_SYCL_COMPILER_LAUNCHER for sccache
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167698
Approved by: https://github.com/EikanWang
2025-11-17 12:58:42 +00:00
661d1653aa [xla hash update] update the pinned xla hash (#167968)
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml).
Update the pinned xla hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167968
Approved by: https://github.com/pytorchbot
2025-11-17 12:20:32 +00:00
53809f9640 [ARM] Improve LLM performance & mem usage using int4-bf16 KleidiAI kernels (#158250)
Co-authored-by: Nikhil Gupta [nikhil.gupta2@arm.com](mailto:nikhil.gupta2@arm.com)

This PR enables the use of KleidiAI INT4 kernels that directly produce BF16 outputs within PyTorch to boost LLM prefill & decode performance

**This change improves decode throughput by ~15% & reduces memory required to inference the model by 50%**

### Benchmark Setup
```
Model: meta-llama/Llama-3.1-8B
Test Platform: Neoverse V2
```
### Detailed Results

| Metric                           | With `--compile`         | Without `--compile`      |
|----------------------------------|---------------------------|---------------------------|
| Quantization Scheme              | INT4 symmetric channelwise | INT4 symmetric channelwise |
| Input Precision                  | BF16                      | BF16                      |
| Number of Layers Quantized       | 32                        | 32                        |
| Average Compression Ratio        | 87.49%                    | 87.49%                    |
| Total Quantization Time (s)      | 9.62                      | 10.32                     |
| Compile Time (First) (s)         | 134.48                    | 1.69                      |
| Compile Time (Second) (s)        | 80.44                     | 1.60                      |
| Compile Time (Subsequent) (s)    | 0.19                      | 0.22                      |
| Prefill Tokens                   | 54                        | 54                        |
| Decoded Tokens                   | 33                        | 33                        |
| Prefill Time (s)                 | 0.19                      | 0.22                      |
| Decode Time (s)                  | 0.76                      | 1.38                      |
| E2E Generation Time (s)          | 0.95                      | 1.60                      |
| Prefill Throughput (tokens/s)    | 288.13                    | 249.91                    |
| Decode Throughput (tokens/s)     | 43.42                     | 23.83                     |
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158250
Approved by: https://github.com/malfet, https://github.com/aditew01, https://github.com/fadara01

Co-authored-by: Nikhil Gupta <nikhil.gupta2@arm.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-11-17 12:06:33 +00:00
93ddd38ecd Re-land#2 "Fix thread safety in getCurrentCUDABlasHandle and getCUDABlasLtWorkspace" (#167928)
Summary:
getCurrentCUDABlasHandle() and getCUDABlasLtWorkspace() use static mutable maps that are not protected from concurrent read-and-write. This leads to crashes.

This diff adds mutexes to synchronize access to the static maps.

Re-land context:

This is a re-land of https://github.com/pytorch/pytorch/pull/167248.

A few issues were addressed:
- fix for a bug in fast path: premature return in getCurrentCUDABlasHandle)
- fix for test flakiness (https://github.com/pytorch/pytorch/pull/167884)

Test Plan:
1. regression tests:
buck2 test \mode/opt //caffe2/test\:test_transformers_cuda
https://www.internalfb.com/intern/testinfra/testrun/6192449759713581

2. Use a GPU OD, run multi-threaded tests with TSAN:

buck test fbcode//mode/dev-tsan fbcode//caffe2:cuda_cublas_handle_pool_test  -- --stress-runs 100
https://www.internalfb.com/intern/testinfra/testrun/14355223937501118

Differential Revision: D87111985

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167928
Approved by: https://github.com/Skylion007
2025-11-17 12:05:08 +00:00
5804408f1b [1/3][XPU][feature] The implementation of memory private pool in XPU device allocator (#166831)
The implementation plan of MemPool for XPU, which is the dependance of [XPUGraph](https://github.com/pytorch/pytorch/pull/166285), following the [RFC](https://github.com/pytorch/pytorch/issues/162143).

- [ ] ->#166831
- [ ] #166833
- [ ] #166843

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166831
Approved by: https://github.com/EikanWang, https://github.com/gujinghui

Co-authored-by: Eikan Wang <eikan.wang@intel.com>
2025-11-17 11:11:23 +00:00
99117c1238 Remove old NVTX interface (#167637)
The PR #167401 reminded me that the removal of old NVTX interface is long overdue, as the header-only NVTX3 has been around for more than 5 years and is shipped with all CUDA Toolkit versions of 12+. In addition to that, `libnvToolsExt.so` was removed in CUDA Toolkit 13 and onward.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167637
Approved by: https://github.com/eqy
2025-11-17 08:07:20 +00:00
b9bccec3bc Revert "[ATen][CUDA] Add sm_121a flag for RowwiseScaledMM (#167734)"
This reverts commit 226850cc66217e591c706397dd212b457ed61e22.

Reverted https://github.com/pytorch/pytorch/pull/167734 on behalf of https://github.com/Aidyn-A due to fails on CUDA 12.8 ([comment](https://github.com/pytorch/pytorch/pull/167734#issuecomment-3540410067))
2025-11-17 07:56:28 +00:00
ca3aaef66e Fix clamp broadcasting on MPS (Fixes #160734) (#165058)
This PR fixes a bug where `torch.clamp` on MPS fails when min/max tensors have more dimensions than the input tensor.
CPU already supports this broadcasting, but MPS raised a RuntimeError.

Example of failing case before the fix:
```python
x = torch.randn(2, 3, device="mps")
min_t = torch.randn(1, 2, 3, device="mps")
max_t = torch.randn(1, 2, 3, device="mps")
torch.clamp(x, min=min_t, max=max_t)  # RuntimeError
```
After this fix, MPS matches CPU behavior.

Fixes #160734

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165058
Approved by: https://github.com/malfet
2025-11-17 07:40:39 +00:00
f2e6f94081 deprecate check_is_size and guard_size_oblivious (#167198)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167198
Approved by: https://github.com/bobrenjc93
2025-11-17 05:47:40 +00:00
aa504d4d2a [audio hash update] update the pinned audio hash (#167914)
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml).
Update the pinned audio hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167914
Approved by: https://github.com/pytorchbot
2025-11-17 05:21:29 +00:00
d8ce6f8df9 Enable PyTorch OSS numerics changes, inductor heuristics (#167799)
Test Plan: CI

Differential Revision: D86211542

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167799
Approved by: https://github.com/njriasan, https://github.com/eellison
2025-11-17 04:31:44 +00:00
4322354770 [Inductor] optimize scalar welford_reduce (#162709)
**Summary:**
Optimize scalar welford_reduce implementation, combining Welford algorithm with cascade sum to improve numerical stability. Specifically:

1. Use Welford algorithm to compute mean and variance.
2. Use cascade summation when computing sum over input for both mean and variance.

**Example:**
Take https://github.com/pytorch/pytorch/issues/141541 as an example:
```
import torch
import torch.nn as nn
torch.manual_seed(0)

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.gn = nn.GroupNorm(num_groups=32, num_channels=32)

    def forward(self, x):
        return self.gn(x)

model = Model().eval()
x = torch.randn(1, 32, 128, 128, 128)

with torch.no_grad():
    output = model(x)
    with torch._inductor.config.patch({"cpp.simdlen": 0}):
        c_model = torch.compile(model)
        c_output = c_model(x)

print(torch.max(torch.abs(output - c_output)))
print(torch.allclose(output, c_output, 1.3e-6, 1e-5))
```
**logs**

- before
```
tensor(0.0005)
False
```
- After
```
tensor(1.4305e-06)
True
```

**Generated code:**
- before
```
cpp_fused_native_group_norm_0 = async_compile.cpp_pybinding(['float*', 'float*', 'const float*', 'const float*', 'const float*', 'float*'], '''
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C"  void  kernel(float* in_out_ptr0,
                       float* in_out_ptr1,
                       const float* in_ptr0,
                       const float* in_ptr1,
                       const float* in_ptr2,
                       float* out_ptr2)
{
    auto out_ptr1 = in_out_ptr0;
    auto out_ptr0 = in_out_ptr1;
    {
        #pragma GCC ivdep
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(32L); x0+=static_cast<int64_t>(1L))
        {
            {
                Welford<float> tmp_acc0 = Welford<float>();
                Welford<float> tmp_acc0_arr[4];
                for (int i = 0; i < 4; i++)
                {
                    tmp_acc0_arr[i] = Welford<float>();
                }
                #pragma omp parallel num_threads(4)
                {
                    int tid = omp_get_thread_num();
                    Welford<float> tmp_acc0_local = Welford<float>();
                    #pragma omp for
                    for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(2097152L); x1+=static_cast<int64_t>(1L))
                    {
                        {
                            {
                                auto tmp0 = in_ptr0[static_cast<int64_t>(x1 + 2097152L*x0)];
                                tmp_acc0_local = welford_combine(tmp_acc0_local, tmp0);
                            }
                        }
                    }
                    tmp_acc0_arr[tid] = tmp_acc0_local;
                }
                for (int tid = 0; tid < 4; tid++)
                {
                    tmp_acc0 = welford_combine(tmp_acc0, tmp_acc0_arr[tid]);
                }
                in_out_ptr1[static_cast<int64_t>(x0)] = tmp_acc0.mean;
                in_out_ptr0[static_cast<int64_t>(x0)] = tmp_acc0.m2;
            }
        }
    }
    {
        #pragma GCC ivdep
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(32L); x0+=static_cast<int64_t>(1L))
        {
            {
                {
                    auto tmp0 = out_ptr1[static_cast<int64_t>(x0)];
                    auto tmp6 = in_ptr1[static_cast<int64_t>(x0)];
                    auto tmp8 = out_ptr0[static_cast<int64_t>(x0)];
                    auto tmp11 = in_ptr2[static_cast<int64_t>(x0)];
                    auto tmp1 = static_cast<float>(2097152.0);
                    auto tmp2 = tmp0 / tmp1;
                    auto tmp3 = static_cast<float>(1e-05);
                    auto tmp4 = float(tmp2 + tmp3);
                    auto tmp5 = 1 / std::sqrt(tmp4);
                    auto tmp7 = float(tmp5 * tmp6);
                    auto tmp9 = decltype(tmp8)(-tmp8);
                    auto tmp10 = float(tmp9 * tmp7);
                    auto tmp12 = float(tmp10 + tmp11);
                    in_out_ptr0[static_cast<int64_t>(x0)] = tmp7;
                    in_out_ptr1[static_cast<int64_t>(x0)] = tmp12;
                }
            }
        }
    }
    #pragma omp parallel num_threads(4)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for
            for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(32L); x0+=static_cast<int64_t>(1L))
            {
                #pragma GCC ivdep
                for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(2097152L); x1+=static_cast<int64_t>(1L))
                {
                    {
                        {
                            auto tmp0 = in_ptr0[static_cast<int64_t>(x1 + 2097152L*x0)];
                            auto tmp1 = in_out_ptr0[static_cast<int64_t>(x0)];
                            auto tmp3 = in_out_ptr1[static_cast<int64_t>(x0)];
                            auto tmp2 = float(tmp0 * tmp1);
                            auto tmp4 = float(tmp2 + tmp3);
                            out_ptr2[static_cast<int64_t>(x1 + 2097152L*x0)] = tmp4;
                        }
                    }
                }
            }
        }
    }
}
''')

async_compile.wait(globals())
del async_compile

class Runner:
    def __init__(self, partitions):
        self.partitions = partitions

    def recursively_apply_fns(self, fns):
        new_callables = []
        for fn, c in zip(fns, self.partitions):
            new_callables.append(fn(c))
        self.partitions = new_callables

    def call(self, args):
        arg0_1, arg1_1, arg2_1 = args
        args.clear()
        assert_size_stride(arg0_1, (32, ), (1, ))
        assert_size_stride(arg1_1, (32, ), (1, ))
        assert_size_stride(arg2_1, (1, 32, 128, 128, 128), (67108864, 2097152, 16384, 128, 1))
        buf0 = empty_strided_cpu((1, 32, 1, 1), (32, 1, 32, 32), torch.float32)
        buf1 = empty_strided_cpu((1, 32, 1, 1), (32, 1, 32, 32), torch.float32)
        buf3 = reinterpret_tensor(buf1, (1, 32, 1, 1), (32, 1, 1, 1), 0); del buf1  # reuse
        buf4 = reinterpret_tensor(buf0, (1, 32, 1, 1), (32, 1, 1, 1), 0); del buf0  # reuse
        buf5 = empty_strided_cpu((1, 32, 128, 128, 128), (67108864, 2097152, 16384, 128, 1), torch.float32)
        # [Provenance debug handles] cpp_fused_native_group_norm_0:1
        cpp_fused_native_group_norm_0(buf3, buf4, arg2_1, arg0_1, arg1_1, buf5)
        del arg0_1
        del arg1_1
        del arg2_1
        return (buf5, )
```

- After
```
cpp_fused_native_group_norm_0 = async_compile.cpp_pybinding(['float*', 'float*', 'const float*', 'const float*', 'const float*', 'float*'], '''
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C"  void  kernel(float* in_out_ptr0,
                       float* in_out_ptr1,
                       const float* in_ptr0,
                       const float* in_ptr1,
                       const float* in_ptr2,
                       float* out_ptr2)
{
    auto out_ptr1 = in_out_ptr0;
    auto out_ptr0 = in_out_ptr1;
    {
        #pragma GCC ivdep
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(32L); x0+=static_cast<int64_t>(1L))
        {
            {
                Welford<float> tmp_acc0 = Welford<float>();
                Welford<float> tmp_acc0_arr[4];
                for (int i = 0; i < 4; i++)
                {
                    tmp_acc0_arr[i] = Welford<float>();
                }
                #pragma omp parallel num_threads(4)
                {
                    int tid = omp_get_thread_num();
                    WelfordHelper<float, float, 4096> scalar_welford_helper0(static_cast<int64_t>(524288L));
                    Welford<float> tmp_acc0_local = Welford<float>();
                    #pragma omp for
                    for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(2097152L); x1+=static_cast<int64_t>(1L))
                    {
                        {
                            {
                                auto tmp0 = in_ptr0[static_cast<int64_t>(x1 + 2097152L*x0)];
                                tmp_acc0_local = welford_combine(tmp_acc0_local, tmp0, &scalar_welford_helper0);
                            }
                        }
                    }
                    tmp_acc0_local = welford_combine(tmp_acc0_local, &scalar_welford_helper0);
                    tmp_acc0_arr[tid] = tmp_acc0_local;
                }
                for (int tid = 0; tid < 4; tid++)
                {
                    tmp_acc0 = welford_combine(tmp_acc0, tmp_acc0_arr[tid]);
                }
                in_out_ptr1[static_cast<int64_t>(x0)] = tmp_acc0.mean;
                in_out_ptr0[static_cast<int64_t>(x0)] = tmp_acc0.m2;
            }
        }
    }
    {
        #pragma GCC ivdep
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(32L); x0+=static_cast<int64_t>(1L))
        {
            {
                {
                    auto tmp0 = out_ptr1[static_cast<int64_t>(x0)];
                    auto tmp6 = in_ptr1[static_cast<int64_t>(x0)];
                    auto tmp8 = out_ptr0[static_cast<int64_t>(x0)];
                    auto tmp11 = in_ptr2[static_cast<int64_t>(x0)];
                    auto tmp1 = static_cast<float>(2097152.0);
                    auto tmp2 = tmp0 / tmp1;
                    auto tmp3 = static_cast<float>(1e-05);
                    auto tmp4 = float(tmp2 + tmp3);
                    auto tmp5 = 1 / std::sqrt(tmp4);
                    auto tmp7 = float(tmp5 * tmp6);
                    auto tmp9 = decltype(tmp8)(-tmp8);
                    auto tmp10 = float(tmp9 * tmp7);
                    auto tmp12 = float(tmp10 + tmp11);
                    in_out_ptr0[static_cast<int64_t>(x0)] = tmp7;
                    in_out_ptr1[static_cast<int64_t>(x0)] = tmp12;
                }
            }
        }
    }
    #pragma omp parallel num_threads(4)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for
            for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(32L); x0+=static_cast<int64_t>(1L))
            {
                #pragma GCC ivdep
                for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(2097152L); x1+=static_cast<int64_t>(1L))
                {
                    {
                        {
                            auto tmp0 = in_ptr0[static_cast<int64_t>(x1 + 2097152L*x0)];
                            auto tmp1 = in_out_ptr0[static_cast<int64_t>(x0)];
                            auto tmp3 = in_out_ptr1[static_cast<int64_t>(x0)];
                            auto tmp2 = float(tmp0 * tmp1);
                            auto tmp4 = float(tmp2 + tmp3);
                            out_ptr2[static_cast<int64_t>(x1 + 2097152L*x0)] = tmp4;
                        }
                    }
                }
            }
        }
    }
}
''')

async_compile.wait(globals())
del async_compile

class Runner:
    def __init__(self, partitions):
        self.partitions = partitions

    def recursively_apply_fns(self, fns):
        new_callables = []
        for fn, c in zip(fns, self.partitions):
            new_callables.append(fn(c))
        self.partitions = new_callables

    def call(self, args):
        arg0_1, arg1_1, arg2_1 = args
        args.clear()
        assert_size_stride(arg0_1, (32, ), (1, ))
        assert_size_stride(arg1_1, (32, ), (1, ))
        assert_size_stride(arg2_1, (1, 32, 128, 128, 128), (67108864, 2097152, 16384, 128, 1))
        buf0 = empty_strided_cpu((1, 32, 1, 1), (32, 1, 32, 32), torch.float32)
        buf1 = empty_strided_cpu((1, 32, 1, 1), (32, 1, 32, 32), torch.float32)
        buf3 = reinterpret_tensor(buf1, (1, 32, 1, 1), (32, 1, 1, 1), 0); del buf1  # reuse
        buf4 = reinterpret_tensor(buf0, (1, 32, 1, 1), (32, 1, 1, 1), 0); del buf0  # reuse
        buf5 = empty_strided_cpu((1, 32, 128, 128, 128), (67108864, 2097152, 16384, 128, 1), torch.float32)
        # [Provenance debug handles] cpp_fused_native_group_norm_0:1
        cpp_fused_native_group_norm_0(buf3, buf4, arg2_1, arg0_1, arg1_1, buf5)
        del arg0_1
        del arg1_1
        del arg2_1
        return (buf5, )
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162709
Approved by: https://github.com/CaoE, https://github.com/jansel
2025-11-17 02:52:33 +00:00
363385ad3e s/Stragety/Strategy/ (#167916)
Signed-off-by: Edward Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167916
Approved by: https://github.com/Skylion007
2025-11-16 19:47:23 +00:00
e2e10753d7 Allow same triton kernels in export (#167862)
Summary: This diff would be a follow-up diff for D85883723.

Test Plan:
See D86719598. We are now able to publish the model.

Unit test:
```
buck run fbcode//mode/opt -c remoteexecution.local=enabled fbcode//sigmoid/inference/test:test_passes -m ovr_config//triton:experimental -- -r test_triton_hop_cpu
```

Differential Revision: D87091238

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167862
Approved by: https://github.com/XueningXu
2025-11-16 17:51:23 +00:00
5d99a795f5 [xpu][test] Migrated two test files to XPU (#166684)
# Description
Fixes #114850, we will port test utils and schema check to Intel GPU
We could enable Intel GPU with following methods and try the best to keep the original code styles:

# Changes
1. Get device type with from accelerator and get_devtype helper method
2. Replace the requires cuda statement to device_type.
3. Add HAS_XPU and HAS GPU check to replace some of the HAS_XPU etc.

# Notify

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166684
Approved by: https://github.com/ezyang, https://github.com/guangyey

Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com>
2025-11-16 14:15:28 +00:00
2245d7d3b9 Improve char printing (#167899)
This PR outputs chars to stream without building temporary strings.
They were modified by (on fish)
```
sed  -i -e 's/<< "\([^\\\']\)"/<< \'\1\'/g' (grep '<< "."' -r torch c10 aten -l)
```
and revert some invalid changes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167899
Approved by: https://github.com/Skylion007
2025-11-16 07:19:16 +00:00
98b94b90dd [pallas backend] implement gpu tiles/mask for power of 2 (#167584)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167584
Approved by: https://github.com/jansel
2025-11-16 07:01:51 +00:00
5cdbda140c [vision hash update] update the pinned vision hash (#167890)
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml).
Update the pinned vision hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167890
Approved by: https://github.com/pytorchbot
2025-11-16 04:58:47 +00:00
0ec53beaeb Refactor TensorAccessor for headeronly. (#166855)
This PR moves the implementations of Tensor accessor classes to headeronly with the following modifications:
- Add ArrayRef and IndexBoundsCheck template parameters to refactor out the usages of `IntArrayRef` and `TORCH_CHECK_INDEX` from Tensor accessor implementations.
- Eliminate usage of `c10::irange` as it is not headeronly-compatible.
- Introduce `torch::headeronly::{TensorAccessorBase,TensorAccessor, GenericPackedTensorAccessorBase, GenericPackedTensorAccessor}` that are headeronly-equivalent to `at::{TensorAccessorBase,TensorAccessor, GenericPackedTensorAccessorBase, GenericPackedTensorAccessor}`. Both these sets of template classes use original implementations from `torch::headeronly::detail` that have new template parameters `ArrayRefCls` and `IndexBoundsCheck` to facilitate `at` and `torch::headeronly` implementations of ArrayRef and checking indices.

TODO:
- ~when https://github.com/pytorch/pytorch/pull/164991 lands, eliminate the placeholder class HeaderOnlyArrayRef~ UPDATE: done.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166855
Approved by: https://github.com/janeyx99
2025-11-15 22:37:24 +00:00
79fc0a9141 [xpu][fix]Fall back deterministic index_copy to index_put on XPU (#167830)
A minor update has been made to the deterministic behavior checks in the `index_copy_out` implementation. This change ensures that deterministic  `index_copy` is dispatched to `index_put` not only for CUDA tensors but also for XPU tensors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167830
Approved by: https://github.com/guangyey, https://github.com/ezyang
2025-11-15 18:09:25 +00:00
d01a7b0241 Back out "MatMal - fix folding logic" (#167884)
Summary:
For sepcific hardware (A100), Autocast will generate a relatively large error on Transformer (torch.nn.TransformerEncoder) when using no_grad decorator on dim=256 (and larger presuably).

H100 seems fine, as does A100 with mig (so less than full SMs).

For now backing out, and revisting next week.

Test Plan:
failed jobs:
https://fburl.com/scuba/remote_execution_action/jzcmujgk

 {F1983543613}

Reviewed By: t-ivan-gr

Differential Revision: D87111518

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167884
Approved by: https://github.com/malfet
2025-11-15 08:29:08 +00:00
deabb3e36d Use c10::filesystem (#167821)
This PR fixes code to use c10::filesystem functionality instead of manually implemented functions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167821
Approved by: https://github.com/Skylion007
2025-11-15 06:01:01 +00:00
79d2397b6b Fix grammar issues in C++ frontend documentation (#167702)
Corrected minor grammatical errors in the documentation.

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167702
Approved by: https://github.com/jerryzh168
2025-11-15 05:55:08 +00:00
6ef3a62c36 Fix typo in FP16 accumulation section (#167703)
Fix typo error
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167703
Approved by: https://github.com/jerryzh168
2025-11-15 05:34:07 +00:00
530e782239 [codemod][lowrisk] Remove unused exception parameter from caffe2/torch/csrc/jit/backends/coreml/objc/PTMCoreMLBackend.mm (#167604)
Summary:
`-Wunused-exception-parameter` has identified an unused exception parameter. This diff removes it.

This:
```
try {
    ...
} catch (exception& e) {
    // no use of e
}
```
should instead be written as
```
} catch (exception&) {
```

If the code compiles, this is safe to land.

Test Plan: Sandcastle

Differential Revision: D85813836

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167604
Approved by: https://github.com/malfet, https://github.com/seemethere
2025-11-15 05:17:48 +00:00
c66a6c432e [HOP][print] Add functionalization (make sure ordering) for print (#167016)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167016
Approved by: https://github.com/angelayi
2025-11-15 05:06:05 +00:00
3d7a8b7e61 MPS: Fix clamp scalar cache key to store floats in hex representation (#167777)
Fixes #167767.

Original issue was that using std::to_string(value) does not work intended here if the value is smaller than 1e-6. The caching keys ended up as `clamp_out_mps_min:0.000000_scalar::f32[1]` instead of `clamp_out_mps_min:0.0000001_scalar::f32[1]`. After the change the values are stored as the hex representation for the floating point number. So for min_value 1e-7 the key will be `impl_min:0x1.ad7f2ap-24_scalar::f32[1]` and for min_value 0.0 `clamp_out_mps_min:0x0p+0_scalar::f32[1]`

Output of the repro code before the change:

```
tensor([0.], device='mps:0')
tensor([0.], device='mps:0')
tensor([0.], device='mps:0')
tensor([0.], device='mps:0')
tensor([0.], device='mps:0')
tensor([1.0000e-07], device='mps:0')
tensor([0.], device='mps:0')
tensor([1.0000e-07], device='mps:0')
```

Output for the repro code after the change:

```
tensor([0.], device='mps:0')
tensor([1.0000e-07], device='mps:0')
tensor([0.], device='mps:0')
tensor([1.0000e-07], device='mps:0')
tensor([0.], device='mps:0')
tensor([1.0000e-07], device='mps:0')
tensor([0.], device='mps:0')
tensor([1.0000e-07], device='mps:0')
```
which matches the expected CPU reference.

Snippet to test with:
```
import torch

device='mps'
dtype=torch.float32
a = torch.zeros(1, device=device, dtype=dtype)

# the following line triggers the incorrect behavior, when commented, the remainder of the script appears to work as expected
a_clamped = a.clamp(min=0.0)

b = torch.zeros(1, device=device)
print(b)
c = b.clamp(min=1e-7)
print(c)

b = torch.zeros(1, device=device)
print(b)
c = b.clamp(min=1e-7, max=None)
print(c)

b = torch.zeros(1, device=device)
print(b)
c = b.clamp(min=1e-7, max=torch.inf)
print(c)

b = torch.zeros(1, device=device)
print(b)
c = b.clamp_min(1e-7)
print(c)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167777
Approved by: https://github.com/malfet
2025-11-15 03:26:38 +00:00
de0d69b2c4 Remove useless super() delegation (#167791)
This PR removes useless super() delegations detected by pylint.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167791
Approved by: https://github.com/albanD
2025-11-15 02:50:51 +00:00
bc60b86066 Skip stable diffusion models in torchbench, get tests and benchmarks green (#167896)
Test Plan:
- wait for CI

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167896
Approved by: https://github.com/aorenste, https://github.com/shunting314
ghstack dependencies: #167609
2025-11-15 02:44:36 +00:00
d7782ddde7 [ATEN][CUDA] Reduce register pressure introduced by CUDA_KERNEL_ASSERT to improve torch.EmbeddingBag performance (#167834)
# Summary

This PR optimizes the CUDA kernels for `torch.nn.EmbeddingBag` by reducing GPU register pressure introduced by `CUDA_KERNEL_ASSERT`, which improves kernel occupancy and overall performance. The optimization separates input validation into a dedicated loop before the main processing loop, allowing the compiler to better optimize register allocation. By extensively testing on various GPUs and CUDA versions, `torch.nn.EmbeddingBag` performance improves by 29% to 111% with this PR.

# Performance Results

The following table shows the performance improvements on various input distributions and GPUs. All benchmarks use PyTorch 2.9.0 compiled with CUDA 12.8.

**Input Distribution Types (simulating recommendation system ID patterns):**
- **random id**: Randomly sampled embedding indices from the full vocabulary (uniform distribution)
- **one-hot**: One ID appears with very high frequency across all bags, simulating a popular item in recommendation systems
- **multi-hot**: Multiple IDs appear with high frequency across all bags, simulating multiple popular items in recommendation systems

**Test Configuration:**
- Embedding shape: `(5000000, 128)` (5M vocabulary size, 128-dimensional embeddings)
- Batch size: 2048 bags
- Average bag size: 150 indices per bag

| GPU  | Input Distribution | Before (µs) | After (µs) | Speedup |
| ---- | ------------------ | ----------- | ---------- | ------- |
| H100 | random id          | 162.4       | 105.9      | 1.53×   |
| H100 | one-hot            | 120.4       | 88.6       | 1.36×   |
| H100 | multi-hot          | 113.1       | 87.8       | 1.29×   |
| H20  | random id          | 278.6       | 132.2      | 2.11×   |
| H20  | one-hot            | 189.7       | 110.3      | 1.72×   |
| H20  | multi-hot          | 172.4       | 107.4      | 1.61×   |

# Motivation

The original implementation performed bounds checking using `CUDA_KERNEL_ASSERT` inline within the main processing loop, which increased register pressure and limited GPU occupancy. From NSight Compute analysis on H20, using PyTorch 2.9 compiled with CUDA 12.8, removing the `CUDA_KERNEL_ASSERT` from the main loop with this PR increases the overall occupancy from 50% to 75%(registers per thread 52->40).

By separating validation into a dedicated loop, we:

1. **Reduce register pressure in the main loop**: The validation loop uses minimal registers, allowing the compiler to optimize the main processing loop independently with better register allocation.
2. **Maintain correctness**: All input validation is still performed, but in a more register-efficient manner.

# Changes

## Modified Kernels

1. **`EmbeddingBag_updateOutputKernel_max`**: Added separate validation loop before main processing
2. **`EmbeddingBag_updateOutputKernel_sum_mean`**: Added separate validation loop before main processing

## Key Implementation Details

- **Separate validation loop**: Input indices are validated in a dedicated loop that checks all indices before processing begins
- **No early exit**: The validation loop intentionally avoids using `break` for early exit, as benchmarking showed that early exit degrades performance, possibly due to increased branch divergence and reduced instruction-level parallelism
- **Consistent error messages**: Improved error message clarity for invalid input indices
- **Design choice: validation loop vs. separate kernel**: We considered removing `CUDA_KERNEL_ASSERT` entirely and performing bounds checking in a separate GPU kernel, which would achieve even better performance (e.g., on H20 with random id distribution: 132.2 µs → 124.6 µs). However, this approach is harder to maintain as it requires coordinating two separate kernel launches and managing additional kernel launch overhead. Instead, we chose the current approach of using a separate validation loop within the same kernel, which provides a good balance between performance improvement and code maintainability.

## Code Changes

```cpp
// Separate validation loop reduces register pressure in the main loop below.
// No early exit (break) on invalid input as benchmarking shows it degrades performance.
bool has_invalid_index = false;
for (int64_t emb = begin; emb < end; emb++) {
  index_t input_idx = input[emb];
  has_invalid_index = has_invalid_index || (input_idx < 0 || input_idx >= numRows);
}
CUDA_KERNEL_ASSERT(!has_invalid_index && "Invalid input index in EmbeddingBag: index out of range [0, numRows)");

// Main processing loop (now with reduced register pressure)
for (int64_t emb = begin; emb < end; emb++) {
  // ... processing logic ...
}
```

# Testing & Compatibility

## Performance Testing

I conducted extensive performance testing across multiple configurations. All tests show significant performance improvements:

**Tested CUDA Versions:**
- CUDA 12.6, 12.8, 13.0

**Tested GPU Architectures:**
- A100, H20, H100

**Tested Input Configurations:**
- **Embedding shapes**: Various sizes including `[5000000, 128]` and `[128000, 4096]`
- **Embedding dtypes**: `torch.float32`, `torch.float16`
- **Input distributions**: Random indices, one-hot (high-frequency single ID), and multi-hot (high-frequency multiple IDs) patterns, simulating recommendation system workloads
- **Input sizes**: Average bag sizes of 150, 20, and 10 indices per bag

## Correctness Testing

-  Correctness tests pass for various embedding types (bfloat16, float32), shapes, and input distributions
-  Register usage reduction verified with NSight Compute
-  Linter passes

## Compatibility

-  No API/ABI changes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167834
Approved by: https://github.com/ngimel, https://github.com/eqy
2025-11-15 02:03:38 +00:00
eqy
fb04e9ad03 [CUDA][CUDA Graphs] Respect node-priority in cudaGraphInstantiate (#167346)
Needed for e.g., stream priority-based implementations of comm-compute overlap

CC @galv

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167346
Approved by: https://github.com/ngimel
2025-11-15 01:59:36 +00:00
cfe799b4aa Revert "Ops convolution_backward optional flag bug (#165008)"
This reverts commit c429b1fc5c60a6819b041f1a881ab09735689fbe.

Reverted https://github.com/pytorch/pytorch/pull/165008 on behalf of https://github.com/clee2000 due to I think this broke some tests in the slow workflow? test/test_ops.py::TestCommonCUDA::test_compare_cpu_convolution_backward_cuda_float32 [GH job link](https://github.com/pytorch/pytorch/actions/runs/19375318020/job/55443680773) [HUD commit link](c429b1fc5c) ([comment](https://github.com/pytorch/pytorch/pull/165008#issuecomment-3535354672))
2025-11-15 01:50:09 +00:00
b7f52773e6 Add meta registration for scaled_mm_v2 and test (#167653)
Summary:

`torch._scaled_mm_v2` didn't have a valid meta registration, or
`FakeTensor` tests, so anything expecting inductor to work (like
torch.ao tests) would fail horribly.

Test Plan:

```
pytest -sv -k "scaled_mm_v2" test/test_ops.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167653
Approved by: https://github.com/drisspg
2025-11-15 01:21:04 +00:00
f6b54d8899 flight_recorder: move to torch.distributed (#167782)
Summary: This moves torchfrtrace to be under `torch.distributed.flight_recorder` instead of `tools.flight_recorder` as the `tools` package is not included in the torch wheels. This makes it so you can use fr trace analyze without using it from a source checkout

Test Plan:
```
buck run //caffe2/fb/flight_recorder:fr_trace
```

CI

Differential Revision: D87022129

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167782
Approved by: https://github.com/fduwjj
2025-11-15 01:16:59 +00:00
da91bf5262 Fix incorrect attention example in ONNX exporter docstring (#167646)
Fixes #167627

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167646
Approved by: https://github.com/malfet, https://github.com/titaiwangms

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-11-15 00:31:48 +00:00
1c1638297e Revert "distributed/debug: add an HTTP server for debugging running jobs (#167395)"
This reverts commit 4ed26f7382bc3e5217121f5085af070e57f2ef40.

Reverted https://github.com/pytorch/pytorch/pull/167395 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/167395#issuecomment-3535150292))
2025-11-15 00:25:51 +00:00
ee0b5b4b1c Add new CI jobs to run dynamo tests on all python versions supported (#166978)
This PR adds 2 new CI jobs to run dynamo core (`test/dynamo/*`) and
`dynamo_wrapped` tests on Python 3.11/3.12.

**Selected Machine**
Tests are executed on `linux.c7i.2xlarge` without GPU. Which means all
cuda tests (if any) are skipped.

**Runtime**
- The core tests takes 30 minutes to run
- The `dynamo_wrapped` test is divided into three shards and each one
  takes around 1.5 hours to execute

**Schedule**
Tests are executed every day at 1:29 PDT or in the presence of
`ciflow/dynamo` label

Co-authored-by: Rob Timpe <rtimpe@openteams.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166978
Approved by: https://github.com/atalman, https://github.com/malfet
ghstack dependencies: #167092
2025-11-15 00:15:36 +00:00
fcfb213c5a [inductor] layout constraint for weight-norm-bwd (#167667)
fix https://github.com/pytorch/pytorch/issues/165749

The weight_norm backward kernel requires its inputs to be contiguous. Add those constraints to the lowering/fallback rule.

A better fix is maybe add decomposition rule for the op. But since we already fallback, this fix does no harm and can fix the attached issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167667
Approved by: https://github.com/eellison
2025-11-14 23:59:24 +00:00
08042bbb9c [6/N] Use Python 3.10 typing (#167649)
This PR applies new Union typing syntax to some python files.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167649
Approved by: https://github.com/albanD
2025-11-14 23:55:08 +00:00
e20ca3bc2e Remove python workaround for ContextDecorator (#167049)
This PR removes the import workaround for ContextDecorator because the import always succeeds in Py 3.10+.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167049
Approved by: https://github.com/Skylion007
2025-11-14 23:54:52 +00:00
4ed26f7382 distributed/debug: add an HTTP server for debugging running jobs (#167395)
This adds a debug HTTP server for debugging stuck or slow jobs. It runs the WorkerServer on every worker and then launches a separate flask process on rank 0 to have users connect to for debugging.

This can easily be improved to trigger profilers as well as visualize the data much better.

Initial handlers:
* pytorch profiler
* FlightRecorder data
* Python stacks

```
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "2000"

from torch.distributed.debug import enable_debug_server

enable_debug_server()
```

Test plan:

```
torchrun --nnodes 1 --nproc_per_node=gpu ~/scripts/debug_test.py
```

<img width="1499" height="1629" alt="20251107_17h10m47s_grim" src="https://github.com/user-attachments/assets/a8b9a0cb-3bbf-4558-be12-5253e418214e" />
<img width="1192" height="1337" alt="20251107_17h10m39s_grim" src="https://github.com/user-attachments/assets/ac5d7011-4acb-4401-bf2c-f9b22c1466bd" />

<img width="984" height="851" alt="20251107_18h35m38s_grim" src="https://github.com/user-attachments/assets/98b3eb31-ed01-4345-90dd-c79345cf82ce" />
<img width="2880" height="777" alt="20251107_18h35m31s_grim" src="https://github.com/user-attachments/assets/8de84b8b-9d06-4bc8-a1bf-280a2958315b" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167395
Approved by: https://github.com/fduwjj
2025-11-14 23:14:38 +00:00
4c79305b87 [targets2buck] Clean up get_pt_ops_deps (#167690)
Summary: I didn't understand what this macro was doing so I created a bit of a mess, mess be gone!

Test Plan: `buck2 ctargets fbcode//caffe2/... fbsource//xplat/caffe2/...`

Reviewed By: mzlee

Differential Revision: D86460608

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167690
Approved by: https://github.com/seemethere
2025-11-14 23:05:24 +00:00
f4b8c4f907 backed size oblivious checks for expand() (#167689)
Summary:
Support semantics when using backed_size_oblivious, similar to https://github.com/pytorch/pytorch/pull/167232

We see errors in a model exported with dynamic shapes, like
```
RuntimeError: non-broadcasting semantics require s67 == 41

While executing %expand : [num_users=1] = call_method[target=expand](args = (%reshape_5, -1, -1, %getitem_9), kwargs = {})
```

Test Plan:
test_dynamic_shapes:
```
test_backed_size_oblivious_expand (test_dynamic_shapes.TestUbackedOps) ... I1112 14:07:54.724596 1386932 Logger.cpp:995] Dropping logs in unit tests.
ok
```

Differential Revision: D86902546

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167689
Approved by: https://github.com/laithsakka
2025-11-14 22:31:28 +00:00
d629b7a459 Move CppTypeToScalarType to torch/headeronly (#167610)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167610
Approved by: https://github.com/pearu, https://github.com/janeyx99
2025-11-14 22:21:45 +00:00
0922ba5f42 [BE] No need to pass const enum values by reference (#167868)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167868
Approved by: https://github.com/slayton58
2025-11-14 21:56:19 +00:00
c87295c044 [precompile] Support captured global tensors. (#167846)
Summary:
In vllm we saw cases where user intialize a tensor in the global scope and reference it in the forward body. This should be supported by pruning the used globals in the scope and serialize them along the artifacts similar to how we handle closure.

Use case example: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3n.py#L65

Test Plan:
test_aot_compile.py

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167846
Approved by: https://github.com/jamesjwu
2025-11-14 21:40:07 +00:00
7aa210d215 Revert "[CodeClean] Remove the Unused MACRO for AOT Inductor Runtime (#165139)"
This reverts commit fcd5f8c352b5b75bd32e57fa044ec5df095032da.

Reverted https://github.com/pytorch/pytorch/pull/165139 on behalf of https://github.com/jeanschmidt due to trying to hevert in the hopes it fixes internal errors, will land it back ([comment](https://github.com/pytorch/pytorch/pull/165139#issuecomment-3534662138))
2025-11-14 21:35:37 +00:00
5a368b8010 Revert "[CodeClean] Replace std::runtime_error with TORCH_CHECK (#165119)"
This reverts commit 398775a43e9808205f75c81d36f5087117d3f3f4.

Reverted https://github.com/pytorch/pytorch/pull/165119 on behalf of https://github.com/jeanschmidt due to trying to hevert in the hopes it fixes internal errors, will land it back ([comment](https://github.com/pytorch/pytorch/pull/165139#issuecomment-3534662138))
2025-11-14 21:35:37 +00:00
602102be50 Revert "Hide all symbols (except stable/headeronly/shim) if TORCH_STABLE_ONLY is defined (#167496)"
This reverts commit bc09a84150eaadaadab8a8ecd76cd9afc60d8a19.

Reverted https://github.com/pytorch/pytorch/pull/167496 on behalf of https://github.com/jeanschmidt due to trying to revert 165139, my intention is to land it again, so, will land this once both are reverted ([comment](https://github.com/pytorch/pytorch/pull/167496#issuecomment-3534641209))
2025-11-14 21:33:02 +00:00
200156e385 DTensor: avoid unnecessary DTensorSpec creation in _ToTorchTensor.backward (#167588)
Looks like the check here is cheap and has a potentially large payoff.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167588
Approved by: https://github.com/ezyang
2025-11-14 21:08:12 +00:00
a2daf3fc86 [Inductor] Add support bound methods in pattern matcher (#167795)
Fixes: #167776

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167795
Approved by: https://github.com/mlazos
2025-11-14 20:55:51 +00:00
52b45c16de Add reshape, view, flatten to torch/csrc/stable (#167600)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167600
Approved by: https://github.com/janeyx99
ghstack dependencies: #167592
2025-11-14 20:35:53 +00:00
2ef85bed5a Add empty to stable ops (#167592)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167592
Approved by: https://github.com/janeyx99
2025-11-14 20:35:53 +00:00
d99c6bcf69 [export] Disable side effects on dynamo_graph_capture_for_export and warn user. (#167763)
Summary:
as title.

Test Plan:
test_dynamo_graph_capture_side_effects

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167763
Approved by: https://github.com/tugsbayasgalan
2025-11-14 20:35:22 +00:00
8378abda84 [torch.export] Fix for flaky test_annotate_on_assert (#167805)
Summary: test_annotate_on_assert become flaky with PR 166341 (Details in https://github.com/pytorch/pytorch/issues/167432). Torchdynamo related metadata can vary depending on the caller. Removing the those metadata before comparison.

Test Plan:
```
buck test mode/opt caffe2/test:test_export -- 'test_annotate_on_assert'
```
https://www.internalfb.com/intern/testinfra/testrun/7036874728749661

Differential Revision: D87036890

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167805
Approved by: https://github.com/yushangdi
2025-11-14 19:56:51 +00:00
5b42a5d9a6 [doc] Add example for torch.is_storage (#161898)
Fixes #161858

### Summary:
Added comprehensive documentation examples for `torch.is_storage()` to help users understand how to check if an object is a PyTorch storage object.

### Impact:

- Enhances API Documentation
- Helps users distinguish between PyTorch storage objects and other types

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161898
Approved by: https://github.com/isuruf, https://github.com/malfet
2025-11-14 19:45:54 +00:00
caca3f2eec Revert "Re-land "Fix thread safety in getCurrentCUDABlasHandle and getCUDABlasLtWorkspace" (#167722)"
This reverts commit 40e6f090d91026947fbec92a42564ad492f37eae.

Reverted https://github.com/pytorch/pytorch/pull/167722 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/167722#issuecomment-3534282212))
2025-11-14 19:38:22 +00:00
fbb59a83dc Update on "Fix all gather bucketing fusion in of dtype casts"
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-11-14 11:34:18 -08:00
a1b4e31536 Update on "Fix all gather bucketing fusion in of dtype casts"
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-11-14 11:30:05 -08:00
9e2bf129e1 [MPS] addmm complex fix (#167826)
Fixes #167727

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167826
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-11-14 19:29:09 +00:00
c429b1fc5c Ops convolution_backward optional flag bug (#165008)
Fixes #89629

When using torch.ops.aten.convolution_backward, the optional argument bias_sizes was being used in the python function registration without checking whether it was defined.

## For the fix
there are two modes to consider with different results.

First @dynamo.optimize("inductor") is the most demanding.
We cannot be wrong about the size passed into the function. But we should not ignore what the user wants/thinks they are doing. For this case, we want to throw an error when the user is wrong. If the user passes in None, we calculate the expected size directly.

Second @dynamo.optimize("eager") is very lenient.
We really can provide any value we want here. If the user is wrong about bias shape in eager mode, the op will just reshape the bias to the proper size so no error is thrown here.

## For testing
An OpInfo was added for torch.ops.aten.convolution_backward.default.
For the CUDA test_noncontiguous_samples test, a slightly updated error tolerance was necessary for the compounded add multiply (for 2x2 kernel).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165008
Approved by: https://github.com/bdhirsh
2025-11-14 19:24:45 +00:00
1176b2b0b7 [BE]: Update NVTX submodule to 3.3.0 (#167751)
Update NVTX to 3.3.0. Mostly fixes some errors in the bindings, improve C++20 support, and improve C++ bindings to NVTX. Header only library upgrade so should be mostly safe.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167751
Approved by: https://github.com/albanD, https://github.com/eqy
2025-11-14 19:24:37 +00:00
dd37a1a434 Fix NaN gradients in atan2_backward when both inputs are zero (#166787)
Fixes #165427

## Description of Bug 🐛

As reported in #165427, When both the input of  `atan2` function is zero the gradient becomes `NaN`. During the forward pass, `atan2` successfully avoids division-by-zero issue, but during backpropagation gradients become `NaN`.

This is because the backward pass calculates `(self * self + other * other).reciprocal()`, which becomes `inf` at `(0, 0)`. The subsequent multiplication by zero `(0 * inf)` results in `NaN`.

## Changes
- Added an `at::where` condition to handle zero denominators in `atan2_backward`.
- If denom is zero return 0 for the reciprocal; otherwise, use the original value.

## Testing
- Added` test_atan2_zero_gradient` in `test/test_autograd.py` to verify `atan2` returns `0.0` gradients for `(0,0)`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166787
Approved by: https://github.com/soulitzer
2025-11-14 19:23:33 +00:00
a74adcf80e [codemod][lowrisk] Remove unused exception parameter from caffe2/caffe2/serialize/inline_container.cc (#167612)
Summary:
`-Wunused-exception-parameter` has identified an unused exception parameter. This diff removes it.

This:
```
try {
    ...
} catch (exception& e) {
    // no use of e
}
```
should instead be written as
```
} catch (exception&) {
```

If the code compiles, this is safe to land.

Test Plan: Sandcastle

Differential Revision: D85813824

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167612
Approved by: https://github.com/seemethere, https://github.com/malfet
2025-11-14 19:11:01 +00:00
5eac46a011 add assume_32bit_indexing inductor config (#167784)
when we know all tensor and intermediate tensors fit in 32 bit but use unbacked DS
we want a way to assume that we can use 32 bit indexing(we will runtime assert on it).

It is not practical to torch check every possible intermediate tensor size ahead of time.

This is needed to enhance vLLM perf with unbacked,  since in vLLM all tensors and
intermediates assumed to fit in 32 bits.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167784
Approved by: https://github.com/jansel
2025-11-14 19:04:22 +00:00
989fa5b102 Update on "Fix all gather bucketing fusion in of dtype casts"
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-11-14 10:38:23 -08:00
db282e751d Update on "Fix all gather bucketing fusion in of dtype casts"
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-11-14 10:35:01 -08:00
7f1accf93e Update on "Fix all gather bucketing fusion in of dtype casts"
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-11-14 10:32:32 -08:00
f71a414ac3 Fix all gather bucketing fusion in of dtype casts
[ghstack-poisoned]
2025-11-14 10:30:33 -08:00
bd04dc7750 small changes
[ghstack-poisoned]
2025-11-14 10:30:28 -08:00
e0fff31ae3 [dynamo] Make global state guards and torch function stack guards droppable. (#167674)
Summary:
Prior to this PR we will always build global and torch funciton guards in all cases.

In this PR we did 2 changes to dynamo guards:
1. Created a new guard called "GLOBAL_STATE" which corresponds to the global state guard and can be filtered out using guard_filter_fn
2. Repurpose the existing "TORCH_FUNCTION_STATE" guard for checking torch function mode stack.

Also added a new helper `torch.compiler.skip_all_guards_unsafe` which can be useful for use cases like vllm

Test Plan:
CI

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167674
Approved by: https://github.com/anijain2305
2025-11-14 18:11:44 +00:00
7ede33b8e3 Tiling bug fix (#167771)
Fix for https://github.com/pytorch/pytorch/issues/166653.

Two fixes:
- We were inducing a split for broadcasted loads. e.g. (x // 16). While a split of 16 here will make the load coalesced in one of the tile vars, since the load is already in cache it's not worth splitting. And it would make the other tile var load from memory that isnt in cache.
- Add a slight term for uncoalesced memory. This prevents doing tiling for loads which are a small % of the overall kernel.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167771
Approved by: https://github.com/v0i0
2025-11-14 17:32:42 +00:00
065176cd97 [export] Add pytree input check for dynamo_graph_capture_for_export (#167731)
Summary:
as title.

Test Plan:
pytest test/export/test_export.py -k test_invalid_pytree_dynamo_graph_capture

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167731
Approved by: https://github.com/tugsbayasgalan
2025-11-14 17:29:55 +00:00
eqy
02ee7dd7d3 [CUDA][Test] Add serialTest() to some largeTensorTest tests (#167471)
Try to prevent two big tests from overlapping in their memory usage

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167471
Approved by: https://github.com/soulitzer
2025-11-14 17:13:14 +00:00
99fdca8f4d [ROCm] Enable StaticCudaLauncher for ROCm (#166492)
This PR enables ROCm/HIP support for PyTorch's StaticCudaLauncher, which provides static compilation and launching of Triton kernels. The implementation has been tested on AMD MI300 and MI200 hardware.

**Changes**

**Python (torch/_inductor/runtime/)**
- static_cuda_launcher.py: Added ROCm detection, .hsaco binary support, and ROCm-specific scratch parameter handling
- triton_heuristics.py: Updated device type checks to support both cuda and hip

**C++ (torch/csrc/)**
- Module.cpp: Enabled StaticCudaLauncher for ROCm builds
- inductor/static_cuda_launcher.cpp: Added HIP API equivalents for all CUDA driver calls
- inductor/static_cuda_launcher.h: Updated header guard

**Tests (test/inductor/)**
- test_static_cuda_launcher.py: Removed @skipIfRocm decorators and updated binary file handling

**Enabled Unit Tests**
All tests in test/inductor/test_static_cuda_launcher.py now pass on ROCm:
1. test_basic
2. test_unsigned_integers
3. test_signed_integers
4. test_basic_1arg
5. test_constexpr
6. test_implied_constant
7. test_kernel_no_args
8. test_high_shared_mem
9. test_too_high_shared_mem
10. test_kernel_empty_tensor
11. test_kernel_many_args
12. test_basic_compile
13. test_incompatible_code
14. test_static_launch_user_defined_triton_kernels
15. test_empty_tensor
16. test_any
17. test_disable_static_cuda_launcher

In addition to this, the following tests from test/inductor/test_codecache.py also pass:
1. test_remote_cache_load_function_device_cuda_float32_dynamic_False_bundle_triton_False_use_static_cuda_launcher_False
2. test_remote_cache_load_function_device_cuda_float32_dynamic_False_bundle_triton_True_use_static_cuda_launcher_False
3. test_remote_cache_load_function_device_cuda_float32_dynamic_False_bundle_triton_True_use_static_cuda_launcher_True
4. test_remote_cache_load_function_device_cuda_bfloat16_dynamic_False_bundle_triton_False_use_static_cuda_launcher_False
5. test_remote_cache_load_function_device_cuda_bfloat16_dynamic_False_bundle_triton_True_use_static_cuda_launcher_False
6. test_remote_cache_load_function_device_cuda_bfloat16_dynamic_False_bundle_triton_True_use_static_cuda_launcher_True

The following tests are skipped since triton bundling is necessary for StaticCudaLauncher:
1. test_remote_cache_load_function_device_cuda_float32_dynamic_False_bundle_triton_False_use_static_cuda_launcher_True
2. test_remote_cache_load_function_device_cuda_bfloat16_dynamic_False_bundle_triton_False_use_static_cuda_launcher_True

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166492
Approved by: https://github.com/jeffdaily
2025-11-14 17:11:45 +00:00
9d1a74cb0c Fix mvlgamma_ FPE crash on x86 with integer input (#164230)
Fixes #161871.

Behaviour on arm:

```
PyTorch version: 2.10.0a0+gitdef3b05
Architecture: arm64
Platform: Darwin
Processor: arm

Testing mvlgamma_ with integer tensor on arm64...
 Got expected error: mvlgamma: result type Long can't be cast to the desired output type Float
```

and on x86:

```
PyTorch version: 2.10.0a0+git1310d6a
Architecture: x86_64
Platform: Linux
Processor: x86_64

Testing mvlgamma_ with integer tensor on x86_64...
 Got expected error: mvlgamma: result type Long can't be cast to the desired output type Float
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164230
Approved by: https://github.com/albanD
2025-11-14 17:09:10 +00:00
3b90bf36f9 Add multiple hiding nodes
[ghstack-poisoned]
2025-11-14 09:04:39 -08:00
40e6f090d9 Re-land "Fix thread safety in getCurrentCUDABlasHandle and getCUDABlasLtWorkspace" (#167722)
Summary:
getCurrentCUDABlasHandle() and getCUDABlasLtWorkspace() use static mutable maps that are not protected from concurrent read-and-write. This leads to crashes.
This diff adds mutexes to synchronize access to the static maps.

Note: this is a re-land of D86316117 / https://github.com/pytorch/pytorch/pull/167248 (see comments for details)

Test Plan:
Use a GPU OD, run multi-threaded tests (cuda_cublas_handle_pool_test) with TSAN:
```
buck test fbcode//mode/dev-tsan fbcode//caffe2:cuda_cublas_handle_pool_test  -- --stress-runs 100
```
https://www.internalfb.com/intern/testinfra/testrun/14355223937501118

TSAN output (before synchronization was added): P2026731804

Differential Revision: D86964261

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167722
Approved by: https://github.com/malfet
2025-11-14 16:16:35 +00:00
bfddfde50c Add basic spin config and linting commands (#167226)
This PR adds a basic spin configuration to allow for linting. It is designed as a drop-in replacement for the current Makefile based solution, i.e. it sets up and updates lintrunner based on the hashes of certain configuration files.

Lintrunner is called via Uv's `uvx` command, separating its environment from the general development environment in an effort to reduce instances of competing requirements breaking environments.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167226
Approved by: https://github.com/atalman, https://github.com/albanD
2025-11-14 15:35:42 +00:00
b6570615f8 [precompile] Integrate AOTI as a backend. (#167338)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167338
Approved by: https://github.com/jamesjwu
2025-11-14 15:33:11 +00:00
226850cc66 [ATen][CUDA] Add sm_121a flag for RowwiseScaledMM (#167734)
This PR add a sm_121a flag for row-wise scaled matmuls on DGX Spark.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167734
Approved by: https://github.com/eqy, https://github.com/cyyever
2025-11-14 08:44:04 +00:00
f8a2ce3b9a Fix inplace ops on Partial DTensors to preserve aliasing semantics (#164729)
Fixes #163374.

Here is the output from reproducible code:

```
W1006 09:09:26.329000 2457 /home/fedora/github/pytorch/torch/distributed/run.py:811]
W1006 09:09:26.329000 2457 /home/fedora/github/pytorch/torch/distributed/run.py:811] *****************************************
W1006 09:09:26.329000 2457 /home/fedora/github/pytorch/torch/distributed/run.py:811] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W1006 09:09:26.329000 2457 /home/fedora/github/pytorch/torch/distributed/run.py:811] *****************************************
  aten::clamp_(dt: f32[][R], None, 2)
    redistribute_input(0, [P] -> [R])
      redistribute_input(t: f32[], [P] -> [R])
        _c10d_functional::all_reduce(t: f32[], sum, 0)
        _c10d_functional::wait_tensor(t: f32[])
    aten::clamp_(t: f32[], None, 2)
    aten::view(t: f32[], [])
(Replicate(),)
tensor(2., device='cuda:0')
```

The behavior is now matching what you were expecting in issue #163374:

Expected behavior (from the issue):
  1. Placement should change from Partial(sum) to Replicate()
  2. Value should be tensor(2.) instead of tensor(144.)

  Actual output from this build:
  1. (Replicate(),) - placement is correct
  2. tensor(2., device='cuda:0') - value is correct

so the inplace operation now properly redistributes the partial DTensor to replicate before performing the clamp snd maintains the correct aliasing semantics. It also produces the expected clamped value.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164729
Approved by: https://github.com/ezyang
2025-11-14 07:46:35 +00:00
e2c6834584 Revert "deprecate check_is_size and guard_size_oblivious (#167198)"
This reverts commit 50bf1f0b819f0b1cc9acbb0646ac9555bb9d44b9.

Reverted https://github.com/pytorch/pytorch/pull/167198 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/167198#issuecomment-3531149912))
2025-11-14 06:46:15 +00:00
0e7235ed73 [xpu][feature] [1/3] add fp8 scaled_mm implementation for XPU (#165978)
This PR implements `scaled_mm` for XPU. It enables the following data types:
1. TensorWise Scaling: `fp8_e4m3` and `fp8_e5m2`
2. RowWise Scaling:  `fp8_e4m3` and `fp8_e5m2`

It leaves the BlockWise Scaling to next PR, so that it will have less reviewing efforts.

This is the first PR that only adds `scaled_mm_xpu` but does not registered. We separate this out for less reviewing efforts.

Secondly, there is a `scaled_mm_v2` API in #164141 . We will align with it once the v1 is cleaned up.

**Co-author:** @yuchengliu1, @carsonwang

## PR stack:

- -> https://github.com/pytorch/pytorch/pull/165978 : implementation of XPU scaled_mm and oneDNN kernel
- https://github.com/pytorch/pytorch/pull/167518 : implementation of XPU scaled_mm_v2
- https://github.com/pytorch/pytorch/pull/166056 : Op registration

## Test Status:

1. Relies on the changes in https://github.com/intel/torch-xpu-ops/pull/1746/, Otherwise the op will fallback to CPU.
2. This PR does not include tests, the tests are enabled in #166056.

## Credit:

This work is based on @yuchengliu1's work at #140972 . The purpose that we created a new PR is to align with the API / checks with CUDA, so there will be less porting efforts.

## FP8 Task tracker:
We will track all the scaled_mm related tasks in: https://github.com/pytorch/pytorch/issues/167170

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165978
Approved by: https://github.com/liangan1, https://github.com/EikanWang

Co-authored-by: Eikan Wang <eikan.wang@intel.com>
2025-11-14 06:41:18 +00:00
3522e0ce74 Revert "Fix different seq length (#167481)"
This reverts commit c78e64622e62eb93a03a9c3762df3290d6c65362.

Reverted https://github.com/pytorch/pytorch/pull/167481 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/167481#issuecomment-3530992724))
2025-11-14 06:05:45 +00:00
50bf1f0b81 deprecate check_is_size and guard_size_oblivious (#167198)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167198
Approved by: https://github.com/bobrenjc93
2025-11-14 05:35:29 +00:00
c78e64622e Fix different seq length (#167481)
Differential Revision: D86685546

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167481
Approved by: https://github.com/eellison
2025-11-14 05:31:29 +00:00
5623628894 [SymmMem] op to get remote tensors (#167779)
To support use case in https://github.com/pytorch/helion/pull/1122, i.e.
```
@helion.kernel
def foo(
    x: Tensor,
    group_name: str
):
    x_remotes = torch.ops.symm_mem.get_remote_tensors(x, group_name)
    for t in x_remotes:
        ...
````

Helion uses fake tensor to trace a program, thus we cannot use the following code in a Helion function:
```
hdl = rendezvous(tensor)
remote_tensors = tuple(
    hdl.get_remote_tensor(peer, ...) for peer in range(world_size)
)
```
The reason is that when `tensor` is fake, the returned `hdl` is None, thus any subsequent call on it will fail.

This PR wraps the above functionality as an op:
```
lib.define("get_remote_tensors(Tensor x, str group_name) -> Tensor[]")
```
so that things like `hdl` is not exposed to Helion. The op also provides a `meta` implementation so that Helion can trace it without actually running the rendezvous.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167779
Approved by: https://github.com/yf225
2025-11-14 05:01:55 +00:00
2aba180114 Always track _local_scalar_dense output in tensorify_python_scalars. (#166573)
We need to track all symbols, we used to skip
u = item()
and fail with
```
 File "/home/lsakka/pytorch10/pytorch/torch/fx/passes/_tensorify_python_scalars.py", line 149, in _sympy_interp
    expr_to_sym_proxy[expr]
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
KeyError: u0
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166573
Approved by: https://github.com/bobrenjc93
2025-11-14 03:51:43 +00:00
45b2c3d312 [OpenReg][Feat][Docs] Enrich OpenReg device management implementation and add focused documentation (#165897)
## Summary
This PR enriches OpenReg device management codes and adds focused documentation.

## Key Changes
- Introduced device management documentation in `device.md`.
- Updated `OpenRegFunctions.h` and `OpenRegFunctions.cpp` to use `DeviceIndex` and added error handling.
- Implemented `check_device_index` function for validating device indices.
- Enhanced Python bindings in `Module.cpp` for device management.
- Added tests for invalid device index handling in `test_device.py`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165897
Approved by: https://github.com/fffrog
2025-11-14 03:08:23 +00:00
5b1e112cf9 [Dynamo] Imporve-graph-break-skip-logs (#167067)
Fixes #150477

### Summary:

- Added frame information (function name, file, line number) to all graph break/skip messages
- Standardized message format: "torch.compile will skip tracing the frame <name> (<file> line <N>) and fall back to eager. Reason: <reason>"

### Impacts:
module: dynamo

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167067
Approved by: https://github.com/williamwen42
2025-11-14 03:06:37 +00:00
5e6ac5c6e1 [Pytorch] Improve conversion to bfloat16 on aarch64/NEON (#166958)
Summary:
Autovectorization of casting to bfloat16_t is broken in clang-[17, 20], fixed in clang-21.

We are adding a workaround vectorized code, which improves conversion speed from smaller int data types.

We've observed the following performance improvements, when compiling with clang-19 and targeting armv9a+sve2:

before:

uint8->bfloat16_t  ===> 319.433us
int8->bfloat16_t  ===> 320.216us
int16->bfloat16_t  ===> 326.899us
int32->bfloat16_t  ===> 327.925us

after:

uint8->bfloat16_t  ===> 185.189us  -----> 72% higher throughput
int8->bfloat16_t  ===> 169.790us  -----> 89% higher throughput
int16->bfloat16_t  ===> 180.744us  -----> 81% higher throughput
int32->bfloat16_t  ===> 185.129us  -----> 77% higher throughput

Test Plan:
Correctness:

buck2 test mode/opt //caffe2/test:test_ops
buck2 test mode/opt //caffe2/test:torch

Performance:

buck2 run mode/opt //caffe2/benchmarks/operator_benchmark/fb:operator_benchmark_test

Differential Revision: D86207189

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166958
Approved by: https://github.com/mcfi
2025-11-14 02:40:08 +00:00
79317dc7a7 Fix no source name in backward kernel names; Add flex_attention HOP to "original_aten" node meta (#167749)
Fixes #167706

- Add `torch.fx.experimental.proxy_tensor.set_original_aten_op()` around flex_atention HOP dispatch so we have `original_aten` populated for flex_attention
- Update the usages of `original_aten` to also expect HOP in addition to OpOverload

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167749
Approved by: https://github.com/drisspg
2025-11-14 02:24:22 +00:00
96a4c4b3d1 add device generalization support for distributed tests (#165067)
## MOTIVATION
To generalize Distributed test cases for non-CUDA devices

## CHANGES
- Replaced hard coded device/backends with torch.accelerator.current_accelerator() and dist.get_default_backend_for_device
- Use DistributedTestBase instead of MultiProcessTestCase to use common utilities
- Remove instantiate_device_tests and make use of torch.accelerator.current_accelerator for test/distributed/test_c10d_object_collectives.py
- fix deterministic context issue for non-cuda devices in test/distributed/optim/test_zero_redundancy_optimizer.py
- use torch.accelerator.device_count() for multi-gpu check in torch/testing/_internal/distributed/_tensor/common_dtensor.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165067
Approved by: https://github.com/guangyey, https://github.com/albanD
2025-11-14 02:21:11 +00:00
05bcfcc5d1 [Profiler] Add Documentation for FunctionEvent (#167688)
Summary:
Adds documentation for EventList, FunctionEvent and FunctionEventAvg.

Closes https://github.com/pytorch/pytorch/issues/165907

Test Plan: N/A Documentation

Differential Revision: D86913697

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167688
Approved by: https://github.com/sanrise
2025-11-14 02:03:19 +00:00
8cf0bdde45 [xpu][fix] Fix conv1d precision error (#162944)
Currently, conv1d converts the 3D view to 4D before calling onednn::convolution().
However, this function converts the 4D tensor to a channel-last memory format for computation, resulting in incorrect return results (the correct result should be channel-first).
This PR fixes this issue, ensuring that the output return value format is consistent with the expected format.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162944
Approved by: https://github.com/EikanWang
2025-11-14 01:12:21 +00:00
813e5eae9b [fx, 3.14] fix assert detection for 3.14 (#167700)
Failing test was `pytest test/export/test_export.py -k test_python_asserts_with_sym_int`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167700
Approved by: https://github.com/bobrenjc93
ghstack dependencies: #167382, #167383, #167384, #167387, #167396, #167669
2025-11-14 01:00:43 +00:00
2ef236e3e3 [3.14, jit] skip jit tests on 3.14+, add jit deprecation warnings to user-facing API (#167669)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167669
Approved by: https://github.com/malfet, https://github.com/atalman
ghstack dependencies: #167382, #167383, #167384, #167387, #167396
2025-11-14 01:00:43 +00:00
532389fe9e [torchelastic] Add flush option to TailLog (#167169)
Differential Revision: D86366889

This PR adds the `flush` option to `TailLog`, and it will automatically flush (by setting `buffering=1`) the files opened by that `TailLog` instance.

This is mainly to resolve the race condition between the default flushing of `TailLog` and where we read the duplicated error files in the termination handler.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167169
Approved by: https://github.com/fduwjj
2025-11-14 00:21:26 +00:00
08de54f1ea [3.14] Skip failing spherical_bessel_j0 tests (#167691)
Starting with scipy 1.15, bool inputs error out.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167691
Approved by: https://github.com/williamwen42
2025-11-14 00:06:42 +00:00
0cd0bd7217 address DDE in matmul decomp (#166541)
Address https://github.com/pytorch/pytorch/issues/165081
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166541
Approved by: https://github.com/mlazos
2025-11-13 23:50:00 +00:00
fe33d7cadf Revert "address DDE in matmul decomp (#166541)"
This reverts commit c940b1fbbca8da7e526bf610ce007f8af75f6cd5.

Reverted https://github.com/pytorch/pytorch/pull/166541 on behalf of https://github.com/zou3519 due to broke Inductor CI ([comment](https://github.com/pytorch/pytorch/pull/166541#issuecomment-3530162518))
2025-11-13 23:29:06 +00:00
a9542426d0 [MPS] Add Metal complex mm implementation (#167755)
As MPSGraph one returns incorrect results if matrix inner dimention exceed 4K
Add regression test

Fixes https://github.com/pytorch/pytorch/issues/167727
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167755
Approved by: https://github.com/manuelcandales
2025-11-13 22:40:59 +00:00
f79cdc89db [CD] [aarch64] unify the build.sh to build for aarch64 wheel (#166044)
related to https://github.com/pytorch/pytorch/issues/163970

Changes:
Below are addressed from review from @malfet and @atalman:

1. Simplified the x86 TORCH_CUDA_ARCH_LIST logic to reuse the base list in`.ci/manywheel/build_cuda.sh`.
2. Added function filter_aarch64_archs() that filters the TORCH_CUDA_ARCH_LIST for aarch64 based on the x86 code.
3. Added function in `.ci/pytorch/build.sh` to report error if ACL is not present.
4. Deprecated previous aarch64 scripts (`.ci/aarch64_linux/` folder).

Improvements:

1. Significant improvement in build time for CUDA ARM wheel build -

Reduced build time from 5.5–6 hours to 1 hour 40–50 minutes
taking this 13.0 build for example, 6h 11m 46s to 1h 50m 1s ≈ 70 % faster build time
old: https://github.com/pytorch/pytorch/actions/runs/19304934204/job/55209695430
new: https://github.com/pytorch/pytorch/actions/runs/19301014750/job/55195226316
Reason: MAX_JOBS=5 is now removed after we move away from original aarch64 build workflow, previously it was OOM in building flash-attn, new MAX_JOBS is 12.
https://github.com/pytorch/pytorch/pull/166044/files#diff-ccef31095e4f2d203710232531c38bff3251e41cf73ec84ee59f224bb64034aeL280

2. Unified workflow for building x86 and sbsa wheels - more maintainable code
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166044
Approved by: https://github.com/atalman
2025-11-13 22:35:00 +00:00
3d063519bf [inductor][ez] skip cache for unit test via envvar (#167237)
It would be surprising to see the cache get hit in Unit Test when TORCHINDUCTOR_FX_GRAPH_CACHE_DEFAULT is set to 1.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167237
Approved by: https://github.com/eellison
2025-11-13 22:28:16 +00:00
0b3bdb0d89 [EZ][BE] Remove unnecessary semicolon in Module.cpp (#167756)
`${subj}`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167756
Approved by: https://github.com/Skylion007
2025-11-13 22:02:08 +00:00
8f00ec31ca [dynamo, nested graph breaks] disallow graph breaks in functorch ops, enable nested graph break tests on test_higher_order_ops.py (#166674)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166674
Approved by: https://github.com/ydwu4
ghstack dependencies: #166673
2025-11-13 21:52:02 +00:00
21f32e4af3 [dynamo] clean up BaseUserFunctionVariable and LocalGeneratorObjectVariable (#166673)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166673
Approved by: https://github.com/Skylion007, https://github.com/guilhermeleobas, https://github.com/mlazos
2025-11-13 21:52:02 +00:00
940979a229 [export, 3.14] handle patching methods with functools.partial correctly in non-strict export (#167396)
Note: dynamo is not affected by this since patching class methods are not supported right now.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167396
Approved by: https://github.com/angelayi
ghstack dependencies: #167382, #167383, #167384, #167387
2025-11-13 21:47:30 +00:00
4fc688625a [3.14, dataloader] handle forkserver default mp start method in 3.14 (#167387)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167387
Approved by: https://github.com/malfet
ghstack dependencies: #167382, #167383, #167384
2025-11-13 21:47:30 +00:00
23f4f323ea [dynamo, 3.14] enable dynamo in 3.14 (#167384)
dynamo tests are passing in the CI PR above - so we could probably just enable dynamo right now.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167384
Approved by: https://github.com/Skylion007, https://github.com/mlazos
ghstack dependencies: #167382, #167383
2025-11-13 21:47:23 +00:00
9ac3fc0d0a [inductor, 3.14] catch pickle.PicklingError exceptions (#167383)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167383
Approved by: https://github.com/aorenste, https://github.com/mlazos
ghstack dependencies: #167382
2025-11-13 21:47:14 +00:00
38806f381a [inductor, 3.14] fix itertools.product pickle error in test_cpu_repro (#167382)
`inductor/test_cpu_cpp_wrapper` was failing since it was attempting to pickle`itertools.product`, and that is no longer picklable in 3.14. We work around by eagerly generating a list.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167382
Approved by: https://github.com/atalman, https://github.com/malfet, https://github.com/mlazos
2025-11-13 21:47:06 +00:00
cfb3a6b3da [2/N][BugFix][Refactor] fix several instances which use f = open(...) without a corresponding f.close() (#167628)
continue in https://github.com/pytorch/pytorch/pull/167423

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167628
Approved by: https://github.com/cyyever, https://github.com/Skylion007
2025-11-13 21:15:45 +00:00
d8384e296e [Inductor] Remove bf16 fallback for atomic_add (#167380)
Fixes: #97016

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167380
Approved by: https://github.com/mlazos
2025-11-13 20:41:35 +00:00
d273422582 [CUDA] Large max pool fix (#167427)
Fixes #167253
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167427
Approved by: https://github.com/eqy, https://github.com/malfet
2025-11-13 20:11:41 +00:00
fadb62f592 [PyTorch] fix profiler issue with empty exported trace file (#167601)
Summary:
The previous implementation incorrectly attempted to read from a `NamedTemporaryFile` file pointer after calling `profiler.export_chrome_trace(fp.name)`. The issue is that `export_chrome_trace()` writes to a file at the path `fp.name`, but doesn't write to the file pointer `fp` itself. This meant when the code tried to read from `fp`, it got empty content.

The fix explicitly closes the temporary file first, then calls `export_chrome_trace(fp.name)` which writes the JSON trace to a file at that path. We then open that file separately for reading and copy its contents to the gzipped output file. This ensures we're reading from the actual file that was written to, not an empty file pointer.

Changes made in both `fbcode/caffe2/torch/profiler/profiler.py` and `xplat/caffe2/torch/profiler/profiler.py`:
- `export_chrome_trace()`: Fixed file reading for gzipped chrome trace exports by opening the written file separately
- `export_memory_timeline()`: Fixed file reading for gzipped memory timeline exports by opening the written file separately

Test Plan:
* run benchmark
```
buck2 run fbcode//mode/opt fbcode//torchrec/distributed/benchmark:benchmark_train_pipeline -- \
    --yaml_config=fbcode/torchrec/distributed/benchmark/yaml/sparse_data_dist_base.yml
```
* upload trace
```
DIFF=D86737513 fbcode/torchrec/fb/scripts/trace_to_manifold.sh
```
======== markdown ============

[manifold folder](https://www.internalfb.com/manifold/explorer/torchrec_benchmark_traces/tree/permanent_traces/DIFF/D86737513)
[trace-sparse_data_dist_base-rank0.json.gz](https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/permanent_traces/DIFF/D86737513/trace-sparse_data_dist_base-rank0.json.gz&bucket=torchrec_benchmark_traces)

Differential Revision: D86737513

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167601
Approved by: https://github.com/angelayi
2025-11-13 19:40:09 +00:00
e5eb89e111 remove allocation of new unbacked symbols during mod eval (#167123)
When executing code like torch._check(numel % newsize == 0, ...), we previously allocated a new unbacked symbol due to #113165. However, this allocation is no longer necessary and can cause issues due to inconsistent behavior when tracing torch._check multiple times.

In particular, the allocation can lead to a memo disaster where the previously allocated symbol is returned instead of a new one, causing unexpected behavior.

This PR removes the unnecessary allocation, ensuring consistent behavior and avoiding potential issues. The change is validated by the following code, which now compiles without issues:
```
import torch

def fn(x):
    i0 = x.nonzero().size(0)
    y = torch.zeros((i0, 192))
    return y.view([12, -1, 192])
with torch._dynamo.config.patch({"capture_dynamic_output_shape_ops": True}):
    torch.compile(fn, fullgraph=True)(torch.ones((12,)))
```

By removing this unnecessary allocation, we simplify the code and avoid potential issues."

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167123
Approved by: https://github.com/Lucaskabela
2025-11-13 18:52:41 +00:00
b5e0e6932a Correctly populate storage offset in DTensor constructor (#167597)
The storage offset always matches the local offset because you never have rank dependent offset (your shard may be different, but your view into it will always be the same across all ranks!)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167597
Approved by: https://github.com/malfet
ghstack dependencies: #166868, #166867, #167076
2025-11-13 18:26:11 +00:00
6ea779188c [DebugMode] torch.hash_tensor option (#167486)
Adds `torch.hash_tensor` (#154149) as tensor hashing variant; allows tuple of hashes in log annotations for more info (e.g. `with DebugMode.log_tensor_hashes(hash_fn=["norm", "hash_tensor"]): ...`)

also fixes some corner cases around norm hashing (preserves NaNs/infs, avoids erroring on smaller dtypes)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167486
Approved by: https://github.com/xmfan
2025-11-13 17:46:09 +00:00
460c7e196c Handle only a Tensor for IntList parsing (#167606)
Fixes https://github.com/pytorch/pytorch/issues/167562

Authored with Claude Code

Signed-off-by: Edward Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167606
Approved by: https://github.com/colesbury
2025-11-13 17:39:38 +00:00
7aac506cdc Revert "[precompile] Integrate AOTI as a backend. (#167338)"
This reverts commit 273babeec3c6211f30b806797f35a6e9c47c737f.

Reverted https://github.com/pytorch/pytorch/pull/167338 on behalf of https://github.com/jeanschmidt due to seems to be breaking internal tests and builds, see D86919103 ([comment](https://github.com/pytorch/pytorch/pull/167338#issuecomment-3528950888))
2025-11-13 17:39:03 +00:00
374ee9e867 Fix missing thrust includes (#167450)
CCCL recently dropped a ton of transient includes that blew up thrust compile times

That means we need to include what we use

Fixes build issues found in internal CI

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167450
Approved by: https://github.com/Skylion007, https://github.com/Aidyn-A
2025-11-13 17:02:43 +00:00
698aa0f3e5 [MPS] sparse_mask_projection (#166260)
Implements sparse mask projection. I'm aware that SparseMPSTensorMath needs some refactoring, which I'll do in a followup PR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166260
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-11-13 17:01:54 +00:00
eqy
d3ca4a3a4f [CUDA][64-bit indexing] Handle 64-bit outer dim cumsum case (#167326)
For #167086, same change more or less as #143696

Let's see if CI wants a large tensor test decorator

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167326
Approved by: https://github.com/ngimel, https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-11-13 17:00:00 +00:00
c940b1fbbc address DDE in matmul decomp (#166541)
Address https://github.com/pytorch/pytorch/issues/165081
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166541
Approved by: https://github.com/mlazos
2025-11-13 16:41:35 +00:00
4de24bcc56 [Fix XPU typo] Fix a comment typo of FindSYCLToolkit.cmake (#165884)
The character U+ff1a ":" could be confused with the ASCII character U+003a ":", which is more common in source code.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165884
Approved by: https://github.com/cyyever, https://github.com/guangyey, https://github.com/EikanWang
2025-11-13 12:32:48 +00:00
f2d0a472ef [xpu][feature] Add XPU support on torch.accelerator.get_memory_info (#162564)
# Motivation
Support XPU for `torch.accelerator.get_memory_info`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162564
Approved by: https://github.com/albanD
ghstack dependencies: #156812
2025-11-13 11:03:17 +00:00
9ae0ecec7d Introduce a new API torch.accelerator.get_memory_info (#156812)
# Motivation
`torch.cuda.mem_get_info` and `torch.xpu.mem_get_info` are widely used in other popular repos, such as
- 076313bd09/python/sglang/srt/utils.py (L378),
- 7ecc2d7f39/src/accelerate/utils/modeling.py (L822),
- 7ba34b1241/vllm/worker/worker.py (L150).
-
This PR introduces a unified API `torch.accelerator.get_memory_info` to cover this scenario.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156812
Approved by: https://github.com/albanD
2025-11-13 11:01:39 +00:00
ce4f31f662 [OpenReg][Feat][Docs] Enrich hook implementation and add focused documentation (#165980)
## Summary
This PR enriches the implementation of `OpenRegHooks.h` and adds focused documentation for `OpenReg` hooks.

## Key Changes
- A new document: `docs/source/accelerator/hooks.md`
- New `OpenReg` hooks like `isBuilt()`, `isAvailable()` and so on...

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165980
Approved by: https://github.com/fffrog

Co-authored-by: Jiawei Li <ljw1101.vip@gmail.com>
2025-11-13 08:36:18 +00:00
2c846bb614 [xpu][test]port embedding indexing and native_mha test files for Intel GPU (#165886)
we port test_indexing, test_native_mha and test_embedding for Intel GPU in this pr.
We could enable Intel GPU with following methods and try the best to keep the original code styles:

Use torch.accelerator for general gpu
Skip the case if running on xpu which has known issues
using torch.nn.attention.sdpa_kernel() to replace torch.backends.cuda.sdp_kernel() for Intel GPU as torch.backends.cuda.sdp_kernel() is depricated and Intel xpu did not support it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165886
Approved by: https://github.com/guangyey, https://github.com/albanD
2025-11-13 08:17:23 +00:00
8c86ccfbc9 [DebugMode] .show_stack_trace inline (#167589)
Shows inline stack traces, with `.debug_string(show_stack_trace=True)`. For bwd ops we use `.fwd_stack_trace` when available.

Needs some improvement for:
- backwards: not all dispatch calls run under an autograd node, so some just have generic traces (e.g. `loss.backward()`)
- compiled regions: stack trace isn't very meaningful to start (e.g. points to codegened line)

Sample for test_nn_module (fwd + bwd):
```
    # File: /data/users/pianpwk/pytorch/test/distributed/tensor/debug/test_debug_mode.py:396 in forward, code: return self.l2(self.l1(x))
    aten::t(t: f32[4, 4])
    aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])
    aten::t(t: f32[4, 4])
    aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])

    # File: /data/users/pianpwk/pytorch/test/distributed/tensor/debug/test_debug_mode.py:405 in forward, code: return self.xyz(self.abc(x))
    aten::t(t: f32[4, 4])
    aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])

    # File: /data/users/pianpwk/pytorch/test/distributed/tensor/debug/test_debug_mode.py:429 in test_nn_module, code: out = mod(inp).sum()
    aten::sum(t: f32[4, 4])

    # File: /data/users/pianpwk/pytorch/test/distributed/tensor/debug/test_debug_mode.py:430 in test_nn_module, code: out.backward()
    aten::ones_like(t: f32[], pin_memory=False, memory_format=torch.preserve_format)

    # File: /data/users/pianpwk/pytorch/test/distributed/tensor/debug/test_debug_mode.py:429 in test_nn_module, code: out = mod(inp).sum()
    aten::expand(t: f32[], [4, 4])

    # File: /data/users/pianpwk/pytorch/test/distributed/tensor/debug/test_debug_mode.py:405 in forward, code: return self.xyz(self.abc(x))
    aten::t(t: f32[4, 4])
    aten::mm(t: f32[4, 4], t: f32[4, 4])
    aten::t(t: f32[4, 4])
    aten::mm(t: f32[4, 4], t: f32[4, 4])
    aten::t(t: f32[4, 4])
    aten::sum.dim_IntList(t: f32[4, 4], [0], True)
    aten::view(t: f32[1, 4], [4])

    # File: /data/users/pianpwk/pytorch/test/distributed/tensor/debug/test_debug_mode.py:430 in test_nn_module, code: out.backward()
    aten::detach(t: f32[4])

    # File: /data/users/pianpwk/pytorch/test/distributed/tensor/debug/test_debug_mode.py:405 in forward, code: return self.xyz(self.abc(x))
    aten::t(t: f32[4, 4])

    # File: /data/users/pianpwk/pytorch/test/distributed/tensor/debug/test_debug_mode.py:430 in test_nn_module, code: out.backward()
    aten::detach(t: f32[4, 4])

    # File: /data/users/pianpwk/pytorch/test/distributed/tensor/debug/test_debug_mode.py:396 in forward, code: return self.l2(self.l1(x))
    aten::t(t: f32[4, 4])
    aten::mm(t: f32[4, 4], t: f32[4, 4])
    aten::t(t: f32[4, 4])
    aten::mm(t: f32[4, 4], t: f32[4, 4])
    aten::t(t: f32[4, 4])
    aten::sum.dim_IntList(t: f32[4, 4], [0], True)
    aten::view(t: f32[1, 4], [4])

    # File: /data/users/pianpwk/pytorch/test/distributed/tensor/debug/test_debug_mode.py:430 in test_nn_module, code: out.backward()
    aten::detach(t: f32[4])

    # File: /data/users/pianpwk/pytorch/test/distributed/tensor/debug/test_debug_mode.py:396 in forward, code: return self.l2(self.l1(x))
    aten::t(t: f32[4, 4])

    # File: /data/users/pianpwk/pytorch/test/distributed/tensor/debug/test_debug_mode.py:430 in test_nn_module, code: out.backward()
    aten::detach(t: f32[4, 4])

    # File: /data/users/pianpwk/pytorch/test/distributed/tensor/debug/test_debug_mode.py:396 in forward, code: return self.l2(self.l1(x))
    aten::t(t: f32[4, 4])
    aten::mm(t: f32[4, 4], t: f32[4, 4])
    aten::t(t: f32[4, 4])
    aten::sum.dim_IntList(t: f32[4, 4], [0], True)
    aten::view(t: f32[1, 4], [4])

    # File: /data/users/pianpwk/pytorch/test/distributed/tensor/debug/test_debug_mode.py:430 in test_nn_module, code: out.backward()
    aten::detach(t: f32[4])

    # File: /data/users/pianpwk/pytorch/test/distributed/tensor/debug/test_debug_mode.py:396 in forward, code: return self.l2(self.l1(x))
    aten::t(t: f32[4, 4])

    # File: /data/users/pianpwk/pytorch/test/distributed/tensor/debug/test_debug_mode.py:430 in test_nn_module, code: out.backward()
    aten::detach(t: f32[4, 4])
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167589
Approved by: https://github.com/yushangdi
2025-11-13 08:15:27 +00:00
8f96e7bc1d Only remove_noop in pre_grad passes if remove_noop is not in the remove_passes_list (#167479)
Summary: Only remove_noop in pre_grad passes if remove_noop is not in the remove_passes_list

Test Plan:
Tested as part of lowering for ss_omni_exp model.

f825774360

Unit Tests were run and succeeded as well!

Differential Revision: D86694854

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167479
Approved by: https://github.com/mlazos
2025-11-13 07:27:31 +00:00
782fc3c72b [DTensor] Add CPU instruction count benchmark for dispatch (#167394)
Following example from #149932 and doc in
[README.md](benchmarks/dynamo/pr_time_benchmarks/README.md)

cd benchmarks/dynamo/pr_time_benchmarks
`PYTHONPATH=./:../../../ python benchmarks/dtensor.py a`

Currently outputs:

```
collecting instruction count for dtensor_dispatch_detach
instruction count for iteration 0 is 14919468
instruction count for iteration 1 is 136283
instruction count for iteration 2 is 133750
instruction count for iteration 3 is 133757
instruction count for iteration 4 is 133751
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167394
Approved by: https://github.com/laithsakka
2025-11-13 06:54:08 +00:00
1a67403fc6 Move MemPool out of c10 and into ATen. (#167506)
Necessary to allow CachingHostAllocator, which sits in ATen, to
allocate its memory to a memory pool.

Otherwise, we would have a circular dependency, where libtorch_cuda.so
depends upon libc10_cuda.so, but libc10_cuda.so's MemPool object
references CachingHostAllocator symbols in libtorch_cuda.so.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167506
Approved by: https://github.com/ngimel, https://github.com/malfet
2025-11-13 06:18:29 +00:00
3d801a4c01 DTensor fast path: port return_and_correct_aliasing and inplace/out checks (#167475)
This seems to generate a several-microsecond performance improvement in the detach benchmark I've been using.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167475
Approved by: https://github.com/ezyang
ghstack dependencies: #167051, #166372, #166808
2025-11-13 06:11:38 +00:00
2034ca99ae extend C++ DTensor fast path to local operator dispatch (#166808)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166808
Approved by: https://github.com/ezyang
ghstack dependencies: #167051, #166372
2025-11-13 06:11:38 +00:00
480b4ff882 Avoid creating Python OpSchema in the DTensor dispatch fast path (#166372)
All we need to do is move a few checks around.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166372
Approved by: https://github.com/ezyang
ghstack dependencies: #167051
2025-11-13 06:11:30 +00:00
f570e589da Add C++ fast path for DTensor.__torch_dispatch__ (#167051)
This patches the `__torch_dispatch__` machinery to detect DTensor and hand over control to a C++ fast path. Unlike #166370 and #166369 (which added a DTensor dispatch key and are intended to be replaced by this PR), this approach fundamentally *is* `__torch_dispatch__`, hopefully sidestepping all manner of thorny "does it work just like `__torch_dispatch__`?" that came up during development and review of #166370.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167051
Approved by: https://github.com/ezyang
2025-11-13 06:11:22 +00:00
f9851af59b Add Attention ops to CI (#165915)
This pull request introduces a new attention operator microbenchmark workflow to the CI system, enabling automated benchmarking and reporting for attention-related operations. The main changes include adding a new GitHub Actions workflow, to add attention benchmarks to the existing Pytorch operator microbenchmark [dashboard](https://hud.pytorch.org/benchmark/v3/dashboard/pytorch_operator_microbenchmark?renderGroupId=main&time.start=2025-10-27T00%3A00%3A00.000Z&time.end=2025-10-29T01%3A00%3A00.000Z&filters.device=cuda&filters.arch=NVIDIA+A100-SXM4-40GB&filters.deviceName=cuda%7C%7CNVIDIA+A100-SXM4-40GB&filters.operatorName=&lcommit.commit=665df0bc7288996d638fcc3da750f8cb2addd6d0&lcommit.workflow_id=18888994873&lcommit.date=2025-10-29T00%3A00%3A00Z&lcommit.branch=refs%2Ftags%2Fciflow%2Fop-benchmark%2F165915&rcommit.commit=665df0bc7288996d638fcc3da750f8cb2addd6d0&rcommit.workflow_id=18888994873&rcommit.date=2025-10-29T00%3A00%3A00Z&rcommit.branch=refs%2Ftags%2Fciflow%2Fop-benchmark%2F165915&lbranch=refs%2Ftags%2Fciflow%2Fop-benchmark%2F165915&rbranch=refs%2Ftags%2Fciflow%2Fop-benchmark%2F165915)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165915
Approved by: https://github.com/jbschlosser
2025-11-13 05:30:04 +00:00
eeebf9f664 [dynamo] [3.14] Update broken numpy test (#167681)
This is related to upgrading numpy versions, not 3.14 specifically.  See https://github.com/numpy/numpy/pull/27148
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167681
Approved by: https://github.com/williamwen42
ghstack dependencies: #167619
2025-11-13 04:27:55 +00:00
d9a50bf9a8 [dynamo] [3.14] Support np._CopyMode (#167619)
Upgrading scipy to 1.16 introduced errors related to the `copy` parameter of
`np.array`.  Add special handling for `np._CopyMode.IF_NEEDED`, which is not
handled correctly, but matches the existing behavior when `copy=None`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167619
Approved by: https://github.com/williamwen42
2025-11-13 04:27:55 +00:00
2984331c87 [inductor][NFC][2/X] extract do_autotuning/autotune/benchmark from AlgorithmSelectorCache.__call__ (#167489)
Summary: see https://github.com/pytorch/pytorch/pull/167487 for context

Test Plan: CI

Differential Revision: D86714833

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167489
Approved by: https://github.com/aorenste
2025-11-13 03:29:39 +00:00
9b68682df2 [ROCm] Enable several DISABLED issues (#167183)
Profiler:
Fixes #166422

Default:
Fixes #165386
Fixes #145019
Fixes #145069
Fixes #165295
Fixes #165294
Fixes #165093
Fixes #164235
Fixes #164194
Fixes #164193
Fixes #155217
Fixes #163918
Fixes #163917
Fixes #155235
Fixes #122352
Fixes #121576
Fixes #121806
Fixes #104366

Inductor:
Fixes #164337
Fixes #148523
Fixes #115002
Fixes #111066
Fixes #107774

Distributed
Fixes #161612
Fixes #161502
Fixes #161459
Fixes #161402
Fixes #155711
Fixes #152201
Fixes #152367
Fixes #152349
Fixes #152168
Fixes #152169
Fixes #151153
Fixes #151077
Fixes #112815

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167183
Approved by: https://github.com/jeffdaily
2025-11-13 02:50:35 +00:00
8f5f89c9a0 Revert "Fix thread safety in getCurrentCUDABlasHandle and getCUDABlasLtWorkspace (#167248)"
This reverts commit 537167aa1e50a4379dca244163aaf369ed8e5161.

Reverted https://github.com/pytorch/pytorch/pull/167248 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/167248#issuecomment-3524925727))
2025-11-13 02:46:35 +00:00
8919f69362 [Inductor][2/2] Decouple flags for optimization and debug symbols (#167575)
Summary:
What: Decouple flags for compile (unoptimized build) and symbols (optimized build)
Why: Reduce confusion around naming and usage

Test Plan: Unit test & CI

Differential Revision: D86683526

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167575
Approved by: https://github.com/jansel, https://github.com/hl475
2025-11-13 00:59:15 +00:00
19c867873a [opqaue obj] Add attribute support (#167230)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167230
Approved by: https://github.com/zou3519
ghstack dependencies: #163284, #163714, #163936
2025-11-13 00:35:20 +00:00
e3dadb1d36 [opaque obj] torch.compile support (#163936)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163936
Approved by: https://github.com/zou3519
ghstack dependencies: #163284, #163714
2025-11-13 00:35:20 +00:00
c9b09a31e8 [opaque obj] Allow non-effectful scriptobjs (#163714)
Fixes functionalization so that we can run ops using ScriptObjects w/o needing effects. Previously we would run into an error when running functionalization on the TorchBindOpOverloads.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163714
Approved by: https://github.com/zou3519
ghstack dependencies: #163284
2025-11-13 00:35:20 +00:00
35571fe94b [effects] Add register_effectful_op (#163284)
Refactored register_effectful_op to return a handler to match how fake kernels are registered. This makes it easier to deregister effects

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163284
Approved by: https://github.com/zou3519
2025-11-13 00:35:20 +00:00
485f2b607a ProxyTorchDispatchMode: Decomposing missing sympy.SymExpr should handle constant literals (#167585)
The previous work to decompose missing sympy.SymExpr (#164717) handled combinations of sub-nodes (like `s1*s2`) but I forgot to handle explicit literals (like `2*s2`).

Added a unit test based on the report.

Fixes T244632748

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167585
Approved by: https://github.com/bobrenjc93
2025-11-13 00:27:10 +00:00
0c5d5c7e9a [dynamo][invoke_subgraph] Do not restore side effects on invoke_subgraph (#167446)
Test that checks non proxy-able outputs. Also add a test that fails to
be fixed later.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167446
Approved by: https://github.com/zou3519
ghstack dependencies: #167438, #167442
2025-11-13 00:16:40 +00:00
5f98a0363a [dynamo] Make HintsWrapperHigherOrderVariable follow wrap semantics (#167442)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167442
Approved by: https://github.com/zou3519
ghstack dependencies: #167438
2025-11-13 00:16:40 +00:00
2d739001d3 [dynamo] speculate_subgraph_with_auto_output_flattening (#167438)
Summary

  This PR refactors the wrap higher-order operator infrastructure in PyTorch's Dynamo to introduce automatic output flattening for subgraph speculation. The key change is the addition of
  speculate_subgraph_with_auto_output_flattening() which separates the output variable trackers (VTs) that Dynamo continues tracing with from the actual FX graph outputs.

  Key Changes

  New speculate_subgraph_with_auto_output_flattening() function

  - Introduces a new approach for handling HOPs (Higher-Order Operators) that are just "subgraph placeholders", i.e. the HOP essentially just runs the subgraph with inputs (e.g., invoke_subgraph, activation checkpointing,
   autograd.Function)
  - Disentangles output VTs from graph outputs: Allows the subgraph to return complex Python objects (like custom user-defined objects containing tensors) while only registering tensor/symint VTs as actual FX
  graph outputs
  - Mirrors typical Dynamo processing where VTs can "run ahead" for continued tracing while the graph is a side data structure

  Benefits

  1. Handles non-proxyable outputs: Supports HOPs that return custom Python objects containing tensors
  2. Cleaner separation of concerns: Output VTs for continued tracing vs. graph outputs for FX representation
  3. More flexible: Returns graph_output_vts instead of treespec, giving more control over what becomes a graph output

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167438
Approved by: https://github.com/zou3519
2025-11-13 00:16:40 +00:00
273babeec3 [precompile] Integrate AOTI as a backend. (#167338)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167338
Approved by: https://github.com/jamesjwu
2025-11-13 00:02:26 +00:00
a76dd6b7c6 [MPS] SparseMps mv op (#166708)
Should be merged after #166561
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166708
Approved by: https://github.com/malfet
2025-11-12 22:44:29 +00:00
2fa18d1545 [export] Codemod more tests to use dynamo_graph_capture_for_export (#167663)
Summary:
as title.

Test Plan:
CI

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167663
Approved by: https://github.com/tugsbayasgalan
2025-11-12 22:44:18 +00:00
537167aa1e Fix thread safety in getCurrentCUDABlasHandle and getCUDABlasLtWorkspace (#167248)
Summary:
getCurrentCUDABlasHandle() and getCUDABlasLtWorkspace() use static mutable maps that are not protected from concurrent read-and-write. This leads to crashes.

This diff adds mutexes to synchronize access to the static maps.

Test Plan:
Use a GPU OD, run multi-threaded tests with TSAN:
```
buck test fbcode//mode/dev-tsan fbcode//caffe2:cuda_cublas_handle_pool_test  -- --stress-runs 100
```
https://www.internalfb.com/intern/testinfra/testrun/14355223937501118

TSAN: P2026731804

Differential Revision: D86316117

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167248
Approved by: https://github.com/Skylion007, https://github.com/malfet
2025-11-12 22:43:56 +00:00
0dac408f43 MatMal - fix folding logic (#166891)
Summary:
Folding logic on Matmal can be decomposed to BMM or folding + MM.

Current common Training path for 3D * 2D matmul: library will always fold, since Tensor1 or Tensor2 BOTH require a grad, so we fold since Tensor2 has grad.   But reasoning isn't really sound, it was done as a memory optimization - when its also generally same/more performant.

However, in Chemistry / Modular Modeling its common to directly calculate Forces as derivate of Energy (ie. dl/dX, but NOT dl/dW) in inference.  This exposed bug where we only have 1 of 2 Tensors requires grad, and may choose NOT to fold, resulting in 30% regression due to suboptimal BMM decomposition of torch.nn.Linear (-> calls into matmul).

I actually think even in cases we need either dl/dX or dl/dW, we should be folding when working with inputs of [B, M, N] and weights of [N, K].  Its strictly better for memory and same/faster when you consider both forward + backward runtime, and M's that are not multiples of 8 are particularly brutally slow using BMM vs MM.

Also, compiler out of box could not solve this issue, which raise another concern (was actually highlighted 2 years ago in comments, but seems still case today: (https://github.com/pytorch/pytorch/issues/118548#issuecomment-1919528910)

Differential Revision: D86128493

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166891
Approved by: https://github.com/ngimel
2025-11-12 22:18:03 +00:00
158e72427b [torch] Update caffe2/c10/cuda to build under CUDA 13 (#167534)
Summary:
Update caffe2/c10/cuda to build under CUDA 13

As of CUDA 13, the cudaMemAdvise() has been updated to take in `cudaMemLocation` as argument instead of `int` device id

This is needed for building FBGEMM_GPU under CUDA 13 (see D86372925)

Test Plan:
```
# Default build
buck build  @//mode/opt fbcode//caffe2/c10/cuda:cuda

# CUDA 13 build
buck build  @//mode/opt -c fbcode.arch=aarch64 -c fbcode.nvcc_arch=b200 -c fbcode.platform010_cuda_version=13.0  fbcode//caffe2/c10/cuda:cuda

# AMD build
buck build --flagfile fbcode//mode/dev-nosan-amd-gpu fbcode//caffe2/c10/cuda:cuda
```

Reviewed By: atalman

Differential Revision: D86578286

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167534
Approved by: https://github.com/seemethere
2025-11-12 22:12:40 +00:00
0184ef291d [inductor][NFC][1/X] extract create_no_valid_choices from AlgorithmSelectorCache.__call__ (#167487)
Summary:
What: moves `create_no_valid_choices` out of `AlgorithmSelectorCache.__call__` and into the body of `AlgorithmSelectorCache`
Why: nested function definitions make it harder to understand what `AlgorithmSelectorCache.__call__` is doing, on top of making patching/testing/etc more difficult

Test Plan: CI

Differential Revision: D86712921

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167487
Approved by: https://github.com/aorenste
2025-11-12 22:03:37 +00:00
2ca428c721 [CD] Preload libnvrtc-builtinso.so (#167614)
Which is a regression introduced by https://github.com/pytorch/pytorch/pull/167046
That causes CuDNN SDPA fail with actionable `cuDNN Frontend error: [cudnn_frontend] Error: No valid execution plans built.` error

Change `cuda_libs` from dict to list, and add `test_sdpa` regression test to binary smoke tests

Fixes https://github.com/pytorch/pytorch/issues/167602
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167614
Approved by: https://github.com/Aidyn-A, https://github.com/atalman, https://github.com/nWEIdia
2025-11-12 21:50:13 +00:00
1311385f9d Revert "fix failure of exporting compiled model with nested dynamic shapes (#166358)"
This reverts commit 416421c7c455e3befb0772fcc3379661a24aff71.

Reverted https://github.com/pytorch/pytorch/pull/166358 on behalf of https://github.com/jeanschmidt due to seems to be breaking internal signals, see D86790405, @angelayi may you help the author get this change landed? ([comment](https://github.com/pytorch/pytorch/pull/166358#issuecomment-3524052822))
2025-11-12 21:46:38 +00:00
5f0a5b8f87 Revert "Use stable topological sort in fuse_by_partitions (#167397)"
This reverts commit 7886070fc5cdbc9b51b7e2b6432c80ccae01c4fc.

Reverted https://github.com/pytorch/pytorch/pull/167397 on behalf of https://github.com/jeanschmidt due to seems to be breaking executorch signals internally, see D86780724 ([comment](https://github.com/pytorch/pytorch/pull/167397#issuecomment-3523992343))
2025-11-12 21:26:57 +00:00
74e85c6944 Add TORCH_BOX helper for STABLE_TORCH_LIBRARY_IMPL (#167582)
Implementation greatly adapted from @lw's https://github.com/pytorch/pytorch/pull/163505. TORCH_BOX is the StableIValue version of `make_boxed_from_unboxed_functor`.

the differences:
- uses headeronly concepts
- adds an unbox type mapping to support user kernels taking in torch::headeronly::HeaderOnlyArrayRef<T> (by calling to<std::vector<T>> in those cases)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167582
Approved by: https://github.com/swolchok
ghstack dependencies: #167386
2025-11-12 20:29:21 +00:00
401 changed files with 13867 additions and 6641 deletions

View File

@ -1,19 +0,0 @@
# Aarch64 (ARM/Graviton) Support Scripts
Scripts for building aarch64 PyTorch PIP Wheels. These scripts build the following wheels:
* torch
* torchvision
* torchaudio
* torchtext
* torchdata
## Aarch64_ci_build.sh
This script is design to support CD operations within PyPi manylinux aarch64 container, and be executed in the container. It prepares the container and then executes __aarch64_wheel_ci_build.py__ to build the wheels. The script "assumes" the PyTorch repo is located at: ```/pytorch``` and will put the wheels into ```/artifacts```.
### Usage
```DESIRED_PYTHON=<PythonVersion> aarch64_ci_build.sh```
__NOTE:__ CI build is currently __EXPERMINTAL__
## Build_aarch64_wheel.py
This app allows a person to build using AWS EC3 resources and requires AWS-CLI and Boto3 with AWS credentials to support building EC2 instances for the wheel builds. Can be used in a codebuild CD or from a local system.
### Usage
```build_aarch64_wheel.py --key-name <YourPemKey> --use-docker --python 3.8 --branch <RCtag>```

View File

@ -1,53 +0,0 @@
#!/bin/bash
set -eux -o pipefail
GPU_ARCH_VERSION=${GPU_ARCH_VERSION:-}
# Set CUDA architecture lists to match x86 build_cuda.sh
if [[ "$GPU_ARCH_VERSION" == *"12.6"* ]]; then
export TORCH_CUDA_ARCH_LIST="8.0;9.0"
elif [[ "$GPU_ARCH_VERSION" == *"12.8"* ]]; then
export TORCH_CUDA_ARCH_LIST="8.0;9.0;10.0;12.0"
elif [[ "$GPU_ARCH_VERSION" == *"12.9"* ]]; then
export TORCH_CUDA_ARCH_LIST="8.0;9.0;10.0;12.0"
elif [[ "$GPU_ARCH_VERSION" == *"13.0"* ]]; then
export TORCH_CUDA_ARCH_LIST="8.0;9.0;10.0;11.0;12.0+PTX"
fi
# Compress the fatbin with -compress-mode=size for CUDA 13
if [[ "$DESIRED_CUDA" == *"13"* ]]; then
export TORCH_NVCC_FLAGS="-compress-mode=size"
# Bundle ptxas into the cu13 wheel, see https://github.com/pytorch/pytorch/issues/163801
export BUILD_BUNDLE_PTXAS=1
fi
SCRIPTPATH="$( cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )"
source $SCRIPTPATH/aarch64_ci_setup.sh
###############################################################################
# Run aarch64 builder python
###############################################################################
cd /
# adding safe directory for git as the permissions will be
# on the mounted pytorch repo
git config --global --add safe.directory /pytorch
pip install -r /pytorch/requirements.txt
pip install auditwheel==6.2.0 wheel
if [ "$DESIRED_CUDA" = "cpu" ]; then
echo "BASE_CUDA_VERSION is not set. Building cpu wheel."
python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn
else
echo "BASE_CUDA_VERSION is set to: $DESIRED_CUDA"
export USE_SYSTEM_NCCL=1
# Check if we should use NVIDIA libs from PyPI (similar to x86 build_cuda.sh logic)
if [[ -z "$PYTORCH_EXTRA_INSTALL_REQUIREMENTS" ]]; then
echo "Bundling CUDA libraries with wheel for aarch64."
else
echo "Using nvidia libs from pypi for aarch64."
echo "Updated PYTORCH_EXTRA_INSTALL_REQUIREMENTS for aarch64: $PYTORCH_EXTRA_INSTALL_REQUIREMENTS"
export USE_NVIDIA_PYPI_LIBS=1
fi
python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn --enable-cuda
fi

View File

@ -1,21 +0,0 @@
#!/bin/bash
set -eux -o pipefail
# This script is used to prepare the Docker container for aarch64_ci_wheel_build.py python script
# By creating symlinks from desired /opt/python to /usr/local/bin/
NUMPY_VERSION=2.0.2
if [[ "$DESIRED_PYTHON" == "3.13" || "$DESIRED_PYTHON" == "3.13t" ]]; then
NUMPY_VERSION=2.1.2
fi
SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )"
source $SCRIPTPATH/../manywheel/set_desired_python.sh
pip install -q numpy==${NUMPY_VERSION} pyyaml==6.0.2 scons==4.7.0 ninja==1.11.1 patchelf==0.17.2
for tool in python python3 pip pip3 ninja scons patchelf; do
ln -sf ${DESIRED_PYTHON_BIN_DIR}/${tool} /usr/local/bin;
done
python --version

View File

@ -1,333 +0,0 @@
#!/usr/bin/env python3
# encoding: UTF-8
import os
import shutil
from subprocess import check_call, check_output
def list_dir(path: str) -> list[str]:
"""'
Helper for getting paths for Python
"""
return check_output(["ls", "-1", path]).decode().split("\n")
def replace_tag(filename) -> None:
with open(filename) as f:
lines = f.readlines()
for i, line in enumerate(lines):
if line.startswith("Tag:"):
lines[i] = line.replace("-linux_", "-manylinux_2_28_")
print(f"Updated tag from {line} to {lines[i]}")
break
with open(filename, "w") as f:
f.writelines(lines)
def patch_library_rpath(
folder: str,
lib_name: str,
use_nvidia_pypi_libs: bool = False,
desired_cuda: str = "",
) -> None:
"""Apply patchelf to set RPATH for a library in torch/lib"""
lib_path = f"{folder}/tmp/torch/lib/{lib_name}"
if use_nvidia_pypi_libs:
# For PyPI NVIDIA libraries, construct CUDA RPATH
cuda_rpaths = [
"$ORIGIN/../../nvidia/cudnn/lib",
"$ORIGIN/../../nvidia/nvshmem/lib",
"$ORIGIN/../../nvidia/nccl/lib",
"$ORIGIN/../../nvidia/cusparselt/lib",
]
if "130" in desired_cuda:
cuda_rpaths.append("$ORIGIN/../../nvidia/cu13/lib")
else:
cuda_rpaths.extend(
[
"$ORIGIN/../../nvidia/cublas/lib",
"$ORIGIN/../../nvidia/cuda_cupti/lib",
"$ORIGIN/../../nvidia/cuda_nvrtc/lib",
"$ORIGIN/../../nvidia/cuda_runtime/lib",
"$ORIGIN/../../nvidia/cufft/lib",
"$ORIGIN/../../nvidia/curand/lib",
"$ORIGIN/../../nvidia/cusolver/lib",
"$ORIGIN/../../nvidia/cusparse/lib",
"$ORIGIN/../../nvidia/nvtx/lib",
"$ORIGIN/../../nvidia/cufile/lib",
]
)
# Add $ORIGIN for local torch libs
rpath = ":".join(cuda_rpaths) + ":$ORIGIN"
else:
# For bundled libraries, just use $ORIGIN
rpath = "$ORIGIN"
if os.path.exists(lib_path):
os.system(
f"cd {folder}/tmp/torch/lib/; "
f"patchelf --set-rpath '{rpath}' --force-rpath {lib_name}"
)
def copy_and_patch_library(
src_path: str,
folder: str,
use_nvidia_pypi_libs: bool = False,
desired_cuda: str = "",
) -> None:
"""Copy a library to torch/lib and patch its RPATH"""
if os.path.exists(src_path):
lib_name = os.path.basename(src_path)
shutil.copy2(src_path, f"{folder}/tmp/torch/lib/{lib_name}")
patch_library_rpath(folder, lib_name, use_nvidia_pypi_libs, desired_cuda)
def package_cuda_wheel(wheel_path, desired_cuda) -> None:
"""
Package the cuda wheel libraries
"""
folder = os.path.dirname(wheel_path)
os.mkdir(f"{folder}/tmp")
os.system(f"unzip {wheel_path} -d {folder}/tmp")
# Delete original wheel since it will be repackaged
os.system(f"rm {wheel_path}")
# Check if we should use PyPI NVIDIA libraries or bundle system libraries
use_nvidia_pypi_libs = os.getenv("USE_NVIDIA_PYPI_LIBS", "0") == "1"
if use_nvidia_pypi_libs:
print("Using nvidia libs from pypi - skipping CUDA library bundling")
# For PyPI approach, we don't bundle CUDA libraries - they come from PyPI packages
# We only need to bundle non-NVIDIA libraries
minimal_libs_to_copy = [
"/lib64/libgomp.so.1",
"/usr/lib64/libgfortran.so.5",
"/acl/build/libarm_compute.so",
"/acl/build/libarm_compute_graph.so",
"/usr/local/lib/libnvpl_lapack_lp64_gomp.so.0",
"/usr/local/lib/libnvpl_blas_lp64_gomp.so.0",
"/usr/local/lib/libnvpl_lapack_core.so.0",
"/usr/local/lib/libnvpl_blas_core.so.0",
]
# Copy minimal libraries to unzipped_folder/torch/lib
for lib_path in minimal_libs_to_copy:
copy_and_patch_library(lib_path, folder, use_nvidia_pypi_libs, desired_cuda)
# Patch torch libraries used for searching libraries
torch_libs_to_patch = [
"libtorch.so",
"libtorch_cpu.so",
"libtorch_cuda.so",
"libtorch_cuda_linalg.so",
"libtorch_global_deps.so",
"libtorch_python.so",
"libtorch_nvshmem.so",
"libc10.so",
"libc10_cuda.so",
"libcaffe2_nvrtc.so",
"libshm.so",
]
for lib_name in torch_libs_to_patch:
patch_library_rpath(folder, lib_name, use_nvidia_pypi_libs, desired_cuda)
else:
print("Bundling CUDA libraries with wheel")
# Original logic for bundling system CUDA libraries
# Common libraries for all CUDA versions
common_libs = [
# Non-NVIDIA system libraries
"/lib64/libgomp.so.1",
"/usr/lib64/libgfortran.so.5",
"/acl/build/libarm_compute.so",
"/acl/build/libarm_compute_graph.so",
# Common CUDA libraries (same for all versions)
"/usr/local/lib/libnvpl_lapack_lp64_gomp.so.0",
"/usr/local/lib/libnvpl_blas_lp64_gomp.so.0",
"/usr/local/lib/libnvpl_lapack_core.so.0",
"/usr/local/lib/libnvpl_blas_core.so.0",
"/usr/local/cuda/extras/CUPTI/lib64/libnvperf_host.so",
"/usr/local/cuda/lib64/libcudnn.so.9",
"/usr/local/cuda/lib64/libcusparseLt.so.0",
"/usr/local/cuda/lib64/libcurand.so.10",
"/usr/local/cuda/lib64/libnccl.so.2",
"/usr/local/cuda/lib64/libnvshmem_host.so.3",
"/usr/local/cuda/lib64/libcudnn_adv.so.9",
"/usr/local/cuda/lib64/libcudnn_cnn.so.9",
"/usr/local/cuda/lib64/libcudnn_graph.so.9",
"/usr/local/cuda/lib64/libcudnn_ops.so.9",
"/usr/local/cuda/lib64/libcudnn_engines_runtime_compiled.so.9",
"/usr/local/cuda/lib64/libcudnn_engines_precompiled.so.9",
"/usr/local/cuda/lib64/libcudnn_heuristic.so.9",
"/usr/local/cuda/lib64/libcufile.so.0",
"/usr/local/cuda/lib64/libcufile_rdma.so.1",
"/usr/local/cuda/lib64/libcusparse.so.12",
]
# CUDA version-specific libraries
if "13" in desired_cuda:
minor_version = desired_cuda[-1]
version_specific_libs = [
"/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.13",
"/usr/local/cuda/lib64/libcublas.so.13",
"/usr/local/cuda/lib64/libcublasLt.so.13",
"/usr/local/cuda/lib64/libcudart.so.13",
"/usr/local/cuda/lib64/libcufft.so.12",
"/usr/local/cuda/lib64/libcusolver.so.12",
"/usr/local/cuda/lib64/libnvJitLink.so.13",
"/usr/local/cuda/lib64/libnvrtc.so.13",
f"/usr/local/cuda/lib64/libnvrtc-builtins.so.13.{minor_version}",
]
elif "12" in desired_cuda:
# Get the last character for libnvrtc-builtins version (e.g., "129" -> "9")
minor_version = desired_cuda[-1]
version_specific_libs = [
"/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.12",
"/usr/local/cuda/lib64/libcublas.so.12",
"/usr/local/cuda/lib64/libcublasLt.so.12",
"/usr/local/cuda/lib64/libcudart.so.12",
"/usr/local/cuda/lib64/libcufft.so.11",
"/usr/local/cuda/lib64/libcusolver.so.11",
"/usr/local/cuda/lib64/libnvJitLink.so.12",
"/usr/local/cuda/lib64/libnvrtc.so.12",
f"/usr/local/cuda/lib64/libnvrtc-builtins.so.12.{minor_version}",
]
else:
raise ValueError(f"Unsupported CUDA version: {desired_cuda}.")
# Combine all libraries
libs_to_copy = common_libs + version_specific_libs
# Copy libraries to unzipped_folder/torch/lib
for lib_path in libs_to_copy:
copy_and_patch_library(lib_path, folder, use_nvidia_pypi_libs, desired_cuda)
# Make sure the wheel is tagged with manylinux_2_28
for f in os.scandir(f"{folder}/tmp/"):
if f.is_dir() and f.name.endswith(".dist-info"):
replace_tag(f"{f.path}/WHEEL")
break
os.system(f"wheel pack {folder}/tmp/ -d {folder}")
os.system(f"rm -rf {folder}/tmp/")
def complete_wheel(folder: str) -> str:
"""
Complete wheel build and put in artifact location
"""
wheel_name = list_dir(f"/{folder}/dist")[0]
# Please note for cuda we don't run auditwheel since we use custom script to package
# the cuda dependencies to the wheel file using update_wheel() method.
# However we need to make sure filename reflects the correct Manylinux platform.
if "pytorch" in folder and not enable_cuda:
print("Repairing Wheel with AuditWheel")
check_call(["auditwheel", "repair", f"dist/{wheel_name}"], cwd=folder)
repaired_wheel_name = list_dir(f"/{folder}/wheelhouse")[0]
print(f"Moving {repaired_wheel_name} wheel to /{folder}/dist")
os.rename(
f"/{folder}/wheelhouse/{repaired_wheel_name}",
f"/{folder}/dist/{repaired_wheel_name}",
)
else:
repaired_wheel_name = list_dir(f"/{folder}/dist")[0]
print(f"Copying {repaired_wheel_name} to artifacts")
shutil.copy2(
f"/{folder}/dist/{repaired_wheel_name}", f"/artifacts/{repaired_wheel_name}"
)
return repaired_wheel_name
def parse_arguments():
"""
Parse inline arguments
"""
from argparse import ArgumentParser
parser = ArgumentParser("AARCH64 wheels python CD")
parser.add_argument("--debug", action="store_true")
parser.add_argument("--build-only", action="store_true")
parser.add_argument("--test-only", type=str)
parser.add_argument("--enable-mkldnn", action="store_true")
parser.add_argument("--enable-cuda", action="store_true")
return parser.parse_args()
if __name__ == "__main__":
"""
Entry Point
"""
args = parse_arguments()
enable_mkldnn = args.enable_mkldnn
enable_cuda = args.enable_cuda
branch = check_output(
["git", "rev-parse", "--abbrev-ref", "HEAD"], cwd="/pytorch"
).decode()
print("Building PyTorch wheel")
build_vars = ""
# MAX_JOB=5 is not required for CPU backend (see commit 465d98b)
if enable_cuda:
build_vars += "MAX_JOBS=5 "
# Handle PyPI NVIDIA libraries vs bundled libraries
use_nvidia_pypi_libs = os.getenv("USE_NVIDIA_PYPI_LIBS", "0") == "1"
if use_nvidia_pypi_libs:
print("Configuring build for PyPI NVIDIA libraries")
# Configure for dynamic linking (matching x86 logic)
build_vars += "ATEN_STATIC_CUDA=0 USE_CUDA_STATIC_LINK=0 USE_CUPTI_SO=1 "
else:
print("Configuring build for bundled NVIDIA libraries")
# Keep existing static linking approach - already configured above
override_package_version = os.getenv("OVERRIDE_PACKAGE_VERSION")
desired_cuda = os.getenv("DESIRED_CUDA")
if override_package_version is not None:
version = override_package_version
build_vars += (
f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version} PYTORCH_BUILD_NUMBER=1 "
)
elif branch in ["nightly", "main"]:
build_date = (
check_output(["git", "log", "--pretty=format:%cs", "-1"], cwd="/pytorch")
.decode()
.replace("-", "")
)
version = (
check_output(["cat", "version.txt"], cwd="/pytorch").decode().strip()[:-2]
)
if enable_cuda:
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date}+{desired_cuda} PYTORCH_BUILD_NUMBER=1 "
else:
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date} PYTORCH_BUILD_NUMBER=1 "
elif branch.startswith(("v1.", "v2.")):
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={branch[1 : branch.find('-')]} PYTORCH_BUILD_NUMBER=1 "
if enable_mkldnn:
print("build pytorch with mkldnn+acl backend")
build_vars += "USE_MKLDNN=ON USE_MKLDNN_ACL=ON "
build_vars += "ACL_ROOT_DIR=/acl "
if enable_cuda:
build_vars += "BLAS=NVPL "
else:
build_vars += "BLAS=OpenBLAS OpenBLAS_HOME=/opt/OpenBLAS "
else:
print("build pytorch without mkldnn backend")
os.system(f"cd /pytorch; {build_vars} python3 -m build --wheel --no-isolation")
if enable_cuda:
print("Updating Cuda Dependency")
filename = os.listdir("/pytorch/dist/")
wheel_path = f"/pytorch/dist/{filename[0]}"
package_cuda_wheel(wheel_path, desired_cuda)
pytorch_wheel_name = complete_wheel("/pytorch/")
print(f"Build Complete. Created {pytorch_wheel_name}..")

View File

@ -1,999 +0,0 @@
#!/usr/bin/env python3
# This script is for building AARCH64 wheels using AWS EC2 instances.
# To generate binaries for the release follow these steps:
# 1. Update mappings for each of the Domain Libraries by adding new row to a table like this:
# "v1.11.0": ("0.11.0", "rc1"),
# 2. Run script with following arguments for each of the supported python versions and required tag, for example:
# build_aarch64_wheel.py --key-name <YourPemKey> --use-docker --python 3.8 --branch v1.11.0-rc3
import os
import subprocess
import sys
import time
from typing import Optional, Union
import boto3
# AMI images for us-east-1, change the following based on your ~/.aws/config
os_amis = {
"ubuntu20_04": "ami-052eac90edaa9d08f", # login_name: ubuntu
"ubuntu22_04": "ami-0c6c29c5125214c77", # login_name: ubuntu
"redhat8": "ami-0698b90665a2ddcf1", # login_name: ec2-user
}
ubuntu20_04_ami = os_amis["ubuntu20_04"]
def compute_keyfile_path(key_name: Optional[str] = None) -> tuple[str, str]:
if key_name is None:
key_name = os.getenv("AWS_KEY_NAME")
if key_name is None:
return os.getenv("SSH_KEY_PATH", ""), ""
homedir_path = os.path.expanduser("~")
default_path = os.path.join(homedir_path, ".ssh", f"{key_name}.pem")
return os.getenv("SSH_KEY_PATH", default_path), key_name
ec2 = boto3.resource("ec2")
def ec2_get_instances(filter_name, filter_value):
return ec2.instances.filter(
Filters=[{"Name": filter_name, "Values": [filter_value]}]
)
def ec2_instances_of_type(instance_type="t4g.2xlarge"):
return ec2_get_instances("instance-type", instance_type)
def ec2_instances_by_id(instance_id):
rc = list(ec2_get_instances("instance-id", instance_id))
return rc[0] if len(rc) > 0 else None
def start_instance(
key_name, ami=ubuntu20_04_ami, instance_type="t4g.2xlarge", ebs_size: int = 50
):
inst = ec2.create_instances(
ImageId=ami,
InstanceType=instance_type,
SecurityGroups=["ssh-allworld"],
KeyName=key_name,
MinCount=1,
MaxCount=1,
BlockDeviceMappings=[
{
"DeviceName": "/dev/sda1",
"Ebs": {
"DeleteOnTermination": True,
"VolumeSize": ebs_size,
"VolumeType": "standard",
},
}
],
)[0]
print(f"Create instance {inst.id}")
inst.wait_until_running()
running_inst = ec2_instances_by_id(inst.id)
print(f"Instance started at {running_inst.public_dns_name}")
return running_inst
class RemoteHost:
addr: str
keyfile_path: str
login_name: str
container_id: Optional[str] = None
ami: Optional[str] = None
def __init__(self, addr: str, keyfile_path: str, login_name: str = "ubuntu"):
self.addr = addr
self.keyfile_path = keyfile_path
self.login_name = login_name
def _gen_ssh_prefix(self) -> list[str]:
return [
"ssh",
"-o",
"StrictHostKeyChecking=no",
"-i",
self.keyfile_path,
f"{self.login_name}@{self.addr}",
"--",
]
@staticmethod
def _split_cmd(args: Union[str, list[str]]) -> list[str]:
return args.split() if isinstance(args, str) else args
def run_ssh_cmd(self, args: Union[str, list[str]]) -> None:
subprocess.check_call(self._gen_ssh_prefix() + self._split_cmd(args))
def check_ssh_output(self, args: Union[str, list[str]]) -> str:
return subprocess.check_output(
self._gen_ssh_prefix() + self._split_cmd(args)
).decode("utf-8")
def scp_upload_file(self, local_file: str, remote_file: str) -> None:
subprocess.check_call(
[
"scp",
"-i",
self.keyfile_path,
local_file,
f"{self.login_name}@{self.addr}:{remote_file}",
]
)
def scp_download_file(
self, remote_file: str, local_file: Optional[str] = None
) -> None:
if local_file is None:
local_file = "."
subprocess.check_call(
[
"scp",
"-i",
self.keyfile_path,
f"{self.login_name}@{self.addr}:{remote_file}",
local_file,
]
)
def start_docker(self, image="quay.io/pypa/manylinux2014_aarch64:latest") -> None:
self.run_ssh_cmd("sudo apt-get install -y docker.io")
self.run_ssh_cmd(f"sudo usermod -a -G docker {self.login_name}")
self.run_ssh_cmd("sudo service docker start")
self.run_ssh_cmd(f"docker pull {image}")
self.container_id = self.check_ssh_output(
f"docker run -t -d -w /root {image}"
).strip()
def using_docker(self) -> bool:
return self.container_id is not None
def run_cmd(self, args: Union[str, list[str]]) -> None:
if not self.using_docker():
return self.run_ssh_cmd(args)
assert self.container_id is not None
docker_cmd = self._gen_ssh_prefix() + [
"docker",
"exec",
"-i",
self.container_id,
"bash",
]
p = subprocess.Popen(docker_cmd, stdin=subprocess.PIPE)
p.communicate(
input=" ".join(["source .bashrc && "] + self._split_cmd(args)).encode(
"utf-8"
)
)
rc = p.wait()
if rc != 0:
raise subprocess.CalledProcessError(rc, docker_cmd)
def check_output(self, args: Union[str, list[str]]) -> str:
if not self.using_docker():
return self.check_ssh_output(args)
assert self.container_id is not None
docker_cmd = self._gen_ssh_prefix() + [
"docker",
"exec",
"-i",
self.container_id,
"bash",
]
p = subprocess.Popen(docker_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
(out, err) = p.communicate(
input=" ".join(["source .bashrc && "] + self._split_cmd(args)).encode(
"utf-8"
)
)
rc = p.wait()
if rc != 0:
raise subprocess.CalledProcessError(rc, docker_cmd, output=out, stderr=err)
return out.decode("utf-8")
def upload_file(self, local_file: str, remote_file: str) -> None:
if not self.using_docker():
return self.scp_upload_file(local_file, remote_file)
tmp_file = os.path.join("/tmp", os.path.basename(local_file))
self.scp_upload_file(local_file, tmp_file)
self.run_ssh_cmd(
["docker", "cp", tmp_file, f"{self.container_id}:/root/{remote_file}"]
)
self.run_ssh_cmd(["rm", tmp_file])
def download_file(self, remote_file: str, local_file: Optional[str] = None) -> None:
if not self.using_docker():
return self.scp_download_file(remote_file, local_file)
tmp_file = os.path.join("/tmp", os.path.basename(remote_file))
self.run_ssh_cmd(
["docker", "cp", f"{self.container_id}:/root/{remote_file}", tmp_file]
)
self.scp_download_file(tmp_file, local_file)
self.run_ssh_cmd(["rm", tmp_file])
def download_wheel(
self, remote_file: str, local_file: Optional[str] = None
) -> None:
if self.using_docker() and local_file is None:
basename = os.path.basename(remote_file)
local_file = basename.replace(
"-linux_aarch64.whl", "-manylinux2014_aarch64.whl"
)
self.download_file(remote_file, local_file)
def list_dir(self, path: str) -> list[str]:
return self.check_output(["ls", "-1", path]).split("\n")
def wait_for_connection(addr, port, timeout=15, attempt_cnt=5):
import socket
for i in range(attempt_cnt):
try:
with socket.create_connection((addr, port), timeout=timeout):
return
except (ConnectionRefusedError, TimeoutError): # noqa: PERF203
if i == attempt_cnt - 1:
raise
time.sleep(timeout)
def update_apt_repo(host: RemoteHost) -> None:
time.sleep(5)
host.run_cmd("sudo systemctl stop apt-daily.service || true")
host.run_cmd("sudo systemctl stop unattended-upgrades.service || true")
host.run_cmd(
"while systemctl is-active --quiet apt-daily.service; do sleep 1; done"
)
host.run_cmd(
"while systemctl is-active --quiet unattended-upgrades.service; do sleep 1; done"
)
host.run_cmd("sudo apt-get update")
time.sleep(3)
host.run_cmd("sudo apt-get update")
def install_condaforge(
host: RemoteHost, suffix: str = "latest/download/Miniforge3-Linux-aarch64.sh"
) -> None:
print("Install conda-forge")
host.run_cmd(f"curl -OL https://github.com/conda-forge/miniforge/releases/{suffix}")
host.run_cmd(f"sh -f {os.path.basename(suffix)} -b")
host.run_cmd(f"rm -f {os.path.basename(suffix)}")
if host.using_docker():
host.run_cmd("echo 'PATH=$HOME/miniforge3/bin:$PATH'>>.bashrc")
else:
host.run_cmd(
[
"sed",
"-i",
"'/^# If not running interactively.*/i PATH=$HOME/miniforge3/bin:$PATH'",
".bashrc",
]
)
def install_condaforge_python(host: RemoteHost, python_version="3.8") -> None:
if python_version == "3.6":
# Python-3.6 EOLed and not compatible with conda-4.11
install_condaforge(
host, suffix="download/4.10.3-10/Miniforge3-4.10.3-10-Linux-aarch64.sh"
)
host.run_cmd(f"conda install -y python={python_version} numpy pyyaml")
else:
install_condaforge(
host, suffix="download/4.11.0-4/Miniforge3-4.11.0-4-Linux-aarch64.sh"
)
# Pytorch-1.10 or older are not compatible with setuptools=59.6 or newer
host.run_cmd(
f"conda install -y python={python_version} numpy pyyaml setuptools>=59.5.0"
)
def embed_libgomp(host: RemoteHost, use_conda, wheel_name) -> None:
host.run_cmd("pip3 install auditwheel")
host.run_cmd(
"conda install -y patchelf" if use_conda else "sudo apt-get install -y patchelf"
)
from tempfile import NamedTemporaryFile
with NamedTemporaryFile() as tmp:
tmp.write(embed_library_script.encode("utf-8"))
tmp.flush()
host.upload_file(tmp.name, "embed_library.py")
print("Embedding libgomp into wheel")
if host.using_docker():
host.run_cmd(f"python3 embed_library.py {wheel_name} --update-tag")
else:
host.run_cmd(f"python3 embed_library.py {wheel_name}")
def checkout_repo(
host: RemoteHost,
*,
branch: str = "main",
url: str,
git_clone_flags: str,
mapping: dict[str, tuple[str, str]],
) -> Optional[str]:
for prefix in mapping:
if not branch.startswith(prefix):
continue
tag = f"v{mapping[prefix][0]}-{mapping[prefix][1]}"
host.run_cmd(f"git clone {url} -b {tag} {git_clone_flags}")
return mapping[prefix][0]
host.run_cmd(f"git clone {url} -b {branch} {git_clone_flags}")
return None
def build_torchvision(
host: RemoteHost,
*,
branch: str = "main",
use_conda: bool = True,
git_clone_flags: str,
run_smoke_tests: bool = True,
) -> str:
print("Checking out TorchVision repo")
build_version = checkout_repo(
host,
branch=branch,
url="https://github.com/pytorch/vision",
git_clone_flags=git_clone_flags,
mapping={
"v1.7.1": ("0.8.2", "rc2"),
"v1.8.0": ("0.9.0", "rc3"),
"v1.8.1": ("0.9.1", "rc1"),
"v1.9.0": ("0.10.0", "rc1"),
"v1.10.0": ("0.11.1", "rc1"),
"v1.10.1": ("0.11.2", "rc1"),
"v1.10.2": ("0.11.3", "rc1"),
"v1.11.0": ("0.12.0", "rc1"),
"v1.12.0": ("0.13.0", "rc4"),
"v1.12.1": ("0.13.1", "rc6"),
"v1.13.0": ("0.14.0", "rc4"),
"v1.13.1": ("0.14.1", "rc2"),
"v2.0.0": ("0.15.1", "rc2"),
"v2.0.1": ("0.15.2", "rc2"),
},
)
print("Building TorchVision wheel")
# Please note libnpg and jpeg are required to build image.so extension
if use_conda:
host.run_cmd("conda install -y libpng jpeg")
# Remove .so files to force static linking
host.run_cmd(
"rm miniforge3/lib/libpng.so miniforge3/lib/libpng16.so miniforge3/lib/libjpeg.so"
)
# And patch setup.py to include libz dependency for libpng
host.run_cmd(
[
'sed -i -e \'s/image_link_flags\\.append("png")/image_link_flags += ["png", "z"]/\' vision/setup.py'
]
)
build_vars = ""
if branch == "nightly":
version = host.check_output(
["if [ -f vision/version.txt ]; then cat vision/version.txt; fi"]
).strip()
if len(version) == 0:
# In older revisions, version was embedded in setup.py
version = (
host.check_output(["grep", '"version = \'"', "vision/setup.py"])
.strip()
.split("'")[1][:-2]
)
build_date = (
host.check_output("cd vision && git log --pretty=format:%s -1")
.strip()
.split()[0]
.replace("-", "")
)
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
elif build_version is not None:
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
if host.using_docker():
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
host.run_cmd(f"cd vision && {build_vars} python3 -m build --wheel --no-isolation")
vision_wheel_name = host.list_dir("vision/dist")[0]
embed_libgomp(host, use_conda, os.path.join("vision", "dist", vision_wheel_name))
print("Copying TorchVision wheel")
host.download_wheel(os.path.join("vision", "dist", vision_wheel_name))
if run_smoke_tests:
host.run_cmd(
f"pip3 install {os.path.join('vision', 'dist', vision_wheel_name)}"
)
host.run_cmd("python3 vision/test/smoke_test.py")
print("Delete vision checkout")
host.run_cmd("rm -rf vision")
return vision_wheel_name
def build_torchdata(
host: RemoteHost,
*,
branch: str = "main",
use_conda: bool = True,
git_clone_flags: str = "",
) -> str:
print("Checking out TorchData repo")
git_clone_flags += " --recurse-submodules"
build_version = checkout_repo(
host,
branch=branch,
url="https://github.com/pytorch/data",
git_clone_flags=git_clone_flags,
mapping={
"v1.13.1": ("0.5.1", ""),
"v2.0.0": ("0.6.0", "rc5"),
"v2.0.1": ("0.6.1", "rc1"),
},
)
print("Building TorchData wheel")
build_vars = ""
if branch == "nightly":
version = host.check_output(
["if [ -f data/version.txt ]; then cat data/version.txt; fi"]
).strip()
build_date = (
host.check_output("cd data && git log --pretty=format:%s -1")
.strip()
.split()[0]
.replace("-", "")
)
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
elif build_version is not None:
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
if host.using_docker():
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
host.run_cmd(f"cd data && {build_vars} python3 -m build --wheel --no-isolation")
wheel_name = host.list_dir("data/dist")[0]
embed_libgomp(host, use_conda, os.path.join("data", "dist", wheel_name))
print("Copying TorchData wheel")
host.download_wheel(os.path.join("data", "dist", wheel_name))
return wheel_name
def build_torchtext(
host: RemoteHost,
*,
branch: str = "main",
use_conda: bool = True,
git_clone_flags: str = "",
) -> str:
print("Checking out TorchText repo")
git_clone_flags += " --recurse-submodules"
build_version = checkout_repo(
host,
branch=branch,
url="https://github.com/pytorch/text",
git_clone_flags=git_clone_flags,
mapping={
"v1.9.0": ("0.10.0", "rc1"),
"v1.10.0": ("0.11.0", "rc2"),
"v1.10.1": ("0.11.1", "rc1"),
"v1.10.2": ("0.11.2", "rc1"),
"v1.11.0": ("0.12.0", "rc1"),
"v1.12.0": ("0.13.0", "rc2"),
"v1.12.1": ("0.13.1", "rc5"),
"v1.13.0": ("0.14.0", "rc3"),
"v1.13.1": ("0.14.1", "rc1"),
"v2.0.0": ("0.15.1", "rc2"),
"v2.0.1": ("0.15.2", "rc2"),
},
)
print("Building TorchText wheel")
build_vars = ""
if branch == "nightly":
version = host.check_output(
["if [ -f text/version.txt ]; then cat text/version.txt; fi"]
).strip()
build_date = (
host.check_output("cd text && git log --pretty=format:%s -1")
.strip()
.split()[0]
.replace("-", "")
)
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
elif build_version is not None:
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
if host.using_docker():
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
host.run_cmd(f"cd text && {build_vars} python3 -m build --wheel --no-isolation")
wheel_name = host.list_dir("text/dist")[0]
embed_libgomp(host, use_conda, os.path.join("text", "dist", wheel_name))
print("Copying TorchText wheel")
host.download_wheel(os.path.join("text", "dist", wheel_name))
return wheel_name
def build_torchaudio(
host: RemoteHost,
*,
branch: str = "main",
use_conda: bool = True,
git_clone_flags: str = "",
) -> str:
print("Checking out TorchAudio repo")
git_clone_flags += " --recurse-submodules"
build_version = checkout_repo(
host,
branch=branch,
url="https://github.com/pytorch/audio",
git_clone_flags=git_clone_flags,
mapping={
"v1.9.0": ("0.9.0", "rc2"),
"v1.10.0": ("0.10.0", "rc5"),
"v1.10.1": ("0.10.1", "rc1"),
"v1.10.2": ("0.10.2", "rc1"),
"v1.11.0": ("0.11.0", "rc1"),
"v1.12.0": ("0.12.0", "rc3"),
"v1.12.1": ("0.12.1", "rc5"),
"v1.13.0": ("0.13.0", "rc4"),
"v1.13.1": ("0.13.1", "rc2"),
"v2.0.0": ("2.0.1", "rc3"),
"v2.0.1": ("2.0.2", "rc2"),
},
)
print("Building TorchAudio wheel")
build_vars = ""
if branch == "nightly":
version = (
host.check_output(["grep", '"version = \'"', "audio/setup.py"])
.strip()
.split("'")[1][:-2]
)
build_date = (
host.check_output("cd audio && git log --pretty=format:%s -1")
.strip()
.split()[0]
.replace("-", "")
)
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
elif build_version is not None:
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
if host.using_docker():
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
host.run_cmd(
f"cd audio && export FFMPEG_ROOT=$(pwd)/third_party/ffmpeg && export USE_FFMPEG=1 \
&& ./packaging/ffmpeg/build.sh \
&& {build_vars} python3 -m build --wheel --no-isolation"
)
wheel_name = host.list_dir("audio/dist")[0]
embed_libgomp(host, use_conda, os.path.join("audio", "dist", wheel_name))
print("Copying TorchAudio wheel")
host.download_wheel(os.path.join("audio", "dist", wheel_name))
return wheel_name
def configure_system(
host: RemoteHost,
*,
compiler: str = "gcc-8",
use_conda: bool = True,
python_version: str = "3.8",
) -> None:
if use_conda:
install_condaforge_python(host, python_version)
print("Configuring the system")
if not host.using_docker():
update_apt_repo(host)
host.run_cmd("sudo apt-get install -y ninja-build g++ git cmake gfortran unzip")
else:
host.run_cmd("yum install -y sudo")
host.run_cmd("conda install -y ninja scons")
if not use_conda:
host.run_cmd(
"sudo apt-get install -y python3-dev python3-yaml python3-setuptools python3-wheel python3-pip"
)
host.run_cmd("pip3 install dataclasses typing-extensions")
if not use_conda:
print("Installing Cython + numpy from PyPy")
host.run_cmd("sudo pip3 install Cython")
host.run_cmd("sudo pip3 install numpy")
def build_domains(
host: RemoteHost,
*,
branch: str = "main",
use_conda: bool = True,
git_clone_flags: str = "",
) -> tuple[str, str, str, str]:
vision_wheel_name = build_torchvision(
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
)
audio_wheel_name = build_torchaudio(
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
)
data_wheel_name = build_torchdata(
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
)
text_wheel_name = build_torchtext(
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
)
return (vision_wheel_name, audio_wheel_name, data_wheel_name, text_wheel_name)
def start_build(
host: RemoteHost,
*,
branch: str = "main",
compiler: str = "gcc-8",
use_conda: bool = True,
python_version: str = "3.8",
pytorch_only: bool = False,
pytorch_build_number: Optional[str] = None,
shallow_clone: bool = True,
enable_mkldnn: bool = False,
) -> tuple[str, str, str, str, str]:
git_clone_flags = " --depth 1 --shallow-submodules" if shallow_clone else ""
if host.using_docker() and not use_conda:
print("Auto-selecting conda option for docker images")
use_conda = True
if not host.using_docker():
print("Disable mkldnn for host builds")
enable_mkldnn = False
configure_system(
host, compiler=compiler, use_conda=use_conda, python_version=python_version
)
if host.using_docker():
print("Move libgfortant.a into a standard location")
# HACK: pypa gforntran.a is compiled without PIC, which leads to the following error
# libgfortran.a(error.o)(.text._gfortrani_st_printf+0x34): unresolvable R_AARCH64_ADR_PREL_PG_HI21 relocation against symbol `__stack_chk_guard@@GLIBC_2.17' # noqa: E501, B950
# Workaround by copying gfortran library from the host
host.run_ssh_cmd("sudo apt-get install -y gfortran-8")
host.run_cmd("mkdir -p /usr/lib/gcc/aarch64-linux-gnu/8")
host.run_ssh_cmd(
[
"docker",
"cp",
"/usr/lib/gcc/aarch64-linux-gnu/8/libgfortran.a",
f"{host.container_id}:/opt/rh/devtoolset-10/root/usr/lib/gcc/aarch64-redhat-linux/10/",
]
)
print("Checking out PyTorch repo")
host.run_cmd(
f"git clone --recurse-submodules -b {branch} https://github.com/pytorch/pytorch {git_clone_flags}"
)
host.run_cmd("pytorch/.ci/docker/common/install_openblas.sh")
print("Building PyTorch wheel")
build_opts = ""
if pytorch_build_number is not None:
build_opts += f" -C--build-option=--build-number={pytorch_build_number}"
# Breakpad build fails on aarch64
build_vars = "USE_BREAKPAD=0 "
if branch == "nightly":
build_date = (
host.check_output("cd pytorch && git log --pretty=format:%s -1")
.strip()
.split()[0]
.replace("-", "")
)
version = host.check_output("cat pytorch/version.txt").strip()[:-2]
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date} PYTORCH_BUILD_NUMBER=1"
if branch.startswith(("v1.", "v2.")):
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={branch[1 : branch.find('-')]} PYTORCH_BUILD_NUMBER=1"
if host.using_docker():
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
if enable_mkldnn:
host.run_cmd("pytorch/.ci/docker/common/install_acl.sh")
print("build pytorch with mkldnn+acl backend")
build_vars += " USE_MKLDNN=ON USE_MKLDNN_ACL=ON"
build_vars += " BLAS=OpenBLAS"
build_vars += " OpenBLAS_HOME=/opt/OpenBLAS"
build_vars += " ACL_ROOT_DIR=/acl"
host.run_cmd(
f"cd $HOME/pytorch && {build_vars} python3 -m build --wheel --no-isolation{build_opts}"
)
print("Repair the wheel")
pytorch_wheel_name = host.list_dir("pytorch/dist")[0]
ld_library_path = "/acl/build:$HOME/pytorch/build/lib"
host.run_cmd(
f"export LD_LIBRARY_PATH={ld_library_path} && auditwheel repair $HOME/pytorch/dist/{pytorch_wheel_name}"
)
print("replace the original wheel with the repaired one")
pytorch_repaired_wheel_name = host.list_dir("wheelhouse")[0]
host.run_cmd(
f"cp $HOME/wheelhouse/{pytorch_repaired_wheel_name} $HOME/pytorch/dist/{pytorch_wheel_name}"
)
else:
print("build pytorch without mkldnn backend")
host.run_cmd(
f"cd pytorch && {build_vars} python3 -m build --wheel --no-isolation{build_opts}"
)
print("Deleting build folder")
host.run_cmd("cd pytorch && rm -rf build")
pytorch_wheel_name = host.list_dir("pytorch/dist")[0]
embed_libgomp(host, use_conda, os.path.join("pytorch", "dist", pytorch_wheel_name))
print("Copying the wheel")
host.download_wheel(os.path.join("pytorch", "dist", pytorch_wheel_name))
print("Installing PyTorch wheel")
host.run_cmd(f"pip3 install pytorch/dist/{pytorch_wheel_name}")
if pytorch_only:
return (pytorch_wheel_name, None, None, None, None)
domain_wheels = build_domains(
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
)
return (pytorch_wheel_name, *domain_wheels)
embed_library_script = """
#!/usr/bin/env python3
from auditwheel.patcher import Patchelf
from auditwheel.wheeltools import InWheelCtx
from auditwheel.elfutils import elf_file_filter
from auditwheel.repair import copylib
from auditwheel.lddtree import lddtree
from subprocess import check_call
import os
import shutil
import sys
from tempfile import TemporaryDirectory
def replace_tag(filename):
with open(filename, 'r') as f:
lines = f.read().split("\\n")
for i,line in enumerate(lines):
if not line.startswith("Tag: "):
continue
lines[i] = line.replace("-linux_", "-manylinux2014_")
print(f'Updated tag from {line} to {lines[i]}')
with open(filename, 'w') as f:
f.write("\\n".join(lines))
class AlignedPatchelf(Patchelf):
def set_soname(self, file_name: str, new_soname: str) -> None:
check_call(['patchelf', '--page-size', '65536', '--set-soname', new_soname, file_name])
def replace_needed(self, file_name: str, soname: str, new_soname: str) -> None:
check_call(['patchelf', '--page-size', '65536', '--replace-needed', soname, new_soname, file_name])
def embed_library(whl_path, lib_soname, update_tag=False):
patcher = AlignedPatchelf()
out_dir = TemporaryDirectory()
whl_name = os.path.basename(whl_path)
tmp_whl_name = os.path.join(out_dir.name, whl_name)
with InWheelCtx(whl_path) as ctx:
torchlib_path = os.path.join(ctx._tmpdir.name, 'torch', 'lib')
ctx.out_wheel=tmp_whl_name
new_lib_path, new_lib_soname = None, None
for filename, elf in elf_file_filter(ctx.iter_files()):
if not filename.startswith('torch/lib'):
continue
libtree = lddtree(filename)
if lib_soname not in libtree['needed']:
continue
lib_path = libtree['libs'][lib_soname]['path']
if lib_path is None:
print(f"Can't embed {lib_soname} as it could not be found")
break
if lib_path.startswith(torchlib_path):
continue
if new_lib_path is None:
new_lib_soname, new_lib_path = copylib(lib_path, torchlib_path, patcher)
patcher.replace_needed(filename, lib_soname, new_lib_soname)
print(f'Replacing {lib_soname} with {new_lib_soname} for {filename}')
if update_tag:
# Add manylinux2014 tag
for filename in ctx.iter_files():
if os.path.basename(filename) != 'WHEEL':
continue
replace_tag(filename)
shutil.move(tmp_whl_name, whl_path)
if __name__ == '__main__':
embed_library(sys.argv[1], 'libgomp.so.1', len(sys.argv) > 2 and sys.argv[2] == '--update-tag')
"""
def run_tests(host: RemoteHost, whl: str, branch="main") -> None:
print("Configuring the system")
update_apt_repo(host)
host.run_cmd("sudo apt-get install -y python3-pip git")
host.run_cmd("sudo pip3 install Cython")
host.run_cmd("sudo pip3 install numpy")
host.upload_file(whl, ".")
host.run_cmd(f"sudo pip3 install {whl}")
host.run_cmd("python3 -c 'import torch;print(torch.rand((3,3))'")
host.run_cmd(f"git clone -b {branch} https://github.com/pytorch/pytorch")
host.run_cmd("cd pytorch/test; python3 test_torch.py -v")
def get_instance_name(instance) -> Optional[str]:
if instance.tags is None:
return None
for tag in instance.tags:
if tag["Key"] == "Name":
return tag["Value"]
return None
def list_instances(instance_type: str) -> None:
print(f"All instances of type {instance_type}")
for instance in ec2_instances_of_type(instance_type):
ifaces = instance.network_interfaces
az = ifaces[0].subnet.availability_zone if len(ifaces) > 0 else None
print(
f"{instance.id} {get_instance_name(instance)} {instance.public_dns_name} {instance.state['Name']} {az}"
)
def terminate_instances(instance_type: str) -> None:
print(f"Terminating all instances of type {instance_type}")
instances = list(ec2_instances_of_type(instance_type))
for instance in instances:
print(f"Terminating {instance.id}")
instance.terminate()
print("Waiting for termination to complete")
for instance in instances:
instance.wait_until_terminated()
def parse_arguments():
from argparse import ArgumentParser
parser = ArgumentParser("Build and test AARCH64 wheels using EC2")
parser.add_argument("--key-name", type=str)
parser.add_argument("--debug", action="store_true")
parser.add_argument("--build-only", action="store_true")
parser.add_argument("--test-only", type=str)
group = parser.add_mutually_exclusive_group()
group.add_argument("--os", type=str, choices=list(os_amis.keys()))
group.add_argument("--ami", type=str)
parser.add_argument(
"--python-version",
type=str,
choices=[f"3.{d}" for d in range(6, 12)],
default=None,
)
parser.add_argument("--alloc-instance", action="store_true")
parser.add_argument("--list-instances", action="store_true")
parser.add_argument("--pytorch-only", action="store_true")
parser.add_argument("--keep-running", action="store_true")
parser.add_argument("--terminate-instances", action="store_true")
parser.add_argument("--instance-type", type=str, default="t4g.2xlarge")
parser.add_argument("--ebs-size", type=int, default=50)
parser.add_argument("--branch", type=str, default="main")
parser.add_argument("--use-docker", action="store_true")
parser.add_argument(
"--compiler",
type=str,
choices=["gcc-7", "gcc-8", "gcc-9", "clang"],
default="gcc-8",
)
parser.add_argument("--use-torch-from-pypi", action="store_true")
parser.add_argument("--pytorch-build-number", type=str, default=None)
parser.add_argument("--disable-mkldnn", action="store_true")
return parser.parse_args()
if __name__ == "__main__":
args = parse_arguments()
ami = (
args.ami
if args.ami is not None
else os_amis[args.os]
if args.os is not None
else ubuntu20_04_ami
)
keyfile_path, key_name = compute_keyfile_path(args.key_name)
if args.list_instances:
list_instances(args.instance_type)
sys.exit(0)
if args.terminate_instances:
terminate_instances(args.instance_type)
sys.exit(0)
if len(key_name) == 0:
raise RuntimeError("""
Cannot start build without key_name, please specify
--key-name argument or AWS_KEY_NAME environment variable.""")
if len(keyfile_path) == 0 or not os.path.exists(keyfile_path):
raise RuntimeError(f"""
Cannot find keyfile with name: [{key_name}] in path: [{keyfile_path}], please
check `~/.ssh/` folder or manually set SSH_KEY_PATH environment variable.""")
# Starting the instance
inst = start_instance(
key_name, ami=ami, instance_type=args.instance_type, ebs_size=args.ebs_size
)
instance_name = f"{args.key_name}-{args.os}"
if args.python_version is not None:
instance_name += f"-py{args.python_version}"
inst.create_tags(
DryRun=False,
Tags=[
{
"Key": "Name",
"Value": instance_name,
}
],
)
addr = inst.public_dns_name
wait_for_connection(addr, 22)
host = RemoteHost(addr, keyfile_path)
host.ami = ami
if args.use_docker:
update_apt_repo(host)
host.start_docker()
if args.test_only:
run_tests(host, args.test_only)
sys.exit(0)
if args.alloc_instance:
if args.python_version is None:
sys.exit(0)
install_condaforge_python(host, args.python_version)
sys.exit(0)
python_version = args.python_version if args.python_version is not None else "3.10"
if args.use_torch_from_pypi:
configure_system(host, compiler=args.compiler, python_version=python_version)
print("Installing PyTorch wheel")
host.run_cmd("pip3 install torch")
build_domains(
host, branch=args.branch, git_clone_flags=" --depth 1 --shallow-submodules"
)
else:
start_build(
host,
branch=args.branch,
compiler=args.compiler,
python_version=python_version,
pytorch_only=args.pytorch_only,
pytorch_build_number=args.pytorch_build_number,
enable_mkldnn=not args.disable_mkldnn,
)
if not args.keep_running:
print(f"Waiting for instance {inst.id} to terminate")
inst.terminate()
inst.wait_until_terminated()

View File

@ -1,87 +0,0 @@
#!/usr/bin/env python3
import os
import shutil
import sys
from subprocess import check_call
from tempfile import TemporaryDirectory
from auditwheel.elfutils import elf_file_filter
from auditwheel.lddtree import lddtree
from auditwheel.patcher import Patchelf
from auditwheel.repair import copylib
from auditwheel.wheeltools import InWheelCtx
def replace_tag(filename):
with open(filename) as f:
lines = f.read().split("\\n")
for i, line in enumerate(lines):
if not line.startswith("Tag: "):
continue
lines[i] = line.replace("-linux_", "-manylinux2014_")
print(f"Updated tag from {line} to {lines[i]}")
with open(filename, "w") as f:
f.write("\\n".join(lines))
class AlignedPatchelf(Patchelf):
def set_soname(self, file_name: str, new_soname: str) -> None:
check_call(
["patchelf", "--page-size", "65536", "--set-soname", new_soname, file_name]
)
def replace_needed(self, file_name: str, soname: str, new_soname: str) -> None:
check_call(
[
"patchelf",
"--page-size",
"65536",
"--replace-needed",
soname,
new_soname,
file_name,
]
)
def embed_library(whl_path, lib_soname, update_tag=False):
patcher = AlignedPatchelf()
out_dir = TemporaryDirectory()
whl_name = os.path.basename(whl_path)
tmp_whl_name = os.path.join(out_dir.name, whl_name)
with InWheelCtx(whl_path) as ctx:
torchlib_path = os.path.join(ctx._tmpdir.name, "torch", "lib")
ctx.out_wheel = tmp_whl_name
new_lib_path, new_lib_soname = None, None
for filename, _ in elf_file_filter(ctx.iter_files()):
if not filename.startswith("torch/lib"):
continue
libtree = lddtree(filename)
if lib_soname not in libtree["needed"]:
continue
lib_path = libtree["libs"][lib_soname]["path"]
if lib_path is None:
print(f"Can't embed {lib_soname} as it could not be found")
break
if lib_path.startswith(torchlib_path):
continue
if new_lib_path is None:
new_lib_soname, new_lib_path = copylib(lib_path, torchlib_path, patcher)
patcher.replace_needed(filename, lib_soname, new_lib_soname)
print(f"Replacing {lib_soname} with {new_lib_soname} for {filename}")
if update_tag:
# Add manylinux2014 tag
for filename in ctx.iter_files():
if os.path.basename(filename) != "WHEEL":
continue
replace_tag(filename)
shutil.move(tmp_whl_name, whl_path)
if __name__ == "__main__":
embed_library(
sys.argv[1], "libgomp.so.1", len(sys.argv) > 2 and sys.argv[2] == "--update-tag"
)

View File

@ -4,14 +4,17 @@ set -ex
SCRIPTPATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
# Source the common build script for architecture-specific configurations (MKLDNN, ACL, etc.)
source "${SCRIPTPATH}/../pytorch/build.sh" || true
case "${GPU_ARCH_TYPE:-BLANK}" in
cuda)
cuda | cuda-aarch64)
bash "${SCRIPTPATH}/build_cuda.sh"
;;
rocm)
bash "${SCRIPTPATH}/build_rocm.sh"
;;
cpu | cpu-cxx11-abi | cpu-s390x)
cpu | cpu-cxx11-abi | cpu-aarch64 | cpu-s390x)
bash "${SCRIPTPATH}/build_cpu.sh"
;;
xpu)

View File

@ -18,12 +18,31 @@ retry () {
$* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*)
}
# Detect architecture first
ARCH=$(uname -m)
echo "Detected architecture: $ARCH"
PLATFORM=""
# TODO move this into the Docker images
OS_NAME=$(awk -F= '/^NAME/{print $2}' /etc/os-release)
if [[ "$OS_NAME" == *"AlmaLinux"* ]]; then
retry yum install -q -y zip openssl
PLATFORM="manylinux_2_28_x86_64"
# Set platform based on architecture
case $ARCH in
x86_64)
PLATFORM="manylinux_2_28_x86_64"
;;
aarch64)
PLATFORM="manylinux_2_28_aarch64"
;;
s390x)
PLATFORM="manylinux_2_28_s390x"
;;
*)
echo "Unsupported architecture: $ARCH"
exit 1
;;
esac
elif [[ "$OS_NAME" == *"Red Hat Enterprise Linux"* ]]; then
retry dnf install -q -y zip openssl
elif [[ "$OS_NAME" == *"Ubuntu"* ]]; then
@ -38,6 +57,8 @@ else
exit 1
fi
echo "Platform set to: $PLATFORM"
# We use the package name to test the package by passing this to 'pip install'
# This is the env variable that setup.py uses to name the package. Note that
# pip 'normalizes' the name first by changing all - to _
@ -299,8 +320,8 @@ for pkg in /$WHEELHOUSE_DIR/torch_no_python*.whl /$WHEELHOUSE_DIR/torch*linux*.w
# ROCm workaround for roctracer dlopens
if [[ "$DESIRED_CUDA" == *"rocm"* ]]; then
patchedpath=$(fname_without_so_number $destpath)
# Keep the so number for XPU dependencies and libgomp.so.1 to avoid twice load
elif [[ "$DESIRED_CUDA" == *"xpu"* || "$filename" == "libgomp.so.1" ]]; then
# Keep the so number for XPU dependencies, libgomp.so.1, ACL libraries, and NVPL libraries to avoid twice load
elif [[ "$DESIRED_CUDA" == *"xpu"* || "$filename" == "libgomp.so.1" || "$filename" == libarm_compute* || "$filename" == libnvpl* || "$filename" == "libgfortran.so.5" ]]; then
patchedpath=$destpath
else
patchedpath=$(fname_with_sha256 $destpath)
@ -346,9 +367,22 @@ for pkg in /$WHEELHOUSE_DIR/torch_no_python*.whl /$WHEELHOUSE_DIR/torch*linux*.w
done
# create Manylinux 2_28 tag this needs to happen before regenerate the RECORD
if [[ $PLATFORM == "manylinux_2_28_x86_64" && $GPU_ARCH_TYPE != "cpu-s390x" && $GPU_ARCH_TYPE != "xpu" ]]; then
# Support all architectures (x86_64, aarch64, s390x)
if [[ "$IS_MANYLINUX2_28" == "1" && $GPU_ARCH_TYPE != "xpu" ]]; then
wheel_file=$(echo $(basename $pkg) | sed -e 's/-cp.*$/.dist-info\/WHEEL/g')
sed -i -e s#linux_x86_64#"${PLATFORM}"# $wheel_file;
echo "Updating wheel tag for $ARCH architecture"
# Replace linux_* with manylinux_2_28_* based on architecture
case $ARCH in
x86_64)
sed -i -e 's#linux_x86_64#manylinux_2_28_x86_64#g' $wheel_file
;;
aarch64)
sed -i -e 's#linux_aarch64#manylinux_2_28_aarch64#g' $wheel_file
;;
s390x)
sed -i -e 's#linux_s390x#manylinux_2_28_s390x#g' $wheel_file
;;
esac
fi
# regenerate the RECORD file with new hashes

View File

@ -15,6 +15,10 @@ if [[ -z "$EXTRA_CAFFE2_CMAKE_FLAGS" ]]; then
EXTRA_CAFFE2_CMAKE_FLAGS=()
fi
# Detect architecture
ARCH=$(uname -m)
echo "Building CPU wheel for architecture: $ARCH"
WHEELHOUSE_DIR="wheelhousecpu"
LIBTORCH_HOUSE_DIR="libtorch_housecpu"
if [[ -z "$PYTORCH_FINAL_PACKAGE_DIR" ]]; then
@ -34,8 +38,10 @@ elif [[ "$OS_NAME" == *"Red Hat Enterprise Linux"* ]]; then
elif [[ "$OS_NAME" == *"AlmaLinux"* ]]; then
LIBGOMP_PATH="/usr/lib64/libgomp.so.1"
elif [[ "$OS_NAME" == *"Ubuntu"* ]]; then
if [[ "$(uname -m)" == "s390x" ]]; then
if [[ "$ARCH" == "s390x" ]]; then
LIBGOMP_PATH="/usr/lib/s390x-linux-gnu/libgomp.so.1"
elif [[ "$ARCH" == "aarch64" ]]; then
LIBGOMP_PATH="/usr/lib/aarch64-linux-gnu/libgomp.so.1"
else
LIBGOMP_PATH="/usr/lib/x86_64-linux-gnu/libgomp.so.1"
fi
@ -49,6 +55,34 @@ DEPS_SONAME=(
"libgomp.so.1"
)
# Add ARM-specific library dependencies for CPU builds
if [[ "$ARCH" == "aarch64" ]]; then
echo "Adding ARM-specific CPU library dependencies"
# ARM Compute Library (if available)
if [[ -d "/acl/build" ]]; then
echo "Adding ARM Compute Library for CPU"
DEPS_LIST+=(
"/acl/build/libarm_compute.so"
"/acl/build/libarm_compute_graph.so"
)
DEPS_SONAME+=(
"libarm_compute.so"
"libarm_compute_graph.so"
)
fi
# ARM system libraries
DEPS_LIST+=(
"/usr/lib64/libgfortran.so.5"
"/opt/OpenBLAS/lib/libopenblas.so.0"
)
DEPS_SONAME+=(
"libgfortran.so.5"
"libopenblas.so.0"
)
fi
rm -rf /usr/local/cuda*
SOURCE_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null && pwd )"

View File

@ -29,6 +29,10 @@ if [[ -z "$EXTRA_CAFFE2_CMAKE_FLAGS" ]]; then
EXTRA_CAFFE2_CMAKE_FLAGS=()
fi
# Detect architecture
ARCH=$(uname -m)
echo "Building for architecture: $ARCH"
# Determine CUDA version and architectures to build for
#
# NOTE: We should first check `DESIRED_CUDA` when determining `CUDA_VERSION`,
@ -53,34 +57,60 @@ fi
cuda_version_nodot=$(echo $CUDA_VERSION | tr -d '.')
EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON")
# Function to remove architectures from a list
remove_archs() {
local result="$1"
shift
for arch in "$@"; do
result="${result//${arch};/}"
done
echo "$result"
}
# Function to filter CUDA architectures for aarch64
# aarch64 ARM GPUs only support certain compute capabilities
# Keep: 8.0 (A100), 9.0+ (Hopper, Grace Hopper, newer)
# Remove: < 8.0 (no ARM GPUs), 8.6 (x86_64 RTX 3090/A6000 only)
filter_aarch64_archs() {
local arch_list="$1"
# Explicitly remove architectures not needed on aarch64
arch_list=$(remove_archs "$arch_list" "5.0" "6.0" "7.0" "7.5" "8.6")
echo "$arch_list"
}
# Base: Common architectures across all modern CUDA versions
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;9.0"
case ${CUDA_VERSION} in
#removing sm_50-sm_60 as these architectures are deprecated in CUDA 12.8/9 and will be removed in future releases
#however we would like to keep sm_70 architecture see: https://github.com/pytorch/pytorch/issues/157517
12.8)
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;9.0;10.0;12.0"
;;
12.9)
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;9.0;10.0;12.0+PTX"
# WAR to resolve the ld error in libtorch build with CUDA 12.9
12.6) TORCH_CUDA_ARCH_LIST="5.0;6.0;${TORCH_CUDA_ARCH_LIST}" ;; # Only 12.6 includes Legacy Maxwell/Pascal that will be removed in future releases
12.8) TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};10.0;12.0" ;; # +Hopper/Blackwell support
12.9) TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};10.0;12.0+PTX" # +Hopper/Blackwell support + PTX for forward compatibility
if [[ "$PACKAGE_TYPE" == "libtorch" ]]; then
TORCH_CUDA_ARCH_LIST="7.5;8.0;9.0;10.0;12.0+PTX"
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST//7.0;/}" # Remove 7.0 to resolve the ld error
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST//8.6;/}" # Remove 8.6 for libtorch
fi
;;
13.0)
TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6;9.0;10.0;12.0+PTX"
;;
12.6)
TORCH_CUDA_ARCH_LIST="5.0;6.0;7.0;7.5;8.0;8.6;9.0"
;;
*)
echo "unknown cuda version $CUDA_VERSION"
exit 1
TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6;9.0;10.0;$([[ "$ARCH" == "aarch64" ]] && echo "11.0;" || echo "")12.0+PTX"
export TORCH_NVCC_FLAGS="-compress-mode=size"
export BUILD_BUNDLE_PTXAS=1
;;
*) echo "unknown cuda version $CUDA_VERSION"; exit 1 ;;
esac
# Filter for aarch64: Remove < 8.0 and 8.6
[[ "$ARCH" == "aarch64" ]] && TORCH_CUDA_ARCH_LIST=$(filter_aarch64_archs "$TORCH_CUDA_ARCH_LIST")
echo "TORCH_CUDA_ARCH_LIST set to: $TORCH_CUDA_ARCH_LIST"
export TORCH_CUDA_ARCH_LIST=${TORCH_CUDA_ARCH_LIST}
echo "${TORCH_CUDA_ARCH_LIST}"
# Disable MAGMA for aarch64 as pre-built libraries are x86-64 only
if [[ "$ARCH" == "aarch64" ]]; then
echo "Disabling MAGMA for aarch64 architecture"
export USE_MAGMA=0
fi
# Package directories
WHEELHOUSE_DIR="wheelhouse$cuda_version_nodot"
LIBTORCH_HOUSE_DIR="libtorch_house$cuda_version_nodot"
@ -244,6 +274,51 @@ else
exit 1
fi
# Add ARM-specific library dependencies
if [[ "$ARCH" == "aarch64" ]]; then
echo "Adding ARM-specific library dependencies"
# ARM Compute Library (if available)
if [[ -d "/acl/build" ]]; then
echo "Adding ARM Compute Library"
DEPS_LIST+=(
"/acl/build/libarm_compute.so"
"/acl/build/libarm_compute_graph.so"
)
DEPS_SONAME+=(
"libarm_compute.so"
"libarm_compute_graph.so"
)
fi
# ARM system libraries
DEPS_LIST+=(
"/lib64/libgomp.so.1"
"/usr/lib64/libgfortran.so.5"
)
DEPS_SONAME+=(
"libgomp.so.1"
"libgfortran.so.5"
)
# NVPL libraries (ARM optimized BLAS/LAPACK)
if [[ -d "/usr/local/lib" && -f "/usr/local/lib/libnvpl_blas_lp64_gomp.so.0" ]]; then
echo "Adding NVPL libraries for ARM"
DEPS_LIST+=(
"/usr/local/lib/libnvpl_lapack_lp64_gomp.so.0"
"/usr/local/lib/libnvpl_blas_lp64_gomp.so.0"
"/usr/local/lib/libnvpl_lapack_core.so.0"
"/usr/local/lib/libnvpl_blas_core.so.0"
)
DEPS_SONAME+=(
"libnvpl_lapack_lp64_gomp.so.0"
"libnvpl_blas_lp64_gomp.so.0"
"libnvpl_lapack_core.so.0"
"libnvpl_blas_core.so.0"
)
fi
fi
# run_tests.sh requires DESIRED_CUDA to know what tests to exclude
export DESIRED_CUDA="$cuda_version_nodot"
@ -251,9 +326,11 @@ export DESIRED_CUDA="$cuda_version_nodot"
rm -rf /usr/local/cuda || true
ln -s "/usr/local/cuda-${CUDA_VERSION}" /usr/local/cuda
# Switch `/usr/local/magma` to the desired CUDA version
rm -rf /usr/local/magma || true
ln -s /usr/local/cuda-${CUDA_VERSION}/magma /usr/local/magma
# Switch `/usr/local/magma` to the desired CUDA version (skip for aarch64)
if [[ "$ARCH" != "aarch64" ]]; then
rm -rf /usr/local/magma || true
ln -s /usr/local/cuda-${CUDA_VERSION}/magma /usr/local/magma
fi
export CUDA_VERSION=$(ls /usr/local/cuda/lib64/libcudart.so.*|sort|tac | head -1 | rev | cut -d"." -f -3 | rev) # 10.0.130
export CUDA_VERSION_SHORT=$(ls /usr/local/cuda/lib64/libcudart.so.*|sort|tac | head -1 | rev | cut -d"." -f -3 | rev | cut -f1,2 -d".") # 10.0

View File

@ -86,10 +86,20 @@ else
fi
fi
# Enable MKLDNN with ARM Compute Library for ARM builds
if [[ "$BUILD_ENVIRONMENT" == *aarch64* ]]; then
export USE_MKLDNN=1
# ACL is required for aarch64 builds
if [[ ! -d "/acl" ]]; then
echo "ERROR: ARM Compute Library not found at /acl"
echo "ACL is required for aarch64 builds. Check Docker image setup."
exit 1
fi
export USE_MKLDNN_ACL=1
export ACL_ROOT_DIR=/acl
echo "ARM Compute Library enabled for MKLDNN: ACL_ROOT_DIR=/acl"
fi
if [[ "$BUILD_ENVIRONMENT" == *riscv64* ]]; then

View File

@ -100,337 +100,6 @@ def check_lib_statically_linked_libstdc_cxx_abi_symbols(lib: str) -> None:
)
def _compile_and_extract_symbols(
cpp_content: str, compile_flags: list[str], exclude_list: list[str] | None = None
) -> list[str]:
"""
Helper to compile a C++ file and extract all symbols.
Args:
cpp_content: C++ source code to compile
compile_flags: Compilation flags
exclude_list: List of symbol names to exclude. Defaults to ["main"].
Returns:
List of all symbols found in the object file (excluding those in exclude_list).
"""
import subprocess
import tempfile
if exclude_list is None:
exclude_list = ["main"]
with tempfile.TemporaryDirectory() as tmpdir:
tmppath = Path(tmpdir)
cpp_file = tmppath / "test.cpp"
obj_file = tmppath / "test.o"
cpp_file.write_text(cpp_content)
result = subprocess.run(
compile_flags + [str(cpp_file), "-o", str(obj_file)],
capture_output=True,
text=True,
timeout=60,
)
if result.returncode != 0:
raise RuntimeError(f"Compilation failed: {result.stderr}")
symbols = get_symbols(str(obj_file))
# Return all symbol names, excluding those in the exclude list
return [name for _addr, _stype, name in symbols if name not in exclude_list]
def check_stable_only_symbols(install_root: Path) -> None:
"""
Test TORCH_STABLE_ONLY and TORCH_TARGET_VERSION by compiling test code and comparing symbol counts.
This approach tests:
1. WITHOUT macros -> many torch symbols exposed
2. WITH TORCH_STABLE_ONLY -> zero torch symbols (all hidden)
3. WITH TORCH_TARGET_VERSION -> zero torch symbols (all hidden)
4. WITH both macros -> zero torch symbols (all hidden)
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
test_cpp_content = """
// Main torch C++ API headers
#include <torch/torch.h>
#include <torch/all.h>
// ATen tensor library
#include <ATen/ATen.h>
// Core c10 headers (commonly used)
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/ScalarType.h>
#include <c10/core/TensorOptions.h>
#include <c10/util/Optional.h>
int main() { return 0; }
"""
base_compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c", # Compile only, don't link
]
# Compile WITHOUT any macros
symbols_without = _compile_and_extract_symbols(
cpp_content=test_cpp_content,
compile_flags=base_compile_flags,
)
# We expect constexpr symbols, inline functions used by other headers etc.
# to produce symbols
num_symbols_without = len(symbols_without)
print(f"Found {num_symbols_without} symbols without any macros defined")
assert num_symbols_without != 0, (
"Expected a non-zero number of symbols without any macros"
)
# Compile WITH TORCH_STABLE_ONLY (expect 0 symbols)
compile_flags_with_stable_only = base_compile_flags + ["-DTORCH_STABLE_ONLY"]
symbols_with_stable_only = _compile_and_extract_symbols(
cpp_content=test_cpp_content,
compile_flags=compile_flags_with_stable_only,
)
num_symbols_with_stable_only = len(symbols_with_stable_only)
assert num_symbols_with_stable_only == 0, (
f"Expected no symbols with TORCH_STABLE_ONLY macro, but found {num_symbols_with_stable_only}"
)
# Compile WITH TORCH_TARGET_VERSION (expect 0 symbols)
compile_flags_with_target_version = base_compile_flags + [
"-DTORCH_TARGET_VERSION=1"
]
symbols_with_target_version = _compile_and_extract_symbols(
cpp_content=test_cpp_content,
compile_flags=compile_flags_with_target_version,
)
num_symbols_with_target_version = len(symbols_with_target_version)
assert num_symbols_with_target_version == 0, (
f"Expected no symbols with TORCH_TARGET_VERSION macro, but found {num_symbols_with_target_version}"
)
# Compile WITH both macros (expect 0 symbols)
compile_flags_with_both = base_compile_flags + [
"-DTORCH_STABLE_ONLY",
"-DTORCH_TARGET_VERSION=1",
]
symbols_with_both = _compile_and_extract_symbols(
cpp_content=test_cpp_content,
compile_flags=compile_flags_with_both,
)
num_symbols_with_both = len(symbols_with_both)
assert num_symbols_with_both == 0, (
f"Expected no symbols with both macros, but found {num_symbols_with_both}"
)
def check_stable_api_symbols(install_root: Path) -> None:
"""
Test that stable API headers still expose symbols with TORCH_STABLE_ONLY.
The torch/csrc/stable/c/shim.h header is tested in check_stable_c_shim_symbols
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
stable_dir = include_dir / "torch" / "csrc" / "stable"
assert stable_dir.exists(), f"Expected {stable_dir} to be present"
stable_headers = list(stable_dir.rglob("*.h"))
if not stable_headers:
raise RuntimeError("Could not find any stable headers")
includes = []
for header in stable_headers:
rel_path = header.relative_to(include_dir)
includes.append(f"#include <{rel_path.as_posix()}>")
includes_str = "\n".join(includes)
test_stable_content = f"""
{includes_str}
int main() {{ return 0; }}
"""
compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c",
"-DTORCH_STABLE_ONLY",
]
symbols_stable = _compile_and_extract_symbols(
cpp_content=test_stable_content,
compile_flags=compile_flags,
)
num_symbols_stable = len(symbols_stable)
print(f"Found {num_symbols_stable} symbols in torch/csrc/stable")
assert num_symbols_stable > 0, (
f"Expected stable headers to expose symbols with TORCH_STABLE_ONLY, "
f"but found {num_symbols_stable} symbols"
)
def check_headeronly_symbols(install_root: Path) -> None:
"""
Test that header-only utility headers still expose symbols with TORCH_STABLE_ONLY.
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
# Find all headers in torch/headeronly
headeronly_dir = include_dir / "torch" / "headeronly"
assert headeronly_dir.exists(), f"Expected {headeronly_dir} to be present"
headeronly_headers = list(headeronly_dir.rglob("*.h"))
if not headeronly_headers:
raise RuntimeError("Could not find any headeronly headers")
# Filter out platform-specific headers that may not compile everywhere
platform_specific_keywords = [
"cpu/vec",
]
filtered_headers = []
for header in headeronly_headers:
rel_path = header.relative_to(include_dir).as_posix()
if not any(
keyword in rel_path.lower() for keyword in platform_specific_keywords
):
filtered_headers.append(header)
includes = []
for header in filtered_headers:
rel_path = header.relative_to(include_dir)
includes.append(f"#include <{rel_path.as_posix()}>")
includes_str = "\n".join(includes)
test_headeronly_content = f"""
{includes_str}
int main() {{ return 0; }}
"""
compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c",
"-DTORCH_STABLE_ONLY",
]
symbols_headeronly = _compile_and_extract_symbols(
cpp_content=test_headeronly_content,
compile_flags=compile_flags,
)
num_symbols_headeronly = len(symbols_headeronly)
print(f"Found {num_symbols_headeronly} symbols in torch/headeronly")
assert num_symbols_headeronly > 0, (
f"Expected headeronly headers to expose symbols with TORCH_STABLE_ONLY, "
f"but found {num_symbols_headeronly} symbols"
)
def check_aoti_shim_symbols(install_root: Path) -> None:
"""
Test that AOTI shim headers still expose symbols with TORCH_STABLE_ONLY.
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
# There are no constexpr symbols etc., so we need to actually use functions
# so that some symbols are found.
test_shim_content = """
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
int main() {
int32_t (*fp1)() = &aoti_torch_device_type_cpu;
int32_t (*fp2)() = &aoti_torch_dtype_float32;
(void)fp1; (void)fp2;
return 0;
}
"""
compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c",
"-DTORCH_STABLE_ONLY",
]
symbols_shim = _compile_and_extract_symbols(
cpp_content=test_shim_content,
compile_flags=compile_flags,
)
num_symbols_shim = len(symbols_shim)
assert num_symbols_shim > 0, (
f"Expected shim headers to expose symbols with TORCH_STABLE_ONLY, "
f"but found {num_symbols_shim} symbols"
)
def check_stable_c_shim_symbols(install_root: Path) -> None:
"""
Test that stable C shim headers still expose symbols with TORCH_STABLE_ONLY.
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
# Check if the stable C shim exists
stable_shim = include_dir / "torch" / "csrc" / "stable" / "c" / "shim.h"
if not stable_shim.exists():
raise RuntimeError("Could not find stable c shim")
# There are no constexpr symbols etc., so we need to actually use functions
# so that some symbols are found.
test_stable_shim_content = """
#include <torch/csrc/stable/c/shim.h>
int main() {
// Reference stable C API functions to create undefined symbols
AOTITorchError (*fp1)(const char*, uint32_t*, int32_t*) = &torch_parse_device_string;
AOTITorchError (*fp2)(uint32_t*) = &torch_get_num_threads;
(void)fp1; (void)fp2;
return 0;
}
"""
compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c",
"-DTORCH_STABLE_ONLY",
]
symbols_stable_shim = _compile_and_extract_symbols(
cpp_content=test_stable_shim_content,
compile_flags=compile_flags,
)
num_symbols_stable_shim = len(symbols_stable_shim)
assert num_symbols_stable_shim > 0, (
f"Expected stable C shim headers to expose symbols with TORCH_STABLE_ONLY, "
f"but found {num_symbols_stable_shim} symbols"
)
def check_lib_symbols_for_abi_correctness(lib: str) -> None:
print(f"lib: {lib}")
cxx11_symbols = grep_symbols(lib, LIBTORCH_CXX11_PATTERNS)
@ -460,13 +129,6 @@ def main() -> None:
check_lib_symbols_for_abi_correctness(libtorch_cpu_path)
check_lib_statically_linked_libstdc_cxx_abi_symbols(libtorch_cpu_path)
# Check symbols when TORCH_STABLE_ONLY is defined
check_stable_only_symbols(install_root)
check_stable_api_symbols(install_root)
check_headeronly_symbols(install_root)
check_aoti_shim_symbols(install_root)
check_stable_c_shim_symbols(install_root)
if __name__ == "__main__":
main()

View File

@ -353,6 +353,17 @@ def test_linalg(device="cpu") -> None:
torch.linalg.svd(A)
def test_sdpa(device="cpu", dtype=torch.float16) -> None:
"""Regression test for https://github.com/pytorch/pytorch/issues/167602
Without nvrtc_builtins on CuDNN-9.13 on CUDA-13 fails with ` No valid execution plans built.`
"""
print(f"Testing SDPA on {device} using type {dtype}")
k, q, v = torch.rand(3, 1, 16, 77, 64, dtype=dtype, device=device).unbind(0)
attn = torch.rand(1, 1, 77, 77, dtype=dtype, device=device)
rc = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn)
assert rc.isnan().any().item() is False
def smoke_test_compile(device: str = "cpu") -> None:
supported_dtypes = [torch.float16, torch.float32, torch.float64]
@ -489,10 +500,12 @@ def main() -> None:
smoke_test_conv2d()
test_linalg()
test_numpy()
test_sdpa()
if is_cuda_system:
test_linalg("cuda")
test_cuda_gds_errors_captured()
test_sdpa("cuda")
if options.package == "all":
smoke_test_modules()

View File

@ -389,6 +389,13 @@ test_lazy_tensor_meta_reference_disabled() {
export -n TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE
}
test_dynamo_core() {
time python test/run_test.py \
--include-dynamo-core-tests \
--verbose \
--upload-artifacts-while-running
assert_git_not_dirty
}
test_dynamo_wrapped_shard() {
if [[ -z "$NUM_TEST_SHARDS" ]]; then
@ -1680,6 +1687,22 @@ test_operator_microbenchmark() {
done
}
test_attention_microbenchmark() {
TEST_REPORTS_DIR=$(pwd)/test/test-reports
mkdir -p "$TEST_REPORTS_DIR"
TEST_DIR=$(pwd)
# Install attention-gym dependency
echo "Installing attention-gym..."
python -m pip install git+https://github.com/meta-pytorch/attention-gym.git@main
pip show triton
cd "${TEST_DIR}"/benchmarks/transformer
$TASKSET python score_mod.py --config configs/config_basic.yaml \
--output-json-for-dashboard "${TEST_REPORTS_DIR}/attention_microbenchmark.json"
}
if ! [[ "${BUILD_ENVIRONMENT}" == *libtorch* || "${BUILD_ENVIRONMENT}" == *-bazel-* ]]; then
(cd test && python -c "import torch; print(torch.__config__.show())")
(cd test && python -c "import torch; print(torch.__config__.parallel_info())")
@ -1737,6 +1760,8 @@ elif [[ "${TEST_CONFIG}" == *operator_benchmark* ]]; then
fi
elif [[ "${TEST_CONFIG}" == *operator_microbenchmark* ]]; then
test_operator_microbenchmark
elif [[ "${TEST_CONFIG}" == *attention_microbenchmark* ]]; then
test_attention_microbenchmark
elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then
test_inductor_distributed
elif [[ "${TEST_CONFIG}" == *inductor-halide* ]]; then
@ -1796,6 +1821,8 @@ elif [[ "${TEST_CONFIG}" == *inductor* ]]; then
test_inductor_shard "${SHARD_NUMBER}"
elif [[ "${TEST_CONFIG}" == *einops* ]]; then
test_einops
elif [[ "${TEST_CONFIG}" == *dynamo_core* ]]; then
test_dynamo_core
elif [[ "${TEST_CONFIG}" == *dynamo_wrapped* ]]; then
install_torchvision
test_dynamo_wrapped_shard "${SHARD_NUMBER}"

View File

@ -1 +1 @@
07b6cbde121417a70e4dc871adb6d27030e0ce3f
ee1a1350eb37804b94334768f328144f058f14e9

View File

@ -1 +1 @@
acccf86477759b2d3500f1ae1be065f7b1e409ec
2d82dc5caa336d179d9b46ac4a0fb8c43d84c5cc

View File

@ -1 +1 @@
e4d25697f9dc5eedaf8f0a5bf085c62c5455a53a
94631807d22c09723dd006f7be5beb649d5f88d0

View File

@ -7,6 +7,7 @@ ciflow_push_tags:
- ciflow/binaries
- ciflow/binaries_libtorch
- ciflow/binaries_wheel
- ciflow/dynamo
- ciflow/h100
- ciflow/h100-cutlass-backend
- ciflow/h100-distributed

View File

@ -50,7 +50,7 @@ def get_tag() -> str:
def get_base_version() -> str:
root = get_pytorch_root()
dirty_version = open(root / "version.txt").read().strip()
dirty_version = Path(root / "version.txt").read_text().strip()
# Strips trailing a0 from version.txt, not too sure why it's there in the
# first place
return re.sub(LEGACY_BASE_VERSION_SUFFIX_PATTERN, "", dirty_version)

View File

@ -260,11 +260,8 @@ jobs:
"${DOCKER_IMAGE}"
)
docker exec -t -w "${PYTORCH_ROOT}" "${container_name}" bash -c "bash .circleci/scripts/binary_populate_env.sh"
if [[ ${BUILD_ENVIRONMENT} == *"aarch64"* ]]; then
docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash /pytorch/.ci/aarch64_linux/aarch64_ci_build.sh"
else
docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash /pytorch/.ci/${{ inputs.PACKAGE_TYPE }}/build.sh"
fi
# Unified build script for all architectures (x86_64, aarch64, s390x)
docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash /pytorch/.ci/${{ inputs.PACKAGE_TYPE }}/build.sh"
- name: Chown artifacts
if: ${{ steps.filter.outputs.is-test-matrix-empty == 'False' && inputs.build_environment != 'linux-s390x-binary-manywheel' }}

View File

@ -326,7 +326,7 @@ jobs:
SCCACHE_BUCKET: ${{ !contains(matrix.runner, 'b200') && 'ossci-compiler-cache-circleci-v2' || '' }}
SCCACHE_REGION: ${{ !contains(matrix.runner, 'b200') && 'us-east-1' || '' }}
SHM_SIZE: ${{ contains(inputs.build-environment, 'cuda') && '2g' || '1g' }}
DOCKER_IMAGE: ${{ inputs.docker-image }}
DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }}
XLA_CUDA: ${{ contains(inputs.build-environment, 'xla') && '0' || '' }}
XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla
PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0' }}

View File

@ -0,0 +1,73 @@
name: attention_op_microbenchmark
on:
push:
tags:
- ciflow/op-benchmark/*
workflow_dispatch:
schedule:
# Run at 06:00 UTC everyday
- cron: 0 7 * * *
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
cancel-in-progress: true
permissions:
id-token: write
contents: read
jobs:
attn-microbenchmark-build:
if: github.repository_owner == 'pytorch'
uses: ./.github/workflows/_linux-build.yml
with:
runner: linux.12xlarge.memory
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
cuda-arch-list: '8.0 9.0'
test-matrix: |
{ include: [
{ config: "attention_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.aws.a100" },
{ config: "attention_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.aws.h100" },
]}
secrets: inherit
attn-microbenchmark-test:
name: attn-microbenchmark-test
uses: ./.github/workflows/_linux-test.yml
needs: attn-microbenchmark-build
with:
timeout-minutes: 500
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
docker-image: ${{ needs.attn-microbenchmark-build.outputs.docker-image }}
test-matrix: ${{ needs.attn-microbenchmark-build.outputs.test-matrix }}
secrets: inherit
# B200 runner
opmicrobenchmark-build-b200:
if: github.repository_owner == 'pytorch'
name: opmicrobenchmark-build-b200
uses: ./.github/workflows/_linux-build.yml
with:
runner: linux.12xlarge.memory
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
cuda-arch-list: '10.0'
test-matrix: |
{ include: [
{ config: "operator_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.dgx.b200" },
]}
secrets: inherit
opmicrobenchmark-test-b200:
name: opmicrobenchmark-test-b200
uses: ./.github/workflows/_linux-test.yml
needs: opmicrobenchmark-build-b200
with:
timeout-minutes: 500
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
docker-image: ${{ needs.opmicrobenchmark-build-b200.outputs.docker-image }}
test-matrix: ${{ needs.opmicrobenchmark-build-b200.outputs.test-matrix }}
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
secrets: inherit

70
.github/workflows/dynamo-unittest.yml vendored Normal file
View File

@ -0,0 +1,70 @@
# Workflow: Dynamo Unit Test
# runs unit tests for dynamo.
name: dynamo-unittest
on:
push:
tags:
- ciflow/dynamo/*
workflow_call:
schedule:
- cron: 29 8 * * * # about 1:29am PDT
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
cancel-in-progress: true
permissions:
id-token: write
contents: read
jobs:
get-label-type:
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
opt_out_experiments: lf
dynamo-build:
name: dynamo-build
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
strategy:
matrix:
python-version: ['3.11', '3.12']
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-py${{ matrix.python-version }}-clang12
docker-image-name: ci-image:pytorch-linux-jammy-py${{ matrix.python-version }}-clang12
test-matrix: |
{ include: [
{ config: "dynamo_core", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
{ config: "dynamo_wrapped", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
{ config: "dynamo_wrapped", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
{ config: "dynamo_wrapped", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
]}
secrets: inherit
dynamo-test:
name: dynamo-test
uses: ./.github/workflows/_linux-test.yml
needs: [get-label-type, dynamo-build]
strategy:
matrix:
python-version: ['3.11', '3.12']
with:
build-environment: linux-jammy-py${{ matrix.python-version }}-clang12
docker-image: ci-image:pytorch-linux-jammy-py${{ matrix.python-version }}-clang12
test-matrix: |
{ include: [
{ config: "dynamo_core", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
{ config: "dynamo_wrapped", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
{ config: "dynamo_wrapped", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
{ config: "dynamo_wrapped", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
]}
secrets: inherit

330
.spin/cmds.py Normal file
View File

@ -0,0 +1,330 @@
import hashlib
import subprocess
import sys
from pathlib import Path
import click
import spin
def file_digest(file, algorithm: str):
try:
return hashlib.file_digest(file, algorithm)
except AttributeError:
pass # Fallback to manual implementation below
hash = hashlib.new(algorithm)
while chunk := file.read(8192):
hash.update(chunk)
return hash
def _hash_file(file):
with open(file, "rb") as f:
hash = file_digest(f, "sha256")
return hash.hexdigest()
def _hash_files(files):
hashes = {file: _hash_file(file) for file in files}
return hashes
def _read_hashes(hash_file: Path):
if not hash_file.exists():
return {}
with hash_file.open("r") as f:
lines = f.readlines()
hashes = {}
for line in lines:
hash = line[:64]
file = line[66:].strip()
hashes[file] = hash
return hashes
def _updated_hashes(hash_file, files_to_hash):
old_hashes = _read_hashes(hash_file)
new_hashes = _hash_files(files_to_hash)
if new_hashes != old_hashes:
return new_hashes
return None
@click.command()
def regenerate_version():
"""Regenerate version.py."""
cmd = [
sys.executable,
"-m",
"tools.generate_torch_version",
"--is-debug=false",
]
spin.util.run(cmd)
TYPE_STUBS = [
(
"Pytorch type stubs",
Path(".lintbin/.pytorch-type-stubs.sha256"),
[
"aten/src/ATen/native/native_functions.yaml",
"aten/src/ATen/native/tags.yaml",
"tools/autograd/deprecated.yaml",
],
[
sys.executable,
"-m",
"tools.pyi.gen_pyi",
"--native-functions-path",
"aten/src/ATen/native/native_functions.yaml",
"--tags-path",
"aten/src/ATen/native/tags.yaml",
"--deprecated-functions-path",
"tools/autograd/deprecated.yaml",
],
),
(
"Datapipes type stubs",
None,
[],
[
sys.executable,
"torch/utils/data/datapipes/gen_pyi.py",
],
),
]
@click.command()
def regenerate_type_stubs():
"""Regenerate type stubs."""
for name, hash_file, files_to_hash, cmd in TYPE_STUBS:
if hash_file:
if hashes := _updated_hashes(hash_file, files_to_hash):
click.echo(
f"Changes detected in type stub files for {name}. Regenerating..."
)
spin.util.run(cmd)
hash_file.parent.mkdir(parents=True, exist_ok=True)
with hash_file.open("w") as f:
for file, hash in hashes.items():
f.write(f"{hash} {file}\n")
click.echo("Type stubs and hashes updated.")
else:
click.echo(f"No changes detected in type stub files for {name}.")
else:
click.echo(f"No hash file for {name}. Regenerating...")
spin.util.run(cmd)
click.echo("Type stubs regenerated.")
@click.command()
def regenerate_clangtidy_files():
"""Regenerate clang-tidy files."""
cmd = [
sys.executable,
"-m",
"tools.linter.clang_tidy.generate_build_files",
]
spin.util.run(cmd)
#: These linters are expected to need less than 3s cpu time total
VERY_FAST_LINTERS = {
"ATEN_CPU_GPU_AGNOSTIC",
"BAZEL_LINTER",
"C10_NODISCARD",
"C10_UNUSED",
"CALL_ONCE",
"CMAKE_MINIMUM_REQUIRED",
"CONTEXT_DECORATOR",
"COPYRIGHT",
"CUBINCLUDE",
"DEPLOY_DETECTION",
"ERROR_PRONE_ISINSTANCE",
"EXEC",
"HEADER_ONLY_LINTER",
"IMPORT_LINTER",
"INCLUDE",
"LINTRUNNER_VERSION",
"MERGE_CONFLICTLESS_CSV",
"META_NO_CREATE_UNBACKED",
"NEWLINE",
"NOQA",
"NO_WORKFLOWS_ON_FORK",
"ONCE_FLAG",
"PYBIND11_INCLUDE",
"PYBIND11_SPECIALIZATION",
"PYPIDEP",
"PYPROJECT",
"RAWCUDA",
"RAWCUDADEVICE",
"ROOT_LOGGING",
"TABS",
"TESTOWNERS",
"TYPEIGNORE",
"TYPENOSKIP",
"WORKFLOWSYNC",
}
#: These linters are expected to take a few seconds, but less than 10s cpu time total
FAST_LINTERS = {
"CMAKE",
"DOCSTRING_LINTER",
"GHA",
"NATIVEFUNCTIONS",
"RUFF",
"SET_LINTER",
"SHELLCHECK",
"SPACES",
}
#: These linters are expected to take more than 10s cpu time total;
#: some need more than 1 hour.
SLOW_LINTERS = {
"ACTIONLINT",
"CLANGFORMAT",
"CLANGTIDY",
"CODESPELL",
"FLAKE8",
"GB_REGISTRY",
"PYFMT",
"PYREFLY",
"TEST_DEVICE_BIAS",
"TEST_HAS_MAIN",
}
ALL_LINTERS = VERY_FAST_LINTERS | FAST_LINTERS | SLOW_LINTERS
LINTRUNNER_CACHE_INFO = (
Path(".lintbin/.lintrunner.sha256"),
[
"requirements.txt",
"pyproject.toml",
".lintrunner.toml",
],
)
LINTRUNNER_BASE_CMD = [
"uvx",
"--python",
"3.10",
"lintrunner@0.12.7",
]
@click.command()
def setup_lint():
"""Set up lintrunner with current CI version."""
cmd = LINTRUNNER_BASE_CMD + ["init"]
subprocess.run(cmd, check=True, capture_output=True, text=True)
def _check_linters():
cmd = LINTRUNNER_BASE_CMD + ["list"]
ret = spin.util.run(cmd, output=False, stderr=subprocess.PIPE)
linters = {l.strip() for l in ret.stdout.decode().strip().split("\n")[1:]}
unknown_linters = linters - ALL_LINTERS
missing_linters = ALL_LINTERS - linters
if unknown_linters:
click.secho(
f"Unknown linters found; please add them to the correct category "
f"in .spin/cmds.py: {', '.join(unknown_linters)}",
fg="yellow",
)
if missing_linters:
click.secho(
f"Missing linters found; please update the corresponding category "
f"in .spin/cmds.py: {', '.join(missing_linters)}",
fg="yellow",
)
return unknown_linters, missing_linters
@spin.util.extend_command(
setup_lint,
doc=f"""
If configuration has changed, update lintrunner.
Compares the stored old hashes of configuration files with new ones and
performs setup via setup-lint if the hashes have changed.
Hashes are stored in {LINTRUNNER_CACHE_INFO[0]}; the following files are
considered: {", ".join(LINTRUNNER_CACHE_INFO[1])}.
""",
)
@click.pass_context
def lazy_setup_lint(ctx, parent_callback, **kwargs):
if hashes := _updated_hashes(*LINTRUNNER_CACHE_INFO):
click.echo(
"Changes detected in lint configuration files. Setting up linting tools..."
)
parent_callback(**kwargs)
hash_file = LINTRUNNER_CACHE_INFO[0]
hash_file.parent.mkdir(parents=True, exist_ok=True)
with hash_file.open("w") as f:
for file, hash in hashes.items():
f.write(f"{hash} {file}\n")
click.echo("Linting tools set up and hashes updated.")
else:
click.echo("No changes detected in lint configuration files. Skipping setup.")
click.echo("Regenerating version...")
ctx.invoke(regenerate_version)
click.echo("Regenerating type stubs...")
ctx.invoke(regenerate_type_stubs)
click.echo("Done.")
_check_linters()
@click.command()
@click.option("-a", "--apply-patches", is_flag=True)
@click.pass_context
def lint(ctx, apply_patches, **kwargs):
"""Lint all files."""
ctx.invoke(lazy_setup_lint)
all_files_linters = VERY_FAST_LINTERS | FAST_LINTERS
changed_files_linters = SLOW_LINTERS
cmd = LINTRUNNER_BASE_CMD
if apply_patches:
cmd += ["--apply-patches"]
all_files_cmd = cmd + [
"--take",
",".join(all_files_linters),
"--all-files",
]
spin.util.run(all_files_cmd)
changed_files_cmd = cmd + [
"--take",
",".join(changed_files_linters),
]
spin.util.run(changed_files_cmd)
@click.command()
@click.pass_context
def fixlint(ctx, **kwargs):
"""Autofix all files."""
ctx.invoke(lint, apply_patches=True)
@click.command()
@click.option("-a", "--apply-patches", is_flag=True)
@click.pass_context
def quicklint(ctx, apply_patches, **kwargs):
"""Lint changed files."""
ctx.invoke(lazy_setup_lint)
cmd = LINTRUNNER_BASE_CMD
if apply_patches:
cmd += ["--apply-patches"]
spin.util.run(cmd)
@click.command()
@click.pass_context
def quickfix(ctx, **kwargs):
"""Autofix changed files."""
ctx.invoke(quicklint, apply_patches=True)

View File

@ -94,6 +94,11 @@ TORCH_API inline void resetPeakStats(c10::DeviceIndex device_index) {
at::getDeviceAllocator(device_type)->resetPeakStats(device_index);
}
TORCH_API inline std::pair<size_t, size_t> getMemoryInfo(
c10::DeviceIndex device_index) {
const auto device_type = getAccelerator(true).value();
return at::getDeviceAllocator(device_type)->getMemoryInfo(device_index);
}
} // namespace at::accelerator
namespace at {

View File

@ -1,5 +1,6 @@
#pragma once
#include <torch/headeronly/core/TensorAccessor.h>
#include <c10/macros/Macros.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Deprecated.h>
@ -11,252 +12,37 @@
namespace at {
// The PtrTraits argument to the TensorAccessor/GenericPackedTensorAccessor
// is used to enable the __restrict__ keyword/modifier for the data
// passed to cuda.
template <typename T>
struct DefaultPtrTraits {
typedef T* PtrType;
};
using torch::headeronly::DefaultPtrTraits;
#if defined(__CUDACC__) || defined(__HIPCC__)
template <typename T>
struct RestrictPtrTraits {
typedef T* __restrict__ PtrType;
};
using torch::headeronly::RestrictPtrTraits;
#endif
// TensorAccessorBase and TensorAccessor are used for both CPU and CUDA tensors.
// For CUDA tensors it is used in device code (only). This means that we restrict ourselves
// to functions and types available there (e.g. IntArrayRef isn't).
// The PtrTraits argument is only relevant to cuda to support `__restrict__` pointers.
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
class TensorAccessorBase {
public:
typedef typename PtrTraits<T>::PtrType PtrType;
using TensorAccessorBase = torch::headeronly::detail::TensorAccessorBase<c10::IntArrayRef, T, N, PtrTraits, index_t>;
C10_HOST_DEVICE TensorAccessorBase(
PtrType data_,
const index_t* sizes_,
const index_t* strides_)
: data_(data_), sizes_(sizes_), strides_(strides_) {}
C10_HOST IntArrayRef sizes() const {
return IntArrayRef(sizes_,N);
}
C10_HOST IntArrayRef strides() const {
return IntArrayRef(strides_,N);
}
C10_HOST_DEVICE index_t stride(index_t i) const {
return strides_[i];
}
C10_HOST_DEVICE index_t size(index_t i) const {
return sizes_[i];
}
C10_HOST_DEVICE PtrType data() {
return data_;
}
C10_HOST_DEVICE const PtrType data() const {
return data_;
}
protected:
PtrType data_;
const index_t* sizes_;
const index_t* strides_;
};
// The `TensorAccessor` is typically instantiated for CPU `Tensor`s using
// `Tensor.accessor<T, N>()`.
// For CUDA `Tensor`s, `GenericPackedTensorAccessor` is used on the host and only
// indexing on the device uses `TensorAccessor`s.
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
class TensorAccessor : public TensorAccessorBase<T,N,PtrTraits,index_t> {
public:
typedef typename PtrTraits<T>::PtrType PtrType;
using TensorAccessor = torch::headeronly::detail::TensorAccessor<c10::IntArrayRef, T, N, PtrTraits, index_t>;
C10_HOST_DEVICE TensorAccessor(
PtrType data_,
const index_t* sizes_,
const index_t* strides_)
: TensorAccessorBase<T, N, PtrTraits, index_t>(data_,sizes_,strides_) {}
namespace detail {
C10_HOST_DEVICE TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) {
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1);
}
C10_HOST_DEVICE const TensorAccessor<T, N-1, PtrTraits, index_t> operator[](index_t i) const {
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1);
}
};
template<typename T, template <typename U> class PtrTraits, typename index_t>
class TensorAccessor<T,1,PtrTraits,index_t> : public TensorAccessorBase<T,1,PtrTraits,index_t> {
public:
typedef typename PtrTraits<T>::PtrType PtrType;
C10_HOST_DEVICE TensorAccessor(
PtrType data_,
const index_t* sizes_,
const index_t* strides_)
: TensorAccessorBase<T, 1, PtrTraits, index_t>(data_,sizes_,strides_) {}
C10_HOST_DEVICE T & operator[](index_t i) {
// NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
return this->data_[this->strides_[0]*i];
}
C10_HOST_DEVICE const T & operator[](index_t i) const {
return this->data_[this->strides_[0]*i];
}
};
// GenericPackedTensorAccessorBase and GenericPackedTensorAccessor are used on for CUDA `Tensor`s on the host
// and as
// In contrast to `TensorAccessor`s, they copy the strides and sizes on instantiation (on the host)
// in order to transfer them on the device when calling kernels.
// On the device, indexing of multidimensional tensors gives to `TensorAccessor`s.
// Use RestrictPtrTraits as PtrTraits if you want the tensor's data pointer to be marked as __restrict__.
// Instantiation from data, sizes, strides is only needed on the host and std::copy isn't available
// on the device, so those functions are host only.
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
class GenericPackedTensorAccessorBase {
public:
typedef typename PtrTraits<T>::PtrType PtrType;
C10_HOST GenericPackedTensorAccessorBase(
PtrType data_,
const index_t* sizes_,
const index_t* strides_)
: data_(data_) {
std::copy(sizes_, sizes_ + N, std::begin(this->sizes_));
std::copy(strides_, strides_ + N, std::begin(this->strides_));
}
// if index_t is not int64_t, we want to have an int64_t constructor
template <typename source_index_t, class = std::enable_if_t<std::is_same_v<source_index_t, int64_t>>>
C10_HOST GenericPackedTensorAccessorBase(
PtrType data_,
const source_index_t* sizes_,
const source_index_t* strides_)
: data_(data_) {
for (const auto i : c10::irange(N)) {
this->sizes_[i] = sizes_[i];
this->strides_[i] = strides_[i];
}
}
C10_HOST_DEVICE index_t stride(index_t i) const {
return strides_[i];
}
C10_HOST_DEVICE index_t size(index_t i) const {
return sizes_[i];
}
C10_HOST_DEVICE PtrType data() {
return data_;
}
C10_HOST_DEVICE const PtrType data() const {
return data_;
}
protected:
PtrType data_;
// NOLINTNEXTLINE(*c-arrays*)
index_t sizes_[N];
// NOLINTNEXTLINE(*c-arrays*)
index_t strides_[N];
C10_HOST void bounds_check_(index_t i) const {
TORCH_CHECK_INDEX(
template <size_t N, typename index_t>
struct IndexBoundsCheck {
IndexBoundsCheck(index_t i) {
TORCH_CHECK_INDEX(
0 <= i && i < index_t{N},
"Index ",
i,
" is not within bounds of a tensor of dimension ",
N);
}
}
};
} // namespace detail
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
class GenericPackedTensorAccessor : public GenericPackedTensorAccessorBase<T,N,PtrTraits,index_t> {
public:
typedef typename PtrTraits<T>::PtrType PtrType;
C10_HOST GenericPackedTensorAccessor(
PtrType data_,
const index_t* sizes_,
const index_t* strides_)
: GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}
// if index_t is not int64_t, we want to have an int64_t constructor
template <typename source_index_t, class = std::enable_if_t<std::is_same_v<source_index_t, int64_t>>>
C10_HOST GenericPackedTensorAccessor(
PtrType data_,
const source_index_t* sizes_,
const source_index_t* strides_)
: GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}
C10_DEVICE TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) {
index_t* new_sizes = this->sizes_ + 1;
index_t* new_strides = this->strides_ + 1;
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i, new_sizes, new_strides);
}
C10_DEVICE const TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) const {
const index_t* new_sizes = this->sizes_ + 1;
const index_t* new_strides = this->strides_ + 1;
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i, new_sizes, new_strides);
}
/// Returns a PackedTensorAccessor of the same dimension after transposing the
/// two dimensions given. Does not actually move elements; transposition is
/// made by permuting the size/stride arrays. If the dimensions are not valid,
/// asserts.
C10_HOST GenericPackedTensorAccessor<T, N, PtrTraits, index_t> transpose(
index_t dim1,
index_t dim2) const {
this->bounds_check_(dim1);
this->bounds_check_(dim2);
GenericPackedTensorAccessor<T, N, PtrTraits, index_t> result(
this->data_, this->sizes_, this->strides_);
std::swap(result.strides_[dim1], result.strides_[dim2]);
std::swap(result.sizes_[dim1], result.sizes_[dim2]);
return result;
}
};
template<typename T, template <typename U> class PtrTraits, typename index_t>
class GenericPackedTensorAccessor<T,1,PtrTraits,index_t> : public GenericPackedTensorAccessorBase<T,1,PtrTraits,index_t> {
public:
typedef typename PtrTraits<T>::PtrType PtrType;
C10_HOST GenericPackedTensorAccessor(
PtrType data_,
const index_t* sizes_,
const index_t* strides_)
: GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}
// if index_t is not int64_t, we want to have an int64_t constructor
template <typename source_index_t, class = std::enable_if_t<std::is_same_v<source_index_t, int64_t>>>
C10_HOST GenericPackedTensorAccessor(
PtrType data_,
const source_index_t* sizes_,
const source_index_t* strides_)
: GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}
C10_DEVICE T & operator[](index_t i) {
return this->data_[this->strides_[0] * i];
}
C10_DEVICE const T& operator[](index_t i) const {
return this->data_[this->strides_[0]*i];
}
// Same as in the general N-dimensional case, but note that in the
// 1-dimensional case the returned PackedTensorAccessor will always be an
// identical copy of the original
C10_HOST GenericPackedTensorAccessor<T, 1, PtrTraits, index_t> transpose(
index_t dim1,
index_t dim2) const {
this->bounds_check_(dim1);
this->bounds_check_(dim2);
return GenericPackedTensorAccessor<T, 1, PtrTraits, index_t>(
this->data_, this->sizes_, this->strides_);
}
};
using GenericPackedTensorAccessorBase = torch::headeronly::detail::GenericPackedTensorAccessorBase<detail::IndexBoundsCheck<N, index_t>, T, N, PtrTraits, index_t>;
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
using GenericPackedTensorAccessor = torch::headeronly::detail::GenericPackedTensorAccessor<TensorAccessor<T, N-1, PtrTraits, index_t>, detail::IndexBoundsCheck<N, index_t>, T, N, PtrTraits, index_t>;
// Can't put this directly into the macro function args because of commas
#define AT_X GenericPackedTensorAccessor<T, N, PtrTraits, index_t>

View File

@ -245,6 +245,9 @@ class TORCH_API TensorBase {
size_t weak_use_count() const noexcept {
return impl_.weak_use_count();
}
bool is_uniquely_owned() const noexcept {
return impl_.is_uniquely_owned();
}
std::string toString() const;

View File

@ -223,6 +223,62 @@ CONVERT_FROM_BF16_TEMPLATE(double)
CONVERT_FROM_BF16_TEMPLATE(float16_t)
#endif
#ifdef __ARM_FEATURE_BF16
// clang-[17, 20] crashes when autovectorizing static cast to bf16
// Below is a workaround to have some vectorization
// Works decently well for smaller int types
template <typename from_type>
inline void convertToBf16Impl(
const from_type* __restrict src,
c10::BFloat16* __restrict dst,
uint64_t n) {
bfloat16_t* dstPtr = reinterpret_cast<bfloat16_t*>(dst);
uint64_t loopBound = n - (n % 16);
uint64_t i = 0;
for (; i < loopBound; i += 16) {
float32x4_t a, b, c, d;
a[0] = static_cast<float>(src[i]);
a[1] = static_cast<float>(src[i + 1]);
a[2] = static_cast<float>(src[i + 2]);
a[3] = static_cast<float>(src[i + 3]);
b[0] = static_cast<float>(src[i + 4]);
b[1] = static_cast<float>(src[i + 5]);
b[2] = static_cast<float>(src[i + 6]);
b[3] = static_cast<float>(src[i + 7]);
c[0] = static_cast<float>(src[i + 8]);
c[1] = static_cast<float>(src[i + 9]);
c[2] = static_cast<float>(src[i + 10]);
c[3] = static_cast<float>(src[i + 11]);
d[0] = static_cast<float>(src[i + 12]);
d[1] = static_cast<float>(src[i + 13]);
d[2] = static_cast<float>(src[i + 14]);
d[3] = static_cast<float>(src[i + 15]);
vst1q_bf16(dstPtr + i, vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(a), b));
vst1q_bf16(dstPtr + i + 8, vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(c), d));
}
#pragma clang loop vectorize(disable) interleave(disable) unroll(disable)
for (; i < n; i++) {
float a = static_cast<float>(src[i]);
dstPtr[i] = vcvth_bf16_f32(a);
}
}
#define CONVERT_TO_BF16_TEMPLATE(from_type) \
template <> \
inline void convert(const from_type* src, c10::BFloat16* dst, int64_t n) { \
return convertToBf16Impl<from_type>(src, dst, n); \
}
CONVERT_TO_BF16_TEMPLATE(uint8_t)
CONVERT_TO_BF16_TEMPLATE(int8_t)
CONVERT_TO_BF16_TEMPLATE(int16_t)
CONVERT_TO_BF16_TEMPLATE(int32_t)
#endif
inline void convertBoolToBfloat16Impl(
const bool* __restrict src,
c10::BFloat16* __restrict dst,

View File

@ -3,6 +3,7 @@
#include <cstdint>
#include <map>
#include <shared_mutex>
#include <cuda_runtime_api.h>
#include <cusparse.h>
@ -88,8 +89,13 @@ TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
TORCH_CUDA_CPP_API void clearCublasWorkspaces();
TORCH_CUDA_CPP_API std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace();
TORCH_CUDA_CPP_API std::map<std::tuple<void *, void *>, at::DataPtr>& cublaslt_handle_stream_to_workspace();
struct WorkspaceMapWithMutex {
std::map<std::tuple<void*, void*>, at::DataPtr> map;
std::shared_mutex mutex;
};
TORCH_CUDA_CPP_API WorkspaceMapWithMutex& cublas_handle_stream_to_workspace();
TORCH_CUDA_CPP_API WorkspaceMapWithMutex& cublaslt_handle_stream_to_workspace();
TORCH_CUDA_CPP_API size_t getChosenWorkspaceSize();
TORCH_CUDA_CPP_API size_t getCUDABlasLtWorkspaceSize();
TORCH_CUDA_CPP_API void* getCUDABlasLtWorkspace();

View File

@ -1,6 +1,7 @@
#include <ATen/cuda/CUDAGeneratorImpl.h>
#include <ATen/cuda/CUDAGraph.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/MemPool.h>
#include <ATen/Functions.h>
#include <c10/cuda/CUDAFunctions.h>
@ -13,7 +14,7 @@ static bool _cuda_graphs_debug = false;
MempoolId_t graph_pool_handle() {
// Sets just the second value, to distinguish it from MempoolId_ts created from
// cudaStreamGetCaptureInfo id_s in capture_begin.
return c10::cuda::MemPool::graph_pool_handle();
return at::cuda::MemPool::graph_pool_handle();
}
/**
@ -90,7 +91,7 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capt
} else {
// User did not ask us to share a mempool. Create graph pool handle using is_user_created=false.
// Sets just the first value, to distinguish it from MempoolId_ts created by graph_pool_handle().
mempool_id_ = c10::cuda::MemPool::graph_pool_handle(false);
mempool_id_ = at::cuda::MemPool::graph_pool_handle(false);
TORCH_INTERNAL_ASSERT(mempool_id_.first > 0);
}
@ -174,17 +175,24 @@ void CUDAGraph::instantiate() {
// Trailing NULL, NULL, 0 arguments were recommended by Cuda driver people,
// who prefer not to report error message through these arguments moving forward
// (they prefer return value, or errors on api calls internal to the capture)
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000)
AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, 0));
// ROCM appears to fail with HIP error: invalid argument
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && !defined(USE_ROCM)
AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, cudaGraphInstantiateFlagUseNodePriority));
#else
AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, NULL, NULL, 0));
#endif
//Since ROCm 6.2, we want to go down this path as hipGraphExecDestroy in the destructor will not immediately free the memory.
//It will wait for the next sync operation. cudaGraphInstantiateFlagAutoFreeOnLaunch will add async frees after graph launch.
} else {
#if !defined(USE_ROCM)
AT_CUDA_CHECK(cudaGraphInstantiateWithFlags(&graph_exec_,
graph_,
cudaGraphInstantiateFlagAutoFreeOnLaunch | cudaGraphInstantiateFlagUseNodePriority));
#else
AT_CUDA_CHECK(cudaGraphInstantiateWithFlags(&graph_exec_,
graph_,
cudaGraphInstantiateFlagAutoFreeOnLaunch));
#endif
}
has_graph_exec_ = true;
}

View File

@ -99,7 +99,7 @@ void destroyCublasHandle(cublasHandle_t handle) {
// - Comments of @soumith copied from cuDNN handle pool implementation
#ifdef NO_CUDNN_DESTROY_HANDLE
#else
cublasDestroy(handle);
cublasDestroy(handle);
#endif
}
@ -107,19 +107,27 @@ using CuBlasPoolType = DeviceThreadHandlePool<cublasHandle_t, createCublasHandle
} // namespace
std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace() {
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
WorkspaceMapWithMutex& cublas_handle_stream_to_workspace() {
static auto& instance = *new WorkspaceMapWithMutex;
return instance;
}
std::map<std::tuple<void *, void *>, at::DataPtr>& cublaslt_handle_stream_to_workspace() {
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
WorkspaceMapWithMutex& cublaslt_handle_stream_to_workspace() {
static auto& instance = *new WorkspaceMapWithMutex;
return instance;
}
void clearCublasWorkspaces() {
cublas_handle_stream_to_workspace().clear();
cublaslt_handle_stream_to_workspace().clear();
{
auto& workspace = cublas_handle_stream_to_workspace();
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
workspace.map.clear();
}
{
auto& workspace = cublaslt_handle_stream_to_workspace();
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
workspace.map.clear();
}
}
size_t parseChosenWorkspaceSize() {
@ -233,6 +241,38 @@ at::DataPtr getNewCUDABlasLtWorkspace() {
return c10::cuda::CUDACachingAllocator::get()->allocate(getCUDABlasLtWorkspaceSize());
}
void setWorkspaceForHandle(cublasHandle_t handle, c10::cuda::CUDAStream stream) {
cudaStream_t _stream = stream;
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
auto& workspace = cublas_handle_stream_to_workspace();
size_t workspace_size = getChosenWorkspaceSize();
// Fast path: check if workspace already exists
{
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
auto workspace_it = workspace.map.find(key);
if (workspace_it != workspace.map.end()) {
TORCH_CUDABLAS_CHECK(cublasSetWorkspace(
handle, workspace_it->second.get(), workspace_size));
return;
}
}
// Slow path: allocate workspace outside the lock
auto new_workspace = getNewWorkspace();
// Insert with lock (double-check in case another thread inserted while we
// were allocating)
{
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
auto workspace_it = workspace.map.try_emplace(key, std::move(new_workspace)).first;
TORCH_CUDABLAS_CHECK(
cublasSetWorkspace(handle, workspace_it->second.get(), workspace_size));
}
}
void* getCUDABlasLtWorkspace() {
#ifndef USE_ROCM
static bool unified = c10::utils::check_env(TORCH_CUBLASLT_UNIFIED_WORKSPACE) == true;
@ -241,8 +281,10 @@ void* getCUDABlasLtWorkspace() {
auto stream = c10::cuda::getCurrentCUDAStream();
cudaStream_t _stream = stream;
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
auto workspace_it = at::cuda::cublas_handle_stream_to_workspace().find(key);
TORCH_INTERNAL_ASSERT(workspace_it != at::cuda::cublas_handle_stream_to_workspace().end());
auto& workspace = at::cuda::cublas_handle_stream_to_workspace();
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
auto workspace_it = workspace.map.find(key);
TORCH_INTERNAL_ASSERT(workspace_it != workspace.map.end());
return workspace_it->second.mutable_get();
}
#endif
@ -250,11 +292,29 @@ void* getCUDABlasLtWorkspace() {
auto stream = c10::cuda::getCurrentCUDAStream();
cudaStream_t _stream = stream;
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
auto workspace_it = cublaslt_handle_stream_to_workspace().find(key);
if (workspace_it == cublaslt_handle_stream_to_workspace().end()) {
workspace_it = cublaslt_handle_stream_to_workspace().insert(workspace_it, {key, getNewCUDABlasLtWorkspace()});
auto& workspace = cublaslt_handle_stream_to_workspace();
// Fast path: check if workspace already exists
{
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
auto workspace_it = workspace.map.find(key);
if (workspace_it != workspace.map.end()) {
return workspace_it->second.mutable_get();
}
}
// Slow path: allocate workspace outside the lock
auto new_workspace = getNewCUDABlasLtWorkspace();
// Insert with lock (double-check in case another thread inserted while we
// were allocating)
{
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
auto workspace_it =
workspace.map.try_emplace(key, std::move(new_workspace)).first;
return workspace_it->second.mutable_get();
}
return workspace_it->second.mutable_get();
}
cublasHandle_t getCurrentCUDABlasHandle() {
@ -298,13 +358,8 @@ cublasHandle_t getCurrentCUDABlasHandle() {
// will allocate memory dynamically (even if they're cheap) outside
// PyTorch's CUDA caching allocator. It's possible that CCA used up
// all the memory and cublas's cudaMallocAsync will return OOM
cudaStream_t _stream = stream;
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
auto workspace_it = cublas_handle_stream_to_workspace().find(key);
if (workspace_it == cublas_handle_stream_to_workspace().end()) {
workspace_it = cublas_handle_stream_to_workspace().insert(workspace_it, {key, getNewWorkspace()});
}
TORCH_CUDABLAS_CHECK(cublasSetWorkspace(handle, workspace_it->second.get(), getChosenWorkspaceSize()));
setWorkspaceForHandle(handle, stream);
#if !defined(USE_ROCM)
// On CUDA >= 11, and architecture >= Ampere, cuBLAS can use TF32 to speedup
// FP32 data type calculations based on the value of the allow_tf32 flag.

View File

@ -0,0 +1,69 @@
#include <ATen/core/CachingHostAllocator.h>
#include <ATen/cuda/MemPool.h>
namespace at::cuda {
// uid_ is incremented when a user creates a MemPool,
// for example: using graph_pool_handle() or c10::cuda::MemPool().
//
// uuid_ is incremented when CUDAGraph creates a MemPool
// as a result of a user not providing a pool.
//
// MempoolId_t of {0, 0} is used to denote when no MemPool has been
// passed to a function, either by user or CUDAGraphs. For example,
// default value of MempoolId_t for capture_begin function is {0, 0}.
// That's why uid_ and uuid_ start at 1.
std::atomic<CaptureId_t> MemPool::uid_{1};
std::atomic<CaptureId_t> MemPool::uuid_{1};
MemPool::MemPool(
CUDACachingAllocator::CUDAAllocator* allocator,
bool is_user_created,
bool use_on_oom)
: allocator_(allocator), is_user_created_(is_user_created) {
if (is_user_created_) {
id_ = {0, uid_++};
} else {
id_ = {uuid_++, 0};
}
device_ = c10::cuda::current_device();
CUDACachingAllocator::createOrIncrefPool(device_, id_, allocator);
if (use_on_oom) {
CUDACachingAllocator::setUseOnOOM(device_, id_);
}
}
MemPool::~MemPool() {
// TORCH_INTERNAL_ASSERT(use_count() == 1);
// We used to assert that TORCH_INTERNAL_ASSERT(use_count() == 1);
// However, this assertion is not true if a memory pool is shared
// with a cuda graph. That CUDAGraph will increase the use count
// until it is reset.
CUDACachingAllocator::releasePool(device_, id_);
c10::cuda::CUDACachingAllocator::emptyCache(id_);
}
MempoolId_t MemPool::id() {
return id_;
}
CUDACachingAllocator::CUDAAllocator* MemPool::allocator() {
return allocator_;
}
int MemPool::use_count() {
return CUDACachingAllocator::getPoolUseCount(device_, id_);
}
c10::DeviceIndex MemPool::device() {
return device_;
}
MempoolId_t MemPool::graph_pool_handle(bool is_user_created) {
if (is_user_created) {
return {0, uid_++};
}
return {uuid_++, 0};
}
} // namespace at::cuda

View File

@ -0,0 +1,44 @@
#pragma once
#include <c10/core/Allocator.h>
#include <c10/cuda/CUDACachingAllocator.h>
namespace at::cuda {
// Keep BC only
using c10::CaptureId_t;
using c10::MempoolId_t;
// MemPool represents a pool of memory in a caching allocator. Currently,
// it's just the ID of the pool object maintained in the CUDACachingAllocator.
//
// An allocator pointer can be passed to the MemPool to define how the
// allocations should be done in the pool. For example: using a different
// system allocator such as ncclMemAlloc.
struct TORCH_CUDA_CPP_API MemPool {
MemPool(
c10::cuda::CUDACachingAllocator::CUDAAllocator* allocator = nullptr,
bool is_user_created = true,
bool use_on_oom = false);
MemPool(const MemPool&) = delete;
MemPool(MemPool&&) = default;
MemPool& operator=(const MemPool&) = delete;
MemPool& operator=(MemPool&&) = default;
~MemPool();
MempoolId_t id();
c10::cuda::CUDACachingAllocator::CUDAAllocator* allocator();
int use_count();
c10::DeviceIndex device();
static MempoolId_t graph_pool_handle(bool is_user_created = true);
private:
static std::atomic<CaptureId_t> uid_;
static std::atomic<CaptureId_t> uuid_;
c10::cuda::CUDACachingAllocator::CUDAAllocator* allocator_;
bool is_user_created_;
MempoolId_t id_;
c10::DeviceIndex device_;
};
} // namespace at::cuda

View File

@ -3541,9 +3541,9 @@ Tensor _dyn_quant_matmul_4bit_cpu(
const int64_t out_features) {
auto M = inp.size(0);
TORCH_CHECK(
inp.dtype() == kFloat,
inp.dtype() == kFloat || (inp.dtype() == kBFloat16 && block_size == in_features),
__func__,
" : expect input to be 32-bit float tensor.");
" : expect input to be float32 or bfloat16 tensor.");
TORCH_CHECK(
block_size == in_features ||
(!(block_size % 32) && !(in_features % block_size)),

View File

@ -1087,7 +1087,8 @@ TORCH_IMPL_FUNC(index_copy_out)
result.copy_(self);
// See Note [Enabling Deterministic Operations]
if (result.is_cuda() && globalContext().deterministicAlgorithms()) {
if ((result.is_cuda() || result.is_xpu()) &&
globalContext().deterministicAlgorithms()) {
torch::List<std::optional<Tensor>> indices;
indices.resize(dim + 1);
indices.set(dim, index);

View File

@ -904,19 +904,11 @@ Tensor mvlgamma(const Tensor& self, int64_t p) {
return args.lgamma_().sum(-1).add_(p2_sub_p * std::log(c10::pi<double>) * QUARTER);
}
// since mvlgamma_ has different signature from its
// out and functional variant, we explicitly
// define it (instead of using structured kernel).
Tensor& mvlgamma_(Tensor& self, int64_t p) {
mvlgamma_check(self, p);
Tensor args = native::arange(
-p *HALF + HALF,
HALF,
HALF,
optTypeMetaToScalarType(self.options().dtype_opt()),
self.options().layout_opt(),
self.options().device_opt(),
self.options().pinned_memory_opt());
args = args.add(self.unsqueeze(-1));
const auto p2_sub_p = static_cast<double>(p * (p - 1));
return self.copy_(args.lgamma_().sum(-1).add_(p2_sub_p * std::log(c10::pi<double>) * QUARTER));
return at::mvlgamma_out(self, self, p);
}
Tensor& mvlgamma_out(const Tensor& self, int64_t p, Tensor& result) {

View File

@ -8,6 +8,7 @@
#include <ATen/cpu/vec/vec.h>
#include <ATen/native/cpu/int_mm_kernel.h>
#include <ATen/native/cpu/utils.h>
#include <cmath>
#include <c10/util/Unroll.h>
#include <c10/util/irange.h>
@ -793,6 +794,139 @@ bool can_use_kleidiai(
}
#endif
static void ref_dyn_quant_matmul_4bit_channelwise_kernel_bf16(
size_t m,
size_t n,
size_t k,
const uint16_t* lhs_bf16,
const uint8_t* rhs_qs4cx,
const float* rhs_scales,
uint16_t* dst_bf16,
float scalar_min,
float scalar_max,
const float* bias) {
// Roundup lambda for internal stride calculations
auto roundup = [](size_t a, size_t b) { return ((a + b - 1) / b) * b; };
// Cast bfloat16 to float32 inline
auto cast_bf16_to_f32 = [](uint16_t bf16_val) {
uint32_t tmp = static_cast<uint32_t>(bf16_val) << 16;
float f;
std::memcpy(&f, &tmp, sizeof(f));
return f;
};
// Cast float32 to bfloat16 inline
auto cast_f32_to_bf16 = [](float f) {
uint32_t bits;
std::memcpy(&bits, &f, sizeof(bits));
return static_cast<uint16_t>(bits >> 16);
};
// Quantization pack lambda (channelwise QA8DX)
auto quant_pack_8bit_channelwise =
[&](size_t M, size_t K, const uint16_t* src_bf16, int8_t* dst_qa8dx) {
constexpr int8_t kI8Min = std::numeric_limits<std::int8_t>::lowest();
constexpr int8_t kI8Max = std::numeric_limits<std::int8_t>::max();
const size_t dst_stride =
K * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
for (size_t i = 0; i < M; ++i) {
const uint16_t* row_ptr = src_bf16 + i * K;
// find min/max
float mn = FLT_MAX, mx = -FLT_MAX;
for (size_t j = 0; j < K; ++j) {
float v = cast_bf16_to_f32(row_ptr[j]);
mn = std::min(mn, v);
mx = std::max(mx, v);
}
float rmin = std::min(0.0f, mn);
float rmax = std::max(0.0f, mx);
constexpr float qmin = static_cast<float>(kI8Min);
constexpr float qmax = static_cast<float>(kI8Max);
float scale = (rmin == rmax) ? 1.f : (qmax - qmin) / (rmax - rmin);
float recip = scale ? 1.0f / scale : 0.0f;
int32_t zp;
float des_min = rmin * scale;
float des_max = rmax * scale;
float err_min = qmin + des_min;
float err_max = qmax + des_max;
float zp_f =
(err_min + err_max) > 0 ? qmin - des_min : qmax - des_max;
zp_f = std::clamp(zp_f, qmin, qmax);
zp = std::lrintf(zp_f);
int8_t* out_ptr = dst_qa8dx + i * dst_stride;
// store header
*reinterpret_cast<float*>(out_ptr) = recip;
*reinterpret_cast<int32_t*>(out_ptr + sizeof(float)) = -zp;
out_ptr += sizeof(float) + sizeof(int32_t);
// quantize
for (size_t j = 0; j < K; ++j) {
float v = cast_bf16_to_f32(row_ptr[j]);
int32_t q = static_cast<int32_t>(std::round(v * scale)) + zp;
q = std::clamp(
q, static_cast<int32_t>(kI8Min), static_cast<int32_t>(kI8Max));
*out_ptr++ = static_cast<int8_t>(q);
}
}
};
// MatMul lambda (MXN x MXK -> MNXK BF16)
auto matmul_kernel = [&](size_t M,
size_t N,
size_t K,
const int8_t* lhs,
const uint8_t* rhs,
const float* scales,
uint16_t* dst,
float lo,
float hi) {
const size_t lhs_stride =
K * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
const size_t rhs_stride = roundup(K, 2) / 2;
for (size_t i = 0; i < M; ++i) {
const int8_t* lhs_row = lhs + i * lhs_stride;
for (size_t j = 0; j < N; ++j) {
int32_t acc = 0;
const int8_t* lptr = lhs_row;
const uint8_t* rptr = rhs + j * rhs_stride;
float lhs_scale = *reinterpret_cast<const float*>(lptr);
int32_t lhs_off =
*reinterpret_cast<const int32_t*>(lptr + sizeof(float));
lptr += sizeof(float) + sizeof(int32_t);
for (size_t t = 0; t < K; ++t) {
int32_t lv = static_cast<int32_t>(lptr[t]);
uint8_t bv = rptr[t / 2];
int32_t rv = ((t & 1) == 0) ? (static_cast<int32_t>(bv & 0xF) - 8)
: (static_cast<int32_t>(bv >> 4) - 8);
acc += lv * rv + lhs_off * rv;
}
float res = static_cast<float>(acc) * scales[j] * lhs_scale;
if (bias) {
res += bias[j];
}
res = std::clamp(res, lo, hi);
*dst++ = cast_f32_to_bf16(res);
}
}
};
// allocate and run
std::unique_ptr<int8_t[]> packed(
new int8_t[m * (k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t))]);
quant_pack_8bit_channelwise(m, k, lhs_bf16, packed.get());
matmul_kernel(
m,
n,
k,
packed.get(),
rhs_qs4cx,
rhs_scales,
dst_bf16,
scalar_min,
scalar_max);
}
/**
* The Int4 quantized weights must be represented as a uint8 tensor
* For matrix multiplication with a weight shape of (N x K)
@ -819,21 +953,21 @@ void dyn_quant_pack_4bit_weight_kernel(
#if AT_KLEIDIAI_ENABLED()
if (can_use_kleidiai(scales_zeros, K, block_size)) {
const int64_t weight_packed_size =
kleidiai::kai_pack_rhs_int4_size(N, K, block_size);
kleidiai::kai_pack_rhs_int4_size(N, K, block_size, weights.scalar_type());
packed_weights.resize_({weight_packed_size});
kleidiai::kai_pack_int4_rhs(
packed_weights, weights, scales_zeros, bias, N, K, block_size);
} else
#endif
{
TORCH_CHECK(
bias.has_value() == 0,
__func__,
" : Bias is unsupported in reference implementation");
packed_weights = packed_weights.to(kFloat);
auto weight_reshaped = weights.view({-1}).to(kFloat);
auto scales_zeros_reshaped = scales_zeros.view({-1}).to(kFloat);
auto res = at::cat({weight_reshaped, scales_zeros_reshaped}, 0);
auto weight_reshaped = weights.reshape({-1}).to(kFloat);
auto scales_zeros_reshaped = scales_zeros.reshape({-1}).to(kFloat);
std::vector<at::Tensor> tensors_to_cat = {weight_reshaped, scales_zeros_reshaped};
if (bias.has_value()) {
tensors_to_cat.push_back(bias.value().view({-1}).to(kFloat));
}
auto res = at::cat(tensors_to_cat, 0);
packed_weights.resize_(res.sizes()).copy_(res);
}
}
@ -847,7 +981,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
const float* rhs_scales_f32,
float* dst_f32,
float scalar_min,
float scalar_max) {
float scalar_max,
const float* bias) {
const size_t input_size_8bit = m * (k + sizeof(int32_t) + sizeof(float));
auto lhs_qa8dx_buffer = std::make_unique<uint8_t[]>(input_size_8bit);
@ -857,6 +992,9 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
// required format for matmul
auto input_quant_pack_8bit_channelwise =
[&](size_t m, size_t k, const float* lhs_f32, int8_t* lhs_qa8dx) {
constexpr int8_t kI8Min = std::numeric_limits<std::int8_t>::lowest();
constexpr int8_t kI8Max = std::numeric_limits<std::int8_t>::max();
const size_t dst_stride =
(k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t));
@ -877,8 +1015,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
}
// Maximum/minimum int8 values
const float qmin = (float)INT8_MIN;
const float qmax = (float)INT8_MAX;
constexpr float qmin = static_cast<float>(kI8Min);
constexpr float qmax = static_cast<float>(kI8Max);
const float rmin0 = std::min(0.0f, min0);
const float rmax0 = std::max(0.0f, max0);
@ -904,7 +1042,7 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
zero_point0 = std::min(zero_point0, qmax);
// Round to nearest integer
const int32_t nudged_zero_point0 = lrintf(zero_point0);
const int32_t nudged_zero_point0 = std::lrintf(zero_point0);
int8_t* dst_ptr = lhs_qa8dx + m_idx * dst_stride;
@ -922,8 +1060,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
v0_s32 = v0_s32 + nudged_zero_point0;
v0_s32 = std::max(v0_s32, static_cast<int32_t>(INT8_MIN));
v0_s32 = std::min(v0_s32, static_cast<int32_t>(INT8_MAX));
v0_s32 = std::max(v0_s32, static_cast<int32_t>(kI8Min));
v0_s32 = std::min(v0_s32, static_cast<int32_t>(kI8Max));
dst_ptr[0] = (int8_t)v0_s32;
dst_ptr += sizeof(int8_t);
}
@ -987,6 +1125,10 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
main_acc = main_acc * lhs_scale;
if (bias) {
main_acc += bias[n_idx];
}
// Clamp (min-max) operation
main_acc = std::max(main_acc, scalar_min);
main_acc = std::min(main_acc, scalar_max);
@ -1007,12 +1149,16 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
const float* rhs_scales_fp32,
float* dst_f32,
float scalar_min,
float scalar_max) {
float scalar_max,
const float* bias) {
// Lambda for LHS quantization
auto lhs_quant_pack = [&](size_t m,
size_t k,
const float* lhs_f32,
int8_t* lhs_qa8dx) {
constexpr int8_t kI8Min = std::numeric_limits<std::int8_t>::lowest();
constexpr int8_t kI8Max = std::numeric_limits<std::int8_t>::max();
const size_t dst_stride =
(k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t));
@ -1028,8 +1174,8 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
min0 = std::min(src0_0, min0);
}
const float qmin = (float)INT8_MIN;
const float qmax = (float)INT8_MAX;
constexpr float qmin = static_cast<float>(kI8Min);
constexpr float qmax = static_cast<float>(kI8Max);
const float rmin0 = std::min(0.0f, min0);
const float rmax0 = std::max(0.0f, max0);
@ -1046,7 +1192,7 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
zero_point0 = std::max(zero_point0, qmin);
zero_point0 = std::min(zero_point0, qmax);
const int32_t nudged_zero_point0 = lrintf(zero_point0);
const int32_t nudged_zero_point0 = std::lrintf(zero_point0);
int8_t* dst_ptr = lhs_qa8dx + row_idx * dst_stride;
@ -1059,9 +1205,8 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
const float src0_0 = src_ptr[k_idx];
int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
v0_s32 = std::max(
std::min(
v0_s32 + nudged_zero_point0, static_cast<int32_t>(INT8_MAX)),
static_cast<int32_t>(INT8_MIN));
std::min(v0_s32 + nudged_zero_point0, static_cast<int32_t>(kI8Max)),
static_cast<int32_t>(kI8Min));
dst_ptr[0] = (int8_t)v0_s32;
dst_ptr += sizeof(int8_t);
}
@ -1118,6 +1263,11 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
}
main_acc = main_acc * lhs_scale;
if (bias) {
main_acc += bias[col_idx];
}
main_acc = std::max(main_acc, scalar_min);
main_acc = std::min(main_acc, scalar_max);
@ -1128,28 +1278,27 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
}
/**
* Dynamic Input Quant 4 bit weights matmul execution flow
(INT4 Weights + FP scales + FP32 Bias)
FP32 Input Packed Buffer
| |
Quantize Cast
to INT8 to INT8
| |
v v
INT8 Input INT8 Weights
\ /
\ /
\ /
INT8 Matrix Multiplication
|
v
FP32 Dequantized and Accumulate in FP32
|
v
FP32 Final Output
* The Groupwise kernel requires BFloat16 Scales and Channelwise kernel requires
* Float32 Scales. If not provided, we will use fallback implementation.
* Dynamic INT4 weight-only MatMul with per-row input quantization.
*
* Execution Flow:
*
* (INT4 Weights + FP Scales [+ optional Bias])
*
* Input (FP32 or BF16) Packed Weight Buffer
* | |
* Row-wise Quantization (INT8) |
* | |
* INT8 Input Activation INT4 Quantized Weights + Scales
* \ /
* \ /
* Quantized Matrix Multiply
* |
* Output Tensor (BF16 or FP32)
*
* Notes:
* - Groupwise kernels expect BF16 scales
* - Channelwise kernels expect FP32 scales
* - Bias is currently unsupported in fallback path
*/
void dyn_quant_matmul_4bit_kernel(
const Tensor& output,
@ -1161,65 +1310,75 @@ void dyn_quant_matmul_4bit_kernel(
const int64_t block_size) {
#if AT_KLEIDIAI_ENABLED()
const int64_t weight_packed_size =
kleidiai::kai_pack_rhs_int4_size(N, K, block_size);
kleidiai::kai_pack_rhs_int4_size(N, K, block_size, inp.scalar_type());
if (weight_packed_size == packed_weights.numel()) {
// KleidiAI interface internally handles the Channelwise and groupwise
// distinction
kleidiai::kai_quant_pack_lhs_int4_mm(
output, inp, packed_weights, M, N, K, block_size);
kleidiai::kai_quant_pack_lhs_int4_mm(output, inp, packed_weights, M, N, K, block_size);
} else
#endif
{
float* lhs_f32 = reinterpret_cast<float*>(inp.data_ptr());
const auto weights_size = N * K / 2;
// The weights needs to be in uint8_t data type after quantization
auto extracted_weights =
(packed_weights.narrow(0, 0, weights_size)).to(kByte);
auto float32_scales =
(packed_weights.narrow(
0, weights_size, packed_weights.size(0) - weights_size))
.to(kFloat);
uint8_t* rhs_4bit =
reinterpret_cast<uint8_t*>(extracted_weights.data_ptr());
float* rhs_scales_f32 = reinterpret_cast<float*>(float32_scales.data_ptr());
float* dst_f32 = reinterpret_cast<float*>(output.data_ptr());
if (block_size == K) {
ref_dyn_quant_matmul_4bit_channelwise_kernel(
M,
N,
K,
lhs_f32,
rhs_4bit,
rhs_scales_f32,
dst_f32,
-FLT_MAX,
FLT_MAX);
} else if (!(block_size % 32) && !(K % block_size)) {
ref_dyn_quant_matmul_4bit_groupwise_kernel(
M,
N,
K,
block_size,
lhs_f32,
rhs_4bit,
rhs_scales_f32,
dst_f32,
-FLT_MAX,
FLT_MAX);
} else {
TORCH_CHECK(
block_size == K || (!(block_size % 32) && !(K % block_size)),
__func__,
": Group size should be multiple 32 or in_features [",
K,
"]. Provided ",
block_size);
{
void* input = inp.data_ptr();
void* dst = output.data_ptr();
// Extract weights, sclaes and biases form from packed tensor
const int weights_elements = N * K / 2;
const int scale_elements = N * (K / block_size);
TORCH_CHECK(packed_weights.numel() >= (weights_elements + scale_elements), "Invalid packed weight tensor size");
auto extracted_weights = packed_weights.narrow(0, 0, weights_elements).to(kByte);
auto extracted_scales_and_bias = packed_weights.narrow(0, weights_elements, packed_weights.size(0) - weights_elements).to(kFloat);
auto float32_scales = extracted_scales_and_bias.narrow(0, 0, scale_elements);
int bias_elements = packed_weights.numel() - (weights_elements + scale_elements);
float* weight_scales = float32_scales.data_ptr<float>();
void* bias_data = nullptr;
if (bias_elements) {
auto float32_bias = extracted_scales_and_bias.narrow(0, scale_elements, bias_elements);
TORCH_CHECK(float32_bias.size(0) == N, "Expected bias length to match output dimension");
bias_data = float32_bias.data_ptr();
}
// 2 elements of 4 bit weights are packed into 1 uint8 packet
uint8_t* weights_4bit = reinterpret_cast<uint8_t*>(extracted_weights.data_ptr());
// Dispatch to reference kernels
if (inp.scalar_type() == at::kBFloat16) {
// BF16 input, BF16 output
constexpr float BF16_MAX = 3.38953139e+38f;
constexpr float BF16_MIN = -BF16_MAX;
if (block_size == K) {
ref_dyn_quant_matmul_4bit_channelwise_kernel_bf16(
M, N, K,
(uint16_t*)input, weights_4bit, weight_scales,
(uint16_t*)dst, BF16_MIN, BF16_MAX, (float*)bias_data);
} else {
TORCH_CHECK(false, "Unsupported block size for BF16 fallback");
}
} else if (inp.scalar_type() == at::kFloat) {
// FP32 input, FP32 output
if (block_size == K) {
ref_dyn_quant_matmul_4bit_channelwise_kernel(
M, N, K,
(float*)input, weights_4bit, weight_scales,
(float*)dst, -FLT_MAX, FLT_MAX, (float*)bias_data);
} else if (!(block_size % 32) && !(K % block_size)) {
ref_dyn_quant_matmul_4bit_groupwise_kernel(
M, N, K, block_size,
(float*)input, weights_4bit, weight_scales,
(float*)dst, -FLT_MAX, FLT_MAX, (float*)bias_data);
} else {
TORCH_CHECK(false, "Unsupported block size for FP32 fallback");
}
} else {
TORCH_CHECK(false, "Unsupported input/output dtype combination for int4mm kernel");
}
}
}
}
} // anonymous namespace
}
ALSO_REGISTER_AVX512_DISPATCH(weight_to_int4pack_stub, &weight_to_int4pack_kernel)
ALSO_REGISTER_AVX512_DISPATCH(int4pack_mm_stub, &int4pack_mm_kernel)
REGISTER_DISPATCH(dyn_quant_pack_4bit_weight_stub, &dyn_quant_pack_4bit_weight_kernel)

View File

@ -1,6 +1,7 @@
#pragma once
#include <ATen/native/CompositeRandomAccessorCommon.h>
#include <thrust/swap.h>
#include <thrust/tuple.h>
namespace at { namespace native {

View File

@ -75,30 +75,52 @@ static inline bool can_use_int32_nhwc(
return true;
}
static inline bool can_use_int32_nchw(
int64_t nbatch, int64_t channels,
int64_t height, int64_t width,
int64_t pooled_height, int64_t pooled_width) {
int64_t hw = height * width;
return can_use_int32_nhwc(
nbatch, channels, height, width,
pooled_height, pooled_width,
channels * hw, // in_stride_n
hw, // in_stride_c
width, // in_stride_h
1 // in_stride_w
);
}
// kernels borrowed from Caffe
template <typename scalar_t>
__global__ void max_pool_forward_nchw(const int nthreads, const scalar_t* bottom_data,
const int64_t channels, const int64_t height,
const int64_t width, const int pooled_height, const int pooled_width,
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w, scalar_t* top_data,
template <typename scalar_t, typename index_t>
__global__ void max_pool_forward_nchw(
const index_t nthreads,
const scalar_t* bottom_data,
const int64_t channels,
const int64_t height,
const int64_t width,
const int pooled_height,
const int pooled_width,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w,
scalar_t* top_data,
int64_t* top_mask) {
CUDA_KERNEL_LOOP(index, nthreads) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
int hstart = ph * stride_h - pad_h;
int wstart = pw * stride_w - pad_w;
int hend = min(hstart + (kernel_h - 1) * dilation_h + 1, height);
int wend = min(wstart + (kernel_w - 1) * dilation_w + 1, width);
CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
index_t pw = index % pooled_width;
index_t ph = (index / pooled_width) % pooled_height;
index_t c = (index / pooled_width / pooled_height) % channels;
index_t n = index / pooled_width / pooled_height / channels;
index_t hstart = ph * stride_h - pad_h;
index_t wstart = pw * stride_w - pad_w;
index_t hend = min(hstart + (kernel_h - 1) * dilation_h + 1, height);
index_t wend = min(wstart + (kernel_w - 1) * dilation_w + 1, width);
while(hstart < 0)
hstart += dilation_h;
while(wstart < 0)
wstart += dilation_w;
scalar_t maxval = at::numeric_limits<scalar_t>::lower_bound(); // -Infinity
int maxidx = hstart * width + wstart;
index_t maxidx = hstart * width + wstart;
const scalar_t* btm_data = bottom_data + (n * channels + c) * height * width;
for (int h = hstart; h < hend; h += dilation_h) {
for (int w = wstart; w < wend; w += dilation_w) {
@ -251,32 +273,39 @@ __global__ void max_pool_forward_nhwc(
static constexpr int BLOCK_THREADS = 256;
template <typename scalar_t, typename accscalar_t>
template <typename scalar_t, typename accscalar_t, typename index_t>
#if defined (USE_ROCM)
C10_LAUNCH_BOUNDS_2(BLOCK_THREADS, 4)
#else
C10_LAUNCH_BOUNDS_2(BLOCK_THREADS, 8)
#endif
__global__ void max_pool_backward_nchw(const scalar_t* top_diff,
const int64_t* top_mask, const int num, const int64_t channels,
const int64_t height, const int64_t width, const int pooled_height,
const int pooled_width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
__global__ void max_pool_backward_nchw(
const scalar_t* top_diff,
const int64_t* top_mask,
const index_t num,
const index_t channels,
const index_t height,
const index_t width,
const index_t pooled_height,
const index_t pooled_width,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w,
scalar_t* bottom_diff) {
CUDA_KERNEL_LOOP(index, height*width) {
int h = index / width;
int w = index - h * width;
int phstart = p_start(h, pad_h, kernel_h, dilation_h, stride_h);
int phend = p_end(h, pad_h, pooled_height, stride_h);
int pwstart = p_start(w, pad_w, kernel_w, dilation_w, stride_w);
int pwend = p_end(w, pad_w, pooled_width, stride_w);
for (int n = blockIdx.y; n < num; n += gridDim.y) {
for (int c = blockIdx.z; c < channels; c+= gridDim.z) {
CUDA_KERNEL_LOOP_TYPE(index, height*width, index_t) {
index_t h = index / width;
index_t w = index - h * width;
index_t phstart = p_start(h, pad_h, kernel_h, dilation_h, stride_h);
index_t phend = p_end(h, pad_h, pooled_height, stride_h);
index_t pwstart = p_start(w, pad_w, kernel_w, dilation_w, stride_w);
index_t pwend = p_end(w, pad_w, pooled_width, stride_w);
for (index_t n = blockIdx.y; n < num; n += gridDim.y) {
for (index_t c = blockIdx.z; c < channels; c += gridDim.z) {
accscalar_t gradient = accscalar_t(0);
int offset = (n * channels + c) * pooled_height * pooled_width;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
index_t offset = (n * channels + c) * pooled_height * pooled_width;
for (index_t ph = phstart; ph < phend; ++ph) {
for (index_t pw = pwstart; pw < pwend; ++pw) {
if (top_mask[ph * pooled_width + pw + offset] == h * width + w) {
gradient += static_cast<accscalar_t>(top_diff[ph * pooled_width + pw + offset]);
}
@ -469,8 +498,6 @@ const Tensor& indices) {
const int64_t in_stride_h = input.stride(-2);
const int64_t in_stride_w = input.stride(-1);
const int count = safe_downcast<int, int64_t>(output.numel());
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
"max_pool2d_with_indices_out_cuda_frame",
[&] {
@ -553,14 +580,42 @@ const Tensor& indices) {
break;
}
case MemoryFormat::Contiguous: {
const int num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock,
BLOCK_THREADS);
max_pool_forward_nchw<scalar_t>
<<<ceil_div(count, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
count, input_data,
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
output_data, indices_data);
const int threads = std::min(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock,
BLOCK_THREADS);
const int64_t nthreads = output.numel();
bool use_int32 = can_use_int32_nchw(
nbatch, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth);
const int maxGridX = at::cuda::getCurrentDeviceProperties()->maxGridSize[0];
const int blocks = static_cast<int>(std::min<int64_t>(
ceil_div(nthreads, static_cast<int64_t>(threads)),
static_cast<int64_t>(maxGridX)));
auto stream = at::cuda::getCurrentCUDAStream();
if (use_int32) {
max_pool_forward_nchw<scalar_t, int32_t>
<<<blocks, threads, 0, stream>>>(
static_cast<int32_t>(nthreads),
input_data,
static_cast<int32_t>(nInputPlane),
static_cast<int32_t>(inputHeight),
static_cast<int32_t>(inputWidth),
static_cast<int32_t>(outputHeight),
static_cast<int32_t>(outputWidth),
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
output_data, indices_data);
} else {
max_pool_forward_nchw<scalar_t, int64_t>
<<<blocks, threads, 0, stream>>>(
nthreads,
input_data,
nInputPlane,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
output_data, indices_data);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
}
@ -633,8 +688,6 @@ const Tensor& gradInput) {
gradInput.zero_();
int64_t count = input.numel();
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
"max_pool2d_with_indices_out_cuda_frame",
[&] {
@ -692,25 +745,45 @@ const Tensor& gradInput) {
break;
}
case MemoryFormat::Contiguous: {
int imgcount = inputWidth * inputHeight;
dim3 grid;
const int blocks = (imgcount + BLOCK_THREADS - 1) / BLOCK_THREADS;
grid.x = blocks;
grid.y = nbatch;
uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
if (maxGridY < grid.y) grid.y = maxGridY;
grid.z = nInputPlane;
uint64_t maxGridZ = at::cuda::getCurrentDeviceProperties()->maxGridSize[2];
if (maxGridZ < grid.z) grid.z = maxGridZ;
max_pool_backward_nchw<scalar_t, accscalar_t>
<<<grid, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
gradOutput_data,
indices_data,
nbatch,
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
gradInput_data);
const int threads = std::min(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock,
BLOCK_THREADS);
const int imgcount = inputWidth * inputHeight;
const int maxGridX = at::cuda::getCurrentDeviceProperties()->maxGridSize[0];
const int maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const int maxGridZ = at::cuda::getCurrentDeviceProperties()->maxGridSize[2];
const int blocks_x = std::min(ceil_div(imgcount, threads), maxGridX);
dim3 grid(blocks_x, static_cast<unsigned>(std::min<int64_t>(nbatch, maxGridY)), static_cast<unsigned>(std::min<int64_t>(nInputPlane, maxGridZ)));
bool use_int32 = can_use_int32_nchw(
nbatch, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth);
auto stream = at::cuda::getCurrentCUDAStream();
if (use_int32) {
max_pool_backward_nchw<scalar_t, accscalar_t, int32_t>
<<<grid, threads, 0, stream>>>(
gradOutput_data,
indices_data,
static_cast<int32_t>(nbatch),
static_cast<int32_t>(nInputPlane),
static_cast<int32_t>(inputHeight),
static_cast<int32_t>(inputWidth),
static_cast<int32_t>(outputHeight),
static_cast<int32_t>(outputWidth),
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
gradInput_data);
} else {
max_pool_backward_nchw<scalar_t, accscalar_t, int64_t>
<<<grid, threads, 0, stream>>>(
gradOutput_data,
indices_data,
nbatch,
nInputPlane,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
gradInput_data);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
}

View File

@ -78,9 +78,18 @@ __global__ void EmbeddingBag_updateOutputKernel_max(
scalar_t weightFeatMax = 0;
int64_t bag_size_ = 0;
int64_t maxWord = -1;
// Separate validation loop reduces register pressure in the main loop below.
// No early exit (break) on invalid input as benchmarking shows it degrades performance.
bool has_invalid_index = false;
for (int64_t emb = begin; emb < end; emb++) {
index_t input_idx = input[emb];
has_invalid_index = has_invalid_index || (input_idx < 0 || input_idx >= numRows);
}
CUDA_KERNEL_ASSERT(!has_invalid_index && "Invalid input index in EmbeddingBag: index out of range [0, numRows)");
for (int64_t emb = begin; emb < end; emb++) {
bool pad = (input[emb] == padding_idx);
CUDA_KERNEL_ASSERT(input[emb] < numRows);
const int64_t weightRow = input[emb] * weight_stride0;
scalar_t weightValue = weightFeat[weightRow];
if (bag_size_ == 0 || weightValue > weightFeatMax) {
@ -129,10 +138,19 @@ __global__ void EmbeddingBag_updateOutputKernel_sum_mean(
CUDA_KERNEL_ASSERT(end >= begin);
accscalar_t weightFeatSum = 0;
int64_t bag_size_ = 0;
// Separate validation loop reduces register pressure in the main loop below.
// No early exit (break) on invalid input as benchmarking shows it degrades performance.
bool has_invalid_index = false;
for (int64_t emb = begin; emb < end; emb++) {
index_t input_idx = input[emb];
has_invalid_index = has_invalid_index || (input_idx < 0 || input_idx >= numRows);
}
CUDA_KERNEL_ASSERT(!has_invalid_index && "Invalid input index in EmbeddingBag: index out of range [0, numRows)");
for (int64_t emb = begin; emb < end; emb++) {
index_t input_idx = input[emb];
bool pad = (input_idx == padding_idx);
CUDA_KERNEL_ASSERT(0 <= input_idx && input_idx < numRows);
const int64_t weightRow = input_idx * weight_stride0;
scalar_t weightValue = weightFeat[weightRow];
weightValue = pad ? static_cast<scalar_t>(0) : weightValue;

View File

@ -78,9 +78,9 @@ _mx8_mx8_bf16_grouped_mm_fbgemm(
const Tensor& mat_a,
const Tensor& mat_b,
const Tensor& scale_a,
const SwizzleType& swizzle_a,
const SwizzleType swizzle_a,
const Tensor& scale_b,
const SwizzleType& swizzle_b,
const SwizzleType swizzle_b,
const std::optional<at::Tensor>& offs,
Tensor& out) {
const bool a_is_2d = mat_a.dim() == 2;

View File

@ -5,69 +5,11 @@
#include <cuda_bf16.h>
#endif
// ROCm 6.3 is planned to have these functions, but until then here they are.
#if defined(USE_ROCM)
#include <device_functions.h>
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
__device__ inline __hip_bfloat162 preview_unsafeAtomicAdd(__hip_bfloat162* address, __hip_bfloat162 value) {
#if (defined(__gfx942__)) && \
__has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2bf16)
typedef unsigned short __attribute__((ext_vector_type(2))) vec_short2;
static_assert(sizeof(vec_short2) == sizeof(__hip_bfloat162_raw));
union {
__hip_bfloat162_raw bf162_raw;
vec_short2 vs2;
} u{static_cast<__hip_bfloat162_raw>(value)};
u.vs2 = __builtin_amdgcn_flat_atomic_fadd_v2bf16((vec_short2*)address, u.vs2);
return static_cast<__hip_bfloat162>(u.bf162_raw);
#else
static_assert(sizeof(unsigned int) == sizeof(__hip_bfloat162_raw));
union u_hold {
__hip_bfloat162_raw h2r;
unsigned int u32;
};
u_hold old_val, new_val;
old_val.u32 = __hip_atomic_load((unsigned int*)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
do {
new_val.h2r = __hadd2(old_val.h2r, value);
} while (!__hip_atomic_compare_exchange_strong(
(unsigned int*)address, &old_val.u32, new_val.u32,
__ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT));
return old_val.h2r;
#endif
}
__device__ inline __half2 preview_unsafeAtomicAdd(__half2* address, __half2 value) {
#if (defined(__gfx942__)) && \
__has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2f16)
// The api expects an ext_vector_type of half
typedef _Float16 __attribute__((ext_vector_type(2))) vec_fp162;
static_assert(sizeof(vec_fp162) == sizeof(__half2_raw));
union {
__half2_raw h2r;
vec_fp162 fp16;
} u {static_cast<__half2_raw>(value)};
u.fp16 = __builtin_amdgcn_flat_atomic_fadd_v2f16((vec_fp162*)address, u.fp16);
return static_cast<__half2>(u.h2r);
#else
static_assert(sizeof(__half2_raw) == sizeof(unsigned int));
union u_hold {
__half2_raw h2r;
unsigned int u32;
};
u_hold old_val, new_val;
old_val.u32 = __hip_atomic_load((unsigned int*)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
do {
new_val.h2r = __hadd2(old_val.h2r, value);
} while (!__hip_atomic_compare_exchange_strong(
(unsigned int*)address, &old_val.u32, new_val.u32,
__ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT));
return old_val.h2r;
#endif
}
#define ATOMICADD preview_unsafeAtomicAdd
#define ATOMICADD unsafeAtomicAdd
#define NATIVE_ZERO_BF16 __float2bfloat16(0.0f)
#else
#define ATOMICADD atomicAdd

View File

@ -740,7 +740,12 @@ _scaled_rowwise_rowwise(
TORCH_CHECK_VALUE(scale_a.numel() == mat_a.size(0) && scale_a.scalar_type() == kFloat, "scale_a must have ", mat_a.size(0), " Float elements, got ", scale_a.numel())
TORCH_CHECK_VALUE(scale_b.numel() == mat_b.size(1) && scale_b.scalar_type() == kFloat, "scale_b must have ", mat_b.size(1), " Float elements, got ", scale_b.numel())
TORCH_CHECK_VALUE(scale_a.stride(1) == 1, "expected scale_a.stride(1) to be 1, but got ", scale_a.stride(1));
// if we have a scale of shape [256, 1] (say), then stride can be [1, 0] - handle this case
TORCH_CHECK_VALUE(
scale_a.stride(1) == 1 ||
scale_a.size(1) == 1,
"expected scale_a.stride(1) to be 1, but got ", scale_a.stride(1)
);
TORCH_CHECK_VALUE(scale_b.stride(1) == 1, "expected scale_b.stride(1) to be 1, but got ", scale_b.stride(1));
auto scaling_choice_a = ScalingType::RowWise;
@ -1096,6 +1101,19 @@ _scaled_mxfp8_mxfp8(
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
}
void
_check_mxfp4_support() {
#ifndef USE_ROCM
auto dprops = at::cuda::getCurrentDeviceProperties();
// Only on B200 GPUs
TORCH_CHECK_NOT_IMPLEMENTED(
// B200 = 10.0, B300 = 10.3
dprops->major == 10,
"MXFP4 scaling only supported in CUDA for B200/B300"
);
#endif
}
Tensor&
_scaled_mxfp4_mxfp4(
@ -1108,6 +1126,7 @@ _scaled_mxfp4_mxfp4(
#if defined(_WIN32) || (!defined(USE_ROCM) && !defined(USE_FBGEMM_GENAI))
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM and CUDA+FBGEMM_GENAI only");
#else
_check_mxfp4_support();
// Restrictions:
// A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32
TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2 && mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_a and mat_b must be fp4 types, got: ",

View File

@ -267,15 +267,15 @@ void scan_dim_with_indices(const TensorBase& self, const TensorBase& values, con
* outer dimensions, which contains several "inner rows").
* Each thread processes a single inner row at a time.
*/
template<typename scalar_t, class BinaryOp>
template<typename scalar_t, typename index_t, class BinaryOp>
__global__ void tensor_kernel_scan_outer_dim(scalar_t *tgt_, const scalar_t *src_,
const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size,
const scalar_t init, BinaryOp binary_op)
{
for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) {
for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) {
const scalar_t *src = src_ + orow * row_size * num_irows + irow;
scalar_t *tgt = tgt_ + orow * row_size * num_irows + irow;
const scalar_t *src = src_ + static_cast<index_t>(orow) * row_size * num_irows + irow;
scalar_t *tgt = tgt_ + (index_t) orow * row_size * num_irows + irow;
scalar_t acc = init;
for (uint32_t col = 0; col < row_size; ++col) {
@ -409,10 +409,15 @@ __host__ void scan_outer_dim(const TensorBase& self, const TensorBase& result,
check_fits_in_unsigned(num_irows, "num_irows");
check_fits_in_unsigned(num_orows, "num_orows");
check_fits_in_unsigned(row_size, "row_size");
tensor_kernel_scan_outer_dim<scalar_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
if (static_cast<size_t>(num_irows) * num_orows * row_size <= UINT_MAX) {
tensor_kernel_scan_outer_dim<scalar_t, uint32_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
result.mutable_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(),
num_orows, num_irows, row_size, init, binary_op);
} else {
tensor_kernel_scan_outer_dim<scalar_t, size_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
result.mutable_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(),
num_orows, num_irows, row_size, init, binary_op);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
}

View File

@ -21,18 +21,27 @@ void kai_pack_int4_rhs(
const int64_t n,
const int64_t k,
const int64_t bl) {
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
if (bl == k) {
// Channelwise
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
auto& params = kernel_packet.rhs_pack_params;
params.lhs_zero_point = 1;
params.rhs_zero_point = 8;
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_f32_qa8dxp_qs4cxp>(
kernel_packet, weight_packed, weight, scales, bias, n, k);
if (weight.scalar_type() == at::kBFloat16) {
auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod);
auto& params = kernel_packet.rhs_pack_params;
params.lhs_zero_point = 1;
params.rhs_zero_point = 8;
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_bf16_qa8dxp_qs4cxp>(
kernel_packet, weight_packed, weight, scales, bias, n, k);
} else {
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
auto& params = kernel_packet.rhs_pack_params;
params.lhs_zero_point = 1;
params.rhs_zero_point = 8;
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_f32_qa8dxp_qs4cxp>(
kernel_packet, weight_packed, weight, scales, bias, n, k);
}
} else if (!(bl % 32) && !(k % bl)) {
// Groupwise
auto kernel_packet = kai_select_groupwise_matmul_ukernel(
@ -63,19 +72,29 @@ void kai_pack_int4_rhs(
size_t kai_pack_rhs_int4_size(
const int64_t n,
const int64_t k,
const int64_t bl) {
const int64_t bl,
at::ScalarType tensor_dtype) {
size_t packed_size = n * k;
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
if (bl == k) {
// Channelwise
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
const auto& ukernel = kernel_packet.ukernel;
const size_t nr = ukernel.get_nr();
const size_t kr = ukernel.get_kr();
const size_t sr = ukernel.get_sr();
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
if (tensor_dtype == at::kBFloat16) {
auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod);
const auto& ukernel = kernel_packet.ukernel;
const size_t nr = ukernel.get_nr();
const size_t kr = ukernel.get_kr();
const size_t sr = ukernel.get_sr();
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
} else {
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
const auto& ukernel = kernel_packet.ukernel;
const size_t nr = ukernel.get_nr();
const size_t kr = ukernel.get_kr();
const size_t sr = ukernel.get_sr();
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
}
} else if (!(bl % 32) && !(k % bl)) {
// Groupwise
auto kernel_packet = kai_select_groupwise_matmul_ukernel(
@ -148,8 +167,7 @@ static void kai_quant_pack_lhs_int4_mm_groupwise(
const auto lhs_src_ptr = lhs_native_mtx_f32 + thread_id * src_stride;
const int64_t m_idx = thread_id * vec_per_thread;
auto lhs_packed_ptr = lhs_packed_base +
kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(
m_idx, k, mr, kr, sr);
kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr);
const int64_t vec_num = (thread_id == num_threads - 1)
? (m - vec_per_thread * thread_id)
: vec_per_thread;
@ -259,8 +277,7 @@ static void kai_quant_pack_lhs_int4_mm_channelwise(
const auto lhs_src_ptr = lhs_native_mtx_f32 + thread_id * src_stride;
const int64_t m_idx = thread_id * vec_per_thread;
auto lhs_packed_ptr = lhs_packed_base +
kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(
m_idx, k, mr, kr, sr);
kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr);
const int64_t vec_num = (thread_id == num_threads - 1)
? (m - vec_per_thread * thread_id)
: vec_per_thread;
@ -320,19 +337,144 @@ static void kai_quant_pack_lhs_int4_mm_channelwise(
});
}
void kai_quant_pack_lhs_int4_mm(
static void kai_quant_pack_lhs_int4_mm_bf16_channelwise(
const Tensor& output,
const Tensor& input,
const Tensor& weight,
const int64_t m,
const int64_t n,
const int64_t k) {
// Kernel IDs for GEMM and GEMV
constexpr kai_kernel_id gemm_id =
kai_kernel_id::matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm;
constexpr kai_kernel_id gemv_id =
kai_kernel_id::matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod;
// Get total threads and select kernel
const int64_t total_threads = at::get_num_threads();
auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(gemv_id);
if (cpuinfo_has_arm_i8mm() && m > 1) {
kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(gemm_id);
}
// Thread blocking parameters
const int64_t n_step = kernel_packet.ukernel.get_n_step();
const size_t mr = kernel_packet.ukernel.get_mr();
const size_t kr = kernel_packet.ukernel.get_kr();
const size_t sr = kernel_packet.ukernel.get_sr();
const size_t lhs_packed_size =
kernel_packet.kai_get_lhs_packed_size(m, k, mr, kr, sr);
auto lhs_packed = std::make_unique<uint8_t[]>(lhs_packed_size);
uint8_t* dst_act_mtx_bf16 = reinterpret_cast<uint8_t*>(output.data_ptr());
const uint8_t* lhs_native_mtx_bf16 =
reinterpret_cast<const uint8_t*>(input.data_ptr());
const uint8_t* rhs_packed_mtx_qs4cx =
reinterpret_cast<const uint8_t*>(weight.data_ptr());
uint8_t* lhs_packed_base = lhs_packed.get();
constexpr int32_t element_size = sizeof(uint16_t);
const size_t lhs_stride = k * element_size;
const size_t dst_stride = n * element_size;
// LHS quantization packing
int64_t vec_per_thread = get_vec_per_thread(m, total_threads, mr);
int64_t num_threads = (m + vec_per_thread - 1) / vec_per_thread;
const size_t src_stride = vec_per_thread * lhs_stride;
auto lhs_quant_pack = [=, &kernel_packet](int64_t thread_id) {
const auto lhs_src_ptr = lhs_native_mtx_bf16 + thread_id * src_stride;
const int64_t m_idx = thread_id * vec_per_thread;
auto lhs_packed_ptr = lhs_packed_base +
kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr);
const int64_t vec_num = (thread_id == num_threads - 1)
? (m - vec_per_thread * thread_id)
: vec_per_thread;
kernel_packet.kai_run_lhs_quant_pack(
vec_num,
k,
mr,
kr,
sr,
0,
(const uint16_t*)lhs_src_ptr,
lhs_stride,
lhs_packed_ptr);
};
at::parallel_for(
0, num_threads, /*grain_size=*/1, [&](int64_t begin, int64_t end) {
for (int64_t thread_id = begin; thread_id < end; ++thread_id) {
lhs_quant_pack(thread_id);
}
});
// Matrix multiplication
vec_per_thread = get_vec_per_thread(n, total_threads, n_step);
num_threads = (n + vec_per_thread - 1) / vec_per_thread;
auto mm = [=, &kernel_packet](int64_t thread_id) {
const auto rhs_packed_ptr = rhs_packed_mtx_qs4cx +
kernel_packet.ukernel.get_rhs_packed_offset(
thread_id * vec_per_thread, k);
auto dst_ptr = dst_act_mtx_bf16 +
kernel_packet.ukernel.get_dst_offset(
0, thread_id * vec_per_thread, dst_stride);
const int64_t vec_num = (thread_id == num_threads - 1)
? (n - vec_per_thread * thread_id)
: vec_per_thread;
kernel_packet.ukernel.run_matmul(
m,
vec_num,
k,
lhs_packed_base,
rhs_packed_ptr,
(uint16_t*)dst_ptr,
dst_stride,
element_size, // dst_stride_col
-FLT_MAX,
FLT_MAX);
};
at::parallel_for(
0, num_threads, /*grain_size=*/1, [&](int64_t begin, int64_t end) {
for (int64_t thread_id = begin; thread_id < end; ++thread_id) {
mm(thread_id);
}
});
}
void kai_quant_pack_lhs_int4_mm(
const at::Tensor& output,
const at::Tensor& input,
const at::Tensor& weight,
const int64_t m,
const int64_t n,
const int64_t k,
const int64_t bl) {
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
if (bl == k) {
kleidiai::kai_quant_pack_lhs_int4_mm_channelwise(
output, input, weight, m, n, k);
} else if (!(bl % 32) && !(k % bl)) {
const auto input_dtype = input.dtype();
if (input_dtype == at::kBFloat16) {
if (cpuinfo_has_arm_bf16()) {
kleidiai::kai_quant_pack_lhs_int4_mm_bf16_channelwise(
output, input, weight, m, n, k);
} else {
TORCH_CHECK(
false,
"BF16 Unsupported: CPU does not support BF16. Please use a CPU with BF16 support.");
}
} else if (input_dtype == at::kFloat) {
kleidiai::kai_quant_pack_lhs_int4_mm_channelwise(
output, input, weight, m, n, k);
} else {
TORCH_CHECK(
false,
"Unsupported input data type: Only Bfloat16 and Float inputs are supported.");
}
} else if ((bl % 32 == 0) && (k % bl == 0)) {
kleidiai::kai_quant_pack_lhs_int4_mm_groupwise(
output, input, weight, m, n, k, bl);
}

View File

@ -25,7 +25,8 @@ void kai_pack_int4_rhs(
size_t kai_pack_rhs_int4_size(
const int64_t n,
const int64_t k,
const int64_t bl);
const int64_t bl,
at::ScalarType tensor_dtype = at::kFloat);
/**
* @brief Run 2 operations ( Input quantize and pack -> 4 bit Matmul )

View File

@ -36,7 +36,8 @@ void kai_pack_rhs_groupwise_int4(
AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null");
}
float* bias_ptr = bias.has_value() ? bias.value().data_ptr<float>() : NULL;
float* bias_ptr =
bias.has_value() ? bias.value().to(kFloat).data_ptr<float>() : NULL;
auto& params = kernel.rhs_pack_params;
kernel.kai_run_rhs_pack(
@ -73,7 +74,8 @@ void kai_pack_rhs_channelwise_int4(
auto weight_packed_data =
reinterpret_cast<uint8_t*>(weight_packed.data_ptr());
const auto weight_data = weight.data_ptr<uint8_t>();
const auto scales_data = scales.data_ptr<float>();
const auto scales_data = scales.to(kFloat).data_ptr<float>();
if (weight_data == nullptr) {
AT_ERROR("kai_pack_rhs_channelwise_int4: Weight data pointer is null");
@ -83,7 +85,8 @@ void kai_pack_rhs_channelwise_int4(
AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null");
}
float* bias_ptr = bias.has_value() ? bias.value().data_ptr<float>() : NULL;
float* bias_ptr =
bias.has_value() ? bias.value().to(kFloat).data_ptr<float>() : NULL;
auto& params = kernel.rhs_pack_params;
kernel.kai_run_rhs_pack(

View File

@ -68,5 +68,39 @@ kai_matmul_ukernel_f32_qa8dxp_qs4cxp kai_select_channelwise_matmul_ukernel(
const kai_kernel_id id) {
return channelwise_8bit_4bit_kernels.at(id);
}
// Kernel Mapping - BF16 Channelwise
std::unordered_map<kai_kernel_id, kai_matmul_ukernel_bf16_qa8dxp_qs4cxp>
bf16_channelwise_8bit_4bit_kernels = {
{kai_kernel_id::
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
{{kai_get_m_step_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_n_step_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_mr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_nr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_kr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_sr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_lhs_packed_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_rhs_packed_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_dst_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_dst_size_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_run_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod}}},
{kai_kernel_id::matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
{{kai_get_m_step_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_n_step_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_mr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_nr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_kr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_sr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_lhs_packed_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_rhs_packed_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_dst_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_dst_size_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_run_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm}}}};
kai_matmul_ukernel_bf16_qa8dxp_qs4cxp kai_select_bf16_channelwise_matmul_ukernel(
const kai_kernel_id id) {
return bf16_channelwise_8bit_4bit_kernels.at(id);
}
} // namespace at::native::kleidiai
#endif

View File

@ -10,21 +10,32 @@
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h>
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h>
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h>
#include <kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.h>
#include <kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.h>
#include <kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_interface.h>
#include <kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h>
#include <kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.h>
#include <kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h>
#include <kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h>
namespace at::native::kleidiai {
enum class kai_kernel_id {
// FP32 inputs, 4-bit weights, FP32 output
matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod =
0, // Groupwise 4 bit GEMV
0, // Groupwise 4-bit GEMV (per-group scales, NEON DOTPROD)
matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm =
1, // Groupwise 4 bit GEMM
1, // Groupwise 4-bit GEMM (per-group scales, NEON I8MM)
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod =
2, // Channelwise 4 bit GEMV
2, // Channelwise 4-bit GEMV (per-channel scales, NEON DOTPROD)
matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm =
3 // Channelwise 4 bit GEMM
3, // Channelwise 4-bit GEMM (per-channel scales, NEON I8MM)
// BF16 inputs, 4-bit weights, BF16 output
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod =
4, // Channelwise 4-bit GEMV with BF16 input/output
matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm =
5 // Channelwise 4-bit GEMM with BF16 input/output
};
// Channelwise Kernel mapping
@ -66,6 +77,9 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp {
void* rhs_packed,
size_t extra_bytes,
const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params);
size_t(*kai_get_lhs_quant_pack_offset)(
size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr
);
kai_matmul_ukernel_f32_qa8dxp_qs4cxp(
const kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel& kernel)
@ -75,12 +89,71 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp {
kai_get_rhs_packed_size(
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32),
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0) {}
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32){}
};
struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp
kai_select_channelwise_matmul_ukernel(const kai_kernel_id id);
// bf16 Channelwise Kernel mapping
struct kai_matmul_ukernel_bf16_qa8dxp_qs4cxp {
struct kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_ukernel ukernel;
struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params rhs_pack_params;
size_t (*kai_get_lhs_packed_size)(
size_t m,
size_t k,
size_t mr,
size_t kr,
size_t sr);
size_t (*kai_get_rhs_packed_size)(
size_t n,
size_t k,
size_t nr,
size_t kr,
size_t sr);
void (*kai_run_lhs_quant_pack)(
size_t m,
size_t k,
size_t mr,
size_t kr,
size_t sr,
size_t m_idx_start,
const void* lhs,
size_t lhs_stride,
void* lhs_packed);
void (*kai_run_rhs_pack)(
size_t num_groups,
size_t n,
size_t k,
size_t nr,
size_t kr,
size_t sr,
const uint8_t* rhs,
const float* bias,
const float* scale,
void* rhs_packed,
size_t extra_bytes,
const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params);
size_t(*kai_get_lhs_quant_pack_offset)(
size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr
);
kai_matmul_ukernel_bf16_qa8dxp_qs4cxp(
const kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_ukernel& kernel)
: ukernel(kernel),
kai_get_lhs_packed_size(
&kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon),
kai_get_rhs_packed_size(
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_bf16_neon),
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon){}
};
struct kai_matmul_ukernel_bf16_qa8dxp_qs4cxp
kai_select_bf16_channelwise_matmul_ukernel(const kai_kernel_id id);
// Groupwise Kernel mapping
struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel ukernel;
@ -125,6 +198,9 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
void* rhs_packed,
size_t extra_bytes,
const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params);
size_t(*kai_get_lhs_quant_pack_offset)(
size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr
);
kai_matmul_ukernel_f32_qa8dxp_qs4c32p(
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& kernel)
@ -134,7 +210,8 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
kai_get_rhs_packed_size(
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0),
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32),
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0) {}
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0),
kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32) {}
};
struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p kai_select_groupwise_matmul_ukernel(

View File

@ -337,10 +337,6 @@ Tensor _convolution_out(
TORCH_CHECK(
3 == ndim || 4 == ndim || 5 == ndim,
"convolution only supports 3D, 4D, 5D tensor");
// get computation format for Conv/TransposedConv
bool is_channels_last_suggested =
use_channels_last_for_conv(input_r, weight_r);
Tensor input = input_r, weight = weight_r;
// PyTorch does not support ChannelsLast1D case,
// thus we need the transformation here
@ -348,13 +344,8 @@ Tensor _convolution_out(
input = view4d(input_r);
weight = view4d(weight_r);
}
// ensure the input/weight/bias/output are congituous in desired format
at::MemoryFormat mfmt = is_channels_last_suggested
? get_cl_tag_by_ndim(input.ndimension())
: at::MemoryFormat::Contiguous;
auto bias = bias_r.defined() ? bias_r.contiguous() : bias_r;
input = input.contiguous(mfmt);
weight = weight.contiguous(mfmt);
// get computation format for Conv/TransposedConv
bool is_channels_last_suggested = use_channels_last_for_conv(input, weight);
auto k = weight.ndimension();
if (k == input.ndimension() + 1) {
@ -388,6 +379,14 @@ Tensor _convolution_out(
expand_param_if_needed(output_padding_, "output_padding", dim);
params.groups = groups_;
}
// ensure the input/weight/bias/output are congituous in desired format
at::MemoryFormat mfmt = is_channels_last_suggested
? get_cl_tag_by_ndim(input.ndimension())
: at::MemoryFormat::Contiguous;
auto bias = bias_r.defined() ? bias_r.contiguous() : bias_r;
input = input.contiguous(mfmt);
weight = weight.contiguous(mfmt);
check_shape_forward(input, weight, bias, params, true);
Tensor output;
@ -514,18 +513,9 @@ Tensor convolution_overrideable(
at::borrow_from_optional_tensor(bias_r_opt);
const Tensor& bias_r = *bias_r_maybe_owned;
auto k = weight_r.ndimension();
at::MemoryFormat backend_memory_format = at::MemoryFormat::Contiguous;
if (xpu_conv_use_channels_last(input_r, weight_r)) {
backend_memory_format = (k == 5) ? at::MemoryFormat::ChannelsLast3d
: at::MemoryFormat::ChannelsLast;
}
Tensor input_c = input_r.contiguous(backend_memory_format);
Tensor weight_c = weight_r.contiguous(backend_memory_format);
return _convolution(
input_c,
weight_c,
input_r,
weight_r,
bias_r,
stride_,
padding_,

View File

@ -0,0 +1,342 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/BlasBackend.h>
#include <ATen/WrapDimUtilsMulti.h>
#include <ATen/ceil_div.h>
#include <ATen/native/Resize.h>
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
#include <ATen/native/xpu/Blas.h>
#include <torch/library.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_addmm_activation_native.h>
#include <ATen/ops/_efficientzerotensor.h>
#include <ATen/ops/_scaled_mm_native.h>
#include <ATen/ops/_unsafe_view_native.h>
#include <ATen/ops/abs.h>
#include <ATen/ops/addmm_native.h>
#include <ATen/ops/addmv_native.h>
#include <ATen/ops/baddbmm_native.h>
#include <ATen/ops/bmm_native.h>
#include <ATen/ops/copy_native.h>
#include <ATen/ops/dot_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_strided.h>
#include <ATen/ops/gelu.h>
#include <ATen/ops/max.h>
#include <ATen/ops/mm_native.h>
#include <ATen/ops/mul.h>
#include <ATen/ops/ones.h>
#include <ATen/ops/relu.h>
#include <ATen/ops/scalar_tensor_native.h>
#include <ATen/ops/vdot_native.h>
#endif
namespace at::native {
using at::blas::ScalingType;
using at::blas::SwizzleType;
namespace {
/*
* Scaling Type Determination:
* ---------------------------
* Conditions and corresponding Scaling Types:
*
* - If scale tensor is `Float8_e8m0fnu` or `Float8_e4m3fn`:
* - Returns BlockWise (with additional size checks).
*
* - Else if scale.numel() == 1:
* - Returns TensorWise.
*
* - Else if scale.dim() == 2 && scale.size(0) == outer_dim && scale.size(1) ==
* 1:
* - Returns RowWise.
*
* - Otherwise:
* - Returns Error.
*/
bool is_tensorwise_scaling(const at::Tensor& t, const at::Tensor& scale) {
return at::isFloat8Type(t.scalar_type()) &&
scale.scalar_type() == at::kFloat && scale.numel() == 1;
}
bool is_rowwise_scaling(const at::Tensor& t, const at::Tensor& scale) {
return (
at::isFloat8Type(t.scalar_type()) && scale.scalar_type() == at::kFloat &&
scale.dim() == 2 && scale.size(0) == t.size(0) && scale.size(1) == 1 &&
scale.is_contiguous());
}
bool is_desired_scaling(
const at::Tensor& t,
const at::Tensor& scale,
ScalingType desired_scaling) {
auto result = desired_scaling == ScalingType::TensorWise
? is_tensorwise_scaling(t, scale)
: is_rowwise_scaling(t, scale);
return result;
}
std::pair<ScalingType, ScalingType> get_joint_scaling(
std::initializer_list<std::pair<ScalingType, ScalingType>> options,
const at::Tensor& a,
const at::Tensor& b,
const at::Tensor& scale_a,
const at::Tensor& scale_b) {
for (auto [lhs, rhs] : options) {
if (is_desired_scaling(a, scale_a, lhs) &&
is_desired_scaling(b.t(), scale_b.t(), rhs)) {
return {lhs, rhs};
}
}
TORCH_CHECK(
false,
"Invalid scaling configuration.\n"
"- For TensorWise scaling, a and b should be float8, scales should be float and singletons.\n"
"- For RowWise scaling, a and b should be float8, scales should be float, scale_a should be (",
a.size(0),
", 1) and scale_b should be (1, ",
b.size(1),
"), and both should be contiguous.\n"
"Got a.dtype()=",
a.scalar_type(),
", scale_a.dtype()=",
scale_a.scalar_type(),
", scale_a.size()=",
scale_a.sizes(),
", scale_a.stride()=",
scale_a.strides(),
", ",
"b.dtype()=",
b.scalar_type(),
", scale_b.dtype()=",
scale_b.scalar_type(),
", scale_b.size()=",
scale_b.sizes(),
" and scale_b.stride()=",
scale_b.strides());
}
Tensor& _scaled_gemm(
const Tensor& mat1,
const Tensor& mat2,
const Tensor& scale_a,
const Tensor& scale_b,
const ScalingType scaling_choice_a,
const ScalingType scaling_choice_b,
const std::optional<Tensor>& bias,
const bool use_fast_accum,
Tensor& out,
const std::optional<Tensor>& alpha = std::nullopt) {
// TODO: scale_result and alpha is not defined or used!
std::optional<Tensor> scaled_result = std::nullopt;
at::native::onednn::scaled_matmul(
mat1,
mat2,
out,
scale_a,
scale_b,
scaling_choice_a,
scaling_choice_b,
bias,
scaled_result,
use_fast_accum);
return out;
}
} // namespace
// Computes matrix multiply + bias while applying scaling to input and output
// matrices Scales are only applicable when matrices are of Float8 type and
// assumed to be equal to 1.0 by default. If output matrix type is 16 or 32-bit
// type, scale_result is not applied. Known limitations:
// - Only works if mat1 is row-major and mat2 is column-major
// - Only works if matrices sizes are divisible by 32
// - If 1-dimensional tensors are used then scale_a should be size =
// mat1.size(0)
// and scale_b should have size = to mat2.size(1)
// Arguments:
// - `mat1`: the first operand of the matrix multiply, can be type
// `torch.float8_e4m3fn` or `torch.float8_e5m2`
// - `mat2`: the second operand of the matrix multiply, can be type
// `torch.float8_e4m3fn` or `torch.float8_e5m2`
// - `bias`: the bias, can be type `torch.float16` or `torch.bfloat16`
// - `out_dtype`: the output dtype, can either be a float8 or a higher
// precision floating point type
// - `scale_a`: a tensor with the inverse scale of `mat1`, whose
// shape/strides/dtype depend on the scaling scheme
// - `scale_b`: a tensor with the inverse scale of `mat2`, whose
// shape/strides/dtype depend on the scaling scheme
// - `scale_result`: a scalar tensor with the scale of the output, only
// utilized if the output is a float8 type
// - `use_fast_accum`: Not applicable for XPU. For now, it should always be
// false.
// - `out`: a reference to the output tensor
Tensor& _scaled_mm_out_xpu(
const Tensor& mat1,
const Tensor& mat2,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
std::optional<c10::ScalarType> out_dtype,
bool use_fast_accum,
Tensor& out) {
// Note: fast_accum is not supported in XPU for now.
TORCH_CHECK(!use_fast_accum, "fast_accum is not supported in XPU for now.");
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
TORCH_CHECK(
mat1.sizes()[1] == mat2.sizes()[0],
"mat1 and mat2 shapes cannot be multiplied (",
mat1.sizes()[0],
"x",
mat1.sizes()[1],
" and ",
mat2.sizes()[0],
"x",
mat2.sizes()[1],
")");
// Check what type of scaling we are doing based on inputs. This list is
// sorted by decreasing priority.
// List of supported datatypes for XPU with oneDNN:
// https://uxlfoundation.github.io/oneDNN/dev_guide_matmul.html#data-types
auto [scaling_choice_a, scaling_choice_b] = get_joint_scaling(
{
std::make_pair(ScalingType::TensorWise, ScalingType::TensorWise),
std::make_pair(ScalingType::RowWise, ScalingType::RowWise),
},
mat1,
mat2,
scale_a,
scale_b);
TORCH_CHECK(
!scale_result ||
(scale_result->numel() == 1 && scale_result->scalar_type() == kFloat),
"scale_result must be a float scalar");
TORCH_CHECK(
!bias || bias->numel() == mat2.sizes()[1],
"Bias must be size ",
mat2.sizes()[1],
" but got ",
bias->numel());
TORCH_CHECK(
mat1.sizes()[1] % 16 == 0,
"Expected trailing dimension of mat1 to be divisible by 16 ",
"but got mat1 shape: (",
mat1.sizes()[0],
"x",
mat1.sizes()[1],
").");
TORCH_CHECK(
mat2.sizes()[0] % 16 == 0 && mat2.sizes()[1] % 16 == 0,
"mat2 shape (",
mat2.sizes()[0],
"x",
mat2.sizes()[1],
") must be divisible by 16");
// Check types
TORCH_CHECK(
!out_dtype || *out_dtype == out.scalar_type(),
"out_dtype must match output matrix type");
TORCH_CHECK(
at::isFloat8Type(mat1.scalar_type()),
"Expected mat1 to be Float8 matrix got ",
mat1.scalar_type());
TORCH_CHECK(
at::isFloat8Type(mat2.scalar_type()),
"Expected mat2 to be Float8 matrix got ",
mat2.scalar_type());
// TODO: oneDNN Currently only supports e4m3 with group scales on BMG. Not
// support 2D scales, only 1D. Needs to add more checks there.
if (bias) {
TORCH_CHECK(
bias->scalar_type() == kFloat ||
bias->scalar_type() == c10::ScalarType::BFloat16 ||
bias->scalar_type() == c10::ScalarType::Half,
"Bias must be Float32 or BFloat16 or Half, but got ",
bias->scalar_type());
}
{
auto bias_ = bias.value_or(Tensor());
auto scale_result_ = scale_result.value_or(Tensor());
// NOLINTNEXTLINE(*c-array*)
TensorArg targs[]{
{out, "out", 0},
{mat1, "mat1", 1},
{mat2, "mat2", 2},
{bias_, "bias", 3},
{scale_a, "scale_a", 4},
{scale_b, "scale_b", 5},
{scale_result_, "scale_result", 6}};
checkAllSameGPU(__func__, targs);
}
// Validation checks have passed lets resize the output to actual size
IntArrayRef mat1_sizes = mat1.sizes();
IntArrayRef mat2_sizes = mat2.sizes();
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
// If any of M, K, N is 0 - return early (the tensorwise/rowwise float8 gemm
// kernels do not support this case).
if (mat1_sizes[0] == 0 || mat1_sizes[1] == 0 || mat2_sizes[1] == 0) {
// `out` was created with `at::empty`. In the case where we are multiplying
// MxK by KxN and K is the zero dim, we need to initialize here to properly
// return a tensor of zeros.
if (mat1_sizes[1] == 0) {
out.zero_();
}
return out;
}
// TODO: Scale_result is not supported by now!!
return _scaled_gemm(
mat1,
mat2,
scale_a,
scale_b,
scaling_choice_a,
scaling_choice_b,
bias,
use_fast_accum,
out);
}
Tensor _scaled_mm_xpu(
const Tensor& mat_a,
const Tensor& mat_b,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
std::optional<c10::ScalarType> out_dtype,
bool use_fast_accum) {
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_));
return _scaled_mm_out_xpu(
mat_a,
mat_b,
scale_a,
scale_b,
bias,
scale_result,
out_dtype,
use_fast_accum,
out);
}
} // namespace at::native

View File

@ -1,3 +1,4 @@
#include <ATen/BlasBackend.h>
#include <ATen/Tensor.h>
#include <ATen/core/Tensor.h>
#include <c10/core/ScalarType.h>
@ -8,7 +9,6 @@
#include <oneapi/dnnl/dnnl.hpp>
namespace at::native::onednn {
at::Tensor broadcast_bias2D(
at::Tensor& dst,
at::Tensor& bias,
@ -328,4 +328,236 @@ void quantized_matmul(
result.copy_(dst);
}
// Describes how to configure oneDNN scales for a given role/ScalingType
struct ScaleSpec {
// specifies the way scale values will be applied to an ARG tensor.
int mask;
// specifies how scales are grouped along dimensions where
// multiple scale factors are used.
dnnl::memory::dims groups;
// specifies data type for scale factors.
dnnl::memory::data_type dtype;
// Helper to compute expected number of elements for scale tensors
// arg_type: "src" for SRC (groups pattern {1, X}),
// "wei" for WEIGHTS (groups pattern {X, 1})
int64_t expected_numel(
int64_t outer_dim,
int64_t inner_dim,
const std::string& arg_type) const {
if (groups == dnnl::memory::dims{1, 1})
return 1; // tensorwise scaling
TORCH_CHECK(
arg_type == "src" || arg_type == "wei",
"Expected arg_type to be 'src' or 'wei', but got '",
arg_type,
"'");
// For rowwise: SRC groups={1, K}, WEI groups={K, 1}
TORCH_INTERNAL_ASSERT(
(groups == dnnl::memory::dims{1, inner_dim} ||
groups == dnnl::memory::dims{inner_dim, 1}),
"The groups must be either {1, inner_dim} or {inner_dim, 1}. But got ",
groups,
".");
return outer_dim;
}
// Normalize an incoming scale tensor to contiguous storage and appropriate
// dtype/view
at::Tensor normalize(const at::Tensor& scale) const {
TORCH_INTERNAL_ASSERT(
dtype == dnnl::memory::data_type::f32,
"tensor scale currently must be f32, but got scale dtype: ",
scale.scalar_type());
return scale.to(at::kFloat).contiguous();
}
};
// This function defines how to set scales mask and groups according to:
// https://github.com/uxlfoundation/oneDNN/blob/main/tests/benchdnn/doc/knobs_attr.md#--attr-scales
// The returned value will be used in
// `set_scales(arg, mask, groups, data_type)`.
inline ScaleSpec make_scale_spec(
at::blas::ScalingType scaling_type,
int64_t M,
int64_t K,
int64_t N,
const std::string& arg_type) {
TORCH_CHECK(
arg_type == "src" || arg_type == "wei",
"Expected arg_type to be 'src' or 'wei', but got '",
arg_type,
"'");
TORCH_INTERNAL_ASSERT(
(scaling_type == at::blas::ScalingType::TensorWise ||
scaling_type == at::blas::ScalingType::RowWise),
"Currently only support scaling_type for TensorWise or RowWise");
int64_t dim = K; // Currently only K is used for grouping
bool is_src = (arg_type == "src");
if (scaling_type == at::blas::ScalingType::TensorWise) {
// Scale tensorwise. The same as `--attr-scales=common`.
// mask=0 : scale whole tensor
// groups={1, 1}: indicates that there is only one group for scaling
return {0, {1, 1}, dnnl::memory::data_type::f32};
} else {
// (scaling_type == at::blas::ScalingType::RowWise)
// Scale RowWise. The same as `--attr-scales=per_dim_01`.
// mask={(1 << 0) | (1 << 1)}: Scale on both dim0 and dim1
// SRC: groups={1, K}, WEIGHTS: groups={K, 1}
return {
(1 << 0) | (1 << 1),
is_src ? dnnl::memory::dims{1, dim} : dnnl::memory::dims{dim, 1},
dnnl::memory::data_type::f32};
}
}
sycl::event scaled_matmul(
const Tensor& mat1,
const Tensor& mat2,
Tensor& result,
const Tensor& scale_a,
const Tensor& scale_b,
at::blas::ScalingType scaling_choice_a,
at::blas::ScalingType scaling_choice_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
bool use_fast_accum) {
auto& engine = GpuEngineManager::Instance().get_engine();
auto& stream = GpuStreamManager::Instance().get_stream();
// This function will do steps with following steps
// 1. create memory descriptor
// 2. call write_to_dnnl_memory() to actually write memory
// 3. execute
const int64_t M = mat1.size(0);
const int64_t K = mat1.size(1);
const int64_t N = mat2.size(1);
// 1.1 Create memory descriptor
dnnl::memory::desc src_md = get_onednn_md(mat1);
dnnl::memory::desc weights_md = get_onednn_md(mat2);
dnnl::memory::desc dst_md = get_onednn_md(result);
// scale_a and scale_b has already be checked in `is_desired_scaling()` call.
// So we could directly get their memory desc and set later.
dnnl::memory::desc scale_a_md = get_onednn_md(scale_a);
dnnl::memory::desc scale_b_md = get_onednn_md(scale_b);
dnnl::memory::desc bias_md;
bool with_bias = bias.has_value();
at::Tensor possible_reshaped_bias = bias.value_or(at::Tensor());
if (with_bias) {
if (possible_reshaped_bias.dim() == 1) {
possible_reshaped_bias =
possible_reshaped_bias.reshape({1, possible_reshaped_bias.size(0)});
bias_md = get_onednn_md(possible_reshaped_bias);
} else {
bias_md = get_onednn_md(possible_reshaped_bias);
}
}
// 1.2 Create primitive descriptor and set scales mask
const ScaleSpec src_spec = make_scale_spec(scaling_choice_a, M, K, N, "src");
const ScaleSpec wei_spec = make_scale_spec(scaling_choice_b, M, K, N, "wei");
dnnl::primitive_attr op_attr = dnnl::primitive_attr();
#if ONEDNN_SUPPORT_DETERMINISTIC
if (at::globalContext().deterministicAlgorithms() ||
at::globalContext().deterministicMkldnn())
op_attr.set_deterministic(true);
#endif
std::vector<int64_t> default_groups;
op_attr.set_scales(
DNNL_ARG_SRC, src_spec.mask, src_spec.groups, src_spec.dtype);
op_attr.set_scales(
DNNL_ARG_WEIGHTS, wei_spec.mask, wei_spec.groups, wei_spec.dtype);
// scale_result tensor currently only supports scalar(TensorWise Scaling).
bool with_dst_scale = scale_result && scale_result->defined();
if (with_dst_scale) {
op_attr.set_scales(DNNL_ARG_DST, 0, {1}, dnnl::memory::data_type::f32);
}
op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
// 1.3 Create the matmul primitive descriptor
dnnl::matmul::primitive_desc matmul_pd = with_bias
? dnnl::matmul::primitive_desc(
engine, src_md, weights_md, bias_md, dst_md, op_attr)
: dnnl::matmul::primitive_desc(
engine, src_md, weights_md, dst_md, op_attr);
// 1.4 (Possible) Additional Checks
// TODO: In case there are memory desc does not align with the actual tensor,
// we might need to reorder weights similar to CPU's reorder_if_differ_in()
// call. For example, weights not the same as matmul_pd.weights_desc(),
// 2. Prepare memory
// Create memory
auto src_usr_m = make_onednn_memory(src_md, engine, mat1.data_ptr());
auto weights_usr_m = make_onednn_memory(weights_md, engine, mat2.data_ptr());
auto dst_usr_m = make_onednn_memory(dst_md, engine, result.data_ptr());
dnnl::memory b_usr_m;
if (with_bias) {
b_usr_m =
make_onednn_memory(bias_md, engine, possible_reshaped_bias.data_ptr());
}
// Prepare runtime scale memories (flat 1-D views) using the specs
auto make_scale_mem_from_spec = [&](const ScaleSpec& spec,
int64_t expected_numel,
const at::Tensor& scale_tensor) {
at::Tensor prepared = spec.normalize(scale_tensor);
TORCH_CHECK(
prepared.numel() == expected_numel,
"Scale buffer length mismatch. Expected ",
expected_numel,
", got ",
prepared.numel());
dnnl::memory::desc scale_md(
{prepared.numel()}, spec.dtype, dnnl::memory::format_tag::x);
return make_onednn_memory(scale_md, engine, prepared.data_ptr());
};
auto scratchpad =
make_onednn_memory(matmul_pd.scratchpad_desc(), engine, nullptr);
// 3. Setup Args for exec
std::unordered_map<int, dnnl::memory> args;
args.insert({DNNL_ARG_SRC, src_usr_m});
args.insert({DNNL_ARG_WEIGHTS, weights_usr_m});
args.insert({DNNL_ARG_DST, dst_usr_m});
args.insert({DNNL_ARG_SCRATCHPAD, scratchpad});
if (with_bias) {
args.insert({DNNL_ARG_BIAS, b_usr_m});
}
// Attach runtime scales using specs
auto src_sc_mem = make_scale_mem_from_spec(
src_spec, src_spec.expected_numel(M, K, "src"), scale_a);
auto wei_sc_mem = make_scale_mem_from_spec(
wei_spec, wei_spec.expected_numel(N, K, "wei"), scale_b);
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_sc_mem});
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_sc_mem});
if (with_dst_scale) {
// Bind single f32 scalar as DST scale
at::Tensor dst_scale_f32 = scale_result->to(at::kFloat).contiguous();
dnnl::memory::desc dst_sc_md(
{1}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::x);
auto dst_sc_mem =
make_onednn_memory(dst_sc_md, engine, dst_scale_f32.data_ptr());
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_sc_mem});
}
dnnl::matmul matmul_p = dnnl::matmul(matmul_pd);
sycl::event matmul_fwd_event =
dnnl::sycl_interop::execute(matmul_p, stream, args);
return matmul_fwd_event;
}
} // namespace at::native::onednn

View File

@ -78,6 +78,10 @@ dnnl::memory::data_type get_onednn_dtype(
return dnnl::memory::data_type::f32;
case at::ScalarType::BFloat16:
return dnnl::memory::data_type::bf16;
case at::ScalarType::Float8_e4m3fn:
return dnnl::memory::data_type::f8_e4m3;
case at::ScalarType::Float8_e5m2:
return dnnl::memory::data_type::f8_e5m2;
default:
if (!allow_undef) {
TORCH_CHECK(

View File

@ -1,6 +1,7 @@
#pragma once
#include <ATen/ATen.h>
#include <ATen/BlasBackend.h>
#include <ATen/native/mkldnn/xpu/detail/Attr.h>
#include <ATen/native/mkldnn/xpu/detail/Utils.h>
#include <ATen/native/mkldnn/xpu/detail/oneDNNContext.h>
@ -202,4 +203,16 @@ void sdpa_backward(
Tensor& grad_query,
Tensor& grad_key,
Tensor& grad_value);
sycl::event scaled_matmul(
const Tensor& mat1,
const Tensor& mat2,
Tensor& result,
const Tensor& scale_a,
const Tensor& scale_b,
at::blas::ScalingType scaling_choice_a,
at::blas::ScalingType scaling_choice_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
bool use_fast_accum);
} // namespace at::native::onednn

View File

@ -82,6 +82,7 @@ NSArray<NSNumber*>* getTensorAxes(const TensorBase& t);
NSArray<NSNumber*>* getTensorAxes(const IntArrayRef& sizes, at::OptionalIntArrayRef dim);
std::string getMPSShapeString(MPSShape* shape);
std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype = true, bool exclude_shape = false);
std::string to_hex_key(float);
std::string getArrayRefString(const IntArrayRef s);
// use has_storage() on the returned tensor to determine if src actually is a view
Tensor gatherViewTensor(const Tensor& src, Tensor& dst);

View File

@ -301,6 +301,10 @@ std::string getArrayRefString(const IntArrayRef s) {
return fmt::to_string(fmt::join(s, ","));
}
std::string to_hex_key(float f) {
return fmt::format("{:a}", f);
}
std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype, bool exclude_shape) {
fmt::basic_memory_buffer<char, 100> buffer;
auto buf_iterator = std::back_inserter(buffer);

View File

@ -40,7 +40,7 @@ inline c10::metal::opmath_t<T> matmul_inner(
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint k = 0; k < TILE_DIM; k++) {
sum += A_tile[tid.y][k] * B_tile[k][tid.x];
sum += c10::metal::mul(A_tile[tid.y][k], B_tile[k][tid.x]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
@ -96,7 +96,9 @@ kernel void addmm(
auto bias =
biasData[thread_id.y * strides[3].x + thread_id.x * strides[3].y];
outputData[thread_id.y * strides[2].x + thread_id.x * strides[2].y] =
static_cast<T>(alpha_beta[0] * sum + alpha_beta[1] * bias);
static_cast<T>(
c10::metal::mul(alpha_beta[0], sum) +
c10::metal::mul(alpha_beta[1], bias));
}
}
@ -832,6 +834,10 @@ INSTANTIATE_MM_OPS(float);
INSTANTIATE_MM_OPS(half);
INSTANTIATE_MM_OPS(bfloat);
// Complex MM
INSTANTIATE_MM_OPS(float2);
INSTANTIATE_MM_OPS(half2);
// Integral MM
INSTANTIATE_MM_OPS(long);
INSTANTIATE_MM_OPS(int);

View File

@ -121,7 +121,7 @@ Tensor& do_metal_addmm(const Tensor& self,
const Scalar& alpha,
const Scalar& beta,
const Tensor& bias) {
if (beta.toDouble() == 0 && alpha.toDouble() == 1) {
if (beta.isFloatingPoint() && alpha.isFloatingPoint() && beta.toDouble() == 0 && alpha.toDouble() == 1) {
return do_metal_mm(self, other, output);
}
auto stream = getCurrentMPSStream();
@ -147,13 +147,15 @@ Tensor& do_metal_addmm(const Tensor& self,
std::array<int64_t, 2> i64;
std::array<int32_t, 2> i32;
std::array<float, 2> f32;
} alpha_beta;
std::array<c10::complex<float>, 2> c64;
} alpha_beta{};
if (output.scalar_type() == kLong) {
alpha_beta.i64 = {alpha.toLong(), beta.toLong()};
} else if (c10::isIntegralType(output.scalar_type(), true)) {
alpha_beta.i32 = {alpha.toInt(), beta.toInt()};
} else if (c10::isComplexType(output.scalar_type())) {
alpha_beta.c64 = {alpha.toComplexFloat(), beta.toComplexFloat()};
} else {
TORCH_INTERNAL_ASSERT(c10::isFloatingType(output.scalar_type()));
alpha_beta.f32 = {alpha.toFloat(), beta.toFloat()};
}
constexpr uint32_t TILE_DIM = 16; // fastest performance from tests on multiple macs
@ -190,10 +192,16 @@ std::tuple<MPSGraphTensor*, MPSGraphTensor*, MPSGraphTensor*> do_mm(MPSGraph* gr
bool use_metal_mm(const Tensor& self, const Tensor& other, const Tensor& output) {
static bool always_use_metal = c10::utils::has_env("PYTORCH_MPS_PREFER_METAL");
constexpr auto max_stride_size = 32768;
constexpr auto max_complex_inner_size = 2048;
static bool is_macos_14_4_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_4_PLUS);
if (always_use_metal || c10::isIntegralType(self.scalar_type(), true)) {
return true;
}
// multiplicationWithPrimaryTensor: returns incorrect results if inner size exceeds 2048
// See https://github.com/pytorch/pytorch/issues/167727#issuecomment-3529308548
if (c10::isComplexType(self.scalar_type()) && self.size(1) > max_complex_inner_size) {
return true;
}
return !is_macos_14_4_or_newer &&
(self.stride(0) > max_stride_size || self.stride(1) > max_stride_size || self.size(0) > max_stride_size ||
self.size(1) > max_stride_size || other.stride(0) > max_stride_size || other.stride(1) > max_stride_size ||

View File

@ -91,25 +91,30 @@ static auto& lib = mps::MetalShaderLibrary::getBundledLibrary();
#include <ATen/native/mps/Repeat_metallib.h>
#endif
template <typename index_t>
void computeRepeatIndices(const index_t* repeat_ptr,
const int64_t* cumsum_ptr,
index_t* result_ptr,
int64_t size,
int64_t result_size) {
id<MTLBuffer> repeatBuffer = reinterpret_cast<id<MTLBuffer>>(repeat_ptr);
id<MTLBuffer> cumsumBuffer = reinterpret_cast<id<MTLBuffer>>(cumsum_ptr);
id<MTLBuffer> resultBuffer = reinterpret_cast<id<MTLBuffer>>(result_ptr);
TORCH_CHECK(repeatBuffer && cumsumBuffer && resultBuffer);
Tensor repeat_interleave_mps(const Tensor& repeat, std::optional<int64_t> output_size) {
TORCH_CHECK(repeat.dim() == 1, "repeat_interleave only accept 1D vector as repeat");
std::string scalar_type;
if constexpr (std::is_same_v<index_t, int32_t>) {
if (repeat.scalar_type() == kInt) {
scalar_type = "int32_t";
} else if constexpr (std::is_same_v<index_t, int64_t>) {
} else if (repeat.scalar_type() == kLong) {
scalar_type = "int64_t";
} else {
TORCH_CHECK(false, "repeat_interleave: unsupported indexing data type");
TORCH_CHECK(false, "repeats has to be Long or Int tensor");
}
if (repeat.size(0) == 0) {
return at::empty_like(repeat, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}
Tensor repeat_ = repeat.contiguous();
Tensor cumsum = repeat.cumsum(0);
int64_t total = 0;
if (output_size.has_value()) {
total = output_size.value();
} else {
total = cumsum[-1].item<int64_t>();
TORCH_CHECK((repeat >= 0).all().item<uint8_t>(), "repeats can not be negative");
}
auto result = at::empty({total}, repeat.options());
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@ -121,20 +126,13 @@ void computeRepeatIndices(const index_t* repeat_ptr,
getMPSProfiler().beginProfileKernel(pipelineState, "repeat_interleave:" + scalar_type, false);
[computeEncoder setComputePipelineState:pipelineState];
mps::mtl_setArgs(computeEncoder, repeatBuffer, cumsumBuffer, resultBuffer, size);
mps::mtl_dispatch1DJob(computeEncoder, pipelineState, size);
mps::mtl_setArgs(computeEncoder, repeat_, cumsum, result, repeat.size(0));
mps::mtl_dispatch1DJob(computeEncoder, pipelineState, repeat.size(0));
getMPSProfiler().endProfileKernel(pipelineState);
}
});
}
Tensor repeat_interleave_mps(const Tensor& repeat, std::optional<int64_t> output_size) {
Tensor output;
AT_DISPATCH_INDEX_TYPES(repeat.scalar_type(), "repeat_interleave_mps", [&]() {
output = repeat_interleave_common<index_t, computeRepeatIndices<index_t>>(repeat, output_size);
});
return output;
return result;
}
} // namespace at::native

View File

@ -5,6 +5,7 @@
#include <ATen/native/Resize.h>
#include <ATen/native/TensorCompare.h>
#include <ATen/native/mps/OperationUtils.h>
#include <algorithm>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@ -89,13 +90,21 @@ static void check_min_max_dims(const OptionalTensorRef clamp_opt, const Tensor&
auto clamp_shape = clamp_opt->sizes();
auto input_shape = input_t.sizes();
TORCH_CHECK(num_clamp_dims <= num_input_dims,
op_name + ": clamp tensor number of dims must not be greater than that of input tensor")
if (num_clamp_dims > num_input_dims) {
auto leading_dims = num_clamp_dims - num_input_dims;
for (int64_t i = 0; i < leading_dims; ++i) {
TORCH_CHECK(clamp_shape[i] == 1,
op_name + ": clamp tensor leading shape must be 1 to broadcast with input tensor");
}
}
for (int i = 0; i < num_clamp_dims; i++)
auto clamp_idx = num_clamp_dims - 1;
auto input_idx = num_input_dims - 1;
auto common_dims = std::min(num_clamp_dims, num_input_dims);
for (int64_t i = 0; i < common_dims; ++i)
// One of the indices is allowed to be 1; will be handled by broadcast
TORCH_CHECK(clamp_shape[num_clamp_dims - 1 - i] == input_shape[num_input_dims - 1 - i] ||
clamp_shape[num_clamp_dims - 1 - i] == 1 || input_shape[num_input_dims - 1 - i] == 1,
TORCH_CHECK(clamp_shape[clamp_idx - i] == input_shape[input_idx - i] || clamp_shape[clamp_idx - i] == 1 ||
input_shape[input_idx - i] == 1,
op_name + ": clamp tensor trailing shape must match input tensor")
}
}
@ -136,9 +145,6 @@ static void clamp_tensor_out_mps(const Tensor& input_t,
auto result_type = output_t.scalar_type();
IntArrayRef new_min_shape;
IntArrayRef new_max_shape;
auto num_min_dims = min_opt->dim();
auto num_max_dims = max_opt->dim();
auto num_input_dims = input_t.dim();
@ -146,24 +152,32 @@ static void clamp_tensor_out_mps(const Tensor& input_t,
std::vector<int64_t> new_min_arr(num_input_dims);
std::vector<int64_t> new_max_arr(num_input_dims);
if (has_min && num_min_dims < num_input_dims) {
fill_new_shape(num_input_dims, num_min_dims, new_min_arr.data(), min_opt->sizes());
new_min_shape = IntArrayRef(new_min_arr);
}
if (has_max && num_max_dims < num_input_dims) {
fill_new_shape(num_input_dims, num_max_dims, new_max_arr.data(), max_opt->sizes());
new_max_shape = IntArrayRef(new_max_arr);
}
Tensor min_opt_tensor;
Tensor max_opt_tensor;
auto reshape_clamp_tensor = [&](const OptionalTensorRef clamp_tensor_ref,
int64_t num_clamp_dims,
std::vector<int64_t>& new_shape_storage) -> Tensor {
IntArrayRef clamp_shape = clamp_tensor_ref->sizes();
bool requires_view = false;
if (num_clamp_dims > num_input_dims) {
clamp_shape = clamp_shape.slice(num_clamp_dims - num_input_dims);
requires_view = true;
} else if (num_clamp_dims < num_input_dims) {
fill_new_shape(num_input_dims, num_clamp_dims, new_shape_storage.data(), clamp_shape);
clamp_shape = IntArrayRef(new_shape_storage);
requires_view = true;
}
return requires_view ? (*clamp_tensor_ref).view(clamp_shape) : *clamp_tensor_ref;
};
if (has_min) {
min_opt_tensor = (num_min_dims < num_input_dims) ? (*min_opt).view(new_min_shape) : *min_opt;
min_opt_tensor = reshape_clamp_tensor(min_opt, num_min_dims, new_min_arr);
}
if (has_max) {
max_opt_tensor = (num_max_dims < num_input_dims) ? (*max_opt).view(new_max_shape) : *max_opt;
max_opt_tensor = reshape_clamp_tensor(max_opt, num_max_dims, new_max_arr);
}
@autoreleasepool {
@ -244,8 +258,8 @@ static void clamp_scalar_out_mps(const Tensor& input_t,
@autoreleasepool {
// the optional min/max refs could affect how we build the cached graph
std::string key = op_name + (has_min ? ("_min:" + std::to_string(min_scalar)) : "") +
(has_max ? ("_max:" + std::to_string(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
std::string key = op_name + (has_min ? ("_min:" + to_hex_key(min_scalar)) : "") +
(has_max ? ("_max:" + to_hex_key(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
if (has_min)
newCachedGraph->minTensor = [mpsGraph constantWithScalar:min_scalar

View File

@ -4389,7 +4389,7 @@
variants: function, method
dispatch:
CompositeExplicitAutograd: mv
SparseCPU, SparseCUDA: mv_sparse
SparseCPU, SparseCUDA, SparseMPS: mv_sparse
- func: mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
@ -7518,7 +7518,7 @@
- func: _sparse_mask_projection(Tensor self, Tensor mask, bool accumulate_matches=False) -> Tensor
variants: method
dispatch:
SparseCPU, SparseCUDA: sparse_mask_projection
SparseCPU, SparseCUDA, SparseMPS: sparse_mask_projection
autogen: _sparse_mask_projection.out
- func: _to_cpu(Tensor[] tensors) -> Tensor[]

View File

@ -30,10 +30,12 @@
#include <thrust/binary_search.h>
#include <thrust/device_ptr.h>
#include <thrust/distance.h>
#include <thrust/iterator/constant_iterator.h>
#include <thrust/scan.h>
#include <thrust/sequence.h>
#include <thrust/sort.h>
#include <thrust/system/cuda/execution_policy.h>
#include <thrust/iterator/constant_iterator.h>
#include <cuda_runtime_api.h>
#include <cusparse.h>

View File

@ -445,6 +445,33 @@ static SparseTensor& mul_out_dense_sparse_mps(
return out;
}
static std::tuple<Tensor, Tensor, int64_t> mps_intersect_binary_search(
const Tensor& A_keys,
const Tensor& B_keys,
int64_t lenA,
int64_t lenB,
bool boolean_flag) {
auto stream = getCurrentMPSStream();
auto outA_idx = at::empty({lenA}, A_keys.options().dtype(at::kLong));
auto outB_idx = at::empty({lenA}, A_keys.options().dtype(at::kLong));
auto counter = at::zeros({1}, A_keys.options().dtype(at::kInt));
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pso = lib.getPipelineStateForFunc("intersect_binary_search");
auto enc = stream->commandEncoder();
[enc setComputePipelineState:pso];
mtl_setArgs(enc, A_keys, B_keys, outA_idx, outB_idx, counter,
static_cast<uint32_t>(lenB), boolean_flag);
mtl_dispatch1DJob(enc, pso, static_cast<uint32_t>(lenA));
}
});
const auto match_count = static_cast<int64_t>(counter.item<int32_t>());
return std::make_tuple(std::move(outA_idx), std::move(outB_idx), match_count);
}
SparseTensor& mul_out_sparse_mps(const Tensor& t_, const Tensor& src_, SparseTensor& r_) {
TORCH_CHECK(r_.is_mps(), "mul: expected 'out' to be MPS, but got ", r_.device());
@ -523,22 +550,10 @@ SparseTensor& mul_out_sparse_mps(const Tensor& t_, const Tensor& src_, SparseTen
auto A_keys = A_is_lhs ? lhs_keys : rhs_keys;
auto B_keys = A_is_lhs ? rhs_keys : lhs_keys;
auto outA_idx = at::empty({lenA}, at::device(device).dtype(kLong));
auto outB_idx = at::empty({lenA}, at::device(device).dtype(kLong));
auto counter = at::zeros({1}, at::device(device).dtype(kInt));
auto [outA_idx, outB_idx, M_int64] = mps_intersect_binary_search(
A_keys, B_keys, lenA, lenB, A_is_lhs);
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pso = lib.getPipelineStateForFunc("intersect_binary_search");
auto enc = stream->commandEncoder();
[enc setComputePipelineState:pso];
mtl_setArgs(enc, A_keys, B_keys, outA_idx, outB_idx, counter,
static_cast<uint32_t>(lenB), A_is_lhs);
mtl_dispatch1DJob(enc, pso, static_cast<uint32_t>(lenA));
}
});
const uint32_t M = counter.item<int32_t>(); // number of structural matches
const auto M = static_cast<uint32_t>(M_int64); // number of structural matches
r_.resize_as_(lhs);
@ -762,6 +777,14 @@ SparseTensor& add_out_sparse_mps(const SparseTensor& self,
using OptTensor = std::optional<Tensor>;
static Tensor create_sparse_output_values(
const Tensor& template_values,
int64_t output_nnz,
ScalarType dtype) {
auto out_val_sizes = template_values.sizes().vec();
out_val_sizes[0] = output_nnz;
return at::zeros(out_val_sizes, template_values.options().dtype(dtype));
}
static void sparse_mask_apply_out_mps_kernel(
Tensor& result,
@ -783,9 +806,9 @@ static void sparse_mask_apply_out_mps_kernel(
auto src = src_in.coalesce();
auto mask = coalesce_mask ? mask_in.coalesce() : mask_in;
const int64_t src_nnz = src._nnz();
const int64_t mask_nnz = mask._nnz();
const int64_t sd = src.sparse_dim();
const auto src_nnz = src._nnz();
const auto mask_nnz = mask._nnz();
const auto sd = src.sparse_dim();
result.sparse_resize_(mask.sizes(), mask.sparse_dim(), mask.dense_dim());
auto commonDtype = at::result_type(src, mask);
@ -814,53 +837,27 @@ static void sparse_mask_apply_out_mps_kernel(
return;
}
auto mask_indices = mask._indices().contiguous();
auto src_values = src._values().to(commonDtype).contiguous();
auto out_values = create_sparse_output_values(src_values, mask_nnz, commonDtype);
if (src_nnz == 0) {
auto out_indices = mask._indices().contiguous();
auto src_values = src._values().to(commonDtype);
auto out_val_sizes = src_values.sizes().vec();
out_val_sizes[0] = mask_nnz;
auto out_values = at::zeros(out_val_sizes, src_values.options());
alias_into_sparse(result, out_indices, out_values);
alias_into_sparse(result, mask_indices, out_values);
result._coalesced_(mask.is_coalesced());
return;
}
auto mask_indices = mask._indices().contiguous();
auto src_indices = src._indices().contiguous();
auto src_values = src._values().to(commonDtype).contiguous();
auto mask_keys = flatten_indices(mask._indices().contiguous(), mask.sizes().slice(0, sd)).contiguous();
auto src_keys = flatten_indices(src._indices().contiguous(), src.sizes().slice(0, sd)).contiguous();
auto mask_keys = flatten_indices(mask_indices, mask.sizes().slice(0, sd)).contiguous();
auto src_keys = flatten_indices(src_indices, src.sizes().slice(0, sd)).contiguous();
const bool A_is_src = (src_nnz <= mask_nnz);
const int64_t lenA = A_is_src ? src_nnz : mask_nnz;
const int64_t lenB = A_is_src ? mask_nnz : src_nnz;
const auto A_is_src = (src_nnz <= mask_nnz);
const auto lenA = A_is_src ? src_nnz : mask_nnz;
const auto lenB = A_is_src ? mask_nnz : src_nnz;
auto A_keys = A_is_src ? src_keys : mask_keys;
auto B_keys = A_is_src ? mask_keys : src_keys;
const auto device = result.device();
auto stream = getCurrentMPSStream();
auto outA_idx = at::empty({lenA}, at::device(device).dtype(at::kLong));
auto outB_idx = at::empty({lenA}, at::device(device).dtype(at::kLong));
auto counter = at::zeros({1}, at::device(device).dtype(at::kInt));
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pso = lib.getPipelineStateForFunc("intersect_binary_search");
auto enc = stream->commandEncoder();
[enc setComputePipelineState:pso];
mtl_setArgs(enc, A_keys, B_keys, outA_idx, outB_idx, counter,
static_cast<uint32_t>(lenB), A_is_src);
mtl_dispatch1DJob(enc, pso, static_cast<uint32_t>(lenA));
}
});
const int64_t M = static_cast<int64_t>(counter.item<int32_t>());
auto out_val_sizes = src_values.sizes().vec();
out_val_sizes[0] = mask_nnz;
auto out_values = at::zeros(out_val_sizes, src_values.options());
auto [outA_idx, outB_idx, M] = mps_intersect_binary_search(
A_keys, B_keys, lenA, lenB, A_is_src);
if (M > 0) {
auto src_match = outA_idx.narrow(0, 0, M);
@ -878,6 +875,70 @@ static void sparse_mask_apply_out_mps_kernel(
result._coalesced_(mask.is_coalesced());
}
static void sparse_mask_projection_out_mps_kernel(
Tensor& result,
const Tensor& lhs,
const Tensor& rhs,
const OptTensor& /*x_hash_opt*/,
bool accumulate_matches) {
TORCH_CHECK(lhs.is_sparse() && rhs.is_sparse(), "sparse_mask_projection: expected sparse COO");
TORCH_CHECK(lhs.is_mps() && rhs.is_mps(), "sparse_mask_projection: expected MPS tensors");
TORCH_CHECK(lhs.sparse_dim() == rhs.sparse_dim(), "sparse_dim mismatch");
auto lhs_c = lhs.coalesce();
auto rhs_c = rhs.coalesce();
const auto sd = lhs_c.sparse_dim();
const auto lhs_nnz = lhs_c._nnz();
const auto rhs_nnz = rhs_c._nnz();
auto commonDtype = at::result_type(lhs_c, rhs_c);
TORCH_CHECK(canCast(commonDtype, result.scalar_type()),
"Can't convert ", commonDtype, " to output ", result.scalar_type());
result.sparse_resize_(lhs.sizes(), lhs.sparse_dim(), lhs.dense_dim());
auto lhs_indices = lhs_c._indices().contiguous();
auto rhs_values = rhs_c._values().to(commonDtype).contiguous();
auto out_values = create_sparse_output_values(rhs_values, lhs_nnz, commonDtype);
if (lhs_nnz > 0 && rhs_nnz > 0) {
auto lhs_keys = flatten_indices(lhs_indices, lhs_c.sizes().slice(0, sd)).contiguous();
auto rhs_keys = flatten_indices(rhs_c._indices().contiguous(), rhs_c.sizes().slice(0, sd)).contiguous();
const auto A_is_lhs = (lhs_nnz <= rhs_nnz);
const auto lenA = A_is_lhs ? lhs_nnz : rhs_nnz;
const auto lenB = A_is_lhs ? rhs_nnz : lhs_nnz;
auto A_keys = A_is_lhs ? lhs_keys : rhs_keys;
auto B_keys = A_is_lhs ? rhs_keys : lhs_keys;
auto [outA_idx, outB_idx, M] = mps_intersect_binary_search(
A_keys, B_keys, lenA, lenB, A_is_lhs);
if (M > 0) {
auto idx_in_A = outA_idx.narrow(0, 0, M);
auto idx_in_B = outB_idx.narrow(0, 0, M);
auto idx_in_lhs = A_is_lhs ? idx_in_A : idx_in_B;
auto idx_in_rhs = A_is_lhs ? idx_in_B : idx_in_A;
const auto view_cols = rhs_values.numel() / std::max<int64_t>(rhs_nnz, 1);
auto rhs_rows = rhs_values.index_select(0, idx_in_rhs).contiguous();
auto rhs_rows_2d = rhs_rows.view({M, view_cols});
auto out_2d = out_values.view({lhs_nnz, view_cols});
if (accumulate_matches) {
out_2d.index_add_(0, idx_in_lhs, rhs_rows_2d);
} else {
out_2d.index_copy_(0, idx_in_lhs, rhs_rows_2d);
}
}
}
alias_into_sparse(result, lhs._indices(), out_values);
result._coalesced_(lhs.is_coalesced());
}
static void sparse_mask_intersection_out_mps_kernel(
Tensor& result,
const Tensor& lhs,
@ -1002,4 +1063,5 @@ Tensor sparse_sparse_matmul_mps(const Tensor& mat1_, const Tensor& mat2_) {
}
REGISTER_MPS_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_mps_kernel);
REGISTER_MPS_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_mps_kernel);
} // namespace at::native

View File

@ -61,6 +61,7 @@ list(APPEND ATen_CUDA_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cuda_complex_math_test.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda_complex_test.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda_cub_test.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda_cublas_handle_pool_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cuda_device_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cuda_distributions_test.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda_dlconvertor_test.cpp

View File

@ -0,0 +1,77 @@
#include <gtest/gtest.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAGuard.h>
#include <atomic>
#include <thread>
#include <vector>
// Test concurrent access to getCurrentCUDABlasHandle and getCUDABlasLtWorkspace
// to verify that the data race fix is working correctly
TEST(CUDABlasHandlePoolTest, ConcurrentGetAndClearWorkspaces) {
if (!at::cuda::is_available()) {
return;
}
constexpr int num_accessor_threads = 15;
constexpr int num_clear_threads = 5;
constexpr int iterations_per_thread = 50;
std::atomic<bool> stop{false};
std::atomic<int> error_count{0};
std::vector<std::thread> threads;
threads.reserve(num_accessor_threads + num_clear_threads);
// Launch accessor threads
for (int i = 0; i < num_accessor_threads; ++i) {
threads.emplace_back([&stop, &error_count]() {
try {
at::cuda::CUDAGuard device_guard(0);
while (!stop.load(std::memory_order_relaxed)) {
const auto handle = at::cuda::getCurrentCUDABlasHandle();
const auto workspace = at::cuda::getCUDABlasLtWorkspace();
if (handle == nullptr || workspace == nullptr) {
error_count++;
}
}
} catch (const std::exception& e) {
error_count++;
}
});
}
// Launch threads that clear workspaces
for (int i = 0; i < num_clear_threads; ++i) {
threads.emplace_back([&error_count]() {
try {
for (int j = 0; j < iterations_per_thread; ++j) {
at::cuda::clearCublasWorkspaces();
std::this_thread::yield();
}
} catch (const std::exception& e) {
error_count++;
}
});
}
// Let them run for a bit
std::this_thread::sleep_for(std::chrono::milliseconds(100));
stop.store(true, std::memory_order_relaxed);
for (auto& thread : threads) {
thread.join();
}
EXPECT_EQ(error_count.load(), 0);
}
int main(int argc, char* argv[]) {
::testing::InitGoogleTest(&argc, argv);
c10::cuda::CUDACachingAllocator::init(1);
return RUN_ALL_TESTS();
}

View File

@ -10,6 +10,13 @@
...
}
{
ignore_empty_generic_uninitialised_conditional_jump
Memcheck:Cond
fun:_ZN2at6detail13empty_genericEN3c108ArrayRefIlEEPNS1_9AllocatorENS1_14DispatchKeySetENS1_10ScalarTypeESt8optionalINS1_12MemoryFormatEE
...
}
{
Cond_cuda
Memcheck:Cond

View File

@ -0,0 +1,62 @@
import sys
from benchmark_base import BenchmarkBase
import torch
from torch.distributed._tensor import DTensor, Replicate
from torch.testing._internal.distributed.fake_pg import FakeStore
class BenchmarkDTensorDispatch(BenchmarkBase):
def __init__(self, operator, world_size) -> None:
super().__init__(
category=f"dtensor_dispatch_{operator}",
device="cuda",
)
self.world_size = world_size
def name(self) -> str:
prefix = f"{self.category()}"
return prefix
def description(self) -> str:
return f"DTensor dispatch time for {self.category()}"
def _prepare_once(self) -> None:
self.mesh = torch.distributed.device_mesh.init_device_mesh(
"cuda", (self.world_size,), mesh_dim_names=("dp",)
)
self.a = DTensor.from_local(
torch.ones(10, 10, device=self.device()), self.mesh, [Replicate()]
)
self.b = DTensor.from_local(
torch.ones(10, 10, device=self.device()), self.mesh, [Replicate()]
)
def _prepare(self) -> None:
pass
class BenchmarkDetach(BenchmarkDTensorDispatch):
def __init__(self, world_size) -> None:
super().__init__(operator="detach", world_size=world_size)
def _work(self) -> None:
self.a.detach()
def main():
world_size = 256
fake_store = FakeStore()
torch.distributed.init_process_group(
"fake", store=fake_store, rank=0, world_size=world_size
)
result_path = sys.argv[1]
BenchmarkDetach(world_size).enable_instruction_count().collect_all().append_results(
result_path
)
torch.distributed.destroy_process_group()
if __name__ == "__main__":
main()

View File

@ -189,6 +189,10 @@ skip:
- hf_Whisper
- hf_distil_whisper
- timm_vision_transformer_large
# https://github.com/pytorch/pytorch/issues/167895
- stable_diffusion
- stable_diffusion_text_encoder
- stable_diffusion_unet
device:
cpu:

View File

@ -125,6 +125,17 @@ AttentionType = Literal[
]
DtypeString = Literal["bfloat16", "float16", "float32"]
SpeedupType = Literal["fwd", "bwd"]
# Operator Name mapping
backend_to_operator_name = {
"math": "math attention kernel",
"efficient": "efficient attention kernel",
"cudnn": "cudnn attention kernel",
"fav2": "flash attention 2 kernel",
"fav3": "flash attention 3 kernel",
"fakv": "flash attention kv cache kernel",
"og-eager": "eager attention kernel",
"flex": "flex attention kernel",
}
def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
@ -1265,12 +1276,14 @@ def _output_json_for_dashboard(
model: ModelInfo
metric: MetricInfo
operator_name = backend_to_operator_name.get(backend, backend)
# Benchmark extra info
benchmark_extra_info = {
"input_config": input_config,
"device": device,
"arch": device_arch,
"operator_name": backend,
"operator_name": operator_name,
"attn_type": config.attn_type,
"shape": str(config.shape),
"max_autotune": config.max_autotune,
@ -1288,7 +1301,7 @@ def _output_json_for_dashboard(
type="attention-benchmark",
origins=["pytorch"],
extra_info={
"operator_name": backend,
"operator_name": operator_name,
"attn_type": config.attn_type,
},
),
@ -1315,7 +1328,7 @@ def _output_json_for_dashboard(
type="attention-benchmark",
origins=["pytorch"],
extra_info={
"operator_name": backend,
"operator_name": operator_name,
},
),
metric=MetricInfo(
@ -1341,7 +1354,7 @@ def _output_json_for_dashboard(
type="attention-benchmark",
origins=["pytorch"],
extra_info={
"operator_name": backend,
"operator_name": operator_name,
},
),
metric=MetricInfo(
@ -1371,7 +1384,7 @@ def _output_json_for_dashboard(
type="attention-benchmark",
origins=["pytorch"],
extra_info={
"operator_name": backend,
"operator_name": operator_name,
},
),
metric=MetricInfo(

View File

@ -2,6 +2,7 @@
# These load paths point to different files in internal and OSS environment
load("@bazel_skylib//lib:paths.bzl", "paths")
load("//tools/build_defs:cell_defs.bzl", "get_fbsource_cell")
load("//tools/build_defs:fb_native_wrapper.bzl", "fb_native")
load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library")
load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule")
@ -590,6 +591,9 @@ def pt_operator_query_codegen(
pt_allow_forced_schema_registration = True,
compatible_with = [],
apple_sdks = None):
if get_fbsource_cell() == "fbcode":
return
oplist_dir_name = name + "_pt_oplist"
# @lint-ignore BUCKLINT
@ -865,6 +869,9 @@ def define_buck_targets(
pt_xplat_cxx_library = fb_xplat_cxx_library,
c2_fbandroid_xplat_compiler_flags = [],
labels = []):
if get_fbsource_cell() == "fbcode":
return
# @lint-ignore BUCKLINT
fb_native.filegroup(
name = "metal_build_srcs",

View File

@ -19,6 +19,17 @@
namespace c10 {
using CaptureId_t = unsigned long long;
// first is set if the instance is created by CUDAGraph::capture_begin.
// second is set if the instance is created by at::cuda::graph_pool_handle.
using MempoolId_t = std::pair<CaptureId_t, CaptureId_t>;
struct MempoolIdHash {
std::size_t operator()(const MempoolId_t& mempool_id) const noexcept {
return mempool_id.first != 0 ? mempool_id.first : mempool_id.second;
}
};
// A DataPtr is a unique pointer (with an attached deleter and some
// context for the deleter) to some memory, which also records what
// device is for its data.

View File

@ -96,6 +96,13 @@ struct C10_API DeviceAllocator : public c10::Allocator {
// Resets peak memory usage statistics for the specified device
virtual void resetPeakStats(c10::DeviceIndex device) = 0;
// Return the free memory size and total memory size in bytes for the
// specified device.
virtual std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) {
TORCH_CHECK_NOT_IMPLEMENTED(
false, "getMemoryInfo is not implemented for this allocator yet.");
}
};
// This function is used to get the DeviceAllocator for a specific device type

View File

@ -44,7 +44,7 @@ struct C10_API SafePyObject {
(*other.pyinterpreter_)->incref(other.data_);
}
if (data_ != nullptr) {
(*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
(*pyinterpreter_)->decref(data_);
}
data_ = other.data_;
pyinterpreter_ = other.pyinterpreter_;
@ -53,7 +53,7 @@ struct C10_API SafePyObject {
~SafePyObject() {
if (data_ != nullptr) {
(*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
(*pyinterpreter_)->decref(data_);
}
}

View File

@ -34,20 +34,6 @@ namespace c10 {
// See [dtype Macros note] in torch/headeronly/core/ScalarType.h
// regarding macros.
template <typename T>
struct CppTypeToScalarType;
#define SPECIALIZE_CppTypeToScalarType(cpp_type, scalar_type) \
template <> \
struct CppTypeToScalarType<cpp_type> \
: std:: \
integral_constant<c10::ScalarType, c10::ScalarType::scalar_type> { \
};
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
#undef SPECIALIZE_CppTypeToScalarType
#define DEFINE_CONSTANT(_, name) \
constexpr ScalarType k##name = ScalarType::name;

View File

@ -48,6 +48,30 @@ void warnDeprecatedDataPtr() {
TORCH_CHECK(false, "Cannot access data pointer of Storage that is invalid.");
}
void StorageImpl::incref_pyobject() const {
// Because intrusive_ptr incref uses relaxed memory order, we need to
// do an acquire fence to ensure that the kHasPyObject bit was
// observed before the load of the PyObject* below.
// NB: This is a no-op on x86/x86-64
std::atomic_thread_fence(std::memory_order_acquire);
PyObject* obj = pyobj_slot_.load_pyobj();
(*pyobj_slot_.pyobj_interpreter())->incref(obj);
}
void StorageImpl::decref_pyobject() const {
PyObject* obj = pyobj_slot_.load_pyobj();
(*pyobj_slot_.pyobj_interpreter())->decref(obj);
}
bool StorageImpl::try_incref_pyobject() const {
c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter();
if (C10_UNLIKELY(!interp)) {
return false;
}
return (*interp)->try_incref(pyobj_slot_);
}
void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr) {
// Allowlist verification.
// Only if the devicetype is in the allowlist,

View File

@ -105,6 +105,12 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
data_ptr_.clear();
}
void incref_pyobject() const override final;
void decref_pyobject() const override final;
bool try_incref_pyobject() const override final;
size_t nbytes() const {
// OK to do this instead of maybe_as_int as nbytes is guaranteed positive
TORCH_CHECK(!size_bytes_is_heap_allocated_);
@ -370,4 +376,18 @@ C10_API c10::intrusive_ptr<c10::StorageImpl> make_storage_impl(
bool resizable,
std::optional<at::Device> device_opt);
namespace detail {
#ifndef C10_MOBILE
template <class T>
struct TargetTraits<
T,
std::enable_if_t<
std::is_base_of_v<c10::StorageImpl, std::remove_cv_t<T>>>> {
static constexpr bool can_have_pyobject = true;
};
#endif
} // namespace detail
} // namespace c10

View File

@ -277,7 +277,6 @@ void TensorImpl::release_resources() {
if (storage_) {
storage_ = {};
}
pyobj_slot_.maybe_destroy_pyobj();
}
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
@ -989,6 +988,30 @@ void TensorImpl::empty_tensor_restride_symint(MemoryFormat memory_format) {
}
}
void TensorImpl::incref_pyobject() const {
// Because intrusive_ptr incref uses relaxed memory order, we need to
// do an acquire fence to ensure that the kHasPyObject bit was
// observed before the load of the PyObject* below.
// NB: This is a no-op on x86/x86-64
std::atomic_thread_fence(std::memory_order_acquire);
PyObject* obj = pyobj_slot_.load_pyobj();
(*pyobj_slot_.pyobj_interpreter())->incref(obj);
}
void TensorImpl::decref_pyobject() const {
PyObject* obj = pyobj_slot_.load_pyobj();
(*pyobj_slot_.pyobj_interpreter())->decref(obj);
}
bool TensorImpl::try_incref_pyobject() const {
c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter();
if (C10_UNLIKELY(!interp)) {
return false;
}
return (*interp)->try_incref(pyobj_slot_);
}
namespace impl {
namespace {

View File

@ -2178,6 +2178,12 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return &pyobj_slot_;
}
void incref_pyobject() const override final;
void decref_pyobject() const override final;
bool try_incref_pyobject() const override final;
private:
// See NOTE [std::optional operator usage in CUDA]
// We probably don't want to expose this publicly until
@ -3079,6 +3085,19 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
friend class C10_TensorImpl_Size_Check_Dummy_Class;
};
namespace detail {
#ifndef C10_MOBILE
template <class T>
struct TargetTraits<
T,
std::enable_if_t<std::is_base_of_v<c10::TensorImpl, std::remove_cv_t<T>>>> {
static constexpr bool can_have_pyobject = true;
};
#endif
} // namespace detail
// Note [TensorImpl size constraints]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Changed the size of TensorImpl? If the size went down, good for

View File

@ -11,8 +11,11 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
void incref(PyObject* pyobj) const override {} // do nothing
void decref(PyObject* pyobj, bool has_pyobj_slot) const override {
} // do nothing
void decref(PyObject* pyobj) const override {} // do nothing
bool try_incref(const c10::impl::PyObjectSlot& pyobj_slot) const override {
return false;
}
#define PANIC(m) \
TORCH_INTERNAL_ASSERT( \
@ -20,6 +23,10 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
"attempted to call " #m \
" on a Tensor with nontrivial PyObject after corresponding interpreter died")
size_t refcnt(PyObject* pyobj) const override {
PANIC(refcnt);
}
c10::intrusive_ptr<TensorImpl> detach(const TensorImpl* self) const override {
PANIC(detach);
}

View File

@ -18,6 +18,9 @@ namespace c10 {
struct IValue;
class OperatorHandle;
struct TensorImpl;
namespace impl {
struct PyObjectSlot;
} // namespace impl
} // namespace c10
namespace torch::jit {
@ -126,9 +129,12 @@ struct C10_API PyInterpreterVTable {
// Run Py_INCREF on a PyObject.
virtual void incref(PyObject* pyobj) const = 0;
// Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call
// See NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg]
virtual void decref(PyObject* pyobj, bool has_pyobj_slot) const = 0;
// Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call.
virtual void decref(PyObject* pyobj) const = 0;
// Run PyUnstable_TryIncRef on a PyObject if it's not NULL.
virtual bool try_incref(const c10::impl::PyObjectSlot& pyobj_slot) const = 0;
// Run Py_REFCNT on a PyObject.
virtual size_t refcnt(PyObject* pyobj) const = 0;
// Perform a detach by deferring to the __torch_dispatch__ implementation of
// detach, which will also arrange for the PyObject to get copied in this

View File

@ -1,56 +0,0 @@
#include <c10/core/impl/PyObjectSlot.h>
namespace c10::impl {
PyObjectSlot::PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {}
PyObjectSlot::~PyObjectSlot() {
maybe_destroy_pyobj();
}
void PyObjectSlot::maybe_destroy_pyobj() {
if (owns_pyobj()) {
TORCH_INTERNAL_ASSERT(pyobj_interpreter_ != nullptr);
TORCH_INTERNAL_ASSERT(pyobj_ != nullptr);
(*pyobj_interpreter_.load(std::memory_order_acquire))
->decref(_unchecked_untagged_pyobj(), /*has_pyobj_slot*/ true);
// NB: this destructor can only be entered when there are no
// references to this C++ object (obviously), NOR any references
// to the PyObject (if there are references to the PyObject,
// then the PyObject holds an owning reference to the tensor).
// So it is OK to clear pyobj_ here as it is impossible for it to
// be used again (modulo weak reference races)
pyobj_ = nullptr; // for safety
}
}
PyInterpreter* PyObjectSlot::pyobj_interpreter() {
return pyobj_interpreter_.load(std::memory_order_acquire);
}
PyObject* PyObjectSlot::_unchecked_untagged_pyobj() const {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return reinterpret_cast<PyObject*>(
reinterpret_cast<uintptr_t>(pyobj_) & ~0x1ULL);
}
PyInterpreter& PyObjectSlot::load_pyobj_interpreter() const {
auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire);
if (interpreter) {
return *interpreter;
}
TORCH_CHECK(false, "cannot access PyObject for Tensor - no interpreter set");
}
bool PyObjectSlot::owns_pyobj() {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return reinterpret_cast<uintptr_t>(pyobj_) & 1;
}
void PyObjectSlot::set_owns_pyobj(bool b) {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
pyobj_ = reinterpret_cast<PyObject*>(
reinterpret_cast<uintptr_t>(_unchecked_untagged_pyobj()) | b);
}
} // namespace c10::impl

View File

@ -8,117 +8,58 @@
#include <atomic>
namespace torch::utils {
class PyObjectPreservation;
}
namespace c10::impl {
struct C10_API PyObjectSlot {
public:
PyObjectSlot();
~PyObjectSlot();
void maybe_destroy_pyobj();
// Associate the TensorImpl with the specified PyObject, and, if necessary,
// also tag the interpreter.
//
// NB: This lives in a header so that we can inline away the switch on status
//
// NB: THIS FUNCTION CAN RAISE AN EXCEPTION. Make sure to clean up after
// PyObject if necessary!
void init_pyobj(PyObject* pyobj) {
pyobj_interpreter_.store(
getGlobalPyInterpreter(), std::memory_order_relaxed);
pyobj_ = pyobj;
}
PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {}
// Query the PyObject interpreter. This may return null if there is no
// interpreter. This is racy!
PyInterpreter* pyobj_interpreter();
PyObject* _unchecked_untagged_pyobj() const;
// Test the interpreter tag. If tagged for the current interpreter, return
// a non-nullopt (but possibly null) PyObject. If (possibly) untagged,
// returns a nullopt. If it is definitely invalid, raises an error.
//
// If `ignore_hermetic_tls` is false and this function is called from a
// hermetic context (ie, `HermeticPyObjectTLS::get_state()` is true), then
// nullopt is returned. If `ignore_hermetic_tls` is true, then the hermetic
// context is ignored, allowing you to check the interpreter tag of a
// nonhermetic PyObject from within a hermetic context. This is necessary
// because there are some cases where the deallocator function of a
// nonhermetic PyObject is called from within a hermetic context, so it must
// be properly treated as a nonhermetic PyObject.
//
// NB: this lives in header so that we can avoid actually creating the
// std::optional
// @todo alban: I'm not too sure what's going on here, we can probably delete
// it but it's worthwhile making sure
std::optional<PyObject*> check_pyobj(bool ignore_hermetic_tls = false) const {
impl::PyInterpreter* interpreter =
pyobj_interpreter_.load(std::memory_order_acquire);
if (interpreter == nullptr) {
return std::nullopt;
}
if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) {
return std::nullopt;
} else {
return _unchecked_untagged_pyobj();
}
// interpreter.
PyInterpreter* pyobj_interpreter() const {
return pyobj_interpreter_.load(std::memory_order_acquire);
}
PyInterpreter& load_pyobj_interpreter() const;
PyInterpreter& load_pyobj_interpreter() const {
auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire);
TORCH_INTERNAL_ASSERT(
interpreter, "cannot access PyObject for Tensor - no interpreter set");
return *interpreter;
}
bool owns_pyobj();
PyObject* load_pyobj() const {
return pyobj_.load(std::memory_order_acquire);
}
void set_owns_pyobj(bool b);
void store_pyobj(PyObject* obj) {
pyobj_.store(obj, std::memory_order_release);
}
bool has_unique_reference() const {
PyObject* pyobj = load_pyobj();
return pyobj != nullptr && load_pyobj_interpreter()->refcnt(pyobj) == 1;
}
void clear() {
pyobj_.store(nullptr, std::memory_order_relaxed);
pyobj_interpreter_.store(nullptr, std::memory_order_relaxed);
}
private:
// This field contains the interpreter tag for this object. See
// Note [Python interpreter tag] for general context
//
// Note [Memory ordering on Python interpreter tag]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// What memory_order do we need when accessing this atomic? We don't
// need a single total modification order (as provided by
// memory_order_seq_cst) as pyobj_interpreter_ is monotonic: it can only
// transition from -1 to some positive integer and never changes afterwards.
// Because there is only one modification, it trivially already has a total
// modification order (e.g., we don't need fences or locked instructions on
// x86)
//
// In fact, one could make a reasonable argument that relaxed reads are OK,
// due to the presence of external locking (GIL) to ensure that interactions
// with other data structures are still correctly synchronized, so that
// we fall in the "Single-Location Data Structures" case as described in
// http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf
// However, on x86, it doesn't matter if I use acquire or relaxed on the load
// as I get the same assembly in both cases. So I just use the more
// conservative acquire (which will impede compiler optimizations but I don't
// care)
// This is now always the global interpreter if the PyObject is set.
// Maybe we can remove this field some day...
std::atomic<PyInterpreter*> pyobj_interpreter_;
// This field contains a reference to a PyObject representing this Tensor.
// If pyobj is nullptr, when we transfer Tensor to Python, we allocate a new
// PyObject for it and set this field. This field does not have to be
// protected by an atomic as it is only allowed to be accessed when you hold
// the GIL, or during destruction of the tensor.
//
// When a PyObject dies, you are obligated to clear this field
// (otherwise, you will try to use-after-free the pyobj); this currently
// occurs in THPVariable_clear in torch/csrc/autograd/python_variable.cpp
//
// NB: Ordinarily, this should not be a strong reference, as if the
// PyObject owns the Tensor, this would create a reference cycle.
// However, sometimes this ownership flips. To track who owns
// who, this has a single pointer tag indicating whether or not the
// C++ object owns the PyObject (the common case, zero, means PyObject
// owns the C++ object); see _unchecked_untagged_pyobj for raw access
// or check_pyobj for checked access. See references to PyObject
// resurrection in torch/csrc/autograd/python_variable.cpp
PyObject* pyobj_;
// The PyObject representing this Tensor or nullptr. Ownership is managed
// by intrusive_ptr. By the time the PyObjectSlot is destroyed, this
// reference is already dead.
std::atomic<PyObject*> pyobj_;
friend class torch::utils::PyObjectPreservation;
};
} // namespace c10::impl

View File

@ -1012,12 +1012,6 @@ PrivatePoolState::PrivatePoolState(
}
}
struct MempoolIdHash {
std::size_t operator()(const MempoolId_t& mempool_id) const noexcept {
return mempool_id.first != 0 ? mempool_id.first : mempool_id.second;
}
};
cudaError_t allocPrimitive(void** ptr, size_t size, AllocParams& p) {
if (p.pool->owner_PrivatePool && p.pool->owner_PrivatePool->allocator()) {
*ptr = p.pool->owner_PrivatePool->allocator()->raw_alloc(size);
@ -4510,66 +4504,3 @@ std::atomic<CUDAAllocator*> allocator;
static BackendStaticInitializer backend_static_initializer;
} // namespace cuda::CUDACachingAllocator
} // namespace c10
namespace c10::cuda {
// uid_ is incremented when a user creates a MemPool,
// for example: using graph_pool_handle() or c10::cuda::MemPool().
//
// uuid_ is incremented when CUDAGraph creates a MemPool
// as a result of a user not providing a pool.
//
// MempoolId_t of {0, 0} is used to denote when no MemPool has been
// passed to a function, either by user or CUDAGraphs. For example,
// default value of MempoolId_t for capture_begin function is {0, 0}.
// That's why uid_ and uuid_ start at 1.
std::atomic<CaptureId_t> MemPool::uid_{1};
std::atomic<CaptureId_t> MemPool::uuid_{1};
MemPool::MemPool(
CUDACachingAllocator::CUDAAllocator* allocator,
bool is_user_created,
bool use_on_oom)
: allocator_(allocator), is_user_created_(is_user_created) {
if (is_user_created_) {
id_ = {0, uid_++};
} else {
id_ = {uuid_++, 0};
}
device_ = c10::cuda::current_device();
CUDACachingAllocator::createOrIncrefPool(device_, id_, allocator);
if (use_on_oom) {
CUDACachingAllocator::setUseOnOOM(device_, id_);
}
}
MemPool::~MemPool() {
TORCH_INTERNAL_ASSERT(use_count() == 1);
CUDACachingAllocator::releasePool(device_, id_);
c10::cuda::CUDACachingAllocator::emptyCache(id_);
}
MempoolId_t MemPool::id() {
return id_;
}
CUDACachingAllocator::CUDAAllocator* MemPool::allocator() {
return allocator_;
}
int MemPool::use_count() {
return CUDACachingAllocator::getPoolUseCount(device_, id_);
}
c10::DeviceIndex MemPool::device() {
return device_;
}
MempoolId_t MemPool::graph_pool_handle(bool is_user_created) {
if (is_user_created) {
return {0, uid_++};
}
return {uuid_++, 0};
}
} // namespace c10::cuda

View File

@ -345,6 +345,13 @@ class CUDAAllocator : public DeviceAllocator {
c10::DeviceIndex device,
std::shared_ptr<AllocatorState> pps) = 0;
virtual std::string name() = 0;
std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) override {
c10::DeviceGuard device_guard({at::kCUDA, device});
size_t free = 0;
size_t total = 0;
C10_CUDA_CHECK(cudaMemGetInfo(&free, &total));
return {free, total};
}
};
// Allocator object, statically initialized
@ -555,41 +562,7 @@ inline std::string getUserMetadata() {
} // namespace c10::cuda::CUDACachingAllocator
namespace c10::cuda {
// Keep BC only
using c10::CaptureId_t;
using c10::MempoolId_t;
// MemPool represents a pool of memory in a caching allocator. Currently,
// it's just the ID of the pool object maintained in the CUDACachingAllocator.
//
// An allocator pointer can be passed to the MemPool to define how the
// allocations should be done in the pool. For example: using a different
// system allocator such as ncclMemAlloc.
struct C10_CUDA_API MemPool {
MemPool(
CUDACachingAllocator::CUDAAllocator* allocator = nullptr,
bool is_user_created = true,
bool use_on_oom = false);
MemPool(const MemPool&) = delete;
MemPool(MemPool&&) = default;
MemPool& operator=(const MemPool&) = delete;
MemPool& operator=(MemPool&&) = default;
~MemPool();
MempoolId_t id();
CUDACachingAllocator::CUDAAllocator* allocator();
int use_count();
c10::DeviceIndex device();
static MempoolId_t graph_pool_handle(bool is_user_created = true);
private:
static std::atomic<CaptureId_t> uid_;
static std::atomic<CaptureId_t> uuid_;
CUDACachingAllocator::CUDAAllocator* allocator_;
bool is_user_created_;
MempoolId_t id_;
c10::DeviceIndex device_;
};
} // namespace c10::cuda

View File

@ -295,11 +295,19 @@ DeviceAssertionsData* CUDAKernelLaunchRegistry::
C10_CUDA_CHECK_WO_DSA(
cudaMallocManaged(&uvm_assertions_ptr, sizeof(DeviceAssertionsData)));
#if CUDART_VERSION >= 13000
cudaMemLocation cpuDevice;
cpuDevice.type = cudaMemLocationTypeDevice;
cpuDevice.id = cudaCpuDeviceId;
#else
const auto cpuDevice = cudaCpuDeviceId;
#endif
C10_CUDA_CHECK_WO_DSA(cudaMemAdvise(
uvm_assertions_ptr,
sizeof(DeviceAssertionsData),
cudaMemAdviseSetPreferredLocation,
cudaCpuDeviceId));
cpuDevice));
// GPU will establish direct mapping of data in CPU memory, no page faults
// will be generated
@ -307,7 +315,7 @@ DeviceAssertionsData* CUDAKernelLaunchRegistry::
uvm_assertions_ptr,
sizeof(DeviceAssertionsData),
cudaMemAdviseSetAccessedBy,
cudaCpuDeviceId));
cpuDevice));
// Initialize the memory from the CPU; otherwise, pages may have to be created
// on demand. We think that UVM documentation indicates that first access may

View File

@ -50,7 +50,13 @@ namespace c10 {
/// However, you should prefer to use ArrayRef when possible, because its use
/// of TORCH_CHECK will lead to better user-facing error messages.
template <typename T>
class ArrayRef final : public HeaderOnlyArrayRef<T> {
// ArrayRef cannot be derived from. Normally, we would use `final`
// specifier to force this constraint at compile time. However, Intel
// compiler does not recognize ArrayRef as a class template (which is
// required in the definition of at::TensorAccessor, for instance)
// when `final` specifier is used. So, we cannot define ArrayRef as
// final because of the Intel compiler issue.
class ArrayRef : public HeaderOnlyArrayRef<T> {
public:
/// @name Constructors, all inherited from HeaderOnlyArrayRef except for
/// SmallVector. As inherited constructors won't work with class template

View File

@ -379,7 +379,11 @@ C10_API std::string GetExceptionString(const std::exception& e);
// ----------------------------------------------------------------------------
#ifdef STRIP_ERROR_MESSAGES
#define TORCH_RETHROW(e, ...) throw
#define TORCH_RETHROW(e, ...) \
do { \
(void)e; /* Suppress unused variable warning */ \
throw; \
} while (false)
#else
#define TORCH_RETHROW(e, ...) \
do { \

View File

@ -12,6 +12,10 @@ template <typename, typename...>
class class_;
}
namespace torch::utils {
class PyObjectPreservation;
}
namespace c10 {
class intrusive_ptr_target;
namespace raw {
@ -33,6 +37,8 @@ constexpr uint64_t kImpracticallyHugeWeakReferenceCount =
constexpr uint64_t kReferenceCountOne = 1;
constexpr uint64_t kWeakReferenceCountOne = (kReferenceCountOne << 32);
constexpr uint64_t kUniqueRef = (kReferenceCountOne | kWeakReferenceCountOne);
// Indicates whether the object has a PyObject wrapper.
constexpr uint64_t kHasPyObject = (uint64_t(1) << 63);
template <class TTarget>
struct intrusive_target_default_null_type final {
@ -55,7 +61,11 @@ inline uint32_t refcount(uint64_t combined_refcount) {
}
inline uint32_t weakcount(uint64_t combined_refcount) {
return static_cast<uint32_t>(combined_refcount >> 32);
return static_cast<uint32_t>((combined_refcount & ~kHasPyObject) >> 32);
}
inline bool has_pyobject(uint64_t combined_refcount) {
return (combined_refcount & kHasPyObject) != 0;
}
// The only requirement for refcount increment is that it happens-before
@ -66,12 +76,6 @@ inline uint64_t atomic_combined_refcount_increment(
return combined_refcount.fetch_add(inc, std::memory_order_relaxed) + inc;
}
inline uint32_t atomic_refcount_increment(
std::atomic<uint64_t>& combined_refcount) {
return detail::refcount(atomic_combined_refcount_increment(
combined_refcount, kReferenceCountOne));
}
inline uint32_t atomic_weakcount_increment(
std::atomic<uint64_t>& combined_refcount) {
return detail::weakcount(atomic_combined_refcount_increment(
@ -99,6 +103,11 @@ inline uint32_t atomic_weakcount_decrement(
combined_refcount, kWeakReferenceCountOne));
}
template <class T, class = void>
struct TargetTraits {
static constexpr bool can_have_pyobject = false;
};
} // namespace detail
/**
@ -155,6 +164,23 @@ class C10_API intrusive_ptr_target {
// we can atomically operate on both at the same time for performance
// and defined behaviors.
//
// Note [PyObject preservation for Tensor and Storages]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// intrusive_ptr has special support for preserving PyObject wrappers
// for TensorImpl and StorageImpl. The most significant bit (kHasPyObject) of
// the combined_refcount_ is used to indicate whether the object has a
// PyObject wrapper.
//
// - The PyObject, if it exists, holds a strong reference to the
// intrusive_ptr_target.
//
// - When the refcount goes from 1 to 2, we incref the PyObject.
//
// - When the refcount goes from 2 to 1, we decref the PyObject.
//
// In other words, the intrusive_ptr keeps the PyObject alive as long as there
// are other C++ references to the intrusive_ptr_target.
mutable std::atomic<uint64_t> combined_refcount_;
static_assert(sizeof(std::atomic<uint64_t>) == 8);
static_assert(alignof(std::atomic<uint64_t>) == 8);
@ -172,6 +198,8 @@ class C10_API intrusive_ptr_target {
template <typename T>
friend struct ExclusivelyOwnedTensorTraits;
friend class torch::utils::PyObjectPreservation;
protected:
// protected destructor. We never want to destruct intrusive_ptr_target*
// directly.
@ -255,6 +283,16 @@ class C10_API intrusive_ptr_target {
*/
virtual void release_resources() {}
/**
* These two methods are called when the refcount transitions between one
* and two and the object has a PyObject wrapper.
*/
virtual void incref_pyobject() const {}
virtual void decref_pyobject() const {}
virtual bool try_incref_pyobject() const {
return false;
}
uint32_t refcount(std::memory_order order = std::memory_order_relaxed) const {
return detail::refcount(combined_refcount_.load(order));
}
@ -265,6 +303,19 @@ class C10_API intrusive_ptr_target {
}
};
namespace detail {
#ifndef C10_MOBILE
template <>
struct TargetTraits<c10::intrusive_ptr_target> {
// A generic intrusive_ptr<intrusive_ptr_target> may actually be a TensorImpl
// or StorageImpl, so we have to allow for PyObject support.
static constexpr bool can_have_pyobject = true;
};
#endif
} // namespace detail
template <class TTarget, class NullType>
class weak_intrusive_ptr;
@ -314,18 +365,34 @@ class intrusive_ptr final {
void retain_() {
if (target_ != NullType::singleton()) {
uint32_t new_refcount =
detail::atomic_refcount_increment(target_->combined_refcount_);
uint64_t combined = detail::atomic_combined_refcount_increment(
target_->combined_refcount_, detail::kReferenceCountOne);
uint32_t new_refcount = detail::refcount(combined);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
new_refcount != 1,
"intrusive_ptr: Cannot increase refcount after it reached zero.");
if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
// If the refcount transitioned from 1 to 2, we need to incref the
// PyObject. In other words, we need to ensure that the PyObject stays
// alive now that we have a C++ reference to this object in addition to
// the PyObject itself.
if (C10_UNLIKELY(
detail::has_pyobject(combined) &&
detail::refcount(combined) == 2)) {
target_->incref_pyobject();
}
} else {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
!detail::has_pyobject(combined),
"TargetTraits indicates that type cannot have PyObject, but refcount has PyObject bit set.");
}
}
}
void reset_() noexcept {
if (target_ != NullType::singleton()) {
if (target_->combined_refcount_.load(std::memory_order_acquire) ==
detail::kUniqueRef) {
if (is_uniquely_owned()) {
// Both counts are 1, so there are no weak references and
// we are releasing the last strong reference. No other
// threads can observe the effects of this target_ deletion
@ -337,9 +404,10 @@ class intrusive_ptr final {
auto combined_refcount = detail::atomic_combined_refcount_decrement(
target_->combined_refcount_, detail::kReferenceCountOne);
if (detail::refcount(combined_refcount) == 0) {
bool should_delete =
(combined_refcount == detail::kWeakReferenceCountOne);
uint32_t new_refcount = detail::refcount(combined_refcount);
bool has_pyobject = detail::has_pyobject(combined_refcount);
if (new_refcount == 0) {
bool should_delete = detail::weakcount(combined_refcount) == 1;
// See comment above about weakcount. As long as refcount>0,
// weakcount is one larger than the actual number of weak references.
// So we need to decrement it here.
@ -356,6 +424,18 @@ class intrusive_ptr final {
if (should_delete) {
delete target_;
}
} else if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
// If the refcount transitioned from 2 to 1, we need to decref the
// PyObject. In other words, we don't want to keep the PyObject alive if
// there are no C++ references to this object other than the PyObject
// itself.
if (C10_UNLIKELY(has_pyobject && new_refcount == 1)) {
target_->decref_pyobject();
}
} else {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
!has_pyobject,
"TargetTraits indicates that type cannot have PyObject, but refcount has PyObject bit set.");
}
}
}
@ -522,6 +602,16 @@ class intrusive_ptr final {
return use_count() == 1;
}
/**
* Stronger than unique() in that it must not have any weakrefs as well.
*/
bool is_uniquely_owned() const noexcept {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(target_ != NullType::singleton());
uint64_t combined =
target_->combined_refcount_.load(std::memory_order_acquire);
return (combined & ~detail::kHasPyObject) == detail::kUniqueRef;
}
/**
* Returns an owning (!) pointer to the underlying object and makes the
* intrusive_ptr instance invalid. That means the refcount is not decreased.
@ -932,6 +1022,7 @@ class weak_intrusive_ptr final {
if (target_ == NullType::singleton()) {
return intrusive_ptr<TTarget, NullType>();
} else {
bool increfed = false;
auto combined_refcount =
target_->combined_refcount_.load(std::memory_order_relaxed);
do {
@ -940,12 +1031,31 @@ class weak_intrusive_ptr final {
// Return nullptr.
return intrusive_ptr<TTarget, NullType>();
}
if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
if (detail::has_pyobject(combined_refcount) &&
detail::refcount(combined_refcount) == 1 && !increfed) {
// Object has a python wrapper with no other C++ references.
// We need to to incref the Python object before we acquire a
// strong reference to the C++ object to avoid a situation
// where the Python object is deallocated concurrently.
if (!target_->try_incref_pyobject()) {
return intrusive_ptr<TTarget, NullType>();
}
increfed = true;
}
}
} while (!target_->combined_refcount_.compare_exchange_weak(
combined_refcount,
combined_refcount + detail::kReferenceCountOne,
std::memory_order_acquire,
std::memory_order_relaxed));
if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
if (increfed && detail::refcount(combined_refcount) != 1) {
target_->decref_pyobject();
}
}
return intrusive_ptr<TTarget, NullType>(
target_, raw::DontIncreaseRefcount{});
}
@ -1060,7 +1170,18 @@ namespace intrusive_ptr {
// NullType::singleton to this function
inline void incref(intrusive_ptr_target* self) {
if (self) {
detail::atomic_refcount_increment(self->combined_refcount_);
uint64_t combined = detail::atomic_combined_refcount_increment(
self->combined_refcount_, detail::kReferenceCountOne);
#ifndef C10_MOBILE
if (C10_UNLIKELY(
detail::has_pyobject(combined) &&
detail::refcount(combined) == 2)) {
self->incref_pyobject();
}
#else
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!detail::has_pyobject(combined));
#endif
}
}

View File

@ -15,6 +15,8 @@ using namespace c10::CachingDeviceAllocator;
// newly allocated memory with 512-byte alignment.
constexpr size_t kDeviceAlignment = 512;
class XPUAllocator;
namespace {
using stream_set = ska::flat_hash_set<xpu::XPUStream>;
@ -23,14 +25,19 @@ typedef bool (*Comparison)(const Block*, const Block*);
bool BlockComparatorSize(const Block* a, const Block* b);
bool BlockComparatorAddress(const Block* a, const Block* b);
struct PrivatePool;
struct BlockPool {
BlockPool(bool small)
BlockPool(bool small, PrivatePool* private_pool = nullptr)
: blocks(BlockComparatorSize),
unmapped(BlockComparatorAddress),
is_small(small) {}
is_small(small),
owner_PrivatePool(private_pool) {}
std::set<Block*, Comparison> blocks;
std::set<Block*, Comparison> unmapped;
const bool is_small;
PrivatePool* owner_PrivatePool;
};
struct ExpandableSegment;
@ -349,6 +356,43 @@ struct AllocParams {
StatTypes stat_types = {};
};
// Internal implementation that manages actual memory blocks.
// high level MemPool interface wraps PrivatePool via MempoolId.
struct PrivatePool {
PrivatePool(MempoolId_t id, XPUAllocator* allocator = nullptr)
: id(std::move(id)),
allocator_(allocator),
large_blocks(/*small=*/false, this),
small_blocks(/*small=*/true, this) {}
PrivatePool(const PrivatePool&) = delete;
PrivatePool(PrivatePool&&) = delete;
PrivatePool& operator=(const PrivatePool&) = delete;
PrivatePool& operator=(PrivatePool&&) = delete;
~PrivatePool() = default;
// default Mempool when no Mempool is specified
MempoolId_t id{0, 0};
// Number of live graphs using this pool
int use_count{1};
// Number of unfreed allocations made for this pool. When use_count and
// allocation_count drop to zero, we can delete this PrivatePool from
// graph_pools.
int allocation_count{0};
XPUAllocator* allocator_;
BlockPool large_blocks;
BlockPool small_blocks;
public:
XPUAllocator* allocator() {
return allocator_;
}
};
struct MempoolIdHash {
std::size_t operator()(const MempoolId_t& mempool_id) const noexcept {
return mempool_id.first != 0 ? mempool_id.first : mempool_id.second;
}
};
} // anonymous namespace
class DeviceCachingAllocator {
@ -365,6 +409,13 @@ class DeviceCachingAllocator {
bool set_fraction = false;
std::vector<ExpandableSegment*> expandable_segments;
std::vector<c10::DeviceIndex> devices_with_peer_access; // reserved
std::vector<std::pair<MempoolId_t, std::function<bool(sycl::queue*)>>>
captures_underway;
ska::flat_hash_map<MempoolId_t, std::unique_ptr<PrivatePool>, MempoolIdHash>
graph_pools;
// Pools no longer referenced by any graph.
ska::flat_hash_map<MempoolId_t, PrivatePool*, MempoolIdHash>
graph_pools_freeable;
size_t try_merge_blocks(Block* dst, Block* src, BlockPool& pool) {
if (!src || src->allocated || src->event_count > 0 ||
@ -463,7 +514,22 @@ class DeviceCachingAllocator {
}
}
BlockPool& get_pool(size_t size) {
BlockPool& get_pool(size_t size, sycl::queue* queue) {
if (C10_UNLIKELY(!captures_underway.empty())) {
for (auto& entry : captures_underway) {
// lookup for mempool id matching current capture graph
if (entry.second(queue)) {
auto it1 = graph_pools.find(entry.first);
// lookup mempool
TORCH_INTERNAL_ASSERT(it1 != graph_pools.end());
if (size <= kSmallSize) {
return it1->second->small_blocks;
} else {
return it1->second->large_blocks;
}
}
}
}
if (size < kSmallSize) {
return small_blocks;
} else {
@ -669,6 +735,10 @@ class DeviceCachingAllocator {
if (!ptr) {
return false;
}
if (p.pool->owner_PrivatePool) {
p.pool->owner_PrivatePool->allocation_count++;
}
p.block = new Block(device, p.queue(), size, p.pool, ptr);
for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) {
stats.reserved_bytes[stat_type].increase(size);
@ -677,11 +747,14 @@ class DeviceCachingAllocator {
return true;
}
void synchronize_and_free_events() {
void synchronize_and_free_events(PrivatePool* pool = nullptr) {
for (auto& xe : xpu_events) {
for (auto& e : xe.second) {
auto event = e.first;
auto* block = e.second;
if (pool && block->pool->owner_PrivatePool != pool) {
continue;
}
event.wait();
block->event_count--;
if (block->event_count == 0) {
@ -785,6 +858,13 @@ class DeviceCachingAllocator {
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
stats.reserved_bytes[stat_type].decrease(unmapped.size);
});
if (block->pool->owner_PrivatePool) {
// The Freed block belonged to a XPU graph's PrivatePool.
TORCH_INTERNAL_ASSERT(
block->pool->owner_PrivatePool->allocation_count > 0);
block->pool->owner_PrivatePool->allocation_count--;
}
}
void release_blocks(BlockPool& pool) {
@ -812,13 +892,41 @@ class DeviceCachingAllocator {
}
}
bool release_cached_blocks() {
synchronize_and_free_events();
// See Note [Safe to Free Blocks on BlockPool]
c10::xpu::syncStreamsOnDevice(device_index);
bool release_cached_blocks(MempoolId_t mempool_id) {
if (mempool_id.first == 0 && mempool_id.second == 0 &&
captures_underway.empty()) {
synchronize_and_free_events();
// See Note [Safe to Free Blocks on BlockPool]
c10::xpu::syncStreamsOnDevice(device_index);
release_blocks(large_blocks);
release_blocks(small_blocks);
release_blocks(large_blocks);
release_blocks(small_blocks);
}
for (auto it = graph_pools_freeable.begin();
it != graph_pools_freeable.end();) {
if (mempool_id.first != 0 || mempool_id.second != 0) {
if (it->first == mempool_id) {
// If there is an active mempool, we sync only the events
// associated with the pool
synchronize_and_free_events(it->second);
} else {
// otherwise we move on
++it;
continue;
}
}
TORCH_INTERNAL_ASSERT(it->second->use_count == 0);
release_blocks(it->second->small_blocks);
release_blocks(it->second->large_blocks);
if (it->second->allocation_count == 0) {
auto erase_count = graph_pools.erase(it->first);
TORCH_INTERNAL_ASSERT(erase_count == 1);
it = graph_pools_freeable.erase(it);
} else {
++it;
}
}
return true;
}
@ -903,6 +1011,30 @@ class DeviceCachingAllocator {
}
}
void create_or_incref_pool(
MempoolId_t mempool_id,
XPUAllocator* allocator = nullptr) {
auto it = graph_pools.find(mempool_id);
if (it == graph_pools.end()) {
// mempool_id does not reference an existing pool.
// Make a new pool for XPU graph capture or memory pool usage.
graph_pools.emplace(
mempool_id, std::make_unique<PrivatePool>(mempool_id, allocator));
} else {
// mempool_id references an existing pool, which the current XPU graph
// capture will share.
TORCH_INTERNAL_ASSERT(it->second->use_count > 0);
TORCH_INTERNAL_ASSERT(allocator == nullptr);
it->second->use_count++;
}
}
PrivatePool* get_private_pool(MempoolId_t mempool_id) {
auto it = graph_pools.find(mempool_id);
TORCH_INTERNAL_ASSERT(it != graph_pools.end());
return it->second.get();
}
public:
DeviceCachingAllocator(DeviceIndex device_index)
: large_blocks(/* small */ false),
@ -911,9 +1043,11 @@ class DeviceCachingAllocator {
Block* malloc(DeviceIndex device, size_t orig_size, sycl::queue& queue) {
std::scoped_lock<std::recursive_mutex> lock(mutex);
process_events();
if (C10_LIKELY(captures_underway.empty())) {
process_events();
}
size_t size = round_size(orig_size);
auto& pool = get_pool(size);
auto& pool = get_pool(size, &queue);
const size_t alloc_size = get_allocation_size(size);
AllocParams params(device, size, &queue, &pool, alloc_size);
params.stat_types = get_stat_types_for_pool(pool);
@ -923,18 +1057,17 @@ class DeviceCachingAllocator {
// Can't reuse an existing block, try to get a new one.
if (!block_found) {
block_found = alloc_block(params, false) ||
(release_cached_blocks() && alloc_block(params, true));
(release_cached_blocks({0, 0}) && alloc_block(params, true));
}
if (!block_found) {
c10::xpu::DeviceProp device_prop;
c10::xpu::get_device_properties(&device_prop, device);
auto device_total = device_prop.global_mem_size;
const auto& raw_device = c10::xpu::get_raw_device(device);
const auto device_total =
raw_device.get_info<sycl::info::device::global_mem_size>();
// Estimate the available device memory when the SYCL runtime does not
// support the corresponding aspect (ext_intel_free_memory).
size_t device_free = device_prop.global_mem_size -
size_t device_free = device_total -
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)]
.current;
auto& raw_device = c10::xpu::get_raw_device(device);
// TODO: Remove the aspect check once the SYCL runtime bug is fixed on
// affected devices.
if (raw_device.has(sycl::aspect::ext_intel_free_memory)) {
@ -1017,9 +1150,9 @@ class DeviceCachingAllocator {
block->stream_uses.insert(stream);
}
void emptyCache() {
void emptyCache(MempoolId_t mempool_id) {
std::scoped_lock<std::recursive_mutex> lock(mutex);
release_cached_blocks();
release_cached_blocks(mempool_id);
}
DeviceStats getStats() {
@ -1052,21 +1185,37 @@ class DeviceCachingAllocator {
}
}
std::pair<size_t, size_t> getMemoryInfo() {
const auto& device = c10::xpu::get_raw_device(device_index);
const size_t total = device.get_info<sycl::info::device::global_mem_size>();
TORCH_CHECK(
device.has(sycl::aspect::ext_intel_free_memory),
"The device (",
device.get_info<sycl::info::device::name>(),
") doesn't support querying the available free memory. ",
"You can file an issue at https://github.com/pytorch/pytorch/issues ",
"to help us prioritize its implementation.");
const size_t free =
device.get_info<sycl::ext::intel::info::device::free_memory>();
return {free, total};
}
double getMemoryFraction() {
if (!set_fraction) {
return 1.0;
}
c10::xpu::DeviceProp device_prop;
c10::xpu::get_device_properties(&device_prop, device_index);
const auto device_total =
xpu::get_raw_device(device_index)
.get_info<sycl::info::device::global_mem_size>();
return static_cast<double>(allowed_memory_maximum) /
static_cast<double>(device_prop.global_mem_size);
static_cast<double>(device_total);
}
void setMemoryFraction(double fraction) {
c10::xpu::DeviceProp device_prop;
c10::xpu::get_device_properties(&device_prop, device_index);
auto device_total = device_prop.global_mem_size;
const auto device_total =
xpu::get_raw_device(device_index)
.get_info<sycl::info::device::global_mem_size>();
allowed_memory_maximum = static_cast<size_t>(fraction * device_total);
set_fraction = true;
}
@ -1157,9 +1306,9 @@ class XPUAllocator : public DeviceAllocator {
}
}
void emptyCache(MempoolId_t mempool_id [[maybe_unused]] = {0, 0}) override {
void emptyCache(MempoolId_t mempool_id) override {
for (auto& da : device_allocators) {
da->emptyCache();
da->emptyCache(mempool_id);
}
}
@ -1240,6 +1389,11 @@ class XPUAllocator : public DeviceAllocator {
c10::xpu::get_raw_device(dev_to_access));
}
std::pair<size_t, size_t> getMemoryInfo(DeviceIndex device) override {
assertValidDevice(device);
return device_allocators[device]->getMemoryInfo();
}
double getMemoryFraction(DeviceIndex device) {
assertValidDevice(device);
return device_allocators[device]->getMemoryFraction();
@ -1270,8 +1424,8 @@ void init(DeviceIndex device_count) {
return allocator.init(device_count);
}
void emptyCache() {
return allocator.emptyCache();
void emptyCache(MempoolId_t mempool_id) {
return allocator.emptyCache(mempool_id);
}
void resetPeakStats(DeviceIndex device) {

View File

@ -10,7 +10,7 @@ C10_XPU_API Allocator* get();
C10_XPU_API void init(DeviceIndex device_count);
C10_XPU_API void emptyCache();
C10_XPU_API void emptyCache(MempoolId_t mempool_id = {0, 0});
C10_XPU_API void resetPeakStats(DeviceIndex device);

View File

@ -1643,8 +1643,6 @@ if(USE_CUDA)
target_link_libraries(torch_cuda PUBLIC c10_cuda)
if(TARGET torch::nvtx3)
target_link_libraries(torch_cuda PRIVATE torch::nvtx3)
else()
target_link_libraries(torch_cuda PUBLIC torch::nvtoolsext)
endif()
target_include_directories(
@ -1741,9 +1739,6 @@ if(BUILD_SHARED_LIBS)
if(USE_CUDA)
target_link_libraries(torch_global_deps ${Caffe2_PUBLIC_CUDA_DEPENDENCY_LIBS})
target_link_libraries(torch_global_deps torch::cudart)
if(TARGET torch::nvtoolsext)
target_link_libraries(torch_global_deps torch::nvtoolsext)
endif()
endif()
install(TARGETS torch_global_deps DESTINATION "${TORCH_INSTALL_LIB_DIR}")
endif()

View File

@ -734,7 +734,7 @@ void PyTorchStreamWriter::setup(const string& file_name) {
file_name,
std::ofstream::out | std::ofstream::trunc | std::ofstream::binary
);
} catch (const std::ios_base::failure& e) {
} catch (const std::ios_base::failure&) {
#ifdef _WIN32
// Windows have verbose error code, we prefer to use it than std errno.
uint32_t error_code = GetLastError();
@ -773,8 +773,20 @@ void PyTorchStreamWriter::writeRecord(
bool compress) {
AT_ASSERT(!finalized_);
AT_ASSERT(!archive_name_plus_slash_.empty());
TORCH_INTERNAL_ASSERT(
files_written_.count(name) == 0, "Tried to serialize file twice: ", name);
if (files_written_.count(name) > 0) {
// Allow multiple writes for triton binaries
bool is_triton_extension =
c10::ends_with(name, ".so") ||
c10::ends_with(name, ".cubin") ||
c10::ends_with(name, ".hsaco");
if (is_triton_extension) {
LOG(WARNING) << "File '" << name << "' is being serialized multiple times";
return;
}
TORCH_INTERNAL_ASSERT(false, "Tried to serialize file twice: ", name);
}
if (name == kSerializationIdRecordName && serialization_id_.empty()) {
// In case of copying records from another file, skip writing a different
// serialization_id than the one computed in this writer.

View File

@ -968,11 +968,8 @@ find_package_handle_standard_args(nvtx3 DEFAULT_MSG nvtx3_dir)
if(nvtx3_FOUND)
add_library(torch::nvtx3 INTERFACE IMPORTED)
target_include_directories(torch::nvtx3 INTERFACE "${nvtx3_dir}")
target_compile_definitions(torch::nvtx3 INTERFACE TORCH_CUDA_USE_NVTX3)
else()
message(WARNING "Cannot find NVTX3, find old NVTX instead")
add_library(torch::nvtoolsext INTERFACE IMPORTED)
set_property(TARGET torch::nvtoolsext PROPERTY INTERFACE_LINK_LIBRARIES CUDA::nvToolsExt)
message(FATAL_ERROR "Cannot find NVTX3!")
endif()

View File

@ -1,7 +1,7 @@
# This will define the following variables:
# SYCL_FOUND : True if the system has the SYCL library.
# SYCL_INCLUDE_DIR : Include directories needed to use SYCL.
# SYCL_LIBRARY_DIR The path to the SYCL library.
# SYCL_LIBRARY_DIR : The path to the SYCL library.
# SYCL_LIBRARY : SYCL library fullname.
# SYCL_COMPILER_VERSION : SYCL compiler version.

View File

@ -132,9 +132,6 @@ if(@USE_CUDA@)
else()
set(TORCH_CUDA_LIBRARIES ${CUDA_NVRTC_LIB})
endif()
if(TARGET torch::nvtoolsext)
list(APPEND TORCH_CUDA_LIBRARIES torch::nvtoolsext)
endif()
if(@BUILD_SHARED_LIBS@)
find_library(C10_CUDA_LIBRARY c10_cuda PATHS "${TORCH_INSTALL_PREFIX}/lib")

View File

@ -10,7 +10,7 @@ API. This API can roughly be divided into five parts:
- **TorchScript**: An interface to the TorchScript JIT compiler and interpreter.
- **C++ Extensions**: A means of extending the Python API with custom C++ and CUDA routines.
Combining, these building blocks form a research and
Combined, these building blocks form a research and
production ready C++ library for tensor computation and dynamic neural
networks with strong emphasis on GPU acceleration as well as fast CPU
performance. It is currently in use at Facebook in research and
@ -76,7 +76,7 @@ C++ Frontend
------------
The PyTorch C++ frontend provides a high level, pure C++ modeling interface for
neural network and general ML(Machine Learning) research and production use cases,
neural networks and general ML (Machine Learning) research and production use cases,
largely following the Python API in design and provided functionality. The C++
frontend includes the following:

View File

@ -40,6 +40,7 @@
:nosignatures:
empty_cache
get_memory_info
max_memory_allocated
max_memory_reserved
memory_allocated

View File

@ -0,0 +1,113 @@
# Device Management
## Background
Device management handles basic operations like querying how many devices are available and switching between them. Accelerator backends need to wrap their device runtime's APIs and expose them to PyTorch.
The OpenReg implementation ([`OpenRegFunctions.h/cpp`][OpenReg Device Management]) shows how to wrap a third-party runtime. These functions are used throughout the backend - by streams, events, generators, and Python bindings.
## Design
Accelerator vendors need to implement these core functions:
| Function Name | Description | Application Scenarios |
| ------------------------- | ---------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------- |
| `device_count()` | Query the total number of available devices in the system | - Application initialization<br>- Multi-device workload distribution<br>- Validating device indices before use |
| `current_device()` | Get the currently active device for the calling thread | - Debugging and logging<br>- Determining tensor placement<br>- Guard implementations |
| `set_device()` | Change the active device for subsequent operations | - Switching context between devices<br>- Initializing specific device resources<br>- Multi-GPU training loops |
| `exchange_device()` | Atomically swap device and return the previous device | - Implementing device guards<br>- Temporarily switching device context<br>- RAII-based device management |
| `maybe_exchange_device()` | Conditionally exchange device only if the index is valid (-1 OK) | - Safe device switching with optional indices<br>- Guard implementations with nullable device values |
These functions are building blocks for more complex features like streams, events, and memory management. Make sure to validate inputs and handle errors properly.
## Implementation
This section shows how to implement device management using `set_device` as an example. The implementation requires:
1. C++ wrappers around the device runtime
2. Python bindings to expose the C++ functions
3. User-friendly Python APIs
### C++ Side
Wrap the device runtime's API and add error handling. The `SetDevice` function shows this pattern:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG SetDevice FUNCTION
:end-before: LITERALINCLUDE END: OPENREG SetDevice FUNCTION
:linenos:
```
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG set_device FUNCTION
:end-before: LITERALINCLUDE END: OPENREG set_device FUNCTION
:linenos:
```
### Binding
Expose the C++ functions to Python using pybind11:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
:language: c++
:start-after: LITERALINCLUDE START: MODULE SET DEVICE HELPER
:end-before: LITERALINCLUDE END: MODULE SET DEVICE HELPER
:linenos:
```
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG MODULE METHODS
:end-before: LITERALINCLUDE END: OPENREG MODULE METHODS
:linenos:
:emphasize-lines: 5
```
### Python Side
Wrap the C++ bindings with user-friendly Python functions:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/__init__.py
:language: python
:start-after: LITERALINCLUDE START: PYTHON SET DEVICE FUNCTION
:end-before: LITERALINCLUDE END: PYTHON SET DEVICE FUNCTION
:linenos:
```
Here's the complete mapping from C++ to Python:
| C++ Binding Function | C++ Binding API (pybind11) | Python User API | Description |
| -------------------- | ---------------------------------------- | -------------------------------- | -------------------------------------------- |
| `_getDeviceCount` | `torch_openreg._C._get_device_count()` | `torch.openreg.device_count()` | Returns the total number of devices |
| `_getDevice` | `torch_openreg._C._get_device()` | `torch.openreg.current_device()` | Returns the current active device index |
| `_setDevice` | `torch_openreg._C._set_device(idx)` | `torch.openreg.set_device(idx)` | Sets the active device |
| `_exchangeDevice` | `torch_openreg._C._exchange_device(idx)` | N/A (internal use only) | Atomically swaps device and returns previous |
## Guard
Device guards provide automatic device switching with exception safety. They're similar to lock guards in C++ - they switch device on construction and restore it on destruction.
Implement `DeviceGuardImplInterface` to integrate with PyTorch's guard system:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h
:language: c++
:start-after: LITERALINCLUDE START: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
:end-before: LITERALINCLUDE END: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
:linenos:
```
**What needs to be implemented:**
1. **exchangeDevice()**: Switch to a new device and return the old one (used by guard constructors)
2. **getDevice()**: Get the current device
3. **setDevice()**: Set the active device
4. **Type checking**: Validate that device type matches the backend
This makes the guard available to PyTorch for the `PrivateUse1` device type. Users can then use standard PyTorch device guards with the custom backend.
[OpenReg Device Management]: https://github.com/pytorch/pytorch/blob/main/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp "OpenReg Device Management"

View File

@ -0,0 +1,164 @@
# Accelerator Hooks
## Background
OpenReg hooks provide a mechanism for integrating custom accelerator devices into PyTorch's runtime system. OpenReg (Open Registration) is PyTorch's extensibility framework that allows accelerator vendors to register custom device backends without modifying PyTorch core code.
## Design
The following tables list all hooks that accelerator vendors need to implement when integrating a new device backend. These hooks are categorized into two priority levels:
- **High Priority Hooks**: Core APIs that PyTorch runtime directly depends on. Accelerator vendors are recommended to implement all high priority hooks to ensure full PyTorch compatibility and enable basic device functionality.
- **Low Priority Hooks**: Device management and utility APIs that PyTorch does not directly depend on. These hooks enhance user experience and multi-device support but are *optional*. Accelerator vendors can choose to implement them based on their specific requirements and use cases.
### High Priority Hooks
| Hook Method | Description | Application Scenario |
| ---------------------------------- | --------------------------------------------------------- | -------------------------------------------------------------------------------- |
| `init()` | Initializes the accelerator runtime and device contexts | Set up necessary state when PyTorch first accesses the device |
| `hasPrimaryContext(DeviceIndex)` | Checks if a primary context exists for the device | Determine whether device initialization has occurred |
| `getDefaultGenerator(DeviceIndex)` | Returns the default random number generator for a device | Access the device's primary RNG for reproducible random operations |
| `getNewGenerator(DeviceIndex)` | Creates a new independent random number generator | Create isolated RNG instances for parallel operations |
| `getDeviceFromPtr(void*)` | Determines which device a memory pointer belongs to | Identify the accelerator device associated with a memory allocation |
| `getPinnedMemoryAllocator()` | Returns an allocator for pinned (page-locked) host memory | Allocate host memory that can be efficiently transferred to/from the accelerator |
| `isPinnedPtr(void*)` | Checks if a pointer points to pinned memory | Validate memory types before performing operations |
### Low Priority Hooks
| Hook Method | Description | Application Scenario |
| ---------------------------------- | ---------------------------------------------------------------------------- | -------------------------------------------------------------------- |
| `isBuilt()` | Returns whether the accelerator backend is built/compiled into the extension | Check whether the accelerator library is available at compile time |
| `isAvailable()` | Returns whether the accelerator hardware is available at runtime | Verify whether accelerator devices can be detected and initialized |
| `deviceCount()` | Returns the number of available accelerator devices | Enumerate all available accelerator devices for device selection |
| `setCurrentDevice(DeviceIndex)` | Sets the active device for the current thread | Switch the current thread's context to a specific accelerator device |
| `getCurrentDevice()` | Returns the currently active device index | Query which accelerator device is active in the current thread |
| `exchangeDevice(DeviceIndex)` | Atomically exchanges the current device and returns the previous one | Temporarily switch devices and restore the previous device afterward |
| `maybeExchangeDevice(DeviceIndex)` | Conditionally exchanges device only if the index is valid | Safely attempt device switching with validation |
## Implementation
We can just take `getDefaultGenerator` as an implementation example:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHooks.h
:language: c++
:start-after: LITERALINCLUDE START: OPENREG HOOK EXAMPLES
:end-before: LITERALINCLUDE END: OPENREG HOOK EXAMPLES
:linenos:
```
In this implementation:
1. **Override the base interface**: The `getDefaultGenerator` method overrides the virtual method from `at::PrivateUse1HooksInterface`.
2. **Delegate to device-specific implementation**: It calls `getDefaultOpenRegGenerator(device_index)`, which manages a per-device generator instance.
3. **Return device-specific generator**: The returned `at::Generator` wraps an `OpenRegGeneratorImpl` that implements device-specific random number generation.
This pattern applies to all hooks: override the interface method, validate inputs, delegate to your device-specific API, and return results in PyTorch's expected format.
## Integration Example
The following sections demonstrate how PyTorch integrates with accelerator hooks when accessing the default random number generator. The example traces the complete flow from user-facing Python code down to the device-specific implementation.
### Layer 1: User Code
User code initiates the operation by calling `manual_seed` to set the random seed for reproducible results:
```python
import torch
torch.openreg.manual_seed(42)
```
### Layer 2: Extension Python API
The Python API layer handles device management and calls into the C++ extension (defined in [`torch_openreg/openreg/random.py`][random.py]):
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/random.py
:language: python
:start-after: LITERALINCLUDE START: OPENREG MANUAL SEED
:end-before: LITERALINCLUDE END: OPENREG MANUAL SEED
:linenos:
```
The `manual_seed` function gets the current device index and calls `torch_openreg._C._get_default_generator(idx)` to obtain the device-specific generator, then sets the seed on it.
### Layer 3: Python/C++ Bridge
The C++ extension exposes `_getDefaultGenerator` to Python, which bridges to PyTorch's core runtime:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG GET DEFAULT GENERATOR
:end-before: LITERALINCLUDE END: OPENREG GET DEFAULT GENERATOR
:linenos:
:emphasize-lines: 10-11
```
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG MODULE METHODS
:end-before: LITERALINCLUDE END: OPENREG MODULE METHODS
:linenos:
:emphasize-lines: 3
```
This function unpacks the device index from Python, creates a `PrivateUse1` device object, and calls `at::globalContext().defaultGenerator()`. PyTorch's context then dispatches to the registered hooks.
### Layer 4: PyTorch Core Context
PyTorch's Context class dispatches to the appropriate accelerator hooks ([`aten/src/ATen/Context.h`][Context.h]):
```{eval-rst}
.. literalinclude:: ../../../aten/src/ATen/Context.h
:language: c++
:lines: 60-103
:linenos:
:emphasize-lines: 8-9, 24-25
```
This layered architecture enables PyTorch to remain device-agnostic while delegating hardware-specific operations to accelerator implementations. The hooks are registered once at module load time:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHooks.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG HOOK REGISTER
:end-before: LITERALINCLUDE END: OPENREG HOOK REGISTER
:linenos:
:emphasize-lines: 4
```
### Layer 5: Accelerator Hooks
The hooks interface provides the abstraction that PyTorch uses to delegate to device-specific implementations:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHooks.h
:language: c++
:start-after: LITERALINCLUDE START: OPENREG HOOK EXAMPLES
:end-before: LITERALINCLUDE END: OPENREG HOOK EXAMPLES
:linenos:
```
The `getDefaultGenerator` hook method overrides the base interface and delegates to `getDefaultOpenRegGenerator`, which manages the actual generator instances.
### Layer 6: Device-Specific Implementation
The device-specific implementation manages per-device generator instances:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGenerator.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG GET DEFAULT GENERATOR IMPL
:end-before: LITERALINCLUDE END: OPENREG GET DEFAULT GENERATOR IMPL
:linenos:
```
This function maintains a static vector of generators (one per device), initializes them on first access, validates the device index, and returns the appropriate generator instance.
[random.py]: https://github.com/pytorch/pytorch/tree/main/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/random.py#L48-L53 "random.py"
[Context.h]: https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/Context.h#L61-L102 "Context.h"

View File

@ -42,6 +42,8 @@ Next, we will delve into each chapter of this guide. Each chapter focuses on a k
:glob:
:maxdepth: 1
device
hooks
autoload
operators
amp

Some files were not shown because too many files have changed in this diff Show More