Compare commits

..

3 Commits

Author SHA1 Message Date
19e52556fa Fix unintended updates to submodules 2025-11-07 08:57:52 -08:00
1d43f171d6 Fix signals 2025-11-07 06:51:15 -08:00
910471526d Type functions 2025-11-07 06:49:52 -08:00
50 changed files with 1963 additions and 1878 deletions

View File

@ -337,7 +337,7 @@ test_python() {
test_python_smoke() {
# Smoke tests for H100/B200
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune inductor/test_cutedsl_grouped_mm $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
assert_git_not_dirty
}

9
.github/labeler.yml vendored
View File

@ -138,8 +138,7 @@
- test/test_matmul_cuda.py
- test/test_scaled_matmul_cuda.py
- test/inductor/test_fp8.py
- aten/src/ATen/native/cuda/*Blas.cpp
- aten/src/ATen/cuda/CUDA*Blas.*
- aten/src/ATen/native/cuda/Blas.cpp
- torch/**/*cublas*
- torch/_inductor/kernel/mm.py
- test/inductor/test_max_autotune.py
@ -149,8 +148,7 @@
- test/test_matmul_cuda.py
- test/test_scaled_matmul_cuda.py
- test/inductor/test_fp8.py
- aten/src/ATen/native/cuda/*Blas.cpp
- aten/src/ATen/cuda/CUDA*Blas.*
- aten/src/ATen/native/cuda/Blas.cpp
- torch/**/*cublas*
- torch/_inductor/kernel/mm.py
- test/inductor/test_max_autotune.py
@ -160,8 +158,7 @@
- test/test_matmul_cuda.py
- test/test_scaled_matmul_cuda.py
- test/inductor/test_fp8.py
- aten/src/ATen/native/cuda/*Blas.cpp
- aten/src/ATen/cuda/CUDA*Blas.*
- aten/src/ATen/native/cuda/Blas.cpp
- torch/_inductor/kernel/mm.py
- test/inductor/test_max_autotune.py
- third_party/fbgemm

1
.gitignore vendored
View File

@ -127,6 +127,7 @@ torch/test/
torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h
torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h
torch/version.py
torch/_inductor/kernel/vendored_templates/*
minifier_launcher.py
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d*
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d*

View File

@ -210,12 +210,8 @@ torch/backends/cudnn/ @eqy @syed-ahmed @Aidyn-A
/test/inductor/test_flex_attention.py @drisspg
/test/inductor/test_flex_decoding.py @drisspg
# Low Precision & Grouped GEMMs
# Low Precision GEMMs
/aten/src/ATen/native/cuda/Blas.cpp @drisspg @slayton58
/aten/src/ATen/native/cuda/GroupedBlas.cpp @drisspg @slayton58
/aten/src/ATen/native/cuda/ScaledBlas.cpp @drisspg @slayton58
/aten/src/ATen/cuda/CUDABlas.cpp @drisspg @slayton58
/aten/src/ATen/cuda/CUDABlas.h @drisspg @slayton58
/aten/src/ATen/cuda/CUDAScaledBlas.cpp @drisspg @slayton58
/aten/src/ATen/cuda/CUDAScaledBlas.h @drisspg @slayton58
/test/test_scaled_matmul_cuda.py @drisspg @slayton58

View File

@ -14,10 +14,6 @@ Utils
sdpa_kernel
SDPBackend
register_flash_attention_impl
activate_flash_attention_impl
list_flash_attention_impls
current_flash_attention_impl
Submodules
----------

View File

@ -630,6 +630,37 @@ def mirror_files_into_torchgen() -> None:
raise RuntimeError("Check the file paths in `mirror_files_into_torchgen()`")
def mirror_inductor_external_kernels() -> None:
"""
Copy external kernels into Inductor so they are importable.
"""
paths = [
(
CWD / "torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py",
CWD
/ "third_party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py",
),
]
for new_path, orig_path in paths:
# Create the dirs involved in new_path if they don't exist
if not new_path.exists():
new_path.parent.mkdir(parents=True, exist_ok=True)
# Copy the files from the orig location to the new location
if orig_path.is_file():
shutil.copyfile(orig_path, new_path)
continue
if orig_path.is_dir():
if new_path.exists():
# copytree fails if the tree exists already, so remove it.
shutil.rmtree(new_path)
shutil.copytree(orig_path, new_path)
continue
raise RuntimeError(
"Check the file paths in `mirror_inductor_external_kernels()`"
)
# ATTENTION: THIS IS AI SLOP
def extract_variant_from_version(version: str) -> str:
"""Extract variant from version string, defaulting to 'cpu'."""
@ -1616,6 +1647,8 @@ def main() -> None:
if RUN_BUILD_DEPS:
build_deps()
mirror_inductor_external_kernels()
(
ext_modules,
cmdclass,
@ -1649,6 +1682,7 @@ def main() -> None:
"_inductor/codegen/aoti_runtime/*.cpp",
"_inductor/script.ld",
"_inductor/kernel/flex/templates/*.jinja",
"_inductor/kernel/templates/*.jinja",
"_export/serde/*.yaml",
"_export/serde/*.thrift",
"share/cmake/ATen/*.cmake",

View File

@ -76,7 +76,7 @@ class ReplicateTest(MultiProcessTestCase):
store=dist.FileStore(self.file_name, self.world_size),
)
@skip_if_lt_x_gpu(4)
@skip_if_lt_x_gpu(2)
def test_replicate_transformer(self):
"""
This tests that replicate works on a transformer model with fully_shard and replicate layers
@ -126,7 +126,7 @@ class ReplicateTest(MultiProcessTestCase):
for parameter in layer.parameters():
self.assertEqual(parameter.placements, (Shard(dim=0),))
@skip_if_lt_x_gpu(4)
@skip_if_lt_x_gpu(2)
def test_replicate_transformer_managed_modules(self):
"""
This tests that replicate managed modules works properly. In this test we use a Transformer Module with 3 layers,
@ -178,7 +178,7 @@ class ReplicateTest(MultiProcessTestCase):
replicate_model = replicate(replicate_model)
self.assertEqual(len(_get_managed_modules((replicate_model,))), 21)
@skip_if_lt_x_gpu(4)
@skip_if_lt_x_gpu(2)
def test_replicate_tp_device_mesh(self):
"""
This tests that a user can pass in a device mesh to replicate a module
@ -206,7 +206,7 @@ class ReplicateTest(MultiProcessTestCase):
self.assertEqual(parameter.device_mesh.shape, (2,))
self.assertEqual(parameter.placements, (Replicate(),))
@skip_if_lt_x_gpu(4)
@skip_if_lt_x_gpu(2)
def test_train_replicate_fsdp(self):
"""
Tests that replicate_model has the same behavior as original model when training
@ -253,7 +253,7 @@ class ReplicateTest(MultiProcessTestCase):
self.assertEqual(replicate_loss, loss)
check_sharded_parity(self, model, replicate_model)
@skip_if_lt_x_gpu(4)
@skip_if_lt_x_gpu(2)
def test_train_parity_2d_mlp(self):
"""
Verifies when a device mesh is passed in, the model has the same behavior as the original model when training

View File

@ -299,7 +299,7 @@ class TestDTensorReshardMeshChange(DTensorTestBase):
@with_comms
@with_temp_dir
@skip_if_lt_x_gpu(4)
@skip_if_lt_x_gpu(2)
def test_dtensor_checkpoint_with_uneven_shards(self) -> None:
"""
Saving a dtensor with uneven shards.
@ -436,7 +436,6 @@ class TestCheckpointableReshard(DTensorTestBase):
@with_comms
@with_temp_dir
@skip_if_lt_x_gpu(4)
def test_uneven_reshard_with_checkpointable_api(self) -> None:
"""
Saves a 1d distributed tensor that has shards with uneven sizes using Checkpointable API.
@ -499,7 +498,6 @@ class TestCheckpointableReshard(DTensorTestBase):
@with_comms
@with_temp_dir
@skip_if_lt_x_gpu(4)
def test_uneven_reshard_with_dtensor_shards_wrapper_api(self) -> None:
"""
Saves a 1d distributed tensor that has shards with uneven sizes using Checkpointable API.

View File

@ -886,7 +886,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
self.assertEqual(cpu_model_value, meta_model_value)
@with_comms
@skip_if_lt_x_gpu(4)
@skip_if_lt_x_gpu(2)
def test_setting_meta_device_model_broadcasting_and_memory(self) -> None:
# This test verifies that we can set model state dict by a meta device model
# With the correlated changes in state_dict, meta device model should be accepted

View File

@ -39,7 +39,6 @@ from torch.nn.modules.loss import MSELoss
from torch.testing._internal.common_distributed import (
MultiProcContinuousTest,
requires_accelerator_dist_backend,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import (
check_leaked_tensors,
@ -232,7 +231,6 @@ class ScheduleTest(MultiProcContinuousTest):
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [_ScheduleForwardOnly])
@skip_if_lt_x_gpu(4)
def test_forward_only(self, ScheduleClass):
mod, mod_ref, x, _, _ = setup_models_and_data(self.config)
x_clone = x.clone()
@ -276,7 +274,6 @@ class ScheduleTest(MultiProcContinuousTest):
ScheduleInterleavedZeroBubble,
],
)
@skip_if_lt_x_gpu(4)
def test_eval_inference_mode(self, ScheduleClass):
num_microbatches = 4
if ScheduleClass in [
@ -354,7 +351,6 @@ class ScheduleTest(MultiProcContinuousTest):
ScheduleInterleavedZeroBubble,
],
)
@skip_if_lt_x_gpu(4)
def test_return_output(self, ScheduleClass):
num_microbatches = 4
if ScheduleClass in [
@ -410,7 +406,6 @@ class ScheduleTest(MultiProcContinuousTest):
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
@skip_if_lt_x_gpu(4)
def test_multi_iter(self, ScheduleClass):
mod, _, x, target, loss_fn = setup_models_and_data(self.config)
chunks = 4
@ -434,7 +429,6 @@ class ScheduleTest(MultiProcContinuousTest):
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
@skip_if_lt_x_gpu(4)
def test_kwargs_with_tracer(self, ScheduleClass):
mod = ModelWithKwargs(d_hid, splits=self.world_size)
mod.to(self.device)
@ -487,7 +481,6 @@ class ScheduleTest(MultiProcContinuousTest):
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
@skip_if_lt_x_gpu(4)
def test_grad_with_tracer(self, ScheduleClass):
mod, ref_mod, x, target, loss_fn = setup_models_and_data(self.config)
@ -530,7 +523,6 @@ class ScheduleTest(MultiProcContinuousTest):
)
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
@parametrize("shape_inference", [True, False])
@skip_if_lt_x_gpu(4)
def test_grad_with_manual(self, ScheduleClass, shape_inference):
mod, ref_mod, x, target, loss_fn = setup_models_and_data(self.config)
@ -594,7 +586,6 @@ class ScheduleTest(MultiProcContinuousTest):
ScheduleInterleavedZeroBubble,
],
)
@skip_if_lt_x_gpu(4)
def test_grad_with_manual_interleaved(self, ScheduleClass):
stages_per_rank = 2
n_stages = stages_per_rank * self.world_size
@ -659,7 +650,6 @@ class ScheduleTest(MultiProcContinuousTest):
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleInterleavedZeroBubble])
@skip_if_lt_x_gpu(4)
def test_schedule_with_weight_update_mlp_e2e(self, ScheduleClass):
stages_per_rank = 2
n_stages = stages_per_rank * self.world_size
@ -746,7 +736,6 @@ class ScheduleTest(MultiProcContinuousTest):
"schedule_class",
[ScheduleZBVZeroBubble, ScheduleDualPipeV],
)
@skip_if_lt_x_gpu(4)
def test_v_shape_schedules(self, schedule_class):
n_stages = 8
rank_stages = {0: [0, 7], 1: [1, 6], 2: [2, 5], 3: [3, 4]}
@ -791,7 +780,6 @@ class ScheduleTest(MultiProcContinuousTest):
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@skip_if_lt_x_gpu(4)
def test_custom_function_callback(self):
"""Test the custom function callback functionality with _PipelineScheduleRuntime."""
n_stages = 8
@ -991,7 +979,6 @@ class ScheduleTest(MultiProcContinuousTest):
"ScheduleClass",
[ScheduleInterleavedZeroBubble, ScheduleInterleaved1F1B],
)
@skip_if_lt_x_gpu(4)
def test_zero_bubble_with_model_kwargs(self, ScheduleClass):
stages_per_rank = 2
n_stages = stages_per_rank * self.world_size
@ -1085,7 +1072,6 @@ class CustomSchedulesTest(MultiProcContinuousTest):
"schedule_class",
[ScheduleVShaped, ScheduleUnbalanced],
)
@skip_if_lt_x_gpu(4)
def test_non_symmetric_stage_ids(self, schedule_class):
n_stages = schedule_class.n_stages
rank_stages = schedule_class.rank_stages
@ -1135,7 +1121,6 @@ class CustomSchedulesTest(MultiProcContinuousTest):
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleWithReorderedB])
@skip_if_lt_x_gpu(4)
def test_pipeline_schedule_runtime_custom_sched(self, ScheduleClass):
n_stages = 2
stages_per_rank = 1
@ -1196,7 +1181,6 @@ class CustomSchedulesTest(MultiProcContinuousTest):
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleWithW])
@skip_if_lt_x_gpu(4)
def test_schedule_with_native_zero_bubble(self, ScheduleClass):
n_stages = ScheduleClass.n_stages
num_microbatches = ScheduleClass.num_microbatches

View File

@ -26,7 +26,6 @@ from torch.distributed.tensor.parallel import (
RowwiseParallel,
SequenceParallel,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
create_local_tensor_test_class,
@ -765,7 +764,6 @@ class DistMathOpsTest(DTensorTestBase):
self.assertEqual(grad1_norm.device_mesh, mesh_y)
@with_comms
@skip_if_lt_x_gpu(4)
def test_foreach_add_different_mesh(self):
mesh_shape = (2, self.world_size // 2)
mesh_2d = init_device_mesh(

View File

@ -577,7 +577,7 @@ class DistTensorReplicateStrategyRegistrationTest(DTensorTestBase):
self.assertEqual(
comm_mode.get_comm_counts(),
{
torch.ops.c10d_functional.all_gather_into_tensor: self.world_size,
torch.ops.c10d_functional.all_gather_into_tensor: 4,
},
)
expected_cost = [

View File

@ -485,7 +485,7 @@ elif TEST_XPU:
def exit_if_lt_x_accelerators(x):
if torch.accelerator.is_available():
if torch.accelerator.device_count() < x:
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
sys.exit(TEST_SKIPS[f"multi-accelerator-{x}"].exit_code)
def with_comms(func=None):

View File

@ -2109,52 +2109,6 @@ Detected recompile when torch.compile stance is 'fail_on_recompile'. filename: '
with self.assertRaises(Unsupported):
outer_f2(inp)
def test_disable_recursive_flags(self):
class SimpleLinear(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.layer0 = torch.nn.Linear(4, 4)
def forward(self, inp):
return self.layer0(torch.sigmoid(inp))
class SimpleModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.layer0 = SimpleLinear()
self.layer1 = torch.nn.Linear(4, 4)
def forward(self, inp):
z = self.layer0(torch.sin(inp))
return self.layer1(z)
for recursive_flag in [True, False]:
model = SimpleModel()
other_model = SimpleModel()
model.forward = torch._dynamo.disable(
model.forward,
recursive=recursive_flag,
)
self.assertEqual(
torch._dynamo.is_dynamo_disable_recursive(model.forward),
recursive_flag,
)
other_model = torch._dynamo.disable(other_model, recursive=recursive_flag)
self.assertEqual(
torch._dynamo.is_dynamo_disable_recursive(
other_model.forward
if isinstance(other_model, torch.nn.Module)
else other_model
),
recursive_flag,
)
# check the model is compilable
torch.compile(model)
torch.compile(other_model)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -422,41 +422,34 @@ from user code:
import optree
@torch.compile(backend="eager")
def fn1(x):
tree = {"a": x, "b": (x - 1, 2 * x)}
sin, cos = optree.tree_transpose_map(
lambda t: (torch.sin(t), torch.cos(t)),
tree,
def fn(x):
d = {"a": 1}
optree.tree_flatten_with_path(d)
return torch.sin(x)
def post_munge(s):
s = re.sub(
r"optree\.\S*\.flatten_with_path",
"optree.<path>.flatten_with_path",
s,
)
return sin, cos
fn1(torch.randn(4))
self.assertEqual(len(counters["graph_break"]), 0)
@torch.compile(backend="eager")
def fn2(x):
spec = optree.treespec_deque([])
return spec, x
fn2(torch.randn(4))
self.assertGreaterEqual(len(counters["graph_break"]), 1)
first_graph_break = next(iter(counters["graph_break"].keys()))
def post_munge(string):
return re.sub(
r"(optree\.|qualname: )\S*(\.make_from_collection)",
r"\1<path>\2",
string,
r"qualname: \S*flatten_with_path",
"qualname: <path>.flatten_with_path",
s,
)
fn(torch.randn(4))
self.assertEqual(len(counters["graph_break"]), 1)
first_graph_break = next(iter(counters["graph_break"].keys()))
self.assertExpectedInline(
post_munge(first_graph_break),
"""\
Attempted to call function marked as skipped
Explanation: Dynamo cannot trace optree C/C++ function optree.<path>.make_from_collection.
Explanation: Dynamo cannot trace optree C/C++ function optree.<path>.flatten_with_path.
Hint: Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py
Developer debug context: module: optree._C, qualname: <path>.make_from_collection, skip reason: <missing reason>
Developer debug context: module: optree._C, qualname: <path>.flatten_with_path, skip reason: <missing reason>
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html""",
)

View File

@ -0,0 +1,154 @@
# Owner(s): ["module: inductor"]
import unittest
import torch
from torch import Tensor
from torch._inductor import config
from torch._inductor.codegen.cuda.cuda_env import is_datacenter_blackwell_arch
from torch._inductor.test_case import run_tests, TestCase as InductorTestCase
from torch._inductor.utils import ensure_cute_available
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
)
@unittest.skipIf(
not (ensure_cute_available() and is_datacenter_blackwell_arch()),
"CuTeDSL library or Blackwell device not available",
)
@instantiate_parametrized_tests
class TestCuTeDSLGroupedGemm(InductorTestCase):
def _get_inputs(
self,
group_size: int,
M_hint: int,
K: int,
N: int,
device: str,
dtype: torch.dtype,
alignment: int = 16,
) -> tuple[Tensor, Tensor, Tensor]:
# --- Random, tile-aligned M sizes ---
M_sizes = (
torch.randint(1, (M_hint // alignment) + 1, (group_size,), dtype=torch.int)
* alignment
)
M_total = torch.sum(M_sizes).item()
# --- Construct input tensors ---
A = torch.randn(int(M_total), K, dtype=dtype, device=device) * 0.1
B = torch.randn((group_size, K, N), dtype=dtype, device=device) * 0.01
# --- Build offsets (no leading zero, strictly increasing) ---
offsets = torch.cumsum(M_sizes, dim=0).to(dtype=torch.int32, device=device)
return (A, B, offsets)
@parametrize("group_size", (2, 8))
@parametrize("M_hint", (256, 1024))
@parametrize("K", (64, 128))
@parametrize("N", (128, 256))
def test_grouped_gemm_basic(self, group_size: int, M_hint: int, K: int, N: int):
device = "cuda"
dtype = torch.bfloat16
A, B, offsets = self._get_inputs(group_size, M_hint, K, N, device, dtype)
def grouped_gemm_fn(A_packed, B_batched, offs):
return torch._grouped_mm(A_packed, B_batched, offs=offs)
# Eager execution
c_eager = grouped_gemm_fn(A, B, offsets)
# Test with Cute backend
with config.patch(
{
"max_autotune": True,
"max_autotune_gemm_backends": "CUTEDSL",
"test_configs.autotune_choice_name_regex": "cutedsl",
"autotune_fallback_to_aten": False,
}
):
grouped_gemm_compiled = torch.compile(
grouped_gemm_fn, backend="inductor", dynamic=False
)
c_compiled = grouped_gemm_compiled(A, B, offsets)
self.assertEqual(c_eager.dtype, dtype)
self.assertEqual(c_compiled.dtype, dtype)
torch.testing.assert_close(c_eager, c_compiled)
@parametrize("layout_A", ("contiguous", "offset", "padded", "view"))
@parametrize("layout_B", ("contiguous", "broadcasted"))
def test_grouped_gemm_assorted_layouts(
self,
layout_A: str,
layout_B: str,
):
device = "cuda"
dtype = torch.bfloat16
G, K, N = 8, 64, 128
M_sizes = [128] * G
sum_M = sum(M_sizes)
offsets = torch.tensor(
[sum(M_sizes[: i + 1]) for i in range(G)], dtype=torch.int32, device=device
)
A_base = torch.randn(sum_M, K, device=device, dtype=dtype)
A = A_base
if layout_A == "offset":
# allocate bigger buffer than needed, use nonzero storage offset
storage = torch.randn(sum_M * K + 512, device=device, dtype=dtype)
offset = 128 # skip first 128 elements
A = torch.as_strided(storage[offset:], (sum_M, K), (K, 1))
elif layout_A == "padded":
# simulate row pitch > K (row_stride = K + pad)
row_pitch = K + 8
storage = torch.randn(sum_M * row_pitch, device=device, dtype=dtype)
A = torch.as_strided(storage, (sum_M, K), (row_pitch, 1))
elif layout_A == "view":
A_storage = torch.randn(sum_M * K, device=device, dtype=dtype)
A = A_storage.view(sum_M, K)
assert A._base is not None
assert A.shape == (sum_M, K)
B = torch.randn((G, K, N), dtype=dtype, device=device) * 0.01
if layout_B == "broadcasted":
# Broadcast B across groups (zero stride along G)
B = B[0].expand(G, K, N)
assert B.stride(0) == 0
def grouped_gemm_fn(A_packed, B_batched, offs):
return torch._grouped_mm(A_packed, B_batched, offs=offs)
# --- eager ---
c_eager = grouped_gemm_fn(A, B, offsets)
# --- compiled (CUTE backend) ---
with config.patch(
{
"max_autotune": True,
"max_autotune_gemm_backends": "CUTEDSL",
"test_configs.autotune_choice_name_regex": "cutedsl",
"autotune_fallback_to_aten": False,
}
):
grouped_gemm_compiled = torch.compile(
grouped_gemm_fn, backend="inductor", dynamic=False
)
c_compiled = grouped_gemm_compiled(A, B, offsets)
self.assertEqual(c_eager.dtype, dtype)
self.assertEqual(c_compiled.dtype, dtype)
torch.testing.assert_close(c_eager, c_compiled)
if __name__ == "__main__":
run_tests()

View File

@ -1913,29 +1913,6 @@ class TestMaxAutotune(TestCase):
# Check that contiguous transform was used
FileCheck().check("contiguous_mm").run(code[0])
@unittest.skipIf(config.cpp_wrapper, "out_dtype override not supported for AOTI")
@unittest.skipIf(TEST_WITH_ROCM, "out_dtype override only available on NVIDIA")
def test_bmm_out_dtype(self):
def f(a, b):
return torch.bmm(a, b, out_dtype=torch.float32)
a = torch.randn(2, 3, 4, device=GPU_TYPE, dtype=torch.float16)
b = torch.randn(2, 4, 5, device=GPU_TYPE, dtype=torch.float16)
with config.patch(
max_autotune=True,
max_autotune_gemm_backends="TRITON",
):
compiled_f = torch.compile(f)
with self.assertRaisesRegex(
torch._inductor.exc.InductorError,
r"LoweringException: NoValidChoicesError: No choices to select",
):
out, code = run_and_get_code(compiled_f, a, b)
compiled_f = torch.compile(f)
out, code = run_and_get_code(compiled_f, a, b)
FileCheck().check("extern_kernels.bmm_dtype").run(code[0])
def test_triton_template_generated_code_cache_key(self):
generate_and_load_args = len(
inspect.signature(

View File

@ -1217,43 +1217,6 @@ class TestPatternMatcher(TestCase):
_, (code) = run_and_get_code(fn2, args[0], args[1], args[2])
FileCheck().check_not("extern_kernels.addmm(").run(code[0])
def test_addmm_alpha_beta_with_pointwise(self):
# Test that addmm with alpha/beta != 1 is unfused correctly with pointwise ops
# See https://github.com/pytorch/pytorch/issues/167313
x = torch.rand(2, device=GPU_TYPE)
a = torch.rand(2, 3, device=GPU_TYPE)
b = torch.rand(3, 2, device=GPU_TYPE)
def f(x, a, b):
return torch.nn.functional.relu(torch.addmm(x, a, b, alpha=0.8, beta=0.2))
fc = torch.compile(f)
expected = f(x, a, b)
actual = fc(x, a, b)
# The compiled version should produce the same result as eager
torch.testing.assert_close(actual, expected)
# Verify that addmm is unfused (should not use extern_kernels.addmm)
# The pattern should be replaced with beta * x + alpha * (a @ b)
_, (code) = run_and_get_code(fc, x, a, b)
FileCheck().check_not("extern_kernels.addmm(").run(code[0])
# Test with alpha=1, beta=1 (default) - should also unfuse
def f_default(x, a, b):
return torch.nn.functional.relu(torch.addmm(x, a, b))
fc_default = torch.compile(f_default)
expected_default = f_default(x, a, b)
actual_default = fc_default(x, a, b)
torch.testing.assert_close(actual_default, expected_default)
# Should unfuse and not use extern_kernels.addmm
_, (code) = run_and_get_code(fc_default, x, a, b)
FileCheck().check_not("extern_kernels.addmm(").run(code[0])
def test_serialized_patterns_up_to_date(self):
import torch.utils._pytree as pytree
from torch._inductor.fx_passes import joint_graph

View File

@ -7486,7 +7486,7 @@ class TestFXMemoryProfiler(TestCase):
return fx_frames
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
def test_fx_memory_profiler_augmentation(self):
"""Test that memory snapshots are augmented with FX debug information."""

View File

@ -4251,7 +4251,7 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@skipIfRocm
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
def test_profiler_stack_trace_augmentation(self):
"""
Test that map_recorded_events_to_aten_ops_with_stack_trace correctly
@ -4307,7 +4307,7 @@ event=cudaLaunchKernel node=addmm_1 stack_trace=x = self.linear2(x)"""
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@skipIfRocm
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
def test_profiler_multiple_modules(self):
"""
Test that multiple compiled modules under the same profiler session
@ -4351,7 +4351,7 @@ event=cudaLaunchKernel node=sub stack_trace=return x - 1"""
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@skipIfRocm
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
def test_profiler_nested_graph_modules(self):
"""
Test that nested graph modules (e.g., graph modules calling subgraphs)

View File

@ -165,6 +165,21 @@ def get_tests(workflow_run_id: int, workflow_run_attempt: int) -> list[dict[str,
return flattened
def get_tests_for_circleci(
workflow_run_id: int, workflow_run_attempt: int
) -> list[dict[str, Any]]:
# Parse the reports and transform them to JSON
test_cases = []
for xml_report in Path(".").glob("**/test/test-reports/**/*.xml"):
test_cases.extend(
parse_xml_report(
"testcase", xml_report, workflow_run_id, workflow_run_attempt
)
)
return test_cases
def summarize_test_cases(test_cases: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Group test cases by classname, file, and job_id. We perform the aggregation
manually instead of using the `test-suite` XML tag because xmlrunner does
@ -243,11 +258,21 @@ if __name__ == "__main__":
required=True,
help="Head repository of the workflow",
)
parser.add_argument(
"--circleci",
action="store_true",
help="If this is being run through circleci",
)
args = parser.parse_args()
print(f"Workflow id is: {args.workflow_run_id}")
test_cases = get_tests(args.workflow_run_id, args.workflow_run_attempt)
if args.circleci:
test_cases = get_tests_for_circleci(
args.workflow_run_id, args.workflow_run_attempt
)
else:
test_cases = get_tests(args.workflow_run_id, args.workflow_run_attempt)
# Flush stdout so that any errors in the upload show up last in the logs.
sys.stdout.flush()

View File

@ -32,7 +32,6 @@ from .decorators import (
error_on_graph_break,
forbid_in_graph,
graph_break,
is_dynamo_disable_recursive,
mark_dynamic,
mark_static,
mark_static_address,
@ -88,7 +87,6 @@ __all__ = [
"forbid_in_graph",
"graph_break",
"is_compiling",
"is_dynamo_disable_recursive",
"list_backends",
"lookup_backend",
"mark_dynamic",

View File

@ -15,7 +15,7 @@ import dataclasses
import re
import sys
import types
from collections import Counter, deque
from collections import Counter
from collections.abc import Callable, Iterable
from typing import Any, Optional, TYPE_CHECKING, Union
@ -597,35 +597,32 @@ class PyCodegen:
graphargs = self.tx.output.graphargs
def extract_nested_sources(source: Source) -> list[Source]:
nested_sources: list[Source] = []
seen_sources: OrderedSet[Source] = OrderedSet()
def collect_temp_source(source: Source) -> None:
if source in seen_sources:
# This source is used at least twice, so it can be reused
self.mark_source_temp(source)
# Dont trace source further. This prevents us from marking too
# many nodes as temp sources.
return
seen_sources.add(source)
if isinstance(source, ChainedSource):
nested_sources.append(source.base)
collect_temp_source(source.base)
if isinstance(source, DictGetItemSource) and isinstance(
source.index, Source
):
nested_sources.append(source.index)
return nested_sources
def collect_temp_sources(sources: deque[Source], codegen: PyCodegen) -> None:
seen_sources: OrderedSet[Source] = OrderedSet()
while sources:
current_source = sources.popleft()
if current_source in seen_sources:
# This source is used at least twice, so it can be reused
codegen.mark_source_temp(current_source)
# Dont trace source further. This prevents us from marking too
# many nodes as temp sources.
continue
seen_sources.add(current_source)
sources.extend(extract_nested_sources(current_source))
collect_temp_source(source.index)
# Collect all the sources that are used more than once, so that we can
# generate tmp variables in the generated pre-graph bytecode. This
# essentially implements CSE.
collect_temp_sources(
deque([arg.source for arg in graphargs if arg.source is not None]), self
)
for arg in graphargs:
if arg.source is not None:
collect_temp_source(arg.source)
cm_var = None
if config.record_runtime_overhead:

View File

@ -740,8 +740,11 @@ enable_aot_compile = False
# HACK: this is for testing custom ops profiling only
_custom_ops_profile: Optional[Any] = None
# Deprecated! Please use the config in torch/fx/experimental/_config instead.
enrich_profiler_metadata: bool = False
# Experimental: If True, graph module will register fx metadata during recompile()
enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated]
default=False,
env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE",
)
if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403

View File

@ -828,7 +828,6 @@ def trace_frame(
raise
finally:
tracer.output.call_cleanup_hooks()
tracer.f_locals = {}
try:
run_tracer()

View File

@ -96,7 +96,6 @@ def disable(fn=None, recursive=True, *, reason=None, wrapping=True): # type: ig
nonrecursive_disable_wrapper._torchdynamo_disable = True # type: ignore[attr-defined]
nonrecursive_disable_wrapper._torchdynamo_disable_msg = reason # type: ignore[attr-defined]
nonrecursive_disable_wrapper._torchdynamo_orig_callable = fn # type: ignore[attr-defined]
nonrecursive_disable_wrapper._torchdynamo_disable_recursive = False # type: ignore[attr-defined]
# pyrefly: ignore [bad-return]
return nonrecursive_disable_wrapper
@ -1024,13 +1023,3 @@ def error_on_graph_break(
The default value of torch.compile's `error_on_graph_break` setting is False.
"""
return ErrorOnGraphBreakDecoratorContextManager(error_on_graph_break)
def is_dynamo_disable_recursive(method: Callable[[Any], Any]) -> Optional[bool]:
"""
Check if a method is marked as `dynamo_disable` recursively. It returns:
- True if disable(recursive=True)
- False if disable(recursive=False)
- None if method is not a disable decorator
"""
return getattr(method, "_torchdynamo_disable_recursive", None)

View File

@ -1155,8 +1155,6 @@ class DisableContext(_TorchDynamoContext):
# of decorators.
_fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined]
_fn._torchdynamo_disable_recursive = True # type: ignore[attr-defined]
return _fn
def __reduce__(self) -> tuple[type[DisableContext], tuple[Any, ...]]:

View File

@ -4,8 +4,6 @@
import importlib
from typing import TYPE_CHECKING
import torch.utils._pytree as python_pytree
from .. import polyfills, trace_rules
@ -21,14 +19,12 @@ POLYFILLED_MODULE_NAMES: tuple[str, ...] = (
"itertools",
"operator",
"os",
"pytree",
"struct",
"sys",
"fx",
"tensor",
)
if python_pytree._cxx_pytree_dynamo_traceable:
POLYFILLED_MODULE_NAMES += ("pytree",)
POLYFILLED_MODULES: tuple["ModuleType", ...] = tuple(
importlib.import_module(f".{submodule}", package=polyfills.__name__)
for submodule in POLYFILLED_MODULE_NAMES

File diff suppressed because it is too large Load Diff

View File

@ -1991,7 +1991,7 @@ class BuiltinVariable(VariableTracker):
# If the object implements a __getitem__ method, iter(...) will call obj.__getitem__()
# with an integer argument starting at 0, until __getitem__ raises IndexError
ret = variables.UserFunctionVariable(
polyfills.builtins.iter_
polyfills.builtins.iter_ # type: ignore[arg-type]
).call_function(tx, [obj, *args], {})
if args:

File diff suppressed because it is too large Load Diff

View File

@ -590,7 +590,7 @@ class FilterVariable(IteratorVariable):
else:
res = self.fn.call_function(tx, [item], {})
pred_res = variables.UserFunctionVariable(
polyfills.predicate
polyfills.predicate # type: ignore[arg-type]
).call_function(tx, [res], {})
if pred_res.as_python_constant():
return item

View File

@ -1498,6 +1498,7 @@ class NamedTupleVariable(TupleVariable):
variables.UserDefinedClassVariable(self.tuple_cls),
)
elif isinstance(method, staticmethod):
# pyrefly: ignore[bad-argument-type]
return UserFunctionVariable(method.__func__)
elif inspect.isfunction(method):
return UserMethodVariable(method, self)

View File

@ -472,7 +472,12 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
)
elif self.value is torch.nn.attention.sdpa_kernel.__wrapped__: # type: ignore[attr-defined]
name_to_arg_map = bind_args_cached(
self.value, tx, self.source, args, kwargs
# pyrefly: ignore[bad-argument-type]
self.value,
tx,
self.source,
args,
kwargs,
)
backends = name_to_arg_map["backends"].as_python_constant()
set_priority = name_to_arg_map["set_priority"].as_python_constant()
@ -1429,7 +1434,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
packed_input_vt = TupleVariable.build(
tx, (TupleVariable.build(tx, args), ConstDictVariable.build(tx, kwargs))
)
out_vt = variables.UserFunctionVariable(tree_flatten).call_function(
out_vt = variables.UserFunctionVariable(tree_flatten).call_function( # type: ignore[arg-type]
tx, [packed_input_vt], {}
)
assert isinstance(out_vt, TupleVariable) and len(out_vt.items) == 2

View File

@ -550,6 +550,10 @@ max_autotune_flex_search_space: Literal["DEFAULT", "EXHAUSTIVE"] = os.environ.ge
"TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT"
).upper() # type: ignore[assignment]
cutedsl_enable_autotuning: bool = (
os.environ.get("CUTEDSL_ENABLE_AUTOTUNING", "0") == "1"
)
# DEPRECATED. This setting is ignored.
autotune_fallback_to_aten = False

View File

@ -913,10 +913,6 @@ def _get_optimization_cflags(
if not config.is_fbcode():
if platform.machine() == "ppc64le":
cflags.append("mcpu=native")
elif platform.machine() == "riscv64":
cflags.append("march=rv64gc")
elif platform.machine() == "riscv32":
cflags.append("march=rv32gc")
else:
cflags.append("march=native")

View File

@ -1516,29 +1516,17 @@ def should_prefer_unfused_addmm(match):
@register_graph_pattern(
CallFunction(
aten.addmm,
KeywordArg("inp"),
Arg(),
Arg(),
beta=KeywordArg("beta"),
alpha=KeywordArg("alpha"),
),
CallFunction(aten.addmm, KeywordArg("inp"), Arg(), Arg()),
# pyrefly: ignore [bad-argument-type]
pass_dict=pass_patterns[2],
extra_check=should_prefer_unfused_addmm,
)
def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp, alpha, beta):
def repl(inp, x1, x2, alpha, beta):
mm_result = x1 @ x2
if alpha != 1:
mm_result = alpha * mm_result
if beta != 1:
inp = beta * inp
return inp + mm_result
def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp):
def repl(inp, x1, x2):
return x1 @ x2 + inp
# pyrefly: ignore [bad-argument-type]
match.replace_by_example(repl, [inp, mat1, mat2, alpha, beta])
match.replace_by_example(repl, [inp, mat1, mat2])
def is_valid_addmm_fusion(match):

View File

@ -1,6 +1,8 @@
# mypy: allow-untyped-defs
import logging
from collections.abc import Sequence
from functools import partial
from pathlib import Path
from typing import Any
import torch
@ -12,6 +14,7 @@ from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols
from .. import config
from ..codegen.wrapper import PythonWrapperCodegen
from ..ir import _IntLike, Layout, TensorBox
from ..utils import load_template
log = logging.getLogger(__name__)
@ -254,3 +257,7 @@ def is_batch_stride_largest_or_zero(mat1, mat2, layout) -> bool:
return False
return True
_KERNEL_TEMPLATE_DIR = Path(__file__).parent / "templates"
load_kernel_template = partial(load_template, template_dir=_KERNEL_TEMPLATE_DIR)

View File

@ -1,11 +1,13 @@
# mypy: allow-untyped-defs
import logging
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from typing import Any, Optional
import torch
from torch._dynamo.utils import counters
from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate
from torch._inductor.runtime.triton_compat import tl
from torch._inductor.template_heuristics.cutedsl import get_groupgemm_configs
from torch._inductor.virtualized import V
from torch.utils._triton import has_triton
@ -22,11 +24,13 @@ from ..utils import (
get_num_sms,
has_free_symbols,
use_aten_gemm_kernels,
use_blackwell_cutedsl_grouped_mm,
use_triton_template,
)
from .mm_common import (
_is_static_problem,
check_supported_striding,
load_kernel_template,
persistent_grouped_mm_grid,
)
@ -513,6 +517,11 @@ triton_scaled_grouped_mm_template = TritonTemplate(
source=triton_grouped_mm_source,
)
cutedsl_grouped_mm_template = CuteDSLTemplate(
name="grouped_gemm_cutedsl",
source=load_kernel_template("cutedsl_mm_grouped"),
)
def grouped_mm_args(
mat1: TensorBox,
@ -714,43 +723,44 @@ def _tuned_grouped_mm_common(
# Checking only for the equality of corresponding dims of
# multiplicands here, relying on meta function checks for
# everything else.
if len(m1_size) == 2:
if len(m2_size) == 2:
m, k1 = m1_size
k2, _ = m2_size
# pyrefly: ignore [missing-attribute]
g = offs.get_size()[0]
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = True, True
else:
# pyrefly: ignore [missing-attribute]
g1 = offs.layout.size[0]
m, k1 = m1_size
g2, k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = True, False
else:
if len(m2_size) == 2:
# pyrefly: ignore [missing-attribute]
g1 = offs.layout.size[0]
g2, m, k1 = m1_size
k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = False, True
else:
g1, m, k1 = m1_size
g2, k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = False, False
if (
is_nonzero
and use_triton_template(layout)
and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result)
):
scaled = scale_a is not None
if len(m1_size) == 2:
if len(m2_size) == 2:
m, k1 = m1_size
k2, _ = m2_size
# pyrefly: ignore [missing-attribute]
g = offs.get_size()[0]
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = True, True
else:
# pyrefly: ignore [missing-attribute]
g1 = offs.layout.size[0]
m, k1 = m1_size
g2, k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = True, False
else:
if len(m2_size) == 2:
# pyrefly: ignore [missing-attribute]
g1 = offs.layout.size[0]
g2, m, k1 = m1_size
k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = False, True
else:
g1, m, k1 = m1_size
g2, k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = False, False
a_is_k_major = mat_a.get_stride()[-1] == 1
b_is_k_major = mat_b.get_stride()[-2] == 1
@ -788,6 +798,22 @@ def _tuned_grouped_mm_common(
**config.kwargs,
)
if use_blackwell_cutedsl_grouped_mm(
mat_a, mat_b, layout, a_is_2d, b_is_2d, offs, bias, scale_result
):
for config in get_groupgemm_configs():
kwargs = dict(
ACC_DTYPE="cutlass.Float32",
)
cutedsl_grouped_mm_template.maybe_append_choice(
choices,
input_nodes=input_nodes,
layout=layout,
**kwargs,
**asdict(config),
)
input_gen_fns = {
4: lambda x: create_offsets(
x, m1_size, m2_size, offs.get_size() if offs is not None else None

View File

@ -0,0 +1,333 @@
import functools
from torch._inductor.runtime.runtime_utils import ceildiv
from cutlass.utils import TensorMapUpdateMode
{{gen_defines()}}
# ---- Import GroupedGemm implementation, copied on PyTorch build from Cutlass repository: cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py ----
from torch._inductor.kernel.vendored_templates.cutedsl_grouped_gemm import (
GroupedGemmKernel,
)
# Note about caching:
# Each instantiated CuTeDSL grouped GEMM kernel file generated by Inductor
# maintains its own local caching system. At this stage, all compile-time
# constexprs (e.g., TILE_M, TILE_N, CLUSTER_M/N, USE_2_CTA) and the kernel
# name itself ({{kernel_name}}) are permanently baked into the file, so they
# do not need to be included in any cache key.
#
# The caching mechanism is split into two levels:
#
# 1. prep_cache
# Caches the compiled executor for build_group_ptrs_from_bases(). This
# kernel depends only on the tensor shapes, strides, and dtypes of A/B/C,
# and can therefore be safely reused across runs with different group
# partitioning (`offs`).
#
# 2. gemm_cache
# Caches the compiled Grouped GEMM executor. Its key extends the prep
# cache key with hardware- and grid-specific parameters:
# (prep_cache_key, max_active_clusters, total_num_clusters).
# This is necessary because different `offs` tensors can change the
# per-group problem sizes and thus alter `total_num_clusters`, which in
# turn changes the grid shape and persistent scheduler configuration.
# Kernels compiled for one grid cannot be safely reused for another.
#
#
# Additionally, note the @lru_cache decorator on get_hardware_info(). Empirically,
# hw.get_max_active_clusters() triggers significant MLIR recompilation overhead,
# despite depending only on the GPU type. We cache this function to mitigate
# redundant recompiles even when shape/stride/dtype cache misses force kernel
# regeneration. A follow-up study will investigate the root cause.
prep_cache = {}
gemm_cache = {}
@functools.lru_cache
def get_hardware_info():
hw = cutlass.utils.HardwareInfo()
sm_count = hw.get_max_active_clusters(1)
max_active_clusters = hw.get_max_active_clusters(CLUSTER_M * CLUSTER_N)
return (sm_count, max_active_clusters)
def get_prep_cache_key(input_a, input_b, output):
"""
Returns a tuple key for caching the preprocessing kernel executor based on kernel name,
shapes, strides, and dtypes of input/output tensors.
"""
return (
tuple(input_a.shape),
tuple(input_a.stride()),
input_a.dtype,
tuple(input_b.shape),
tuple(input_b.stride()),
input_b.dtype,
tuple(output.shape),
tuple(output.stride()),
output.dtype,
)
def get_gemm_cache_key(prep_cache_key, max_active_clusters, total_num_clusters):
"""
Returns a tuple key for caching the gemm kernel executor by extending the
prep cache key with hardware- and grid-specific parameters.
"""
return (
prep_cache_key,
max_active_clusters,
total_num_clusters,
)
@cute.kernel
def build_group_ptrs_from_bases_kernel(
base_A_u64: cutlass.Int64, # device addr of input_a (bytes)
base_B_u64: cutlass.Int64, # device addr of input_b (bytes)
base_C_u64: cutlass.Int64, # device addr of Output (bytes)
offs: cute.Tensor, # [G], cutlass.Int32/64 cumulative
K: cutlass.Constexpr,
N: cutlass.Constexpr,
sizeof_element: cutlass.Int32, # bytes
# -------- STRIDES (in ELEMENTS) --------
stride_A_m_elems: cutlass.Constexpr, # A.stride(0)
stride_A_k_elems: cutlass.Constexpr, # A.stride(1)
stride_B0_elems: cutlass.Constexpr, # B.stride(0)
stride_Bk_elems: cutlass.Constexpr, # B.stride(1)
stride_Bn_elems: cutlass.Constexpr, # B.stride(2)
stride_C_m_elems: cutlass.Constexpr, # C.stride(0)
stride_C_n_elems: cutlass.Constexpr, # C.stride(1)
# -------- OUTPUTS --------
out_ptrs: cute.Tensor, # [G,3] cutlass.Int64: (A_ptr, B_ptr, C_ptr)
out_problem: cute.Tensor, # [G,4] cutlass.Int32: (m_g, n, k, 1)
out_strides_abc: cute.Tensor, # [G,3,2] cutlass.Int32 [[A_m,A_k],[B_n,B_k],[C_m,C_n]]
):
tidx, _, _ = cute.arch.thread_idx()
g = tidx
m_beg_i32 = 0
if g > 0:
m_beg_i32 = offs[g - 1]
m_end_i32 = offs[g]
m_g_i32 = m_end_i32 - m_beg_i32
a_byte_off = (
cutlass.Int64(m_beg_i32) * stride_A_m_elems * cutlass.Int64(sizeof_element)
)
c_byte_off = (
cutlass.Int64(m_beg_i32) * stride_C_m_elems * cutlass.Int64(sizeof_element)
)
b_byte_off = cutlass.Int64(g) * stride_B0_elems * cutlass.Int64(sizeof_element)
# ---- pointers ----
out_ptrs[g, 0] = base_A_u64 + a_byte_off
out_ptrs[g, 1] = base_B_u64 + b_byte_off
out_ptrs[g, 2] = base_C_u64 + c_byte_off
# ---- (m, n, k, 1) ----
out_problem[g, 0] = m_g_i32
out_problem[g, 1] = N
out_problem[g, 2] = K
out_problem[g, 3] = cutlass.Int32(1)
# ---- strides ----
out_strides_abc[g, 0, 0] = cutlass.Int32(stride_A_m_elems)
out_strides_abc[g, 0, 1] = cutlass.Int32(stride_A_k_elems)
out_strides_abc[g, 1, 0] = cutlass.Int32(stride_Bn_elems)
out_strides_abc[g, 1, 1] = cutlass.Int32(stride_Bk_elems)
out_strides_abc[g, 2, 0] = cutlass.Int32(stride_C_m_elems)
out_strides_abc[g, 2, 1] = cutlass.Int32(stride_C_n_elems)
@cute.jit
def launch_build_group_ptrs_from_bases(
base_A_u64: cutlass.Int64,
base_B_u64: cutlass.Int64,
base_C_u64: cutlass.Int64,
offs: cute.Tensor,
G: cutlass.Constexpr,
K: cutlass.Constexpr,
N: cutlass.Constexpr,
sizeof_element: cutlass.Constexpr,
stride_A_m_elems: cutlass.Constexpr,
stride_A_k_elems: cutlass.Constexpr,
stride_B0_elems: cutlass.Constexpr,
stride_Bk_elems: cutlass.Constexpr,
stride_Bn_elems: cutlass.Constexpr,
stride_C_m_elems: cutlass.Constexpr,
stride_C_n_elems: cutlass.Constexpr,
out_ptrs: cute.Tensor, # [G,3] cutlass.Int64
out_problem: cute.Tensor, # [G,4] cutlass.Int32
out_strides_abc: cute.Tensor, # [3,2] cutlass.Int32
stream: cuda.CUstream,
):
build_group_ptrs_from_bases_kernel(
base_A_u64,
base_B_u64,
base_C_u64,
offs,
K,
N,
sizeof_element,
stride_A_m_elems,
stride_A_k_elems,
stride_B0_elems,
stride_Bk_elems,
stride_Bn_elems,
stride_C_m_elems,
stride_C_n_elems,
out_ptrs,
out_problem,
out_strides_abc,
).launch(grid=(1, 1, 1), block=(G, 1, 1), stream=stream)
{{def_kernel("input_a", "input_b", "input_a_offs")}}
stream = cuda.CUstream(stream)
input_b = input_b.transpose(1, 2)
sumM, K = input_a.shape
G, N, Kb = input_b.shape
dev = input_a.device
base_A_u64 = int(input_a.data_ptr())
base_B_u64 = int(input_b.data_ptr())
base_C_u64 = int({{get_output()}}.data_ptr())
ptrs_t = torch.empty((G, 3), device=dev, dtype=torch.int64)
probs_t = torch.empty((G, 4), device=dev, dtype=torch.int32)
strides_t = torch.empty((G, 3, 2), device=dev, dtype=torch.int32)
ptrs = from_dlpack(ptrs_t)
probs = from_dlpack(probs_t)
strides = from_dlpack(strides_t)
prep_cache_key = get_prep_cache_key(input_a, input_b, {{get_output()}})
prep_executor = prep_cache.get(prep_cache_key)
if prep_executor is None:
sizeof_element = int(input_a.element_size())
sA_m, sA_k = map(int, input_a.stride())
sB_0, sB_n, sB_k = map(int, input_b.stride())
sC_m, sC_n = map(int, {{get_output()}}.stride())
prep_executor = cute.compile(
launch_build_group_ptrs_from_bases,
base_A_u64=base_A_u64,
base_B_u64=base_B_u64,
base_C_u64=base_C_u64,
offs=from_dlpack(input_a_offs),
G=int(G),
K=int(K),
N=int(N),
sizeof_element=sizeof_element,
stride_A_m_elems=sA_m,
stride_A_k_elems=sA_k,
stride_B0_elems=sB_0,
stride_Bk_elems=sB_k,
stride_Bn_elems=sB_n,
stride_C_m_elems=sC_m,
stride_C_n_elems=sC_n,
out_ptrs=ptrs,
out_problem=probs,
out_strides_abc=strides,
stream=stream,
)
prep_cache[prep_cache_key] = prep_executor
prep_executor(
base_A_u64=base_A_u64,
base_B_u64=base_B_u64,
base_C_u64=base_C_u64,
offs=from_dlpack(input_a_offs),
out_ptrs=ptrs,
out_problem=probs,
out_strides_abc=strides,
stream=stream,
)
# --- Tensormap workspace per SM ---
num_tensormap_buffers, max_active_clusters = get_hardware_info()
tensormap_shape = (
num_tensormap_buffers,
GroupedGemmKernel.num_tensormaps,
GroupedGemmKernel.bytes_per_tensormap // 8,
)
tensormap_workspace_t = torch.empty(tensormap_shape, device=dev, dtype=torch.int64)
tensormap_workspace = from_dlpack(tensormap_workspace_t)
# --- Total clusters ---
def compute_total_num_clusters(
problem_sizes_mnkl,
cluster_tile_shape_mn,
):
total_num_clusters = 0
for m, n, _, _ in problem_sizes_mnkl:
num_clusters_mn = tuple(
ceildiv(x, y) for x, y in zip((m, n), cluster_tile_shape_mn)
)
total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn)
return total_num_clusters
# Compute cluster tile shape
def compute_cluster_tile_shape(
mma_tiler_mn,
cluster_shape_mn,
use_2cta_instrs,
):
cta_tile_shape_mn = list(mma_tiler_mn)
if use_2cta_instrs:
cta_tile_shape_mn[0] = cta_tile_shape_mn[0] // 2
return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn))
cluster_tile_shape_mn = compute_cluster_tile_shape(
(TILE_M, TILE_N), (CLUSTER_M, CLUSTER_N), bool(USE_2_CTA)
)
total_num_clusters = int(compute_total_num_clusters(probs_t, cluster_tile_shape_mn))
gemm_cache_key = get_gemm_cache_key(
prep_cache_key, max_active_clusters, total_num_clusters
)
gemm_executor = gemm_cache.get(gemm_cache_key)
if gemm_executor is None:
grouped_gemm = GroupedGemmKernel(
acc_dtype=ACC_DTYPE,
use_2cta_instrs=USE_2_CTA,
mma_tiler_mn=(TILE_M, TILE_N),
cluster_shape_mn=(CLUSTER_M, CLUSTER_N),
tensormap_update_mode=TENSORMAP_UPDATE_MODE,
)
gemm_executor = cute.compile(
grouped_gemm,
from_dlpack(input_a.unsqueeze(-1), assumed_align=16),
from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16),
from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16),
G,
probs,
strides,
ptrs,
total_num_clusters,
tensormap_workspace,
max_active_clusters,
stream,
)
gemm_cache[gemm_cache_key] = gemm_executor
gemm_executor(
from_dlpack(input_a.unsqueeze(-1), assumed_align=16),
from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16),
from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16),
probs,
strides,
ptrs,
tensormap_workspace,
stream,
)

View File

@ -0,0 +1,141 @@
from dataclasses import dataclass
from enum import auto, Enum
from itertools import product
import torch._inductor.config as config
class TensorMapUpdateMode(Enum):
"""Enum mirroring cutlass.utils.TensorMapUpdateMode to decouple this file from a cutlass dependency."""
SMEM = auto()
GMEM = auto()
@dataclass(frozen=True)
class CuTeGemmConfig:
TILE_M: int = 128
TILE_N: int = 192
CLUSTER_M: int = 2
CLUSTER_N: int = 1
USE_2_CTA: bool = False
TENSORMAP_UPDATE_MODE: TensorMapUpdateMode = TensorMapUpdateMode.SMEM
def get_exhaustive_groupgemm_configs() -> list[CuTeGemmConfig]:
"""
Returns the exhaustive configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
For information regarding valid config sets, see:
https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py
"""
# Tile_n is always the same regardless of 2cta
tile_n_vals = [32, 64, 96, 128, 160, 192, 224, 256]
# Valid clusters
clusters_no_2cta = [
(1, 1),
(1, 2),
(1, 4),
(1, 8),
(1, 16),
(2, 1),
(2, 2),
(2, 4),
(2, 8),
(4, 1),
(4, 2),
(4, 4),
(8, 1),
(8, 2),
(16, 1),
]
clusters_2cta = [
(2, 1),
(2, 2),
(2, 4),
(2, 8),
(4, 1),
(4, 2),
(4, 4),
(8, 1),
(8, 2),
(16, 1),
]
configs: list[CuTeGemmConfig] = []
for use_2cta, cluster_set, tile_m_range in [
(False, clusters_no_2cta, [64, 128]),
(True, clusters_2cta, [128, 256]),
]:
for tensormap_update_mode, tile_m, tile_n, (cluster_m, cluster_n) in product(
[TensorMapUpdateMode.SMEM, TensorMapUpdateMode.GMEM],
tile_m_range,
tile_n_vals,
cluster_set,
):
configs.append(
CuTeGemmConfig(
tile_m,
tile_n,
cluster_m,
cluster_n,
USE_2_CTA=use_2cta,
TENSORMAP_UPDATE_MODE=tensormap_update_mode,
)
)
return configs
def get_default_groupgemm_configs() -> list[CuTeGemmConfig]:
"""
Returns the default configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
"""
config_tuples = [
(128, 256, 2, 1, False, TensorMapUpdateMode.SMEM),
(256, 160, 2, 1, True, TensorMapUpdateMode.GMEM),
(256, 256, 2, 1, True, TensorMapUpdateMode.GMEM),
(64, 32, 1, 1, False, TensorMapUpdateMode.GMEM),
(64, 256, 1, 2, False, TensorMapUpdateMode.SMEM),
(128, 256, 1, 2, False, TensorMapUpdateMode.SMEM),
(256, 256, 2, 2, True, TensorMapUpdateMode.GMEM),
(128, 256, 1, 2, False, TensorMapUpdateMode.GMEM),
(64, 32, 1, 1, False, TensorMapUpdateMode.SMEM),
(256, 256, 2, 1, True, TensorMapUpdateMode.SMEM),
(128, 256, 1, 1, False, TensorMapUpdateMode.GMEM),
(256, 256, 8, 1, True, TensorMapUpdateMode.GMEM),
(64, 32, 1, 2, False, TensorMapUpdateMode.SMEM),
(256, 192, 2, 1, True, TensorMapUpdateMode.GMEM),
(256, 256, 2, 2, True, TensorMapUpdateMode.SMEM),
(128, 96, 1, 2, False, TensorMapUpdateMode.SMEM),
(64, 192, 1, 1, False, TensorMapUpdateMode.SMEM),
(64, 64, 1, 1, False, TensorMapUpdateMode.GMEM),
(64, 192, 1, 1, False, TensorMapUpdateMode.GMEM),
(128, 64, 1, 1, False, TensorMapUpdateMode.GMEM),
(64, 160, 1, 1, False, TensorMapUpdateMode.GMEM),
(64, 256, 1, 1, False, TensorMapUpdateMode.GMEM),
]
return [CuTeGemmConfig(*args) for args in config_tuples]
def get_groupgemm_configs() -> list[CuTeGemmConfig]:
"""
Returns the configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
Note: CuTeDSL autotuning is still experimental — enabling it may trigger kernel launch failures
or unstable results. By default, autotuning is disabled and we return only
a single baseline config.
"""
if (
config.cutedsl_enable_autotuning
and config.max_autotune_gemm_search_space == "EXHAUSTIVE"
):
return get_exhaustive_groupgemm_configs()
elif config.cutedsl_enable_autotuning:
return get_default_groupgemm_configs()
else:
return [get_default_groupgemm_configs()[0]]

View File

@ -1911,6 +1911,84 @@ def use_triton_blackwell_tma_template(
return has_triton_tensor_descriptor_host_tma() and is_datacenter_blackwell_arch()
@functools.lru_cache(maxsize=1)
def ensure_cute_available() -> bool:
"""Check if CuTeDSL is importable; cache the result for reuse.
Call ensure_cute_available.cache_clear() after installing CuTeDSL
in the same interpreter to retry the import.
"""
try:
return importlib.util.find_spec("cutlass.cute") is not None
except ImportError:
return False
def use_blackwell_cutedsl_grouped_mm(
mat_a: Any,
mat_b: Any,
layout: Layout,
a_is_2d: bool,
b_is_2d: bool,
offs: Optional[Any],
bias: Optional[Any],
scale_result: Optional[Any],
) -> bool:
"""
Returns True if we can use the blackwell kernel for grouped mm.
Required conditions:
1. CuTeDSL backend is enabled
2. CuTeDSL is available
3. We are on a blackwell arch
4. The dtype is bf16
5. Max autotune or max autotune gemm is enabled
6. A, B, and the output are 16B aligned
7. We are not using dynamic shapes
8. A is 2d
9. B is 3d
10. Offsets are provided
11. Bias and Scale are not provided
"""
if not ensure_cute_available():
return False
if not _use_autotune_backend("CUTEDSL"):
return False
from .codegen.cuda.cuda_env import is_datacenter_blackwell_arch
if not is_gpu(layout.device.type):
return False
if not is_datacenter_blackwell_arch():
return False
layout_dtypes = [torch.bfloat16]
if not _use_template_for_gpu(layout, layout_dtypes):
return False
if not (config.max_autotune or config.max_autotune_gemm):
return False
# Checks for 16B ptr and stride alignment
if not can_use_tma(mat_a, mat_b, output_layout=layout):
return False
if any(is_dynamic(x) for x in [mat_a, mat_b]):
return False
if not a_is_2d or b_is_2d:
return False
if offs is None:
return False
if bias is not None or scale_result is not None:
return False
return True
def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
from .virtualized import V

View File

@ -1485,7 +1485,6 @@ class PipelineScheduleMulti(_PipelineSchedule):
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
use_full_backward: Optional[bool] = None,
scale_grads: bool = True,
backward_requires_autograd: bool = True,
):
# Init parent
super().__init__(
@ -1518,11 +1517,6 @@ class PipelineScheduleMulti(_PipelineSchedule):
# This will be set during init of derived schedules
self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
# When using a custom backward function, we may or may not need autograd to be used
# for the backward pass. This flag is used to determine whether or torch.is_grad_enabled()
# check should be performed before the step function.
self._backward_requires_autograd = backward_requires_autograd
if use_full_backward is not None:
logger.warning(
"Deprecation warning: 'use_full_backward' is no longer supported. "
@ -1615,11 +1609,7 @@ class PipelineScheduleMulti(_PipelineSchedule):
losses: a list to store the losses for each microbatch.
return_outputs: whether to return the outputs from the last stage.
"""
if (
self._has_backward
and self._backward_requires_autograd
and not torch.is_grad_enabled()
):
if self._has_backward and not torch.is_grad_enabled():
raise RuntimeError(
"step() requires gradients to be enabled for backward computation; "
"it should not be used under torch.no_grad() context. "
@ -1901,7 +1891,7 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
Args:
computation_type: The computation type for which to register the custom function
custom_function: The function to execute when this computation type is encountered.
Must have signature: (action: _Action, ctx: _PipelineContext) -> None
Must have signature: (stage: _PipelineStageBase, mb_index: int, *args, **kwargs) -> None
"""
# Ensure that the computation type is valid
if computation_type not in (
@ -1910,13 +1900,10 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
BACKWARD_INPUT,
BACKWARD_WEIGHT,
OVERLAP_F_B,
UNSHARD,
RESHARD,
REDUCE_GRAD,
):
raise ValueError(
f"Invalid computation type {computation_type}. Only FORWARD, FULL_BACKWARD, \
BACKWARD_INPUT, BACKWARD_WEIGHT, OVERLAP_F_B, UNSHARD, RESHARD and REDUCE_GRAD are supported."
BACKWARD_INPUT, BACKWARD_WEIGHT, and OVERLAP_F_B are supported."
)
# Check if computation_type is already registered
@ -2309,7 +2296,6 @@ class ScheduleLoopedBFS(_PipelineScheduleRuntime):
loss_fn: Optional[Union[Callable, _Loss]] = None,
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
scale_grads: bool = True,
backward_requires_autograd: bool = True,
):
super().__init__(
stages=stages,
@ -2317,7 +2303,6 @@ class ScheduleLoopedBFS(_PipelineScheduleRuntime):
loss_fn=loss_fn,
output_merge_spec=output_merge_spec,
scale_grads=scale_grads,
backward_requires_autograd=backward_requires_autograd,
)
# 1. Create the pipeline_order (all ranks do this calculation)
@ -2525,7 +2510,6 @@ class ScheduleInterleaved1F1B(_PipelineScheduleRuntime):
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
scale_grads: bool = True,
backward_requires_autograd: bool = True,
):
self.pp_group_size = stages[0].group_size
super().__init__(
@ -2536,7 +2520,6 @@ class ScheduleInterleaved1F1B(_PipelineScheduleRuntime):
kwargs_chunk_spec=kwargs_chunk_spec,
output_merge_spec=output_merge_spec,
scale_grads=scale_grads,
backward_requires_autograd=backward_requires_autograd,
)
self.n_local_stages = len(stages)
self.rank = stages[0].group_rank
@ -2639,7 +2622,6 @@ class ScheduleInterleavedZeroBubble(_PipelineScheduleRuntime):
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
scale_grads: bool = True,
backward_requires_autograd: bool = True,
):
# TODO: we dont support input/weight backward split with torch.compile
_check_torch_compile_compatibility(stages, self.__class__.__name__)
@ -2652,7 +2634,6 @@ class ScheduleInterleavedZeroBubble(_PipelineScheduleRuntime):
kwargs_chunk_spec=kwargs_chunk_spec,
output_merge_spec=output_merge_spec,
scale_grads=scale_grads,
backward_requires_autograd=backward_requires_autograd,
)
self.n_local_stages = len(stages)
self.rank = stages[0].group_rank
@ -2838,7 +2819,6 @@ class ScheduleZBVZeroBubble(_PipelineScheduleRuntime):
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
scale_grads: bool = True,
backward_requires_autograd: bool = True,
):
# TODO: we dont support input/weight backward split with torch.compile
_check_torch_compile_compatibility(stages, self.__class__.__name__)
@ -2851,7 +2831,6 @@ class ScheduleZBVZeroBubble(_PipelineScheduleRuntime):
kwargs_chunk_spec=kwargs_chunk_spec,
output_merge_spec=output_merge_spec,
scale_grads=scale_grads,
backward_requires_autograd=backward_requires_autograd,
)
self.stage_index_to_group_rank = generate_stage_to_rank_mapping(
self.pp_group_size, self._num_stages, style="v"
@ -3016,7 +2995,6 @@ class ScheduleDualPipeV(_PipelineScheduleRuntime):
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
scale_grads: bool = True,
backward_requires_autograd: bool = True,
):
# TODO: we dont support input/weight backward split with torch.compile
_check_torch_compile_compatibility(stages, self.__class__.__name__)
@ -3029,7 +3007,6 @@ class ScheduleDualPipeV(_PipelineScheduleRuntime):
kwargs_chunk_spec=kwargs_chunk_spec,
output_merge_spec=output_merge_spec,
scale_grads=scale_grads,
backward_requires_autograd=backward_requires_autograd,
)
self.stage_index_to_group_rank = generate_stage_to_rank_mapping(
self.pp_group_size, self._num_stages, style="v"

View File

@ -2,8 +2,6 @@ import os
import sys
from typing import Optional
from torch.utils._config_module import Config, install_config_module
# [@compile_ignored: debug] Fails hard instead of graph breaking on guard on data dependent errors.
no_data_dependent_graph_break = (
@ -102,11 +100,7 @@ backed_size_oblivious = False
# Skip dtype check in meta registrations. Only used for systems that does its own dtype checking.
skip_dtype_check_in_meta_registrations = False
# Experimental: If True, graph module will register fx metadata during recompile()
enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated]
default=False,
env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE",
)
from torch.utils._config_module import install_config_module
install_config_module(sys.modules[__name__])

View File

@ -20,7 +20,6 @@ from torch.nn.modules.module import _addindent
from torch.package import Importer, PackageExporter, PackageImporter, sys_importer
from ._compatibility import compatibility
from .experimental import _config as fx_experimental_config
from .graph import (
_BoxedCodeGen,
_custom_builtins,
@ -859,15 +858,14 @@ class {module_name}(torch.nn.Module):
called after editing the contained ``graph``, otherwise the generated
code of this ``GraphModule`` will be out of date.
"""
# Do not import anything inside recompile, it might slow down the
# function and cause perf regression. Import outside of the method instead.
if isinstance(self._graph._codegen, _PyTreeCodeGen):
self._in_spec = self._graph._codegen.pytree_info.in_spec
self._out_spec = self._graph._codegen.pytree_info.out_spec
from torch._dynamo import config as dynamo_config
python_code = self._graph.python_code(
root_module="self",
record_func=fx_experimental_config.enrich_profiler_metadata,
root_module="self", record_func=dynamo_config.enrich_profiler_metadata
)
self._code = python_code.src
self._lineno_map = python_code._lineno_map
@ -876,7 +874,7 @@ class {module_name}(torch.nn.Module):
cls = type(self)
co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {}
if fx_experimental_config.enrich_profiler_metadata:
if dynamo_config.enrich_profiler_metadata:
# Generate metadata and register for profiler augmentation
node_metadata: dict[int, dict[str, Any]] = {}
for i, node in enumerate(self._graph.nodes):

View File

@ -19,13 +19,8 @@ __all__: list[str] = [
"SDPBackend",
"sdpa_kernel",
"WARN_FOR_UNFUSED_KERNELS",
"register_flash_attention_impl",
"activate_flash_attention_impl",
"list_flash_attention_impls",
"current_flash_attention_impl",
]
# Note: [SDPA warnings]
# TODO: Consider using this for sdpa regardless of subclasses
# This only effects users of bias subclasses
@ -167,23 +162,3 @@ def _sdpa_kernel_variadic(*backends: SDPBackend):
def _get_flash_version() -> str:
"""This returns the closest matching tag for the flash attention backend"""
return "2.5.7"
from . import _registry
# Re-export registry types and functions for public API
_FlashAttentionImpl = _registry._FlashAttentionImpl
_RegisterFn = _registry._RegisterFn
register_flash_attention_impl = _registry.register_flash_attention_impl
activate_flash_attention_impl = _registry.activate_flash_attention_impl
list_flash_attention_impls = _registry.list_flash_attention_impls
current_flash_attention_impl = _registry.current_flash_attention_impl
register_flash_attention_impl.__module__ = __name__
activate_flash_attention_impl.__module__ = __name__
list_flash_attention_impls.__module__ = __name__
current_flash_attention_impl.__module__ = __name__
# Import built-in implementations to trigger self-registration
from . import _fa4 # noqa: F401

View File

@ -1,444 +0,0 @@
"""UBER PROTOTYPE!!!"""
# mypy: allow-untyped-defs
from __future__ import annotations
import importlib
from dataclasses import dataclass
from functools import cache
from typing import Any, TYPE_CHECKING
from typing_extensions import TypeVarTuple, Unpack
from . import _registry
if TYPE_CHECKING:
from types import ModuleType
import torch
from torch.library import Library
__all__ = [
"register_flash_attention_fa4",
]
_FA4_MODULE_PATH: str | None = None
@dataclass
class _FA4Handle:
library: Library | None
def remove(self) -> None:
self.library = None
@cache
def _get_device_major(device: torch.device) -> int:
major, _ = torch.cuda.get_device_capability(device)
return major
def register_flash_attention_fa4(
module_path: str = "flash_attn.cute.interface",
) -> _FA4Handle:
"""
Register FA4 flash attention kernels with the PyTorch dispatcher.
Args:
module_path: Python module path to the FA4 implementation.
"""
global _FA4_MODULE_PATH
_ = _fa4_import_module(module_path)
_FA4_MODULE_PATH = module_path
return _FA4Handle(_fa4_register_kernels())
@cache
def _fa4_import_module(module_path: str) -> ModuleType:
module = importlib.import_module(module_path)
if not hasattr(module, "_flash_attn_fwd") or not hasattr(module, "_flash_attn_bwd"):
raise RuntimeError(f"Module '{module_path}' does not expose FA4 kernels")
return module
def _fa4_register_kernels() -> Library:
lib = Library("aten", "IMPL", "CUDA") # noqa: TOR901
lib.impl("_flash_attention_forward", _fa4_flash_attention_forward_impl, "CUDA")
lib.impl("_flash_attention_backward", _fa4_flash_attention_backward_impl, "CUDA")
lib.impl(
"_scaled_dot_product_flash_attention",
_fa4_scaled_dot_product_flash_attention_forward_impl,
"CUDA",
)
lib.impl(
"_scaled_dot_product_flash_attention_backward",
_fa4_scaled_dot_product_flash_attention_backward_impl,
"CUDA",
)
return lib
def _fa4_common_support_error(
query: torch.Tensor,
tensors: tuple[torch.Tensor, ...],
cum_seq_q: torch.Tensor | None,
require_fp32: tuple[tuple[str, torch.Tensor], ...] = (),
) -> str | None:
if not all(t.is_cuda for t in tensors):
return "inputs must be CUDA tensors"
if len({t.device for t in tensors}) != 1:
return "inputs must share device"
if query.dtype not in (torch.float16, torch.bfloat16):
return "query dtype must be float16 or bfloat16"
for name, tensor in require_fp32:
if tensor.dtype != torch.float32:
return f"{name} dtype must be float32"
if cum_seq_q is None and query.dim() != 4:
return "dense query must be 4D"
if cum_seq_q is not None and query.dim() != 3:
return "ragged query must be 3D"
if not torch.cuda.is_available():
return "CUDA not available"
if _get_device_major(query.device) not in (9, 10):
return "FA4 requires compute capability 9.0 or 10.0"
return None
def _fa4_forward_support_error(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float,
return_debug_mask: bool,
alibi_slopes: torch.Tensor | None,
seqused_k: torch.Tensor | None,
cum_seq_q: torch.Tensor | None,
) -> str | None:
if dropout_p != 0.0:
return "dropout_p must be 0"
if return_debug_mask:
return "return_debug_mask must be False"
if alibi_slopes is not None:
return "alibi_slopes not supported"
if seqused_k is not None:
if seqused_k.dtype != torch.int32:
return "seqused_k must be int32"
if not seqused_k.is_cuda:
return "seqused_k must be CUDA"
error = _fa4_common_support_error(
query,
(query, key, value),
cum_seq_q,
)
if error is not None:
if error == "inputs must share device":
return "query, key, value must be on same device"
return error
return None
def _fa4_backward_support_error(
grad_out: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
logsumexp: torch.Tensor,
dropout_p: float,
cum_seq_q: torch.Tensor | None,
window_size_left: int | None,
window_size_right: int | None,
) -> str | None:
if dropout_p != 0.0:
return "dropout_p must be 0"
if window_size_left is not None or window_size_right is not None:
return "windowed attention not supported"
error = _fa4_common_support_error(
query,
(grad_out, query, key, value, out, logsumexp),
cum_seq_q,
require_fp32=(("logsumexp", logsumexp),),
)
if error is not None:
return error
return None
Ts = TypeVarTuple("Ts")
def _transpose_dense(*tensors: Unpack[Ts]) -> tuple[Unpack[Ts]]:
return tuple(t.transpose(1, 2) for t in tensors) # type: ignore[attr-defined]
def _fa4_run_forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seq_q: torch.Tensor | None,
cu_seq_k: torch.Tensor | None,
scale: float | None,
is_causal: bool,
window_size_left: int | None,
window_size_right: int | None,
seqused_k: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
if _FA4_MODULE_PATH is None:
raise RuntimeError("FA4 not registered")
module = _fa4_import_module(_FA4_MODULE_PATH)
kwargs: dict[str, Any] = {
"softmax_scale": scale,
"causal": is_causal,
"window_size_left": window_size_left,
"window_size_right": window_size_right,
"return_lse": True,
"cu_seqlens_q": cu_seq_q,
"cu_seqlens_k": cu_seq_k,
"seqused_k": seqused_k.contiguous() if seqused_k is not None else None,
}
out, lse = module._flash_attn_fwd(query, key, value, **kwargs)
return out, lse.contiguous()
def _fa4_run_backward(
grad_out: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
logsumexp: torch.Tensor,
cu_seq_q: torch.Tensor | None,
cu_seq_k: torch.Tensor | None,
scale: float | None,
is_causal: bool,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if _FA4_MODULE_PATH is None:
raise RuntimeError("FA4 not registered")
module = _fa4_import_module(_FA4_MODULE_PATH)
dq, dk, dv = module._flash_attn_bwd(
query,
key,
value,
out,
grad_out,
logsumexp.contiguous(),
softmax_scale=scale,
causal=is_causal,
cu_seqlens_q=cu_seq_q,
cu_seqlens_k=cu_seq_k,
)
return dq, dk, dv
def _fa4_flash_attention_forward_impl(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cum_seq_q: torch.Tensor | None,
cum_seq_k: torch.Tensor | None,
max_q: int,
max_k: int,
dropout_p: float,
is_causal: bool,
return_debug_mask: bool,
*,
scale: float | None = None,
window_size_left: int | None = None,
window_size_right: int | None = None,
seqused_k: torch.Tensor | None = None,
alibi_slopes: torch.Tensor | None = None,
):
error = _fa4_forward_support_error(
query,
key,
value,
dropout_p,
return_debug_mask,
alibi_slopes,
seqused_k,
cum_seq_q,
)
if error is not None:
raise RuntimeError(f"FA4 flash_attention forward unsupported: {error}")
out, lse = _fa4_run_forward(
query,
key,
value,
cum_seq_q,
cum_seq_k,
scale,
is_causal,
window_size_left,
window_size_right,
seqused_k,
)
rng_state = torch.zeros((2,), dtype=torch.uint64, device=query.device)
philox_offset = torch.zeros((), dtype=torch.uint64, device=query.device)
debug_mask = torch.empty(0, dtype=query.dtype, device=query.device)
return out, lse, rng_state, philox_offset, debug_mask
def _fa4_flash_attention_backward_impl(
grad_out: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
logsumexp: torch.Tensor,
cum_seq_q: torch.Tensor | None,
cum_seq_k: torch.Tensor | None,
max_q: int,
max_k: int,
dropout_p: float,
is_causal: bool,
rng_state: torch.Tensor,
unused: torch.Tensor,
*,
scale: float | None = None,
window_size_left: int | None = None,
window_size_right: int | None = None,
):
error = _fa4_backward_support_error(
grad_out,
query,
key,
value,
out,
logsumexp,
dropout_p,
cum_seq_q,
window_size_left,
window_size_right,
)
if error is not None:
raise RuntimeError(f"FA4 flash_attention backward unsupported: {error}")
dq, dk, dv = _fa4_run_backward(
grad_out,
query,
key,
value,
out,
logsumexp,
cum_seq_q,
cum_seq_k,
scale,
is_causal,
)
return dq, dk, dv
def _fa4_scaled_dot_product_flash_attention_forward_impl(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
is_causal: bool = False,
return_debug_mask: bool = False,
*,
scale: float | None = None,
):
error = _fa4_forward_support_error(
query,
key,
value,
dropout_p,
return_debug_mask,
None,
None,
None,
)
if error is not None:
raise RuntimeError(f"FA4 SDPA forward unsupported: {error}")
q, k, v = _transpose_dense(query, key, value)
max_q_flash = q.size(1)
max_k_flash = k.size(1)
out, lse, rng_state, philox_offset, debug_mask = _fa4_flash_attention_forward_impl(
q,
k,
v,
None,
None,
max_q_flash,
max_k_flash,
dropout_p,
is_causal,
return_debug_mask,
scale=scale,
)
(out,) = _transpose_dense(out)
max_q = query.size(2)
max_k = key.size(2)
return (
out,
lse,
None,
None,
max_q,
max_k,
rng_state,
philox_offset,
debug_mask,
)
def _fa4_scaled_dot_product_flash_attention_backward_impl(
grad_out: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
logsumexp: torch.Tensor,
cum_seq_q: torch.Tensor | None,
cum_seq_k: torch.Tensor | None,
max_q: int,
max_k: int,
dropout_p: float,
is_causal: bool,
philox_seed: torch.Tensor,
philox_offset: torch.Tensor,
*,
scale: float | None = None,
):
error = _fa4_backward_support_error(
grad_out,
query,
key,
value,
out,
logsumexp,
dropout_p,
None,
None,
None,
)
if error is not None:
raise RuntimeError(f"FA4 SDPA backward unsupported: {error}")
q, k, v, o, go = _transpose_dense(query, key, value, out, grad_out)
max_q = query.size(2)
max_k = key.size(2)
dq, dk, dv = _fa4_flash_attention_backward_impl(
go,
q,
k,
v,
o,
logsumexp,
None,
None,
max_q,
max_k,
dropout_p,
is_causal,
philox_seed,
philox_offset,
scale=scale,
)
dq, dk, dv = _transpose_dense(dq, dk, dv)
return dq, dk, dv
_registry.register_flash_attention_impl("FA4", register_fn=register_flash_attention_fa4)

View File

@ -1,109 +0,0 @@
# mypy: allow-untyped-defs
"""Registry for flash attention implementations.
This module contains the registration system for flash attention implementations.
It has no torch dependencies to avoid circular imports during initialization.
"""
from collections.abc import Callable
from typing import Literal, Protocol
class FlashAttentionHandle(Protocol):
def remove(self) -> None: ...
_RegisterFn = Callable[..., FlashAttentionHandle | None]
_FlashAttentionImpl = Literal["FA4"]
_FLASH_ATTENTION_IMPLS: dict[str, _RegisterFn] = {}
_FLASH_ATTENTION_ACTIVE: str | None = None
_FLASH_ATTENTION_HANDLES: dict[str, FlashAttentionHandle] = {}
def register_flash_attention_impl(
impl: str | _FlashAttentionImpl,
*,
register_fn: _RegisterFn,
) -> None:
"""
Register the callable that activates a flash attention impl.
.. note::
This function is intended for SDPA backend providers to register their
implementations. End users should use :func:`activate_flash_attention_impl`
to activate a registered implementation.
Args:
impl: Implementation identifier (e.g., ``"FA4"``).
register_fn: Callable that performs the actual dispatcher registration.
This function will be invoked by :func:`activate_flash_attention_impl`
and should register custom kernels with the PyTorch dispatcher.
It may optionally return a handle implementing
:class:`FlashAttentionHandle` to keep any necessary state alive.
Example:
>>> def my_impl_register(module_path: str = "my_flash_impl"):
... # Register custom kernels with torch dispatcher
... pass # doctest: +SKIP
>>> register_flash_attention_impl(
... "MyImpl", register_fn=my_impl_register
... ) # doctest: +SKIP
"""
_FLASH_ATTENTION_IMPLS[impl] = register_fn
def activate_flash_attention_impl(
impl: str | _FlashAttentionImpl,
) -> None:
"""
Activate into the dispatcher a previously registered flash attention impl.
.. note::
Backend providers should NOT automatically activate their implementation
on import. Users should explicitly opt-in by calling this function or via
environment variables to ensure multiple provider libraries can coexist.
Args:
impl: Implementation identifier to activate. See
:func:`~torch.nn.attention.list_flash_attention_impls` for available
implementations.
If the backend's :func:`register_flash_attention_impl` callable
returns a :class:`FlashAttentionHandle`, the registry keeps that
handle alive for the lifetime of the process (until explicit
uninstall support exists).
Example:
>>> activate_flash_attention_impl("FA4") # doctest: +SKIP
"""
global _FLASH_ATTENTION_ACTIVE
register_fn = _FLASH_ATTENTION_IMPLS.get(impl)
if register_fn is None:
raise ValueError(
f"Unknown flash attention impl '{impl}'. "
f"Available implementations: {list_flash_attention_impls()}"
)
# TODO: The only way to actually register a new impl is to unregister the current impl
# reinstall the default impl and then register the new impl
if _FLASH_ATTENTION_ACTIVE == impl:
return
handle = register_fn()
if handle is not None:
_FLASH_ATTENTION_HANDLES[impl] = handle
_FLASH_ATTENTION_ACTIVE = impl
def list_flash_attention_impls() -> list[str]:
"""Return the names of all available flash attention implementations."""
return sorted(_FLASH_ATTENTION_IMPLS.keys())
def current_flash_attention_impl() -> str | None:
"""
Return the currently activated flash attention impl name, if any.
``None`` indicates that no custom impl has been activated.
"""
return _FLASH_ATTENTION_ACTIVE

View File

@ -1771,22 +1771,6 @@ class MultiProcContinuousTest(TestCase):
cls._run_test_given_id(test_id)
completion_queue.put(test_id)
except BaseException as ex: # noqa: B036
if isinstance(ex, SystemExit):
# Get exit code from the process
exit_code = getattr(ex, "code", None)
# Look up exit code in TEST_SKIPS to see if it is a valid skip
skip_entry = next(
(v for v in TEST_SKIPS.values() if v.exit_code == exit_code),
None,
)
# If we found an entry, we want to skip the test and the object back to the main process
if skip_entry:
completion_queue.put(unittest.SkipTest(skip_entry.message))
# Skip exception handling below, move to main thread for processing the skip
continue
raised_exception = True
# Send the exception and stack trace back to the dispatcher
exc_info = sys.exc_info()
@ -1908,8 +1892,6 @@ class MultiProcContinuousTest(TestCase):
# Wait for the workers to finish the test
for i, completion_queue in enumerate(self.completion_queues):
rv = completion_queue.get()
if isinstance(rv, unittest.SkipTest):
raise rv
if isinstance(rv, BaseException):
# Hit an exception, re-raise it in the main process.
logger.warning(

View File

@ -114,6 +114,8 @@ class ProfilingMode(Enum):
PROFILING = 3
# Set by parse_cmd_line_args() if called
CI_FUNCTORCH_ROOT = ""
CI_PT_ROOT = ""
CI_TEST_PREFIX = ""
DISABLED_TESTS_FILE = ""
GRAPH_EXECUTOR : Optional[ProfilingMode] = None
@ -957,6 +959,8 @@ def _get_test_report_path():
return os.path.join('test-reports', test_source)
def parse_cmd_line_args():
global CI_FUNCTORCH_ROOT
global CI_PT_ROOT
global CI_TEST_PREFIX
global DISABLED_TESTS_FILE
global GRAPH_EXECUTOR
@ -1035,8 +1039,10 @@ def parse_cmd_line_args():
set_rng_seed()
# CI Prefix path used only on CI environment
# CI Prefix path used only on CI environment
CI_TEST_PREFIX = str(Path(os.getcwd()))
CI_PT_ROOT = str(Path(os.getcwd()).parent)
CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch"))
def wait_for_process(p, timeout=None):
try: