Commit Graph

91547 Commits

Author SHA1 Message Date
a7abf57aab [ROCm] Support large inputs for coalesceValuesKernel (#158281)
# Description

`.coalesce` cannot handle large inputs on ROCM due to maximal grid size limit.

This PR splits axis `X` into axes `X` and `Y`, and repurposes `Z` for original `Y` on ROCm to avoid such limitation.

Confirmed the new approach can handle large inputs. Correctness needs validation.

# Testing Command

`python torch_spmv.py 22500000 272500000`

## Script `torch_spmv.py`

``` python
import torch
import argparse

def parse_args():
    parser = argparse.ArgumentParser(
        description="Sparse COO Matrix by Dense Vector Multiplication using PyTorch"
    )
    parser.add_argument("n", type=int, help="Size of the NxN matrix")
    parser.add_argument("nnz", type=int, help="Number of non-zero entries")
    return parser.parse_args()

def main():
    args = parse_args()
    n = args.n
    nnz = args.nnz
    dtype = torch.float32
    device = torch.device('cuda')

    # Generate random indices for the sparse matrix in COO format.
    torch.manual_seed(42)
    rows = torch.randint(0, n, (nnz,), dtype=torch.int64, device=device)
    cols = torch.randint(0, n, (nnz,), dtype=torch.int64, device=device)
    indices = torch.stack([rows, cols], dim=0)

    # Generate random values.
    values = torch.randn(nnz, dtype=torch.float32, device=device)

    # Create the sparse COO matrix and move it to the target device.
    sparse_matrix = torch.sparse_coo_tensor(indices, values, size=(n, n), dtype=torch.float32, device=device)
    sparse_matrix = sparse_matrix.coalesce()

    # Generate a random dense vector.
    dense_vector = torch.randn(n, dtype=torch.float32, device=device)

    # Perform sparse matrix - dense vector multiplication.
    # Using torch.sparse.mm which expects a 2D tensor for the vector.
    result = torch.sparse.mm(sparse_matrix, dense_vector.unsqueeze(1)).squeeze()
    # result = torch.mv(sparse_matrix, dense_vector)

    # Print the result.
    print("Result of the multiplication:")
    print(torch.sum(result))

if __name__ == "__main__":
    main()
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158281
Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily
2025-08-12 16:42:55 +00:00
f7b2f3314c Revert "[triton_heuristics] Optimize the triton launcher in pt2 (#160000)"
This reverts commit d0e2240f680ea2a553f7ee8188f52482e130bfd0.

Reverted https://github.com/pytorch/pytorch/pull/160000 on behalf of https://github.com/davidberard98 due to D80054972 failing with test_triton_kernel_2d_autotune_grad_False_dynamic_True_backend_inductor_grid_type_1_tdlp_1 ([comment](https://github.com/pytorch/pytorch/pull/160000#issuecomment-3180144676))
2025-08-12 16:33:02 +00:00
9d37c960a4 [ROCm][CI] use new benchmark image for dynamo (#160421)
Follow-up to #160047 that separated the rocm image into default CI and benchmarks.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160421
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-08-12 16:07:19 +00:00
b219ca2a00 Revert "Update triton xpu commit to support python 3.14 (#160183)"
This reverts commit 7fbc22855c17741ae016992803b2e147a13aa22d.

Reverted https://github.com/pytorch/pytorch/pull/160183 on behalf of https://github.com/clee2000 due to I'm not sure how, but it seems to have broken inductor/test_extension_backend.py::ExtensionBackendTests::test_open_device_registration [GH job link](https://github.com/pytorch/pytorch/actions/runs/16911267995/job/47917091939) [HUD commit link](7fbc22855c).  Maybe because the docker build changed?  Note to self: not bad TD ([comment](https://github.com/pytorch/pytorch/pull/160183#issuecomment-3179840160))
2025-08-12 15:29:19 +00:00
b7db86600a Fix Tensor illustration, use permalinks for image embedding in Readme.md (#160416)
Fixes Tensor illustration being broken on pypi.org. Also uses permalinks instead of links to images for embedding as per this suggestion of Alban: https://github.com/pytorch/pytorch/pull/160187#discussion_r2262978006

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160416
Approved by: https://github.com/malfet
2025-08-12 15:15:12 +00:00
9708fcf92d Account for triton kernel source code hidden in custom ops properly in AOTAutogradCache (#160120)
This PR fixes a bug where user defined triton kernels hidden behind `triton_op` do not register source code changes. If a user *only* changes a triton kernel source_code, because triton kernels are hidden under the custom op, dynamo hasn't traced into them yet.

This means at AOTAutograd time, we don't know the list of triton kernels that are defined by custom ops. This is an initial fix for the issue by parsing the AST of the custom op looking for triton kernels. This won't catch more degenerate cases if the custom op calls other custom ops/functions that then call triton kernels, and then the toplevel compiled graph doesn't know about it. To handle that, we'd have to trace through the custom op at dynamo time.

This should handle 99% of cases, though. I added an expectedFailure test to show the limitation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160120
Approved by: https://github.com/zou3519
2025-08-12 14:11:06 +00:00
a288b15ea9 [CI] Reduce XPU Windows build time (#159763)
Reduce the time cost from 2.5 hours to about 1.5 hours.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159763
Approved by: https://github.com/EikanWang, https://github.com/atalman
2025-08-12 14:04:29 +00:00
7fbc22855c Update triton xpu commit to support python 3.14 (#160183)
Follow PR #159725
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160183
Approved by: https://github.com/EikanWang, https://github.com/atalman
2025-08-12 14:02:36 +00:00
f33ce40bc0 [bucketing] Bucket only adjacent collectives to prevent reordering (#159983)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159983
Approved by: https://github.com/wconstab, https://github.com/eellison
2025-08-12 11:57:00 +00:00
4d5b3f2d5a [dynamo][guards] Install dict watchers for recrusive dict tag optimization (#159796)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159796
Approved by: https://github.com/jansel
2025-08-12 09:49:11 +00:00
f990490a23 Add label_smoothing param in nn.BCELoss and nn.BCEWithLogitsLoss (#150282)
Fixes #91545

## Changes

- Add `label_smoothing` param and docs
- Add test case for `label_smoothing`
- Remove duplicate description in `nn.BCELoss` and `nn.BCEWithLogitsLoss`

##  Test Result

```bash
pytest -s test/test_nn.py -k test_bce
```

![image](https://github.com/user-attachments/assets/30c0b7fe-fe49-4aa0-9b05-4d70403a7b05)

![image](https://github.com/user-attachments/assets/4fe3fd1c-54b8-4012-afd9-133ce9fb4964)

![image](https://github.com/user-attachments/assets/5cad019a-3a4c-475a-9fde-9c1acad5792d)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150282
Approved by: https://github.com/cyyever, https://github.com/mikaylagawarecki
2025-08-12 09:37:03 +00:00
b9003ed3d8 Dynamo Deep Dive Documentation Fix (#158860)
changed SourceBuilder to VariableBuilder

Fixes #158447

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158860
Approved by: https://github.com/mlazos
2025-08-12 08:53:33 +00:00
fea7e9dd37 extract shape in _view_has_unbacked_input (#160255)
Summary: We were getting DDE on reshape still!! i looked deeper and found an issue in _view_has_unbacked_input namely when input is [[,,]] it need to be normalized to [..]

Test Plan:
existing tests.

Rollback Plan:

Differential Revision: D79951119

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160255
Approved by: https://github.com/bobrenjc93
2025-08-12 08:38:19 +00:00
9a0f7a3bb0 [retry-land][pytorch][dynamo_compile] Log stack_trace to dynamo_compile (#160348)
refer: https://github.com/pytorch/pytorch/pull/159655

Earlier pr failed on dynamo/test_utils.py::TestDynamoTimed::test_dynamo_timed.
Updated test_dynamo_timed + re-ran locally to test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160348
Approved by: https://github.com/masnesral
2025-08-12 06:24:54 +00:00
01bcf9a40d Bump transformers pin (#159291)
Trying to update hf pin.

Benchmarking run to figure out issues

<img width="1356" height="123" alt="image" src="https://github.com/user-attachments/assets/fbc435f3-a7cb-4280-9636-2ea6d15d7b6d" />

Retrying - https://github.com/pytorch/pytorch/pull/156118

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159291
Approved by: https://github.com/BoyuanFeng, https://github.com/huydhn

Co-authored-by: Huy Do <huydhn@gmail.com>
2025-08-12 05:14:17 +00:00
8d3d1c8443 [dynamo] fixes to propagate tag safeness (#159807)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159807
Approved by: https://github.com/jansel
2025-08-12 04:50:13 +00:00
0f3b10b8ee [audio hash update] update the pinned audio hash (#160384)
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml).
Update the pinned audio hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160384
Approved by: https://github.com/pytorchbot
2025-08-12 04:38:04 +00:00
5f1010fbb3 [Graph Partition] Pass all OSS unit tests (#154667)
Graph partition leads to 6.2% speedup on vision_maskrcnn, 5.8% speedup on yolov3. [P1819700563](https://www.internalfb.com/phabricator/paste/view/P1819700563), 39.5% speedup on speech_transformer inference [P1830602200](https://www.internalfb.com/phabricator/paste/view/P1830602200), 85% speedup on speech_transformer training [P1831115315](https://www.internalfb.com/phabricator/paste/view/P1831115315).

Run the same diff on two days and both show speedup on average.

[first TorchInductor Benchmark ci run](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2021%20Jul%202025%2016%3A37%3A55%20GMT&stopTime=Mon%2C%2028%20Jul%202025%2016%3A37%3A55%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=bf/partition-turn-on&lCommit=75ef90fe89b82c967362a2d40fdf1af047202bc2&rBranch=main&rCommit=abcb24f4de11f8fedf2c2c9ff53b6092ef42306d)
<img width="1885" height="752" alt="image" src="https://github.com/user-attachments/assets/13bba9fc-5dbf-42ad-8558-d54f7e367b41" />

[second TorchInductorBenchmark ci run](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Wed%2C%2023%20Jul%202025%2016%3A38%3A27%20GMT&stopTime=Wed%2C%2030%20Jul%202025%2016%3A38%3A27%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=bf/partition-turn-on&lCommit=66de27e29338c26b1be94733049868cb0309ea52&rBranch=main&rCommit=70d2e9ba455c3c910f6f95b24171c8eee7bc00bf)
<img width="2513" height="1030" alt="image" src="https://github.com/user-attachments/assets/3a413dcb-2314-4292-919a-7ca181f9eeac" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154667
Approved by: https://github.com/eellison
2025-08-12 04:37:58 +00:00
edaa151d0d [CI] Move CUDA tests to trunk workflow (#160379)
Which is getting run before PR is merged anyway, but according to 3X
less frequently than pull workflow according to [Flambeau](https://pytorchci.grafana.net/public-dashboards/1c571e79090443eaaa9811db71f8d23b)
<img width="796" height="573" alt="image" src="https://github.com/user-attachments/assets/0235e610-4e1c-4be5-88bf-ea8278d1c656" />

I.e. that will probably results in some longer time to signal, but considering that frequency of changes to eager PyTorch-on-CUDA slowed down and Inductor changes are decorated with ciflow/inductor, this looks like an acceptable tradeoff to reduce costs
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160379
Approved by: https://github.com/izaitsevfb
2025-08-12 04:23:50 +00:00
10bc36fe84 Get tensor subclasses and torch.library.triton_op to dispatch correctly (#160341)
Short-term fix for https://github.com/pytorch/pytorch/issues/160333

The problem is:
1) `triton_op` adds a decomposition for FunctionalTensorMode for this operation
2) Tensor Subclasses rely on FunctionalTensorMode's `__torch_dispatch__` returning NotImplemented.
3) `triton_op`'s FunctionalTensorMode decomposition takes precedence over FunctionalTensorMode's decomposition.

The easy fix is to copy-paste the FunctionalTensorMode's NotImplemented
return logic into the decomposition.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160341
Approved by: https://github.com/drisspg
2025-08-12 04:09:37 +00:00
32e5e2f596 [vllm hash update] update the pinned vllm hash (#160259)
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml).
Update the pinned vllm hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160259
Approved by: https://github.com/pytorchbot
2025-08-12 04:04:53 +00:00
bfc873d02e [ROCm][Windows] Revert copying hipblaslt and rocblas dirs. (#159083)
This reverts the changes from b367e5f6a6. This will also close https://github.com/pytorch/pytorch/pull/158922.

Since 30387ab2e4, ROCm is bootstrapped using the 'rocm' Python module which contains these files (see https://github.com/ROCm/TheRock/blob/main/docs/packaging/python_packaging.md), so they do not need to be bundled into torch/lib.

There was also a bug in here - if `ROCM_DIR` is unset, the code crashes:
```
  File "D:\projects\TheRock\external-builds\pytorch\.venv\Lib\site-packages\setuptools\_distutils\dist.py", line 1002, in run_command
    cmd_obj.run()
  File "D:\b\pytorch_main\setup.py", line 853, in run
    rocm_dir_path = Path(os.environ["ROCM_DIR"])
                         ~~~~~~~~~~^^^^^^^^^^^^
  File "<frozen os>", line 714, in __getitem__
KeyError: 'ROCM_DIR'
```
The code could have checked for `ROCM_PATH` too.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159083
Approved by: https://github.com/jeffdaily
2025-08-12 02:45:49 +00:00
eed9dbf70f [ROCm] Add torch/_rocm_init.py to .gitignore. (#159806)
Follow-up to https://github.com/pytorch/pytorch/pull/155285.

Build scripts like https://github.com/ROCm/TheRock/blob/main/external-builds/pytorch/build_prod_wheels.py generate this file with contents like:

```python
def initialize():
    import rocm_sdk
    rocm_sdk.initialize_process(
        preload_shortnames=['amd_comgr', 'amdhip64', 'hiprtc', 'hipblas', 'hipfft', 'hiprand', 'hipsparse', 'hipsolver', 'hipblaslt', 'miopen'],
        check_version='7.0.0rc20250804')
```

We may also have https://github.com/pytorch/pytorch/blob/main/tools/amd_build/build_amd.py do the same thing as more of that build support moves here into the upstream PyTorch repository itself (see https://github.com/pytorch/pytorch/issues/159520).

This file is then loaded if present here: a7f3bdf550/torch/__init__.py (L145-L157)

Given that the file is generated by build scripts, I think adding it to `.gitignore` makes sense, as that will prevent accidental check-ins and keep local history cleaner.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159806
Approved by: https://github.com/jeffdaily
2025-08-12 02:24:21 +00:00
be53f609aa fix retaining multimem in symmetric memory (#160343)
fixes OOM in #160289

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160343
Approved by: https://github.com/eqy
2025-08-12 02:03:20 +00:00
95210cc409 [BE] Isolate pre-push hook dependencies in dedicated virtual environment (#160048)
This adds two changes:
- Isolates pre-push hook dependencies into an isolated venv, no longer affect your system environment
- Lets you manually run the pre-push lintrunner (including with lintrunner -a) by invoking `python scripts/lintrunner.py [-a]` (it's ugly, but better than nothing...for now)

This is a follow up to:
- https://github.com/pytorch/pytorch/pull/158389

## Problem
The current pre-push hook setup installs lintrunner and related dependencies globally, which makes developers nervous about system pollution and can cause version conflicts with existing installations.

Also, if the pre-push lintrunner found errors, you had to hope your normal lintrunner could fix them (which wasn't always the case, e.g. if those errors only manifested in certain python versions)

##  Key Changes:
  - Isolated Environment: Creates .git/hooks/linter/.venv/ with Python 3.9 (the python used in CI) and an isolated lintrunner installation
  - User-Friendly CLI: New python scripts/lintrunner.py wrapper allows developers to run lintrunner (including -a auto-fix) from any environment
  - Simplified Architecture: Eliminates pre-commit dependency entirely - uses direct git hooks

  File Changes:
  - scripts/setup_hooks.py: Rewritten to create isolated uv-managed virtual environment
  - scripts/lintrunner.py: New wrapper script with shared hash management logic
  - scripts/run_lintrunner.py: Removed (functionality merged into lintrunner.py)
  - .pre-commit-config.yaml: Removed (no longer needed)

##  Usage:
```
  # Setup (run once)
  python scripts/setup_hooks.py

  # Manual linting (works from any environment)
  python scripts/lintrunner.py        # Check mode
  python scripts/lintrunner.py -a     # Auto-fix mode

  # Git hooks work automatically
  git push  # Runs lintrunner in isolated environment

  # Need to skip the pre-push hook?
  git push --no-verify
```

##  Benefits:
  -  Zero global dependency installation
  -  Per-repository isolation prevents version conflicts
  -  Full lintrunner functionality is now accessible

##  Implementation Notes:
  - Virtual env is kept in a dedicated dir in .git, to keep per-repo mechanics
  - lintrunner.py does not need to be invoked from a specific venv.  It'll invoke the right venv itself.

A minor bug: It tends to garble the lintrunner output a bit, like the screenshot below shows, but I haven't found a workaround so far and it remains understandable to users:
<img width="241" height="154" alt="image" src="https://github.com/user-attachments/assets/9496f925-8524-4434-8486-dc579442d688" />

## What's next?
Features that could be added:
- Check for lintrunner updates, auto-update if needed
- Depending on dev response, this could be enabled by default for all pytorch/pytorch environments
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160048
Approved by: https://github.com/seemethere
2025-08-12 01:58:46 +00:00
7a974a88f2 [ROCm] Fix resource_strings.h (#159996)
This PR fixes the errors like below:

```
[rank7]: RuntimeError: /tmp/comgr-c3c81b/input/CompileSourceejOPx6:34:8: error: unknown type name 'uint64_t'; did you mean
'__hip_internal::uint64_t'? [rank7]: 34 | if(((uint64_t) t0.data) % (4 * sizeof(half)) != 0) flag_vec4 = false;
```

The following datatypes needs to be defined in `torch/csrc/jit/codegen/fuser/cuda/resource_strings.h` for ROCm versions >= 7.0.

```
typedef unsigned char uint8_t;
typedef signed char int8_t;
typedef short int  int16_t;
typedef long long int int64_t;
typedef unsigned long long int uint64_t;
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159996
Approved by: https://github.com/pruthvistony, https://github.com/Skylion007, https://github.com/jeffdaily
2025-08-12 01:58:02 +00:00
f3f159ff8c [BE][cutlass backend] Reduce severity of log message for no cutlass config found (#160148)
This is not really a problem. Sometimes we cannot find a cutlass config due to shape, e.g. when k is odd.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160148
Approved by: https://github.com/mlazos, https://github.com/Skylion007
2025-08-12 01:41:58 +00:00
b90feeac86 [BE][cutlass backend] Fix subproc addmm tests (#160295)
Differential Revision: [D79977421](https://our.internmc.facebook.com/intern/diff/D79977421/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160295
Approved by: https://github.com/jingsh
2025-08-12 01:41:06 +00:00
0d40ff3b49 [inductor] fix test_different_file_paths_local_pgo on Windows. (#160382)
fix test_different_file_paths_local_pgo on Windows.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160382
Approved by: https://github.com/angelayi
2025-08-12 01:35:39 +00:00
cae2b5e3d2 [ROCm][Windows] Enable USE_ROCM, disable USE_RCCL on Windows. (#159079)
This allows setting `USE_ROCM` on Windows. A few other patches are still required to build (see https://github.com/ROCm/TheRock/issues/589), but we have instructions using open source code and rocm python packages available at https://github.com/ROCm/TheRock/tree/main/external-builds/pytorch#build-pytorch-with-rocm-support.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159079
Approved by: https://github.com/jeffdaily
2025-08-12 01:28:20 +00:00
ee89cc7a0a [ROCm][Windows] Fix LoadHIP handling of environment variable paths on Windows. (#159080)
See https://cmake.org/cmake/help/latest/command/file.html#path-conversion. Paths stored in environment variables may use `/` or `\` (e.g. on Windows), while cmake-style paths always use `/`.

This fixes configure errors like:
```
CMake Error at D:/b/pytorch_main/build/CMakeFiles/CMakeScratch/TryCompile-srhq07/CMakeLists.txt:2 (set):
  Syntax error in cmake code at

    D:/b/pytorch_main/build/CMakeFiles/CMakeScratch/TryCompile-srhq07/CMakeLists.txt:2

  when parsing string

    D:\projects\TheRock\external-builds\pytorch\.venv\Lib\site-packages\_rocm_sdk_devel/cmake/;D:/b/pytorch_main/cmake/Modules

  Invalid character escape '\p'.

CMake Error at D:/projects/TheRock/external-builds/pytorch/.venv/Lib/site-packages/cmake/data/share/cmake-3.31/Modules/Internal/CheckSourceCompiles.cmake:108 (try_compile):
  Failed to configure test project build system.
```

(note the mixed usage of `\` and `/` in that string)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159080
Approved by: https://github.com/jeffdaily
2025-08-12 00:18:19 +00:00
e63c2b21c1 [PP] Initialize P2P communicators on first step (#160210)
Was hitting hangs in multi-node settings and initializing the NCCL communicators needed for batch p2p ops ahead of time fixes this.

This change adds extra communication since it communicates a dummy tensor to next and previous stage ranks. However, this is only paid on the first step so it is negligible.

Debug history: https://docs.google.com/document/d/1EKVJYmW2hj_VsvDvnSggXhZzJyvMu9dA0iDJWOZAtjY/edit?tab=t.0

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160210
Approved by: https://github.com/wconstab
2025-08-11 23:46:58 +00:00
3626ba711b [FlexAttention] Swap from and to & for new triton (#160227)
Fixes #158463

On B200 I am getting a bunch of error spew:
```Shell
/tmp/tmp0yiz3c94/p4/cp4ahrfnz4obsvzgftux7dg3aszopks2jljnoaz3eowlooi2scem.py:18:0: error: Failures have been detected while processing an MLIR pass pipeline
/tmp/tmp0yiz3c94/p4/cp4ahrfnz4obsvzgftux7dg3aszopks2jljnoaz3eowlooi2scem.py:18:0: note: Pipeline failed while executing [`TritonGPUHoistTMEMAlloc` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.`
Triton compilation failed: triton_tem_fused_zeros_1
def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0):
    PRESCALE_QK : tl.constexpr = False
```
```Shell
74 = arith.subi %170, %166 : i32
          %175 = arith.muli %174, %c128_i32 : i32
          %176 = arith.subi %175, %c64_i32 : i32
          %177 = arith.extui %173 : i1 to i32
          %178 = arith.muli %176, %177 : i32
          %179 = arith.subi %c1_i32, %177 : i32
          %180 = arith.muli %179, %c64_i32 : i32
          %181 = arith.addi %178, %180 : i32
          %182 = arith.muli %181, %c64_i32 : i32
          %183 = tt.splat %182 : i32 -> tensor<64x64xi32>
          %184 = tt.addptr %arg19, %183 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32>
          %185 = tt.addptr %arg20, %183 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32>
          %186 = tt.splat %181 : i32 -> tensor<64xi32>
          %187 = arith.addi %arg21, %186 : tensor<64xi32>
          scf.yield %163, %184, %185, %187 : tensor<64x64xf32>, tensor<64x64x!tt.ptr<f16>>, tensor<64x64x!tt.ptr<f16>>, tensor<64xi32>
        }
        %114 = tt.expand_dims %113#3 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32>
        %115 = arith.cmpi slt, %114, %cst_7 : tensor<1x64xi32>
        %116 = tt.broadcast %115 : tensor<1x64xi1> -> tensor<64x64xi1>
        %117 = tt.load %113#1, %116, %cst_8 : tensor<64x64x!tt.ptr<f16>>
        %118 = tt.dot %46, %117, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32>
        %119 = arith.mulf %118, %cst_13 : tensor<64x64xf32>
        %120 = arith.mulf %119, %cst_3 : tensor<64x64xf32>
        %121 = arith.select %116, %120, %cst_6 : tensor<64x64xi1>, tensor<64x64xf32>
        %122 = arith.select %115, %cst_4, %cst_5 : tensor<1x64xi1>, tensor<1x64xi1>
        %123 = tt.broadcast %122 : tensor<1x64xi1> -> tensor<64x64xi1>
        %124 = arith.select %123, %121, %cst_6 : tensor<64x64xi1>, tensor<64x64xf32>
        %125 = arith.mulf %124, %cst_2 : tensor<64x64xf32>
        %126 = tt.broadcast %61 : tensor<64x1xf32> -> tensor<64x64xf32>
        %127 = arith.subf %125, %126 : tensor<64x64xf32>
        %128 = math.exp2 %127 : tensor<64x64xf32>
        %129 = tt.load %113#2, %116, %cst_8 : tensor<64x64x!tt.ptr<f16>>
        %130 = tt.dot %51, %129, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32>
        %131 = tt.expand_dims %55 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32>
        %132 = tt.broadcast %131 : tensor<64x1xf32> -> tensor<64x64xf32>
        %133 = arith.subf %130, %132 : tensor<64x64xf32>
        %134 = arith.mulf %128, %133 : tensor<64x64xf32>
        %135 = arith.mulf %134, %cst_3 : tensor<64x64xf32>
        %136 = arith.select %116, %135, %cst_9 : tensor<64x64xi1>, tensor<64x64xf32>
        %137 = arith.select %115, %122, %cst_5 : tensor<1x64xi1>, tensor<1x64xi1>
        %138 = tt.broadcast %137 : tensor<1x64xi1> -> tensor<64x64xi1>
        %139 = arith.select %138, %136, %cst_9 : tensor<64x64xi1>, tensor<64x64xf32>
        %140 = arith.truncf %139 : tensor<64x64xf32> to tensor<64x64xf16>
        %141 = tt.trans %117 {order = array<i32: 1, 0>} : tensor<64x64xf16> -> tensor<64x64xf16>
        %142 = tt.dot %140, %141, %113#0, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32>
        scf.yield %142 : tensor<64x64xf32>
      } else {
        scf.yield %cst_9 : tensor<64x64xf32>
      }
      %84 = tt.addptr %arg13, %22 : !tt.ptr<i32>, i32
      %85 = tt.load %84 : !tt.ptr<i32>
      %86 = arith.muli %85, %c128_i32 : i32
      %87 = tt.addptr %arg12, %21 : !tt.ptr<i32>, i32
      %88 = tt.load %87 : !tt.ptr<i32>
      %89 = tt.splat %86 : i32 -> tensor<64xi32>
      %90 = arith.addi %89, %14 : tensor<64xi32>
      %91 = tt.expand_dims %90 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32>
      %92 = arith.muli %91, %cst_11 : tensor<1x64xi32>
      %93 = tt.addptr %71, %92 : tensor<1x64x!tt.ptr<f16>>, tensor<1x64xi32>
      %94 = tt.broadcast %93 : tensor<1x64x!tt.ptr<f16>> -> tensor<64x64x!tt.ptr<f16>>
      %95 = tt.addptr %94, %74 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32>
      %96 = tt.addptr %76, %92 : tensor<1x64x!tt.ptr<f16>>, tensor<1x64xi32>
      %97 = tt.broadcast %96 : tensor<1x64x!tt.ptr<f16>> -> tensor<64x64x!tt.ptr<f16>>
      %98 = tt.addptr %97, %74 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32>
      %99 = arith.muli %88, %c2_i32 : i32
      %100 = arith.minsi %99, %c4_i32 : i32
      %101 = arith.cmpi sge, %100, %c1_i32 : i32
      %102 = scf.if %101 -> (tensor<64x64xf32>) {
        %112 = arith.subi %100, %c1_i32 : i32
        %113:4 = scf.for %arg17 = %c0_i32 to %112 step %c1_i32 iter_args(%arg18 = %83, %arg19 = %95, %arg20 = %98, %arg21 = %90) -> (tensor<64x64xf32>, tensor<64x64x!tt.ptr<f16>>, tensor<64x64x!tt.ptr<f16>>, tensor<64xi32>)  : i32 {
          %137 = tt.expand_dims %arg21 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32>
          %138 = arith.cmpi slt, %137, %cst_7 : tensor<1x64xi32>
          %139 = tt.broadcast %138 : tensor<1x64xi1> -> tensor<64x64xi1>
          %140 = tt.load %arg19, %139, %cst_8 : tensor<64x64x!tt.ptr<f16>>
          %141 = tt.dot %46, %140, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32>
          %142 = arith.mulf %141, %cst_13 : tensor<64x64xf32>
          %143 = arith.mulf %142, %cst_3 : tensor<64x64xf32>
          %144 = arith.mulf %143, %cst_2 : tensor<64x64xf32>
          %145 = tt.broadcast %61 : tensor<64x1xf32> -> tensor<64x64xf32>
          %146 = arith.subf %144, %145 : tensor<64x64xf32>
          %147 = math.exp2 %146 : tensor<64x64xf32>
          %148 = tt.load %arg20, %139, %cst_8 : tensor<64x64x!tt.ptr<f16>>
          %149 = tt.dot %51, %148, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32>
          %150 = tt.expand_dims %55 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32>
          %151 = tt.broadcast %150 : tensor<64x1xf32> -> tensor<64x64xf32>
          %152 = arith.subf %149, %151 : tensor<64x64xf32>
          %153 = arith.mulf %147, %152 : tensor<64x64xf32>
          %154 = arith.mulf %153, %cst_3 : tensor<64x64xf32>
          %155 = arith.truncf %154 : tensor<64x64xf32> to tensor<64x64xf16>
          %156 = tt.trans %140 {order = array<i32: 1, 0>} : tensor<64x64xf16> -> tensor<64x64xf16>
          %157 = tt.dot %155, %156, %arg18, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32>
          %158 = arith.divsi %arg17, %c2_i32 : i32
          %159 = tt.addptr %84, %158 : !tt.ptr<i32>, i32
          %160 = tt.load %159 evictionPolicy = evict_last : !tt.ptr<i32>
          %161 = arith.addi %158, %c1_i32 : i32
          %162 = arith.cmpi slt, %161, %88 : i32
          %163 = tt.addptr %159, %c1_i32 : !tt.ptr<i32>, i32
          %164 = tt.load %163, %162 evictionPolicy = evict_last : !tt.ptr<i32>
          %165 = arith.addi %arg17, %c1_i32 : i32
          %166 = arith.remsi %165, %c2_i32 : i32
          %167 = arith.cmpi eq, %166, %c0_i32 : i32
          %168 = arith.subi %164, %160 : i32
          %169 = arith.muli %168, %c128_i32 : i32
          %170 = arith.subi %169, %c64_i32 : i32
          %171 = arith.extui %167 : i1 to i32
          %172 = arith.muli %170, %171 : i32
          %173 = arith.subi %c1_i32, %171 : i32
          %174 = arith.muli %173, %c64_i32 : i32
          %175 = arith.addi %172, %174 : i32
          %176 = arith.muli %175, %c64_i32 : i32
          %177 = tt.splat %176 : i32 -> tensor<64x64xi32>
          %178 = tt.addptr %arg19, %177 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32>
          %179 = tt.addptr %arg20, %177 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32>
          %180 = tt.splat %175 : i32 -> tensor<64xi32>
          %181 = arith.addi %arg21, %180 : tensor<64xi32>
          scf.yield %157, %178, %179, %181 : tensor<64x64xf32>, tensor<64x64x!tt.ptr<f16>>, tensor<64x64x!tt.ptr<f16>>, tensor<64xi32>
        }
        %114 = tt.expand_dims %113#3 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32>
        %115 = arith.cmpi slt, %114, %cst_7 : tensor<1x64xi32>
        %116 = tt.broadcast %115 : tensor<1x64xi1> -> tensor<64x64xi1>
        %117 = tt.load %113#1, %116, %cst_8 : tensor<64x64x!tt.ptr<f16>>
        %118 = tt.dot %46, %117, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32>
        %119 = arith.mulf %118, %cst_13 : tensor<64x64xf32>
        %120 = arith.mulf %119, %cst_3 : tensor<64x64xf32>
        %121 = arith.select %116, %120, %cst_6 : tensor<64x64xi1>, tensor<64x64xf32>
        %122 = arith.mulf %121, %cst_2 : tensor<64x64xf32>
        %123 = tt.broadcast %61 : tensor<64x1xf32> -> tensor<64x64xf32>
        %124 = arith.subf %122, %123 : tensor<64x64xf32>
        %125 = math.exp2 %124 : tensor<64x64xf32>
        %126 = tt.load %113#2, %116, %cst_8 : tensor<64x64x!tt.ptr<f16>>
        %127 = tt.dot %51, %126, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32>
        %128 = tt.expand_dims %55 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32>
        %129 = tt.broadcast %128 : tensor<64x1xf32> -> tensor<64x64xf32>
        %130 = arith.subf %127, %129 : tensor<64x64xf32>
        %131 = arith.mulf %125, %130 : tensor<64x64xf32>
        %132 = arith.mulf %131, %cst_3 : tensor<64x64xf32>
        %133 = arith.select %116, %132, %cst_9 : tensor<64x64xi1>, tensor<64x64xf32>
        %134 = arith.truncf %133 : tensor<64x64xf32> to tensor<64x64xf16>
        %135 = tt.trans %117 {order = array<i32: 1, 0>} : tensor<64x64xf16> -> tensor<64x64xf16>
        %136 = tt.dot %134, %135, %113#0, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32>
        scf.yield %136 : tensor<64x64xf32>
      } else {
        scf.yield %83 : tensor<64x64xf32>
      }
      %103 = tt.splat %33 : !tt.ptr<f16> -> tensor<64x1x!tt.ptr<f16>>
      %104 = tt.addptr %103, %37 : tensor<64x1x!tt.ptr<f16>>, tensor<64x1xi32>
      %105 = tt.broadcast %104 : tensor<64x1x!tt.ptr<f16>> -> tensor<64x64x!tt.ptr<f16>>
      %106 = tt.addptr %105, %42 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32>
      %107 = arith.mulf %102, %cst_13 : tensor<64x64xf32>
      %108 = arith.cmpi slt, %40, %cst_11 : tensor<1x64xi32>
      %109 = tt.broadcast %108 : tensor<1x64xi1> -> tensor<64x64xi1>
      %110 = arith.andi %45, %109 : tensor<64x64xi1>
      %111 = arith.truncf %107 : tensor<64x64xf32> to tensor<64x64xf16>
      tt.store %106, %111, %110 : tensor<64x64x!tt.ptr<f16>>
    } else {
      %16 = arith.divsi %0, %c2_i32 : i32
      %17 = arith.muli %0, %c64_i32 : i32
      %18 = tt.splat %17 : i32 -> tensor<64xi32>
      %19 = arith.addi %18, %14 : tensor<64xi32>
      %20 = tt.expand_dims %19 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32>
      %21 = arith.muli %20, %cst_14 : tensor<64x1xi32>
      %22 = tt.splat %11 : !tt.ptr<f16> -> tensor<64x1x!tt.ptr<f16>>
      %23 = tt.addptr %22, %21 : tensor<64x1x!tt.ptr<f16>>, tensor<64x1xi32>
      %24 = tt.expand_dims %14 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32>
      %25 = tt.broadcast %23 : tensor<64x1x!tt.ptr<f16>> -> tensor<64x64x!tt.ptr<f16>>
      %26 = tt.broadcast %24 : tensor<1x64xi32> -> tensor<64x64xi32>
      %27 = tt.addptr %25, %26 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32>
      %28 = arith.cmpi slt, %20, %cst_10 : tensor<64x1xi32>
      %29 = tt.broadcast %28 : tensor<64x1xi1> -> tensor<64x64xi1>
      %30 = tt.load %27, %29, %cst_8 : tensor<64x64x!tt.ptr<f16>>
      %31 = tt.splat %12 : !tt.ptr<f16> -> tensor<64x1x!tt.ptr<f16>>
      %32 = tt.addptr %31, %21 : tensor<64x1x!tt.ptr<f16>>, tensor<64x1xi32>
      %33 = tt.broadcast %32 : tensor<64x1x!tt.ptr<f16>> -> tensor<64x64x!tt.ptr<f16>>
      %34 = tt.addptr %33, %26 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32>
      %35 = tt.load %34, %29, %cst_8 : tensor<64x64x!tt.ptr<f16>>
      %36:2 = scf.for %arg17 = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%arg18 = %cst_9, %arg19 = %cst_9) -> (tensor<64x64xf32>, tensor<64x64xf32>)  : i32 {
        %55 = arith.muli %2, %c4_i32 : i32
        %56 = arith.addi %55, %arg17 : i32
        %57 = arith.muli %56, %c2048_i32 : i32
        %58 = arith.muli %1, %c32768_i32 : i32
        %59 = arith.addi %57, %58 : i32
        %60 = arith.extsi %59 : i32 to i64
        %61 = arith.muli %1, %c16_i32 : i32
        %62 = arith.addi %61, %56 : i32
        %63 = arith.muli %62, %c32_i32 : i32
        %64 = arith.extsi %63 : i32 to i64
        %65 = tt.addptr %arg0, %60 : !tt.ptr<f16>, i64
        %66 = tt.addptr %arg5, %60 : !tt.ptr<f16>, i64
        %67 = tt.addptr %arg3, %64 : !tt.ptr<f32>, i64
        %68 = tt.addptr %arg4, %64 : !tt.ptr<f32>, i64
        %69 = arith.remsi %56, %c16_i32 : i32
        %70 = arith.muli %3, %c16_i32 : i32
        %71 = arith.addi %70, %69 : i32
        %72 = arith.muli %71, %c2_i32 : i32
        %73 = arith.addi %72, %16 : i32
        %74 = tt.addptr %arg11, %73 : !tt.ptr<i32>, i32
        %75 = tt.load %74 : !tt.ptr<i32>
        %76 = arith.muli %75, %c128_i32 : i32
        %77 = tt.addptr %arg10, %73 : !tt.ptr<i32>, i32
        %78 = tt.load %77 : !tt.ptr<i32>
        %79 = tt.splat %76 : i32 -> tensor<64xi32>
        %80 = arith.addi %79, %14 : tensor<64xi32>
        %81 = tt.expand_dims %80 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32>
        %82 = arith.muli %81, %cst_11 : tensor<1x64xi32>
        %83 = tt.splat %65 : !tt.ptr<f16> -> tensor<1x64x!tt.ptr<f16>>
        %84 = tt.addptr %83, %82 : tensor<1x64x!tt.ptr<f16>>, tensor<1x64xi32>
        %85 = tt.expand_dims %14 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32>
        %86 = tt.broadcast %84 : tensor<1x64x!tt.ptr<f16>> -> tensor<64x64x!tt.ptr<f16>>
        %87 = tt.broadcast %85 : tensor<64x1xi32> -> tensor<64x64xi32>
        %88 = tt.addptr %86, %87 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32>
        %89 = tt.expand_dims %80 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32>
        %90 = arith.muli %89, %cst_14 : tensor<64x1xi32>
        %91 = tt.splat %66 : !tt.ptr<f16> -> tensor<64x1x!tt.ptr<f16>>
        %92 = tt.addptr %91, %90 : tensor<64x1x!tt.ptr<f16>>, tensor<64x1xi32>
        %93 = tt.broadcast %92 : tensor<64x1x!tt.ptr<f16>> -> tensor<64x64x!tt.ptr<f16>>
        %94 = tt.addptr %93, %26 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32>
        %95 = arith.muli %78, %c2_i32 : i32
        %96 = arith.minsi %95, %c1_i32 : i32
        %97 = arith.cmpi sge, %96, %c1_i32 : i32
        %98:2 = scf.if %97 -> (tensor<64x64xf32>, tensor<64x64xf32>) {
          %120 = arith.subi %96, %c1_i32 : i32
          %121:5 = scf.for %arg20 = %c0_i32 to %120 step %c1_i32 iter_args(%arg21 = %arg18, %arg22 = %arg19, %arg23 = %88, %arg24 = %94, %arg25 = %80) -> (tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64x!tt.ptr<f16>>, tensor<64x64x!tt.ptr<f16>>, tensor<64xi32>)  : i32 {
            %167 = tt.expand_dims %arg25 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32>
            %168 = arith.cmpi slt, %167, %cst_1 : tensor<1x64xi32>
            %169 = tt.broadcast %168 : tensor<1x64xi1> -> tensor<64x64xi1>
            %170 = tt.load %arg23, %169, %cst_8 : tensor<64x64x!tt.ptr<f16>>
            %171 = arith.cmpi slt, %arg25, %cst_17 : tensor<64xi32>
            %172 = tt.splat %67 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
            %173 = tt.addptr %172, %arg25 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
            %174 = tt.load %173, %171 : tensor<64x!tt.ptr<f32>>
            %175 = arith.cmpf oeq, %174, %cst_16 : tensor<64xf32>
            %176 = arith.select %175, %cst_15, %174 : tensor<64xi1>, tensor<64xf32>
            %177 = tt.dot %30, %170, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32>
            %178 = arith.mulf %177, %cst_13 : tensor<64x64xf32>
            %179 = arith.mulf %178, %cst_3 : tensor<64x64xf32>
            %180 = arith.mulf %179, %cst_2 : tensor<64x64xf32>
            %181 = tt.expand_dims %176 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32>
            %182 = tt.broadcast %181 : tensor<1x64xf32> -> tensor<64x64xf32>
            %183 = arith.subf %180, %182 : tensor<64x64xf32>
            %184 = math.exp2 %183 : tensor<64x64xf32>
            %185 = tt.expand_dims %arg25 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32>
            %186 = arith.cmpi slt, %185, %cst_12 : tensor<64x1xi32>
            %187 = tt.broadcast %186 : tensor<64x1xi1> -> tensor<64x64xi1>
            %188 = tt.load %arg24, %187, %cst_8 : tensor<64x64x!tt.ptr<f16>>
            %189 = arith.truncf %184 : tensor<64x64xf32> to tensor<64x64xf16>
            %190 = tt.dot %189, %188, %arg22, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32>
            %191 = tt.splat %68 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
            %192 = tt.addptr %191, %arg25 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
            %193 = tt.load %192, %171 : tensor<64x!tt.ptr<f32>>
            %194 = tt.trans %188 {order = array<i32: 1, 0>} : tensor<64x64xf16> -> tensor<64x64xf16>
            %195 = tt.dot %35, %194, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32>
            %196 = tt.expand_dims %193 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32>
            %197 = tt.broadcast %196 : tensor<1x64xf32> -> tensor<64x64xf32>
            %198 = arith.subf %195, %197 : tensor<64x64xf32>
            %199 = arith.mulf %184, %198 : tensor<64x64xf32>
            %200 = arith.mulf %199, %cst_3 : tensor<64x64xf32>
            %201 = arith.truncf %200 : tensor<64x64xf32> to tensor<64x64xf16>
            %202 = tt.trans %170 {order = array<i32: 1, 0>} : tensor<64x64xf16> -> tensor<64x64xf16>
            %203 = tt.dot %201, %202, %arg21, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32>
            %204 = arith.divsi %arg20, %c2_i32 : i32
            %205 = tt.addptr %74, %204 : !tt.ptr<i32>, i32
            %206 = tt.load %205 evictionPolicy = evict_last : !tt.ptr<i32>
            %207 = arith.addi %204, %c1_i32 : i32
            %208 = arith.cmpi slt, %207, %78 : i32
            %209 = tt.addptr %205, %c1_i32 : !tt.ptr<i32>, i32
            %210 = tt.load %209, %208 evictionPolicy = evict_last : !tt.ptr<i32>
            %211 = arith.addi %arg20, %c1_i32 : i32
            %212 = arith.remsi %211, %c2_i32 : i32
            %213 = arith.cmpi eq, %212, %c0_i32 : i32
            %214 = arith.subi %210, %206 : i32
            %215 = arith.muli %214, %c128_i32 : i32
            %216 = arith.subi %215, %c64_i32 : i32
            %217 = arith.extui %213 : i1 to i32
            %218 = arith.muli %216, %217 : i32
            %219 = arith.subi %c1_i32, %217 : i32
            %220 = arith.muli %219, %c64_i32 : i32
            %221 = arith.addi %218, %220 : i32
            %222 = arith.muli %221, %c64_i32 : i32
            %223 = tt.splat %222 : i32 -> tensor<64x64xi32>
            %224 = tt.addptr %arg23, %223 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32>
            %225 = tt.addptr %arg24, %223 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32>
            %226 = tt.splat %221 : i32 -> tensor<64xi32>
            %227 = arith.addi %arg25, %226 : tensor<64xi32>
            scf.yield %203, %190, %224, %225, %227 : tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64x!tt.ptr<f16>>, tensor<64x64x!tt.ptr<f16>>, tensor<64xi32>
          }
          %122 = tt.expand_dims %121#4 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32>
          %123 = arith.cmpi slt, %122, %cst_1 : tensor<1x64xi32>
          %124 = tt.broadcast %123 : tensor<1x64xi1> -> tensor<64x64xi1>
          %125 = tt.load %121#2, %124, %cst_8 : tensor<64x64x!tt.ptr<f16>>
          %126 = arith.cmpi slt, %121#4, %cst_17 : tensor<64xi32>
          %127 = tt.splat %67 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
          %128 = tt.addptr %127, %121#4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
          %129 = tt.load %128, %126 : tensor<64x!tt.ptr<f32>>
          %130 = arith.cmpf oeq, %129, %cst_16 : tensor<64xf32>
          %131 = arith.select %130, %cst_15, %129 : tensor<64xi1>, tensor<64xf32>
          %132 = tt.dot %30, %125, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32>
          %133 = arith.mulf %132, %cst_13 : tensor<64x64xf32>
          %134 = arith.mulf %133, %cst_3 : tensor<64x64xf32>
          %135 = arith.select %29, %134, %cst_6 : tensor<64x64xi1>, tensor<64x64xf32>
          %136 = arith.select %28, %cst, %cst_0 : tensor<64x1xi1>, tensor<64x1xi1>
          %137 = tt.broadcast %136 : tensor<64x1xi1> -> tensor<64x64xi1>
          %138 = arith.select %137, %135, %cst_6 : tensor<64x64xi1>, tensor<64x64xf32>
          %139 = arith.mulf %138, %cst_2 : tensor<64x64xf32>
          %140 = tt.expand_dims %131 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32>
          %141 = tt.broadcast %140 : tensor<1x64xf32> -> tensor<64x64xf32>
          %142 = arith.subf %139, %141 : tensor<64x64xf32>
          %143 = math.exp2 %142 : tensor<64x64xf32>
          %144 = tt.expand_dims %121#4 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32>
          %145 = arith.cmpi slt, %144, %cst_12 : tensor<64x1xi32>
          %146 = tt.broadcast %145 : tensor<64x1xi1> -> tensor<64x64xi1>
          %147 = tt.load %121#3, %146, %cst_8 : tensor<64x64x!tt.ptr<f16>>
          %148 = arith.truncf %143 : tensor<64x64xf32> to tensor<64x64xf16>
          %149 = tt.dot %148, %147, %121#1, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32>
          %150 = tt.splat %68 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
          %151 = tt.addptr %150, %121#4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
          %152 = tt.load %151, %126 : tensor<64x!tt.ptr<f32>>
          %153 = tt.trans %147 {order = array<i32: 1, 0>} : tensor<64x64xf16> -> tensor<64x64xf16>
          %154 = tt.dot %35, %153, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32>
          %155 = tt.expand_dims %152 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32>
          %156 = tt.broadcast %155 : tensor<1x64xf32> -> tensor<64x64xf32>
          %157 = arith.subf %154, %156 : tensor<64x64xf32>
          %158 = arith.mulf %143, %157 : tensor<64x64xf32>
          %159 = arith.mulf %158, %cst_3 : tensor<64x64xf32>
          %160 = arith.select %29, %159, %cst_9 : tensor<64x64xi1>, tensor<64x64xf32>
          %161 = arith.select %28, %136, %cst_0 : tensor<64x1xi1>, tensor<64x1xi1>
          %162 = tt.broadcast %161 : tensor<64x1xi1> -> tensor<64x64xi1>
          %163 = arith.select %162, %160, %cst_9 : tensor<64x64xi1>, tensor<64x64xf32>
          %164 = arith.truncf %163 : tensor<64x64xf32> to tensor<64x64xf16>
          %165 = tt.trans %125 {order = array<i32: 1, 0>} : tensor<64x64xf16> -> tensor<64x64xf16>
          %166 = tt.dot %164, %165, %121#0, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32>
          scf.yield %166, %149 : tensor<64x64xf32>, tensor<64x64xf32>
        } else {
          scf.yield %arg18, %arg19 : tensor<64x64xf32>, tensor<64x64xf32>
        }
        %99 = tt.addptr %arg15, %73 : !tt.ptr<i32>, i32
        %100 = tt.load %99 : !tt.ptr<i32>
        %101 = arith.muli %100, %c128_i32 : i32
        %102 = tt.addptr %arg14, %73 : !tt.ptr<i32>, i32
        %103 = tt.load %102 : !tt.ptr<i32>
        %104 = tt.splat %101 : i32 -> tensor<64xi32>
        %105 = arith.addi %104, %14 : tensor<64xi32>
        %106 = tt.expand_dims %105 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32>
        %107 = arith.muli %106, %cst_11 : tensor<1x64xi32>
        %108 = tt.addptr %83, %107 : tensor<1x64x!tt.ptr<f16>>, tensor<1x64xi32>
        %109 = tt.broadcast %108 : tensor<1x64x!tt.ptr<f16>> -> tensor<64x64x!tt.ptr<f16>>
        %110 = tt.addptr %109, %87 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32>
        %111 = tt.expand_dims %105 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32>
        %112 = arith.muli %111, %cst_14 : tensor<64x1xi32>
        %113 = tt.addptr %91, %112 : tensor<64x1x!tt.ptr<f16>>, tensor<64x1xi32>
        %114 = tt.broadcast %113 : tensor<64x1x!tt.ptr<f16>> -> tensor<64x64x!tt.ptr<f16>>
        %115 = tt.addptr %114, %26 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32>
        %116 = arith.muli %103, %c2_i32 : i32
        %117 = arith.minsi %116, %c1_i32 : i32
        %118 = arith.cmpi sge, %117, %c1_i32 : i32
        %119:2 = scf.if %118 -> (tensor<64x64xf32>, tensor<64x64xf32>) {
          %120 = arith.subi %117, %c1_i32 : i32
          %121:5 = scf.for %arg20 = %c0_i32 to %120 step %c1_i32 iter_args(%arg21 = %98#0, %arg22 = %98#1, %arg23 = %110, %arg24 = %115, %arg25 = %105) -> (tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64x!tt.ptr<f16>>, tensor<64x64x!tt.ptr<f16>>, tensor<64xi32>)  : i32 {
            %161 = tt.expand_dims %arg25 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32>
            %162 = arith.cmpi slt, %161, %cst_1 : tensor<1x64xi32>
            %163 = tt.broadcast %162 : tensor<1x64xi1> -> tensor<64x64xi1>
            %164 = tt.load %arg23, %163, %cst_8 : tensor<64x64x!tt.ptr<f16>>
            %165 = arith.cmpi slt, %arg25, %cst_17 : tensor<64xi32>
            %166 = tt.splat %67 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
            %167 = tt.addptr %166, %arg25 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
            %168 = tt.load %167, %165 : tensor<64x!tt.ptr<f32>>
            %169 = arith.cmpf oeq, %168, %cst_16 : tensor<64xf32>
            %170 = arith.select %169, %cst_15, %168 : tensor<64xi1>, tensor<64xf32>
            %171 = tt.dot %30, %164, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32>
            %172 = arith.mulf %171, %cst_13 : tensor<64x64xf32>
            %173 = arith.mulf %172, %cst_3 : tensor<64x64xf32>
            %174 = arith.mulf %173, %cst_2 : tensor<64x64xf32>
            %175 = tt.expand_dims %170 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32>
            %176 = tt.broadcast %175 : tensor<1x64xf32> -> tensor<64x64xf32>
            %177 = arith.subf %174, %176 : tensor<64x64xf32>
            %178 = math.exp2 %177 : tensor<64x64xf32>
            %179 = tt.expand_dims %arg25 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32>
            %180 = arith.cmpi slt, %179, %cst_12 : tensor<64x1xi32>
            %181 = tt.broadcast %180 : tensor<64x1xi1> -> tensor<64x64xi1>
            %182 = tt.load %arg24, %181, %cst_8 : tensor<64x64x!tt.ptr<f16>>
            %183 = arith.truncf %178 : tensor<64x64xf32> to tensor<64x64xf16>
            %184 = tt.dot %183, %182, %arg22, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32>
            %185 = tt.splat %68 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
            %186 = tt.addptr %185, %arg25 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
            %187 = tt.load %186, %165 : tensor<64x!tt.ptr<f32>>
            %188 = tt.trans %182 {order = array<i32: 1, 0>} : tensor<64x64xf16> -> tensor<64x64xf16>
            %189 = tt.dot %35, %188, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32>
            %190 = tt.expand_dims %187 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32>
            %191 = tt.broadcast %190 : tensor<1x64xf32> -> tensor<64x64xf32>
            %192 = arith.subf %189, %191 : tensor<64x64xf32>
            %193 = arith.mulf %178, %192 : tensor<64x64xf32>
            %194 = arith.mulf %193, %cst_3 : tensor<64x64xf32>
            %195 = arith.truncf %194 : tensor<64x64xf32> to tensor<64x64xf16>
            %196 = tt.trans %164 {order = array<i32: 1, 0>} : tensor<64x64xf16> -> tensor<64x64xf16>
            %197 = tt.dot %195, %196, %arg21, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32>
            %198 = arith.divsi %arg20, %c2_i32 : i32
            %199 = tt.addptr %99, %198 : !tt.ptr<i32>, i32
            %200 = tt.load %199 evictionPolicy = evict_last : !tt.ptr<i32>
            %201 = arith.addi %198, %c1_i32 : i32
            %202 = arith.cmpi slt, %201, %103 : i32
            %203 = tt.addptr %199, %c1_i32 : !tt.ptr<i32>, i32
            %204 = tt.load %203, %202 evictionPolicy = evict_last : !tt.ptr<i32>
            %205 = arith.addi %arg20, %c1_i32 : i32
            %206 = arith.remsi %205, %c2_i32 : i32
            %207 = arith.cmpi eq, %206, %c0_i32 : i32
            %208 = arith.subi %204, %200 : i32
            %209 = arith.muli %208, %c128_i32 : i32
            %210 = arith.subi %209, %c64_i32 : i32
            %211 = arith.extui %207 : i1 to i32
            %212 = arith.muli %210, %211 : i32
            %213 = arith.subi %c1_i32, %211 : i32
            %214 = arith.muli %213, %c64_i32 : i32
            %215 = arith.addi %212, %214 : i32
            %216 = arith.muli %215, %c64_i32 : i32
            %217 = tt.splat %216 : i32 -> tensor<64x64xi32>
            %218 = tt.addptr %arg23, %217 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32>
            %219 = tt.addptr %arg24, %217 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32>
            %220 = tt.splat %215 : i32 -> tensor<64xi32>
            %221 = arith.addi %arg25, %220 : tensor<64xi32>
            scf.yield %197, %184, %218, %219, %221 : tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64x!tt.ptr<f16>>, tensor<64x64x!tt.ptr<f16>>, tensor<64xi32>
          }
          %122 = tt.expand_dims %121#4 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32>
          %123 = arith.cmpi slt, %122, %cst_1 : tensor<1x64xi32>
          %124 = tt.broadcast %123 : tensor<1x64xi1> -> tensor<64x64xi1>
          %125 = tt.load %121#2, %124, %cst_8 : tensor<64x64x!tt.ptr<f16>>
          %126 = arith.cmpi slt, %121#4, %cst_17 : tensor<64xi32>
          %127 = tt.splat %67 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
          %128 = tt.addptr %127, %121#4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
          %129 = tt.load %128, %126 : tensor<64x!tt.ptr<f32>>
          %130 = arith.cmpf oeq, %129, %cst_16 : tensor<64xf32>
          %131 = arith.select %130, %cst_15, %129 : tensor<64xi1>, tensor<64xf32>
          %132 = tt.dot %30, %125, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32>
          %133 = arith.mulf %132, %cst_13 : tensor<64x64xf32>
          %134 = arith.mulf %133, %cst_3 : tensor<64x64xf32>
          %135 = arith.select %29, %134, %cst_6 : tensor<64x64xi1>, tensor<64x64xf32>
          %136 = arith.mulf %135, %cst_2 : tensor<64x64xf32>
          %137 = tt.expand_dims %131 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32>
          %138 = tt.broadcast %137 : tensor<1x64xf32> -> tensor<64x64xf32>
          %139 = arith.subf %136, %138 : tensor<64x64xf32>
          %140 = math.exp2 %139 : tensor<64x64xf32>
          %141 = tt.expand_dims %121#4 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32>
          %142 = arith.cmpi slt, %141, %cst_12 : tensor<64x1xi32>
          %143 = tt.broadcast %142 : tensor<64x1xi1> -> tensor<64x64xi1>
          %144 = tt.load %121#3, %143, %cst_8 : tensor<64x64x!tt.ptr<f16>>
          %145 = arith.truncf %140 : tensor<64x64xf32> to tensor<64x64xf16>
          %146 = tt.dot %145, %144, %121#1, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32>
          %147 = tt.splat %68 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
          %148 = tt.addptr %147, %121#4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
          %149 = tt.load %148, %126 : tensor<64x!tt.ptr<f32>>
          %150 = tt.trans %144 {order = array<i32: 1, 0>} : tensor<64x64xf16> -> tensor<64x64xf16>
          %151 = tt.dot %35, %150, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32>
          %152 = tt.expand_dims %149 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32>
          %153 = tt.broadcast %152 : tensor<1x64xf32> -> tensor<64x64xf32>
          %154 = arith.subf %151, %153 : tensor<64x64xf32>
          %155 = arith.mulf %140, %154 : tensor<64x64xf32>
          %156 = arith.mulf %155, %cst_3 : tensor<64x64xf32>
          %157 = arith.select %29, %156, %cst_9 : tensor<64x64xi1>, tensor<64x64xf32>
          %158 = arith.truncf %157 : tensor<64x64xf32> to tensor<64x64xf16>
          %159 = tt.trans %125 {order = array<i32: 1, 0>} : tensor<64x64xf16> -> tensor<64x64xf16>
          %160 = tt.dot %158, %159, %121#0, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32>
          scf.yield %160, %146 : tensor<64x64xf32>, tensor<64x64xf32>
        } else {
          scf.yield %98#0, %98#1 : tensor<64x64xf32>, tensor<64x64xf32>
        }
        scf.yield %119#0, %119#1 : tensor<64x64xf32>, tensor<64x64xf32>
      }
      %37 = tt.splat %13 : !tt.ptr<f16> -> tensor<64x1x!tt.ptr<f16>>
      %38 = tt.addptr %37, %21 : tensor<64x1x!tt.ptr<f16>>, tensor<64x1xi32>
      %39 = tt.broadcast %38 : tensor<64x1x!tt.ptr<f16>> -> tensor<64x64x!tt.ptr<f16>>
      %40 = tt.addptr %39, %26 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32>
      %41 = arith.cmpi slt, %24, %cst_11 : tensor<1x64xi32>
      %42 = tt.broadcast %41 : tensor<1x64xi1> -> tensor<64x64xi1>
      %43 = arith.andi %29, %42 : tensor<64x64xi1>
      %44 = arith.truncf %36#1 : tensor<64x64xf32> to tensor<64x64xf16>
      tt.store %40, %44, %43 : tensor<64x64x!tt.ptr<f16>>
      %45 = arith.mulf %36#0, %cst_13 : tensor<64x64xf32>
      %46 = tt.broadcast %21 : tensor<64x1xi32> -> tensor<64x64xi32>
      %47 = arith.addi %26, %46 : tensor<64x64xi32>
      %48 = tt.splat %4 : i32 -> tensor<64x64xi32>
      %49 = arith.addi %47, %48 : tensor<64x64xi32>
      %50 = tt.splat %8 : i32 -> tensor<64x64xi32>
      %51 = arith.addi %49, %50 : tensor<64x64xi32>
      %52 = tt.splat %arg16 : !tt.ptr<f16> -> tensor<64x64x!tt.ptr<f16>>
      %53 = tt.addptr %52, %51 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32>
      %54 = arith.truncf %45 : tensor<64x64xf32> to tensor<64x64xf16>
      tt.store %53, %54, %29 : tensor<64x64x!tt.ptr<f16>>
    }
    tt.return
  }
}

{-#
  external_resources: {
    mlir_reproducer: {
      pipeline: "builtin.module(convert-triton-to-tritongpu{enable-source-remat=false num-ctas=1 num-warps=4 target=cuda:100 threads-per-warp=32}, tritongpu-coalesce, tritongpu-F32DotTC, triton-nvidia-gpu-plan-cta, tritongpu-remove-layout-conversions, tritongpu-optimize-thread-locality, tritongpu-accelerate-matmul, tritongpu-remove-layout-conversions, tritongpu-optimize-dot-operands{hoist-layout-conversion=true}, triton-nvidia-optimize-descriptor-encoding, triton-loop-aware-cse, tritongpu-fuse-nested-loops, canonicalize{  max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, triton-licm, tritongpu-optimize-accumulator-init, tritongpu-hoist-tmem-alloc, tritongpu-promote-lhs-to-tmem, tritongpu-assign-latencies{num-stages=3}, tritongpu-schedule-loops, tritongpu-automatic-warp-specialization{num-stages=3}, tritongpu-pipeline{dump-intermediate-steps=false num-stages=3}, tritongpu-combine-tensor-select-and-if, triton-nvidia-gpu-remove-tmem-tokens, canonicalize{  max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, triton-loop-aware-cse, tritongpu-prefetch, tritongpu-optimize-dot-operands{hoist-layout-conversion=true}, tritongpu-coalesce-async-copy, triton-nvidia-optimize-tmem-layouts, tritongpu-remove-layout-conversions, triton-nvidia-interleave-tmem, tritongpu-reduce-data-duplication, tritongpu-reorder-instructions, triton-loop-aware-cse, symbol-dce, triton-nvidia-tma-lowering, triton-nvidia-gpu-fence-insertion{compute-capability=90}, sccp, canonicalize{  max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true})",
      disable_threading: false,
      verify_each: true
    }
  }
#-}
/tmp/tmp0yiz3c94/p4/cp4ahrfnz4obsvzgftux7dg3aszopks2jljnoaz3eowlooi2scem.py:18:0: error: Failures have been detected while processing an MLIR pass pipeline
/tmp/tmp0yiz3c94/p4/cp4ahrfnz4obsvzgftux7dg3aszopks2jljnoaz3eowlooi2scem.py:18:0: note: Pipeline failed while executing [`TritonGPUHoistTMEMAlloc` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.`
Triton compilation failed: triton_tem_fused_zeros_1
def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0):
    PRESCALE_QK : tl.constexpr = False
    ROWS_GUARANTEED_SAFE : tl.constexpr = False
    BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
    WRITE_DQ : tl.constexpr = True
    OUTPUT_LOGSUMEXP : tl.constexpr = True
    FLOAT32_PRECISION : tl.constexpr = 'tf32'
    IS_DIVISIBLE : tl.constexpr = False
    SM_SCALE : tl.constexpr = 0.125
    GQA_SHARED_HEADS : tl.constexpr = 4
    HAS_FULL_BLOCKS : tl.constexpr = True
    QK_HEAD_DIM : tl.constexpr = 64
    QK_HEAD_DIM_ROUNDED : tl.constexpr = 64
    V_HEAD_DIM : tl.constexpr = 64
    V_HEAD_DIM_ROUNDED : tl.constexpr = 64
    SAFE_HEAD_DIM : tl.constexpr = True
    BLOCK_M1 : tl.constexpr = 64
    BLOCK_N1 : tl.constexpr = 64
    BLOCK_M2 : tl.constexpr = 64
    BLOCK_N2 : tl.constexpr = 64
    SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
    SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
    Q = arg_Q
    K = arg_K
    V = arg_V
    LSE = arg_LSE
    DELTA = arg_DELTA
    DO = arg_DO
    DQ = arg_DQ
    DV = arg_DV
    KV_NUM_BLKS = arg_KV_NUM_BLKS
    KV_IDX = arg_KV_IDX
    Q_NUM_BLKS = arg_Q_NUM_BLKS
    Q_IDX = arg_Q_IDX
    FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
    FULL_KV_IDX = arg_FULL_KV_IDX
    FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
    FULL_Q_IDX = arg_FULL_Q_IDX

    # Sub notation for this kernel:
    #
    # Q: Query, K: Key, V: Value
    # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
    # DELTA: Precomputed sum(OUT*DO, axis=-1)
    # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
    # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
    # inductor codegen
    # M: Number of queries, N: Number of keys/values
    # QK_HEAD_DIM: The dimension of the query and key embeddings
    # V_HEAD_DIM: The dimension of the value embeddings
    # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
    # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
    # (Modifiable) Performance tuning options
    # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
    # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
    # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
    # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
    #
    # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
    # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
    # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
    # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
    # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
    # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
    # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
    # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
    # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.

    # The below are kernel options that can be applied for certain score_mods,
    # or involve a numerics vs. perf tradeoff
    # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
    # about 20% more numerical error, but slightly faster.

    # Define strides of inputs
    stride_qz, stride_qh, stride_qm, stride_qd = 32768, 2048, 64, 1
    stride_kz, stride_kh, stride_kn, stride_kd = 65536, 16384, 64, 1
    stride_vz, stride_vh, stride_vn, stride_vd = 65536, 16384, 64, 1
    stride_doz, stride_doh, stride_dom, stride_dod = 32768, 2048, 64, 1

    stride_dqz, stride_dqh, stride_dqm, stride_dqd = 32768, 2048, 64, 1
    stride_dvz, stride_dvh, stride_dvm, stride_dvd = 65536, 16384, 64, 1

    ZQ = 2
    HQ = 16
    HKV = 4
    Q_LEN = 32
    ZKV = 2
    KV_LEN = 256

    MATMUL_PRECISION = Q.dtype.element_ty

    pid = tl.program_id(0)
    NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
    NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)

    off_zq = tl.program_id(1) # q batch idx
    off_hkv = tl.program_id(2) # kv head idx
    off_zkv = off_zq % ZKV # kv batch idx

    SPARSE_Z = 2
    SPARSE_HQ = 16

    sparse_idx_z = off_zq % SPARSE_Z

    k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
    v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
    # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
    # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
    dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)

    # offset K, V, DV pointers for batch/kv-head
    K += k_adj
    V += v_adj
    DV += dv_adj

    RCP_LN2 = 1.44269504
    offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
    offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)

    if pid >= NUM_KV_BLOCKS:
        off_pid = pid - NUM_KV_BLOCKS
        # THIS BLOCK DOES DQ
        SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
        SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
        off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
        start_m2_block = off_pid % NUM_Q_BLOCKS
        off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
        stride_kv_num_blks_h = 1
        stride_kv_idx_h = 2
        stride_kv_idx_m = 2

        sparse_idx_hq2 = off_hq2 % SPARSE_HQ
        sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2

        sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
        sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m  # noqa: B950

        # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
        q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
        do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
        dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
        off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)

        Q2 = Q + q_adj2
        DO2 = DO + do_adj2
        # TODO: This does not work if DQ is not the same layout as Q (for example,
        # if Q is broadcasted)
        DQ2 = DQ + dq_adj2
        LSE2 = LSE + off_chz2
        DELTA2 = DELTA + off_chz2

        # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
        dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)

        start_m2 = start_m2_block * BLOCK_M2
        offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)

        # load Q and do: they stay in SRAM throughout the inner loop.
        q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
        do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)

        if PRESCALE_QK:
            q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)

        if IS_DIVISIBLE:
            Di = tl.load(DELTA2 + offs_m2)
            lse = tl.load(LSE2 + offs_m2)
        else:
            Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
            lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
        lse = tl.where(lse == -float("inf"), 0.0, lse)
        lse = lse[:, None]

        # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # KV_IDX and KV_NUM_BLKS are always contiguous.
        kv_indices = KV_IDX + sparse_kv_idx_offset
        kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
        sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)

        offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
        dq = bwd_dq_inner(
            arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
            K, V,
            dq, q, do, Di, lse,
            off_zq, off_hq2, offs_m2, offs_n2,
            stride_kn, stride_kd, stride_vn, stride_vd,
            kv_indices, sparse_kv_num_blocks,
            MATMUL_PRECISION,
            IS_FULL_BLOCKS=False,
        )

        if HAS_FULL_BLOCKS:
            # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
            kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
            kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
            sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)

            offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
            dq = bwd_dq_inner(
                arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
                K, V,
                dq, q, do, Di, lse,
                off_zq, off_hq2, offs_m2, offs_n2,
                stride_kn, stride_kd, stride_vn, stride_vd,
                kv_indices, sparse_kv_num_blocks,
                MATMUL_PRECISION,
                IS_FULL_BLOCKS=True,
            )

        # Write back dQ.
        dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
        dq *= SM_SCALE
        if IS_DIVISIBLE and SAFE_HEAD_DIM:
            tl.store(dq_ptrs, dq)
        else:
            tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
    else:
        # THIS BLOCK DOES DK & DV
        SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
        SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)

        pid_mask = pid // SPARSE_KV_MULTIPLE

        stride_q_num_blks_h = 2
        stride_q_idx_h = 2
        stride_q_idx_n = 1

        dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
        dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)

        start_n1 = pid * BLOCK_N1
        offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)

        # load K and V: they stay in SRAM throughout the inner loop.
        k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
        v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)

        if PRESCALE_QK:
            k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)

        for off_g in range(0, GQA_SHARED_HEADS):
            off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g

            # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
            q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
            do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
            dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
            off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)

            Q1 = Q + q_adj1
            DO1 = DO + do_adj1
            # TODO: This does not work if DQ is not the same layout as Q (for example,
            # if Q is broadcasted)
            LSE1 = LSE + off_chz1
            DELTA1 = DELTA + off_chz1

            sparse_idx_hq1 = off_hq1 % SPARSE_HQ
            sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1

            sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
            sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n  # noqa: B950

            # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            # Q_IDX and Q_NUM_BLKS are always contiguous.
            q_indices = Q_IDX + sparse_q_idx_offset
            q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
            sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)

            offs_m1 = q_start + tl.arange(0, BLOCK_M1)
            dk, dv = bwd_dkdv_inner(
                arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
                Q1, DO1, DELTA1, LSE1,
                dk, dv, k, v,
                off_zq, off_hq1, offs_n1, offs_m1,
                stride_qm, stride_qd, stride_dom, stride_dod,
                q_indices, sparse_q_num_blocks,
                MATMUL_PRECISION,
                IS_FULL_BLOCKS=False,
            )

            if HAS_FULL_BLOCKS:
                # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
                # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
                q_indices = FULL_Q_IDX + sparse_q_idx_offset
                q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
                sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)

                offs_m1 = q_start + tl.arange(0, BLOCK_M1)
                dk, dv = bwd_dkdv_inner(
                    arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0,
                    Q1, DO1, DELTA1, LSE1,
                    dk, dv, k, v,
                    off_zq, off_hq1, offs_n1, offs_m1,
                    stride_qm, stride_qd, stride_dom, stride_dod,
                    q_indices, sparse_q_num_blocks,
                    MATMUL_PRECISION,
                    IS_FULL_BLOCKS=True,
                )

        # Write back dV and dK.
        dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd

        index_n = offs_n1[:, None]
        index_k = offs_k[None, :]
        index_v = offs_v[None, :]

        if IS_DIVISIBLE and SAFE_HEAD_DIM:
            tl.store(dv_ptrs, dv)
        else:
            tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))

        dk *= SM_SCALE

        if SAFE_HEAD_DIM:
            mask = index_n < KV_LEN
        else:
            mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)

        # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
        # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
        xindex = index_k + 64*index_n + 16384*off_hkv + 65536*off_zq
        tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask)

