Add cse for make_block_ptr in Triton codegen (#163399)

Summary: per title

Test Plan: added test cases

Differential Revision: D82648215

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163399
Approved by: https://github.com/jansel, https://github.com/njriasan
This commit is contained in:
Nan Zhang
2025-10-16 05:29:48 +00:00
committed by PyTorch MergeBot
parent 5d0b22008d
commit 00afa06800
3 changed files with 63 additions and 5 deletions

View File

@ -32,6 +32,8 @@ class CodegenInductorTest(InductorTestCase):
*args,
compile_kwargs: Optional[dict] = None,
config_patches: Optional[dict] = None,
atol: float | None = 1e-05,
rtol: float | None = 1e-08,
):
"""
Runs the module through Inductor, comparing to eager reference.
@ -53,7 +55,7 @@ class CodegenInductorTest(InductorTestCase):
ref_tensors = flatten_tensors(func(*args))
actual_tensors = flatten_tensors(result)
for ref, actual in zip(ref_tensors, actual_tensors):
self.assertTrue(torch.allclose(ref, actual))
self.assertTrue(torch.allclose(ref, actual, atol=atol, rtol=rtol))
return result, code
@ -89,6 +91,34 @@ class CodegenInductorTest(InductorTestCase):
else:
self.count_code(reinterpret_call, code, 2)
@requires_gpu()
@skipIf(GPU_TYPE == "mps", "Triton is not available for MPS")
def test_cse_make_block_ptr_reduction(self):
def func(a, b):
tmp0 = a * b
tmp1 = a + b
c = tmp0 + tmp1
return c.sum(dim=0)
config_patches = {
"triton.use_block_ptr": True,
"triton.tile_reductions": True,
"triton.prefer_nd_tiling": True,
"triton.max_tiles": 3,
"split_reductions": False,
}
a = torch.randn((512, 4096), device=torch.device(GPU_TYPE))
b = torch.randn((512, 4096), device=torch.device(GPU_TYPE))
_, code = self.run_and_compare(
func,
a,
b,
config_patches=config_patches,
atol=1e-4,
)
self.count_code("= tl.make_block_ptr(in_ptr", code, 2)
self.count_code("= tl.load(block_ptr", code, 2)
@requires_gpu()
@skipIf(GPU_TYPE == "mps", "Triton is not available for MPS")
def test_kernel_fusion_thresholds(self):

View File

@ -824,7 +824,7 @@ class CommonTemplate:
[
((8, 8), 1, 1, True), # Persistent Welford fallback
subtest(
((128, 128), 9, 2, False), decorators=[xfail_if_use_tensor_descriptor]
((128, 128), 7, 2, False), decorators=[xfail_if_use_tensor_descriptor]
), # Looped Welford reduction
],
)
@ -924,7 +924,7 @@ class CommonTemplate:
result, (code,) = self._run_and_compare(
foo,
view,
expected_num_block_pointers=6,
expected_num_block_pointers=5,
expected_num_triton_kernels=2,
config_patches={
"triton.multi_kernel": True,

View File

@ -2865,6 +2865,24 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
indexing: Union[BlockPtrOptions, TensorDescriptorOptions],
other="",
) -> tuple[str, str]:
"""Generate a block pointer or tensor descriptor for Triton kernel operations.
This method creates either a block pointer (for regular Triton operations) or
a tensor descriptor (for TMA operations) based on the indexing type. It handles
caching and reuse of descriptors for performance optimization.
Args:
name: The name of the buffer/tensor being accessed
var: The variable name for the pointer
indexing: Block pointer options or tensor descriptor options containing
indexing information and boundary check settings
other: Additional parameters string (e.g., padding options)
Returns:
A tuple containing:
- block_descriptor: The generated block pointer or tensor descriptor variable name
- other: Modified additional parameters string with boundary check options
"""
check = indexing.boundary_check()
if isinstance(indexing, TensorDescriptorOptions):
if check and other:
@ -2892,14 +2910,24 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
# tensor descriptor.
block_descriptor = self.prologue_cache[var]
else:
block_ptr_line = indexing.format(var, roffset=False)
block_var = self.cse.try_get(block_ptr_line)
# Early return if block descriptor already exists
if block_var:
return str(block_var), other
block_descriptor_id = next(self.block_ptr_id)
if isinstance(indexing, BlockPtrOptions):
block_descriptor = f"block_ptr{block_descriptor_id}"
else:
block_descriptor = f"tma_descriptor{block_descriptor_id}"
line_body = DeferredLine(
name, f"{block_descriptor} = {indexing.format(var, roffset=False)}"
named_var = self.cse.namedvar(
block_descriptor, dtype=torch.uint64, shape=[]
)
self.cse.put(block_ptr_line, named_var)
line_body = DeferredLine(name, f"{block_descriptor} = {block_ptr_line}")
if indexing.can_lift:
self.prologue.writeline(line_body)
# Cache the descriptor for epilogue subtiling