[Triton] [Inductor] Add a Blackwell specific Template for persistent matmul (#162916)

Summary:
This adds the Triton Tutorial Matmul persistent matmul with device side TMA for Blackwell and adds it as a template option for blackwell. This uses newer Triton features such as automatic warp specialization and loop flattening, which while still containing flaws can improve performance on blackwell. This does not include the Epilogue subtiling section, as that will be a followup PR.

This PR doesn't include any tuning. I am doing a larger benchmarking run to determine the best initial configs for tuning and will open a followup PR with better defaults soon.

Test Plan:
Tested on a Blackwell machine with test_max_autotune.py and confirmed the new tests pass.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162916
Approved by: https://github.com/NikhilAPatel
This commit is contained in:
Nick Riasanovsky
2025-09-15 23:23:01 +00:00
committed by PyTorch MergeBot
parent c77726b1d7
commit 955e195c7d
6 changed files with 348 additions and 1 deletions

View File

@ -52,7 +52,11 @@ from torch.testing._internal.common_utils import (
TEST_WITH_ROCM,
)
from torch.testing._internal.logging_utils import multiple_logs_to_string
from torch.utils._triton import has_triton_stable_tma_api, has_triton_tma_device
from torch.utils._triton import (
has_datacenter_blackwell_tma_device,
has_triton_stable_tma_api,
has_triton_tma_device,
)
aten = torch.ops.aten
@ -259,6 +263,69 @@ class TestMaxAutotune(TestCase):
check_str = "triton.language.make_tensor_descriptor"
FileCheck().check("triton_tem_fused_mm").check(check_str).run(code[0])
@unittest.skipIf(
not has_datacenter_blackwell_tma_device(),
"Need Blackwell with device-side TMA support in Triton",
)
@parametrize("a_transposed", (False, True))
@parametrize("b_transposed", (False, True))
@parametrize("dynamic", (False, True))
@parametrize("tma_store", (False, True))
def test_blackwell_max_autotune_regular_mm_persistent_tma(
self,
a_transposed: bool,
b_transposed: bool,
dynamic: bool,
tma_store: bool,
):
def mm(a, b):
# TMA requires 16-byte alignment: here we repeat the dims
# by the factor of 8, as float16 is 2-byte. All dims are
# repeated due to the possible transpositions below.
a = a.repeat(8, 8)
b = b.repeat(8, 8)
if a_transposed:
a = a.T
if b_transposed:
b = b.T
return torch.mm(a, b)
M, N, K = 32, 16, 48
a = (
torch.randn(*((K, M) if a_transposed else (M, K)))
.to(torch.float16)
.to(GPU_TYPE)
)
b = (
torch.randn(*((N, K) if b_transposed else (K, N)))
.to(torch.float16)
.to(GPU_TYPE)
)
with config.patch(
{
"max_autotune": True,
"triton.enable_persistent_tma_matmul": True,
"triton.enable_template_tma_store": tma_store,
"test_configs.autotune_choice_name_regex": "blackwell_ws_persistent_device_tma",
}
):
c_actual, code = run_and_get_code(torch.compile(mm, dynamic=dynamic), a, b)
c_expected = mm(a, b)
torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=1e-2)
if tma_store:
# Verify that we are using a TMA implementation
# Note: The tma_descriptor0 is generated by the kernel. If the
# code generation process changes this could change.
write_api = "tma_descriptor0.store"
else:
write_api = "tl.store"
FileCheck().check("triton_tem_fused_mm").check(
"triton.language.make_tensor_descriptor"
).check("tl.load_tensor_descriptor").check(write_api).run(code[0])
@unittest.skipIf(
not has_triton_tma_device(), "Need device-side TMA support in Triton"
)
@ -451,6 +518,79 @@ class TestMaxAutotune(TestCase):
torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=1e-2)
@unittest.skipIf(
not has_datacenter_blackwell_tma_device(),
"Need Blackwell with device-side TMA support in Triton",
)
@parametrize("a_transposed", (False, True))
@parametrize("b_transposed", (False, True))
@parametrize("dynamic", (False, True))
@parametrize("tma_store", (False, True))
def test_blackwell_max_autotune_addmm_persistent_tma(
self,
a_transposed: bool,
b_transposed: bool,
dynamic: bool,
tma_store: bool,
):
def addmm(x, a, b):
# TMA requires 16-byte alignment: here we repeat the dims
# by the factor of 8, as float16 is 2-byte. All dims are
# repeated due to the possible transpositions below.
x = x.repeat(8)
a = a.repeat(8, 8)
b = b.repeat(8, 8)
if a_transposed:
a = a.T
if b_transposed:
b = b.T
return torch.addmm(x, a, b)
M, N, K = 21, 31, 11
a = (
torch.randn(*((K, M) if a_transposed else (M, K)))
.to(torch.float16)
.to(GPU_TYPE)
)
b = (
torch.randn(*((N, K) if b_transposed else (K, N)))
.to(torch.float16)
.to(GPU_TYPE)
)
x = torch.randn(N).to(torch.float16).to(GPU_TYPE)
with config.patch(
{
"max_autotune": True,
"triton.enable_persistent_tma_matmul": True,
"triton.enable_template_tma_store": tma_store,
"test_configs.autotune_choice_name_regex": "blackwell_ws_persistent_device_tma",
}
):
c_actual, code = run_and_get_code(
torch.compile(addmm, dynamic=dynamic), x, a, b
)
c_expected = addmm(x, a, b)
make_desc_api = "triton.language.make_tensor_descriptor"
read_api = "tl.load_tensor_descriptor"
if tma_store:
# Verify that we are using a TMA implementation
# Note: The tma_descriptor0 is generated by the kernel. If the
# code generation process changes this could change.
write_api = "tma_descriptor0.store"
else:
write_api = "tl.store"
# Verify that we are using a TMA implementation
FileCheck().check("triton_tem_fused_addmm").check(make_desc_api).check(
read_api
).check(write_api).run(code[0])
torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=1e-2)
@unittest.skipIf(
not has_triton_tma_device(), "Need device-side TMA support in Triton"
)

View File

@ -27,6 +27,16 @@ def get_cuda_arch() -> Optional[str]:
return None
@clear_on_fresh_cache
@functools.lru_cache(1)
def is_datacenter_blackwell_arch() -> bool:
arch = get_cuda_arch()
if arch is None:
return False
arch_number = int(arch)
return arch_number >= 100 and arch_number < 110
@clear_on_fresh_cache
@functools.lru_cache(1)
def get_cuda_version() -> Optional[str]:

View File

@ -41,6 +41,7 @@ from ..utils import (
use_cpp_gemm_template,
use_cutlass_template,
use_decompose_k_choice,
use_triton_blackwell_tma_template,
use_triton_template,
use_triton_tma_template,
)
@ -563,6 +564,103 @@ scaled_mm_device_tma_template = TritonTemplate(
source=device_tma + load_scales + apply_scaling,
)
_compute_blackwell_pid = r"""
@triton.jit
def _compute_pid(tile_id, num_pid_in_group, grid_m, GROUP_M: tl.constexpr, NUM_SMS: tl.constexpr):
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_M
GROUP_M = min(grid_m - first_pid_m, GROUP_M)
pid_m = first_pid_m + (tile_id % GROUP_M)
pid_n = (tile_id % num_pid_in_group) // GROUP_M
return pid_m, pid_n
"""
_blackwell_ws_persistent_device_tma = r"""
{{def_kernel("A", "B")}}
M = {{size("A", 0)}}
N = {{size("B", 1)}}
K = {{size("A", 1)}}
if M * N == 0:
# early exit due to zero-size input(s)
return
start_pid = tl.program_id(0)
grid_m = tl.cdiv(M, BLOCK_M)
grid_n = tl.cdiv(N, BLOCK_N)
k_tiles = tl.cdiv(K, BLOCK_K)
num_tiles = grid_m * grid_n
# Note: We require TMA_EXPERIMENTAL_API == False, which
# we will check before invoking this template.
stride_am = {{stride("A", 0)}}
stride_ak = {{stride("A", 1)}}
stride_bk = {{stride("B", 0)}}
stride_bn = {{stride("B", 1)}}
a_desc = triton.language.make_tensor_descriptor(
base=A,
shape=[M, K] if A_ROW_MAJOR else [K, M],
strides=[stride_am, 1] if A_ROW_MAJOR else [stride_ak, 1],
block_shape=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M],
)
b_desc = triton.language.make_tensor_descriptor(
base=B,
shape=[K, N] if B_ROW_MAJOR else [N, K],
strides=[stride_bk, 1] if B_ROW_MAJOR else [stride_bn, 1],
block_shape=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K],
)
# tile_id_c is used in the epilogue to break the dependency between
# the prologue and the epilogue
tile_id_c = start_pid - NUM_SMS
num_pid_in_group = GROUP_M * grid_n
for tile_id in tl.range(
start_pid, num_tiles, NUM_SMS, flatten=FLATTEN, warp_specialize=WARP_SPECIALIZE
):
pid_m, pid_n = _compute_pid(
tile_id, num_pid_in_group, grid_m, GROUP_M, NUM_SMS
)
offs_am = pid_m * BLOCK_M
offs_bn = pid_n * BLOCK_N
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for ki in range(k_tiles):
offs_k = ki * BLOCK_K
a = tl.load_tensor_descriptor(
a_desc,
[offs_am, offs_k] if A_ROW_MAJOR else [offs_k, offs_am],
)
b = tl.load_tensor_descriptor(
b_desc,
[offs_k, offs_bn] if B_ROW_MAJOR else [offs_bn, offs_k],
)
accumulator += tl.dot(
a if A_ROW_MAJOR else a.T,
b if B_ROW_MAJOR else b.T,
allow_tf32=ALLOW_TF32,
)
tile_id_c += NUM_SMS
pid_m, pid_n = _compute_pid(
tile_id_c, num_pid_in_group, grid_m, GROUP_M, NUM_SMS
)
offs_cm = pid_m * BLOCK_M
offs_cn = pid_n * BLOCK_N
# TODO: Add EPILOGUE_SUBTILE
{{store_output(
("offs_cm", "offs_cn"),
"accumulator",
indent_width=8,
val_shape=("BLOCK_M", "BLOCK_N"),
block_indexing=True
)}}
"""
blackwell_ws_persistent_device_tma_mm_template = TritonTemplate(
name="blackwell_ws_persistent_device_tma",
grid=persistent_mm_grid,
source=_blackwell_ws_persistent_device_tma + _compute_blackwell_pid,
)
# prevent duplication registration of extern functions
@functools.cache
@ -777,6 +875,9 @@ def tuned_mm(mat1, mat2, *, layout=None):
if use_triton_tma_template(mat1, mat2, output_layout=layout):
templates_to_use.append(persistent_tma_mm_template)
if use_triton_blackwell_tma_template(mat1, mat2, output_layout=layout):
templates_to_use.append(blackwell_ws_persistent_device_tma_mm_template)
if use_decompose_k_choice(m, n, k):
templates_to_use.append(decompose_k_subgraph_template)
@ -980,6 +1081,9 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
if use_triton_tma_template(mat1, mat2, output_layout=layout):
templates_to_use.append(persistent_tma_mm_template)
if use_triton_blackwell_tma_template(mat1, mat2, output_layout=layout):
templates_to_use.append(blackwell_ws_persistent_device_tma_mm_template)
templates_to_use.append(addmm_contiguous_subgraph_template)
# Single unified call for all templates

