Compare commits

...

2 Commits

Author SHA1 Message Date
093ff36791 Update on "Freeze layout for potentially padded strides in template autotuning"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-11-17 14:10:02 -08:00
fa8f17a9fd Freeze layout for potentially padded strides in template autotuning
[ghstack-poisoned]
2025-11-17 13:40:40 -08:00
2 changed files with 66 additions and 3 deletions

View File

@ -2483,6 +2483,66 @@ class TestMaxAutotune(TestCase):
finally:
clear_preprocessing_fns(clear_defaults=False)
@config.patch(
{"test_configs.max_mm_configs": 4, "max_autotune_gemm_backends": "TRITON"}
)
def test_fixed_layout_at_lowering(self):
"""
Test that max-autotune with addmm/bmm/mm_plus_mm correctly handles
padding and maintains correct output strides. Specifically, when matrix
b with shape (4608, 1490) is padded, its stride should become 1536.
"""
def mm_func(a, b) -> torch.Tensor:
a_t = torch.permute(a, [1, 0]).to(torch.bfloat16)
b_dtype = b.to(torch.bfloat16)
# Add .to() to make sure that mm could be potentially padded
# Strides for output are not padded
return (a_t @ b_dtype).to(torch.float32)
def addmm_func(a, b, bias) -> torch.Tensor:
a_t = torch.permute(a, [1, 0]).to(torch.bfloat16)
b_dtype = b.to(torch.bfloat16)
bias_dtype = bias.to(torch.bfloat16)
return torch.addmm(bias_dtype, a_t, b_dtype).to(torch.float32)
def bmm_func(a, b) -> torch.Tensor:
a_t = torch.permute(a, [2, 0, 1]).to(torch.bfloat16)
b_dtype = b.to(torch.bfloat16)
return torch.bmm(a_t, b_dtype).to(torch.float32)
def mm_plus_mm_func(a1, b1, a2, b2) -> torch.Tensor:
a1_t = torch.permute(a1, [1, 0]).to(torch.bfloat16)
b1_dtype = b1.to(torch.bfloat16)
a2_t = torch.permute(a2, [1, 0]).to(torch.bfloat16)
b2_dtype = b2.to(torch.bfloat16)
return (a1_t @ b1_dtype + a2_t @ b2_dtype).to(torch.float32)
a = torch.randn((4608, 512), device=GPU_TYPE, dtype=torch.bfloat16)
b = torch.randn((4608, 1490), device=GPU_TYPE)
bias = torch.randn(1490, device=GPU_TYPE)
a_bmm = torch.randn((512, 4608, 8), device=GPU_TYPE, dtype=torch.bfloat16)
b_bmm = torch.randn((8, 4608, 1490), device=GPU_TYPE)
# Test mm_plus_mm
a2 = torch.randn((4608, 512), device=GPU_TYPE, dtype=torch.bfloat16)
b2 = torch.randn((4608, 1490), device=GPU_TYPE)
# 1490 padded to 1536, check in template code
output_code_padding_check = "stride_bk = 1536"
funcs_and_args = [
(mm_func, (a, b)),
(addmm_func, (a, b, bias)),
(bmm_func, (a_bmm, b_bmm)),
(mm_plus_mm_func, (a, b, a2, b2)),
]
for f, args in funcs_and_args:
c_f = torch.compile(f, mode="max-autotune-no-cudagraphs")
_, code_out = run_and_get_code(c_f, *args)
FileCheck().check(output_code_padding_check).run(code_out[0])
class TestMaxAutotunePrecompile(TestCase):
def test_precompilation_threads(self):

View File

@ -777,7 +777,7 @@ class TritonTemplateKernel(TritonKernel):
val = self.output_node.get_stride()
else:
assert isinstance(name, str)
val = self.named_input_nodes[name].get_stride()
val = self.get_stride_and_maybe_freeze_layout(self.named_input_nodes[name])
if isinstance(index, int):
return texpr(self.rename_indexing(val[index]))
@ -955,7 +955,6 @@ class TritonTemplateKernel(TritonKernel):
self.template_mask = mask if mask is not None else "None"
self.template_out_shape = index_shape if index_shape else "xindex"
self.template_indices = indices
self.named_input_nodes[input_name].data.freeze_layout()
self.cse.invalidate(OrderedSet())
template_mask = self.template_mask
@ -1412,7 +1411,7 @@ class TritonTemplateKernel(TritonKernel):
assert isinstance(indices, (list, tuple))
assert isinstance(name, str)
assert isinstance(mask, str)
stride = self.named_input_nodes[name].get_stride()
stride = self.get_stride_and_maybe_freeze_layout(self.named_input_nodes[name])
indices = list(map(OpOverrides.paren, indices))
assert len(indices) == len(stride)
index = " + ".join(
@ -1502,6 +1501,10 @@ class TritonTemplateKernel(TritonKernel):
)
]
def get_stride_and_maybe_freeze_layout(self, node) -> list[int]:
node.data.freeze_layout()
return node.get_stride()
@functools.cache
def _jinja2_env():