Compare commits

...

93 Commits

Author SHA1 Message Date
605612232b Update on "small changes"
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-11-17 15:09:57 -08:00
5fb77f86f3 Update base for Update on "small changes"
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-11-17 15:09:57 -08:00
e4b87540e7 Update on "small changes"
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-11-17 11:23:14 -08:00
232872ff0d Update base for Update on "small changes"
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-11-17 11:23:14 -08:00
2f7db8ce15 Update on "small changes"
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-11-17 09:53:39 -08:00
4edfc71feb Update base for Update on "small changes"
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-11-17 09:53:39 -08:00
567dcdba75 Fix longstanding race condition around getAllOperatorsFor (#167860)
getAllOperatorsFor returns a const reference to internal state that is protected by a lock. Presuming that the lock is necessary in the first place (about which I offer no opinion because it's unclear to what extent the GIL should help here), this is a straightforward way to cause callers to create race conditions.

This should fix those race conditions by copying the state instead. I modified calling code to stop binding a const reference to the result for clarity.

Differential Revision: [D87088731](https://our.internmc.facebook.com/intern/diff/D87088731/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D87088731/)!

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167860
Approved by: https://github.com/zou3519
2025-11-17 17:37:02 +00:00
77acc66df9 [ROCm][CI] Upgrade ROCm CI to 7.1 (#166743)
Upgrade all the ROCm docker images to ROCm 7.1 release version.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166743
Approved by: https://github.com/atalman, https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
Co-authored-by: Prachi Gupta <prachi.gupta@amd.com>
2025-11-17 17:17:25 +00:00
95d1df7d4e Disable CUDA MXFP4 on non-B200 GPUs (#167857)
Summary:

MXFP4 unit tests pass on B200, fail on RTX 5090 - disable non-B200
cases.

Also add a fail w/a not implemented error for non-B200 to avoid
unhelpful failure messages.

Test Plan:

```
pytest -sv -k "mxfp4" test/test_scaled_matmul_cuda.py
```

Reviewers:

@nWEIdia

Subscribers:

Tasks:

Fixes https://github.com/pytorch/pytorch/issues/167850

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167857
Approved by: https://github.com/nWEIdia, https://github.com/malfet
2025-11-17 17:14:53 +00:00
094e529c64 [MPS] Fix repeat_interleave with slices (#167961)
Alas, one can not use `repeat_interleave_common` for MPS tensors, as `data_offset` is not a valid pointer to `id<MTLTensor>`
On the other hand, one does not need to use `AT_DISPATCH_INDEX_TYPES` as dispatching is happening on the shader side

Fixes https://github.com/pytorch/pytorch/issues/167924
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167961
Approved by: https://github.com/manuelcandales
2025-11-17 17:10:59 +00:00
a4c7bf7e8d Revert "Use c10::filesystem (#167821)"
This reverts commit deabb3e36de207aa497b035a8bdf6ec1b37d17fe.

Reverted https://github.com/pytorch/pytorch/pull/167821 on behalf of https://github.com/jeanschmidt due to Breaks internal tests, see D87148810. @Skylion007 may you help the author to get this PR merged? ([comment](https://github.com/pytorch/pytorch/pull/167821#issuecomment-3542877623))
2025-11-17 16:48:57 +00:00
22ccd44d73 Revert "Improve char printing (#167899)"
This reverts commit 2245d7d3b90162ae2958929a22c140537cfc4b42.

Reverted https://github.com/pytorch/pytorch/pull/167899 on behalf of https://github.com/jeanschmidt due to need to revert in order to revert https://github.com/pytorch/pytorch/pull/167899 ([comment](https://github.com/pytorch/pytorch/pull/167899#issuecomment-3542869096))
2025-11-17 16:46:44 +00:00
39ebab1dd9 Revert "Remove python workaround for ContextDecorator (#167049)"
This reverts commit e20ca3bc2e6ef9935c782fe548348f81fabc5bd7.

Reverted https://github.com/pytorch/pytorch/pull/167049 on behalf of https://github.com/jeanschmidt due to breaks internal tests see D87120562, @Skylion007 please thelp the author get this PR merged ([comment](https://github.com/pytorch/pytorch/pull/167049#issuecomment-3542847796))
2025-11-17 16:41:26 +00:00
4c152a71ad Revert "add device generalization support for distributed tests (#165067)"
This reverts commit 96a4c4b3d1c533b36cfa7259524b91a0eaf4254f.

Reverted https://github.com/pytorch/pytorch/pull/165067 on behalf of https://github.com/jeanschmidt due to breaks internal tests see D87036515, @albanD please help the author get this PR merged ([comment](https://github.com/pytorch/pytorch/pull/165067#issuecomment-3542820651))
2025-11-17 16:37:07 +00:00
1730d78e4f Update on "small changes"
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-11-17 08:32:42 -08:00
11ede3e98b Update base for Update on "small changes"
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-11-17 08:32:42 -08:00
1b43d6cd4e [ROCm] enable fastSpecializedAtomicAdd for gfx950 (#167661)
Use standard HIP headers for unsafeAtomicAdd. Removes copy/paste of unsafeAtomicAdd as "preview" implementation for gfx942.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167661
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-11-17 16:18:49 +00:00
2b69673bbf [CD] Add libopenblas to dep list for AArch64+CPU whl (#167841)
#166044 removes openblas from whl dependency list for AArch64+CPU build so this PR adds it back. Only affects CPU build since AArch64+CUDA uses NVPL.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167841
Approved by: https://github.com/tinglvv, https://github.com/malfet
2025-11-17 16:11:39 +00:00
2f74916e36 Do not hardfail on use nccl estimations for non-nccl (#167827)
Previously we hard failed if pg was "gloo".
Fallback on hardcoded formulas.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167827
Approved by: https://github.com/eellison
2025-11-17 16:06:26 +00:00
2b5eabc74b Rework PyObject preservation (v2) (#167564)
Make the PyObject preservation scheme thread-safe with free threaded (nogil) Python. The general idea is:

* Python Tensor and Storage objects always hold a strong reference to their underlying c10 object
* c10 objects hold a strong reference to their Python objects if there's at least one other reference to the c10 object

This is implemented in `intrusive_ptr`:

* The top most bit (`kHasPyObject`) from the weakref count is now used to indicate if the `intrusive_ptr_target` has an associated PyObject. So `kHasPyObject` is one bit, the weakref count is now 31 bits and the strong refcount remains 32 bits.
* When the reference count increases from one to two and `kHasPyObject` is set, we incref the associated Python object to ensure that it's kept alive.
* When the reference count decreases from two to one (i.e., there are no C++ reference to the `intrusive_ptr_target` other than from the Python object), we decre the associated Python object to break the cycle.

Other benefits:

* We can delete a lot of the copypasta from Python internal `subtype_dealloc`
* This fixes the weakref and GC bugs we had in the previous scheme. Python weakrefs on Tensors and Storages should just work as expected now.

Risks:

* Extra branch for reference count operations on `intrusive_ptr<TensorImpl>`, `intrusive_ptr<StorageImpl>`, and the generic `intrusive_ptr<intrusive_ptr_target>` even when we're not using Python.
* It's a big change

(Second attempt at https://github.com/pytorch/pytorch/pull/166342)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167564
Approved by: https://github.com/albanD, https://github.com/Skylion007
2025-11-17 14:52:02 +00:00
9ff95f6835 [inductor] Expose config for fx bucket all_reduces (#167634)
Exposing `_inductor.config.bucket_all_reduces_fx` similar to all_gathers, reduce_scatters with only option "all".

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167634
Approved by: https://github.com/eellison
2025-11-17 13:10:36 +00:00
6fdb974f4a Update torch-xpu-ops commit pin (#167698)
Update the torch-xpu-ops commit to [intel/torch-xpu-ops@1e69f4](1e69f40b3c), includes:

- Add PTL in the default AOT target list for both Win and Lin
- Use PyTorch p2p API in Copy kernel
- Add event cache and event timing to XCCL
- Add Float8_e8m0fnu support for copy
- Add CMAKE_SYCL_COMPILER_LAUNCHER for sccache
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167698
Approved by: https://github.com/EikanWang
2025-11-17 12:58:42 +00:00
661d1653aa [xla hash update] update the pinned xla hash (#167968)
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml).
Update the pinned xla hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167968
Approved by: https://github.com/pytorchbot
2025-11-17 12:20:32 +00:00
53809f9640 [ARM] Improve LLM performance & mem usage using int4-bf16 KleidiAI kernels (#158250)
Co-authored-by: Nikhil Gupta [nikhil.gupta2@arm.com](mailto:nikhil.gupta2@arm.com)

This PR enables the use of KleidiAI INT4 kernels that directly produce BF16 outputs within PyTorch to boost LLM prefill & decode performance

**This change improves decode throughput by ~15% & reduces memory required to inference the model by 50%**

### Benchmark Setup
```
Model: meta-llama/Llama-3.1-8B
Test Platform: Neoverse V2
```
### Detailed Results

| Metric                           | With `--compile`         | Without `--compile`      |
|----------------------------------|---------------------------|---------------------------|
| Quantization Scheme              | INT4 symmetric channelwise | INT4 symmetric channelwise |
| Input Precision                  | BF16                      | BF16                      |
| Number of Layers Quantized       | 32                        | 32                        |
| Average Compression Ratio        | 87.49%                    | 87.49%                    |
| Total Quantization Time (s)      | 9.62                      | 10.32                     |
| Compile Time (First) (s)         | 134.48                    | 1.69                      |
| Compile Time (Second) (s)        | 80.44                     | 1.60                      |
| Compile Time (Subsequent) (s)    | 0.19                      | 0.22                      |
| Prefill Tokens                   | 54                        | 54                        |
| Decoded Tokens                   | 33                        | 33                        |
| Prefill Time (s)                 | 0.19                      | 0.22                      |
| Decode Time (s)                  | 0.76                      | 1.38                      |
| E2E Generation Time (s)          | 0.95                      | 1.60                      |
| Prefill Throughput (tokens/s)    | 288.13                    | 249.91                    |
| Decode Throughput (tokens/s)     | 43.42                     | 23.83                     |
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158250
Approved by: https://github.com/malfet, https://github.com/aditew01, https://github.com/fadara01

Co-authored-by: Nikhil Gupta <nikhil.gupta2@arm.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-11-17 12:06:33 +00:00
93ddd38ecd Re-land#2 "Fix thread safety in getCurrentCUDABlasHandle and getCUDABlasLtWorkspace" (#167928)
Summary:
getCurrentCUDABlasHandle() and getCUDABlasLtWorkspace() use static mutable maps that are not protected from concurrent read-and-write. This leads to crashes.

This diff adds mutexes to synchronize access to the static maps.

Re-land context:

This is a re-land of https://github.com/pytorch/pytorch/pull/167248.

A few issues were addressed:
- fix for a bug in fast path: premature return in getCurrentCUDABlasHandle)
- fix for test flakiness (https://github.com/pytorch/pytorch/pull/167884)

Test Plan:
1. regression tests:
buck2 test \mode/opt //caffe2/test\:test_transformers_cuda
https://www.internalfb.com/intern/testinfra/testrun/6192449759713581

2. Use a GPU OD, run multi-threaded tests with TSAN:

buck test fbcode//mode/dev-tsan fbcode//caffe2:cuda_cublas_handle_pool_test  -- --stress-runs 100
https://www.internalfb.com/intern/testinfra/testrun/14355223937501118

Differential Revision: D87111985

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167928
Approved by: https://github.com/Skylion007
2025-11-17 12:05:08 +00:00
5804408f1b [1/3][XPU][feature] The implementation of memory private pool in XPU device allocator (#166831)
The implementation plan of MemPool for XPU, which is the dependance of [XPUGraph](https://github.com/pytorch/pytorch/pull/166285), following the [RFC](https://github.com/pytorch/pytorch/issues/162143).

- [ ] ->#166831
- [ ] #166833
- [ ] #166843

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166831
Approved by: https://github.com/EikanWang, https://github.com/gujinghui

Co-authored-by: Eikan Wang <eikan.wang@intel.com>
2025-11-17 11:11:23 +00:00
99117c1238 Remove old NVTX interface (#167637)
The PR #167401 reminded me that the removal of old NVTX interface is long overdue, as the header-only NVTX3 has been around for more than 5 years and is shipped with all CUDA Toolkit versions of 12+. In addition to that, `libnvToolsExt.so` was removed in CUDA Toolkit 13 and onward.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167637
Approved by: https://github.com/eqy
2025-11-17 08:07:20 +00:00
b9bccec3bc Revert "[ATen][CUDA] Add sm_121a flag for RowwiseScaledMM (#167734)"
This reverts commit 226850cc66217e591c706397dd212b457ed61e22.

Reverted https://github.com/pytorch/pytorch/pull/167734 on behalf of https://github.com/Aidyn-A due to fails on CUDA 12.8 ([comment](https://github.com/pytorch/pytorch/pull/167734#issuecomment-3540410067))
2025-11-17 07:56:28 +00:00
ca3aaef66e Fix clamp broadcasting on MPS (Fixes #160734) (#165058)
This PR fixes a bug where `torch.clamp` on MPS fails when min/max tensors have more dimensions than the input tensor.
CPU already supports this broadcasting, but MPS raised a RuntimeError.

Example of failing case before the fix:
```python
x = torch.randn(2, 3, device="mps")
min_t = torch.randn(1, 2, 3, device="mps")
max_t = torch.randn(1, 2, 3, device="mps")
torch.clamp(x, min=min_t, max=max_t)  # RuntimeError
```
After this fix, MPS matches CPU behavior.

Fixes #160734

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165058
Approved by: https://github.com/malfet
2025-11-17 07:40:39 +00:00
f2e6f94081 deprecate check_is_size and guard_size_oblivious (#167198)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167198
Approved by: https://github.com/bobrenjc93
2025-11-17 05:47:40 +00:00
aa504d4d2a [audio hash update] update the pinned audio hash (#167914)
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml).
Update the pinned audio hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167914
Approved by: https://github.com/pytorchbot
2025-11-17 05:21:29 +00:00
d8ce6f8df9 Enable PyTorch OSS numerics changes, inductor heuristics (#167799)
Test Plan: CI

Differential Revision: D86211542

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167799
Approved by: https://github.com/njriasan, https://github.com/eellison
2025-11-17 04:31:44 +00:00
4322354770 [Inductor] optimize scalar welford_reduce (#162709)
**Summary:**
Optimize scalar welford_reduce implementation, combining Welford algorithm with cascade sum to improve numerical stability. Specifically:

1. Use Welford algorithm to compute mean and variance.
2. Use cascade summation when computing sum over input for both mean and variance.

**Example:**
Take https://github.com/pytorch/pytorch/issues/141541 as an example:
```
import torch
import torch.nn as nn
torch.manual_seed(0)

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.gn = nn.GroupNorm(num_groups=32, num_channels=32)

    def forward(self, x):
        return self.gn(x)

model = Model().eval()
x = torch.randn(1, 32, 128, 128, 128)

with torch.no_grad():
    output = model(x)
    with torch._inductor.config.patch({"cpp.simdlen": 0}):
        c_model = torch.compile(model)
        c_output = c_model(x)

print(torch.max(torch.abs(output - c_output)))
print(torch.allclose(output, c_output, 1.3e-6, 1e-5))
```
**logs**

- before
```
tensor(0.0005)
False
```
- After
```
tensor(1.4305e-06)
True
```

**Generated code:**
- before
```
cpp_fused_native_group_norm_0 = async_compile.cpp_pybinding(['float*', 'float*', 'const float*', 'const float*', 'const float*', 'float*'], '''
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C"  void  kernel(float* in_out_ptr0,
                       float* in_out_ptr1,
                       const float* in_ptr0,
                       const float* in_ptr1,
                       const float* in_ptr2,
                       float* out_ptr2)
{
    auto out_ptr1 = in_out_ptr0;
    auto out_ptr0 = in_out_ptr1;
    {
        #pragma GCC ivdep
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(32L); x0+=static_cast<int64_t>(1L))
        {
            {
                Welford<float> tmp_acc0 = Welford<float>();
                Welford<float> tmp_acc0_arr[4];
                for (int i = 0; i < 4; i++)
                {
                    tmp_acc0_arr[i] = Welford<float>();
                }
                #pragma omp parallel num_threads(4)
                {
                    int tid = omp_get_thread_num();
                    Welford<float> tmp_acc0_local = Welford<float>();
                    #pragma omp for
                    for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(2097152L); x1+=static_cast<int64_t>(1L))
                    {
                        {
                            {
                                auto tmp0 = in_ptr0[static_cast<int64_t>(x1 + 2097152L*x0)];
                                tmp_acc0_local = welford_combine(tmp_acc0_local, tmp0);
                            }
                        }
                    }
                    tmp_acc0_arr[tid] = tmp_acc0_local;
                }
                for (int tid = 0; tid < 4; tid++)
                {
                    tmp_acc0 = welford_combine(tmp_acc0, tmp_acc0_arr[tid]);
                }
                in_out_ptr1[static_cast<int64_t>(x0)] = tmp_acc0.mean;
                in_out_ptr0[static_cast<int64_t>(x0)] = tmp_acc0.m2;
            }
        }
    }
    {
        #pragma GCC ivdep
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(32L); x0+=static_cast<int64_t>(1L))
        {
            {
                {
                    auto tmp0 = out_ptr1[static_cast<int64_t>(x0)];
                    auto tmp6 = in_ptr1[static_cast<int64_t>(x0)];
                    auto tmp8 = out_ptr0[static_cast<int64_t>(x0)];
                    auto tmp11 = in_ptr2[static_cast<int64_t>(x0)];
                    auto tmp1 = static_cast<float>(2097152.0);
                    auto tmp2 = tmp0 / tmp1;
                    auto tmp3 = static_cast<float>(1e-05);
                    auto tmp4 = float(tmp2 + tmp3);
                    auto tmp5 = 1 / std::sqrt(tmp4);
                    auto tmp7 = float(tmp5 * tmp6);
                    auto tmp9 = decltype(tmp8)(-tmp8);
                    auto tmp10 = float(tmp9 * tmp7);
                    auto tmp12 = float(tmp10 + tmp11);
                    in_out_ptr0[static_cast<int64_t>(x0)] = tmp7;
                    in_out_ptr1[static_cast<int64_t>(x0)] = tmp12;
                }
            }
        }
    }
    #pragma omp parallel num_threads(4)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for
            for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(32L); x0+=static_cast<int64_t>(1L))
            {
                #pragma GCC ivdep
                for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(2097152L); x1+=static_cast<int64_t>(1L))
                {
                    {
                        {
                            auto tmp0 = in_ptr0[static_cast<int64_t>(x1 + 2097152L*x0)];
                            auto tmp1 = in_out_ptr0[static_cast<int64_t>(x0)];
                            auto tmp3 = in_out_ptr1[static_cast<int64_t>(x0)];
                            auto tmp2 = float(tmp0 * tmp1);
                            auto tmp4 = float(tmp2 + tmp3);
                            out_ptr2[static_cast<int64_t>(x1 + 2097152L*x0)] = tmp4;
                        }
                    }
                }
            }
        }
    }
}
''')

async_compile.wait(globals())
del async_compile

class Runner:
    def __init__(self, partitions):
        self.partitions = partitions

    def recursively_apply_fns(self, fns):
        new_callables = []
        for fn, c in zip(fns, self.partitions):
            new_callables.append(fn(c))
        self.partitions = new_callables

    def call(self, args):
        arg0_1, arg1_1, arg2_1 = args
        args.clear()
        assert_size_stride(arg0_1, (32, ), (1, ))
        assert_size_stride(arg1_1, (32, ), (1, ))
        assert_size_stride(arg2_1, (1, 32, 128, 128, 128), (67108864, 2097152, 16384, 128, 1))
        buf0 = empty_strided_cpu((1, 32, 1, 1), (32, 1, 32, 32), torch.float32)
        buf1 = empty_strided_cpu((1, 32, 1, 1), (32, 1, 32, 32), torch.float32)
        buf3 = reinterpret_tensor(buf1, (1, 32, 1, 1), (32, 1, 1, 1), 0); del buf1  # reuse
        buf4 = reinterpret_tensor(buf0, (1, 32, 1, 1), (32, 1, 1, 1), 0); del buf0  # reuse
        buf5 = empty_strided_cpu((1, 32, 128, 128, 128), (67108864, 2097152, 16384, 128, 1), torch.float32)
        # [Provenance debug handles] cpp_fused_native_group_norm_0:1
        cpp_fused_native_group_norm_0(buf3, buf4, arg2_1, arg0_1, arg1_1, buf5)
        del arg0_1
        del arg1_1
        del arg2_1
        return (buf5, )
```

- After
```
cpp_fused_native_group_norm_0 = async_compile.cpp_pybinding(['float*', 'float*', 'const float*', 'const float*', 'const float*', 'float*'], '''
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C"  void  kernel(float* in_out_ptr0,
                       float* in_out_ptr1,
                       const float* in_ptr0,
                       const float* in_ptr1,
                       const float* in_ptr2,
                       float* out_ptr2)
{
    auto out_ptr1 = in_out_ptr0;
    auto out_ptr0 = in_out_ptr1;
    {
        #pragma GCC ivdep
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(32L); x0+=static_cast<int64_t>(1L))
        {
            {
                Welford<float> tmp_acc0 = Welford<float>();
                Welford<float> tmp_acc0_arr[4];
                for (int i = 0; i < 4; i++)
                {
                    tmp_acc0_arr[i] = Welford<float>();
                }
                #pragma omp parallel num_threads(4)
                {
                    int tid = omp_get_thread_num();
                    WelfordHelper<float, float, 4096> scalar_welford_helper0(static_cast<int64_t>(524288L));
                    Welford<float> tmp_acc0_local = Welford<float>();
                    #pragma omp for
                    for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(2097152L); x1+=static_cast<int64_t>(1L))
                    {
                        {
                            {
                                auto tmp0 = in_ptr0[static_cast<int64_t>(x1 + 2097152L*x0)];
                                tmp_acc0_local = welford_combine(tmp_acc0_local, tmp0, &scalar_welford_helper0);
                            }
                        }
                    }
                    tmp_acc0_local = welford_combine(tmp_acc0_local, &scalar_welford_helper0);
                    tmp_acc0_arr[tid] = tmp_acc0_local;
                }
                for (int tid = 0; tid < 4; tid++)
                {
                    tmp_acc0 = welford_combine(tmp_acc0, tmp_acc0_arr[tid]);
                }
                in_out_ptr1[static_cast<int64_t>(x0)] = tmp_acc0.mean;
                in_out_ptr0[static_cast<int64_t>(x0)] = tmp_acc0.m2;
            }
        }
    }
    {
        #pragma GCC ivdep
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(32L); x0+=static_cast<int64_t>(1L))
        {
            {
                {
                    auto tmp0 = out_ptr1[static_cast<int64_t>(x0)];
                    auto tmp6 = in_ptr1[static_cast<int64_t>(x0)];
                    auto tmp8 = out_ptr0[static_cast<int64_t>(x0)];
                    auto tmp11 = in_ptr2[static_cast<int64_t>(x0)];
                    auto tmp1 = static_cast<float>(2097152.0);
                    auto tmp2 = tmp0 / tmp1;
                    auto tmp3 = static_cast<float>(1e-05);
                    auto tmp4 = float(tmp2 + tmp3);
                    auto tmp5 = 1 / std::sqrt(tmp4);
                    auto tmp7 = float(tmp5 * tmp6);
                    auto tmp9 = decltype(tmp8)(-tmp8);
                    auto tmp10 = float(tmp9 * tmp7);
                    auto tmp12 = float(tmp10 + tmp11);
                    in_out_ptr0[static_cast<int64_t>(x0)] = tmp7;
                    in_out_ptr1[static_cast<int64_t>(x0)] = tmp12;
                }
            }
        }
    }
    #pragma omp parallel num_threads(4)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for
            for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(32L); x0+=static_cast<int64_t>(1L))
            {
                #pragma GCC ivdep
                for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(2097152L); x1+=static_cast<int64_t>(1L))
                {
                    {
                        {
                            auto tmp0 = in_ptr0[static_cast<int64_t>(x1 + 2097152L*x0)];
                            auto tmp1 = in_out_ptr0[static_cast<int64_t>(x0)];
                            auto tmp3 = in_out_ptr1[static_cast<int64_t>(x0)];
                            auto tmp2 = float(tmp0 * tmp1);
                            auto tmp4 = float(tmp2 + tmp3);
                            out_ptr2[static_cast<int64_t>(x1 + 2097152L*x0)] = tmp4;
                        }
                    }
                }
            }
        }
    }
}
''')

async_compile.wait(globals())
del async_compile

class Runner:
    def __init__(self, partitions):
        self.partitions = partitions

    def recursively_apply_fns(self, fns):
        new_callables = []
        for fn, c in zip(fns, self.partitions):
            new_callables.append(fn(c))
        self.partitions = new_callables

    def call(self, args):
        arg0_1, arg1_1, arg2_1 = args
        args.clear()
        assert_size_stride(arg0_1, (32, ), (1, ))
        assert_size_stride(arg1_1, (32, ), (1, ))
        assert_size_stride(arg2_1, (1, 32, 128, 128, 128), (67108864, 2097152, 16384, 128, 1))
        buf0 = empty_strided_cpu((1, 32, 1, 1), (32, 1, 32, 32), torch.float32)
        buf1 = empty_strided_cpu((1, 32, 1, 1), (32, 1, 32, 32), torch.float32)
        buf3 = reinterpret_tensor(buf1, (1, 32, 1, 1), (32, 1, 1, 1), 0); del buf1  # reuse
        buf4 = reinterpret_tensor(buf0, (1, 32, 1, 1), (32, 1, 1, 1), 0); del buf0  # reuse
        buf5 = empty_strided_cpu((1, 32, 128, 128, 128), (67108864, 2097152, 16384, 128, 1), torch.float32)
        # [Provenance debug handles] cpp_fused_native_group_norm_0:1
        cpp_fused_native_group_norm_0(buf3, buf4, arg2_1, arg0_1, arg1_1, buf5)
        del arg0_1
        del arg1_1
        del arg2_1
        return (buf5, )
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162709
Approved by: https://github.com/CaoE, https://github.com/jansel
2025-11-17 02:52:33 +00:00
363385ad3e s/Stragety/Strategy/ (#167916)
Signed-off-by: Edward Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167916
Approved by: https://github.com/Skylion007
2025-11-16 19:47:23 +00:00
e2e10753d7 Allow same triton kernels in export (#167862)
Summary: This diff would be a follow-up diff for D85883723.

Test Plan:
See D86719598. We are now able to publish the model.

Unit test:
```
buck run fbcode//mode/opt -c remoteexecution.local=enabled fbcode//sigmoid/inference/test:test_passes -m ovr_config//triton:experimental -- -r test_triton_hop_cpu
```

Differential Revision: D87091238

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167862
Approved by: https://github.com/XueningXu
2025-11-16 17:51:23 +00:00
5d99a795f5 [xpu][test] Migrated two test files to XPU (#166684)
# Description
Fixes #114850, we will port test utils and schema check to Intel GPU
We could enable Intel GPU with following methods and try the best to keep the original code styles:

# Changes
1. Get device type with from accelerator and get_devtype helper method
2. Replace the requires cuda statement to device_type.
3. Add HAS_XPU and HAS GPU check to replace some of the HAS_XPU etc.

# Notify

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166684
Approved by: https://github.com/ezyang, https://github.com/guangyey

Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com>
2025-11-16 14:15:28 +00:00
2245d7d3b9 Improve char printing (#167899)
This PR outputs chars to stream without building temporary strings.
They were modified by (on fish)
```
sed  -i -e 's/<< "\([^\\\']\)"/<< \'\1\'/g' (grep '<< "."' -r torch c10 aten -l)
```
and revert some invalid changes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167899
Approved by: https://github.com/Skylion007
2025-11-16 07:19:16 +00:00
98b94b90dd [pallas backend] implement gpu tiles/mask for power of 2 (#167584)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167584
Approved by: https://github.com/jansel
2025-11-16 07:01:51 +00:00
5cdbda140c [vision hash update] update the pinned vision hash (#167890)
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml).
Update the pinned vision hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167890
Approved by: https://github.com/pytorchbot
2025-11-16 04:58:47 +00:00
0ec53beaeb Refactor TensorAccessor for headeronly. (#166855)
This PR moves the implementations of Tensor accessor classes to headeronly with the following modifications:
- Add ArrayRef and IndexBoundsCheck template parameters to refactor out the usages of `IntArrayRef` and `TORCH_CHECK_INDEX` from Tensor accessor implementations.
- Eliminate usage of `c10::irange` as it is not headeronly-compatible.
- Introduce `torch::headeronly::{TensorAccessorBase,TensorAccessor, GenericPackedTensorAccessorBase, GenericPackedTensorAccessor}` that are headeronly-equivalent to `at::{TensorAccessorBase,TensorAccessor, GenericPackedTensorAccessorBase, GenericPackedTensorAccessor}`. Both these sets of template classes use original implementations from `torch::headeronly::detail` that have new template parameters `ArrayRefCls` and `IndexBoundsCheck` to facilitate `at` and `torch::headeronly` implementations of ArrayRef and checking indices.

TODO:
- ~when https://github.com/pytorch/pytorch/pull/164991 lands, eliminate the placeholder class HeaderOnlyArrayRef~ UPDATE: done.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166855
Approved by: https://github.com/janeyx99
2025-11-15 22:37:24 +00:00
79fc0a9141 [xpu][fix]Fall back deterministic index_copy to index_put on XPU (#167830)
A minor update has been made to the deterministic behavior checks in the `index_copy_out` implementation. This change ensures that deterministic  `index_copy` is dispatched to `index_put` not only for CUDA tensors but also for XPU tensors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167830
Approved by: https://github.com/guangyey, https://github.com/ezyang
2025-11-15 18:09:25 +00:00
d01a7b0241 Back out "MatMal - fix folding logic" (#167884)
Summary:
For sepcific hardware (A100), Autocast will generate a relatively large error on Transformer (torch.nn.TransformerEncoder) when using no_grad decorator on dim=256 (and larger presuably).

H100 seems fine, as does A100 with mig (so less than full SMs).

For now backing out, and revisting next week.

Test Plan:
failed jobs:
https://fburl.com/scuba/remote_execution_action/jzcmujgk

 {F1983543613}

Reviewed By: t-ivan-gr

Differential Revision: D87111518

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167884
Approved by: https://github.com/malfet
2025-11-15 08:29:08 +00:00
deabb3e36d Use c10::filesystem (#167821)
This PR fixes code to use c10::filesystem functionality instead of manually implemented functions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167821
Approved by: https://github.com/Skylion007
2025-11-15 06:01:01 +00:00
79d2397b6b Fix grammar issues in C++ frontend documentation (#167702)
Corrected minor grammatical errors in the documentation.

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167702
Approved by: https://github.com/jerryzh168
2025-11-15 05:55:08 +00:00
6ef3a62c36 Fix typo in FP16 accumulation section (#167703)
Fix typo error
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167703
Approved by: https://github.com/jerryzh168
2025-11-15 05:34:07 +00:00
530e782239 [codemod][lowrisk] Remove unused exception parameter from caffe2/torch/csrc/jit/backends/coreml/objc/PTMCoreMLBackend.mm (#167604)
Summary:
`-Wunused-exception-parameter` has identified an unused exception parameter. This diff removes it.

This:
```
try {
    ...
} catch (exception& e) {
    // no use of e
}
```
should instead be written as
```
} catch (exception&) {
```

If the code compiles, this is safe to land.

Test Plan: Sandcastle

Differential Revision: D85813836

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167604
Approved by: https://github.com/malfet, https://github.com/seemethere
2025-11-15 05:17:48 +00:00
c66a6c432e [HOP][print] Add functionalization (make sure ordering) for print (#167016)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167016
Approved by: https://github.com/angelayi
2025-11-15 05:06:05 +00:00
3d7a8b7e61 MPS: Fix clamp scalar cache key to store floats in hex representation (#167777)
Fixes #167767.

Original issue was that using std::to_string(value) does not work intended here if the value is smaller than 1e-6. The caching keys ended up as `clamp_out_mps_min:0.000000_scalar::f32[1]` instead of `clamp_out_mps_min:0.0000001_scalar::f32[1]`. After the change the values are stored as the hex representation for the floating point number. So for min_value 1e-7 the key will be `impl_min:0x1.ad7f2ap-24_scalar::f32[1]` and for min_value 0.0 `clamp_out_mps_min:0x0p+0_scalar::f32[1]`

Output of the repro code before the change:

```
tensor([0.], device='mps:0')
tensor([0.], device='mps:0')
tensor([0.], device='mps:0')
tensor([0.], device='mps:0')
tensor([0.], device='mps:0')
tensor([1.0000e-07], device='mps:0')
tensor([0.], device='mps:0')
tensor([1.0000e-07], device='mps:0')
```

Output for the repro code after the change:

```
tensor([0.], device='mps:0')
tensor([1.0000e-07], device='mps:0')
tensor([0.], device='mps:0')
tensor([1.0000e-07], device='mps:0')
tensor([0.], device='mps:0')
tensor([1.0000e-07], device='mps:0')
tensor([0.], device='mps:0')
tensor([1.0000e-07], device='mps:0')
```
which matches the expected CPU reference.

Snippet to test with:
```
import torch

device='mps'
dtype=torch.float32
a = torch.zeros(1, device=device, dtype=dtype)

# the following line triggers the incorrect behavior, when commented, the remainder of the script appears to work as expected
a_clamped = a.clamp(min=0.0)

b = torch.zeros(1, device=device)
print(b)
c = b.clamp(min=1e-7)
print(c)

b = torch.zeros(1, device=device)
print(b)
c = b.clamp(min=1e-7, max=None)
print(c)

b = torch.zeros(1, device=device)
print(b)
c = b.clamp(min=1e-7, max=torch.inf)
print(c)

b = torch.zeros(1, device=device)
print(b)
c = b.clamp_min(1e-7)
print(c)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167777
Approved by: https://github.com/malfet
2025-11-15 03:26:38 +00:00
de0d69b2c4 Remove useless super() delegation (#167791)
This PR removes useless super() delegations detected by pylint.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167791
Approved by: https://github.com/albanD
2025-11-15 02:50:51 +00:00
bc60b86066 Skip stable diffusion models in torchbench, get tests and benchmarks green (#167896)
Test Plan:
- wait for CI

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167896
Approved by: https://github.com/aorenste, https://github.com/shunting314
ghstack dependencies: #167609
2025-11-15 02:44:36 +00:00
d7782ddde7 [ATEN][CUDA] Reduce register pressure introduced by CUDA_KERNEL_ASSERT to improve torch.EmbeddingBag performance (#167834)
# Summary

This PR optimizes the CUDA kernels for `torch.nn.EmbeddingBag` by reducing GPU register pressure introduced by `CUDA_KERNEL_ASSERT`, which improves kernel occupancy and overall performance. The optimization separates input validation into a dedicated loop before the main processing loop, allowing the compiler to better optimize register allocation. By extensively testing on various GPUs and CUDA versions, `torch.nn.EmbeddingBag` performance improves by 29% to 111% with this PR.

# Performance Results

The following table shows the performance improvements on various input distributions and GPUs. All benchmarks use PyTorch 2.9.0 compiled with CUDA 12.8.

**Input Distribution Types (simulating recommendation system ID patterns):**
- **random id**: Randomly sampled embedding indices from the full vocabulary (uniform distribution)
- **one-hot**: One ID appears with very high frequency across all bags, simulating a popular item in recommendation systems
- **multi-hot**: Multiple IDs appear with high frequency across all bags, simulating multiple popular items in recommendation systems

**Test Configuration:**
- Embedding shape: `(5000000, 128)` (5M vocabulary size, 128-dimensional embeddings)
- Batch size: 2048 bags
- Average bag size: 150 indices per bag

| GPU  | Input Distribution | Before (µs) | After (µs) | Speedup |
| ---- | ------------------ | ----------- | ---------- | ------- |
| H100 | random id          | 162.4       | 105.9      | 1.53×   |
| H100 | one-hot            | 120.4       | 88.6       | 1.36×   |
| H100 | multi-hot          | 113.1       | 87.8       | 1.29×   |
| H20  | random id          | 278.6       | 132.2      | 2.11×   |
| H20  | one-hot            | 189.7       | 110.3      | 1.72×   |
| H20  | multi-hot          | 172.4       | 107.4      | 1.61×   |

# Motivation

The original implementation performed bounds checking using `CUDA_KERNEL_ASSERT` inline within the main processing loop, which increased register pressure and limited GPU occupancy. From NSight Compute analysis on H20, using PyTorch 2.9 compiled with CUDA 12.8, removing the `CUDA_KERNEL_ASSERT` from the main loop with this PR increases the overall occupancy from 50% to 75%(registers per thread 52->40).

By separating validation into a dedicated loop, we:

1. **Reduce register pressure in the main loop**: The validation loop uses minimal registers, allowing the compiler to optimize the main processing loop independently with better register allocation.
2. **Maintain correctness**: All input validation is still performed, but in a more register-efficient manner.

# Changes

## Modified Kernels

1. **`EmbeddingBag_updateOutputKernel_max`**: Added separate validation loop before main processing
2. **`EmbeddingBag_updateOutputKernel_sum_mean`**: Added separate validation loop before main processing

## Key Implementation Details

- **Separate validation loop**: Input indices are validated in a dedicated loop that checks all indices before processing begins
- **No early exit**: The validation loop intentionally avoids using `break` for early exit, as benchmarking showed that early exit degrades performance, possibly due to increased branch divergence and reduced instruction-level parallelism
- **Consistent error messages**: Improved error message clarity for invalid input indices
- **Design choice: validation loop vs. separate kernel**: We considered removing `CUDA_KERNEL_ASSERT` entirely and performing bounds checking in a separate GPU kernel, which would achieve even better performance (e.g., on H20 with random id distribution: 132.2 µs → 124.6 µs). However, this approach is harder to maintain as it requires coordinating two separate kernel launches and managing additional kernel launch overhead. Instead, we chose the current approach of using a separate validation loop within the same kernel, which provides a good balance between performance improvement and code maintainability.

## Code Changes

```cpp
// Separate validation loop reduces register pressure in the main loop below.
// No early exit (break) on invalid input as benchmarking shows it degrades performance.
bool has_invalid_index = false;
for (int64_t emb = begin; emb < end; emb++) {
  index_t input_idx = input[emb];
  has_invalid_index = has_invalid_index || (input_idx < 0 || input_idx >= numRows);
}
CUDA_KERNEL_ASSERT(!has_invalid_index && "Invalid input index in EmbeddingBag: index out of range [0, numRows)");

// Main processing loop (now with reduced register pressure)
for (int64_t emb = begin; emb < end; emb++) {
  // ... processing logic ...
}
```

# Testing & Compatibility

## Performance Testing

I conducted extensive performance testing across multiple configurations. All tests show significant performance improvements:

**Tested CUDA Versions:**
- CUDA 12.6, 12.8, 13.0

**Tested GPU Architectures:**
- A100, H20, H100

**Tested Input Configurations:**
- **Embedding shapes**: Various sizes including `[5000000, 128]` and `[128000, 4096]`
- **Embedding dtypes**: `torch.float32`, `torch.float16`
- **Input distributions**: Random indices, one-hot (high-frequency single ID), and multi-hot (high-frequency multiple IDs) patterns, simulating recommendation system workloads
- **Input sizes**: Average bag sizes of 150, 20, and 10 indices per bag

## Correctness Testing

-  Correctness tests pass for various embedding types (bfloat16, float32), shapes, and input distributions
-  Register usage reduction verified with NSight Compute
-  Linter passes

## Compatibility

-  No API/ABI changes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167834
Approved by: https://github.com/ngimel, https://github.com/eqy
2025-11-15 02:03:38 +00:00
eqy
fb04e9ad03 [CUDA][CUDA Graphs] Respect node-priority in cudaGraphInstantiate (#167346)
Needed for e.g., stream priority-based implementations of comm-compute overlap

CC @galv

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167346
Approved by: https://github.com/ngimel
2025-11-15 01:59:36 +00:00
cfe799b4aa Revert "Ops convolution_backward optional flag bug (#165008)"
This reverts commit c429b1fc5c60a6819b041f1a881ab09735689fbe.

Reverted https://github.com/pytorch/pytorch/pull/165008 on behalf of https://github.com/clee2000 due to I think this broke some tests in the slow workflow? test/test_ops.py::TestCommonCUDA::test_compare_cpu_convolution_backward_cuda_float32 [GH job link](https://github.com/pytorch/pytorch/actions/runs/19375318020/job/55443680773) [HUD commit link](c429b1fc5c) ([comment](https://github.com/pytorch/pytorch/pull/165008#issuecomment-3535354672))
2025-11-15 01:50:09 +00:00
b7f52773e6 Add meta registration for scaled_mm_v2 and test (#167653)
Summary:

`torch._scaled_mm_v2` didn't have a valid meta registration, or
`FakeTensor` tests, so anything expecting inductor to work (like
torch.ao tests) would fail horribly.

Test Plan:

```
pytest -sv -k "scaled_mm_v2" test/test_ops.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167653
Approved by: https://github.com/drisspg
2025-11-15 01:21:04 +00:00
f6b54d8899 flight_recorder: move to torch.distributed (#167782)
Summary: This moves torchfrtrace to be under `torch.distributed.flight_recorder` instead of `tools.flight_recorder` as the `tools` package is not included in the torch wheels. This makes it so you can use fr trace analyze without using it from a source checkout

Test Plan:
```
buck run //caffe2/fb/flight_recorder:fr_trace
```

CI

Differential Revision: D87022129

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167782
Approved by: https://github.com/fduwjj
2025-11-15 01:16:59 +00:00
da91bf5262 Fix incorrect attention example in ONNX exporter docstring (#167646)
Fixes #167627

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167646
Approved by: https://github.com/malfet, https://github.com/titaiwangms

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-11-15 00:31:48 +00:00
1c1638297e Revert "distributed/debug: add an HTTP server for debugging running jobs (#167395)"
This reverts commit 4ed26f7382bc3e5217121f5085af070e57f2ef40.

Reverted https://github.com/pytorch/pytorch/pull/167395 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/167395#issuecomment-3535150292))
2025-11-15 00:25:51 +00:00
ee0b5b4b1c Add new CI jobs to run dynamo tests on all python versions supported (#166978)
This PR adds 2 new CI jobs to run dynamo core (`test/dynamo/*`) and
`dynamo_wrapped` tests on Python 3.11/3.12.

**Selected Machine**
Tests are executed on `linux.c7i.2xlarge` without GPU. Which means all
cuda tests (if any) are skipped.

**Runtime**
- The core tests takes 30 minutes to run
- The `dynamo_wrapped` test is divided into three shards and each one
  takes around 1.5 hours to execute

**Schedule**
Tests are executed every day at 1:29 PDT or in the presence of
`ciflow/dynamo` label

Co-authored-by: Rob Timpe <rtimpe@openteams.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166978
Approved by: https://github.com/atalman, https://github.com/malfet
ghstack dependencies: #167092
2025-11-15 00:15:36 +00:00
fcfb213c5a [inductor] layout constraint for weight-norm-bwd (#167667)
fix https://github.com/pytorch/pytorch/issues/165749

The weight_norm backward kernel requires its inputs to be contiguous. Add those constraints to the lowering/fallback rule.

A better fix is maybe add decomposition rule for the op. But since we already fallback, this fix does no harm and can fix the attached issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167667
Approved by: https://github.com/eellison
2025-11-14 23:59:24 +00:00
08042bbb9c [6/N] Use Python 3.10 typing (#167649)
This PR applies new Union typing syntax to some python files.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167649
Approved by: https://github.com/albanD
2025-11-14 23:55:08 +00:00
e20ca3bc2e Remove python workaround for ContextDecorator (#167049)
This PR removes the import workaround for ContextDecorator because the import always succeeds in Py 3.10+.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167049
Approved by: https://github.com/Skylion007
2025-11-14 23:54:52 +00:00
4ed26f7382 distributed/debug: add an HTTP server for debugging running jobs (#167395)
This adds a debug HTTP server for debugging stuck or slow jobs. It runs the WorkerServer on every worker and then launches a separate flask process on rank 0 to have users connect to for debugging.

This can easily be improved to trigger profilers as well as visualize the data much better.

Initial handlers:
* pytorch profiler
* FlightRecorder data
* Python stacks

```
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "2000"

from torch.distributed.debug import enable_debug_server

enable_debug_server()
```

Test plan:

```
torchrun --nnodes 1 --nproc_per_node=gpu ~/scripts/debug_test.py
```

<img width="1499" height="1629" alt="20251107_17h10m47s_grim" src="https://github.com/user-attachments/assets/a8b9a0cb-3bbf-4558-be12-5253e418214e" />
<img width="1192" height="1337" alt="20251107_17h10m39s_grim" src="https://github.com/user-attachments/assets/ac5d7011-4acb-4401-bf2c-f9b22c1466bd" />

<img width="984" height="851" alt="20251107_18h35m38s_grim" src="https://github.com/user-attachments/assets/98b3eb31-ed01-4345-90dd-c79345cf82ce" />
<img width="2880" height="777" alt="20251107_18h35m31s_grim" src="https://github.com/user-attachments/assets/8de84b8b-9d06-4bc8-a1bf-280a2958315b" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167395
Approved by: https://github.com/fduwjj
2025-11-14 23:14:38 +00:00
4c79305b87 [targets2buck] Clean up get_pt_ops_deps (#167690)
Summary: I didn't understand what this macro was doing so I created a bit of a mess, mess be gone!

Test Plan: `buck2 ctargets fbcode//caffe2/... fbsource//xplat/caffe2/...`

Reviewed By: mzlee

Differential Revision: D86460608

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167690
Approved by: https://github.com/seemethere
2025-11-14 23:05:24 +00:00
f4b8c4f907 backed size oblivious checks for expand() (#167689)
Summary:
Support semantics when using backed_size_oblivious, similar to https://github.com/pytorch/pytorch/pull/167232

We see errors in a model exported with dynamic shapes, like
```
RuntimeError: non-broadcasting semantics require s67 == 41

While executing %expand : [num_users=1] = call_method[target=expand](args = (%reshape_5, -1, -1, %getitem_9), kwargs = {})
```

Test Plan:
test_dynamic_shapes:
```
test_backed_size_oblivious_expand (test_dynamic_shapes.TestUbackedOps) ... I1112 14:07:54.724596 1386932 Logger.cpp:995] Dropping logs in unit tests.
ok
```

Differential Revision: D86902546

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167689
Approved by: https://github.com/laithsakka
2025-11-14 22:31:28 +00:00
d629b7a459 Move CppTypeToScalarType to torch/headeronly (#167610)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167610
Approved by: https://github.com/pearu, https://github.com/janeyx99
2025-11-14 22:21:45 +00:00
0922ba5f42 [BE] No need to pass const enum values by reference (#167868)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167868
Approved by: https://github.com/slayton58
2025-11-14 21:56:19 +00:00
c87295c044 [precompile] Support captured global tensors. (#167846)
Summary:
In vllm we saw cases where user intialize a tensor in the global scope and reference it in the forward body. This should be supported by pruning the used globals in the scope and serialize them along the artifacts similar to how we handle closure.

Use case example: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3n.py#L65

Test Plan:
test_aot_compile.py

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167846
Approved by: https://github.com/jamesjwu
2025-11-14 21:40:07 +00:00
7aa210d215 Revert "[CodeClean] Remove the Unused MACRO for AOT Inductor Runtime (#165139)"
This reverts commit fcd5f8c352b5b75bd32e57fa044ec5df095032da.

Reverted https://github.com/pytorch/pytorch/pull/165139 on behalf of https://github.com/jeanschmidt due to trying to hevert in the hopes it fixes internal errors, will land it back ([comment](https://github.com/pytorch/pytorch/pull/165139#issuecomment-3534662138))
2025-11-14 21:35:37 +00:00
5a368b8010 Revert "[CodeClean] Replace std::runtime_error with TORCH_CHECK (#165119)"
This reverts commit 398775a43e9808205f75c81d36f5087117d3f3f4.

Reverted https://github.com/pytorch/pytorch/pull/165119 on behalf of https://github.com/jeanschmidt due to trying to hevert in the hopes it fixes internal errors, will land it back ([comment](https://github.com/pytorch/pytorch/pull/165139#issuecomment-3534662138))
2025-11-14 21:35:37 +00:00
602102be50 Revert "Hide all symbols (except stable/headeronly/shim) if TORCH_STABLE_ONLY is defined (#167496)"
This reverts commit bc09a84150eaadaadab8a8ecd76cd9afc60d8a19.

Reverted https://github.com/pytorch/pytorch/pull/167496 on behalf of https://github.com/jeanschmidt due to trying to revert 165139, my intention is to land it again, so, will land this once both are reverted ([comment](https://github.com/pytorch/pytorch/pull/167496#issuecomment-3534641209))
2025-11-14 21:33:02 +00:00
200156e385 DTensor: avoid unnecessary DTensorSpec creation in _ToTorchTensor.backward (#167588)
Looks like the check here is cheap and has a potentially large payoff.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167588
Approved by: https://github.com/ezyang
2025-11-14 21:08:12 +00:00
a2daf3fc86 [Inductor] Add support bound methods in pattern matcher (#167795)
Fixes: #167776

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167795
Approved by: https://github.com/mlazos
2025-11-14 20:55:51 +00:00
52b45c16de Add reshape, view, flatten to torch/csrc/stable (#167600)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167600
Approved by: https://github.com/janeyx99
ghstack dependencies: #167592
2025-11-14 20:35:53 +00:00
2ef85bed5a Add empty to stable ops (#167592)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167592
Approved by: https://github.com/janeyx99
2025-11-14 20:35:53 +00:00
d99c6bcf69 [export] Disable side effects on dynamo_graph_capture_for_export and warn user. (#167763)
Summary:
as title.

Test Plan:
test_dynamo_graph_capture_side_effects

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167763
Approved by: https://github.com/tugsbayasgalan
2025-11-14 20:35:22 +00:00
8378abda84 [torch.export] Fix for flaky test_annotate_on_assert (#167805)
Summary: test_annotate_on_assert become flaky with PR 166341 (Details in https://github.com/pytorch/pytorch/issues/167432). Torchdynamo related metadata can vary depending on the caller. Removing the those metadata before comparison.

Test Plan:
```
buck test mode/opt caffe2/test:test_export -- 'test_annotate_on_assert'
```
https://www.internalfb.com/intern/testinfra/testrun/7036874728749661

Differential Revision: D87036890

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167805
Approved by: https://github.com/yushangdi
2025-11-14 19:56:51 +00:00
5b42a5d9a6 [doc] Add example for torch.is_storage (#161898)
Fixes #161858

### Summary:
Added comprehensive documentation examples for `torch.is_storage()` to help users understand how to check if an object is a PyTorch storage object.

### Impact:

- Enhances API Documentation
- Helps users distinguish between PyTorch storage objects and other types

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161898
Approved by: https://github.com/isuruf, https://github.com/malfet
2025-11-14 19:45:54 +00:00
caca3f2eec Revert "Re-land "Fix thread safety in getCurrentCUDABlasHandle and getCUDABlasLtWorkspace" (#167722)"
This reverts commit 40e6f090d91026947fbec92a42564ad492f37eae.

Reverted https://github.com/pytorch/pytorch/pull/167722 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/167722#issuecomment-3534282212))
2025-11-14 19:38:22 +00:00
9e2bf129e1 [MPS] addmm complex fix (#167826)
Fixes #167727

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167826
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-11-14 19:29:09 +00:00
c429b1fc5c Ops convolution_backward optional flag bug (#165008)
Fixes #89629

When using torch.ops.aten.convolution_backward, the optional argument bias_sizes was being used in the python function registration without checking whether it was defined.

## For the fix
there are two modes to consider with different results.

First @dynamo.optimize("inductor") is the most demanding.
We cannot be wrong about the size passed into the function. But we should not ignore what the user wants/thinks they are doing. For this case, we want to throw an error when the user is wrong. If the user passes in None, we calculate the expected size directly.

Second @dynamo.optimize("eager") is very lenient.
We really can provide any value we want here. If the user is wrong about bias shape in eager mode, the op will just reshape the bias to the proper size so no error is thrown here.

## For testing
An OpInfo was added for torch.ops.aten.convolution_backward.default.
For the CUDA test_noncontiguous_samples test, a slightly updated error tolerance was necessary for the compounded add multiply (for 2x2 kernel).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165008
Approved by: https://github.com/bdhirsh
2025-11-14 19:24:45 +00:00
1176b2b0b7 [BE]: Update NVTX submodule to 3.3.0 (#167751)
Update NVTX to 3.3.0. Mostly fixes some errors in the bindings, improve C++20 support, and improve C++ bindings to NVTX. Header only library upgrade so should be mostly safe.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167751
Approved by: https://github.com/albanD, https://github.com/eqy
2025-11-14 19:24:37 +00:00
dd37a1a434 Fix NaN gradients in atan2_backward when both inputs are zero (#166787)
Fixes #165427

## Description of Bug 🐛

As reported in #165427, When both the input of  `atan2` function is zero the gradient becomes `NaN`. During the forward pass, `atan2` successfully avoids division-by-zero issue, but during backpropagation gradients become `NaN`.

This is because the backward pass calculates `(self * self + other * other).reciprocal()`, which becomes `inf` at `(0, 0)`. The subsequent multiplication by zero `(0 * inf)` results in `NaN`.

## Changes
- Added an `at::where` condition to handle zero denominators in `atan2_backward`.
- If denom is zero return 0 for the reciprocal; otherwise, use the original value.

## Testing
- Added` test_atan2_zero_gradient` in `test/test_autograd.py` to verify `atan2` returns `0.0` gradients for `(0,0)`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166787
Approved by: https://github.com/soulitzer
2025-11-14 19:23:33 +00:00
a74adcf80e [codemod][lowrisk] Remove unused exception parameter from caffe2/caffe2/serialize/inline_container.cc (#167612)
Summary:
`-Wunused-exception-parameter` has identified an unused exception parameter. This diff removes it.

This:
```
try {
    ...
} catch (exception& e) {
    // no use of e
}
```
should instead be written as
```
} catch (exception&) {
```

If the code compiles, this is safe to land.

Test Plan: Sandcastle

Differential Revision: D85813824

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167612
Approved by: https://github.com/seemethere, https://github.com/malfet
2025-11-14 19:11:01 +00:00
5eac46a011 add assume_32bit_indexing inductor config (#167784)
when we know all tensor and intermediate tensors fit in 32 bit but use unbacked DS
we want a way to assume that we can use 32 bit indexing(we will runtime assert on it).

It is not practical to torch check every possible intermediate tensor size ahead of time.

This is needed to enhance vLLM perf with unbacked,  since in vLLM all tensors and
intermediates assumed to fit in 32 bits.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167784
Approved by: https://github.com/jansel
2025-11-14 19:04:22 +00:00
bd04dc7750 small changes
[ghstack-poisoned]
2025-11-14 10:30:28 -08:00
e0fff31ae3 [dynamo] Make global state guards and torch function stack guards droppable. (#167674)
Summary:
Prior to this PR we will always build global and torch funciton guards in all cases.

In this PR we did 2 changes to dynamo guards:
1. Created a new guard called "GLOBAL_STATE" which corresponds to the global state guard and can be filtered out using guard_filter_fn
2. Repurpose the existing "TORCH_FUNCTION_STATE" guard for checking torch function mode stack.

Also added a new helper `torch.compiler.skip_all_guards_unsafe` which can be useful for use cases like vllm

Test Plan:
CI

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167674
Approved by: https://github.com/anijain2305
2025-11-14 18:11:44 +00:00
7ede33b8e3 Tiling bug fix (#167771)
Fix for https://github.com/pytorch/pytorch/issues/166653.

Two fixes:
- We were inducing a split for broadcasted loads. e.g. (x // 16). While a split of 16 here will make the load coalesced in one of the tile vars, since the load is already in cache it's not worth splitting. And it would make the other tile var load from memory that isnt in cache.
- Add a slight term for uncoalesced memory. This prevents doing tiling for loads which are a small % of the overall kernel.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167771
Approved by: https://github.com/v0i0
2025-11-14 17:32:42 +00:00
065176cd97 [export] Add pytree input check for dynamo_graph_capture_for_export (#167731)
Summary:
as title.

Test Plan:
pytest test/export/test_export.py -k test_invalid_pytree_dynamo_graph_capture

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167731
Approved by: https://github.com/tugsbayasgalan
2025-11-14 17:29:55 +00:00
eqy
02ee7dd7d3 [CUDA][Test] Add serialTest() to some largeTensorTest tests (#167471)
Try to prevent two big tests from overlapping in their memory usage

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167471
Approved by: https://github.com/soulitzer
2025-11-14 17:13:14 +00:00
99fdca8f4d [ROCm] Enable StaticCudaLauncher for ROCm (#166492)
This PR enables ROCm/HIP support for PyTorch's StaticCudaLauncher, which provides static compilation and launching of Triton kernels. The implementation has been tested on AMD MI300 and MI200 hardware.

**Changes**

**Python (torch/_inductor/runtime/)**
- static_cuda_launcher.py: Added ROCm detection, .hsaco binary support, and ROCm-specific scratch parameter handling
- triton_heuristics.py: Updated device type checks to support both cuda and hip

**C++ (torch/csrc/)**
- Module.cpp: Enabled StaticCudaLauncher for ROCm builds
- inductor/static_cuda_launcher.cpp: Added HIP API equivalents for all CUDA driver calls
- inductor/static_cuda_launcher.h: Updated header guard

**Tests (test/inductor/)**
- test_static_cuda_launcher.py: Removed @skipIfRocm decorators and updated binary file handling

**Enabled Unit Tests**
All tests in test/inductor/test_static_cuda_launcher.py now pass on ROCm:
1. test_basic
2. test_unsigned_integers
3. test_signed_integers
4. test_basic_1arg
5. test_constexpr
6. test_implied_constant
7. test_kernel_no_args
8. test_high_shared_mem
9. test_too_high_shared_mem
10. test_kernel_empty_tensor
11. test_kernel_many_args
12. test_basic_compile
13. test_incompatible_code
14. test_static_launch_user_defined_triton_kernels
15. test_empty_tensor
16. test_any
17. test_disable_static_cuda_launcher

In addition to this, the following tests from test/inductor/test_codecache.py also pass:
1. test_remote_cache_load_function_device_cuda_float32_dynamic_False_bundle_triton_False_use_static_cuda_launcher_False
2. test_remote_cache_load_function_device_cuda_float32_dynamic_False_bundle_triton_True_use_static_cuda_launcher_False
3. test_remote_cache_load_function_device_cuda_float32_dynamic_False_bundle_triton_True_use_static_cuda_launcher_True
4. test_remote_cache_load_function_device_cuda_bfloat16_dynamic_False_bundle_triton_False_use_static_cuda_launcher_False
5. test_remote_cache_load_function_device_cuda_bfloat16_dynamic_False_bundle_triton_True_use_static_cuda_launcher_False
6. test_remote_cache_load_function_device_cuda_bfloat16_dynamic_False_bundle_triton_True_use_static_cuda_launcher_True

The following tests are skipped since triton bundling is necessary for StaticCudaLauncher:
1. test_remote_cache_load_function_device_cuda_float32_dynamic_False_bundle_triton_False_use_static_cuda_launcher_True
2. test_remote_cache_load_function_device_cuda_bfloat16_dynamic_False_bundle_triton_False_use_static_cuda_launcher_True

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166492
Approved by: https://github.com/jeffdaily
2025-11-14 17:11:45 +00:00
9d1a74cb0c Fix mvlgamma_ FPE crash on x86 with integer input (#164230)
Fixes #161871.

Behaviour on arm:

```
PyTorch version: 2.10.0a0+gitdef3b05
Architecture: arm64
Platform: Darwin
Processor: arm

Testing mvlgamma_ with integer tensor on arm64...
 Got expected error: mvlgamma: result type Long can't be cast to the desired output type Float
```

and on x86:

```
PyTorch version: 2.10.0a0+git1310d6a
Architecture: x86_64
Platform: Linux
Processor: x86_64

Testing mvlgamma_ with integer tensor on x86_64...
 Got expected error: mvlgamma: result type Long can't be cast to the desired output type Float
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164230
Approved by: https://github.com/albanD
2025-11-14 17:09:10 +00:00
3b90bf36f9 Add multiple hiding nodes
[ghstack-poisoned]
2025-11-14 09:04:39 -08:00
40e6f090d9 Re-land "Fix thread safety in getCurrentCUDABlasHandle and getCUDABlasLtWorkspace" (#167722)
Summary:
getCurrentCUDABlasHandle() and getCUDABlasLtWorkspace() use static mutable maps that are not protected from concurrent read-and-write. This leads to crashes.
This diff adds mutexes to synchronize access to the static maps.

Note: this is a re-land of D86316117 / https://github.com/pytorch/pytorch/pull/167248 (see comments for details)

Test Plan:
Use a GPU OD, run multi-threaded tests (cuda_cublas_handle_pool_test) with TSAN:
```
buck test fbcode//mode/dev-tsan fbcode//caffe2:cuda_cublas_handle_pool_test  -- --stress-runs 100
```
https://www.internalfb.com/intern/testinfra/testrun/14355223937501118

TSAN output (before synchronization was added): P2026731804

Differential Revision: D86964261

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167722
Approved by: https://github.com/malfet
2025-11-14 16:16:35 +00:00
230 changed files with 5942 additions and 3128 deletions

View File

@ -75,9 +75,11 @@ if [[ "$ARCH" == "aarch64" ]]; then
# ARM system libraries
DEPS_LIST+=(
"/usr/lib64/libgfortran.so.5"
"/opt/OpenBLAS/lib/libopenblas.so.0"
)
DEPS_SONAME+=(
"libgfortran.so.5"
"libopenblas.so.0"
)
fi

View File

@ -100,337 +100,6 @@ def check_lib_statically_linked_libstdc_cxx_abi_symbols(lib: str) -> None:
)
def _compile_and_extract_symbols(
cpp_content: str, compile_flags: list[str], exclude_list: list[str] | None = None
) -> list[str]:
"""
Helper to compile a C++ file and extract all symbols.
Args:
cpp_content: C++ source code to compile
compile_flags: Compilation flags
exclude_list: List of symbol names to exclude. Defaults to ["main"].
Returns:
List of all symbols found in the object file (excluding those in exclude_list).
"""
import subprocess
import tempfile
if exclude_list is None:
exclude_list = ["main"]
with tempfile.TemporaryDirectory() as tmpdir:
tmppath = Path(tmpdir)
cpp_file = tmppath / "test.cpp"
obj_file = tmppath / "test.o"
cpp_file.write_text(cpp_content)
result = subprocess.run(
compile_flags + [str(cpp_file), "-o", str(obj_file)],
capture_output=True,
text=True,
timeout=60,
)
if result.returncode != 0:
raise RuntimeError(f"Compilation failed: {result.stderr}")
symbols = get_symbols(str(obj_file))
# Return all symbol names, excluding those in the exclude list
return [name for _addr, _stype, name in symbols if name not in exclude_list]
def check_stable_only_symbols(install_root: Path) -> None:
"""
Test TORCH_STABLE_ONLY and TORCH_TARGET_VERSION by compiling test code and comparing symbol counts.
This approach tests:
1. WITHOUT macros -> many torch symbols exposed
2. WITH TORCH_STABLE_ONLY -> zero torch symbols (all hidden)
3. WITH TORCH_TARGET_VERSION -> zero torch symbols (all hidden)
4. WITH both macros -> zero torch symbols (all hidden)
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
test_cpp_content = """
// Main torch C++ API headers
#include <torch/torch.h>
#include <torch/all.h>
// ATen tensor library
#include <ATen/ATen.h>
// Core c10 headers (commonly used)
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/ScalarType.h>
#include <c10/core/TensorOptions.h>
#include <c10/util/Optional.h>
int main() { return 0; }
"""
base_compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c", # Compile only, don't link
]
# Compile WITHOUT any macros
symbols_without = _compile_and_extract_symbols(
cpp_content=test_cpp_content,
compile_flags=base_compile_flags,
)
# We expect constexpr symbols, inline functions used by other headers etc.
# to produce symbols
num_symbols_without = len(symbols_without)
print(f"Found {num_symbols_without} symbols without any macros defined")
assert num_symbols_without != 0, (
"Expected a non-zero number of symbols without any macros"
)
# Compile WITH TORCH_STABLE_ONLY (expect 0 symbols)
compile_flags_with_stable_only = base_compile_flags + ["-DTORCH_STABLE_ONLY"]
symbols_with_stable_only = _compile_and_extract_symbols(
cpp_content=test_cpp_content,
compile_flags=compile_flags_with_stable_only,
)
num_symbols_with_stable_only = len(symbols_with_stable_only)
assert num_symbols_with_stable_only == 0, (
f"Expected no symbols with TORCH_STABLE_ONLY macro, but found {num_symbols_with_stable_only}"
)
# Compile WITH TORCH_TARGET_VERSION (expect 0 symbols)
compile_flags_with_target_version = base_compile_flags + [
"-DTORCH_TARGET_VERSION=1"
]
symbols_with_target_version = _compile_and_extract_symbols(
cpp_content=test_cpp_content,
compile_flags=compile_flags_with_target_version,
)
num_symbols_with_target_version = len(symbols_with_target_version)
assert num_symbols_with_target_version == 0, (
f"Expected no symbols with TORCH_TARGET_VERSION macro, but found {num_symbols_with_target_version}"
)
# Compile WITH both macros (expect 0 symbols)
compile_flags_with_both = base_compile_flags + [
"-DTORCH_STABLE_ONLY",
"-DTORCH_TARGET_VERSION=1",
]
symbols_with_both = _compile_and_extract_symbols(
cpp_content=test_cpp_content,
compile_flags=compile_flags_with_both,
)
num_symbols_with_both = len(symbols_with_both)
assert num_symbols_with_both == 0, (
f"Expected no symbols with both macros, but found {num_symbols_with_both}"
)
def check_stable_api_symbols(install_root: Path) -> None:
"""
Test that stable API headers still expose symbols with TORCH_STABLE_ONLY.
The torch/csrc/stable/c/shim.h header is tested in check_stable_c_shim_symbols
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
stable_dir = include_dir / "torch" / "csrc" / "stable"
assert stable_dir.exists(), f"Expected {stable_dir} to be present"
stable_headers = list(stable_dir.rglob("*.h"))
if not stable_headers:
raise RuntimeError("Could not find any stable headers")
includes = []
for header in stable_headers:
rel_path = header.relative_to(include_dir)
includes.append(f"#include <{rel_path.as_posix()}>")
includes_str = "\n".join(includes)
test_stable_content = f"""
{includes_str}
int main() {{ return 0; }}
"""
compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c",
"-DTORCH_STABLE_ONLY",
]
symbols_stable = _compile_and_extract_symbols(
cpp_content=test_stable_content,
compile_flags=compile_flags,
)
num_symbols_stable = len(symbols_stable)
print(f"Found {num_symbols_stable} symbols in torch/csrc/stable")
assert num_symbols_stable > 0, (
f"Expected stable headers to expose symbols with TORCH_STABLE_ONLY, "
f"but found {num_symbols_stable} symbols"
)
def check_headeronly_symbols(install_root: Path) -> None:
"""
Test that header-only utility headers still expose symbols with TORCH_STABLE_ONLY.
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
# Find all headers in torch/headeronly
headeronly_dir = include_dir / "torch" / "headeronly"
assert headeronly_dir.exists(), f"Expected {headeronly_dir} to be present"
headeronly_headers = list(headeronly_dir.rglob("*.h"))
if not headeronly_headers:
raise RuntimeError("Could not find any headeronly headers")
# Filter out platform-specific headers that may not compile everywhere
platform_specific_keywords = [
"cpu/vec",
]
filtered_headers = []
for header in headeronly_headers:
rel_path = header.relative_to(include_dir).as_posix()
if not any(
keyword in rel_path.lower() for keyword in platform_specific_keywords
):
filtered_headers.append(header)
includes = []
for header in filtered_headers:
rel_path = header.relative_to(include_dir)
includes.append(f"#include <{rel_path.as_posix()}>")
includes_str = "\n".join(includes)
test_headeronly_content = f"""
{includes_str}
int main() {{ return 0; }}
"""
compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c",
"-DTORCH_STABLE_ONLY",
]
symbols_headeronly = _compile_and_extract_symbols(
cpp_content=test_headeronly_content,
compile_flags=compile_flags,
)
num_symbols_headeronly = len(symbols_headeronly)
print(f"Found {num_symbols_headeronly} symbols in torch/headeronly")
assert num_symbols_headeronly > 0, (
f"Expected headeronly headers to expose symbols with TORCH_STABLE_ONLY, "
f"but found {num_symbols_headeronly} symbols"
)
def check_aoti_shim_symbols(install_root: Path) -> None:
"""
Test that AOTI shim headers still expose symbols with TORCH_STABLE_ONLY.
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
# There are no constexpr symbols etc., so we need to actually use functions
# so that some symbols are found.
test_shim_content = """
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
int main() {
int32_t (*fp1)() = &aoti_torch_device_type_cpu;
int32_t (*fp2)() = &aoti_torch_dtype_float32;
(void)fp1; (void)fp2;
return 0;
}
"""
compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c",
"-DTORCH_STABLE_ONLY",
]
symbols_shim = _compile_and_extract_symbols(
cpp_content=test_shim_content,
compile_flags=compile_flags,
)
num_symbols_shim = len(symbols_shim)
assert num_symbols_shim > 0, (
f"Expected shim headers to expose symbols with TORCH_STABLE_ONLY, "
f"but found {num_symbols_shim} symbols"
)
def check_stable_c_shim_symbols(install_root: Path) -> None:
"""
Test that stable C shim headers still expose symbols with TORCH_STABLE_ONLY.
"""
include_dir = install_root / "include"
assert include_dir.exists(), f"Expected {include_dir} to be present"
# Check if the stable C shim exists
stable_shim = include_dir / "torch" / "csrc" / "stable" / "c" / "shim.h"
if not stable_shim.exists():
raise RuntimeError("Could not find stable c shim")
# There are no constexpr symbols etc., so we need to actually use functions
# so that some symbols are found.
test_stable_shim_content = """
#include <torch/csrc/stable/c/shim.h>
int main() {
// Reference stable C API functions to create undefined symbols
AOTITorchError (*fp1)(const char*, uint32_t*, int32_t*) = &torch_parse_device_string;
AOTITorchError (*fp2)(uint32_t*) = &torch_get_num_threads;
(void)fp1; (void)fp2;
return 0;
}
"""
compile_flags = [
"g++",
"-std=c++17",
f"-I{include_dir}",
f"-I{include_dir}/torch/csrc/api/include",
"-c",
"-DTORCH_STABLE_ONLY",
]
symbols_stable_shim = _compile_and_extract_symbols(
cpp_content=test_stable_shim_content,
compile_flags=compile_flags,
)
num_symbols_stable_shim = len(symbols_stable_shim)
assert num_symbols_stable_shim > 0, (
f"Expected stable C shim headers to expose symbols with TORCH_STABLE_ONLY, "
f"but found {num_symbols_stable_shim} symbols"
)
def check_lib_symbols_for_abi_correctness(lib: str) -> None:
print(f"lib: {lib}")
cxx11_symbols = grep_symbols(lib, LIBTORCH_CXX11_PATTERNS)
@ -460,13 +129,6 @@ def main() -> None:
check_lib_symbols_for_abi_correctness(libtorch_cpu_path)
check_lib_statically_linked_libstdc_cxx_abi_symbols(libtorch_cpu_path)
# Check symbols when TORCH_STABLE_ONLY is defined
check_stable_only_symbols(install_root)
check_stable_api_symbols(install_root)
check_headeronly_symbols(install_root)
check_aoti_shim_symbols(install_root)
check_stable_c_shim_symbols(install_root)
if __name__ == "__main__":
main()

View File

@ -389,6 +389,13 @@ test_lazy_tensor_meta_reference_disabled() {
export -n TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE
}
test_dynamo_core() {
time python test/run_test.py \
--include-dynamo-core-tests \
--verbose \
--upload-artifacts-while-running
assert_git_not_dirty
}
test_dynamo_wrapped_shard() {
if [[ -z "$NUM_TEST_SHARDS" ]]; then
@ -1814,6 +1821,8 @@ elif [[ "${TEST_CONFIG}" == *inductor* ]]; then
test_inductor_shard "${SHARD_NUMBER}"
elif [[ "${TEST_CONFIG}" == *einops* ]]; then
test_einops
elif [[ "${TEST_CONFIG}" == *dynamo_core* ]]; then
test_dynamo_core
elif [[ "${TEST_CONFIG}" == *dynamo_wrapped* ]]; then
install_torchvision
test_dynamo_wrapped_shard "${SHARD_NUMBER}"

View File

@ -1 +1 @@
07b6cbde121417a70e4dc871adb6d27030e0ce3f
ee1a1350eb37804b94334768f328144f058f14e9

View File

@ -1 +1 @@
acccf86477759b2d3500f1ae1be065f7b1e409ec
2d82dc5caa336d179d9b46ac4a0fb8c43d84c5cc

View File

@ -1 +1 @@
e4d25697f9dc5eedaf8f0a5bf085c62c5455a53a
94631807d22c09723dd006f7be5beb649d5f88d0

View File

@ -7,6 +7,7 @@ ciflow_push_tags:
- ciflow/binaries
- ciflow/binaries_libtorch
- ciflow/binaries_wheel
- ciflow/dynamo
- ciflow/h100
- ciflow/h100-cutlass-backend
- ciflow/h100-distributed

View File

@ -326,7 +326,7 @@ jobs:
SCCACHE_BUCKET: ${{ !contains(matrix.runner, 'b200') && 'ossci-compiler-cache-circleci-v2' || '' }}
SCCACHE_REGION: ${{ !contains(matrix.runner, 'b200') && 'us-east-1' || '' }}
SHM_SIZE: ${{ contains(inputs.build-environment, 'cuda') && '2g' || '1g' }}
DOCKER_IMAGE: ${{ inputs.docker-image }}
DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }}
XLA_CUDA: ${{ contains(inputs.build-environment, 'xla') && '0' || '' }}
XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla
PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0' }}

70
.github/workflows/dynamo-unittest.yml vendored Normal file
View File

@ -0,0 +1,70 @@
# Workflow: Dynamo Unit Test
# runs unit tests for dynamo.
name: dynamo-unittest
on:
push:
tags:
- ciflow/dynamo/*
workflow_call:
schedule:
- cron: 29 8 * * * # about 1:29am PDT
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
cancel-in-progress: true
permissions:
id-token: write
contents: read
jobs:
get-label-type:
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
opt_out_experiments: lf
dynamo-build:
name: dynamo-build
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
strategy:
matrix:
python-version: ['3.11', '3.12']
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-py${{ matrix.python-version }}-clang12
docker-image-name: ci-image:pytorch-linux-jammy-py${{ matrix.python-version }}-clang12
test-matrix: |
{ include: [
{ config: "dynamo_core", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
{ config: "dynamo_wrapped", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
{ config: "dynamo_wrapped", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
{ config: "dynamo_wrapped", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
]}
secrets: inherit
dynamo-test:
name: dynamo-test
uses: ./.github/workflows/_linux-test.yml
needs: [get-label-type, dynamo-build]
strategy:
matrix:
python-version: ['3.11', '3.12']
with:
build-environment: linux-jammy-py${{ matrix.python-version }}-clang12
docker-image: ci-image:pytorch-linux-jammy-py${{ matrix.python-version }}-clang12
test-matrix: |
{ include: [
{ config: "dynamo_core", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
{ config: "dynamo_wrapped", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
{ config: "dynamo_wrapped", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
{ config: "dynamo_wrapped", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
]}
secrets: inherit

View File

@ -1,5 +1,6 @@
#pragma once
#include <torch/headeronly/core/TensorAccessor.h>
#include <c10/macros/Macros.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Deprecated.h>
@ -11,252 +12,37 @@
namespace at {
// The PtrTraits argument to the TensorAccessor/GenericPackedTensorAccessor
// is used to enable the __restrict__ keyword/modifier for the data
// passed to cuda.
template <typename T>
struct DefaultPtrTraits {
typedef T* PtrType;
};
using torch::headeronly::DefaultPtrTraits;
#if defined(__CUDACC__) || defined(__HIPCC__)
template <typename T>
struct RestrictPtrTraits {
typedef T* __restrict__ PtrType;
};
using torch::headeronly::RestrictPtrTraits;
#endif
// TensorAccessorBase and TensorAccessor are used for both CPU and CUDA tensors.
// For CUDA tensors it is used in device code (only). This means that we restrict ourselves
// to functions and types available there (e.g. IntArrayRef isn't).
// The PtrTraits argument is only relevant to cuda to support `__restrict__` pointers.
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
class TensorAccessorBase {
public:
typedef typename PtrTraits<T>::PtrType PtrType;
using TensorAccessorBase = torch::headeronly::detail::TensorAccessorBase<c10::IntArrayRef, T, N, PtrTraits, index_t>;
C10_HOST_DEVICE TensorAccessorBase(
PtrType data_,
const index_t* sizes_,
const index_t* strides_)
: data_(data_), sizes_(sizes_), strides_(strides_) {}
C10_HOST IntArrayRef sizes() const {
return IntArrayRef(sizes_,N);
}
C10_HOST IntArrayRef strides() const {
return IntArrayRef(strides_,N);
}
C10_HOST_DEVICE index_t stride(index_t i) const {
return strides_[i];
}
C10_HOST_DEVICE index_t size(index_t i) const {
return sizes_[i];
}
C10_HOST_DEVICE PtrType data() {
return data_;
}
C10_HOST_DEVICE const PtrType data() const {
return data_;
}
protected:
PtrType data_;
const index_t* sizes_;
const index_t* strides_;
};
// The `TensorAccessor` is typically instantiated for CPU `Tensor`s using
// `Tensor.accessor<T, N>()`.
// For CUDA `Tensor`s, `GenericPackedTensorAccessor` is used on the host and only
// indexing on the device uses `TensorAccessor`s.
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
class TensorAccessor : public TensorAccessorBase<T,N,PtrTraits,index_t> {
public:
typedef typename PtrTraits<T>::PtrType PtrType;
using TensorAccessor = torch::headeronly::detail::TensorAccessor<c10::IntArrayRef, T, N, PtrTraits, index_t>;
C10_HOST_DEVICE TensorAccessor(
PtrType data_,
const index_t* sizes_,
const index_t* strides_)
: TensorAccessorBase<T, N, PtrTraits, index_t>(data_,sizes_,strides_) {}
namespace detail {
C10_HOST_DEVICE TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) {
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1);
}
C10_HOST_DEVICE const TensorAccessor<T, N-1, PtrTraits, index_t> operator[](index_t i) const {
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1);
}
};
template<typename T, template <typename U> class PtrTraits, typename index_t>
class TensorAccessor<T,1,PtrTraits,index_t> : public TensorAccessorBase<T,1,PtrTraits,index_t> {
public:
typedef typename PtrTraits<T>::PtrType PtrType;
C10_HOST_DEVICE TensorAccessor(
PtrType data_,
const index_t* sizes_,
const index_t* strides_)
: TensorAccessorBase<T, 1, PtrTraits, index_t>(data_,sizes_,strides_) {}
C10_HOST_DEVICE T & operator[](index_t i) {
// NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
return this->data_[this->strides_[0]*i];
}
C10_HOST_DEVICE const T & operator[](index_t i) const {
return this->data_[this->strides_[0]*i];
}
};
// GenericPackedTensorAccessorBase and GenericPackedTensorAccessor are used on for CUDA `Tensor`s on the host
// and as
// In contrast to `TensorAccessor`s, they copy the strides and sizes on instantiation (on the host)
// in order to transfer them on the device when calling kernels.
// On the device, indexing of multidimensional tensors gives to `TensorAccessor`s.
// Use RestrictPtrTraits as PtrTraits if you want the tensor's data pointer to be marked as __restrict__.
// Instantiation from data, sizes, strides is only needed on the host and std::copy isn't available
// on the device, so those functions are host only.
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
class GenericPackedTensorAccessorBase {
public:
typedef typename PtrTraits<T>::PtrType PtrType;
C10_HOST GenericPackedTensorAccessorBase(
PtrType data_,
const index_t* sizes_,
const index_t* strides_)
: data_(data_) {
std::copy(sizes_, sizes_ + N, std::begin(this->sizes_));
std::copy(strides_, strides_ + N, std::begin(this->strides_));
}
// if index_t is not int64_t, we want to have an int64_t constructor
template <typename source_index_t, class = std::enable_if_t<std::is_same_v<source_index_t, int64_t>>>
C10_HOST GenericPackedTensorAccessorBase(
PtrType data_,
const source_index_t* sizes_,
const source_index_t* strides_)
: data_(data_) {
for (const auto i : c10::irange(N)) {
this->sizes_[i] = sizes_[i];
this->strides_[i] = strides_[i];
}
}
C10_HOST_DEVICE index_t stride(index_t i) const {
return strides_[i];
}
C10_HOST_DEVICE index_t size(index_t i) const {
return sizes_[i];
}
C10_HOST_DEVICE PtrType data() {
return data_;
}
C10_HOST_DEVICE const PtrType data() const {
return data_;
}
protected:
PtrType data_;
// NOLINTNEXTLINE(*c-arrays*)
index_t sizes_[N];
// NOLINTNEXTLINE(*c-arrays*)
index_t strides_[N];
C10_HOST void bounds_check_(index_t i) const {
TORCH_CHECK_INDEX(
template <size_t N, typename index_t>
struct IndexBoundsCheck {
IndexBoundsCheck(index_t i) {
TORCH_CHECK_INDEX(
0 <= i && i < index_t{N},
"Index ",
i,
" is not within bounds of a tensor of dimension ",
N);
}
}
};
} // namespace detail
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
class GenericPackedTensorAccessor : public GenericPackedTensorAccessorBase<T,N,PtrTraits,index_t> {
public:
typedef typename PtrTraits<T>::PtrType PtrType;
C10_HOST GenericPackedTensorAccessor(
PtrType data_,
const index_t* sizes_,
const index_t* strides_)
: GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}
// if index_t is not int64_t, we want to have an int64_t constructor
template <typename source_index_t, class = std::enable_if_t<std::is_same_v<source_index_t, int64_t>>>
C10_HOST GenericPackedTensorAccessor(
PtrType data_,
const source_index_t* sizes_,
const source_index_t* strides_)
: GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}
C10_DEVICE TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) {
index_t* new_sizes = this->sizes_ + 1;
index_t* new_strides = this->strides_ + 1;
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i, new_sizes, new_strides);
}
C10_DEVICE const TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) const {
const index_t* new_sizes = this->sizes_ + 1;
const index_t* new_strides = this->strides_ + 1;
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i, new_sizes, new_strides);
}
/// Returns a PackedTensorAccessor of the same dimension after transposing the
/// two dimensions given. Does not actually move elements; transposition is
/// made by permuting the size/stride arrays. If the dimensions are not valid,
/// asserts.
C10_HOST GenericPackedTensorAccessor<T, N, PtrTraits, index_t> transpose(
index_t dim1,
index_t dim2) const {
this->bounds_check_(dim1);
this->bounds_check_(dim2);
GenericPackedTensorAccessor<T, N, PtrTraits, index_t> result(
this->data_, this->sizes_, this->strides_);
std::swap(result.strides_[dim1], result.strides_[dim2]);
std::swap(result.sizes_[dim1], result.sizes_[dim2]);
return result;
}
};
template<typename T, template <typename U> class PtrTraits, typename index_t>
class GenericPackedTensorAccessor<T,1,PtrTraits,index_t> : public GenericPackedTensorAccessorBase<T,1,PtrTraits,index_t> {
public:
typedef typename PtrTraits<T>::PtrType PtrType;
C10_HOST GenericPackedTensorAccessor(
PtrType data_,
const index_t* sizes_,
const index_t* strides_)
: GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}
// if index_t is not int64_t, we want to have an int64_t constructor
template <typename source_index_t, class = std::enable_if_t<std::is_same_v<source_index_t, int64_t>>>
C10_HOST GenericPackedTensorAccessor(
PtrType data_,
const source_index_t* sizes_,
const source_index_t* strides_)
: GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}
C10_DEVICE T & operator[](index_t i) {
return this->data_[this->strides_[0] * i];
}
C10_DEVICE const T& operator[](index_t i) const {
return this->data_[this->strides_[0]*i];
}
// Same as in the general N-dimensional case, but note that in the
// 1-dimensional case the returned PackedTensorAccessor will always be an
// identical copy of the original
C10_HOST GenericPackedTensorAccessor<T, 1, PtrTraits, index_t> transpose(
index_t dim1,
index_t dim2) const {
this->bounds_check_(dim1);
this->bounds_check_(dim2);
return GenericPackedTensorAccessor<T, 1, PtrTraits, index_t>(
this->data_, this->sizes_, this->strides_);
}
};
using GenericPackedTensorAccessorBase = torch::headeronly::detail::GenericPackedTensorAccessorBase<detail::IndexBoundsCheck<N, index_t>, T, N, PtrTraits, index_t>;
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
using GenericPackedTensorAccessor = torch::headeronly::detail::GenericPackedTensorAccessor<TensorAccessor<T, N-1, PtrTraits, index_t>, detail::IndexBoundsCheck<N, index_t>, T, N, PtrTraits, index_t>;
// Can't put this directly into the macro function args because of commas
#define AT_X GenericPackedTensorAccessor<T, N, PtrTraits, index_t>

View File

@ -245,6 +245,9 @@ class TORCH_API TensorBase {
size_t weak_use_count() const noexcept {
return impl_.weak_use_count();
}
bool is_uniquely_owned() const noexcept {
return impl_.is_uniquely_owned();
}
std::string toString() const;

View File

@ -3,6 +3,7 @@
#include <cstdint>
#include <map>
#include <shared_mutex>
#include <cuda_runtime_api.h>
#include <cusparse.h>
@ -88,8 +89,13 @@ TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
TORCH_CUDA_CPP_API void clearCublasWorkspaces();
TORCH_CUDA_CPP_API std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace();
TORCH_CUDA_CPP_API std::map<std::tuple<void *, void *>, at::DataPtr>& cublaslt_handle_stream_to_workspace();
struct WorkspaceMapWithMutex {
std::map<std::tuple<void*, void*>, at::DataPtr> map;
std::shared_mutex mutex;
};
TORCH_CUDA_CPP_API WorkspaceMapWithMutex& cublas_handle_stream_to_workspace();
TORCH_CUDA_CPP_API WorkspaceMapWithMutex& cublaslt_handle_stream_to_workspace();
TORCH_CUDA_CPP_API size_t getChosenWorkspaceSize();
TORCH_CUDA_CPP_API size_t getCUDABlasLtWorkspaceSize();
TORCH_CUDA_CPP_API void* getCUDABlasLtWorkspace();

View File

@ -175,17 +175,24 @@ void CUDAGraph::instantiate() {
// Trailing NULL, NULL, 0 arguments were recommended by Cuda driver people,
// who prefer not to report error message through these arguments moving forward
// (they prefer return value, or errors on api calls internal to the capture)
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000)
AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, 0));
// ROCM appears to fail with HIP error: invalid argument
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && !defined(USE_ROCM)
AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, cudaGraphInstantiateFlagUseNodePriority));
#else
AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, NULL, NULL, 0));
#endif
//Since ROCm 6.2, we want to go down this path as hipGraphExecDestroy in the destructor will not immediately free the memory.
//It will wait for the next sync operation. cudaGraphInstantiateFlagAutoFreeOnLaunch will add async frees after graph launch.
} else {
#if !defined(USE_ROCM)
AT_CUDA_CHECK(cudaGraphInstantiateWithFlags(&graph_exec_,
graph_,
cudaGraphInstantiateFlagAutoFreeOnLaunch | cudaGraphInstantiateFlagUseNodePriority));
#else
AT_CUDA_CHECK(cudaGraphInstantiateWithFlags(&graph_exec_,
graph_,
cudaGraphInstantiateFlagAutoFreeOnLaunch));
#endif
}
has_graph_exec_ = true;
}

View File

@ -99,7 +99,7 @@ void destroyCublasHandle(cublasHandle_t handle) {
// - Comments of @soumith copied from cuDNN handle pool implementation
#ifdef NO_CUDNN_DESTROY_HANDLE
#else
cublasDestroy(handle);
cublasDestroy(handle);
#endif
}
@ -107,19 +107,27 @@ using CuBlasPoolType = DeviceThreadHandlePool<cublasHandle_t, createCublasHandle
} // namespace
std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace() {
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
WorkspaceMapWithMutex& cublas_handle_stream_to_workspace() {
static auto& instance = *new WorkspaceMapWithMutex;
return instance;
}
std::map<std::tuple<void *, void *>, at::DataPtr>& cublaslt_handle_stream_to_workspace() {
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
WorkspaceMapWithMutex& cublaslt_handle_stream_to_workspace() {
static auto& instance = *new WorkspaceMapWithMutex;
return instance;
}
void clearCublasWorkspaces() {
cublas_handle_stream_to_workspace().clear();
cublaslt_handle_stream_to_workspace().clear();
{
auto& workspace = cublas_handle_stream_to_workspace();
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
workspace.map.clear();
}
{
auto& workspace = cublaslt_handle_stream_to_workspace();
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
workspace.map.clear();
}
}
size_t parseChosenWorkspaceSize() {
@ -233,6 +241,38 @@ at::DataPtr getNewCUDABlasLtWorkspace() {
return c10::cuda::CUDACachingAllocator::get()->allocate(getCUDABlasLtWorkspaceSize());
}
void setWorkspaceForHandle(cublasHandle_t handle, c10::cuda::CUDAStream stream) {
cudaStream_t _stream = stream;
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
auto& workspace = cublas_handle_stream_to_workspace();
size_t workspace_size = getChosenWorkspaceSize();
// Fast path: check if workspace already exists
{
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
auto workspace_it = workspace.map.find(key);
if (workspace_it != workspace.map.end()) {
TORCH_CUDABLAS_CHECK(cublasSetWorkspace(
handle, workspace_it->second.get(), workspace_size));
return;
}
}
// Slow path: allocate workspace outside the lock
auto new_workspace = getNewWorkspace();
// Insert with lock (double-check in case another thread inserted while we
// were allocating)
{
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
auto workspace_it = workspace.map.try_emplace(key, std::move(new_workspace)).first;
TORCH_CUDABLAS_CHECK(
cublasSetWorkspace(handle, workspace_it->second.get(), workspace_size));
}
}
void* getCUDABlasLtWorkspace() {
#ifndef USE_ROCM
static bool unified = c10::utils::check_env(TORCH_CUBLASLT_UNIFIED_WORKSPACE) == true;
@ -241,8 +281,10 @@ void* getCUDABlasLtWorkspace() {
auto stream = c10::cuda::getCurrentCUDAStream();
cudaStream_t _stream = stream;
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
auto workspace_it = at::cuda::cublas_handle_stream_to_workspace().find(key);
TORCH_INTERNAL_ASSERT(workspace_it != at::cuda::cublas_handle_stream_to_workspace().end());
auto& workspace = at::cuda::cublas_handle_stream_to_workspace();
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
auto workspace_it = workspace.map.find(key);
TORCH_INTERNAL_ASSERT(workspace_it != workspace.map.end());
return workspace_it->second.mutable_get();
}
#endif
@ -250,11 +292,29 @@ void* getCUDABlasLtWorkspace() {
auto stream = c10::cuda::getCurrentCUDAStream();
cudaStream_t _stream = stream;
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
auto workspace_it = cublaslt_handle_stream_to_workspace().find(key);
if (workspace_it == cublaslt_handle_stream_to_workspace().end()) {
workspace_it = cublaslt_handle_stream_to_workspace().insert(workspace_it, {key, getNewCUDABlasLtWorkspace()});
auto& workspace = cublaslt_handle_stream_to_workspace();
// Fast path: check if workspace already exists
{
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
auto workspace_it = workspace.map.find(key);
if (workspace_it != workspace.map.end()) {
return workspace_it->second.mutable_get();
}
}
// Slow path: allocate workspace outside the lock
auto new_workspace = getNewCUDABlasLtWorkspace();
// Insert with lock (double-check in case another thread inserted while we
// were allocating)
{
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
auto workspace_it =
workspace.map.try_emplace(key, std::move(new_workspace)).first;
return workspace_it->second.mutable_get();
}
return workspace_it->second.mutable_get();
}
cublasHandle_t getCurrentCUDABlasHandle() {
@ -298,13 +358,8 @@ cublasHandle_t getCurrentCUDABlasHandle() {
// will allocate memory dynamically (even if they're cheap) outside
// PyTorch's CUDA caching allocator. It's possible that CCA used up
// all the memory and cublas's cudaMallocAsync will return OOM
cudaStream_t _stream = stream;
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
auto workspace_it = cublas_handle_stream_to_workspace().find(key);
if (workspace_it == cublas_handle_stream_to_workspace().end()) {
workspace_it = cublas_handle_stream_to_workspace().insert(workspace_it, {key, getNewWorkspace()});
}
TORCH_CUDABLAS_CHECK(cublasSetWorkspace(handle, workspace_it->second.get(), getChosenWorkspaceSize()));
setWorkspaceForHandle(handle, stream);
#if !defined(USE_ROCM)
// On CUDA >= 11, and architecture >= Ampere, cuBLAS can use TF32 to speedup
// FP32 data type calculations based on the value of the allow_tf32 flag.

View File

@ -1936,7 +1936,7 @@ static bool should_fold(const Tensor& tensor1, const Tensor& tensor2, bool has_o
// We order the tensors. t1 will be the larger tensor
// We can always transpose tensor2 as the dimensions are always >= 1 (precondition from matmul)
// and tensor1_larger iff tensor2.dim() > tensor1.dim()
// and tensor1_larger iff tensor2.dim() > tensor1.dim(9
const auto t1 = tensor1_larger ? MaybeOwned<Tensor>::borrowed(tensor1)
: MaybeOwned<Tensor>::owned(tensor2.mT());
const int64_t dim_t1 = t1->dim();
@ -1948,11 +1948,20 @@ static bool should_fold(const Tensor& tensor1, const Tensor& tensor2, bool has_o
return false;
}
// If we require a gradient, we should fold to minimize backward memory usage - even if this
// leads to a copy in forward because is needed in backward,
// only time we avoid this strict pre-allocated memory usage (has_out = True)
bool requires_grad = tensor1.requires_grad() || tensor2.requires_grad();
if (requires_grad && !has_out) {
// In this case we *do* incur in an extra copy to avoid creating an unnecessary large tensor in the backward
// Suppose we don't fold here. Let t1.shape = [b, m, n] t2.shape = [n, k] like in a transformer
// t2 will be expanded to a tensor of shape [b, n, k] and then we do t1.bmm(t2_expanded)
// The issue appears in the backward.
// The output gradient g of this operation would have shape [b, m, k]
// The backward wrt. t2 of bmm would be given by t1.mH @ g, which has shape [b, n, k]
// Then, the backward of expand is simply `sum(0)`. As such, we are instantiating a tensor
// of shape [b, n, k] unnecessarily, which may cause a large memory footprint, and in the
// worst case, an OOM
bool t2_requires_grad = tensor1_larger ? tensor2.requires_grad() : tensor1.requires_grad();
if (t2_requires_grad && !has_out) {
// We should be checking !at::GradMode::is_enabled(), but apparently
// this regresses performance in some cases:
// https://github.com/pytorch/pytorch/issues/118548#issuecomment-1916022394
return true;
}
@ -3532,9 +3541,9 @@ Tensor _dyn_quant_matmul_4bit_cpu(
const int64_t out_features) {
auto M = inp.size(0);
TORCH_CHECK(
inp.dtype() == kFloat,
inp.dtype() == kFloat || (inp.dtype() == kBFloat16 && block_size == in_features),
__func__,
" : expect input to be 32-bit float tensor.");
" : expect input to be float32 or bfloat16 tensor.");
TORCH_CHECK(
block_size == in_features ||
(!(block_size % 32) && !(in_features % block_size)),

View File

@ -1087,7 +1087,8 @@ TORCH_IMPL_FUNC(index_copy_out)
result.copy_(self);
// See Note [Enabling Deterministic Operations]
if (result.is_cuda() && globalContext().deterministicAlgorithms()) {
if ((result.is_cuda() || result.is_xpu()) &&
globalContext().deterministicAlgorithms()) {
torch::List<std::optional<Tensor>> indices;
indices.resize(dim + 1);
indices.set(dim, index);

View File

@ -904,19 +904,11 @@ Tensor mvlgamma(const Tensor& self, int64_t p) {
return args.lgamma_().sum(-1).add_(p2_sub_p * std::log(c10::pi<double>) * QUARTER);
}
// since mvlgamma_ has different signature from its
// out and functional variant, we explicitly
// define it (instead of using structured kernel).
Tensor& mvlgamma_(Tensor& self, int64_t p) {
mvlgamma_check(self, p);
Tensor args = native::arange(
-p *HALF + HALF,
HALF,
HALF,
optTypeMetaToScalarType(self.options().dtype_opt()),
self.options().layout_opt(),
self.options().device_opt(),
self.options().pinned_memory_opt());
args = args.add(self.unsqueeze(-1));
const auto p2_sub_p = static_cast<double>(p * (p - 1));
return self.copy_(args.lgamma_().sum(-1).add_(p2_sub_p * std::log(c10::pi<double>) * QUARTER));
return at::mvlgamma_out(self, self, p);
}
Tensor& mvlgamma_out(const Tensor& self, int64_t p, Tensor& result) {

View File

@ -8,6 +8,7 @@
#include <ATen/cpu/vec/vec.h>
#include <ATen/native/cpu/int_mm_kernel.h>
#include <ATen/native/cpu/utils.h>
#include <cmath>
#include <c10/util/Unroll.h>
#include <c10/util/irange.h>
@ -793,6 +794,139 @@ bool can_use_kleidiai(
}
#endif
static void ref_dyn_quant_matmul_4bit_channelwise_kernel_bf16(
size_t m,
size_t n,
size_t k,
const uint16_t* lhs_bf16,
const uint8_t* rhs_qs4cx,
const float* rhs_scales,
uint16_t* dst_bf16,
float scalar_min,
float scalar_max,
const float* bias) {
// Roundup lambda for internal stride calculations
auto roundup = [](size_t a, size_t b) { return ((a + b - 1) / b) * b; };
// Cast bfloat16 to float32 inline
auto cast_bf16_to_f32 = [](uint16_t bf16_val) {
uint32_t tmp = static_cast<uint32_t>(bf16_val) << 16;
float f;
std::memcpy(&f, &tmp, sizeof(f));
return f;
};
// Cast float32 to bfloat16 inline
auto cast_f32_to_bf16 = [](float f) {
uint32_t bits;
std::memcpy(&bits, &f, sizeof(bits));
return static_cast<uint16_t>(bits >> 16);
};
// Quantization pack lambda (channelwise QA8DX)
auto quant_pack_8bit_channelwise =
[&](size_t M, size_t K, const uint16_t* src_bf16, int8_t* dst_qa8dx) {
constexpr int8_t kI8Min = std::numeric_limits<std::int8_t>::lowest();
constexpr int8_t kI8Max = std::numeric_limits<std::int8_t>::max();
const size_t dst_stride =
K * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
for (size_t i = 0; i < M; ++i) {
const uint16_t* row_ptr = src_bf16 + i * K;
// find min/max
float mn = FLT_MAX, mx = -FLT_MAX;
for (size_t j = 0; j < K; ++j) {
float v = cast_bf16_to_f32(row_ptr[j]);
mn = std::min(mn, v);
mx = std::max(mx, v);
}
float rmin = std::min(0.0f, mn);
float rmax = std::max(0.0f, mx);
constexpr float qmin = static_cast<float>(kI8Min);
constexpr float qmax = static_cast<float>(kI8Max);
float scale = (rmin == rmax) ? 1.f : (qmax - qmin) / (rmax - rmin);
float recip = scale ? 1.0f / scale : 0.0f;
int32_t zp;
float des_min = rmin * scale;
float des_max = rmax * scale;
float err_min = qmin + des_min;
float err_max = qmax + des_max;
float zp_f =
(err_min + err_max) > 0 ? qmin - des_min : qmax - des_max;
zp_f = std::clamp(zp_f, qmin, qmax);
zp = std::lrintf(zp_f);
int8_t* out_ptr = dst_qa8dx + i * dst_stride;
// store header
*reinterpret_cast<float*>(out_ptr) = recip;
*reinterpret_cast<int32_t*>(out_ptr + sizeof(float)) = -zp;
out_ptr += sizeof(float) + sizeof(int32_t);
// quantize
for (size_t j = 0; j < K; ++j) {
float v = cast_bf16_to_f32(row_ptr[j]);
int32_t q = static_cast<int32_t>(std::round(v * scale)) + zp;
q = std::clamp(
q, static_cast<int32_t>(kI8Min), static_cast<int32_t>(kI8Max));
*out_ptr++ = static_cast<int8_t>(q);
}
}
};
// MatMul lambda (MXN x MXK -> MNXK BF16)
auto matmul_kernel = [&](size_t M,
size_t N,
size_t K,
const int8_t* lhs,
const uint8_t* rhs,
const float* scales,
uint16_t* dst,
float lo,
float hi) {
const size_t lhs_stride =
K * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
const size_t rhs_stride = roundup(K, 2) / 2;
for (size_t i = 0; i < M; ++i) {
const int8_t* lhs_row = lhs + i * lhs_stride;
for (size_t j = 0; j < N; ++j) {
int32_t acc = 0;
const int8_t* lptr = lhs_row;
const uint8_t* rptr = rhs + j * rhs_stride;
float lhs_scale = *reinterpret_cast<const float*>(lptr);
int32_t lhs_off =
*reinterpret_cast<const int32_t*>(lptr + sizeof(float));
lptr += sizeof(float) + sizeof(int32_t);
for (size_t t = 0; t < K; ++t) {
int32_t lv = static_cast<int32_t>(lptr[t]);
uint8_t bv = rptr[t / 2];
int32_t rv = ((t & 1) == 0) ? (static_cast<int32_t>(bv & 0xF) - 8)
: (static_cast<int32_t>(bv >> 4) - 8);
acc += lv * rv + lhs_off * rv;
}
float res = static_cast<float>(acc) * scales[j] * lhs_scale;
if (bias) {
res += bias[j];
}
res = std::clamp(res, lo, hi);
*dst++ = cast_f32_to_bf16(res);
}
}
};
// allocate and run
std::unique_ptr<int8_t[]> packed(
new int8_t[m * (k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t))]);
quant_pack_8bit_channelwise(m, k, lhs_bf16, packed.get());
matmul_kernel(
m,
n,
k,
packed.get(),
rhs_qs4cx,
rhs_scales,
dst_bf16,
scalar_min,
scalar_max);
}
/**
* The Int4 quantized weights must be represented as a uint8 tensor
* For matrix multiplication with a weight shape of (N x K)
@ -819,21 +953,21 @@ void dyn_quant_pack_4bit_weight_kernel(
#if AT_KLEIDIAI_ENABLED()
if (can_use_kleidiai(scales_zeros, K, block_size)) {
const int64_t weight_packed_size =
kleidiai::kai_pack_rhs_int4_size(N, K, block_size);
kleidiai::kai_pack_rhs_int4_size(N, K, block_size, weights.scalar_type());
packed_weights.resize_({weight_packed_size});
kleidiai::kai_pack_int4_rhs(
packed_weights, weights, scales_zeros, bias, N, K, block_size);
} else
#endif
{
TORCH_CHECK(
bias.has_value() == 0,
__func__,
" : Bias is unsupported in reference implementation");
packed_weights = packed_weights.to(kFloat);
auto weight_reshaped = weights.view({-1}).to(kFloat);
auto scales_zeros_reshaped = scales_zeros.view({-1}).to(kFloat);
auto res = at::cat({weight_reshaped, scales_zeros_reshaped}, 0);
auto weight_reshaped = weights.reshape({-1}).to(kFloat);
auto scales_zeros_reshaped = scales_zeros.reshape({-1}).to(kFloat);
std::vector<at::Tensor> tensors_to_cat = {weight_reshaped, scales_zeros_reshaped};
if (bias.has_value()) {
tensors_to_cat.push_back(bias.value().view({-1}).to(kFloat));
}
auto res = at::cat(tensors_to_cat, 0);
packed_weights.resize_(res.sizes()).copy_(res);
}
}
@ -847,7 +981,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
const float* rhs_scales_f32,
float* dst_f32,
float scalar_min,
float scalar_max) {
float scalar_max,
const float* bias) {
const size_t input_size_8bit = m * (k + sizeof(int32_t) + sizeof(float));
auto lhs_qa8dx_buffer = std::make_unique<uint8_t[]>(input_size_8bit);
@ -857,6 +992,9 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
// required format for matmul
auto input_quant_pack_8bit_channelwise =
[&](size_t m, size_t k, const float* lhs_f32, int8_t* lhs_qa8dx) {
constexpr int8_t kI8Min = std::numeric_limits<std::int8_t>::lowest();
constexpr int8_t kI8Max = std::numeric_limits<std::int8_t>::max();
const size_t dst_stride =
(k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t));
@ -877,8 +1015,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
}
// Maximum/minimum int8 values
const float qmin = (float)INT8_MIN;
const float qmax = (float)INT8_MAX;
constexpr float qmin = static_cast<float>(kI8Min);
constexpr float qmax = static_cast<float>(kI8Max);
const float rmin0 = std::min(0.0f, min0);
const float rmax0 = std::max(0.0f, max0);
@ -904,7 +1042,7 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
zero_point0 = std::min(zero_point0, qmax);
// Round to nearest integer
const int32_t nudged_zero_point0 = lrintf(zero_point0);
const int32_t nudged_zero_point0 = std::lrintf(zero_point0);
int8_t* dst_ptr = lhs_qa8dx + m_idx * dst_stride;
@ -922,8 +1060,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
v0_s32 = v0_s32 + nudged_zero_point0;
v0_s32 = std::max(v0_s32, static_cast<int32_t>(INT8_MIN));
v0_s32 = std::min(v0_s32, static_cast<int32_t>(INT8_MAX));
v0_s32 = std::max(v0_s32, static_cast<int32_t>(kI8Min));
v0_s32 = std::min(v0_s32, static_cast<int32_t>(kI8Max));
dst_ptr[0] = (int8_t)v0_s32;
dst_ptr += sizeof(int8_t);
}
@ -987,6 +1125,10 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
main_acc = main_acc * lhs_scale;
if (bias) {
main_acc += bias[n_idx];
}
// Clamp (min-max) operation
main_acc = std::max(main_acc, scalar_min);
main_acc = std::min(main_acc, scalar_max);
@ -1007,12 +1149,16 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
const float* rhs_scales_fp32,
float* dst_f32,
float scalar_min,
float scalar_max) {
float scalar_max,
const float* bias) {
// Lambda for LHS quantization
auto lhs_quant_pack = [&](size_t m,
size_t k,
const float* lhs_f32,
int8_t* lhs_qa8dx) {
constexpr int8_t kI8Min = std::numeric_limits<std::int8_t>::lowest();
constexpr int8_t kI8Max = std::numeric_limits<std::int8_t>::max();
const size_t dst_stride =
(k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t));
@ -1028,8 +1174,8 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
min0 = std::min(src0_0, min0);
}
const float qmin = (float)INT8_MIN;
const float qmax = (float)INT8_MAX;
constexpr float qmin = static_cast<float>(kI8Min);
constexpr float qmax = static_cast<float>(kI8Max);
const float rmin0 = std::min(0.0f, min0);
const float rmax0 = std::max(0.0f, max0);
@ -1046,7 +1192,7 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
zero_point0 = std::max(zero_point0, qmin);
zero_point0 = std::min(zero_point0, qmax);
const int32_t nudged_zero_point0 = lrintf(zero_point0);
const int32_t nudged_zero_point0 = std::lrintf(zero_point0);
int8_t* dst_ptr = lhs_qa8dx + row_idx * dst_stride;
@ -1059,9 +1205,8 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
const float src0_0 = src_ptr[k_idx];
int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
v0_s32 = std::max(
std::min(
v0_s32 + nudged_zero_point0, static_cast<int32_t>(INT8_MAX)),
static_cast<int32_t>(INT8_MIN));
std::min(v0_s32 + nudged_zero_point0, static_cast<int32_t>(kI8Max)),
static_cast<int32_t>(kI8Min));
dst_ptr[0] = (int8_t)v0_s32;
dst_ptr += sizeof(int8_t);
}
@ -1118,6 +1263,11 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
}
main_acc = main_acc * lhs_scale;
if (bias) {
main_acc += bias[col_idx];
}
main_acc = std::max(main_acc, scalar_min);
main_acc = std::min(main_acc, scalar_max);
@ -1128,28 +1278,27 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
}
/**
* Dynamic Input Quant 4 bit weights matmul execution flow
(INT4 Weights + FP scales + FP32 Bias)
FP32 Input Packed Buffer
| |
Quantize Cast
to INT8 to INT8
| |
v v
INT8 Input INT8 Weights
\ /
\ /
\ /
INT8 Matrix Multiplication
|
v
FP32 Dequantized and Accumulate in FP32
|
v
FP32 Final Output
* The Groupwise kernel requires BFloat16 Scales and Channelwise kernel requires
* Float32 Scales. If not provided, we will use fallback implementation.
* Dynamic INT4 weight-only MatMul with per-row input quantization.
*
* Execution Flow:
*
* (INT4 Weights + FP Scales [+ optional Bias])
*
* Input (FP32 or BF16) Packed Weight Buffer
* | |
* Row-wise Quantization (INT8) |
* | |
* INT8 Input Activation INT4 Quantized Weights + Scales
* \ /
* \ /
* Quantized Matrix Multiply
* |
* Output Tensor (BF16 or FP32)
*
* Notes:
* - Groupwise kernels expect BF16 scales
* - Channelwise kernels expect FP32 scales
* - Bias is currently unsupported in fallback path
*/
void dyn_quant_matmul_4bit_kernel(
const Tensor& output,
@ -1161,65 +1310,75 @@ void dyn_quant_matmul_4bit_kernel(
const int64_t block_size) {
#if AT_KLEIDIAI_ENABLED()
const int64_t weight_packed_size =
kleidiai::kai_pack_rhs_int4_size(N, K, block_size);
kleidiai::kai_pack_rhs_int4_size(N, K, block_size, inp.scalar_type());
if (weight_packed_size == packed_weights.numel()) {
// KleidiAI interface internally handles the Channelwise and groupwise
// distinction
kleidiai::kai_quant_pack_lhs_int4_mm(
output, inp, packed_weights, M, N, K, block_size);
kleidiai::kai_quant_pack_lhs_int4_mm(output, inp, packed_weights, M, N, K, block_size);
} else
#endif
{
float* lhs_f32 = reinterpret_cast<float*>(inp.data_ptr());
const auto weights_size = N * K / 2;
// The weights needs to be in uint8_t data type after quantization
auto extracted_weights =
(packed_weights.narrow(0, 0, weights_size)).to(kByte);
auto float32_scales =
(packed_weights.narrow(
0, weights_size, packed_weights.size(0) - weights_size))
.to(kFloat);
uint8_t* rhs_4bit =
reinterpret_cast<uint8_t*>(extracted_weights.data_ptr());
float* rhs_scales_f32 = reinterpret_cast<float*>(float32_scales.data_ptr());
float* dst_f32 = reinterpret_cast<float*>(output.data_ptr());
if (block_size == K) {
ref_dyn_quant_matmul_4bit_channelwise_kernel(
M,
N,
K,
lhs_f32,
rhs_4bit,
rhs_scales_f32,
dst_f32,
-FLT_MAX,
FLT_MAX);
} else if (!(block_size % 32) && !(K % block_size)) {
ref_dyn_quant_matmul_4bit_groupwise_kernel(
M,
N,
K,
block_size,
lhs_f32,
rhs_4bit,
rhs_scales_f32,
dst_f32,
-FLT_MAX,
FLT_MAX);
} else {
TORCH_CHECK(
block_size == K || (!(block_size % 32) && !(K % block_size)),
__func__,
": Group size should be multiple 32 or in_features [",
K,
"]. Provided ",
block_size);
{
void* input = inp.data_ptr();
void* dst = output.data_ptr();
// Extract weights, sclaes and biases form from packed tensor
const int weights_elements = N * K / 2;
const int scale_elements = N * (K / block_size);
TORCH_CHECK(packed_weights.numel() >= (weights_elements + scale_elements), "Invalid packed weight tensor size");
auto extracted_weights = packed_weights.narrow(0, 0, weights_elements).to(kByte);
auto extracted_scales_and_bias = packed_weights.narrow(0, weights_elements, packed_weights.size(0) - weights_elements).to(kFloat);
auto float32_scales = extracted_scales_and_bias.narrow(0, 0, scale_elements);
int bias_elements = packed_weights.numel() - (weights_elements + scale_elements);
float* weight_scales = float32_scales.data_ptr<float>();
void* bias_data = nullptr;
if (bias_elements) {
auto float32_bias = extracted_scales_and_bias.narrow(0, scale_elements, bias_elements);
TORCH_CHECK(float32_bias.size(0) == N, "Expected bias length to match output dimension");
bias_data = float32_bias.data_ptr();
}
// 2 elements of 4 bit weights are packed into 1 uint8 packet
uint8_t* weights_4bit = reinterpret_cast<uint8_t*>(extracted_weights.data_ptr());
// Dispatch to reference kernels
if (inp.scalar_type() == at::kBFloat16) {
// BF16 input, BF16 output
constexpr float BF16_MAX = 3.38953139e+38f;
constexpr float BF16_MIN = -BF16_MAX;
if (block_size == K) {
ref_dyn_quant_matmul_4bit_channelwise_kernel_bf16(
M, N, K,
(uint16_t*)input, weights_4bit, weight_scales,
(uint16_t*)dst, BF16_MIN, BF16_MAX, (float*)bias_data);
} else {
TORCH_CHECK(false, "Unsupported block size for BF16 fallback");
}
} else if (inp.scalar_type() == at::kFloat) {
// FP32 input, FP32 output
if (block_size == K) {
ref_dyn_quant_matmul_4bit_channelwise_kernel(
M, N, K,
(float*)input, weights_4bit, weight_scales,
(float*)dst, -FLT_MAX, FLT_MAX, (float*)bias_data);
} else if (!(block_size % 32) && !(K % block_size)) {
ref_dyn_quant_matmul_4bit_groupwise_kernel(
M, N, K, block_size,
(float*)input, weights_4bit, weight_scales,
(float*)dst, -FLT_MAX, FLT_MAX, (float*)bias_data);
} else {
TORCH_CHECK(false, "Unsupported block size for FP32 fallback");
}
} else {
TORCH_CHECK(false, "Unsupported input/output dtype combination for int4mm kernel");
}
}
}
}
} // anonymous namespace
}
ALSO_REGISTER_AVX512_DISPATCH(weight_to_int4pack_stub, &weight_to_int4pack_kernel)
ALSO_REGISTER_AVX512_DISPATCH(int4pack_mm_stub, &int4pack_mm_kernel)
REGISTER_DISPATCH(dyn_quant_pack_4bit_weight_stub, &dyn_quant_pack_4bit_weight_kernel)

View File

@ -78,9 +78,18 @@ __global__ void EmbeddingBag_updateOutputKernel_max(
scalar_t weightFeatMax = 0;
int64_t bag_size_ = 0;
int64_t maxWord = -1;
// Separate validation loop reduces register pressure in the main loop below.
// No early exit (break) on invalid input as benchmarking shows it degrades performance.
bool has_invalid_index = false;
for (int64_t emb = begin; emb < end; emb++) {
index_t input_idx = input[emb];
has_invalid_index = has_invalid_index || (input_idx < 0 || input_idx >= numRows);
}
CUDA_KERNEL_ASSERT(!has_invalid_index && "Invalid input index in EmbeddingBag: index out of range [0, numRows)");
for (int64_t emb = begin; emb < end; emb++) {
bool pad = (input[emb] == padding_idx);
CUDA_KERNEL_ASSERT(input[emb] < numRows);
const int64_t weightRow = input[emb] * weight_stride0;
scalar_t weightValue = weightFeat[weightRow];
if (bag_size_ == 0 || weightValue > weightFeatMax) {
@ -129,10 +138,19 @@ __global__ void EmbeddingBag_updateOutputKernel_sum_mean(
CUDA_KERNEL_ASSERT(end >= begin);
accscalar_t weightFeatSum = 0;
int64_t bag_size_ = 0;
// Separate validation loop reduces register pressure in the main loop below.
// No early exit (break) on invalid input as benchmarking shows it degrades performance.
bool has_invalid_index = false;
for (int64_t emb = begin; emb < end; emb++) {
index_t input_idx = input[emb];
has_invalid_index = has_invalid_index || (input_idx < 0 || input_idx >= numRows);
}
CUDA_KERNEL_ASSERT(!has_invalid_index && "Invalid input index in EmbeddingBag: index out of range [0, numRows)");
for (int64_t emb = begin; emb < end; emb++) {
index_t input_idx = input[emb];
bool pad = (input_idx == padding_idx);
CUDA_KERNEL_ASSERT(0 <= input_idx && input_idx < numRows);
const int64_t weightRow = input_idx * weight_stride0;
scalar_t weightValue = weightFeat[weightRow];
weightValue = pad ? static_cast<scalar_t>(0) : weightValue;

View File

@ -78,9 +78,9 @@ _mx8_mx8_bf16_grouped_mm_fbgemm(
const Tensor& mat_a,
const Tensor& mat_b,
const Tensor& scale_a,
const SwizzleType& swizzle_a,
const SwizzleType swizzle_a,
const Tensor& scale_b,
const SwizzleType& swizzle_b,
const SwizzleType swizzle_b,
const std::optional<at::Tensor>& offs,
Tensor& out) {
const bool a_is_2d = mat_a.dim() == 2;

View File

@ -5,69 +5,11 @@
#include <cuda_bf16.h>
#endif
// ROCm 6.3 is planned to have these functions, but until then here they are.
#if defined(USE_ROCM)
#include <device_functions.h>
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
__device__ inline __hip_bfloat162 preview_unsafeAtomicAdd(__hip_bfloat162* address, __hip_bfloat162 value) {
#if (defined(__gfx942__)) && \
__has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2bf16)
typedef unsigned short __attribute__((ext_vector_type(2))) vec_short2;
static_assert(sizeof(vec_short2) == sizeof(__hip_bfloat162_raw));
union {
__hip_bfloat162_raw bf162_raw;
vec_short2 vs2;
} u{static_cast<__hip_bfloat162_raw>(value)};
u.vs2 = __builtin_amdgcn_flat_atomic_fadd_v2bf16((vec_short2*)address, u.vs2);
return static_cast<__hip_bfloat162>(u.bf162_raw);
#else
static_assert(sizeof(unsigned int) == sizeof(__hip_bfloat162_raw));
union u_hold {
__hip_bfloat162_raw h2r;
unsigned int u32;
};
u_hold old_val, new_val;
old_val.u32 = __hip_atomic_load((unsigned int*)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
do {
new_val.h2r = __hadd2(old_val.h2r, value);
} while (!__hip_atomic_compare_exchange_strong(
(unsigned int*)address, &old_val.u32, new_val.u32,
__ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT));
return old_val.h2r;
#endif
}
__device__ inline __half2 preview_unsafeAtomicAdd(__half2* address, __half2 value) {
#if (defined(__gfx942__)) && \
__has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2f16)
// The api expects an ext_vector_type of half
typedef _Float16 __attribute__((ext_vector_type(2))) vec_fp162;
static_assert(sizeof(vec_fp162) == sizeof(__half2_raw));
union {
__half2_raw h2r;
vec_fp162 fp16;
} u {static_cast<__half2_raw>(value)};
u.fp16 = __builtin_amdgcn_flat_atomic_fadd_v2f16((vec_fp162*)address, u.fp16);
return static_cast<__half2>(u.h2r);
#else
static_assert(sizeof(__half2_raw) == sizeof(unsigned int));
union u_hold {
__half2_raw h2r;
unsigned int u32;
};
u_hold old_val, new_val;
old_val.u32 = __hip_atomic_load((unsigned int*)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
do {
new_val.h2r = __hadd2(old_val.h2r, value);
} while (!__hip_atomic_compare_exchange_strong(
(unsigned int*)address, &old_val.u32, new_val.u32,
__ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT));
return old_val.h2r;
#endif
}
#define ATOMICADD preview_unsafeAtomicAdd
#define ATOMICADD unsafeAtomicAdd
#define NATIVE_ZERO_BF16 __float2bfloat16(0.0f)
#else
#define ATOMICADD atomicAdd

View File

@ -740,7 +740,12 @@ _scaled_rowwise_rowwise(
TORCH_CHECK_VALUE(scale_a.numel() == mat_a.size(0) && scale_a.scalar_type() == kFloat, "scale_a must have ", mat_a.size(0), " Float elements, got ", scale_a.numel())
TORCH_CHECK_VALUE(scale_b.numel() == mat_b.size(1) && scale_b.scalar_type() == kFloat, "scale_b must have ", mat_b.size(1), " Float elements, got ", scale_b.numel())
TORCH_CHECK_VALUE(scale_a.stride(1) == 1, "expected scale_a.stride(1) to be 1, but got ", scale_a.stride(1));
// if we have a scale of shape [256, 1] (say), then stride can be [1, 0] - handle this case
TORCH_CHECK_VALUE(
scale_a.stride(1) == 1 ||
scale_a.size(1) == 1,
"expected scale_a.stride(1) to be 1, but got ", scale_a.stride(1)
);
TORCH_CHECK_VALUE(scale_b.stride(1) == 1, "expected scale_b.stride(1) to be 1, but got ", scale_b.stride(1));
auto scaling_choice_a = ScalingType::RowWise;
@ -1096,6 +1101,19 @@ _scaled_mxfp8_mxfp8(
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
}
void
_check_mxfp4_support() {
#ifndef USE_ROCM
auto dprops = at::cuda::getCurrentDeviceProperties();
// Only on B200 GPUs
TORCH_CHECK_NOT_IMPLEMENTED(
// B200 = 10.0, B300 = 10.3
dprops->major == 10,
"MXFP4 scaling only supported in CUDA for B200/B300"
);
#endif
}
Tensor&
_scaled_mxfp4_mxfp4(
@ -1108,6 +1126,7 @@ _scaled_mxfp4_mxfp4(
#if defined(_WIN32) || (!defined(USE_ROCM) && !defined(USE_FBGEMM_GENAI))
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM and CUDA+FBGEMM_GENAI only");
#else
_check_mxfp4_support();
// Restrictions:
// A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32
TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2 && mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_a and mat_b must be fp4 types, got: ",

View File

@ -21,18 +21,27 @@ void kai_pack_int4_rhs(
const int64_t n,
const int64_t k,
const int64_t bl) {
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
if (bl == k) {
// Channelwise
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
auto& params = kernel_packet.rhs_pack_params;
params.lhs_zero_point = 1;
params.rhs_zero_point = 8;
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_f32_qa8dxp_qs4cxp>(
kernel_packet, weight_packed, weight, scales, bias, n, k);
if (weight.scalar_type() == at::kBFloat16) {
auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod);
auto& params = kernel_packet.rhs_pack_params;
params.lhs_zero_point = 1;
params.rhs_zero_point = 8;
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_bf16_qa8dxp_qs4cxp>(
kernel_packet, weight_packed, weight, scales, bias, n, k);
} else {
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
auto& params = kernel_packet.rhs_pack_params;
params.lhs_zero_point = 1;
params.rhs_zero_point = 8;
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_f32_qa8dxp_qs4cxp>(
kernel_packet, weight_packed, weight, scales, bias, n, k);
}
} else if (!(bl % 32) && !(k % bl)) {
// Groupwise
auto kernel_packet = kai_select_groupwise_matmul_ukernel(
@ -63,19 +72,29 @@ void kai_pack_int4_rhs(
size_t kai_pack_rhs_int4_size(
const int64_t n,
const int64_t k,
const int64_t bl) {
const int64_t bl,
at::ScalarType tensor_dtype) {
size_t packed_size = n * k;
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
if (bl == k) {
// Channelwise
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
const auto& ukernel = kernel_packet.ukernel;
const size_t nr = ukernel.get_nr();
const size_t kr = ukernel.get_kr();
const size_t sr = ukernel.get_sr();
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
if (tensor_dtype == at::kBFloat16) {
auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod);
const auto& ukernel = kernel_packet.ukernel;
const size_t nr = ukernel.get_nr();
const size_t kr = ukernel.get_kr();
const size_t sr = ukernel.get_sr();
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
} else {
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
const auto& ukernel = kernel_packet.ukernel;
const size_t nr = ukernel.get_nr();
const size_t kr = ukernel.get_kr();
const size_t sr = ukernel.get_sr();
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
}
} else if (!(bl % 32) && !(k % bl)) {
// Groupwise
auto kernel_packet = kai_select_groupwise_matmul_ukernel(
@ -148,8 +167,7 @@ static void kai_quant_pack_lhs_int4_mm_groupwise(
const auto lhs_src_ptr = lhs_native_mtx_f32 + thread_id * src_stride;
const int64_t m_idx = thread_id * vec_per_thread;
auto lhs_packed_ptr = lhs_packed_base +
kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(
m_idx, k, mr, kr, sr);
kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr);
const int64_t vec_num = (thread_id == num_threads - 1)
? (m - vec_per_thread * thread_id)
: vec_per_thread;
@ -259,8 +277,7 @@ static void kai_quant_pack_lhs_int4_mm_channelwise(
const auto lhs_src_ptr = lhs_native_mtx_f32 + thread_id * src_stride;
const int64_t m_idx = thread_id * vec_per_thread;
auto lhs_packed_ptr = lhs_packed_base +
kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(
m_idx, k, mr, kr, sr);
kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr);
const int64_t vec_num = (thread_id == num_threads - 1)
? (m - vec_per_thread * thread_id)
: vec_per_thread;
@ -320,19 +337,144 @@ static void kai_quant_pack_lhs_int4_mm_channelwise(
});
}
void kai_quant_pack_lhs_int4_mm(
static void kai_quant_pack_lhs_int4_mm_bf16_channelwise(
const Tensor& output,
const Tensor& input,
const Tensor& weight,
const int64_t m,
const int64_t n,
const int64_t k) {
// Kernel IDs for GEMM and GEMV
constexpr kai_kernel_id gemm_id =
kai_kernel_id::matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm;
constexpr kai_kernel_id gemv_id =
kai_kernel_id::matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod;
// Get total threads and select kernel
const int64_t total_threads = at::get_num_threads();
auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(gemv_id);
if (cpuinfo_has_arm_i8mm() && m > 1) {
kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(gemm_id);
}
// Thread blocking parameters
const int64_t n_step = kernel_packet.ukernel.get_n_step();
const size_t mr = kernel_packet.ukernel.get_mr();
const size_t kr = kernel_packet.ukernel.get_kr();
const size_t sr = kernel_packet.ukernel.get_sr();
const size_t lhs_packed_size =
kernel_packet.kai_get_lhs_packed_size(m, k, mr, kr, sr);
auto lhs_packed = std::make_unique<uint8_t[]>(lhs_packed_size);
uint8_t* dst_act_mtx_bf16 = reinterpret_cast<uint8_t*>(output.data_ptr());
const uint8_t* lhs_native_mtx_bf16 =
reinterpret_cast<const uint8_t*>(input.data_ptr());
const uint8_t* rhs_packed_mtx_qs4cx =
reinterpret_cast<const uint8_t*>(weight.data_ptr());
uint8_t* lhs_packed_base = lhs_packed.get();
constexpr int32_t element_size = sizeof(uint16_t);
const size_t lhs_stride = k * element_size;
const size_t dst_stride = n * element_size;
// LHS quantization packing
int64_t vec_per_thread = get_vec_per_thread(m, total_threads, mr);
int64_t num_threads = (m + vec_per_thread - 1) / vec_per_thread;
const size_t src_stride = vec_per_thread * lhs_stride;
auto lhs_quant_pack = [=, &kernel_packet](int64_t thread_id) {
const auto lhs_src_ptr = lhs_native_mtx_bf16 + thread_id * src_stride;
const int64_t m_idx = thread_id * vec_per_thread;
auto lhs_packed_ptr = lhs_packed_base +
kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr);
const int64_t vec_num = (thread_id == num_threads - 1)
? (m - vec_per_thread * thread_id)
: vec_per_thread;
kernel_packet.kai_run_lhs_quant_pack(
vec_num,
k,
mr,
kr,
sr,
0,
(const uint16_t*)lhs_src_ptr,
lhs_stride,
lhs_packed_ptr);
};
at::parallel_for(
0, num_threads, /*grain_size=*/1, [&](int64_t begin, int64_t end) {
for (int64_t thread_id = begin; thread_id < end; ++thread_id) {
lhs_quant_pack(thread_id);
}
});
// Matrix multiplication
vec_per_thread = get_vec_per_thread(n, total_threads, n_step);
num_threads = (n + vec_per_thread - 1) / vec_per_thread;
auto mm = [=, &kernel_packet](int64_t thread_id) {
const auto rhs_packed_ptr = rhs_packed_mtx_qs4cx +
kernel_packet.ukernel.get_rhs_packed_offset(
thread_id * vec_per_thread, k);
auto dst_ptr = dst_act_mtx_bf16 +
kernel_packet.ukernel.get_dst_offset(
0, thread_id * vec_per_thread, dst_stride);
const int64_t vec_num = (thread_id == num_threads - 1)
? (n - vec_per_thread * thread_id)
: vec_per_thread;
kernel_packet.ukernel.run_matmul(
m,
vec_num,
k,
lhs_packed_base,
rhs_packed_ptr,
(uint16_t*)dst_ptr,
dst_stride,
element_size, // dst_stride_col
-FLT_MAX,
FLT_MAX);
};
at::parallel_for(
0, num_threads, /*grain_size=*/1, [&](int64_t begin, int64_t end) {
for (int64_t thread_id = begin; thread_id < end; ++thread_id) {
mm(thread_id);
}
});
}
void kai_quant_pack_lhs_int4_mm(
const at::Tensor& output,
const at::Tensor& input,
const at::Tensor& weight,
const int64_t m,
const int64_t n,
const int64_t k,
const int64_t bl) {
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
if (bl == k) {
kleidiai::kai_quant_pack_lhs_int4_mm_channelwise(
output, input, weight, m, n, k);
} else if (!(bl % 32) && !(k % bl)) {
const auto input_dtype = input.dtype();
if (input_dtype == at::kBFloat16) {
if (cpuinfo_has_arm_bf16()) {
kleidiai::kai_quant_pack_lhs_int4_mm_bf16_channelwise(
output, input, weight, m, n, k);
} else {
TORCH_CHECK(
false,
"BF16 Unsupported: CPU does not support BF16. Please use a CPU with BF16 support.");
}
} else if (input_dtype == at::kFloat) {
kleidiai::kai_quant_pack_lhs_int4_mm_channelwise(
output, input, weight, m, n, k);
} else {
TORCH_CHECK(
false,
"Unsupported input data type: Only Bfloat16 and Float inputs are supported.");
}
} else if ((bl % 32 == 0) && (k % bl == 0)) {
kleidiai::kai_quant_pack_lhs_int4_mm_groupwise(
output, input, weight, m, n, k, bl);
}

View File

@ -25,7 +25,8 @@ void kai_pack_int4_rhs(
size_t kai_pack_rhs_int4_size(
const int64_t n,
const int64_t k,
const int64_t bl);
const int64_t bl,
at::ScalarType tensor_dtype = at::kFloat);
/**
* @brief Run 2 operations ( Input quantize and pack -> 4 bit Matmul )

View File

@ -36,7 +36,8 @@ void kai_pack_rhs_groupwise_int4(
AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null");
}
float* bias_ptr = bias.has_value() ? bias.value().data_ptr<float>() : NULL;
float* bias_ptr =
bias.has_value() ? bias.value().to(kFloat).data_ptr<float>() : NULL;
auto& params = kernel.rhs_pack_params;
kernel.kai_run_rhs_pack(
@ -73,7 +74,8 @@ void kai_pack_rhs_channelwise_int4(
auto weight_packed_data =
reinterpret_cast<uint8_t*>(weight_packed.data_ptr());
const auto weight_data = weight.data_ptr<uint8_t>();
const auto scales_data = scales.data_ptr<float>();
const auto scales_data = scales.to(kFloat).data_ptr<float>();
if (weight_data == nullptr) {
AT_ERROR("kai_pack_rhs_channelwise_int4: Weight data pointer is null");
@ -83,7 +85,8 @@ void kai_pack_rhs_channelwise_int4(
AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null");
}
float* bias_ptr = bias.has_value() ? bias.value().data_ptr<float>() : NULL;
float* bias_ptr =
bias.has_value() ? bias.value().to(kFloat).data_ptr<float>() : NULL;
auto& params = kernel.rhs_pack_params;
kernel.kai_run_rhs_pack(

View File

@ -68,5 +68,39 @@ kai_matmul_ukernel_f32_qa8dxp_qs4cxp kai_select_channelwise_matmul_ukernel(
const kai_kernel_id id) {
return channelwise_8bit_4bit_kernels.at(id);
}
// Kernel Mapping - BF16 Channelwise
std::unordered_map<kai_kernel_id, kai_matmul_ukernel_bf16_qa8dxp_qs4cxp>
bf16_channelwise_8bit_4bit_kernels = {
{kai_kernel_id::
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
{{kai_get_m_step_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_n_step_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_mr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_nr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_kr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_sr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_lhs_packed_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_rhs_packed_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_dst_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_dst_size_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_run_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod}}},
{kai_kernel_id::matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
{{kai_get_m_step_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_n_step_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_mr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_nr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_kr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_sr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_lhs_packed_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_rhs_packed_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_dst_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_dst_size_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_run_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm}}}};
kai_matmul_ukernel_bf16_qa8dxp_qs4cxp kai_select_bf16_channelwise_matmul_ukernel(
const kai_kernel_id id) {
return bf16_channelwise_8bit_4bit_kernels.at(id);
}
} // namespace at::native::kleidiai
#endif

View File

@ -10,21 +10,32 @@
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h>
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h>
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h>
#include <kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.h>
#include <kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.h>
#include <kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_interface.h>
#include <kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h>
#include <kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.h>
#include <kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h>
#include <kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h>
namespace at::native::kleidiai {
enum class kai_kernel_id {
// FP32 inputs, 4-bit weights, FP32 output
matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod =
0, // Groupwise 4 bit GEMV
0, // Groupwise 4-bit GEMV (per-group scales, NEON DOTPROD)
matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm =
1, // Groupwise 4 bit GEMM
1, // Groupwise 4-bit GEMM (per-group scales, NEON I8MM)
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod =
2, // Channelwise 4 bit GEMV
2, // Channelwise 4-bit GEMV (per-channel scales, NEON DOTPROD)
matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm =
3 // Channelwise 4 bit GEMM
3, // Channelwise 4-bit GEMM (per-channel scales, NEON I8MM)
// BF16 inputs, 4-bit weights, BF16 output
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod =
4, // Channelwise 4-bit GEMV with BF16 input/output
matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm =
5 // Channelwise 4-bit GEMM with BF16 input/output
};
// Channelwise Kernel mapping
@ -66,6 +77,9 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp {
void* rhs_packed,
size_t extra_bytes,
const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params);
size_t(*kai_get_lhs_quant_pack_offset)(
size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr
);
kai_matmul_ukernel_f32_qa8dxp_qs4cxp(
const kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel& kernel)
@ -75,12 +89,71 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp {
kai_get_rhs_packed_size(
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32),
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0) {}
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32){}
};
struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp
kai_select_channelwise_matmul_ukernel(const kai_kernel_id id);
// bf16 Channelwise Kernel mapping
struct kai_matmul_ukernel_bf16_qa8dxp_qs4cxp {
struct kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_ukernel ukernel;
struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params rhs_pack_params;
size_t (*kai_get_lhs_packed_size)(
size_t m,
size_t k,
size_t mr,
size_t kr,
size_t sr);
size_t (*kai_get_rhs_packed_size)(
size_t n,
size_t k,
size_t nr,
size_t kr,
size_t sr);
void (*kai_run_lhs_quant_pack)(
size_t m,
size_t k,
size_t mr,
size_t kr,
size_t sr,
size_t m_idx_start,
const void* lhs,
size_t lhs_stride,
void* lhs_packed);
void (*kai_run_rhs_pack)(
size_t num_groups,
size_t n,
size_t k,
size_t nr,
size_t kr,
size_t sr,
const uint8_t* rhs,
const float* bias,
const float* scale,
void* rhs_packed,
size_t extra_bytes,
const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params);
size_t(*kai_get_lhs_quant_pack_offset)(
size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr
);
kai_matmul_ukernel_bf16_qa8dxp_qs4cxp(
const kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_ukernel& kernel)
: ukernel(kernel),
kai_get_lhs_packed_size(
&kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon),
kai_get_rhs_packed_size(
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_bf16_neon),
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon){}
};
struct kai_matmul_ukernel_bf16_qa8dxp_qs4cxp
kai_select_bf16_channelwise_matmul_ukernel(const kai_kernel_id id);
// Groupwise Kernel mapping
struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel ukernel;
@ -125,6 +198,9 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
void* rhs_packed,
size_t extra_bytes,
const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params);
size_t(*kai_get_lhs_quant_pack_offset)(
size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr
);
kai_matmul_ukernel_f32_qa8dxp_qs4c32p(
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& kernel)
@ -134,7 +210,8 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
kai_get_rhs_packed_size(
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0),
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32),
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0) {}
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0),
kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32) {}
};
struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p kai_select_groupwise_matmul_ukernel(

View File

@ -82,6 +82,7 @@ NSArray<NSNumber*>* getTensorAxes(const TensorBase& t);
NSArray<NSNumber*>* getTensorAxes(const IntArrayRef& sizes, at::OptionalIntArrayRef dim);
std::string getMPSShapeString(MPSShape* shape);
std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype = true, bool exclude_shape = false);
std::string to_hex_key(float);
std::string getArrayRefString(const IntArrayRef s);
// use has_storage() on the returned tensor to determine if src actually is a view
Tensor gatherViewTensor(const Tensor& src, Tensor& dst);

View File

@ -301,6 +301,10 @@ std::string getArrayRefString(const IntArrayRef s) {
return fmt::to_string(fmt::join(s, ","));
}
std::string to_hex_key(float f) {
return fmt::format("{:a}", f);
}
std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype, bool exclude_shape) {
fmt::basic_memory_buffer<char, 100> buffer;
auto buf_iterator = std::back_inserter(buffer);

View File

@ -96,7 +96,9 @@ kernel void addmm(
auto bias =
biasData[thread_id.y * strides[3].x + thread_id.x * strides[3].y];
outputData[thread_id.y * strides[2].x + thread_id.x * strides[2].y] =
static_cast<T>(alpha_beta[0] * sum + alpha_beta[1] * bias);
static_cast<T>(
c10::metal::mul(alpha_beta[0], sum) +
c10::metal::mul(alpha_beta[1], bias));
}
}

View File

@ -121,7 +121,7 @@ Tensor& do_metal_addmm(const Tensor& self,
const Scalar& alpha,
const Scalar& beta,
const Tensor& bias) {
if (beta.toDouble() == 0 && alpha.toDouble() == 1) {
if (beta.isFloatingPoint() && alpha.isFloatingPoint() && beta.toDouble() == 0 && alpha.toDouble() == 1) {
return do_metal_mm(self, other, output);
}
auto stream = getCurrentMPSStream();
@ -147,13 +147,15 @@ Tensor& do_metal_addmm(const Tensor& self,
std::array<int64_t, 2> i64;
std::array<int32_t, 2> i32;
std::array<float, 2> f32;
} alpha_beta;
std::array<c10::complex<float>, 2> c64;
} alpha_beta{};
if (output.scalar_type() == kLong) {
alpha_beta.i64 = {alpha.toLong(), beta.toLong()};
} else if (c10::isIntegralType(output.scalar_type(), true)) {
alpha_beta.i32 = {alpha.toInt(), beta.toInt()};
} else if (c10::isComplexType(output.scalar_type())) {
alpha_beta.c64 = {alpha.toComplexFloat(), beta.toComplexFloat()};
} else {
TORCH_INTERNAL_ASSERT(c10::isFloatingType(output.scalar_type()));
alpha_beta.f32 = {alpha.toFloat(), beta.toFloat()};
}
constexpr uint32_t TILE_DIM = 16; // fastest performance from tests on multiple macs

View File

@ -91,25 +91,30 @@ static auto& lib = mps::MetalShaderLibrary::getBundledLibrary();
#include <ATen/native/mps/Repeat_metallib.h>
#endif
template <typename index_t>
void computeRepeatIndices(const index_t* repeat_ptr,
const int64_t* cumsum_ptr,
index_t* result_ptr,
int64_t size,
int64_t result_size) {
id<MTLBuffer> repeatBuffer = reinterpret_cast<id<MTLBuffer>>(repeat_ptr);
id<MTLBuffer> cumsumBuffer = reinterpret_cast<id<MTLBuffer>>(cumsum_ptr);
id<MTLBuffer> resultBuffer = reinterpret_cast<id<MTLBuffer>>(result_ptr);
TORCH_CHECK(repeatBuffer && cumsumBuffer && resultBuffer);
Tensor repeat_interleave_mps(const Tensor& repeat, std::optional<int64_t> output_size) {
TORCH_CHECK(repeat.dim() == 1, "repeat_interleave only accept 1D vector as repeat");
std::string scalar_type;
if constexpr (std::is_same_v<index_t, int32_t>) {
if (repeat.scalar_type() == kInt) {
scalar_type = "int32_t";
} else if constexpr (std::is_same_v<index_t, int64_t>) {
} else if (repeat.scalar_type() == kLong) {
scalar_type = "int64_t";
} else {
TORCH_CHECK(false, "repeat_interleave: unsupported indexing data type");
TORCH_CHECK(false, "repeats has to be Long or Int tensor");
}
if (repeat.size(0) == 0) {
return at::empty_like(repeat, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}
Tensor repeat_ = repeat.contiguous();
Tensor cumsum = repeat.cumsum(0);
int64_t total = 0;
if (output_size.has_value()) {
total = output_size.value();
} else {
total = cumsum[-1].item<int64_t>();
TORCH_CHECK((repeat >= 0).all().item<uint8_t>(), "repeats can not be negative");
}
auto result = at::empty({total}, repeat.options());
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@ -121,20 +126,13 @@ void computeRepeatIndices(const index_t* repeat_ptr,
getMPSProfiler().beginProfileKernel(pipelineState, "repeat_interleave:" + scalar_type, false);
[computeEncoder setComputePipelineState:pipelineState];
mps::mtl_setArgs(computeEncoder, repeatBuffer, cumsumBuffer, resultBuffer, size);
mps::mtl_dispatch1DJob(computeEncoder, pipelineState, size);
mps::mtl_setArgs(computeEncoder, repeat_, cumsum, result, repeat.size(0));
mps::mtl_dispatch1DJob(computeEncoder, pipelineState, repeat.size(0));
getMPSProfiler().endProfileKernel(pipelineState);
}
});
}
Tensor repeat_interleave_mps(const Tensor& repeat, std::optional<int64_t> output_size) {
Tensor output;
AT_DISPATCH_INDEX_TYPES(repeat.scalar_type(), "repeat_interleave_mps", [&]() {
output = repeat_interleave_common<index_t, computeRepeatIndices<index_t>>(repeat, output_size);
});
return output;
return result;
}
} // namespace at::native

View File

@ -5,6 +5,7 @@
#include <ATen/native/Resize.h>
#include <ATen/native/TensorCompare.h>
#include <ATen/native/mps/OperationUtils.h>
#include <algorithm>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@ -89,13 +90,21 @@ static void check_min_max_dims(const OptionalTensorRef clamp_opt, const Tensor&
auto clamp_shape = clamp_opt->sizes();
auto input_shape = input_t.sizes();
TORCH_CHECK(num_clamp_dims <= num_input_dims,
op_name + ": clamp tensor number of dims must not be greater than that of input tensor")
if (num_clamp_dims > num_input_dims) {
auto leading_dims = num_clamp_dims - num_input_dims;
for (int64_t i = 0; i < leading_dims; ++i) {
TORCH_CHECK(clamp_shape[i] == 1,
op_name + ": clamp tensor leading shape must be 1 to broadcast with input tensor");
}
}
for (int i = 0; i < num_clamp_dims; i++)
auto clamp_idx = num_clamp_dims - 1;
auto input_idx = num_input_dims - 1;
auto common_dims = std::min(num_clamp_dims, num_input_dims);
for (int64_t i = 0; i < common_dims; ++i)
// One of the indices is allowed to be 1; will be handled by broadcast
TORCH_CHECK(clamp_shape[num_clamp_dims - 1 - i] == input_shape[num_input_dims - 1 - i] ||
clamp_shape[num_clamp_dims - 1 - i] == 1 || input_shape[num_input_dims - 1 - i] == 1,
TORCH_CHECK(clamp_shape[clamp_idx - i] == input_shape[input_idx - i] || clamp_shape[clamp_idx - i] == 1 ||
input_shape[input_idx - i] == 1,
op_name + ": clamp tensor trailing shape must match input tensor")
}
}
@ -136,9 +145,6 @@ static void clamp_tensor_out_mps(const Tensor& input_t,
auto result_type = output_t.scalar_type();
IntArrayRef new_min_shape;
IntArrayRef new_max_shape;
auto num_min_dims = min_opt->dim();
auto num_max_dims = max_opt->dim();
auto num_input_dims = input_t.dim();
@ -146,24 +152,32 @@ static void clamp_tensor_out_mps(const Tensor& input_t,
std::vector<int64_t> new_min_arr(num_input_dims);
std::vector<int64_t> new_max_arr(num_input_dims);
if (has_min && num_min_dims < num_input_dims) {
fill_new_shape(num_input_dims, num_min_dims, new_min_arr.data(), min_opt->sizes());
new_min_shape = IntArrayRef(new_min_arr);
}
if (has_max && num_max_dims < num_input_dims) {
fill_new_shape(num_input_dims, num_max_dims, new_max_arr.data(), max_opt->sizes());
new_max_shape = IntArrayRef(new_max_arr);
}
Tensor min_opt_tensor;
Tensor max_opt_tensor;
auto reshape_clamp_tensor = [&](const OptionalTensorRef clamp_tensor_ref,
int64_t num_clamp_dims,
std::vector<int64_t>& new_shape_storage) -> Tensor {
IntArrayRef clamp_shape = clamp_tensor_ref->sizes();
bool requires_view = false;
if (num_clamp_dims > num_input_dims) {
clamp_shape = clamp_shape.slice(num_clamp_dims - num_input_dims);
requires_view = true;
} else if (num_clamp_dims < num_input_dims) {
fill_new_shape(num_input_dims, num_clamp_dims, new_shape_storage.data(), clamp_shape);
clamp_shape = IntArrayRef(new_shape_storage);
requires_view = true;
}
return requires_view ? (*clamp_tensor_ref).view(clamp_shape) : *clamp_tensor_ref;
};
if (has_min) {
min_opt_tensor = (num_min_dims < num_input_dims) ? (*min_opt).view(new_min_shape) : *min_opt;
min_opt_tensor = reshape_clamp_tensor(min_opt, num_min_dims, new_min_arr);
}
if (has_max) {
max_opt_tensor = (num_max_dims < num_input_dims) ? (*max_opt).view(new_max_shape) : *max_opt;
max_opt_tensor = reshape_clamp_tensor(max_opt, num_max_dims, new_max_arr);
}
@autoreleasepool {
@ -244,8 +258,8 @@ static void clamp_scalar_out_mps(const Tensor& input_t,
@autoreleasepool {
// the optional min/max refs could affect how we build the cached graph
std::string key = op_name + (has_min ? ("_min:" + std::to_string(min_scalar)) : "") +
(has_max ? ("_max:" + std::to_string(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
std::string key = op_name + (has_min ? ("_min:" + to_hex_key(min_scalar)) : "") +
(has_max ? ("_max:" + to_hex_key(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
if (has_min)
newCachedGraph->minTensor = [mpsGraph constantWithScalar:min_scalar

View File

@ -61,6 +61,7 @@ list(APPEND ATen_CUDA_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cuda_complex_math_test.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda_complex_test.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda_cub_test.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda_cublas_handle_pool_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cuda_device_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cuda_distributions_test.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda_dlconvertor_test.cpp

View File

@ -0,0 +1,77 @@
#include <gtest/gtest.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAGuard.h>
#include <atomic>
#include <thread>
#include <vector>
// Test concurrent access to getCurrentCUDABlasHandle and getCUDABlasLtWorkspace
// to verify that the data race fix is working correctly
TEST(CUDABlasHandlePoolTest, ConcurrentGetAndClearWorkspaces) {
if (!at::cuda::is_available()) {
return;
}
constexpr int num_accessor_threads = 15;
constexpr int num_clear_threads = 5;
constexpr int iterations_per_thread = 50;
std::atomic<bool> stop{false};
std::atomic<int> error_count{0};
std::vector<std::thread> threads;
threads.reserve(num_accessor_threads + num_clear_threads);
// Launch accessor threads
for (int i = 0; i < num_accessor_threads; ++i) {
threads.emplace_back([&stop, &error_count]() {
try {
at::cuda::CUDAGuard device_guard(0);
while (!stop.load(std::memory_order_relaxed)) {
const auto handle = at::cuda::getCurrentCUDABlasHandle();
const auto workspace = at::cuda::getCUDABlasLtWorkspace();
if (handle == nullptr || workspace == nullptr) {
error_count++;
}
}
} catch (const std::exception& e) {
error_count++;
}
});
}
// Launch threads that clear workspaces
for (int i = 0; i < num_clear_threads; ++i) {
threads.emplace_back([&error_count]() {
try {
for (int j = 0; j < iterations_per_thread; ++j) {
at::cuda::clearCublasWorkspaces();
std::this_thread::yield();
}
} catch (const std::exception& e) {
error_count++;
}
});
}
// Let them run for a bit
std::this_thread::sleep_for(std::chrono::milliseconds(100));
stop.store(true, std::memory_order_relaxed);
for (auto& thread : threads) {
thread.join();
}
EXPECT_EQ(error_count.load(), 0);
}
int main(int argc, char* argv[]) {
::testing::InitGoogleTest(&argc, argv);
c10::cuda::CUDACachingAllocator::init(1);
return RUN_ALL_TESTS();
}

View File

@ -10,6 +10,13 @@
...
}
{
ignore_empty_generic_uninitialised_conditional_jump
Memcheck:Cond
fun:_ZN2at6detail13empty_genericEN3c108ArrayRefIlEEPNS1_9AllocatorENS1_14DispatchKeySetENS1_10ScalarTypeESt8optionalINS1_12MemoryFormatEE
...
}
{
Cond_cuda
Memcheck:Cond

View File

@ -189,6 +189,10 @@ skip:
- hf_Whisper
- hf_distil_whisper
- timm_vision_transformer_large
# https://github.com/pytorch/pytorch/issues/167895
- stable_diffusion
- stable_diffusion_text_encoder
- stable_diffusion_unet
device:
cpu:

View File

@ -2,6 +2,7 @@
# These load paths point to different files in internal and OSS environment
load("@bazel_skylib//lib:paths.bzl", "paths")
load("//tools/build_defs:cell_defs.bzl", "get_fbsource_cell")
load("//tools/build_defs:fb_native_wrapper.bzl", "fb_native")
load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library")
load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule")
@ -590,6 +591,9 @@ def pt_operator_query_codegen(
pt_allow_forced_schema_registration = True,
compatible_with = [],
apple_sdks = None):
if get_fbsource_cell() == "fbcode":
return
oplist_dir_name = name + "_pt_oplist"
# @lint-ignore BUCKLINT
@ -865,6 +869,9 @@ def define_buck_targets(
pt_xplat_cxx_library = fb_xplat_cxx_library,
c2_fbandroid_xplat_compiler_flags = [],
labels = []):
if get_fbsource_cell() == "fbcode":
return
# @lint-ignore BUCKLINT
fb_native.filegroup(
name = "metal_build_srcs",

View File

@ -44,7 +44,7 @@ struct C10_API SafePyObject {
(*other.pyinterpreter_)->incref(other.data_);
}
if (data_ != nullptr) {
(*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
(*pyinterpreter_)->decref(data_);
}
data_ = other.data_;
pyinterpreter_ = other.pyinterpreter_;
@ -53,7 +53,7 @@ struct C10_API SafePyObject {
~SafePyObject() {
if (data_ != nullptr) {
(*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
(*pyinterpreter_)->decref(data_);
}
}

View File

@ -34,20 +34,6 @@ namespace c10 {
// See [dtype Macros note] in torch/headeronly/core/ScalarType.h
// regarding macros.
template <typename T>
struct CppTypeToScalarType;
#define SPECIALIZE_CppTypeToScalarType(cpp_type, scalar_type) \
template <> \
struct CppTypeToScalarType<cpp_type> \
: std:: \
integral_constant<c10::ScalarType, c10::ScalarType::scalar_type> { \
};
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
#undef SPECIALIZE_CppTypeToScalarType
#define DEFINE_CONSTANT(_, name) \
constexpr ScalarType k##name = ScalarType::name;

View File

@ -48,6 +48,30 @@ void warnDeprecatedDataPtr() {
TORCH_CHECK(false, "Cannot access data pointer of Storage that is invalid.");
}
void StorageImpl::incref_pyobject() const {
// Because intrusive_ptr incref uses relaxed memory order, we need to
// do an acquire fence to ensure that the kHasPyObject bit was
// observed before the load of the PyObject* below.
// NB: This is a no-op on x86/x86-64
std::atomic_thread_fence(std::memory_order_acquire);
PyObject* obj = pyobj_slot_.load_pyobj();
(*pyobj_slot_.pyobj_interpreter())->incref(obj);
}
void StorageImpl::decref_pyobject() const {
PyObject* obj = pyobj_slot_.load_pyobj();
(*pyobj_slot_.pyobj_interpreter())->decref(obj);
}
bool StorageImpl::try_incref_pyobject() const {
c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter();
if (C10_UNLIKELY(!interp)) {
return false;
}
return (*interp)->try_incref(pyobj_slot_);
}
void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr) {
// Allowlist verification.
// Only if the devicetype is in the allowlist,

View File

@ -105,6 +105,12 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
data_ptr_.clear();
}
void incref_pyobject() const override final;
void decref_pyobject() const override final;
bool try_incref_pyobject() const override final;
size_t nbytes() const {
// OK to do this instead of maybe_as_int as nbytes is guaranteed positive
TORCH_CHECK(!size_bytes_is_heap_allocated_);
@ -370,4 +376,18 @@ C10_API c10::intrusive_ptr<c10::StorageImpl> make_storage_impl(
bool resizable,
std::optional<at::Device> device_opt);
namespace detail {
#ifndef C10_MOBILE
template <class T>
struct TargetTraits<
T,
std::enable_if_t<
std::is_base_of_v<c10::StorageImpl, std::remove_cv_t<T>>>> {
static constexpr bool can_have_pyobject = true;
};
#endif
} // namespace detail
} // namespace c10

View File

@ -277,7 +277,6 @@ void TensorImpl::release_resources() {
if (storage_) {
storage_ = {};
}
pyobj_slot_.maybe_destroy_pyobj();
}
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
@ -989,6 +988,30 @@ void TensorImpl::empty_tensor_restride_symint(MemoryFormat memory_format) {
}
}
void TensorImpl::incref_pyobject() const {
// Because intrusive_ptr incref uses relaxed memory order, we need to
// do an acquire fence to ensure that the kHasPyObject bit was
// observed before the load of the PyObject* below.
// NB: This is a no-op on x86/x86-64
std::atomic_thread_fence(std::memory_order_acquire);
PyObject* obj = pyobj_slot_.load_pyobj();
(*pyobj_slot_.pyobj_interpreter())->incref(obj);
}
void TensorImpl::decref_pyobject() const {
PyObject* obj = pyobj_slot_.load_pyobj();
(*pyobj_slot_.pyobj_interpreter())->decref(obj);
}
bool TensorImpl::try_incref_pyobject() const {
c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter();
if (C10_UNLIKELY(!interp)) {
return false;
}
return (*interp)->try_incref(pyobj_slot_);
}
namespace impl {
namespace {

View File

@ -2178,6 +2178,12 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return &pyobj_slot_;
}
void incref_pyobject() const override final;
void decref_pyobject() const override final;
bool try_incref_pyobject() const override final;
private:
// See NOTE [std::optional operator usage in CUDA]
// We probably don't want to expose this publicly until
@ -3079,6 +3085,19 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
friend class C10_TensorImpl_Size_Check_Dummy_Class;
};
namespace detail {
#ifndef C10_MOBILE
template <class T>
struct TargetTraits<
T,
std::enable_if_t<std::is_base_of_v<c10::TensorImpl, std::remove_cv_t<T>>>> {
static constexpr bool can_have_pyobject = true;
};
#endif
} // namespace detail
// Note [TensorImpl size constraints]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Changed the size of TensorImpl? If the size went down, good for

View File

@ -11,8 +11,11 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
void incref(PyObject* pyobj) const override {} // do nothing
void decref(PyObject* pyobj, bool has_pyobj_slot) const override {
} // do nothing
void decref(PyObject* pyobj) const override {} // do nothing
bool try_incref(const c10::impl::PyObjectSlot& pyobj_slot) const override {
return false;
}
#define PANIC(m) \
TORCH_INTERNAL_ASSERT( \
@ -20,6 +23,10 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
"attempted to call " #m \
" on a Tensor with nontrivial PyObject after corresponding interpreter died")
size_t refcnt(PyObject* pyobj) const override {
PANIC(refcnt);
}
c10::intrusive_ptr<TensorImpl> detach(const TensorImpl* self) const override {
PANIC(detach);
}

View File

@ -18,6 +18,9 @@ namespace c10 {
struct IValue;
class OperatorHandle;
struct TensorImpl;
namespace impl {
struct PyObjectSlot;
} // namespace impl
} // namespace c10
namespace torch::jit {
@ -126,9 +129,12 @@ struct C10_API PyInterpreterVTable {
// Run Py_INCREF on a PyObject.
virtual void incref(PyObject* pyobj) const = 0;
// Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call
// See NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg]
virtual void decref(PyObject* pyobj, bool has_pyobj_slot) const = 0;
// Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call.
virtual void decref(PyObject* pyobj) const = 0;
// Run PyUnstable_TryIncRef on a PyObject if it's not NULL.
virtual bool try_incref(const c10::impl::PyObjectSlot& pyobj_slot) const = 0;
// Run Py_REFCNT on a PyObject.
virtual size_t refcnt(PyObject* pyobj) const = 0;
// Perform a detach by deferring to the __torch_dispatch__ implementation of
// detach, which will also arrange for the PyObject to get copied in this

View File

@ -1,56 +0,0 @@
#include <c10/core/impl/PyObjectSlot.h>
namespace c10::impl {
PyObjectSlot::PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {}
PyObjectSlot::~PyObjectSlot() {
maybe_destroy_pyobj();
}
void PyObjectSlot::maybe_destroy_pyobj() {
if (owns_pyobj()) {
TORCH_INTERNAL_ASSERT(pyobj_interpreter_ != nullptr);
TORCH_INTERNAL_ASSERT(pyobj_ != nullptr);
(*pyobj_interpreter_.load(std::memory_order_acquire))
->decref(_unchecked_untagged_pyobj(), /*has_pyobj_slot*/ true);
// NB: this destructor can only be entered when there are no
// references to this C++ object (obviously), NOR any references
// to the PyObject (if there are references to the PyObject,
// then the PyObject holds an owning reference to the tensor).
// So it is OK to clear pyobj_ here as it is impossible for it to
// be used again (modulo weak reference races)
pyobj_ = nullptr; // for safety
}
}
PyInterpreter* PyObjectSlot::pyobj_interpreter() {
return pyobj_interpreter_.load(std::memory_order_acquire);
}
PyObject* PyObjectSlot::_unchecked_untagged_pyobj() const {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return reinterpret_cast<PyObject*>(
reinterpret_cast<uintptr_t>(pyobj_) & ~0x1ULL);
}
PyInterpreter& PyObjectSlot::load_pyobj_interpreter() const {
auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire);
if (interpreter) {
return *interpreter;
}
TORCH_CHECK(false, "cannot access PyObject for Tensor - no interpreter set");
}
bool PyObjectSlot::owns_pyobj() {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return reinterpret_cast<uintptr_t>(pyobj_) & 1;
}
void PyObjectSlot::set_owns_pyobj(bool b) {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
pyobj_ = reinterpret_cast<PyObject*>(
reinterpret_cast<uintptr_t>(_unchecked_untagged_pyobj()) | b);
}
} // namespace c10::impl

View File

@ -8,117 +8,58 @@
#include <atomic>
namespace torch::utils {
class PyObjectPreservation;
}
namespace c10::impl {
struct C10_API PyObjectSlot {
public:
PyObjectSlot();
~PyObjectSlot();
void maybe_destroy_pyobj();
// Associate the TensorImpl with the specified PyObject, and, if necessary,
// also tag the interpreter.
//
// NB: This lives in a header so that we can inline away the switch on status
//
// NB: THIS FUNCTION CAN RAISE AN EXCEPTION. Make sure to clean up after
// PyObject if necessary!
void init_pyobj(PyObject* pyobj) {
pyobj_interpreter_.store(
getGlobalPyInterpreter(), std::memory_order_relaxed);
pyobj_ = pyobj;
}
PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {}
// Query the PyObject interpreter. This may return null if there is no
// interpreter. This is racy!
PyInterpreter* pyobj_interpreter();
PyObject* _unchecked_untagged_pyobj() const;
// Test the interpreter tag. If tagged for the current interpreter, return
// a non-nullopt (but possibly null) PyObject. If (possibly) untagged,
// returns a nullopt. If it is definitely invalid, raises an error.
//
// If `ignore_hermetic_tls` is false and this function is called from a
// hermetic context (ie, `HermeticPyObjectTLS::get_state()` is true), then
// nullopt is returned. If `ignore_hermetic_tls` is true, then the hermetic
// context is ignored, allowing you to check the interpreter tag of a
// nonhermetic PyObject from within a hermetic context. This is necessary
// because there are some cases where the deallocator function of a
// nonhermetic PyObject is called from within a hermetic context, so it must
// be properly treated as a nonhermetic PyObject.
//
// NB: this lives in header so that we can avoid actually creating the
// std::optional
// @todo alban: I'm not too sure what's going on here, we can probably delete
// it but it's worthwhile making sure
std::optional<PyObject*> check_pyobj(bool ignore_hermetic_tls = false) const {
impl::PyInterpreter* interpreter =
pyobj_interpreter_.load(std::memory_order_acquire);
if (interpreter == nullptr) {
return std::nullopt;
}
if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) {
return std::nullopt;
} else {
return _unchecked_untagged_pyobj();
}
// interpreter.
PyInterpreter* pyobj_interpreter() const {
return pyobj_interpreter_.load(std::memory_order_acquire);
}
PyInterpreter& load_pyobj_interpreter() const;
PyInterpreter& load_pyobj_interpreter() const {
auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire);
TORCH_INTERNAL_ASSERT(
interpreter, "cannot access PyObject for Tensor - no interpreter set");
return *interpreter;
}
bool owns_pyobj();
PyObject* load_pyobj() const {
return pyobj_.load(std::memory_order_acquire);
}
void set_owns_pyobj(bool b);
void store_pyobj(PyObject* obj) {
pyobj_.store(obj, std::memory_order_release);
}
bool has_unique_reference() const {
PyObject* pyobj = load_pyobj();
return pyobj != nullptr && load_pyobj_interpreter()->refcnt(pyobj) == 1;
}
void clear() {
pyobj_.store(nullptr, std::memory_order_relaxed);
pyobj_interpreter_.store(nullptr, std::memory_order_relaxed);
}
private:
// This field contains the interpreter tag for this object. See
// Note [Python interpreter tag] for general context
//
// Note [Memory ordering on Python interpreter tag]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// What memory_order do we need when accessing this atomic? We don't
// need a single total modification order (as provided by
// memory_order_seq_cst) as pyobj_interpreter_ is monotonic: it can only
// transition from -1 to some positive integer and never changes afterwards.
// Because there is only one modification, it trivially already has a total
// modification order (e.g., we don't need fences or locked instructions on
// x86)
//
// In fact, one could make a reasonable argument that relaxed reads are OK,
// due to the presence of external locking (GIL) to ensure that interactions
// with other data structures are still correctly synchronized, so that
// we fall in the "Single-Location Data Structures" case as described in
// http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf
// However, on x86, it doesn't matter if I use acquire or relaxed on the load
// as I get the same assembly in both cases. So I just use the more
// conservative acquire (which will impede compiler optimizations but I don't
// care)
// This is now always the global interpreter if the PyObject is set.
// Maybe we can remove this field some day...
std::atomic<PyInterpreter*> pyobj_interpreter_;
// This field contains a reference to a PyObject representing this Tensor.
// If pyobj is nullptr, when we transfer Tensor to Python, we allocate a new
// PyObject for it and set this field. This field does not have to be
// protected by an atomic as it is only allowed to be accessed when you hold
// the GIL, or during destruction of the tensor.
//
// When a PyObject dies, you are obligated to clear this field
// (otherwise, you will try to use-after-free the pyobj); this currently
// occurs in THPVariable_clear in torch/csrc/autograd/python_variable.cpp
//
// NB: Ordinarily, this should not be a strong reference, as if the
// PyObject owns the Tensor, this would create a reference cycle.
// However, sometimes this ownership flips. To track who owns
// who, this has a single pointer tag indicating whether or not the
// C++ object owns the PyObject (the common case, zero, means PyObject
// owns the C++ object); see _unchecked_untagged_pyobj for raw access
// or check_pyobj for checked access. See references to PyObject
// resurrection in torch/csrc/autograd/python_variable.cpp
PyObject* pyobj_;
// The PyObject representing this Tensor or nullptr. Ownership is managed
// by intrusive_ptr. By the time the PyObjectSlot is destroyed, this
// reference is already dead.
std::atomic<PyObject*> pyobj_;
friend class torch::utils::PyObjectPreservation;
};
} // namespace c10::impl

View File

@ -50,7 +50,13 @@ namespace c10 {
/// However, you should prefer to use ArrayRef when possible, because its use
/// of TORCH_CHECK will lead to better user-facing error messages.
template <typename T>
class ArrayRef final : public HeaderOnlyArrayRef<T> {
// ArrayRef cannot be derived from. Normally, we would use `final`
// specifier to force this constraint at compile time. However, Intel
// compiler does not recognize ArrayRef as a class template (which is
// required in the definition of at::TensorAccessor, for instance)
// when `final` specifier is used. So, we cannot define ArrayRef as
// final because of the Intel compiler issue.
class ArrayRef : public HeaderOnlyArrayRef<T> {
public:
/// @name Constructors, all inherited from HeaderOnlyArrayRef except for
/// SmallVector. As inherited constructors won't work with class template

View File

@ -379,7 +379,11 @@ C10_API std::string GetExceptionString(const std::exception& e);
// ----------------------------------------------------------------------------
#ifdef STRIP_ERROR_MESSAGES
#define TORCH_RETHROW(e, ...) throw
#define TORCH_RETHROW(e, ...) \
do { \
(void)e; /* Suppress unused variable warning */ \
throw; \
} while (false)
#else
#define TORCH_RETHROW(e, ...) \
do { \

View File

@ -12,6 +12,10 @@ template <typename, typename...>
class class_;
}
namespace torch::utils {
class PyObjectPreservation;
}
namespace c10 {
class intrusive_ptr_target;
namespace raw {
@ -33,6 +37,8 @@ constexpr uint64_t kImpracticallyHugeWeakReferenceCount =
constexpr uint64_t kReferenceCountOne = 1;
constexpr uint64_t kWeakReferenceCountOne = (kReferenceCountOne << 32);
constexpr uint64_t kUniqueRef = (kReferenceCountOne | kWeakReferenceCountOne);
// Indicates whether the object has a PyObject wrapper.
constexpr uint64_t kHasPyObject = (uint64_t(1) << 63);
template <class TTarget>
struct intrusive_target_default_null_type final {
@ -55,7 +61,11 @@ inline uint32_t refcount(uint64_t combined_refcount) {
}
inline uint32_t weakcount(uint64_t combined_refcount) {
return static_cast<uint32_t>(combined_refcount >> 32);
return static_cast<uint32_t>((combined_refcount & ~kHasPyObject) >> 32);
}
inline bool has_pyobject(uint64_t combined_refcount) {
return (combined_refcount & kHasPyObject) != 0;
}
// The only requirement for refcount increment is that it happens-before
@ -66,12 +76,6 @@ inline uint64_t atomic_combined_refcount_increment(
return combined_refcount.fetch_add(inc, std::memory_order_relaxed) + inc;
}
inline uint32_t atomic_refcount_increment(
std::atomic<uint64_t>& combined_refcount) {
return detail::refcount(atomic_combined_refcount_increment(
combined_refcount, kReferenceCountOne));
}
inline uint32_t atomic_weakcount_increment(
std::atomic<uint64_t>& combined_refcount) {
return detail::weakcount(atomic_combined_refcount_increment(
@ -99,6 +103,11 @@ inline uint32_t atomic_weakcount_decrement(
combined_refcount, kWeakReferenceCountOne));
}
template <class T, class = void>
struct TargetTraits {
static constexpr bool can_have_pyobject = false;
};
} // namespace detail
/**
@ -155,6 +164,23 @@ class C10_API intrusive_ptr_target {
// we can atomically operate on both at the same time for performance
// and defined behaviors.
//
// Note [PyObject preservation for Tensor and Storages]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// intrusive_ptr has special support for preserving PyObject wrappers
// for TensorImpl and StorageImpl. The most significant bit (kHasPyObject) of
// the combined_refcount_ is used to indicate whether the object has a
// PyObject wrapper.
//
// - The PyObject, if it exists, holds a strong reference to the
// intrusive_ptr_target.
//
// - When the refcount goes from 1 to 2, we incref the PyObject.
//
// - When the refcount goes from 2 to 1, we decref the PyObject.
//
// In other words, the intrusive_ptr keeps the PyObject alive as long as there
// are other C++ references to the intrusive_ptr_target.
mutable std::atomic<uint64_t> combined_refcount_;
static_assert(sizeof(std::atomic<uint64_t>) == 8);
static_assert(alignof(std::atomic<uint64_t>) == 8);
@ -172,6 +198,8 @@ class C10_API intrusive_ptr_target {
template <typename T>
friend struct ExclusivelyOwnedTensorTraits;
friend class torch::utils::PyObjectPreservation;
protected:
// protected destructor. We never want to destruct intrusive_ptr_target*
// directly.
@ -255,6 +283,16 @@ class C10_API intrusive_ptr_target {
*/
virtual void release_resources() {}
/**
* These two methods are called when the refcount transitions between one
* and two and the object has a PyObject wrapper.
*/
virtual void incref_pyobject() const {}
virtual void decref_pyobject() const {}
virtual bool try_incref_pyobject() const {
return false;
}
uint32_t refcount(std::memory_order order = std::memory_order_relaxed) const {
return detail::refcount(combined_refcount_.load(order));
}
@ -265,6 +303,19 @@ class C10_API intrusive_ptr_target {
}
};
namespace detail {
#ifndef C10_MOBILE
template <>
struct TargetTraits<c10::intrusive_ptr_target> {
// A generic intrusive_ptr<intrusive_ptr_target> may actually be a TensorImpl
// or StorageImpl, so we have to allow for PyObject support.
static constexpr bool can_have_pyobject = true;
};
#endif
} // namespace detail
template <class TTarget, class NullType>
class weak_intrusive_ptr;
@ -314,18 +365,34 @@ class intrusive_ptr final {
void retain_() {
if (target_ != NullType::singleton()) {
uint32_t new_refcount =
detail::atomic_refcount_increment(target_->combined_refcount_);
uint64_t combined = detail::atomic_combined_refcount_increment(
target_->combined_refcount_, detail::kReferenceCountOne);
uint32_t new_refcount = detail::refcount(combined);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
new_refcount != 1,
"intrusive_ptr: Cannot increase refcount after it reached zero.");
if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
// If the refcount transitioned from 1 to 2, we need to incref the
// PyObject. In other words, we need to ensure that the PyObject stays
// alive now that we have a C++ reference to this object in addition to
// the PyObject itself.
if (C10_UNLIKELY(
detail::has_pyobject(combined) &&
detail::refcount(combined) == 2)) {
target_->incref_pyobject();
}
} else {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
!detail::has_pyobject(combined),
"TargetTraits indicates that type cannot have PyObject, but refcount has PyObject bit set.");
}
}
}
void reset_() noexcept {
if (target_ != NullType::singleton()) {
if (target_->combined_refcount_.load(std::memory_order_acquire) ==
detail::kUniqueRef) {
if (is_uniquely_owned()) {
// Both counts are 1, so there are no weak references and
// we are releasing the last strong reference. No other
// threads can observe the effects of this target_ deletion
@ -337,9 +404,10 @@ class intrusive_ptr final {
auto combined_refcount = detail::atomic_combined_refcount_decrement(
target_->combined_refcount_, detail::kReferenceCountOne);
if (detail::refcount(combined_refcount) == 0) {
bool should_delete =
(combined_refcount == detail::kWeakReferenceCountOne);
uint32_t new_refcount = detail::refcount(combined_refcount);
bool has_pyobject = detail::has_pyobject(combined_refcount);
if (new_refcount == 0) {
bool should_delete = detail::weakcount(combined_refcount) == 1;
// See comment above about weakcount. As long as refcount>0,
// weakcount is one larger than the actual number of weak references.
// So we need to decrement it here.
@ -356,6 +424,18 @@ class intrusive_ptr final {
if (should_delete) {
delete target_;
}
} else if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
// If the refcount transitioned from 2 to 1, we need to decref the
// PyObject. In other words, we don't want to keep the PyObject alive if
// there are no C++ references to this object other than the PyObject
// itself.
if (C10_UNLIKELY(has_pyobject && new_refcount == 1)) {
target_->decref_pyobject();
}
} else {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
!has_pyobject,
"TargetTraits indicates that type cannot have PyObject, but refcount has PyObject bit set.");
}
}
}
@ -522,6 +602,16 @@ class intrusive_ptr final {
return use_count() == 1;
}
/**
* Stronger than unique() in that it must not have any weakrefs as well.
*/
bool is_uniquely_owned() const noexcept {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(target_ != NullType::singleton());
uint64_t combined =
target_->combined_refcount_.load(std::memory_order_acquire);
return (combined & ~detail::kHasPyObject) == detail::kUniqueRef;
}
/**
* Returns an owning (!) pointer to the underlying object and makes the
* intrusive_ptr instance invalid. That means the refcount is not decreased.
@ -932,6 +1022,7 @@ class weak_intrusive_ptr final {
if (target_ == NullType::singleton()) {
return intrusive_ptr<TTarget, NullType>();
} else {
bool increfed = false;
auto combined_refcount =
target_->combined_refcount_.load(std::memory_order_relaxed);
do {
@ -940,12 +1031,31 @@ class weak_intrusive_ptr final {
// Return nullptr.
return intrusive_ptr<TTarget, NullType>();
}
if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
if (detail::has_pyobject(combined_refcount) &&
detail::refcount(combined_refcount) == 1 && !increfed) {
// Object has a python wrapper with no other C++ references.
// We need to to incref the Python object before we acquire a
// strong reference to the C++ object to avoid a situation
// where the Python object is deallocated concurrently.
if (!target_->try_incref_pyobject()) {
return intrusive_ptr<TTarget, NullType>();
}
increfed = true;
}
}
} while (!target_->combined_refcount_.compare_exchange_weak(
combined_refcount,
combined_refcount + detail::kReferenceCountOne,
std::memory_order_acquire,
std::memory_order_relaxed));
if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
if (increfed && detail::refcount(combined_refcount) != 1) {
target_->decref_pyobject();
}
}
return intrusive_ptr<TTarget, NullType>(
target_, raw::DontIncreaseRefcount{});
}
@ -1060,7 +1170,18 @@ namespace intrusive_ptr {
// NullType::singleton to this function
inline void incref(intrusive_ptr_target* self) {
if (self) {
detail::atomic_refcount_increment(self->combined_refcount_);
uint64_t combined = detail::atomic_combined_refcount_increment(
self->combined_refcount_, detail::kReferenceCountOne);
#ifndef C10_MOBILE
if (C10_UNLIKELY(
detail::has_pyobject(combined) &&
detail::refcount(combined) == 2)) {
self->incref_pyobject();
}
#else
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!detail::has_pyobject(combined));
#endif
}
}

View File

@ -15,6 +15,8 @@ using namespace c10::CachingDeviceAllocator;
// newly allocated memory with 512-byte alignment.
constexpr size_t kDeviceAlignment = 512;
class XPUAllocator;
namespace {
using stream_set = ska::flat_hash_set<xpu::XPUStream>;
@ -23,14 +25,19 @@ typedef bool (*Comparison)(const Block*, const Block*);
bool BlockComparatorSize(const Block* a, const Block* b);
bool BlockComparatorAddress(const Block* a, const Block* b);
struct PrivatePool;
struct BlockPool {
BlockPool(bool small)
BlockPool(bool small, PrivatePool* private_pool = nullptr)
: blocks(BlockComparatorSize),
unmapped(BlockComparatorAddress),
is_small(small) {}
is_small(small),
owner_PrivatePool(private_pool) {}
std::set<Block*, Comparison> blocks;
std::set<Block*, Comparison> unmapped;
const bool is_small;
PrivatePool* owner_PrivatePool;
};
struct ExpandableSegment;
@ -349,6 +356,43 @@ struct AllocParams {
StatTypes stat_types = {};
};
// Internal implementation that manages actual memory blocks.
// high level MemPool interface wraps PrivatePool via MempoolId.
struct PrivatePool {
PrivatePool(MempoolId_t id, XPUAllocator* allocator = nullptr)
: id(std::move(id)),
allocator_(allocator),
large_blocks(/*small=*/false, this),
small_blocks(/*small=*/true, this) {}
PrivatePool(const PrivatePool&) = delete;
PrivatePool(PrivatePool&&) = delete;
PrivatePool& operator=(const PrivatePool&) = delete;
PrivatePool& operator=(PrivatePool&&) = delete;
~PrivatePool() = default;
// default Mempool when no Mempool is specified
MempoolId_t id{0, 0};
// Number of live graphs using this pool
int use_count{1};
// Number of unfreed allocations made for this pool. When use_count and
// allocation_count drop to zero, we can delete this PrivatePool from
// graph_pools.
int allocation_count{0};
XPUAllocator* allocator_;
BlockPool large_blocks;
BlockPool small_blocks;
public:
XPUAllocator* allocator() {
return allocator_;
}
};
struct MempoolIdHash {
std::size_t operator()(const MempoolId_t& mempool_id) const noexcept {
return mempool_id.first != 0 ? mempool_id.first : mempool_id.second;
}
};
} // anonymous namespace
class DeviceCachingAllocator {
@ -365,6 +409,13 @@ class DeviceCachingAllocator {
bool set_fraction = false;
std::vector<ExpandableSegment*> expandable_segments;
std::vector<c10::DeviceIndex> devices_with_peer_access; // reserved
std::vector<std::pair<MempoolId_t, std::function<bool(sycl::queue*)>>>
captures_underway;
ska::flat_hash_map<MempoolId_t, std::unique_ptr<PrivatePool>, MempoolIdHash>
graph_pools;
// Pools no longer referenced by any graph.
ska::flat_hash_map<MempoolId_t, PrivatePool*, MempoolIdHash>
graph_pools_freeable;
size_t try_merge_blocks(Block* dst, Block* src, BlockPool& pool) {
if (!src || src->allocated || src->event_count > 0 ||
@ -463,7 +514,22 @@ class DeviceCachingAllocator {
}
}
BlockPool& get_pool(size_t size) {
BlockPool& get_pool(size_t size, sycl::queue* queue) {
if (C10_UNLIKELY(!captures_underway.empty())) {
for (auto& entry : captures_underway) {
// lookup for mempool id matching current capture graph
if (entry.second(queue)) {
auto it1 = graph_pools.find(entry.first);
// lookup mempool
TORCH_INTERNAL_ASSERT(it1 != graph_pools.end());
if (size <= kSmallSize) {
return it1->second->small_blocks;
} else {
return it1->second->large_blocks;
}
}
}
}
if (size < kSmallSize) {
return small_blocks;
} else {
@ -669,6 +735,10 @@ class DeviceCachingAllocator {
if (!ptr) {
return false;
}
if (p.pool->owner_PrivatePool) {
p.pool->owner_PrivatePool->allocation_count++;
}
p.block = new Block(device, p.queue(), size, p.pool, ptr);
for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) {
stats.reserved_bytes[stat_type].increase(size);
@ -677,11 +747,14 @@ class DeviceCachingAllocator {
return true;
}
void synchronize_and_free_events() {
void synchronize_and_free_events(PrivatePool* pool = nullptr) {
for (auto& xe : xpu_events) {
for (auto& e : xe.second) {
auto event = e.first;
auto* block = e.second;
if (pool && block->pool->owner_PrivatePool != pool) {
continue;
}
event.wait();
block->event_count--;
if (block->event_count == 0) {
@ -785,6 +858,13 @@ class DeviceCachingAllocator {
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
stats.reserved_bytes[stat_type].decrease(unmapped.size);
});
if (block->pool->owner_PrivatePool) {
// The Freed block belonged to a XPU graph's PrivatePool.
TORCH_INTERNAL_ASSERT(
block->pool->owner_PrivatePool->allocation_count > 0);
block->pool->owner_PrivatePool->allocation_count--;
}
}
void release_blocks(BlockPool& pool) {
@ -812,13 +892,41 @@ class DeviceCachingAllocator {
}
}
bool release_cached_blocks() {
synchronize_and_free_events();
// See Note [Safe to Free Blocks on BlockPool]
c10::xpu::syncStreamsOnDevice(device_index);
bool release_cached_blocks(MempoolId_t mempool_id) {
if (mempool_id.first == 0 && mempool_id.second == 0 &&
captures_underway.empty()) {
synchronize_and_free_events();
// See Note [Safe to Free Blocks on BlockPool]
c10::xpu::syncStreamsOnDevice(device_index);
release_blocks(large_blocks);
release_blocks(small_blocks);
release_blocks(large_blocks);
release_blocks(small_blocks);
}
for (auto it = graph_pools_freeable.begin();
it != graph_pools_freeable.end();) {
if (mempool_id.first != 0 || mempool_id.second != 0) {
if (it->first == mempool_id) {
// If there is an active mempool, we sync only the events
// associated with the pool
synchronize_and_free_events(it->second);
} else {
// otherwise we move on
++it;
continue;
}
}
TORCH_INTERNAL_ASSERT(it->second->use_count == 0);
release_blocks(it->second->small_blocks);
release_blocks(it->second->large_blocks);
if (it->second->allocation_count == 0) {
auto erase_count = graph_pools.erase(it->first);
TORCH_INTERNAL_ASSERT(erase_count == 1);
it = graph_pools_freeable.erase(it);
} else {
++it;
}
}
return true;
}
@ -903,6 +1011,30 @@ class DeviceCachingAllocator {
}
}
void create_or_incref_pool(
MempoolId_t mempool_id,
XPUAllocator* allocator = nullptr) {
auto it = graph_pools.find(mempool_id);
if (it == graph_pools.end()) {
// mempool_id does not reference an existing pool.
// Make a new pool for XPU graph capture or memory pool usage.
graph_pools.emplace(
mempool_id, std::make_unique<PrivatePool>(mempool_id, allocator));
} else {
// mempool_id references an existing pool, which the current XPU graph
// capture will share.
TORCH_INTERNAL_ASSERT(it->second->use_count > 0);
TORCH_INTERNAL_ASSERT(allocator == nullptr);
it->second->use_count++;
}
}
PrivatePool* get_private_pool(MempoolId_t mempool_id) {
auto it = graph_pools.find(mempool_id);
TORCH_INTERNAL_ASSERT(it != graph_pools.end());
return it->second.get();
}
public:
DeviceCachingAllocator(DeviceIndex device_index)
: large_blocks(/* small */ false),
@ -911,9 +1043,11 @@ class DeviceCachingAllocator {
Block* malloc(DeviceIndex device, size_t orig_size, sycl::queue& queue) {
std::scoped_lock<std::recursive_mutex> lock(mutex);
process_events();
if (C10_LIKELY(captures_underway.empty())) {
process_events();
}
size_t size = round_size(orig_size);
auto& pool = get_pool(size);
auto& pool = get_pool(size, &queue);
const size_t alloc_size = get_allocation_size(size);
AllocParams params(device, size, &queue, &pool, alloc_size);
params.stat_types = get_stat_types_for_pool(pool);
@ -923,7 +1057,7 @@ class DeviceCachingAllocator {
// Can't reuse an existing block, try to get a new one.
if (!block_found) {
block_found = alloc_block(params, false) ||
(release_cached_blocks() && alloc_block(params, true));
(release_cached_blocks({0, 0}) && alloc_block(params, true));
}
if (!block_found) {
const auto& raw_device = c10::xpu::get_raw_device(device);
@ -1016,9 +1150,9 @@ class DeviceCachingAllocator {
block->stream_uses.insert(stream);
}
void emptyCache() {
void emptyCache(MempoolId_t mempool_id) {
std::scoped_lock<std::recursive_mutex> lock(mutex);
release_cached_blocks();
release_cached_blocks(mempool_id);
}
DeviceStats getStats() {
@ -1172,9 +1306,9 @@ class XPUAllocator : public DeviceAllocator {
}
}
void emptyCache(MempoolId_t mempool_id [[maybe_unused]] = {0, 0}) override {
void emptyCache(MempoolId_t mempool_id) override {
for (auto& da : device_allocators) {
da->emptyCache();
da->emptyCache(mempool_id);
}
}
@ -1290,8 +1424,8 @@ void init(DeviceIndex device_count) {
return allocator.init(device_count);
}
void emptyCache() {
return allocator.emptyCache();
void emptyCache(MempoolId_t mempool_id) {
return allocator.emptyCache(mempool_id);
}
void resetPeakStats(DeviceIndex device) {

View File

@ -10,7 +10,7 @@ C10_XPU_API Allocator* get();
C10_XPU_API void init(DeviceIndex device_count);
C10_XPU_API void emptyCache();
C10_XPU_API void emptyCache(MempoolId_t mempool_id = {0, 0});
C10_XPU_API void resetPeakStats(DeviceIndex device);

View File

@ -1643,8 +1643,6 @@ if(USE_CUDA)
target_link_libraries(torch_cuda PUBLIC c10_cuda)
if(TARGET torch::nvtx3)
target_link_libraries(torch_cuda PRIVATE torch::nvtx3)
else()
target_link_libraries(torch_cuda PUBLIC torch::nvtoolsext)
endif()
target_include_directories(
@ -1741,9 +1739,6 @@ if(BUILD_SHARED_LIBS)
if(USE_CUDA)
target_link_libraries(torch_global_deps ${Caffe2_PUBLIC_CUDA_DEPENDENCY_LIBS})
target_link_libraries(torch_global_deps torch::cudart)
if(TARGET torch::nvtoolsext)
target_link_libraries(torch_global_deps torch::nvtoolsext)
endif()
endif()
install(TARGETS torch_global_deps DESTINATION "${TORCH_INSTALL_LIB_DIR}")
endif()

View File

@ -734,7 +734,7 @@ void PyTorchStreamWriter::setup(const string& file_name) {
file_name,
std::ofstream::out | std::ofstream::trunc | std::ofstream::binary
);
} catch (const std::ios_base::failure& e) {
} catch (const std::ios_base::failure&) {
#ifdef _WIN32
// Windows have verbose error code, we prefer to use it than std errno.
uint32_t error_code = GetLastError();
@ -773,8 +773,20 @@ void PyTorchStreamWriter::writeRecord(
bool compress) {
AT_ASSERT(!finalized_);
AT_ASSERT(!archive_name_plus_slash_.empty());
TORCH_INTERNAL_ASSERT(
files_written_.count(name) == 0, "Tried to serialize file twice: ", name);
if (files_written_.count(name) > 0) {
// Allow multiple writes for triton binaries
bool is_triton_extension =
c10::ends_with(name, ".so") ||
c10::ends_with(name, ".cubin") ||
c10::ends_with(name, ".hsaco");
if (is_triton_extension) {
LOG(WARNING) << "File '" << name << "' is being serialized multiple times";
return;
}
TORCH_INTERNAL_ASSERT(false, "Tried to serialize file twice: ", name);
}
if (name == kSerializationIdRecordName && serialization_id_.empty()) {
// In case of copying records from another file, skip writing a different
// serialization_id than the one computed in this writer.

View File

@ -118,11 +118,6 @@ if(INTERN_BUILD_ATEN_OPS)
list(APPEND _file_compile_flags "-gencode;arch=compute_120a,code=sm_120a")
endif()
endif()
if("${_arch}" STREQUAL "121a")
if(_existing_arch_flags MATCHES ".*compute_120.*")
list(APPEND _file_compile_flags "-gencode;arch=compute_121a,code=sm_121a")
endif()
endif()
endforeach()
list(JOIN _file_compile_flags " " _file_compile_flags)
@ -131,7 +126,7 @@ if(INTERN_BUILD_ATEN_OPS)
_BUILD_FOR_ADDITIONAL_ARCHS(
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu"
"89;90a;100a;103a;120a;121a")
"89;90a;100a;103a;120a")
_BUILD_FOR_ADDITIONAL_ARCHS(
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/ScaledGroupMM.cu"
"90a")

View File

@ -968,11 +968,8 @@ find_package_handle_standard_args(nvtx3 DEFAULT_MSG nvtx3_dir)
if(nvtx3_FOUND)
add_library(torch::nvtx3 INTERFACE IMPORTED)
target_include_directories(torch::nvtx3 INTERFACE "${nvtx3_dir}")
target_compile_definitions(torch::nvtx3 INTERFACE TORCH_CUDA_USE_NVTX3)
else()
message(WARNING "Cannot find NVTX3, find old NVTX instead")
add_library(torch::nvtoolsext INTERFACE IMPORTED)
set_property(TARGET torch::nvtoolsext PROPERTY INTERFACE_LINK_LIBRARIES CUDA::nvToolsExt)
message(FATAL_ERROR "Cannot find NVTX3!")
endif()

View File

@ -132,9 +132,6 @@ if(@USE_CUDA@)
else()
set(TORCH_CUDA_LIBRARIES ${CUDA_NVRTC_LIB})
endif()
if(TARGET torch::nvtoolsext)
list(APPEND TORCH_CUDA_LIBRARIES torch::nvtoolsext)
endif()
if(@BUILD_SHARED_LIBS@)
find_library(C10_CUDA_LIBRARY c10_cuda PATHS "${TORCH_INSTALL_PREFIX}/lib")

View File

@ -10,7 +10,7 @@ API. This API can roughly be divided into five parts:
- **TorchScript**: An interface to the TorchScript JIT compiler and interpreter.
- **C++ Extensions**: A means of extending the Python API with custom C++ and CUDA routines.
Combining, these building blocks form a research and
Combined, these building blocks form a research and
production ready C++ library for tensor computation and dynamic neural
networks with strong emphasis on GPU acceleration as well as fast CPU
performance. It is currently in use at Facebook in research and
@ -76,7 +76,7 @@ C++ Frontend
------------
The PyTorch C++ frontend provides a high level, pure C++ modeling interface for
neural network and general ML(Machine Learning) research and production use cases,
neural networks and general ML (Machine Learning) research and production use cases,
largely following the Python API in design and provided functionality. The C++
frontend includes the following:

View File

@ -254,7 +254,7 @@ To toggle the reduced precision reduction flags in C++, one can do
.. _fp16accumulation:
Full FP16 Accmumulation in FP16 GEMMs
Full FP16 Accumulation in FP16 GEMMs
-------------------------------------
Certain GPUs have increased performance when doing _all_ FP16 GEMM accumulation

View File

@ -30,5 +30,6 @@ For a quick overview of `torch.compiler`, see {ref}`torch.compiler_overview`.
skip_guard_on_all_nn_modules_unsafe
keep_tensor_guards_unsafe
skip_guard_on_globals_unsafe
skip_all_guards_unsafe
nested_compile_region
```

View File

@ -32,7 +32,7 @@ project-excludes = [
"torch/utils/tensorboard/summary.py",
# formatting issues, will turn on after adjusting where suppressions can be
# in import statements
"tools/flight_recorder/components/types.py",
"torch/distributed/flight_recorder/components/types.py",
"torch/linalg/__init__.py",
"torch/package/importer.py",
"torch/package/_package_pickler.py",

View File

@ -1358,45 +1358,6 @@ class concat_license_files:
# Need to create the proper LICENSE.txt for the wheel
class bdist_wheel(setuptools.command.bdist_wheel.bdist_wheel):
def _wrap_headers_with_macro(self, bdist_dir: Path) -> None:
"""Wrap all header files with #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION).
Excludes:
- torch/include/torch/headeronly/*
- torch/include/torch/csrc/stable/*
- torch/include/torch/csrc/inductor/aoti_torch/c/ (only shim headers)
- torch/include/torch/csrc/inductor/aoti_torch/generated/
"""
header_extensions = (".h", ".hpp", ".cuh")
header_files = [
f for ext in header_extensions for f in bdist_dir.rglob(f"*{ext}")
]
# Paths to exclude from wrapping
exclude_dir_patterns = [
"torch/include/torch/headeronly/",
"torch/include/torch/csrc/stable/",
"torch/include/torch/csrc/inductor/aoti_torch/c/",
"torch/include/torch/csrc/inductor/aoti_torch/generated/",
]
for header_file in header_files:
rel_path = header_file.relative_to(bdist_dir).as_posix()
if any(rel_path.startswith(pattern) for pattern in exclude_dir_patterns):
report(f"Skipping header: {rel_path}")
continue
original_content = header_file.read_text(encoding="utf-8")
wrapped_content = (
"#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n"
f"{original_content}"
"\n#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n"
)
header_file.write_text(wrapped_content, encoding="utf-8")
report(f"Wrapped header: {rel_path}")
def run(self) -> None:
with concat_license_files(include_files=True):
super().run()
@ -1419,14 +1380,6 @@ class bdist_wheel(setuptools.command.bdist_wheel.bdist_wheel):
# need an __init__.py file otherwise we wouldn't have a package
(bdist_dir / "torch" / "__init__.py").touch()
# Wrap all header files with TORCH_STABLE_ONLY macro
assert self.bdist_dir is not None, "bdist_dir should be set during wheel build"
bdist_dir = Path(self.bdist_dir)
report(
"-- Wrapping header files with if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)"
)
self._wrap_headers_with_macro(bdist_dir)
class clean(Command):
user_options: ClassVar[list[tuple[str, str | None, str]]] = []
@ -1632,7 +1585,7 @@ def configure_extension_build() -> tuple[
if cmake_cache_vars["USE_DISTRIBUTED"]:
# Only enable fr_trace command if distributed is enabled
entry_points["console_scripts"].append(
"torchfrtrace = tools.flight_recorder.fr_trace:main",
"torchfrtrace = torch.distributed.flight_recorder.fr_trace:main",
)
return ext_modules, cmdclass, packages, entry_points, extra_install_requires

View File

@ -8,6 +8,7 @@ set(AOTI_ABI_CHECK_TEST_ROOT ${TORCH_ROOT}/test/cpp/aoti_abi_check)
# Build the cpp gtest binary containing the cpp-only tests.
set(AOTI_ABI_CHECK_TEST_SRCS
${AOTI_ABI_CHECK_TEST_ROOT}/main.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_accessor.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_cast.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_devicetype.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_dispatch.cpp

View File

@ -0,0 +1,50 @@
#include <gtest/gtest.h>
#include <torch/headeronly/core/TensorAccessor.h>
#include <string>
TEST(TestAccessor, HeaderOnlyTensorAccessor) {
std::vector<int32_t> v = {11, 12, 13, 21, 22, 23};
std::vector<int64_t> sizes = {2, 3};
std::vector<int64_t> strides = {3, 1};
auto acc = torch::headeronly::HeaderOnlyTensorAccessor<int32_t, 2>(
v.data(), sizes.data(), strides.data());
EXPECT_EQ(acc[0][0], 11);
EXPECT_EQ(acc[0][1], 12);
EXPECT_EQ(acc[0][2], 13);
EXPECT_EQ(acc[1][0], 21);
EXPECT_EQ(acc[1][1], 22);
EXPECT_EQ(acc[1][2], 23);
}
TEST(TestAccessor, HeaderOnlyGenericPackedTensorAccessor) {
std::vector<int32_t> v = {11, 12, 13, 21, 22, 23};
std::vector<int64_t> sizes = {2, 3};
std::vector<int64_t> strides = {3, 1};
auto acc =
torch::headeronly::HeaderOnlyGenericPackedTensorAccessor<int32_t, 2>(
v.data(), sizes.data(), strides.data());
EXPECT_EQ(acc[0][0], 11);
EXPECT_EQ(acc[0][1], 12);
EXPECT_EQ(acc[0][2], 13);
EXPECT_EQ(acc[1][0], 21);
EXPECT_EQ(acc[1][1], 22);
EXPECT_EQ(acc[1][2], 23);
auto tacc = acc.transpose(0, 1);
EXPECT_EQ(tacc[0][0], 11);
EXPECT_EQ(tacc[0][1], 21);
EXPECT_EQ(tacc[1][0], 12);
EXPECT_EQ(tacc[1][1], 22);
EXPECT_EQ(tacc[2][0], 13);
EXPECT_EQ(tacc[2][1], 23);
try {
acc.transpose(0, 2);
} catch (const std::exception& e) {
EXPECT_TRUE(
std::string(e.what()).find("HeaderOnlyIndexBoundsCheck") !=
std::string::npos);
}
}

View File

@ -13,6 +13,17 @@ TEST(TestScalarType, ScalarTypeToCPPTypeT) {
#undef DEFINE_CHECK
}
TEST(TestScalarType, CppTypeToScalarType) {
using torch::headeronly::CppTypeToScalarType;
using torch::headeronly::ScalarType;
#define DEFINE_CHECK(TYPE, SCALARTYPE) \
EXPECT_EQ(CppTypeToScalarType<TYPE>::value, ScalarType::SCALARTYPE);
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
#undef DEFINE_CHECK
}
#define DEFINE_CHECK(TYPE, SCALARTYPE) \
{ \
EXPECT_EQ( \

View File

@ -0,0 +1,30 @@
#include "kernel.h"
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <cuda_runtime.h>
using torch::stable::Tensor;
Tensor mv_tensor_accessor_cuda(Tensor m, Tensor v) {
STD_TORCH_CHECK(m.dim() == 2, "m must be 2D");
STD_TORCH_CHECK(v.dim() == 1, "v must be 1D");
STD_TORCH_CHECK(m.size(1) == v.size(0), "m.shape[1] == v.shape[0] must hold");
STD_TORCH_CHECK(m.scalar_type() == v.scalar_type(), "m and v must have the same dtype");
STD_TORCH_CHECK(m.device() == v.device(), "m and v must be on the same device");
Tensor res = new_empty(m, {m.size(0)});
THO_DISPATCH_V2(m.scalar_type(), "mv_tensor_accessor_cuda",
AT_WRAP(([&]() {
auto resa = Accessor_cuda<scalar_t, 1>(reinterpret_cast<scalar_t*>(res.data_ptr()), res.sizes().data(), res.strides().data());
auto ma = Accessor_cuda<scalar_t, 2>(reinterpret_cast<scalar_t*>(m.data_ptr()), m.sizes().data(), m.strides().data());
auto va = Accessor_cuda<scalar_t, 1>(reinterpret_cast<scalar_t*>(v.data_ptr()), v.sizes().data(), v.strides().data());
mv_tensor_accessor_kernel<Accessor_cuda, scalar_t><<<1, 1, 0, 0>>>(resa, ma, va);
})),
AT_FLOATING_TYPES);
return res;
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CUDA, m) {
m.impl("mv_tensor_accessor", TORCH_BOX(&mv_tensor_accessor_cuda));
}

View File

@ -1,3 +1,5 @@
#include "kernel.h"
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/stable/accelerator.h>
#include <torch/csrc/stable/device.h>
@ -514,6 +516,32 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("test_device_is_cpu", &boxed_test_device_is_cpu);
}
Tensor mv_tensor_accessor_cpu(Tensor m, Tensor v) {
STD_TORCH_CHECK(m.dim() == 2, "m must be 2D");
STD_TORCH_CHECK(v.dim() == 1, "v must be 1D");
STD_TORCH_CHECK(m.size(1) == v.size(0), "m.shape[1] == v.shape[0] must hold");
STD_TORCH_CHECK(m.scalar_type() == v.scalar_type(), "m and v must have the same dtype");
STD_TORCH_CHECK(m.device() == v.device(), "m and v must be on the same device");
Tensor res = new_empty(m, {m.size(0)});
THO_DISPATCH_V2(m.scalar_type(), "mv_tensor_accessor_cpu",
AT_WRAP(([&]() {
auto resa = Accessor_cpu<scalar_t, 1>(reinterpret_cast<scalar_t*>(res.data_ptr()), res.sizes().data(), res.strides().data());
auto ma = Accessor_cpu<scalar_t, 2>(reinterpret_cast<scalar_t*>(m.data_ptr()), m.sizes().data(), m.strides().data());
auto va = Accessor_cpu<scalar_t, 1>(reinterpret_cast<scalar_t*>(v.data_ptr()), v.sizes().data(), v.strides().data());
mv_tensor_accessor_kernel<Accessor_cpu, scalar_t>(resa, ma, va);
})),
AT_FLOATING_TYPES);
return res;
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("mv_tensor_accessor(Tensor m, Tensor v) -> Tensor");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
m.impl("mv_tensor_accessor", TORCH_BOX(&mv_tensor_accessor_cpu));
}
// Test functions for torch::stable::accelerator APIs
#ifdef LAE_USE_CUDA
@ -634,3 +662,38 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("test_parallel_for", &boxed_test_parallel_for);
m.impl("test_get_num_threads", &boxed_test_get_num_threads);
}
Tensor my_empty(
torch::headeronly::HeaderOnlyArrayRef<int64_t> size,
std::optional<torch::headeronly::ScalarType> dtype,
std::optional<torch::stable::Device> device,
std::optional<bool> pin_memory) {
return empty(size, dtype, device, pin_memory);
}
Tensor my_flatten(Tensor t, int64_t start_dim, int64_t end_dim) {
return flatten(t, start_dim, end_dim);
}
Tensor my_reshape(Tensor t, torch::headeronly::HeaderOnlyArrayRef<int64_t> shape) {
return reshape(t, shape);
}
Tensor my_view(Tensor t, torch::headeronly::HeaderOnlyArrayRef<int64_t> size) {
return view(t, size);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def(
"my_empty(int[] size, ScalarType? dtype=None, Device? device=None, bool? pin_memory=None) -> Tensor");
m.def("my_flatten(Tensor t, int start_dim=0, int end_dim=-1) -> Tensor");
m.def("my_reshape(Tensor t, int[] shape) -> Tensor");
m.def("my_view(Tensor t, int[] size) -> Tensor");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("my_empty", TORCH_BOX(&my_empty));
m.impl("my_flatten", TORCH_BOX(&my_flatten));
m.impl("my_reshape", TORCH_BOX(&my_reshape));
m.impl("my_view", TORCH_BOX(&my_view));
}

View File

@ -0,0 +1,26 @@
#include <torch/headeronly/core/Dispatch_v2.h>
#include <torch/headeronly/core/TensorAccessor.h>
template <typename T, size_t N>
using Accessor_cpu = torch::headeronly::HeaderOnlyTensorAccessor<T, N>;
#if defined(__CUDACC__) || defined(__HIPCC__)
#define MAYBE_GLOBAL __global__
template <typename T, size_t N>
using Accessor_cuda = torch::headeronly::HeaderOnlyGenericPackedTensorAccessor<T, N, torch::headeronly::RestrictPtrTraits>;
#else
#define MAYBE_GLOBAL
#endif
template <template <typename, size_t> class Accessor, typename scalar_t>
MAYBE_GLOBAL void mv_tensor_accessor_kernel(Accessor<scalar_t, 1> resa, Accessor<scalar_t, 2> ma, Accessor<scalar_t, 1> va) {
for (int64_t i = 0; i < resa.size(0); i++) {
scalar_t val = 0;
for (int64_t j = 0; j < ma.size(1); j++) {
val += ma[i][j] * va[j];
}
resa[i] = val;
}
}

View File

@ -487,3 +487,72 @@ def test_get_num_threads() -> int:
Returns: int - the number of threads for the parallel backend
"""
return torch.ops.libtorch_agnostic.test_get_num_threads.default()
def my_empty(size, dtype=None, device=None, pin_memory=None) -> Tensor:
"""
Creates an empty tensor with the specified size, dtype, device, and pin_memory.
Args:
size: list[int] - size of the tensor to create
dtype: ScalarType or None - data type of the tensor
device: Device or None - device on which to create the tensor
pin_memory: bool or None - whether to use pinned memory
Returns: Tensor - an uninitialized tensor with the specified properties
"""
return torch.ops.libtorch_agnostic.my_empty.default(size, dtype, device, pin_memory)
def my_flatten(t, start_dim=0, end_dim=-1) -> Tensor:
"""
Flattens the input tensor from start_dim to end_dim into a single dimension.
Args:
t: Tensor - tensor to flatten
start_dim: int - first dimension to flatten (default: 0)
end_dim: int - last dimension to flatten (default: -1)
Returns: Tensor - flattened tensor
"""
return torch.ops.libtorch_agnostic.my_flatten.default(t, start_dim, end_dim)
def my_reshape(t, shape) -> Tensor:
"""
Returns a tensor with the same data but different shape.
Args:
t: Tensor - tensor to reshape
shape: list[int] - new shape for the tensor
Returns: Tensor - reshaped tensor
"""
return torch.ops.libtorch_agnostic.my_reshape.default(t, shape)
def my_view(t, size) -> Tensor:
"""
Returns a new tensor with the same data as the input tensor but of a different shape.
Args:
t: Tensor - tensor to view
size: list[int] - new size for the tensor
Returns: Tensor - tensor with new view
"""
return torch.ops.libtorch_agnostic.my_view.default(t, size)
def mv_tensor_accessor(m, v) -> Tensor:
"""
Returns matrix-vector product.
Args:
m: any 2-D Tensor with shape (N, M)
v: any 1-D Tensor with shape (M,)
Returns:
a 1-D Tensor with shape (N,)
"""
return torch.ops.libtorch_agnostic.mv_tensor_accessor.default(m, v)

View File

@ -33,16 +33,17 @@ class clean(distutils.command.clean.clean):
def get_extension():
extra_compile_args = {
"cxx": ["-fdiagnostics-color=always", "-DTORCH_STABLE_ONLY"],
"cxx": ["-fdiagnostics-color=always"],
}
sources = list(CSRC_DIR.glob("**/*.cpp"))
extension = CppExtension
# allow including <cuda_runtime.h>
if torch.cuda.is_available():
extra_compile_args["cxx"].append("-DLAE_USE_CUDA")
extra_compile_args["nvcc"] = ["-O2"]
extension = CUDAExtension
sources = list(CSRC_DIR.glob("**/*.cpp"))
sources.extend(CSRC_DIR.glob("**/*.cu"))
return [
extension(

View File

@ -525,6 +525,113 @@ if not IS_WINDOWS:
expected_num_threads = torch.get_num_threads()
self.assertEqual(num_threads, expected_num_threads)
def test_my_empty(self, device):
import libtorch_agnostic
deterministic = torch.are_deterministic_algorithms_enabled()
try:
# set use_deterministic_algorithms to fill uninitialized memory
torch.use_deterministic_algorithms(True)
size = [2, 3]
result = libtorch_agnostic.ops.my_empty(size, None, None, None)
expected = torch.empty(size)
self.assertEqual(result, expected, exact_device=True)
result_float = libtorch_agnostic.ops.my_empty(
size, torch.float32, None, None
)
expected_float = torch.empty(size, dtype=torch.float32)
self.assertEqual(result_float, expected_float, exact_device=True)
result_with_device = libtorch_agnostic.ops.my_empty(
size, torch.float64, device, None
)
expected_with_device = torch.empty(
size, dtype=torch.float64, device=device
)
self.assertEqual(
result_with_device, expected_with_device, exact_device=True
)
if device == "cuda":
result_pinned = libtorch_agnostic.ops.my_empty(
size, torch.float32, "cpu", True
)
expected_pinned = torch.empty(
size, dtype=torch.float32, device="cpu", pin_memory=True
)
self.assertEqual(result_pinned, expected_pinned)
self.assertTrue(result_pinned.is_pinned())
finally:
torch.use_deterministic_algorithms(deterministic)
def test_my_flatten(self, device):
import libtorch_agnostic
t = torch.randn(2, 3, 4, device=device)
result = libtorch_agnostic.ops.my_flatten(t)
expected = torch.flatten(t)
self.assertEqual(result, expected)
result_start = libtorch_agnostic.ops.my_flatten(t, 1)
expected_start = torch.flatten(t, 1)
self.assertEqual(result_start, expected_start)
result_range = libtorch_agnostic.ops.my_flatten(t, 2, -1)
expected_range = torch.flatten(t, 2, -1)
self.assertEqual(result_range, expected_range)
def test_my_reshape(self, device):
import libtorch_agnostic
t = torch.randn(2, 3, 4, device=device)
result = libtorch_agnostic.ops.my_reshape(t, [6, 4])
expected = torch.reshape(t, [6, 4])
self.assertEqual(result, expected)
result_infer = libtorch_agnostic.ops.my_reshape(t, [-1, 4])
expected_infer = torch.reshape(t, [-1, 4])
self.assertEqual(result_infer, expected_infer)
result_flat = libtorch_agnostic.ops.my_reshape(t, [-1])
expected_flat = torch.reshape(t, [-1])
self.assertEqual(result_flat, expected_flat)
def test_my_view(self, device):
import libtorch_agnostic
t = torch.randn(2, 3, 4, device=device)
result = libtorch_agnostic.ops.my_view(t, [6, 4])
expected = t.view([6, 4])
self.assertEqual(result, expected)
result_infer = libtorch_agnostic.ops.my_view(t, [-1, 4])
expected_infer = t.view([-1, 4])
self.assertEqual(result_infer, expected_infer)
result_flat = libtorch_agnostic.ops.my_view(t, [-1])
expected_flat = t.view([-1])
self.assertEqual(result_flat, expected_flat)
def test_mv_tensor_accessor(self, device):
import libtorch_agnostic
m = torch.rand(3, 5, device=device)
v = torch.rand(5, device=device)
result = libtorch_agnostic.ops.mv_tensor_accessor(m, v)
expected = torch.mv(m, v)
self.assertEqual(result, expected)
# non-contiguous inputs
m = torch.rand(3 * 2, 5 * 3, device=device)[::2, ::3]
v = torch.rand(5 * 4, device=device)[::4]
result = libtorch_agnostic.ops.mv_tensor_accessor(m, v)
expected = torch.mv(m, v)
self.assertEqual(result, expected)
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
if __name__ == "__main__":

View File

@ -0,0 +1,67 @@
import distutils.command.clean
import shutil
from pathlib import Path
from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CppExtension
ROOT_DIR = Path(__file__).parent
CSRC_DIR = ROOT_DIR / "torch_stable_test" / "csrc"
class clean(distutils.command.clean.clean):
def run(self):
# Run default behavior first
distutils.command.clean.clean.run(self)
# Remove extension
for path in (ROOT_DIR / "torch_stable_test").glob("**/*.so"):
path.unlink()
# Remove build and dist and egg-info directories
dirs = [
ROOT_DIR / "build",
ROOT_DIR / "dist",
ROOT_DIR / "torch_stable_test.egg-info",
]
for path in dirs:
if path.exists():
shutil.rmtree(str(path), ignore_errors=True)
def get_extension():
extra_compile_args = {
"cxx": ["-fdiagnostics-color=always", "-DTORCH_STABLE_ONLY"],
}
sources = list(CSRC_DIR.glob("**/*.cpp"))
return [
CppExtension(
"torch_stable_test._C",
sources=sorted(str(s) for s in sources),
py_limited_api=True,
extra_compile_args=extra_compile_args,
extra_link_args=[],
)
]
setup(
name="torch_stable_test",
version="0.0",
author="PyTorch Core Team",
description="Test extension to verify TORCH_STABLE_ONLY flag",
packages=find_packages(exclude=("test",)),
package_data={"torch_stable_test": ["*.dll", "*.dylib", "*.so"]},
install_requires=[
"torch",
],
ext_modules=get_extension(),
cmdclass={
"build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
"clean": clean,
},
options={"bdist_wheel": {"py_limited_api": "cp39"}},
)

View File

@ -0,0 +1 @@
#include <ATen/core/TensorBase.h> // This should trigger the TORCH_STABLE_ONLY error

View File

@ -0,0 +1,22 @@
# Owner(s): ["module: cpp"]
from pathlib import Path
from torch.testing._internal.common_utils import (
install_cpp_extension,
IS_WINDOWS,
run_tests,
TestCase,
)
if not IS_WINDOWS:
class TestTorchStable(TestCase):
def test_setup_fails(self):
with self.assertRaisesRegex(RuntimeError, "build failed for cpp extension"):
install_cpp_extension(extension_root=Path(__file__).parent.parent)
if __name__ == "__main__":
run_tests()

View File

@ -65,7 +65,6 @@ from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
curr_backend = dist.get_default_backend_for_device(device_type)
class SimpleModel(nn.Module):
@ -423,10 +422,10 @@ class TestFullyShard2DStateDict(DTensorTestBase):
@property
def backend(self):
# need to specify gloo backend for testing cpu offload
return f"cpu:gloo,{device_type}:{curr_backend}"
return "cpu:gloo,xpu:xccl" if TEST_XPU else "cpu:gloo,cuda:nccl"
@skip_if_lt_x_gpu(4)
@with_comms
@skip_if_lt_x_gpu(4)
def test_fully_shard_tp_2d_set_full_state_dict(self):
dummy_model = SimpleModel().to(device_type)
mesh_2d = init_device_mesh(
@ -515,8 +514,8 @@ class Test2dFSDP1ParallelIntegration(DTensorTestBase):
).to_local()
self.assertEqual(param_m2, param_m1)
@skip_if_lt_x_gpu(4)
@with_comms
@skip_if_lt_x_gpu(4)
def test_2d_ddp_integration_functionality(self) -> None:
model, twod_model, dp_pg = self.init_model(self.device_type)
optim = torch.optim.Adam(model.parameters(), lr=3e-5)
@ -567,8 +566,8 @@ class TestNew2dParallelTraining(DTensorTestBase):
p2 = p2.redistribute(p2.device_mesh, [Replicate()]).to_local()
self.assertTrue(torch.allclose(p1, p2), f"{p1} vs {p2}")
@skip_if_lt_x_gpu(4)
@with_comms
@skip_if_lt_x_gpu(4)
def test_2d_fsdp_state_enable_extension(self):
mesh_2d = init_device_mesh(
self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
@ -643,18 +642,18 @@ class TestNew2dParallelTraining(DTensorTestBase):
# Ensure all params are still the same after optimizer update.
self._compare_params(model, model_2d)
@skip_if_lt_x_gpu(4)
@with_comms
@skip_if_lt_x_gpu(4)
def test_2d_e2e_training_default(self):
self._test_2d_e2e_training()
@skip_if_lt_x_gpu(4)
@with_comms
@skip_if_lt_x_gpu(4)
def test_2d_e2e_training_use_orig_params(self):
self._test_2d_e2e_training(use_orig_params=True)
@skip_if_lt_x_gpu(4)
@with_comms
@skip_if_lt_x_gpu(4)
def test_2d_e2e_training_not_use_orig_params(self):
# TODO: need to revisit input_reshard API about why it failed multi-gpu tests.
# self._test_2d_e2e_training(recompute_activation=True)
@ -667,10 +666,10 @@ class TestNew2dParallelStateDict(DTensorTestBase):
@property
def backend(self):
# need to specify gloo backend for testing cpu offload
return f"cpu:gloo,{device_type}:{curr_backend}"
return "cpu:gloo,xpu:xccl" if TEST_XPU else "cpu:gloo,cuda:nccl"
@skip_if_lt_x_gpu(4)
@with_comms
@skip_if_lt_x_gpu(4)
def test_fsdp_2d_extension(self):
"""
Test whether _fsdp_extension from FSDPstate has been set correctly.
@ -701,8 +700,8 @@ class TestNew2dParallelStateDict(DTensorTestBase):
model_1d_fsdp_state = _get_module_fsdp_state(model_1d)
self.assertEqual(model_1d_fsdp_state._fsdp_extension, None)
@skip_if_lt_x_gpu(4)
@with_comms
@skip_if_lt_x_gpu(4)
@parametrize("is_even_sharded_model", [True, False])
def test_2d_state_dict(self, is_even_sharded_model):
simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven
@ -757,8 +756,8 @@ class TestNew2dParallelStateDict(DTensorTestBase):
torch.allclose(no_wrap_v, all_gather_two_d_v.to_local()), True
)
@skip_if_lt_x_gpu(4)
@with_comms
@skip_if_lt_x_gpu(4)
@parametrize("is_even_sharded_model", [True, False])
def test_2d_load_state_dict(self, is_even_sharded_model):
simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven
@ -812,8 +811,8 @@ class TestNew2dParallelStateDict(DTensorTestBase):
self.assertEqual(v1.device_mesh, v2.device_mesh)
self.assertEqual(v1.placements, v2.placements)
@skip_if_lt_x_gpu(4)
@with_comms
@skip_if_lt_x_gpu(4)
@parametrize("is_even_sharded_model", [True, False])
def test_2d_optim_state_dict(self, is_even_sharded_model):
simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven
@ -900,9 +899,9 @@ class TestNew2dParallelStateDict(DTensorTestBase):
else:
self.assertEqual(new_state, state)
@skip_if_lt_x_gpu(4)
@with_comms
@with_temp_dir
@skip_if_lt_x_gpu(4)
def test_fsdp1_tp_2d_set_full_state_dict(self):
"""
This is a workaround for loading full state dict into a FSDP1+TP 2D model.

View File

@ -29,8 +29,8 @@ from torch.distributed.tensor.parallel import (
parallelize_module,
RowwiseParallel,
)
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_distributed import (
at_least_x_gpu,
MultiProcessTestCase,
requires_accelerator_dist_backend,
skip_if_lt_x_gpu,
@ -40,6 +40,7 @@ from torch.testing._internal.common_utils import (
parametrize,
run_tests,
skip_but_pass_in_sandcastle_if,
TEST_XPU,
)
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
@ -106,9 +107,11 @@ class ComposabilityTest(MultiProcessTestCase):
def device(self):
return self.rank
@requires_accelerator_dist_backend()
@requires_accelerator_dist_backend(["nccl", "xccl"])
@skip_if_lt_x_gpu(8)
@skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs")
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU and not TEST_XPU, "Test requires 4+ GPUs"
)
def test_pp_and_dcp(self):
"""
Test that pipeline parallelism and distributed checkpointing can be used together and
@ -198,9 +201,11 @@ class ComposabilityTest(MultiProcessTestCase):
_dcp_test(self)
@requires_accelerator_dist_backend()
@requires_accelerator_dist_backend(["nccl", "xccl"])
@skip_if_lt_x_gpu(8)
@skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs")
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU and not TEST_XPU, "Test requires 8+ GPUs"
)
@parametrize(
"ScheduleClass",
[
@ -350,9 +355,11 @@ class ComposabilityTest(MultiProcessTestCase):
torch.distributed.destroy_process_group()
@requires_accelerator_dist_backend()
@requires_accelerator_dist_backend(["nccl", "xccl"])
@skip_if_lt_x_gpu(8)
@skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs")
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU and not TEST_XPU, "Test requires 8+ GPUs"
)
@parametrize(
"ScheduleClass",
[
@ -543,9 +550,11 @@ class ComposabilityTest(MultiProcessTestCase):
torch.distributed.destroy_process_group()
@requires_accelerator_dist_backend()
@requires_accelerator_dist_backend(["nccl", "xccl"])
@skip_if_lt_x_gpu(8)
@skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs")
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU and not TEST_XPU, "Test requires 8+ GPUs"
)
@parametrize(
"ScheduleClass",
[

View File

@ -1,5 +1,6 @@
# Owner(s): ["oncall: distributed"]
import os
import sys
import torch
@ -17,8 +18,8 @@ from torch.distributed.algorithms.ddp_comm_hooks import (
)
from torch.nn.parallel import DistributedDataParallel
from torch.testing._internal.common_distributed import (
DistributedTestBase,
requires_accelerator_dist_backend,
MultiProcessTestCase,
requires_nccl,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
@ -29,12 +30,9 @@ if TEST_WITH_DEV_DBG_ASAN:
sys.exit(0)
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
def gpus_for_rank(world_size):
visible_devices = list(range(torch.accelerator.device_count()))
gpus_per_process = torch.accelerator.device_count() // world_size
visible_devices = list(range(torch.cuda.device_count()))
gpus_per_process = torch.cuda.device_count() // world_size
gpus_for_rank = []
for rank in range(world_size):
gpus_for_rank.append(
@ -62,7 +60,27 @@ class TestDdpCommHook(nn.Module):
return self.t0(x ** (1 + rank))
class DistributedDataParallelCommHookTest(DistributedTestBase):
class DistributedDataParallelCommHookTest(MultiProcessTestCase):
def setUp(self):
super().setUp()
self._spawn_processes()
def tearDown(self):
try:
os.remove(self.file_name)
except OSError:
pass
def _get_process_group_nccl(self):
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(
backend="nccl",
world_size=self.world_size,
rank=self.rank,
store=store,
)
return dist.distributed_c10d._get_default_group()
@property
def world_size(self):
return 2
@ -101,14 +119,14 @@ class DistributedDataParallelCommHookTest(DistributedTestBase):
param = next(model.parameters())
return param.grad
@requires_accelerator_dist_backend()
@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_ddp_comm_hook_allreduce_hook(self):
"""
This unit test verifies the ``allreduce`` hook registered case gives same result
with no hook registered case.
"""
process_group = self.create_pg(device_type)
process_group = self._get_process_group_nccl()
# No hook registered case, get the reference grads.
reference_grads = self._get_grads(process_group, None)
@ -117,14 +135,14 @@ class DistributedDataParallelCommHookTest(DistributedTestBase):
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=0)
@requires_accelerator_dist_backend()
@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_ddp_comm_hook_fp16compress_hook(self):
"""
This unit test verifies the ``fp16 compress`` hook registered case
gives close result with no hook registered case.
"""
process_group = self.create_pg(device_type)
process_group = self._get_process_group_nccl()
# No hook registered case, get the reference grads.
reference_grads = self._get_grads(process_group, None)
@ -133,14 +151,14 @@ class DistributedDataParallelCommHookTest(DistributedTestBase):
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4)
@requires_accelerator_dist_backend()
@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_ddp_comm_hook_quantize_per_tensor_hook(self):
"""
This unit test verifies the ``quantize per tensor`` hook registered case
gives close result with no hook registered case.
"""
process_group = self.create_pg(device_type)
process_group = self._get_process_group_nccl()
# No hook registered case, get the reference grads.
reference_grads = self._get_grads(process_group, None)
@ -149,14 +167,14 @@ class DistributedDataParallelCommHookTest(DistributedTestBase):
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4)
@requires_accelerator_dist_backend()
@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_ddp_comm_hook_quantize_per_channel_hook(self):
"""
This unit test verifies the ``quantize per channel`` hook registered case
gives close result with no hook registered case.
"""
process_group = self.create_pg(device_type)
process_group = self._get_process_group_nccl()
# No hook registered case, get the reference grads.
reference_grads = self._get_grads(process_group, None)
@ -167,14 +185,14 @@ class DistributedDataParallelCommHookTest(DistributedTestBase):
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4)
@requires_accelerator_dist_backend()
@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_ddp_comm_hook_noop_hook(self):
"""
This unit test verifies the ``noop`` hook registered case and a subsequent allreduce
gives same result with no hook registered case.
"""
process_group = self.create_pg(device_type)
process_group = self._get_process_group_nccl()
# No hook registered case, get the reference grads.
reference_grads = self._get_grads(process_group, None)
@ -186,10 +204,10 @@ class DistributedDataParallelCommHookTest(DistributedTestBase):
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=0)
@requires_accelerator_dist_backend()
@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_is_last_hook(self):
process_group = self.create_pg(device_type)
process_group = self._get_process_group_nccl()
def hook(flags, bucket):
flags.append(bucket.is_last())

View File

@ -32,7 +32,7 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
class TestStateDictUtils(DTensorTestBase):
@property
def world_size(self):
return min(4, torch.accelerator.device_count())
return min(4, torch.cuda.device_count())
@with_comms
@skip_if_lt_x_gpu(2)
@ -49,7 +49,7 @@ class TestStateDictUtils(DTensorTestBase):
dist_tensor.to_local(), gather_dim=0, group=(device_mesh, 0)
)
self.assertEqual(expected_gathered_dtensor, gathered_state_dict["dtensor"])
self.assertEqual(gathered_state_dict["dtensor"].device.type, self.device_type)
self.assertTrue(gathered_state_dict["dtensor"].is_cuda)
@with_comms
@skip_if_lt_x_gpu(4)
@ -69,16 +69,14 @@ class TestStateDictUtils(DTensorTestBase):
)
if dist.get_rank() in (0, 2):
self.assertEqual(expected_gathered_dtensor, gathered_state_dict["dtensor"])
self.assertNotEqual(
gathered_state_dict["dtensor"].device.type, self.device_type
)
self.assertFalse(gathered_state_dict["dtensor"].is_cuda)
else:
self.assertEqual(gathered_state_dict, {})
@with_comms
@skip_if_lt_x_gpu(4)
def test_cpu_and_ranks_only(self):
device = torch.device(self.device_type)
device = torch.device("cuda")
state_dict = {
"tensor1": torch.arange(10, device=device),
"tensor2": torch.ones(10, device=device),
@ -87,7 +85,7 @@ class TestStateDictUtils(DTensorTestBase):
cpu_state_dict = _offload_state_dict_to_cpu(state_dict, ranks_only=(0, 2))
if dist.get_rank() in (0, 2):
for v in cpu_state_dict.values():
self.assertNotEqual(v.device.type, self.device_type)
self.assertFalse(v.is_cuda)
self.assertEqual(cpu_state_dict["tensor1"], torch.arange(10))
self.assertEqual(cpu_state_dict["tensor2"], torch.ones(10))
else:
@ -111,27 +109,27 @@ class TestStateDictUtils(DTensorTestBase):
for _ in range(10):
tensor, dtensor = create_dtensor()
ltensor.append(tensor)
ltensor.append(torch.ones(10, device=torch.device(self.device_type)))
ltensor.append(torch.ones(10, device=torch.device("cuda")))
ldtensor.append(dtensor)
ldtensor.append(torch.ones(10, device=torch.device(self.device_type)))
ldtensor.append(torch.ones(10, device=torch.device("cuda")))
tensor, dtensor = create_dtensor()
dist_state_dict = {
"local": dtensor,
"list": ldtensor,
"arange": torch.arange(10, device=torch.device(self.device_type)),
"arange": torch.arange(10, device=torch.device("cuda")),
}
state_dict = {
"local": tensor,
"list": ltensor,
"arange": torch.arange(10, device=torch.device(self.device_type)),
"arange": torch.arange(10, device=torch.device("cuda")),
}
self.assertEqual(state_dict, _gather_state_dict(dist_state_dict))
@with_comms
@skip_if_lt_x_gpu(2)
def test_create_cpu_state_dict(self):
device = torch.device(self.device_type)
device = torch.device("cuda")
rank = dist.get_rank()
# Scale tensors based on world size
# to fit in the tensor shards accurately.
@ -151,7 +149,7 @@ class TestStateDictUtils(DTensorTestBase):
metadata=ShardMetadata(
shard_offsets=[5 * rank, 0],
shard_sizes=[5, 10],
placement=f"rank:{rank}/{self.device_type}:{rank}",
placement=f"rank:{rank}/cuda:{rank}",
),
)
],
@ -161,7 +159,7 @@ class TestStateDictUtils(DTensorTestBase):
torch.arange(50 * scale_factor, device=device).reshape(
5 * scale_factor, 10
),
init_device_mesh(self.device_type, mesh_shape=(self.world_size,)),
init_device_mesh("cuda", mesh_shape=(self.world_size,)),
[Shard(0)],
),
"non_tensor_bytes_io": copy.deepcopy(buffer),
@ -247,7 +245,7 @@ class TestStateDictUtils(DTensorTestBase):
even_tensor = torch.randn(self.world_size, 2)
uneven_tensor = torch.randn(1, 2)
mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,))
mesh = init_device_mesh("cuda", mesh_shape=(self.world_size,))
even_dtensor = distribute_tensor(
torch.randn(self.world_size, 2), mesh, [Shard(0)]
)
@ -275,10 +273,10 @@ class TestStateDictUtils(DTensorTestBase):
@with_comms
@skip_if_lt_x_gpu(2)
def test_cpu_offload_for_dtensor(self):
device_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,))
device_mesh = init_device_mesh("cuda", mesh_shape=(self.world_size,))
sd = {
"k": DTensor.from_local(
torch.ones(8, 8, device=self.device_type), device_mesh, [Shard(0)]
torch.ones(8, 8, device="cuda"), device_mesh, [Shard(0)]
)
}
cpu_sd = _create_cpu_state_dict(sd)
@ -292,12 +290,12 @@ class TestStateDictUtils(DTensorTestBase):
self.assertFalse(torch.equal(sd["k"].cpu(), cpu_sd["k"]))
_copy_state_dict(sd, cpu_sd, non_blocking=True)
torch.accelerator.synchronize()
torch.cuda.synchronize()
self.assertTrue(torch.equal(sd["k"].cpu(), cpu_sd["k"]))
sd["k"] += 1
self.assertFalse(torch.equal(sd["k"].cpu(), cpu_sd["k"]))
_copy_state_dict(sd, cpu_sd, non_blocking=True)
torch.accelerator.synchronize()
torch.cuda.synchronize()
self.assertTrue(torch.equal(sd["k"].cpu(), cpu_sd["k"]))

View File

@ -2,23 +2,16 @@
import copy
import math
import pathlib
import sys
from typing import Any
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent
sys.path.insert(0, str(REPO_ROOT))
from tools.flight_recorder.components.builder import build_db
from tools.flight_recorder.components.config_manager import JobConfig
from tools.flight_recorder.components.types import COLLECTIVES, MatchInfo, MatchState
from tools.flight_recorder.components.utils import match_one_event
# Make sure to remove REPO_ROOT after import is done
sys.path.remove(str(REPO_ROOT))
from torch.distributed.flight_recorder.components.builder import build_db
from torch.distributed.flight_recorder.components.config_manager import JobConfig
from torch.distributed.flight_recorder.components.types import (
COLLECTIVES,
MatchInfo,
MatchState,
)
from torch.distributed.flight_recorder.components.utils import match_one_event
from torch.testing._internal.common_utils import run_tests, TestCase

View File

@ -7,7 +7,7 @@
import copy
import sys
from contextlib import contextmanager, nullcontext
from contextlib import nullcontext
from typing import Any, cast
import numpy as np
@ -40,6 +40,7 @@ from torch.testing._internal.common_distributed import (
skip_if_rocm_multiprocess,
skip_if_win32,
)
from torch.testing._internal.common_fsdp import get_devtype
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
@ -56,17 +57,7 @@ except ImportError:
HAS_TORCHVISION = False
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
@contextmanager
def deterministic_algorithms(enabled=True):
prev_state = torch.are_deterministic_algorithms_enabled()
torch.use_deterministic_algorithms(enabled)
try:
yield
finally:
torch.use_deterministic_algorithms(prev_state)
device_type = str(get_devtype())
class TestZeroRedundancyOptimizer(DistributedTestBase):
@ -1250,7 +1241,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
enabled=True, deterministic=True, benchmark=False
)
if "cuda" in device
else deterministic_algorithms(True)
else torch.use_deterministic_algorithms(True)
)
with det_ctx:
device_ids = [rank] if requires_ddp_rank(device) else None

View File

@ -29,7 +29,7 @@ from torch.testing._internal.common_utils import skipIfRocm
from torch.testing._internal.inductor_utils import HAS_GPU
def estimate_aten_runtime(fx_node, compute_multiplier=1.0):
def estimate_aten_runtime(fx_node, override_size=None, compute_multiplier=1.0):
# for tests, assume a matmul can hide a single collective
if "c10" in str(fx_node.target):
return 1.0
@ -1061,6 +1061,63 @@ class TestComputeCommReorderingBucketing(TestComputeCommReorderingMultiProc):
correct = func(a, b, c)
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_bucket_patches())
def test_multiple_hiding_nodes_bucketing(self):
"""Test that collectives hidden by multiple compute ops can bucket together."""
# Use 0.5 compute multiplier so each collective needs 2 matmuls to be fully hidden
def estimate_with_half_compute(fx_node, override_size=None):
return estimate_aten_runtime(fx_node, override_size, compute_multiplier=0.5)
def func(a, b, *, ranks):
# Two all_gathers that will be hidden by multiple compute operations
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
# Multiple compute operations that can hide the collectives
# With 0.5 multiplier: mm1 and mm2 together hide ag1, mm2 and mm3 together hide ag2
mm1 = torch.matmul(a, a.T)
mm2 = torch.matmul(b, b.T)
mm3 = torch.matmul(a + b, (a + b).T)
return ag1.sum() + ag2.sum() + mm1.sum() + mm2.sum() + mm3.sum()
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
a = torch.ones(8, 8, dtype=torch.float, device=device_type)
b = torch.ones(8, 8, dtype=torch.float, device=device_type) * 2
ranks = list(range(self.world_size))
func_c = functools.partial(func, ranks=ranks)
# Patch with custom estimation that uses 0.5 multiplier
with torch._inductor.config.patch(
{
"aten_distributed_optimizations.custom_runtime_estimation": estimate_with_half_compute
}
):
compiled = torch.compile(func_c)
out, aten_graph_str = run_and_get_aten_graph(compiled, a, b)
# Should have 1 bucketed all_gather (both ag1 and ag2 bucketed together)
FileCheck().check_count(
"torch.ops._c10d_functional.wait_tensor.default", 1, exactly=True
).run(aten_graph_str)
# Verify bucketed collective is scheduled before all matmuls
FileCheck().check("functional.all_gather_into_tensor").check(
"aten.mm"
).check("aten.mm").check("aten.mm").check("wait_tensor").run(aten_graph_str)
# Verify correctness
correct = func(a, b, ranks=ranks)
self.assertTrue(same(out, correct))
def get_toy_model(device_type: str):
"""

View File

@ -24,7 +24,7 @@ from torch.distributed._functional_collectives import (
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
from torch.testing._internal.common_device_type import e4m3_type
from torch.testing._internal.common_distributed import (
DistributedTestBase,
MultiProcessTestCase,
requires_accelerator_dist_backend,
skip_if_lt_x_gpu,
)
@ -59,8 +59,12 @@ if not dist.is_available():
sys.exit(0)
@requires_accelerator_dist_backend()
class TestWithNCCL(DistributedTestBase):
@requires_accelerator_dist_backend(["nccl", "xccl"])
class TestWithNCCL(MultiProcessTestCase):
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
@property
def world_size(self) -> int:
return 2
@ -74,7 +78,16 @@ class TestWithNCCL(DistributedTestBase):
return torch.device(self.rank)
def _init_process_group(self) -> None:
self.create_pg(self.device.type)
torch.accelerator.set_device_index(self.rank)
store = dist.FileStore(self.file_name, self.world_size)
backend = dist.get_default_backend_for_device(self.device.type)
dist.init_process_group(
backend=backend,
world_size=self.world_size,
rank=self.rank,
store=store,
)
torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD)
@skip_if_lt_x_gpu(2)

View File

@ -11,10 +11,13 @@ if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_distributed import DistributedTestBase, TEST_SKIPS
from torch.testing._internal.common_utils import (
run_tests,
skipIfHpu,
TEST_CUDA,
TEST_HPU,
TEST_WITH_DEV_DBG_ASAN,
)
@ -26,8 +29,16 @@ if TEST_WITH_DEV_DBG_ASAN:
)
sys.exit(0)
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
device_count = torch.accelerator.device_count()
if TEST_HPU:
DEVICE = "hpu"
elif TEST_CUDA:
DEVICE = "cuda"
else:
DEVICE = "cpu"
device_module = torch.get_device_module(DEVICE)
device_count = device_module.device_count()
BACKEND = dist.get_default_backend_for_device(DEVICE)
def with_comms(func=None):
@ -38,10 +49,11 @@ def with_comms(func=None):
@wraps(func)
def wrapper(self, *args, **kwargs):
if device_type != "cpu" and device_count < self.world_size:
if DEVICE != "cpu" and device_count < self.world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
self.pg = self.create_pg(device=device_type)
kwargs["device"] = DEVICE
self.pg = self.create_pg(device=DEVICE)
try:
return func(self, *args, **kwargs)
finally:
@ -52,7 +64,7 @@ def with_comms(func=None):
class TestObjectCollectives(DistributedTestBase):
@with_comms()
def test_all_gather_object(self):
def test_all_gather_object(self, device):
output = [None] * dist.get_world_size()
dist.all_gather_object(object_list=output, obj=self.rank)
@ -60,7 +72,7 @@ class TestObjectCollectives(DistributedTestBase):
self.assertEqual(i, v, f"rank: {self.rank}")
@with_comms()
def test_gather_object(self):
def test_gather_object(self, device):
output = [None] * dist.get_world_size() if self.rank == 0 else None
dist.gather_object(obj=self.rank, object_gather_list=output)
@ -70,7 +82,7 @@ class TestObjectCollectives(DistributedTestBase):
@skipIfHpu
@with_comms()
def test_send_recv_object_list(self):
def test_send_recv_object_list(self, device):
val = 99 if self.rank == 0 else None
object_list = [val] * dist.get_world_size()
if self.rank == 0:
@ -84,7 +96,7 @@ class TestObjectCollectives(DistributedTestBase):
self.assertEqual(None, object_list[0])
@with_comms()
def test_broadcast_object_list(self):
def test_broadcast_object_list(self, device):
val = 99 if self.rank == 0 else None
object_list = [val] * dist.get_world_size()
# TODO test with broadcast_object_list's device argument
@ -93,7 +105,7 @@ class TestObjectCollectives(DistributedTestBase):
self.assertEqual(99, object_list[0])
@with_comms()
def test_scatter_object_list(self):
def test_scatter_object_list(self, device):
input_list = list(range(dist.get_world_size())) if self.rank == 0 else None
output_list = [None]
dist.scatter_object_list(
@ -111,30 +123,34 @@ class TestObjectCollectives(DistributedTestBase):
my_pg = dist.new_group(ranks, use_local_synchronization=True)
return rank, ranks, my_pg
@skipIfHpu
@with_comms()
def test_subpg_scatter_object(self):
def test_subpg_scatter_object(self, device):
rank, ranks, my_pg = self.setup_sub_pg()
out_list = [None]
dist.scatter_object_list(out_list, ranks, src=ranks[0], group=my_pg)
self.assertEqual(rank, out_list[0])
@skipIfHpu
@with_comms()
def test_subpg_all_gather_object(self):
def test_subpg_all_gather_object(self, device):
rank, ranks, my_pg = self.setup_sub_pg()
out_list = [None] * len(ranks)
dist.all_gather_object(out_list, rank, group=my_pg)
self.assertEqual(ranks, out_list)
@skipIfHpu
@with_comms()
def test_subpg_gather_object(self):
def test_subpg_gather_object(self, device):
rank, ranks, my_pg = self.setup_sub_pg()
out_list = [None] * len(ranks) if rank == ranks[0] else None
dist.gather_object(rank, out_list, dst=ranks[0], group=my_pg)
if rank == ranks[0]:
self.assertEqual(ranks, out_list)
@skipIfHpu
@with_comms()
def test_subpg_broadcast_object(self):
def test_subpg_broadcast_object(self, device):
rank, ranks, my_pg = self.setup_sub_pg()
out_list = [None]
if rank == ranks[0]:
@ -143,5 +159,7 @@ class TestObjectCollectives(DistributedTestBase):
self.assertEqual(ranks[0], out_list[0])
devices = ("cpu", "cuda", "hpu")
instantiate_device_type_tests(TestObjectCollectives, globals(), only_for=devices)
if __name__ == "__main__":
run_tests()

View File

@ -29,7 +29,7 @@ from torch.distributed.tensor._collective_utils import (
)
from torch.distributed.tensor.placement_types import _Partial, Shard
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests, TEST_HPU, TEST_XPU, TestCase
from torch.testing._internal.common_utils import run_tests, TEST_XPU, TestCase
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
@ -58,7 +58,7 @@ def _set_env_var(addr="localhost", port="25364", world_size=1, rank=0, local_ran
os.environ["LOCAL_RANK"] = f"{local_rank}"
@unittest.skipIf(TEST_XPU or TEST_HPU, "XPU/HPU does not support gloo backend.")
@unittest.skipIf(TEST_XPU, "XPU does not support gloo backend.")
class DeviceMeshTestGlooBackend(DTensorTestBase):
@property
def backend(self):

View File

@ -40,6 +40,7 @@ from torch.testing._internal.common_distributed import (
DynamoDistributedSingleProcTestCase,
MultiProcessTestCase,
requires_accelerator_dist_backend,
requires_gloo,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import (
@ -1773,16 +1774,10 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
inputs = [x, w, ar_0, ar_1]
f(*inputs, **self.get_world_trs())
def _pass(g):
from torch._inductor.fx_passes.bucketing import bucket_all_reduce
bucket_all_reduce(g.owning_module, lambda _: 2000)
torch._inductor.config.post_grad_custom_post_pass = _pass
with torch._inductor.config.patch(
{
"reorder_for_compute_comm_overlap": False,
"bucket_all_reduces_fx": bucket_mode,
}
):
compiled = torch.compile(f)
@ -2234,6 +2229,50 @@ class TestSyncDecisionCrossRanks(MultiProcessTestCase):
)
assert est_ms_nccl > 0
@skip_if_lt_x_gpu(2)
@requires_gloo()
def test_regression_use_nccl_estimate_with_gloo(self):
# Test checks that using nccl estimator option does not hard fail
# with backends that does not support runtime estimations, e.g. gloo
store = c10d.FileStore(self.file_name, self.world_size)
c10d.init_process_group(
backend="gloo", store=store, rank=self.rank, world_size=self.world_size
)
group = c10d.distributed_c10d._get_default_group()
group_name = "default"
torch._C._distributed_c10d._register_process_group(
group_name, torch.distributed.group.WORLD
)
group_size = group.size()
def func(inp, group_size, group_name):
ag_0_out = torch.ops._c10d_functional.all_gather_into_tensor(
inp, group_size, group_name
)
ag_0_wait = torch.ops.c10d_functional.wait_tensor(ag_0_out)
ag_1_out = torch.ops._c10d_functional.all_gather_into_tensor(
ag_0_wait, group_size, group_name
)
ag_1_wait = torch.ops.c10d_functional.wait_tensor(ag_1_out)
return ag_1_wait
gm = make_fx(func)(torch.ones(4, 4), group_size, group_name)
g = gm.graph
for n in g.nodes:
if is_all_gather_into_tensor(n):
from torch._inductor.comm_analysis import (
estimate_nccl_collective_runtime_from_fx_node,
)
est_ms = estimate_nccl_collective_runtime_from_fx_node(
n, use_nccl_estimator=False
)
assert est_ms > 0
est_ms_nccl = estimate_nccl_collective_runtime_from_fx_node(
n, use_nccl_estimator=True
)
assert est_ms_nccl > 0
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -49,7 +49,8 @@ def build_collective_info(graph, hiding_annotations):
"""
Build CollectiveInfo dict from manual hiding annotations.
hiding_annotations: dict mapping collective_start -> hiding_compute_node
hiding_annotations: dict mapping collective_start -> hiding_compute_node(s)
Can be a single node or a list/OrderedSet of nodes
"""
from torch._inductor.fx_passes.overlap_scheduling import CollectiveInfo
@ -65,12 +66,20 @@ def build_collective_info(graph, hiding_annotations):
# Build CollectiveInfo for each collective
for start_node, wait_node in start_to_wait.items():
hiding_node = hiding_annotations.get(start_node)
hiding_annotation = hiding_annotations.get(start_node)
# Convert to OrderedSet
hiding_nodes = OrderedSet()
if hiding_annotation is not None:
if isinstance(hiding_annotation, list | OrderedSet):
hiding_nodes = OrderedSet(hiding_annotation)
else:
hiding_nodes = OrderedSet([hiding_annotation])
# Estimate size and time
size_bytes = 16 * 4 # 4x4 tensor of floats
estimated_time_ms = 1.0 # Dummy time
exposed_time_ms = 0.0 if hiding_node else 1.0 # Hidden if has hiding_node
exposed_time_ms = 0.0 if hiding_nodes else 1.0 # Hidden if has hiding_nodes
collective_info[start_node] = CollectiveInfo(
start_node=start_node,
@ -78,7 +87,7 @@ def build_collective_info(graph, hiding_annotations):
size_bytes=size_bytes,
estimated_time_ms=estimated_time_ms,
exposed_time_ms=exposed_time_ms,
hiding_node=hiding_node,
hiding_nodes=hiding_nodes,
)
return collective_info
@ -567,6 +576,97 @@ class TestOverlapPreservingBucketing(InductorTestCase):
graph_str
)
def test_can_bucket_with_multiple_hiding_nodes(self):
"""
Test that collectives with multiple hiding nodes CAN bucket.
Graph structure:
ag1_start -> ag2_start -> mm1 -> mm2 -> mm3 -> ag1_wait -> ag2_wait
Where:
- ag1 is hidden by mm1 and mm2
- ag2 is hidden by mm2 and mm3
- Both collectives share mm2 as a hiding node
"""
def func(a, b):
group_name = "0"
group_size = 1
# Start both collectives
ag1 = torch.ops._c10d_functional.all_gather_into_tensor(
a, group_size, group_name
)
ag2 = torch.ops._c10d_functional.all_gather_into_tensor(
b, group_size, group_name
)
# Three compute operations that hide the collectives
mm1 = torch.mm(a, a)
mm2 = torch.mm(b, b)
mm3 = torch.mm(a + b, a + b)
# Wait for both
ag1_out = torch.ops._c10d_functional.wait_tensor(ag1)
ag2_out = torch.ops._c10d_functional.wait_tensor(ag2)
return ag1_out.sum() + ag2_out.sum() + mm1.sum() + mm2.sum() + mm3.sum()
# Use fake mode to trace without executing
with FakeTensorMode():
a = torch.ones(4, 4, device=self.device)
b = torch.ones(4, 4, device=self.device) * 2
# Trace with make_fx
traced = make_fx(func)(a, b)
# Find nodes using find_nodes
ag1, ag2 = traced.graph.find_nodes(
op="call_function",
target=torch.ops._c10d_functional.all_gather_into_tensor.default,
)
mm1, mm2, mm3 = traced.graph.find_nodes(
op="call_function", target=torch.ops.aten.mm.default
)
# Manually annotate hiding relationships with multiple hiding nodes
hiding_annotations = {
ag1: [mm1, mm2], # ag1 is hidden by mm1 and mm2
ag2: [mm2, mm3], # ag2 is hidden by mm2 and mm3
}
# Build collective info and ancestors
collective_info = build_collective_info(traced.graph, hiding_annotations)
node_ancestors = compute_ancestors(traced.graph)
scheduled = OrderedSet(traced.graph.nodes)
# Verify hiding_nodes are correctly set
self.assertEqual(len(collective_info[ag1].hiding_nodes), 2)
self.assertIn(mm1, collective_info[ag1].hiding_nodes)
self.assertIn(mm2, collective_info[ag1].hiding_nodes)
self.assertEqual(len(collective_info[ag2].hiding_nodes), 2)
self.assertIn(mm2, collective_info[ag2].hiding_nodes)
self.assertIn(mm3, collective_info[ag2].hiding_nodes)
# Run bucketing
from torch._inductor.fx_passes.overlap_preserving_bucketer import (
OverlapPreservingBucketer,
)
bucketer = OverlapPreservingBucketer(
traced.graph,
collective_info,
node_ancestors,
scheduled,
)
bucketer.bucket_collectives()
FileCheck().check_count(
"all_gather_into_tensor_out", 1, exactly=False
).check_count("torch.ops.aten.mm.default", 3, exactly=True).run(
str(traced.graph)
)
if __name__ == "__main__":
run_tests()

View File

@ -29,6 +29,8 @@ from torch.testing._internal.common_utils import (
MY_LAMBDA = lambda x: x + 1 # noqa: E731
EPS = torch.tensor(1e-7)
class CustomCompiledFunction(torch._dynamo.aot_compile.SerializableCallable):
def __init__(self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]):
@ -587,6 +589,18 @@ from user code:
actual = compiled_fn(fn, *inputs)
self.assertEqual(expected, actual)
def test_aot_compile_with_global_tensor(self):
def fn(x, y):
return x + y + EPS
def make_inputs():
return (torch.randn(3, 4), torch.randn(3, 4))
compiled_fn = torch.compile(fn, fullgraph=True).aot_compile((make_inputs(), {}))
test_inputs = make_inputs()
self.assertEqual(compiled_fn(*test_inputs), fn(*test_inputs))
def test_aot_compile_with_default_args(self):
def fn(x, y=1):
return x + x

View File

@ -330,6 +330,13 @@ y = FakeTensor(..., size=(2,))
'obj_weakref': None
'guarded_class': None
}
global '' GLOBAL_STATE
{
'guard_types': None,
'code': None,
'obj_weakref': None
'guarded_class': None
}
global '' TORCH_FUNCTION_STATE
{
'guard_types': None,

View File

@ -90,12 +90,12 @@ class GraphModule(torch.nn.Module):
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(u0)", primals_2: "Sym(u1)", primals_3: "Sym(u2)", primals_4: "f32[u0, u1, u2]"):
ge_1: "Sym(u0 >= 0)" = primals_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge_3: "Sym(u1 >= 0)" = primals_2 >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
ge_5: "Sym(u2 >= 0)" = primals_3 >= 0
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_5, "Runtime assertion failed for expression u2 >= 0 on node 'ge_2'"); ge_5 = _assert_scalar_2 = None
ge: "Sym(u0 >= 0)" = primals_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None
ge_1: "Sym(u1 >= 0)" = primals_2 >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None
ge_2: "Sym(u2 >= 0)" = primals_3 >= 0
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u2 >= 0 on node 'ge_2'"); ge_2 = _assert_scalar_2 = None
floordiv: "Sym((u0//2))" = primals_1 // 2

View File

@ -1214,7 +1214,7 @@ class TestGuardSerialization(TestGuardSerializationBase):
x = torch.randn(3, 2)
with torch.enable_grad():
ref, loaded = self._test_serialization("GRAD_MODE", fn, x)
ref, loaded = self._test_serialization("GLOBAL_STATE", fn, x)
with torch.no_grad():
self._test_check_fn(ref, loaded, {"x": x}, False)
with torch.enable_grad():
@ -1226,7 +1226,7 @@ class TestGuardSerialization(TestGuardSerializationBase):
x = torch.randn(3, 2)
with torch.enable_grad():
ref, _ = self._test_serialization("GRAD_MODE", fn, x)
ref, _ = self._test_serialization("GLOBAL_STATE", fn, x)
with torch.no_grad():
# Ensure guards state loading is not affected by the current global grad mode.
guards_state = pickle.loads(self._cached_guards_state)
@ -1246,7 +1246,7 @@ class TestGuardSerialization(TestGuardSerializationBase):
try:
x = torch.randn(3, 2)
torch.use_deterministic_algorithms(True)
ref, loaded = self._test_serialization("DETERMINISTIC_ALGORITHMS", fn, x)
ref, loaded = self._test_serialization("GLOBAL_STATE", fn, x)
torch.use_deterministic_algorithms(False)
self._test_check_fn(ref, loaded, {"x": x}, False)
torch.use_deterministic_algorithms(True)
@ -1270,6 +1270,9 @@ class TestGuardSerialization(TestGuardSerializationBase):
ref, loaded = self._test_serialization("TORCH_FUNCTION_STATE", fn, x)
self._test_check_fn(ref, loaded, {"x": x}, True)
self._test_check_fn(ref, loaded, {"x": x}, False)
with GlobalTorchFunctionMode():
ref, loaded = self._test_serialization("GLOBAL_STATE", fn, x)
self._test_check_fn(ref, loaded, {"x": x}, True)
with GlobalTorchFunctionMode():
with torch._C.DisableTorchFunction():
self._test_check_fn(ref, loaded, {"x": x}, False)
@ -1306,7 +1309,7 @@ class TestGuardSerialization(TestGuardSerializationBase):
x = torch.randn(3, 2)
with torch.enable_grad():
ref, loaded = self._test_serialization("FSDP_TRAINING_STATE", fn, x)
ref, loaded = self._test_serialization("GLOBAL_STATE", fn, x)
with torch.no_grad():
self._test_check_fn(ref, loaded, {"x": x}, False)
with torch.enable_grad():
@ -1690,6 +1693,38 @@ class TestGuardSerialization(TestGuardSerializationBase):
ref, loaded, {"x": x, "d": ModWithDict({"b": 1e-9, "a": 1e9})}, False
)
def test_global_state_guard_filter(self):
def foo(x):
return x + 1
x = torch.randn(3, 2)
with torch.no_grad():
compiled_fn = torch.compile(
foo, options={"guard_filter_fn": torch.compiler.skip_all_guards_unsafe}
)
compiled_fn(x)
# Check global guards are gone.
with torch.enable_grad(), torch.compiler.set_stance("fail_on_recompile"):
self.assertEqual(compiled_fn(x), foo(x))
def test_torch_function_state_filter(self):
def foo(x):
return x + 1
x = torch.randn(3, 2)
with GlobalTorchFunctionMode():
compiled_fn = torch.compile(
foo, options={"guard_filter_fn": torch.compiler.skip_all_guards_unsafe}
)
compiled_fn(x)
# Check global guards are gone.
with torch.compiler.set_stance("fail_on_recompile"):
self.assertEqual(compiled_fn(x), foo(x))
class SimpleModule(torch.nn.Module):
def __init__(self, c):

View File

@ -727,7 +727,7 @@ class GraphModule(torch.nn.Module):
x = torch.randn(3)
arg_count = ifdynstaticdefault(4, 5)
# when compiled with dynamic, we don't have upper bound runtime assertions for u0
expected_op_count = ifdynstaticdefault(10, 8)
expected_op_count = ifdynstaticdefault(9, 7)
out_graph = self._test_wrap_simple(
f,
default_args_generator((x,)),
@ -747,7 +747,6 @@ class GraphModule(torch.nn.Module):
c: "i64[u0, 1]" = l_x_.nonzero()
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
_check_is_size = torch._check_is_size(sym_size_int_1); _check_is_size = None
ge: "Sym(u0 >= 0)" = sym_size_int_1 >= 0
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
@ -784,7 +783,6 @@ class GraphModule(torch.nn.Module):
c: "i64[u0, 1]" = l_x_.nonzero()
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
_check_is_size = torch._check_is_size(sym_size_int_1); _check_is_size = None
ge: "Sym(u0 >= 0)" = sym_size_int_1 >= 0
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
@ -883,7 +881,7 @@ class GraphModule(torch.nn.Module):
x = torch.randn(3)
arg_count = ifdynstaticdefault(4, 5)
# when compiled with dynamic, we don't have upper bound runtime assertions for u0
expected_op_count = ifdynstaticdefault(10, 8)
expected_op_count = ifdynstaticdefault(9, 7)
out_graph = self._test_wrap_simple(
f,
default_args_generator((x,)),
@ -905,7 +903,6 @@ class GraphModule(torch.nn.Module):
c: "i64[u0, 1]" = l_x_.nonzero()
sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
_check_is_size = torch._check_is_size(sym_size_int); _check_is_size = None
ge: "Sym(u0 >= 0)" = sym_size_int >= 0
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
@ -956,7 +953,7 @@ class GraphModule(torch.nn.Module):
y = torch.randn(3)
arg_count = ifdynstaticdefault(5, 6)
# when compiled with dynamic, we don't have upper bound runtime assertions for u0 and u1
expected_op_count = ifdynstaticdefault(17, 13)
expected_op_count = ifdynstaticdefault(15, 11)
out_graph = self._test_wrap_simple(
f,
default_args_generator((x, y)),
@ -977,7 +974,6 @@ class GraphModule(torch.nn.Module):
c: "i64[u0, 1]" = l_x_.nonzero()
sym_size_int_2: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
_check_is_size = torch._check_is_size(sym_size_int_2); _check_is_size = None
ge: "Sym(u0 >= 0)" = sym_size_int_2 >= 0
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
@ -987,7 +983,6 @@ class GraphModule(torch.nn.Module):
d: "i64[u1, 1]" = l_y_.nonzero(); l_y_ = None
sym_size_int_3: "Sym(u1)" = torch.ops.aten.sym_size.int(d, 0)
_check_is_size_1 = torch._check_is_size(sym_size_int_3); _check_is_size_1 = None
ge_1: "Sym(u1 >= 0)" = sym_size_int_3 >= 0
_assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_default_2 = None

View File

@ -244,6 +244,61 @@ class MiscTests(torch._inductor.test_case.TestCase):
self.assertTrue(same(val4, correct1))
self.assertEqual(counter.frame_count, 3)
@unittest.skipIf(not TEST_CUDA, "cuda needed")
def test_assume_32_bit_indexing(self):
@torch.compile(backend="inductor")
def func(a, b):
# Multiple concat operations
x = torch.concat([a, b], dim=0)
y = torch.concat([a, b], dim=1)
# Reshape to create indexing patterns
x_flat = x.reshape(-1)
y_flat = y.reshape(-1)
# Take the smaller one and expand
min_size = min(x_flat.shape[0], y_flat.shape[0])
x_trunc = x_flat[:min_size]
y_trunc = y_flat[:min_size]
# Combine and compute
result = (x_trunc + y_trunc) * 10
# Cumulative operations create complex indexing
cumsum = result.cumsum(dim=0)
return cumsum.sum()
a = torch.rand(100, 30, device="cuda")
b = torch.rand(100, 30, device="cuda")
torch._dynamo.decorators.mark_unbacked(a, 0)
torch._dynamo.decorators.mark_unbacked(a, 1)
torch._dynamo.decorators.mark_unbacked(b, 0)
torch._dynamo.decorators.mark_unbacked(b, 1)
source_code = run_and_get_code(func, a, b)[1]
self.assertTrue(
"xindex = xoffset + tl.arange(0, XBLOCK)[:].to(tl.int64)\\n"
in str(source_code)
)
self.assertFalse(
"xindex = xoffset + tl.arange(0, XBLOCK)[:]\\n" in str(source_code)
)
torch._dynamo.reset()
with torch._inductor.config.patch(assume_32bit_indexing=True):
source_code = run_and_get_code(func, a, b)[1]
self.assertFalse(
"xindex = xoffset + tl.arange(0, XBLOCK)[:].to(tl.int64)\\n"
in str(source_code)
)
self.assertTrue(
"xindex = xoffset + tl.arange(0, XBLOCK)[:]\\n" in str(source_code)
)
def test_dynamo_inside_custom_op(self):
cnt = torch._dynamo.testing.InductorAndRecordGraphs()
cnt1 = torch._dynamo.testing.InductorAndRecordGraphs()

View File

@ -562,7 +562,7 @@ class TestDynamoTimed(TestCase):
'graph_node_count': 3,
'graph_node_shapes': None,
'graph_op_count': 1,
'guard_count': 9,
'guard_count': 10,
'has_guarded_code': True,
'inductor_code_gen_cumulative_compile_time_us': 0,
'inductor_compile_time_s': 0.0,
@ -608,7 +608,7 @@ class TestDynamoTimed(TestCase):
'tensorify_float_attempt': None,
'tensorify_float_failure': None,
'tensorify_float_success': None,
'triton_compile_time_us': None,
'triton_compile_time_us': 0,
'triton_kernel_compile_times_us': None,
'triton_version': None}"""
if _IS_WINDOWS
@ -649,7 +649,7 @@ class TestDynamoTimed(TestCase):
'graph_node_count': 3,
'graph_node_shapes': None,
'graph_op_count': 1,
'guard_count': 9,
'guard_count': 10,
'has_guarded_code': True,
'inductor_code_gen_cumulative_compile_time_us': 0,
'inductor_compile_time_s': 0.0,
@ -920,7 +920,7 @@ class TestDynamoTimed(TestCase):
first, second = {
(3, 9): (10, 6),
(3, 10): (10, 6),
(3, 11): (10, 6),
(3, 11): (11, 7),
(3, 12): (11, 7),
(3, 13): (11, 7),
(3, 14): (11, 7),

View File

@ -3,6 +3,7 @@
import copy
import types
import unittest
import warnings
from dataclasses import dataclass
from typing import Dict, List, Tuple
@ -18,6 +19,9 @@ from torch.testing import FileCheck
from torch.testing._internal.common_utils import TEST_CUDA
GLOBAL_LIST = []
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't supported")
class TestExperiment(TestCase):
def test_joint_basic(self) -> None:
@ -585,9 +589,9 @@ def forward(self, args_0):
_tree_leaf_0, _tree_leaf_1, = pytree.tree_leaves((self, args_0,))
L_args_0_ , = self._in_shuffle_graph(_tree_leaf_0, _tree_leaf_1)
l_args_0_ = L_args_0_
add = l_args_0_ + 1
add = l_args_0_ + 1; add = None
mul = l_args_0_ * 2; l_args_0_ = None
return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, mul, add), self._out_spec)""",
return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, mul), self._out_spec)""",
)
self.assertEqual(gm(*test_inputs), foo(*test_inputs))
@ -611,6 +615,34 @@ def forward(self, args_0):
self.assertEqual(len(list(gm.buffers())), len(list(foo.buffers())))
self.assertEqual(len(list(gm.parameters())), len(list(foo.parameters())))
def test_dynamo_graph_capture_side_effects(self):
GLOBAL_LIST.clear()
def foo(x):
z = x + 1
GLOBAL_LIST.append(z)
return z
def make_inputs():
return (torch.randn(2, 3),)
trace_inputs = make_inputs()
with warnings.catch_warnings(record=True) as w:
gm = dynamo_graph_capture_for_export(foo)(*trace_inputs)
cnt = 0
for entry in w:
if "While compiling, we found certain side effects happened" in str(
entry.message
):
cnt += 1
self.assertEqual(cnt, 1)
self.assertEqual(len(GLOBAL_LIST), 0)
test_inputs = make_inputs()
gm_results = gm(*test_inputs)
self.assertEqual(len(GLOBAL_LIST), 0)
self.assertEqual(gm_results, foo(*test_inputs))
self.assertEqual(len(GLOBAL_LIST), 1)
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
def test_dynamo_graph_capture_fx_graph_annotate_overlap_pass(self):
class DummyOp(torch.autograd.Function):

View File

@ -740,18 +740,26 @@ class TestExport(TestCase):
dynamic_shapes={"x": {0: Dim("b")}, "y": None},
)
# clean up _torchdynamo related meta data as it could vary depending on the caller
# https://github.com/pytorch/pytorch/issues/167432
for node in ep.graph.nodes:
if "custom" in node.meta:
node.meta["custom"] = {
k: v
for k, v in node.meta["custom"].items()
if "_torchdynamo_disable" not in k
}
custom_metadata = torch.fx.traceback._get_custom_metadata(ep.module())
self.assertExpectedInline(
str(custom_metadata),
"""\
('placeholder', 'x', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'})
('placeholder', 'y', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'})
('call_function', 'cat', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
('call_function', 'item', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
('call_function', 'ge_1', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
('call_function', '_assert_scalar_default', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
('call_function', 'mul', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
('output', 'output', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'})""",
('call_function', 'cat', {'moo': 0})
('call_function', 'item', {'moo': 0})
('call_function', 'ge_1', {'moo': 0})
('call_function', '_assert_scalar_default', {'moo': 0})
('call_function', 'mul', {'moo': 0})""",
)
@requires_gpu
@ -3073,15 +3081,12 @@ def forward(self, x, y):
foo = torch.ops.export.foo.default(x, y); x = None
sym_size_int = torch.ops.aten.sym_size.int(foo, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(foo, 1)
sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int); sym_constrain_range_for_size_default = None
ge = sym_size_int >= 0; sym_size_int = None
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
sym_constrain_range_for_size_default_1 = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_1); sym_constrain_range_for_size_default_1 = None
ge_1 = sym_size_int_1 >= 0; sym_size_int_1 = None
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_default_1 = None
bar = torch.ops.export.bar.default(y); y = None
sym_size_int_2 = torch.ops.aten.sym_size.int(bar, 0)
sym_constrain_range_for_size_default_2 = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_2); sym_constrain_range_for_size_default_2 = None
ge_2 = sym_size_int_2 >= 0; sym_size_int_2 = None
_assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u2 >= 0 on node 'ge_2'"); ge_2 = _assert_scalar_default_2 = None
return (foo, bar)""",
@ -15303,12 +15308,12 @@ graph():
def forward(self, block):
return block.a + block.b
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
with self.assertRaisesRegex(
torch._dynamo.exc.UserError, "It looks like one of the inputs with type"
):
_dynamo_graph_capture_for_export(Foo())(
dynamo_graph_capture_for_export(Foo())(
Block(torch.randn(4, 4), torch.randn(4, 4))
)
@ -17735,7 +17740,6 @@ class TestExportCustomClass(TorchTestCase):
def forward(self, x, mask):
masked_select = torch.ops.aten.masked_select.default(x, mask); x = mask = None
sym_size_int_1 = torch.ops.aten.sym_size.int(masked_select, 0)
sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_1); sym_constrain_range_for_size_default = None
ge = sym_size_int_1 >= 0
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
le = sym_size_int_1 <= 1188864

View File

@ -1,71 +0,0 @@
# Owner(s): ["oncall: export"]
from torch._dynamo import config as dynamo_config
from torch._dynamo.testing import make_test_cls_with_patches
from torch._export import config as export_config
try:
from . import test_export, testing
except ImportError:
import test_export # @manual=fbcode//caffe2/test:test_export-library
import testing # @manual=fbcode//caffe2/test:test_export-library
from torch.export import export
test_classes = {}
def mocked_strict_export(*args, **kwargs):
# If user already specified strict, don't make it strict
if "strict" in kwargs:
return export(*args, **kwargs)
return export(*args, **kwargs, strict=True)
def make_dynamic_cls(cls):
# Some test check for ending in suffix; need to make
# the `_strict` for end of string as a result
suffix = test_export.INLINE_AND_INSTALL_STRICT_SUFFIX
cls_prefix = "InlineAndInstall"
cls_a = testing.make_test_cls_with_mocked_export(
cls,
"StrictExport",
suffix,
mocked_strict_export,
xfail_prop="_expected_failure_strict",
)
test_class = make_test_cls_with_patches(
cls_a,
cls_prefix,
"",
(export_config, "use_new_tracer_experimental", True),
(dynamo_config, "install_free_tensors", True),
(dynamo_config, "inline_inbuilt_nn_modules", True),
xfail_prop="_expected_failure_inline_and_install",
)
test_classes[test_class.__name__] = test_class
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
globals()[test_class.__name__] = test_class
test_class.__module__ = __name__
return test_class
tests = [
test_export.TestDynamismExpression,
test_export.TestExport,
]
for test in tests:
make_dynamic_cls(test)
del test
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -4,6 +4,7 @@ from unittest.mock import patch
import torch
from torch._dynamo.utils import counters
from torch._functorch.aot_autograd import aot_export_module
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import run_tests, TestCase
@ -90,6 +91,99 @@ def forward(self, arg0_1):
self.assertEqual(printed_output, f"moo 1 2\nmoo {new_inp}\nmoo 1 2\nyeehop 4")
def test_print_with_side_effect(self):
class M(torch.nn.Module):
def forward(self, x):
torch._higher_order_ops.print("moo {x} {y}", x=1, y=2)
res = x + x
torch._higher_order_ops.print("moo {x} {y}", x=1, y=2)
return (res,)
inputs = (torch.randn(3),)
# With functionalization, it should appear wrapped with with_effects()
gm, gs = aot_export_module(M(), inputs, trace_joint=False)
self.assertExpectedInline(
str(gm.code).strip(),
"""\
def forward(self, arg0_1, arg1_1):
with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.higher_order.print, 'moo {x} {y}', x = 1, y = 2); \
arg0_1 = None
getitem = with_effects[0]; with_effects = None
add = torch.ops.aten.add.Tensor(arg1_1, arg1_1); arg1_1 = None
with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.higher_order.print, 'moo {x} {y}', x = 1, y = 2); \
getitem = None
getitem_2 = with_effects_1[0]; with_effects_1 = None
return (getitem_2, add)""",
)
self.assertEqual(len(gs.input_tokens), 1)
self.assertEqual(len(gs.output_tokens), 1)
def test_print_with_input_mutations(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
torch._higher_order_ops.print("moo {x} {y}", x=x, y=2)
res = x + x
x.add_(res)
res = x + x
torch._higher_order_ops.print("moo {x} {y}", x=x, y=res)
return (res,)
inputs = (torch.randn(3),)
# With functionalization, it should appear wrapped with with_effects()
gm, gs = aot_export_module(M(), inputs, trace_joint=False)
self.assertEqual(len(gs.input_tokens), 1)
self.assertEqual(len(gs.output_tokens), 1)
self.assertEqual(len(gs.user_inputs_to_mutate), 1)
self.assertExpectedInline(
str(gm.code).strip(),
"""\
def forward(self, arg0_1, arg1_1):
with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.higher_order.print, 'moo {x} {y}', \
x = arg1_1, y = 2); arg0_1 = None
getitem = with_effects[0]; with_effects = None
add = torch.ops.aten.add.Tensor(arg1_1, arg1_1)
add_1 = torch.ops.aten.add.Tensor(arg1_1, add); arg1_1 = add = None
add_2 = torch.ops.aten.add.Tensor(add_1, add_1)
with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.higher_order.print, 'moo {x} {y}', \
x = add_1, y = add_2); getitem = None
getitem_2 = with_effects_1[0]; with_effects_1 = None
return (getitem_2, add_1, add_2)""",
)
def test_print_gen_schema(self):
from torch._higher_order_ops.print import print as print_op
# Test basic schema generation with simple kwargs int
format_str = "Hello {x} {y}"
schema = print_op.gen_schema(format_str, x=1, y=2)
self.assertExpectedInline(
str(schema),
"""print(str format_str, *, int x, int y) -> ()""",
)
# Test schema generation with different types of inputs
# Tensor input
tensor = torch.randn(2, 2)
schema_tensor = print_op.gen_schema("Tensor: {x}", x=tensor)
self.assertExpectedInline(
str(schema_tensor),
"""print(str format_str, *, Tensor x) -> ()""",
)
# TODO: Add schema support with kwargs with value of list type
# No kwargs
schema_no_kwargs = print_op.gen_schema("Simple message")
self.assertExpectedInline(
str(schema_no_kwargs),
"""print(str format_str) -> ()""",
)
if __name__ == "__main__":
run_tests()

View File

@ -1492,8 +1492,8 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"):
clone: "f32[s77][1]cpu" = torch.ops.aten.clone.default(arg1_1)
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(clone); clone = None
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None
_to_copy: "f32[u0, 1][1, u0]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_alias = True, _y_base_index = 1, _y_alias = True, _all_bases = [arg1_1, _to_copy]); _to_copy = None
getitem_1: "f32[s77][1]cpu" = auto_functionalized_v2[1]
@ -1513,8 +1513,8 @@ def forward(self, arg0_1: "f32[2][1]cpu"):
clone: "f32[2][1]cpu" = torch.ops.aten.clone.default(arg0_1)
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(clone); clone = None
sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
ge_1: "Sym(u0 >= 0)" = sym_size_int >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge: "Sym(u0 >= 0)" = sym_size_int >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None
le: "Sym(u0 <= 2)" = sym_size_int <= 2; sym_size_int = None
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 2 on node 'le'"); le = _assert_scalar_1 = None
_to_copy: "f32[u0, 1][1, u0]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None
@ -1538,8 +1538,8 @@ def forward(self, arg0_1: "f32[2][1]cpu"):
def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"):
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(arg1_1)
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None
convert_element_type: "f32[u0, 1][1, u0]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None
alias_default: "f32[s77][1]cpu" = torch.ops.aten.alias.default(arg1_1)
alias_default_1: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.alias.default(convert_element_type)
@ -1557,8 +1557,8 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"):
def forward(self, arg0_1: "f32[2][1]cpu"):
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(arg0_1)
sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
ge_1: "Sym(u0 >= 0)" = sym_size_int >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge: "Sym(u0 >= 0)" = sym_size_int >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None
le: "Sym(u0 <= 2)" = sym_size_int <= 2; sym_size_int = None
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 2 on node 'le'"); le = _assert_scalar_1 = None
convert_element_type: "f32[u0, 1][1, u0]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None

Some files were not shown because too many files have changed in this diff Show More