metadata: {'signature': {'arg_Q': '*fp16', 'arg_K': '*fp16', 'arg_V': '*fp16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*fp16', 'arg_DQ': '*fp16', 'arg_DV': '*fp16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*fp16'}, 'device': 0, 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}], 'device_type': 'cuda', 'num_warps': 4, 'num_stages': 3, 'debug': True, 'cc': 100}
Traceback (most recent call last):
  File "/home/drisspg/meta/pytorch/torch/_inductor/runtime/triton_heuristics.py", line 748, in _precompile_config
    binary = triton.compile(*compile_args, **compile_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/drisspg/.conda/envs/dev/lib/python3.12/site-packages/triton/compiler/compiler.py", line 359, in compile
    next_module = compile_ir(module, metadata)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/drisspg/.conda/envs/dev/lib/python3.12/site-packages/triton/backends/nvidia/compiler.py", line 456, in <lambda>
    stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability)
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/drisspg/.conda/envs/dev/lib/python3.12/site-packages/triton/backends/nvidia/compiler.py", line 298, in make_ttgir
    pm.run(mod)
RuntimeError: PassManager::run failed
frames [('total', 3), ('ok', 3)]
inline_call []
stats [('calls_captured', 8), ('unique_graphs', 3)]
aot_autograd [('total', 1), ('autograd_cache_miss', 1), ('ok', 1)]
inductor [('triton_bundler_save_kernel', 8), ('async_compile_cache_miss', 3), ('fxgraph_cache_miss', 1), ('triton_bundler_save_static_autotuner', 1), ('fxgraph_cache_bypass', 1)]
graph_break []
F

