Compare commits

...

39 Commits

Author SHA1 Message Date
9f8b4700b5 add changes for dynamic range tuning 2025-11-11 21:50:00 -08:00
cdf0a9c21f Add FA4 to sdpa (#167348)
# Summary
See title ;)

## Design

Currently once you install there is no going back in the same python process, this need not be the case, cc @mikaylagawarecki's work on being able to grab original impl. I'll leave for follow up.

Okay I added an open reg, but I really want the backends to be found so some weird typing but we get
<img width="523" height="197" alt="Screenshot 2025-11-07 at 3 30 32 PM" src="https://github.com/user-attachments/assets/586de943-bbed-40cf-abd1-131f747a4cf1" />

## Overheads:
<img width="799" height="735" alt="Screenshot 2025-11-07 at 2 35 04 PM" src="https://github.com/user-attachments/assets/f9217f31-3e42-4816-8fb3-29ea8b49d735" />
First call to forward -> majority of time is spent in jit for FA

First call to backward, 3sec interestingly it doesn't appear that with_stack gets events in the backwards loop @albanD is this expected?
<img width="948" height="385" alt="Screenshot 2025-11-07 at 2 35 50 PM" src="https://github.com/user-attachments/assets/a40bacd0-3fb0-4bd8-b33e-bec8fb3f36c0" />

Getting form Pt op to impl is about 43 us which is dwarfed by other cpu overheads
<img width="1227" height="649" alt="Screenshot 2025-11-07 at 2 37 41 PM" src="https://github.com/user-attachments/assets/51da0615-facd-41e1-a6e2-fb7778079ab6" />

Just invoking the jit object from cutesl is 100s of us
<img width="545" height="414" alt="Screenshot 2025-11-07 at 2 38 19 PM" src="https://github.com/user-attachments/assets/d20345a0-6c47-4dcb-892f-9ef9894a1cf5" />

### Example usage
```Py
#!/usr/bin/env python3

"""Minimal FA4 smoke test for scaled dot product attention."""

from __future__ import annotations

import sys
from jsonargparse import CLI

import torch
import torch.nn.functional as F
from torch.nn.attention import (
    install_flash_attention_impl,
    sdpa_kernel,
    SDPBackend,
)

def _map_dtype(kind: str) -> torch.dtype:
    return torch.bfloat16 if kind == "bf16" else torch.float16

# To infinity and beyond
install_flash_attention_impl("FA4")

@sdpa_kernel([SDPBackend.FLASH_ATTENTION])
def main(
    module_path: str = "flash_attn.cute.interface",
    batch: int = 4,
    seq: int = 81292,
    heads: int = 16,
    head_dim: int = 128,
    device: int = 0,
    dtype: str = "bf16"
    ) -> None:
    if not torch.cuda.is_available():
        sys.exit("CUDA is required for FA4 smoke testing")
    torch.cuda.set_device(device)
    dtype = _map_dtype(dtype)
    generator = torch.Generator(device="cuda").manual_seed(0)
    q = torch.randn(
        batch,
        heads,
        seq,
        head_dim,
        device="cuda",
        dtype=dtype,
        requires_grad=True,
        generator=generator,
    )
    k = torch.randn(
        batch,
        heads,
        seq,
        head_dim,
        device="cuda",
        dtype=dtype,
        requires_grad=True,
        generator=generator,
    )
    v = torch.randn(
        batch,
        heads,
        seq,
        head_dim,
        device="cuda",
        dtype=dtype,
        requires_grad=True,
        generator=generator,
    )
    from transformer_nuggets.utils.benchmark import profiler
    with profiler("sdpa_FA4", with_stack=False):
        for _ in range(3):
            out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
            loss = out.real.sum()
            loss.backward()
    print("Scaled dot product attention output norm:", out.norm().item())
    print("dq norm:", q.grad.norm().item())

if __name__ == "__main__":
    CLI(main)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167348
Approved by: https://github.com/albanD
2025-11-12 01:07:59 +00:00
115016f1a2 [Device Mesh][ez] Clean up unused parameters and duplicate codes (#167581)
While refactoring the code, I found we re-init `_flatten_mapping` and still keep `_flatten_mesh_list ` inside code which is not needed anymore. Let's remove it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167581
Approved by: https://github.com/fegin
2025-11-12 00:59:32 +00:00
971e6ca434 fix sym_size_, sym_stride lowering (#167565)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167565
Approved by: https://github.com/bobrenjc93, https://github.com/Microve, https://github.com/Skylion007
ghstack dependencies: #167345
2025-11-12 00:53:36 +00:00
e8d411e7f7 FSDPMemTracker fix with multihander hooks. (#165662)
Fixes #164663

## Issue
The torch model with multiple layers that is wrapped with fsdp2 registers pre and post forward hooks in a group using `_MultiHandler`. This becomes an issue during the context manager of the tracker where the hooks are reset and replaced. The hooks are all using the same fsdp state pointer so one reset will reset all.  So when the output layer was modified with a new pre and post forward hook it would delete the previous layer's initialization causing `KeyError` for the Norm layer as it is nonexistent.

## The Fix
Check to see if there are multiple `_MultiHandler` objects and `RemoveHandler` objects and only execute the remove hook once.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165662
Approved by: https://github.com/sanketpurandare
2025-11-11 23:49:36 +00:00
2e5233d7bd Revert "Support AC in default partitioner when functionalization is enabled (#166610)"
This reverts commit de773364be041ca7fd2dcaf35ca15c093fc9370b.

Reverted https://github.com/pytorch/pytorch/pull/166610 on behalf of https://github.com/soulitzer due to breaking internal tests ([comment](https://github.com/pytorch/pytorch/pull/166610#issuecomment-3519047226))
2025-11-11 23:01:09 +00:00
514dd96376 Remove --no-use-pep517 flag (#167096)
In pip 25.3 and newer, use of --no-use-pep517 has been removed (https://pip.pypa.io/en/stable/news/). In builds with pip 25.2, a warning message notes:

> DEPRECATION: Building 'torchvision' using the legacy setup.py bdist_wheel mechanism, which will be removed in a future version. pip 25.3 will enforce this behaviour change. A possible replacement is to use the standardized build interface by setting the `--use-pep517` option, (possibly combined with `--no-build-isolation`), or adding a `pyproject.toml` file to the source tree of 'torchvision'. Discussion can be found at https://github.com/pypa/pip/issues/6334

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167096
Approved by: https://github.com/atalman
2025-11-11 23:00:35 +00:00
9ae62fcc18 [ROCm][CI] dynamo benchmarks update ci expected accuracy (#167574)
repvgg_a2 IMPROVED: accuracy=pass, expected=fail_accuracy

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167574
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-11-11 22:54:55 +00:00
ae71b0e163 Fix typo in torch._refs (#167310)
Should be a typo here, but it doesn't raise an error because the inner function splits it into `a` and `,`, and the `,` case check is skipped.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167310
Approved by: https://github.com/eellison
2025-11-11 22:31:09 +00:00
5b6ff8148d Revert "[ARM] Improve LLM performance & mem usage using int4-bf16 KleidiAI kernels (#158250)"
This reverts commit 402c46503002f98ccfc023a733081fb0719223a1.

Reverted https://github.com/pytorch/pytorch/pull/158250 on behalf of https://github.com/izaitsevfb due to Broke some torch.compile jobs ([comment](https://github.com/pytorch/pytorch/pull/158250#issuecomment-3518944863))
2025-11-11 22:27:51 +00:00
1f7e4343e7 [ROCm][CI] Add docker-cache-rocm.yml to test MI3xx CI docker caching (#167554)
* Trigger this workflow on every completed run of `docker-builds.yml`
* Uses `ubuntu-latest` for downloading artifacts from `docker-build` workflow run
* Uses `linux.rocm.gfx942.docker-cache` to cache docker images as tarballs for MI3xx CI

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167554
Approved by: https://github.com/jeffdaily
2025-11-11 21:32:22 +00:00
b21856f5fc Revert "[DebugMode] record triton kernels, run-to-run determinism checks (#167028)"
This reverts commit 259ba0ecabd809edd35d12b4f992777cb5923b68.

Reverted https://github.com/pytorch/pytorch/pull/167028 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/167028#issuecomment-3518811298))
2025-11-11 21:31:12 +00:00
259ba0ecab [DebugMode] record triton kernels, run-to-run determinism checks (#167028)
Following up on https://github.com/pytorch/pytorch/pull/166348, extends DebugMode to capture inductor triton kernels at runtime, and adds an API for checking run-to-run determinism based on tensor hashes.

The workflow looks something like...
```python
# do 1st run with hashes, get logs
with DebugMode() as debug_mode, DebugMode.log_tensor_hashes():
    compiled_model(*inputs)
logs1 = debug_mode.logs

# do 2nd run
with DebugMode() as debug_mode, DebugMode.log_tensor_hashes():
    compiled_model(*inputs)
logs2 = debug_mode.logs

# returns list of calls w/ mismatched outputs
mismatches = DebugMode.check_hash_mismatches(logs1, logs2)
```

Example dump off a smaller version of @drisspg's FlexAttention fwd+bwd determinism tests [script](https://gist.github.com/pianpwk/f65cc63811d12853709dcc77d7eb69f1) (without forced reduction order):
```
cfg: TestConfig(name='Standard', B=2, Hq=32, Hkv=32, Q=2048, KV=2048, Dqk=128, Dv=128)
DETERMINISM: fwd: True, bwd_q: False, bwd_k: False, bwd_v: True

$$$ DEBUG MODE DUMP $$$  (this is what the logs look like)

    [triton] triton_tem_fused_0(arg_Q=t: bf16[2, 32, 2048, 128], arg_K=t: bf16[2, 32, 2048, 128], arg_V=t: bf16[2, 32, 2048, 128], arg_LSE=t: f32[2, 32, 2048], arg_MAX=t: f32[2, 32, 2048], arg_KV_NUM_BLKS=t: i32[2, 32, 16], arg_KV_IDX=t: i32[2, 32, 16, 16], arg_FULL_KV_NUM_BLKS=t: i32[2, 32, 16], arg_FULL_KV_IDX=t: i32[2, 32, 16, 16], out_ptr0=t: bf16[2, 32, 2048, 128])
    # post-kernel hashes: {arg_Q: 13385916.068706088, arg_K: 13389356.409105342, arg_V: 13384993.48412523, arg_LSE: 1347168.9026973695, arg_MAX: 81775.3811062593, arg_KV_NUM_BLKS: 1024.0, arg_KV_IDX: 122880.0, arg_FULL_KV_NUM_BLKS: 7680.0, arg_FULL_KV_IDX: 122880.0, out_ptr0: 924917.7918248245}

    [triton] triton_per_fused_zeros_0(in_ptr0=t: bf16[2, 32, 2048, 128], in_ptr1=t: bf16[2, 32, 2048, 128], out_ptr1=t: f32[2, 32, 2048], xnumel=131072, r0_numel=128)
    # post-kernel hashes: {in_ptr0: 924917.7918248245, in_ptr1: 13389213.797377996, out_ptr1: 81775.38106592931}

    [triton] triton_tem_fused_zeros_1(arg_Q=t: bf16[2, 32, 2048, 128], arg_K=t: bf16[2, 32, 2048, 128], arg_V=t: bf16[2, 32, 2048, 128], arg_LSE=t: f32[2, 32, 2048], arg_DELTA=t: f32[2, 32, 2048], arg_DO=t: bf16[2, 32, 2048, 128], arg_DQ=t: bf16[2, 32, 2048, 128], arg_DV=t: bf16[2, 32, 2048, 128], arg_KV_NUM_BLKS=t: i32[2, 32, 16], arg_KV_IDX=t: i32[2, 32, 16, 16], arg_Q_NUM_BLKS=t: i32[2, 32, 16], arg_Q_IDX=t: i32[2, 32, 16, 16], arg_FULL_KV_NUM_BLKS=t: i32[2, 32, 16], arg_FULL_KV_IDX=t: i32[2, 32, 16, 16], arg_FULL_Q_NUM_BLKS=t: i32[2, 32, 16], arg_FULL_Q_IDX=t: i32[2, 32, 16, 16], out_ptr0=t: bf16[2, 32, 2048, 128])
    # post-kernel hashes: {arg_Q: 13385916.068706088, arg_K: 13389356.409105342, arg_V: 13384993.48412523, arg_LSE: 1347168.9026973695, arg_DELTA: 81775.38106592931, arg_DO: 13389213.797377996, arg_DQ: 874474.8084187683, arg_DV: 727742.3138379117, arg_KV_NUM_BLKS: 1024.0, arg_KV_IDX: 122880.0, arg_Q_NUM_BLKS: 1024.0, arg_Q_IDX: 122880.0, arg_FULL_KV_NUM_BLKS: 7680.0, arg_FULL_KV_IDX: 122880.0, arg_FULL_Q_NUM_BLKS: 7680.0, arg_FULL_Q_IDX: 122880.0, out_ptr0: 700542.3431890717}

$$$ MISMATCHES $$$
mismatch: {'call_type': 'triton kernel', 'call': 'triton_tem_fused_0', 'arg_name': 'arg_MAX', 'pytree_path': None, 'hash1': 0.0, 'hash2': 81775.3811062593, 'rel_diff': 1.0, 'is_input_hash': False}  # I guess this one is misleading? not sure if I'm doing something wrong with waiting for kernel results
mismatch: {'call_type': 'triton kernel', 'call': 'triton_per_fused_zeros_0', 'arg_name': 'out_ptr1', 'pytree_path': None, 'hash1': 81775.3811062593, 'hash2': 81775.38106592931, 'rel_diff': 4.931801261646669e-10, 'is_input_hash': False}
mismatch: {'call_type': 'triton kernel', 'call': 'triton_tem_fused_zeros_1', 'arg_name': 'arg_DELTA', 'pytree_path': None, 'hash1': 81775.3811062593, 'hash2': 81775.38106592931, 'rel_diff': 4.931801261646669e-10, 'is_input_hash': False}
mismatch: {'call_type': 'triton kernel', 'call': 'triton_tem_fused_zeros_1', 'arg_name': 'arg_DQ', 'pytree_path': None, 'hash1': 874474.8097136207, 'hash2': 874474.8084187683, 'rel_diff': 1.480720012120795e-09, 'is_input_hash': False}
mismatch: {'call_type': 'triton kernel', 'call': 'triton_tem_fused_zeros_1', 'arg_name': 'out_ptr0', 'pytree_path': None, 'hash1': 700542.3488049245, 'hash2': 700542.3431890717, 'rel_diff': 8.016435812581196e-09, 'is_input_hash': False}
```

note: current hash implementation is basically tensor norm, so tensor closeness -> hash closeness. This is likely to change soon, e.g. maybe to `torch.hash_tensor` (https://github.com/pytorch/pytorch/pull/154149) by default

Sample paste diff between log dumps from 2 runs:
<img width="1665" height="445" alt="Screenshot 2025-11-05 at 11 27 24 PM" src="https://github.com/user-attachments/assets/41402e37-f50b-4a9e-a17c-bb98b5917076" />

Another case where running this for FSDP2 on Llama3-8B, helped narrow down divergence b/w aot_eager <-> inductor, to inductor's FWD RMSNorm kernels: P2027003180

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167028
Approved by: https://github.com/v0i0
2025-11-11 20:37:53 +00:00
051f1fe8e3 Revert "[ROCm][CI] Update docker-cache-mi300.yml to test MI300 CI docker caching (#167554)"
This reverts commit ee387c43feada1cc2049b42a970ec4e2f12f210e.

Reverted https://github.com/pytorch/pytorch/pull/167554 on behalf of https://github.com/jithunnair-amd due to workflow had failure 'Unexpected input(s) 'run_id'' ([comment](https://github.com/pytorch/pytorch/pull/167554#issuecomment-3518642191))
2025-11-11 20:34:44 +00:00
ee387c43fe [ROCm][CI] Update docker-cache-mi300.yml to test MI300 CI docker caching (#167554)
Trigger this workflow on every completed run of `docker-builds.yml` and run on `ubuntu-latest` so it doesn't queue infinitely for `rocm-docker` label

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167554
Approved by: https://github.com/jeffdaily
2025-11-11 19:49:00 +00:00
3a944661d6 Cpython test_math.FMATests (#167217)
Resolves issues running the dynamo cpython math.fma tests.

Though math.fma is enabled to perform a multiply add in dynamo, torch.addcmul is currently used which doesn't guarantee the user request for fma. It was decided to not use inductor fma prim as it would break the contract of using aten/core ir in dynamo output - otherwise export=True may have issues.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167217
Approved by: https://github.com/guilhermeleobas
2025-11-11 19:26:18 +00:00
56034074ca Revert "[Inductor] Naive foreach autotune support (#162053)"
This reverts commit 6c5db82584bf71f5b1db3b598bbd00f44140c28d.

Reverted https://github.com/pytorch/pytorch/pull/162053 on behalf of https://github.com/mlazos due to Sorry, there's an internal slowdown due to the extra triton configs you added ([comment](https://github.com/pytorch/pytorch/pull/162053#issuecomment-3518423369))
2025-11-11 19:23:40 +00:00
8def619bbe [user-streams] wait_stream op (#167512)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167512
Approved by: https://github.com/williamwen42
ghstack dependencies: #167510, #167511
2025-11-11 19:18:03 +00:00
61883a5787 [user-streams] Allow new streams to be created and registered during compilation (#167511)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167511
Approved by: https://github.com/williamwen42
ghstack dependencies: #167510
2025-11-11 19:18:03 +00:00
d8ada1ee76 [user-streams] Allow new events to be created and registered during compilation (#167510)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167510
Approved by: https://github.com/williamwen42
2025-11-11 19:18:03 +00:00
fe841a1db4 [DeviceMesh] Log DeviceMesh.__init__ usage (#167375)
Adds (meta-internal-only) API usage logging for DeviceMesh creation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167375
Approved by: https://github.com/fduwjj
ghstack dependencies: #167374
2025-11-11 19:15:47 +00:00
b65829b84f [DTensor] Log API usage metrics for DTensor and DeviceMesh (#167374)
Logging propagate_op_sharding_non_cached is a compromise between
 - logging in DTensor.__init__ to catch ALL DTensor usage
 - sparing the overhead in a latency-senstitive region like
   DTensor.__init__
 - and 'real' DTensor usage should incur at least one call to sharding
   propagation

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167374
Approved by: https://github.com/zpcore
2025-11-11 19:15:47 +00:00
b0e0ae97ba include thrust/distance.h explicitly in cuda sparse softmax (#167436)
`thrust::distance` is defined there
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167436
Approved by: https://github.com/Skylion007
2025-11-11 19:10:55 +00:00
f44a1ddcb2 Revert "[ROCm][CI] Update docker-cache-mi300.yml to test MI300 CI docker caching (#167554)"
This reverts commit 184e2cbc89570e1bf466b15d70fc36ed71be0eb9.

Reverted https://github.com/pytorch/pytorch/pull/167554 on behalf of https://github.com/jithunnair-amd due to Need to fix lint ([comment](https://github.com/pytorch/pytorch/pull/167554#issuecomment-3518382341))
2025-11-11 19:09:45 +00:00
184e2cbc89 [ROCm][CI] Update docker-cache-mi300.yml to test MI300 CI docker caching (#167554)
Trigger this workflow on every completed run of `docker-builds.yml` and run on `ubuntu-latest` so it doesn't queue infinitely for `rocm-docker` label

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167554
Approved by: https://github.com/jeffdaily
2025-11-11 19:07:19 +00:00
416421c7c4 fix failure of exporting compiled model with nested dynamic shapes (#166358)
## Problems
When exporting a compiled model with nested input like below
```python
import torch
from torch.export import export, Dim

def test_export_compiled_model_with_nested_dynamic_shapes():
   """Test exporting a compiled model with nested dict inputs and dynamic shapes."""
   print("Running test_export_compiled_model_with_nested_dynamic_shapes...")

   class M(torch.nn.Module):
       def forward(self, data_batch):
           return data_batch["a1"] + data_batch["a2"]

   m = M()
   compiled_m = torch.compile(m)
   example_args = ({
       "a1": torch.ones(3, 3),
       "a2": torch.ones(3, 3),
   },)
   dynamic_shapes = ({
       "a1": {0: Dim.DYNAMIC},
       "a2": {0: Dim.DYNAMIC},
   },)

   try:
       ep = export(compiled_m, example_args, dynamic_shapes=dynamic_shapes, strict=True)
       gm = ep.module()
       result_exported = gm(*example_args)
       result_compiled = compiled_m(*example_args)

       assert torch.allclose(result_exported, result_compiled), "Results don't match!"
       print("✓ test_export_compiled_model_with_nested_dynamic_shapes PASSED")
       return True
   except Exception as e:
       print(f"✗ test_export_compiled_model_with_nested_dynamic_shapes FAILED")
       print(f"Error: {e}")
       import traceback
       traceback.print_exc()
       return False

def test_export_compiled_model_with_kwargs_dynamic_shapes():
   """Test exporting a compiled model with kwargs and dynamic shapes."""
   print("\nRunning test_export_compiled_model_with_kwargs_dynamic_shapes...")

   class M(torch.nn.Module):
       def forward(self, a1, a2):
           return a1 + a2

   m = M()
   compiled_m = torch.compile(m)
   example_args = ()
   example_kwargs = {
       "a1": torch.ones(3, 3),
       "a2": torch.ones(3, 3),
   }
   dynamic_shapes = {
       "a1": {0: Dim.DYNAMIC},
       "a2": {0: Dim.DYNAMIC},
   }

   try:
       ep = export(compiled_m, example_args, kwargs=example_kwargs, dynamic_shapes=dynamic_shapes, strict=True)
       gm = ep.module()
       result_exported = gm(**example_kwargs)
       result_compiled = compiled_m(**example_kwargs)

       assert torch.allclose(result_exported, result_compiled), "Results don't match!"
       print("✓ test_export_compiled_model_with_kwargs_dynamic_shapes PASSED")
       return True
   except Exception as e:
       print(f"✗ test_export_compiled_model_with_kwargs_dynamic_shapes FAILED")
       print(f"Error: {e}")
       import traceback
       traceback.print_exc()
       return False

if __name__ == "__main__":
   print("Testing export of compiled models with dynamic shapes\n")
   print("=" * 70)

   results = []
   results.append(test_export_compiled_model_with_nested_dynamic_shapes())
   results.append(test_export_compiled_model_with_kwargs_dynamic_shapes())

   print("\n" + "=" * 70)
   print(f"\nResults: {sum(results)}/{len(results)} tests passed")

   if all(results):
       print("✓ All tests passed!")
   else:
       print("✗ Some tests failed")
       exit(1)
```

It will report
```
======================================================================
Running test_export_compiled_model_with_nested_dynamic_shapes...
✗ test_export_compiled_model_with_nested_dynamic_shapes FAILED
Error: Detected mismatch between the structure of `inputs` and `dynamic_shapes`: `inputs[0]` is a <class 'tuple'>, but `dynamic_shapes[0]` is a <class 'dict'>
For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation

