mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor] [Triton] Apply feedback to Enable padded stride support (#160614)
Summary: Issue I noticed while fixing tests for TMA store. This triton.language.make_tensor_descriptor call hardcodes the shape information as the stride, which is not necessarily correct. In particular, its legal to have a stride bigger than the shape (e.g. padded to a size). A good example of the usage of this would be to allocate a tensor to always be a multiple of 16 and just pad the result so TMA is legal. This is redo of https://github.com/pytorch/pytorch/pull/160493 because I broke this accidentally trying to land internally first instead of merging through Github directly. Test Plan: Tested with `buck2 run mode/opt-split-dwarf mode/inplace -c fbcode.nvcc_arch=h100 caffe2/test/inductor:max_autotune 2>&1 | tee ~/test_logs.log` and confirmed all max autotune tests passed. Rollback Plan: Differential Revision: D80224578 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160614 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
d387a48c38
commit
25ccc4716e
@ -165,6 +165,65 @@ class TestMaxAutotune(TestCase):
|
||||
|
||||
torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=1e-2)
|
||||
|
||||
@unittest.skipIf(
|
||||
not has_triton_tma_device(), "Need device-side TMA support in Triton"
|
||||
)
|
||||
@parametrize("a_transposed", (False, True))
|
||||
@parametrize("b_transposed", (False, True))
|
||||
@parametrize("dynamic", (False, True))
|
||||
def test_max_autotune_regular_mm_persistent_tma_strided(
|
||||
self,
|
||||
a_transposed: bool,
|
||||
b_transposed: bool,
|
||||
dynamic: bool,
|
||||
):
|
||||
def mm(a, b):
|
||||
# TMA requires 16-byte alignment: here we repeat the dims
|
||||
# by the factor of 8, as float16 is 2-byte. All dims are
|
||||
# repeated due to the possible transpositions below.
|
||||
a = a.repeat(8, 8)
|
||||
b = b.repeat(8, 8)
|
||||
if a_transposed:
|
||||
a = a.T
|
||||
if b_transposed:
|
||||
b = b.T
|
||||
|
||||
return torch.mm(a, b)
|
||||
|
||||
def next_multiple_16(a: int) -> int:
|
||||
return ((a + 15) // 16) * 16
|
||||
|
||||
M, N, K = 21, 31, 11
|
||||
a_shape = (K, M) if a_transposed else (M, K)
|
||||
a_stride = (
|
||||
(next_multiple_16(M), 1) if a_transposed else (next_multiple_16(K), 1)
|
||||
)
|
||||
a = torch.empty_strided(a_shape, a_stride, dtype=torch.float16).to(GPU_TYPE)
|
||||
a[:] = torch.randn(a_shape, dtype=torch.float16)
|
||||
a = a.to(GPU_TYPE)
|
||||
b_shape = (N, K) if b_transposed else (K, N)
|
||||
b_stride = (
|
||||
(next_multiple_16(K), 1) if a_transposed else (next_multiple_16(N), 1)
|
||||
)
|
||||
b = torch.empty_strided(b_shape, b_stride, dtype=torch.float16)
|
||||
b[:] = torch.randn(b_shape, dtype=torch.float16)
|
||||
b = b.to(GPU_TYPE)
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"triton.enable_persistent_tma_matmul": "1",
|
||||
"test_configs.autotune_choice_name_regex": "mm_persistent_tma",
|
||||
}
|
||||
):
|
||||
c_actual, code = run_and_get_code(torch.compile(mm, dynamic=dynamic), a, b)
|
||||
c_expected = mm(a, b)
|
||||
|
||||
torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=1e-2)
|
||||
# Verify that we are using a TMA implementation
|
||||
FileCheck().check("triton_tem_fused_mm").check(
|
||||
"triton.language.make_tensor_descriptor"
|
||||
).run(code[0])
|
||||
|
||||
@unittest.skipIf(
|
||||
not has_triton_tma_device(), "Need device-side TMA support in Triton"
|
||||
)
|
||||
|
@ -283,16 +283,20 @@ persistent_tma_mm_template = TritonTemplate(
|
||||
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
|
||||
|
||||
{%- else %}
|
||||
stride_am = {{stride("A", 0)}}
|
||||
stride_ak = {{stride("A", 1)}}
|
||||
stride_bk = {{stride("B", 0)}}
|
||||
stride_bn = {{stride("B", 1)}}
|
||||
a_desc = triton.language.make_tensor_descriptor(
|
||||
base=A,
|
||||
shape=[M, K] if A_ROW_MAJOR else [K, M],
|
||||
strides=[K, 1] if A_ROW_MAJOR else [M, 1],
|
||||
strides=[stride_am, 1] if A_ROW_MAJOR else [stride_ak, 1],
|
||||
block_shape=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M],
|
||||
)
|
||||
b_desc = triton.language.make_tensor_descriptor(
|
||||
base=B,
|
||||
shape=[K, N] if B_ROW_MAJOR else [N, K],
|
||||
strides=[N, 1] if B_ROW_MAJOR else [K, 1],
|
||||
strides=[stride_bk, 1] if B_ROW_MAJOR else [stride_bn, 1],
|
||||
block_shape=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K],
|
||||
)
|
||||
{%- endif %}
|
||||
@ -461,16 +465,18 @@ device_tma = r"""
|
||||
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
|
||||
|
||||
{%- else %}
|
||||
stride_am = {{stride("A", 0)}}
|
||||
stride_bn = {{stride("B", 1)}}
|
||||
a_desc = triton.language.make_tensor_descriptor(
|
||||
base=A,
|
||||
shape=[M, K],
|
||||
strides=[K, 1],
|
||||
strides=[stride_am, 1],
|
||||
block_shape=[BLOCK_M, BLOCK_K],
|
||||
)
|
||||
b_desc = triton.language.make_tensor_descriptor(
|
||||
base=B,
|
||||
shape=[N, K],
|
||||
strides=[K, 1],
|
||||
strides=[stride_bn, 1],
|
||||
block_shape=[BLOCK_N, BLOCK_K],
|
||||
)
|
||||
{%- endif %}
|
||||
|
@ -1730,7 +1730,8 @@ def can_use_tma(*matrices: IRNode, add_guards: bool = False) -> bool:
|
||||
|
||||
def use_triton_tma_template(*matrices: IRNode, add_guards: bool = False) -> bool:
|
||||
return (
|
||||
can_use_tma(*matrices, add_guards=add_guards)
|
||||
all(len(m.get_size()) == 2 for m in matrices)
|
||||
and can_use_tma(*matrices, add_guards=add_guards)
|
||||
and config.triton.enable_persistent_tma_matmul
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user