==================================================== FAILURES =====================================================
_____________________________ TestFlexAttentionCUDA.test_GQA_score_mod1_cuda_float16 ______________________________
Traceback (most recent call last):
  File "/home/drisspg/.conda/envs/dev/lib/python3.12/unittest/case.py", line 58, in testPartExecutor
    yield
  File "/home/drisspg/.conda/envs/dev/lib/python3.12/unittest/case.py", line 634, in run
    self._callTestMethod(testMethod)
  File "/home/drisspg/.conda/envs/dev/lib/python3.12/unittest/case.py", line 589, in _callTestMethod
    if method() is not None:
       ^^^^^^^^
  File "/home/drisspg/meta/pytorch/torch/testing/_internal/common_utils.py", line 3224, in wrapper
    method(*args, **kwargs)
  File "/home/drisspg/meta/pytorch/torch/testing/_internal/common_utils.py", line 3224, in wrapper
    method(*args, **kwargs)
  File "/home/drisspg/meta/pytorch/torch/testing/_internal/common_device_type.py", line 446, in instantiated_test
    raise rte
  File "/home/drisspg/meta/pytorch/torch/testing/_internal/common_device_type.py", line 426, in instantiated_test
    result = test(self, **param_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/drisspg/meta/pytorch/torch/testing/_internal/common_device_type.py", line 1349, in dep_fn
    return fn(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/drisspg/meta/pytorch/torch/testing/_internal/common_device_type.py", line 1215, in dep_fn
    return fn(slf, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/drisspg/meta/pytorch/test/inductor/test_flex_attention.py", line 1430, in test_GQA
    self.run_test(*inputs)
  File "/home/drisspg/meta/pytorch/test/inductor/test_flex_attention.py", line 566, in run_test
    compiled_out.backward(backward_grad)
  File "/home/drisspg/meta/pytorch/torch/_tensor.py", line 625, in backward
    torch.autograd.backward(
  File "/home/drisspg/meta/pytorch/torch/autograd/__init__.py", line 354, in backward
    _engine_run_backward(
  File "/home/drisspg/meta/pytorch/torch/autograd/graph.py", line 829, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/drisspg/meta/pytorch/torch/autograd/function.py", line 315, in apply
    return user_fn(self, *args)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/drisspg/meta/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2303, in backward
    return impl_fn()
           ^^^^^^^^^
  File "/home/drisspg/meta/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2289, in impl_fn
    out = CompiledFunction._backward_impl(ctx, all_args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/drisspg/meta/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2394, in _backward_impl
    CompiledFunction.compiled_bw = aot_config.bw_compiler(
                                   ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/drisspg/meta/pytorch/torch/_functorch/_aot_autograd/schemas.py", line 1256, in __call__
    return self.compiler_fn(gm, example_inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/drisspg/meta/pytorch/torch/_dynamo/backends/common.py", line 76, in _wrapped_bw_compiler
    disable(
  File "/home/drisspg/meta/pytorch/torch/_dynamo/eval_frame.py", line 1005, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/drisspg/meta/pytorch/torch/_utils_internal.py", line 92, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/drisspg/meta/pytorch/torch/_inductor/compile_fx.py", line 2428, in bw_compiler
    return inner_compile(
           ^^^^^^^^^^^^^^
  File "/home/drisspg/meta/pytorch/torch/_inductor/compile_fx.py", line 773, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/drisspg/meta/pytorch/torch/_dynamo/repro/after_aot.py", line 124, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/drisspg/meta/pytorch/torch/_inductor/compile_fx.py", line 952, in _compile_fx_inner
    mb_compiled_graph = fx_codegen_and_compile(
                        ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/drisspg/meta/pytorch/torch/_inductor/compile_fx.py", line 1652, in fx_codegen_and_compile
    return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/drisspg/meta/pytorch/torch/_inductor/compile_fx.py", line 1506, in codegen_and_compile
    compiled_module = graph.compile_to_module()
                      ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/drisspg/meta/pytorch/torch/_inductor/graph.py", line 2318, in compile_to_module
    return self._compile_to_module()
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/drisspg/meta/pytorch/torch/_inductor/graph.py", line 2328, in _compile_to_module
    mod = self._compile_to_module_lines(wrapper_code)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/drisspg/meta/pytorch/torch/_inductor/graph.py", line 2396, in _compile_to_module_lines
    mod = PyCodeCache.load_by_key_path(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/drisspg/meta/pytorch/torch/_inductor/codecache.py", line 3466, in load_by_key_path
    mod = _reload_python_module(key, path, set_sys_modules=in_toplevel)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/drisspg/meta/pytorch/torch/_inductor/runtime/compile_tasks.py", line 33, in _reload_python_module
    exec(code, mod.__dict__, mod.__dict__)
  File "/tmp/tmp0yiz3c94/az/caza2gzmsagyuusmf2ka3oat3na4xv6zudssk244xmlzsbv2knze.py", line 117, in <module>
  File "/home/drisspg/meta/pytorch/torch/_inductor/async_compile.py", line 489, in triton
    kernel.precompile(
  File "/home/drisspg/meta/pytorch/torch/_inductor/runtime/triton_heuristics.py", line 437, in precompile
    self._precompile_worker()
  File "/home/drisspg/meta/pytorch/torch/_inductor/runtime/triton_heuristics.py", line 459, in _precompile_worker
    compile_results.append(self._precompile_config(c))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/drisspg/meta/pytorch/torch/_inductor/runtime/triton_heuristics.py", line 748, in _precompile_config
    binary = triton.compile(*compile_args, **compile_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/drisspg/.conda/envs/dev/lib/python3.12/site-packages/triton/compiler/compiler.py", line 359, in compile
    next_module = compile_ir(module, metadata)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/drisspg/.conda/envs/dev/lib/python3.12/site-packages/triton/backends/nvidia/compiler.py", line 456, in <lambda>
    stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability)
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/drisspg/.conda/envs/dev/lib/python3.12/site-packages/triton/backends/nvidia/compiler.py", line 298, in make_ttgir
    pm.run(mod)
RuntimeError: PassManager::run failed

To execute this test, run the following from the base repo dir:
    python test/inductor/test_flex_attention.py TestFlexAttentionCUDA.test_GQA_score_mod1_cuda_float16

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
============================================= short test summary info =============================================
FAILED [5.1441s] test/inductor/test_flex_attention.py::TestFlexAttentionCUDA::test_GQA_score_mod1_cuda_float16 - RuntimeError: PassManager::run failed
================================== 1 failed, 1 passed, 1404 deselected in 18.10s ==================================
~/meta/pytorch flex-warning !1 ❯
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160227
Approved by: https://github.com/Skylion007, https://github.com/Chillee
2025-08-11 23:30:20 +00:00
99bc2f94c1 Update export/schema.py (#160220)
Summary:
Model could have multiple ExportedPrograms
- for different methods. They can have different weights.
- for different delegates. They can also have different weights.

For this reason, we make weight per ExportedProgram.

Also, we cleanup Model, and Program. IIUC, Model and Program are not used anywhere, so it's ok to make BC breaking change.

Test Plan:
CI

Rollback Plan:

Differential Revision: D79917395

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160220
Approved by: https://github.com/angelayi, https://github.com/dolpm, https://github.com/jingsh
2025-08-11 23:14:08 +00:00
fc25c68f20 [hop][exc] make UncapturedHigherOrderOpError print user code and avoid re-raise (#159296)
After the change, the error stacktrace is attached with user code stack and  is suppressed into 1 (without the scrolling up mssage). For example:
```python
    class Test(torch.nn.Module):
        def forward(self, c, x):
            def cond_fn(c, x):
                return c > 0 and x.size(0) < 20

            def body_fn(c, x):
                return c - 1, x.sin()

            return torch._higher_order_ops.while_loop(cond_fn, body_fn, (c, x))
```

Now gives the following error message:
```python
Traceback (most recent call last):
  File "/home/yidi/local/pytorch/test/inductor/test_control_flow.py", line 1705, in test_while_loop_size_mismatch_tensor_expansion
    self._run_test(
    ~~~~~~~~~~~~~~^
        model=WhileLoopModels.SizeMismatchTensorExpansion(),
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    ...<2 lines>...
        dynamic=dynamic,
        ^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/yidi/local/pytorch/test/inductor/test_control_flow.py", line 1417, in _run_test
    result = model(*inputs_with_counters)
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yidi/local/pytorch/test/inductor/test_control_flow.py", line 1053, in forward
    return torch._higher_order_ops.while_loop(cond_fn, body_fn, (c, x))
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_higher_order_ops/while_loop.py", line 176, in while_loop
    return torch.compile(
           ~~~~~~~~~~~~~~
        _while_loop_op_wrapper, backend=backend, fullgraph=True
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    )(flat_cond_fn, flat_body_fn, tuple(flat_inputs), tuple())
    ~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 804, in compile_wrapper
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 1595, in __call__
    result = self._torchdynamo_orig_backend(
        frame, cache_entry, self.hooks, frame_state, skip=1
    )
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 1353, in __call__
    result = self._inner_convert(
        frame, cache_entry, hooks, frame_state, skip=skip + 1
    )
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 682, in __call__
    result = _compile(
        frame.f_code,
    ...<16 lines>...
        convert_frame_box=self._box,
    )
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 1172, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/yidi/local/pytorch/torch/_utils_internal.py", line 98, in wrapper_function
    return function(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 858, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 897, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1461, in transform_code_object
    transformations(instructions, code_options)
    ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 300, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 818, in transform
    tracer.run()
    ~~~~~~~~~~^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3528, in run
    super().run()
    ~~~~~~~~~~~^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1372, in run
    while self.step():
          ~~~~~~~~~^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1276, in step
    self.dispatch_table[inst.opcode](self, inst)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 852, in wrapper
    return inner_fn(self, inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2240, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
    ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1200, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/lazy.py", line 212, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 91, in graph_break_as_hard_error
    raise exc.with_traceback(sys.exc_info()[2]) from None
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 77, in graph_break_as_hard_error
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 1287, in call_function
    ) = speculate_subgraph(
        ~~~~~~~~~~~~~~~~~~^
        tx,
        ^^^
    ...<33 lines>...
        supports_aliasing=self.supports_aliasing,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 877, in speculate_subgraph
    raise ex
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 718, in speculate_subgraph
    output = f.call_function(tx, args, sub_kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 580, in call_function
    return super().call_function(tx, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 334, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1217, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3733, in inline_call
    return tracer.inline_call_()
           ~~~~~~~~~~~~~~~~~~~^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3936, in inline_call_
    self.run()
    ~~~~~~~~^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1372, in run
    while self.step():
          ~~~~~~~~~^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1276, in step
    self.dispatch_table[inst.opcode](self, inst)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 852, in wrapper
    return inner_fn(self, inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2240, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
    ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1200, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/lazy.py", line 212, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 580, in call_function
    return super().call_function(tx, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 334, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1217, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3733, in inline_call
    return tracer.inline_call_()
           ~~~~~~~~~~~~~~~~~~~^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3936, in inline_call_
    self.run()
    ~~~~~~~~^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1372, in run
    while self.step():
          ~~~~~~~~~^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1276, in step
    self.dispatch_table[inst.opcode](self, inst)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 830, in inner
    unimplemented_v2(
    ~~~~~~~~~~~~~~~~^
        gb_type="Data-dependent branching",
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    ...<5 lines>...
        ],
        ^^
    )
    ^
  File "/home/yidi/local/pytorch/torch/_dynamo/exc.py", line 580, in unimplemented_v2
    raise Unsupported(msg)
torch._dynamo.exc.UncapturedHigherOrderOpError: while_loop doesn't work unless it is captured completely with torch.compile. Got Data-dependent branching
  Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow.
  Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.
  Hint: Use `torch.cond` to express dynamic control flow.

  Developer debug context: attempted to jump with TensorVariable()

 For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0170.html

from user code:
   File "/home/yidi/local/pytorch/torch/_higher_order_ops/while_loop.py", line 167, in _while_loop_op_wrapper
    return while_loop_op(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_higher_order_ops/while_loop.py", line 137, in flat_cond_fn
    return cond_fn(*carried, *additional)
  File "/home/yidi/local/pytorch/test/inductor/test_control_flow.py", line 1047, in cond_fn
    return c > 0 and x.size(0) < 20

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

To execute this test, run the following from the base repo dir:
    python test/inductor/test_control_flow.py WhileLoopTests.test_while_loop_size_mismatch_tensor_expansion_device_cpu_dynamic_False

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159296
Approved by: https://github.com/zou3519
2025-08-11 22:48:10 +00:00
5a40c57844 [MTIA] Implement isAvailable() for MTIA hooks (#160304)
Summary: MTIA is missing the `isAvailable()` override, which is necessary for some of the device agnostic methods.

Test Plan:
`torch._C._get_accelerator()`

Rollback Plan:

Differential Revision: D79981115

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160304
Approved by: https://github.com/nautsimon
2025-08-11 21:45:11 +00:00
7d2ec704e4 Fix MPS autocast for ConvTranspose3d (#160345)
## Summary
- ensure ConvTranspose3d uses fp32 under MPS autocast
- add MPS autocast test for ConvTranspose3d

Generated by Codex, see https://chatgpt.com/codex/tasks/task_e_689a360388288327a2cac6f55bbfc42c

Fixes https://github.com/pytorch/pytorch/issues/160332

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160345
Approved by: https://github.com/dcci
2025-08-11 21:01:52 +00:00
fc80f6859e Fix collective schedule logging and runtime tests (#160260)
Summary:

- Fix collective schedule logging so that only logs when collectives present
- Fix runtime estimate test to check if each op has a number value

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160260
Approved by: https://github.com/Skylion007
2025-08-11 20:58:52 +00:00
cf0a0dcb0a Make user defined Triton kernels serializable for fx_graph_runnable (#160002)
Resolves issue https://github.com/pytorch/pytorch/issues/153475 where `fx_graph_runnable` didn't work with user defined triton kernels.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160002
Approved by: https://github.com/eellison
2025-08-11 20:54:33 +00:00
b149c7204c Revert "port distributed pipeline test files for Intel GPU (#159033)"
This reverts commit 76a0609b6bddb2bc40f1eb4ade12885023653d59.

Reverted https://github.com/pytorch/pytorch/pull/159033 on behalf of https://github.com/clee2000 due to broke test_cpp_extensions_stream_and_event.py::TestCppExtensionStreamAndEvent::test_stream_event [GH job link](https://github.com/pytorch/pytorch/actions/runs/16890370216/job/47849586456) [HUD commit link](76a0609b6b) note to self: bad TD ([comment](https://github.com/pytorch/pytorch/pull/159033#issuecomment-3176833314))
2025-08-11 20:44:45 +00:00
09381f5dac Revert "[Graph Partition] Pass all OSS unit tests (#154667)"
This reverts commit ca7315c17162ea21b1ca5ba23f4bf6168766c7b9.

Reverted https://github.com/pytorch/pytorch/pull/154667 on behalf of https://github.com/clee2000 due to broke inductor/test_memory.py::TestOperatorReorderForPeakMemory::test_reorder_peak_memory_lpmf [GH job link](https://github.com/pytorch/pytorch/actions/runs/16885961204/job/47836769279) [HUD commit link](ca7315c171) note to self: bad TD ([comment](https://github.com/pytorch/pytorch/pull/154667#issuecomment-3176805477))
2025-08-11 20:34:27 +00:00
9eedd2a20b [PGO] no counterfactual suggestions for dynamic allowlist (#160231)
Being more conservative with whitelist suggestions as we roll out suggestions; now we only suggest sources that were dynamic in previous runs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160231
Approved by: https://github.com/bobrenjc93
2025-08-11 20:13:25 +00:00
c3dc8dc412 159965 is merged, no need to patch it in (#160275)
Signed-off-by: Edward Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160275
Approved by: https://github.com/albanD, https://github.com/ZainRizvi
2025-08-11 19:55:04 +00:00
76a0609b6b port distributed pipeline test files for Intel GPU (#159033)
In this PR we will port all distributed pipeline test files.
We could enable Intel GPU with following methods and try the best to keep the original code styles:

1. instantiate_device_type_tests()
2. use "torch.accelerator.current_accelerator()" to determine the accelerator backend
3. use "requires_accelerator_dist_backend()" to replace requires_nccl()
4. use "get_default_backend_for_device()" to get backend
5. enabled XPU for some test path
6. add TEST_MULTIACCELERATOR in common_utils for all backend.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159033
Approved by: https://github.com/guangyey, https://github.com/d4l3k

Co-authored-by: Daisy Deng <daisy.deng@intel.com>
2025-08-11 19:43:15 +00:00
c8205cb354 [autograd] match 0-dim gradients device type regardless of subclassness (#160165)
Not sure if there some subclasses where the outer.dim() == 0 but you wouldn't want to move it?

FIXES https://github.com/pytorch/pytorch/issues/160084

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160165
Approved by: https://github.com/ezyang, https://github.com/albanD
2025-08-11 17:57:32 +00:00
d25c4f954d [MPS] Type-promote tensor-iterator common dtype (#160334)
Otherwise, `torch.add(FloatTensor, IntTensor, alpha=2)` and `torch.add(FloatTensor, IntTensor, alpha=2)` were dispatched to different kernels

Fixes https://github.com/pytorch/pytorch/issues/160208
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160334
Approved by: https://github.com/Skylion007, https://github.com/dcci
2025-08-11 17:53:56 +00:00
d0e2240f68 [triton_heuristics] Optimize the triton launcher in pt2 (#160000)
Summary:

(Original author: Xu Zhao. Commandeered by David to land this since it is relatively urgent)

We observed ~10us PT2-Triton launch overhead regression after pin update.

Before Triton pin-update:
 {F1980557238}

After Triton pin-update:
 {F1980557240}

The root cause is because https://github.com/pytorch/pytorch/pull/145051 adds `_get_args_with_constexprs` to the cubin launcher caller function, which is on the critical path.

The motivation for `_get_args_with_constexprs` was that between triton 3.2 and triton 3.3, the convention for calling Triton kernels (at the level that non-static-cuda-launcher inductor integrates) changed. Previously, the callable did not take constexpr arguments as parameters; after 3.3, it does. With pointwise/reduction kernels, we don't know the constexpr values until after autotuning occurs; so `_get_args_with_constexprs` would inject constexprs into the arguments list before calling the Triton kernel. The fix (in this PR) is to instead inject the constexpr args into the launcher string - this avoids the cost of sorting/reordering arguments which previously occurred upon execution of each kernel.

Note that the static_cuda_launcher.py does not require constants to be passed to the cubin launcher (e96c7c4bb0/torch/_inductor/runtime/static_cuda_launcher.py (L220)), there is no need to pass in constexprs to the generated launcher code.

The new launcher code needs to work on three cases:
- StaticallyLaunchedCudaKernel
- triton.compile.CompiledKernel
- AOTInductor

Analysis: https://docs.google.com/document/d/1PHaSmx2w59K8qpjw5_qzKWShfEgptf_Zpv_DL7YxiWU/edit?tab=t.0

Test Plan:
Before:
```
$ buck2 run mode/opt //pytorch/benchmark:pt2 -- --only BERT_pytorch --performance --backend=inductor --training --amp --disable-cudagraphs

1.893x
```

```

$ buck2 run mode/opt //pytorch/tritonbench:run -- --op launch_latency
  x_val    nop_python_function-walltime    nop_triton_kernel-walltime    nop_triton_compiled_kernel_run-walltime    nop_inductor_kernel-walltime    nop_inductor_kernel_cudagraph-walltime
-------  ------------------------------  ----------------------------  -----------------------------------------  ------------------------------  ----------------------------------------
      0                      0.00760921                       1.80298                                   0.623282                         5.25024                                  0.203722
     19                      0.00799885                       4.78223                                   1.00226                          5.8213                                   0.239084
average                      0.00780403                       3.29261                                   0.812769                         5.53577                                  0.221403
```

After:

```
buck2 run mode/opt //pytorch/tritonbench:run -- --op launch_latency
  x_val    nop_python_function-walltime    nop_triton_kernel-walltime    nop_triton_compiled_kernel_run-walltime    nop_inductor_kernel-walltime    nop_inductor_kernel_cudagraph-walltime
-------  ------------------------------  ----------------------------  -----------------------------------------  ------------------------------  ----------------------------------------
      0                      0.00747067                       1.92589                                   0.726509                         4.35459                                  0.204205
     19                      0.00747823                       7.36852                                   1.26241                          6.28208                                  0.239278
average                      0.00747445                       4.6472                                    0.994459                         5.31834                                  0.221741
```

```
$ buck2 run mode/opt //pytorch/benchmark:pt2 -- --only BERT_pytorch --performance --backend=inductor --training --amp --disable-cudagraphs

1.985x
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160000
Approved by: https://github.com/jansel

Co-authored-by: Xu Zhao <xzhao9@meta.com>
2025-08-11 17:22:40 +00:00
9ccd0f5e31 Fix unbacked symint and memory leak in inductor memory planning (#159839)
Summary:

In memory planning, some allocation sizes involve unbacked symints. These unbacked symints are not known before they are computed in run time, so **allocation pools that involve unbacked symints cannot be allocated until we have the values of the unbacked symints** .

So we add a notion of `earliest_available` to Allocation nodes. If an allocation node has unbacked symint, it is available at only when its live range begin.

Then in AllocationPool, if a pool involves an Allocation node that has an earliest available time, we restrict its life range.

If a block's earliest available time is later than a pool's life range's start time, we cannot allocate it from the pool.

We also fix a memory leak that's caused by allocating tensor without wrapping it with RAIIAtenTensor.

In python wrapper for JIT inductor, `codegen_alloc_from_pool` doesn't actually write the alloc lines to wrapper, it just returns the string to alloc. However, in cpp_wrapper, `codegen_alloc_from_pool`  actually write to the wrapper. Specifically, it writes the following and returns string `RAIIAtenTensorHandle`.

```
AtenTensorHandle handle_name;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool(....);
```

This is bug prune. **If you write aoti_torch__alloc_from_pool lines, you must write the RAIIAtenTensorHandle as well**, otherwise you get memory leaks.

We remove the alloc_from_pool call from codegen_create, because this doesn't work for AOTI. In python wrapper, we can generate the same alloc_from_pool variable name for the same block, but cpp_wrapper will generate a different variable name for each call to alloc_from_pool.

Test Plan:
```
 python test/inductor/test_memory_planning.py
```

Rollback Plan:

Differential Revision: D79603119

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159839
Approved by: https://github.com/jansel
2025-08-11 17:16:15 +00:00
ca7315c171 [Graph Partition] Pass all OSS unit tests (#154667)
Graph partition leads to 6.2% speedup on vision_maskrcnn, 5.8% speedup on yolov3. [P1819700563](https://www.internalfb.com/phabricator/paste/view/P1819700563), 39.5% speedup on speech_transformer inference [P1830602200](https://www.internalfb.com/phabricator/paste/view/P1830602200), 85% speedup on speech_transformer training [P1831115315](https://www.internalfb.com/phabricator/paste/view/P1831115315).

Run the same diff on two days and both show speedup on average.

[first TorchInductor Benchmark ci run](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2021%20Jul%202025%2016%3A37%3A55%20GMT&stopTime=Mon%2C%2028%20Jul%202025%2016%3A37%3A55%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=bf/partition-turn-on&lCommit=75ef90fe89b82c967362a2d40fdf1af047202bc2&rBranch=main&rCommit=abcb24f4de11f8fedf2c2c9ff53b6092ef42306d)
<img width="1885" height="752" alt="image" src="https://github.com/user-attachments/assets/13bba9fc-5dbf-42ad-8558-d54f7e367b41" />

[second TorchInductorBenchmark ci run](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Wed%2C%2023%20Jul%202025%2016%3A38%3A27%20GMT&stopTime=Wed%2C%2030%20Jul%202025%2016%3A38%3A27%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=bf/partition-turn-on&lCommit=66de27e29338c26b1be94733049868cb0309ea52&rBranch=main&rCommit=70d2e9ba455c3c910f6f95b24171c8eee7bc00bf)
<img width="2513" height="1030" alt="image" src="https://github.com/user-attachments/assets/3a413dcb-2314-4292-919a-7ca181f9eeac" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154667
Approved by: https://github.com/eellison
2025-08-11 16:25:12 +00:00
68a4b4b2e3 [codemod] Fix unreachable-break issue in caffe2/c10/cuda/CUDAFunctions.cpp +2 (#160257)
Summary:
LLVM has a warning `-Wunreachable-code-break` which identifies `break` statements that cannot be reached. These compromise readability, are misleading, and may identify bugs. This diff removes such statements.

For questions/comments, contact r-barnes.

 - If you approve of this diff, please use the "Accept & Ship" button :-)

Test Plan:
Sandcastle

Rollback Plan:

Differential Revision: D79835614

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160257
Approved by: https://github.com/Skylion007
2025-08-11 16:09:24 +00:00