Compare commits

..

171 Commits

Author SHA1 Message Date
7895ff12a7 Test push
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
2025-07-31 23:20:29 -07:00
25ef3d315d [aoti][mps] Dynamic reductions (#159355)
Dynamic kernel:
```cpp
[[max_total_threads_per_threadgroup(1024)]]
kernel void generated_kernel(
    device float* out_ptr0,
    constant float* in_ptr0,
    constant long& r0_numel,
    uint2 thread_pos [[thread_position_in_grid]],
    uint2 group_pos [[thread_position_in_threadgroup]]
) {
    auto xindex = thread_pos.x;
    auto r0_index = thread_pos.y;
    int x0 = xindex;
    threadgroup float tmp_acc_0[32];
    float tmp_acc_1 = 0;
    for(auto r0_1_cnt = 0; r0_1_cnt < static_cast<int>(metal::floor(static_cast<float>(0.99902343750000000 + 0.00097656250000000000*r0_numel))); ++r0_1_cnt) {
        int r0_1 = 1024 * r0_1_cnt + r0_index;
        if (r0_1 >= r0_numel) break;
        auto tmp0 = in_ptr0[x0 + 5*r0_1];
        tmp_acc_1 += tmp0;
    }
    auto tmp1 = c10:🤘:threadgroup_sum(tmp_acc_0, tmp_acc_1, r0_index * 1, metal::min(static_cast<decltype(1024+r0_numel)>(1024), static_cast<decltype(1024+r0_numel)>(r0_numel)));
    if (r0_index == 0) out_ptr0[x0] = static_cast<float>(tmp1);
}

void AOTInductorModel::run_impl(...) {
    ...
    auto arg0_1_size = arg0_1.sizes();
    int64_t s77 = arg0_1_size[0];
    inputs.clear();
    [[maybe_unused]] auto& kernels = static_cast<AOTInductorModelKernels&>(*this->kernels_.get());
    static constexpr int64_t int_array_0[] = {5LL, };
    static constexpr int64_t int_array_1[] = {1LL, };
    AtenTensorHandle buf0_handle;
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, int_array_0, int_array_1, cached_torch_dtype_float32, cached_torch_device_type_mps, this->device_idx_, &buf0_handle));
    RAIIAtenTensorHandle buf0(buf0_handle);
    auto mps_lib_0_func = mps_lib_0.getKernelFunction("generated_kernel");
    auto mps_lib_0_func_handle = AOTIMetalKernelFunctionHandle(mps_lib_0_func.get());
    mps_lib_0_func->runCommandBlock([&] {
        mps_lib_0_func->startEncoding();
        aoti_torch_mps_set_arg_tensor(mps_lib_0_func_handle, 0, buf0);
        aoti_torch_mps_set_arg_tensor(mps_lib_0_func_handle, 1, arg0_1);
        aoti_torch_mps_set_arg_int(mps_lib_0_func_handle, 2, s77);
        mps_lib_0_func->dispatch({static_cast<uint64_t>(5LL), static_cast<uint64_t>(std::min(static_cast<int64_t>(1024LL), static_cast<int64_t>(s77)))}, {static_cast<uint64_t>(1), static_cast<uint64_t>(std::min(static_cast<int64_t>(1024LL), static_cast<int64_t>(s77)))});

    });
    arg0_1.reset();
    output_handles[0] = buf0.release();
} // AOTInductorModel::run_impl
```

Static kernel:
```cpp
kernel void generated_kernel(
    device float* out_ptr0,
    constant float* in_ptr0,
    uint xindex [[thread_position_in_grid]]
) {
    int x0 = xindex;
    auto tmp0 = in_ptr0[x0];
    auto tmp1 = in_ptr0[5 + x0];
    auto tmp3 = in_ptr0[10 + x0];
    auto tmp5 = in_ptr0[15 + x0];
    auto tmp2 = tmp0 + tmp1;
    auto tmp4 = tmp2 + tmp3;
    auto tmp6 = tmp4 + tmp5;
    out_ptr0[x0] = static_cast<float>(tmp6);
}

void AOTInductorModel::run_impl(...) {
    ...
    static constexpr int64_t int_array_0[] = {5LL, };
    static constexpr int64_t int_array_1[] = {1LL, };
    AtenTensorHandle buf0_handle;
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, int_array_0, int_array_1, cached_torch_dtype_float32, cached_torch_device_type_mps, this->device_idx_, &buf0_handle));
    RAIIAtenTensorHandle buf0(buf0_handle);
    auto mps_lib_0_func = mps_lib_0.getKernelFunction("generated_kernel");
    auto mps_lib_0_func_handle = AOTIMetalKernelFunctionHandle(mps_lib_0_func.get());
    mps_lib_0_func->runCommandBlock([&] {
        mps_lib_0_func->startEncoding();
        aoti_torch_mps_set_arg_tensor(mps_lib_0_func_handle, 0, buf0);
        aoti_torch_mps_set_arg_tensor(mps_lib_0_func_handle, 1, arg0_1);
        mps_lib_0_func->dispatch({static_cast<uint64_t>(5LL)});

    });
    arg0_1.reset();
    output_handles[0] = buf0.release();
} // AOTInductorModel::run_impl
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159355
Approved by: https://github.com/malfet
2025-07-31 23:15:02 +00:00
7e00f2ec9d [AOTI] add zero size consts asm handler (#159225)
Add `get_zero_consts_asm_code` to handle zero size consts to object.
This function is used to handle zero consts situation. Because cpp standard does not allow zero size array:
https://stackoverflow.com/questions/9722632/what-happens-if-i-define-a-0-size-array-in-c-c
1. On Windows, MSVC will report error C2466:
https://learn.microsoft.com/en-us/cpp/error-messages/compiler-errors-1/compiler-error-c2466?view=msvc-170
So, we can use assmbely compiler to handle this situation.
2. On Windows, why not use Win32 asm to handle all path? Because ml64 only supports up to align `16`, it is
not aligned to pytorch's `64`. Reference: https://learn.microsoft.com/en-us/cpp/assembler/masm/ml-and-ml64-command-line-reference?view=msvc-170
```
Packs structures on the specified byte boundary. The alignment can be 1, 2, 4, 8, or 16.
```
3. It function can handle zero size case on both Windows and Linux, as that:
    A. On Linux, we added `-pedantic` to disable zero size array on C++ compiler. 8e07c9870d/torch/_inductor/cpp_builder.py (L580)
    B. On Windows, msvc is not support zero size array by default.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159225
Approved by: https://github.com/desertfire
2025-07-31 22:46:33 +00:00
490cb3f1a4 Revert "[inductor] Add logging for distributed collective ops for multi‑rank diagnostics (#159190)"
This reverts commit bb62e1f769ef51e2ec149d7256c135d09425aaa0.

Reverted https://github.com/pytorch/pytorch/pull/159190 on behalf of https://github.com/clee2000 due to broke [GH job link](https://github.com/pytorch/pytorch/actions/runs/16658705097/job/47150840171) [HUD commit link](bb62e1f769) on mac ([comment](https://github.com/pytorch/pytorch/pull/159190#issuecomment-3141513921))
2025-07-31 22:22:13 +00:00
b95cf5c91d Move complex to headeronly (#159411)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159411
Approved by: https://github.com/albanD
ghstack dependencies: #159415
2025-07-31 22:05:43 +00:00
5e2ef2a465 Move Float8 variations to headeronly (#159415)
This PR is a big copy pasta from `c10/util/Float8*` -> `torch/headeronly/util/` which is why we are breaking PR sanity :C (sorry @albanD!).

Why is it not a clean copy paste?
- For BC reasons, we have to keep the old c10 file around so that OSS devs relying on those files can still get the same APIs
- Because we reexpose APIs that are headeronly through torch::headeronly, so there is an extra chunk of code in the new torch::headeronly files to do that.

Outside of the copy paste, I:
- changed the tests to call torch::headeronly instead of c10
- updated header_only_apis.txt
- added `// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)` to pass lint (which was previously skipped for -inl.h files)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159415
Approved by: https://github.com/albanD
2025-07-31 22:05:43 +00:00
9f753f8c0d [DTensor] Improve sort strategy (#159189)
- Sort strategy now supports sharding on non sorted dim.
~~- Fix histc xfail.~~
  - ~~Previously `python test/distributed/tensor/test_dtensor_ops.py TestDTensorOpsCPU.test_dtensor_op_db_histc_cpu_float32` will fail with `PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=18`. However, if we run `PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=18 python test/distributed/tensor/test_dtensor_ops.py TestDTensorOpsCPU.test_dtensor_op_db_histc_cpu_float32`, the test will pass. This kind of error is due to DTensor reuses the strategy schema hashing. It turns out that not only the strategy,  the result correctness also depends on `static_argnum` or the op will reuse the previous args from hashed schema and output wrong results. I updated the document also.~~ (fixed in https://github.com/pytorch/pytorch/pull/159289)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159189
Approved by: https://github.com/XilunWu
2025-07-31 21:52:42 +00:00
db437690d1 Add myself as a reviewer for when someone touches headeronly or stable (#159583)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159583
Approved by: https://github.com/mikaylagawarecki
2025-07-31 21:30:05 +00:00
669009bcd1 [inductor] respect layout tags for ops with registered lowerings (#159134)
scaled_grouped_mm's kernel only supports column-major on the second operand. I -think- this is just for efficiency reasons. But inductor treats that buffer as flexible and may tweak the strides to be row-major instead, as seen in the issue.

~Tagging the op as "needs_fixed_stride_order"/"needs_exact_strides" does not work. Inductor only considers those tags for ops that don't have registered lowering (not sure if this is intended). scaled_grouped_mm does have a lowering, so we never check its tags.~ From discussion below, the op tags are expected to work.

FIXES https://github.com/pytorch/pytorch/issues/159097

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159134
Approved by: https://github.com/eellison
2025-07-31 21:29:40 +00:00
e4e2701429 Add the RunLLM widget to the website (#152055)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152055
Approved by: https://github.com/albanD
2025-07-31 20:53:53 +00:00
64cc649275 [itertools] Fix accumulate (#158774)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158774
Approved by: https://github.com/guilhermeleobas, https://github.com/zou3519
2025-07-31 20:32:02 +00:00
b1fb552974 Revert "Fix ep deepcopy when there is python builitin name (#159478)"
This reverts commit de7376537f2a11783169fee2b3bc276d266898bf.

Reverted https://github.com/pytorch/pytorch/pull/159478 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/159478#issuecomment-3141228423))
2025-07-31 20:20:53 +00:00
bb62e1f769 [inductor] Add logging for distributed collective ops for multi‑rank diagnostics (#159190)
This change introduces structured logging of the collective communication schedule, enabling downstream tools (e.g. TLParse) to ingest and analyze per‑rank collective‐order information for multi‑rank jobs.

- Iterates over scheduler.nodes, filters for _CollectiveKernel nodes
- Extracts each op’s python_kernel_name
- Emits a structured JSON payload under the inductor_collective_schedule artifact name
- Dumps the full schedule list to collective_schedule.json via the PyTorch trace‑structured artifact
- Added comprehensive unit tests for collective schedule tracing: Created test_collective_schedule_empty() and test_collective_schedule_real() tests to verify structured trace logging works correctly for both empty collective schedules and real collective operations (like all_reduce and wait_tensor from _c10d_functional ops).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159190
Approved by: https://github.com/yushangdi, https://github.com/xmfan
2025-07-31 19:58:07 +00:00
327e2ca580 [ez] get rid of unused var (#159571)
Summary: att

Test Plan:
ci

Rollback Plan:

Differential Revision: D79320299

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159571
Approved by: https://github.com/houseroad, https://github.com/georgiaphillips
2025-07-31 19:11:57 +00:00
1ebcba4e1b Fix typo in link to torch memory_viz tool (#159214)
Fixes a small typo in the torch_cuda_memory docs

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159214
Approved by: https://github.com/yewentao256, https://github.com/HDCharles, https://github.com/Skylion007
2025-07-31 18:50:54 +00:00
5f7eae697d Deprecate DataLoader pin_memory_device param (#158323)
Build on top of https://github.com/pytorch/pytorch/pull/146821

- Moves enabling pin_memory back inside `_BaseDataLoaderIter`
  - This is required for `StatefulDataloader` which leveraged  `_BaseDataLoaderIter` directly and not the `Dataloader` class init
- Add a simple test for CPU only env where setting `pin_memory=True` is a no-op.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158323
Approved by: https://github.com/ramanishsingh

Co-authored-by: zeshengzong <zesheng.zong@outlook.com>
2025-07-31 18:42:07 +00:00
c1722db0f7 [NativeRT] Make VariadicOpConverter and FuseListUnpackConverter for cpu nodes only (#159519)
Summary:
VariadicOpConverter and FuseListUnpackConverter would introduce ops that only have CPU kernels.

Currently, the graph passes are ran if static_dispatch is enabled.

As we plan to enable static_dispatch by default, this diff add the additional check for the graph pass to only work on the node that has all the inputs/outputs on CPU.

Test Plan:
CI

Rollback Plan:

Differential Revision: D79295640

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159519
Approved by: https://github.com/dolpm, https://github.com/henryoier
2025-07-31 18:17:21 +00:00
8a233d6000 Revert "[ContextParallel][FlexAttention] Prototype of supporting FlexAttention in Context Parallel (#158692)"
This reverts commit 07fad04181321d18963b71e9566d44f86a25c9f7.

Reverted https://github.com/pytorch/pytorch/pull/158692 on behalf of https://github.com/yangw-dev due to failed some internal testapf.metrics.tests.generate_graph_def_test.GenerateGraphDefTest: test_aps_generate_inference_graph_def_with_justknobs1) AssertionError: Expected 'check' to be called once. Called 3 times., please fix the internal test and reland it ([comment](https://github.com/pytorch/pytorch/pull/158692#issuecomment-3140873894))
2025-07-31 18:00:30 +00:00
bf3ebd7ad4 Fix grouped MM load along K when TMA loads are not used (#159485)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159485
Approved by: https://github.com/ngimel
2025-07-31 17:58:02 +00:00
c07bb277a0 Revert "fix strategy hashing arg mismatch (#159506)"
This reverts commit 3a556762002ec0027b2120a7e6675182c0e50dbd.

Reverted https://github.com/pytorch/pytorch/pull/159506 on behalf of https://github.com/yangw-dev due to failed the internal tests test_get_bwd_hook (torch.equal(output * 2, input_tensor.grad)) ([comment](https://github.com/pytorch/pytorch/pull/159506#issuecomment-3140858905))
2025-07-31 17:54:29 +00:00
f89c28cc6b [inductor] add lowering for repeat_interleave.Tensor with output size specified (#147160) (#158462)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158462
Approved by: https://github.com/eellison
2025-07-31 17:00:32 +00:00
8fedcfa59a [export] _ccode for PythonMod (#158851)
Summary: Adds ccode impl to PythonMod

Test Plan:
test_export

Rollback Plan:

Differential Revision: D76463347

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158851
Approved by: https://github.com/kalpit-meta-1
2025-07-31 16:46:51 +00:00
6662a76f59 [cutlass backend] Fix EVT tests post buf name change (#159541)
Differential Revision: [D79317791](https://our.internmc.facebook.com/intern/diff/D79317791/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159541
Approved by: https://github.com/mlazos
2025-07-31 16:39:49 +00:00
eqy
05aade1b6d [CUDA] Add serialTest decorator to largeTensorTest in test_cuda.py (#159271)
Hopefully helps with disabled tests due to OOM such as https://github.com/pytorch/pytorch/issues/159069

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159271
Approved by: https://github.com/Skylion007, https://github.com/ngimel
2025-07-31 16:27:16 +00:00
f946b25865 [MPS] Speedup argmax/argmin (#159524)
By using efficient `threadgroup_arg[max|min]` primitives.
- Fixed bug in `simd_argmax` when result of the `simd_ballot` were prematurely cast to `ushort` and adjusted unit test
- Fixed nan handling in compiled argmax, but can't reliably test it as MPS(eager) implementaiton of argmax is buggy

Now according to `bench_mps_ops.py` `max(x, dim=0)` is reliably faster than eager implementaiton:
```
[---------------------------------------------------------------------------------------------  --------------------------------------------------------------------------------------------]
                           |  eager-512x512  |  compile-512x512  |  eager-1024x1024  |  compile-1024x1024  |  eager-2048x2048  |  compile-2048x2048  |  eager-4096x4096  |  compile-4096x4096
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      max (torch.float16)  |      285.8      |       272.2       |       422.3       |        354.5        |       721.6       |        683.5        |       2224.0      |        1979.1
      max (torch.float32)  |      300.2      |       267.0       |       389.6       |        342.5        |       769.4       |        682.6        |       2995.7      |        2609.8
      max (torch.int32)    |      299.6      |       275.4       |       390.0       |        361.7        |       758.7       |        686.1        |       3103.4      |        2646.5
      max (torch.int64)    |      297.5      |       275.5       |       417.0       |        382.1        |       856.1       |        722.6        |       5467.7      |        3156.8

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159524
Approved by: https://github.com/Skylion007, https://github.com/dcci
ghstack dependencies: #158990
2025-07-31 16:18:32 +00:00
d2e02585b8 [AOTI] Explicitly delete wait_tensor returned tensor (#159502)
Summary: In the Python wrapper codegen, the returned tensor from wait_tensor is not assigned or used anywhere, because wait_tensor always returns its input, see more discussion in https://github.com/pytorch/pytorch/issues/126773. Similarly, we should just immediately delete the returned tensor handle from aoti_torch_cpu__c10d_functional_wait_tensor in the cpp wrapper codegen, otherwise it may cause tensor's lifetime expansion and even cause OOM in some cases.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159502
Approved by: https://github.com/yushangdi, https://github.com/jingsh
ghstack dependencies: #159476, #159487
2025-07-31 15:33:36 +00:00
3dd7ebf418 [BE] Fix buf name mismatch in test_c10d_functional_native.py (#159487)
Summary: test_c10d_functional_native.py uses hard-coded buf names to check the generated code string. This is fragile given that Inductor can update its buffer naming implementation freely. Thus this PR uses name regex matching to find buffer names at the run time. This will solve issues like https://github.com/pytorch/pytorch/issues/147754. Currently we do name matching based on empty_strided_ calls. We can expand it later if needed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159487
Approved by: https://github.com/yushangdi
ghstack dependencies: #159476
2025-07-31 15:33:36 +00:00
8273ee0646 [BE] Fix global config leak in test_c10d_functional_native.py (#159476)
Summary: test_c10d_functional_native.py tests torch._inductor.config.cpp_wrapper as True and False. Currently torch._inductor.config.cpp_wrapper is set globally which can cause a problem when running the whole test file. This PR changes it to use patch context.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159476
Approved by: https://github.com/yushangdi
2025-07-31 15:33:36 +00:00
c57382a493 Move BFloat16.h to headeronly (#159412)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159412
Approved by: https://github.com/desertfire
2025-07-31 15:29:17 +00:00
e7cc42df58 [inductor] consolidate common GEMM triton param retrieval (#159383)
\# Why

- Make loop iteration simpler
- Have a common spot where to make modifications that affect
  all the GEMM Triton templates, avoiding missed spots

\# What

- pull out commong logic of taking the BaseConfig objects
  and turning them into kwargs to feed into maybe_append_choice
  for Triton GEMM templates

Differential Revision: [D79186962](https://our.internmc.facebook.com/intern/diff/D79186962)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159383
Approved by: https://github.com/jansel
2025-07-31 13:05:04 +00:00
cyy
72c69e731f set MSVC debug information only on debug builds (#159533)
Fixes: https://github.com/pytorch/pytorch/issues/159515
To reduce the binary size increment in release builds by removing debug information.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159533
Approved by: https://github.com/atalman
2025-07-31 12:57:33 +00:00
78b9dea754 [inductor] Fix set_linter's handling of f-strings for Python 3.12 and up (fix #159056) (#159252)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159252
Approved by: https://github.com/Skylion007
2025-07-31 12:56:09 +00:00
838924436e update the baseline for nightly max_autotune tests (#154973)
Hi @desertfire, according to the latest test [results](https://github.com/pytorch/pytorch/actions/runs/15385952839) from the inductor nightly for max_autotune tests, we plan to update the baseline data:

In the latest nightly test, two models require baseline updates:

- vision_maskrcnn: This model shows improved graph breaks, so I’ve updated the baseline accordingly.
- detectron2_fcos_r_50_fpn: This model has a different number of graph breaks. However, since its accuracy result still shows fail_accuracy, so I skipped the graph break check for this model.

```
vision_maskrcnn                     IMPROVED:           graph_breaks=29, expected=30
Improvement: 1 models have fixed dynamo graph breaks:
    vision_maskrcnn
```

```
detectron2_fcos_r_50_fpn            XFAIL
detectron2_fcos_r_50_fpn            FAIL:               graph_breaks=24, expected=22
Error: 1 models have new dynamo graph breaks:
    detectron2_fcos_r_50_fpn
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154973
Approved by: https://github.com/desertfire
2025-07-31 11:38:55 +00:00
2ffb510942 [Break XPU][Indutor UT] Fix failures introduced by community. (#159463)
Fixes #159000, Fixes #159335, Fixes #159334, Fixes #159332, Fixes #159331, Fixes #159330

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159463
Approved by: https://github.com/jansel
2025-07-31 08:37:41 +00:00
20b5f694f8 [Dynamo] Make frozen dataclasses hashable (#159529)
Fixes https://github.com/pytorch/pytorch/issues/159424

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159529
Approved by: https://github.com/oulgen
ghstack dependencies: #159513
2025-07-31 07:03:01 +00:00
447e300d55 [Dynamo] Frozen dataclass attr access test (#159513)
Verifies https://github.com/pytorch/pytorch/issues/159424, but perhaps the issue is not fixed yet.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159513
Approved by: https://github.com/oulgen
2025-07-31 07:03:01 +00:00
5b2ad9279c [draft export] logging (#159004)
Summary: adds logging for draft export

Test Plan:
loggercli stage actualize-stage TorchDraftExportUsageLoggerConfig

Rollback Plan:

Differential Revision: D78308105

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159004
Approved by: https://github.com/angelayi
2025-07-31 05:52:13 +00:00
78d7f0cdec disable execution frame cleanup (#159531)
Summary: Want to disable execution frame cleanup until fix in D78621408 is merged

Test Plan:
CI

Rollback Plan:

Differential Revision: D79306602

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159531
Approved by: https://github.com/SherlockNoMad
2025-07-31 05:02:36 +00:00
d5c719ec3c [inductor] fix open temp file failed on Windows. (#159342)
Fix open temp file failed on Windows. Error message:
<img width="1181" height="239" alt="image" src="https://github.com/user-attachments/assets/e4a6f438-cb06-44c6-959b-0a6a49d2f44f" />

Here two option to fix this issue: https://stackoverflow.com/questions/66744497/python-tempfile-namedtemporaryfile-cant-use-generated-tempfile
1. `tempfile.NamedTemporaryFile` must setup `delete=False` on Windows
2. Use `WritableTempFile` to handle this case on Windows.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159342
Approved by: https://github.com/jansel
2025-07-31 04:58:02 +00:00
c44efc3755 [Refactor] Fix Compile Warning: possibly dangling reference to a temporary (#159517)
```bash
DEBUG pytorch/torch/csrc/dynamo/compiled_autograd.h:1388:25: warning: possibly dangling reference to a temporary [-Wdangling-reference]
DEBUG  1388 |     for (const at::IValue& elt : lst) {
DEBUG       |                         ^~~
DEBUG pytorch/torch/csrc/dynamo/compiled_autograd.h:1388:1: note: the temporary was destroyed at the end of the full expression ‘__for_begin .c10::impl::ListIterator<c10::IValue, __gnu_cxx::__normal_iterator<c10::IValue*, std::vector<c10::IValue> > >::operator*().c10::impl::ListElementReference<c10::IValue, __gnu_cxx::__normal_iterator<c10::IValue*, std::vector<c10::IValue> > >::operator std::conditional_t<true, const c10::IValue&, c10::IValue>()’
DEBUG  1388 |     for (const at::IValue& elt : lst) {
DEBUG       | ^
```

This PR fixes this warning

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159517
Approved by: https://github.com/xmfan
2025-07-31 04:49:43 +00:00
6b9473469f [Graph Partition] add log for graph partition reasons and #partitions (#159425)
Previously, we log `skipping cudagraphs due to [xxx reasons]` when there are cudagraph-unsafe ops. With graph partition, we will split off these ops and cudagraph remaining parts. But the log message is also skipped.

In this PR, we add logs for graph partition reasons and the number of partitions to better understand the workload.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159425
Approved by: https://github.com/eellison
2025-07-31 04:21:06 +00:00
7a4167a164 support fabric handles with symmetric memory (#159319)
enable fabric handles for symmetric memory

Enables handle exchange via CU_MEM_HANDLE_TYPE_FABRIC on the systems that support it. This is needed to enable symmetric memory on NVLS72 systems.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159319
Approved by: https://github.com/malfet, https://github.com/kwen2501
2025-07-31 04:16:20 +00:00
8e67a6ae89 [vllm hash update] update the pinned vllm hash (#159320)
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml).
Update the pinned vllm hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159320
Approved by: https://github.com/pytorchbot
2025-07-31 04:08:14 +00:00
c68ad1bd6a [dynamo][guards] Always record user.stack for informative tlparse guards (#159526)
Before
<img width="1146" height="280" alt="image" src="https://github.com/user-attachments/assets/4ddb11b2-dec8-4010-a28d-63b3cd4a7929" />

After
<img width="1248" height="248" alt="image" src="https://github.com/user-attachments/assets/8aafc5be-92cd-4468-bb8f-ad966de8c717" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159526
Approved by: https://github.com/Lucaskabela
2025-07-31 03:18:33 +00:00
3e5e094615 Revert "Fix large_tensor_test skipping cpu (#158617)"
This reverts commit debc0591b888f211bfe846bdc7cfa0626a5f6f6a.

Reverted https://github.com/pytorch/pytorch/pull/158617 on behalf of https://github.com/ZainRizvi due to Sorry but this seems to be breaking trunk. See [GH job link](https://github.com/pytorch/pytorch/actions/runs/16631113381/job/47062415099) [HUD commit link](debc0591b8) ([comment](https://github.com/pytorch/pytorch/pull/158617#issuecomment-3138387762))
2025-07-31 02:57:22 +00:00
clr
c65efc8ea1 torch.compile: Record a pt2_compile_event for combo kernels (#159306)
This is off by default, but some jobs have it on. Having this show up in
perfetto and be globally queryable would be useful to see how expensive this
is.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159306
Approved by: https://github.com/masnesral
2025-07-31 02:51:38 +00:00
a9049413e2 [dynamo] Turn on recursive dict tag optimization (#159186)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159186
Approved by: https://github.com/jansel
2025-07-31 02:36:37 +00:00
d7a5ec9355 Fix the Doc of padding in avg_poolnd (#159142)
Fixes #159141

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159142
Approved by: https://github.com/mikaylagawarecki
2025-07-31 02:02:48 +00:00
2c46922ce4 Fix rand_like decomposition to preserve strides (#159294)
Summary: Like https://github.com/pytorch/pytorch/pull/158898, the rand_like variants are not preserving strides. Followed the pattern established in https://github.com/pytorch/pytorch/pull/158898.

Test Plan: New unit test (fails before this PR; but fixed after)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159294
Approved by: https://github.com/eellison
2025-07-31 01:36:50 +00:00
668d414ae7 [CPU] Fix bias dtype issue for FP8 qlinear (#159125)
Fixes
`RuntimeError: self and mat2 must have the same dtype, but got BFloat16 and Float`

With bf16 autocast, bias converted into BFloat16, but fp8_qlinear_onednn_ref not support bf16 bias.
In this pr, convert bias into bf16 on fp8_qlinear_onednn_ref.

Add this case into ut and reproduce:
`python test/test_quantization.py -k test_qlinear_fp8`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159125
Approved by: https://github.com/Xia-Weiwen, https://github.com/cyyever, https://github.com/CaoE
2025-07-31 01:26:45 +00:00
4541509237 [Triton] [Inductor] Fix an incorrect descriptor (#159407)
Summary: Fixes a clear template typo where `a_desc_ptr` was passed instead of `b_desc_ptr` to define `b_desc`.

Test Plan:
Found by inspection.

Rollback Plan:

Reviewed By: NoamPaz

Differential Revision: D79178538

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159407
Approved by: https://github.com/NikhilAPatel
2025-07-31 00:34:19 +00:00
6c7f88c2c9 Check addmm dtypes (#159509)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159509
Approved by: https://github.com/eqy
2025-07-31 00:15:46 +00:00
c400c8e2e0 [ROCm] Add FP8 rowwise support to _scaled_grouped_mm + Submodule update (#159075)
Summary:

In this PR we integrate the [FBGEMM AMD FP8 rowwise scaling grouped GEMM kernel](https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped) to add support for the `_scaled_grouped_mm` API on AMD. `_scaled_grouped_mm` is [currently supported on Nvidia](9faef3d17c/aten/src/ATen/native/cuda/Blas.cpp (L1614)), this PR aims to bring parity to AMD. Related: [[RFC]: PyTorch Low-Precision GEMMs Public API](https://github.com/pytorch/pytorch/issues/157950#top) #157950.

The kernel is developed using the Composable Kernel framework. Only MI300X is currently supported. In the near future we plan to add support for MI350X as well. For data types we support FP8 e3m4.

The kernel support will be gated with the `USE_FBGEMM_GENAI` flag. We hope to enable this by default for relevant AMD builds.

Note we also update submodule `third_party/fbgemm` to 0adf62831 for the required updates from fbgemm.

Test Plan:

**Hipify & build**
```
python tools/amd_build/build_amd.py
USE_FBGEMM_GENAI=1 python setup.py develop
```

**Unit tests**
```
python test/test_matmul_cuda.py -- TestFP8MatmulCUDA
Ran 488 tests in 32.969s
OK (skipped=454)
```

**Performance Sample**
| G  | M | N | K | Runtime Ms | GB/S | TFLOPS |
| --  | -- | -- | -- | -- | -- | -- |
| 128 | 1 | 2048 | 5120 | 0.37| 3590 | 7.17 |
| 128 | 64 | 2048 | 5120 | 0.51| 2792 | 338.34 |
| 128 | 128 | 2048 | 5120 | 0.66| 2272 | 522.72 |
| 128 | 1 | 5120 | 1024 | 0.21| 3224 | 6.43 |
| 128 | 64 | 5120 | 1024 | 0.29| 2590 | 291.40 |
| 128 | 128 | 5120 | 1024 | 0.40| 2165 | 434.76 |
| 128 | 1 | 4096 | 4096 | 0.69| 3126 | 6.25 |
| 128 | 64 | 4096 | 4096 | 0.85| 2655 | 324.66 |
| 128 | 128 | 4096 | 4096 | 1.10| 2142 | 501.40 |
| 128 | 1 | 8192 | 8192 | 2.45| 3508 | 7.01 |
| 128 | 64 | 8192 | 8192 | 3.27| 2692 | 336.74 |
| 128 | 128 | 8192 | 8192 | 4.04| 2224 | 543.76 |
| 16 | 1 | 2048 | 5120 | 0.04| 3928 | 7.85 |
| 16 | 64 | 2048 | 5120 | 0.05| 3295 | 399.29 |
| 16 | 128 | 2048 | 5120 | 0.07| 2558 | 588.69 |
| 16 | 1 | 5120 | 1024 | 0.03| 3119 | 6.23 |
| 16 | 64 | 5120 | 1024 | 0.03| 2849 | 320.62 |
| 16 | 128 | 5120 | 1024 | 0.05| 2013 | 404.11 |
| 16 | 1 | 4096 | 4096 | 0.06| 4512 | 9.02 |
| 16 | 64 | 4096 | 4096 | 0.09| 3124 | 381.95 |
| 16 | 128 | 4096 | 4096 | 0.13| 2340 | 547.67 |
| 16 | 1 | 8192 | 8192 | 0.32| 3374 | 6.75 |
| 16 | 64 | 8192 | 8192 | 0.42| 2593 | 324.28 |
| 16 | 128 | 8192 | 8192 | 0.53| 2120 | 518.36 |

- Using ROCm 6.4.1
- Collected through `triton.testing.do_bench_cudagraph`

**Binary size with gfx942 arch**
Before: 116103856 Jul 23 14:12 build/lib/libtorch_hip.so
After:  118860960 Jul 23 14:29 build/lib/libtorch_hip.so
The difference is 2757104 bytes (~2.6 MiB).

Reviewers: @drisspg @ngimel @jwfromm @jeffdaily

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159075
Approved by: https://github.com/drisspg
2025-07-30 23:53:58 +00:00
25c3a7e317 [CUDA][CUDA Graphs] Move cuda graphs test to subprocess to avoid polluting mempool tests (#159305)
Otherwise mempool test will fail as the previous graph capture failed but doesn't have its state in the caching allocator fully cleaned up. See also #159301

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159305
Approved by: https://github.com/eellison, https://github.com/BoyuanFeng, https://github.com/naromero77amd
2025-07-30 23:31:38 +00:00
de7376537f Fix ep deepcopy when there is python builitin name (#159478)
Summary: title

Test Plan:
CI

Rollback Plan:

Differential Revision: D79261007

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159478
Approved by: https://github.com/pianpwk
2025-07-30 23:14:31 +00:00
fd2c64e286 Fix duplicated sources in inductor provenance tracking (#159484)
Summary:

The `replace_hook` is called once for each user of the replaced node. This fix avoids adding duplicated node sources.

This also means that if there are two nested pass like:

```
with GraphTransformObserver(gm, "outer"):
      with GraphTransformObserver(gm, "inner"):
              .....
```

We'll only see the outer pass's pass name recorded for the replaced node in the "from_node" node meta. I think this is fine. In practice, the outer pass usually contains a more meaningful name, e.g. `decompose_auto_functionalized`, and the inner pass name is just a default pass name like `pattern_matcher`.

Test Plan:
```
buck2 run @mode/dev-nosan fbcode//caffe2/test:fx -- -r test_graph_transform_observer_replace
```

Rollback Plan:

Differential Revision: D79203058

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159484
Approved by: https://github.com/angelayi
2025-07-30 23:03:11 +00:00
2b1ae29960 [Dynamo][Better Engineering] Add typing annotations to guard and source (#158397) (#159491)
Summary:
X-link: https://github.com/pytorch/executorch/pull/12986

As part of better engineering week, we would like to improve out type support to improve dev experience in dynamo

This PR adds strict typing support to a critical set of files for dynamo, `source.py` and the base `_guards.py`

Running
```
mypy torch/_dynamo/source.py torch/_guards.py --linecount-report /tmp/coverage_log
```

| -------- | Lines Unannotated | Lines Total | % lines covered | Funcs Unannotated | Funcs Total | % funcs covered |
| -------- | ------- | -------- | ------- | ------- | ------- | ------- |
| Main  |  1227 | 2208 | 55.57% | 207 | 362 | 57.18% |
| This PR | 2217 | 2217 | 100.00% | 362 | 362 | 100.00% |
| Delta    | +990 | +9 | +44.43% | +155 | 0 | +42.82% |

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 jerryzh168 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

Test Plan:
Imported from GitHub, without a `Test Plan:` line.

Rollback Plan:

Reviewed By: JacobSzwejbka, yangw-dev

Differential Revision: D79199389

Pulled By: Lucaskabela

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159491
Approved by: https://github.com/anijain2305, https://github.com/yangw-dev
2025-07-30 22:57:50 +00:00
1293405c8d [MPS] Add simd_[arg][max|min] (#158990)
And add eager tests for those.
Re-implement `threadgroup_[max|min]` using those function as they are significantly faster (though much slower than eager, due to the arg part) than before, which could be verified by running the following script
```python
import itertools
import timeit
import torch
from torch.utils.benchmark import Compare, Measurement, Timer

def bench_unary_op(func, x, label) -> Measurement:
    sync_cmd = "torch.mps.synchronize()" if "mps" in str(x.device) else ""
    t = Timer(
        stmt=f"f(x);{sync_cmd}",
        globals={"f": func, "x": x},
        language="python",
        timer=timeit.default_timer,
        sub_label=f"{func.__name__} ({str(x.dtype)})",
        description=label,
        env=torch.__version__,
    )
    return t.blocked_autorange()

def bench_reduction(
    reduction_func, device: str = "mps", dtype: torch.dtype = torch.float32
) -> list[Measurement]:
    rc = []

    # Bench 2D with reduction over dim=0
    def f(t):
        return reduction_func(t, dim=0)[0]

    f.__name__ = reduction_func.__name__
    f_c = torch.compile(f, dynamic=False, fullgraph=True)

    for size in (512, 1024, 2048, 4096):
        x = torch.testing.make_tensor(size, size, device=device, dtype=dtype)
        rc_c, rc_e = f(x), f_c(x)
        rc_c, rc_e = (rc_c[0], rc_e[0]) if isinstance(rc_c, tuple) else (rc_c, rc_e)
        rc.append(bench_unary_op(f, x, f"eager-{size}x{size}"))
        rc.append(bench_unary_op(f_c, x, f"compile-{size}x{size}"))
    return rc

def main() -> None:
    #dtypes = [torch.float16, torch.float32, torch.bfloat16, torch.int32, torch.int64]
    dtypes = [torch.float32, torch.int32, torch.int64]

    # Profile reduction ops
    rc = []
    for op, dtype in itertools.product([torch.max], dtypes):
        rc.extend(bench_reduction(op, dtype=dtype))
    Compare(rc).print()

if __name__ == "__main__":
    torch._dynamo.config.cache_size_limit = 2**16
    main()
```

Produces the following table before
```
[---------------------------------------------------------------------------------------------  --------------------------------------------------------------------------------------------]
                           |  eager-512x512  |  compile-512x512  |  eager-1024x1024  |  compile-1024x1024  |  eager-2048x2048  |  compile-2048x2048  |  eager-4096x4096  |  compile-4096x4096
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      max (torch.float32)  |      297.3      |       531.6       |       394.1       |        2550.5       |       773.0       |        4904.7       |       3647.2      |        9682.0
      max (torch.int32)    |      297.8      |       359.2       |       387.7       |        1179.4       |       768.2       |        2175.0       |       3677.1      |        4495.9
      max (torch.int64)    |      278.7      |       541.4       |       410.2       |        2873.3       |       858.9       |        5620.4       |       6107.2      |       11176.1

Times are in microseconds (us).
```
And after
```
[---------------------------------------------------------------------------------------------  --------------------------------------------------------------------------------------------]
                           |  eager-512x512  |  compile-512x512  |  eager-1024x1024  |  compile-1024x1024  |  eager-2048x2048  |  compile-2048x2048  |  eager-4096x4096  |  compile-4096x4096
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      max (torch.float32)  |      307.9      |       265.3       |       401.0       |        340.8        |       766.5       |        661.9        |       3463.5      |        2829.5
      max (torch.int32)    |      293.5      |       263.1       |       405.0       |        338.8        |       761.4       |        672.5        |       3050.0      |        2688.6
      max (torch.int64)    |      308.2      |       255.7       |       417.4       |        341.4        |       877.0       |        695.0        |       5812.2      |        5762.2

```

`argmax`/`argmin` are much tricker due to the nan-handling logic that need to be added there.

Also fixes `torch.max/min` compilation for half-precision types, added regression types for it.

This PR also introduces a bunch of helper functions, such as `simd_broadcast` that works for int64 and `c10:🤘:pair` template, which are used by `simd_argmax` to return both value and index

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158990
Approved by: https://github.com/dcci, https://github.com/Skylion007
2025-07-30 21:57:25 +00:00
3a65ff84b6 [dynamo, easy] add comment on skipping sys.monitoring frames (#159493)
Add a comment so we know why we're doing this code (followup to https://github.com/pytorch/pytorch/pull/159369)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159493
Approved by: https://github.com/azahed98, https://github.com/Lucaskabela, https://github.com/zou3519, https://github.com/jingsh
ghstack dependencies: #159369
2025-07-30 21:54:38 +00:00
acf13a9b75 Fix a bug of distributed 'gather' with uncontiguous tensors on the Gloo backend (#158903)
Fixes #158902

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158903
Approved by: https://github.com/H-Huang
2025-07-30 21:44:29 +00:00
3a55676200 fix strategy hashing arg mismatch (#159506)
Reland https://github.com/pytorch/pytorch/pull/159289.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159506
Approved by: https://github.com/XilunWu
2025-07-30 21:37:13 +00:00
af39144a93 Don't use torch.backends.cuda.matmul.allow_tf32 in inductor cache key (#159480)
Summary: According to https://github.com/pytorch/pytorch/pull/158209, the API is deprecated and we should be using torch.backends.cuda.matmul.fp32_precision instead.

Fixes https://github.com/pytorch/pytorch/issues/159440

Test Plan: CI

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159480
Approved by: https://github.com/xmfan, https://github.com/oulgen
2025-07-30 21:29:38 +00:00
25343b343e [ATen][CUDA][cuFFT] Guard against deprecated error codes (#159466)
This PR adds a guard based on CUDA version, per latest cuFFT [documentation](https://docs.nvidia.com/cuda/cufft/index.html#return-value-cufftresult):
>The following error codes are deprecated and will be removed in a future release: `CUFFT_INCOMPLETE_PARAMETER_LIST`, `CUFFT_PARSE_ERROR`, `CUFFT_LICENSE_ERROR`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159466
Approved by: https://github.com/albanD, https://github.com/eqy, https://github.com/Skylion007
2025-07-30 21:10:32 +00:00
07fad04181 [ContextParallel][FlexAttention] Prototype of supporting FlexAttention in Context Parallel (#158692)
**Summary**
This PR adds an all-gather based FlexAttention and uses TorchFunctionMode to dispatch
`FlexAttentionHOP.__call__` to it.

This PR makes the following changes:

- add a user-facing API `create_cp_block_mask` for creating CP-specific `BlockMask`
which masks over the attention result of Q shard and KV global.
- add `_ContextParallelGlobalVars` to store all necessary global vars that CP FlexAttention
requires. `torch_function_mode` is critical to maintain singleton mode to avoid dynamo
recompilations.
- add a dispatch path for `FlexAttentionForwardHOP.__call__` (TorchFunctionMode dispatch
won't work correctly without this line)

What's not in this PR:
- QKV load balancing
- Test on other masking besides `causal_mask`.
- Support on small attention (i.e. qkv size is smaller than 128) because the block mask
rewrite function requires `Q_BLOCK_SIZE == KV_BLOCK_SIZE == 128`.

**Test**
`pytest test/distributed/tensor/test_attention.py -s -k test_ring_flex_attention`

**Followup**
1. create an issue to reproduce the error in `create_fw_bw_graph()` when trying to call `create_block_mask`
to re-write `block_mask` in `FlexAttentionHOP` dispatch in `TorchFunctionMode`.
2. Merge `_ContextParallelGlobalVars` and `_cp_options`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158692
Approved by: https://github.com/drisspg
2025-07-30 21:01:53 +00:00
7ac70ac4cd Revert "Fix rand_like decomposition to preserve strides (#159294)"
This reverts commit a3a51282dbabe0220c2c3947a89f7d2ecc514d33.

Reverted https://github.com/pytorch/pytorch/pull/159294 on behalf of https://github.com/yangw-dev due to failed internal build Failed to load config ([comment](https://github.com/pytorch/pytorch/pull/159294#issuecomment-3137796767))
2025-07-30 20:59:19 +00:00
e221a1c853 [Code Motion]Restructure flex attention kernel into flex subdirectory (#159437)
Mostly code motion, updating relative paths, moving some imports that had to be lazy before to top level scope now that we are free from the curse.

This will make it easier to add newer templates and provide some organization

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159437
Approved by: https://github.com/Chillee, https://github.com/BoyuanFeng, https://github.com/eellison, https://github.com/Skylion007
2025-07-30 20:12:35 +00:00
4defea1e2c [c10d] Fix setGroupName and setGroupDesc in group_split and merge_remote_group (#159429)
Summary:
We found that we don't really set group_name inside group_split correctly, because we are setting group_name to `deviceTypeToBackend_` which is set after `setBackend`. Same thing as group_desc. I added more unit tests for it.

We need to setGroupName correctly, otherwise, this will break DeviceMesh use case when split_group is used in DeviceMesh

Also ncclx needs to be aware of that its Option is a subclass of BackendOption

Test Plan:
CI

Rollback Plan:

Differential Revision: D79201132

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159429
Approved by: https://github.com/xunnanxu
2025-07-30 19:55:55 +00:00
53d68b95de [ROCm CI] Migrate to MI325 Capacity. (#159059)
This PR moves PyTorch CI capacity from mi300 to a new, larger mi325 cluster. Both of these GPUs are the same architecture gfx942 and our testing plans don't change within an architecture, so we pool them under the same label `linux.rocm.gpu.gfx942.<#gpus>` with this PR as well to reduce overhead and confusion.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159059
Approved by: https://github.com/jithunnair-amd, https://github.com/atalman

Co-authored-by: deedongala <deekshitha.dongala@amd.com>
2025-07-30 19:47:59 +00:00
f74842d57f [PP] Fix zero bubble schedules for eval() (#159475)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159475
Approved by: https://github.com/tianyu-l, https://github.com/Skylion007
2025-07-30 19:46:10 +00:00
644fee2610 Fix TestAutogradFallback flaky tests under Dynamo: migrate to lib._destroy() (#159443)
under dynamo, the libraries couldn't properly be cleared unless we manually did `gc.collect()`, but that's slow. it also worked if we just used the _destroy() method to tear down

FIXES
#159398
#159349
#159254
#159237
#159153
#159114
#159040
#158910
#158841
#158763
#158735

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159443
Approved by: https://github.com/zou3519, https://github.com/Skylion007
2025-07-30 19:30:55 +00:00
7821fbc560 [BE] Clarify comment to not revert when command has been edited (#159495)
This is mostly a nit. I was a bit confused when I saw
<img width="1032" height="183" alt="image" src="https://github.com/user-attachments/assets/7a18f167-78c1-4c33-ba6f-3588914c642e" />
in https://github.com/pytorch/pytorch/pull/159172

So I decided I should clean up this message a bit.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159495
Approved by: https://github.com/yangw-dev, https://github.com/clee2000, https://github.com/ZainRizvi, https://github.com/malfet
2025-07-30 19:23:33 +00:00
73ee323380 [ONNX] RMS Norm (#159377)
- Implement rms norm using onnx RMSNormalization-23
- Use the correct eps for float32
  eaadd1282c/aten/src/ATen/native/cuda/layer_norm_kernel.cu (L1844-L1866)
  <img width="743" height="107" alt="image" src="https://github.com/user-attachments/assets/a6fd45aa-01d9-4667-924d-3012232cfcde" />

- Created facility to run tests with the reference runtime by extending ONNXProgram and assert_onnx_program.

Fix https://github.com/pytorch/pytorch/issues/159257
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159377
Approved by: https://github.com/titaiwangms
2025-07-30 18:55:47 +00:00
176c6446f8 Update CODEOWNERS for ONNX (#159390)
Update CODEOWNERS for ONNX to reflect current maintainers.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159390
Approved by: https://github.com/titaiwangms, https://github.com/malfet
2025-07-30 18:54:25 +00:00
debc0591b8 Fix large_tensor_test skipping cpu (#158617)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158617
Approved by: https://github.com/BoyuanFeng
2025-07-30 18:48:07 +00:00
0df78f0c11 Remove /d2implyavx512upperregs- flag (#159431)
And reopen https://github.com/pytorch/pytorch/issues/145702

As this flag is not documented anywhere, slows down sccache accelerated build and  per https://developercommunity.visualstudio.com/t/Invalid-code-gen-when-using-AVX2-and-SSE/10527298#T-N10562579 it does not workaround a compiler bug, but rather disables some optimizations of AVX512 instructions which are being invoked in AVX2 codepath

Fixes https://github.com/pytorch/pytorch/issues/159082

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159431
Approved by: https://github.com/clee2000
2025-07-30 18:47:03 +00:00
d0e8a0ec4c Add CPython test for heapq (#159370)
Not used directly but used internally by `collections.Counter`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159370
Approved by: https://github.com/zou3519, https://github.com/Skylion007
2025-07-30 18:43:06 +00:00
22492848b6 [BE]: Update CUTLASS submodule to 4.1.0 (#158854)
Update the CUTLASS submodule to the latest version with new supported architectures and new features we can use.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158854
Approved by: https://github.com/henrylhtsang
2025-07-30 17:44:38 +00:00
5c14315b05 fixed typo error (#159451)
Fixes #159375

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159451
Approved by: https://github.com/albanD
2025-07-30 17:41:30 +00:00
1b99c1859c [BE] Make PyObjectSlot use a global PyInterpreter and remove (#158427)
This PR is a bit more involved but effectively works to drastically simplify PyObjectSlot and PyInterpreter.
1) For PyObjectSlot we now use a global pyinterpreter since there only is one. From here we change all of the call sites to rely on this assumption.
2) We also remove the "tags" of the PyInterpreter by deprecating `PyInterpreterStatus`.

For the reviewer, sadly it seems like `functorch/csrc/dim/dim.cpp` needed to get linted, so there is an unreadable amount of changes there. Fortunately, the only actual change in the file is as follows which just removes `getPyInterpreter()` from  the `check_pyobj` call.

```
 mpy::handle handle_from_tensor(Arena& A, TensorRef t) {
-    // fast case: tensor is live in python
-    std::optional<PyObject*> mb_obj =
-        t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(getPyInterpreter(), /*ignore_hermetic_tls=*/false);
-    if (mb_obj.has_value() && !t->unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) {
-        return *mb_obj;
-    }
-    return A.autorelease(mpy::object::checked_steal(THPVariable_Wrap(*t)));
-}
-}
+  // fast case: tensor is live in python
+  std::optional<PyObject*> mb_obj =
+      t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
+          /*ignore_hermetic_tls=*/false);
+  if (mb_obj.has_value() &&
+      !t->unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) {
+    return *mb_obj;
+  }
+  return A.autorelease(mpy::object::checked_steal(THPVariable_Wrap(*t)));
+}
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158427
Approved by: https://github.com/albanD
2025-07-30 17:29:43 +00:00
435edbcb5d [Graph Partition] add graph partition doc (#159450)
This pr adds doc for graph partition.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159450
Approved by: https://github.com/eellison
2025-07-30 17:01:10 +00:00
6c6e11c206 Revert "Fix max_width computation in _tensor_str._Formatter (#126859)"
This reverts commit 1465757959dd7e63715b7621650896eca977aefa.

Reverted https://github.com/pytorch/pytorch/pull/126859 on behalf of https://github.com/yangw-dev due to broke trunk with test  distributed/test_c10d_functional_native.py::CompileTest::test_inductor_all_reduce_single - RuntimeError: Expected to find buf7 = empty but did not find it ([comment](https://github.com/pytorch/pytorch/pull/126859#issuecomment-3137137030))
2025-07-30 16:56:32 +00:00
a775c8e73e [Profiler] Fix lost C call events problem in Python 3.12.0-3.12.4 (#155446)
Hi team,

Please help review this patch.

This PR https://github.com/pytorch/pytorch/pull/150370 tried to fix the "Empty C Call Queue" problem on Python 3.12. It added C calls for each starting Python event with a callable.

I found the root cause is not that we cannot get C function frames by `PyFrame_GetBack` when PythonTracer is filling start frames, but the c call event loss problem bug on Python 3.12.0-3.12.4. And that problem was fixed by 257c413cd1 on 3.12.5.

So I think the https://github.com/pytorch/pytorch/pull/150370 cannot fix the problem, this patch reverts the change of it.

There are solutions to fix the problem correctly, such as we can add a new monitoring callback to compensate call events of methods with C function or we can override the callback registered by `PyEval_SetProfile`.  These solutions may make the code hard to maintain.

~~Since upgrading the micro version of Python is not difficult for users, we can just ignore C functions and suggest user upgrade.~~

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155446
Approved by: https://github.com/sraikund16
2025-07-30 16:35:51 +00:00
24d07b3a67 [inductor] Fix mm decomposition evaluating symints (#158998)
Fixes #154111

Resolves an issue during compilation with dynamic shapes where `torch._inductor.decomposition.mm` evaluates the SymInt expression for the input tensor due to a for loop, and thus the output tensor is not dynamically shaped. This issue is limited to (Mx1)x(1xN) small matrix multiplications, and creates an explicit error with tensor subclasses such as DTensor.

The proposed fix replaces the loop with a simple product instead. Benchmark currently running https://hud.pytorch.org/benchmark/compilers

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158998
Approved by: https://github.com/jansel, https://github.com/BoyuanFeng
2025-07-30 16:34:15 +00:00
90fd06be71 Various bugfixes for running NanoGPT training (#159166)
Fix various small bugs with running nanogpt on torchbenchmark in OSS under python 3.10. After these changes, the following now succeeds:

```
tlp python benchmarks/dynamo/torchbench.py --only nanogpt --performance  --training --backend inductor  --caching-precompile --warm-start-latency
```

Cold start: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmp12LuZ5/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000

Warm start (we are invesigating the recompile):
https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpT5YTB2/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159166
Approved by: https://github.com/zhxchen17
2025-07-30 16:30:22 +00:00
002f18807e [DCP] Improve error handling for process based async checkpointing (#159374)
Summary:
### PR Context
- Kill background process only when PG init fails or there is an explicit `TERMINATE` signal from main process.
- When a checkpoint fails to save, log and return the error but continue the serving loop.

Test Plan:
CI

Rollback Plan:

Differential Revision: D79177410

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159374
Approved by: https://github.com/sibuachu
2025-07-30 16:25:28 +00:00
259e79e3ff Move Half to headeronly (#159172)
Essence of this copypasta:
- combine Half-inl.h and Half.h in c10/util -> torch/headeronly/util/Half.h
- Add NOLINTNEXTLINE's to the portions of Half-inl.h that were previously in the ignore list of clangtidy
- Re-expose all APIs in namespaces and through includes of the original files. Ideally, we would have the APIs in torch::headeronly and reexpose them in c10, but that runs into BC issues (see D78997465) so for now we are keeping the APIs in c10 but reexposing them in torch::headeronly.
- Change test cases in test_aoti_abi_check to test torch::headeronly::Half vs c10::Half (they're the same thing but we eventually want all the tests for headeronly APIs to only import from headeronly).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159172
Approved by: https://github.com/albanD, https://github.com/desertfire
2025-07-30 16:11:58 +00:00
ee343ce60c [RPC][TensorPipe] Fix import torch if compiled without TensorPipe (#159461)
This is a follow up on the PR #154382, as the issue still persists:
```
  File "/opt/pytorch/pytorch/torch/distributed/rpc/__init__.py", line 81, in <module>
    from . import api, backend_registry, functions
  File "/opt/pytorch/pytorch/torch/distributed/rpc/api.py", line 35, in <module>
    from .constants import DEFAULT_SHUTDOWN_TIMEOUT, UNSET_RPC_TIMEOUT
  File "/opt/pytorch/pytorch/torch/distributed/rpc/constants.py", line 3, in <module>
    from torch._C._distributed_rpc import (
ImportError: cannot import name '_DEFAULT_NUM_WORKER_THREADS' from 'torch._C._distributed_rpc' (unknown location)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159461
Approved by: https://github.com/lw
2025-07-30 16:04:02 +00:00
ea5369113a unflatten closure (#159418)
Summary: Sometimes the call history recorded in a `nn_module_stack` does not have the stack property, where each FQN is a prefix of the next FQN. This can cause errors during `unflatten`. Instead of erroring we now drop entries from such a `nn_module_stack` to restore the stack property. This effectively leads to less unflattening: the last FQN in the call history before the stack property was broken keeps the entire flat subgraph of its call.

Test Plan:
added test, updated another

Rollback Plan:

Differential Revision: D79204669

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159418
Approved by: https://github.com/angelayi
2025-07-30 15:42:18 +00:00
b268f22ab2 Move Float4 to headeronly (#159414)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159414
Approved by: https://github.com/desertfire
2025-07-30 15:34:01 +00:00
52a52d1b78 [dynamo][guards] Skip no tensor aliasing guard on inbuilt nn module buffers (#159453)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159453
Approved by: https://github.com/jansel
2025-07-30 15:31:07 +00:00
eaadd1282c Revert "Move Half to headeronly (#159172)"
This reverts commit 6d0f4566e2b6e05369d8bb6c0d0e83a0eee982aa.

Reverted https://github.com/pytorch/pytorch/pull/159172 on behalf of https://github.com/clee2000 due to broke lint [GH job link](https://github.com/pytorch/pytorch/actions/runs/16613893793/job/47002486679) [HUD commit link](6d0f4566e2).  Note to self: why isn't Dr. CI updating ([comment](https://github.com/pytorch/pytorch/pull/159172#issuecomment-3136769493))
2025-07-30 15:10:26 +00:00
1465757959 Fix max_width computation in _tensor_str._Formatter (#126859)
Previous version of `torch._tensor_str._Formatter` was not using `PRINT_OPTS.sci_mode` for the `max_width` computation but was using it for the formatting of values leading to a weird discrepancy.

Now, the code first checks if it should be in sci_mode, then compute `max_width`

Here is an example to test the behavior:
```python
A = torch.tensor([10, 1e-1, 1e-2])
B = torch.tensor([10, 1e-1, 1e-1])

print("================= Default =================")
print(A, f"Formatter max_width: {torch._tensor_str._Formatter(A).max_width}")
print(B, f"Formatter max_width: {torch._tensor_str._Formatter(B).max_width}")

print("================= sci_mode=False =================")
with torch._tensor_str.printoptions(sci_mode=False):
    print(A, f"Formatter max_width: {torch._tensor_str._Formatter(A).max_width}")
    print(B, f"Formatter max_width: {torch._tensor_str._Formatter(B).max_width}")

print("================= sci_mode=True =================")
with torch._tensor_str.printoptions(sci_mode=True):
    print(A, f"Formatter max_width: {torch._tensor_str._Formatter(A).max_width}")
    print(B, f"Formatter max_width: {torch._tensor_str._Formatter(B).max_width}")
```

In the current version this prints:
```
================= Default =================
tensor([1.0000e+01, 1.0000e-01, 1.0000e-02]) Formatter max_width: 10
tensor([10.0000,  0.1000,  0.1000]) Formatter max_width: 7
================= sci_mode=False =================
tensor([   10.0000,     0.1000,     0.0100]) Formatter max_width: 10
tensor([10.0000,  0.1000,  0.1000]) Formatter max_width: 7
================= sci_mode=True =================
tensor([1.0000e+01, 1.0000e-01, 1.0000e-02]) Formatter max_width: 10
tensor([1.0000e+01, 1.0000e-01, 1.0000e-01]) Formatter max_width: 7
```

On can see that in `sci_mode=False`, the values of A are prefixed with unneeded 0 and does not have the same `max_width` as B (It keeps the `max_width` from `sci_mode = None`)

Also in `sci_mode = True`, for B, the `max_width` is 7 but each value takes 10 chars... (But it is fine as the code that uses `max_width` do not rely much on it, but still, this is missleading)

After this commit, this will print
```
================= Default =================
tensor([1.0000e+01, 1.0000e-01, 1.0000e-02]) Formatter max_width: 10
tensor([10.0000,  0.1000,  0.1000]) Formatter max_width: 7
================= sci_mode=False =================
tensor([10.0000,  0.1000,  0.0100]) Formatter max_width: 7
tensor([10.0000,  0.1000,  0.1000]) Formatter max_width: 7
================= sci_mode=True =================
tensor([1.0000e+01, 1.0000e-01, 1.0000e-02]) Formatter max_width: 10
tensor([1.0000e+01, 1.0000e-01, 1.0000e-01]) Formatter max_width: 10
```

This also allows to align A with B for `sci_mode=False`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126859
Approved by: https://github.com/malfet
2025-07-30 14:01:00 +00:00
17b9c618dd [a2av] not returning out tensor from ops (#159435)
torch.compile of `all_to_all_vdev_2d` hits the following error:
```
torch._dynamo.exc.BackendCompilerFailed: backend='aot_eager' raised:
RuntimeError: Found a custom (non-ATen) operator whose output has alias annotations: symm_mem::all_to_all_vdev_2d(Tensor input, Tensor(a!) out, Tensor in_splits, Tensor(a!) out_splits_offsets, str group_name, int? major_align=None) -> Tensor(a!). We only support functionalizing operators whose outputs do not have alias annotations (e.g. 'Tensor(a)' is a Tensor with an alias annotation whereas 'Tensor' is a Tensor without. The '(a)' is the alias annotation). The alias annotation specifies that the output Tensor shares storage with an input that has the same annotation. Please check if (1) the output needs to be an output (if not, don't return it), (2) if the output doesn't share storage with any inputs, then delete the alias annotation. (3) if the output indeed shares storage with an input, then add a .clone() before returning it to prevent storage sharing and then delete the alias annotation. Otherwise, please file an issue on GitHub.
```

This PR selects option (1).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159435
Approved by: https://github.com/ngimel, https://github.com/xmfan
2025-07-30 08:30:25 +00:00
d3ce45012e Generalize torch._C._set_allocator_settings to be generic (#156175)
# Motivation
This PR moves the implementation of `torch.cuda.memory._set_allocator_settings` to `torch._C._accelerator_setAllocatorSettings`.
Since the original API was intended as a temporary/internal utility, I am not exposing the new function as a public API.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156175
Approved by: https://github.com/albanD
ghstack dependencies: #149601, #157908, #150312, #156165
2025-07-30 06:37:15 +00:00
1fc010a9d8 Deprecate overleap functions in CUDAAllocatorConfig, use AcceleratorAllocatorConfig instead (#156165)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156165
Approved by: https://github.com/albanD
ghstack dependencies: #149601, #157908, #150312
2025-07-30 06:37:15 +00:00
dfacf11f66 Refactor CUDAAllocatorConfig to reuse AcceleratorAllocatorConfig (#150312)
# Motivation
Refactor `CUDAAllocatorConfig` to reuse `AcceleratorAllocatorConfig` and `ConfigTokenizer`. We would deprecate those option that overleap with `AcceleratorAllocatorConfig` in the following PR and keep them only for BC.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150312
Approved by: https://github.com/albanD
ghstack dependencies: #149601, #157908
2025-07-30 06:37:06 +00:00
c8cf811995 Enable AcceleratorAllocatorConfig key check (#157908)
# Motivation
Add a mechanism to ensure raise the key if the key is unrecognized in allocator config.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157908
Approved by: https://github.com/albanD
ghstack dependencies: #149601
2025-07-30 06:36:56 +00:00
914b1a3873 Introduce AcceleratorAllocatorConfig as the common class (#149601)
# Motivation
This PR aims to generalize `AllocatorConfig` to be device-agnostic. Introduce the class `AcceleratorAllocatorConfig` to clarify its scope as a configuration manager for accelerator backends (e.g., CUDA, XPU). The another name `AllocatorConfig` is now reserved for a potential future base class that can unify configuration handling for both CPU and accelerator allocators, should similar requirements arise for the CPU path.

# Design Rule
## Overall
This class configures memory allocation for both device and host memory. A single `AcceleratorAllocatorConfig` instance is shared across all accelerator backends, such as CUDA and XPU, under the assumption that relevant environment variables apply uniformly to all accelerators. Device-specific configuration extensions are supported via hooks (see `registerDeviceConfigParserHook`).
Introduce a new class `ConfigTokenizer` to help process the env variable config key-value pair

## Naming Convention:
- Public API names in `AcceleratorAllocatorConfig` should be device-generic.
- Members prefixed with `pinned_` are specific to the host/pinned allocator.
- Environment variable names should be generic across backends.
- Comma-separated key-value pairs in the format: `key:value`. Use square brackets `[]` for list values Example: `key1:123, key2:[val1,val2]`

## Environment Variables:
- The default environment variable for configuration is `PYTORCH_ALLOC_CONF`.
- For backward compatibility, `PYTORCH_CUDA_ALLOC_CONF` and `PYTORCH_HIP_ALLOC_CONF` are also supported with lower priority.

Differential Revision: [D79011786](https://our.internmc.facebook.com/intern/diff/D79011786)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149601
Approved by: https://github.com/albanD
2025-07-30 06:36:46 +00:00
7eb5fdb358 [dynamo][guards] Recursive dict tag optimization (#159183)
Design doc here - https://docs.google.com/document/d/1W29DrWID5miGWlZXspsQVN5U0zydE3kjZpziOXrhuaY/edit?tab=t.0#bookmark=id.sba04iw9sp68

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159183
Approved by: https://github.com/jansel
2025-07-30 06:01:32 +00:00
f1fb57d854 Add user annotation for FX graph cache key (#159318)
Summary: AI system co-design team requested to add user annotation for FX graph cache key in PyTorch Kineto trace and Execution trace. With this annotation, they can know the FX graph to which the kernels belong.

Test Plan:
buck2 run mode/opt caffe2/test:test_profiler_cuda -- profiler.test_execution_trace.TestExecutionTraceCUDA

Rollback Plan:

Differential Revision: D79019069

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159318
Approved by: https://github.com/sraikund16, https://github.com/jansel
2025-07-30 05:52:50 +00:00
6d0f4566e2 Move Half to headeronly (#159172)
Essence of this copypasta:
- combine Half-inl.h and Half.h in c10/util -> torch/headeronly/util/Half.h
- Add NOLINTNEXTLINE's to the portions of Half-inl.h that were previously in the ignore list of clangtidy
- Re-expose all APIs in namespaces and through includes of the original files. Ideally, we would have the APIs in torch::headeronly and reexpose them in c10, but that runs into BC issues (see D78997465) so for now we are keeping the APIs in c10 but reexposing them in torch::headeronly.
- Change test cases in test_aoti_abi_check to test torch::headeronly::Half vs c10::Half (they're the same thing but we eventually want all the tests for headeronly APIs to only import from headeronly).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159172
Approved by: https://github.com/albanD, https://github.com/desertfire
2025-07-30 05:02:13 +00:00
e785c087c5 [audio hash update] update the pinned audio hash (#159321)
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml).
Update the pinned audio hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159321
Approved by: https://github.com/pytorchbot
2025-07-30 04:35:01 +00:00
d214901133 Add a title to distributed._dist2.md (#159385)
Sphinx likes titles and complains about them when they are not there. So adding a title to address this Wartning in the build:
```
WARNING: toctree contains reference to document 'distributed._dist2' that doesn't have a title: no link will be generated
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159385
Approved by: https://github.com/d4l3k
2025-07-30 04:09:41 +00:00
96ac64d00c Migrate easy q(u)int/bits stuff to torch/headeronly (#159302)
Straightup copy pasta. Keeps APIs in c10 and reexposes them to torch::headeronly.

It is arguable that we should just get rid of some of these unused dtypes but that is outside the scope of this PR, which is meant to build up to ScalarType moving to headeronly.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159302
Approved by: https://github.com/malfet, https://github.com/albanD
2025-07-30 03:41:27 +00:00
46d34d6766 (should_fold) gso to guard_or_false when checking folding whether to 3d bmm into 2d mm (#159184)
Switch from guard_size_oblivious to guard_or_false if you encounter a DDE, this would then avoid folding this 3d bmm into a mm.

806d9e3fe7/torch/_decomp/decompositions.py (L4506-L4512)

## DDE
```
  File "/data/users/colinpeppler/pytorch/torch/_decomp/decompositions.py", line 4506, in matmul
    elif should_fold(tensor1, tensor2, is_out):
  File "/data/users/colinpeppler/pytorch/torch/_decomp/decompositions.py", line 4472, in should_fold
    if guard_size_oblivious(t1.numel() == 0):
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(12*((u0//2)), 0) (unhinted: Eq(12*((u0//2)), 0)).  (Size-like symbols: none)

Caused by: (_decomp/decompositions.py:4472 in should_fold)
```

```
  File "/data/users/colinpeppler/pytorch/torch/_decomp/decompositions.py", line 4506, in matmul
    elif should_fold(tensor1, tensor2, is_out):
  File "/data/users/colinpeppler/pytorch/torch/_decomp/decompositions.py", line 4483, in should_fold
    return all(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(3*((u0//2)), 3) (unhinted: Eq(3*((u0//2)), 3)).  (Size-like symbols: none)

Caused by: (_decomp/decompositions.py:4483 in should_fold)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159184
Approved by: https://github.com/ezyang
ghstack dependencies: #158894
2025-07-30 03:12:14 +00:00
clr
880249adbc dynamo: handle AttributeErrors from nn_module when infer_paramaters throws. (#158501)
This only handles AttributeError, but in general, any exception coming from
here is a user exception. let me know if we prefer to catch all exceptions, and then reraise them as observed exceptions.

```
 File "/packages/aps.ads.gmp/launcher_with_publish#link-tree/torch/_dynamo/symbolic_convert.py", line 2200, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/packages/aps.ads.gmp/launcher_with_publish#link-tree/torch/_dynamo/symbolic_convert.py", line 1210, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/packages/aps.ads.gmp/launcher_with_publish#link-tree/torch/_dynamo/variables/lazy.py", line 201, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
  File "/packages/aps.ads.gmp/launcher_with_publish#link-tree/torch/_dynamo/variables/nn_module.py", line 472, in call_function
    initialize_lazy_module(tx, mod, args, kwargs)
  File "/packages/aps.ads.gmp/launcher_with_publish#link-tree/torch/_dynamo/variables/nn_module.py", line 104, in initialize_lazy_module
    mod._infer_parameters(mod, fake_args, fake_kwargs)
  File "/packages/aps.ads.gmp/launcher_with_publish#link-tree/torch/nn/modules/lazy.py", line 261, in _infer_parameters
    module.initialize_parameters(*args, **kwargs)
  ...,
  File "/packages/aps.ads.gmp/launcher_with_publish#link-tree/torch/nn/modules/module.py", line 1962, in __getattr__
    raise AttributeError(
torch._dynamo.exc.InternalTorchDynamoError: AttributeError: '...' object has no attribute '...'
```

Note that we crash with a sligthly different exception trace in the other test I added. Let me know if we want this to not throw directly to the end user.
```
======================================================================
ERROR: test_lazy_module_bad_params (__main__.NNModuleTests.test_lazy_module_bad_params)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/data/users/clr/pytorch/torch/testing/_internal/common_utils.py", line 3223, in wrapper
    method(*args, **kwargs)
    ~~~~~~^^^^^^^^^^^^^^^^^
  File "/data/users/clr/pytorch/test/dynamo/test_modules.py", line 1683, in test_lazy_module_bad_params
    exp_res = opt_m(x, y)
  File "/data/users/clr/pytorch/torch/_dynamo/eval_frame.py", line 411, in __call__
    return super().__call__(*args, **kwargs)
           ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/data/users/clr/pytorch/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/data/users/clr/pytorch/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/users/clr/pytorch/torch/_dynamo/eval_frame.py", line 473, in _call_lazy_check
    self._orig_mod._infer_parameters(self._orig_mod, args, kwargs)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/clr/pytorch/torch/nn/modules/lazy.py", line 261, in _infer_parameters
    module.initialize_parameters(*args, **kwargs)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/data/users/clr/pytorch/test/dynamo/test_modules.py", line 711, in initialize_parameters
    self.foo += 1
    ^^^^^^^^
  File "/data/users/clr/pytorch/torch/nn/modules/module.py", line 1962, in __getattr__
    raise AttributeError(
        f"'{type(self).__name__}' object has no attribute '{name}'"
    )
AttributeError: 'LazyModuleBadInferParams' object has no attribute 'foo'
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158501
Approved by: https://github.com/williamwen42, https://github.com/jansel
2025-07-30 02:41:41 +00:00
846ada4973 [AOTI] disable crashed AOTI UTs on Windows. (#159427)
disable crashed AOTI UTs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159427
Approved by: https://github.com/angelayi
2025-07-30 02:23:27 +00:00
badd0618e4 Remove unused paramter on CUDA AllocParams (#159159)
# Motivation
While refactoring the caching allocator, I noticed that the `AllocParams` constructor on CUDA had an unused parameter. This change removes that unused argument to avoid potential confusion.

# Additional Context
I noticed that `AllocParams` is defined in cpp file, so it should be safe to make this change.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159159
Approved by: https://github.com/cyyever, https://github.com/albanD
2025-07-30 02:05:25 +00:00
a753a72b14 [BE] Modify PyObjectSlot the assume only a single interpreter is in use (#158407)
This PR makes some less risky changes to PyObjectSlot as there is a lot of stuff we do not need since there is only one interpreter. Specifically `check_interpreter` and `has_pyobj_nonhermetic` are removed

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158407
Approved by: https://github.com/albanD
ghstack dependencies: #158290, #158291
2025-07-30 01:36:03 +00:00
b57d1ef110 [BE] Remove __reduce_deploy__ (#158291)
This PR removes the integration point torch.fx had with torch::deploy (and another minor change).

Note: This PR has some broken mypy errors, but I believe those should have been in the code base beforehand, and should be fixed in a separate PR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158291
Approved by: https://github.com/albanD
ghstack dependencies: #158290
2025-07-30 01:36:03 +00:00
dd7c996d5c [BE] Remove torch deploy | remove torch deploy specific files (#158290)
This PR removes specific files found in pytorch which are only used for torch::deploy. This is mostly testing code and a debugger.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158290
Approved by: https://github.com/albanD
2025-07-30 01:36:03 +00:00
70d2e9ba45 [MPS] Avoid outputing zeros from exponential_ for MPS (#159386)
Fixes #159103
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159386
Approved by: https://github.com/malfet
2025-07-30 00:20:31 +00:00
eqy
62f98dbb44 [CUDA][Convolution] Add tf32_on_and_off decorator to test_deconv_freezing_cuda (#159280)
Blackwell seems to select TF32 kernels for this case

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159280
Approved by: https://github.com/zou3519, https://github.com/jingsh, https://github.com/Skylion007
2025-07-29 23:44:10 +00:00
e288c258f7 Revert "Remove tensorexpr tests (#158928)"
This reverts commit d742a2896c571a535003d5928fe80397325575a5.

Reverted https://github.com/pytorch/pytorch/pull/158928 on behalf of https://github.com/yangw-dev due to this breaks bunch of internal dependency since some tests are still using the deleted test files from this pr, the internal reviewer please help fix this using codev ([comment](https://github.com/pytorch/pytorch/pull/158928#issuecomment-3134378616))
2025-07-29 23:32:07 +00:00
df58db8831 [dynamo, docs] add recompilation, observability, reporting issues docs (#159062)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159062
Approved by: https://github.com/svekars, https://github.com/zou3519, https://github.com/anijain2305
2025-07-29 23:23:51 +00:00
15bb81ea4f [2/N][CI] Remove MacOS-13 workarounds from tests (#159304)
Part of https://github.com/pytorch/pytorch/issues/159275

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159304
Approved by: https://github.com/dcci, https://github.com/cyyever
ghstack dependencies: #159277, #159278
2025-07-29 23:12:13 +00:00
8d37073bac [ROCm] Update jit_utils.cpp trait modification based on HIP version. (#159292)
The mi355 ci regression and hiprtc kernel compilation is failing due to duplicate definitions of traits leading to errors like `error: redefinition of 'integral_constant'`. This seems to be the culprit: https://github.com/pytorch/pytorch/pull/158868. Checking if using hip version instead of rocm version for the check would help with resolution here as rocm version and hip version aren't synced. ROCm 7.0 Alpha build used in CI is still on HIP 6.5.

Confirmed that this patch works here: https://github.com/pytorch/pytorch/actions/runs/16579227179?pr=159292

Also, this PR increases the frequency of this MI355 CI to twice a day so we can catch and identify regressions easier if they happen for now.

Jeff is on vacation, so Jithun asked me to reach out to y'all. Please help stamp and approve, so we can resolve the recent MI355 CI regression/timeout (https://github.com/pytorch/pytorch/actions/workflows/rocm-mi355.yml) :) @huydhn @malfet @atalman @seemethere

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159292
Approved by: https://github.com/malfet
2025-07-29 22:45:27 +00:00
dc286aef61 Fused RMSNorm Housekeeping (#159317)
Small PR to address comments that were made from the original fused rmsnorm PR that were not landed

Changes:
- Warning message when input.dtype doesn't match weight.dtype
- Ensure default epsilon value is correct

Comments:
https://github.com/pytorch/pytorch/pull/153666#discussion_r2114735005
https://github.com/pytorch/pytorch/pull/153666#discussion_r2223518064

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159317
Approved by: https://github.com/ngimel, https://github.com/Skylion007, https://github.com/eqy
2025-07-29 22:39:18 +00:00
b4619f0272 Pin Helion to 0.0.10 in PyTorch CI (#159420)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159420
Approved by: https://github.com/aorenste, https://github.com/malfet
2025-07-29 22:06:50 +00:00
477c2273e1 [dynamo] better way to skip tracing sys.monitoring callables (#159369)
Better approach to https://github.com/pytorch/pytorch/pull/158171, according to https://github.com/python/cpython/issues/137178#issuecomment-3131617493.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159369
Approved by: https://github.com/Skylion007
2025-07-29 21:54:58 +00:00
2176d481c1 [DTensor] dispatch to sharding prop over decomps (#159324)
Fixes #159110

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159324
Approved by: https://github.com/ezyang
2025-07-29 21:28:36 +00:00
b97274e8ac [iter] Raise TypeError if iter arg cannot be iterable (#158410)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158410
Approved by: https://github.com/XuehaiPan, https://github.com/zou3519
ghstack dependencies: #156371, #156416, #156460
2025-07-29 21:24:21 +00:00
f9be65cea4 [iter] Wrap iter(..) call in a ObjectIteratorVariable (#156460)
This object keeps track when the iterator is exhausted (raise Stopiteration).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156460
Approved by: https://github.com/zou3519
ghstack dependencies: #156371, #156416
2025-07-29 21:24:20 +00:00
4e3e3dc0a7 [iter] support iter(callable, sentinel) (#156416)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156416
Approved by: https://github.com/XuehaiPan, https://github.com/zou3519
ghstack dependencies: #156371
2025-07-29 21:24:20 +00:00
fcf59df2b6 [iter] Add support for sequence protocol in iter(..) (#156371)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156371
Approved by: https://github.com/zou3519
2025-07-29 21:24:20 +00:00
1bcb2f41e0 [BE] Eliminate workspace info in templates with new API (#159055)
Summary: Moves the workspace info calculations to the old TMA API.

Test Plan:
NFC

Rollback Plan:

Differential Revision: D78904434

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159055
Approved by: https://github.com/NikhilAPatel
2025-07-29 21:22:36 +00:00
8460131087 [nativert] Add OSS version of ModelRunner (#159268)
Summary: Implement a ModelRunner from scratch with the minimum features for OSS only

Test Plan:
test_export -r NativeRT

Rollback Plan:

Differential Revision: D78979812

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159268
Approved by: https://github.com/dolpm
2025-07-29 21:08:14 +00:00
c0c24b61ff Revert "Partitioner: Fix to align partition node order with original graph (#157892)"
This reverts commit 2d1e92307d3e67622f4fe8058d62e44fe4fa2f4e.

Reverted https://github.com/pytorch/pytorch/pull/157892 on behalf of https://github.com/yangw-dev due to fails internal tests : [executorch/backends/xnnpack/partition/xnnpack_partitioner.py:101:24] Incompatible parameter type [6]: In call `Partition.__init__`, for argument `nodes`, expected `Optional[Iterable[Tuple[Node, Optional[int]]]]` but got `dict_keys[Node, str]`. ([comment](https://github.com/pytorch/pytorch/pull/157892#issuecomment-3134004881))
2025-07-29 20:41:45 +00:00
4fac43b21f [BE] Move _freeze.py to torch/fb/utils (#159307)
Summary: We are trying to deprecate torch deploy externally. However a bunch of legacy stuff still uses it. This PR allows the legacy tests to still run if neccessary

Test Plan:
It's a targets change so CI should suffice

Rollback Plan:

Differential Revision: D78910653

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159307
Approved by: https://github.com/albanD
2025-07-29 20:07:17 +00:00
b794e77b7b Disable cudagraph GCs by default (#158649)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158649
Approved by: https://github.com/eellison
ghstack dependencies: #158193
2025-07-29 19:56:11 +00:00
d987a6f7f0 Revert "[Dynamo][Better Engineering] Add typing annotations to guard and source (#158397)"
This reverts commit abcb24f4de11f8fedf2c2c9ff53b6092ef42306d.

Reverted https://github.com/pytorch/pytorch/pull/158397 on behalf of https://github.com/yangw-dev due to Suggested to fix failing internal signals on D78911890 ([comment](https://github.com/pytorch/pytorch/pull/158397#issuecomment-3133823766))
2025-07-29 19:49:40 +00:00
5d93127c87 Revert "[HOP, map] Rework of map autograd to the new interface (#153343)"
This reverts commit 24b1f10ca13d682430725c511812e43a35fcd6a6.

Reverted https://github.com/pytorch/pytorch/pull/153343 on behalf of https://github.com/yangw-dev due to a older pr this pr dependes on needed to revert, rebase it after it's in ([comment](https://github.com/pytorch/pytorch/pull/153343#issuecomment-3133816812))
2025-07-29 19:46:42 +00:00
a3a51282db Fix rand_like decomposition to preserve strides (#159294)
Summary: Like https://github.com/pytorch/pytorch/pull/158898, the rand_like variants are not preserving strides. Followed the pattern established in https://github.com/pytorch/pytorch/pull/158898.

Test Plan: New unit test (fails before this PR; but fixed after)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159294
Approved by: https://github.com/eellison
2025-07-29 19:26:20 +00:00
e557b3d5e5 Revert "[inductor] Fix mm decomposition evaluating symints (#158998)"
This reverts commit 52e180c3799a7638ee668b1291a711865ab8cfec.

Reverted https://github.com/pytorch/pytorch/pull/158998 on behalf of https://github.com/yangw-dev due to it broke trunk with pr_time_benchmark test  ([comment](https://github.com/pytorch/pytorch/pull/158998#issuecomment-3133696775))
2025-07-29 19:04:11 +00:00
f3a9e99036 Fix inductor cuda sort nan behavior (#159308)
Fix for https://github.com/pytorch/pytorch/issues/152423

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159308
Approved by: https://github.com/isuruf
2025-07-29 19:02:45 +00:00
f7d6e9f500 [dynamo][guards] More small guard optimizations (#159345)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159345
Approved by: https://github.com/williamwen42
ghstack dependencies: #159288
2025-07-29 18:36:49 +00:00
e43e09e6c1 [dynamo][guards] Use lambda guards for object aliasing to improve object aliasing guards (#159288)
# Note - On Lambda guarding of object aliasing
        # We previously installed object‑aliasing guards as relational guards,
        # but that undermined the recursive‑dict guard optimization: placing the
        # aliasing guard at a leaf prevented the parent dict node from
        # qualifying as a recursive‑dict guard root. Because aliasing guards are
        # rare, we now emit them as epilogue guards via a small Python lambda.
        # This repeats the access in Python—adding a bit of work—but the
        # overhead is outweighed by the gains from enabling recursive‑dict guard
        # optimization.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159288
Approved by: https://github.com/StrongerXi
2025-07-29 18:36:49 +00:00
2004f8aa10 FXConverter handling of generic output in inductor fallback kernel (#159002) (#159297)
Summary:

A fallback kernel's output may be a non-list/tuple but a `MultiOutput` with empty indices. Allow the `FXConverter` to handle such case.

Test Plan:
Modified the fxir test for fallbacks, then ran `buck2 test mode/dev-nosan caffe2/test/inductor:fxir_backend -- test_fallback`.

Before this diff the modified test would fail with
```
File "/re_cwd/buck-out/v2/gen/fbcode/e2105f7329ead90a/caffe2/test/inductor/__fxir_backend__/fxir_backend#link-tree/torch/_inductor/codegen/wrapper_fxir.py", line 341, in generate
    line.codegen_fx(self)(line)
  File "/re_cwd/buck-out/v2/gen/fbcode/e2105f7329ead90a/caffe2/test/inductor/__fxir_backend__/fxir_backend#link-tree/torch/_inductor/codegen/wrapper_fxir.py", line 489, in _generate_multi_output
    inds = line.indices[0][1:]
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
IndexError: list index out of range
```
 (Full error paste in P1878839403)

With this diff the error is no longer present.

Rollback Plan:

Differential Revision: [D79126619](https://our.internmc.facebook.com/intern/diff/D79126619)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159297
Approved by: https://github.com/blaine-rister
2025-07-29 18:29:01 +00:00
31b3b38e3a Ensure export joint with descriptors + compile works (#159337)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159337
Approved by: https://github.com/wconstab
ghstack dependencies: #159336
2025-07-29 17:43:52 +00:00
2f0db0444e Track previous MetricsContext edits for ease of debugging. (#159336)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159336
Approved by: https://github.com/wconstab
2025-07-29 17:43:52 +00:00
6162e650b0 [BE] remove torch deploy - conditionals (#158288)
This PR is part of the work to deprecate torch::deploy in OSS. Effectively it does 3 things to get started.
1. Remove test_deploy_interaction as we no longer need to worry about this
2. Remove all torch._running_with_deploy checks and use the False path always (surfaced 1)
3. Remove `USE_DEPLOY` and switch to the default path always

Note: MyPy does fail on a bunch of things here as a bunch of older files are touched. It may be better to fix these things on a separate PR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158288
Approved by: https://github.com/albanD
2025-07-29 17:40:49 +00:00
5d89634ca8 Graph break with error message (#158800)
Fixes #157452

Test with
```
python test/dynamo/test_repros.py ReproTests.test_nn_parameter_ctor_graph_breaks
```

### Release Notes

Change to nn.Parameter Constructor Behavior in Dynamo

Semantic change introduced in the nn.Parameter constructor; previously, if the constructor lacked a clean source, the system would attempt to infer arguments to construct a clone and lift this synthetic proxy in the computation graph. This approach had many potential edge cases and was difficult to reason about. The new behavior defaults to graph breaking when the nn.Parameter constructor does not have a clean source. Users are now suggested to manually move the constructor out of the graph in such cases. This change improves clarity and reduces complexity in graph construction and debugging.  Users can escape hatch to old semantics with `torch.dynamo.config.graph_break_on_nn_param_ctor=False` if this cannot be done.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158800
Approved by: https://github.com/anijain2305
2025-07-29 17:34:49 +00:00
52e180c379 [inductor] Fix mm decomposition evaluating symints (#158998)
Fixes #154111

Resolves an issue during compilation with dynamic shapes where `torch._inductor.decomposition.mm` evaluates the SymInt expression for the input tensor due to a for loop, and thus the output tensor is not dynamically shaped. This issue is limited to (Mx1)x(1xN) small matrix multiplications, and creates an explicit error with tensor subclasses such as DTensor.

The proposed fix replaces the loop with a simple product instead. Benchmark currently running https://hud.pytorch.org/benchmark/compilers

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158998
Approved by: https://github.com/jansel, https://github.com/BoyuanFeng
2025-07-29 17:29:38 +00:00
c55e72bea1 [Re-land][Inductor] Support native Inductor as backend for MTIA (#159211)
The previous [diff/PR] (https://github.com/pytorch/pytorch/pull/158526) was reverted due to this docstring lint error:
<img width="1736" height="722" alt="image" src="https://github.com/user-attachments/assets/216b1720-4002-48da-b5f3-32b5d48aaa54" />
I didn't add the docstring cause I thought I'm not supposed to add docstring for an EXISTING function.

So this diff/PR is an exactly copy of the previous one, except for adding the docstring.

-------------
This diff/PR includes the changes to support native Inductor integration for MTIA. The goal is to support `torch.compile(backend="inductor")` for MTIA. Inductor should generate code(triton kernel + python wrapper code) similar to CUDA. And the triton kernels can be launched eagerly.

The changes include:
- Add MTIA device interfaces used by Dynamo and Inductor, including APIs on device, stream, event, etc.
- Add required torch.mtia APIs, like is_bf16_supported, memory_allocated, set_stream_by_id, etc.
- MTIA specific codegen logic, for example, loading MTIA dynamic_library.
- Other necessary changes to integrate with Inductor codegn, following other devices like CUDA, XPU.
- Integrate with the [empty_strided_mtia](https://www.internalfb.com/code/fbsource/[0d017d3a4a1bdff7253f9c66a9f38e77bd62166b]/fbcode/caffe2/aten/src/ATen/native/mtia/EmptyTensor.cpp?lines=49%2C63%2C71%2C74%2C78) API that we’ve added for the new MTIA ATen backend.
- A change in Inductor runtime to avoid re-initialize MTIADriver.
- BUCK changes to include ATen-mtia in Inductor, and to use -USE_MTIA preprocessor flag.
- Update `test_mnist_e2e.py` to cover native Inductor as backend, using the `--use_native_inductor` flag.
- Add a personal script(`scripts/anwang/run_native_inductor_script.py`) for testing purpose.

Note:
- This approach(option 3) aims to provide a pytorch native approach of Inductor integration for MTIA, minimizing the onboarding overhead. The downside of this approach is that it doesn't leverage MTIA specific graph optimization, and is limited to eagerly launch overhead.
- MTIA will support another approach(option 2) to provide best performance, based on WrapperFxCodegen. We should be able to reuse the fundamental changes of this diff for option 2, like the device interfaces, steam/event APIs, etc, especially as WrapperFxCodegen inherits PythonWrapperCodegen.

Internal:
References:
- [post for context](https://fb.workplace.com/groups/mtiasw/permalink/1718377262384606/)
- [Inductor integration discussion(option 1/2/3)](https://docs.google.com/document/d/1p6363OXtVIRv1hPoaKlRSK3j-iir3QIbDd5bjyqCNig/edit?tab=t.0#heading=h.7s4ns6wcnhmb)
- [Project design doc(option 3)](https://docs.google.com/document/d/1jXUmhgoV9WvkMf-bcY3Od_kK9K_RDOdgHdt1LoQ5Tc4/edit?tab=t.0#heading=h.y43gwdqlv46w)
- [early prototying diff](https://www.internalfb.com/diff/D75110196)
- [MPS integration PR](https://github.com/pytorch/pytorch/pull/153959)
- [empty_strided_xpu PR](https://github.com/pytorch/pytorch/pull/126678)

Differential Revision: [D79040806](https://our.internmc.facebook.com/intern/diff/D79040806/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159211
Approved by: https://github.com/eellison, https://github.com/blaine-rister, https://github.com/jansel
2025-07-29 17:03:24 +00:00
750348b579 [NativeRT] Clean up use of TargetDevice in KernelFactory (#159298)
Summary:
Remove use of targetDevice in KernelFactory.

AOTI would infer device when creating AOTIDelegateExecutor.

Test Plan:
CI

Rollback Plan:

Reviewed By: dolpm

Differential Revision: D79007317

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159298
Approved by: https://github.com/dolpm
2025-07-29 16:24:33 +00:00
52b9af163c Add avg_pool3d for MPS (#158877)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158877
Approved by: https://github.com/malfet
2025-07-29 15:22:22 +00:00
f4bfac11c7 [Precompile] [easy] API For Editable PrecompileCacheArtifacts (#158586)
This adds an option for backend precompile artifacts to be *editable*, i.e. to not serialize them right away, but instead be able to apply a Callable edit_fn to them.

This allows us to support editing the precompile artifact with more updated autotune results at a later time in the next PR. The goal flow here is:
- User runs AOTAutograd -> Inductor -> Triton
- User saves to AOTAutogradCache the normal results
- User runs autotuning
- User calls serialize(), it takes the new autotuning results at runtime and saves only the necessary triton kernels.

This PR just implements the API for editing the cache artifacts. The next PR actually adds the autotuning saving support.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158586
Approved by: https://github.com/zhxchen17
2025-07-29 14:53:21 +00:00
8d00833fdb [PP] Fix eval step under no_grad() (#159293)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159293
Approved by: https://github.com/tianyu-l, https://github.com/wconstab
2025-07-29 14:42:33 +00:00
de529ef002 [ONNX] onnx.md to simplify deprecated entities (#159312)
Simplify documentation of deprecated entities and remove the auto-generated page for JitScalarType
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159312
Approved by: https://github.com/titaiwangms
2025-07-29 14:24:17 +00:00
61aa2ae20f Revert "[CPU] fix _weight_int8pack_mm with large output shape (#158341)"
This reverts commit e469414b59ceeaae2860e36708de8852b9892776.

Reverted https://github.com/pytorch/pytorch/pull/158341 on behalf of https://github.com/albanD due to Breaks slowtest ([comment](https://github.com/pytorch/pytorch/pull/158341#issuecomment-3132641530))
2025-07-29 13:56:20 +00:00
9d32aa9789 Help fix numpy detection in cross compiled layouts (#137084)
We had trouble at conda-forge getting numpy to get detected on aarch64 due to our splayed layout and cross compilation needs.

see:
* https://github.com/conda-forge/pytorch-cpu-feedstock/pull/256
* https://github.com/conda-forge/pytorch-cpu-feedstock/issues/266
* https://github.com/conda-forge/pytorch-cpu-feedstock/pull/267

This is my attempt at making an "upstreamable patch" that tries to follow your structure.

It could introduce a new environment variable `Python_NumPy_INCLUDE_DIR` if you want, but CMake doesn't use it as an environment variable, so I feel like that would be weird.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137084
Approved by: https://github.com/atalman
2025-07-29 12:08:56 +00:00
5cf77a0ea2 Fix redistribution costs for slice_scatter (#159223)
We were previously assuming that the `input_strategy == src_strategy`, which is not true in all cases.

This should fix this.

On the side, I also realized that for `slice_scatter` some DTensorSpecs don't have TensorMeta, e.g., https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_ops/_tensor_ops.py#L524

It would be good to fix it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159223
Approved by: https://github.com/ezyang, https://github.com/wconstab
2025-07-29 12:00:39 +00:00
efcf87654e [CI] update flake8 and mypy lint dependencies (#158720)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158720
Approved by: https://github.com/Skylion007
2025-07-29 08:05:56 +00:00
2523e58781 unbacked handling for view_copy (#159244)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159244
Approved by: https://github.com/bobrenjc93
2025-07-29 07:10:46 +00:00
222fa451a2 Move some of vec into headeronly in preparation for Half.h (#158976)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158976
Approved by: https://github.com/albanD, https://github.com/desertfire
2025-07-29 05:43:53 +00:00
6de24135e5 Fix flaky test_inductor_multiple_specializations (#159264)
Summary: This test was using do_bench, so it was flaky performance is non-deterministic.

Test Plan:
buck test 'fbcode//mode/opt' fbcode//caffe2/test/inductor:compile_subprocess -- --exact 'caffe2/test/inductor:compile_subprocess - test_inductor_multiple_specializations_cuda (caffe2.test.inductor.test_compile_subprocess.GPUTests)' --run-disabled

Rollback Plan:

Differential Revision: D79098692

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159264
Approved by: https://github.com/jingsh
2025-07-29 05:16:55 +00:00
27ae72036d [cutlass] Prep for cutlass upgrade by ignoring Wunused-but-set-variable (#159276)
Differential Revision: [D79106238](https://our.internmc.facebook.com/intern/diff/D79106238/)

This is in prep for cutlass upgrade.

More context: https://github.com/NVIDIA/cutlass/issues/2487

Tested in https://github.com/pytorch/pytorch/pull/159115
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159276
Approved by: https://github.com/adamomainz, https://github.com/njriasan, https://github.com/Skylion007
2025-07-29 04:40:24 +00:00
e924df23a6 [NativeRT] Strengthen matcher check for StaticDispatch kernel (#159187)
Summary:
Strength matcher for StaticDispatch kernels: all input, output tensor must be on CPU, all Device-typed attribute must be CPU.

Previously, we only check output tensor on CPU. This will miss catching the case where we do DeviceToHost aten._to_copy.

Prepare for turning on static dispatch kernel by default.

Test Plan:
I should add some test before land.

Rollback Plan:

Differential Revision: D78747600

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159187
Approved by: https://github.com/dolpm
2025-07-29 04:03:49 +00:00
67e68e0785 [c10d] Cleanup split_group logic using the newly built splitGroup (#158488)
with https://github.com/pytorch/pytorch/pull/157716 merged we want to further clean up the code on the python side for `split_group` API. We do need to keep some old global book keeping for bc. The rest of logic is now all in cpp. Regarding the change brought in https://github.com/pytorch/pytorch/pull/152175, we did clean up in https://github.com/pytorch/pytorch/pull/158790 (including internal changes) so that we can safely remove it.

Differential Revision: [D78777152](https://our.internmc.facebook.com/intern/diff/D78777152)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158488
Approved by: https://github.com/d4l3k
2025-07-29 03:27:11 +00:00
775788f93b [BE][PYFMT] migrate PYFMT for test/[i-z]*/ to ruff format (#144556)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144556
Approved by: https://github.com/ezyang
2025-07-29 03:26:09 +00:00
19ce1beb05 [AOTInductor] Add test for enabling CUDACachingAllocator for AOTInductor's Weight (#159279)
Summary:
Add test for enabling CUDACachingAllocator for AOTInductor's Weight.
Implementation TBD

Test Plan:
N/A, commit is adding a test.

Rollback Plan:

Differential Revision: D79107507

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159279
Approved by: https://github.com/desertfire, https://github.com/jingsh
2025-07-29 02:52:10 +00:00
a91ddea61f Add CPython tests for collections module (#158950)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158950
Approved by: https://github.com/zou3519
2025-07-29 02:24:27 +00:00
ffccb90ff4 [dynamo, docs] add fullgraph=False docs (#159050)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159050
Approved by: https://github.com/svekars, https://github.com/anijain2305
ghstack dependencies: #157985, #158055, #158531
2025-07-29 01:53:47 +00:00
f916f34739 [dynamo, docs] non-strict programming model docs (#158531)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158531
Approved by: https://github.com/AlannaBurke, https://github.com/mlazos, https://github.com/anijain2305
ghstack dependencies: #157985, #158055

Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
2025-07-29 01:53:47 +00:00
c32994ce4b [docs, dynamo] add fullgraph=True, common graph breaks docs (#158055)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158055
Approved by: https://github.com/AlannaBurke, https://github.com/anijain2305
ghstack dependencies: #157985

Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
2025-07-29 01:53:41 +00:00
433e43cbec [dynamo, docs] programming model dynamo core concepts (#157985)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157985
Approved by: https://github.com/svekars, https://github.com/anijain2305
2025-07-29 01:53:34 +00:00
e469414b59 [CPU] fix _weight_int8pack_mm with large output shape (#158341)
**Summary**
`_weight_int8pack_mm` on CPU may cause segmentation fault if output shape is large (i.e., M * N is large). It's because the kernel compute output buffer address by
```c++
auto* C_ptr = C_data + mb_start * N + nb_start;
```
where both `mb_start` and `N` are `int` and when they are large their product may overflow.
The solution is simple: declare these variables as `int64_t` so that the product won't overflow.

**Test plan**
```
pytest -sv test/test_linalg.py -k test__int8_mm_large_shape
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158341
Approved by: https://github.com/mingfeima, https://github.com/drisspg
2025-07-29 01:14:50 +00:00
657e5e9aa6 All custom operators go through Inductor's graph.call_function (#159174)
Fixes #158892

All custom operators should go through the graph.call_function path. The
other fallback path is for aten/prim operations that don't have support
for things (like torch.float8_e8m0fn).

Test Plan:
- new tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159174
Approved by: https://github.com/eellison
2025-07-29 00:31:57 +00:00
f02b783aae [1/N] Remove MacOS-13 MPS testing (#159278)
Starts addressing https://github.com/pytorch/pytorch/issues/159275
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159278
Approved by: https://github.com/dcci
ghstack dependencies: #159277
2025-07-28 23:52:47 +00:00
8ad96a563c [inductor] normalize path of the code. (#159255)
Error stack:
<img width="1361" height="345" alt="image" src="https://github.com/user-attachments/assets/50fb2baa-34fd-4a48-a3e7-76e3185391d4" />

After fix:
<img width="1103" height="398" alt="image" src="https://github.com/user-attachments/assets/ece5a9ba-a085-46fe-b061-0c2ebda3a2df" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159255
Approved by: https://github.com/desertfire
2025-07-28 23:42:11 +00:00
59e261bbd8 Revert "[CI] update flake8 and mypy lint dependencies (#158720)"
This reverts commit f5130bf339f12ccf5c6296130c47685bdc4858e4.

Reverted https://github.com/pytorch/pytorch/pull/158720 on behalf of https://github.com/yangw-dev due to this pr failed internally when build torchgen due to rror: fail: Unknown PyPI project: pyyaml, it seems like this is caused by change PyYAML into  pyyaml, please fix it ([comment](https://github.com/pytorch/pytorch/pull/158720#issuecomment-3129995414))
2025-07-28 22:02:10 +00:00
611 changed files with 60271 additions and 12788 deletions

View File

@ -103,5 +103,5 @@ fi
# It depends on torch and triton. We don't want to install
# triton and torch from production on Docker CI images
if [[ "$ANACONDA_PYTHON_VERSION" != 3.9* ]]; then
pip_install helion --no-deps
pip_install helion==0.0.10 --no-deps
fi

View File

@ -1,7 +1,7 @@
sphinx==5.3.0
#Description: This is used to generate PyTorch docs
#Pinned versions: 5.3.0
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@pytorch_sphinx_theme2#egg=pytorch_sphinx_theme2
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@722b7e6f9ca512fcc526ad07d62b3d28c50bb6cd#egg=pytorch_sphinx_theme2
# TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering
# but it doesn't seem to work and hangs around idly. The initial thought that it is probably
@ -50,8 +50,8 @@ IPython==8.12.0
#Pinned versions: 8.12.0
myst-nb==0.17.2
#Description: This is used to generate PyTorch functorch docs
#Pinned versions: 0.13.2
#Description: This is used to generate PyTorch functorch and torch.compile docs.
#Pinned versions: 0.17.2
# The following are required to build torch.distributed.elastic.rendezvous.etcd* docs
python-etcd==0.4.5
@ -59,4 +59,3 @@ sphinx-copybutton==0.5.0
sphinx-design==0.4.0
sphinxcontrib-mermaid==1.0.0
myst-parser==0.18.1
myst-nb

View File

@ -50,6 +50,9 @@ if [[ ${BUILD_ENVIRONMENT} == *"parallelnative"* ]]; then
export ATEN_THREADING=NATIVE
fi
# Enable LLVM dependency for TensorExpr testing
export USE_LLVM=/opt/llvm
export LLVM_DIR=/opt/llvm/lib/cmake/llvm
if ! which conda; then
# In ROCm CIs, we are doing cross compilation on build machines with
@ -189,6 +192,7 @@ if [[ "$BUILD_ENVIRONMENT" == *-clang*-asan* ]]; then
export USE_ASAN=1
export REL_WITH_DEB_INFO=1
export UBSAN_FLAGS="-fno-sanitize-recover=all"
unset USE_LLVM
fi
if [[ "${BUILD_ENVIRONMENT}" == *no-ops* ]]; then

View File

@ -462,7 +462,7 @@ test_inductor_aoti() {
# rebuild with the build cache with `BUILD_AOT_INDUCTOR_TEST` enabled
/usr/bin/env CMAKE_FRESH=1 BUILD_AOT_INDUCTOR_TEST=1 "${BUILD_COMMAND[@]}"
/usr/bin/env "${TEST_ENVS[@]}" python test/run_test.py --cpp --verbose -i cpp/test_aoti_abi_check cpp/test_aoti_inference -dist=loadfile
/usr/bin/env "${TEST_ENVS[@]}" python test/run_test.py --cpp --verbose -i cpp/test_aoti_abi_check cpp/test_aoti_inference cpp/test_vec_half_AVX2 -dist=loadfile
}
test_inductor_cpp_wrapper_shard() {
@ -1039,10 +1039,20 @@ test_libtorch_api() {
mkdir -p $TEST_REPORTS_DIR
OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="${MNIST_DIR}" "$TORCH_BIN_DIR"/test_api --gtest_filter='-IMethodTest.*' --gtest_output=xml:$TEST_REPORTS_DIR/test_api.xml
"$TORCH_BIN_DIR"/test_tensorexpr --gtest_output=xml:$TEST_REPORTS_DIR/test_tensorexpr.xml
else
# Exclude IMethodTest that relies on torch::deploy, which will instead be ran in test_deploy
OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="${MNIST_DIR}" python test/run_test.py --cpp --verbose -i cpp/test_api -k "not IMethodTest"
# On s390x, pytorch is built without llvm.
# Even if it would be built with llvm, llvm currently doesn't support used features on s390x and
# test fails with errors like:
# JIT session error: Unsupported target machine architecture in ELF object pytorch-jitted-objectbuffer
# unknown file: Failure
# C++ exception with description "valOrErr INTERNAL ASSERT FAILED at "/var/lib/jenkins/workspace/torch/csrc/jit/tensorexpr/llvm_jit.h":34, please report a bug to PyTorch. Unexpected failure in LLVM JIT: Failed to materialize symbols: { (main, { func }) }
if [[ "${BUILD_ENVIRONMENT}" != *s390x* ]]; then
python test/run_test.py --cpp --verbose -i cpp/test_tensorexpr
fi
fi
# quantization is not fully supported on s390x yet

View File

@ -53,13 +53,12 @@ self-hosted-runner:
- linux.rocm.gpu.mi250
- linux.rocm.gpu.2
- linux.rocm.gpu.4
# MI300 runners
- linux.rocm.gpu.mi300.2
- linux.rocm.gpu.mi300.4
# gfx942 runners
- linux.rocm.gpu.gfx942.2
- linux.rocm.gpu.gfx942.4
- rocm-docker
# Org wise AWS `mac2.metal` runners (2020 Mac mini hardware powered by Apple silicon M1 processors)
- macos-m1-stable
- macos-m1-13
- macos-m1-14
# GitHub-hosted MacOS runners
- macos-latest-xlarge

View File

@ -1 +1 @@
f6dfe1231dcdd221a68416e49ab85c2575cbb824
bf305f538005f2e900f8850ed57146024a8bc559

View File

@ -1 +1 @@
8f605ee30912541126c0fe46d0c8c413101b600a
ca9e2be3ed6320b51f52f536595cd24e254f8bb2

View File

@ -2,7 +2,7 @@ boto3==1.35.42
cmake==3.27.*
expecttest==0.3.0
fbscribelogger==0.1.7
filelock==3.13.1
filelock==3.18.0
hypothesis==6.56.4
librosa>=0.6.2
mpmath==1.3.0

View File

@ -1891,7 +1891,9 @@ def validate_revert(
else pr.get_comment_by_id(comment_id)
)
if comment.editor_login is not None:
raise PostCommentError("Don't want to revert based on edited command")
raise PostCommentError(
"Halting the revert as the revert comment has been edited."
)
author_association = comment.author_association
author_login = comment.author_login
allowed_reverters = ["COLLABORATOR", "MEMBER", "OWNER"]

View File

@ -269,8 +269,8 @@ jobs:
# copy test results back to the mounted workspace, needed sudo, resulting permissions were correct
docker exec -t "${{ env.CONTAINER_NAME }}" sh -c "cd ../pytorch && sudo cp -R test/test-reports ../workspace/test"
- name: Change permissions (only needed for MI300 and MI355 kubernetes runners for now)
if: ${{ always() && steps.test.conclusion && (contains(matrix.runner, 'mi300') || contains(matrix.runner, 'mi355')) }}
- name: Change permissions (only needed for kubernetes runners for now)
if: ${{ always() && steps.test.conclusion && (contains(matrix.runner, 'gfx942') || contains(matrix.runner, 'mi355')) }}
run: |
docker exec -t "${{ env.CONTAINER_NAME }}" sh -c "sudo chown -R 1001:1001 test"

View File

@ -88,23 +88,23 @@ jobs:
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
test-matrix: |
{ include: [
{ config: "inductor_huggingface_perf_rocm", shard: 1, num_shards: 4, runner: "linux.rocm.gpu.mi300.2" },
{ config: "inductor_huggingface_perf_rocm", shard: 2, num_shards: 4, runner: "linux.rocm.gpu.mi300.2" },
{ config: "inductor_huggingface_perf_rocm", shard: 3, num_shards: 4, runner: "linux.rocm.gpu.mi300.2" },
{ config: "inductor_huggingface_perf_rocm", shard: 4, num_shards: 4, runner: "linux.rocm.gpu.mi300.2" },
{ config: "inductor_timm_perf_rocm", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.mi300.2" },
{ config: "inductor_timm_perf_rocm", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.mi300.2" },
{ config: "inductor_timm_perf_rocm", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.mi300.2" },
{ config: "inductor_timm_perf_rocm", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.mi300.2" },
{ config: "inductor_timm_perf_rocm", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.mi300.2" },
{ config: "inductor_torchbench_perf_rocm", shard: 1, num_shards: 8, runner: "linux.rocm.gpu.mi300.2" },
{ config: "inductor_torchbench_perf_rocm", shard: 2, num_shards: 8, runner: "linux.rocm.gpu.mi300.2" },
{ config: "inductor_torchbench_perf_rocm", shard: 3, num_shards: 8, runner: "linux.rocm.gpu.mi300.2" },
{ config: "inductor_torchbench_perf_rocm", shard: 4, num_shards: 8, runner: "linux.rocm.gpu.mi300.2" },
{ config: "inductor_torchbench_perf_rocm", shard: 5, num_shards: 8, runner: "linux.rocm.gpu.mi300.2" },
{ config: "inductor_torchbench_perf_rocm", shard: 6, num_shards: 8, runner: "linux.rocm.gpu.mi300.2" },
{ config: "inductor_torchbench_perf_rocm", shard: 7, num_shards: 8, runner: "linux.rocm.gpu.mi300.2" },
{ config: "inductor_torchbench_perf_rocm", shard: 8, num_shards: 8, runner: "linux.rocm.gpu.mi300.2" },
{ config: "inductor_huggingface_perf_rocm", shard: 1, num_shards: 4, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "inductor_huggingface_perf_rocm", shard: 2, num_shards: 4, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "inductor_huggingface_perf_rocm", shard: 3, num_shards: 4, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "inductor_huggingface_perf_rocm", shard: 4, num_shards: 4, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "inductor_timm_perf_rocm", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "inductor_timm_perf_rocm", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "inductor_timm_perf_rocm", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "inductor_timm_perf_rocm", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "inductor_timm_perf_rocm", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "inductor_torchbench_perf_rocm", shard: 1, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "inductor_torchbench_perf_rocm", shard: 2, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "inductor_torchbench_perf_rocm", shard: 3, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "inductor_torchbench_perf_rocm", shard: 4, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "inductor_torchbench_perf_rocm", shard: 5, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "inductor_torchbench_perf_rocm", shard: 6, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "inductor_torchbench_perf_rocm", shard: 7, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "inductor_torchbench_perf_rocm", shard: 8, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" },
]}
secrets: inherit

View File

@ -47,8 +47,8 @@ jobs:
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
test-matrix: |
{ include: [
{ config: "inductor", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
{ config: "inductor", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
{ config: "inductor", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "inductor", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
]}
secrets: inherit

View File

@ -28,7 +28,6 @@ jobs:
# than our AWS macos-m1-14 runners
test-matrix: |
{ include: [
{ config: "test_mps", shard: 1, num_shards: 1, runner: "macos-m1-13" },
{ config: "test_mps", shard: 1, num_shards: 1, runner: "macos-m1-14" },
{ config: "test_mps", shard: 1, num_shards: 1, runner: "macos-m2-15" },
]}

View File

@ -59,9 +59,9 @@ jobs:
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
test-matrix: |
{ include: [
{ config: "distributed", shard: 1, num_shards: 3, runner: "linux.rocm.gpu.mi300.4", owners: ["module:rocm", "oncall:distributed"] },
{ config: "distributed", shard: 2, num_shards: 3, runner: "linux.rocm.gpu.mi300.4", owners: ["module:rocm", "oncall:distributed"] },
{ config: "distributed", shard: 3, num_shards: 3, runner: "linux.rocm.gpu.mi300.4", owners: ["module:rocm", "oncall:distributed"] },
{ config: "distributed", shard: 1, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4", owners: ["module:rocm", "oncall:distributed"] },
{ config: "distributed", shard: 2, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4", owners: ["module:rocm", "oncall:distributed"] },
{ config: "distributed", shard: 3, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4", owners: ["module:rocm", "oncall:distributed"] },
]}
secrets: inherit

View File

@ -48,12 +48,12 @@ jobs:
sync-tag: rocm-build
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi300.2" },
{ config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.mi300.2" },
{ config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.mi300.2" },
{ config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.mi300.2" },
{ config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.mi300.2" },
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.mi300.2" },
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.gfx942.2" },
]}
secrets: inherit

View File

@ -3,7 +3,7 @@ name: rocm-mi355
on:
workflow_dispatch:
schedule:
- cron: 30 9 * * * # about 2:30am PDT
- cron: 30 11,1 * * * # about 4:30am PDT and 6:30pm PDT
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}

View File

@ -94,7 +94,6 @@ jobs:
{ config: "default", shard: 1, num_shards: 3, runner: "macos-m1-stable" },
{ config: "default", shard: 2, num_shards: 3, runner: "macos-m1-stable" },
{ config: "default", shard: 3, num_shards: 3, runner: "macos-m1-stable" },
{ config: "mps", shard: 1, num_shards: 1, runner: "macos-m1-13" },
{ config: "mps", shard: 1, num_shards: 1, runner: "macos-m1-14" },
{ config: "mps", shard: 1, num_shards: 1, runner: "macos-m2-15" },
]}

View File

@ -164,7 +164,7 @@ init_command = [
'types-setuptools==79.0.0.20250422',
'types-jinja2==2.11.9',
'types-colorama==0.4.6',
'filelock==3.13.1',
'filelock==3.18.0',
'junitparser==2.1.1',
'rich==14.1.0',
'pyyaml==6.0.2',

View File

@ -679,6 +679,7 @@ cc_library(
[
"torch/*.h",
"torch/csrc/**/*.h",
"torch/nativert/**/*.h",
"torch/csrc/distributed/c10d/**/*.hpp",
"torch/lib/libshm/*.h",
],

View File

@ -564,7 +564,7 @@ if(MSVC)
set(CMAKE_NINJA_CMCLDEPS_RC OFF)
if(MSVC_Z7_OVERRIDE)
# CMake set debug flags to use /Z7
set(CMAKE_MSVC_DEBUG_INFORMATION_FORMAT Embedded)
set(CMAKE_MSVC_DEBUG_INFORMATION_FORMAT "$<$<CONFIG:Debug,RelWithDebInfo>:Embedded>")
endif()
foreach(
flag_var
@ -872,6 +872,14 @@ cmake_dependent_option(
"USE_CUDA OR USE_ROCM;NOT MSVC"
OFF)
cmake_dependent_option(
USE_FBGEMM_GENAI
"Whether to build FBGEMM GenAI quantized GEMM kernels.\
Will be disabled if not supported by the platform"
OFF
"USE_CUDA OR USE_ROCM"
OFF)
# CAVEAT: Again, Flash Attention2 will error while building for sm52 while Mem
# Eff Attention won't
cmake_dependent_option(
@ -905,6 +913,10 @@ if(USE_FBGEMM)
string(APPEND CMAKE_CXX_FLAGS " -DUSE_FBGEMM")
endif()
if(USE_FBGEMM_GENAI)
string(APPEND CMAKE_CXX_FLAGS " -DUSE_FBGEMM_GENAI")
endif()
if(USE_PYTORCH_QNNPACK)
string(APPEND CMAKE_CXX_FLAGS " -DUSE_PYTORCH_QNNPACK")
endif()

View File

@ -14,7 +14,6 @@
/torch/csrc/autograd/ @albanD @soulitzer
/torch/autograd/ @albanD @soulitzer
/tools/autograd/ @albanD @soulitzer
/torch/header_only_apis.txt @janeyx99
/torch/nn/ @albanD @jbschlosser @mikaylagawarecki
/torch/optim/ @albanD @janeyx99
/test/test_public_bindings.py @albanD
@ -51,12 +50,12 @@ nn/qat/ @jerryzh168
/torch/csrc/distributed/c10d/Ops.* @kwen2501
# ONNX Export
/torch/_dynamo/backends/onnxrt.py @wschin
/torch/csrc/jit/passes/onnx.h @titaiwangms @shubhambhokare1
/torch/csrc/jit/passes/onnx.cpp @titaiwangms @shubhambhokare1
/torch/csrc/jit/passes/onnx/ @titaiwangms @shubhambhokare1
/torch/onnx/ @titaiwangms @shubhambhokare1 @justinchuby @wschin
/test/onnx/ @titaiwangms @shubhambhokare1 @justinchuby @wschin
/torch/_dynamo/backends/onnxrt.py @titaiwangms @xadupre @justinchuby
/torch/csrc/jit/passes/onnx.h @titaiwangms @xadupre
/torch/csrc/jit/passes/onnx.cpp @titaiwangms @xadupre
/torch/csrc/jit/passes/onnx/ @titaiwangms @xadupre
/torch/onnx/ @titaiwangms @xadupre @justinchuby
/test/onnx/ @titaiwangms @xadupre @justinchuby
# CI
/.ci @pytorch/pytorch-dev-infra
@ -196,3 +195,8 @@ torch/backends/cudnn/ @eqy @syed-ahmed
/torch/utils/_cxx_pytree.py @XuehaiPan
/torch/utils/pytree/ @XuehaiPan
/torch/_dynamo/polyfills/pytree.py @XuehaiPan
# Relating to libtorch ABI
/torch/csrc/stable/ @janeyx99 @mikaylagawarecki
/torch/headeronly/ @janeyx99
/torch/header_only_apis.txt @janeyx99

View File

@ -247,6 +247,50 @@ if(USE_MEM_EFF_ATTENTION)
list(APPEND ATen_ATTENTION_KERNEL_SRCS ${mem_eff_attention_cuda_kernels_cu})
endif()
IF(USE_FBGEMM_GENAI AND USE_ROCM AND NOT "gfx942" IN_LIST PYTORCH_ROCM_ARCH)
message(WARNING "Unsupported ROCM arch for FBGEMM GenAI, will set USE_FBGEMM_GENAI to OFF")
set(USE_FBGEMM_GENAI off)
endif()
# FBGEMM GenAI
IF(USE_FBGEMM_GENAI)
set(FBGEMM_THIRD_PARTY ${PROJECT_SOURCE_DIR}/third_party/fbgemm/external/)
set(FBGEMM_GENAI_DIR ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize)
if(USE_ROCM)
# Only include the kernels we want to build to avoid increasing binary size.
file(GLOB_RECURSE fbgemm_genai_native_rocm_hip
"${FBGEMM_GENAI_DIR}/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped*.hip"
"${FBGEMM_GENAI_DIR}/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip")
set_source_files_properties(${fbgemm_genai_native_rocm_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
# Add additional HIPCC compiler flags for performance
set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
-mllvm
-amdgpu-coerce-illegal-types=1
-mllvm
-enable-post-misched=0
-mllvm
-greedy-reverse-local-assignment=1
-fhip-new-launch-api)
hip_add_library(
fbgemm_genai STATIC
${fbgemm_genai_native_rocm_hip}
HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${FBGEMM_GENAI_EXTRA_HIPCC_FLAGS})
set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_compile_definitions(fbgemm_genai PRIVATE FBGEMM_GENAI_NO_EXTENDED_SHAPES)
target_include_directories(fbgemm_genai PUBLIC
# FBGEMM version of Composable Kernel is used due to some customizations
${FBGEMM_THIRD_PARTY}/composable_kernel/include
${FBGEMM_THIRD_PARTY}/composable_kernel/library/include
${FBGEMM_GENAI_DIR}/include/
${FBGEMM_GENAI_DIR}/common/include/
)
endif()
endif()
# XNNPACK
file(GLOB native_xnnpack "native/xnnpack/*.cpp")

View File

@ -1,55 +1 @@
#pragma once
#if defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__))
/* GCC or clang-compatible compiler, targeting x86/x86-64 */
#include <x86intrin.h>
#elif defined(__clang__) && (defined(__ARM_NEON__) || defined(__aarch64__))
/* Clang-compatible compiler, targeting arm neon */
#include <arm_neon.h>
#if defined(__ARM_FEATURE_SVE)
/* CLANG-compatible compiler, targeting ARM with SVE */
#include <arm_sve.h>
#endif
#elif defined(_MSC_VER)
/* Microsoft C/C++-compatible compiler */
#include <intrin.h>
#if _MSC_VER <= 1900
#define _mm256_extract_epi64(X, Y) \
(_mm_extract_epi64(_mm256_extractf128_si256(X, Y >> 1), Y % 2))
#define _mm256_extract_epi32(X, Y) \
(_mm_extract_epi32(_mm256_extractf128_si256(X, Y >> 2), Y % 4))
#define _mm256_extract_epi16(X, Y) \
(_mm_extract_epi16(_mm256_extractf128_si256(X, Y >> 3), Y % 8))
#define _mm256_extract_epi8(X, Y) \
(_mm_extract_epi8(_mm256_extractf128_si256(X, Y >> 4), Y % 16))
#endif
#elif defined(__GNUC__) && (defined(__ARM_NEON__) || defined(__aarch64__))
/* GCC-compatible compiler, targeting ARM with NEON */
#include <arm_neon.h>
#if defined(__ARM_FEATURE_SVE)
/* GCC-compatible compiler, targeting ARM with SVE */
#include <arm_sve.h>
#endif
#if defined(MISSING_ARM_VLD1)
#include <ATen/cpu/vec/vec256/missing_vld1_neon.h>
#elif defined(MISSING_ARM_VST1)
#include <ATen/cpu/vec/vec256/missing_vst1_neon.h>
#endif
#elif defined(__GNUC__) && defined(__IWMMXT__)
/* GCC-compatible compiler, targeting ARM with WMMX */
#include <mmintrin.h>
#elif defined(__s390x__)
// targets Z/architecture
// we will include vecintrin later
#elif (defined(__GNUC__) || defined(__xlC__)) && \
(defined(__VEC__) || defined(__ALTIVEC__))
/* XLC or GCC-compatible compiler, targeting PowerPC with VMX/VSX */
#include <altivec.h>
/* We need to undef those tokens defined by <altivec.h> to avoid conflicts
with the C++ types. => Can still use __bool/__vector */
#undef bool
#undef vector
#undef pixel
#elif defined(__GNUC__) && defined(__SPE__)
/* GCC-compatible compiler, targeting PowerPC with SPE */
#include <spe.h>
#endif
#include <torch/headeronly/cpu/vec/intrinsics.h>

View File

@ -1,396 +1 @@
/* Workaround for missing vld1_*_x2 and vst1_*_x2 intrinsics in gcc-7. */
__extension__ extern __inline uint8x8x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_u8_x2(const uint8_t* __a) {
uint8x8x2_t ret;
asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline int8x8x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_s8_x2(const int8_t* __a) {
int8x8x2_t ret;
asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline uint16x4x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_u16_x2(const uint16_t* __a) {
uint16x4x2_t ret;
asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline int16x4x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_s16_x2(const int16_t* __a) {
int16x4x2_t ret;
asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline uint32x2x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_u32_x2(const uint32_t* __a) {
uint32x2x2_t ret;
asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline int32x2x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_s32_x2(const int32_t* __a) {
int32x2x2_t ret;
asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline uint64x1x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_u64_x2(const uint64_t* __a) {
uint64x1x2_t ret;
asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline int64x1x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_s64_x2(const int64_t* __a) {
int64x1x2_t ret;
__builtin_aarch64_simd_oi __o;
asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline float16x4x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_f16_x2(const float16_t* __a) {
float16x4x2_t ret;
asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline float32x2x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_f32_x2(const float32_t* __a) {
float32x2x2_t ret;
asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline float64x1x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_f64_x2(const float64_t* __a) {
float64x1x2_t ret;
asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline poly8x8x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_p8_x2(const poly8_t* __a) {
poly8x8x2_t ret;
asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline poly16x4x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_p16_x2(const poly16_t* __a) {
poly16x4x2_t ret;
asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline poly64x1x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_p64_x2(const poly64_t* __a) {
poly64x1x2_t ret;
asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline uint8x16x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_u8_x2(const uint8_t* __a) {
uint8x16x2_t ret;
asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline int8x16x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_s8_x2(const int8_t* __a) {
int8x16x2_t ret;
asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline uint16x8x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_u16_x2(const uint16_t* __a) {
uint16x8x2_t ret;
asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline int16x8x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_s16_x2(const int16_t* __a) {
int16x8x2_t ret;
asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline uint32x4x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_u32_x2(const uint32_t* __a) {
uint32x4x2_t ret;
asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline int32x4x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_s32_x2(const int32_t* __a) {
int32x4x2_t ret;
asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline uint64x2x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_u64_x2(const uint64_t* __a) {
uint64x2x2_t ret;
asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline int64x2x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_s64_x2(const int64_t* __a) {
int64x2x2_t ret;
asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline float16x8x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_f16_x2(const float16_t* __a) {
float16x8x2_t ret;
asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline float32x4x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_f32_x2(const float32_t* __a) {
float32x4x2_t ret;
asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline float64x2x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_f64_x2(const float64_t* __a) {
float64x2x2_t ret;
asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline poly8x16x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_p8_x2(const poly8_t* __a) {
poly8x16x2_t ret;
asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline poly16x8x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_p16_x2(const poly16_t* __a) {
poly16x8x2_t ret;
asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline poly64x2x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_p64_x2(const poly64_t* __a) {
poly64x2x2_t ret;
asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
/* vst1x2 */
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_s64_x2(int64_t* __a, int64x1x2_t val) {
asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_u64_x2(uint64_t* __a, uint64x1x2_t val) {
asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_f64_x2(float64_t* __a, float64x1x2_t val) {
asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_s8_x2(int8_t* __a, int8x8x2_t val) {
asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_p8_x2(poly8_t* __a, poly8x8x2_t val) {
asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_s16_x2(int16_t* __a, int16x4x2_t val) {
asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_p16_x2(poly16_t* __a, poly16x4x2_t val) {
asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_s32_x2(int32_t* __a, int32x2x2_t val) {
asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_u8_x2(uint8_t* __a, uint8x8x2_t val) {
asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_u16_x2(uint16_t* __a, uint16x4x2_t val) {
asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_u32_x2(uint32_t* __a, uint32x2x2_t val) {
asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_f16_x2(float16_t* __a, float16x4x2_t val) {
asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_f32_x2(float32_t* __a, float32x2x2_t val) {
asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_p64_x2(poly64_t* __a, poly64x1x2_t val) {
asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_s8_x2(int8_t* __a, int8x16x2_t val) {
asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_p8_x2(poly8_t* __a, poly8x16x2_t val) {
asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_s16_x2(int16_t* __a, int16x8x2_t val) {
asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_p16_x2(poly16_t* __a, poly16x8x2_t val) {
asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_s32_x2(int32_t* __a, int32x4x2_t val) {
asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_s64_x2(int64_t* __a, int64x2x2_t val) {
asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_u8_x2(uint8_t* __a, uint8x16x2_t val) {
asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_u16_x2(uint16_t* __a, uint16x8x2_t val) {
asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_u32_x2(uint32_t* __a, uint32x4x2_t val) {
asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_u64_x2(uint64_t* __a, uint64x2x2_t val) {
asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_f16_x2(float16_t* __a, float16x8x2_t val) {
asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_f32_x2(float32_t* __a, float32x4x2_t val) {
asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_f64_x2(float64_t* __a, float64x2x2_t val) {
asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_p64_x2(poly64_t* __a, poly64x2x2_t val) {
asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val));
}
#include <torch/headeronly/cpu/vec/vec256/missing_vld1_neon.h>

View File

@ -1,7 +1 @@
/* Workaround for missing vst1q_f32_x2 in gcc-8. */
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_f32_x2(float32_t* __a, float32x4x2_t val) {
asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val));
}
#include <torch/headeronly/cpu/vec/vec256/missing_vst1_neon.h>

View File

@ -3,50 +3,12 @@
#include <ATen/cpu/vec/intrinsics.h>
#include <c10/util/Exception.h>
#include <torch/headeronly/cpu/vec/vec_half.h>
namespace at::vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \
!defined(__APPLE__)
static inline uint16_t float2half_scalar(float val) {
#if defined(CPU_CAPABILITY_AVX2)
#if defined(_MSC_VER)
__m256 v = _mm256_set1_ps(val);
__m128i o =
_mm256_cvtps_ph(v, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
return static_cast<std::uint16_t>(_mm_cvtsi128_si32(o));
#else
return _cvtss_sh(val, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
#endif
#elif defined(CPU_CAPABILITY_AVX512)
__m512 v = _mm512_set1_ps(val);
__m256i o =
_mm512_cvtps_ph(v, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
return static_cast<std::uint16_t>(
_mm_cvtsi128_si32(_mm256_castsi256_si128(o)));
#endif
}
static inline float half2float_scalar(uint16_t val) {
#if defined(CPU_CAPABILITY_AVX2)
#if defined(_MSC_VER)
__m128i v = _mm_cvtsi32_si128(val);
__m256 o = _mm256_cvtph_ps(v);
return _mm256_cvtss_f32(o);
#else
return _cvtsh_ss(val);
#endif
#elif defined(CPU_CAPABILITY_AVX512)
__m256i v =
_mm256_setr_epi16(val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
__m512 o = _mm512_cvtph_ps(v);
return _mm512_cvtss_f32(o);
#endif
}
#endif
// Transpose a [2, 32] matrix to [32, 2]
// Note: the output leading dimension should be 2,
// that is, the output must be contiguous

View File

@ -162,7 +162,7 @@ struct CUDACachingHostAllocatorImpl
}
bool pinned_use_background_threads() override {
return c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::
return c10::CachingAllocator::AcceleratorAllocatorConfig::
pinned_use_background_threads();
}

View File

@ -21,6 +21,10 @@
#include <ATen/native/cuda/GroupMM.h>
#include <ATen/ceil_div.h>
#ifdef USE_FBGEMM_GENAI
#include <fbgemm_gpu/torch_ops.h>
#endif
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
@ -1216,7 +1220,7 @@ std::pair<ScalingType, ScalingType> get_joint_scaling(
// - `scale_a`: a tensor with the inverse scale of `mat1`, whose shape/strides/dtype depend on the scaling scheme
// - `scale_b`: a tensor with the inverse scale of `mat2`, whose shape/strides/dtype depend on the scaling scheme
// - `scale_result`: a scalar tensor with the scale of the output, only utilized if the output is a float8 type
// - `use_fast_accum`: if true, enables fast float8 accumulation
// - `use_fast_accum`: if true, enables fast float8 accumulation. Backends may ignore this option if not applicable.
// - `out`: a reference to the output tensor
Tensor&
@ -1525,6 +1529,7 @@ namespace {
const auto out_dtype_ = out_dtype.value_or(kBFloat16);
TORCH_CHECK(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
#ifndef USE_ROCM
// For TMA transfers, strides of output tensor have to be either
// 1, or aligned to 16 bytes.
const auto last_dim = out_size.size() - 1;
@ -1536,9 +1541,10 @@ namespace {
} else {
out_stride = {out_size[1] * size_padded, size_padded, 1};
}
auto out = at::empty_strided(out_size, out_stride, mat_a.options().dtype(out_dtype_));
return out;
return at::empty_strided(out_size, out_stride, mat_a.options().dtype(out_dtype_));
#else
return at::empty(out_size, mat_a.options().dtype(out_dtype_));
#endif
}
bool check_valid_strides_and_return_transposed(const Tensor& mat) {
@ -1619,12 +1625,9 @@ const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
std::optional<c10::ScalarType> out_dtype,
bool use_fast_accum) {
#ifndef USE_ROCM
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true);
TORCH_CHECK(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = 9.0");
bool allowed_device = _scaled_mm_allowed_device();
TORCH_CHECK(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = 9.0, or ROCm MI300+");
TORCH_CHECK(mat_a.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_a.scalar_type());
TORCH_CHECK(mat_b.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_b.scalar_type());
TORCH_CHECK(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed");
TORCH_CHECK(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed");
TORCH_CHECK(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d");
@ -1664,6 +1667,10 @@ bool use_fast_accum) {
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype);
#ifndef USE_ROCM
TORCH_CHECK(mat_a.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_a.scalar_type());
TORCH_CHECK(mat_b.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_b.scalar_type());
at::cuda::detail::f8f8bf16_grouped_mm(
mat_a,
mat_b,
@ -1674,12 +1681,23 @@ bool use_fast_accum) {
use_fast_accum,
out);
return out;
#else
TORCH_CHECK(false, "grouped gemm is not supported on ROCM")
#ifdef USE_FBGEMM_GENAI
TORCH_CHECK(mat_a.dtype() == at::kFloat8_e4m3fnuz, "Expected mat_a to be Float8_e4m3fnuz matrix got ", mat_a.scalar_type());
TORCH_CHECK(mat_b.dtype() == at::kFloat8_e4m3fnuz, "Expected mat_a to be Float8_e4m3fnuz matrix got ", mat_b.scalar_type());
fbgemm_gpu::f8f8bf16_rowwise_grouped_mm(
mat_a,
// FBGEMM expects B matrix shape to be (.., N, K)
mat_b.transpose(-2, -1),
scale_a,
scale_b,
offs,
out);
return out;
#else
TORCH_CHECK(false, "grouped gemm is not supported without USE_FBGEMM_GENAI on ROCM")
#endif
#endif
}

View File

@ -38,17 +38,19 @@ static inline std::string _cudaGetErrorEnum(cufftResult error)
return "CUFFT_INVALID_SIZE";
case CUFFT_UNALIGNED_DATA:
return "CUFFT_UNALIGNED_DATA";
case CUFFT_INCOMPLETE_PARAMETER_LIST:
return "CUFFT_INCOMPLETE_PARAMETER_LIST";
case CUFFT_INVALID_DEVICE:
return "CUFFT_INVALID_DEVICE";
case CUFFT_PARSE_ERROR:
return "CUFFT_PARSE_ERROR";
case CUFFT_NO_WORKSPACE:
return "CUFFT_NO_WORKSPACE";
case CUFFT_NOT_IMPLEMENTED:
return "CUFFT_NOT_IMPLEMENTED";
#if !defined(USE_ROCM)
#if CUDA_VERSION <= 12090
case CUFFT_INCOMPLETE_PARAMETER_LIST:
return "CUFFT_INCOMPLETE_PARAMETER_LIST";
case CUFFT_PARSE_ERROR:
return "CUFFT_PARSE_ERROR";
#endif
#if !defined(USE_ROCM) && CUDA_VERSION <= 12090
case CUFFT_LICENSE_ERROR:
return "CUFFT_LICENSE_ERROR";
#endif

View File

@ -9,6 +9,7 @@
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wset-but-not-used")
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter")
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wmissing-field-initializers")
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-variable")
// Determine if the architecture supports rowwise scaled mm
// Currently failing on windows with:
@ -44,6 +45,7 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wmissing-field-initializers")
#include <ATen/native/cuda/cutlass_common.cuh>
C10_DIAGNOSTIC_POP()
C10_DIAGNOSTIC_POP()
C10_DIAGNOSTIC_POP()

View File

@ -10,6 +10,7 @@
// Two warninngs in Cutlass included header files
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wset-but-not-used")
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter")
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-variable")
// Determine if the architecture supports rowwise scaled mm
// Currently failing on windows with:
@ -44,6 +45,7 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter")
#include <cutlass/gemm/kernel/gemm_universal.hpp>
#include <cutlass/util/packed_stride.hpp>
C10_DIAGNOSTIC_POP()
C10_DIAGNOSTIC_POP()
C10_DIAGNOSTIC_POP()

View File

@ -45,7 +45,7 @@ namespace at::cuda::jit {
// Copied from aten/src/ATen/cuda/llvm_basic.cpp, then modified as above.
// If not compiling for ROCm, return the original get_traits_string().
std::string get_traits_string_but_hiprtc_safe() {
#if defined(USE_ROCM) && ROCM_VERSION < 70000
#if defined(USE_ROCM) && HIP_VERSION_MAJOR < 7
return R"ESCAPE(
namespace std {

View File

@ -342,8 +342,8 @@ Tensor rms_norm_symint(
if (weight_opt.has_value() && weight_opt.value().defined() && weight_opt.value().dtype() != input.dtype()) {
TORCH_WARN_ONCE(
"Mismatch dtype between input and module: input dtype = ", input.dtype(),
", module dtype = ", weight_opt.value().dtype(), ", Can not dispatch to fused implementation"
"Mismatch dtype between input and weight: input dtype = ", input.dtype(),
", weight dtype = ", weight_opt.value().dtype(), ", Cannot dispatch to fused implementation."
);
return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast<const int64_t*>(normalized_shape.data()), normalized_shape.size()), weight_opt, eps));
}

View File

@ -22,6 +22,22 @@ struct PoolingParams {
bool return_indices;
};
template <unsigned N = 5, typename idx_type_t = int32_t>
struct AvgPoolingParams {
int32_t dims;
int32_t pooling_dims;
::c10::metal::array<idx_type_t, N> input_sizes;
::c10::metal::array<idx_type_t, N> input_strides;
::c10::metal::array<idx_type_t, N> output_sizes;
::c10::metal::array<idx_type_t, N> output_strides;
::c10::metal::array<idx_type_t, N - 2> kernel_size;
::c10::metal::array<idx_type_t, N - 2> stride;
::c10::metal::array<idx_type_t, N - 2> padding;
bool count_include_pad;
bool has_divisor_override;
int32_t divisor_override;
};
template <unsigned N = 5, typename idx_type_t = int32_t>
struct PoolingBackwardParams {
int32_t dims;

View File

@ -292,12 +292,154 @@ kernel void max_pool_backward(
pooling_dims);
}
#define REGISTER_MAX_POOL_OP(DTYPE) \
template <typename T>
struct AvgPoolIterBounds {
T start;
T end;
T count;
};
template <int32_t dim>
AvgPoolIterBounds<int32_t> get_avg_pool_input_iter_bounds(
constant int32_t* input_sizes,
thread int32_t (&pooling_dim_indices)[3],
constant int32_t* kernel_size,
constant int32_t* stride,
constant int32_t* padding,
bool count_include_pad) {
auto start = stride[dim] * pooling_dim_indices[dim] - padding[dim];
auto end = start + kernel_size[dim];
auto end_corrected = min(start + kernel_size[dim], input_sizes[dim]);
auto start_corrected = (start < 0) ? 0 : start;
auto count = count_include_pad
? (min(end, input_sizes[dim] + padding[dim]) - start)
: (end_corrected - start_corrected);
return {start_corrected, end_corrected, count};
}
// Iterates through all the input elements that this kernel needs to
// apply max to. Specialized for 3 pooling dimensions.
template <typename T>
void avg_pool_3d_input_iter(
constant T* input,
device T* output,
constant int32_t* input_sizes,
constant int32_t* input_strides,
thread int32_t (&pooling_dim_indices)[3],
constant int32_t* kernel_size,
constant int32_t* stride,
constant int32_t* padding,
bool count_include_pad,
bool has_divisor_override,
int32_t divisor_override) {
auto bounds0 = get_avg_pool_input_iter_bounds<0>(
input_sizes,
pooling_dim_indices,
kernel_size,
stride,
padding,
count_include_pad);
auto bounds1 = get_avg_pool_input_iter_bounds<1>(
input_sizes,
pooling_dim_indices,
kernel_size,
stride,
padding,
count_include_pad);
auto bounds2 = get_avg_pool_input_iter_bounds<2>(
input_sizes,
pooling_dim_indices,
kernel_size,
stride,
padding,
count_include_pad);
T value_sum = 0;
auto divisor = has_divisor_override
? divisor_override
: (bounds0.count) * (bounds1.count) * (bounds2.count);
auto size12 = input_sizes[1] * input_sizes[2];
for (auto i0 = bounds0.start; i0 < bounds0.end; i0++) {
auto offset0 = input_strides[0] * i0;
for (auto i1 = bounds1.start; i1 < bounds1.end; i1++) {
auto offset1 = input_strides[1] * i1;
for (auto i2 = bounds2.start; i2 < bounds2.end; i2++) {
auto offset2 = input_strides[2] * i2;
auto input_value = input[offset0 + offset1 + offset2];
value_sum += input_value;
}
}
}
*output = value_sum / static_cast<T>(divisor);
}
// Kernel computes one element of the output per kernel call.
template <typename T>
kernel void avg_pool(
constant T* input [[buffer(0)]],
device T* output [[buffer(1)]],
constant AvgPoolingParams<5>& params [[buffer(2)]],
uint tid [[thread_position_in_grid]]) {
auto pooling_dims = params.pooling_dims;
auto dims = params.dims;
auto input_sizes = params.input_sizes.data();
auto input_strides = params.input_strides.data();
auto output_sizes = params.output_sizes.data();
auto output_strides = params.output_strides.data();
auto kernel_size = params.kernel_size.data();
auto stride = params.stride.data();
auto padding = params.padding.data();
auto leading_dims = dims - pooling_dims;
// This buffer keeps track of the pooling dimension indices of this thread's
// element of the output. We need to fill it with the proper values below.
int32_t pooling_dim_indices[3];
PoolOffsets offsets = find_pool_offsets(
output_sizes,
output_strides,
/*indices_strides=*/nullptr,
input_strides,
pooling_dim_indices,
dims,
leading_dims,
/*return_indices=*/false,
tid);
output += offsets.output;
input += offsets.input_leading;
input_sizes += leading_dims;
input_strides += leading_dims;
avg_pool_3d_input_iter<T>(
input,
output,
input_sizes,
input_strides,
pooling_dim_indices,
kernel_size,
stride,
padding,
params.count_include_pad,
params.has_divisor_override,
params.divisor_override);
}
#define REGISTER_POOL_OP(DTYPE) \
template [[host_name("max_pool_" #DTYPE)]] kernel void max_pool<DTYPE>( \
constant DTYPE * input [[buffer(0)]], \
device DTYPE * output [[buffer(1)]], \
device int64_t* indices [[buffer(2)]], \
constant PoolingParams<5>& params [[buffer(3)]], \
uint tid [[thread_position_in_grid]]); \
\
template [[host_name("avg_pool_" #DTYPE)]] kernel void avg_pool<DTYPE>( \
constant DTYPE * input [[buffer(0)]], \
device DTYPE * output [[buffer(1)]], \
constant AvgPoolingParams<5> & params [[buffer(2)]], \
uint tid [[thread_position_in_grid]]);
#define REGISTER_MAX_POOL_BACKWARD_OP(DTYPE) \
@ -309,19 +451,19 @@ kernel void max_pool_backward(
constant PoolingBackwardParams<5>& params [[buffer(3)]], \
uint tid [[thread_position_in_grid]]);
REGISTER_MAX_POOL_OP(float);
REGISTER_MAX_POOL_OP(half);
REGISTER_MAX_POOL_OP(int);
REGISTER_MAX_POOL_OP(long);
REGISTER_MAX_POOL_OP(short);
REGISTER_MAX_POOL_OP(char);
REGISTER_MAX_POOL_OP(uchar);
REGISTER_MAX_POOL_OP(bool);
REGISTER_POOL_OP(float);
REGISTER_POOL_OP(half);
REGISTER_POOL_OP(int);
REGISTER_POOL_OP(long);
REGISTER_POOL_OP(short);
REGISTER_POOL_OP(char);
REGISTER_POOL_OP(uchar);
REGISTER_POOL_OP(bool);
REGISTER_MAX_POOL_BACKWARD_OP(float);
REGISTER_MAX_POOL_BACKWARD_OP(half);
#if __METAL_VERSION__ >= 310
REGISTER_MAX_POOL_OP(bfloat);
REGISTER_POOL_OP(bfloat);
REGISTER_MAX_POOL_BACKWARD_OP(bfloat);
#endif

View File

@ -418,8 +418,9 @@ Tensor& exponential_mps_(Tensor& self, double lambda, std::optional<Generator> g
MPSGraphTensor* logTensor = [mpsGraph logarithmWithTensor:subtractTensor name:nil];
return [mpsGraph divisionWithPrimaryTensor:logTensor secondaryTensor:minusLambdaTensor name:nil];
};
auto eps = std::numeric_limits<float>::epsilon();
return mps::random_mps_impl<double>(self,
0.0,
eps,
1.0,
std::nullopt,
std::nullopt,

View File

@ -14,6 +14,7 @@
#include <ATen/ops/avg_pool2d_backward.h>
#include <ATen/ops/avg_pool2d_backward_native.h>
#include <ATen/ops/avg_pool2d_native.h>
#include <ATen/ops/avg_pool3d_native.h>
#include <ATen/ops/max_pool2d_backward_native.h>
#include <ATen/ops/max_pool2d_native.h>
#include <ATen/ops/max_pool2d_with_indices_backward_native.h>
@ -265,13 +266,13 @@ using PoolSizes = std::tuple<int32_t,
std::vector<int32_t>,
std::vector<int32_t>,
std::vector<int32_t>,
std::vector<int32_t>>;
std::optional<std::vector<int32_t>>>;
static PoolSizes process_pool_sizes(const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
std::optional<IntArrayRef> dilation_opt,
bool ceil_mode,
const int32_t pooling_dims,
const std::string& op_name) {
@ -305,18 +306,22 @@ static PoolSizes process_pool_sizes(const Tensor& input,
pooling_dims,
" ints");
TORCH_CHECK(dilation.size() == 1 || dilation.size() == pooling_dims,
op_name,
": dilation must be either a single int, or a tuple of ",
pooling_dims,
" ints");
if (dilation_opt.has_value()) {
auto dilation = dilation_opt.value();
TORCH_CHECK(dilation.size() == 1 || dilation.size() == pooling_dims,
op_name,
": dilation must be either a single int, or a tuple of ",
pooling_dims,
" ints");
}
int32_t leading_dims = input.dim() - pooling_dims;
const auto kernel_size_expanded = copy_and_maybe_expand(kernel_size, pooling_dims);
const auto stride_expanded = copy_and_maybe_expand(stride.empty() ? kernel_size : stride, pooling_dims);
const auto padding_expanded = copy_and_maybe_expand(padding, pooling_dims);
const auto dilation_expanded = copy_and_maybe_expand(dilation, pooling_dims);
const auto dilation_expanded = dilation_opt.has_value() ? copy_and_maybe_expand(dilation_opt.value(), pooling_dims)
: std::vector<int32_t>(pooling_dims, 1);
for (const auto dim : c10::irange(pooling_dims)) {
TORCH_CHECK(padding_expanded[dim] >= 0, op_name, ": pad must be non-negative");
@ -362,7 +367,12 @@ static PoolSizes process_pool_sizes(const Tensor& input,
output_size[leading_dims + dim] = output_pooling_size[dim];
}
return PoolSizes(dims, output_size, kernel_size_expanded, stride_expanded, padding_expanded, dilation_expanded);
return PoolSizes(dims,
output_size,
kernel_size_expanded,
stride_expanded,
padding_expanded,
dilation_opt.has_value() ? std::make_optional(dilation_expanded) : std::nullopt);
}
static void max_pool_with_indices_out_mps_template(const Tensor& output,
@ -375,8 +385,10 @@ static void max_pool_with_indices_out_mps_template(const Tensor& output,
bool ceil_mode,
const int32_t pooling_dims,
const std::string& op_name) {
auto [dims, output_size, kernel_size, stride, padding, dilation] =
auto [dims, output_size, kernel_size, stride, padding, dilation_opt] =
process_pool_sizes(input, _kernel_size, _stride, _padding, _dilation, ceil_mode, pooling_dims, op_name);
TORCH_INTERNAL_ASSERT(dilation_opt.has_value());
auto dilation = dilation_opt.value();
const Tensor& indices = *(at::borrow_from_optional_tensor(indices_opt));
const bool return_indices = indices.defined();
@ -442,7 +454,7 @@ static void max_pool_with_indices_backward_out_mps_template(Tensor& grad_input,
bool ceil_mode,
const int32_t pooling_dims,
const std::string& op_name) {
auto [dims, output_size, kernel_size, stride, padding, dilation] =
auto [dims, output_size, kernel_size, stride, padding, dilation_opt] =
process_pool_sizes(input, _kernel_size, _stride, _padding, _dilation, ceil_mode, pooling_dims, op_name);
const auto memory_format = input.suggest_memory_format();
@ -601,6 +613,62 @@ static void avg_pool2d_template(const Tensor& input,
op_name);
}
static void avg_pool_out_mps_template(const Tensor& output,
const Tensor& input,
IntArrayRef _kernel_size,
IntArrayRef _stride,
IntArrayRef _padding,
bool ceil_mode,
bool count_include_pad,
std::optional<int64_t> divisor_override,
const int32_t pooling_dims,
const std::string& op_name) {
auto [dims, output_size, kernel_size, stride, padding, _] =
process_pool_sizes(input, _kernel_size, _stride, _padding, std::nullopt, ceil_mode, pooling_dims, op_name);
const auto memory_format = input.suggest_memory_format();
output.resize_(output_size, memory_format);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
const auto numThreads = output.numel();
AvgPoolingParams<5> params;
params.dims = dims;
params.pooling_dims = pooling_dims;
params.count_include_pad = count_include_pad;
params.has_divisor_override = divisor_override.has_value();
if (divisor_override.has_value()) {
params.divisor_override = safe_downcast<int32_t, int64_t>(divisor_override.value());
}
for (const auto dim : c10::irange(dims)) {
params.input_sizes[dim] = safe_downcast<int32_t, int64_t>(input.size(dim));
params.input_strides[dim] = safe_downcast<int32_t, int64_t>(input.stride(dim));
params.output_sizes[dim] = safe_downcast<int32_t, int64_t>(output.size(dim));
params.output_strides[dim] = safe_downcast<int32_t, int64_t>(output.stride(dim));
}
memcpy(params.kernel_size.data(), kernel_size.data(), pooling_dims * sizeof(int32_t));
memcpy(params.stride.data(), stride.data(), pooling_dims * sizeof(int32_t));
memcpy(params.padding.data(), padding.data(), pooling_dims * sizeof(int32_t));
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
auto PSO = lib.getPipelineStateForFunc("avg_pool_" + scalarToMetalTypeString(input));
getMPSProfiler().beginProfileKernel(PSO, op_name, {input});
[computeEncoder setComputePipelineState:PSO];
mtl_setArgs(computeEncoder, input, output, params);
mtl_dispatch1DJob(computeEncoder, PSO, numThreads);
getMPSProfiler().endProfileKernel(PSO);
}
});
}
} // namespace mps
Tensor mps_max_pool2d(const Tensor& input,
@ -876,4 +944,25 @@ TORCH_IMPL_FUNC(avg_pool2d_backward_out_mps)
"avg_pool2d_backward");
}
TORCH_IMPL_FUNC(avg_pool3d_out_mps)
(const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
bool ceil_mode,
bool count_include_pad,
std::optional<int64_t> divisor_override,
const Tensor& output) {
mps::avg_pool_out_mps_template(output,
input,
kernel_size,
stride,
padding,
ceil_mode,
count_include_pad,
divisor_override,
/*pooling_dims=*/3,
"avg_pool3d");
}
} // namespace at::native

View File

@ -7124,18 +7124,21 @@
dispatch:
CPU: _scaled_mm_cpu
CUDA: _scaled_mm_cuda
tags: needs_exact_strides
- func: _scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!)
variants: function
dispatch:
CPU: _scaled_mm_out_cpu
CUDA: _scaled_mm_out_cuda
tags: needs_exact_strides
- func: _scaled_grouped_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? offs=None, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor
variants: function
dispatch:
CUDA: _scaled_grouped_mm_cuda
tags: needs_exact_strides
- func: _grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor
variants: function
@ -12334,6 +12337,7 @@
dispatch:
CPU: avg_pool3d_out_cpu
CUDA: avg_pool3d_out_cuda
MPS: avg_pool3d_out_mps
MkldnnCPU: mkldnn_avg_pool3d_out
- func: avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor

View File

@ -955,7 +955,10 @@ static at::Tensor fp8_qlinear_onednn_ref(
std::vector<int64_t> w_scales_new_shape(weight.dim(), 1);
w_scales_new_shape[0] = -1;
auto dqw = weight.to(at::kFloat) * weight_scales.reshape(w_scales_new_shape);
auto y_f32 = at::linear(dqx, dqw, bias);
auto y_f32 = at::linear(dqx, dqw);
if (bias.has_value()) {
y_f32 += bias.value().to(at::kFloat);
}
if (binary_post_op == "none") {
if (unary_post_op == "relu") {
at::relu_(y_f32);

View File

@ -1,8 +1,7 @@
#include <gtest/gtest.h>
#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include <c10/util/irange.h>
#include <test/cpp/tensorexpr/test_base.h>
#include <thread>
@ -10,7 +9,7 @@
// numbers of threads set and also whether the scheduler
// will throw an exception when multiple threads call
// their first parallel construct.
static void test(int given_num_threads) {
void test(int given_num_threads) {
auto t = at::ones({1000 * 1000}, at::CPU(at::kFloat));
ASSERT_TRUE(given_num_threads >= 0);
ASSERT_EQ(at::get_num_threads(), given_num_threads);
@ -20,7 +19,7 @@ static void test(int given_num_threads) {
}
}
TEST(ThreadInitTest, ThreadInit) {
int main() {
at::init_num_threads();
at::set_num_threads(4);
@ -33,11 +32,13 @@ TEST(ThreadInitTest, ThreadInit) {
#if !AT_PARALLEL_NATIVE
at::set_num_threads(5);
ASSERT_EQ(at::get_num_threads(), 5);
ASSERT_TRUE(at::get_num_threads() == 5);
#endif
// test inter-op settings
at::set_num_interop_threads(5);
ASSERT_EQ(at::get_num_interop_threads(), 5);
ASSERT_ANY_THROW(at::set_num_interop_threads(6));
return 0;
}

View File

@ -13,6 +13,7 @@ flaky_models = {
"gluon_inception_v3",
"detectron2_maskrcnn_r_101_c4",
"XGLMForCausalLM", # discovered in https://github.com/pytorch/pytorch/pull/128148
"detectron2_fcos_r_50_fpn",
}

View File

@ -346,7 +346,7 @@ vgg16,pass,0
vision_maskrcnn,fail_accuracy,30
vision_maskrcnn,fail_accuracy,29

1 name accuracy graph_breaks
346
347
348
349
350
351
352

View File

@ -1,32 +1,32 @@
add_loop_eager,compile_time_instruction_count,3070000000,0.10
add_loop_eager,compile_time_instruction_count,3070000000,0.1
add_loop_eager_dynamic,compile_time_instruction_count,4432000000,0.10
add_loop_eager_dynamic,compile_time_instruction_count,4432000000,0.1
add_loop_inductor,compile_time_instruction_count,30280000000,0.10
add_loop_inductor,compile_time_instruction_count,30280000000,0.1
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,39910000000,0.10
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,39910000000,0.1
add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.10
add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.1
basic_modules_ListOfLinears_eager,compile_time_instruction_count,969100000,0.10
basic_modules_ListOfLinears_eager,compile_time_instruction_count,969100000,0.1
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18030000000,0.10
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,15240000000,0.1
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.10
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.1
@ -34,56 +34,56 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11090000
update_hint_regression,compile_time_instruction_count,1719000000,0.10
update_hint_regression,compile_time_instruction_count,1719000000,0.1
sum_floordiv_regression,compile_time_instruction_count,966100000,0.10
sum_floordiv_regression,compile_time_instruction_count,966100000,0.1
symint_sum,compile_time_instruction_count,3237000000,0.10
symint_sum,compile_time_instruction_count,3237000000,0.1
symint_sum_loop,compile_time_instruction_count,4299000000,0.10
symint_sum_loop,compile_time_instruction_count,4299000000,0.1
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2151000000,0.10
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2151000000,0.1
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6124000000,0.10
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6124000000,0.1
aotdispatcher_partitioner_cpu,compile_time_instruction_count,9005000000,0.10
aotdispatcher_partitioner_cpu,compile_time_instruction_count,9005000000,0.1
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1989000000,0.10
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1989000000,0.1
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3959000000,0.10
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3959000000,0.1
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10650000000,0.10
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10650000000,0.1
mm_loop_inductor_gpu,compile_time_instruction_count,4461000000,0.10
mm_loop_inductor_gpu,compile_time_instruction_count,4461000000,0.1
mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8417000000,0.10
mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8417000000,0.1
basic_NestedModule_eager,compile_time_instruction_count,8348000000,0.10
basic_NestedModule_eager,compile_time_instruction_count,8348000000,0.1
basic_InlineMod_eager,compile_time_instruction_count,7464000000,0.10
basic_InlineMod_eager,compile_time_instruction_count,7464000000,0.1

1 add_loop_eager compile_time_instruction_count 3070000000 0.10 0.1
2 add_loop_eager_dynamic compile_time_instruction_count 4432000000 0.10 0.1
3 add_loop_inductor compile_time_instruction_count 30280000000 0.10 0.1
4 add_loop_inductor_dynamic_gpu compile_time_instruction_count 39910000000 0.10 0.1
5 add_loop_inductor_gpu compile_time_instruction_count 26800000000 0.10 0.1
6 basic_modules_ListOfLinears_eager compile_time_instruction_count 969100000 0.10 0.1
7 basic_modules_ListOfLinears_inductor compile_time_instruction_count 18030000000 15240000000 0.10 0.1
8 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 17020000000 0.10 0.1
9 basic_modules_ListOfLinears_inductor_gpu compile_time_instruction_count 11090000000 0.2 0.2
10 update_hint_regression compile_time_instruction_count 1719000000 0.10 0.1
11 sum_floordiv_regression compile_time_instruction_count 966100000 0.10 0.1
12 symint_sum compile_time_instruction_count 3237000000 0.10 0.1
13 symint_sum_loop compile_time_instruction_count 4299000000 0.10 0.1
14 aotdispatcher_inference_nosubclass_cpu compile_time_instruction_count 2151000000 0.10 0.1
15 aotdispatcher_inference_subclass_cpu compile_time_instruction_count 6124000000 0.10 0.1
16 aotdispatcher_partitioner_cpu compile_time_instruction_count 9005000000 0.10 0.1
17 aotdispatcher_partitioner_cpu2 compile_time_instruction_count 1989000000 0.10 0.1
18 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3959000000 0.10 0.1
19 aotdispatcher_training_subclass_cpu compile_time_instruction_count 10650000000 0.10 0.1
20 mm_loop_inductor_gpu compile_time_instruction_count 4461000000 0.10 0.1
21 mm_loop_inductor_dynamic_gpu compile_time_instruction_count 8417000000 0.10 0.1
22 basic_NestedModule_eager compile_time_instruction_count 8348000000 0.10 0.1
23 basic_InlineMod_eager compile_time_instruction_count 7464000000 0.10 0.1
24
25
26
27
28
29
30
31
32
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

View File

@ -944,6 +944,7 @@ def define_buck_targets(
[
("torch/csrc/api/include", "torch/**/*.h"),
("", "torch/csrc/**/*.h"),
("", "torch/nativert/**/*.h"),
("", "torch/headeronly/**/*.h"),
("", "torch/script.h"),
("", "torch/library.h"),

View File

@ -593,11 +593,13 @@ libtorch_core_jit_sources = sorted(jit_sources_full)
libtorch_nativert_sources = [
"torch/nativert/ModelRunner.cpp",
"torch/nativert/graph/Graph.cpp",
"torch/nativert/graph/GraphPasses.cpp",
"torch/nativert/graph/GraphSignature.cpp",
"torch/nativert/graph/Serialization.cpp",
"torch/nativert/graph/TensorMeta.cpp",
"torch/nativert/graph/GraphUtils.cpp",
"torch/nativert/executor/DelegateExecutor.cpp",
"torch/nativert/executor/Placement.cpp",
"torch/nativert/executor/ExecutionPlanner.cpp",
@ -864,6 +866,7 @@ libtorch_python_core_sources = [
"torch/csrc/QScheme.cpp",
"torch/csrc/Module.cpp",
"torch/csrc/PyInterpreter.cpp",
"torch/csrc/PyInterpreterHooks.cpp",
"torch/csrc/python_dimname.cpp",
"torch/csrc/Size.cpp",
"torch/csrc/Storage.cpp",
@ -986,6 +989,7 @@ libtorch_python_core_sources = [
"torch/csrc/utils/verbose.cpp",
"torch/csrc/cpu/Module.cpp",
"torch/csrc/instruction_counter/Module.cpp",
"torch/nativert/python/Bindings.cpp",
] + lazy_tensor_core_python_sources
libtorch_python_distributed_core_sources = [

View File

@ -0,0 +1,241 @@
#include <c10/core/AllocatorConfig.h>
#include <c10/core/DeviceType.h>
#include <c10/util/env.h>
namespace c10::CachingAllocator {
namespace {
constexpr size_t kRoundUpPowerOfTwoIntervals = 16;
constexpr size_t kMB = 1024 * 1024ul;
constexpr size_t kRoundUpPowerOfTwoStart = 1 * kMB; // 1MB
constexpr size_t kRoundUpPowerOfTwoEnd = 64 * 1024ul * kMB; // 64GB
} // anonymous namespace
AcceleratorAllocatorConfig& AcceleratorAllocatorConfig::instance() {
static AcceleratorAllocatorConfig instance;
#define C10_ALLOCATOR_CONFIG_PARSE_ENV(env, deprecated) \
auto env##_name = c10::utils::get_env(#env); \
if (env##_name.has_value()) { \
if (deprecated) { \
TORCH_WARN_ONCE(#env " is deprecated, use PYTORCH_ALLOC_CONF instead"); \
} \
instance.parseArgs(env##_name.value()); \
return true; \
}
static bool env_flag [[maybe_unused]] = []() {
C10_ALLOCATOR_CONFIG_PARSE_ENV(PYTORCH_ALLOC_CONF, false)
// Keep this for backwards compatibility
C10_ALLOCATOR_CONFIG_PARSE_ENV(PYTORCH_CUDA_ALLOC_CONF, /*deprecated=*/true)
C10_ALLOCATOR_CONFIG_PARSE_ENV(PYTORCH_HIP_ALLOC_CONF, /*deprecated=*/true)
return false;
}();
#undef C10_ALLOCATOR_CONFIG_PARSE_ENV
return instance;
}
AcceleratorAllocatorConfig::AcceleratorAllocatorConfig() {
roundup_power2_divisions_.assign(kRoundUpPowerOfTwoIntervals, 0);
}
size_t AcceleratorAllocatorConfig::roundup_power2_divisions(size_t size) {
size_t log_size = (63 - llvm::countLeadingZeros(size));
// Our intervals start at 1MB and end at 64GB
const size_t interval_start =
63 - llvm::countLeadingZeros(kRoundUpPowerOfTwoStart);
const size_t interval_end =
63 - llvm::countLeadingZeros(kRoundUpPowerOfTwoEnd);
TORCH_CHECK_VALUE(
interval_end - interval_start == kRoundUpPowerOfTwoIntervals,
"kRoundUpPowerOfTwoIntervals mismatch");
size_t index =
(log_size > interval_start) ? (log_size - interval_start) : 0ul;
index = std::min(index, kRoundUpPowerOfTwoIntervals - 1);
return instance().roundup_power2_divisions_[index];
}
size_t AcceleratorAllocatorConfig::parseMaxSplitSize(
const ConfigTokenizer& tokenizer,
size_t i) {
tokenizer.checkToken(++i, ":");
constexpr size_t min_allowed_split_size_mb = kLargeBuffer / kMB;
constexpr size_t max_allowed_split_size_mb =
std::numeric_limits<size_t>::max() / kMB;
size_t val_env = tokenizer.toSizeT(++i);
TORCH_CHECK_VALUE(
val_env >= min_allowed_split_size_mb,
"CachingAllocator option max_split_size_mb too small, must be >= ",
min_allowed_split_size_mb);
val_env = std::min(val_env, max_allowed_split_size_mb);
max_split_size_ = val_env * kMB;
return i;
}
size_t AcceleratorAllocatorConfig::parseMaxNonSplitRoundingSize(
const ConfigTokenizer& tokenizer,
size_t i) {
tokenizer.checkToken(++i, ":");
constexpr size_t min_allowed_split_size_mb = kLargeBuffer / kMB;
constexpr size_t max_allowed_split_size_mb =
std::numeric_limits<size_t>::max() / kMB;
size_t val_env = tokenizer.toSizeT(++i);
TORCH_CHECK_VALUE(
val_env >= min_allowed_split_size_mb,
"CachingAllocator option max_non_split_rounding_mb too small, must be >= ",
min_allowed_split_size_mb);
val_env = std::min(val_env, max_allowed_split_size_mb);
max_non_split_rounding_size_ = val_env * kMB;
return i;
}
size_t AcceleratorAllocatorConfig::parseGarbageCollectionThreshold(
const ConfigTokenizer& tokenizer,
size_t i) {
tokenizer.checkToken(++i, ":");
double val_env = tokenizer.toDouble(++i);
TORCH_CHECK_VALUE(
val_env > 0 && val_env < 1.0,
"garbage_collect_threshold is invalid, set it in (0.0, 1.0)");
garbage_collection_threshold_ = val_env;
return i;
}
size_t AcceleratorAllocatorConfig::parseRoundUpPower2Divisions(
const ConfigTokenizer& tokenizer,
size_t i) {
tokenizer.checkToken(++i, ":");
bool first_value = true;
if (tokenizer[++i] == "[") {
size_t last_index = 0;
// NOLINTNEXTLINE(bugprone-inc-dec-in-conditions)
while (++i < tokenizer.size() && tokenizer[i] != "]") {
size_t value_index = i;
tokenizer.checkToken(++i, ":");
size_t value = tokenizer.toSizeT(++i);
TORCH_CHECK_VALUE(
value == 0 || llvm::isPowerOf2_64(value),
"For roundups, the divisions has to be power of 2 or 0 to disable roundup ");
if (tokenizer[value_index] == ">") {
std::fill(
std::next(
roundup_power2_divisions_.begin(),
static_cast<std::vector<size_t>::difference_type>(
last_index + 1)),
roundup_power2_divisions_.end(),
value);
} else {
size_t boundary = tokenizer.toSizeT(value_index);
TORCH_CHECK_VALUE(
llvm::isPowerOf2_64(boundary),
"For roundups, the intervals have to be power of 2 ");
size_t index = 63 - llvm::countLeadingZeros(boundary);
index =
std::clamp(index, size_t{0}, roundup_power2_divisions_.size() - 1);
if (first_value) {
std::fill(
roundup_power2_divisions_.begin(),
std::next(
roundup_power2_divisions_.begin(),
static_cast<std::vector<size_t>::difference_type>(index)),
value);
first_value = false;
}
roundup_power2_divisions_[index] = value;
last_index = index;
}
if (tokenizer[i + 1] != "]") {
tokenizer.checkToken(++i, ",");
}
}
TORCH_INTERNAL_ASSERT(
i < tokenizer.size(),
"Expected closing bracket ']' in ConfigTokenizer but reached end of config");
} else { // Keep this for backwards compatibility
size_t value = tokenizer.toSizeT(i);
TORCH_CHECK_VALUE(
llvm::isPowerOf2_64(value),
"For roundups, the divisions has to be power of 2 ");
std::fill(
roundup_power2_divisions_.begin(),
roundup_power2_divisions_.end(),
value);
}
return i;
}
size_t AcceleratorAllocatorConfig::parseExpandableSegments(
const ConfigTokenizer& tokenizer,
size_t i) {
tokenizer.checkToken(++i, ":");
use_expandable_segments_ = tokenizer.toBool(++i);
return i;
}
size_t AcceleratorAllocatorConfig::parsePinnedUseBackgroundThreads(
const ConfigTokenizer& tokenizer,
size_t i) {
tokenizer.checkToken(++i, ":");
pinned_use_background_threads_ = tokenizer.toBool(++i);
return i;
}
void AcceleratorAllocatorConfig::parseArgs(const std::string& env) {
// The following option will be reset to its default value if not explicitly
// set each time.
max_split_size_ = std::numeric_limits<size_t>::max();
roundup_power2_divisions_.assign(kRoundUpPowerOfTwoIntervals, 0);
garbage_collection_threshold_ = 0;
{
std::lock_guard<std::mutex> lock(last_allocator_settings_mutex_);
last_allocator_settings_ = env;
}
ConfigTokenizer tokenizer(env);
for (size_t i = 0; i < tokenizer.size(); i++) {
const auto& key = tokenizer[i];
if (key == "max_split_size_mb") {
i = parseMaxSplitSize(tokenizer, i);
} else if (key == "max_non_split_rounding_mb") {
i = parseMaxNonSplitRoundingSize(tokenizer, i);
} else if (key == "garbage_collection_threshold") {
i = parseGarbageCollectionThreshold(tokenizer, i);
} else if (key == "roundup_power2_divisions") {
i = parseRoundUpPower2Divisions(tokenizer, i);
} else if (key == "expandable_segments") {
i = parseExpandableSegments(tokenizer, i);
} else if (key == "pinned_use_background_threads") {
i = parsePinnedUseBackgroundThreads(tokenizer, i);
} else {
// If a device-specific configuration parser hook is registered, it will
// check if the key is unrecognized.
if (device_config_parser_hook_) {
TORCH_CHECK(
keys_.find(key) != keys_.end(),
"Unrecognized key '",
key,
"' in Accelerator allocator config.");
}
i = tokenizer.skipKey(i);
}
if (i + 1 < tokenizer.size()) {
tokenizer.checkToken(++i, ",");
}
}
}
} // namespace c10::CachingAllocator

372
c10/core/AllocatorConfig.h Normal file
View File

@ -0,0 +1,372 @@
#pragma once
#include <c10/core/DeviceType.h>
#include <c10/util/Exception.h>
#include <c10/util/llvmMathExtras.h>
#include <atomic>
#include <mutex>
#include <string>
#include <unordered_set>
#include <vector>
namespace c10::CachingAllocator {
// "large" allocations may be packed in 20 MiB blocks
const size_t kLargeBuffer = 20971520;
// A utility class for tokenizing allocator configuration strings into discrete
// parts. For example, the config string:
// "key1:val1,key2:[val2,val3]"
// is tokenized into:
// "key1", ":", "val1", ",", "key2", ":", "[", "val2", ",", "val3", "]",
//
// Tokens include keys, values, and special characters (':', ',', '[', ']').
// Whitespace is ignored.
class ConfigTokenizer {
public:
explicit ConfigTokenizer(const std::string& env) {
std::string buffer;
for (char ch : env) {
if (ch == ',' || ch == ':' || ch == '[' || ch == ']') {
if (!buffer.empty()) {
config_.emplace_back(std::move(buffer));
buffer.clear();
}
config_.emplace_back(1, ch);
} else if (!std::isspace(static_cast<unsigned char>(ch))) {
buffer += ch;
}
}
if (!buffer.empty()) {
config_.emplace_back(std::move(buffer));
}
}
const std::string& operator[](size_t i) const {
TORCH_INTERNAL_ASSERT(
i < config_.size(), "Index out of bounds in ConfigTokenizer");
return config_[i];
}
size_t size() const {
return config_.size();
}
bool checkToken(size_t i, const std::string& token) const {
checkIndex(i);
return config_[i] == token;
}
size_t toSizeT(size_t i) const {
checkIndex(i);
return std::stoull(config_[i]);
}
double toDouble(size_t i) const {
checkIndex(i);
return std::stod(config_[i]);
}
bool toBool(size_t i) const {
checkIndex(i);
const auto& token = config_[i];
if (token == "True") {
return true;
} else if (token == "False") {
return false;
} else {
TORCH_CHECK_VALUE(
false,
"Expected 'True' or 'False' at index ",
i,
" in ConfigTokenizer but got '",
token,
"'");
}
}
// Skips the current token group and returns the index of the value token.
// Assumes the current index `i` points to a key name in a key-value pair.
size_t skipKey(size_t i) const {
// Expect a colon after the key
checkToken(++i, ":");
++i; // Move to the value
checkIndex(i);
if (config_[i] != "[") {
// Value is a single token (not a list) -> return its index
return i;
}
// Skip tokens inside the list until matching ']'
// NOLINTNEXTLINE(bugprone-inc-dec-in-conditions)
while (++i < config_.size() && config_[i] != "]") {
}
TORCH_INTERNAL_ASSERT(
i < config_.size(),
"Expected closing bracket ']' in ConfigTokenizer but reached end of config");
return i; // Return the index of the closing ']'
}
private:
void checkIndex(size_t i) const {
TORCH_INTERNAL_ASSERT(
i < config_.size(), "Index out of bounds in ConfigTokenizer");
}
std::vector<std::string> config_;
};
/**
* Note [AcceleratorAllocatorConfig design]
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
* This class configures memory allocation for both device and host memory. A
* single `AcceleratorAllocatorConfig` instance is shared across all accelerator
* backends, such as CUDA and XPU, under the assumption that relevant
* environment variables apply uniformly to all accelerators. Device-specific
* configuration extensions are supported via hooks (see
* `registerDeviceConfigParserHook`).
*
* Recommended design:
* - Place common configurations in `AcceleratorAllocatorConfig`.
* - Extend backend-specific configurations in corresponding device-specific
* classes, such as `CUDAAllocatorConfig`, etc.
*
* Scope:
* - Configuration options must be environment-variable driven.
*
* Naming Convention:
* - Public API names in `AcceleratorAllocatorConfig` should be device-generic.
* - Members prefixed with `pinned_` are specific to the host/pinned allocator.
* - Environment variable names should be generic across backends.
* - Comma-separated key-value pairs in the format: `key:value`. Use square
* brackets `[]` for list values Example: `key1:123, key2:[val1,val2]`
*
* Environment Variables:
* - The primary environment variable for configuration is `PYTORCH_ALLOC_CONF`.
* - For backward compatibility, `PYTORCH_CUDA_ALLOC_CONF` is also supported
* with lower priority.
*/
class C10_API AcceleratorAllocatorConfig {
public:
static AcceleratorAllocatorConfig& instance();
C10_DISABLE_COPY_AND_ASSIGN(AcceleratorAllocatorConfig);
AcceleratorAllocatorConfig(AcceleratorAllocatorConfig&&) = delete;
AcceleratorAllocatorConfig& operator=(AcceleratorAllocatorConfig&&) = delete;
~AcceleratorAllocatorConfig() = default;
/* Device allocator settings */
// Returns the maximum block size (in MB) that is allowed to be split. The
// default is unlimited (all blocks can be split).
static size_t max_split_size() {
return instance().max_split_size_;
}
// Returns the maximum block size (in MB) that is allowed to be rounded up
// without requiring splitting when searching for a free block. The default is
// 20 MiB.
static size_t max_non_split_rounding_size() {
return instance().max_non_split_rounding_size_;
}
// Return the number of divisions used when rounding up allocation sizes (in
// MB) to the nearest power-of-2 boundary.
static size_t roundup_power2_divisions(size_t size);
// Returns the vector of division factors used for rounding up allocation
// sizes. These divisions apply to size intervals between 1MB and 64GB.
static const std::vector<size_t>& roundup_power2_divisions() {
return instance().roundup_power2_divisions_;
}
// Returns the threshold that triggers garbage collection when the ratio of
// used memory to maximum allowed memory exceeds this value. The default is 0,
// meaning no garbage collection is triggered. The value should be in the
// range (0.0, 1.0).
static double garbage_collection_threshold() {
return instance().garbage_collection_threshold_;
}
// Returns whether the expandable segment feature is enabled. This allows the
// allocator to start with one segment that grows as needed, rather than
// creating a new segment for each allocation. Default is false (expandable
// segments disabled).
static bool use_expandable_segments() {
return instance().use_expandable_segments_;
}
/* Host allocator settings */
// Returns whether the pinned host allocator uses background threads for
// processing events. This is useful for improving performance in scenarios
// where many small allocations are made. Default is false (background threads
// disabled).
static bool pinned_use_background_threads() {
return instance().pinned_use_background_threads_;
}
/* Settings for both device and host allocator */
// Returns the current allocator settings as a string. This string is useful
// to expand device-specific allocator configurations
static std::string last_allocator_settings() {
std::lock_guard<std::mutex> lock(instance().last_allocator_settings_mutex_);
return instance().last_allocator_settings_;
}
// Returns the set of valid keys for the allocator configuration.
// This set is used to validate the presence and correctness of keys in
// device-specific configuration parsers.
static const std::unordered_set<std::string>& getKeys() {
return keys_;
}
// Registers a device-specific configuration parser hook and its key. This
// allows backends to parse additional device-specific configuration options
// from the environment variable. The hook should be a function that takes a
// string (the environment variable value) and parses it to set
// device-specific configuration options. The hook will be called when the
// environment variable is parsed. If a hook is already registered, it will be
// replaced with the new one.
static void registerDeviceConfigParserHook(
std::function<void(const std::string&)>&& hook,
const std::unordered_set<std::string>& keys) {
device_config_parser_hook_ = std::move(hook);
for (auto& key : keys) {
TORCH_CHECK(
keys_.insert(key).second,
"Duplicated key '",
key,
"' found in device-specific configuration parser hook registration");
}
}
// Calls the registered device-specific configuration parser hook with the
// provided environment string. This allows backends to parse additional
// device-specific configuration options from the environment variable.
// If no hook is registered, this function does nothing.
static void callDeviceConfigParserHook(const std::string& env) {
if (device_config_parser_hook_) {
device_config_parser_hook_(env);
}
}
// Parses the environment variable `env` to update the allocator settings.
// If the environment variable is not set, it does nothing.
// The configuration string should be a comma-separated list of key-value
// pairs, where each key is a configuration option and the value is the
// corresponding setting. For example:
// "max_split_size_mb:100,max_non_split_rounding_mb:20,garbage_collection_threshold:0.5,roundup_power2_divisions:[64:8,256:4,1024:4,>:1],expandable_segments:true,pinned_use_background_threads:true"
void parseArgs(const std::string& env);
private:
AcceleratorAllocatorConfig();
/* Internal functions for device allocator */
// Parse `max_split_size_mb` from environment variable.
size_t parseMaxSplitSize(const ConfigTokenizer& tokenizer, size_t i);
// Parse `max_non_split_rounding_mb` from environment variable.
size_t parseMaxNonSplitRoundingSize(
const ConfigTokenizer& tokenizer,
size_t i);
// Parse `garbage_collection_threshold` from environment variable.
size_t parseGarbageCollectionThreshold(
const ConfigTokenizer& tokenizer,
size_t i);
// Parse `roundup_power2_divisions` from environment variable.
size_t parseRoundUpPower2Divisions(
const ConfigTokenizer& tokenizer,
size_t i);
// Parse `expandable_segments` from environment variable.
size_t parseExpandableSegments(const ConfigTokenizer& tokenizer, size_t i);
/* Internal functions for host allocator */
// Parse `pinned_use_background_threads` from environment variable.
size_t parsePinnedUseBackgroundThreads(
const ConfigTokenizer& tokenizer,
size_t i);
/* The following members are specifically used for the device allocator. */
// The maximum block size that is allowed to be split.
std::atomic<size_t> max_split_size_{std::numeric_limits<size_t>::max()};
// The maximum allowable extra size of a memory block without requiring
// splitting when searching for a free block.
std::atomic<size_t> max_non_split_rounding_size_{kLargeBuffer};
// Used to store how memory allocations of different sizes should be rounded
// up to the nearest power of 2 divisions.
std::vector<size_t> roundup_power2_divisions_;
// The threshold that triggers garbage collection when the ratio of used
// memory to maximum allowed memory exceeds this value.
std::atomic<double> garbage_collection_threshold_{0};
// A flag to enable expandable segments feature.
std::atomic<bool> use_expandable_segments_{false};
/* The following members are specifically used for the host allocator. */
// A flag to enable background thread for processing events.
std::atomic<bool> pinned_use_background_threads_{false};
/* The following members are used for both device and host allocator. */
// Record the last allocator config environment setting.
std::mutex last_allocator_settings_mutex_;
std::string last_allocator_settings_;
// Optional hook for parsing additional device-specific allocator settings.
// This allows backends (e.g., CUDA, XPU) to register a custom parser for
// their own environment configuration extensions.
inline static std::function<void(const std::string&)>
device_config_parser_hook_{nullptr};
// A set of valid configuration keys, including both common and
// device-specific options. This set is used to validate the presence and
// correctness of keys during parsing.
inline static std::unordered_set<std::string> keys_{
"max_split_size_mb",
"max_non_split_rounding_mb",
"garbage_collection_threshold",
"roundup_power2_divisions",
"expandable_segments",
"pinned_use_background_threads"};
};
C10_API inline void setAllocatorSettings(const std::string& env) {
AcceleratorAllocatorConfig::instance().parseArgs(env);
AcceleratorAllocatorConfig::callDeviceConfigParserHook(env);
}
C10_API inline std::string getAllocatorSettings() {
return AcceleratorAllocatorConfig::instance().last_allocator_settings();
}
struct DeviceConfigParserHookRegistry {
explicit DeviceConfigParserHookRegistry(
std::function<void(const std::string&)>&& hook,
const std::unordered_set<std::string>& keys) {
// Use static method to avoid static initialization order fiasco issues
AcceleratorAllocatorConfig::registerDeviceConfigParserHook(
std::move(hook), keys);
}
};
// Assume each config parser has `parseArgs` and `getKeys` methods
#define REGISTER_ALLOCATOR_CONFIG_PARSE_HOOK(parser_cls) \
namespace { \
static at::CachingAllocator::DeviceConfigParserHookRegistry \
g_device_config_parse_hook_registry_instance( \
[](const std::string& env) { \
parser_cls::instance().parseArgs(env); \
}, \
parser_cls::getKeys()); \
}
} // namespace c10::CachingAllocator

View File

@ -240,24 +240,4 @@ struct C10_API PyInterpreter {
void disarm() noexcept;
};
// PyInterpreterStatus describes what the state of its interpreter tag
// is, relative to the thread currently holding the GIL.
enum class PyInterpreterStatus {
// We just allocated the Tensor, it hasn't escaped to other threads,
// we know that it definitely hasn't been tagged to be associated
// with an interpreter.
DEFINITELY_UNINITIALIZED,
// We queried the interpreter field and it looked uninitialized. But
// another thread may have raced with us to tag it with some other
// interpreter id. So we will have to do a CEX to make sure we can
// actually nab it.
MAYBE_UNINITIALIZED,
// We queried the interpreter field and it was tagged to belong to us.
// This means we have sole write access (as we hold the GIL for this
// interpreter)
TAGGED_BY_US,
// Someone else tagged this. We can't use this TensorImpl from Python.
TAGGED_BY_OTHER,
};
} // namespace c10::impl

View File

@ -0,0 +1,32 @@
#include <c10/core/impl/PyInterpreterHooks.h>
namespace c10::impl {
// Define the registry
C10_DEFINE_REGISTRY(
PyInterpreterHooksRegistry,
PyInterpreterHooksInterface,
PyInterpreterHooksArgs)
const PyInterpreterHooksInterface& getPyInterpreterHooks() {
auto create_impl = [] {
#if !defined C10_MOBILE
auto hooks = PyInterpreterHooksRegistry()->Create(
"PyInterpreterHooks", PyInterpreterHooksArgs{});
if (hooks) {
return hooks;
}
#endif
// Return stub implementation that will throw errors when methods are called
return std::make_unique<PyInterpreterHooksInterface>();
};
static auto hooks = create_impl();
return *hooks;
}
// Main function to get global PyInterpreter
PyInterpreter* getGlobalPyInterpreter() {
return getPyInterpreterHooks().getPyInterpreter();
}
} // namespace c10::impl

View File

@ -0,0 +1,39 @@
#pragma once
#include <c10/core/impl/PyInterpreter.h>
#include <c10/macros/Export.h>
#include <c10/util/Registry.h>
#include <memory>
namespace c10::impl {
// Minimal interface for PyInterpreter hooks
struct C10_API PyInterpreterHooksInterface {
virtual ~PyInterpreterHooksInterface() = default;
// Get the PyInterpreter instance
// Stub implementation throws error when Python is not available
virtual PyInterpreter* getPyInterpreter() const {
TORCH_CHECK(
false,
"PyTorch was compiled without Python support. "
"Cannot access Python interpreter from C++.");
}
};
struct C10_API PyInterpreterHooksArgs{};
C10_DECLARE_REGISTRY(
PyInterpreterHooksRegistry,
PyInterpreterHooksInterface,
PyInterpreterHooksArgs);
#define REGISTER_PYTHON_HOOKS(clsname) \
C10_REGISTER_CLASS(PyInterpreterHooksRegistry, clsname, clsname)
// Get the global PyInterpreter hooks instance
C10_API const PyInterpreterHooksInterface& getPyInterpreterHooks();
C10_API PyInterpreter* getGlobalPyInterpreter();
} // namespace c10::impl

View File

@ -34,29 +34,12 @@ PyObject* PyObjectSlot::_unchecked_untagged_pyobj() const {
reinterpret_cast<uintptr_t>(pyobj_) & ~0x1ULL);
}
void PyObjectSlot::unchecked_clear_pyobj(PyInterpreter* interpreter) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(interpreter == pyobj_interpreter_.load());
pyobj_ = nullptr;
}
PyInterpreter& PyObjectSlot::load_pyobj_interpreter() const {
auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire);
if (interpreter) {
return *interpreter;
}
TORCH_CHECK(
false,
"cannot access PyObject for Tensor on interpreter ",
(*pyobj_interpreter_.load())->name());
}
bool PyObjectSlot::check_interpreter(PyInterpreter* interpreter) {
return interpreter == pyobj_interpreter();
}
bool PyObjectSlot::has_pyobj_nonhermetic() {
return check_pyobj(pyobj_interpreter(), /*ignore_hermetic_tls=*/true)
.has_value();
TORCH_CHECK(false, "cannot access PyObject for Tensor - no interpreter set");
}
bool PyObjectSlot::owns_pyobj() {

View File

@ -2,6 +2,7 @@
#include <c10/core/impl/HermeticPyObjectTLS.h>
#include <c10/core/impl/PyInterpreter.h>
#include <c10/core/impl/PyInterpreterHooks.h>
#include <c10/util/python_stub.h>
#include <optional>
@ -24,52 +25,9 @@ struct C10_API PyObjectSlot {
//
// NB: THIS FUNCTION CAN RAISE AN EXCEPTION. Make sure to clean up after
// PyObject if necessary!
void init_pyobj(
PyInterpreter* self_interpreter,
PyObject* pyobj,
PyInterpreterStatus status) {
impl::PyInterpreter* expected = nullptr;
switch (status) {
case impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED:
// caller guarantees there is no multithreaded access; if there is
// no data race OK to do a relaxed store
pyobj_interpreter_.store(self_interpreter, std::memory_order_relaxed);
break;
case impl::PyInterpreterStatus::TAGGED_BY_US:
// no tagging is necessary, the tag is already correct
break;
case impl::PyInterpreterStatus::MAYBE_UNINITIALIZED:
// attempt to claim this TensorImpl with the specified interpreter
// tag
if (pyobj_interpreter_.compare_exchange_strong(
expected, self_interpreter, std::memory_order_acq_rel)) {
break;
}
// test if, actually, it was already tagged by us! this situation can't
// be caused by a race, but it could be caused by a situation
// where someone conservatively tagged the tensor as MAYBE_UNINITIALIZED
// (because they didn't pre-check the tag) when actually it was
// owned by the interpreter
if (expected == self_interpreter) {
break;
}
// fallthrough, we lost the race. We are guaranteed not to lose the
// race with ourself, as calls to init_pyobj with the same interpreter
// ID must be sequentialized by the GIL
[[fallthrough]];
case impl::PyInterpreterStatus::TAGGED_BY_OTHER:
TORCH_CHECK(
false,
"cannot allocate PyObject for Tensor on interpreter ",
self_interpreter,
" that has already been used by another torch deploy interpreter ",
pyobj_interpreter_.load());
}
// we are the ONLY thread that can have gotten to this point. It is not
// possible to conflict with another zero interpreter as access is protected
// by GIL
// NB: owns_pyobj tag is initially false
void init_pyobj(PyObject* pyobj) {
pyobj_interpreter_.store(
getGlobalPyInterpreter(), std::memory_order_relaxed);
pyobj_ = pyobj;
}
@ -94,49 +52,25 @@ struct C10_API PyObjectSlot {
//
// NB: this lives in header so that we can avoid actually creating the
// std::optional
std::optional<PyObject*> check_pyobj(
PyInterpreter* self_interpreter,
bool ignore_hermetic_tls = false) const {
// Note [Memory ordering on Python interpreter tag]
// @todo alban: I'm not too sure what's going on here, we can probably delete
// it but it's worthwhile making sure
std::optional<PyObject*> check_pyobj(bool ignore_hermetic_tls = false) const {
impl::PyInterpreter* interpreter =
pyobj_interpreter_.load(std::memory_order_acquire);
if (interpreter == nullptr) {
// NB: This never returns DEFINITELY_UNINITIALIZED because there is
// always the possibility that another thread races to initialize
// after we query here. The only time when we can conclude a tensor
// is definitely uninitialized is when we have just allocated it and
// it cannot have escaped to other threads yet
return std::nullopt;
} else if (interpreter == self_interpreter) {
// NB: pyobj_ could still be null!
if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) {
return std::nullopt;
} else {
return _unchecked_untagged_pyobj();
}
}
if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) {
return std::nullopt;
} else {
TORCH_CHECK(
false,
"cannot access PyObject for Tensor on interpreter ",
(*self_interpreter)->name(),
" that has already been used by another torch deploy interpreter ",
(*pyobj_interpreter_.load())->name());
return _unchecked_untagged_pyobj();
}
}
// Clear the PyObject field for an interpreter, in situations where we
// statically know the tensor is tagged with our interpreter.
void unchecked_clear_pyobj(PyInterpreter* interpreter);
PyInterpreter& load_pyobj_interpreter() const;
// Check if the PyObjectSlot's interpreter is the same as the specified
// interpreter
bool check_interpreter(PyInterpreter* interpreter);
// Check if the PyObjectSlot is holding a PyObject, owned or non-owned
bool has_pyobj_nonhermetic();
bool owns_pyobj();
void set_owns_pyobj(bool b);

View File

@ -1,389 +1,119 @@
#include <c10/cuda/CUDAAllocatorConfig.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/util/llvmMathExtras.h>
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
#include <c10/cuda/driver_api.h>
#endif
#include <cuda_runtime_api.h>
namespace c10::cuda::CUDACachingAllocator {
constexpr size_t kRoundUpPowerOfTwoIntervals = 16;
CUDAAllocatorConfig::CUDAAllocatorConfig()
: m_max_split_size(std::numeric_limits<size_t>::max()),
m_max_non_split_rounding_size(kLargeBuffer),
m_garbage_collection_threshold(0),
m_pinned_num_register_threads(1),
m_expandable_segments(false),
#if CUDA_VERSION >= 12030
m_expandable_segments_handle_type(
Expandable_Segments_Handle_Type::UNSPECIFIED),
#else
m_expandable_segments_handle_type(
Expandable_Segments_Handle_Type::POSIX_FD),
#endif
m_release_lock_on_cudamalloc(false),
m_pinned_use_cuda_host_register(false),
m_pinned_use_background_threads(false) {
m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0);
}
size_t CUDAAllocatorConfig::roundup_power2_divisions(size_t size) {
size_t log_size = (63 - llvm::countLeadingZeros(size));
// Our intervals start at 1MB and end at 64GB
const size_t interval_start =
63 - llvm::countLeadingZeros(static_cast<size_t>(1048576));
const size_t interval_end =
63 - llvm::countLeadingZeros(static_cast<size_t>(68719476736));
TORCH_CHECK(
(interval_end - interval_start == kRoundUpPowerOfTwoIntervals),
"kRoundUpPowerOfTwoIntervals mismatch");
int index = static_cast<int>(log_size) - static_cast<int>(interval_start);
index = std::max(0, index);
index = std::min(index, static_cast<int>(kRoundUpPowerOfTwoIntervals) - 1);
return instance().m_roundup_power2_divisions[index];
}
void CUDAAllocatorConfig::lexArgs(
const std::string& env,
std::vector<std::string>& config) {
std::vector<char> buf;
for (char ch : env) {
if (ch == ',' || ch == ':' || ch == '[' || ch == ']') {
if (!buf.empty()) {
config.emplace_back(buf.begin(), buf.end());
buf.clear();
}
config.emplace_back(1, ch);
} else if (ch != ' ') {
buf.emplace_back(ch);
}
}
if (!buf.empty()) {
config.emplace_back(buf.begin(), buf.end());
}
}
void CUDAAllocatorConfig::consumeToken(
const std::vector<std::string>& config,
size_t i,
const char c) {
TORCH_CHECK(
i < config.size() && config[i] == std::string(1, c),
"Error parsing CachingAllocator settings, expected ",
c,
"");
}
size_t CUDAAllocatorConfig::parseMaxSplitSize(
const std::vector<std::string>& config,
size_t i) {
consumeToken(config, ++i, ':');
constexpr int mb = 1024 * 1024;
if (++i < config.size()) {
size_t val1 = stoi(config[i]);
TORCH_CHECK(
val1 > kLargeBuffer / mb,
"CachingAllocator option max_split_size_mb too small, must be > ",
kLargeBuffer / mb,
"");
val1 = std::max(val1, kLargeBuffer / mb);
val1 = std::min(val1, (std::numeric_limits<size_t>::max() / mb));
m_max_split_size = val1 * 1024 * 1024;
} else {
TORCH_CHECK(false, "Error, expecting max_split_size_mb value", "");
}
return i;
}
size_t CUDAAllocatorConfig::parseMaxNonSplitRoundingSize(
const std::vector<std::string>& config,
size_t i) {
consumeToken(config, ++i, ':');
constexpr int mb = 1024 * 1024;
if (++i < config.size()) {
size_t val1 = stoi(config[i]);
TORCH_CHECK(
val1 > kLargeBuffer / mb,
"CachingAllocator option max_non_split_rounding_mb too small, must be > ",
kLargeBuffer / mb,
"");
val1 = std::max(val1, kLargeBuffer / mb);
val1 = std::min(val1, (std::numeric_limits<size_t>::max() / mb));
m_max_non_split_rounding_size = val1 * 1024 * 1024;
} else {
TORCH_CHECK(false, "Error, expecting max_non_split_rounding_mb value", "");
}
return i;
}
size_t CUDAAllocatorConfig::parseGarbageCollectionThreshold(
const std::vector<std::string>& config,
size_t i) {
consumeToken(config, ++i, ':');
if (++i < config.size()) {
double val1 = stod(config[i]);
TORCH_CHECK(
val1 > 0, "garbage_collect_threshold too small, set it 0.0~1.0", "");
TORCH_CHECK(
val1 < 1.0, "garbage_collect_threshold too big, set it 0.0~1.0", "");
m_garbage_collection_threshold = val1;
} else {
TORCH_CHECK(
false, "Error, expecting garbage_collection_threshold value", "");
}
return i;
}
size_t CUDAAllocatorConfig::parseRoundUpPower2Divisions(
const std::vector<std::string>& config,
size_t i) {
consumeToken(config, ++i, ':');
bool first_value = true;
if (++i < config.size()) {
if (std::string_view(config[i]) == "[") {
size_t last_index = 0;
// NOLINTNEXTLINE(bugprone-inc-dec-in-conditions)
while (++i < config.size() && std::string_view(config[i]) != "]") {
const std::string& val1 = config[i];
size_t val2 = 0;
consumeToken(config, ++i, ':');
if (++i < config.size()) {
val2 = stoi(config[i]);
} else {
TORCH_CHECK(
false, "Error parsing roundup_power2_divisions value", "");
}
TORCH_CHECK(
val2 == 0 || llvm::isPowerOf2_64(val2),
"For roundups, the divisions has to be power of 2 or 0 to disable roundup ",
"");
if (std::string_view(val1) == ">") {
std::fill(
std::next(
m_roundup_power2_divisions.begin(),
static_cast<std::vector<unsigned long>::difference_type>(
last_index)),
m_roundup_power2_divisions.end(),
val2);
} else {
size_t val1_long = stoul(val1);
TORCH_CHECK(
llvm::isPowerOf2_64(val1_long),
"For roundups, the intervals have to be power of 2 ",
"");
size_t index = 63 - llvm::countLeadingZeros(val1_long);
index = std::max((size_t)0, index);
index = std::min(index, m_roundup_power2_divisions.size() - 1);
if (first_value) {
std::fill(
m_roundup_power2_divisions.begin(),
std::next(
m_roundup_power2_divisions.begin(),
static_cast<std::vector<unsigned long>::difference_type>(
index)),
val2);
first_value = false;
}
if (index < m_roundup_power2_divisions.size()) {
m_roundup_power2_divisions[index] = val2;
}
last_index = index;
}
if (std::string_view(config[i + 1]) != "]") {
consumeToken(config, ++i, ',');
}
}
} else { // Keep this for backwards compatibility
size_t val1 = stoi(config[i]);
TORCH_CHECK(
llvm::isPowerOf2_64(val1),
"For roundups, the divisions has to be power of 2 ",
"");
std::fill(
m_roundup_power2_divisions.begin(),
m_roundup_power2_divisions.end(),
val1);
}
} else {
TORCH_CHECK(false, "Error, expecting roundup_power2_divisions value", "");
}
return i;
}
size_t CUDAAllocatorConfig::parseAllocatorConfig(
const std::vector<std::string>& config,
size_t i,
bool& used_cudaMallocAsync) {
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
size_t i) {
// For ease of maintenance and understanding, the CUDA and ROCm
// implementations of this function are separated. This avoids having many
// #ifdef's throughout.
#ifdef USE_ROCM
// Ease burden on ROCm users by allowing either cuda or hip tokens.
// cuda token is broken up to prevent hipify matching it.
#define PYTORCH_TOKEN1 \
"cud" \
"aMallocAsync"
#define PYTORCH_TOKEN2 "hipMallocAsync"
consumeToken(config, ++i, ':');
if (++i < config.size()) {
tokenizer.checkToken(++i, ":");
i++; // Move to the value after the colon
TORCH_CHECK_VALUE(
((tokenizer[i] == "native") || (tokenizer[i] == PYTORCH_TOKEN1) ||
(tokenizer[i] == PYTORCH_TOKEN2)),
"Unknown allocator backend, "
"options are native, " PYTORCH_TOKEN1 ", and " PYTORCH_TOKEN2);
if (m_is_allocator_loaded) {
bool aync_allocator_at_runtime = (tokenizer[i] != "native");
TORCH_CHECK(
((config[i] == "native") || (config[i] == PYTORCH_TOKEN1) ||
(config[i] == PYTORCH_TOKEN2)),
"Unknown allocator backend, "
"options are native, " PYTORCH_TOKEN1 ", and " PYTORCH_TOKEN2);
used_cudaMallocAsync =
(config[i] == PYTORCH_TOKEN1 || config[i] == PYTORCH_TOKEN2);
TORCH_INTERNAL_ASSERT(
config[i] == get()->name() ||
(config[i] == PYTORCH_TOKEN1 && get()->name() == PYTORCH_TOKEN2),
"Allocator backend parsed at runtime != "
"allocator backend parsed at load time, ",
config[i],
aync_allocator_at_runtime == m_use_async_allocator,
"Allocator async backend parsed at runtime != allocator async backend parsed at load time, ",
aync_allocator_at_runtime,
" != ",
get()->name());
} else {
TORCH_CHECK(false, "Error parsing backend value", "");
m_use_async_allocator);
}
m_use_async_allocator =
(tokenizer[i] == PYTORCH_TOKEN1 || tokenizer[i] == PYTORCH_TOKEN2);
// CUDA allocator is always loaded at the start of the program
m_is_allocator_loaded = true;
#if defined(CUDA_VERSION)
if (m_use_async_allocator) {
#if CUDA_VERSION >= 11040
int version = 0;
C10_CUDA_CHECK(cudaDriverGetVersion(&version));
TORCH_CHECK(
version >= 11040,
"backend:cudaMallocAsync requires CUDA runtime "
"11.4 or newer, but cudaDriverGetVersion returned ",
version);
#else
TORCH_CHECK(
false,
"backend:cudaMallocAsync requires PyTorch to be built with "
"CUDA 11.4 or newer, but CUDA_VERSION is ",
CUDA_VERSION);
#endif
}
#endif
return i;
#undef PYTORCH_TOKEN1
#undef PYTORCH_TOKEN2
#else // USE_ROCM
consumeToken(config, ++i, ':');
if (++i < config.size()) {
TORCH_CHECK(
((config[i] == "native") || (config[i] == "cudaMallocAsync")),
"Unknown allocator backend, "
"options are native and cudaMallocAsync");
used_cudaMallocAsync = (config[i] == "cudaMallocAsync");
if (used_cudaMallocAsync) {
#if CUDA_VERSION >= 11040
int version = 0;
C10_CUDA_CHECK(cudaDriverGetVersion(&version));
TORCH_CHECK(
version >= 11040,
"backend:cudaMallocAsync requires CUDA runtime "
"11.4 or newer, but cudaDriverGetVersion returned ",
version);
#else
TORCH_CHECK(
false,
"backend:cudaMallocAsync requires PyTorch to be built with "
"CUDA 11.4 or newer, but CUDA_VERSION is ",
CUDA_VERSION);
#endif
}
TORCH_INTERNAL_ASSERT(
config[i] == get()->name(),
"Allocator backend parsed at runtime != "
"allocator backend parsed at load time");
} else {
TORCH_CHECK(false, "Error parsing backend value", "");
}
return i;
#endif // USE_ROCM
}
void CUDAAllocatorConfig::parseArgs(const std::optional<std::string>& env) {
void CUDAAllocatorConfig::parseArgs(const std::string& env) {
// If empty, set the default values
m_max_split_size = std::numeric_limits<size_t>::max();
m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0);
m_garbage_collection_threshold = 0;
bool used_cudaMallocAsync = false;
bool used_native_specific_option = false;
if (!env.has_value()) {
return;
}
{
std::lock_guard<std::mutex> lock(m_last_allocator_settings_mutex);
m_last_allocator_settings = env.value();
}
std::vector<std::string> config;
lexArgs(env.value(), config);
for (size_t i = 0; i < config.size(); i++) {
std::string_view config_item_view(config[i]);
if (config_item_view == "max_split_size_mb") {
i = parseMaxSplitSize(config, i);
used_native_specific_option = true;
} else if (config_item_view == "max_non_split_rounding_mb") {
i = parseMaxNonSplitRoundingSize(config, i);
used_native_specific_option = true;
} else if (config_item_view == "garbage_collection_threshold") {
i = parseGarbageCollectionThreshold(config, i);
used_native_specific_option = true;
} else if (config_item_view == "roundup_power2_divisions") {
i = parseRoundUpPower2Divisions(config, i);
used_native_specific_option = true;
} else if (config_item_view == "backend") {
i = parseAllocatorConfig(config, i, used_cudaMallocAsync);
} else if (config_item_view == "expandable_segments") {
used_native_specific_option = true;
consumeToken(config, ++i, ':');
++i;
TORCH_CHECK(
i < config.size() &&
(std::string_view(config[i]) == "True" ||
std::string_view(config[i]) == "False"),
"Expected a single True/False argument for expandable_segments");
config_item_view = config[i];
m_expandable_segments = (config_item_view == "True");
c10::CachingAllocator::ConfigTokenizer tokenizer(env);
for (size_t i = 0; i < tokenizer.size(); i++) {
const auto& key = tokenizer[i];
if (key == "backend") {
i = parseAllocatorConfig(tokenizer, i);
} else if (
// ROCm build's hipify step will change "cuda" to "hip", but for ease of
// use, accept both. We must break up the string to prevent hipify here.
config_item_view == "release_lock_on_hipmalloc" ||
config_item_view ==
key == "release_lock_on_hipmalloc" ||
key ==
"release_lock_on_c"
"udamalloc") {
used_native_specific_option = true;
consumeToken(config, ++i, ':');
++i;
TORCH_CHECK(
i < config.size() &&
(std::string_view(config[i]) == "True" ||
std::string_view(config[i]) == "False"),
"Expected a single True/False argument for release_lock_on_cudamalloc");
config_item_view = config[i];
m_release_lock_on_cudamalloc = (config_item_view == "True");
tokenizer.checkToken(++i, ":");
m_release_lock_on_cudamalloc = tokenizer.toBool(++i);
} else if (
// ROCm build's hipify step will change "cuda" to "hip", but for ease of
// use, accept both. We must break up the string to prevent hipify here.
config_item_view == "pinned_use_hip_host_register" ||
config_item_view ==
key == "pinned_use_hip_host_register" ||
key ==
"pinned_use_c"
"uda_host_register") {
i = parsePinnedUseCudaHostRegister(config, i);
i = parsePinnedUseCudaHostRegister(tokenizer, i);
used_native_specific_option = true;
} else if (config_item_view == "pinned_num_register_threads") {
i = parsePinnedNumRegisterThreads(config, i);
used_native_specific_option = true;
} else if (config_item_view == "pinned_use_background_threads") {
i = parsePinnedUseBackgroundThreads(config, i);
} else if (key == "pinned_num_register_threads") {
i = parsePinnedNumRegisterThreads(tokenizer, i);
used_native_specific_option = true;
} else {
const auto& keys =
c10::CachingAllocator::AcceleratorAllocatorConfig::getKeys();
TORCH_CHECK(
false, "Unrecognized CachingAllocator option: ", config_item_view);
keys.find(key) != keys.end(),
"Unrecognized key '",
key,
"' in Accelerator allocator config.");
i = tokenizer.skipKey(i);
}
if (i + 1 < config.size()) {
consumeToken(config, ++i, ',');
if (i + 1 < tokenizer.size()) {
tokenizer.checkToken(++i, ",");
}
}
if (used_cudaMallocAsync && used_native_specific_option) {
if (m_use_async_allocator && used_native_specific_option) {
TORCH_WARN(
"backend:cudaMallocAsync ignores max_split_size_mb,"
"roundup_power2_divisions, and garbage_collect_threshold.");
@ -391,64 +121,33 @@ void CUDAAllocatorConfig::parseArgs(const std::optional<std::string>& env) {
}
size_t CUDAAllocatorConfig::parsePinnedUseCudaHostRegister(
const std::vector<std::string>& config,
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
size_t i) {
consumeToken(config, ++i, ':');
if (++i < config.size()) {
TORCH_CHECK(
(config[i] == "True" || config[i] == "False"),
"Expected a single True/False argument for pinned_use_cuda_host_register");
m_pinned_use_cuda_host_register = (config[i] == "True");
} else {
TORCH_CHECK(
false, "Error, expecting pinned_use_cuda_host_register value", "");
}
tokenizer.checkToken(++i, ":");
m_pinned_use_cuda_host_register = tokenizer.toBool(++i);
return i;
}
size_t CUDAAllocatorConfig::parsePinnedNumRegisterThreads(
const std::vector<std::string>& config,
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
size_t i) {
consumeToken(config, ++i, ':');
if (++i < config.size()) {
size_t val2 = stoi(config[i]);
TORCH_CHECK(
llvm::isPowerOf2_64(val2),
"Number of register threads has to be power of 2 ",
"");
auto maxThreads = CUDAAllocatorConfig::pinned_max_register_threads();
TORCH_CHECK(
val2 <= maxThreads,
"Number of register threads should be less than or equal to " +
std::to_string(maxThreads),
"");
m_pinned_num_register_threads = val2;
} else {
TORCH_CHECK(
false, "Error, expecting pinned_num_register_threads value", "");
}
tokenizer.checkToken(++i, ":");
size_t val2 = tokenizer.toSizeT(++i);
TORCH_CHECK_VALUE(
llvm::isPowerOf2_64(val2),
"Number of register threads has to be power of 2 ",
"");
auto maxThreads = CUDAAllocatorConfig::pinned_max_register_threads();
TORCH_CHECK_VALUE(
val2 <= maxThreads,
"Number of register threads should be less than or equal to " +
std::to_string(maxThreads),
"");
m_pinned_num_register_threads = val2;
return i;
}
size_t CUDAAllocatorConfig::parsePinnedUseBackgroundThreads(
const std::vector<std::string>& config,
size_t i) {
consumeToken(config, ++i, ':');
if (++i < config.size()) {
TORCH_CHECK(
(config[i] == "True" || config[i] == "False"),
"Expected a single True/False argument for pinned_use_background_threads");
m_pinned_use_background_threads = (config[i] == "True");
} else {
TORCH_CHECK(
false, "Error, expecting pinned_use_background_threads value", "");
}
return i;
}
// General caching allocator utilities
void setAllocatorSettings(const std::string& env) {
CUDACachingAllocator::CUDAAllocatorConfig::instance().parseArgs(env.c_str());
}
REGISTER_ALLOCATOR_CONFIG_PARSE_HOOK(CUDAAllocatorConfig)
} // namespace c10::cuda::CUDACachingAllocator

View File

@ -1,16 +1,12 @@
#pragma once
#include <c10/core/AllocatorConfig.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAMacros.h>
#include <c10/util/Deprecated.h>
#include <c10/util/Exception.h>
#include <c10/util/env.h>
#include <atomic>
#include <cstddef>
#include <cstdlib>
#include <mutex>
#include <string>
#include <vector>
namespace c10::cuda::CUDACachingAllocator {
enum class Expandable_Segments_Handle_Type : int {
@ -22,21 +18,28 @@ enum class Expandable_Segments_Handle_Type : int {
// Environment config parser
class C10_CUDA_API CUDAAllocatorConfig {
public:
C10_DEPRECATED_MESSAGE(
"c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::max_split_size() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::max_split_size() instead.")
static size_t max_split_size() {
return instance().m_max_split_size;
return c10::CachingAllocator::AcceleratorAllocatorConfig::max_split_size();
}
C10_DEPRECATED_MESSAGE(
"c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::garbage_collection_threshold() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::garbage_collection_threshold() instead.")
static double garbage_collection_threshold() {
return instance().m_garbage_collection_threshold;
return c10::CachingAllocator::AcceleratorAllocatorConfig::
garbage_collection_threshold();
}
static bool expandable_segments() {
bool enabled = c10::CachingAllocator::AcceleratorAllocatorConfig::
use_expandable_segments();
#ifndef PYTORCH_C10_DRIVER_API_SUPPORTED
if (instance().m_expandable_segments) {
if (enabled) {
TORCH_WARN_ONCE("expandable_segments not supported on this platform")
}
return false;
#else
return instance().m_expandable_segments;
return enabled;
#endif
}
@ -62,8 +65,11 @@ class C10_CUDA_API CUDAAllocatorConfig {
return instance().m_pinned_num_register_threads;
}
C10_DEPRECATED_MESSAGE(
"c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::pinned_use_background_threads() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::pinned_use_background_threads() instead.")
static bool pinned_use_background_threads() {
return instance().m_pinned_use_background_threads;
return c10::CachingAllocator::AcceleratorAllocatorConfig::
pinned_use_background_threads();
}
static size_t pinned_max_register_threads() {
@ -73,92 +79,105 @@ class C10_CUDA_API CUDAAllocatorConfig {
return 128;
}
// This is used to round-up allocation size to nearest power of 2 divisions.
// More description below in function roundup_power2_next_division
// As an example, if we want 4 divisions between 2's power, this can be done
// using env variable: PYTORCH_CUDA_ALLOC_CONF=roundup_power2_divisions:4
static size_t roundup_power2_divisions(size_t size);
C10_DEPRECATED_MESSAGE(
"c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::roundup_power2_divisions() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::roundup_power2_divisions() instead.")
static size_t roundup_power2_divisions(size_t size) {
return c10::CachingAllocator::AcceleratorAllocatorConfig::
roundup_power2_divisions(size);
}
C10_DEPRECATED_MESSAGE(
"c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::roundup_power2_divisions() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::roundup_power2_divisions() instead.")
static std::vector<size_t> roundup_power2_divisions() {
return instance().m_roundup_power2_divisions;
return c10::CachingAllocator::AcceleratorAllocatorConfig::
roundup_power2_divisions();
}
C10_DEPRECATED_MESSAGE(
"c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::max_non_split_rounding_size() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::max_non_split_rounding_size() instead.")
static size_t max_non_split_rounding_size() {
return instance().m_max_non_split_rounding_size;
return c10::CachingAllocator::AcceleratorAllocatorConfig::
max_non_split_rounding_size();
}
C10_DEPRECATED_MESSAGE(
"c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::last_allocator_settings() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::last_allocator_settings() instead.")
static std::string last_allocator_settings() {
std::lock_guard<std::mutex> lock(
instance().m_last_allocator_settings_mutex);
return instance().m_last_allocator_settings;
return c10::CachingAllocator::getAllocatorSettings();
}
static bool use_async_allocator() {
return instance().m_use_async_allocator;
}
static const std::unordered_set<std::string>& getKeys() {
return keys_;
}
static CUDAAllocatorConfig& instance() {
static CUDAAllocatorConfig* s_instance = ([]() {
auto inst = new CUDAAllocatorConfig();
auto env = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF");
auto env = c10::utils::get_env("PYTORCH_ALLOC_CONF");
if (!env.has_value()) {
// For backward compatibility, check for the old environment variable
// PYTORCH_CUDA_ALLOC_CONF.
env = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF");
}
#ifdef USE_ROCM
// convenience for ROCm users, allow alternative HIP token
if (!env.has_value()) {
env = c10::utils::get_env("PYTORCH_HIP_ALLOC_CONF");
}
#endif
inst->parseArgs(env);
if (env.has_value()) {
inst->parseArgs(env.value());
}
return inst;
})();
return *s_instance;
}
void parseArgs(const std::optional<std::string>& env);
void parseArgs(const std::string& env);
private:
CUDAAllocatorConfig();
CUDAAllocatorConfig() = default;
static void lexArgs(const std::string& env, std::vector<std::string>& config);
static void consumeToken(
const std::vector<std::string>& config,
size_t i,
const char c);
size_t parseMaxSplitSize(const std::vector<std::string>& config, size_t i);
size_t parseMaxNonSplitRoundingSize(
const std::vector<std::string>& config,
size_t i);
size_t parseGarbageCollectionThreshold(
const std::vector<std::string>& config,
size_t i);
size_t parseRoundUpPower2Divisions(
const std::vector<std::string>& config,
size_t i);
size_t parseAllocatorConfig(
const std::vector<std::string>& config,
size_t i,
bool& used_cudaMallocAsync);
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
size_t i);
size_t parsePinnedUseCudaHostRegister(
const std::vector<std::string>& config,
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
size_t i);
size_t parsePinnedNumRegisterThreads(
const std::vector<std::string>& config,
size_t i);
size_t parsePinnedUseBackgroundThreads(
const std::vector<std::string>& config,
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
size_t i);
std::atomic<size_t> m_max_split_size;
std::atomic<size_t> m_max_non_split_rounding_size;
std::vector<size_t> m_roundup_power2_divisions;
std::atomic<double> m_garbage_collection_threshold;
std::atomic<size_t> m_pinned_num_register_threads;
std::atomic<bool> m_expandable_segments;
std::atomic<Expandable_Segments_Handle_Type>
m_expandable_segments_handle_type;
std::atomic<bool> m_release_lock_on_cudamalloc;
std::atomic<bool> m_pinned_use_cuda_host_register;
std::atomic<bool> m_pinned_use_background_threads;
std::string m_last_allocator_settings;
std::mutex m_last_allocator_settings_mutex;
std::atomic<size_t> m_pinned_num_register_threads{1};
std::atomic<Expandable_Segments_Handle_Type> m_expandable_segments_handle_type
#if CUDA_VERSION >= 12030
{Expandable_Segments_Handle_Type::UNSPECIFIED};
#else
{Expandable_Segments_Handle_Type::POSIX_FD};
#endif
std::atomic<bool> m_release_lock_on_cudamalloc{false};
std::atomic<bool> m_pinned_use_cuda_host_register{false};
std::atomic<bool> m_use_async_allocator{false};
std::atomic<bool> m_is_allocator_loaded{false};
inline static std::unordered_set<std::string> keys_{
"backend",
// keep BC for Rocm: `cuda` -> `cud` `a`, to avoid hipify issues
// NOLINTBEGIN(bugprone-suspicious-missing-comma,-warnings-as-errors)
"release_lock_on_cud"
"amalloc",
"pinned_use_cud"
"a_host_register",
// NOLINTEND(bugprone-suspicious-missing-comma,-warnings-as-errors)
"release_lock_on_hipmalloc",
"pinned_use_hip_host_register",
"pinned_num_register_threads"};
};
// General caching allocator utilities
C10_CUDA_API void setAllocatorSettings(const std::string& env);
// Keep this for backwards compatibility
using c10::CachingAllocator::setAllocatorSettings;
} // namespace c10::cuda::CUDACachingAllocator

View File

@ -1,7 +1,6 @@
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/core/impl/GPUTrace.h>
#include <c10/cuda/CUDAAllocatorConfig.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/cuda/CUDAGuard.h>
@ -64,10 +63,6 @@ namespace cuda::CUDACachingAllocator {
using namespace c10::CachingAllocator;
using namespace c10::CachingDeviceAllocator;
// Included here as this is externally used in CUDAAllocatorConfig
const size_t kLargeBuffer =
20971520; // "large" allocations may be packed in 20 MiB blocks
namespace Native {
//
@ -843,8 +838,7 @@ struct AllocParams {
size_t size,
cudaStream_t stream,
BlockPool* pool,
size_t alloc_size,
DeviceStats& stats)
size_t alloc_size)
: search_key(device, stream, size), pool(pool), alloc_size(alloc_size) {}
c10::DeviceIndex device() const {
@ -1231,7 +1225,7 @@ class DeviceCachingAllocator {
DeviceCachingAllocator()
: large_blocks(/*small=*/false), small_blocks(/*small=*/true) {
stats.max_split_size =
static_cast<int64_t>(CUDAAllocatorConfig::max_split_size());
static_cast<int64_t>(AcceleratorAllocatorConfig::max_split_size());
context_recorder_.store(nullptr);
}
@ -1341,7 +1335,7 @@ class DeviceCachingAllocator {
size_t size = round_size(orig_size);
auto& pool = get_pool(size, stream);
const size_t alloc_size = get_allocation_size(size);
AllocParams params(device, size, stream, &pool, alloc_size, stats);
AllocParams params(device, size, stream, &pool, alloc_size);
params.stat_types = get_stat_types_for_pool(pool);
// First, try to get a block from the existing pool.
@ -1356,7 +1350,8 @@ class DeviceCachingAllocator {
// Do garbage collection if the flag is set.
if (C10_UNLIKELY(
set_fraction &&
CUDAAllocatorConfig::garbage_collection_threshold() > 0.0)) {
AcceleratorAllocatorConfig::garbage_collection_threshold() >
0.0)) {
garbage_collect_cached_blocks(context);
}
// Attempt allocate
@ -1388,7 +1383,7 @@ class DeviceCachingAllocator {
beginAllocateToPool(mempool_id, filter);
auto& mempool = get_pool(size, stream);
AllocParams mempool_params(
device, size, stream, &mempool, alloc_size, stats);
device, size, stream, &mempool, alloc_size);
mempool_params.stat_types = get_stat_types_for_pool(mempool);
block_found = get_free_block(mempool_params);
endAllocateToPool(mempool_id);
@ -1608,7 +1603,7 @@ class DeviceCachingAllocator {
stats.active_bytes[stat_type].increase(block->size);
stats.requested_bytes[stat_type].increase(block->requested_size);
});
if (block->size >= CUDAAllocatorConfig::max_split_size())
if (block->size >= AcceleratorAllocatorConfig::max_split_size())
stats.oversize_allocations.increase(1);
auto allocated_bytes_gauge =
@ -1659,7 +1654,7 @@ class DeviceCachingAllocator {
block->pool->owner_MempoolId(),
context ? context : block->context_when_allocated);
if (block->size >= CUDAAllocatorConfig::max_split_size())
if (block->size >= AcceleratorAllocatorConfig::max_split_size())
stats.oversize_allocations.decrease(1);
if (!block->stream_uses.empty()) {
@ -1929,8 +1924,7 @@ class DeviceCachingAllocator {
block_state.size,
block_state.stream,
&pool,
block_state.size,
stats);
block_state.size);
pool.blocks.erase(curr_block);
params.block = curr_block;
params.stat_types = get_stat_types_for_pool(pool);
@ -2209,7 +2203,8 @@ class DeviceCachingAllocator {
if (size < kMinBlockSize) {
return kMinBlockSize;
} else {
auto divisions = CUDAAllocatorConfig::roundup_power2_divisions(size);
auto divisions =
AcceleratorAllocatorConfig::roundup_power2_divisions(size);
if (divisions > 1 && size > (kMinBlockSize * divisions)) {
return roundup_power2_next_division(size, divisions);
} else {
@ -2699,7 +2694,7 @@ class DeviceCachingAllocator {
if (block->pool->is_small || CUDAAllocatorConfig::expandable_segments()) {
return remaining >= kMinBlockSize;
} else {
return (size < CUDAAllocatorConfig::max_split_size()) &&
return (size < AcceleratorAllocatorConfig::max_split_size()) &&
(remaining > kSmallSize);
}
}
@ -2719,7 +2714,7 @@ class DeviceCachingAllocator {
if (C10_UNLIKELY(
set_fraction &&
CUDAAllocatorConfig::garbage_collection_threshold() > 0.0)) {
AcceleratorAllocatorConfig::garbage_collection_threshold() > 0.0)) {
// Track block reuse interval only when garbage collection is enabled.
++pool.get_free_blocks_call_count;
}
@ -2761,13 +2756,13 @@ class DeviceCachingAllocator {
}
// Do not return an oversized block for a large request
if ((p.size() < CUDAAllocatorConfig::max_split_size()) &&
((*it)->size >= CUDAAllocatorConfig::max_split_size()))
if ((p.size() < AcceleratorAllocatorConfig::max_split_size()) &&
((*it)->size >= AcceleratorAllocatorConfig::max_split_size()))
return false;
// Allow oversized block size to be rounded up but within a limit
if ((p.size() >= CUDAAllocatorConfig::max_split_size()) &&
if ((p.size() >= AcceleratorAllocatorConfig::max_split_size()) &&
((*it)->size >=
p.size() + CUDAAllocatorConfig::max_non_split_rounding_size()))
p.size() + AcceleratorAllocatorConfig::max_non_split_rounding_size()))
return false;
p.block = *it;
pool.blocks.erase(it);
@ -2790,7 +2785,7 @@ class DeviceCachingAllocator {
// therefore should be of less overheads.
size_t gc_threshold = static_cast<size_t>(
CUDAAllocatorConfig::garbage_collection_threshold() *
AcceleratorAllocatorConfig::garbage_collection_threshold() *
static_cast<double>(allowed_memory_maximum));
// No need to trigger GC yet
if (total_allocated_memory <= gc_threshold) {
@ -2938,7 +2933,7 @@ class DeviceCachingAllocator {
stats.segment[stat_type].increase(1);
stats.reserved_bytes[stat_type].increase(size);
});
if (size >= CUDAAllocatorConfig::max_split_size())
if (size >= AcceleratorAllocatorConfig::max_split_size())
stats.oversize_segments.increase(1);
auto reserved_bytes_gauge =
STATIC_GAUGE(pytorch.CUDACachingAllocator.reserved_bytes);
@ -2967,7 +2962,7 @@ class DeviceCachingAllocator {
bool release_available_cached_blocks(
const AllocParams& p,
const std::shared_ptr<GatheredContext>& context) {
if (CUDAAllocatorConfig::max_split_size() ==
if (AcceleratorAllocatorConfig::max_split_size() ==
std::numeric_limits<size_t>::max())
return false;
BlockPool& pool = *p.pool;
@ -2975,8 +2970,8 @@ class DeviceCachingAllocator {
// because of std::unique_ptr, block cannot be trivially copied
// Use constructor for search key.
Block key(p.search_key.device, p.search_key.stream, p.search_key.size);
key.size = (key.size < CUDAAllocatorConfig::max_split_size())
? CUDAAllocatorConfig::max_split_size()
key.size = (key.size < AcceleratorAllocatorConfig::max_split_size())
? AcceleratorAllocatorConfig::max_split_size()
: key.size;
auto it = pool.blocks.lower_bound(&key);
if (it == pool.blocks.end() || (*it)->stream != p.stream() ||
@ -2989,7 +2984,7 @@ class DeviceCachingAllocator {
--it; // Back up one item. Now on the largest block for the correct
// stream
while ((totalReleased < key.size) &&
((*it)->size >= CUDAAllocatorConfig::max_split_size()) &&
((*it)->size >= AcceleratorAllocatorConfig::max_split_size()) &&
((*it)->stream == p.stream())) {
auto cur = it;
bool is_first = cur == pool.blocks.begin();
@ -3114,7 +3109,7 @@ class DeviceCachingAllocator {
stats.reserved_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
.current);
if (block->size >= CUDAAllocatorConfig::max_split_size())
if (block->size >= AcceleratorAllocatorConfig::max_split_size())
stats.oversize_segments.decrease(1);
pool->blocks.erase(block);
delete block;
@ -3741,8 +3736,8 @@ class NativeCachingAllocator : public CUDAAllocator {
auto& md = result.config_metadata;
md.garbage_collection_threshold =
CUDAAllocatorConfig::garbage_collection_threshold();
md.max_split_size = CUDAAllocatorConfig::max_split_size();
AcceleratorAllocatorConfig::garbage_collection_threshold();
md.max_split_size = AcceleratorAllocatorConfig::max_split_size();
md.pinned_num_register_threads =
CUDAAllocatorConfig::pinned_num_register_threads();
md.expandable_segments = CUDAAllocatorConfig::expandable_segments();
@ -3750,9 +3745,10 @@ class NativeCachingAllocator : public CUDAAllocator {
CUDAAllocatorConfig::release_lock_on_cudamalloc();
md.pinned_use_host_register =
CUDAAllocatorConfig::pinned_use_cuda_host_register();
md.last_allocator_settings = CUDAAllocatorConfig::last_allocator_settings();
md.last_allocator_settings =
AcceleratorAllocatorConfig::last_allocator_settings();
md.roundup_power2_divisions =
CUDAAllocatorConfig::roundup_power2_divisions();
AcceleratorAllocatorConfig::roundup_power2_divisions();
return result;
}
@ -4130,49 +4126,10 @@ CUDAAllocator* allocator();
} // namespace CudaMallocAsync
struct BackendStaticInitializer {
// Parses env for backend at load time, duplicating some logic from
// CUDAAllocatorConfig. CUDAAllocatorConfig double-checks it later (at
// runtime). Defers verbose exceptions and error checks, including Cuda
// version checks, to CUDAAllocatorConfig's runtime doublecheck. If this
// works, maybe we should move all of CUDAAllocatorConfig here?
CUDAAllocator* parseEnvForBackend() {
auto val = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF");
#ifdef USE_ROCM
// convenience for ROCm users to allow either CUDA or HIP env var
if (!val.has_value()) {
val = c10::utils::get_env("PYTORCH_HIP_ALLOC_CONF");
}
#endif
if (val.has_value()) {
const std::string& config = val.value();
std::regex exp("[\\s,]+");
std::sregex_token_iterator it(config.begin(), config.end(), exp, -1);
std::sregex_token_iterator end;
std::vector<std::string> options(it, end);
for (auto option : options) {
std::regex exp2("[:]+");
std::sregex_token_iterator it2(option.begin(), option.end(), exp2, -1);
std::sregex_token_iterator end2;
std::vector<std::string> kv(it2, end2);
if (kv.size() >= 2) {
if (kv[0] == "backend") {
#ifdef USE_ROCM
// convenience for ROCm users to allow either CUDA or HIP env var
if (kv[1] ==
"cud"
"aMallocAsync" ||
kv[1] == "hipMallocAsync")
#else
if (kv[1] == "cudaMallocAsync")
#endif
return CudaMallocAsync::allocator();
if (kv[1] == "native")
return &Native::allocator;
}
}
}
// If the environment variable is set, we use the CudaMallocAsync allocator.
if (CUDAAllocatorConfig::use_async_allocator()) {
return CudaMallocAsync::allocator();
}
return &Native::allocator;
}

View File

@ -1,6 +1,7 @@
#pragma once
#include <c10/core/CachingDeviceAllocator.h>
#include <c10/cuda/CUDAAllocatorConfig.h>
#include <c10/cuda/CUDAGraphsC10Utils.h>
#include <c10/cuda/CUDAMacros.h>
#include <c10/cuda/CUDAStream.h>
@ -49,10 +50,9 @@ namespace c10::cuda::CUDACachingAllocator {
// Preserved only for BC reasons
// NOLINTNEXTLINE(misc-unused-using-decls)
using c10::CachingAllocator::kLargeBuffer;
using c10::CachingDeviceAllocator::DeviceStats;
extern const size_t kLargeBuffer;
typedef std::shared_ptr<GatheredContext> (*CreateContextFn)();
// Struct containing info of an allocation block (i.e. a fractional part of a

View File

@ -5,15 +5,86 @@
namespace c10 {
namespace metal {
namespace detail {
template <typename T>
struct simd_type {
using t = T;
};
// Helper that allows one to run simd ops over bfl16 by upcasting them to fp32
template <typename T>
using simd_type_t = typename simd_type<T>::t;
#if __METAL_VERSION__ >= 310
template <>
struct simd_type<bfloat> {
using t = float;
};
#endif
} // namespace detail
template <typename T>
inline ::metal::enable_if_t<!::metal::is_same_v<T, long>, T> simd_sum(T val) {
return ::metal::simd_sum(val);
return T(::metal::simd_sum(detail::simd_type_t<T>(val)));
}
template <typename T>
inline ::metal::enable_if_t<!::metal::is_same_v<T, long>, T> simd_prod(T val) {
return ::metal::simd_product(val);
return T(::metal::simd_product(detail::simd_type_t<T>(val)));
}
// Extend simd_broadcast to 64-bit integral types using int2 trick
template <
typename T,
::metal::enable_if_t<::metal::is_integral_v<T> && sizeof(T) == 8, bool> =
true>
inline T simd_broadcast(T val, ushort lane_id) {
return as_type<T>(::metal::simd_broadcast(as_type<int2>(val), lane_id));
}
template <
typename T,
::metal::enable_if_t<!::metal::is_integral_v<T> || sizeof(T) != 8, bool> =
true>
inline T simd_broadcast(T val, ushort lane_id) {
return ::metal::simd_broadcast(val, lane_id);
}
// Floating simd_min/max with nan propagation
template <
typename T,
::metal::enable_if_t<::metal::is_floating_point_v<T>, bool> = true>
inline T simd_max(T val) {
if (::metal::simd_any(::metal::isnan(val))) {
return ::metal::numeric_limits<T>::quiet_NaN();
}
return T(::metal::simd_max(detail::simd_type_t<T>(val)));
}
template <
typename T,
::metal::enable_if_t<::metal::is_floating_point_v<T>, bool> = true>
inline T simd_min(T val) {
if (::metal::simd_any(::metal::isnan(val))) {
return ::metal::numeric_limits<T>::quiet_NaN();
}
return T(::metal::simd_min(detail::simd_type_t<T>(val)));
}
template <
typename T,
::metal::enable_if_t<::metal::is_integral_v<T> && sizeof(T) != 8, bool> =
true>
inline T simd_max(T val) {
return ::metal::simd_max(val);
}
template <
typename T,
::metal::enable_if_t<::metal::is_integral_v<T> && sizeof(T) != 8, bool> =
true>
inline T simd_min(T val) {
return ::metal::simd_min(val);
}
// Metal does not support SIMD reductions over 64-bit types, but it could be
@ -28,7 +99,7 @@ inline ::metal::enable_if_t<::metal::is_same_v<T, long>, T> simd_sum(T val) {
val += as_type<T>(
::metal::simd_shuffle_and_fill_down(as_type<int2>(val), int2(0), i));
}
return as_type<T>(::metal::simd_broadcast(as_type<int2>(val), 0));
return simd_broadcast(val, 0);
}
template <typename T>
@ -37,7 +108,78 @@ inline ::metal::enable_if_t<::metal::is_same_v<T, long>, T> simd_prod(T val) {
val *= as_type<T>(
::metal::simd_shuffle_and_fill_down(as_type<int2>(val), int2(0), i));
}
return as_type<T>(::metal::simd_broadcast(as_type<int2>(val), 0));
return simd_broadcast(val, 0);
}
template <typename T>
inline ::metal::enable_if_t<::metal::is_same_v<T, long>, T> simd_max(T val) {
for (ushort i = simdgroup_size / 2; i > 0; i /= 2) {
val = ::metal::max(
val,
as_type<T>(::metal::simd_shuffle_and_fill_down(
as_type<int2>(val), int2(0), i)));
}
return simd_broadcast(val, 0);
}
template <typename T>
inline ::metal::enable_if_t<::metal::is_same_v<T, long>, T> simd_min(T val) {
for (ushort i = simdgroup_size / 2; i > 0; i /= 2) {
val = ::metal::min(
val,
as_type<T>(::metal::simd_shuffle_and_fill_down(
as_type<int2>(val), int2(0), i)));
}
return simd_broadcast(val, 0);
}
// argmin/argmax helpers using simd_ballot
template <
typename T,
::metal::enable_if_t<::metal::is_integral_v<T>, bool> = true>
inline ::c10::metal::pair<T, ushort> simd_argmin(T val) {
const auto rc = simd_min(val);
const auto vote = ::metal::simd_ballot(val == rc);
return {rc, static_cast<ushort>(::metal::ctz(static_cast<ulong>(vote)))};
}
template <
typename T,
::metal::enable_if_t<::metal::is_floating_point_v<T>, bool> = true>
inline ::c10::metal::pair<T, ushort> simd_argmin(T val) {
const auto rc = simd_min(val);
const auto vote = ::metal::simd_ballot(val == rc || ::metal::isnan(val));
return {rc, static_cast<ushort>(::metal::ctz(static_cast<ulong>(vote)))};
}
template <
typename T,
::metal::enable_if_t<::metal::is_integral_v<T>, bool> = true>
inline ::c10::metal::pair<T, ushort> simd_argmax(T val) {
const auto rc = simd_max(val);
const auto vote = ::metal::simd_ballot(val == rc);
return {rc, static_cast<ushort>(::metal::ctz(static_cast<ulong>(vote)))};
}
template <
typename T,
::metal::enable_if_t<::metal::is_floating_point_v<T>, bool> = true>
inline ::c10::metal::pair<T, ushort> simd_argmax(T val) {
const auto rc = simd_max(val);
const auto vote = ::metal::simd_ballot(val == rc || ::metal::isnan(val));
return {rc, static_cast<ushort>(::metal::ctz(static_cast<ulong>(vote)))};
}
template <typename ARG_T, typename IDX_T>
inline c10::metal::pair<ARG_T, IDX_T> simd_argmin(ARG_T val, IDX_T idx_val) {
auto rc = simd_argmin(val);
return {rc.first, simd_broadcast(idx_val, rc.second)};
}
template <typename ARG_T, typename IDX_T>
inline c10::metal::pair<ARG_T, IDX_T> simd_argmax(ARG_T val, IDX_T idx_val) {
auto rc = simd_argmax(val);
return {rc.first, simd_broadcast(idx_val, rc.second)};
}
// Below algorithms are written with hardcoded assumption that simdgroup is 32
@ -88,6 +230,44 @@ opmath_t<T> threadgroup_prod(
return data[0];
}
template <typename T>
T threadgroup_max(threadgroup T* data, T val, unsigned idx, unsigned size) {
auto rc = simd_max(val);
if (idx % simdgroup_size == 0) {
data[idx / simdgroup_size] = rc;
}
if (size > simdgroup_size) {
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
if (idx < ((size + simdgroup_size - 1) / simdgroup_size)) {
auto rc1 = simd_max(data[idx]);
if (idx == 0) {
data[0] = rc1;
}
}
}
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
return data[0];
}
template <typename T>
T threadgroup_min(threadgroup T* data, T val, unsigned idx, unsigned size) {
auto rc = simd_min(val);
if (idx % simdgroup_size == 0) {
data[idx / simdgroup_size] = rc;
}
if (size > simdgroup_size) {
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
if (idx < ((size + simdgroup_size - 1) / simdgroup_size)) {
auto rc1 = simd_min(data[idx]);
if (idx == 0) {
data[0] = rc1;
}
}
}
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
return data[0];
}
template <typename T>
float3 threadgroup_welford_reduce(threadgroup T* data, unsigned size) {
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
@ -123,52 +303,58 @@ float3 threadgroup_welford_combine(threadgroup T* data, unsigned size) {
return rc;
}
template <typename T>
T threadgroup_max(threadgroup T* data, unsigned size) {
// TODO: This should be moved to the callee
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
T rc = data[0];
for (unsigned idx = 1; idx < size; ++idx) {
rc = ::c10::metal::max(rc, data[idx]);
template <typename ARG_T, typename IDX_T>
IDX_T threadgroup_argmax(
threadgroup ARG_T* arg_data,
threadgroup IDX_T* idx_data,
ARG_T val,
IDX_T idx_val,
unsigned idx,
unsigned size) {
auto rc = simd_argmax(val, idx_val);
if (size <= simdgroup_size) {
return rc.second;
}
return rc;
}
template <typename T>
T threadgroup_min(threadgroup T* data, unsigned size) {
// TODO: This should be moved to the callee
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
T rc = data[0];
for (unsigned idx = 1; idx < size; ++idx) {
rc = ::c10::metal::min(rc, data[idx]);
if (idx % simdgroup_size == 0) {
arg_data[idx / simdgroup_size] = rc.first;
idx_data[idx / simdgroup_size] = rc.second;
}
return rc;
}
template <typename T>
int threadgroup_argmax(threadgroup T* data, unsigned size) {
// TODO: This should be moved to the callee
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
int rc = 0;
for (unsigned idx = 1; idx < size; ++idx) {
if (data[idx] > data[rc]) {
rc = idx;
if (idx < ((size + simdgroup_size - 1) / simdgroup_size)) {
auto rc1 = simd_argmax(arg_data[idx], idx_data[idx]);
if (idx == 0) {
idx_data[0] = rc1.second;
}
}
return rc;
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
return idx_data[0];
}
template <typename T>
int threadgroup_argmin(threadgroup T* data, unsigned size) {
// TODO: This should be moved to the callee
template <typename ARG_T, typename IDX_T>
IDX_T threadgroup_argmin(
threadgroup ARG_T* arg_data,
threadgroup IDX_T* idx_data,
ARG_T val,
IDX_T idx_val,
unsigned idx,
unsigned size) {
auto rc = simd_argmin(val, idx_val);
if (size <= simdgroup_size) {
return rc.second;
}
if (idx % simdgroup_size == 0) {
arg_data[idx / simdgroup_size] = rc.first;
idx_data[idx / simdgroup_size] = rc.second;
}
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
int rc = 0;
for (unsigned idx = 1; idx < size; ++idx) {
if (data[idx] < data[rc]) {
rc = idx;
if (idx < ((size + simdgroup_size - 1) / simdgroup_size)) {
auto rc1 = simd_argmin(arg_data[idx], idx_data[idx]);
if (idx == 0) {
idx_data[0] = rc1.second;
}
}
return rc;
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
return idx_data[0];
}
} // namespace metal

View File

@ -330,5 +330,11 @@ inline float log1p(float x) {
return rc;
}
template <typename T1, typename T2 = T1>
struct pair {
T1 first;
T2 second;
};
} // namespace metal
} // namespace c10

View File

@ -0,0 +1,130 @@
#include <c10/core/AllocatorConfig.h>
#include <gtest/gtest.h>
using namespace c10::CachingAllocator;
constexpr size_t kMB = 1024 * 1024ul;
struct ExtendedAllocatorConfig {
static ExtendedAllocatorConfig& instance() {
static ExtendedAllocatorConfig instance;
return instance;
}
// Returns the device-specific option value in bytes.
static size_t device_specific_option() {
return instance().device_specific_option_;
}
static const std::unordered_set<std::string>& getKeys() {
return keys_;
}
void parseArgs(const std::string& env) {
// Parse device-specific options from the environment variable
ConfigTokenizer tokenizer(env);
for (size_t i = 0; i < tokenizer.size(); i++) {
const auto& key = tokenizer[i];
if (key == "device_specific_option_mb") {
tokenizer.checkToken(++i, ":");
device_specific_option_ = tokenizer.toSizeT(++i) * kMB;
} else {
i = tokenizer.skipKey(i);
}
if (i + 1 < tokenizer.size()) {
tokenizer.checkToken(++i, ",");
}
}
}
private:
// Device-specific option, e.g., memory limit for a specific device.
std::atomic<size_t> device_specific_option_{0};
inline static std::unordered_set<std::string> keys_{
"device_specific_option_mb"};
};
REGISTER_ALLOCATOR_CONFIG_PARSE_HOOK(ExtendedAllocatorConfig)
TEST(AllocatorConfigTest, allocator_config_test) {
std::string env =
"max_split_size_mb:40,"
"max_non_split_rounding_mb:30,"
"garbage_collection_threshold:0.5,"
"roundup_power2_divisions:[64:8,128:2,256:4,512:2,1024:4,>:1],"
"expandable_segments:True,"
"pinned_use_background_threads:True,"
"device_specific_option_mb:64";
c10::CachingAllocator::setAllocatorSettings(env);
EXPECT_EQ(c10::CachingAllocator::getAllocatorSettings(), env);
EXPECT_EQ(AcceleratorAllocatorConfig::max_split_size(), 40 * kMB);
EXPECT_EQ(
AcceleratorAllocatorConfig::max_non_split_rounding_size(), 30 * kMB);
EXPECT_EQ(AcceleratorAllocatorConfig::garbage_collection_threshold(), 0.5);
EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(32 * kMB), 8);
EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(64 * kMB), 8);
EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(128 * kMB), 2);
EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(256 * kMB), 4);
EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(512 * kMB), 2);
EXPECT_EQ(
AcceleratorAllocatorConfig::roundup_power2_divisions(1024 * kMB), 4);
EXPECT_EQ(
AcceleratorAllocatorConfig::roundup_power2_divisions(2048 * kMB), 1);
EXPECT_EQ(
AcceleratorAllocatorConfig::roundup_power2_divisions(4096 * kMB), 1);
EXPECT_EQ(
AcceleratorAllocatorConfig::roundup_power2_divisions(8192 * kMB), 1);
EXPECT_EQ(AcceleratorAllocatorConfig::use_expandable_segments(), true);
EXPECT_EQ(AcceleratorAllocatorConfig::pinned_use_background_threads(), true);
EXPECT_EQ(ExtendedAllocatorConfig::device_specific_option(), 64 * kMB);
env =
"max_split_size_mb:20,"
"max_non_split_rounding_mb:40,"
"garbage_collection_threshold:0.8";
c10::CachingAllocator::setAllocatorSettings(env);
EXPECT_EQ(c10::CachingAllocator::getAllocatorSettings(), env);
EXPECT_EQ(AcceleratorAllocatorConfig::max_split_size(), 20 * kMB);
EXPECT_EQ(
AcceleratorAllocatorConfig::max_non_split_rounding_size(), 40 * kMB);
EXPECT_EQ(AcceleratorAllocatorConfig::garbage_collection_threshold(), 0.8);
// roundup_power2_divisions knob array syntax
env = "roundup_power2_divisions:[128:8,256:16,512:1,2048:8,>:2]";
c10::CachingAllocator::setAllocatorSettings(env);
EXPECT_EQ(c10::CachingAllocator::getAllocatorSettings(), env);
EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(64 * kMB), 8);
EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(128 * kMB), 8);
EXPECT_EQ(
AcceleratorAllocatorConfig::roundup_power2_divisions(256 * kMB), 16);
EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(512 * kMB), 1);
EXPECT_EQ(
AcceleratorAllocatorConfig::roundup_power2_divisions(1024 * kMB), 0);
EXPECT_EQ(
AcceleratorAllocatorConfig::roundup_power2_divisions(2048 * kMB), 8);
EXPECT_EQ(
AcceleratorAllocatorConfig::roundup_power2_divisions(4096 * kMB), 2);
// roundup_power2_divisions single value syntax for backward compatibility
env = "roundup_power2_divisions:4";
c10::CachingAllocator::setAllocatorSettings(env);
EXPECT_EQ(c10::CachingAllocator::getAllocatorSettings(), env);
EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(64 * kMB), 4);
EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(256 * kMB), 4);
EXPECT_EQ(
AcceleratorAllocatorConfig::roundup_power2_divisions(2048 * kMB), 4);
env = "expandable_segments:False,";
c10::CachingAllocator::setAllocatorSettings(env);
EXPECT_EQ(c10::CachingAllocator::getAllocatorSettings(), env);
EXPECT_EQ(AcceleratorAllocatorConfig::use_expandable_segments(), false);
env = "pinned_use_background_threads:False";
c10::CachingAllocator::setAllocatorSettings(env);
EXPECT_EQ(c10::CachingAllocator::getAllocatorSettings(), env);
EXPECT_EQ(AcceleratorAllocatorConfig::pinned_use_background_threads(), false);
env = "foo:123,bar:456";
ASSERT_THROW(c10::CachingAllocator::setAllocatorSettings(env), c10::Error);
}

View File

@ -1,340 +1 @@
#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/bit_cast.h>
#include <limits>
C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
#endif
#if defined(CL_SYCL_LANGUAGE_VERSION)
#include <CL/sycl.hpp> // for SYCL 1.2.1
#elif defined(SYCL_LANGUAGE_VERSION)
#include <sycl/sycl.hpp> // for SYCL 2020
#endif
namespace c10 {
/// Constructors
inline C10_HOST_DEVICE BFloat16::BFloat16(float value)
:
#if defined(__CUDACC__) && !defined(USE_ROCM) && defined(__CUDA_ARCH__) && \
__CUDA_ARCH__ >= 800
x(__bfloat16_as_ushort(__float2bfloat16(value)))
#elif defined(__SYCL_DEVICE_ONLY__) && \
defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
x(c10::bit_cast<uint16_t>(sycl::ext::oneapi::bfloat16(value)))
#else
// RNE by default
x(detail::round_to_nearest_even(value))
#endif
{
}
/// Implicit conversions
inline C10_HOST_DEVICE BFloat16::operator float() const {
#if defined(__CUDACC__) && !defined(USE_ROCM)
return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&x));
#elif defined(__SYCL_DEVICE_ONLY__) && \
defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
return float(*reinterpret_cast<const sycl::ext::oneapi::bfloat16*>(&x));
#else
return detail::f32_from_bits(x);
#endif
}
#if defined(__CUDACC__) && !defined(USE_ROCM)
inline C10_HOST_DEVICE BFloat16::BFloat16(const __nv_bfloat16& value) {
x = *reinterpret_cast<const unsigned short*>(&value);
}
inline C10_HOST_DEVICE BFloat16::operator __nv_bfloat16() const {
return *reinterpret_cast<const __nv_bfloat16*>(&x);
}
#endif
#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
inline C10_HOST_DEVICE BFloat16::BFloat16(
const sycl::ext::oneapi::bfloat16& value) {
x = *reinterpret_cast<const unsigned short*>(&value);
}
inline C10_HOST_DEVICE BFloat16::operator sycl::ext::oneapi::bfloat16() const {
return *reinterpret_cast<const sycl::ext::oneapi::bfloat16*>(&x);
}
#endif
// CUDA intrinsics
#if defined(__CUDACC__) || defined(__HIPCC__)
inline C10_DEVICE BFloat16 __ldg(const BFloat16* ptr) {
#if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __ldg(reinterpret_cast<const __nv_bfloat16*>(ptr));
#else
return *ptr;
#endif
}
#endif
/// Arithmetic
inline C10_HOST_DEVICE BFloat16
operator+(const BFloat16& a, const BFloat16& b) {
return static_cast<float>(a) + static_cast<float>(b);
}
inline C10_HOST_DEVICE BFloat16
operator-(const BFloat16& a, const BFloat16& b) {
return static_cast<float>(a) - static_cast<float>(b);
}
inline C10_HOST_DEVICE BFloat16
operator*(const BFloat16& a, const BFloat16& b) {
return static_cast<float>(a) * static_cast<float>(b);
}
inline C10_HOST_DEVICE BFloat16 operator/(const BFloat16& a, const BFloat16& b)
__ubsan_ignore_float_divide_by_zero__ {
return static_cast<float>(a) / static_cast<float>(b);
}
inline C10_HOST_DEVICE BFloat16 operator-(const BFloat16& a) {
return -static_cast<float>(a);
}
inline C10_HOST_DEVICE BFloat16& operator+=(BFloat16& a, const BFloat16& b) {
a = a + b;
return a;
}
inline C10_HOST_DEVICE BFloat16& operator-=(BFloat16& a, const BFloat16& b) {
a = a - b;
return a;
}
inline C10_HOST_DEVICE BFloat16& operator*=(BFloat16& a, const BFloat16& b) {
a = a * b;
return a;
}
inline C10_HOST_DEVICE BFloat16& operator/=(BFloat16& a, const BFloat16& b) {
a = a / b;
return a;
}
inline C10_HOST_DEVICE BFloat16& operator|(BFloat16& a, const BFloat16& b) {
a.x = a.x | b.x;
return a;
}
inline C10_HOST_DEVICE BFloat16& operator^(BFloat16& a, const BFloat16& b) {
a.x = a.x ^ b.x;
return a;
}
inline C10_HOST_DEVICE BFloat16& operator&(BFloat16& a, const BFloat16& b) {
a.x = a.x & b.x;
return a;
}
/// Arithmetic with floats
inline C10_HOST_DEVICE float operator+(BFloat16 a, float b) {
return static_cast<float>(a) + b;
}
inline C10_HOST_DEVICE float operator-(BFloat16 a, float b) {
return static_cast<float>(a) - b;
}
inline C10_HOST_DEVICE float operator*(BFloat16 a, float b) {
return static_cast<float>(a) * b;
}
inline C10_HOST_DEVICE float operator/(BFloat16 a, float b) {
return static_cast<float>(a) / b;
}
inline C10_HOST_DEVICE float operator+(float a, BFloat16 b) {
return a + static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator-(float a, BFloat16 b) {
return a - static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator*(float a, BFloat16 b) {
return a * static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator/(float a, BFloat16 b) {
return a / static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator+=(float& a, const BFloat16& b) {
return a += static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator-=(float& a, const BFloat16& b) {
return a -= static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator*=(float& a, const BFloat16& b) {
return a *= static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator/=(float& a, const BFloat16& b) {
return a /= static_cast<float>(b);
}
/// Arithmetic with doubles
inline C10_HOST_DEVICE double operator+(BFloat16 a, double b) {
return static_cast<double>(a) + b;
}
inline C10_HOST_DEVICE double operator-(BFloat16 a, double b) {
return static_cast<double>(a) - b;
}
inline C10_HOST_DEVICE double operator*(BFloat16 a, double b) {
return static_cast<double>(a) * b;
}
inline C10_HOST_DEVICE double operator/(BFloat16 a, double b) {
return static_cast<double>(a) / b;
}
inline C10_HOST_DEVICE double operator+(double a, BFloat16 b) {
return a + static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator-(double a, BFloat16 b) {
return a - static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator*(double a, BFloat16 b) {
return a * static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator/(double a, BFloat16 b) {
return a / static_cast<double>(b);
}
/// Arithmetic with ints
inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int b) {
return a + static_cast<BFloat16>(b);
}
inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int b) {
return a - static_cast<BFloat16>(b);
}
inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int b) {
return a * static_cast<BFloat16>(b);
}
inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int b) {
return a / static_cast<BFloat16>(b);
}
inline C10_HOST_DEVICE BFloat16 operator+(int a, BFloat16 b) {
return static_cast<BFloat16>(a) + b;
}
inline C10_HOST_DEVICE BFloat16 operator-(int a, BFloat16 b) {
return static_cast<BFloat16>(a) - b;
}
inline C10_HOST_DEVICE BFloat16 operator*(int a, BFloat16 b) {
return static_cast<BFloat16>(a) * b;
}
inline C10_HOST_DEVICE BFloat16 operator/(int a, BFloat16 b) {
return static_cast<BFloat16>(a) / b;
}
//// Arithmetic with int64_t
inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int64_t b) {
return a + static_cast<BFloat16>(b);
}
inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int64_t b) {
return a - static_cast<BFloat16>(b);
}
inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int64_t b) {
return a * static_cast<BFloat16>(b);
}
inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int64_t b) {
return a / static_cast<BFloat16>(b);
}
inline C10_HOST_DEVICE BFloat16 operator+(int64_t a, BFloat16 b) {
return static_cast<BFloat16>(a) + b;
}
inline C10_HOST_DEVICE BFloat16 operator-(int64_t a, BFloat16 b) {
return static_cast<BFloat16>(a) - b;
}
inline C10_HOST_DEVICE BFloat16 operator*(int64_t a, BFloat16 b) {
return static_cast<BFloat16>(a) * b;
}
inline C10_HOST_DEVICE BFloat16 operator/(int64_t a, BFloat16 b) {
return static_cast<BFloat16>(a) / b;
}
// Overloading < and > operators, because std::max and std::min use them.
inline C10_HOST_DEVICE bool operator>(BFloat16& lhs, BFloat16& rhs) {
return float(lhs) > float(rhs);
}
inline C10_HOST_DEVICE bool operator<(BFloat16& lhs, BFloat16& rhs) {
return float(lhs) < float(rhs);
}
} // namespace c10
namespace std {
template <>
class numeric_limits<c10::BFloat16> {
public:
static constexpr bool is_signed = true;
static constexpr bool is_specialized = true;
static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr bool has_infinity = true;
static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = true;
static constexpr auto has_denorm = numeric_limits<float>::has_denorm;
static constexpr auto has_denorm_loss =
numeric_limits<float>::has_denorm_loss;
static constexpr auto round_style = numeric_limits<float>::round_style;
static constexpr bool is_iec559 = false;
static constexpr bool is_bounded = true;
static constexpr bool is_modulo = false;
static constexpr int digits = 8;
static constexpr int digits10 = 2;
static constexpr int max_digits10 = 4;
static constexpr int radix = 2;
static constexpr int min_exponent = -125;
static constexpr int min_exponent10 = -37;
static constexpr int max_exponent = 128;
static constexpr int max_exponent10 = 38;
static constexpr auto traps = numeric_limits<float>::traps;
static constexpr auto tinyness_before =
numeric_limits<float>::tinyness_before;
static constexpr c10::BFloat16 min() {
return c10::BFloat16(0x0080, c10::BFloat16::from_bits());
}
static constexpr c10::BFloat16 lowest() {
return c10::BFloat16(0xFF7F, c10::BFloat16::from_bits());
}
static constexpr c10::BFloat16 max() {
return c10::BFloat16(0x7F7F, c10::BFloat16::from_bits());
}
static constexpr c10::BFloat16 epsilon() {
return c10::BFloat16(0x3C00, c10::BFloat16::from_bits());
}
static constexpr c10::BFloat16 round_error() {
return c10::BFloat16(0x3F00, c10::BFloat16::from_bits());
}
static constexpr c10::BFloat16 infinity() {
return c10::BFloat16(0x7F80, c10::BFloat16::from_bits());
}
static constexpr c10::BFloat16 quiet_NaN() {
return c10::BFloat16(0x7FC0, c10::BFloat16::from_bits());
}
static constexpr c10::BFloat16 signaling_NaN() {
return c10::BFloat16(0x7F80, c10::BFloat16::from_bits());
}
static constexpr c10::BFloat16 denorm_min() {
return c10::BFloat16(0x0001, c10::BFloat16::from_bits());
}
};
} // namespace std
C10_CLANG_DIAGNOSTIC_POP()
#include <torch/headeronly/util/BFloat16.h>

View File

@ -1,116 +1 @@
#pragma once
// Defines the bloat16 type (brain floating-point). This representation uses
// 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa.
#include <c10/macros/Macros.h>
#include <c10/util/bit_cast.h>
#include <cmath>
#include <cstdint>
#include <cstring>
#include <iosfwd>
#include <ostream>
#if defined(__CUDACC__) && !defined(USE_ROCM)
#include <cuda_bf16.h>
#endif
#if defined(CL_SYCL_LANGUAGE_VERSION)
#include <CL/sycl.hpp> // for SYCL 1.2.1
#elif defined(SYCL_LANGUAGE_VERSION)
#include <sycl/sycl.hpp> // for SYCL 2020
#endif
namespace c10 {
namespace detail {
inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) {
float res = 0;
uint32_t tmp = src;
tmp <<= 16;
#if defined(USE_ROCM) && defined(__HIPCC__)
float* tempRes;
// We should be using memcpy in order to respect the strict aliasing rule
// but it fails in the HIP environment.
tempRes = reinterpret_cast<float*>(&tmp);
res = *tempRes;
#else
std::memcpy(&res, &tmp, sizeof(tmp));
#endif
return res;
}
inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) {
uint32_t res = 0;
#if defined(USE_ROCM) && defined(__HIPCC__)
// We should be using memcpy in order to respect the strict aliasing rule
// but it fails in the HIP environment.
uint32_t* tempRes = reinterpret_cast<uint32_t*>(&src);
res = *tempRes;
#else
std::memcpy(&res, &src, sizeof(res));
#endif
return res >> 16;
}
inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) {
#if defined(USE_ROCM) && defined(__HIPCC__)
if (src != src) {
#elif defined(_MSC_VER)
if (isnan(src)) {
#else
if (std::isnan(src)) {
#endif
return UINT16_C(0x7FC0);
} else {
const uint32_t U32 = c10::bit_cast<uint32_t>(src);
uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
return static_cast<uint16_t>((U32 + rounding_bias) >> 16);
}
}
} // namespace detail
struct alignas(2) BFloat16 {
uint16_t x;
// HIP wants __host__ __device__ tag, CUDA does not
#if defined(USE_ROCM) && defined(__HIPCC__)
C10_HOST_DEVICE BFloat16() = default;
#else
BFloat16() = default;
#endif
struct from_bits_t {};
static constexpr C10_HOST_DEVICE from_bits_t from_bits() {
return from_bits_t();
}
constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t)
: x(bits) {}
/* implicit */ inline C10_HOST_DEVICE BFloat16(float value);
inline C10_HOST_DEVICE operator float() const;
#if defined(__CUDACC__) && !defined(USE_ROCM)
inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16& value);
explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const;
#endif
#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
inline C10_HOST_DEVICE BFloat16(const sycl::ext::oneapi::bfloat16& value);
explicit inline C10_HOST_DEVICE operator sycl::ext::oneapi::bfloat16() const;
#endif
};
inline std::ostream& operator<<(std::ostream& out, const BFloat16& value) {
out << (float)value;
return out;
}
} // namespace c10
#include <c10/util/BFloat16-inl.h> // IWYU pragma: keep
#include <torch/headeronly/util/BFloat16.h>

View File

@ -1,28 +1 @@
#pragma once
#include <cstdint>
#include <c10/macros/Macros.h>
/// Defines the Float4_e2m1fn_x2 type (4-bit floating-point, two elements packed
/// into one byte). This is the FP4 dtype from the OCP MX format spec
/// (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf,
/// Section 5.3.3)
///
/// Given two high precision values val0 and val1, here is the
/// binary configuration of their packed representation, from MSB to LSB:
///
/// original value | val1 : val0
/// ========================================
/// bit index (MSB==7, LSB==0) | 7654 : 3210
/// sign/exponent/mantissa | seem : seem
///
namespace c10 {
struct alignas(1) Float4_e2m1fn_x2 {
uint8_t val_;
Float4_e2m1fn_x2() = default;
C10_HOST_DEVICE explicit Float4_e2m1fn_x2(uint8_t val) : val_(val) {}
};
} // namespace c10
#include <torch/headeronly/util/Float4_e2m1fn_x2.h>

View File

@ -1,274 +1 @@
#pragma once
#include <c10/macros/Macros.h>
#include <cstdint>
#include <limits>
C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
#endif
namespace c10 {
/// Constructors
inline C10_HOST_DEVICE Float8_e4m3fn::Float8_e4m3fn(float value)
: x(detail::fp8e4m3fn_from_fp32_value(value)) {}
/// Implicit conversions
inline C10_HOST_DEVICE Float8_e4m3fn::operator float() const {
return detail::fp8e4m3fn_to_fp32_value(x);
}
/// Special values helper
inline C10_HOST_DEVICE bool Float8_e4m3fn::isnan() const {
return (x & 0b01111111) == 0b01111111;
}
/// Arithmetic
inline C10_HOST_DEVICE Float8_e4m3fn
operator+(const Float8_e4m3fn& a, const Float8_e4m3fn& b) {
return static_cast<float>(a) + static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn
operator-(const Float8_e4m3fn& a, const Float8_e4m3fn& b) {
return static_cast<float>(a) - static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn
operator*(const Float8_e4m3fn& a, const Float8_e4m3fn& b) {
return static_cast<float>(a) * static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn operator/(
const Float8_e4m3fn& a,
const Float8_e4m3fn& b) __ubsan_ignore_float_divide_by_zero__ {
return static_cast<float>(a) / static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn operator-(const Float8_e4m3fn& a) {
return -static_cast<float>(a);
}
inline C10_HOST_DEVICE Float8_e4m3fn& operator+=(
Float8_e4m3fn& a,
const Float8_e4m3fn& b) {
a = a + b;
return a;
}
inline C10_HOST_DEVICE Float8_e4m3fn& operator-=(
Float8_e4m3fn& a,
const Float8_e4m3fn& b) {
a = a - b;
return a;
}
inline C10_HOST_DEVICE Float8_e4m3fn& operator*=(
Float8_e4m3fn& a,
const Float8_e4m3fn& b) {
a = a * b;
return a;
}
inline C10_HOST_DEVICE Float8_e4m3fn& operator/=(
Float8_e4m3fn& a,
const Float8_e4m3fn& b) {
a = a / b;
return a;
}
/// Arithmetic with floats
inline C10_HOST_DEVICE float operator+(Float8_e4m3fn a, float b) {
return static_cast<float>(a) + b;
}
inline C10_HOST_DEVICE float operator-(Float8_e4m3fn a, float b) {
return static_cast<float>(a) - b;
}
inline C10_HOST_DEVICE float operator*(Float8_e4m3fn a, float b) {
return static_cast<float>(a) * b;
}
inline C10_HOST_DEVICE float operator/(Float8_e4m3fn a, float b)
__ubsan_ignore_float_divide_by_zero__ {
return static_cast<float>(a) / b;
}
inline C10_HOST_DEVICE float operator+(float a, Float8_e4m3fn b) {
return a + static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator-(float a, Float8_e4m3fn b) {
return a - static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator*(float a, Float8_e4m3fn b) {
return a * static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator/(float a, Float8_e4m3fn b)
__ubsan_ignore_float_divide_by_zero__ {
return a / static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e4m3fn& b) {
return a += static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e4m3fn& b) {
return a -= static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e4m3fn& b) {
return a *= static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e4m3fn& b) {
return a /= static_cast<float>(b);
}
/// Arithmetic with doubles
inline C10_HOST_DEVICE double operator+(Float8_e4m3fn a, double b) {
return static_cast<double>(a) + b;
}
inline C10_HOST_DEVICE double operator-(Float8_e4m3fn a, double b) {
return static_cast<double>(a) - b;
}
inline C10_HOST_DEVICE double operator*(Float8_e4m3fn a, double b) {
return static_cast<double>(a) * b;
}
inline C10_HOST_DEVICE double operator/(Float8_e4m3fn a, double b)
__ubsan_ignore_float_divide_by_zero__ {
return static_cast<double>(a) / b;
}
inline C10_HOST_DEVICE double operator+(double a, Float8_e4m3fn b) {
return a + static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator-(double a, Float8_e4m3fn b) {
return a - static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator*(double a, Float8_e4m3fn b) {
return a * static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator/(double a, Float8_e4m3fn b)
__ubsan_ignore_float_divide_by_zero__ {
return a / static_cast<double>(b);
}
/// Arithmetic with ints
inline C10_HOST_DEVICE Float8_e4m3fn operator+(Float8_e4m3fn a, int b) {
return a + static_cast<Float8_e4m3fn>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn operator-(Float8_e4m3fn a, int b) {
return a - static_cast<Float8_e4m3fn>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn operator*(Float8_e4m3fn a, int b) {
return a * static_cast<Float8_e4m3fn>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn operator/(Float8_e4m3fn a, int b) {
return a / static_cast<Float8_e4m3fn>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn operator+(int a, Float8_e4m3fn b) {
return static_cast<Float8_e4m3fn>(a) + b;
}
inline C10_HOST_DEVICE Float8_e4m3fn operator-(int a, Float8_e4m3fn b) {
return static_cast<Float8_e4m3fn>(a) - b;
}
inline C10_HOST_DEVICE Float8_e4m3fn operator*(int a, Float8_e4m3fn b) {
return static_cast<Float8_e4m3fn>(a) * b;
}
inline C10_HOST_DEVICE Float8_e4m3fn operator/(int a, Float8_e4m3fn b) {
return static_cast<Float8_e4m3fn>(a) / b;
}
//// Arithmetic with int64_t
inline C10_HOST_DEVICE Float8_e4m3fn operator+(Float8_e4m3fn a, int64_t b) {
return a + static_cast<Float8_e4m3fn>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn operator-(Float8_e4m3fn a, int64_t b) {
return a - static_cast<Float8_e4m3fn>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn operator*(Float8_e4m3fn a, int64_t b) {
return a * static_cast<Float8_e4m3fn>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn operator/(Float8_e4m3fn a, int64_t b) {
return a / static_cast<Float8_e4m3fn>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn operator+(int64_t a, Float8_e4m3fn b) {
return static_cast<Float8_e4m3fn>(a) + b;
}
inline C10_HOST_DEVICE Float8_e4m3fn operator-(int64_t a, Float8_e4m3fn b) {
return static_cast<Float8_e4m3fn>(a) - b;
}
inline C10_HOST_DEVICE Float8_e4m3fn operator*(int64_t a, Float8_e4m3fn b) {
return static_cast<Float8_e4m3fn>(a) * b;
}
inline C10_HOST_DEVICE Float8_e4m3fn operator/(int64_t a, Float8_e4m3fn b) {
return static_cast<Float8_e4m3fn>(a) / b;
}
/// NOTE: we do not define comparisons directly and instead rely on the implicit
/// conversion from c10::Float8_e4m3fn to float.
} // namespace c10
namespace std {
template <>
class numeric_limits<c10::Float8_e4m3fn> {
public:
static constexpr bool is_specialized = true;
static constexpr bool is_signed = true;
static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr bool has_infinity = false;
static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = false;
static constexpr auto has_denorm = true;
static constexpr auto has_denorm_loss = true;
static constexpr auto round_style = numeric_limits<float>::round_style;
static constexpr bool is_iec559 = false;
static constexpr bool is_bounded = true;
static constexpr bool is_modulo = false;
static constexpr int digits = 4;
static constexpr int digits10 = 0;
static constexpr int max_digits10 = 3;
static constexpr int radix = 2;
static constexpr int min_exponent = -5;
static constexpr int min_exponent10 = -1;
static constexpr int max_exponent = 8;
static constexpr int max_exponent10 = 2;
static constexpr auto traps = numeric_limits<float>::traps;
static constexpr auto tinyness_before = false;
static constexpr c10::Float8_e4m3fn min() {
return c10::Float8_e4m3fn(0x08, c10::Float8_e4m3fn::from_bits());
}
static constexpr c10::Float8_e4m3fn lowest() {
return c10::Float8_e4m3fn(0xFE, c10::Float8_e4m3fn::from_bits());
}
static constexpr c10::Float8_e4m3fn max() {
return c10::Float8_e4m3fn(0x7E, c10::Float8_e4m3fn::from_bits());
}
static constexpr c10::Float8_e4m3fn epsilon() {
return c10::Float8_e4m3fn(0x20, c10::Float8_e4m3fn::from_bits());
}
static constexpr c10::Float8_e4m3fn round_error() {
return c10::Float8_e4m3fn(0x30, c10::Float8_e4m3fn::from_bits());
}
static constexpr c10::Float8_e4m3fn quiet_NaN() {
return c10::Float8_e4m3fn(0x7F, c10::Float8_e4m3fn::from_bits());
}
static constexpr c10::Float8_e4m3fn denorm_min() {
return c10::Float8_e4m3fn(0x01, c10::Float8_e4m3fn::from_bits());
}
};
} // namespace std
C10_CLANG_DIAGNOSTIC_POP()
#include <torch/headeronly/util/Float8_e4m3fn.h>

View File

@ -1,238 +1 @@
#pragma once
/// Defines the Float8_e4m3fn type (8-bit floating-point) including conversions
/// to standard C types and basic arithmetic operations. Note that arithmetic
/// operations are implemented by converting to floating point and
/// performing the operation in float32.
/// Binary configuration:
/// s eeee mmm
/// 1 sign bit
/// 4 exponent bits
/// 3 mantissa bits
/// bias = 7
///
/// Implementation based on the paper https://arxiv.org/pdf/2209.05433.pdf
/// and inspired by Half implementation from pytorch/c10/util/Half.h
#include <c10/macros/Macros.h>
#include <c10/util/floating_point_utils.h>
#if defined(__cplusplus)
#include <cmath>
#include <cstdint>
#elif !defined(__OPENCL_VERSION__)
#include <math.h>
#include <stdint.h>
#endif
#ifdef _MSC_VER
#include <intrin.h>
#endif
#include <climits>
#include <iostream>
namespace c10 {
namespace detail {
/*
* Convert a 8-bit floating-point number in fp8 E4M3FN format, in bit
* representation, to a 32-bit floating-point number in IEEE single-precision
* format, in bit representation.
*
* @note The implementation doesn't use any floating-point operations.
*/
inline C10_HOST_DEVICE float fp8e4m3fn_to_fp32_value(uint8_t input) {
/*
* Extend the fp8 E4M3FN number to 32 bits and shift to the
* upper part of the 32-bit word:
* +---+----+---+-----------------------------+
* | S |EEEE|MMM|0000 0000 0000 0000 0000 0000|
* +---+----+---+-----------------------------+
* Bits 31 27-30 24-26 0-23
*
* S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0
* - zero bits.
*/
const uint32_t w = (uint32_t)input << 24;
/*
* Extract the sign of the input number into the high bit of the 32-bit word:
*
* +---+----------------------------------+
* | S |0000000 00000000 00000000 00000000|
* +---+----------------------------------+
* Bits 31 0-31
*/
const uint32_t sign = w & UINT32_C(0x80000000);
/*
* Extract mantissa and biased exponent of the input number into the bits 0-30
* of the 32-bit word:
*
* +---+----+---+-----------------------------+
* | S |EEEE|MMM|0000 0000 0000 0000 0000 0000|
* +---+----+---+-----------------------------+
* Bits 31 27-30 24-26 0-23
*/
const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
/*
* Renorm shift is the number of bits to shift mantissa left to make the
* half-precision number normalized. If the initial number is normalized, some
* of its high 5 bits (sign == 0 and 4-bit exponent) equals one. In this case
* renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note
* that if we shift denormalized nonsign by renorm_shift, the unit bit of
* mantissa will shift into exponent, turning the biased exponent into 1, and
* making mantissa normalized (i.e. without leading 1).
*/
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
uint32_t renorm_shift = __clz(nonsign);
#elif defined(__SYCL_DEVICE_ONLY__)
// Note: zero is not a supported input into `__builtin_clz`
uint32_t renorm_shift =
nonsign != 0 ? __builtin_clz(nonsign) : sizeof(uint32_t) * CHAR_BIT;
#elif defined(_MSC_VER) && !defined(__clang__)
unsigned long nonsign_bsr;
_BitScanReverse(&nonsign_bsr, (unsigned long)nonsign);
uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31;
#else
// Note: zero is not a supported input into `__builtin_clz`
uint32_t renorm_shift =
nonsign != 0 ? __builtin_clz(nonsign) : sizeof(uint32_t) * CHAR_BIT;
#endif
renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
/*
* Iff fp8e4m3fn number has all exponent and mantissa bits set to 1,
* the addition overflows it into bit 31, and the subsequent shift turns the
* high 9 bits into 1. Thus inf_nan_mask == 0x7F800000 if the fp8e4m3fn number
* is Nan, 0x00000000 otherwise
*/
const int32_t inf_nan_mask =
((int32_t)(nonsign + 0x01000000) >> 8) & INT32_C(0x7F800000);
/*
* Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31
* into 1. Otherwise, bit 31 remains 0. The signed shift right by 31
* broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask ==
* 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h)
* 0x00000000 otherwise
*/
const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31;
/*
* 1. Shift nonsign left by renorm_shift to normalize it (if the input
* was denormal)
* 2. Shift nonsign right by 4 so the exponent (4 bits originally)
* becomes an 8-bit field and 3-bit mantissa shifts into the 3 high
* bits of the 23-bit mantissa of IEEE single-precision number.
* 3. Add 0x78 to the exponent (starting at bit 23) to compensate the
* different in exponent bias (0x7F for single-precision number less 0x07
* for fp8e4m3fn number).
* 4. Subtract renorm_shift from the exponent (starting at bit 23) to
* account for renormalization. As renorm_shift is less than 0x78, this
* can be combined with step 3.
* 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the
* input was NaN or infinity.
* 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent
* into zero if the input was zero.
* 7. Combine with the sign of the input number.
*/
uint32_t result = sign |
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
inf_nan_mask) &
~zero_mask);
return fp32_from_bits(result);
}
/*
* Convert a 32-bit floating-point number in IEEE single-precision format to a
* 8-bit floating-point number in fp8 E4M3FN format, in bit representation.
*/
inline C10_HOST_DEVICE uint8_t fp8e4m3fn_from_fp32_value(float f) {
/*
* Binary representation of 480.0f, which is the first value
* not representable in fp8e4m3fn range:
* 0 1111 111 - fp8e4m3fn
* 0 10000111 11100000000000000000000 - fp32
*/
constexpr uint32_t fp8_max = UINT32_C(1087) << 20;
/*
* A mask for converting fp32 numbers lower than fp8e4m3fn normal range
* into denorm representation
* magic number: ((127 - 7) + (23 - 3) + 1)
*/
constexpr uint32_t denorm_mask = UINT32_C(141) << 23;
uint32_t f_bits = fp32_to_bits(f);
uint8_t result = 0u;
/*
* Extract the sign of the input number into the high bit of the 32-bit word:
*
* +---+----------------------------------+
* | S |0000000 00000000 00000000 00000000|
* +---+----------------------------------+
* Bits 31 0-31
*/
const uint32_t sign = f_bits & UINT32_C(0x80000000);
/*
* Set sign bit to 0
*/
f_bits ^= sign;
if (f_bits >= fp8_max) {
// NaN - all exponent and mantissa bits set to 1
result = 0x7f;
} else {
if (f_bits < (UINT32_C(121) << 23)) {
// Input number is smaller than 2^(-6), which is the smallest
// fp8e4m3fn normal number
f_bits =
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
result = static_cast<uint8_t>(f_bits - denorm_mask);
} else {
// resulting mantissa is odd
uint8_t mant_odd = (f_bits >> 20) & 1;
// update exponent, rounding bias part 1
f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF;
// rounding bias part 2
f_bits += mant_odd;
// take the bits!
result = static_cast<uint8_t>(f_bits >> 20);
}
}
result |= static_cast<uint8_t>(sign >> 24);
return result;
}
} // namespace detail
struct alignas(1) Float8_e4m3fn {
uint8_t x;
struct from_bits_t {};
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
return from_bits_t();
}
Float8_e4m3fn() = default;
constexpr C10_HOST_DEVICE Float8_e4m3fn(uint8_t bits, from_bits_t)
: x(bits) {}
inline C10_HOST_DEVICE Float8_e4m3fn(float value);
inline C10_HOST_DEVICE operator float() const;
inline C10_HOST_DEVICE bool isnan() const;
};
inline std::ostream& operator<<(std::ostream& out, const Float8_e4m3fn& value) {
out << (float)value;
return out;
}
} // namespace c10
#include <c10/util/Float8_e4m3fn-inl.h> // IWYU pragma: keep
#include <torch/headeronly/util/Float8_e4m3fn.h>

View File

@ -1,279 +1 @@
#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/Float8_fnuz_cvt.h>
#include <cstring>
#include <limits>
C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
#endif
namespace c10 {
/// Constructors
inline C10_HOST_DEVICE Float8_e4m3fnuz::Float8_e4m3fnuz(float value)
: x(detail::fp8e4m3fnuz_from_fp32_value(value)) {}
/// Implicit conversions
inline C10_HOST_DEVICE Float8_e4m3fnuz::operator float() const {
return detail::fp8_fnuz_to_fp32_value<4, 3>(x);
}
/// Special values helper
inline C10_HOST_DEVICE bool Float8_e4m3fnuz::isnan() const {
return x == 0b10000000;
}
/// Arithmetic
inline C10_HOST_DEVICE Float8_e4m3fnuz
operator+(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) {
return static_cast<float>(a) + static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz
operator-(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) {
return static_cast<float>(a) - static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz
operator*(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) {
return static_cast<float>(a) * static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(
const Float8_e4m3fnuz& a,
const Float8_e4m3fnuz& b) __ubsan_ignore_float_divide_by_zero__ {
return static_cast<float>(a) / static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(const Float8_e4m3fnuz& a) {
return -static_cast<float>(a);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz& operator+=(
Float8_e4m3fnuz& a,
const Float8_e4m3fnuz& b) {
a = a + b;
return a;
}
inline C10_HOST_DEVICE Float8_e4m3fnuz& operator-=(
Float8_e4m3fnuz& a,
const Float8_e4m3fnuz& b) {
a = a - b;
return a;
}
inline C10_HOST_DEVICE Float8_e4m3fnuz& operator*=(
Float8_e4m3fnuz& a,
const Float8_e4m3fnuz& b) {
a = a * b;
return a;
}
inline C10_HOST_DEVICE Float8_e4m3fnuz& operator/=(
Float8_e4m3fnuz& a,
const Float8_e4m3fnuz& b) {
a = a / b;
return a;
}
/// Arithmetic with floats
inline C10_HOST_DEVICE float operator+(Float8_e4m3fnuz a, float b) {
return static_cast<float>(a) + b;
}
inline C10_HOST_DEVICE float operator-(Float8_e4m3fnuz a, float b) {
return static_cast<float>(a) - b;
}
inline C10_HOST_DEVICE float operator*(Float8_e4m3fnuz a, float b) {
return static_cast<float>(a) * b;
}
inline C10_HOST_DEVICE float operator/(Float8_e4m3fnuz a, float b)
__ubsan_ignore_float_divide_by_zero__ {
return static_cast<float>(a) / b;
}
inline C10_HOST_DEVICE float operator+(float a, Float8_e4m3fnuz b) {
return a + static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator-(float a, Float8_e4m3fnuz b) {
return a - static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator*(float a, Float8_e4m3fnuz b) {
return a * static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator/(float a, Float8_e4m3fnuz b)
__ubsan_ignore_float_divide_by_zero__ {
return a / static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e4m3fnuz& b) {
return a += static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e4m3fnuz& b) {
return a -= static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e4m3fnuz& b) {
return a *= static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e4m3fnuz& b) {
return a /= static_cast<float>(b);
}
/// Arithmetic with doubles
inline C10_HOST_DEVICE double operator+(Float8_e4m3fnuz a, double b) {
return static_cast<double>(a) + b;
}
inline C10_HOST_DEVICE double operator-(Float8_e4m3fnuz a, double b) {
return static_cast<double>(a) - b;
}
inline C10_HOST_DEVICE double operator*(Float8_e4m3fnuz a, double b) {
return static_cast<double>(a) * b;
}
inline C10_HOST_DEVICE double operator/(Float8_e4m3fnuz a, double b)
__ubsan_ignore_float_divide_by_zero__ {
return static_cast<double>(a) / b;
}
inline C10_HOST_DEVICE double operator+(double a, Float8_e4m3fnuz b) {
return a + static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator-(double a, Float8_e4m3fnuz b) {
return a - static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator*(double a, Float8_e4m3fnuz b) {
return a * static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator/(double a, Float8_e4m3fnuz b)
__ubsan_ignore_float_divide_by_zero__ {
return a / static_cast<double>(b);
}
/// Arithmetic with ints
inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(Float8_e4m3fnuz a, int b) {
return a + static_cast<Float8_e4m3fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(Float8_e4m3fnuz a, int b) {
return a - static_cast<Float8_e4m3fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(Float8_e4m3fnuz a, int b) {
return a * static_cast<Float8_e4m3fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(Float8_e4m3fnuz a, int b) {
return a / static_cast<Float8_e4m3fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(int a, Float8_e4m3fnuz b) {
return static_cast<Float8_e4m3fnuz>(a) + b;
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(int a, Float8_e4m3fnuz b) {
return static_cast<Float8_e4m3fnuz>(a) - b;
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(int a, Float8_e4m3fnuz b) {
return static_cast<Float8_e4m3fnuz>(a) * b;
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(int a, Float8_e4m3fnuz b) {
return static_cast<Float8_e4m3fnuz>(a) / b;
}
//// Arithmetic with int64_t
inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(Float8_e4m3fnuz a, int64_t b) {
return a + static_cast<Float8_e4m3fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(Float8_e4m3fnuz a, int64_t b) {
return a - static_cast<Float8_e4m3fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(Float8_e4m3fnuz a, int64_t b) {
return a * static_cast<Float8_e4m3fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(Float8_e4m3fnuz a, int64_t b) {
return a / static_cast<Float8_e4m3fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(int64_t a, Float8_e4m3fnuz b) {
return static_cast<Float8_e4m3fnuz>(a) + b;
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(int64_t a, Float8_e4m3fnuz b) {
return static_cast<Float8_e4m3fnuz>(a) - b;
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(int64_t a, Float8_e4m3fnuz b) {
return static_cast<Float8_e4m3fnuz>(a) * b;
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(int64_t a, Float8_e4m3fnuz b) {
return static_cast<Float8_e4m3fnuz>(a) / b;
}
/// NOTE: we do not define comparisons directly and instead rely on the implicit
/// conversion from c10::Float8_e4m3fnuz to float.
} // namespace c10
namespace std {
template <>
class numeric_limits<c10::Float8_e4m3fnuz> {
public:
static constexpr bool is_specialized = true;
static constexpr bool is_signed = true;
static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr bool has_infinity = false;
static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = false;
static constexpr auto has_denorm = true;
static constexpr auto has_denorm_loss = true;
static constexpr auto round_style = numeric_limits<float>::round_style;
static constexpr bool is_iec559 = false;
static constexpr bool is_bounded = true;
static constexpr bool is_modulo = false;
static constexpr int digits = 4;
static constexpr int digits10 = 0;
static constexpr int max_digits10 = 3;
static constexpr int radix = 2;
static constexpr int min_exponent = -6;
static constexpr int min_exponent10 = -1;
static constexpr int max_exponent = 8;
static constexpr int max_exponent10 = 2;
static constexpr auto traps = numeric_limits<float>::traps;
static constexpr auto tinyness_before = false;
static constexpr c10::Float8_e4m3fnuz min() {
return c10::Float8_e4m3fnuz(0x08, c10::Float8_e4m3fnuz::from_bits());
}
static constexpr c10::Float8_e4m3fnuz lowest() {
return c10::Float8_e4m3fnuz(0xFF, c10::Float8_e4m3fnuz::from_bits());
}
static constexpr c10::Float8_e4m3fnuz max() {
return c10::Float8_e4m3fnuz(0x7F, c10::Float8_e4m3fnuz::from_bits());
}
static constexpr c10::Float8_e4m3fnuz epsilon() {
return c10::Float8_e4m3fnuz(0x28, c10::Float8_e4m3fnuz::from_bits());
}
static constexpr c10::Float8_e4m3fnuz round_error() {
return c10::Float8_e4m3fnuz(0x38, c10::Float8_e4m3fnuz::from_bits());
}
static constexpr c10::Float8_e4m3fnuz infinity() {
// NaN (no infinities)
return c10::Float8_e4m3fnuz(0x80, c10::Float8_e4m3fnuz::from_bits());
}
static constexpr c10::Float8_e4m3fnuz quiet_NaN() {
return c10::Float8_e4m3fnuz(0x80, c10::Float8_e4m3fnuz::from_bits());
}
static constexpr c10::Float8_e4m3fnuz denorm_min() {
return c10::Float8_e4m3fnuz(0x01, c10::Float8_e4m3fnuz::from_bits());
}
};
} // namespace std
C10_CLANG_DIAGNOSTIC_POP()
#include <torch/headeronly/util/Float8_e4m3fnuz.h>

View File

@ -1,139 +1 @@
#pragma once
/// Defines the Float8_e4m3fnuz type (8-bit floating-point) including
/// conversions to standard C types and basic arithmetic operations. Note that
/// arithmetic operations are implemented by converting to floating point and
/// performing the operation in float32.
/// Binary configuration remains the same as Float8_e4m3fn:
/// s eeee mmm
/// 1 sign bit
/// 4 exponent bits
/// 3 mantissa bits
/// The key differences versus Float8_e4m3fn are:
/// bias = 8
/// no infinities or negative zero
/// NaN only when sign bit is 1, rest all 0s
///
/// Implementation based on the paper https://arxiv.org/pdf/2206.02915.pdf and
/// the existing Float8_e4m3fn implementation.
#include <c10/macros/Export.h>
#include <c10/macros/Macros.h>
#include <c10/util/floating_point_utils.h>
#include <type_traits>
#if defined(__cplusplus)
#include <cstdint>
#elif !defined(__OPENCL_VERSION__)
#include <math.h>
#include <stdint.h>
#endif
#include <iosfwd>
#include <ostream>
namespace c10 {
namespace detail {
/*
* Convert a 32-bit floating-point number in IEEE single-precision format to a
* 8-bit floating-point number in fp8 E4M3FNUZ format, in bit representation.
*/
inline C10_HOST_DEVICE uint8_t fp8e4m3fnuz_from_fp32_value(float f) {
/*
* Binary representation of 256.0f, which is the first value not representable
* (i.e. the first value which would overflow in to the sign bit, resulting in
* a NaN) in fp8e4m3fnuz range:
* 1 0000 000 - fp8e4m3fnuz
* 0 10000111 00000000000000000000000 - fp32
*/
constexpr uint32_t fnuz_max = UINT32_C(0x87) << 23;
/*
* A mask for converting fp32 numbers lower than fp8e4m3fnuz normal range
* into denorm representation
* magic number: ((127 - 8) + (23 - 3) + 1)
*/
constexpr uint32_t denorm_mask = UINT32_C(0x8C) << 23;
uint32_t f_bits = fp32_to_bits(f);
uint32_t result = 0u;
/*
* Extract the sign of the input number into the high bit of the 32-bit word:
*
* +---+----------------------------------+
* | S |0000000 00000000 00000000 00000000|
* +---+----------------------------------+
* Bits 31 0-31
*/
const uint32_t sign = f_bits & UINT32_C(0x80000000);
/*
* Set sign bit to 0
*/
f_bits ^= sign;
if (f_bits >= fnuz_max) {
// NaN -- sign bit set to 1, rest 0s.
return 0x80;
}
if (f_bits < (UINT32_C(0x78) << 23) /* 2^-7 in float32 */) {
// Input exponent is less than -7, the smallest e4m3fnuz exponent, so the
// number will become subnormal.
f_bits = fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
result = static_cast<uint8_t>(f_bits - denorm_mask);
if (result == 0) {
// fnuz types don't have negative zero.
return 0;
}
} else {
// resulting mantissa is odd
uint8_t mant_odd = (f_bits >> 20) & 1;
// update exponent, rounding bias part 1
f_bits += ((uint32_t)(8 - 127) << 23) + 0x7FFFF;
// rounding bias part 2
f_bits += mant_odd;
// take the bits!
result = static_cast<uint8_t>(f_bits >> 20);
}
result |= sign >> 24;
return result;
}
} // namespace detail
struct alignas(1) Float8_e4m3fnuz {
uint8_t x;
struct from_bits_t {};
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
return from_bits_t();
}
Float8_e4m3fnuz() = default;
constexpr C10_HOST_DEVICE Float8_e4m3fnuz(uint8_t bits, from_bits_t)
: x(bits) {}
inline C10_HOST_DEVICE Float8_e4m3fnuz(float value);
inline C10_HOST_DEVICE operator float() const;
inline C10_HOST_DEVICE bool isnan() const;
};
inline std::ostream& operator<<(
std::ostream& out,
const Float8_e4m3fnuz& value) {
out << (float)value;
return out;
}
} // namespace c10
#include <c10/util/Float8_e4m3fnuz-inl.h> // IWYU pragma: keep
#include <torch/headeronly/util/Float8_e4m3fnuz.h>

View File

@ -1,286 +1 @@
#pragma once
#include <c10/macros/Macros.h>
#include <cstring>
#include <limits>
C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
#endif
#define EXP_WIDTH_FP8 5
#define MAN_WIDTH_FP8 2
#define EXP_BIAS_FP8 15
namespace c10 {
/// Constructors
inline C10_HOST_DEVICE Float8_e5m2::Float8_e5m2(float value)
: x(detail::fp8e5m2_from_fp32_value(value)) {}
/// Implicit conversions
inline C10_HOST_DEVICE Float8_e5m2::operator float() const {
return detail::fp8e5m2_to_fp32_value(x);
}
/// Special values helpers
inline C10_HOST_DEVICE bool Float8_e5m2::isnan() const {
return (x & 0b01111111) > 0b01111100;
}
inline C10_HOST_DEVICE bool Float8_e5m2::isinf() const {
return (x & 0b01111111) == 0b01111100;
}
/// Arithmetic
inline C10_HOST_DEVICE Float8_e5m2
operator+(const Float8_e5m2& a, const Float8_e5m2& b) {
return static_cast<float>(a) + static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e5m2
operator-(const Float8_e5m2& a, const Float8_e5m2& b) {
return static_cast<float>(a) - static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e5m2
operator*(const Float8_e5m2& a, const Float8_e5m2& b) {
return static_cast<float>(a) * static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e5m2 operator/(
const Float8_e5m2& a,
const Float8_e5m2& b) __ubsan_ignore_float_divide_by_zero__ {
return static_cast<float>(a) / static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e5m2 operator-(const Float8_e5m2& a) {
return -static_cast<float>(a);
}
inline C10_HOST_DEVICE Float8_e5m2& operator+=(
Float8_e5m2& a,
const Float8_e5m2& b) {
a = a + b;
return a;
}
inline C10_HOST_DEVICE Float8_e5m2& operator-=(
Float8_e5m2& a,
const Float8_e5m2& b) {
a = a - b;
return a;
}
inline C10_HOST_DEVICE Float8_e5m2& operator*=(
Float8_e5m2& a,
const Float8_e5m2& b) {
a = a * b;
return a;
}
inline C10_HOST_DEVICE Float8_e5m2& operator/=(
Float8_e5m2& a,
const Float8_e5m2& b) {
a = a / b;
return a;
}
/// Arithmetic with floats
inline C10_HOST_DEVICE float operator+(Float8_e5m2 a, float b) {
return static_cast<float>(a) + b;
}
inline C10_HOST_DEVICE float operator-(Float8_e5m2 a, float b) {
return static_cast<float>(a) - b;
}
inline C10_HOST_DEVICE float operator*(Float8_e5m2 a, float b) {
return static_cast<float>(a) * b;
}
inline C10_HOST_DEVICE float operator/(Float8_e5m2 a, float b)
__ubsan_ignore_float_divide_by_zero__ {
return static_cast<float>(a) / b;
}
inline C10_HOST_DEVICE float operator+(float a, Float8_e5m2 b) {
return a + static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator-(float a, Float8_e5m2 b) {
return a - static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator*(float a, Float8_e5m2 b) {
return a * static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator/(float a, Float8_e5m2 b)
__ubsan_ignore_float_divide_by_zero__ {
return a / static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e5m2& b) {
return a += static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e5m2& b) {
return a -= static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e5m2& b) {
return a *= static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e5m2& b) {
return a /= static_cast<float>(b);
}
/// Arithmetic with doubles
inline C10_HOST_DEVICE double operator+(Float8_e5m2 a, double b) {
return static_cast<double>(a) + b;
}
inline C10_HOST_DEVICE double operator-(Float8_e5m2 a, double b) {
return static_cast<double>(a) - b;
}
inline C10_HOST_DEVICE double operator*(Float8_e5m2 a, double b) {
return static_cast<double>(a) * b;
}
inline C10_HOST_DEVICE double operator/(Float8_e5m2 a, double b)
__ubsan_ignore_float_divide_by_zero__ {
return static_cast<double>(a) / b;
}
inline C10_HOST_DEVICE double operator+(double a, Float8_e5m2 b) {
return a + static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator-(double a, Float8_e5m2 b) {
return a - static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator*(double a, Float8_e5m2 b) {
return a * static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator/(double a, Float8_e5m2 b)
__ubsan_ignore_float_divide_by_zero__ {
return a / static_cast<double>(b);
}
/// Arithmetic with ints
inline C10_HOST_DEVICE Float8_e5m2 operator+(Float8_e5m2 a, int b) {
return a + static_cast<Float8_e5m2>(b);
}
inline C10_HOST_DEVICE Float8_e5m2 operator-(Float8_e5m2 a, int b) {
return a - static_cast<Float8_e5m2>(b);
}
inline C10_HOST_DEVICE Float8_e5m2 operator*(Float8_e5m2 a, int b) {
return a * static_cast<Float8_e5m2>(b);
}
inline C10_HOST_DEVICE Float8_e5m2 operator/(Float8_e5m2 a, int b) {
return a / static_cast<Float8_e5m2>(b);
}
inline C10_HOST_DEVICE Float8_e5m2 operator+(int a, Float8_e5m2 b) {
return static_cast<Float8_e5m2>(a) + b;
}
inline C10_HOST_DEVICE Float8_e5m2 operator-(int a, Float8_e5m2 b) {
return static_cast<Float8_e5m2>(a) - b;
}
inline C10_HOST_DEVICE Float8_e5m2 operator*(int a, Float8_e5m2 b) {
return static_cast<Float8_e5m2>(a) * b;
}
inline C10_HOST_DEVICE Float8_e5m2 operator/(int a, Float8_e5m2 b) {
return static_cast<Float8_e5m2>(a) / b;
}
//// Arithmetic with int64_t
inline C10_HOST_DEVICE Float8_e5m2 operator+(Float8_e5m2 a, int64_t b) {
return a + static_cast<Float8_e5m2>(b);
}
inline C10_HOST_DEVICE Float8_e5m2 operator-(Float8_e5m2 a, int64_t b) {
return a - static_cast<Float8_e5m2>(b);
}
inline C10_HOST_DEVICE Float8_e5m2 operator*(Float8_e5m2 a, int64_t b) {
return a * static_cast<Float8_e5m2>(b);
}
inline C10_HOST_DEVICE Float8_e5m2 operator/(Float8_e5m2 a, int64_t b) {
return a / static_cast<Float8_e5m2>(b);
}
inline C10_HOST_DEVICE Float8_e5m2 operator+(int64_t a, Float8_e5m2 b) {
return static_cast<Float8_e5m2>(a) + b;
}
inline C10_HOST_DEVICE Float8_e5m2 operator-(int64_t a, Float8_e5m2 b) {
return static_cast<Float8_e5m2>(a) - b;
}
inline C10_HOST_DEVICE Float8_e5m2 operator*(int64_t a, Float8_e5m2 b) {
return static_cast<Float8_e5m2>(a) * b;
}
inline C10_HOST_DEVICE Float8_e5m2 operator/(int64_t a, Float8_e5m2 b) {
return static_cast<Float8_e5m2>(a) / b;
}
/// NOTE: we do not define comparisons directly and instead rely on the implicit
/// conversion from c10::Float8_e5m2 to float.
} // namespace c10
namespace std {
template <>
class numeric_limits<c10::Float8_e5m2> {
public:
static constexpr bool is_signed = true;
static constexpr bool is_integer = false;
static constexpr bool is_specialized = true;
static constexpr bool is_exact = false;
static constexpr bool has_infinity = true;
static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = false;
static constexpr auto has_denorm = true;
static constexpr auto has_denorm_loss = true;
static constexpr auto round_style = numeric_limits<float>::round_style;
static constexpr bool is_iec559 = false;
static constexpr bool is_bounded = true;
static constexpr bool is_modulo = false;
static constexpr int digits = 3;
static constexpr int digits10 = 0;
static constexpr int max_digits10 = 2;
static constexpr int radix = 2;
static constexpr int min_exponent = -13;
static constexpr int min_exponent10 = -4;
static constexpr int max_exponent = 16;
static constexpr int max_exponent10 = 4;
static constexpr auto traps = numeric_limits<float>::traps;
static constexpr auto tinyness_before =
numeric_limits<float>::tinyness_before;
static constexpr c10::Float8_e5m2 min() {
return c10::Float8_e5m2(0x4, c10::Float8_e5m2::from_bits());
}
static constexpr c10::Float8_e5m2 max() {
return c10::Float8_e5m2(0x7B, c10::Float8_e5m2::from_bits());
}
static constexpr c10::Float8_e5m2 lowest() {
return c10::Float8_e5m2(0xFB, c10::Float8_e5m2::from_bits());
}
static constexpr c10::Float8_e5m2 epsilon() {
return c10::Float8_e5m2(0x34, c10::Float8_e5m2::from_bits());
}
static constexpr c10::Float8_e5m2 round_error() {
return c10::Float8_e5m2(0x38, c10::Float8_e5m2::from_bits());
}
static constexpr c10::Float8_e5m2 infinity() {
return c10::Float8_e5m2(0x7C, c10::Float8_e5m2::from_bits());
}
static constexpr c10::Float8_e5m2 quiet_NaN() {
return c10::Float8_e5m2(0x7F, c10::Float8_e5m2::from_bits());
}
static constexpr c10::Float8_e5m2 denorm_min() {
return c10::Float8_e5m2(0x01, c10::Float8_e5m2::from_bits());
}
};
} // namespace std
C10_CLANG_DIAGNOSTIC_POP()
#include <torch/headeronly/util/Float8_e5m2.h>

View File

@ -1,146 +1 @@
#pragma once
/// Defines the Float8_e5m2 type (8-bit floating-point) including conversions
/// to standard C types and basic arithmetic operations. Note that arithmetic
/// operations are implemented by converting to floating point and
/// performing the operation in float32.
/// Binary configuration:
/// s eeeee mm
/// 1 sign bit
/// 5 exponent bits
/// 2 mantissa bits
/// bias = 15
///
/// Implementation based on the paper https://arxiv.org/pdf/2209.05433.pdf
/// and inspired by Half implementation from pytorch/c10/util/Half.h
#include <c10/util/Half.h>
namespace c10 {
namespace detail {
/*
* Convert a 8-bit floating-point number in fp8 E5M2 format, in bit
* representation, to a 32-bit floating-point number in IEEE single-precision
* format, in bit representation.
*
* @note The implementation doesn't use any floating-point operations.
*/
inline C10_HOST_DEVICE float fp8e5m2_to_fp32_value(uint8_t input) {
/*
* Extend the fp8 E5M2 number to 32 bits and shift to the
* upper part of the 32-bit word:
* +---+----+---+-----------------------------+
* | S |EEEEE|MM|0000 0000 0000 0000 0000 0000|
* +---+----+---+-----------------------------+
* Bits 31 26-30 24-25 0-23
*
* S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0
* - zero bits.
*/
uint16_t half_representation = input;
half_representation <<= 8;
return fp16_ieee_to_fp32_value(half_representation);
}
/*
* Convert a 32-bit floating-point number in IEEE single-precision format to a
* 8-bit floating-point number in fp8 E5M2 format, in bit representation.
*/
inline C10_HOST_DEVICE uint8_t fp8e5m2_from_fp32_value(float f) {
/*
* Binary representation of fp32 infinity
* 0 11111111 00000000000000000000000
*/
constexpr uint32_t fp32_inf = UINT32_C(255) << 23;
/*
* Binary representation of 65536.0f, which is the first value
* not representable in fp8e5m2 range:
* 0 11111 00 - fp8e5m2
* 0 10001111 00000000000000000000000 - fp32
*/
constexpr uint32_t fp8_max = UINT32_C(143) << 23;
/*
* A mask for converting fp32 numbers lower than fp8e5m2 normal range
* into denorm representation
* magic number: ((127 - 15) + (23 - 2) + 1)
*/
constexpr uint32_t denorm_mask = UINT32_C(134) << 23;
uint32_t f_bits = fp32_to_bits(f);
uint8_t result = 0u;
/*
* Extract the sign of the input number into the high bit of the 32-bit word:
*
* +---+----------------------------------+
* | S |0000000 00000000 00000000 00000000|
* +---+----------------------------------+
* Bits 31 0-31
*/
const uint32_t sign = f_bits & UINT32_C(0x80000000);
/*
* Set sign bit to 0
*/
f_bits ^= sign;
if (f_bits >= fp8_max) {
// NaN - all exponent and mantissa bits set to 1
result = f_bits > fp32_inf ? UINT8_C(0x7F) : UINT8_C(0x7C);
} else {
if (f_bits < (UINT32_C(113) << 23)) {
// Input number is smaller than 2^(-14), which is the smallest
// fp8e5m2 normal number
f_bits =
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
result = static_cast<uint8_t>(f_bits - denorm_mask);
} else {
// resulting mantissa is odd
uint32_t mant_odd = (f_bits >> 21) & 1;
// update exponent, rounding bias part 1
f_bits += ((uint32_t)(15 - 127) << 23) + 0xFFFFF;
// rounding bias part 2
f_bits += mant_odd;
// take the bits!
result = static_cast<uint8_t>(f_bits >> 21);
}
}
result |= static_cast<uint8_t>(sign >> 24);
return result;
}
} // namespace detail
struct alignas(1) Float8_e5m2 {
uint8_t x;
struct from_bits_t {};
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
return from_bits_t();
}
Float8_e5m2() = default;
constexpr C10_HOST_DEVICE Float8_e5m2(uint8_t bits, from_bits_t) : x(bits) {}
inline C10_HOST_DEVICE Float8_e5m2(float value);
inline C10_HOST_DEVICE operator float() const;
inline C10_HOST_DEVICE bool isnan() const;
inline C10_HOST_DEVICE bool isinf() const;
};
inline std::ostream& operator<<(std::ostream& out, const Float8_e5m2& value) {
out << (float)value;
return out;
}
} // namespace c10
#include <c10/util/Float8_e5m2-inl.h> // IWYU pragma: keep
#include <torch/headeronly/util/Float8_e5m2.h>

View File

@ -1,285 +1 @@
#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/Float8_fnuz_cvt.h>
#include <cstring>
#include <limits>
C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
#endif
namespace c10 {
/// Constructors
inline C10_HOST_DEVICE Float8_e5m2fnuz::Float8_e5m2fnuz(float value)
: x(detail::fp8e5m2fnuz_from_fp32_value(value)) {}
/// Implicit conversions
inline C10_HOST_DEVICE Float8_e5m2fnuz::operator float() const {
return detail::fp8_fnuz_to_fp32_value<5, 2>(x);
}
/// Special values helpers
inline C10_HOST_DEVICE bool Float8_e5m2fnuz::isnan() const {
return x == 0b10000000;
}
inline C10_HOST_DEVICE bool Float8_e5m2fnuz::isinf() const {
return false;
}
/// Arithmetic
inline C10_HOST_DEVICE Float8_e5m2fnuz
operator+(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) {
return static_cast<float>(a) + static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz
operator-(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) {
return static_cast<float>(a) - static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz
operator*(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) {
return static_cast<float>(a) * static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(
const Float8_e5m2fnuz& a,
const Float8_e5m2fnuz& b) __ubsan_ignore_float_divide_by_zero__ {
return static_cast<float>(a) / static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(const Float8_e5m2fnuz& a) {
return -static_cast<float>(a);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz& operator+=(
Float8_e5m2fnuz& a,
const Float8_e5m2fnuz& b) {
a = a + b;
return a;
}
inline C10_HOST_DEVICE Float8_e5m2fnuz& operator-=(
Float8_e5m2fnuz& a,
const Float8_e5m2fnuz& b) {
a = a - b;
return a;
}
inline C10_HOST_DEVICE Float8_e5m2fnuz& operator*=(
Float8_e5m2fnuz& a,
const Float8_e5m2fnuz& b) {
a = a * b;
return a;
}
inline C10_HOST_DEVICE Float8_e5m2fnuz& operator/=(
Float8_e5m2fnuz& a,
const Float8_e5m2fnuz& b) {
a = a / b;
return a;
}
/// Arithmetic with floats
inline C10_HOST_DEVICE float operator+(Float8_e5m2fnuz a, float b) {
return static_cast<float>(a) + b;
}
inline C10_HOST_DEVICE float operator-(Float8_e5m2fnuz a, float b) {
return static_cast<float>(a) - b;
}
inline C10_HOST_DEVICE float operator*(Float8_e5m2fnuz a, float b) {
return static_cast<float>(a) * b;
}
inline C10_HOST_DEVICE float operator/(Float8_e5m2fnuz a, float b)
__ubsan_ignore_float_divide_by_zero__ {
return static_cast<float>(a) / b;
}
inline C10_HOST_DEVICE float operator+(float a, Float8_e5m2fnuz b) {
return a + static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator-(float a, Float8_e5m2fnuz b) {
return a - static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator*(float a, Float8_e5m2fnuz b) {
return a * static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator/(float a, Float8_e5m2fnuz b)
__ubsan_ignore_float_divide_by_zero__ {
return a / static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e5m2fnuz& b) {
return a += static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e5m2fnuz& b) {
return a -= static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e5m2fnuz& b) {
return a *= static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e5m2fnuz& b) {
return a /= static_cast<float>(b);
}
/// Arithmetic with doubles
inline C10_HOST_DEVICE double operator+(Float8_e5m2fnuz a, double b) {
return static_cast<double>(a) + b;
}
inline C10_HOST_DEVICE double operator-(Float8_e5m2fnuz a, double b) {
return static_cast<double>(a) - b;
}
inline C10_HOST_DEVICE double operator*(Float8_e5m2fnuz a, double b) {
return static_cast<double>(a) * b;
}
inline C10_HOST_DEVICE double operator/(Float8_e5m2fnuz a, double b)
__ubsan_ignore_float_divide_by_zero__ {
return static_cast<double>(a) / b;
}
inline C10_HOST_DEVICE double operator+(double a, Float8_e5m2fnuz b) {
return a + static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator-(double a, Float8_e5m2fnuz b) {
return a - static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator*(double a, Float8_e5m2fnuz b) {
return a * static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator/(double a, Float8_e5m2fnuz b)
__ubsan_ignore_float_divide_by_zero__ {
return a / static_cast<double>(b);
}
/// Arithmetic with ints
inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(Float8_e5m2fnuz a, int b) {
return a + static_cast<Float8_e5m2fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(Float8_e5m2fnuz a, int b) {
return a - static_cast<Float8_e5m2fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(Float8_e5m2fnuz a, int b) {
return a * static_cast<Float8_e5m2fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(Float8_e5m2fnuz a, int b) {
return a / static_cast<Float8_e5m2fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(int a, Float8_e5m2fnuz b) {
return static_cast<Float8_e5m2fnuz>(a) + b;
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(int a, Float8_e5m2fnuz b) {
return static_cast<Float8_e5m2fnuz>(a) - b;
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(int a, Float8_e5m2fnuz b) {
return static_cast<Float8_e5m2fnuz>(a) * b;
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(int a, Float8_e5m2fnuz b) {
return static_cast<Float8_e5m2fnuz>(a) / b;
}
//// Arithmetic with int64_t
inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(Float8_e5m2fnuz a, int64_t b) {
return a + static_cast<Float8_e5m2fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(Float8_e5m2fnuz a, int64_t b) {
return a - static_cast<Float8_e5m2fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(Float8_e5m2fnuz a, int64_t b) {
return a * static_cast<Float8_e5m2fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(Float8_e5m2fnuz a, int64_t b) {
return a / static_cast<Float8_e5m2fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(int64_t a, Float8_e5m2fnuz b) {
return static_cast<Float8_e5m2fnuz>(a) + b;
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(int64_t a, Float8_e5m2fnuz b) {
return static_cast<Float8_e5m2fnuz>(a) - b;
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(int64_t a, Float8_e5m2fnuz b) {
return static_cast<Float8_e5m2fnuz>(a) * b;
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(int64_t a, Float8_e5m2fnuz b) {
return static_cast<Float8_e5m2fnuz>(a) / b;
}
/// NOTE: we do not define comparisons directly and instead rely on the implicit
/// conversion from c10::Float8_e5m2fnuz to float.
} // namespace c10
namespace std {
template <>
class numeric_limits<c10::Float8_e5m2fnuz> {
public:
static constexpr bool is_signed = true;
static constexpr bool is_integer = false;
static constexpr bool is_specialized = true;
static constexpr bool is_exact = false;
static constexpr bool has_infinity = false;
static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = false;
static constexpr auto has_denorm = true;
static constexpr auto has_denorm_loss = true;
static constexpr auto round_style = numeric_limits<float>::round_style;
static constexpr bool is_iec559 = false;
static constexpr bool is_bounded = true;
static constexpr bool is_modulo = false;
static constexpr int digits = 3;
static constexpr int digits10 = 0;
static constexpr int max_digits10 = 2;
static constexpr int radix = 2;
static constexpr int min_exponent = -14;
static constexpr int min_exponent10 = -4;
static constexpr int max_exponent = 16;
static constexpr int max_exponent10 = 4;
static constexpr auto traps = numeric_limits<float>::traps;
static constexpr auto tinyness_before =
numeric_limits<float>::tinyness_before;
static constexpr c10::Float8_e5m2fnuz min() {
return c10::Float8_e5m2fnuz(0x04, c10::Float8_e5m2fnuz::from_bits());
}
static constexpr c10::Float8_e5m2fnuz max() {
return c10::Float8_e5m2fnuz(0x7F, c10::Float8_e5m2fnuz::from_bits());
}
static constexpr c10::Float8_e5m2fnuz lowest() {
return c10::Float8_e5m2fnuz(0xFF, c10::Float8_e5m2fnuz::from_bits());
}
static constexpr c10::Float8_e5m2fnuz epsilon() {
return c10::Float8_e5m2fnuz(0x34, c10::Float8_e5m2fnuz::from_bits());
}
static constexpr c10::Float8_e5m2fnuz round_error() {
return c10::Float8_e5m2fnuz(0x38, c10::Float8_e5m2fnuz::from_bits());
}
static constexpr c10::Float8_e5m2fnuz infinity() {
return c10::Float8_e5m2fnuz(0x80, c10::Float8_e5m2fnuz::from_bits());
}
// TODO(future): we are mapping neg_zero to both inf and NaN, this is
// surprising and we should figure out what to do about it.
static constexpr c10::Float8_e5m2fnuz quiet_NaN() {
return c10::Float8_e5m2fnuz(0x80, c10::Float8_e5m2fnuz::from_bits());
}
static constexpr c10::Float8_e5m2fnuz denorm_min() {
return c10::Float8_e5m2fnuz(0x01, c10::Float8_e5m2fnuz::from_bits());
}
};
} // namespace std
C10_CLANG_DIAGNOSTIC_POP()
#include <torch/headeronly/util/Float8_e5m2fnuz.h>

View File

@ -1,138 +1 @@
#pragma once
/// Defines the Float8_e5m2fnuz type (8-bit floating-point) including
/// conversions to standard C types and basic arithmetic operations. Note that
/// arithmetic operations are implemented by converting to floating point and
/// performing the operation in float32.
/// Binary configuration remains the same as e5m2:
/// s eeeee mm
/// 1 sign bit
/// 5 exponent bits
/// 2 mantissa bits
/// The key differences that e5m2fnuz brings are:
/// bias = 16
/// no infinities or negative zero
/// NaN only when sign bit is 1, rest all 0s
///
/// Implementation based on the paper https://arxiv.org/pdf/2206.02915.pdf and
/// the existing Float8_e4m3fn implementation.
#include <c10/macros/Macros.h>
#include <c10/util/TypeSafeSignMath.h>
#include <c10/util/floating_point_utils.h>
#if defined(__cplusplus)
#include <cstdint>
#elif !defined(__OPENCL_VERSION__)
#include <math.h>
#include <stdint.h>
#endif
#include <iosfwd>
#include <ostream>
namespace c10 {
namespace detail {
/*
* Convert a 32-bit floating-point number in IEEE single-precision format to a
* 8-bit floating-point number in fp8 E5M2 format, in bit representation.
*/
inline C10_HOST_DEVICE uint8_t fp8e5m2fnuz_from_fp32_value(float f) {
/*
* Binary representation of 65536.0f, which is the first value not
* representable (i.e. the first value which would overflow in to the sign
* bit, resulting in a NaN) in fp8e4m3fnuz range:
* 1 00000 00 - fp8e5m2fnuz
* 0 10001111 00000000000000000000000 - fp32
*/
constexpr uint32_t fnuz_max = UINT32_C(0x8F) << 23;
/*
* A mask for converting fp32 numbers lower than fp8e5m2fnuz normal range
* into denormalized representation.
* magic number: ((127 - 16) + (23 - 2) + 1)
*/
constexpr uint32_t denorm_mask = UINT32_C(0x85) << 23;
uint32_t f_bits = fp32_to_bits(f);
uint32_t result = 0u;
/*
* Extract the sign of the input number into the high bit of the 32-bit word:
*
* +---+----------------------------------+
* | S |0000000 00000000 00000000 00000000|
* +---+----------------------------------+
* Bits 31 0-31
*/
const uint32_t sign = f_bits & UINT32_C(0x80000000);
/*
* Set sign bit to 0
*/
f_bits ^= sign;
if (f_bits >= fnuz_max) {
// NaN -- sign bit set to 1, rest 0s
return 0x80;
}
if (f_bits < (UINT32_C(0x70) << 23) /* 2^-15 in float32 */) {
// Input exponent is less than -15, the smallest e5m2fnuz exponent, so the
// number will become subnormal.
f_bits = fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
result = static_cast<uint8_t>(f_bits - denorm_mask);
if (result == 0) {
// fnuz types don't have negative zero.
return 0;
}
} else {
// resulting mantissa is odd
uint8_t mant_odd = (f_bits >> 21) & 1;
// update exponent, rounding bias part 1
f_bits += ((uint32_t)(16 - 127) << 23) + 0xFFFFF;
// rounding bias part 2
f_bits += mant_odd;
// take the bits!
result = static_cast<uint8_t>(f_bits >> 21);
}
result |= sign >> 24;
return result;
}
} // namespace detail
struct alignas(1) Float8_e5m2fnuz {
uint8_t x;
struct from_bits_t {};
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
return from_bits_t();
}
Float8_e5m2fnuz() = default;
constexpr C10_HOST_DEVICE Float8_e5m2fnuz(uint8_t bits, from_bits_t)
: x(bits) {}
inline C10_HOST_DEVICE Float8_e5m2fnuz(float value);
inline C10_HOST_DEVICE operator float() const;
inline C10_HOST_DEVICE bool isnan() const;
inline C10_HOST_DEVICE bool isinf() const;
};
inline std::ostream& operator<<(
std::ostream& out,
const Float8_e5m2fnuz& value) {
out << (float)value;
return out;
}
} // namespace c10
#include <c10/util/Float8_e5m2fnuz-inl.h> // IWYU pragma: keep
#include <torch/headeronly/util/Float8_e5m2fnuz.h>

View File

@ -1,112 +1 @@
#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/floating_point_utils.h>
#include <cstring>
#include <limits>
// TODO(#146647): Can we remove the below warning?
C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
#endif
namespace c10 {
/// Constructors
inline C10_HOST_DEVICE Float8_e8m0fnu::Float8_e8m0fnu(float value)
: x(detail::fp8e8m0fnu_from_fp32_value(value)) {}
/// Implicit conversions
inline C10_HOST_DEVICE Float8_e8m0fnu::operator float() const {
// TODO(#146647): maybe rewrite without control flow
// if exponent is zero, need to special case to return 2^-127 instead of zero
if (x == 0) {
return c10::detail::fp32_from_bits(0x00400000);
}
// if exponent is NaN, need to special case to return properly encoded NaN
if (isnan()) {
return c10::detail::fp32_from_bits(0x7f800001);
}
// leave sign at 0, set the exponent bits, leave stored mantissa at 0
uint32_t res = x << 23;
return c10::detail::fp32_from_bits(res);
}
/// Special values helper
inline C10_HOST_DEVICE bool Float8_e8m0fnu::isnan() const {
return x == 0b11111111;
}
/// NOTE: we do not define comparisons directly and instead rely on the implicit
/// conversion from c10::Float8_e8m0fnu to float.
} // namespace c10
namespace std {
template <>
class numeric_limits<c10::Float8_e8m0fnu> {
public:
static constexpr bool is_specialized = true;
static constexpr bool is_signed = false;
static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr bool has_infinity = false;
static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = false;
static constexpr auto has_denorm = false;
static constexpr auto has_denorm_loss = false;
static constexpr auto round_style = numeric_limits<float>::round_style;
static constexpr bool is_iec559 = false;
static constexpr bool is_bounded = true;
static constexpr bool is_modulo = false;
static constexpr int digits = 1;
static constexpr int digits10 = 0;
static constexpr int max_digits10 = 1; // just a 2!
static constexpr int radix = 2;
static constexpr int min_exponent = -126;
static constexpr int min_exponent10 = -38;
static constexpr int max_exponent = 128;
static constexpr int max_exponent10 = 38;
static constexpr auto traps = numeric_limits<float>::traps;
static constexpr auto tinyness_before = false;
static constexpr c10::Float8_e8m0fnu min() {
// 2^-127
return c10::Float8_e8m0fnu(0b00000000, c10::Float8_e8m0fnu::from_bits());
}
static constexpr c10::Float8_e8m0fnu lowest() {
// 2^-127
return c10::Float8_e8m0fnu(0b00000000, c10::Float8_e8m0fnu::from_bits());
}
static constexpr c10::Float8_e8m0fnu max() {
// 254 biased, which is 127 unbiased, so 2^127
return c10::Float8_e8m0fnu(0b11111110, c10::Float8_e8m0fnu::from_bits());
}
static constexpr c10::Float8_e8m0fnu epsilon() {
// according to https://en.cppreference.com/w/cpp/types/numeric_limits, this
// is "the difference between 1.0 and the next representable value of the
// given floating-point type". The next representable value is 2.0, so the
// difference is 1.0 which is 2^0. 0 unbiased is 127 biased.
return c10::Float8_e8m0fnu(0b01111111, c10::Float8_e8m0fnu::from_bits());
}
static constexpr c10::Float8_e8m0fnu round_error() {
// 0.5 in float, which is 2^-1, and -1 + 127 = 126
return c10::Float8_e8m0fnu(0b01111110, c10::Float8_e8m0fnu::from_bits());
}
static constexpr c10::Float8_e8m0fnu quiet_NaN() {
return c10::Float8_e8m0fnu(0b11111111, c10::Float8_e8m0fnu::from_bits());
}
};
} // namespace std
C10_CLANG_DIAGNOSTIC_POP()
#include <torch/headeronly/util/Float8_e8m0fnu.h>

View File

@ -1,120 +1 @@
#pragma once
/// Defines the Float8_e8m0fnu type (8-bit floating-point) including
/// conversions to standard C types
/// Binary configuration :
/// eeeeeeee
/// no sign bits
/// 8 exponent bits
/// no mantissa bits
///
/// This is the E8M0 dtype from the OCP MX format spec
/// (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf,
/// Section 5.4.1)
#include <c10/macros/Export.h>
#include <c10/macros/Macros.h>
#include <c10/util/floating_point_utils.h>
#include <type_traits>
// TODO(#146647): do we need to special case OPENCL?
#if defined(__cplusplus)
#include <cstdint>
#elif !defined(__OPENCL_VERSION__)
#include <math.h>
#include <stdint.h>
#endif
#include <iosfwd>
#include <ostream>
namespace c10 {
namespace detail {
/*
* Convert a 32-bit floating-point number in IEEE single-precision format to a
* 8-bit floating-point number in fp8 e8m0fnu format, in bit representation.
*/
inline C10_HOST_DEVICE uint8_t fp8e8m0fnu_from_fp32_value(float f) {
// TODO(#146647): maybe rewrite without control flow
uint32_t f_bits = c10::detail::fp32_to_bits(f);
// extract the exponent
uint32_t exponent = (f_bits >> 23) & 0b11111111;
// special case float32 NaN and +-inf to map to e8m0 nan
if (exponent == 0b11111111) {
return exponent;
}
// next, we use guard, round, sticky bits and the LSB to implement round to
// nearest, with ties to even
// guard bit - bit 23, or 22 zero-indexed
uint8_t g = (f_bits & 0x400000) > 0;
// round bit - bit 22, or 21 zero-indexed
uint8_t r = (f_bits & 0x200000) > 0;
// sticky bit - bits 21 to 1, or 20 to 0 zero-indexed
uint8_t s = (f_bits & 0x1FFFFF) > 0;
// in casting to e8m0, LSB is the implied mantissa bit. It equals to 0 if the
// original float32 is denormal, and to 1 if the original float32 is normal.
uint8_t lsb = exponent > 0;
// implement the RNE logic
bool round_up = false;
// if g == 0, round down (no-op)
if (g == 1) {
if ((r == 1) || (s == 1)) {
// round up
round_up = true;
} else {
if (lsb == 1) {
// round up
round_up = true;
}
// if lsb == 0, round down (no-op)
}
}
if (round_up) {
// adjust exponent
// note that if exponent was 255 we would have already returned earlier, so
// we know we can add one safely without running out of bounds
exponent++;
}
return exponent;
}
} // namespace detail
struct alignas(1) Float8_e8m0fnu {
uint8_t x;
struct from_bits_t {};
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
return from_bits_t();
}
Float8_e8m0fnu() = default;
constexpr C10_HOST_DEVICE Float8_e8m0fnu(uint8_t bits, from_bits_t)
: x(bits) {}
inline C10_HOST_DEVICE Float8_e8m0fnu(float value);
inline C10_HOST_DEVICE operator float() const;
inline C10_HOST_DEVICE bool isnan() const;
};
inline std::ostream& operator<<(
std::ostream& out,
const Float8_e8m0fnu& value) {
out << (float)value;
return out;
}
} // namespace c10
#include <c10/util/Float8_e8m0fnu-inl.h> // IWYU pragma: keep
#include <torch/headeronly/util/Float8_e8m0fnu.h>

View File

@ -1,350 +1 @@
#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/bit_cast.h>
#include <cstring>
#include <limits>
#ifdef __CUDACC__
#include <cuda_fp16.h>
#endif
#ifdef __HIPCC__
#include <hip/hip_fp16.h>
#endif
#if defined(CL_SYCL_LANGUAGE_VERSION)
#include <CL/sycl.hpp> // for SYCL 1.2.1
#elif defined(SYCL_LANGUAGE_VERSION)
#include <sycl/sycl.hpp> // for SYCL 2020
#endif
#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \
!defined(__APPLE__)
#include <ATen/cpu/vec/vec_half.h>
#endif
C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
#endif
namespace c10 {
#if defined(__aarch64__) && !defined(__CUDACC__)
/// Constructors
inline Half::Half(float16_t value) : x(detail::fp16_to_bits(value)) {}
inline Half::operator float16_t() const {
return detail::fp16_from_bits(x);
}
#else
inline C10_HOST_DEVICE Half::Half(float value)
:
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
x(__half_as_short(__float2half(value)))
#elif defined(__SYCL_DEVICE_ONLY__)
x(c10::bit_cast<uint16_t>(sycl::half(value)))
#elif (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \
!defined(__APPLE__)
x(at::vec::float2half_scalar(value))
#else
x(detail::fp16_ieee_from_fp32_value(value))
#endif
{
}
/// Implicit conversions
inline C10_HOST_DEVICE Half::operator float() const {
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
return __half2float(*reinterpret_cast<const __half*>(&x));
#elif defined(__SYCL_DEVICE_ONLY__)
return float(c10::bit_cast<sycl::half>(x));
#elif (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \
!defined(__APPLE__)
return at::vec::half2float_scalar(x);
#elif defined(__aarch64__) && !defined(__CUDACC__)
return detail::native_fp16_to_fp32_value(x);
#else
return detail::fp16_ieee_to_fp32_value(x);
#endif
}
#endif /* !defined(__aarch64__) || defined(__CUDACC__) \
*/
#if defined(__CUDACC__) || defined(__HIPCC__)
inline C10_HOST_DEVICE Half::Half(const __half& value) {
x = *reinterpret_cast<const unsigned short*>(&value);
}
inline C10_HOST_DEVICE Half::operator __half() const {
return *reinterpret_cast<const __half*>(&x);
}
#endif
#ifdef SYCL_LANGUAGE_VERSION
inline C10_HOST_DEVICE Half::Half(const sycl::half& value) {
x = *reinterpret_cast<const unsigned short*>(&value);
}
inline C10_HOST_DEVICE Half::operator sycl::half() const {
return *reinterpret_cast<const sycl::half*>(&x);
}
#endif
// CUDA intrinsics
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 350)) || \
(defined(__clang__) && defined(__CUDA__))
inline __device__ Half __ldg(const Half* ptr) {
return __ldg(reinterpret_cast<const __half*>(ptr));
}
#endif
/// Arithmetic
inline C10_HOST_DEVICE Half operator+(const Half& a, const Half& b) {
return static_cast<float>(a) + static_cast<float>(b);
}
inline C10_HOST_DEVICE Half operator-(const Half& a, const Half& b) {
return static_cast<float>(a) - static_cast<float>(b);
}
inline C10_HOST_DEVICE Half operator*(const Half& a, const Half& b) {
return static_cast<float>(a) * static_cast<float>(b);
}
inline C10_HOST_DEVICE Half operator/(const Half& a, const Half& b)
__ubsan_ignore_float_divide_by_zero__ {
return static_cast<float>(a) / static_cast<float>(b);
}
inline C10_HOST_DEVICE Half operator-(const Half& a) {
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \
defined(__HIP_DEVICE_COMPILE__)
return __hneg(a);
#elif defined(__SYCL_DEVICE_ONLY__)
return -c10::bit_cast<sycl::half>(a);
#else
return -static_cast<float>(a);
#endif
}
inline C10_HOST_DEVICE Half& operator+=(Half& a, const Half& b) {
a = a + b;
return a;
}
inline C10_HOST_DEVICE Half& operator-=(Half& a, const Half& b) {
a = a - b;
return a;
}
inline C10_HOST_DEVICE Half& operator*=(Half& a, const Half& b) {
a = a * b;
return a;
}
inline C10_HOST_DEVICE Half& operator/=(Half& a, const Half& b) {
a = a / b;
return a;
}
/// Arithmetic with floats
inline C10_HOST_DEVICE float operator+(Half a, float b) {
return static_cast<float>(a) + b;
}
inline C10_HOST_DEVICE float operator-(Half a, float b) {
return static_cast<float>(a) - b;
}
inline C10_HOST_DEVICE float operator*(Half a, float b) {
return static_cast<float>(a) * b;
}
inline C10_HOST_DEVICE float operator/(Half a, float b)
__ubsan_ignore_float_divide_by_zero__ {
return static_cast<float>(a) / b;
}
inline C10_HOST_DEVICE float operator+(float a, Half b) {
return a + static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator-(float a, Half b) {
return a - static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator*(float a, Half b) {
return a * static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator/(float a, Half b)
__ubsan_ignore_float_divide_by_zero__ {
return a / static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator+=(float& a, const Half& b) {
return a += static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator-=(float& a, const Half& b) {
return a -= static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator*=(float& a, const Half& b) {
return a *= static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator/=(float& a, const Half& b) {
return a /= static_cast<float>(b);
}
/// Arithmetic with doubles
inline C10_HOST_DEVICE double operator+(Half a, double b) {
return static_cast<double>(a) + b;
}
inline C10_HOST_DEVICE double operator-(Half a, double b) {
return static_cast<double>(a) - b;
}
inline C10_HOST_DEVICE double operator*(Half a, double b) {
return static_cast<double>(a) * b;
}
inline C10_HOST_DEVICE double operator/(Half a, double b)
__ubsan_ignore_float_divide_by_zero__ {
return static_cast<double>(a) / b;
}
inline C10_HOST_DEVICE double operator+(double a, Half b) {
return a + static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator-(double a, Half b) {
return a - static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator*(double a, Half b) {
return a * static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator/(double a, Half b)
__ubsan_ignore_float_divide_by_zero__ {
return a / static_cast<double>(b);
}
/// Arithmetic with ints
inline C10_HOST_DEVICE Half operator+(Half a, int b) {
return a + static_cast<Half>(b);
}
inline C10_HOST_DEVICE Half operator-(Half a, int b) {
return a - static_cast<Half>(b);
}
inline C10_HOST_DEVICE Half operator*(Half a, int b) {
return a * static_cast<Half>(b);
}
inline C10_HOST_DEVICE Half operator/(Half a, int b) {
return a / static_cast<Half>(b);
}
inline C10_HOST_DEVICE Half operator+(int a, Half b) {
return static_cast<Half>(a) + b;
}
inline C10_HOST_DEVICE Half operator-(int a, Half b) {
return static_cast<Half>(a) - b;
}
inline C10_HOST_DEVICE Half operator*(int a, Half b) {
return static_cast<Half>(a) * b;
}
inline C10_HOST_DEVICE Half operator/(int a, Half b) {
return static_cast<Half>(a) / b;
}
//// Arithmetic with int64_t
inline C10_HOST_DEVICE Half operator+(Half a, int64_t b) {
return a + static_cast<Half>(b);
}
inline C10_HOST_DEVICE Half operator-(Half a, int64_t b) {
return a - static_cast<Half>(b);
}
inline C10_HOST_DEVICE Half operator*(Half a, int64_t b) {
return a * static_cast<Half>(b);
}
inline C10_HOST_DEVICE Half operator/(Half a, int64_t b) {
return a / static_cast<Half>(b);
}
inline C10_HOST_DEVICE Half operator+(int64_t a, Half b) {
return static_cast<Half>(a) + b;
}
inline C10_HOST_DEVICE Half operator-(int64_t a, Half b) {
return static_cast<Half>(a) - b;
}
inline C10_HOST_DEVICE Half operator*(int64_t a, Half b) {
return static_cast<Half>(a) * b;
}
inline C10_HOST_DEVICE Half operator/(int64_t a, Half b) {
return static_cast<Half>(a) / b;
}
/// NOTE: we do not define comparisons directly and instead rely on the implicit
/// conversion from c10::Half to float.
} // namespace c10
namespace std {
template <>
class numeric_limits<c10::Half> {
public:
static constexpr bool is_specialized = true;
static constexpr bool is_signed = true;
static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr bool has_infinity = true;
static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = true;
static constexpr auto has_denorm = numeric_limits<float>::has_denorm;
static constexpr auto has_denorm_loss =
numeric_limits<float>::has_denorm_loss;
static constexpr auto round_style = numeric_limits<float>::round_style;
static constexpr bool is_iec559 = true;
static constexpr bool is_bounded = true;
static constexpr bool is_modulo = false;
static constexpr int digits = 11;
static constexpr int digits10 = 3;
static constexpr int max_digits10 = 5;
static constexpr int radix = 2;
static constexpr int min_exponent = -13;
static constexpr int min_exponent10 = -4;
static constexpr int max_exponent = 16;
static constexpr int max_exponent10 = 4;
static constexpr auto traps = numeric_limits<float>::traps;
static constexpr auto tinyness_before =
numeric_limits<float>::tinyness_before;
static constexpr c10::Half min() {
return c10::Half(0x0400, c10::Half::from_bits());
}
static constexpr c10::Half lowest() {
return c10::Half(0xFBFF, c10::Half::from_bits());
}
static constexpr c10::Half max() {
return c10::Half(0x7BFF, c10::Half::from_bits());
}
static constexpr c10::Half epsilon() {
return c10::Half(0x1400, c10::Half::from_bits());
}
static constexpr c10::Half round_error() {
return c10::Half(0x3800, c10::Half::from_bits());
}
static constexpr c10::Half infinity() {
return c10::Half(0x7C00, c10::Half::from_bits());
}
static constexpr c10::Half quiet_NaN() {
return c10::Half(0x7E00, c10::Half::from_bits());
}
static constexpr c10::Half signaling_NaN() {
return c10::Half(0x7D00, c10::Half::from_bits());
}
static constexpr c10::Half denorm_min() {
return c10::Half(0x0001, c10::Half::from_bits());
}
};
} // namespace std
C10_CLANG_DIAGNOSTIC_POP()
#include <torch/headeronly/util/Half.h>

View File

@ -1,424 +1,8 @@
#pragma once
#include <torch/headeronly/util/Half.h>
/// Defines the Half type (half-precision floating-point) including conversions
/// to standard C types and basic arithmetic operations. Note that arithmetic
/// operations are implemented by converting to floating point and
/// performing the operation in float32, instead of using CUDA half intrinsics.
/// Most uses of this type within ATen are memory bound, including the
/// element-wise kernels, and the half intrinsics aren't efficient on all GPUs.
/// If you are writing a compute bound kernel, you can use the CUDA half
/// intrinsics directly on the Half type from device code.
#include <c10/macros/Export.h>
#include <c10/macros/Macros.h>
#include <c10/util/bit_cast.h>
#include <c10/util/floating_point_utils.h>
#include <type_traits>
#if defined(__cplusplus)
#include <cmath>
#elif !defined(__OPENCL_VERSION__)
#include <math.h>
// need to keep the following for BC because the APIs in here were exposed
// before migrating Half to torch/headeronly
#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \
!defined(__APPLE__)
#include <ATen/cpu/vec/vec_half.h>
#endif
#ifdef _MSC_VER
#include <intrin.h>
#endif
#include <cstdint>
#include <cstring>
#include <iosfwd>
#include <limits>
#include <ostream>
#ifdef __CUDACC__
#include <cuda_fp16.h>
#endif
#ifdef __HIPCC__
#include <hip/hip_fp16.h>
#endif
#if defined(CL_SYCL_LANGUAGE_VERSION)
#include <CL/sycl.hpp> // for SYCL 1.2.1
#elif defined(SYCL_LANGUAGE_VERSION)
#include <sycl/sycl.hpp> // for SYCL 2020
#endif
#if defined(__aarch64__) && !defined(__CUDACC__)
#include <arm_neon.h>
#endif
#if defined(__GNUC__) || defined(__clang__)
#if defined(__x86_64__) || defined(_M_X64) || defined(__i386) || \
defined(_M_IX86)
#if defined(__F16C__) && \
!(defined(__CUDA_ARCH__) || defined(__CUDACC__) || \
defined(__HIP_DEVICE_COMPILE__))
#define C10_X86_F16 1
#include <immintrin.h> // import conversion ops from f16cintrin.h
#endif // defined(__F16C__) && !(defined(__CUDA_ARCH__) || defined(__CUDACC__)
// || defined(__HIP_DEVICE_COMPILE__))
#endif // __x86_64__ || _M_X64 || __i386 || _M_IX86
#endif // __GNUC__ || __clang__
namespace c10 {
namespace detail {
/*
* Convert a 16-bit floating-point number in IEEE half-precision format, in bit
* representation, to a 32-bit floating-point number in IEEE single-precision
* format, in bit representation.
*
* @note The implementation doesn't use any floating-point operations.
*/
inline uint32_t fp16_ieee_to_fp32_bits(uint16_t h) {
/*
* Extend the half-precision floating-point number to 32 bits and shift to the
* upper part of the 32-bit word:
* +---+-----+------------+-------------------+
* | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000|
* +---+-----+------------+-------------------+
* Bits 31 26-30 16-25 0-15
*
* S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0
* - zero bits.
*/
const uint32_t w = (uint32_t)h << 16;
/*
* Extract the sign of the input number into the high bit of the 32-bit word:
*
* +---+----------------------------------+
* | S |0000000 00000000 00000000 00000000|
* +---+----------------------------------+
* Bits 31 0-31
*/
const uint32_t sign = w & UINT32_C(0x80000000);
/*
* Extract mantissa and biased exponent of the input number into the bits 0-30
* of the 32-bit word:
*
* +---+-----+------------+-------------------+
* | 0 |EEEEE|MM MMMM MMMM|0000 0000 0000 0000|
* +---+-----+------------+-------------------+
* Bits 30 27-31 17-26 0-16
*/
const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
/*
* Renorm shift is the number of bits to shift mantissa left to make the
* half-precision number normalized. If the initial number is normalized, some
* of its high 6 bits (sign == 0 and 5-bit exponent) equals one. In this case
* renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note
* that if we shift denormalized nonsign by renorm_shift, the unit bit of
* mantissa will shift into exponent, turning the biased exponent into 1, and
* making mantissa normalized (i.e. without leading 1).
*/
#ifdef _MSC_VER
unsigned long nonsign_bsr;
_BitScanReverse(&nonsign_bsr, (unsigned long)nonsign);
uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31;
#else
uint32_t renorm_shift = __builtin_clz(nonsign);
#endif
renorm_shift = renorm_shift > 5 ? renorm_shift - 5 : 0;
/*
* Iff half-precision number has exponent of 15, the addition overflows
* it into bit 31, and the subsequent shift turns the high 9 bits
* into 1. Thus inf_nan_mask == 0x7F800000 if the half-precision number
* had exponent of 15 (i.e. was NaN or infinity) 0x00000000 otherwise
*/
const int32_t inf_nan_mask =
((int32_t)(nonsign + 0x04000000) >> 8) & INT32_C(0x7F800000);
/*
* Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31
* into 1. Otherwise, bit 31 remains 0. The signed shift right by 31
* broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask ==
* 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h)
* 0x00000000 otherwise
*/
const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31;
/*
* 1. Shift nonsign left by renorm_shift to normalize it (if the input
* was denormal)
* 2. Shift nonsign right by 3 so the exponent (5 bits originally)
* becomes an 8-bit field and 10-bit mantissa shifts into the 10 high
* bits of the 23-bit mantissa of IEEE single-precision number.
* 3. Add 0x70 to the exponent (starting at bit 23) to compensate the
* different in exponent bias (0x7F for single-precision number less 0xF
* for half-precision number).
* 4. Subtract renorm_shift from the exponent (starting at bit 23) to
* account for renormalization. As renorm_shift is less than 0x70, this
* can be combined with step 3.
* 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the
* input was NaN or infinity.
* 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent
* into zero if the input was zero.
* 7. Combine with the sign of the input number.
*/
return sign |
((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) |
inf_nan_mask) &
~zero_mask);
}
/*
* Convert a 16-bit floating-point number in IEEE half-precision format, in bit
* representation, to a 32-bit floating-point number in IEEE single-precision
* format.
*
* @note The implementation relies on IEEE-like (no assumption about rounding
* mode and no operations on denormals) floating-point operations and bitcasts
* between integer and floating-point variables.
*/
C10_HOST_DEVICE inline float fp16_ieee_to_fp32_value(uint16_t h) {
#ifdef C10_X86_F16
return _cvtsh_ss(h);
#else
/*
* Extend the half-precision floating-point number to 32 bits and shift to the
* upper part of the 32-bit word:
* +---+-----+------------+-------------------+
* | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000|
* +---+-----+------------+-------------------+
* Bits 31 26-30 16-25 0-15
*
* S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0
* - zero bits.
*/
const uint32_t w = (uint32_t)h << 16;
/*
* Extract the sign of the input number into the high bit of the 32-bit word:
*
* +---+----------------------------------+
* | S |0000000 00000000 00000000 00000000|
* +---+----------------------------------+
* Bits 31 0-31
*/
const uint32_t sign = w & UINT32_C(0x80000000);
/*
* Extract mantissa and biased exponent of the input number into the high bits
* of the 32-bit word:
*
* +-----+------------+---------------------+
* |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000|
* +-----+------------+---------------------+
* Bits 27-31 17-26 0-16
*/
const uint32_t two_w = w + w;
/*
* Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become
* mantissa and exponent of a single-precision floating-point number:
*
* S|Exponent | Mantissa
* +-+---+-----+------------+----------------+
* |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000|
* +-+---+-----+------------+----------------+
* Bits | 23-31 | 0-22
*
* Next, there are some adjustments to the exponent:
* - The exponent needs to be corrected by the difference in exponent bias
* between single-precision and half-precision formats (0x7F - 0xF = 0x70)
* - Inf and NaN values in the inputs should become Inf and NaN values after
* conversion to the single-precision number. Therefore, if the biased
* exponent of the half-precision input was 0x1F (max possible value), the
* biased exponent of the single-precision output must be 0xFF (max possible
* value). We do this correction in two steps:
* - First, we adjust the exponent by (0xFF - 0x1F) = 0xE0 (see exp_offset
* below) rather than by 0x70 suggested by the difference in the exponent bias
* (see above).
* - Then we multiply the single-precision result of exponent adjustment by
* 2**(-112) to reverse the effect of exponent adjustment by 0xE0 less the
* necessary exponent adjustment by 0x70 due to difference in exponent bias.
* The floating-point multiplication hardware would ensure than Inf and
* NaN would retain their value on at least partially IEEE754-compliant
* implementations.
*
* Note that the above operations do not handle denormal inputs (where biased
* exponent == 0). However, they also do not operate on denormal inputs, and
* do not produce denormal results.
*/
constexpr uint32_t exp_offset = UINT32_C(0xE0) << 23;
// const float exp_scale = 0x1.0p-112f;
constexpr uint32_t scale_bits = (uint32_t)15 << 23;
float exp_scale_val = 0;
#if defined(_MSC_VER) && defined(__clang__)
__builtin_memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val));
#else
std::memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val));
#endif
const float exp_scale = exp_scale_val;
const float normalized_value =
fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
/*
* Convert denormalized half-precision inputs into single-precision results
* (always normalized). Zero inputs are also handled here.
*
* In a denormalized number the biased exponent is zero, and mantissa has
* on-zero bits. First, we shift mantissa into bits 0-9 of the 32-bit word.
*
* zeros | mantissa
* +---------------------------+------------+
* |0000 0000 0000 0000 0000 00|MM MMMM MMMM|
* +---------------------------+------------+
* Bits 10-31 0-9
*
* Now, remember that denormalized half-precision numbers are represented as:
* FP16 = mantissa * 2**(-24).
* The trick is to construct a normalized single-precision number with the
* same mantissa and thehalf-precision input and with an exponent which would
* scale the corresponding mantissa bits to 2**(-24). A normalized
* single-precision floating-point number is represented as: FP32 = (1 +
* mantissa * 2**(-23)) * 2**(exponent - 127) Therefore, when the biased
* exponent is 126, a unit change in the mantissa of the input denormalized
* half-precision number causes a change of the constructed single-precision
* number by 2**(-24), i.e. the same amount.
*
* The last step is to adjust the bias of the constructed single-precision
* number. When the input half-precision number is zero, the constructed
* single-precision number has the value of FP32 = 1 * 2**(126 - 127) =
* 2**(-1) = 0.5 Therefore, we need to subtract 0.5 from the constructed
* single-precision number to get the numerical equivalent of the input
* half-precision number.
*/
constexpr uint32_t magic_mask = UINT32_C(126) << 23;
constexpr float magic_bias = 0.5f;
const float denormalized_value =
fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
/*
* - Choose either results of conversion of input as a normalized number, or
* as a denormalized number, depending on the input exponent. The variable
* two_w contains input exponent in bits 27-31, therefore if its smaller than
* 2**27, the input is either a denormal number, or zero.
* - Combine the result of conversion of exponent and mantissa with the sign
* of the input number.
*/
constexpr uint32_t denormalized_cutoff = UINT32_C(1) << 27;
const uint32_t result = sign |
(two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value)
: fp32_to_bits(normalized_value));
return fp32_from_bits(result);
#endif // C10_X86_F16
}
/*
* Convert a 32-bit floating-point number in IEEE single-precision format to a
* 16-bit floating-point number in IEEE half-precision format, in bit
* representation.
*
* @note The implementation relies on IEEE-like (no assumption about rounding
* mode and no operations on denormals) floating-point operations and bitcasts
* between integer and floating-point variables.
*/
inline uint16_t fp16_ieee_from_fp32_value(float f) {
#ifdef C10_X86_F16
return _cvtss_sh(f, _MM_FROUND_TO_NEAREST_INT);
#else
// const float scale_to_inf = 0x1.0p+112f;
// const float scale_to_zero = 0x1.0p-110f;
constexpr uint32_t scale_to_inf_bits = (uint32_t)239 << 23;
constexpr uint32_t scale_to_zero_bits = (uint32_t)17 << 23;
float scale_to_inf_val = 0, scale_to_zero_val = 0;
std::memcpy(&scale_to_inf_val, &scale_to_inf_bits, sizeof(scale_to_inf_val));
std::memcpy(
&scale_to_zero_val, &scale_to_zero_bits, sizeof(scale_to_zero_val));
const float scale_to_inf = scale_to_inf_val;
const float scale_to_zero = scale_to_zero_val;
#if defined(_MSC_VER) && _MSC_VER == 1916
float base = ((signbit(f) != 0 ? -f : f) * scale_to_inf) * scale_to_zero;
#else
float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
#endif
const uint32_t w = fp32_to_bits(f);
const uint32_t shl1_w = w + w;
const uint32_t sign = w & UINT32_C(0x80000000);
uint32_t bias = shl1_w & UINT32_C(0xFF000000);
if (bias < UINT32_C(0x71000000)) {
bias = UINT32_C(0x71000000);
}
base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
const uint32_t bits = fp32_to_bits(base);
const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
const uint32_t nonsign = exp_bits + mantissa_bits;
return static_cast<uint16_t>(
(sign >> 16) |
(shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign));
#endif // C10_X86_F16
}
#ifdef C10_X86_F16
#undef C10_X86_F16
#endif // C10_X86_F16
#if defined(__aarch64__) && !defined(__CUDACC__)
inline float16_t fp16_from_bits(uint16_t h) {
return c10::bit_cast<float16_t>(h);
}
inline uint16_t fp16_to_bits(float16_t f) {
return c10::bit_cast<uint16_t>(f);
}
// According to https://godbolt.org/z/frExdbsWG it would translate to single
// fcvt s0, h0
inline float native_fp16_to_fp32_value(uint16_t h) {
return static_cast<float>(fp16_from_bits(h));
}
inline uint16_t native_fp16_from_fp32_value(float f) {
return fp16_to_bits(static_cast<float16_t>(f));
}
#endif
} // namespace detail
struct alignas(2) Half {
unsigned short x;
struct from_bits_t {};
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
return from_bits_t();
}
// HIP wants __host__ __device__ tag, CUDA does not
#if defined(USE_ROCM)
C10_HOST_DEVICE Half() = default;
#else
Half() = default;
#endif
constexpr C10_HOST_DEVICE Half(unsigned short bits, from_bits_t) : x(bits) {}
#if defined(__aarch64__) && !defined(__CUDACC__)
inline Half(float16_t value);
inline operator float16_t() const;
#else
inline C10_HOST_DEVICE Half(float value);
inline C10_HOST_DEVICE operator float() const;
#endif
#if defined(__CUDACC__) || defined(__HIPCC__)
inline C10_HOST_DEVICE Half(const __half& value);
inline C10_HOST_DEVICE operator __half() const;
#endif
#ifdef SYCL_LANGUAGE_VERSION
inline C10_HOST_DEVICE Half(const sycl::half& value);
inline C10_HOST_DEVICE operator sycl::half() const;
#endif
};
inline std::ostream& operator<<(std::ostream& out, const Half& value) {
out << (float)value;
return out;
}
} // namespace c10
#include <c10/util/Half-inl.h> // IWYU pragma: keep

View File

@ -1,140 +1 @@
#pragma once
#include <c10/macros/Macros.h>
#include <limits>
#include <type_traits>
C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wstring-conversion")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wstring-conversion")
#endif
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
#endif
namespace c10 {
/// Returns false since we cannot have x < 0 if x is unsigned.
template <typename T>
inline constexpr bool is_negative(
const T& /*x*/,
std::true_type /*is_unsigned*/) {
return false;
}
/// Returns true if a signed variable x < 0
template <typename T>
inline constexpr bool is_negative(const T& x, std::false_type /*is_unsigned*/) {
return x < T(0);
}
/// Returns true if x < 0
/// NOTE: Will fail on an unsigned custom type
/// For the most part it's possible to fix this if
/// the custom type has a constexpr constructor.
/// However, notably, c10::Half does not :-(
template <typename T>
inline constexpr bool is_negative(const T& x) {
return is_negative(x, std::is_unsigned<T>());
}
/// Returns the sign of an unsigned variable x as 0, 1
template <typename T>
inline constexpr int signum(const T& x, std::true_type /*is_unsigned*/) {
return T(0) < x;
}
/// Returns the sign of a signed variable x as -1, 0, 1
template <typename T>
inline constexpr int signum(const T& x, std::false_type /*is_unsigned*/) {
return (T(0) < x) - (x < T(0));
}
/// Returns the sign of x as -1, 0, 1
/// NOTE: Will fail on an unsigned custom type
/// For the most part it's possible to fix this if
/// the custom type has a constexpr constructor.
/// However, notably, c10::Half does not :-(
template <typename T>
inline constexpr int signum(const T& x) {
return signum(x, std::is_unsigned<T>());
}
/// Returns true if a and b are not both negative
template <typename T, typename U>
inline constexpr bool signs_differ(const T& a, const U& b) {
return is_negative(a) != is_negative(b);
}
// Suppress sign compare warning when compiling with GCC
// as later does not account for short-circuit rule before
// raising the warning, see https://godbolt.org/z/Tr3Msnz99
#ifdef __GNUC__
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wsign-compare"
#endif
/// Returns true if x is greater than the greatest value of the type Limit
template <typename Limit, typename T>
inline constexpr bool greater_than_max(const T& x) {
constexpr bool can_overflow =
std::numeric_limits<T>::digits > std::numeric_limits<Limit>::digits;
return can_overflow && x > (std::numeric_limits<Limit>::max)();
}
#ifdef __GNUC__
#pragma GCC diagnostic pop
#endif
/// Returns true if x < lowest(Limit). Standard comparison
template <typename Limit, typename T>
inline constexpr bool less_than_lowest(
const T& x,
std::false_type /*limit_is_unsigned*/,
std::false_type /*x_is_unsigned*/) {
return x < std::numeric_limits<Limit>::lowest();
}
/// Returns false since all the limit is signed and therefore includes
/// negative values but x cannot be negative because it is unsigned
template <typename Limit, typename T>
inline constexpr bool less_than_lowest(
const T& /*x*/,
std::false_type /*limit_is_unsigned*/,
std::true_type /*x_is_unsigned*/) {
return false;
}
/// Returns true if x < 0, where 0 is constructed from T.
/// Limit is not signed, so its lower value is zero
template <typename Limit, typename T>
inline constexpr bool less_than_lowest(
const T& x,
std::true_type /*limit_is_unsigned*/,
std::false_type /*x_is_unsigned*/) {
return x < T(0);
}
/// Returns false sign both types are unsigned
template <typename Limit, typename T>
inline constexpr bool less_than_lowest(
const T& /*x*/,
std::true_type /*limit_is_unsigned*/,
std::true_type /*x_is_unsigned*/) {
return false;
}
/// Returns true if x is less than the lowest value of type T
/// NOTE: Will fail on an unsigned custom type
/// For the most part it's possible to fix this if
/// the custom type has a constexpr constructor.
/// However, notably, c10::Half does not :
template <typename Limit, typename T>
inline constexpr bool less_than_lowest(const T& x) {
return less_than_lowest<Limit>(
x, std::is_unsigned<Limit>(), std::is_unsigned<T>());
}
} // namespace c10
C10_CLANG_DIAGNOSTIC_POP()
#include <torch/headeronly/util/TypeSafeSignMath.h>

View File

@ -1,46 +1 @@
#pragma once
#include <cstring>
#include <type_traits>
#include <c10/macros/Macros.h>
#if __has_include(<bit>) && (defined(__cpp_lib_bit_cast) && __cpp_lib_bit_cast >= 201806L)
#include <bit>
#define C10_HAVE_STD_BIT_CAST 1
#else
#define C10_HAVE_STD_BIT_CAST 0
#endif // __has_include(<bit>) && (__cplusplus >= 202002L ||
// (defined(__cpp_lib_bit_cast) && __cpp_lib_bit_cast >= 201806L))
namespace c10 {
#if C10_HAVE_STD_BIT_CAST
using std::bit_cast;
#else
// Implementations of std::bit_cast() from C++ 20.
//
// This is a less sketchy version of reinterpret_cast.
//
// See https://en.cppreference.com/w/cpp/numeric/bit_cast for more
// information as well as the source of our implementations.
template <class To, class From>
C10_HOST_DEVICE std::enable_if_t<
sizeof(To) == sizeof(From) && std::is_trivially_copyable_v<From> &&
std::is_trivially_copyable_v<To>,
To>
// constexpr support needs compiler magic
bit_cast(const From& src) noexcept {
static_assert(
std::is_trivially_constructible_v<To>,
"This implementation additionally requires "
"destination type to be trivially constructible");
To dst;
std::memcpy(&dst, &src, sizeof(To));
return dst;
}
#endif // C10_HAVE_STD_BIT_CAST
#undef C10_HAVE_STD_BIT_CAST
} // namespace c10
#include <torch/headeronly/util/bit_cast.h>

View File

@ -4,531 +4,7 @@
#include <c10/macros/Macros.h>
#include <c10/util/Half.h>
#if defined(__CUDACC__) || defined(__HIPCC__)
#include <thrust/complex.h>
#endif
C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
#endif
#if C10_CLANG_HAS_WARNING("-Wfloat-conversion")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wfloat-conversion")
#endif
namespace c10 {
// c10::complex is an implementation of complex numbers that aims
// to work on all devices supported by PyTorch
//
// Most of the APIs duplicates std::complex
// Reference: https://en.cppreference.com/w/cpp/numeric/complex
//
// [NOTE: Complex Operator Unification]
// Operators currently use a mix of std::complex, thrust::complex, and
// c10::complex internally. The end state is that all operators will use
// c10::complex internally. Until then, there may be some hacks to support all
// variants.
//
//
// [Note on Constructors]
//
// The APIs of constructors are mostly copied from C++ standard:
// https://en.cppreference.com/w/cpp/numeric/complex/complex
//
// Since C++14, all constructors are constexpr in std::complex
//
// There are three types of constructors:
// - initializing from real and imag:
// `constexpr complex( const T& re = T(), const T& im = T() );`
// - implicitly-declared copy constructor
// - converting constructors
//
// Converting constructors:
// - std::complex defines converting constructor between float/double/long
// double,
// while we define converting constructor between float/double.
// - For these converting constructors, upcasting is implicit, downcasting is
// explicit.
// - We also define explicit casting from std::complex/thrust::complex
// - Note that the conversion from thrust is not constexpr, because
// thrust does not define them as constexpr ????
//
//
// [Operator =]
//
// The APIs of operator = are mostly copied from C++ standard:
// https://en.cppreference.com/w/cpp/numeric/complex/operator%3D
//
// Since C++20, all operator= are constexpr. Although we are not building with
// C++20, we also obey this behavior.
//
// There are three types of assign operator:
// - Assign a real value from the same scalar type
// - In std, this is templated as complex& operator=(const T& x)
// with specialization `complex& operator=(T x)` for float/double/long
// double Since we only support float and double, on will use `complex&
// operator=(T x)`
// - Copy assignment operator and converting assignment operator
// - There is no specialization of converting assignment operators, which type
// is
// convertible is solely dependent on whether the scalar type is convertible
//
// In addition to the standard assignment, we also provide assignment operators
// with std and thrust
//
//
// [Casting operators]
//
// std::complex does not have casting operators. We define casting operators
// casting to std::complex and thrust::complex
//
//
// [Operator ""]
//
// std::complex has custom literals `i`, `if` and `il` defined in namespace
// `std::literals::complex_literals`. We define our own custom literals in the
// namespace `c10::complex_literals`. Our custom literals does not follow the
// same behavior as in std::complex, instead, we define _if, _id to construct
// float/double complex literals.
//
//
// [real() and imag()]
//
// In C++20, there are two overload of these functions, one it to return the
// real/imag, another is to set real/imag, they are both constexpr. We follow
// this design.
//
//
// [Operator +=,-=,*=,/=]
//
// Since C++20, these operators become constexpr. In our implementation, they
// are also constexpr.
//
// There are two types of such operators: operating with a real number, or
// operating with another complex number. For the operating with a real number,
// the generic template form has argument type `const T &`, while the overload
// for float/double/long double has `T`. We will follow the same type as
// float/double/long double in std.
//
// [Unary operator +-]
//
// Since C++20, they are constexpr. We also make them expr
//
// [Binary operators +-*/]
//
// Each operator has three versions (taking + as example):
// - complex + complex
// - complex + real
// - real + complex
//
// [Operator ==, !=]
//
// Each operator has three versions (taking == as example):
// - complex == complex
// - complex == real
// - real == complex
//
// Some of them are removed on C++20, but we decide to keep them
//
// [Operator <<, >>]
//
// These are implemented by casting to std::complex
//
//
//
// TODO(@zasdfgbnm): c10::complex<c10::Half> is not currently supported,
// because:
// - lots of members and functions of c10::Half are not constexpr
// - thrust::complex only support float and double
template <typename T>
struct alignas(sizeof(T) * 2) complex {
using value_type = T;
T real_ = T(0);
T imag_ = T(0);
constexpr complex() = default;
C10_HOST_DEVICE constexpr complex(const T& re, const T& im = T())
: real_(re), imag_(im) {}
template <typename U>
explicit constexpr complex(const std::complex<U>& other)
: complex(other.real(), other.imag()) {}
#if defined(__CUDACC__) || defined(__HIPCC__)
template <typename U>
explicit C10_HOST_DEVICE complex(const thrust::complex<U>& other)
: real_(other.real()), imag_(other.imag()) {}
// NOTE can not be implemented as follow due to ROCm bug:
// explicit C10_HOST_DEVICE complex(const thrust::complex<U> &other):
// complex(other.real(), other.imag()) {}
#endif
// Use SFINAE to specialize casting constructor for c10::complex<float> and
// c10::complex<double>
template <typename U = T>
C10_HOST_DEVICE explicit constexpr complex(
const std::enable_if_t<std::is_same_v<U, float>, complex<double>>& other)
: real_(other.real_), imag_(other.imag_) {}
template <typename U = T>
C10_HOST_DEVICE constexpr complex(
const std::enable_if_t<std::is_same_v<U, double>, complex<float>>& other)
: real_(other.real_), imag_(other.imag_) {}
constexpr complex<T>& operator=(T re) {
real_ = re;
imag_ = 0;
return *this;
}
constexpr complex<T>& operator+=(T re) {
real_ += re;
return *this;
}
constexpr complex<T>& operator-=(T re) {
real_ -= re;
return *this;
}
constexpr complex<T>& operator*=(T re) {
real_ *= re;
imag_ *= re;
return *this;
}
constexpr complex<T>& operator/=(T re) {
real_ /= re;
imag_ /= re;
return *this;
}
template <typename U>
constexpr complex<T>& operator=(const complex<U>& rhs) {
real_ = rhs.real();
imag_ = rhs.imag();
return *this;
}
template <typename U>
constexpr complex<T>& operator+=(const complex<U>& rhs) {
real_ += rhs.real();
imag_ += rhs.imag();
return *this;
}
template <typename U>
constexpr complex<T>& operator-=(const complex<U>& rhs) {
real_ -= rhs.real();
imag_ -= rhs.imag();
return *this;
}
template <typename U>
constexpr complex<T>& operator*=(const complex<U>& rhs) {
// (a + bi) * (c + di) = (a*c - b*d) + (a * d + b * c) i
T a = real_;
T b = imag_;
U c = rhs.real();
U d = rhs.imag();
real_ = a * c - b * d;
imag_ = a * d + b * c;
return *this;
}
#ifdef __APPLE__
#define FORCE_INLINE_APPLE __attribute__((always_inline))
#else
#define FORCE_INLINE_APPLE
#endif
template <typename U>
constexpr FORCE_INLINE_APPLE complex<T>& operator/=(const complex<U>& rhs)
__ubsan_ignore_float_divide_by_zero__ {
// (a + bi) / (c + di) = (ac + bd)/(c^2 + d^2) + (bc - ad)/(c^2 + d^2) i
// the calculation below follows numpy's complex division
T a = real_;
T b = imag_;
U c = rhs.real();
U d = rhs.imag();
#if defined(__GNUC__) && !defined(__clang__)
// std::abs is already constexpr by gcc
auto abs_c = std::abs(c);
auto abs_d = std::abs(d);
#else
auto abs_c = c < 0 ? -c : c;
auto abs_d = d < 0 ? -d : d;
#endif
if (abs_c >= abs_d) {
if (abs_c == U(0) && abs_d == U(0)) {
/* divide by zeros should yield a complex inf or nan */
real_ = a / abs_c;
imag_ = b / abs_d;
} else {
auto rat = d / c;
auto scl = U(1.0) / (c + d * rat);
real_ = (a + b * rat) * scl;
imag_ = (b - a * rat) * scl;
}
} else {
auto rat = c / d;
auto scl = U(1.0) / (d + c * rat);
real_ = (a * rat + b) * scl;
imag_ = (b * rat - a) * scl;
}
return *this;
}
#undef FORCE_INLINE_APPLE
template <typename U>
constexpr complex<T>& operator=(const std::complex<U>& rhs) {
real_ = rhs.real();
imag_ = rhs.imag();
return *this;
}
#if defined(__CUDACC__) || defined(__HIPCC__)
template <typename U>
C10_HOST_DEVICE complex<T>& operator=(const thrust::complex<U>& rhs) {
real_ = rhs.real();
imag_ = rhs.imag();
return *this;
}
#endif
template <typename U>
explicit constexpr operator std::complex<U>() const {
return std::complex<U>(std::complex<T>(real(), imag()));
}
#if defined(__CUDACC__) || defined(__HIPCC__)
template <typename U>
C10_HOST_DEVICE explicit operator thrust::complex<U>() const {
return static_cast<thrust::complex<U>>(thrust::complex<T>(real(), imag()));
}
#endif
// consistent with NumPy behavior
explicit constexpr operator bool() const {
return real() || imag();
}
C10_HOST_DEVICE constexpr T real() const {
return real_;
}
constexpr void real(T value) {
real_ = value;
}
C10_HOST_DEVICE constexpr T imag() const {
return imag_;
}
constexpr void imag(T value) {
imag_ = value;
}
};
namespace complex_literals {
constexpr complex<float> operator""_if(long double imag) {
return complex<float>(0.0f, static_cast<float>(imag));
}
constexpr complex<double> operator""_id(long double imag) {
return complex<double>(0.0, static_cast<double>(imag));
}
constexpr complex<float> operator""_if(unsigned long long imag) {
return complex<float>(0.0f, static_cast<float>(imag));
}
constexpr complex<double> operator""_id(unsigned long long imag) {
return complex<double>(0.0, static_cast<double>(imag));
}
} // namespace complex_literals
template <typename T>
constexpr complex<T> operator+(const complex<T>& val) {
return val;
}
template <typename T>
constexpr complex<T> operator-(const complex<T>& val) {
return complex<T>(-val.real(), -val.imag());
}
template <typename T>
constexpr complex<T> operator+(const complex<T>& lhs, const complex<T>& rhs) {
complex<T> result = lhs;
return result += rhs;
}
template <typename T>
constexpr complex<T> operator+(const complex<T>& lhs, const T& rhs) {
complex<T> result = lhs;
return result += rhs;
}
template <typename T>
constexpr complex<T> operator+(const T& lhs, const complex<T>& rhs) {
return complex<T>(lhs + rhs.real(), rhs.imag());
}
template <typename T>
constexpr complex<T> operator-(const complex<T>& lhs, const complex<T>& rhs) {
complex<T> result = lhs;
return result -= rhs;
}
template <typename T>
constexpr complex<T> operator-(const complex<T>& lhs, const T& rhs) {
complex<T> result = lhs;
return result -= rhs;
}
template <typename T>
constexpr complex<T> operator-(const T& lhs, const complex<T>& rhs) {
complex<T> result = -rhs;
return result += lhs;
}
template <typename T>
constexpr complex<T> operator*(const complex<T>& lhs, const complex<T>& rhs) {
complex<T> result = lhs;
return result *= rhs;
}
template <typename T>
constexpr complex<T> operator*(const complex<T>& lhs, const T& rhs) {
complex<T> result = lhs;
return result *= rhs;
}
template <typename T>
constexpr complex<T> operator*(const T& lhs, const complex<T>& rhs) {
complex<T> result = rhs;
return result *= lhs;
}
template <typename T>
constexpr complex<T> operator/(const complex<T>& lhs, const complex<T>& rhs) {
complex<T> result = lhs;
return result /= rhs;
}
template <typename T>
constexpr complex<T> operator/(const complex<T>& lhs, const T& rhs) {
complex<T> result = lhs;
return result /= rhs;
}
template <typename T>
constexpr complex<T> operator/(const T& lhs, const complex<T>& rhs) {
complex<T> result(lhs, T());
return result /= rhs;
}
// Define operators between integral scalars and c10::complex. std::complex does
// not support this when T is a floating-point number. This is useful because it
// saves a lot of "static_cast" when operate a complex and an integer. This
// makes the code both less verbose and potentially more efficient.
#define COMPLEX_INTEGER_OP_TEMPLATE_CONDITION \
typename std::enable_if_t< \
std::is_floating_point_v<fT> && std::is_integral_v<iT>, \
int> = 0
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
constexpr c10::complex<fT> operator+(const c10::complex<fT>& a, const iT& b) {
return a + static_cast<fT>(b);
}
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
constexpr c10::complex<fT> operator+(const iT& a, const c10::complex<fT>& b) {
return static_cast<fT>(a) + b;
}
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
constexpr c10::complex<fT> operator-(const c10::complex<fT>& a, const iT& b) {
return a - static_cast<fT>(b);
}
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
constexpr c10::complex<fT> operator-(const iT& a, const c10::complex<fT>& b) {
return static_cast<fT>(a) - b;
}
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
constexpr c10::complex<fT> operator*(const c10::complex<fT>& a, const iT& b) {
return a * static_cast<fT>(b);
}
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
constexpr c10::complex<fT> operator*(const iT& a, const c10::complex<fT>& b) {
return static_cast<fT>(a) * b;
}
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
constexpr c10::complex<fT> operator/(const c10::complex<fT>& a, const iT& b) {
return a / static_cast<fT>(b);
}
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
constexpr c10::complex<fT> operator/(const iT& a, const c10::complex<fT>& b) {
return static_cast<fT>(a) / b;
}
#undef COMPLEX_INTEGER_OP_TEMPLATE_CONDITION
template <typename T>
constexpr bool operator==(const complex<T>& lhs, const complex<T>& rhs) {
return (lhs.real() == rhs.real()) && (lhs.imag() == rhs.imag());
}
template <typename T>
constexpr bool operator==(const complex<T>& lhs, const T& rhs) {
return (lhs.real() == rhs) && (lhs.imag() == T());
}
template <typename T>
constexpr bool operator==(const T& lhs, const complex<T>& rhs) {
return (lhs == rhs.real()) && (T() == rhs.imag());
}
template <typename T>
constexpr bool operator!=(const complex<T>& lhs, const complex<T>& rhs) {
return !(lhs == rhs);
}
template <typename T>
constexpr bool operator!=(const complex<T>& lhs, const T& rhs) {
return !(lhs == rhs);
}
template <typename T>
constexpr bool operator!=(const T& lhs, const complex<T>& rhs) {
return !(lhs == rhs);
}
template <typename T, typename CharT, typename Traits>
std::basic_ostream<CharT, Traits>& operator<<(
std::basic_ostream<CharT, Traits>& os,
const complex<T>& x) {
return (os << static_cast<std::complex<T>>(x));
}
template <typename T, typename CharT, typename Traits>
std::basic_istream<CharT, Traits>& operator>>(
std::basic_istream<CharT, Traits>& is,
complex<T>& x) {
std::complex<T> tmp;
is >> tmp;
x = tmp;
return is;
}
} // namespace c10
#include <torch/headeronly/util/complex.h>
// std functions
//
@ -594,72 +70,6 @@ constexpr c10::complex<T> conj(const c10::complex<T>& z) {
} // namespace std
namespace c10 {
template <typename T>
C10_HOST_DEVICE complex<T> polar(const T& r, const T& theta = T()) {
#if defined(__CUDACC__) || defined(__HIPCC__)
return static_cast<complex<T>>(thrust::polar(r, theta));
#else
// std::polar() requires r >= 0, so spell out the explicit implementation to
// avoid a branch.
return complex<T>(r * std::cos(theta), r * std::sin(theta));
#endif
}
template <>
struct alignas(4) complex<Half> {
Half real_;
Half imag_;
// Constructors
complex() = default;
// Half constructor is not constexpr so the following constructor can't
// be constexpr
C10_HOST_DEVICE explicit inline complex(const Half& real, const Half& imag)
: real_(real), imag_(imag) {}
C10_HOST_DEVICE inline complex(const c10::complex<float>& value)
: real_(value.real()), imag_(value.imag()) {}
// Conversion operator
inline C10_HOST_DEVICE operator c10::complex<float>() const {
return {real_, imag_};
}
constexpr C10_HOST_DEVICE Half real() const {
return real_;
}
constexpr C10_HOST_DEVICE Half imag() const {
return imag_;
}
C10_HOST_DEVICE complex<Half>& operator+=(const complex<Half>& other) {
real_ = static_cast<float>(real_) + static_cast<float>(other.real_);
imag_ = static_cast<float>(imag_) + static_cast<float>(other.imag_);
return *this;
}
C10_HOST_DEVICE complex<Half>& operator-=(const complex<Half>& other) {
real_ = static_cast<float>(real_) - static_cast<float>(other.real_);
imag_ = static_cast<float>(imag_) - static_cast<float>(other.imag_);
return *this;
}
C10_HOST_DEVICE complex<Half>& operator*=(const complex<Half>& other) {
auto a = static_cast<float>(real_);
auto b = static_cast<float>(imag_);
auto c = static_cast<float>(other.real());
auto d = static_cast<float>(other.imag());
real_ = a * c - b * d;
imag_ = a * d + b * c;
return *this;
}
};
} // namespace c10
C10_CLANG_DIAGNOSTIC_POP()
#define C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H
// math functions are included in a separate file
#include <c10/util/complex_math.h> // IWYU pragma: keep

View File

@ -1,33 +1 @@
#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/bit_cast.h>
#include <cstdint>
namespace c10::detail {
C10_HOST_DEVICE inline float fp32_from_bits(uint32_t w) {
#if defined(__OPENCL_VERSION__)
return as_float(w);
#elif defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
return __uint_as_float((unsigned int)w);
#elif defined(__INTEL_COMPILER)
return _castu32_f32(w);
#else
return c10::bit_cast<float>(w);
#endif
}
C10_HOST_DEVICE inline uint32_t fp32_to_bits(float f) {
#if defined(__OPENCL_VERSION__)
return as_uint(f);
#elif defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
return (uint32_t)__float_as_uint(f);
#elif defined(__INTEL_COMPILER)
return _castf32_u32(f);
#else
return c10::bit_cast<uint32_t>(f);
#endif
}
} // namespace c10::detail
#include <torch/headeronly/util/floating_point_utils.h>

View File

@ -1,3 +1,4 @@
#include <c10/core/AllocatorConfig.h>
#include <c10/util/flat_hash_map.h>
#include <c10/util/irange.h>
#include <c10/xpu/XPUCachingAllocator.h>
@ -20,8 +21,6 @@ constexpr size_t kMinBlockSize = 512;
constexpr size_t kSmallSize = 1048576;
// "small" allocations are packed in 2 MiB blocks
constexpr size_t kSmallBuffer = 2097152;
// "large" allocations may be packed in 20 MiB blocks
constexpr size_t kLargeBuffer = 20971520;
// allocations between 1 and 10 MiB may use kLargeBuffer
constexpr size_t kMinLargeAlloc = 10485760;
// round up large allocations to 2 MiB

View File

@ -1346,6 +1346,10 @@ if(BUILD_TEST)
add_subdirectory(${TORCH_ROOT}/test/cpp/jit ${CMAKE_BINARY_DIR}/test_jit)
add_subdirectory(${TORCH_ROOT}/test/cpp/nativert ${CMAKE_BINARY_DIR}/test_nativert)
add_subdirectory(${TORCH_ROOT}/test/inductor ${CMAKE_BINARY_DIR}/test_inductor)
add_subdirectory(
${TORCH_ROOT}/test/cpp/tensorexpr
${CMAKE_BINARY_DIR}/test_tensorexpr
)
if(USE_DISTRIBUTED)
add_subdirectory(${TORCH_ROOT}/test/cpp/c10d ${CMAKE_BINARY_DIR}/test_cpp_c10d)
if(NOT WIN32)
@ -1767,6 +1771,10 @@ if(USE_ROCM)
target_link_libraries(torch_hip PUBLIC torch_cpu_library ${Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS})
target_link_libraries(torch_hip PRIVATE ${Caffe2_HIP_DEPENDENCY_LIBS})
if(USE_FBGEMM_GENAI)
target_link_libraries(torch_hip PRIVATE fbgemm_genai)
endif()
# Since PyTorch files contain HIP headers, this is also needed to capture the includes.
# ROCM_INCLUDE_DIRS is defined in LoadHIP.cmake
target_include_directories(torch_hip PRIVATE ${Caffe2_HIP_INCLUDE} ${ROCM_INCLUDE_DIRS})

View File

@ -362,14 +362,6 @@ function(torch_compile_options libname)
# For MS official doc: https://learn.microsoft.com/en-us/cpp/build/reference/zc-preprocessor
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Zc:preprocessor" PARENT_SCOPE)
if(${MSVC_TOOLSET_VERSION} GREATER_EQUAL 143)
# Add /d2implyavx512upperregs- to disable compiler over-aggressive optimization, which caused involeved AVX512 register on AVX2 machine.
# Reference: https://github.com/pytorch/pytorch/issues/145702#issuecomment-2874029459
target_compile_options(${libname} PUBLIC $<$<COMPILE_LANGUAGE:CXX>:/d2implyavx512upperregs->)
endif()
target_compile_options(${libname} PUBLIC
$<$<COMPILE_LANGUAGE:CXX>:
${MSVC_RUNTIME_LIBRARY_OPTION}

View File

@ -0,0 +1,17 @@
document.addEventListener("DOMContentLoaded", function () {
var script = document.createElement("script");
script.type = "module";
script.id = "runllm-widget-script"
script.src = "https://widget.runllm.com";
script.setAttribute("version", "stable");
script.setAttribute("crossorigin", "true");
script.setAttribute("runllm-keyboard-shortcut", "Mod+j");
script.setAttribute("runllm-name", "PyTorch");
script.setAttribute("runllm-position", "BOTTOM_RIGHT");
script.setAttribute("runllm-assistant-id", "834");
script.async = true;
document.head.appendChild(script);
});

Binary file not shown.

After

Width:  |  Height:  |  Size: 424 KiB

View File

@ -0,0 +1,15 @@
import functools
import os
import torch
# to lower notebook execution time while hiding backend="eager"
torch.compile = functools.partial(torch.compile, backend="eager")
# to clear torch logs format
os.environ["TORCH_LOGS_FORMAT"] = ""
torch._logging._internal.DEFAULT_FORMATTER = (
torch._logging._internal._default_formatter()
)
torch._logging._internal._init_logs()

View File

@ -0,0 +1,142 @@
---
file_format: mystnb
kernelspec:
name: python3
mystnb:
execution_timeout: 30
execution_show_tb: True
merge_streams: True
---
```{code-cell}
:tags: [remove-cell]
import torch
import header_code
torch._logging.set_logs(graph_breaks=True)
```
# Common Graph Breaks
Below are some common graph breaks and some workarounds.
## Incorrect Code
Your code might contain errors (meaning it doesn't execute even without `torch.compile`). In the example below, there's a typo in the `torch.sin` call due to an extra argument. **Always disable `torch.compile` to check if the code runs correctly.**
```{code-cell}
@torch.compile
def fn(x):
y = torch.sin(x, x)
return y
try:
fn(torch.ones(3, 3))
except Exception as e:
pass
```
Dynamo makes a best-effort attempt to hint if a graph break is caused by your code.
But it can still sometimes be difficult to tell from the logs if the graph break is caused by an error in your code,
is a more complicated graph break, or is a `torch.compile` bug. In order to differentiate, we recommend trying to run your code without `torch.compile` to see if you still get the error reported by the graph break.
## Data-dependent operations
`torch.compile` graph breaks on data-dependent operations such as data-dependent control flow (if-statements, loops with tensors) and direct tensor data accesses (`.item`, `.data_ptr`).
```{code-cell}
@torch.compile
def fn(x):
y = x.sum()
if y > 0:
return x + y.item()
return x - y.item()
print(fn(torch.ones(3, 3)))
```
The general workaround for these graph breaks is to avoid doing data-dependent operations. Some specific workarounds are:
- If your control flow doesn't actually depend on data values, consider modifying your code to perform control flow on constants.
```{code-cell}
# old
x = torch.randn(3, 3)
@torch.compile
def fn(y):
if x.sum() > 0:
return y + x
else:
return y - x
print(fn(torch.ones(3, 3)))
```
```{code-cell}
# new
x = torch.randn(3, 3)
cond = (x.sum() > 0).item()
@torch.compile
def fn(y):
if cond:
return y + x
else:
return y - x
print(fn(torch.ones(3, 3)))
```
- Use higher-order ops like {ref}`cond` in place of data-dependent control flow
```{code-cell}
# old
@torch.compile
def fn(x):
if x.sum() > 0:
return x + 1
return x - 1
print(fn(torch.ones(3, 3)))
```
```{code-cell}
# new
@torch.compile
def fn(x):
return torch.cond(
x.sum() > 0,
lambda x: x + 1,
lambda x: x - 1,
(x,),
)
print(fn(torch.ones(3, 3)))
```
- If you have a `.item()` call, try `torch._dynamo.config.capture_scalar_outputs = True`
or `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1`.
- Wrap problematic parts of the function in a custom operator
## Printing and logging
Printing/logging/issuing warnings will result in a graph break.
You can try working around this by using `torch._dynamo.config.reorderable_logging_functions`.
This config is used to reorder logging functions so that they are called at the end of the
traced function, thus avoiding a graph break.
However, the logged contents may differ if, for example, a mutation occurs.
```{code-cell}
torch._dynamo.config.reorderable_logging_functions.add(print)
@torch.compile
def fn(x):
x += 1
print("log!")
return torch.sin(x)
print(fn(torch.ones(3, 3)))
```

View File

@ -0,0 +1,75 @@
---
file_format: mystnb
kernelspec:
name: python3
mystnb:
execution_timeout: 30
execution_show_tb: True
merge_streams: True
---
```{code-cell}
:tags: [remove-cell]
import torch
import header_code
torch._logging.set_logs(graph_breaks=True, graph_code=True)
```
# Disabling and Suppressing Errors
For some model architectures, there are portions of the model which are particularly difficult to compile -
either there are many graph breaks, or there are crashes.
You may want to explicitly disable these portions of the model which are problematic so that you can apply
`torch.compile` to the parts that work. You can do this by using the `@torch.compiler.disable` decorator.
When `torch.compile` attempts to call a disabled function, it breaks the graph and skips tracing the disabled function,
resuming tracing after the call. By default, all recursive calls made from a disabled function are also disabled.
Use the `recursive=False` option to allow compilation for recursive calls.
```{code-cell}
def inner1(x):
torch._dynamo.graph_break() # not traced
return x + 1 # not traced
@torch.compiler.disable
def outer1(x):
x = x + 2 # not traced
torch._dynamo.graph_break() # not traced
return inner1(x)
@torch.compile
def f(x):
x = outer1(x)
return x + 4 # traced
print(f(torch.ones(3)))
```
```{code-cell}
def inner2(x):
torch._dynamo.graph_break() # traced
return x + 1 # traced
@torch.compiler.disable(recursive=False)
def outer2(x):
x = x + 2 # not traced
torch._dynamo.graph_break() # not traced
return inner2(x)
@torch.compile
def g(x):
x = outer2(x)
return x + 4 # traced
print(g(torch.ones(3)))
```
For example, one can use `torch.compiler.disable` to disable `torch.compile` on sparse architecture in
recommendation models, as the sparse arch is difficult to compile.
Preprocessing and logging functions are other examples of functions that typically cause
a lot of graph breaks and do not get value from being compiled.
If you are experiencing compiler crashes and you want to continue regardless,
you can set `torch._dynamo.config.suppress_errors = True`.
When the compiler crashes, we will just skip tracing the function and try again later.
**This is not best practice** - it is better to eventually manually add `disable` annotations as necessary.

View File

@ -0,0 +1,12 @@
# Custom Operators
**Summary:**
- Use custom operators to have `torch.compile` treat a function as opaque. `torch.compile` will never trace into the function and Inductor (the backend) will run the function as-is.
You may wish to use a custom operator in any of the following situations:
- Your code calls some C/C++/CUDA code. Dynamo is a Python bytecode interpreter and generally does not know how to handle calls to C/C++/CUDA functions that are bound to Python.
- Dynamo and non-strict tracing have trouble tracing through a function and you want it to be ignored by `torch.compile`.
Please see [the Python custom ops tutorial](https://pytorch.org/tutorials/advanced/python_custom_ops.html#python-custom-ops-tutorial)for more details on how to wrap a Python function into a `torch.compile`-understood custom operator.
For more advanced use cases, you may wish to use our C++ Custom Operator API; please see [here](https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html) for more information.

View File

@ -0,0 +1,167 @@
---
file_format: mystnb
kernelspec:
name: python3
mystnb:
execution_timeout: 30
execution_show_tb: True
merge_streams: True
---
```{code-cell}
:tags: [remove-cell]
import torch
import header_code
```
# Dynamo Core Concepts
**Summary:**
- Dynamo, `torch.compile`'s frontend, performs **tracing** to capture the semantics of a Python function
(and its nested function calls) into a linear sequence of operations (the "(FX) graph"),
residual bytecode, and "guards" (a list of conditions under which the graph and bytecode are valid).
- Unsupported Python features lead to **graph breaks**, where Dynamo compiles a partial graph acquired from tracing,
then runs the unsupported code, then resumes tracing.
- Graph breaks may lead to slowness in torch.compile and prevent backend optimization opportunities.
If you're not seeing the performance you expect, then check for graph breaks.
## Dynamo Tracing
`torch.compile`'s frontend (Dynamo) is a custom Python bytecode interpreter designed to allow graph compilation
in PyTorch programs while retaining the full flexibility of Python. Given a function to be compiled, Dynamo
interprets Python bytecode to extract sequences of PyTorch operations into 1 or more FX graphs that may be further optimized by a backend.
![Summary diagram of Dynamo](_static/dynamo_summary_diagram.png)
For example, for the function `f` in the above diagram, Dynamo produces:
- a single **FX graph** that takes in the original input plus some additional inputs required by the function.
- **Python bytecode** that can be used as a drop-in replacement for `f`. In our example, the bytecode retrieves
the additional inputs and passes it to the graph and also contains unoptimizable Python side effects (the list append)
- **guards** that specify the conditions under which the graph and bytecode are valid. Unless otherwise specified,
the graph produced by Dynamo specializes on the shapes of input Tensors.
(programming_model.dynamo_core_concepts.graph_breaks)=
## Graph Breaks
Dynamo traces your code and attempts to capture your PyTorch code into a single computation graph of PyTorch
operators (FX graph). However, this is not always possible. When encountering code that can't be traced, a "**graph break**" occurs.
In the default `torch.compile` settings, a graph break involves compiling the FX graph that has been determined so far,
running the unsupported code in regular Python, then resuming tracing after the unsupported code with a new FX graph.
Graph breaks are a feature that allows Dynamo to run over arbitrary Python code and carve out functional subgraphs that can each be individually optimized.
However, it is possible for graph breaks to lead to unexpected slowness in `torch.compile`.
If you're not getting the speedups you expect, we recommend checking for graph breaks and removing them.
Graph breaks may occur on things like:
- Data-dependent if-statements
- Many Python built-in functions
- C functions
```{code-cell}
:tags: [remove-cell]
torch._logging.set_logs(graph_breaks=True)
```
Below is an example of a graph break due to calling an unsupported operation `torch.save`:
```{code-cell}
@torch.compile
def f(x):
y = x ** 2 / 2
torch.save(y, "foo.pt") # torch.save is an unsupported operation
z = y ** 3 / 6
return z
x = torch.randn(3)
print(f(x))
```
```{code-cell}
:tags: [remove-cell]
import os
os.remove("foo.pt")
```
The semantics of `torch.compile(f)(x)` are roughly this:
```python
def compiled_f_semantics(x):
y = torch.compile(g, fullgraph=True)(x)
torch.save(y, "foo.pt")
z = torch.compile(h, fullgraph=True)(x)
return z
def g(x):
return x ** 2 / 2
def h(x):
return y ** 3 / 6
```
## Guards
`torch.compile` makes some assumptions about runtime values as we trace through code. During tracing, we generate "guards",
which are runtime checks for these assumptions. Guards are run in future calls to the compiled function to determine if we
can reuse previously compiled code. Examples of runtime checks are constant values, types, and object IDs.
Below is an example of generated guards. The `TENSOR_MATCH` guard checks for the input's type, device, dtype, shape, etc.
```{code-cell}
:tags: [remove-cell]
torch._logging.set_logs(guards=True)
```
```{code-cell}
@torch.compile
def fn(x):
return x + 1
print(fn(torch.ones(3, 3)))
```
## Recompilations
If the guards fail for every instance of previously compiled code, then `torch.compile` must "recompile" the function,
requiring the original code to be traced again. In the example below, recompilation is necessary because the guard checking the tensor argument's shape failed.
```{code-cell}
:tags: [remove-cell]
torch._logging.set_logs(recompiles=True)
```
```{code-cell}
@torch.compile
def fn(x):
return x + 1
print(fn(torch.ones(3, 3)))
print(fn(torch.ones(4, 4)))
```
## Dynamic Shapes
`torch.compile` initially assumes tensor shapes are static/constant and guards based on these assumptions. By using "dynamic shapes,"
we can get `torch.compile` to produce compiled code that can accept tensor inputs with different shapes - we avoid recompiling every time shapes differ.
By default, automatic dynamic shapes are enabled in `torch.compile(dynamic=None)` - if compilation fails due to shape mismatch,
recompilation is attempted with dynamic shapes. Dynamic shapes can also be fully enabled (`dynamic=True`) or disabled (`dynamic=False`).
Below, we enable dynamic shapes and note that we no longer need to recompile.
```{code-cell}
:tags: [remove-cell]
import logging
torch._logging.set_logs(dynamic=logging.DEBUG, recompiles=True)
```
```{code-cell}
@torch.compile(dynamic=True)
def fn(x):
return x + 1
print(fn(torch.ones(3, 3)))
print(fn(torch.ones(4, 4)))
```
For more information on dynamic shapes, see [The dynamic shapes manual](https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit?tab=t.0#heading=h.fh8zzonyw8ng).

View File

@ -0,0 +1,101 @@
---
file_format: mystnb
kernelspec:
name: python3
mystnb:
execution_timeout: 30
execution_show_tb: True
merge_streams: True
---
```{code-cell}
:tags: [remove-cell]
import torch
import header_code
```
# Use `torch._dynamo.nonstrict_trace`
**Summary:**
- Use `nonstrict_trace` to trace a function with non-strict tracing inside of a `torch.compile`'d region.
You may wish to do this because the Dynamo graph breaks on something inside of the function
and you are sure that the function is non-strict traceable.
Consider the following scenario:
```{code-cell}
def get_magic_num():
# This explicit graph break call is meant to emulate any kind of Dynamo
# graph break, e.g., the function is implemented in C, or uses some python
# language feature Dynamo doesn't yet support.
torch._dynamo.graph_break()
return torch.tensor([42])
@torch.compile(fullgraph=True)
def func(x):
n = get_magic_num()
return x + n
try:
func(torch.rand(10))
except Exception as e:
print(e)
```
If we run the code above, we'll get an error from Dynamo, because it sees a graph break while the user specified `fullgraph=True`.
In these situations, if a user still wants to keep `fullgraph=True`, they typically have several options:
1. The graph break is due to a language feature Dynamo doesn't yet support.
In this case, the user either rewrites their code, or files an issue on GitHub.
2. The graph break is due to a call to a function implemented in C.
In this case, the user can try to use a custom op.
The user could also try providing a polyfill (a reference implementation in Python)
so that Dynamo can trace through it.
3. Worst case scenario -- an internal compiler error. In this case, the user likely has to file an issue on GitHub.
In addition to all these options, PyTorch does provide an alternative `torch._dynamo.nonstrict_trace`, if the function call that induced the graph break satisfies certain requirements:
- The requirements of [general non-strict tracing](programming_model.non_strict_tracing_model).
- The inputs and outputs must contain either basic types (e.g., `int`, `float`, `list`, `dict`, `torch.Tensor`),
or user-defined types that are registered to `torch.utils._pytree`.
- The function must be defined outside the `torch.compile`'d region.
- Any non-input values read by the function will be treated as a constant
(e.g., a global tensor), and will not be guarded on.
When tracing through a call to a `torch._dynamo.nonstrict_trace`'d function, `torch.compile` switches to [non-strict tracing](programming_model.non_strict_tracing_model),
and the FX graph will eventually contain all the relevant tensor operations which happened inside that function.
For the example above, we can use `torch._dynamo.nonstrict_trace to eliminate` the graph break:
```{code-cell}
@torch._dynamo.nonstrict_trace
def get_magic_num():
# This explicit graph break call is meant to emulate any kind of Dynamo
# graph break, e.g., the function is implemented in C, or uses some python
# language feature Dynamo doesn't yet support.
torch._dynamo.graph_break()
return torch.tensor([42])
@torch.compile(fullgraph=True)
def func(x):
n = get_magic_num()
return x + n
print(func(torch.rand(10)))
# No graph break and no error.
```
Note that one can use it inside a `torch.compile`'d region as well:
```{code-cell}
def get_magic_num():
# This explicit graph break call is meant to emulate any kind of Dynamo
# graph break, e.g., the function is implemented in C, or uses some python
# language feature Dynamo doesn't yet support.
torch._dynamo.graph_break()
return torch.tensor([42])
@torch.compile(fullgraph=True)
def func(x):
n = torch._dynamo.nonstrict_trace(get_magic_num)()
return x + n
print(func(torch.rand(10)))
# No graph break and no error.
```

View File

@ -0,0 +1,24 @@
# Working with `fullgraph=False`
While `fullgraph=False` is the default `torch.compile` setting, the semantics of resuming compilation upon encountering a graph break are more complicated.
You can find details on the `fullgraph=False` semantics in the subsections.
The strategy for using `torch.compile(fullgraph=False)` is as follows:
1. [Determine the ideal location to place `torch.compile`](programming_model.where_to_apply_compile). Normally, it is the highest-level function that doesnt result in excessive graph breaks.
Functions that do a lot of preprocessing or I/O operations are examples of functions that result in many graph breaks and do not significantly benefit from `torch.compile`.
a. You can isolate issues by first compiling individual functions/modules before compiling entire models.
2. [Apply `torch.compiler.disable` to functions in the compiled region that result in a lot of graph breaks
and do not benefit from compilation](programming_model.compiler_disable). In this case, one graph break is better than potentially tens or hundreds.
3. [Use `TORCH_LOGS="graph_breaks"` or tlparse to investigate remaining graph breaks.](programming_model.observability)
Work around these graph breaks using the same approaches as working around graph breaks under
the `fullgraph=True` programming model. Not all graph breaks need to be removed - some may
impact performance more than others. The general rule is to focus on graph breaks that are happening during model computation.
a. We recommend using `torch.compile(backend='eager')` when debugging graph breaks, for faster debugging iteration times
```{toctree}
programming_model.where_to_apply_compile
programming_model.compiler_disable
programming_model.nested_graph_breaks
programming_model.skipped_functions
```

View File

@ -0,0 +1,247 @@
---
file_format: mystnb
kernelspec:
name: python3
mystnb:
execution_timeout: 30
execution_show_tb: True
merge_streams: True
---
```{code-cell}
:tags: [remove-cell]
import torch
import header_code
```
# Use `fullgraph=True` to Identify and Eliminate Graph Breaks
Using `torch.compile(fullgraph=False)` (the default) is a good way to get started with `torch.compile`: it supports all Python programs out-of-the-box via the ability to graph break and gives good performance on common cases.
However, if you're trying to get more performance out of your model, you should explicitly think about what regions of code should be compiled:
- We recommend using `torch.compile(fullgraph=True)` to find and eliminate graph breaks in your code.
- If you're a library developer (or testing if your code "works" with `torch.compile`), we recommend testing using `torch.compile(fullgraph=True)`.
`torch.compile(fullgraph=True)` offers stronger guarantees over `fullgraph=False`:
we will always capture a single FX graph to be compiled (or error if we cannot due to a graph break).
**In particular, you are forced to resolve every graph break that is encountered.**
There are a number of strategies for resolving a graph break.
## Strategy 1: Rewrite the unsupported code to use features supported by Dynamo
Many graph break error messages will give some suggestions on how to rewrite code to avoid the graph break.
If the graph break is still difficult to resolve, then please move on to the next strategy
or submit an issue to the [PyTorch GitHub repo](https://github.com/pytorch/pytorch/issues).
More graph break examples and how to resolve them can be found in [Common Graph Breaks](programming_model.common_graph_breaks).
Example: Dynamo does not support calling `next` on a `list_iterator` object that was an input to the function being compiled.
```{code-cell}
@torch.compile(fullgraph=True)
def f(xs):
a = next(xs)
b = next(xs)
return a + b
xs = [torch.tensor(1.), torch.tensor(2.)]
try:
out = f(iter(xs))
except Exception as e:
print(e)
```
Instead, rewrite the compiled function to accept a list.
```{code-cell}
@torch.compile(fullgraph=True)
def f_rewritten(xs):
it = iter(xs)
a = next(it)
b = next(it)
return a + b
f_rewritten(xs)
```
## Strategy 2: Pure functions can always be compiled via an escape hatch.
**Summary**: The space of all Python functions is vast and thus it is impractical for Dynamo to be able to trace
through every Python function without graph breaks. For Python functions considered to be "pure"
that Dynamo cannot trace through without graph breaks, we provide some escape hatches to attempt
to trace through these functions anyway:
1. Use `custom_op` or `triton_op` on pure triton kernels.
2. Use `nonstrict_trace` for pure functions that only use PyTorch Tensor ops.
3. Use `custom_op` for all other pure functions.
A "pure function" is a function with the following properties:
- Determinism. Given the same inputs, the pure function will always return the same output
- No external side effects. A pure function does not have any externally-visible side effects,
such as modifying external state or performing I/O operations.
Side effects that remain internal to the function are allowed (e.g. mutating intermediate tensors).
One notable exception is that mutating `torch.*` ops on function input Tensors are generally allowed.
- Explicit input/output. All the input data must be passed through the function parameters and all of the outputs are returned from the function.
See [Pure Functions](programming_model.non_strict_tracing_model.pure_functions) for examples.
Dynamo is theoretically able to handle a wide variety of impure functions, but may be lacking coverage for specific
Python language features. However, pure functions can always be compiled via an escape hatch.
If you have a graph break it may be possible to refactor the code around it into a pure function and use an escape hatch that bypasses Dynamo tracing:
1. Use `torch._dynamo.nonstrict_trace` if you want the Tensor operations in the function to show up in the Dynamo output graph (and therefore be optimizable). `nonstrict_trace` tells Dynamo to use **non-strict tracing**.
2. Use custom operators if you want the function to be opaque w.r.t. to `torch.compile` (both the frontend Dynamo and the backend).
Note that there is nothing preventing these escape hatches from being applied to impure functions,
but **we do not provide any soundness guarantees**.
Example: If Dynamo doesn't support some Python feature or API that is non-strict traceable (e.g. it uses PyTorch operations), [use `torch._dynamo.nonstrict_trace` to capture it instead](programming_model.dynamo_nonstrict_trace).
```{code-cell}
# this is a function that Dynamo doesn't support (due to the graph_break() call).
def g(x):
y = x.sin()
torch._dynamo.graph_break()
z = y.sin()
return z
@torch.compile(fullgraph=True)
def f(x):
w = x.sin()
return g(w)
x = torch.randn(3)
try:
f(x) # Graph Break: there was a call to torch._dynamo.graph_break()
except Exception as e:
print(e)
@torch.compile(fullgraph=True)
def f_rewritten(x):
w = x.sin()
return torch._dynamo.nonstrict_trace(g)(w)
f_rewritten(x) # works
```
Example: use [custom operators](programming_model.custom_ops) to create opaque functions w.r.t. to `torch.compile`
```{code-cell}
from torch.utils.cpp_extension import load_inline
# C++ source code for the square operation
cpp_source = """
torch::Tensor square_cpu(torch::Tensor input) {
// Check that input is a CPU tensor
TORCH_CHECK(input.device().is_cpu(), "Input must be a CPU tensor");
// Create output tensor with same shape and dtype as input
torch::Tensor output = torch::empty_like(input);
// Get data pointers
float* input_data = input.data_ptr<float>();
float* output_data = output.data_ptr<float>();
// Get total number of elements
int64_t numel = input.numel();
// For loop to compute square of each element
for (int64_t i = 0; i < numel; i++) {
output_data[i] = input_data[i] * input_data[i];
}
return output;
}
"""
# Load the extension inline
square_module = load_inline(
name="square_cpu_kernel",
cpp_sources=cpp_source,
functions=["square_cpu"],
verbose=True
)
def square(x):
return square_module.square_cpu(x)
@torch.compile(fullgraph=True)
def f(x):
return square(x)
try:
f(torch.randn(3, 3)) # graph break
except Exception as e:
print(e)
```
```{code-cell}
# Use torch.library.custom_op to define a new custom operator.
# Custom operators are opaque with respect to torch.compile:
# that is, torch.compile does not peek into them.
@torch.library.custom_op("mylib::square", mutates_args=())
def square(x: torch.Tensor) -> torch.Tensor:
return square_module.square_cpu(x)
# Use register_fake to add a ``FakeTensor`` kernel for the operator
@square.register_fake
def _(x):
return x.new_empty(x.size())
print(f(torch.randn(3, 3))) # no graph break
```
For more information on `triton_op` for custom triton kernels, see the
[user-defined triton kernel tutorial](https://docs.pytorch.org/tutorials/recipes/torch_compile_user_defined_triton_kernel_tutorial.html).
## Strategy 3: Don't compile the code
Not all code is amenable to being compiled. `torch.compile` is a compiler for Tensor computation;
it will not be able to optimize things like disk IO. Try to refactor the code such that the unsupported
code is not called in the compiled region.
```{code-cell}
@torch.compile(fullgraph=True)
def f(x):
y = x ** 2 / 2
torch.save(y, "foo.pt")
z = y ** 3 / 6
return z
x = torch.randn(3)
try:
f(x) # Graph Break: torch.save not supported
except Exception as e:
print(e)
```
```{code-cell}
def f_rewritten(x):
y = g(x)
torch.save(y, "foo.pt")
z = h(y)
return z
@torch.compile(fullgraph=True)
def g(x):
y = x ** 2 / 2
return y
@torch.compile(fullgraph=True)
def h(y):
z = y ** 3 / 6
return z
f_rewritten(x)
```
```{code-cell}
:tags: [remove-cell]
import os
os.remove("foo.pt")
```

View File

@ -0,0 +1,21 @@
# Working with Graph Breaks
As you might remember from (Dynamo Core Concepts)[programming_model.dynamo_core_concepts] that Dynamo performs a graph break when
it encounters code that can't be traced. In the default `torch.compile` settings, Dynamo compiles the FX graph
that has been determined up to that point, executes the unsupported code in regular Python, and then resumes tracing.
Graph breaks enable Dynamo to trace through arbitrary Python code and carve out functional
subgraphs that can each be individually optimized.
However, graph breaks may cause unexpected slowness in `torch.compile`.
If you're not seeing the expected speedups, we recommend checking for graph breaks and removing them.
The following sections outline strategies for addressing graph breaks.
```{toctree}
programming_model.fullgraph_true
programming_model.common_graph_breaks
programming_model.dynamo_nonstrict_trace
programming_model.custom_ops
programming_model.fullgraph_false
```

View File

@ -0,0 +1,16 @@
# torch.compile Programming Model
The `torch.compile` programming model:
1. Clarifies some internal behaviors of `torch.compile` so that one can better predict compiler behavior on user code and
2. Provides ways for one to take more fine-grained control over `torch.compile`.
By understanding the `torch.compile` programming model, one can systematically unblock themselves when encountering issues with `torch.compile`.
```{toctree}
programming_model.dynamo_core_concepts
programming_model.graph_breaks_index
programming_model.non_strict_tracing_model
programming_model.recompilation
programming_model.observability
programming_model.reporting_issues
```

View File

@ -0,0 +1,191 @@
# Nested Graph Breaks
Summary:
- Graph breaks in nested functions can result in hard-to-understand compiler behavior, which we document below
- A nested graph break results in {math}`\mathcal O(N)` duplicate graph break behavior
Recall that when `torch.compile` is applied to a function, any nested function calls are also traced.
A **nested graph break** refers to any graph break that happens in a nested function call.
```python
def inner(x):
...
torch._dynamo.graph_break() # nested graph break
...
@torch.compile
def outer(x):
...
y = inner(x)
...
```
The resumption semantics around nested graph breaks can be confusing, so we describe the behavior here.
Recall that in `fullgraph=False`, [graph breaks are handled](programming_model.dynamo_core_concepts.graph_breaks) by compiling the FX graph that has been determined so far,
running the unsupported code in regular Python, then resuming tracing after the unsupported code with a new FX graph.
Resuming a function is actually a fairly complicated technical feat, so resuming tracing is only supported on top-level functions.
We can therefore resume tracing after a nested graph break with this restriction in the following way:
First, consider the below example where `torch.compile` traces from `f` and traces all the way until the
graph break in `inner1` is encountered.
```python
def inner1(x):
x = x + 1
torch._dynamo.graph_break() # stop tracing due to graph break
return x + 2
def inner2(x):
x = x + 4
x = inner1(x)
x = x + 8
@torch.compile
def f(x):
# start tracing from here
x = x + 16
x = inner2(x)
x = x + 32
f(torch.randn(3))
```
Since we can only resume from top-level functions, we graph break on the `inner2` call in `f`.
```python
# The semantics of torch.compile(f)(x) is roughly this:
def compiled_f_semantics(x):
y = x + 16
z = inner2(y)
return torch.compile(resume_f_semantics)(z)
def resume_f_semantics(x):
return x + 32
compiled_f_semantics(torch.randn(3))
```
`inner2` is then automatically compiled as a top-level function.
We trace all the way until the graph break in `inner1` is encountered again.
```python
def inner1(x):
x = x + 1
torch._dynamo.graph_break() # stop tracing due to graph break
return x + 2
# this torch.compile is automatically applied
@torch.compile
def inner2(x):
# start tracing from here
x = x + 4
x = inner1(x)
x = x + 8
def compiled_f_semantics(x):
y = x + 16
z = inner2(y)
return torch.compile(resume_f_semantics)(z)
def resume_f_semantics(x):
return x + 32
compiled_f_semantics(torch.randn(3))
```
Then we graph break on the `inner1` call in `inner2`.
```python
def compiled_inner2_semantics(x):
y = x + 4
z = inner1(y)
return torch.compile(resume_inner2_semantics)(z)
def resume_inner2_semantics(x):
return x + 8
```
`inner1` is then automatically compiled as a top-level function.
The graph break is from `inner1`, so we handle the graph break normally.
```python
# this torch.compile is automatically applied
@torch.compile
def inner1(x):
# start tracing from here
x = x + 1
torch._dynamo.graph_break() # stop tracing due to graph break
return x + 2
def compiled_f_semantics(x):
y = x + 16
z = compiled_inner2_semantics(y)
return torch.compile(resume_f_semantics)(z)
def resume_f_semantics(x):
return x + 32
def compiled_inner2_semantics(x):
y = x + 4
z = inner1(y)
return torch.compile(resume_inner2_semantics)(z)
def resume_inner2_semantics(x):
return x + 8
compiled_f_semantics(torch.randn(3))
```
`inner1` is handled normally:
```python
def compiled_inner1_semantics(x):
y = x + 1
torch._dynamo.graph_break()
return torch.compile(resume_inner1_semantics)(y)
def resume_inner1_semantics(x):
return x + 2
```
So the initial code is semantically equivalent to
```python
def compiled_f_semantics(x):
y = x + 16
z = compiled_inner2_semantics(y)
return torch.compile(resume_f_semantics)(z)
def resume_f_semantics(x):
return x + 32
def compiled_inner2_semantics(x):
y = x + 4
z = compiled_inner1_semantics(y)
return torch.compile(resume_inner2_semantics)(z)
def resume_inner2_semantics(x):
return x + 8
def compiled_inner1_semantics(x):
y = x + 1
torch._dynamo.graph_break()
return torch.compile(resume_inner1_semantics)(y)
def resume_inner1_semantics(x):
return x + 2
compiled_f_semantics(torch.randn(3))
```
Note in particular that we traced 3 top-level functions, and that we traced the same graph break 3 times.
**This explains why you may encounter duplicate graph breaks when using `torch.compile`.**
In summary, nested graph breaks are handled by:
- Tracing from the top-level function all the way to the nested graph break
- Graph breaking on the top-level function at the call to the second-level function
- Compiling the PyTorch ops tracked so far and running the compiled graph
- Calling the second-level function, which gets automatically compiled as a top-level function
- Resuming tracing after the second-level function call
Note that the runtime of handling this graph break is {math}`\mathcal O(NK)`, where {math}`N` is the nesting depth,
and {math}`K` is the number of instructions from the top-level function to the graph break.
We end up tracing {math}`\mathcal O(N^2)` frames, and we trace the same graph break {math}`\mathcal O(N)` times.

View File

@ -0,0 +1,204 @@
---
file_format: mystnb
kernelspec:
name: python3
mystnb:
execution_timeout: 30
execution_show_tb: True
merge_streams: True
---
```{code-cell}
:tags: [remove-cell]
import torch
import header_code
```
# Non-strict Tracing Programming Model
**Summary:**
- **Non-strict tracing** is a way to trace Python code that is less strict than Dynamo, but may result in silent incorrectness.
- Non-strict tracing runs a Python function and uses Python and PyTorchs operator overloading capabilities to record what Tensor operations occurred during execution into a trace.
- A function is **non-strict traceable** if it complies with some constraints, namely, that the function is **pure** and does not directly manipulate Tensor.data_ptr().
- Non-strict tracing may **specialize** on certain variables and treat them as **constants**, baking the values of the variables into the trace.
`torch.compile` internals (`make_fx`, AOTDispatcher) use **non-strict tracing**. [`torch._dynamo.nonstrict_trace`](programming_model.dynamo_nonstrict_trace) can also be used in `torch.compile`d code to mark sections of code to be traced with non-strict tracing.
Non-strict tracing runs a Python function and uses Python and PyTorchs operator overloading capabilities to record what Tensor operations occurred during execution into a trace.
**`make_fx`** is the main entrypoint for non-strict tracing. For the following function, only the top branch is taken during execution of the inputs, so it captures a graph with only that branch.
```{code-cell}
from torch.fx.experimental.proxy_tensor import make_fx
def f(x):
if x.shape[0] > 2:
return x ** 2 / 6
else:
return x * 3
x = torch.randn(3)
gm = make_fx(f, tracing_mode="fake")(x)
gm.print_readable()
```
Non-strict tracing differs from Dynamo (strict) tracing in that **it is unsafe**, that is, given a function, it captures a graph of Tensor operations that may have different semantics than the original function.
Given a Python function, Dynamo Tracing captures a graph of Tensor operations and residual bytecode that when combined give the same semantics as the Python function.
(programming_model.non_strict_tracing_model.pure_functions)=
## Pure Functions
Non-strict tracing is sound only on **pure functions**, and thus only pure functions should be non-strict traced.
A pure function is a function with the following properties:
- **Determinism.** Given the same inputs, the pure function will always return the same output.
- **No side effects.** A pure function does not have any side effects such as modifying external state or performing I/O operations.
- **Explicit input/output.** All the input data must be passed through the function parameters and all of the outputs are returned from the function.
Here are some examples of impure functions for which the captured graph behaves differently from the original function.
### Example 1: No explicit input (e.g. accesses global tensor)
```{code-cell}
var = torch.tensor(1)
def function_with_global_access(y):
return y + var
x = torch.tensor([0, 1, 2])
# _allow_non_fake_inputs=True is needed to capture the global variable
# for demonstration purposes.
gm = make_fx(
function_with_global_access, tracing_mode="fake", _allow_non_fake_inputs=True
)(x)
# Non-strict Tracing captures the value of the global (1.)
print("1. call function", function_with_global_access(x))
print("1. call graph", gm(x))
# However, after changing the global, the captured graph
# produces a different result from the original function
var = torch.tensor(2)
print("2. call function", function_with_global_access(x))
print("2. call graph", gm(x))
# To capture a graph that can have a varying `var` tensor,
# it must be an explicit input:
def function_fixed(y, var):
return y + var
var = torch.tensor(3)
gm = make_fx(function_fixed, tracing_mode="fake")(x, var)
print("3. call function", function_fixed(x, var))
print("3. call graph", gm(x, var))
var = torch.tensor(4)
print("4. call function", function_fixed(x, var))
print("4. call graph", gm(x, var))
```
See [Specialization and Constants](specialization-and-constants) for an explanation of why.
### Example 2: Side effect (printing)
```{code-cell}
def function_with_side_effect(y):
print(y)
x = torch.tensor([0, 1, 2])
_ = function_with_side_effect(x)
```
Running `f` in Python prints a Tensor as a side effect.
```{code-cell}
gm = make_fx(function_with_side_effect, tracing_mode="fake")(x)
```
During non-strict tracing, this print occurs during the graph capture.
```{code-cell}
_ = gm(x)
```
The graph does not store a call to the `print` statement, so executing the graph doesnt print anything.
### Example 3: Side effect (input list mutation)
```{code-cell}
lst = []
def function_with_input_list_mutation(lst):
val = lst.pop()
return val
x = torch.tensor([0, 1, 2])
y = torch.tensor([0, 1, 2])
# Each time the function is executed, the list shrinks in size
lst = [x, y]
function_with_input_list_mutation(lst)
print("len(lst) after one call", len(lst))
function_with_input_list_mutation(lst)
print("len(lst) after two calls", len(lst))
# With Non-strict Tracing, the length of the list shrinks during
# the graph capture but not in invocations of the graph.
lst = [x, y]
gm = make_fx(function_with_input_list_mutation, tracing_mode="fake")(lst)
print("len(lst) after graph capture", len(lst))
gm(lst)
print("len(lst) after one call to graph", len(lst))
gm(lst)
print("len(lst) after two calls to graph", len(lst))
```
### No direct data_ptr manipulation
Directly manipulating `Tensor.data_ptr` is not non-strict traceable. The intuition behind this is that PyTorch is unable to tell *how* you manipulated the `data_ptr`.
```{code-cell}
import ctypes
# Create a tensor with a single element
tensor = torch.tensor([42], dtype=torch.int32) # Using int32 for simplicity
def function_with_data_ptr(tensor):
# Get the data pointer
ptr = tensor.data_ptr()
# Cast the pointer to a ctypes pointer
ctypes_ptr = ctypes.cast(ptr, ctypes.POINTER(ctypes.c_int32))
# Increment the value at the pointer
ctypes_ptr.contents.value += 1
return tensor
try:
make_fx(function_with_data_ptr, tracing_mode="fake")(tensor)
except Exception as e:
print(e)
```
(specialization-and-constants)=
## Specialization and Constants
Non-strict tracing captures a graph that may be specialized on some values. What this means is the captured graph is only valid for these values. We say the graph treats those values as **constant**.
All non-Tensor variables are treated as constant during Non-strict Tracing:
```{code-cell}
def f(x, y):
return x + y
x = torch.tensor([0, 1, 2])
y = 3.14
gm = make_fx(f, tracing_mode="fake")(x, y)
gm.print_readable()
```
3.14 is a constant in the graph.
Non-strict tracing will also specialize on properties of the input Tensors.
```{code-cell}
def f(x):
if x.shape[0] > 2:
return x ** 2 / 6
else:
return x * 3
x = torch.randn(3)
gm = make_fx(f, tracing_mode="fake")(x)
gm.print_readable()
```
And it will also specialize on any variables not directly passed into the function:
```{code-cell}
var = torch.tensor(1)
def f(x):
return x + y
x = torch.randn(3)
gm = make_fx(f, tracing_mode="fake")(x)
gm.print_readable()
```

View File

@ -0,0 +1,141 @@
# tlparse / TORCH_TRACE
tlparse / `TORCH_TRACE` are a pair of tools that produce compilation reports that look [like this](https://web.mit.edu/~ezyang/Public/bhack-20240609-tlparse/index.html).
Traces are fairly straightforward to collect. To collect a trace, run your model like so:
```bash
TORCH_TRACE="/tmp/tracedir" python foo.py
pip install tlparse
tlparse /tmp/tracedir
```
This approach works even if you are running a distributed job, providing a trace for each rank.
It will open your browser with HTML similar to whats generated above.
If you are making a bug report for a complicated problem that you dont have a standalone reproduction for,
you can still greatly assist PyTorch developers by attaching the trace log generated in `/tmp/tracedir`.
```{warning}
The trace log contains all of your model code.
Do not share the trace log if the model you are working on is sensitive. The trace log does NOT contain weights.
```
```{raw} html
<style>
.red {background-color:#ff0000;}
.green {background-color:#00ff00;}
.dark-green {background-color:#027f02;}
</style>
```
```{eval-rst}
.. role:: red
.. role:: green
.. role:: dark-green
```
The output of `tlparse` is primarily aimed for PyTorch developers,
and the log format is easy to upload and share on GitHub.
However, as a non-PyTorch developer, you can still extract useful information from it.
We recommend starting with the inline help text in the report, which explains its contents.
Here are some insights you can gain from a `tlparse`:
- What model code was compiled by looking at the stack trie?
This is especially useful if you're not familiar with the codebase being compiled!
- How many graph breaks / distinct compilation regions are there?
(Each distinct compile is its own color coded block like {dark-green}`[0/0]`).
Frames that are potentially graph-broken are light green {green}`[2/4]`.
If there are a lot of frames, that is suspicious, and suggests that you had some catastrophic graph breaks,
or maybe your code isn't a good match for `torch.compile`.
- How many times did I recompile a particular frame? Something that recompiled a lot will look like:
{dark-green}`[10/0]` {dark-green}`[10/1]` {dark-green}`[10/2]`
\- if something is being recompiled a lot, that is very suspicious and worth looking into, even if it isn't the root cause of your problem.
- Was there a compilation error? Frames that errored will look like {red}`[0/1]`.
- What intermediate compiler products did I generate for a given frame?
For example, you can look at the high-level generated FX graph or the generated Triton code.
- Is there relevant information for a particular frame? You can find these in `compilation_metrics`.
## TORCH_LOGS
You can use the `TORCH_LOGS` environment variable to selectively enable parts of the `torch.compile` stack to log.
`TORCH_LOGS` is in fact the source of logs for `tlparse`. The format of the `TORCH_LOGS` environment variable looks like this:
```bash
TORCH_LOGS="<option1>,<option2>,..." python foo.py
```
You can also programmatically set logging options using `torch._logging.set_logs`:
```python
import logging
torch._logging.set_logs(graph_breaks=True, dynamic=logging.DEBUG)
```
The most useful options are:
- `graph_breaks`: logs locations of graph breaks in user code and the reason for the graph break
- `guards`: logs guards that are generated
- `recompiles`: logs which function recompiled and the guards that failed, leading to the recompilation
- `dynamic`: logs related to dynamic shapes
- `output_code`: logs the code generated by Inductor
Some more helpful `TORCH_LOGS` options include:
```{eval-rst}
.. list-table::
:widths: 25 50
:header-rows: 1
* - Option
- Description
* - +all
- Output debug logs from all ``torch.compile`` components
* - +dynamo
- Output debug logs from TorchDynamo
* - +aot
- Output debug logs from AOTAutograd
* - +inductor
- Output debug logs from TorchInductor
* - dynamic
- Output logs from dynamic shapes
* - graph_code
- Output the Python code for the FX graph that Dynamo generated
* - graph_sizes
- Output the tensor sizes of the FX graph that Dynamo generated
* - trace_bytecode
- Output the bytecode instructions that Dynamo is tracing through and the symbolic interpreter stack Dynamo is keeping track of
* - trace_source
- Output the line of code in the original source that Dynamo is currently tracing through
* - bytecode
- Output Dynamo-generated bytecode
* - guards
- Output generated guards
* - recompiles
- Output recompilation reasons (only the first guard check that fails)
* - recompiles_verbose
- Output all guard checks that fail when a recompilation occurs
* - aot_graphs
- Output graph generated by AOTAutograd
* - aot_joint_graphs
- Output the joint forward-backward graph generated by AOTAutograd
* - output_code
- Output code generated by Inductor
* - kernel_code
- Output code generated by Inductor on a per-kernel basis
* - schedule
- Output Inductor scheduling logs
* - perf_hints
- Output Inductor perf hint logs
* - fusion
- Output Inductor fusion logs
```
For the full list of options, see [torch.\_logging](https://pytorch.org/docs/stable/logging.html)
and [torch.\_logging.set_logs](https://pytorch.org/docs/stable/generated/torch._logging.set_logs.html#torch._logging.set_logs).
## tlparse vs. TORCH_LOGS
Generally, we suggest first using `tlparse` when encountering issues.
`tlparse` is ideal for debugging large models and gaining a high-level overview of how your model was compiled.
On the other hand, `TORCH_LOGS` is preferred for small examples and fine-grained debugging detail,
when we already have an idea of which `torch.compile` component is causing the problem.

View File

@ -0,0 +1,161 @@
---
file_format: mystnb
kernelspec:
name: python3
mystnb:
execution_timeout: 30
execution_show_tb: True
merge_streams: True
---
```{code-cell}
:tags: [remove-cell]
import torch
import header_code
torch._logging.set_logs(recompiles=True)
```
# Dealing with Recompilations
Recompilations are necessary for `torch.compile` soundness, but can result in significantly increased compile time.
Thus, minimizing recompilations while preserving soundness is essential for reducing compile time.
You can view recompilations and their reasons using tlparse or `TORCH_LOGS=recompiles`.
## Is Dynamic Shapes Enabled?
In the below example, we recompile due to mismatched shapes:
```{code-cell}
@torch.compile
def fn(x):
return x + 1
fn(torch.ones(3))
fn(torch.ones(4))
```
Make sure that the dynamic option of `torch.compile` is not set to `False`.
The default option, `dynamic=None`, will only attempt dynamic shapes after the first compilation.
You can set `dynamic=True` to upfront compile as dynamic as possible:
```{code-cell}
@torch.compile(dynamic=True)
def gn(x):
return x + 1
gn(torch.ones(3))
gn(torch.ones(4))
```
For more information on dynamic shapes, including dealing with errors/recompilations due to
dynamic shapes, see [the dynamic shapes manual](https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit?tab=t.0#heading=h.fh8zzonyw8ng).
## Wrapping Constants with Tensors
By default, `int` / `float` variables are treated as constants and are guarded on their exact value.
In the below example, we have a recompilation for each function call.
```{code-cell}
@torch.compile
def fn(x, c):
return x + c
for i in range(5):
fn(torch.ones(i), 0.5 + i)
```
In particular, for LR schedulers, initializing with a constant can lead to recompilations:
```{code-cell}
mod = torch.nn.Linear(3, 3)
opt = torch.optim.Adam(mod.parameters(), lr=0.01)
sched = torch.optim.lr_scheduler.ExponentialLR(opt, 0.9)
@torch.compile
def gn(inp):
opt.zero_grad(True)
out = mod(inp).sum()
out.backward()
opt.step()
sched.step()
for i in range(5):
gn(torch.ones(3, 3))
```
In both examples, we can wrap `float` variables in tensors in order to prevent recompilations.
```{code-cell}
:tags: [remove-cell]
torch._dynamo.reset()
```
```{code-cell}
# first example
for i in range(5):
fn(torch.ones(i), torch.tensor(0.5 + i))
# second example
opt = torch.optim.Adam(mod.parameters(), lr=torch.tensor(0.01))
sched = torch.optim.lr_scheduler.ExponentialLR(opt, torch.tensor(0.9))
for i in range(5):
gn(torch.ones(3, 3))
```
(programming_model.recompilation.changing_cache_size_limit)=
## Changing the Cache Size Limit
There is a limit to how many times a function can be recompiled,
determined by `torch._dynamo.config.cache_size_limit` and `torch._dynamo.config.accumulated_cache_size_limit`
(The exact difference between these 2 values is detailed in [`torch/_dynamo/cache_size.py`](https://github.com/pytorch/pytorch/blob/4ce6e6ec8890a3f6ee604c9efb3ff153825ce575/torch/_dynamo/cache_size.py#L14)).
If the Dynamo cache limit is hit, then all future compilation attempts **will result in the function being skipped (run eagerly)**.
Dynamo will still attempt to use previously compiled bytecode for future function calls, if the guards pass.
Note that in the case of a recompilation limit hit, **all nested function calls WILL be skipped**
(Dynamo will try to use previously compiled bytecode for the nested functions).
Dynamo will also issue a warning containing the affected function and which limit was hit.
In the example below, each function call results in a recompile attempt.
When we hit the cache size limit (by default, 8), we stop attempting to recompile.
(Note that we set `dynamic=False` for demonstration purposes to force recompilation every time).
```{code-cell}
@torch.compile(dynamic=False)
def fn(x):
return x + 1
for i in range(1, 10):
# recompile every time due to dynamic=False
fn(torch.ones(i))
```
If you know that the number of recompilations has a reasonable constant upper bound, you can raise the cache size limit.
If the cost of recompilation outweighs the benefit of compilation, then you can consider lowering the cache size limit.
```{code-cell}
torch._dynamo.config.cache_size_limit = 16
@torch.compile(dynamic=False)
def gn(x):
return x + 1
for i in range(1, 10):
gn(torch.ones(i))
```
## Graph Breaking to Reduce Recompilation Costs
If a large graph is recompiling and causing high compile time, you can intentionally introduce
a graph break in order to reduce recompilation costs, at the expense of introducing a performance hit.
```{code-cell}
def very_large_function(x):
return x + 1
@torch.compile(dynamic=False)
def fn(x, c):
y = very_large_function(x) # recompiled every time
return y + c
for i in range(1, 5):
fn(torch.ones(3), i)
@torch.compile(dynamic=False)
def gn(x, c):
y = very_large_function(x) # compiled only once
torch._dynamo.graph_break()
return y + c # recompiled every time
for i in range(1, 5):
gn(torch.ones(3), i)
```

View File

@ -0,0 +1,73 @@
# Reporting Issues
If the provided workarounds were not enough to get `torch.compile` working,
then you should consider reporting the issue to PyTorch.
But there are a few things that you can do to make our lives significantly easier.
## Ablation
Check which component of the `torch.compile` stack is the one causing the issue using the `backend=` option for `torch.compile`.
In particular, try:
- `torch.compile(fn, backend="eager")`, which only runs TorchDynamo, the graph capture component of `torch.compile`.
- `torch.compile(fn, backend="aot_eager")`, which runs TorchDynamo and AOTAutograd, which additionally generates the backward graph during compilation.
- `torch.compile(fn, backend="aot_eager_decomp_partition")`, which runs TorchDynamo and AOTAutograd with operator decompositions/partitions.
- `torch.compile(fn, backend="inductor")`, which runs TorchDynamo, AOTAutograd, and TorchInductor, the backend ML compiler that generates compiled kernels.
If you only fail with the Inductor backend, you can additionally test various Inductor modes:
- `torch.compile(fn, backend="inductor", mode="default")`
- `torch.compile(fn, backend="inductor", mode="reduce-overhead")`
- `torch.compile(fn, backend="inductor", mode="max-autotune")`
You can also check if dynamic shapes is causing issues with any backend:
- `torch.compile(fn, dynamic=True)` (always use dynamic shapes)
- `torch.compile(fn, dynamic=False)` (never use dynamic shapes)
- `torch.compile(fn, dynamic=None)` (automatic dynamic shapes)
## Bisecting
Did you try on the latest nightly? Did something work in the past but now no longer works?
Can you bisect to determine the first nightly where your issue occurs?
Bisecting is especially helpful for performance, accuracy, or compile time regressions,
where it is not immediately obvious where the problem originates from.
## Creating a reproducer
Creating reproducers is a lot of work, and it is perfectly fine if you do not have the time to do it.
However, if you are a motivated user unfamiliar with the internals of `torch.compile`,
creating a standalone reproducer can have a huge impact on our ability to fix the bug.
Without a reproducer, your bug report must contain enough information for us to identify the root cause of the problem and write a reproducer from scratch.
Here's a list of useful reproducers, ranked from most to least preferred:
1. **Self-contained, small reproducer:** A script with no external dependencies, under 100 lines of code, that reproduces the problem when run.
2. **Self-contained, large reproducer:** Even if it's large, being self-contained is a huge advantage!
3. **Non-self-contained reproducer with manageable dependencies:**
For example, if you can reproduce the problem by running a script after `pip install transformers`,
that's manageable. We can likely run it and investigate.
4. **Non-self-contained reproducer requiring substantial setup:** This might involve downloading datasets,
multiple environment setup steps, or specific system library versions requiring a Docker image.
The more complex the setup, the harder it is for us to recreate the environment.
:::{note}
Docker simplifies setup but complicates changes to the environment, so it's not a perfect solution, though we'll use it if necessary.
:::
If possible, try to make your reproducer single-process, as those are easier to debug than a multi-process reproducer.
Additionally, below is a non-exhaustive list of aspects to check in your
issue that you can attempt to replicate in your reproducer:
- **Autograd**. Did you have tensor inputs with `requires_grad=True`? Did you call `backward()` on the output?
- **Dynamic shapes**. Did you set `dynamic=True`? Or did you run the test code multiple times with varying shapes?
- **Custom operators**. Is there a custom operator involved in the real workflow?
Can you replicate some of its important characteristics using the Python custom operator API?
- **Configuration**. Did you set all the same configuration?
This includes `torch._dynamo.config` and `torch._inductor.config` settings,
as well as arguments to `torch.compile` like `backend` / `mode`.
- **Context managers**. Did you replicate any active context managers?
This could be `torch.no_grad`, automatic mixed precision, `TorchFunctionMode` / `TorchDispatchMode`,
activation checkpointing, compiled autograd etc.
- **Tensor subclasses**. Is there a tensor subclass involved?

View File

@ -0,0 +1,199 @@
---
file_format: mystnb
kernelspec:
name: python3
mystnb:
execution_timeout: 30
execution_show_tb: True
merge_streams: True
---
```{code-cell}
:tags: [remove-cell]
import torch
import header_code
import logging
torch._logging.set_logs(dynamo=logging.DEBUG)
```
# Skipped Functions
**Summary:**
- Sometimes, `torch.compile` completely gives up compiling a function and runs it eagerly instead,
resulting in potentially lost optimization opportunities.
- There are ways to work around skipped functions in order to re-enable tracing around the problematic code.
Sometimes, `torch.compile` with `fullgraph=False` is unable to resume tracing when encountering a graph break
or other compiler error. In many of these cases, `torch.compile` will skip compiling the function entirely and run it eagerly.
Note that the skip is only applied to the current function and NOT any nested function calls.
`torch.compile` will still attempt to compile nested calls.
<!-- TODO: fix logging for skipped functions. -->
```{code-cell}
def inner1(x):
return x + 1
def inner2(x):
return x + 2
@torch.compile
def fn(x):
x = inner1(x)
torch._dynamo.skip_frame()
x = inner2(x)
fn(torch.randn(3))
```
In the above example, `torch.compile` will trace `fn` (including `inner1`) up until the `skip_frame`.
Then `fn` is skipped and run eagerly - `inner1` and `inner2` are compiled when they are called.
Skipping functions may result in lost optimization opportunities,
so it is important to check if code you want compiled is being skipped, and if so, to work around the skip.
## Graph Break in a Loop
`torch.compile` cannot resume tracing if a graph break occurs in a loop:
```{code-cell}
@torch.compile
def fn(x):
for i in range(5):
x = x + 1
if i == 3:
torch._dynamo.graph_break()
return x
fn(torch.randn(3))
```
In this example, we can avoid skipping by unrolling the loop:
```{code-cell}
@torch.compile
def fn(x):
def inner(i):
nonlocal x
x = x + 1
if i == 3:
torch._dynamo.graph_break()
inner(0)
inner(1)
inner(2)
inner(3)
inner(4)
return x
fn(torch.randn(3))
```
In general, resolving the graph break causing the skip will also resolve the skip.
## Graph Break in a Context Manager
Another common example of an unresumable graph break is a graph break in most context managers:
```{code-cell}
class CustomCtxManager:
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, traceback):
pass
@torch.compile
def fn(x):
with CustomCtxManager():
x = x + 1
torch._dynamo.graph_break()
return x + 1
fn(torch.randn(3))
```
We can avoid skipping by moving the graph break outside of the context manager:
```{code-cell}
@torch.compile
def fn(x):
with CustomCtxManager():
x = x + 1
torch._dynamo.graph_break()
with CustomCtxManager():
return x + 1
fn(torch.randn(3))
```
There are some context managers where Dynamo can resume after a graph break.
Some of these can be found in `supported_ctx_manager_classes` in `torch/_dynamo/variables/torch.py`.
In general, any context manager represented by a `ContextWrappingVariable` subclass in
`torch/_dynamo/variables/ctx_manager.py` support resuming after a graph break. For example:
```{code-cell}
import contextlib
@torch.compile
def fn(x):
with contextlib.nullcontext():
with torch.no_grad():
x = x + 1
torch._dynamo.graph_break()
return x + 1
fn(torch.randn(3))
```
## Graph Break in a Try Block
A graph break in a try block cannot be resumed:
```{code-cell}
@torch.compile
def fn(x):
try:
x = x + 1
torch._dynamo.graph_break()
return x + 1
except Exception as e:
pass
fn(torch.randn(3))
```
We can avoid skipping by moving the graph break outside of the try block:
```{code-cell}
@torch.compile
def fn(x):
try:
x = x + 1
except Exception as e:
pass
torch._dynamo.graph_break()
try:
return x + 1
except Exception as e:
pass
fn(torch.randn(3))
```
## Hitting a Recompilation Limit
See [Changing the Cache Size Limit.](programming_model.recompilation.changing_cache_size_limit)
## Compiler Errors
Some compiler errors will result in skipped functions.
Other compiler errors will result in a hard error rather than a skipped function.
## Dealing with Skipped Functions
In general, you can resolve a skipped function by fixing the underlying graph break or error that
is causing the function to be skipped.
If the graph break/error causing the skipped function is difficult to fix,
then consider isolating the graph break/error in its own function so that minimal things are skipped.
```{code-cell}
def inner1(x):
return x + 1
def inner2(x):
return x + 2
@torch.compile
def fn(x):
x = inner1(x)
def problematic_code():
torch._dynamo.skip_frame()
problematic_code()
x = inner2(x)
fn(torch.randn(3))
```

View File

@ -0,0 +1,77 @@
# Where to apply torch.compile?
We recommend applying `torch.compile` to the highest-level function that doesnt cause excessive problems.
Typically, it is:
- your `train` or `eval` step with the optimizer but without the loop,
- your top-level `nn.Module`
- or some sub-`nn.Module`s.
`torch.compile` specifically doesnt handle distributed wrapper modules like DDP or FSDP very well,
so consider applying `torch.compile` to the inner module passed to the wrapper.
```python
# inference
model = ...
model.compile()
for _ in range(N_ITERS):
inp = ...
out = model(inp)
```
```python
# training
model = ...
opt = torch.optim.Adam(model.parameters())
@torch.compile
def train(mod, data):
opt.zero_grad(True)
pred = mod(data[0])
loss = torch.nn.CrossEntropyLoss()(pred, data[1])
loss.backward()
opt.step()
for _ in range(N_ITERS):
inp = ...
train(model, inp)
```
```python
# DistributedDataParallel
model = ...
model.compile()
model_ddp = DistributedDataParallel(model, ...)
for _ in range(N_ITERS):
inp = ...
out = model_ddp(inp)
```
<!-- TODO add examples for specific model domains, compile(model) vs. model.compile()-->
## `compile(model)` vs `model.compile()`
Due to nuances to how `torch.compile` interacts with `nn.Module` instances,
we advise using the `.compile()` method of `nn.Module` instances if you wish to compile them as
top-level functions. Nested module calls will be traced correctly -
there is no need to call `.compile()` in that case.
```python
# DO NOT DO THIS
model = MyModel()
model = torch.compile(model)
model(inp)
# DO THIS
model = MyModel()
model.compile()
model(inp)
# this is also acceptable
@torch.compile
def fn(model, inp):
return model(inp)
model = MyModel()
fn(model, inp)
```

Some files were not shown because too many files have changed in this diff Show More