mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
a7abf57aabec0ce686092e2d66e53ba185dbc56b
91547 Commits
| Author | SHA1 | Message | Date | |
|---|---|---|---|---|
| a7abf57aab |
[ROCm] Support large inputs for coalesceValuesKernel (#158281)
# Description
`.coalesce` cannot handle large inputs on ROCM due to maximal grid size limit.
This PR splits axis `X` into axes `X` and `Y`, and repurposes `Z` for original `Y` on ROCm to avoid such limitation.
Confirmed the new approach can handle large inputs. Correctness needs validation.
# Testing Command
`python torch_spmv.py 22500000 272500000`
## Script `torch_spmv.py`
``` python
import torch
import argparse
def parse_args():
parser = argparse.ArgumentParser(
description="Sparse COO Matrix by Dense Vector Multiplication using PyTorch"
)
parser.add_argument("n", type=int, help="Size of the NxN matrix")
parser.add_argument("nnz", type=int, help="Number of non-zero entries")
return parser.parse_args()
def main():
args = parse_args()
n = args.n
nnz = args.nnz
dtype = torch.float32
device = torch.device('cuda')
# Generate random indices for the sparse matrix in COO format.
torch.manual_seed(42)
rows = torch.randint(0, n, (nnz,), dtype=torch.int64, device=device)
cols = torch.randint(0, n, (nnz,), dtype=torch.int64, device=device)
indices = torch.stack([rows, cols], dim=0)
# Generate random values.
values = torch.randn(nnz, dtype=torch.float32, device=device)
# Create the sparse COO matrix and move it to the target device.
sparse_matrix = torch.sparse_coo_tensor(indices, values, size=(n, n), dtype=torch.float32, device=device)
sparse_matrix = sparse_matrix.coalesce()
# Generate a random dense vector.
dense_vector = torch.randn(n, dtype=torch.float32, device=device)
# Perform sparse matrix - dense vector multiplication.
# Using torch.sparse.mm which expects a 2D tensor for the vector.
result = torch.sparse.mm(sparse_matrix, dense_vector.unsqueeze(1)).squeeze()
# result = torch.mv(sparse_matrix, dense_vector)
# Print the result.
print("Result of the multiplication:")
print(torch.sum(result))
if __name__ == "__main__":
main()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158281
Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily
|
|||
| f7b2f3314c |
Revert "[triton_heuristics] Optimize the triton launcher in pt2 (#160000)"
This reverts commit d0e2240f680ea2a553f7ee8188f52482e130bfd0. Reverted https://github.com/pytorch/pytorch/pull/160000 on behalf of https://github.com/davidberard98 due to D80054972 failing with test_triton_kernel_2d_autotune_grad_False_dynamic_True_backend_inductor_grid_type_1_tdlp_1 ([comment](https://github.com/pytorch/pytorch/pull/160000#issuecomment-3180144676)) |
|||
| 9d37c960a4 |
[ROCm][CI] use new benchmark image for dynamo (#160421)
Follow-up to #160047 that separated the rocm image into default CI and benchmarks. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160421 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily <jeff.daily@amd.com> |
|||
| b219ca2a00 |
Revert "Update triton xpu commit to support python 3.14 (#160183)"
This reverts commit 7fbc22855c17741ae016992803b2e147a13aa22d.
Reverted https://github.com/pytorch/pytorch/pull/160183 on behalf of https://github.com/clee2000 due to I'm not sure how, but it seems to have broken inductor/test_extension_backend.py::ExtensionBackendTests::test_open_device_registration [GH job link](https://github.com/pytorch/pytorch/actions/runs/16911267995/job/47917091939) [HUD commit link](
|
|||
| b7db86600a |
Fix Tensor illustration, use permalinks for image embedding in Readme.md (#160416)
Fixes Tensor illustration being broken on pypi.org. Also uses permalinks instead of links to images for embedding as per this suggestion of Alban: https://github.com/pytorch/pytorch/pull/160187#discussion_r2262978006 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160416 Approved by: https://github.com/malfet |
|||
| 9708fcf92d |
Account for triton kernel source code hidden in custom ops properly in AOTAutogradCache (#160120)
This PR fixes a bug where user defined triton kernels hidden behind `triton_op` do not register source code changes. If a user *only* changes a triton kernel source_code, because triton kernels are hidden under the custom op, dynamo hasn't traced into them yet. This means at AOTAutograd time, we don't know the list of triton kernels that are defined by custom ops. This is an initial fix for the issue by parsing the AST of the custom op looking for triton kernels. This won't catch more degenerate cases if the custom op calls other custom ops/functions that then call triton kernels, and then the toplevel compiled graph doesn't know about it. To handle that, we'd have to trace through the custom op at dynamo time. This should handle 99% of cases, though. I added an expectedFailure test to show the limitation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160120 Approved by: https://github.com/zou3519 |
|||
| a288b15ea9 |
[CI] Reduce XPU Windows build time (#159763)
Reduce the time cost from 2.5 hours to about 1.5 hours. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159763 Approved by: https://github.com/EikanWang, https://github.com/atalman |
|||
| 7fbc22855c |
Update triton xpu commit to support python 3.14 (#160183)
Follow PR #159725 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160183 Approved by: https://github.com/EikanWang, https://github.com/atalman |
|||
| f33ce40bc0 |
[bucketing] Bucket only adjacent collectives to prevent reordering (#159983)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159983 Approved by: https://github.com/wconstab, https://github.com/eellison |
|||
| 4d5b3f2d5a |
[dynamo][guards] Install dict watchers for recrusive dict tag optimization (#159796)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159796 Approved by: https://github.com/jansel |
|||
| f990490a23 |
Add label_smoothing param in nn.BCELoss and nn.BCEWithLogitsLoss (#150282)
Fixes #91545 ## Changes - Add `label_smoothing` param and docs - Add test case for `label_smoothing` - Remove duplicate description in `nn.BCELoss` and `nn.BCEWithLogitsLoss` ## Test Result ```bash pytest -s test/test_nn.py -k test_bce ```    Pull Request resolved: https://github.com/pytorch/pytorch/pull/150282 Approved by: https://github.com/cyyever, https://github.com/mikaylagawarecki |
|||
| b9003ed3d8 |
Dynamo Deep Dive Documentation Fix (#158860)
changed SourceBuilder to VariableBuilder Fixes #158447 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158860 Approved by: https://github.com/mlazos |
|||
| fea7e9dd37 |
extract shape in _view_has_unbacked_input (#160255)
Summary: We were getting DDE on reshape still!! i looked deeper and found an issue in _view_has_unbacked_input namely when input is [[,,]] it need to be normalized to [..] Test Plan: existing tests. Rollback Plan: Differential Revision: D79951119 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160255 Approved by: https://github.com/bobrenjc93 |
|||
| 9a0f7a3bb0 |
[retry-land][pytorch][dynamo_compile] Log stack_trace to dynamo_compile (#160348)
refer: https://github.com/pytorch/pytorch/pull/159655 Earlier pr failed on dynamo/test_utils.py::TestDynamoTimed::test_dynamo_timed. Updated test_dynamo_timed + re-ran locally to test. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160348 Approved by: https://github.com/masnesral |
|||
| 01bcf9a40d |
Bump transformers pin (#159291)
Trying to update hf pin. Benchmarking run to figure out issues <img width="1356" height="123" alt="image" src="https://github.com/user-attachments/assets/fbc435f3-a7cb-4280-9636-2ea6d15d7b6d" /> Retrying - https://github.com/pytorch/pytorch/pull/156118 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159291 Approved by: https://github.com/BoyuanFeng, https://github.com/huydhn Co-authored-by: Huy Do <huydhn@gmail.com> |
|||
| 8d3d1c8443 |
[dynamo] fixes to propagate tag safeness (#159807)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159807 Approved by: https://github.com/jansel |
|||
| 0f3b10b8ee |
[audio hash update] update the pinned audio hash (#160384)
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml). Update the pinned audio hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160384 Approved by: https://github.com/pytorchbot |
|||
| 5f1010fbb3 |
[Graph Partition] Pass all OSS unit tests (#154667)
Graph partition leads to 6.2% speedup on vision_maskrcnn, 5.8% speedup on yolov3. [P1819700563](https://www.internalfb.com/phabricator/paste/view/P1819700563), 39.5% speedup on speech_transformer inference [P1830602200](https://www.internalfb.com/phabricator/paste/view/P1830602200), 85% speedup on speech_transformer training [P1831115315](https://www.internalfb.com/phabricator/paste/view/P1831115315). Run the same diff on two days and both show speedup on average. [first TorchInductor Benchmark ci run](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2021%20Jul%202025%2016%3A37%3A55%20GMT&stopTime=Mon%2C%2028%20Jul%202025%2016%3A37%3A55%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=bf/partition-turn-on&lCommit=75ef90fe89b82c967362a2d40fdf1af047202bc2&rBranch=main&rCommit=abcb24f4de11f8fedf2c2c9ff53b6092ef42306d) <img width="1885" height="752" alt="image" src="https://github.com/user-attachments/assets/13bba9fc-5dbf-42ad-8558-d54f7e367b41" /> [second TorchInductorBenchmark ci run](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Wed%2C%2023%20Jul%202025%2016%3A38%3A27%20GMT&stopTime=Wed%2C%2030%20Jul%202025%2016%3A38%3A27%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=bf/partition-turn-on&lCommit=66de27e29338c26b1be94733049868cb0309ea52&rBranch=main&rCommit=70d2e9ba455c3c910f6f95b24171c8eee7bc00bf) <img width="2513" height="1030" alt="image" src="https://github.com/user-attachments/assets/3a413dcb-2314-4292-919a-7ca181f9eeac" /> Pull Request resolved: https://github.com/pytorch/pytorch/pull/154667 Approved by: https://github.com/eellison |
|||
| edaa151d0d |
[CI] Move CUDA tests to trunk workflow (#160379)
Which is getting run before PR is merged anyway, but according to 3X less frequently than pull workflow according to [Flambeau](https://pytorchci.grafana.net/public-dashboards/1c571e79090443eaaa9811db71f8d23b) <img width="796" height="573" alt="image" src="https://github.com/user-attachments/assets/0235e610-4e1c-4be5-88bf-ea8278d1c656" /> I.e. that will probably results in some longer time to signal, but considering that frequency of changes to eager PyTorch-on-CUDA slowed down and Inductor changes are decorated with ciflow/inductor, this looks like an acceptable tradeoff to reduce costs Pull Request resolved: https://github.com/pytorch/pytorch/pull/160379 Approved by: https://github.com/izaitsevfb |
|||
| 10bc36fe84 |
Get tensor subclasses and torch.library.triton_op to dispatch correctly (#160341)
Short-term fix for https://github.com/pytorch/pytorch/issues/160333 The problem is: 1) `triton_op` adds a decomposition for FunctionalTensorMode for this operation 2) Tensor Subclasses rely on FunctionalTensorMode's `__torch_dispatch__` returning NotImplemented. 3) `triton_op`'s FunctionalTensorMode decomposition takes precedence over FunctionalTensorMode's decomposition. The easy fix is to copy-paste the FunctionalTensorMode's NotImplemented return logic into the decomposition. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160341 Approved by: https://github.com/drisspg |
|||
| 32e5e2f596 |
[vllm hash update] update the pinned vllm hash (#160259)
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml). Update the pinned vllm hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160259 Approved by: https://github.com/pytorchbot |
|||
| bfc873d02e |
[ROCm][Windows] Revert copying hipblaslt and rocblas dirs. (#159083)
This reverts the changes from |
|||
| eed9dbf70f |
[ROCm] Add torch/_rocm_init.py to .gitignore. (#159806)
Follow-up to https://github.com/pytorch/pytorch/pull/155285.
Build scripts like https://github.com/ROCm/TheRock/blob/main/external-builds/pytorch/build_prod_wheels.py generate this file with contents like:
```python
def initialize():
import rocm_sdk
rocm_sdk.initialize_process(
preload_shortnames=['amd_comgr', 'amdhip64', 'hiprtc', 'hipblas', 'hipfft', 'hiprand', 'hipsparse', 'hipsolver', 'hipblaslt', 'miopen'],
check_version='7.0.0rc20250804')
```
We may also have https://github.com/pytorch/pytorch/blob/main/tools/amd_build/build_amd.py do the same thing as more of that build support moves here into the upstream PyTorch repository itself (see https://github.com/pytorch/pytorch/issues/159520).
This file is then loaded if present here:
|
|||
| be53f609aa |
fix retaining multimem in symmetric memory (#160343)
fixes OOM in #160289 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160343 Approved by: https://github.com/eqy |
|||
| 95210cc409 |
[BE] Isolate pre-push hook dependencies in dedicated virtual environment (#160048)
This adds two changes: - Isolates pre-push hook dependencies into an isolated venv, no longer affect your system environment - Lets you manually run the pre-push lintrunner (including with lintrunner -a) by invoking `python scripts/lintrunner.py [-a]` (it's ugly, but better than nothing...for now) This is a follow up to: - https://github.com/pytorch/pytorch/pull/158389 ## Problem The current pre-push hook setup installs lintrunner and related dependencies globally, which makes developers nervous about system pollution and can cause version conflicts with existing installations. Also, if the pre-push lintrunner found errors, you had to hope your normal lintrunner could fix them (which wasn't always the case, e.g. if those errors only manifested in certain python versions) ## Key Changes: - Isolated Environment: Creates .git/hooks/linter/.venv/ with Python 3.9 (the python used in CI) and an isolated lintrunner installation - User-Friendly CLI: New python scripts/lintrunner.py wrapper allows developers to run lintrunner (including -a auto-fix) from any environment - Simplified Architecture: Eliminates pre-commit dependency entirely - uses direct git hooks File Changes: - scripts/setup_hooks.py: Rewritten to create isolated uv-managed virtual environment - scripts/lintrunner.py: New wrapper script with shared hash management logic - scripts/run_lintrunner.py: Removed (functionality merged into lintrunner.py) - .pre-commit-config.yaml: Removed (no longer needed) ## Usage: ``` # Setup (run once) python scripts/setup_hooks.py # Manual linting (works from any environment) python scripts/lintrunner.py # Check mode python scripts/lintrunner.py -a # Auto-fix mode # Git hooks work automatically git push # Runs lintrunner in isolated environment # Need to skip the pre-push hook? git push --no-verify ``` ## Benefits: - ✅ Zero global dependency installation - ✅ Per-repository isolation prevents version conflicts - ✅ Full lintrunner functionality is now accessible ## Implementation Notes: - Virtual env is kept in a dedicated dir in .git, to keep per-repo mechanics - lintrunner.py does not need to be invoked from a specific venv. It'll invoke the right venv itself. A minor bug: It tends to garble the lintrunner output a bit, like the screenshot below shows, but I haven't found a workaround so far and it remains understandable to users: <img width="241" height="154" alt="image" src="https://github.com/user-attachments/assets/9496f925-8524-4434-8486-dc579442d688" /> ## What's next? Features that could be added: - Check for lintrunner updates, auto-update if needed - Depending on dev response, this could be enabled by default for all pytorch/pytorch environments Pull Request resolved: https://github.com/pytorch/pytorch/pull/160048 Approved by: https://github.com/seemethere |
|||
| 7a974a88f2 |
[ROCm] Fix resource_strings.h (#159996)
This PR fixes the errors like below: ``` [rank7]: RuntimeError: /tmp/comgr-c3c81b/input/CompileSourceejOPx6:34:8: error: unknown type name 'uint64_t'; did you mean '__hip_internal::uint64_t'? [rank7]: 34 | if(((uint64_t) t0.data) % (4 * sizeof(half)) != 0) flag_vec4 = false; ``` The following datatypes needs to be defined in `torch/csrc/jit/codegen/fuser/cuda/resource_strings.h` for ROCm versions >= 7.0. ``` typedef unsigned char uint8_t; typedef signed char int8_t; typedef short int int16_t; typedef long long int int64_t; typedef unsigned long long int uint64_t; ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/159996 Approved by: https://github.com/pruthvistony, https://github.com/Skylion007, https://github.com/jeffdaily |
|||
| f3f159ff8c |
[BE][cutlass backend] Reduce severity of log message for no cutlass config found (#160148)
This is not really a problem. Sometimes we cannot find a cutlass config due to shape, e.g. when k is odd. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160148 Approved by: https://github.com/mlazos, https://github.com/Skylion007 |
|||
| b90feeac86 |
[BE][cutlass backend] Fix subproc addmm tests (#160295)
Differential Revision: [D79977421](https://our.internmc.facebook.com/intern/diff/D79977421/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/160295 Approved by: https://github.com/jingsh |
|||
| 0d40ff3b49 |
[inductor] fix test_different_file_paths_local_pgo on Windows. (#160382)
fix test_different_file_paths_local_pgo on Windows. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160382 Approved by: https://github.com/angelayi |
|||
| cae2b5e3d2 |
[ROCm][Windows] Enable USE_ROCM, disable USE_RCCL on Windows. (#159079)
This allows setting `USE_ROCM` on Windows. A few other patches are still required to build (see https://github.com/ROCm/TheRock/issues/589), but we have instructions using open source code and rocm python packages available at https://github.com/ROCm/TheRock/tree/main/external-builds/pytorch#build-pytorch-with-rocm-support. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159079 Approved by: https://github.com/jeffdaily |
|||
| ee89cc7a0a |
[ROCm][Windows] Fix LoadHIP handling of environment variable paths on Windows. (#159080)
See https://cmake.org/cmake/help/latest/command/file.html#path-conversion. Paths stored in environment variables may use `/` or `\` (e.g. on Windows), while cmake-style paths always use `/`. This fixes configure errors like: ``` CMake Error at D:/b/pytorch_main/build/CMakeFiles/CMakeScratch/TryCompile-srhq07/CMakeLists.txt:2 (set): Syntax error in cmake code at D:/b/pytorch_main/build/CMakeFiles/CMakeScratch/TryCompile-srhq07/CMakeLists.txt:2 when parsing string D:\projects\TheRock\external-builds\pytorch\.venv\Lib\site-packages\_rocm_sdk_devel/cmake/;D:/b/pytorch_main/cmake/Modules Invalid character escape '\p'. CMake Error at D:/projects/TheRock/external-builds/pytorch/.venv/Lib/site-packages/cmake/data/share/cmake-3.31/Modules/Internal/CheckSourceCompiles.cmake:108 (try_compile): Failed to configure test project build system. ``` (note the mixed usage of `\` and `/` in that string) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159080 Approved by: https://github.com/jeffdaily |
|||
| e63c2b21c1 |
[PP] Initialize P2P communicators on first step (#160210)
Was hitting hangs in multi-node settings and initializing the NCCL communicators needed for batch p2p ops ahead of time fixes this. This change adds extra communication since it communicates a dummy tensor to next and previous stage ranks. However, this is only paid on the first step so it is negligible. Debug history: https://docs.google.com/document/d/1EKVJYmW2hj_VsvDvnSggXhZzJyvMu9dA0iDJWOZAtjY/edit?tab=t.0 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160210 Approved by: https://github.com/wconstab |
|||
| 3626ba711b |
[FlexAttention] Swap from and to & for new triton (#160227)
Fixes #158463 On B200 I am getting a bunch of error spew: ```Shell /tmp/tmp0yiz3c94/p4/cp4ahrfnz4obsvzgftux7dg3aszopks2jljnoaz3eowlooi2scem.py:18:0: error: Failures have been detected while processing an MLIR pass pipeline /tmp/tmp0yiz3c94/p4/cp4ahrfnz4obsvzgftux7dg3aszopks2jljnoaz3eowlooi2scem.py:18:0: note: Pipeline failed while executing [`TritonGPUHoistTMEMAlloc` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.` Triton compilation failed: triton_tem_fused_zeros_1 def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0): PRESCALE_QK : tl.constexpr = False ``` ```Shell 74 = arith.subi %170, %166 : i32 %175 = arith.muli %174, %c128_i32 : i32 %176 = arith.subi %175, %c64_i32 : i32 %177 = arith.extui %173 : i1 to i32 %178 = arith.muli %176, %177 : i32 %179 = arith.subi %c1_i32, %177 : i32 %180 = arith.muli %179, %c64_i32 : i32 %181 = arith.addi %178, %180 : i32 %182 = arith.muli %181, %c64_i32 : i32 %183 = tt.splat %182 : i32 -> tensor<64x64xi32> %184 = tt.addptr %arg19, %183 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32> %185 = tt.addptr %arg20, %183 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32> %186 = tt.splat %181 : i32 -> tensor<64xi32> %187 = arith.addi %arg21, %186 : tensor<64xi32> scf.yield %163, %184, %185, %187 : tensor<64x64xf32>, tensor<64x64x!tt.ptr<f16>>, tensor<64x64x!tt.ptr<f16>>, tensor<64xi32> } %114 = tt.expand_dims %113#3 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> %115 = arith.cmpi slt, %114, %cst_7 : tensor<1x64xi32> %116 = tt.broadcast %115 : tensor<1x64xi1> -> tensor<64x64xi1> %117 = tt.load %113#1, %116, %cst_8 : tensor<64x64x!tt.ptr<f16>> %118 = tt.dot %46, %117, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %119 = arith.mulf %118, %cst_13 : tensor<64x64xf32> %120 = arith.mulf %119, %cst_3 : tensor<64x64xf32> %121 = arith.select %116, %120, %cst_6 : tensor<64x64xi1>, tensor<64x64xf32> %122 = arith.select %115, %cst_4, %cst_5 : tensor<1x64xi1>, tensor<1x64xi1> %123 = tt.broadcast %122 : tensor<1x64xi1> -> tensor<64x64xi1> %124 = arith.select %123, %121, %cst_6 : tensor<64x64xi1>, tensor<64x64xf32> %125 = arith.mulf %124, %cst_2 : tensor<64x64xf32> %126 = tt.broadcast %61 : tensor<64x1xf32> -> tensor<64x64xf32> %127 = arith.subf %125, %126 : tensor<64x64xf32> %128 = math.exp2 %127 : tensor<64x64xf32> %129 = tt.load %113#2, %116, %cst_8 : tensor<64x64x!tt.ptr<f16>> %130 = tt.dot %51, %129, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %131 = tt.expand_dims %55 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32> %132 = tt.broadcast %131 : tensor<64x1xf32> -> tensor<64x64xf32> %133 = arith.subf %130, %132 : tensor<64x64xf32> %134 = arith.mulf %128, %133 : tensor<64x64xf32> %135 = arith.mulf %134, %cst_3 : tensor<64x64xf32> %136 = arith.select %116, %135, %cst_9 : tensor<64x64xi1>, tensor<64x64xf32> %137 = arith.select %115, %122, %cst_5 : tensor<1x64xi1>, tensor<1x64xi1> %138 = tt.broadcast %137 : tensor<1x64xi1> -> tensor<64x64xi1> %139 = arith.select %138, %136, %cst_9 : tensor<64x64xi1>, tensor<64x64xf32> %140 = arith.truncf %139 : tensor<64x64xf32> to tensor<64x64xf16> %141 = tt.trans %117 {order = array<i32: 1, 0>} : tensor<64x64xf16> -> tensor<64x64xf16> %142 = tt.dot %140, %141, %113#0, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> scf.yield %142 : tensor<64x64xf32> } else { scf.yield %cst_9 : tensor<64x64xf32> } %84 = tt.addptr %arg13, %22 : !tt.ptr<i32>, i32 %85 = tt.load %84 : !tt.ptr<i32> %86 = arith.muli %85, %c128_i32 : i32 %87 = tt.addptr %arg12, %21 : !tt.ptr<i32>, i32 %88 = tt.load %87 : !tt.ptr<i32> %89 = tt.splat %86 : i32 -> tensor<64xi32> %90 = arith.addi %89, %14 : tensor<64xi32> %91 = tt.expand_dims %90 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> %92 = arith.muli %91, %cst_11 : tensor<1x64xi32> %93 = tt.addptr %71, %92 : tensor<1x64x!tt.ptr<f16>>, tensor<1x64xi32> %94 = tt.broadcast %93 : tensor<1x64x!tt.ptr<f16>> -> tensor<64x64x!tt.ptr<f16>> %95 = tt.addptr %94, %74 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32> %96 = tt.addptr %76, %92 : tensor<1x64x!tt.ptr<f16>>, tensor<1x64xi32> %97 = tt.broadcast %96 : tensor<1x64x!tt.ptr<f16>> -> tensor<64x64x!tt.ptr<f16>> %98 = tt.addptr %97, %74 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32> %99 = arith.muli %88, %c2_i32 : i32 %100 = arith.minsi %99, %c4_i32 : i32 %101 = arith.cmpi sge, %100, %c1_i32 : i32 %102 = scf.if %101 -> (tensor<64x64xf32>) { %112 = arith.subi %100, %c1_i32 : i32 %113:4 = scf.for %arg17 = %c0_i32 to %112 step %c1_i32 iter_args(%arg18 = %83, %arg19 = %95, %arg20 = %98, %arg21 = %90) -> (tensor<64x64xf32>, tensor<64x64x!tt.ptr<f16>>, tensor<64x64x!tt.ptr<f16>>, tensor<64xi32>) : i32 { %137 = tt.expand_dims %arg21 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> %138 = arith.cmpi slt, %137, %cst_7 : tensor<1x64xi32> %139 = tt.broadcast %138 : tensor<1x64xi1> -> tensor<64x64xi1> %140 = tt.load %arg19, %139, %cst_8 : tensor<64x64x!tt.ptr<f16>> %141 = tt.dot %46, %140, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %142 = arith.mulf %141, %cst_13 : tensor<64x64xf32> %143 = arith.mulf %142, %cst_3 : tensor<64x64xf32> %144 = arith.mulf %143, %cst_2 : tensor<64x64xf32> %145 = tt.broadcast %61 : tensor<64x1xf32> -> tensor<64x64xf32> %146 = arith.subf %144, %145 : tensor<64x64xf32> %147 = math.exp2 %146 : tensor<64x64xf32> %148 = tt.load %arg20, %139, %cst_8 : tensor<64x64x!tt.ptr<f16>> %149 = tt.dot %51, %148, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %150 = tt.expand_dims %55 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32> %151 = tt.broadcast %150 : tensor<64x1xf32> -> tensor<64x64xf32> %152 = arith.subf %149, %151 : tensor<64x64xf32> %153 = arith.mulf %147, %152 : tensor<64x64xf32> %154 = arith.mulf %153, %cst_3 : tensor<64x64xf32> %155 = arith.truncf %154 : tensor<64x64xf32> to tensor<64x64xf16> %156 = tt.trans %140 {order = array<i32: 1, 0>} : tensor<64x64xf16> -> tensor<64x64xf16> %157 = tt.dot %155, %156, %arg18, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %158 = arith.divsi %arg17, %c2_i32 : i32 %159 = tt.addptr %84, %158 : !tt.ptr<i32>, i32 %160 = tt.load %159 evictionPolicy = evict_last : !tt.ptr<i32> %161 = arith.addi %158, %c1_i32 : i32 %162 = arith.cmpi slt, %161, %88 : i32 %163 = tt.addptr %159, %c1_i32 : !tt.ptr<i32>, i32 %164 = tt.load %163, %162 evictionPolicy = evict_last : !tt.ptr<i32> %165 = arith.addi %arg17, %c1_i32 : i32 %166 = arith.remsi %165, %c2_i32 : i32 %167 = arith.cmpi eq, %166, %c0_i32 : i32 %168 = arith.subi %164, %160 : i32 %169 = arith.muli %168, %c128_i32 : i32 %170 = arith.subi %169, %c64_i32 : i32 %171 = arith.extui %167 : i1 to i32 %172 = arith.muli %170, %171 : i32 %173 = arith.subi %c1_i32, %171 : i32 %174 = arith.muli %173, %c64_i32 : i32 %175 = arith.addi %172, %174 : i32 %176 = arith.muli %175, %c64_i32 : i32 %177 = tt.splat %176 : i32 -> tensor<64x64xi32> %178 = tt.addptr %arg19, %177 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32> %179 = tt.addptr %arg20, %177 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32> %180 = tt.splat %175 : i32 -> tensor<64xi32> %181 = arith.addi %arg21, %180 : tensor<64xi32> scf.yield %157, %178, %179, %181 : tensor<64x64xf32>, tensor<64x64x!tt.ptr<f16>>, tensor<64x64x!tt.ptr<f16>>, tensor<64xi32> } %114 = tt.expand_dims %113#3 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> %115 = arith.cmpi slt, %114, %cst_7 : tensor<1x64xi32> %116 = tt.broadcast %115 : tensor<1x64xi1> -> tensor<64x64xi1> %117 = tt.load %113#1, %116, %cst_8 : tensor<64x64x!tt.ptr<f16>> %118 = tt.dot %46, %117, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %119 = arith.mulf %118, %cst_13 : tensor<64x64xf32> %120 = arith.mulf %119, %cst_3 : tensor<64x64xf32> %121 = arith.select %116, %120, %cst_6 : tensor<64x64xi1>, tensor<64x64xf32> %122 = arith.mulf %121, %cst_2 : tensor<64x64xf32> %123 = tt.broadcast %61 : tensor<64x1xf32> -> tensor<64x64xf32> %124 = arith.subf %122, %123 : tensor<64x64xf32> %125 = math.exp2 %124 : tensor<64x64xf32> %126 = tt.load %113#2, %116, %cst_8 : tensor<64x64x!tt.ptr<f16>> %127 = tt.dot %51, %126, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %128 = tt.expand_dims %55 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32> %129 = tt.broadcast %128 : tensor<64x1xf32> -> tensor<64x64xf32> %130 = arith.subf %127, %129 : tensor<64x64xf32> %131 = arith.mulf %125, %130 : tensor<64x64xf32> %132 = arith.mulf %131, %cst_3 : tensor<64x64xf32> %133 = arith.select %116, %132, %cst_9 : tensor<64x64xi1>, tensor<64x64xf32> %134 = arith.truncf %133 : tensor<64x64xf32> to tensor<64x64xf16> %135 = tt.trans %117 {order = array<i32: 1, 0>} : tensor<64x64xf16> -> tensor<64x64xf16> %136 = tt.dot %134, %135, %113#0, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> scf.yield %136 : tensor<64x64xf32> } else { scf.yield %83 : tensor<64x64xf32> } %103 = tt.splat %33 : !tt.ptr<f16> -> tensor<64x1x!tt.ptr<f16>> %104 = tt.addptr %103, %37 : tensor<64x1x!tt.ptr<f16>>, tensor<64x1xi32> %105 = tt.broadcast %104 : tensor<64x1x!tt.ptr<f16>> -> tensor<64x64x!tt.ptr<f16>> %106 = tt.addptr %105, %42 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32> %107 = arith.mulf %102, %cst_13 : tensor<64x64xf32> %108 = arith.cmpi slt, %40, %cst_11 : tensor<1x64xi32> %109 = tt.broadcast %108 : tensor<1x64xi1> -> tensor<64x64xi1> %110 = arith.andi %45, %109 : tensor<64x64xi1> %111 = arith.truncf %107 : tensor<64x64xf32> to tensor<64x64xf16> tt.store %106, %111, %110 : tensor<64x64x!tt.ptr<f16>> } else { %16 = arith.divsi %0, %c2_i32 : i32 %17 = arith.muli %0, %c64_i32 : i32 %18 = tt.splat %17 : i32 -> tensor<64xi32> %19 = arith.addi %18, %14 : tensor<64xi32> %20 = tt.expand_dims %19 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> %21 = arith.muli %20, %cst_14 : tensor<64x1xi32> %22 = tt.splat %11 : !tt.ptr<f16> -> tensor<64x1x!tt.ptr<f16>> %23 = tt.addptr %22, %21 : tensor<64x1x!tt.ptr<f16>>, tensor<64x1xi32> %24 = tt.expand_dims %14 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> %25 = tt.broadcast %23 : tensor<64x1x!tt.ptr<f16>> -> tensor<64x64x!tt.ptr<f16>> %26 = tt.broadcast %24 : tensor<1x64xi32> -> tensor<64x64xi32> %27 = tt.addptr %25, %26 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32> %28 = arith.cmpi slt, %20, %cst_10 : tensor<64x1xi32> %29 = tt.broadcast %28 : tensor<64x1xi1> -> tensor<64x64xi1> %30 = tt.load %27, %29, %cst_8 : tensor<64x64x!tt.ptr<f16>> %31 = tt.splat %12 : !tt.ptr<f16> -> tensor<64x1x!tt.ptr<f16>> %32 = tt.addptr %31, %21 : tensor<64x1x!tt.ptr<f16>>, tensor<64x1xi32> %33 = tt.broadcast %32 : tensor<64x1x!tt.ptr<f16>> -> tensor<64x64x!tt.ptr<f16>> %34 = tt.addptr %33, %26 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32> %35 = tt.load %34, %29, %cst_8 : tensor<64x64x!tt.ptr<f16>> %36:2 = scf.for %arg17 = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%arg18 = %cst_9, %arg19 = %cst_9) -> (tensor<64x64xf32>, tensor<64x64xf32>) : i32 { %55 = arith.muli %2, %c4_i32 : i32 %56 = arith.addi %55, %arg17 : i32 %57 = arith.muli %56, %c2048_i32 : i32 %58 = arith.muli %1, %c32768_i32 : i32 %59 = arith.addi %57, %58 : i32 %60 = arith.extsi %59 : i32 to i64 %61 = arith.muli %1, %c16_i32 : i32 %62 = arith.addi %61, %56 : i32 %63 = arith.muli %62, %c32_i32 : i32 %64 = arith.extsi %63 : i32 to i64 %65 = tt.addptr %arg0, %60 : !tt.ptr<f16>, i64 %66 = tt.addptr %arg5, %60 : !tt.ptr<f16>, i64 %67 = tt.addptr %arg3, %64 : !tt.ptr<f32>, i64 %68 = tt.addptr %arg4, %64 : !tt.ptr<f32>, i64 %69 = arith.remsi %56, %c16_i32 : i32 %70 = arith.muli %3, %c16_i32 : i32 %71 = arith.addi %70, %69 : i32 %72 = arith.muli %71, %c2_i32 : i32 %73 = arith.addi %72, %16 : i32 %74 = tt.addptr %arg11, %73 : !tt.ptr<i32>, i32 %75 = tt.load %74 : !tt.ptr<i32> %76 = arith.muli %75, %c128_i32 : i32 %77 = tt.addptr %arg10, %73 : !tt.ptr<i32>, i32 %78 = tt.load %77 : !tt.ptr<i32> %79 = tt.splat %76 : i32 -> tensor<64xi32> %80 = arith.addi %79, %14 : tensor<64xi32> %81 = tt.expand_dims %80 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> %82 = arith.muli %81, %cst_11 : tensor<1x64xi32> %83 = tt.splat %65 : !tt.ptr<f16> -> tensor<1x64x!tt.ptr<f16>> %84 = tt.addptr %83, %82 : tensor<1x64x!tt.ptr<f16>>, tensor<1x64xi32> %85 = tt.expand_dims %14 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> %86 = tt.broadcast %84 : tensor<1x64x!tt.ptr<f16>> -> tensor<64x64x!tt.ptr<f16>> %87 = tt.broadcast %85 : tensor<64x1xi32> -> tensor<64x64xi32> %88 = tt.addptr %86, %87 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32> %89 = tt.expand_dims %80 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> %90 = arith.muli %89, %cst_14 : tensor<64x1xi32> %91 = tt.splat %66 : !tt.ptr<f16> -> tensor<64x1x!tt.ptr<f16>> %92 = tt.addptr %91, %90 : tensor<64x1x!tt.ptr<f16>>, tensor<64x1xi32> %93 = tt.broadcast %92 : tensor<64x1x!tt.ptr<f16>> -> tensor<64x64x!tt.ptr<f16>> %94 = tt.addptr %93, %26 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32> %95 = arith.muli %78, %c2_i32 : i32 %96 = arith.minsi %95, %c1_i32 : i32 %97 = arith.cmpi sge, %96, %c1_i32 : i32 %98:2 = scf.if %97 -> (tensor<64x64xf32>, tensor<64x64xf32>) { %120 = arith.subi %96, %c1_i32 : i32 %121:5 = scf.for %arg20 = %c0_i32 to %120 step %c1_i32 iter_args(%arg21 = %arg18, %arg22 = %arg19, %arg23 = %88, %arg24 = %94, %arg25 = %80) -> (tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64x!tt.ptr<f16>>, tensor<64x64x!tt.ptr<f16>>, tensor<64xi32>) : i32 { %167 = tt.expand_dims %arg25 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> %168 = arith.cmpi slt, %167, %cst_1 : tensor<1x64xi32> %169 = tt.broadcast %168 : tensor<1x64xi1> -> tensor<64x64xi1> %170 = tt.load %arg23, %169, %cst_8 : tensor<64x64x!tt.ptr<f16>> %171 = arith.cmpi slt, %arg25, %cst_17 : tensor<64xi32> %172 = tt.splat %67 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>> %173 = tt.addptr %172, %arg25 : tensor<64x!tt.ptr<f32>>, tensor<64xi32> %174 = tt.load %173, %171 : tensor<64x!tt.ptr<f32>> %175 = arith.cmpf oeq, %174, %cst_16 : tensor<64xf32> %176 = arith.select %175, %cst_15, %174 : tensor<64xi1>, tensor<64xf32> %177 = tt.dot %30, %170, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %178 = arith.mulf %177, %cst_13 : tensor<64x64xf32> %179 = arith.mulf %178, %cst_3 : tensor<64x64xf32> %180 = arith.mulf %179, %cst_2 : tensor<64x64xf32> %181 = tt.expand_dims %176 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> %182 = tt.broadcast %181 : tensor<1x64xf32> -> tensor<64x64xf32> %183 = arith.subf %180, %182 : tensor<64x64xf32> %184 = math.exp2 %183 : tensor<64x64xf32> %185 = tt.expand_dims %arg25 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> %186 = arith.cmpi slt, %185, %cst_12 : tensor<64x1xi32> %187 = tt.broadcast %186 : tensor<64x1xi1> -> tensor<64x64xi1> %188 = tt.load %arg24, %187, %cst_8 : tensor<64x64x!tt.ptr<f16>> %189 = arith.truncf %184 : tensor<64x64xf32> to tensor<64x64xf16> %190 = tt.dot %189, %188, %arg22, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %191 = tt.splat %68 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>> %192 = tt.addptr %191, %arg25 : tensor<64x!tt.ptr<f32>>, tensor<64xi32> %193 = tt.load %192, %171 : tensor<64x!tt.ptr<f32>> %194 = tt.trans %188 {order = array<i32: 1, 0>} : tensor<64x64xf16> -> tensor<64x64xf16> %195 = tt.dot %35, %194, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %196 = tt.expand_dims %193 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> %197 = tt.broadcast %196 : tensor<1x64xf32> -> tensor<64x64xf32> %198 = arith.subf %195, %197 : tensor<64x64xf32> %199 = arith.mulf %184, %198 : tensor<64x64xf32> %200 = arith.mulf %199, %cst_3 : tensor<64x64xf32> %201 = arith.truncf %200 : tensor<64x64xf32> to tensor<64x64xf16> %202 = tt.trans %170 {order = array<i32: 1, 0>} : tensor<64x64xf16> -> tensor<64x64xf16> %203 = tt.dot %201, %202, %arg21, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %204 = arith.divsi %arg20, %c2_i32 : i32 %205 = tt.addptr %74, %204 : !tt.ptr<i32>, i32 %206 = tt.load %205 evictionPolicy = evict_last : !tt.ptr<i32> %207 = arith.addi %204, %c1_i32 : i32 %208 = arith.cmpi slt, %207, %78 : i32 %209 = tt.addptr %205, %c1_i32 : !tt.ptr<i32>, i32 %210 = tt.load %209, %208 evictionPolicy = evict_last : !tt.ptr<i32> %211 = arith.addi %arg20, %c1_i32 : i32 %212 = arith.remsi %211, %c2_i32 : i32 %213 = arith.cmpi eq, %212, %c0_i32 : i32 %214 = arith.subi %210, %206 : i32 %215 = arith.muli %214, %c128_i32 : i32 %216 = arith.subi %215, %c64_i32 : i32 %217 = arith.extui %213 : i1 to i32 %218 = arith.muli %216, %217 : i32 %219 = arith.subi %c1_i32, %217 : i32 %220 = arith.muli %219, %c64_i32 : i32 %221 = arith.addi %218, %220 : i32 %222 = arith.muli %221, %c64_i32 : i32 %223 = tt.splat %222 : i32 -> tensor<64x64xi32> %224 = tt.addptr %arg23, %223 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32> %225 = tt.addptr %arg24, %223 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32> %226 = tt.splat %221 : i32 -> tensor<64xi32> %227 = arith.addi %arg25, %226 : tensor<64xi32> scf.yield %203, %190, %224, %225, %227 : tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64x!tt.ptr<f16>>, tensor<64x64x!tt.ptr<f16>>, tensor<64xi32> } %122 = tt.expand_dims %121#4 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> %123 = arith.cmpi slt, %122, %cst_1 : tensor<1x64xi32> %124 = tt.broadcast %123 : tensor<1x64xi1> -> tensor<64x64xi1> %125 = tt.load %121#2, %124, %cst_8 : tensor<64x64x!tt.ptr<f16>> %126 = arith.cmpi slt, %121#4, %cst_17 : tensor<64xi32> %127 = tt.splat %67 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>> %128 = tt.addptr %127, %121#4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32> %129 = tt.load %128, %126 : tensor<64x!tt.ptr<f32>> %130 = arith.cmpf oeq, %129, %cst_16 : tensor<64xf32> %131 = arith.select %130, %cst_15, %129 : tensor<64xi1>, tensor<64xf32> %132 = tt.dot %30, %125, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %133 = arith.mulf %132, %cst_13 : tensor<64x64xf32> %134 = arith.mulf %133, %cst_3 : tensor<64x64xf32> %135 = arith.select %29, %134, %cst_6 : tensor<64x64xi1>, tensor<64x64xf32> %136 = arith.select %28, %cst, %cst_0 : tensor<64x1xi1>, tensor<64x1xi1> %137 = tt.broadcast %136 : tensor<64x1xi1> -> tensor<64x64xi1> %138 = arith.select %137, %135, %cst_6 : tensor<64x64xi1>, tensor<64x64xf32> %139 = arith.mulf %138, %cst_2 : tensor<64x64xf32> %140 = tt.expand_dims %131 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> %141 = tt.broadcast %140 : tensor<1x64xf32> -> tensor<64x64xf32> %142 = arith.subf %139, %141 : tensor<64x64xf32> %143 = math.exp2 %142 : tensor<64x64xf32> %144 = tt.expand_dims %121#4 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> %145 = arith.cmpi slt, %144, %cst_12 : tensor<64x1xi32> %146 = tt.broadcast %145 : tensor<64x1xi1> -> tensor<64x64xi1> %147 = tt.load %121#3, %146, %cst_8 : tensor<64x64x!tt.ptr<f16>> %148 = arith.truncf %143 : tensor<64x64xf32> to tensor<64x64xf16> %149 = tt.dot %148, %147, %121#1, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %150 = tt.splat %68 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>> %151 = tt.addptr %150, %121#4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32> %152 = tt.load %151, %126 : tensor<64x!tt.ptr<f32>> %153 = tt.trans %147 {order = array<i32: 1, 0>} : tensor<64x64xf16> -> tensor<64x64xf16> %154 = tt.dot %35, %153, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %155 = tt.expand_dims %152 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> %156 = tt.broadcast %155 : tensor<1x64xf32> -> tensor<64x64xf32> %157 = arith.subf %154, %156 : tensor<64x64xf32> %158 = arith.mulf %143, %157 : tensor<64x64xf32> %159 = arith.mulf %158, %cst_3 : tensor<64x64xf32> %160 = arith.select %29, %159, %cst_9 : tensor<64x64xi1>, tensor<64x64xf32> %161 = arith.select %28, %136, %cst_0 : tensor<64x1xi1>, tensor<64x1xi1> %162 = tt.broadcast %161 : tensor<64x1xi1> -> tensor<64x64xi1> %163 = arith.select %162, %160, %cst_9 : tensor<64x64xi1>, tensor<64x64xf32> %164 = arith.truncf %163 : tensor<64x64xf32> to tensor<64x64xf16> %165 = tt.trans %125 {order = array<i32: 1, 0>} : tensor<64x64xf16> -> tensor<64x64xf16> %166 = tt.dot %164, %165, %121#0, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> scf.yield %166, %149 : tensor<64x64xf32>, tensor<64x64xf32> } else { scf.yield %arg18, %arg19 : tensor<64x64xf32>, tensor<64x64xf32> } %99 = tt.addptr %arg15, %73 : !tt.ptr<i32>, i32 %100 = tt.load %99 : !tt.ptr<i32> %101 = arith.muli %100, %c128_i32 : i32 %102 = tt.addptr %arg14, %73 : !tt.ptr<i32>, i32 %103 = tt.load %102 : !tt.ptr<i32> %104 = tt.splat %101 : i32 -> tensor<64xi32> %105 = arith.addi %104, %14 : tensor<64xi32> %106 = tt.expand_dims %105 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> %107 = arith.muli %106, %cst_11 : tensor<1x64xi32> %108 = tt.addptr %83, %107 : tensor<1x64x!tt.ptr<f16>>, tensor<1x64xi32> %109 = tt.broadcast %108 : tensor<1x64x!tt.ptr<f16>> -> tensor<64x64x!tt.ptr<f16>> %110 = tt.addptr %109, %87 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32> %111 = tt.expand_dims %105 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> %112 = arith.muli %111, %cst_14 : tensor<64x1xi32> %113 = tt.addptr %91, %112 : tensor<64x1x!tt.ptr<f16>>, tensor<64x1xi32> %114 = tt.broadcast %113 : tensor<64x1x!tt.ptr<f16>> -> tensor<64x64x!tt.ptr<f16>> %115 = tt.addptr %114, %26 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32> %116 = arith.muli %103, %c2_i32 : i32 %117 = arith.minsi %116, %c1_i32 : i32 %118 = arith.cmpi sge, %117, %c1_i32 : i32 %119:2 = scf.if %118 -> (tensor<64x64xf32>, tensor<64x64xf32>) { %120 = arith.subi %117, %c1_i32 : i32 %121:5 = scf.for %arg20 = %c0_i32 to %120 step %c1_i32 iter_args(%arg21 = %98#0, %arg22 = %98#1, %arg23 = %110, %arg24 = %115, %arg25 = %105) -> (tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64x!tt.ptr<f16>>, tensor<64x64x!tt.ptr<f16>>, tensor<64xi32>) : i32 { %161 = tt.expand_dims %arg25 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> %162 = arith.cmpi slt, %161, %cst_1 : tensor<1x64xi32> %163 = tt.broadcast %162 : tensor<1x64xi1> -> tensor<64x64xi1> %164 = tt.load %arg23, %163, %cst_8 : tensor<64x64x!tt.ptr<f16>> %165 = arith.cmpi slt, %arg25, %cst_17 : tensor<64xi32> %166 = tt.splat %67 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>> %167 = tt.addptr %166, %arg25 : tensor<64x!tt.ptr<f32>>, tensor<64xi32> %168 = tt.load %167, %165 : tensor<64x!tt.ptr<f32>> %169 = arith.cmpf oeq, %168, %cst_16 : tensor<64xf32> %170 = arith.select %169, %cst_15, %168 : tensor<64xi1>, tensor<64xf32> %171 = tt.dot %30, %164, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %172 = arith.mulf %171, %cst_13 : tensor<64x64xf32> %173 = arith.mulf %172, %cst_3 : tensor<64x64xf32> %174 = arith.mulf %173, %cst_2 : tensor<64x64xf32> %175 = tt.expand_dims %170 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> %176 = tt.broadcast %175 : tensor<1x64xf32> -> tensor<64x64xf32> %177 = arith.subf %174, %176 : tensor<64x64xf32> %178 = math.exp2 %177 : tensor<64x64xf32> %179 = tt.expand_dims %arg25 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> %180 = arith.cmpi slt, %179, %cst_12 : tensor<64x1xi32> %181 = tt.broadcast %180 : tensor<64x1xi1> -> tensor<64x64xi1> %182 = tt.load %arg24, %181, %cst_8 : tensor<64x64x!tt.ptr<f16>> %183 = arith.truncf %178 : tensor<64x64xf32> to tensor<64x64xf16> %184 = tt.dot %183, %182, %arg22, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %185 = tt.splat %68 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>> %186 = tt.addptr %185, %arg25 : tensor<64x!tt.ptr<f32>>, tensor<64xi32> %187 = tt.load %186, %165 : tensor<64x!tt.ptr<f32>> %188 = tt.trans %182 {order = array<i32: 1, 0>} : tensor<64x64xf16> -> tensor<64x64xf16> %189 = tt.dot %35, %188, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %190 = tt.expand_dims %187 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> %191 = tt.broadcast %190 : tensor<1x64xf32> -> tensor<64x64xf32> %192 = arith.subf %189, %191 : tensor<64x64xf32> %193 = arith.mulf %178, %192 : tensor<64x64xf32> %194 = arith.mulf %193, %cst_3 : tensor<64x64xf32> %195 = arith.truncf %194 : tensor<64x64xf32> to tensor<64x64xf16> %196 = tt.trans %164 {order = array<i32: 1, 0>} : tensor<64x64xf16> -> tensor<64x64xf16> %197 = tt.dot %195, %196, %arg21, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %198 = arith.divsi %arg20, %c2_i32 : i32 %199 = tt.addptr %99, %198 : !tt.ptr<i32>, i32 %200 = tt.load %199 evictionPolicy = evict_last : !tt.ptr<i32> %201 = arith.addi %198, %c1_i32 : i32 %202 = arith.cmpi slt, %201, %103 : i32 %203 = tt.addptr %199, %c1_i32 : !tt.ptr<i32>, i32 %204 = tt.load %203, %202 evictionPolicy = evict_last : !tt.ptr<i32> %205 = arith.addi %arg20, %c1_i32 : i32 %206 = arith.remsi %205, %c2_i32 : i32 %207 = arith.cmpi eq, %206, %c0_i32 : i32 %208 = arith.subi %204, %200 : i32 %209 = arith.muli %208, %c128_i32 : i32 %210 = arith.subi %209, %c64_i32 : i32 %211 = arith.extui %207 : i1 to i32 %212 = arith.muli %210, %211 : i32 %213 = arith.subi %c1_i32, %211 : i32 %214 = arith.muli %213, %c64_i32 : i32 %215 = arith.addi %212, %214 : i32 %216 = arith.muli %215, %c64_i32 : i32 %217 = tt.splat %216 : i32 -> tensor<64x64xi32> %218 = tt.addptr %arg23, %217 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32> %219 = tt.addptr %arg24, %217 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32> %220 = tt.splat %215 : i32 -> tensor<64xi32> %221 = arith.addi %arg25, %220 : tensor<64xi32> scf.yield %197, %184, %218, %219, %221 : tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64x!tt.ptr<f16>>, tensor<64x64x!tt.ptr<f16>>, tensor<64xi32> } %122 = tt.expand_dims %121#4 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> %123 = arith.cmpi slt, %122, %cst_1 : tensor<1x64xi32> %124 = tt.broadcast %123 : tensor<1x64xi1> -> tensor<64x64xi1> %125 = tt.load %121#2, %124, %cst_8 : tensor<64x64x!tt.ptr<f16>> %126 = arith.cmpi slt, %121#4, %cst_17 : tensor<64xi32> %127 = tt.splat %67 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>> %128 = tt.addptr %127, %121#4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32> %129 = tt.load %128, %126 : tensor<64x!tt.ptr<f32>> %130 = arith.cmpf oeq, %129, %cst_16 : tensor<64xf32> %131 = arith.select %130, %cst_15, %129 : tensor<64xi1>, tensor<64xf32> %132 = tt.dot %30, %125, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %133 = arith.mulf %132, %cst_13 : tensor<64x64xf32> %134 = arith.mulf %133, %cst_3 : tensor<64x64xf32> %135 = arith.select %29, %134, %cst_6 : tensor<64x64xi1>, tensor<64x64xf32> %136 = arith.mulf %135, %cst_2 : tensor<64x64xf32> %137 = tt.expand_dims %131 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> %138 = tt.broadcast %137 : tensor<1x64xf32> -> tensor<64x64xf32> %139 = arith.subf %136, %138 : tensor<64x64xf32> %140 = math.exp2 %139 : tensor<64x64xf32> %141 = tt.expand_dims %121#4 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> %142 = arith.cmpi slt, %141, %cst_12 : tensor<64x1xi32> %143 = tt.broadcast %142 : tensor<64x1xi1> -> tensor<64x64xi1> %144 = tt.load %121#3, %143, %cst_8 : tensor<64x64x!tt.ptr<f16>> %145 = arith.truncf %140 : tensor<64x64xf32> to tensor<64x64xf16> %146 = tt.dot %145, %144, %121#1, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %147 = tt.splat %68 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>> %148 = tt.addptr %147, %121#4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32> %149 = tt.load %148, %126 : tensor<64x!tt.ptr<f32>> %150 = tt.trans %144 {order = array<i32: 1, 0>} : tensor<64x64xf16> -> tensor<64x64xf16> %151 = tt.dot %35, %150, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %152 = tt.expand_dims %149 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> %153 = tt.broadcast %152 : tensor<1x64xf32> -> tensor<64x64xf32> %154 = arith.subf %151, %153 : tensor<64x64xf32> %155 = arith.mulf %140, %154 : tensor<64x64xf32> %156 = arith.mulf %155, %cst_3 : tensor<64x64xf32> %157 = arith.select %29, %156, %cst_9 : tensor<64x64xi1>, tensor<64x64xf32> %158 = arith.truncf %157 : tensor<64x64xf32> to tensor<64x64xf16> %159 = tt.trans %125 {order = array<i32: 1, 0>} : tensor<64x64xf16> -> tensor<64x64xf16> %160 = tt.dot %158, %159, %121#0, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> scf.yield %160, %146 : tensor<64x64xf32>, tensor<64x64xf32> } else { scf.yield %98#0, %98#1 : tensor<64x64xf32>, tensor<64x64xf32> } scf.yield %119#0, %119#1 : tensor<64x64xf32>, tensor<64x64xf32> } %37 = tt.splat %13 : !tt.ptr<f16> -> tensor<64x1x!tt.ptr<f16>> %38 = tt.addptr %37, %21 : tensor<64x1x!tt.ptr<f16>>, tensor<64x1xi32> %39 = tt.broadcast %38 : tensor<64x1x!tt.ptr<f16>> -> tensor<64x64x!tt.ptr<f16>> %40 = tt.addptr %39, %26 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32> %41 = arith.cmpi slt, %24, %cst_11 : tensor<1x64xi32> %42 = tt.broadcast %41 : tensor<1x64xi1> -> tensor<64x64xi1> %43 = arith.andi %29, %42 : tensor<64x64xi1> %44 = arith.truncf %36#1 : tensor<64x64xf32> to tensor<64x64xf16> tt.store %40, %44, %43 : tensor<64x64x!tt.ptr<f16>> %45 = arith.mulf %36#0, %cst_13 : tensor<64x64xf32> %46 = tt.broadcast %21 : tensor<64x1xi32> -> tensor<64x64xi32> %47 = arith.addi %26, %46 : tensor<64x64xi32> %48 = tt.splat %4 : i32 -> tensor<64x64xi32> %49 = arith.addi %47, %48 : tensor<64x64xi32> %50 = tt.splat %8 : i32 -> tensor<64x64xi32> %51 = arith.addi %49, %50 : tensor<64x64xi32> %52 = tt.splat %arg16 : !tt.ptr<f16> -> tensor<64x64x!tt.ptr<f16>> %53 = tt.addptr %52, %51 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32> %54 = arith.truncf %45 : tensor<64x64xf32> to tensor<64x64xf16> tt.store %53, %54, %29 : tensor<64x64x!tt.ptr<f16>> } tt.return } } {-# external_resources: { mlir_reproducer: { pipeline: "builtin.module(convert-triton-to-tritongpu{enable-source-remat=false num-ctas=1 num-warps=4 target=cuda:100 threads-per-warp=32}, tritongpu-coalesce, tritongpu-F32DotTC, triton-nvidia-gpu-plan-cta, tritongpu-remove-layout-conversions, tritongpu-optimize-thread-locality, tritongpu-accelerate-matmul, tritongpu-remove-layout-conversions, tritongpu-optimize-dot-operands{hoist-layout-conversion=true}, triton-nvidia-optimize-descriptor-encoding, triton-loop-aware-cse, tritongpu-fuse-nested-loops, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, triton-licm, tritongpu-optimize-accumulator-init, tritongpu-hoist-tmem-alloc, tritongpu-promote-lhs-to-tmem, tritongpu-assign-latencies{num-stages=3}, tritongpu-schedule-loops, tritongpu-automatic-warp-specialization{num-stages=3}, tritongpu-pipeline{dump-intermediate-steps=false num-stages=3}, tritongpu-combine-tensor-select-and-if, triton-nvidia-gpu-remove-tmem-tokens, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, triton-loop-aware-cse, tritongpu-prefetch, tritongpu-optimize-dot-operands{hoist-layout-conversion=true}, tritongpu-coalesce-async-copy, triton-nvidia-optimize-tmem-layouts, tritongpu-remove-layout-conversions, triton-nvidia-interleave-tmem, tritongpu-reduce-data-duplication, tritongpu-reorder-instructions, triton-loop-aware-cse, symbol-dce, triton-nvidia-tma-lowering, triton-nvidia-gpu-fence-insertion{compute-capability=90}, sccp, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true})", disable_threading: false, verify_each: true } } #-} /tmp/tmp0yiz3c94/p4/cp4ahrfnz4obsvzgftux7dg3aszopks2jljnoaz3eowlooi2scem.py:18:0: error: Failures have been detected while processing an MLIR pass pipeline /tmp/tmp0yiz3c94/p4/cp4ahrfnz4obsvzgftux7dg3aszopks2jljnoaz3eowlooi2scem.py:18:0: note: Pipeline failed while executing [`TritonGPUHoistTMEMAlloc` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.` Triton compilation failed: triton_tem_fused_zeros_1 def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0): PRESCALE_QK : tl.constexpr = False ROWS_GUARANTEED_SAFE : tl.constexpr = False BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False WRITE_DQ : tl.constexpr = True OUTPUT_LOGSUMEXP : tl.constexpr = True FLOAT32_PRECISION : tl.constexpr = 'tf32' IS_DIVISIBLE : tl.constexpr = False SM_SCALE : tl.constexpr = 0.125 GQA_SHARED_HEADS : tl.constexpr = 4 HAS_FULL_BLOCKS : tl.constexpr = True QK_HEAD_DIM : tl.constexpr = 64 QK_HEAD_DIM_ROUNDED : tl.constexpr = 64 V_HEAD_DIM : tl.constexpr = 64 V_HEAD_DIM_ROUNDED : tl.constexpr = 64 SAFE_HEAD_DIM : tl.constexpr = True BLOCK_M1 : tl.constexpr = 64 BLOCK_N1 : tl.constexpr = 64 BLOCK_M2 : tl.constexpr = 64 BLOCK_N2 : tl.constexpr = 64 SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 Q = arg_Q K = arg_K V = arg_V LSE = arg_LSE DELTA = arg_DELTA DO = arg_DO DQ = arg_DQ DV = arg_DV KV_NUM_BLKS = arg_KV_NUM_BLKS KV_IDX = arg_KV_IDX Q_NUM_BLKS = arg_Q_NUM_BLKS Q_IDX = arg_Q_IDX FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS FULL_KV_IDX = arg_FULL_KV_IDX FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS FULL_Q_IDX = arg_FULL_Q_IDX # Sub notation for this kernel: # # Q: Query, K: Key, V: Value # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) # DELTA: Precomputed sum(OUT*DO, axis=-1) # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value # DK: Derivative of Key, is the written to via the store_output call due to some limitations with # inductor codegen # M: Number of queries, N: Number of keys/values # QK_HEAD_DIM: The dimension of the query and key embeddings # V_HEAD_DIM: The dimension of the value embeddings # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. # (Modifiable) Performance tuning options # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. # # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. # The below are kernel options that can be applied for certain score_mods, # or involve a numerics vs. perf tradeoff # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has # about 20% more numerical error, but slightly faster. # Define strides of inputs stride_qz, stride_qh, stride_qm, stride_qd = 32768, 2048, 64, 1 stride_kz, stride_kh, stride_kn, stride_kd = 65536, 16384, 64, 1 stride_vz, stride_vh, stride_vn, stride_vd = 65536, 16384, 64, 1 stride_doz, stride_doh, stride_dom, stride_dod = 32768, 2048, 64, 1 stride_dqz, stride_dqh, stride_dqm, stride_dqd = 32768, 2048, 64, 1 stride_dvz, stride_dvh, stride_dvm, stride_dvd = 65536, 16384, 64, 1 ZQ = 2 HQ = 16 HKV = 4 Q_LEN = 32 ZKV = 2 KV_LEN = 256 MATMUL_PRECISION = Q.dtype.element_ty pid = tl.program_id(0) NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) off_zq = tl.program_id(1) # q batch idx off_hkv = tl.program_id(2) # kv head idx off_zkv = off_zq % ZKV # kv batch idx SPARSE_Z = 2 SPARSE_HQ = 16 sparse_idx_z = off_zq % SPARSE_Z k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) # offset K, V, DV pointers for batch/kv-head K += k_adj V += v_adj DV += dv_adj RCP_LN2 = 1.44269504 offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) if pid >= NUM_KV_BLOCKS: off_pid = pid - NUM_KV_BLOCKS # THIS BLOCK DOES DQ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS start_m2_block = off_pid % NUM_Q_BLOCKS off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE stride_kv_num_blks_h = 1 stride_kv_idx_h = 2 stride_kv_idx_m = 2 sparse_idx_hq2 = off_hq2 % SPARSE_HQ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) Q2 = Q + q_adj2 DO2 = DO + do_adj2 # TODO: This does not work if DQ is not the same layout as Q (for example, # if Q is broadcasted) DQ2 = DQ + dq_adj2 LSE2 = LSE + off_chz2 DELTA2 = DELTA + off_chz2 # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) start_m2 = start_m2_block * BLOCK_M2 offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) # load Q and do: they stay in SRAM throughout the inner loop. q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) if PRESCALE_QK: q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) if IS_DIVISIBLE: Di = tl.load(DELTA2 + offs_m2) lse = tl.load(LSE2 + offs_m2) else: Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) lse = tl.where(lse == -float("inf"), 0.0, lse) lse = lse[:, None] # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # KV_IDX and KV_NUM_BLKS are always contiguous. kv_indices = KV_IDX + sparse_kv_idx_offset kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) offs_n2 = kv_start + tl.arange(0, BLOCK_N2) dq = bwd_dq_inner( arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, K, V, dq, q, do, Di, lse, off_zq, off_hq2, offs_m2, offs_n2, stride_kn, stride_kd, stride_vn, stride_vd, kv_indices, sparse_kv_num_blocks, MATMUL_PRECISION, IS_FULL_BLOCKS=False, ) if HAS_FULL_BLOCKS: # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. kv_indices = FULL_KV_IDX + sparse_kv_idx_offset kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) offs_n2 = kv_start + tl.arange(0, BLOCK_N2) dq = bwd_dq_inner( arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, K, V, dq, q, do, Di, lse, off_zq, off_hq2, offs_m2, offs_n2, stride_kn, stride_kd, stride_vn, stride_vd, kv_indices, sparse_kv_num_blocks, MATMUL_PRECISION, IS_FULL_BLOCKS=True, ) # Write back dQ. dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd dq *= SM_SCALE if IS_DIVISIBLE and SAFE_HEAD_DIM: tl.store(dq_ptrs, dq) else: tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) else: # THIS BLOCK DOES DK & DV SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) pid_mask = pid // SPARSE_KV_MULTIPLE stride_q_num_blks_h = 2 stride_q_idx_h = 2 stride_q_idx_n = 1 dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) start_n1 = pid * BLOCK_N1 offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) # load K and V: they stay in SRAM throughout the inner loop. k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) if PRESCALE_QK: k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) for off_g in range(0, GQA_SHARED_HEADS): off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) Q1 = Q + q_adj1 DO1 = DO + do_adj1 # TODO: This does not work if DQ is not the same layout as Q (for example, # if Q is broadcasted) LSE1 = LSE + off_chz1 DELTA1 = DELTA + off_chz1 sparse_idx_hq1 = off_hq1 % SPARSE_HQ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Q_IDX and Q_NUM_BLKS are always contiguous. q_indices = Q_IDX + sparse_q_idx_offset q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) offs_m1 = q_start + tl.arange(0, BLOCK_M1) dk, dv = bwd_dkdv_inner( arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, Q1, DO1, DELTA1, LSE1, dk, dv, k, v, off_zq, off_hq1, offs_n1, offs_m1, stride_qm, stride_qd, stride_dom, stride_dod, q_indices, sparse_q_num_blocks, MATMUL_PRECISION, IS_FULL_BLOCKS=False, ) if HAS_FULL_BLOCKS: # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. q_indices = FULL_Q_IDX + sparse_q_idx_offset q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) offs_m1 = q_start + tl.arange(0, BLOCK_M1) dk, dv = bwd_dkdv_inner( arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, Q1, DO1, DELTA1, LSE1, dk, dv, k, v, off_zq, off_hq1, offs_n1, offs_m1, stride_qm, stride_qd, stride_dom, stride_dod, q_indices, sparse_q_num_blocks, MATMUL_PRECISION, IS_FULL_BLOCKS=True, ) # Write back dV and dK. dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd index_n = offs_n1[:, None] index_k = offs_k[None, :] index_v = offs_v[None, :] if IS_DIVISIBLE and SAFE_HEAD_DIM: tl.store(dv_ptrs, dv) else: tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) dk *= SM_SCALE if SAFE_HEAD_DIM: mask = index_n < KV_LEN else: mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] xindex = index_k + 64*index_n + 16384*off_hkv + 65536*off_zq tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) metadata: {'signature': {'arg_Q': '*fp16', 'arg_K': '*fp16', 'arg_V': '*fp16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*fp16', 'arg_DQ': '*fp16', 'arg_DV': '*fp16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*fp16'}, 'device': 0, 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}], 'device_type': 'cuda', 'num_warps': 4, 'num_stages': 3, 'debug': True, 'cc': 100} Traceback (most recent call last): File "/home/drisspg/meta/pytorch/torch/_inductor/runtime/triton_heuristics.py", line 748, in _precompile_config binary = triton.compile(*compile_args, **compile_kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/.conda/envs/dev/lib/python3.12/site-packages/triton/compiler/compiler.py", line 359, in compile next_module = compile_ir(module, metadata) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/.conda/envs/dev/lib/python3.12/site-packages/triton/backends/nvidia/compiler.py", line 456, in <lambda> stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/.conda/envs/dev/lib/python3.12/site-packages/triton/backends/nvidia/compiler.py", line 298, in make_ttgir pm.run(mod) RuntimeError: PassManager::run failed frames [('total', 3), ('ok', 3)] inline_call [] stats [('calls_captured', 8), ('unique_graphs', 3)] aot_autograd [('total', 1), ('autograd_cache_miss', 1), ('ok', 1)] inductor [('triton_bundler_save_kernel', 8), ('async_compile_cache_miss', 3), ('fxgraph_cache_miss', 1), ('triton_bundler_save_static_autotuner', 1), ('fxgraph_cache_bypass', 1)] graph_break [] F ==================================================== FAILURES ===================================================== _____________________________ TestFlexAttentionCUDA.test_GQA_score_mod1_cuda_float16 ______________________________ Traceback (most recent call last): File "/home/drisspg/.conda/envs/dev/lib/python3.12/unittest/case.py", line 58, in testPartExecutor yield File "/home/drisspg/.conda/envs/dev/lib/python3.12/unittest/case.py", line 634, in run self._callTestMethod(testMethod) File "/home/drisspg/.conda/envs/dev/lib/python3.12/unittest/case.py", line 589, in _callTestMethod if method() is not None: ^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/testing/_internal/common_utils.py", line 3224, in wrapper method(*args, **kwargs) File "/home/drisspg/meta/pytorch/torch/testing/_internal/common_utils.py", line 3224, in wrapper method(*args, **kwargs) File "/home/drisspg/meta/pytorch/torch/testing/_internal/common_device_type.py", line 446, in instantiated_test raise rte File "/home/drisspg/meta/pytorch/torch/testing/_internal/common_device_type.py", line 426, in instantiated_test result = test(self, **param_kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/testing/_internal/common_device_type.py", line 1349, in dep_fn return fn(self, *args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/testing/_internal/common_device_type.py", line 1215, in dep_fn return fn(slf, *args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/test/inductor/test_flex_attention.py", line 1430, in test_GQA self.run_test(*inputs) File "/home/drisspg/meta/pytorch/test/inductor/test_flex_attention.py", line 566, in run_test compiled_out.backward(backward_grad) File "/home/drisspg/meta/pytorch/torch/_tensor.py", line 625, in backward torch.autograd.backward( File "/home/drisspg/meta/pytorch/torch/autograd/__init__.py", line 354, in backward _engine_run_backward( File "/home/drisspg/meta/pytorch/torch/autograd/graph.py", line 829, in _engine_run_backward return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/autograd/function.py", line 315, in apply return user_fn(self, *args) ^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2303, in backward return impl_fn() ^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2289, in impl_fn out = CompiledFunction._backward_impl(ctx, all_args) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2394, in _backward_impl CompiledFunction.compiled_bw = aot_config.bw_compiler( ^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_functorch/_aot_autograd/schemas.py", line 1256, in __call__ return self.compiler_fn(gm, example_inputs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_dynamo/backends/common.py", line 76, in _wrapped_bw_compiler disable( File "/home/drisspg/meta/pytorch/torch/_dynamo/eval_frame.py", line 1005, in _fn return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_utils_internal.py", line 92, in wrapper_function return function(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_inductor/compile_fx.py", line 2428, in bw_compiler return inner_compile( ^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_inductor/compile_fx.py", line 773, in compile_fx_inner return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_dynamo/repro/after_aot.py", line 124, in debug_wrapper inner_compiled_fn = compiler_fn(gm, example_inputs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_inductor/compile_fx.py", line 952, in _compile_fx_inner mb_compiled_graph = fx_codegen_and_compile( ^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_inductor/compile_fx.py", line 1652, in fx_codegen_and_compile return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_inductor/compile_fx.py", line 1506, in codegen_and_compile compiled_module = graph.compile_to_module() ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_inductor/graph.py", line 2318, in compile_to_module return self._compile_to_module() ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_inductor/graph.py", line 2328, in _compile_to_module mod = self._compile_to_module_lines(wrapper_code) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_inductor/graph.py", line 2396, in _compile_to_module_lines mod = PyCodeCache.load_by_key_path( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_inductor/codecache.py", line 3466, in load_by_key_path mod = _reload_python_module(key, path, set_sys_modules=in_toplevel) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_inductor/runtime/compile_tasks.py", line 33, in _reload_python_module exec(code, mod.__dict__, mod.__dict__) File "/tmp/tmp0yiz3c94/az/caza2gzmsagyuusmf2ka3oat3na4xv6zudssk244xmlzsbv2knze.py", line 117, in <module> File "/home/drisspg/meta/pytorch/torch/_inductor/async_compile.py", line 489, in triton kernel.precompile( File "/home/drisspg/meta/pytorch/torch/_inductor/runtime/triton_heuristics.py", line 437, in precompile self._precompile_worker() File "/home/drisspg/meta/pytorch/torch/_inductor/runtime/triton_heuristics.py", line 459, in _precompile_worker compile_results.append(self._precompile_config(c)) ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_inductor/runtime/triton_heuristics.py", line 748, in _precompile_config binary = triton.compile(*compile_args, **compile_kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/.conda/envs/dev/lib/python3.12/site-packages/triton/compiler/compiler.py", line 359, in compile next_module = compile_ir(module, metadata) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/.conda/envs/dev/lib/python3.12/site-packages/triton/backends/nvidia/compiler.py", line 456, in <lambda> stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/.conda/envs/dev/lib/python3.12/site-packages/triton/backends/nvidia/compiler.py", line 298, in make_ttgir pm.run(mod) RuntimeError: PassManager::run failed To execute this test, run the following from the base repo dir: python test/inductor/test_flex_attention.py TestFlexAttentionCUDA.test_GQA_score_mod1_cuda_float16 This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0 ============================================= short test summary info ============================================= FAILED [5.1441s] test/inductor/test_flex_attention.py::TestFlexAttentionCUDA::test_GQA_score_mod1_cuda_float16 - RuntimeError: PassManager::run failed ================================== 1 failed, 1 passed, 1404 deselected in 18.10s ================================== ~/meta/pytorch flex-warning !1 ❯ ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/160227 Approved by: https://github.com/Skylion007, https://github.com/Chillee |
|||
| 99bc2f94c1 |
Update export/schema.py (#160220)
Summary: Model could have multiple ExportedPrograms - for different methods. They can have different weights. - for different delegates. They can also have different weights. For this reason, we make weight per ExportedProgram. Also, we cleanup Model, and Program. IIUC, Model and Program are not used anywhere, so it's ok to make BC breaking change. Test Plan: CI Rollback Plan: Differential Revision: D79917395 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160220 Approved by: https://github.com/angelayi, https://github.com/dolpm, https://github.com/jingsh |
|||
| fc25c68f20 |
[hop][exc] make UncapturedHigherOrderOpError print user code and avoid re-raise (#159296)
After the change, the error stacktrace is attached with user code stack and is suppressed into 1 (without the scrolling up mssage). For example:
```python
class Test(torch.nn.Module):
def forward(self, c, x):
def cond_fn(c, x):
return c > 0 and x.size(0) < 20
def body_fn(c, x):
return c - 1, x.sin()
return torch._higher_order_ops.while_loop(cond_fn, body_fn, (c, x))
```
Now gives the following error message:
```python
Traceback (most recent call last):
File "/home/yidi/local/pytorch/test/inductor/test_control_flow.py", line 1705, in test_while_loop_size_mismatch_tensor_expansion
self._run_test(
~~~~~~~~~~~~~~^
model=WhileLoopModels.SizeMismatchTensorExpansion(),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
...<2 lines>...
dynamic=dynamic,
^^^^^^^^^^^^^^^^
)
^
File "/home/yidi/local/pytorch/test/inductor/test_control_flow.py", line 1417, in _run_test
result = model(*inputs_with_counters)
File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/home/yidi/local/pytorch/test/inductor/test_control_flow.py", line 1053, in forward
return torch._higher_order_ops.while_loop(cond_fn, body_fn, (c, x))
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_higher_order_ops/while_loop.py", line 176, in while_loop
return torch.compile(
~~~~~~~~~~~~~~
_while_loop_op_wrapper, backend=backend, fullgraph=True
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
)(flat_cond_fn, flat_body_fn, tuple(flat_inputs), tuple())
~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 804, in compile_wrapper
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 1595, in __call__
result = self._torchdynamo_orig_backend(
frame, cache_entry, self.hooks, frame_state, skip=1
)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 1353, in __call__
result = self._inner_convert(
frame, cache_entry, hooks, frame_state, skip=skip + 1
)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 682, in __call__
result = _compile(
frame.f_code,
...<16 lines>...
convert_frame_box=self._box,
)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 1172, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/home/yidi/local/pytorch/torch/_utils_internal.py", line 98, in wrapper_function
return function(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 858, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 897, in _compile_inner
out_code = transform_code_object(code, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1461, in transform_code_object
transformations(instructions, code_options)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 300, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 818, in transform
tracer.run()
~~~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3528, in run
super().run()
~~~~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1372, in run
while self.step():
~~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1276, in step
self.dispatch_table[inst.opcode](self, inst)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 852, in wrapper
return inner_fn(self, inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2240, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars)
~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1200, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/variables/lazy.py", line 212, in realize_and_forward
return getattr(self.realize(), name)(*args, **kwargs)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 91, in graph_break_as_hard_error
raise exc.with_traceback(sys.exc_info()[2]) from None
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 77, in graph_break_as_hard_error
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 1287, in call_function
) = speculate_subgraph(
~~~~~~~~~~~~~~~~~~^
tx,
^^^
...<33 lines>...
supports_aliasing=self.supports_aliasing,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 877, in speculate_subgraph
raise ex
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 718, in speculate_subgraph
output = f.call_function(tx, args, sub_kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 580, in call_function
return super().call_function(tx, args, kwargs)
~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 334, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1217, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3733, in inline_call
return tracer.inline_call_()
~~~~~~~~~~~~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3936, in inline_call_
self.run()
~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1372, in run
while self.step():
~~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1276, in step
self.dispatch_table[inst.opcode](self, inst)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 852, in wrapper
return inner_fn(self, inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2240, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars)
~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1200, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/variables/lazy.py", line 212, in realize_and_forward
return getattr(self.realize(), name)(*args, **kwargs)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 580, in call_function
return super().call_function(tx, args, kwargs)
~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 334, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1217, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3733, in inline_call
return tracer.inline_call_()
~~~~~~~~~~~~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3936, in inline_call_
self.run()
~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1372, in run
while self.step():
~~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1276, in step
self.dispatch_table[inst.opcode](self, inst)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 830, in inner
unimplemented_v2(
~~~~~~~~~~~~~~~~^
gb_type="Data-dependent branching",
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
...<5 lines>...
],
^^
)
^
File "/home/yidi/local/pytorch/torch/_dynamo/exc.py", line 580, in unimplemented_v2
raise Unsupported(msg)
torch._dynamo.exc.UncapturedHigherOrderOpError: while_loop doesn't work unless it is captured completely with torch.compile. Got Data-dependent branching
Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow.
Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.
Hint: Use `torch.cond` to express dynamic control flow.
Developer debug context: attempted to jump with TensorVariable()
For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0170.html
from user code:
File "/home/yidi/local/pytorch/torch/_higher_order_ops/while_loop.py", line 167, in _while_loop_op_wrapper
return while_loop_op(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_higher_order_ops/while_loop.py", line 137, in flat_cond_fn
return cond_fn(*carried, *additional)
File "/home/yidi/local/pytorch/test/inductor/test_control_flow.py", line 1047, in cond_fn
return c > 0 and x.size(0) < 20
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
To execute this test, run the following from the base repo dir:
python test/inductor/test_control_flow.py WhileLoopTests.test_while_loop_size_mismatch_tensor_expansion_device_cpu_dynamic_False
This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159296
Approved by: https://github.com/zou3519
|
|||
| 5a40c57844 |
[MTIA] Implement isAvailable() for MTIA hooks (#160304)
Summary: MTIA is missing the `isAvailable()` override, which is necessary for some of the device agnostic methods. Test Plan: `torch._C._get_accelerator()` Rollback Plan: Differential Revision: D79981115 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160304 Approved by: https://github.com/nautsimon |
|||
| 7d2ec704e4 |
Fix MPS autocast for ConvTranspose3d (#160345)
## Summary - ensure ConvTranspose3d uses fp32 under MPS autocast - add MPS autocast test for ConvTranspose3d Generated by Codex, see https://chatgpt.com/codex/tasks/task_e_689a360388288327a2cac6f55bbfc42c Fixes https://github.com/pytorch/pytorch/issues/160332 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160345 Approved by: https://github.com/dcci |
|||
| fc80f6859e |
Fix collective schedule logging and runtime tests (#160260)
Summary: - Fix collective schedule logging so that only logs when collectives present - Fix runtime estimate test to check if each op has a number value Pull Request resolved: https://github.com/pytorch/pytorch/pull/160260 Approved by: https://github.com/Skylion007 |
|||
| cf0a0dcb0a |
Make user defined Triton kernels serializable for fx_graph_runnable (#160002)
Resolves issue https://github.com/pytorch/pytorch/issues/153475 where `fx_graph_runnable` didn't work with user defined triton kernels. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160002 Approved by: https://github.com/eellison |
|||
| b149c7204c |
Revert "port distributed pipeline test files for Intel GPU (#159033)"
This reverts commit 76a0609b6bddb2bc40f1eb4ade12885023653d59.
Reverted https://github.com/pytorch/pytorch/pull/159033 on behalf of https://github.com/clee2000 due to broke test_cpp_extensions_stream_and_event.py::TestCppExtensionStreamAndEvent::test_stream_event [GH job link](https://github.com/pytorch/pytorch/actions/runs/16890370216/job/47849586456) [HUD commit link](
|
|||
| 09381f5dac |
Revert "[Graph Partition] Pass all OSS unit tests (#154667)"
This reverts commit ca7315c17162ea21b1ca5ba23f4bf6168766c7b9.
Reverted https://github.com/pytorch/pytorch/pull/154667 on behalf of https://github.com/clee2000 due to broke inductor/test_memory.py::TestOperatorReorderForPeakMemory::test_reorder_peak_memory_lpmf [GH job link](https://github.com/pytorch/pytorch/actions/runs/16885961204/job/47836769279) [HUD commit link](
|
|||
| 9eedd2a20b |
[PGO] no counterfactual suggestions for dynamic allowlist (#160231)
Being more conservative with whitelist suggestions as we roll out suggestions; now we only suggest sources that were dynamic in previous runs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160231 Approved by: https://github.com/bobrenjc93 |
|||
| c3dc8dc412 |
159965 is merged, no need to patch it in (#160275)
Signed-off-by: Edward Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/160275 Approved by: https://github.com/albanD, https://github.com/ZainRizvi |
|||
| 76a0609b6b |
port distributed pipeline test files for Intel GPU (#159033)
In this PR we will port all distributed pipeline test files. We could enable Intel GPU with following methods and try the best to keep the original code styles: 1. instantiate_device_type_tests() 2. use "torch.accelerator.current_accelerator()" to determine the accelerator backend 3. use "requires_accelerator_dist_backend()" to replace requires_nccl() 4. use "get_default_backend_for_device()" to get backend 5. enabled XPU for some test path 6. add TEST_MULTIACCELERATOR in common_utils for all backend. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159033 Approved by: https://github.com/guangyey, https://github.com/d4l3k Co-authored-by: Daisy Deng <daisy.deng@intel.com> |
|||
| c8205cb354 |
[autograd] match 0-dim gradients device type regardless of subclassness (#160165)
Not sure if there some subclasses where the outer.dim() == 0 but you wouldn't want to move it? FIXES https://github.com/pytorch/pytorch/issues/160084 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160165 Approved by: https://github.com/ezyang, https://github.com/albanD |
|||
| d25c4f954d |
[MPS] Type-promote tensor-iterator common dtype (#160334)
Otherwise, `torch.add(FloatTensor, IntTensor, alpha=2)` and `torch.add(FloatTensor, IntTensor, alpha=2)` were dispatched to different kernels Fixes https://github.com/pytorch/pytorch/issues/160208 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160334 Approved by: https://github.com/Skylion007, https://github.com/dcci |
|||
| d0e2240f68 |
[triton_heuristics] Optimize the triton launcher in pt2 (#160000)
Summary:
(Original author: Xu Zhao. Commandeered by David to land this since it is relatively urgent)
We observed ~10us PT2-Triton launch overhead regression after pin update.
Before Triton pin-update:
{F1980557238}
After Triton pin-update:
{F1980557240}
The root cause is because https://github.com/pytorch/pytorch/pull/145051 adds `_get_args_with_constexprs` to the cubin launcher caller function, which is on the critical path.
The motivation for `_get_args_with_constexprs` was that between triton 3.2 and triton 3.3, the convention for calling Triton kernels (at the level that non-static-cuda-launcher inductor integrates) changed. Previously, the callable did not take constexpr arguments as parameters; after 3.3, it does. With pointwise/reduction kernels, we don't know the constexpr values until after autotuning occurs; so `_get_args_with_constexprs` would inject constexprs into the arguments list before calling the Triton kernel. The fix (in this PR) is to instead inject the constexpr args into the launcher string - this avoids the cost of sorting/reordering arguments which previously occurred upon execution of each kernel.
Note that the static_cuda_launcher.py does not require constants to be passed to the cubin launcher (
|
|||
| 9ccd0f5e31 |
Fix unbacked symint and memory leak in inductor memory planning (#159839)
Summary: In memory planning, some allocation sizes involve unbacked symints. These unbacked symints are not known before they are computed in run time, so **allocation pools that involve unbacked symints cannot be allocated until we have the values of the unbacked symints** . So we add a notion of `earliest_available` to Allocation nodes. If an allocation node has unbacked symint, it is available at only when its live range begin. Then in AllocationPool, if a pool involves an Allocation node that has an earliest available time, we restrict its life range. If a block's earliest available time is later than a pool's life range's start time, we cannot allocate it from the pool. We also fix a memory leak that's caused by allocating tensor without wrapping it with RAIIAtenTensor. In python wrapper for JIT inductor, `codegen_alloc_from_pool` doesn't actually write the alloc lines to wrapper, it just returns the string to alloc. However, in cpp_wrapper, `codegen_alloc_from_pool` actually write to the wrapper. Specifically, it writes the following and returns string `RAIIAtenTensorHandle`. ``` AtenTensorHandle handle_name; AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool(....); ``` This is bug prune. **If you write aoti_torch__alloc_from_pool lines, you must write the RAIIAtenTensorHandle as well**, otherwise you get memory leaks. We remove the alloc_from_pool call from codegen_create, because this doesn't work for AOTI. In python wrapper, we can generate the same alloc_from_pool variable name for the same block, but cpp_wrapper will generate a different variable name for each call to alloc_from_pool. Test Plan: ``` python test/inductor/test_memory_planning.py ``` Rollback Plan: Differential Revision: D79603119 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159839 Approved by: https://github.com/jansel |
|||
| ca7315c171 |
[Graph Partition] Pass all OSS unit tests (#154667)
Graph partition leads to 6.2% speedup on vision_maskrcnn, 5.8% speedup on yolov3. [P1819700563](https://www.internalfb.com/phabricator/paste/view/P1819700563), 39.5% speedup on speech_transformer inference [P1830602200](https://www.internalfb.com/phabricator/paste/view/P1830602200), 85% speedup on speech_transformer training [P1831115315](https://www.internalfb.com/phabricator/paste/view/P1831115315). Run the same diff on two days and both show speedup on average. [first TorchInductor Benchmark ci run](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2021%20Jul%202025%2016%3A37%3A55%20GMT&stopTime=Mon%2C%2028%20Jul%202025%2016%3A37%3A55%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=bf/partition-turn-on&lCommit=75ef90fe89b82c967362a2d40fdf1af047202bc2&rBranch=main&rCommit=abcb24f4de11f8fedf2c2c9ff53b6092ef42306d) <img width="1885" height="752" alt="image" src="https://github.com/user-attachments/assets/13bba9fc-5dbf-42ad-8558-d54f7e367b41" /> [second TorchInductorBenchmark ci run](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Wed%2C%2023%20Jul%202025%2016%3A38%3A27%20GMT&stopTime=Wed%2C%2030%20Jul%202025%2016%3A38%3A27%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=bf/partition-turn-on&lCommit=66de27e29338c26b1be94733049868cb0309ea52&rBranch=main&rCommit=70d2e9ba455c3c910f6f95b24171c8eee7bc00bf) <img width="2513" height="1030" alt="image" src="https://github.com/user-attachments/assets/3a413dcb-2314-4292-919a-7ca181f9eeac" /> Pull Request resolved: https://github.com/pytorch/pytorch/pull/154667 Approved by: https://github.com/eellison |
|||
| 68a4b4b2e3 |
[codemod] Fix unreachable-break issue in caffe2/c10/cuda/CUDAFunctions.cpp +2 (#160257)
Summary: LLVM has a warning `-Wunreachable-code-break` which identifies `break` statements that cannot be reached. These compromise readability, are misleading, and may identify bugs. This diff removes such statements. For questions/comments, contact r-barnes. - If you approve of this diff, please use the "Accept & Ship" button :-) Test Plan: Sandcastle Rollback Plan: Differential Revision: D79835614 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160257 Approved by: https://github.com/Skylion007 |