View File

@ -18,6 +18,7 @@ from torch.utils._triton import has_triton_stable_tma_api
from .. import config, config as inductor_config
from ..kernel.bmm import bmm_template
from ..kernel.mm import (
blackwell_ws_persistent_device_tma_mm_template,
mm_template,
persistent_tma_mm_template,
scaled_mm_device_tma_template,
@ -1652,6 +1653,35 @@ class TMATemplateConfigMixin(TMAWorkspaceMixin, MMTemplateConfigMixin):
yield {**template_kwargs, **tma_opts}
# TMA mixins for Blackwell templates
class BlackwellTMATemplateConfigMixin(TMATemplateConfigMixin):
def _get_template_configs_impl(
self,
kernel_inputs: KernelInputs,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
"""
Generate TMA template configs by calling super and adding TMA-specific options.
"""
base_ops = {
"NUM_SMS": get_num_sms(),
# TODO: Consider making this tunable.
"FLATTEN": True,
}
# Get base template configs from superclass
for template_kwargs in super()._get_template_configs_impl(
kernel_inputs,
op_name,
):
# Some Triton versions requires num_warps >= 4 for WS
# to avoid compilation issues. Triton disables WS if num_warps < 4
# or num_stages < 2. Similar issues have been seen with num_stages=1
ws = (
template_kwargs["num_warps"] >= 4 and template_kwargs["num_stages"] >= 2
)
yield {**template_kwargs, **base_ops, "WARP_SPECIALIZE": ws}
# Scaled MM-specific mixin for scaled MM templates
class BaseScaledMMConfigMixin(MMTemplateConfigMixin):
"""
@ -1889,6 +1919,22 @@ class CUDAPersistentTMATemplateConfigHeuristic(
self.mm_configs = self.persistent_mm_configs
@register_template_heuristic(
blackwell_ws_persistent_device_tma_mm_template.uid,
"cuda",
register=torch.version.hip is None,
)
class CUDABlackwellPersistentTMATemplateConfigHeuristic(
BlackwellTMATemplateConfigMixin, CUDAConfigHeuristic
):
"""Blackwell Persistent TMA template"""
def __init__(self) -> None:
super().__init__()
# TODO: Tune mm_configs for blackwell.
self.mm_configs = self.persistent_mm_configs
@register_template_heuristic(
persistent_tma_mm_template.uid,
"cuda",
@ -1901,6 +1947,22 @@ class CUDAAddmmPersistentTMATemplateConfigHeuristic(
"""Addmm specific mixin for CUDA"""
@register_template_heuristic(
blackwell_ws_persistent_device_tma_mm_template.uid,
"cuda",
register=torch.version.hip is None,
)
class CUDABlackwellAddmmPersistentTMATemplateConfigHeuristic(
AddMMConfigMixin, CUDABlackwellPersistentTMATemplateConfigHeuristic
):
"""Addmm extension for DataCenter Blackwell Templates"""
def __init__(self) -> None:
super().__init__()
# TODO: Tune mm_configs for blackwell.
self.mm_configs = self.persistent_mm_configs
@register_template_heuristic(
mm_template.uid, "cuda", register=torch.version.hip is None, op_name="scaled_mm"
)

View File

@ -1790,6 +1790,22 @@ def use_triton_tma_template(
)
def use_triton_blackwell_tma_template(
*matrices: IRNode, output_layout: Layout, add_guards: bool = False
) -> bool:
if not use_triton_tma_template(
*matrices, output_layout=output_layout, add_guards=add_guards
):
return False
from torch.utils._triton import has_triton_tensor_descriptor_host_tma
from .codegen.cuda.cuda_env import is_datacenter_blackwell_arch
# Blackwell template require the tensor descriptor API, not the experimental API.
return has_triton_tensor_descriptor_host_tma() and is_datacenter_blackwell_arch()
def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
from .virtualized import V

View File

@ -105,6 +105,21 @@ def has_triton_tma_device() -> bool:
return False
@functools.cache
def has_datacenter_blackwell_tma_device() -> bool:
import torch
if (
torch.cuda.is_available()
and torch.cuda.get_device_capability() >= (10, 0)
and torch.cuda.get_device_capability() < (11, 0)
and not torch.version.hip
):
return has_triton_tma_device() and has_triton_tensor_descriptor_host_tma()
return False
@functools.lru_cache(None)
def has_triton_stable_tma_api() -> bool:
if has_triton_package():