[Inductor] [Triton] Apply feedback to Enable padded stride support (#160614)

Summary:
Issue I noticed while fixing tests for TMA store. This triton.language.make_tensor_descriptor call hardcodes the shape information as the stride, which is not necessarily correct.

In particular, its legal to have a stride bigger than the shape (e.g. padded to a size). A good example of the usage of this would be to allocate a tensor to always be a multiple of 16 and just pad the result so TMA is legal.

This is redo of https://github.com/pytorch/pytorch/pull/160493 because I broke this accidentally trying to land internally first instead of merging through Github directly.

Test Plan:
Tested with `buck2 run mode/opt-split-dwarf mode/inplace -c fbcode.nvcc_arch=h100 caffe2/test/inductor:max_autotune 2>&1 | tee ~/test_logs.log` and confirmed all max autotune tests passed.

Rollback Plan:

Differential Revision: D80224578

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160614
Approved by: https://github.com/eellison
This commit is contained in:
Nick Riasanovsky
2025-08-15 02:06:11 +00:00
committed by PyTorch MergeBot
parent d387a48c38
commit 25ccc4716e
3 changed files with 71 additions and 5 deletions

View File

@ -165,6 +165,65 @@ class TestMaxAutotune(TestCase):
torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=1e-2)
@unittest.skipIf(
not has_triton_tma_device(), "Need device-side TMA support in Triton"
)
@parametrize("a_transposed", (False, True))
@parametrize("b_transposed", (False, True))
@parametrize("dynamic", (False, True))
def test_max_autotune_regular_mm_persistent_tma_strided(
self,
a_transposed: bool,
b_transposed: bool,
dynamic: bool,
):
def mm(a, b):
# TMA requires 16-byte alignment: here we repeat the dims
# by the factor of 8, as float16 is 2-byte. All dims are
# repeated due to the possible transpositions below.
a = a.repeat(8, 8)
b = b.repeat(8, 8)
if a_transposed:
a = a.T
if b_transposed:
b = b.T
return torch.mm(a, b)
def next_multiple_16(a: int) -> int:
return ((a + 15) // 16) * 16
M, N, K = 21, 31, 11
a_shape = (K, M) if a_transposed else (M, K)
a_stride = (
(next_multiple_16(M), 1) if a_transposed else (next_multiple_16(K), 1)
)
a = torch.empty_strided(a_shape, a_stride, dtype=torch.float16).to(GPU_TYPE)
a[:] = torch.randn(a_shape, dtype=torch.float16)
a = a.to(GPU_TYPE)
b_shape = (N, K) if b_transposed else (K, N)
b_stride = (
(next_multiple_16(K), 1) if a_transposed else (next_multiple_16(N), 1)
)
b = torch.empty_strided(b_shape, b_stride, dtype=torch.float16)
b[:] = torch.randn(b_shape, dtype=torch.float16)
b = b.to(GPU_TYPE)
with config.patch(
{
"max_autotune": True,
"triton.enable_persistent_tma_matmul": "1",
"test_configs.autotune_choice_name_regex": "mm_persistent_tma",
}
):
c_actual, code = run_and_get_code(torch.compile(mm, dynamic=dynamic), a, b)
c_expected = mm(a, b)
torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=1e-2)
# Verify that we are using a TMA implementation
FileCheck().check("triton_tem_fused_mm").check(
"triton.language.make_tensor_descriptor"
).run(code[0])
@unittest.skipIf(
not has_triton_tma_device(), "Need device-side TMA support in Triton"
)

View File

@ -283,16 +283,20 @@ persistent_tma_mm_template = TritonTemplate(
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
{%- else %}
stride_am = {{stride("A", 0)}}
stride_ak = {{stride("A", 1)}}
stride_bk = {{stride("B", 0)}}
stride_bn = {{stride("B", 1)}}
a_desc = triton.language.make_tensor_descriptor(
base=A,
shape=[M, K] if A_ROW_MAJOR else [K, M],
strides=[K, 1] if A_ROW_MAJOR else [M, 1],
strides=[stride_am, 1] if A_ROW_MAJOR else [stride_ak, 1],
block_shape=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M],
)
b_desc = triton.language.make_tensor_descriptor(
base=B,
shape=[K, N] if B_ROW_MAJOR else [N, K],
strides=[N, 1] if B_ROW_MAJOR else [K, 1],
strides=[stride_bk, 1] if B_ROW_MAJOR else [stride_bn, 1],
block_shape=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K],
)
{%- endif %}
@ -461,16 +465,18 @@ device_tma = r"""
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
{%- else %}
stride_am = {{stride("A", 0)}}
stride_bn = {{stride("B", 1)}}
a_desc = triton.language.make_tensor_descriptor(
base=A,
shape=[M, K],
strides=[K, 1],
strides=[stride_am, 1],
block_shape=[BLOCK_M, BLOCK_K],
)
b_desc = triton.language.make_tensor_descriptor(
base=B,
shape=[N, K],
strides=[K, 1],
strides=[stride_bn, 1],
block_shape=[BLOCK_N, BLOCK_K],
)
{%- endif %}

View File

@ -1730,7 +1730,8 @@ def can_use_tma(*matrices: IRNode, add_guards: bool = False) -> bool:
def use_triton_tma_template(*matrices: IRNode, add_guards: bool = False) -> bool:
return (
can_use_tma(*matrices, add_guards=add_guards)
all(len(m.get_size()) == 2 for m in matrices)
and can_use_tma(*matrices, add_guards=add_guards)
and config.triton.enable_persistent_tma_matmul
)