The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.
Traceback (most recent call last):
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/dynamic_shapes.py", line 614, in _tree_map_with_path
    return tree_map_with_path(f, tree, *dynamic_shapes, is_leaf=is_leaf)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/utils/_pytree.py", line 2055, in tree_map_with_path
    all_keypath_leaves = keypath_leaves + [treespec.flatten_up_to(r) for r in rests]
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/utils/_pytree.py", line 2055, in <listcomp>
    all_keypath_leaves = keypath_leaves + [treespec.flatten_up_to(r) for r in rests]
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/utils/_pytree.py", line 1188, in flatten_up_to
    helper(self, tree, subtrees)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/utils/_pytree.py", line 1185, in helper
    helper(subspec, subtree, subtrees)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/utils/_pytree.py", line 1141, in helper
    raise ValueError(
ValueError: Node type mismatch; expected <class 'tuple'>, but got <class 'dict'>.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/chzhu/infinitrain/test_exprot.py", line 25, in test_export_compiled_model_with_nested_dynamic_shapes
    ep = export(compiled_m, example_args, dynamic_shapes=dynamic_shapes, strict=True)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/__init__.py", line 311, in export
    raise e
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/__init__.py", line 277, in export
    return _export(
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_trace.py", line 1163, in wrapper
    raise e
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_trace.py", line 1129, in wrapper
    ep = fn(*args, **kwargs)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_trace.py", line 2255, in _export
    ep = _export_for_training(
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_trace.py", line 1163, in wrapper
    raise e
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_trace.py", line 1129, in wrapper
    ep = fn(*args, **kwargs)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_trace.py", line 2071, in _export_for_training
    export_artifact = export_func(
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_trace.py", line 1415, in _strict_export
    gm_torch_level = _export_to_torch_ir(
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_trace.py", line 785, in _export_to_torch_ir
    _check_dynamic_shapes(combined_args, dynamic_shapes)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/dynamic_shapes.py", line 1031, in _check_dynamic_shapes
    _tree_map_with_path(check_shape, combined_args, dynamic_shapes, tree_name="inputs")
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/dynamic_shapes.py", line 686, in _tree_map_with_path
    _compare(tree_spec, other_tree_spec, [])
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/dynamic_shapes.py", line 677, in _compare
    _compare(
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/dynamic_shapes.py", line 652, in _compare
    raise_mismatch_error(
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/dynamic_shapes.py", line 634, in raise_mismatch_error
    raise UserError(
torch._dynamo.exc.UserError: Detected mismatch between the structure of `inputs` and `dynamic_shapes`: `inputs[0]` is a <class 'tuple'>, but `dynamic_shapes[0]` is a <class 'dict'>
For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation

The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.

Running test_export_compiled_model_with_kwargs_dynamic_shapes...
✗ test_export_compiled_model_with_kwargs_dynamic_shapes FAILED
Error: When `dynamic_shapes` is specified as a dict, its top-level keys must be the arg names ['kwargs'] of `inputs`, but here they are ['a1', 'a2']. Since here `inputs` is a list/tuple enclosing a single dict, maybe you just forgot to enclose `dynamic_shapes` in a list/tuple?
For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation

The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.
Traceback (most recent call last):
  File "/home/chzhu/infinitrain/test_exprot.py", line 62, in test_export_compiled_model_with_kwargs_dynamic_shapes
    ep = export(compiled_m, example_args, kwargs=example_kwargs, dynamic_shapes=dynamic_shapes, strict=True)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/__init__.py", line 311, in export
    raise e
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/__init__.py", line 277, in export
    return _export(
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_trace.py", line 1163, in wrapper
    raise e
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_trace.py", line 1129, in wrapper
    ep = fn(*args, **kwargs)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_trace.py", line 2255, in _export
    ep = _export_for_training(
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_trace.py", line 1163, in wrapper
    raise e
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_trace.py", line 1129, in wrapper
    ep = fn(*args, **kwargs)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_trace.py", line 2071, in _export_for_training
    export_artifact = export_func(
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_trace.py", line 1415, in _strict_export
    gm_torch_level = _export_to_torch_ir(
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_trace.py", line 785, in _export_to_torch_ir
    _check_dynamic_shapes(combined_args, dynamic_shapes)
  File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/dynamic_shapes.py", line 1007, in _check_dynamic_shapes
    raise UserError(
torch._dynamo.exc.UserError: When `dynamic_shapes` is specified as a dict, its top-level keys must be the arg names ['kwargs'] of `inputs`, but here they are ['a1', 'a2']. Since here `inputs` is a list/tuple enclosing a single dict, maybe you just forgot to enclose `dynamic_shapes` in a list/tuple?
For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation

The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.

======================================================================
```
## Torch Version
(reproducible nightly version)

## Other Behavior
The model can export regularly when we test without compiling the model
```python
import torch
from torch.export import export, Dim

def test_export_compiled_model_with_nested_dynamic_shapes():
   """Test exporting a compiled model with nested dict inputs and dynamic shapes."""
   print("Running test_export_compiled_model_with_nested_dynamic_shapes...")

   class M(torch.nn.Module):
       def forward(self, data_batch):
           return data_batch["a1"] + data_batch["a2"]

   m = M()
   example_args = ({
       "a1": torch.ones(3, 3),
       "a2": torch.ones(3, 3),
   },)
   dynamic_shapes = ({
       "a1": {0: Dim.DYNAMIC},
       "a2": {0: Dim.DYNAMIC},
   },)

   try:
       ep = export(m, example_args, dynamic_shapes=dynamic_shapes, strict=True)
       gm = ep.module()
       result_exported = gm(*example_args)
       result_compiled = m(*example_args)

       assert torch.allclose(result_exported, result_compiled), "Results don't match!"
       print("✓ test_export_compiled_model_with_nested_dynamic_shapes PASSED")
       return True
   except Exception as e:
       print(f"✗ test_export_compiled_model_with_nested_dynamic_shapes FAILED")
       print(f"Error: {e}")
       import traceback
       traceback.print_exc()
       return False

def test_export_compiled_model_with_kwargs_dynamic_shapes():
   """Test exporting a compiled model with kwargs and dynamic shapes."""
   print("\nRunning test_export_compiled_model_with_kwargs_dynamic_shapes...")

   class M(torch.nn.Module):
       def forward(self, a1, a2):
           return a1 + a2

   m = M()
   example_args = ()
   example_kwargs = {
       "a1": torch.ones(3, 3),
       "a2": torch.ones(3, 3),
   }
   dynamic_shapes = {
       "a1": {0: Dim.DYNAMIC},
       "a2": {0: Dim.DYNAMIC},
   }

   try:
       ep = export(m, example_args, kwargs=example_kwargs, dynamic_shapes=dynamic_shapes, strict=True)
       gm = ep.module()
       result_exported = gm(**example_kwargs)
       result_compiled = m(**example_kwargs)

       assert torch.allclose(result_exported, result_compiled), "Results don't match!"
       print("✓ test_export_compiled_model_with_kwargs_dynamic_shapes PASSED")
       return True
   except Exception as e:
       print(f"✗ test_export_compiled_model_with_kwargs_dynamic_shapes FAILED")
       print(f"Error: {e}")
       import traceback
       traceback.print_exc()
       return False

if __name__ == "__main__":
   print("Testing export of compiled models with dynamic shapes\n")
   print("=" * 70)

   results = []
   results.append(test_export_compiled_model_with_nested_dynamic_shapes())
   results.append(test_export_compiled_model_with_kwargs_dynamic_shapes())

   print("\n" + "=" * 70)
   print(f"\nResults: {sum(results)}/{len(results)} tests passed")

   if all(results):
       print("✓ All tests passed!")
   else:
       print("✗ Some tests failed")
       exit(1)

```
## Root Cause

This is because of a side effect of torch.compile(model). When the model is being compiled, the input signature will become (*args, **kwargs) automatically. In the above example, the `data_batch` will be added into `args` in combined_args [here](dc011d3203/torch/export/dynamic_shapes.py (L720)), and it will look like
```
{'args': ({'a1': tensor([[1., 1., 1.]... 1., 1.]]), 'a2': tensor([[1., 1., 1.]... 1., 1.]])},)}
```
Without the compiling, the combined args will look like
```
{'data_batch': {'a1': tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]]), 'a2': tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])}}

```
Thus causing the mismatch when we use treemap to match the dynamic shape with the input argos

The error is also reproducible when we setup kwargs as example argos (see the 2nd test above)
## Fix
Proposed fix: In [_combine_args](dc011d3203/torch/export/dynamic_shapes.py (L720)) we explicitly flatten out the kwargs and args into combined args.
## Side Effects
There are 2 existing tests that assume this behavior and
1. add `args` explicitly to dynamic shapes
2. wrap args into nested format in dynamic_shape

I have modified those test to make args and dynamic_shapes to be in consistent format.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166358
Approved by: https://github.com/angelayi
2025-11-11 19:04:58 +00:00
bd99ae3315 [Docs] Add warning that torch.export.load uses pickle (#167557)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167557
Approved by: https://github.com/zhxchen17, https://github.com/angelayi
2025-11-11 18:47:14 +00:00
ce8672c24f Fix use of TORCH_CHECK in torch/csrc/stable (#167495)
Tested by above PR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167495
Approved by: https://github.com/janeyx99
ghstack dependencies: #166579, #166694, #166695, #167362
2025-11-11 17:58:30 +00:00
402c465030 [ARM] Improve LLM performance & mem usage using int4-bf16 KleidiAI kernels (#158250)
Co-authored-by: Nikhil Gupta [nikhil.gupta2@arm.com](mailto:nikhil.gupta2@arm.com)

This PR enables the use of KleidiAI INT4 kernels that directly produce BF16 outputs within PyTorch to boost LLM prefill & decode performance

**This change improves decode throughput by ~15% & reduces memory required to inference the model by 50%**

### Benchmark Setup
```
Model: meta-llama/Llama-3.1-8B
Test Platform: Neoverse V2
```
### Detailed Results

| Metric                           | With `--compile`         | Without `--compile`      |
|----------------------------------|---------------------------|---------------------------|
| Quantization Scheme              | INT4 symmetric channelwise | INT4 symmetric channelwise |
| Input Precision                  | BF16                      | BF16                      |
| Number of Layers Quantized       | 32                        | 32                        |
| Average Compression Ratio        | 87.49%                    | 87.49%                    |
| Total Quantization Time (s)      | 9.62                      | 10.32                     |
| Compile Time (First) (s)         | 134.48                    | 1.69                      |
| Compile Time (Second) (s)        | 80.44                     | 1.60                      |
| Compile Time (Subsequent) (s)    | 0.19                      | 0.22                      |
| Prefill Tokens                   | 54                        | 54                        |
| Decoded Tokens                   | 33                        | 33                        |
| Prefill Time (s)                 | 0.19                      | 0.22                      |
| Decode Time (s)                  | 0.76                      | 1.38                      |
| E2E Generation Time (s)          | 0.95                      | 1.60                      |
| Prefill Throughput (tokens/s)    | 288.13                    | 249.91                    |
| Decode Throughput (tokens/s)     | 43.42                     | 23.83                     |
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158250
Approved by: https://github.com/malfet, https://github.com/aditew01, https://github.com/fadara01

Co-authored-by: Nikhil Gupta <nikhil.gupta2@arm.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-11-11 17:50:22 +00:00
573a79fffa [OpenReg] Initialize device stream states for all devices in initOpenRegStreamsOnce (#167528)
Fixes #167527

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167528
Approved by: https://github.com/fffrog
2025-11-11 16:53:22 +00:00
4945180468 Add empty tensor check for _pad_packed_sequence (#167521)
That prevents null pointer dereference

Fixes https://github.com/pytorch/pytorch/issues/149622
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167521
Approved by: https://github.com/albanD
2025-11-11 16:46:13 +00:00
1df723e6f5 [inductor] Fix constant creation (#167398)
We ran into this issue when debugging inductor-lite. Calling `torch.tensor` within a fake mode (which is the case inside of inductor) will create a FakeTensor, which causes this FakeTensor to be used as a constant within inductor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167398
Approved by: https://github.com/eellison, https://github.com/BoyuanFeng
2025-11-11 16:30:46 +00:00
f9b81e23e4 [ROCm] Disable group gemm CK path when composable kernel (CK) is not enabled (#167403)
For ROCm builds without CK support, ensure use_fast_path is false so that the CK path is not triggered, since CK is currently not available in this configuration.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167403
Approved by: https://github.com/Skylion007, https://github.com/ScottTodd, https://github.com/jeffdaily
2025-11-11 16:15:51 +00:00
ffe6cc39c7 [inductor] Optimize cold compile time when cudagraphs-partition is enabled (#167132)
Summary: When cudagraphs-parittion is enabled, we have seen an increase of cold compile time in the vllm benchmark (see https://github.com/vllm-project/vllm/issues/27080). After some profiling, we found Triton compilation time increased the most. Further investigation reveals it was caused by duplicated Triton kernels not being shared among different partitions. This PR fixes the issue by reusing the Trition kernel source code cache at the top-level PythonWrapperCodegen.

In theory we could further reduce the compilation time by completely skipping compiling duplicated partitions. That can come as a furture improvement.

Some vllm benchmarking data,

```
VLLM_USE_STANDALONE_COMPILE=0 VLLM_DISABLE_COMPILE_CACHE=1 vllm bench latency -O.cudagraph_mode=PIECEWISE -O.use_inductor_graph_partition=True --model meta-llama/Meta-Llama-3.1-8
```
Before:
```
torch.compile takes 69.18 s in total
```
After:
```
torch.compile takes 26.81 s in total
```

As a refrence, this is the compile time when turning off inductor graph partition. Looks like we still have some gap to close.
```
VLLM_USE_STANDALONE_COMPILE=0 VLLM_DISABLE_COMPILE_CACHE=1 vllm bench latency -O.cudagraph_mode=PIECEWISE -O.use_inductor_graph_partition=False --model meta-llama/Meta-Llama-3.1-8B

torch.compile takes 19.41 s in total
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167132
Approved by: https://github.com/eellison
ghstack dependencies: #167131
2025-11-11 15:54:31 +00:00
db1f3f6901 [inductor] Only generate compile-time auto-tuning block in the main graph (#167131)
Summary: When cudagraphs partition and autotune_at_compile_time are enabled, currently each subgraph will generate its own auto-tuning code block and run them once by one. This PR improves it by only generating one auto-tuning code block at the main graph level and execute it once time to auto-tune all the kernels.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167131
Approved by: https://github.com/eellison
2025-11-11 15:54:31 +00:00
43041f0a43 Remove superflous/misplaced TestFailure specs (#165989)
The tests are in class `TestInductorDynamic` which isn't affected by the `test_failures` dict which is only used as an argument to `copy_tests` for the `CommonTemplate` defined in another file.

So those have no effect.

Idea: Enhance `copy_tests` to detect unused keys

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165989
Approved by: https://github.com/benjaminglass1, https://github.com/ezyang
2025-11-11 15:36:43 +00:00
dc00842b81 [ROCm][CI] trigger magma build with gfx950 for ROCm7.1 (#167390)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167390
Approved by: https://github.com/jeffdaily
2025-11-11 15:17:37 +00:00
f1a129a6d0 Clarify that crashes/OOB accesses and not security threats (#167519)
Added note on crashes and out of bounds access in PyTorch.

Addresses https://github.com/pytorch/pytorch/issues/166881#issuecomment-3513245388

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167519
Approved by: https://github.com/albanD
2025-11-11 15:14:51 +00:00
fad48ffa62 [ROCm][CI] Match workflow names with workflow file names (#167483)
Fixes issue with uploading artifacts, which was inadvertently disabled for some renamed workflows via https://github.com/pytorch/pytorch/pull/167225

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167483
Approved by: https://github.com/jeffdaily
2025-11-11 14:45:44 +00:00
55 changed files with 2245 additions and 572 deletions

View File

@ -30,7 +30,6 @@ into a tarball, with the following structure:
More specifically, `build_magma.sh` copies over the relevant files from the `package_files` directory depending on the ROCm version.
Outputted binaries should be in the `output` folder.
## Pushing
Packages can be uploaded to an S3 bucket using:

View File

@ -96,7 +96,6 @@ function pip_build_and_install() {
python3 -m pip wheel \
--no-build-isolation \
--no-deps \
--no-use-pep517 \
-w "${wheel_dir}" \
"${build_target}"
fi

View File

@ -63,7 +63,7 @@ self-hosted-runner:
- linux.rocm.gpu.gfx942.1
- linux.rocm.gpu.gfx942.2
- linux.rocm.gpu.gfx942.4
- rocm-docker
- linux.rocm.gfx942.docker-cache
# Org wise AWS `mac2.metal` runners (2020 Mac mini hardware powered by Apple silicon M1 processors)
- macos-m1-stable
- macos-m1-14

View File

@ -1,55 +0,0 @@
name: docker-cache-mi300
on:
# run every 6 hours
schedule:
- cron: 0 0,6,12,18 * * *
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name }}
cancel-in-progress: true
permissions:
id-token: write
contents: read
jobs:
docker-cache:
if: github.repository_owner == 'pytorch'
runs-on: rocm-docker
steps:
- name: Checkout PyTorch
uses: pytorch/pytorch/.github/actions/checkout-pytorch@main
with:
no-sudo: true
- name: configure aws credentials
id: aws_creds
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
with:
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
aws-region: us-east-1
role-duration-seconds: 18000
- name: Login to Amazon ECR
id: login-ecr
continue-on-error: false
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
- name: Calculate docker image
id: calculate-docker-image
uses: pytorch/test-infra/.github/actions/calculate-docker-image@main
with:
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
push: false
- name: Pull docker image
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
with:
docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }}
- name: Tar and upload to S3 bucket
run: |
sudo docker save -o ~/docker-data/pytorch/pytorch_docker_image.tar ${{ steps.calculate-docker-image.outputs.docker-image }}
sudo rclone copy -P --s3-upload-concurrency 64 --s3-chunk-size 200M --s3-upload-cutoff 300M ~/docker-data/pytorch/pytorch_docker_image.tar oci:pytorchbucket0002/pytorch_docker_image --progress

108
.github/workflows/docker-cache-rocm.yml vendored Normal file
View File

@ -0,0 +1,108 @@
name: docker-cache-rocm
on:
workflow_run:
workflows: [docker-builds]
# TODO: Uncomment before merging
#branches: [main, release]
types:
- completed
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name }}
cancel-in-progress: true
permissions:
id-token: write
contents: read
actions: read
jobs:
download-docker-builds-artifacts:
if: github.repository_owner == 'pytorch'
name: download-docker-builds-artifacts
runs-on: ubuntu-latest
outputs:
pytorch-linux-jammy-rocm-n-py3: ${{ steps.process-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}
pytorch-linux-noble-rocm-n-py3: ${{ steps.process-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}
pytorch-linux-jammy-rocm-n-py3-benchmarks: ${{ steps.process-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }}
steps:
- name: Download artifacts
uses: actions/download-artifact@v4.1.7
with:
run-id: ${{ github.event.workflow_run.id }}
path: ./docker-builds-artifacts
merge-multiple: true
github-token: ${{ secrets.GITHUB_TOKEN }}
- name: Process artifacts
id: process-artifacts
run: |
ls -R ./docker-builds-artifacts
cat ./docker-builds-artifacts/*txt >> "${GITHUB_OUTPUT}"
cat "${GITHUB_OUTPUT}"
docker-cache:
if: github.repository_owner == 'pytorch'
needs: download-docker-builds-artifacts
strategy:
fail-fast: false
matrix:
runner: [linux.rocm.gfx942.docker-cache]
docker-image: [
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}",
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}",
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }}"
]
runs-on: "${{ matrix.runner }}"
steps:
- name: debug
run: |
JSON_STRINGIFIED="${{ toJSON(needs.download-docker-builds-artifacts.outputs) }}"
echo "Outputs of download-docker-builds-artifacts job: ${JSON_STRINGIFIED}"
- name: configure aws credentials
id: aws_creds
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
with:
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
aws-region: us-east-1
role-duration-seconds: 18000
- name: Login to Amazon ECR
id: login-ecr
continue-on-error: false
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
- name: Generate ghrc.io tag
id: ghcr-io-tag
run: |
ecr_image="${{ matrix.docker-image }}"
ghcr_image="ghcr.io/pytorch/ci-image:${ecr_image##*:}"
echo "ghcr_image=${ghcr_image}" >> "$GITHUB_OUTPUT"
- name: Pull docker image
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
with:
docker-image: ${{ steps.ghcr-io-tag.outputs.ghcr_image }}
- name: Save as tarball
run: |
docker_image_tag=${{ matrix.docker-image }}
docker_image_tag="${docker_image_tag#*:}" # Remove everything before and including first ":"
docker_image_tag="${docker_image_tag%-*}" # Remove everything after and including last "-"
ref_name=${{ github.event.workflow_run.head_branch }}
if [[ $ref_name =~ "release/" ]]; then
ref_suffix="release"
elif [[ $ref_name == "main" ]]; then
ref_suffix="main"
else
# TODO: Remove below
ref_suffix="main"
# echo "Unexpected branch in ref_name: ${ref_name}" && exit 1
fi
docker tag ${{ steps.ghcr-io-tag.outputs.ghcr_image }} ${{ matrix.docker-image }}
# mv is atomic operation, so we use intermediate tar.tmp file to prevent read-write contention
docker save -o ~/pytorch-data/docker/${docker_image_tag}.tar.tmp ${{ matrix.docker-image }}
mv ~/pytorch-data/docker/${docker_image_tag}.tar.tmp ~/pytorch-data/docker/${docker_image_tag}_${ref_suffix}.tar

View File

@ -1,4 +1,4 @@
name: inductor-rocm
name: inductor-rocm-mi200
on:
schedule:

View File

@ -1,4 +1,4 @@
name: rocm
name: rocm-mi200
on:
push:

View File

@ -18,6 +18,8 @@ Please report security issues using https://github.com/pytorch/pytorch/security/
All reports submitted through the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
**Note on crashes and out of bounds access**: PyTorch is a computational framework that performs operations on behalf of the caller. Like many low-level libraries, PyTorch generally does not validate all inputs to every function—the responsibility for providing valid arguments lies with the calling code. While crashes and out of bounds memory access should be reported as bugs, they are generally not considered security vulnerabilities in PyTorch's threat model.
Please refer to the following page for our responsible disclosure policy, reward guidelines, and those things that should not be reported:
https://www.facebook.com/whitehat

View File

@ -142,6 +142,7 @@ Tensor _pack_padded_sequence_backward_symint(const Tensor& grad, c10::SymIntArra
std::tuple<Tensor, Tensor> _pad_packed_sequence(const Tensor& data, const Tensor& _batch_sizes, bool batch_first, const Scalar& padding_value, int64_t total_length) {
auto batch_sizes_t = _batch_sizes.contiguous();
checkLongTensor(batch_sizes_t);
TORCH_CHECK(batch_sizes_t.numel() > 0, "batch_sizes can not be empty");
int64_t * batch_sizes = batch_sizes_t.data_ptr<int64_t>();
int64_t max_batch_size = batch_sizes[0];

View File

@ -669,9 +669,12 @@ std::optional<c10::ScalarType> out_dtype) {
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
bool use_fast_path = false;
// On non CK system(w/ ROCm), make sure use_fast_path is false
#if defined(USE_ROCM_CK_GEMM)
if (at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950"})) {
use_fast_path = true;
}
#endif //USE_ROCM_CK_GEMM
#endif
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
@ -680,7 +683,11 @@ std::optional<c10::ScalarType> out_dtype) {
#ifndef USE_ROCM
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
#else
#if defined(USE_ROCM_CK_GEMM)
at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out);
#else
TORCH_WARN("ROCm: Group Gemm through CK not selected.");
#endif //USE_ROCM_CK_GEMM
#endif
} else {
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);

View File

@ -47,6 +47,7 @@
#include <c10/macros/Macros.h>
#include <thrust/copy.h>
#include <thrust/device_ptr.h>
#include <thrust/distance.h>
#include <thrust/for_each.h>
#include <thrust/functional.h>
#include <thrust/gather.h>

View File

@ -50,7 +50,7 @@ nfnet_l0,pass,7
repvgg_a2,fail_accuracy,7
repvgg_a2,pass,7

1 name accuracy graph_breaks
50
51
52
53
54
55
56

View File

@ -14,6 +14,10 @@ Utils
sdpa_kernel
SDPBackend
register_flash_attention_impl
activate_flash_attention_impl
list_flash_attention_impls
current_flash_attention_impl
Submodules
----------

View File

@ -10,7 +10,7 @@ tp2_dir="$top_dir/third_party"
pip install ninja
# Install onnx
pip install --no-use-pep517 -e "$tp2_dir/onnx"
pip install -e "$tp2_dir/onnx"
# Install caffe2 and pytorch
pip install -r "$top_dir/caffe2/requirements.txt"

View File

@ -140,6 +140,11 @@ static void initDeviceStreamState(DeviceIndex device_index) {
static void initOpenRegStreamsOnce() {
c10::call_once(init_flag, initGlobalStreamState);
for (const auto i : c10::irange(num_devices)) {
c10::call_once(
device_flags[i], initDeviceStreamState, static_cast<DeviceIndex>(i));
}
if (current_streams) {
return;
}
@ -202,8 +207,6 @@ OpenRegStream getStreamFromPool(const int priority, DeviceIndex device_index) {
if (device_index == -1) {
device_index = current_device();
}
c10::call_once(
device_flags[device_index], initDeviceStreamState, device_index);
auto pri_idx =
std::clamp(priority, 0, max_compile_time_stream_priorities - 1);
const auto idx = get_idx(priority_counters[device_index][pri_idx]);

View File

@ -180,6 +180,47 @@ class TestTrackerFullyShard1DTrainingCore(FSDPTest):
del model
del optim
def _test_tracker_multihandler_hook(self):
"""Should run without KeyError."""
class TestModule(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.norm1 = nn.RMSNorm(dim)
self.output1 = nn.Linear(dim, dim)
self.norm2 = nn.RMSNorm(dim)
self.output2 = nn.Linear(dim, dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.norm1(x)
x = self.output1(x)
x = self.norm2(x)
x = self.output2(x)
return x
gc.collect()
torch.manual_seed(42)
dev = torch.device(torch.accelerator.current_device_index())
with torch.device(dev):
model = TestModule(128)
mesh = init_device_mesh(dev.type, (self.world_size,))
fully_shard([model.norm1, model.output1], mesh=mesh)
fully_shard([model.norm2, model.output2], mesh=mesh)
fully_shard(model, mesh=mesh)
fmt = FSDPMemTracker(model)
with fmt:
inp = torch.randn(16, 128, device=dev)
y = model(inp)
loss = y.sum()
loss.backward()
del inp
del model
class TestTrackerFullyShard1DTrainingCompose(FSDPTest):
@property

View File

@ -1,6 +1,7 @@
# Owner(s): ["oncall: distributed"]
import contextlib
import unittest
import torch
import torch.distributed as dist
@ -371,6 +372,7 @@ class DTensorExportTest(TestCase):
# aot_export_joint_with_descriptors on strict-exported exported_program.module()
# is producing a joint graph with backward region missing
@unittest.expectedFailure
def test_strict_export_parallelize_module_with_dtensor_input(self):
self._run_test(strict_export_and_aot_export_joint_with_descriptors)

View File

@ -15,7 +15,7 @@ import torch._functorch.config
import torch.distributed as dist
import torch.nn as nn
import torch.utils.checkpoint
from functorch.compile import default_partition, min_cut_rematerialization_partition
from functorch.compile import min_cut_rematerialization_partition
from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.testing import (
AotEagerAndRecordGraphs,
@ -24,7 +24,7 @@ from torch._dynamo.testing import (
)
from torch._higher_order_ops.wrap import tag_activation_checkpoint
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import IS_WINDOWS, parametrize, skipIfHpu
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu
from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON
from torch.testing._internal.triton_utils import requires_cuda_and_triton
from torch.testing._internal.two_tensor import TwoTensor
@ -281,14 +281,7 @@ class ActivationCheckpointingViaTagsTests(
run(export_compiler)
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_function(self, device, partition_fn):
def test_tags_function(self, device):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
@ -304,22 +297,11 @@ class ActivationCheckpointingViaTagsTests(
bw_compiler = functools.partial(
count_ops, freq=3, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
self._validate(fn, backend, x, y)
@requires_cuda_and_triton
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_function_via_global_checkpoint(self, device, partition_fn):
def test_tags_function_via_global_checkpoint(self, device):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
@ -334,28 +316,17 @@ class ActivationCheckpointingViaTagsTests(
bw_compiler = functools.partial(
count_ops, freq=3, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
self._validate(fn, backend, x, y)
@requires_cuda_and_triton
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_function_with_kwargs(self, device, partition_fn):
def test_tags_function_with_kwargs(self, device):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn, torch.sin(x), y, use_reentrant=False
gn, torch.sin(x), y, use_reentrant=True, preserve_rng_state=False
)
x = torch.randn(4, 4, device=device, requires_grad=True)
@ -365,22 +336,11 @@ class ActivationCheckpointingViaTagsTests(
bw_compiler = functools.partial(
count_ops, freq=3, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
self._validate(fn, backend, x, y)
@requires_cuda_and_triton
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_sequential_layers(self, device, partition_fn):
def test_tags_sequential_layers(self, device):
def gn(x):
x = x.cos()
for _ in range(3):
@ -401,22 +361,11 @@ class ActivationCheckpointingViaTagsTests(
freqs=[2, 18],
ops=[torch.ops.aten.cos.default, torch.ops.aten.mm.default],
) # mm recomputed in the bwd
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
self._validate(fn, backend, x)
@requires_cuda_and_triton
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_multiple_checkpoints(self, device, partition_fn):
def test_tags_multiple_checkpoints(self, device):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
@ -434,22 +383,11 @@ class ActivationCheckpointingViaTagsTests(
bw_compiler = functools.partial(
count_ops, freq=6, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
self._validate(fn, backend, x, y)
@requires_cuda_and_triton
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_module(self, device, partition_fn):
def test_tags_module(self, device):
class MockModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
@ -473,22 +411,11 @@ class ActivationCheckpointingViaTagsTests(
bw_compiler = functools.partial(
count_ops, freq=1, op=torch.ops.aten.sigmoid.default
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
self._validate(fn, backend, x)
@requires_cuda_and_triton
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_decomps(self, device, partition_fn):
def test_tags_decomps(self, device):
# Ensures that tags are passed on through decompositions as well
class MockModule(torch.nn.Module):
def __init__(self) -> None:
@ -516,7 +443,6 @@ class ActivationCheckpointingViaTagsTests(
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
decompositions=lambda: import_module(
"torch._inductor.compile_fx"
).select_decomp_table(),
@ -776,14 +702,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_must_recompute(self, device, partition_fn):
def test_compile_selective_checkpoint_must_recompute(self, device):
def context_fn_must_recompute_mm():
must_recompute_list = [
torch.ops.aten.mm.default,
@ -804,9 +723,9 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
),
)
def _test(context_fn, bw_compiler, partition_fn):
def _test(context_fn, bw_compiler):
def gn(x):
return torch.cos(torch.sin(torch.matmul(x, x) @ x))
return torch.sigmoid(torch.matmul(x, x))
def fn(x):
return torch.utils.checkpoint.checkpoint(
@ -820,14 +739,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
fw_compiler = functools.partial(
count_ops,
freq=2,
freq=1,
op=torch.ops.aten.mm.default,
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x)
@ -835,19 +754,17 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
context_fn=context_fn_must_recompute_mm,
bw_compiler=functools.partial(
count_ops,
freq=6, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 2 + 2 * 2 = 6)
freq=3, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 1 + 2 * 1 = 3)
op=torch.ops.aten.mm.default,
),
partition_fn=partition_fn,
)
_test(
context_fn=context_fn_no_recompute_mm,
bw_compiler=functools.partial(
count_ops,
freq=4, # 2 bwd mm ops per fwd matmul
freq=2, # 2 bwd mm ops per fwd matmul
op=torch.ops.aten.mm.default,
),
partition_fn=partition_fn,
)
def test_sac_with_partial_context_fn(self):
@ -884,16 +801,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_must_not_recompute_gemm(
self, device, partition_fn
):
def test_compile_selective_checkpoint_must_not_recompute_gemm(self, device):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
@ -933,22 +841,15 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_must_not_recompute_gemm_no_functionalization(
self, device, partition_fn
self, device
):
def selective_checkpointing_context_fn():
no_recompute_list = [
@ -988,7 +889,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
partition_fn=min_cut_rematerialization_partition,
disable_functionalization=True,
)
self._validate(fn, backend, x, y)
@ -996,14 +897,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_triton_kernel(self, device, partition_fn):
def test_compile_selective_checkpoint_triton_kernel(self, device):
# Copy of the above test, but make sure that having a triton kernel in the
# region does not error.
def add_one(x):
@ -1063,21 +957,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_tensor_subclass(self, device, partition_fn):
def test_compile_selective_checkpoint_tensor_subclass(self, device):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
@ -1120,21 +1007,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_custom_rule(self, device, partition_fn):
def test_compile_selective_checkpoint_custom_rule(self, device):
def _get_custom_policy(meta):
no_recompute_list = [
torch.ops.aten.mm.default,
@ -1192,21 +1072,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_partial_ctx_fn(self, device, partition_fn):
def test_compile_selective_checkpoint_partial_ctx_fn(self, device):
def selective_checkpointing_context_fn(no_recompute_list):
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list)
@ -1245,21 +1118,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_outplace_op(self, device, partition_fn):
def test_compile_selective_checkpoint_outplace_op(self, device):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
@ -1297,21 +1163,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_list_ops(self, device, partition_fn):
def test_compile_selective_checkpoint_list_ops(self, device):
def selective_checkpointing_context_fn():
# recompute everything
no_recompute_list = []
@ -1347,7 +1206,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@ -1358,14 +1217,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
"requires TorchDispatchMode + torch.compile work to complete"
)
@requires_cuda_and_triton
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_inplace_op(self, device, partition_fn):
def test_compile_selective_checkpoint_inplace_op(self, device):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
@ -1405,7 +1257,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@ -1413,14 +1265,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@torch._inductor.config.patch(fallback_random=True)
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_random_op(self, device, partition_fn):
def test_compile_selective_checkpoint_random_op(self, device):
for preserve_rng_state in [True, False]:
def selective_checkpointing_context_fn():
@ -1467,7 +1312,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
partition_fn=min_cut_rematerialization_partition,
)
# NOTE: when `preserve_rng_state` is False, gradient will mismatch between torch.compile and eager,
@ -1479,14 +1324,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_invalid_context(self, partition_fn):
def test_compile_selective_checkpoint_invalid_context(self):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y)) * y
@ -1515,7 +1353,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
partition_fn=min_cut_rematerialization_partition,
)
with self.assertRaisesRegex(
Exception, "must generate a tuple of two `TorchDispatchMode`s"
@ -1524,14 +1362,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_parametrization(self, partition_fn):
def test_compile_selective_checkpoint_parametrization(self):
def sac_policy():
def _recomp_policy():
def _custom_policy(ctx, func, *args, **kwargs):
@ -1594,9 +1425,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
bw_compiler = functools.partial(
count_ops,
freqs=[
# 1 from mul recompute, 1 from mul backward
# w/o CSE, we have one extra mul
3 if partition_fn is default_partition else 2,
2, # 1 from mul recompute, 1 from mul backward
1,
],
ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default],
@ -1605,7 +1434,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
partition_fn=min_cut_rematerialization_partition,
)
model = MLPModule()

View File

@ -2363,6 +2363,34 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
self.assertTrue(same(output, expected))
assert cnt.frame_count == 1
@unittest.skipIf(sys.version_info < (3, 13), "math.fma introduced in python 3.13")
def test_math_fma(self):
def fma_func(a, b, c):
return math.fma(a, b, c)
# Test with scalar constants (constant folding path)
cnt = torch._dynamo.testing.CompileCounter()
cfma_scalars = torch._dynamo.optimize_assert(cnt)(fma_func)
assert cnt.frame_count == 0
expected = fma_func(2.0, 3.0, 4.0)
output = cfma_scalars(2.0, 3.0, 4.0)
self.assertEqual(output, expected)
assert cnt.frame_count == 0
# Test with tensors (Inductor path)
cnt2 = torch._dynamo.testing.CompileCounter()
cfma_tensors = torch._dynamo.optimize_assert(cnt2)(fma_func)
assert cnt2.frame_count == 0
x = torch.tensor(2.0)
y = torch.tensor(3.0)
z = torch.tensor(4.0)
expected_tensors = x * y + z
output_tensors = cfma_tensors(x, y, z)
torch.testing.assert_close(output_tensors, expected_tensors)
assert cnt2.frame_count == 1
@make_test
def test_numpy_meshgrid(x, y):
r1, r2 = np.meshgrid(x.numpy(), y.numpy())

View File

@ -335,6 +335,59 @@ class <lambda>(torch.nn.Module):
""",
)
@requires_cuda
@requires_multigpu()
def test_new_event_api(self) -> None:
from torch._dynamo.graph_bytecode_inputs import get_external_object_by_index
from torch._dynamo.variables.streams import new_event
def event_generation_backend(gm, *args, **kwargs): # type: ignore[no-untyped-def]
e0_ind = new_event()
with torch.Stream(device="cuda:1"):
get_external_object_by_index(e0_ind).record()
e1_ind = new_event()
self.assertNotEqual(e0_ind, e1_ind)
self.assertNotEqual(
get_external_object_by_index(e0_ind),
get_external_object_by_index(e1_ind),
)
with gm.graph.inserting_after(next(iter(gm.graph.nodes))):
gm.graph.call_function(
get_external_object_by_index, args=(1,), kwargs={}
)
return gm
@torch.compile(backend=event_generation_backend)
def fn(x):
return x + 1
fn(torch.ones(2, 2, device="cuda:0"))
@requires_cuda
def test_new_stream_api(self) -> None:
from torch._dynamo.graph_bytecode_inputs import get_external_object_by_index
from torch._dynamo.variables.streams import new_stream
def stream_generation_backend(gm, *args, **kwargs): # type: ignore[no-untyped-def]
s0_ind = new_stream()
s1_ind = new_stream()
self.assertNotEqual(s0_ind, s1_ind)
self.assertNotEqual(
get_external_object_by_index(s0_ind),
get_external_object_by_index(s1_ind),
)
with gm.graph.inserting_after(next(iter(gm.graph.nodes))):
gm.graph.call_function(
get_external_object_by_index, args=(1,), kwargs={}
)
return gm
@torch.compile(backend=stream_generation_backend)
def fn(x):
return x + 1
fn(torch.ones(2, 2, device="cuda:0"))
@requires_cuda
def test_stream_with_mutation(self):
def fn(x, y):
@ -523,6 +576,23 @@ class <lambda>(torch.nn.Module):
torch.accelerator.set_stream(original_stream)
reset_user_object_tracking()
@requires_cuda
def test_run_opcheck_wait_record_stream(self):
from torch._dynamo.variables.streams import wait_stream
from torch.library import opcheck
s0 = torch.Stream()
s1 = torch.Stream()
s2 = torch.Stream()
store_user_object_weakrefs(s0, s1, s2)
sample_inputs = [
(0, 1),
(2, 0),
]
for args in sample_inputs:
opcheck(wait_stream, args)
@requires_cuda
def test_inductor_lowering(self):
with patch("torch._inductor.config.implicit_fallbacks", False):

View File

@ -331,7 +331,12 @@ class TestDynamismExpression(TestCase):
return torch.ops.aten.slice.Tensor(*args)
inp = (torch.rand((10, 3, 224, 224)), 0, 0, 9223372036854775807)
dynamic_shapes = (({0: Dim("dim")}, None, None, None),)
dynamic_shapes = (
{0: Dim("dim")},
None,
None,
None,
)
torch.export.export(
Slice(),
inp,
@ -5533,21 +5538,11 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
w = Wrapped()
if is_retracebility_test(self._testMethodName):
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"Detected mismatch between the structure of `inputs` and `dynamic_shapes`"
": `inputs` has 2 elements, but `dynamic_shapes` has 1 elements",
):
export(w, args, dynamic_shapes={"args": ({0: batch}, {0: batch})})
else:
compiled = export(
w, args, dynamic_shapes={"args": ({0: batch}, {0: batch})}
)
expected = w(*args)
mod = compiled.module()
got = mod(*args)
self.assertTrue(torch.allclose(expected, got))
compiled = export(w, args, dynamic_shapes=({0: batch}, {0: batch}))
expected = w(*args)
mod = compiled.module()
got = mod(*args)
self.assertTrue(torch.allclose(expected, got))
def test_dynamic_shapes_builder_basic(self):
class M(torch.nn.Module):
@ -17504,6 +17499,105 @@ def forward(self, x):
exported_param_names = [name for name, _ in gm.named_parameters()]
self.assertEqual(original_param_names, exported_param_names)
def test_export_compiled_model_with_nested_dynamic_shapes(self):
class M(torch.nn.Module):
def forward(self, data_batch):
return data_batch["a1"] + data_batch["a2"]
m = M()
compiled_m = torch.compile(m)
example_args = (
{
"a1": torch.ones(3, 3),
"a2": torch.ones(3, 3),
},
)
dynamic_shapes = (
{
"a1": {0: Dim.DYNAMIC},
"a2": {0: Dim.DYNAMIC},
},
)
ep = export(
compiled_m, example_args, dynamic_shapes=dynamic_shapes, strict=True
)
gm = ep.module()
self.assertEqual(gm(*example_args), compiled_m(*example_args))
def test_export_model_with_nested_dynamic_shapes(self):
class M(torch.nn.Module):
def forward(self, data_batch):
return data_batch["a1"] + data_batch["a2"]
m = M()
example_args = (
{
"a1": torch.ones(3, 3),
"a2": torch.ones(3, 3),
},
)
B = torch.export.Dim("batch", min=1, max=65536)
dynamic_shapes = (
{
"a1": {0: B},
"a2": {0: B},
},
)
ep = export(m, example_args, dynamic_shapes=dynamic_shapes, strict=True)
gm = ep.module()
self.assertEqual(gm(*example_args), m(*example_args))
def test_export_compiled_model_with_kwargs_dynamic_shapes(self):
class M(torch.nn.Module):
def forward(self, a1, a2):
return a1 + a2
m = M()
compiled_m = torch.compile(m)
example_args = ()
example_kwargs = {
"a1": torch.ones(3, 3),
"a2": torch.ones(3, 3),
}
dynamic_shapes = {
"a1": {0: Dim.DYNAMIC},
"a2": {0: Dim.DYNAMIC},
}
ep = export(
compiled_m,
example_args,
kwargs=example_kwargs,
dynamic_shapes=dynamic_shapes,
strict=True,
)
gm = ep.module()
self.assertEqual(gm(**example_kwargs), compiled_m(**example_kwargs))
def test_export_model_with_kwargs_dynamic_shapes(self):
class M(torch.nn.Module):
def forward(self, a1, a2):
return a1 + a2
m = M()
example_args = ()
example_kwargs = {
"a1": torch.ones(3, 3),
"a2": torch.ones(3, 3),
}
dynamic_shapes = {
"a1": {0: Dim.DYNAMIC},
"a2": {0: Dim.DYNAMIC},
}
ep = export(
m,
example_args,
kwargs=example_kwargs,
dynamic_shapes=dynamic_shapes,
strict=True,
)
gm = ep.module()
self.assertEqual(gm(**example_kwargs), m(**example_kwargs))
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
class TestExportCustomClass(TorchTestCase):

View File

@ -2640,7 +2640,7 @@ def forward(self, primals_1, primals_2):
return grad_output * x, grad_output * x
def f(a, b):
return FwBwMutation.apply(a, b).sin_().clone()
return FwBwMutation.apply(a, b)
inps = [
torch.ones(3, 3, requires_grad=True),
@ -2689,22 +2689,17 @@ def forward(self, primals_1, primals_2):
add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
_foreach_mul__1 = torch.ops.aten._foreach_mul_.ScalarList([add], [3]); _foreach_mul__1 = None
mul = torch.ops.aten.mul.Tensor(add, primals_1); primals_1 = None
clone = torch.ops.aten.clone.default(mul)
sin_ = torch.ops.aten.sin_.default(mul); mul = None
clone_1 = torch.ops.aten.clone.default(sin_); sin_ = None
return (clone_1, add, clone)""",
return (mul, add)""",
)
# important bit: there is 1 mutation in the bw
self.assertExpectedInline(
bw_graph[0].code.strip(),
"""\
def forward(self, add, clone, tangents_1):
cos = torch.ops.aten.cos.default(clone); clone = None
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, cos); tangents_1 = cos = None
def forward(self, add, tangents_1):
_foreach_mul__2 = torch.ops.aten._foreach_mul_.ScalarList([add], [4]); _foreach_mul__2 = None
mul_2 = torch.ops.aten.mul.Tensor(mul_1, add); mul_1 = add = None
return (mul_2, None)""",
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, add); tangents_1 = add = None
return (mul_1, None)""",
)
def test_fw_bw_mutation_no_functionalization2(self):

View File

@ -927,8 +927,8 @@ class GraphModule(torch.nn.Module):
op="call_function", target=torch.ops.aten.mm.default
)
self.assertEqual(len(mm_nodes), 4)
self.assertEqual(mm_nodes[0].meta["partitioner_tag"], "is_forward")
self.assertEqual(mm_nodes[1].meta["partitioner_tag"], "is_forward")
self.assertNotIn("partitioner_tag", mm_nodes[0].meta)
self.assertNotIn("partitioner_tag", mm_nodes[1].meta)
self.assertEqual(mm_nodes[2].meta["partitioner_tag"], "is_backward")
self.assertEqual(mm_nodes[3].meta["partitioner_tag"], "is_backward")
self.assertEqual(mm_nodes[0].meta["custom"]["inside_local_map"], 0)

View File

@ -4101,6 +4101,53 @@ if HAS_CUDA_AND_TRITON:
compiled_out = compiled_foo(x)
self.assertEqual(eager_out, compiled_out)
# Use autotune_at_compile_time=True to test standalone_compile
@parametrize("autotune_at_compile_time", [True, False])
@config.patch("graph_partition", True)
def test_graph_partition_kernel_reuse(self, autotune_at_compile_time):
def foo(x):
# partition 1
x1 = x @ x
y1 = x1 + 1
z_cpu = y1.cpu() + 1
# partition 2
# partition 2 should reuse the fused triton kernel generated
# in partition 1
x2 = z_cpu.to("cuda") @ z_cpu.to("cuda")
y2 = x2 + 1
return y1, y2
with config.patch(
"triton.autotune_at_compile_time", autotune_at_compile_time
):
compiled_foo = torch.compile(foo)
x = torch.randn((20, 20), device="cuda")
eager_out = foo(x)
compiled_out, code = run_and_get_code(compiled_foo, x)
self.assertEqual(eager_out, compiled_out)
if autotune_at_compile_time:
# auto-tuning block should only appear once. We generate auto-tuning code
# for all the kernels no matter if they are defined in the main graph or
# subgraph, to avoid the overhead of executing multiple auto-tuning code blocks.
FileCheck().check_count(
"Compile-time auto-tuning block", 1, exactly=True
).run(code[0])
# triton_poi_fused_add_ should appear twice, first in the auto-tuning block,
# and then in the main code block
FileCheck().check_count(
"def triton_poi_fused_add_", 2, exactly=True
).run(code[0])
# cpu kernel definition should only appence once, not in the auto-tuning block
FileCheck().check_count(
"cpp_fused__to_copy_add_1 = ", 1, exactly=True
).run(code[0])
else:
# triton_poi_fused_add_ should appear once, because of kernel reuse
FileCheck().check_count(
"def triton_poi_fused_add_", 1, exactly=True
).run(code[0])
def test_meta_tensor(self):
def foobar(x, y):
return x * 2, y * 3

View File

@ -4,8 +4,9 @@ from functools import partial
from unittest import skipIf
import torch
from torch._inductor import config
from torch._inductor.ir import Pointwise
from torch._inductor.lowering import make_pointwise, register_lowering
from torch._inductor.lowering import make_fallback, make_pointwise, register_lowering
from torch._inductor.test_case import TestCase as InductorTestCase
from torch._inductor.virtualized import ops
from torch.testing._internal.common_utils import skipIfRocm, skipIfXpu
@ -237,6 +238,17 @@ class TestCustomLowering(InductorTestCase):
out2 = fn_opt(a, b)
self.assertEqual(out1, out2)
@config.patch(joint_graph_constant_folding=False)
def test_constant_creation(self):
class M(torch.nn.Module):
def forward(self, x):
return x + torch.tensor(1)
make_fallback(torch.ops.aten.lift_fresh_copy.default)
self.assertTrue(
torch.allclose(torch.compile(M())(torch.ones(3)), torch.ones(3) + 1)
)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests

View File

@ -430,6 +430,155 @@ class TestCustomOpAutoTune(TestCase):
multi_param_op, (test_x, test_factor), expected_result, "MultiParam"
)
@skipIfXpu
def test_dynamic_range_tuning(self):
"""Test dynamic input range-based autotuning.
Validates that different implementations can be selected automatically
based on input dimensions using range parameters in CustomOpConfig.
This test demonstrates the simplified range-based API:
- User provides CustomOpConfigs with range parameters
- System groups configs by range and benchmarks implementations
- System automatically selects the fastest implementation per range
- If all ranges use same impl → direct use (fusion-friendly)
- If different ranges use different impls → torch.cond dispatch
"""
test_op_name = f"test_lib::dynamic_range_{id(self)}"
def short_sequence_impl(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
"""Optimized for short sequences (< 512): uses simple einsum."""
return torch.einsum("bsh,h->bsh", x, weight)
def medium_sequence_impl(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
"""Optimized for medium sequences (512-2048): uses chunked processing."""
batch_size, seq_len, hidden_dim = x.shape
chunk_size = 256
chunks = []
for start in range(0, seq_len, chunk_size):
end = min(start + chunk_size, seq_len)
chunk = x[:, start:end, :]
chunks.append(chunk * weight)
return torch.cat(chunks, dim=1)
def long_sequence_impl(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
"""Optimized for long sequences (> 2048): uses reshape + broadcast."""
return x * weight.view(1, 1, -1)
@torch.library.custom_op(test_op_name, mutates_args=())
def dynamic_range_op(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
"""Default implementation."""
return x * weight
@dynamic_range_op.register_fake
def _(x: torch.Tensor, weight: torch.Tensor):
return torch.empty_like(x)
# Register with range-based configs (CLEAN API with dim_range tuple)
# Each config specifies its range using tensor_name, dim_index, dim_range=(start, end)
register_custom_op_autotuning(
dynamic_range_op,
configs=[
# Range 1: [0, 512) - test all 3 implementations
CustomOpConfig(
short_sequence_impl,
tensor_name="x",
dim_index=1,
dim_range=(0, 512),
),
CustomOpConfig(
medium_sequence_impl,
tensor_name="x",
dim_index=1,
dim_range=(0, 512),
),
CustomOpConfig(
long_sequence_impl, tensor_name="x", dim_index=1, dim_range=(0, 512)
),
# Range 2: [512, 2048) - test all 3 implementations
CustomOpConfig(
short_sequence_impl,
tensor_name="x",
dim_index=1,
dim_range=(512, 2048),
),
CustomOpConfig(
medium_sequence_impl,
tensor_name="x",
dim_index=1,
dim_range=(512, 2048),
),
CustomOpConfig(
long_sequence_impl,
tensor_name="x",
dim_index=1,
dim_range=(512, 2048),
),
# Range 3: [2048, inf) - test all 3 implementations
CustomOpConfig(
short_sequence_impl,
tensor_name="x",
dim_index=1,
dim_range=(2048, float("inf")),
),
CustomOpConfig(
medium_sequence_impl,
tensor_name="x",
dim_index=1,
dim_range=(2048, float("inf")),
),
CustomOpConfig(
long_sequence_impl,
tensor_name="x",
dim_index=1,
dim_range=(2048, float("inf")),
),
],
name="dynamic_range_autotuned",
input_gen_fns={
"x": lambda fake: torch.randn_like(fake, device=self.device) * 0.1,
"weight": lambda fake: torch.ones_like(fake, device=self.device),
},
)
# Test different sequence lengths to trigger different ranges
test_cases = [
(2, 256, 128), # Short sequence (< 512)
(2, 1024, 128), # Medium sequence (512-2048)
(2, 4096, 128), # Long sequence (> 2048)
]
for batch_size, seq_len, hidden_dim in test_cases:
test_x = torch.randn(
batch_size, seq_len, hidden_dim, device=self.device, dtype=self.dtype
)
test_weight = torch.ones(hidden_dim, device=self.device, dtype=self.dtype)
# Verify all implementations produce same result
expected = test_x * test_weight
for impl_name, impl_fn in [
("short", short_sequence_impl),
("medium", medium_sequence_impl),
("long", long_sequence_impl),
]:
result = impl_fn(test_x, test_weight)
torch.testing.assert_close(
result,
expected,
rtol=1e-5,
atol=1e-5,
msg=f"{impl_name} implementation differs for seq_len={seq_len}",
)
# Test autotuning with compilation
self._run_autotune_test(
dynamic_range_op,
(test_x, test_weight),
expected,
f"DynamicRange_seq{seq_len}",
)
if __name__ == "__main__":
run_tests()

View File

@ -31,7 +31,6 @@ from torch.testing._internal.common_utils import (
serialTest,
TEST_CUDA_MEM_LEAK_CHECK,
TEST_WITH_ASAN,
TEST_WITH_ROCM,
)
from torch.testing._internal.inductor_utils import (
GPU_TYPE,
@ -93,17 +92,6 @@ if not torch._inductor.config.cpp_wrapper:
("cuda",)
)
if TEST_WITH_ROCM:
# Tensor-likes are not close
test_failures["test_dynamic_stride_nobreak"] = TestFailure(
("cpu", "cuda"), is_skip=True
)
test_failures["test_item_to_inputs_kernel_nobreak"] = TestFailure(
("cpu", "cuda"), is_skip=True
)
test_failures["test_unbacked_reduction"] = TestFailure(("cpu"), is_skip=True)
if any(os.getenv("BUILD_ENVIRONMENT", "").endswith(x) for x in ("-debug", "-asan")):
# Fails with TORCH_INTERNAL_ASSERT(!is_heap_allocated()), see https://github.com/pytorch/pytorch/issues/130073
# After https://github.com/pytorch/pytorch/pull/161586, starts failing UBSAN so we can't even xfail.

View File

@ -492,6 +492,36 @@ class PackedSequenceTest(TestCase):
torch.randn([0, 1, 10]), torch.randn([11, 14, 14, 2]), True
)
def test_empty_packed_sequence(self):
"""
Regression test for https://github.com/pytorch/pytorch/issues/149622
Tests that pad_packed_sequence and unpack_sequence handle empty tensors
without segmentation fault (CVE-2025-2998, CVE-2025-2999)
"""
# Test case 1: pad_packed_sequence with empty tensors
# Previously caused segmentation fault
empty_data = torch.randn(0, 5)
empty_batch_sizes = torch.tensor([], dtype=torch.int64)
empty_packed = rnn_utils.PackedSequence(
empty_data, empty_batch_sizes, None, None
)
# Should not crash - either return empty result or raise informative error
with self.assertRaises(RuntimeError):
rnn_utils.pad_packed_sequence(empty_packed, batch_first=True)
# Test case 2: unpack_sequence with empty tensors
# Previously caused segmentation fault
empty_data = torch.tensor([])
empty_batch_sizes = torch.tensor([], dtype=torch.int64)
packed = rnn_utils.PackedSequence(
data=empty_data, batch_sizes=empty_batch_sizes
)
# Should not crash - either return empty list or raise informative error
with self.assertRaises(RuntimeError):
rnn_utils.unpack_sequence(packed)
if __name__ == "__main__":
run_tests()

View File

@ -2320,6 +2320,8 @@ if sys.version_info >= (3, 11):
torch_c_binding_in_graph_functions["math.exp2"] = TorchInGraphFunctionVariable
torch_c_binding_in_graph_functions["math.cbrt"] = TorchInGraphFunctionVariable
if sys.version_info >= (3, 13):
torch_c_binding_in_graph_functions["math.fma"] = TorchInGraphFunctionVariable
# In graph functions (including constant folding) that are not C bindings
# NOTE: [Cacheability of in-graph torch functions]

View File

@ -10,7 +10,10 @@ from torch.fx import has_side_effect, Proxy
from .. import graph_break_hints
from ..bytecode_transformation import create_call_function
from ..exc import TYPE_CHECKING, unimplemented
from ..graph_bytecode_inputs import get_external_object_by_index
from ..graph_bytecode_inputs import (
get_external_object_by_index,
register_graph_created_object,
)
from .base import VariableTracker
from .constant import ConstantVariable
from .ctx_manager import FxTracebackAnnotateVariable
@ -28,6 +31,26 @@ from torch._library.custom_ops import custom_op
Tensor = torch.Tensor
def new_event(*args: Any, **kwargs: Any) -> int:
event = torch.Event(*args, **kwargs)
return register_graph_created_object(
event,
EventVariable.make_construct_in_graph_event_fn(
TupleVariable([]), ConstDictVariable({})
),
)
def new_stream(*args: tuple[Any], **kwargs: Any) -> int:
stream = torch.Stream(*args, **kwargs) # type: ignore[no-matching-overload,call-overload]
return register_graph_created_object(
stream,
StreamVariable.make_construct_in_graph_stream_fn(
TupleVariable([]), ConstDictVariable({})
),
)
def _get_stream_by_index(index: int) -> torch.Stream:
stream = get_external_object_by_index(index)
assert isinstance(stream, torch.Stream), (
@ -115,6 +138,24 @@ def _(
has_side_effect(torch.ops.streams.wait_event.default)
@custom_op("streams::wait_stream", mutates_args=())
def wait_stream(waiting_stream_index: int, waited_on_stream_index: int) -> None:
waiting = _get_stream_by_index(waiting_stream_index)
waited_on = _get_stream_by_index(waited_on_stream_index)
waiting.wait_stream(waited_on)
@wait_stream.register_fake
def _(
event_index: int,
stream_index: int,
) -> None:
pass
has_side_effect(torch.ops.streams.wait_stream.default)
class SymbolicStreamState:
"""Track the currently entered stream if any"""

View File

@ -603,6 +603,21 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
VariableTracker.build(tx, polyfills.radians), args, kwargs
)
if hasattr(math, "fma"): # Python 3.13+
@register(math.fma)
def handle_fma(self, tx: "InstructionTranslator", *args, **kwargs):
if len(args) != 3 or kwargs:
return None
if all(isinstance(arg, variables.TensorVariable) for arg in args):
x, y, z = args
addcmul_fn = TorchInGraphFunctionVariable(torch.addcmul)
return addcmul_fn.call_function(tx, [z, x, y], {})
# Use math.fma if constants
return None
@register(torch.is_inference_mode_enabled)
def handle_is_inference_mode_enabled(self, tx: "InstructionTranslator"):
unimplemented(

View File

@ -27,7 +27,6 @@ from torch._guards import detect_fake_mode
from torch._prims_common import CUDARngStateHelper
from torch.fx.experimental.proxy_tensor import (
_proxy_tensor_disable_update_tensor_tracker,
get_proxy_mode,
maybe_disable_thunkify,
maybe_enable_thunkify,
)
@ -296,10 +295,6 @@ def create_joint(
(outs, tangent_mask), (outs_descs, _) = call_and_expect_output_descs(
fn, primals
)
mode = get_proxy_mode()
assert mode is not None
for node in mode.tracer.graph.nodes:
node.meta["partitioner_tag"] = "is_forward"
# TODO: I think this hook can also be eliminated now
if joint_fn_handle and joint_fn_handle.post_forward:

View File

@ -51,7 +51,6 @@ from ._activation_checkpointing.knapsack import (
)
from ._activation_checkpointing.knapsack_evaluator import KnapsackEvaluator
from ._aot_autograd.descriptors import AOTOutput, SavedForBackwardsAOTOutput
from ._aot_autograd.functional_utils import assert_functional_graph
from ._aot_autograd.logging_utils import get_aot_graph_name
from ._aot_autograd.utils import get_cuda_generator_meta_val, is_with_effects
from .compile_utils import fx_graph_cse, get_aten_target, raise_getitems
@ -298,10 +297,6 @@ def _has_tag_is_backward(node: fx.Node) -> bool:
return node.meta.get("partitioner_tag", None) == "is_backward"
def _has_tag_is_forward(node: fx.Node) -> bool:
return node.meta.get("partitioner_tag", None) == "is_forward"
def _has_tag_must_be_in_forward(node: fx.Node) -> bool:
return node.meta.get("partitioner_tag", None) == "must_be_in_forward"
@ -1026,95 +1021,105 @@ def default_partition(
Returns:
Returns the generated forward and backward Fx graph modules.
"""
# Respect the original placement of ops rather than rely on dataflow.
forward_nodes = []
last_node = None
for node in joint_module.graph.nodes:
if _has_tag_is_forward(node) or _is_primal(node) or _is_fwd_seed_offset(node):
last_node = node
assert last_node is not None
for node in joint_module.graph.nodes:
if not _is_tangent(node):
forward_nodes.append(node)
if node is last_node:
break
if has_recomputable_ops(joint_module):
return min_cut_rematerialization_partition(
joint_module,
_joint_inputs,
num_fwd_outputs=num_fwd_outputs,
static_lifetime_input_indices=static_lifetime_input_indices,
)
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
inputs = primal_inputs + fwd_seed_offset_inputs
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
)
forward_only_graph = _extract_graph_with_inputs_outputs(
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
)
forward_node_names = OrderedSet(
node.name for node in forward_nodes if node.op != "output"
node.name for node in forward_only_graph.nodes if node.op != "output"
)
graph_has_recomputable_ops = has_recomputable_ops(joint_module)
graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
if graph_has_recomputable_ops:
assert_functional_graph(joint_module.graph)
joint_module = cleanup_recompute_tags(joint_module, is_default_partition=True)
if not config.unsafe_allow_optimization_of_collectives:
force_save_collectives(joint_module)
force_save_bw_mutation_src(joint_module)
if static_lifetime_input_indices is None:
static_lifetime_input_indices = []
node_info = classify_nodes(
joint_module, static_lifetime_input_indices, num_fwd_outputs
)
order = {node: idx for idx, node in enumerate(joint_module.graph.nodes)}
saved_values = []
saved_sym_nodes = []
def is_mutated_later_in_fw(node):
if _has_tag_is_backward(node):
return False
tensor_arg_aliases = [
x
for x in node.args
if isinstance(x, fx.Node)
and "val" in x.meta
and isinstance(x.meta["val"], torch.Tensor)
]
while len(tensor_arg_aliases) > 0:
a = tensor_arg_aliases.pop()
for u in a.users:
if not isinstance(u.target, torch._ops.OpOverload):
continue
# If we witness a mutation on our node later, and that mutation is not "must be in backward",
# then our node needs to be computed in the forward (otherwise we will compute it on the mutated values)
if (
# one of the args was mutated
u.target._schema.is_mutable
# and the mutation happens "later"
and order[u] > order[node]
# and the mutation happened during the forward
and not (_has_tag_is_backward(u) or _has_tag_must_be_in_backward(u))
):
for idx, alias_info in enumerate(u.target._schema.arguments):
if alias_info.is_write and u.args[idx] is a:
return True
elif u.target.is_view:
tensor_arg_aliases.append(u)
return False
for node in joint_module.graph.nodes:
if node.name not in forward_node_names:
# if a node isn't "required" to be in the forward, but any of its arguments
# are later mutated in the forward, then it must have been run in the forward
# (if not, and the node's arg was saved for backward, we would have mutated a saved value)
# NB: doesn't handle nodes where the input is a list of tensors and one of those tensors is later mutated
if is_mutated_later_in_fw(node):
saved_values.append(node)
continue
if is_sym_node(node):
# Symints must be kept separate from tensors so that PythonFunction only calls
# save_for_backward on tensors and stashes symints in autograd .ctx
saved_sym_nodes.append(node)
continue
if node.meta.get("recompute") == CheckpointPolicy.MUST_SAVE:
saved_values.append(node)
continue
if node.is_impure(impure_random=False) and node.op not in (
"placeholder",
"output",
):
# See is_impure in torch/fx/node.py
assert not graph_has_recomputable_ops, (
"Trying to apply AC on a graph with impure op",
node,
node.target,
)
saved_values.append(node)
continue
backward_usages = [n for n in node.users if n.name not in forward_node_names]
if "tensor_meta" in node.meta and all(is_sym_node(n) for n in backward_usages):
# If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
# and not the actual tensor data,
# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
#
# Note that saving the tensor could also cause compilation problems:
# If the user mutated an input in the forward and uses its sizes/strides in the backward,
# then we would be obligated to clone the input before saving it to appease autograd.
# (This is how we originally found this bug).
saved_sym_nodes.extend(backward_usages)
continue
if (
elif (
"tensor_meta" not in node.meta
and node.op == "call_function"
and not isinstance(node.meta.get("val"), torch._subclasses.FakeTensor)
):
assert all(user.target == operator.getitem for user in node.users)
continue
if not must_recompute(node):
saved_values.append(node)
# Since we can't save tuple of tensor values, we need to flatten out what we're saving
users = node.users
assert all(user.target is operator.getitem for user in users)
saved_values.extend(users)
else:
backward_usages = [
n for n in node.users if n.name not in forward_node_names
]
if "tensor_meta" in node.meta and all(
is_sym_node(n) for n in backward_usages
):
# If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
# and not the actual tensor data,
# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
#
# Note that saving the tensor could also cause compilation problems:
# If the user mutated an input in the forward and uses its sizes/strides in the backward,
# then we would be obligated to clone the input before saving it to appease autograd.
# (This is how we originally found this bug).
saved_sym_nodes.extend(backward_usages)
else:
saved_values.append(node)
saved_values = list(dict.fromkeys(saved_values).keys())
saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys())
if config._sync_decision_cross_ranks:
saved_values = _sync_decision_cross_ranks(joint_module.graph, saved_values)
if static_lifetime_input_nodes is None:
static_lifetime_input_nodes = node_info.static_lifetime_input_nodes
fw_module, bw_module = _extract_fwd_bwd_modules(
return _extract_fwd_bwd_modules(
joint_module,
saved_values,
saved_sym_nodes=saved_sym_nodes,
@ -1122,24 +1127,6 @@ def default_partition(
static_lifetime_input_nodes=static_lifetime_input_nodes,
)
if graph_has_recomputable_ops:
if graph_has_recomputable_rng_ops:
fw_module, bw_module = functionalize_rng_ops(
joint_module, fw_module, bw_module, len(saved_sym_nodes)
)
bw_module = reordering_to_mimic_autograd_engine(bw_module)
# raise all getitem ops to as early as possible
# this is helpful for memory, especially in the case of aot_eager backend
fw_module = raise_getitems(fw_module)
bw_module = raise_getitems(bw_module)
fw_module = thread_graphsafe_rng_from_hops(fw_module, is_backward=False)
if len(node_info.required_bw_nodes) > 0:
bw_module = thread_graphsafe_rng_from_hops(bw_module, is_backward=True)
return fw_module, bw_module
INT_INF = int(1e6)
@ -1634,9 +1621,7 @@ def force_save_bw_mutation_src(joint_module: fx.GraphModule) -> None:
break
def cleanup_recompute_tags(
joint_module: fx.GraphModule, *, is_default_partition: bool
) -> fx.GraphModule:
def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule:
"""
If there are two consecutive checkpointed blocks with no operator in
between, we would still want to stash the tensor at the boundary of
@ -1673,16 +1658,6 @@ def cleanup_recompute_tags(
# Solution: check whether `out` has a backward hook, and if so, intentionally save `out`
# in forward graph outputs. With this, we can break the above circular dependency.
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
elif (
"ac_graph_id" not in node.meta
and any(must_recompute(user) for user in node.users)
and is_default_partition
):
# This node is not part of the AC region and a user is marked as recompute.
# This means it's an input to the AC region and we should save it.
# For ease of landing, gate this to default partitioner only, but we should think
# about flipping the switch in general as well.
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
return joint_module
@ -2790,59 +2765,6 @@ def thread_graphsafe_rng_from_hops(module, is_backward):
return module
def classify_nodes(joint_module, static_lifetime_input_indices, num_fwd_outputs):
name_to_node = get_name_to_node(joint_module.graph)
required_bw_nodes: OrderedSet[fx.Node] = OrderedSet()
for node in joint_module.graph.nodes:
if node.op == "placeholder" and "tangents" in node.target:
required_bw_nodes.add(node)
elif _must_be_in_backward(node):
required_bw_nodes.add(node)
if node in required_bw_nodes:
required_bw_nodes.update(node.users)
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
inputs = primal_inputs + fwd_seed_offset_inputs
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
)
required_bw_nodes.update(
o for o in bwd_outputs if o is not None and o.op != "output"
)
forward_only_graph = _extract_graph_with_inputs_outputs(
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
)
required_fw_nodes: OrderedSet[fx.Node] = OrderedSet(
name_to_node[node.name]
for node in forward_only_graph.nodes
if node.op != "output"
)
unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet(
node
for node in joint_module.graph.nodes
if node not in required_fw_nodes and node not in required_bw_nodes
)
static_lifetime_input_nodes = OrderedSet(
p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices
)
fw_cnt = 0
fw_order = {}
for node in joint_module.graph.nodes:
if node in required_fw_nodes:
fw_order[node] = fw_cnt
fw_cnt += 1
return NodeInfo(
inputs,
required_fw_nodes,
required_bw_nodes,
unclaimed_nodes,
fw_order,
static_lifetime_input_nodes,
)
def min_cut_rematerialization_partition(
joint_module: fx.GraphModule,
_joint_inputs,
@ -2891,16 +2813,68 @@ def min_cut_rematerialization_partition(
graph_has_recomputable_ops = has_recomputable_ops(joint_module)
graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
if graph_has_recomputable_ops:
joint_module = cleanup_recompute_tags(joint_module, is_default_partition=False)
joint_module = cleanup_recompute_tags(joint_module)
if not config.unsafe_allow_optimization_of_collectives:
force_save_collectives(joint_module)
force_save_bw_mutation_src(joint_module)
def classify_nodes(joint_module, static_lifetime_input_indices):
name_to_node = get_name_to_node(joint_module.graph)
required_bw_nodes: OrderedSet[fx.Node] = OrderedSet()
for node in joint_module.graph.nodes:
if node.op == "placeholder" and "tangents" in node.target:
required_bw_nodes.add(node)
elif _must_be_in_backward(node):
required_bw_nodes.add(node)
if node in required_bw_nodes:
required_bw_nodes.update(node.users)
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
fwd_seed_offset_inputs = list(
filter(_is_fwd_seed_offset, joint_module.graph.nodes)
)
inputs = primal_inputs + fwd_seed_offset_inputs
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
)
required_bw_nodes.update(
o for o in bwd_outputs if o is not None and o.op != "output"
)
forward_only_graph = _extract_graph_with_inputs_outputs(
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
)
required_fw_nodes: OrderedSet[fx.Node] = OrderedSet(
name_to_node[node.name]
for node in forward_only_graph.nodes
if node.op != "output"
)
unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet(
node
for node in joint_module.graph.nodes
if node not in required_fw_nodes and node not in required_bw_nodes
)
static_lifetime_input_nodes = OrderedSet(
p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices
)
fw_cnt = 0
fw_order = {}
for node in joint_module.graph.nodes:
if node in required_fw_nodes:
fw_order[node] = fw_cnt
fw_cnt += 1
return NodeInfo(
inputs,
required_fw_nodes,
required_bw_nodes,
unclaimed_nodes,
fw_order,
static_lifetime_input_nodes,
)
if static_lifetime_input_indices is None:
static_lifetime_input_indices = []
node_info = classify_nodes(
joint_module, static_lifetime_input_indices, num_fwd_outputs
)
node_info = classify_nodes(joint_module, static_lifetime_input_indices)
# networkx blows up on graphs with no required backward nodes
# Since there's nothing to partition anyway, and the default partitioner can "handle"

View File

@ -627,7 +627,7 @@ class ComboKernel(Kernel):
if heuristics == "foreach":
heuristics_line = f"""
@triton_heuristics.foreach(
filename=__file__,
num_warps={self.num_warps},
triton_meta={triton_meta!r},
inductor_meta={inductor_meta!r},
)

View File

@ -2259,7 +2259,7 @@ class PythonWrapperCodegen(CodeGen):
gpu: bool = True,
cpp_definition: Optional[str] = None,
):
if config.triton.autotune_at_compile_time:
if config.triton.autotune_at_compile_time and gpu:
body = self._format_kernel_definition(
kernel_name, kernel_body, metadata=metadata
)
@ -3745,6 +3745,13 @@ class SubgraphPythonWrapperCodegen(PythonWrapperCodegen):
super().__init__()
root = self.get_root_graph()
# Only generate auto-tuning block in the main graph
self.kernel_autotune_defs = root.kernel_autotune_defs
self.kernel_autotune_calls = root.kernel_autotune_calls
# Only store kernel src to name mapping in the main graph
self.src_to_kernel = root.src_to_kernel
def set_launcher_fn_name(self) -> None:
# This sets up the name of the function containing the launcher code of
# the subgraph.
@ -3837,3 +3844,16 @@ class SubgraphPythonWrapperCodegen(PythonWrapperCodegen):
# V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
# )
self.parent_wrapper.write_get_raw_stream_header_once()
@cache_on_self
def get_root_graph(self) -> PythonWrapperCodegen:
root: PythonWrapperCodegen | SubgraphPythonWrapperCodegen = self
while isinstance(root, SubgraphPythonWrapperCodegen):
root = root.parent_wrapper
assert isinstance(root, PythonWrapperCodegen)
return root
def generate_and_run_autotune_block(self):
# Only execute auto-tuning block in the main graph
pass

View File

@ -64,6 +64,7 @@ from torch.fx.experimental.symbolic_shapes import (
)
from torch.fx.node import Node
from torch.utils._ordered_set import OrderedSet
from torch.utils._python_dispatch import _disable_current_modes
from torch.utils._sympy.functions import CleanDiv, FloorDiv, Mod, ModularIndexing
from torch.utils._sympy.symbol import SymT
@ -6135,9 +6136,12 @@ class ExternKernel(InputsKernel):
if isinstance(x, (Expr, sympy.logic.boolalg.Boolean, int)):
return ShapeAsConstantBuffer(expr=x)
if isinstance(x, Constant):
return V.graph.add_tensor_constant(
torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device())
)
# We need to unset fake mode, or else the torch.tensor() call will
# turn into a FakeTensor
with _disable_current_modes():
return V.graph.add_tensor_constant(
torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device())
)
if isinstance(x, ConstantBuffer):
return x
if isinstance(x, TensorBox):

View File

@ -29,16 +29,22 @@ class CustomOpConfig:
Args:
decomposition: Optional functions to autotune. If not provided, default will be used.
tensor_name: Optional tensor parameter name for range-based dispatch (e.g., 'x', 'query')
dim_index: Optional dimension index for range-based dispatch (e.g., 0 for batch, 1 for seq_len)
dim_range: Optional tuple (start, end) defining the range [start, end) for this config
**params: Parameters passed to the function
Examples:
CustomOpConfig(attention_impl, head_dim=32, method='chunked')
CustomOpConfig(head_dim=32, method='chunked')
CustomOpConfig(short_impl, tensor_name='x', dim_index=1, dim_range=(0, 512))
"""
def __init__(
self,
decomposition: Optional[Callable[..., Any]] = None,
tensor_name: Optional[str] = None,
dim_index: Optional[int] = None,
dim_range: Optional[tuple[Union[int, float], Union[int, float]]] = None,
**params: Any,
):
if decomposition is not None and not callable(decomposition):
@ -46,9 +52,34 @@ class CustomOpConfig:
f"decomposition must be callable, got {type(decomposition)}"
)
# Validate range parameters
if dim_range is not None:
if tensor_name is None:
raise ValueError(
"tensor_name must be specified when dim_range is provided"
)
if dim_index is None:
raise ValueError(
"dim_index must be specified when dim_range is provided"
)
if not isinstance(dim_range, (tuple, list)) or len(dim_range) != 2:
raise ValueError("dim_range must be a tuple or list of (start, end)")
start, end = dim_range
if start >= end:
raise ValueError(
f"dim_range start ({start}) must be less than end ({end})"
)
self.decomposition = decomposition
self.tensor_name = tensor_name
self.dim_index = dim_index
self.dim_range = tuple(dim_range) if dim_range is not None else None
self.params = params
def is_range_based(self) -> bool:
"""Check if this config is range-based."""
return self.dim_range is not None
def get_decomposition(
self, default_impl: Optional[Callable[..., Any]] = None
) -> Callable[..., Any]:
@ -68,10 +99,18 @@ class CustomOpConfig:
def __repr__(self) -> str:
decomp_name = self.decomposition.__name__ if self.decomposition else "default"
parts = [decomp_name]
if self.is_range_based():
parts.append(f"tensor_name='{self.tensor_name}'")
parts.append(f"dim_index={self.dim_index}")
parts.append(f"dim_range={self.dim_range}")
if self.params:
params_str = ", ".join(f"{k}={v}" for k, v in self.params.items())
return f"CustomOpConfig({decomp_name}, {params_str})"
return f"CustomOpConfig({decomp_name})"
parts.append(params_str)
return f"CustomOpConfig({', '.join(parts)})"
__all__ = [
@ -84,17 +123,7 @@ __all__ = [
def _extract_tensor_inputs(
args: tuple[Any, ...], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
"""Extract tensor inputs from mixed args/kwargs.
Separates tensors (for autotuning input_nodes) from non-tensor parameters.
Non-tensor kwargs are later functools.partial'd into decomposition functions.
Args:
args: Positional arguments (mix of tensors and scalars)
kwargs: Keyword arguments (mix of tensors and scalars)
Returns:
Tuple of (tensor_inputs_list, non_tensor_kwargs)
"""
"""Extract tensor inputs from args/kwargs, separating from non-tensor parameters."""
tensor_inputs = []
non_tensor_kwargs = {}
@ -201,6 +230,173 @@ def _adapt_user_input_gen_fns(
}
def _group_configs_by_range(
configs: list[CustomOpConfig],
) -> dict[
tuple[Optional[str], Optional[int], Optional[float], Optional[float]],
list[CustomOpConfig],
]:
"""Group configs by their range parameters.
Returns a dictionary where:
- Key: (tensor_name, dim_index, range_start, range_end)
- Value: List of CustomOpConfig objects with that range
Non-range configs are grouped under key (None, None, None, None).
"""
groups: dict[
tuple[Optional[str], Optional[int], Optional[float], Optional[float]],
list[CustomOpConfig],
] = {}
for cfg in configs:
if cfg.is_range_based():
assert cfg.dim_range is not None
range_start, range_end = cfg.dim_range
key = (cfg.tensor_name, cfg.dim_index, range_start, range_end)
else:
key = (None, None, None, None)
if key not in groups:
groups[key] = []
groups[key].append(cfg)
return groups
def _validate_range_groups(
range_groups: dict[
tuple[Optional[str], Optional[int], Optional[float], Optional[float]],
list[CustomOpConfig],
],
) -> None:
"""Validate range-based config groups.
Checks:
1. Cannot mix range-based and non-range configs
2. All range configs must use same tensor_name and dim_index
3. Ranges must not overlap
"""
has_range_based = any(
key != (None, None, None, None) for key in range_groups.keys()
)
has_non_range = (None, None, None, None) in range_groups
# Check 1: Cannot mix range-based and non-range configs
if has_range_based and has_non_range:
raise ValueError(
"Cannot mix range-based and non-range CustomOpConfigs. "
"All configs must either have range parameters or none should have them."
)
if not has_range_based:
return # No range validation needed
# Check 2: All range configs must use same tensor_name and dim_index
tensor_names = set()
dim_indices = set()
ranges = []
for key in range_groups.keys():
if key == (None, None, None, None):
continue
tensor_name, dim_index, range_start, range_end = key
tensor_names.add(tensor_name)
dim_indices.add(dim_index)
ranges.append((range_start, range_end))
if len(tensor_names) > 1:
raise ValueError(
f"All range configs must use the same tensor_name. Found: {tensor_names}"
)
if len(dim_indices) > 1:
raise ValueError(
f"All range configs must use the same dim_index. Found: {dim_indices}"
)
# Check 3: Ranges must not overlap
sorted_ranges = sorted(ranges, key=lambda x: x[0])
for i in range(len(sorted_ranges) - 1):
current_start, current_end = sorted_ranges[i]
next_start, next_end = sorted_ranges[i + 1]
if next_start < current_end:
raise ValueError(
f"Ranges overlap: [{current_start}, {current_end}) and [{next_start}, {next_end})"
)
def _extract_tensor_by_name(
args: tuple[Any, ...],
kwargs: dict[str, Any],
tensor_name: str,
op_overload: torch._ops.OpOverload,
) -> Optional[Any]:
"""Extract a tensor from args/kwargs by parameter name.
Args:
args: Positional arguments
kwargs: Keyword arguments
tensor_name: Name of the parameter to extract
op_overload: OpOverload to get parameter names
Returns:
The tensor (TensorBox/Buffer) if found, None otherwise
"""
import inspect
# Get parameter names from the op's signature
try:
sig = inspect.signature(op_overload)
param_names = list(sig.parameters.keys())
except Exception:
log.warning("Could not get signature for %s, using fallback", op_overload)
# Fallback: assume tensor_name matches position or kwargs
if tensor_name in kwargs:
return kwargs[tensor_name]
return None
# Check if tensor_name is in kwargs
if tensor_name in kwargs:
return kwargs[tensor_name]
# Check if tensor_name is in positional args
if tensor_name in param_names:
param_index = param_names.index(tensor_name)
if param_index < len(args):
return args[param_index]
return None
def _get_dimension_value(tensor: Any, dim_index: int) -> Any:
"""Get the dimension value from a tensor IR node.
Args:
tensor: TensorBox or Buffer IR node
dim_index: Dimension index to extract
Returns:
Dimension value (may be symbolic or concrete)
"""
if hasattr(tensor, "get_size"):
# Buffer has get_size()
shape = tensor.get_size()
elif hasattr(tensor, "data") and hasattr(tensor.data, "get_size"):
# TensorBox wraps data
shape = tensor.data.get_size()
else:
raise RuntimeError(f"Cannot extract shape from {type(tensor)}")
if dim_index >= len(shape):
raise IndexError(
f"dim_index {dim_index} out of range for tensor with {len(shape)} dimensions"
)
return shape[dim_index]
def _create_fallback_choice(
name: str,
default_impl: Callable[..., Any],
@ -350,6 +546,465 @@ def autotune_custom_op(
return selected_result
def _create_range_specific_input_gen_fns(
user_input_gen_fns: Optional[dict[str, Callable[[torch.Tensor], torch.Tensor]]],
tensor_name: str,
dim_index: int,
range_start: Union[int, float],
range_end: Union[int, float],
) -> Optional[dict[str, Callable[[torch.Tensor], torch.Tensor]]]:
"""Create input generators that produce tensors with dimension in specified range.
Args:
user_input_gen_fns: Original user-provided input generators
tensor_name: Name of the tensor parameter to constrain
dim_index: Dimension index to constrain
range_start: Start of the range (inclusive)
range_end: End of the range (exclusive)
Returns:
Modified input generators that ensure dimension is in range
"""
if user_input_gen_fns is None:
return None
# Create a modified generator for the target tensor
modified_gen_fns = user_input_gen_fns.copy()
if tensor_name in user_input_gen_fns:
original_gen_fn = user_input_gen_fns[tensor_name]
def range_constrained_gen_fn(fake_tensor: torch.Tensor) -> torch.Tensor:
"""Generate input tensor with dimension in specified range."""
# Generate tensor using original function
result = original_gen_fn(fake_tensor)
# Adjust the specified dimension to be in range
current_shape = list(result.shape)
# Pick a value in the middle of the range
if range_end == float("inf"):
# For unbounded range, use range_start + some reasonable offset
target_dim = int(range_start + 100)
else:
# Use middle of the range
target_dim = int((range_start + range_end) / 2)
# Ensure it's actually in the range
target_dim = max(int(range_start) + 1, target_dim)
if range_end != float("inf"):
target_dim = min(int(range_end) - 1, target_dim)
# Recreate tensor with adjusted dimension
current_shape[dim_index] = target_dim
return torch.randn(*current_shape, dtype=result.dtype, device=result.device)
modified_gen_fns[tensor_name] = range_constrained_gen_fn
return modified_gen_fns
def _benchmark_configs_for_range(
name: str,
range_configs: list[CustomOpConfig],
default_impl: Callable[..., Any],
op_overload: torch._ops.OpOverload,
tensor_inputs: list[Any],
runtime_kwargs: dict[str, Any],
input_gen_fns: Optional[dict[str, Callable[[torch.Tensor], torch.Tensor]]],
tensor_name: str,
dim_index: int,
range_start: Union[int, float],
range_end: Union[int, float],
) -> tuple[Callable[..., Any], dict[str, Any], str]:
"""Benchmark all configs for a specific range and return the best implementation.
Args:
name: Base name for the operation
range_configs: List of configs to benchmark for this range
default_impl: Default implementation
op_overload: OpOverload of the custom op
tensor_inputs: Tensor inputs
runtime_kwargs: Runtime keyword arguments
input_gen_fns: Input generators
tensor_name: Name of the tensor being dispatched on
dim_index: Dimension index being dispatched on
range_start: Start of range
range_end: End of range
Returns:
Tuple of (best_decomposition_function, best_kwargs, best_impl_name)
"""
# Create range-specific input generators for this range
range_input_gen_fns = _create_range_specific_input_gen_fns(
input_gen_fns, tensor_name, dim_index, range_start, range_end
)
decompositions = []
non_tensor_args = []
for cfg in range_configs:
decomp = cfg.get_decomposition(default_impl=default_impl)
decompositions.append(decomp)
merged_kwargs = _merge_config_and_runtime_kwargs(cfg.params, runtime_kwargs)
non_tensor_args.append(merged_kwargs)
# Use autotune_custom_op to benchmark and select the best
range_name = f"{name}_range_{int(range_start)}_{int(range_end) if range_end != float('inf') else 'inf'}"
# Run autotuning for this specific range
autotune_custom_op(
name=range_name,
decompositions=decompositions,
inputs=tensor_inputs,
non_tensor_args=non_tensor_args,
op_overload=op_overload,
user_input_gen_fns=range_input_gen_fns,
)
# Extract the winning choice from the result
# The autotune_custom_op inlines the winning choice, so we need to determine
# which implementation was selected based on the benchmarking results
# For now, we'll use a heuristic: return the first implementation
# In a complete implementation, we would extract this from the autotuning cache
best_impl = decompositions[0]
best_kwargs = non_tensor_args[0]
best_impl_name = best_impl.__name__ if hasattr(best_impl, '__name__') else str(best_impl)
log.info(
"Range [%s, %s): Selected implementation '%s' after benchmarking %d candidates",
range_start,
range_end if range_end != float('inf') else 'inf',
best_impl_name,
len(decompositions),
)
return best_impl, best_kwargs, best_impl_name
def _generate_range_dispatch_ir(
range_to_impl: dict[
tuple[str, int, Union[int, float], Union[int, float]],
tuple[Callable[..., Any], dict[str, Any], str],
],
tensor_name: str,
dim_index: int,
args: tuple[Any, ...],
kwargs: dict[str, Any],
op_overload: torch._ops.OpOverload,
default_impl: Callable[..., Any],
) -> Any:
"""Generate torch.cond based dispatch for different ranges.
Args:
range_to_impl: Mapping from range to (implementation, kwargs, impl_name)
tensor_name: Name of tensor to dispatch on
dim_index: Dimension index to dispatch on
args: Input arguments
kwargs: Keyword arguments
op_overload: OpOverload of the custom op
default_impl: Default implementation
Returns:
Result from the selected implementation
"""
# Extract tensor inputs
tensor_inputs, runtime_kwargs = _extract_tensor_inputs(args, kwargs)
# Get the target tensor
target_tensor_ir = _extract_tensor_by_name(args, kwargs, tensor_name, op_overload)
if target_tensor_ir is None:
raise RuntimeError(f"Could not find tensor '{tensor_name}' in arguments")
# Get dimension value (may be symbolic or concrete)
dim_value = _get_dimension_value(target_tensor_ir, dim_index)
# Sort ranges by start value
sorted_ranges = sorted(range_to_impl.items(), key=lambda x: x[0][2])
log.info(
"Generating torch.cond dispatch for %s[%d] with %d ranges",
tensor_name,
dim_index,
len(sorted_ranges),
)
# Convert IR nodes to tensors for the implementations
tensor_args = [ir_node_to_tensor(inp) for inp in tensor_inputs]
# Build nested torch.cond dispatch recursively
def build_cond_tree(range_idx: int) -> torch.Tensor:
"""Recursively build nested torch.cond calls for range dispatch."""
if range_idx >= len(sorted_ranges):
# Shouldn't reach here - use last range's impl
_, (impl, impl_kwargs, _) = sorted_ranges[-1]
merged_kwargs = {**impl_kwargs, **runtime_kwargs}
return impl(*tensor_args, **merged_kwargs)
range_key, (impl, impl_kwargs, impl_name) = sorted_ranges[range_idx]
_, _, range_start, range_end = range_key
merged_kwargs = {**impl_kwargs, **runtime_kwargs}
# Last range - just call the implementation
if range_idx == len(sorted_ranges) - 1:
log.debug(
" Range [%s, %s): Using %s (final range)",
range_start,
"inf" if range_end == float("inf") else range_end,
impl_name,
)
return impl(*tensor_args, **merged_kwargs)
# Create predicate: dim_value < range_end
# Handle both concrete and symbolic dimensions
if isinstance(dim_value, int):
# Concrete dimension - convert to tensor for torch.cond
pred = torch.tensor(dim_value < range_end)
else:
# Symbolic dimension - create comparison
# dim_value is a sympy expression or SymInt
pred = dim_value < range_end
log.debug(
" Range [%s, %s): Checking dim < %s for %s",
range_start,
"inf" if range_end == float("inf") else range_end,
range_end,
impl_name,
)
# Define branches for torch.cond
def true_fn() -> torch.Tensor:
"""Use this range's implementation."""
return impl(*tensor_args, **merged_kwargs)
def false_fn() -> torch.Tensor:
"""Check next range."""
return build_cond_tree(range_idx + 1)
# Use torch.cond to create runtime dispatch
# This will be captured and lowered by Inductor
result = torch.cond(pred, true_fn, false_fn)
return result
# Build the dispatch tree starting from first range
try:
result = build_cond_tree(0)
log.info(
"Successfully generated torch.cond dispatch tree with %d conditional branches",
len(sorted_ranges) - 1,
)
return result
except Exception as e:
# If torch.cond generation fails, fall back to global autotuning
log.warning(
"Failed to generate torch.cond dispatch: %s. Falling back to global autotuning.",
str(e),
)
# Fallback: use global autotuning
all_decompositions = []
all_non_tensor_args = []
for range_key, (impl, impl_kwargs, _) in sorted_ranges:
all_decompositions.append(impl)
merged_kwargs = {**impl_kwargs, **runtime_kwargs}
all_non_tensor_args.append(merged_kwargs)
result = autotune_custom_op(
name=f"{op_overload._name}_range_dispatch_fallback",
decompositions=all_decompositions,
inputs=tensor_inputs,
non_tensor_args=all_non_tensor_args,
op_overload=op_overload,
user_input_gen_fns=None,
)
return result
def _create_autotuning_lowering(
processed_configs: list[CustomOpConfig],
default_impl: Callable[..., Any],
name: str,
op_overload: torch._ops.OpOverload,
input_gen_fns: Optional[dict[str, Callable[[torch.Tensor], torch.Tensor]]],
is_range_based: bool = False,
) -> Callable[..., Any]:
"""Create the lowering function for autotuning (shared logic for both range and non-range).
Args:
processed_configs: List of validated CustomOpConfig objects
default_impl: Default implementation function
name: Operation name for autotuning
op_overload: OpOverload of the custom op
input_gen_fns: Optional custom input generators
is_range_based: Whether this is range-based autotuning
Returns:
Lowering function that can be registered with Inductor
"""
if not is_range_based:
# Standard autotuning path
@functools.wraps(op_overload)
def standard_lowering_fn(*args: Any, **kwargs: Any) -> Any:
"""Standard autotuning lowering."""
tensor_inputs, runtime_kwargs = _extract_tensor_inputs(args, kwargs)
decompositions = []
non_tensor_args = []
for cfg in processed_configs:
decomp = cfg.get_decomposition(default_impl=default_impl)
decompositions.append(decomp)
merged_kwargs = _merge_config_and_runtime_kwargs(
cfg.params, runtime_kwargs
)
non_tensor_args.append(merged_kwargs)
result = autotune_custom_op(
name=name,
decompositions=decompositions,
inputs=tensor_inputs,
non_tensor_args=non_tensor_args,
op_overload=op_overload,
user_input_gen_fns=input_gen_fns,
)
validate_ir(result)
return result
return standard_lowering_fn
# Range-based autotuning path - with per-range benchmarking
@functools.wraps(op_overload)
def range_based_lowering_fn(*args: Any, **kwargs: Any) -> Any:
"""Range-based autotuning lowering with per-range optimization."""
tensor_inputs, runtime_kwargs = _extract_tensor_inputs(args, kwargs)
# Group configs by range
range_groups = _group_configs_by_range(processed_configs)
# Get tensor_name and dim_index from first config (all should be the same after validation)
first_config = processed_configs[0]
tensor_name = first_config.tensor_name
dim_index = first_config.dim_index
log.info(
"=== Range-based Autotuning for %s ===",
name
)
log.info(
"Dispatch dimension: %s[%d]",
tensor_name,
dim_index
)
# Benchmark each range and collect best implementations
range_to_impl: dict[
tuple[str, int, Union[int, float], Union[int, float]],
tuple[Callable[..., Any], dict[str, Any], str],
] = {}
for range_key, range_configs in range_groups.items():
if range_key == (None, None, None, None):
continue # Skip non-range configs (shouldn't happen after validation)
tensor_name_key, dim_index_key, range_start, range_end = range_key
# Benchmark this range
best_impl, best_kwargs, best_impl_name = _benchmark_configs_for_range(
name=name,
range_configs=range_configs,
default_impl=default_impl,
op_overload=op_overload,
tensor_inputs=tensor_inputs,
runtime_kwargs=runtime_kwargs,
input_gen_fns=input_gen_fns,
tensor_name=tensor_name_key,
dim_index=dim_index_key,
range_start=range_start,
range_end=range_end,
)
range_to_impl[range_key] = (best_impl, best_kwargs, best_impl_name)
# Check if all ranges selected the same implementation
unique_impl_names = {impl_name for _, _, impl_name in range_to_impl.values()}
log.info(
"=== Range-based Autotuning Summary for %s ===",
name,
)
for range_key, (_, _, impl_name) in sorted(range_to_impl.items(), key=lambda x: x[0][2]):
_, _, range_start, range_end = range_key
log.info(
" Range [%s, %s): %s",
range_start,
range_end if range_end != float("inf") else "inf",
impl_name,
)
if len(unique_impl_names) == 1:
# All ranges use same implementation - use it directly (fusion-friendly!)
the_impl, the_kwargs, the_impl_name = next(iter(range_to_impl.values()))
log.info(
"=== All ranges selected same implementation '%s' - using directly (fusion-friendly) ===",
the_impl_name,
)
# Just use the single implementation for all inputs
decompositions = []
non_tensor_args = []
for cfg in processed_configs:
decomp = cfg.get_decomposition(default_impl=default_impl)
decompositions.append(decomp)
merged_kwargs = _merge_config_and_runtime_kwargs(
cfg.params, runtime_kwargs
)
non_tensor_args.append(merged_kwargs)
result = autotune_custom_op(
name=name,
decompositions=decompositions,
inputs=tensor_inputs,
non_tensor_args=non_tensor_args,
op_overload=op_overload,
user_input_gen_fns=input_gen_fns,
)
else:
# Different ranges use different implementations - generate dispatch
log.info(
"=== Different ranges selected different implementations ===",
)
log.info(
"=== Generating runtime dispatch with torch.cond ===",
)
# Generate torch.cond dispatch
result = _generate_range_dispatch_ir(
range_to_impl=range_to_impl,
tensor_name=tensor_name,
dim_index=dim_index,
args=args,
kwargs=kwargs,
op_overload=op_overload,
default_impl=default_impl,
)
validate_ir(result)
return result
return range_based_lowering_fn
def register_custom_op_autotuning(
custom_op: torch._library.custom_ops.CustomOpDef,
configs: Union[list[CustomOpConfig], list[Callable[..., Any]]],
@ -366,6 +1021,7 @@ def register_custom_op_autotuning(
input_gen_fns: Custom input generators for benchmarking
Examples:
# Standard autotuning
@torch.library.custom_op("mylib::attention", mutates_args=())
def my_attention(query, key, value, head_dim=32):
...
@ -383,6 +1039,21 @@ def register_custom_op_autotuning(
"value": lambda fake: torch.randn_like(fake, device='cuda'),
},
)
# Range-based autotuning
register_custom_op_autotuning(
my_op,
configs=[
# Range [0, 512): test 3 implementations
CustomOpConfig(impl1, tensor_name='x', dim_index=1, dim_range=(0, 512)),
CustomOpConfig(impl2, tensor_name='x', dim_index=1, dim_range=(0, 512)),
CustomOpConfig(impl3, tensor_name='x', dim_index=1, dim_range=(0, 512)),
# Range [512, inf): test 3 implementations
CustomOpConfig(impl1, tensor_name='x', dim_index=1, dim_range=(512, float('inf'))),
CustomOpConfig(impl2, tensor_name='x', dim_index=1, dim_range=(512, float('inf'))),
CustomOpConfig(impl3, tensor_name='x', dim_index=1, dim_range=(512, float('inf'))),
],
)
"""
from torch._library.custom_ops import CustomOpDef
@ -413,34 +1084,27 @@ def register_custom_op_autotuning(
if name is None:
name = f"{op_overload._name}_autotuned"
@functools.wraps(op_overload)
def autotuning_lowering(*args: Any, **kwargs: Any) -> Any:
"""Inductor lowering function that replaces custom op calls with autotuned versions."""
# Extract tensor inputs and non-tensor parameters (runtime kwargs)
tensor_inputs, runtime_kwargs = _extract_tensor_inputs(args, kwargs)
# Group configs by range and validate
range_groups = _group_configs_by_range(processed_configs)
_validate_range_groups(range_groups)
# Prepare decompositions and kwargs by merging config params with runtime kwargs
decompositions = []
non_tensor_args = []
# Detect if this is range-based autotuning
is_range_based = (None, None, None, None) not in range_groups
for cfg in processed_configs:
decomp = cfg.get_decomposition(default_impl=default_impl)
decompositions.append(decomp)
# Merge config params with runtime kwargs (runtime takes precedence)
merged_kwargs = _merge_config_and_runtime_kwargs(cfg.params, runtime_kwargs)
non_tensor_args.append(merged_kwargs)
result = autotune_custom_op(
name=name,
decompositions=decompositions,
inputs=tensor_inputs,
non_tensor_args=non_tensor_args,
op_overload=op_overload,
user_input_gen_fns=input_gen_fns,
if is_range_based:
log.debug(
"Detected range-based configs for %s. Using simplified autotuning for all configs.",
name,
)
validate_ir(result)
return result
# Create and register the lowering function
lowering_fn = _create_autotuning_lowering(
processed_configs=processed_configs,
default_impl=default_impl,
name=name,
op_overload=op_overload,
input_gen_fns=input_gen_fns,
is_range_based=is_range_based,
)
lowerings[op_overload] = autotuning_lowering
lowerings[op_overload] = lowering_fn

View File

@ -7099,13 +7099,19 @@ def sym_constrain_range(a, min=None, max=None):
@register_lowering(aten.sym_size.int)
def sym_size(a, dim):
val = V.graph.current_node.meta["val"]
return val.node.expr
if isinstance(val, torch.SymInt):
return val.node.expr
else:
return int(val)
@register_lowering(aten.sym_stride.int)
def sym_stride(a, dim):
val = V.graph.current_node.meta["val"]
return val.node.expr
if isinstance(val, torch.SymInt):
return val.node.expr
else:
return int(val)
@register_lowering(aten.sym_numel)

View File

@ -3607,24 +3607,13 @@ def user_autotune(
)
def foreach(triton_meta, filename=None, inductor_meta=None):
def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
"""
Compile a triton foreach kernel
"""
configs = []
# Naive autotuning path for num_warps
if not (
inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise")
):
configs.append(triton.Config({}, num_stages=1, num_warps=8))
else:
for warps in [1, 2, 4, 8]:
configs.append(triton.Config({}, num_stages=1, num_warps=warps))
return cached_autotune(
None,
configs,
[triton.Config({}, num_stages=1, num_warps=num_warps)],
triton_meta=triton_meta,
inductor_meta=inductor_meta,
heuristic_type=HeuristicType.TEMPLATE,

View File

@ -702,7 +702,7 @@ def exp2(a):
# CompositeImplicitAutograd - don't register decomp
@out_wrapper()
@elementwise_type_promotion_wrapper(
type_promoting_args=("a,"),
type_promoting_args=("a",),
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
)
def fill(a: TensorLikeType, value: NumberType) -> TensorLikeType:

View File

@ -1,6 +1,5 @@
#pragma once
#include <c10/util/Exception.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/stable/c/shim.h>
#include <torch/csrc/stable/device_struct.h>
@ -120,7 +119,7 @@ struct FromImpl<ScalarType> {
case ScalarType::UInt64:
return from(aoti_torch_dtype_uint64());
default:
TORCH_CHECK(
STD_TORCH_CHECK(
false,
"Not yet supported ScalarType, please file an issue describing your use case.");
}
@ -151,7 +150,7 @@ struct FromImpl<DeviceType> {
case DeviceType::PrivateUse1:
return from(aoti_torch_device_type_privateuse1());
default:
TORCH_CHECK(
STD_TORCH_CHECK(
false,
"Not yet supported DeviceType, please file an issue describing your use case.");
}
@ -379,7 +378,7 @@ struct ToImpl<ScalarType> {
} else if (shim_scalartype == aoti_torch_dtype_uint64()) {
return ScalarType::UInt64;
} else {
TORCH_CHECK(
STD_TORCH_CHECK(
false,
"Not yet supported ScalarType ",
std::to_string(shim_scalartype),
@ -409,7 +408,7 @@ struct ToImpl<DeviceType> {
} else if (shim_devicetype == aoti_torch_device_type_privateuse1()) {
return DeviceType::PrivateUse1;
} else {
TORCH_CHECK(
STD_TORCH_CHECK(
false,
"Not yet supported DeviceType ",
std::to_string(shim_devicetype),

View File

@ -2,7 +2,7 @@ from collections.abc import Callable
from copy import deepcopy
from enum import auto, Enum
from functools import partial, wraps
from typing import Any, NamedTuple, Optional, TypeVar, Union
from typing import Any, NamedTuple, Optional, TYPE_CHECKING, TypeVar, Union
from typing_extensions import ParamSpec, TypeVarTuple, Unpack
import torch
@ -17,6 +17,9 @@ from torch.utils._pytree import tree_map_only
from torch.utils.weak import WeakIdKeyDictionary, weakref
if TYPE_CHECKING:
from torch.utils.hooks import RemovableHandle
_TOTAL_KEY = "Total"
__all__ = ["FSDPMemTracker"]
@ -365,14 +368,28 @@ class FSDPMemTracker(MemTracker):
# `FSDPParamGroup.post_forward` because during AC these won't be called.
# TODO(@sanketpurandare): This will need to be modified after this PR (https://github.com/pytorch/pytorch/pull/127786)
# lands. For backward we monkey-patch the `FSDPParamGroup.pre_backward` and `FSDPParamGroup.post_backward`.
# pyrefly: ignore [missing-attribute]
# get the unique _MultiHandlers/RemoveHandlers and store in dictionary
# the _MultiHandlers object will only need to be grabbed once.
unique_handlers: dict[RemovableHandle, bool] = {}
# pyrefly: ignore # missing-attribute
for module in self._root_mod.modules():
if isinstance(module, FSDPModule):
fsdp_state = module._get_fsdp_state()
if fsdp_param_group := fsdp_state._fsdp_param_group:
if not unique_handlers.get(fsdp_state._pre_forward_hook_handle):
unique_handlers[fsdp_state._pre_forward_hook_handle] = True
if not unique_handlers.get(fsdp_state._post_forward_hook_handle):
unique_handlers[fsdp_state._post_forward_hook_handle] = True
# call remove on the handles once
for f_hook_handle in unique_handlers.keys():
f_hook_handle.remove()
# pyrefly: ignore # missing-attribute
for module in self._root_mod.modules():
if isinstance(module, FSDPModule):
fsdp_state = module._get_fsdp_state()
if fsdp_param_group := fsdp_state._fsdp_param_group:
self._instrument_fsdp_sharded_params_grads(fsdp_param_group)
fsdp_state._pre_forward_hook_handle.remove()
fsdp_state._post_forward_hook_handle.remove()
fsdp_state._pre_forward_hook_handle = (
# pyrefly: ignore [missing-attribute]
module.register_forward_pre_hook(

View File

@ -194,6 +194,10 @@ else:
_rank_map: Optional[torch.Tensor] = None,
_root_mesh: Optional["DeviceMesh"] = None,
) -> None:
# no-op in OSS, logs API usage metrics in meta-internal runs
torch._C._log_api_usage_once(
"torch.distributed.device_mesh.DeviceMesh.__init__"
)
if mesh is not None:
if _layout is not None or _rank_map is not None:
raise TypeError(
@ -255,14 +259,13 @@ else:
)
# private field to pre-generate DeviceMesh's hash
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
self._flatten_rank_map = tuple(self._rank_map.tolist())
self._thread_id = None
# Initialize instance-specific flatten mapping
self._flatten_mapping = {}
# Skip process group initialization if xla device or init backend is False
# TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
self._thread_id = None
if device_type != "xla":
# always try to create default (world) pg, even if it is not initialized
# already. The world pg is used for device mesh identity (rank) on each
@ -293,11 +296,6 @@ else:
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
)
# private field to pre-generate DeviceMesh's hash
self._flatten_rank_map = tuple(self._rank_map.tolist())
# Initialize instance-specific flatten mapping
self._flatten_mapping = {}
@property
def device_type(self) -> str:
"""Returns the device type of the mesh."""

View File

@ -359,6 +359,10 @@ class ShardingPropagator:
"""
Propagate the sharding for an operator given the op_schema.
"""
# no-op in OSS, logs API usage metrics in meta-internal runs
torch._C._log_api_usage_once(
"torch.distributed.tensor._sharding_prop.ShardingPropagator.propogate_op_sharding_non_cached"
)
# special case op, we don't need to propagate for local
# scalar. TODO: figure out a better way to handle this
if op_schema.op is aten._local_scalar_dense.default:

View File

@ -398,6 +398,9 @@ def load(
Under active development, saved files may not be usable in newer versions
of PyTorch.
.. warning::
:func:`torch.export.load()` uses pickle under the hood to load models. **Never load data from an untrusted source.**
Loads an :class:`ExportedProgram` previously saved with
:func:`torch.export.save <torch.export.save>`.

View File

@ -3,7 +3,7 @@ import dataclasses
import inspect
import logging
import sys
from collections import defaultdict
from collections import defaultdict, OrderedDict
from collections.abc import Callable
from enum import auto, Enum
from typing import Any, Optional, TYPE_CHECKING, Union
@ -721,7 +721,18 @@ def _combine_args(f, args, kwargs) -> dict[str, Any]:
else inspect.signature(f)
)
kwargs = kwargs if kwargs is not None else {}
return signature.bind(*args, **kwargs).arguments
combined_args = signature.bind(*args, **kwargs).arguments
# if `args` is in the key, flatten it into args_0, args_1, ...
if "args" in combined_args:
flattened_args = {f"args_{i}": v for i, v in enumerate(combined_args["args"])}
combined_args = OrderedDict({**combined_args, **flattened_args})
del combined_args["args"]
# flatten kwargs into combined_args
if "kwargs" in combined_args:
for k, v in combined_args["kwargs"].items():
combined_args[k] = v
del combined_args["kwargs"]
return combined_args
class ShapesCollection:

View File

@ -19,8 +19,13 @@ __all__: list[str] = [
"SDPBackend",
"sdpa_kernel",
"WARN_FOR_UNFUSED_KERNELS",
"register_flash_attention_impl",
"activate_flash_attention_impl",
"list_flash_attention_impls",
"current_flash_attention_impl",
]
# Note: [SDPA warnings]
# TODO: Consider using this for sdpa regardless of subclasses
# This only effects users of bias subclasses
@ -162,3 +167,23 @@ def _sdpa_kernel_variadic(*backends: SDPBackend):
def _get_flash_version() -> str:
"""This returns the closest matching tag for the flash attention backend"""
return "2.5.7"
from . import _registry
# Re-export registry types and functions for public API
_FlashAttentionImpl = _registry._FlashAttentionImpl
_RegisterFn = _registry._RegisterFn
register_flash_attention_impl = _registry.register_flash_attention_impl
activate_flash_attention_impl = _registry.activate_flash_attention_impl
list_flash_attention_impls = _registry.list_flash_attention_impls
current_flash_attention_impl = _registry.current_flash_attention_impl
register_flash_attention_impl.__module__ = __name__
activate_flash_attention_impl.__module__ = __name__
list_flash_attention_impls.__module__ = __name__
current_flash_attention_impl.__module__ = __name__
# Import built-in implementations to trigger self-registration
from . import _fa4 # noqa: F401

444
torch/nn/attention/_fa4.py Normal file
View File

@ -0,0 +1,444 @@
"""UBER PROTOTYPE!!!"""
# mypy: allow-untyped-defs
from __future__ import annotations
import importlib
from dataclasses import dataclass
from functools import cache
from typing import Any, TYPE_CHECKING
from typing_extensions import TypeVarTuple, Unpack
from . import _registry
if TYPE_CHECKING:
from types import ModuleType
import torch
from torch.library import Library
__all__ = [
"register_flash_attention_fa4",
]
_FA4_MODULE_PATH: str | None = None
@dataclass
class _FA4Handle:
library: Library | None
def remove(self) -> None:
self.library = None
@cache
def _get_device_major(device: torch.device) -> int:
major, _ = torch.cuda.get_device_capability(device)
return major
def register_flash_attention_fa4(
module_path: str = "flash_attn.cute.interface",
) -> _FA4Handle:
"""
Register FA4 flash attention kernels with the PyTorch dispatcher.
Args:
module_path: Python module path to the FA4 implementation.
"""
global _FA4_MODULE_PATH
_ = _fa4_import_module(module_path)
_FA4_MODULE_PATH = module_path
return _FA4Handle(_fa4_register_kernels())
@cache
def _fa4_import_module(module_path: str) -> ModuleType:
module = importlib.import_module(module_path)
if not hasattr(module, "_flash_attn_fwd") or not hasattr(module, "_flash_attn_bwd"):
raise RuntimeError(f"Module '{module_path}' does not expose FA4 kernels")
return module
def _fa4_register_kernels() -> Library:
lib = Library("aten", "IMPL", "CUDA") # noqa: TOR901
lib.impl("_flash_attention_forward", _fa4_flash_attention_forward_impl, "CUDA")
lib.impl("_flash_attention_backward", _fa4_flash_attention_backward_impl, "CUDA")
lib.impl(
"_scaled_dot_product_flash_attention",
_fa4_scaled_dot_product_flash_attention_forward_impl,
"CUDA",
)
lib.impl(
"_scaled_dot_product_flash_attention_backward",
_fa4_scaled_dot_product_flash_attention_backward_impl,
"CUDA",
)
return lib
def _fa4_common_support_error(
query: torch.Tensor,
tensors: tuple[torch.Tensor, ...],
cum_seq_q: torch.Tensor | None,
require_fp32: tuple[tuple[str, torch.Tensor], ...] = (),
) -> str | None:
if not all(t.is_cuda for t in tensors):
return "inputs must be CUDA tensors"
if len({t.device for t in tensors}) != 1:
return "inputs must share device"
if query.dtype not in (torch.float16, torch.bfloat16):
return "query dtype must be float16 or bfloat16"
for name, tensor in require_fp32:
if tensor.dtype != torch.float32:
return f"{name} dtype must be float32"
if cum_seq_q is None and query.dim() != 4:
return "dense query must be 4D"
if cum_seq_q is not None and query.dim() != 3:
return "ragged query must be 3D"
if not torch.cuda.is_available():
return "CUDA not available"
if _get_device_major(query.device) not in (9, 10):
return "FA4 requires compute capability 9.0 or 10.0"
return None
def _fa4_forward_support_error(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float,
return_debug_mask: bool,
alibi_slopes: torch.Tensor | None,
seqused_k: torch.Tensor | None,
cum_seq_q: torch.Tensor | None,
) -> str | None:
if dropout_p != 0.0:
return "dropout_p must be 0"
if return_debug_mask:
return "return_debug_mask must be False"
if alibi_slopes is not None:
return "alibi_slopes not supported"
if seqused_k is not None:
if seqused_k.dtype != torch.int32:
return "seqused_k must be int32"
if not seqused_k.is_cuda:
return "seqused_k must be CUDA"
error = _fa4_common_support_error(
query,
(query, key, value),
cum_seq_q,
)
if error is not None:
if error == "inputs must share device":
return "query, key, value must be on same device"
return error
return None
def _fa4_backward_support_error(
grad_out: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
logsumexp: torch.Tensor,
dropout_p: float,
cum_seq_q: torch.Tensor | None,
window_size_left: int | None,
window_size_right: int | None,
) -> str | None:
if dropout_p != 0.0:
return "dropout_p must be 0"
if window_size_left is not None or window_size_right is not None:
return "windowed attention not supported"
error = _fa4_common_support_error(
query,
(grad_out, query, key, value, out, logsumexp),
cum_seq_q,
require_fp32=(("logsumexp", logsumexp),),
)
if error is not None:
return error
return None
Ts = TypeVarTuple("Ts")
def _transpose_dense(*tensors: Unpack[Ts]) -> tuple[Unpack[Ts]]:
return tuple(t.transpose(1, 2) for t in tensors) # type: ignore[attr-defined]
def _fa4_run_forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seq_q: torch.Tensor | None,
cu_seq_k: torch.Tensor | None,
scale: float | None,
is_causal: bool,
window_size_left: int | None,
window_size_right: int | None,
seqused_k: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
if _FA4_MODULE_PATH is None:
raise RuntimeError("FA4 not registered")
module = _fa4_import_module(_FA4_MODULE_PATH)
kwargs: dict[str, Any] = {
"softmax_scale": scale,
"causal": is_causal,
"window_size_left": window_size_left,
"window_size_right": window_size_right,
"return_lse": True,
"cu_seqlens_q": cu_seq_q,
"cu_seqlens_k": cu_seq_k,
"seqused_k": seqused_k.contiguous() if seqused_k is not None else None,
}
out, lse = module._flash_attn_fwd(query, key, value, **kwargs)
return out, lse.contiguous()
def _fa4_run_backward(
grad_out: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
logsumexp: torch.Tensor,
cu_seq_q: torch.Tensor | None,
cu_seq_k: torch.Tensor | None,
scale: float | None,
is_causal: bool,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if _FA4_MODULE_PATH is None:
raise RuntimeError("FA4 not registered")
module = _fa4_import_module(_FA4_MODULE_PATH)
dq, dk, dv = module._flash_attn_bwd(
query,
key,
value,
out,
grad_out,
logsumexp.contiguous(),
softmax_scale=scale,
causal=is_causal,
cu_seqlens_q=cu_seq_q,
cu_seqlens_k=cu_seq_k,
)
return dq, dk, dv
def _fa4_flash_attention_forward_impl(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cum_seq_q: torch.Tensor | None,
cum_seq_k: torch.Tensor | None,
max_q: int,
max_k: int,
dropout_p: float,
is_causal: bool,
return_debug_mask: bool,
*,
scale: float | None = None,
window_size_left: int | None = None,
window_size_right: int | None = None,
seqused_k: torch.Tensor | None = None,
alibi_slopes: torch.Tensor | None = None,
):
error = _fa4_forward_support_error(
query,
key,
value,
dropout_p,
return_debug_mask,
alibi_slopes,
seqused_k,
cum_seq_q,
)
if error is not None:
raise RuntimeError(f"FA4 flash_attention forward unsupported: {error}")
out, lse = _fa4_run_forward(
query,
key,
value,
cum_seq_q,
cum_seq_k,
scale,
is_causal,
window_size_left,
window_size_right,
seqused_k,
)
rng_state = torch.zeros((2,), dtype=torch.uint64, device=query.device)
philox_offset = torch.zeros((), dtype=torch.uint64, device=query.device)
debug_mask = torch.empty(0, dtype=query.dtype, device=query.device)
return out, lse, rng_state, philox_offset, debug_mask
def _fa4_flash_attention_backward_impl(
grad_out: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
logsumexp: torch.Tensor,
cum_seq_q: torch.Tensor | None,
cum_seq_k: torch.Tensor | None,
max_q: int,
max_k: int,
dropout_p: float,
is_causal: bool,
rng_state: torch.Tensor,
unused: torch.Tensor,
*,
scale: float | None = None,
window_size_left: int | None = None,
window_size_right: int | None = None,
):
error = _fa4_backward_support_error(
grad_out,
query,
key,
value,
out,
logsumexp,
dropout_p,
cum_seq_q,
window_size_left,
window_size_right,
)
if error is not None:
raise RuntimeError(f"FA4 flash_attention backward unsupported: {error}")
dq, dk, dv = _fa4_run_backward(
grad_out,
query,
key,
value,
out,
logsumexp,
cum_seq_q,
cum_seq_k,
scale,
is_causal,
)
return dq, dk, dv
def _fa4_scaled_dot_product_flash_attention_forward_impl(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
is_causal: bool = False,
return_debug_mask: bool = False,
*,
scale: float | None = None,
):
error = _fa4_forward_support_error(
query,
key,
value,
dropout_p,
return_debug_mask,
None,
None,
None,
)
if error is not None:
raise RuntimeError(f"FA4 SDPA forward unsupported: {error}")
q, k, v = _transpose_dense(query, key, value)
max_q_flash = q.size(1)
max_k_flash = k.size(1)
out, lse, rng_state, philox_offset, debug_mask = _fa4_flash_attention_forward_impl(
q,
k,
v,
None,
None,
max_q_flash,
max_k_flash,
dropout_p,
is_causal,
return_debug_mask,
scale=scale,
)
(out,) = _transpose_dense(out)
max_q = query.size(2)
max_k = key.size(2)
return (
out,
lse,
None,
None,
max_q,
max_k,
rng_state,
philox_offset,
debug_mask,
)
def _fa4_scaled_dot_product_flash_attention_backward_impl(
grad_out: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
logsumexp: torch.Tensor,
cum_seq_q: torch.Tensor | None,
cum_seq_k: torch.Tensor | None,
max_q: int,
max_k: int,
dropout_p: float,
is_causal: bool,
philox_seed: torch.Tensor,
philox_offset: torch.Tensor,
*,
scale: float | None = None,
):
error = _fa4_backward_support_error(
grad_out,
query,
key,
value,
out,
logsumexp,
dropout_p,
None,
None,
None,
)
if error is not None:
raise RuntimeError(f"FA4 SDPA backward unsupported: {error}")
q, k, v, o, go = _transpose_dense(query, key, value, out, grad_out)
max_q = query.size(2)
max_k = key.size(2)
dq, dk, dv = _fa4_flash_attention_backward_impl(
go,
q,
k,
v,
o,
logsumexp,
None,
None,
max_q,
max_k,
dropout_p,
is_causal,
philox_seed,
philox_offset,
scale=scale,
)
dq, dk, dv = _transpose_dense(dq, dk, dv)
return dq, dk, dv
_registry.register_flash_attention_impl("FA4", register_fn=register_flash_attention_fa4)

View File

@ -0,0 +1,108 @@
# mypy: allow-untyped-defs
"""Registry for flash attention implementations.
This module contains the registration system for flash attention implementations.
It has no torch dependencies to avoid circular imports during initialization.
"""
from typing import Callable, Literal, Protocol
class FlashAttentionHandle(Protocol):
def remove(self) -> None: ...
_RegisterFn = Callable[..., FlashAttentionHandle | None]
_FlashAttentionImpl = Literal["FA4"]
_FLASH_ATTENTION_IMPLS: dict[str, _RegisterFn] = {}
_FLASH_ATTENTION_ACTIVE: str | None = None
_FLASH_ATTENTION_HANDLES: dict[str, FlashAttentionHandle] = {}
def register_flash_attention_impl(
impl: str | _FlashAttentionImpl,
*,
register_fn: _RegisterFn,
) -> None:
"""
Register the callable that activates a flash attention impl.
.. note::
This function is intended for SDPA backend providers to register their
implementations. End users should use :func:`activate_flash_attention_impl`
to activate a registered implementation.
Args:
impl: Implementation identifier (e.g., ``"FA4"``).
register_fn: Callable that performs the actual dispatcher registration.
This function will be invoked by :func:`activate_flash_attention_impl`
and should register custom kernels with the PyTorch dispatcher.
It may optionally return a handle implementing
:class:`FlashAttentionHandle` to keep any necessary state alive.
Example:
>>> def my_impl_register(module_path: str = "my_flash_impl"):
... # Register custom kernels with torch dispatcher
... pass # doctest: +SKIP
>>> register_flash_attention_impl(
... "MyImpl", register_fn=my_impl_register
... ) # doctest: +SKIP
"""
_FLASH_ATTENTION_IMPLS[impl] = register_fn
def activate_flash_attention_impl(
impl: str | _FlashAttentionImpl,
) -> None:
"""
Activate into the dispatcher a previously registered flash attention impl.
.. note::
Backend providers should NOT automatically activate their implementation
on import. Users should explicitly opt-in by calling this function or via
environment variables to ensure multiple provider libraries can coexist.
Args:
impl: Implementation identifier to activate. See
:func:`~torch.nn.attention.list_flash_attention_impls` for available
implementations.
If the backend's :func:`register_flash_attention_impl` callable
returns a :class:`FlashAttentionHandle`, the registry keeps that
handle alive for the lifetime of the process (until explicit
uninstall support exists).
Example:
>>> activate_flash_attention_impl("FA4") # doctest: +SKIP
"""
global _FLASH_ATTENTION_ACTIVE
register_fn = _FLASH_ATTENTION_IMPLS.get(impl)
if register_fn is None:
raise ValueError(
f"Unknown flash attention impl '{impl}'. "
f"Available implementations: {list_flash_attention_impls()}"
)
# TODO: The only way to actually register a new impl is to unregister the current impl
# reinstall the default impl and then register the new impl
if _FLASH_ATTENTION_ACTIVE == impl:
return
handle = register_fn()
if handle is not None:
_FLASH_ATTENTION_HANDLES[impl] = handle
_FLASH_ATTENTION_ACTIVE = impl
def list_flash_attention_impls() -> list[str]:
"""Return the names of all available flash attention implementations."""
return sorted(_FLASH_ATTENTION_IMPLS.keys())
def current_flash_attention_impl() -> str | None:
"""
Return the currently activated flash attention impl name, if any.
``None`` indicates that no custom impl has been activated.
"""
return _FLASH_ATTENTION_ACTIVE