Add cse for make_block_ptr in Triton codegen (#163399)

Summary: per title

Test Plan: added test cases

Differential Revision: D82648215

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163399
Approved by: https://github.com/jansel, https://github.com/njriasan
This commit is contained in:
Nan Zhang
2025-10-16 05:29:48 +00:00
committed by PyTorch MergeBot
parent 5d0b22008d
commit 00afa06800
3 changed files with 63 additions and 5 deletions

View File

@ -32,6 +32,8 @@ class CodegenInductorTest(InductorTestCase):
*args, *args,
compile_kwargs: Optional[dict] = None, compile_kwargs: Optional[dict] = None,
config_patches: Optional[dict] = None, config_patches: Optional[dict] = None,
atol: float | None = 1e-05,
rtol: float | None = 1e-08,
): ):
""" """
Runs the module through Inductor, comparing to eager reference. Runs the module through Inductor, comparing to eager reference.
@ -53,7 +55,7 @@ class CodegenInductorTest(InductorTestCase):
ref_tensors = flatten_tensors(func(*args)) ref_tensors = flatten_tensors(func(*args))
actual_tensors = flatten_tensors(result) actual_tensors = flatten_tensors(result)
for ref, actual in zip(ref_tensors, actual_tensors): for ref, actual in zip(ref_tensors, actual_tensors):
self.assertTrue(torch.allclose(ref, actual)) self.assertTrue(torch.allclose(ref, actual, atol=atol, rtol=rtol))
return result, code return result, code
@ -89,6 +91,34 @@ class CodegenInductorTest(InductorTestCase):
else: else:
self.count_code(reinterpret_call, code, 2) self.count_code(reinterpret_call, code, 2)
@requires_gpu()
@skipIf(GPU_TYPE == "mps", "Triton is not available for MPS")
def test_cse_make_block_ptr_reduction(self):
def func(a, b):
tmp0 = a * b
tmp1 = a + b
c = tmp0 + tmp1
return c.sum(dim=0)
config_patches = {
"triton.use_block_ptr": True,
"triton.tile_reductions": True,
"triton.prefer_nd_tiling": True,
"triton.max_tiles": 3,
"split_reductions": False,
}
a = torch.randn((512, 4096), device=torch.device(GPU_TYPE))
b = torch.randn((512, 4096), device=torch.device(GPU_TYPE))
_, code = self.run_and_compare(
func,
a,
b,
config_patches=config_patches,
atol=1e-4,
)
self.count_code("= tl.make_block_ptr(in_ptr", code, 2)
self.count_code("= tl.load(block_ptr", code, 2)
@requires_gpu() @requires_gpu()
@skipIf(GPU_TYPE == "mps", "Triton is not available for MPS") @skipIf(GPU_TYPE == "mps", "Triton is not available for MPS")
def test_kernel_fusion_thresholds(self): def test_kernel_fusion_thresholds(self):

View File

@ -824,7 +824,7 @@ class CommonTemplate:
[ [
((8, 8), 1, 1, True), # Persistent Welford fallback ((8, 8), 1, 1, True), # Persistent Welford fallback
subtest( subtest(
((128, 128), 9, 2, False), decorators=[xfail_if_use_tensor_descriptor] ((128, 128), 7, 2, False), decorators=[xfail_if_use_tensor_descriptor]
), # Looped Welford reduction ), # Looped Welford reduction
], ],
) )
@ -924,7 +924,7 @@ class CommonTemplate:
result, (code,) = self._run_and_compare( result, (code,) = self._run_and_compare(
foo, foo,
view, view,
expected_num_block_pointers=6, expected_num_block_pointers=5,
expected_num_triton_kernels=2, expected_num_triton_kernels=2,
config_patches={ config_patches={
"triton.multi_kernel": True, "triton.multi_kernel": True,

View File

@ -2865,6 +2865,24 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
indexing: Union[BlockPtrOptions, TensorDescriptorOptions], indexing: Union[BlockPtrOptions, TensorDescriptorOptions],
other="", other="",
) -> tuple[str, str]: ) -> tuple[str, str]:
"""Generate a block pointer or tensor descriptor for Triton kernel operations.
This method creates either a block pointer (for regular Triton operations) or
a tensor descriptor (for TMA operations) based on the indexing type. It handles
caching and reuse of descriptors for performance optimization.
Args:
name: The name of the buffer/tensor being accessed
var: The variable name for the pointer
indexing: Block pointer options or tensor descriptor options containing
indexing information and boundary check settings
other: Additional parameters string (e.g., padding options)
Returns:
A tuple containing:
- block_descriptor: The generated block pointer or tensor descriptor variable name
- other: Modified additional parameters string with boundary check options
"""
check = indexing.boundary_check() check = indexing.boundary_check()
if isinstance(indexing, TensorDescriptorOptions): if isinstance(indexing, TensorDescriptorOptions):
if check and other: if check and other:
@ -2892,14 +2910,24 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
# tensor descriptor. # tensor descriptor.
block_descriptor = self.prologue_cache[var] block_descriptor = self.prologue_cache[var]
else: else:
block_ptr_line = indexing.format(var, roffset=False)
block_var = self.cse.try_get(block_ptr_line)
# Early return if block descriptor already exists
if block_var:
return str(block_var), other
block_descriptor_id = next(self.block_ptr_id) block_descriptor_id = next(self.block_ptr_id)
if isinstance(indexing, BlockPtrOptions): if isinstance(indexing, BlockPtrOptions):
block_descriptor = f"block_ptr{block_descriptor_id}" block_descriptor = f"block_ptr{block_descriptor_id}"
else: else:
block_descriptor = f"tma_descriptor{block_descriptor_id}" block_descriptor = f"tma_descriptor{block_descriptor_id}"
line_body = DeferredLine( named_var = self.cse.namedvar(
name, f"{block_descriptor} = {indexing.format(var, roffset=False)}" block_descriptor, dtype=torch.uint64, shape=[]
) )
self.cse.put(block_ptr_line, named_var)
line_body = DeferredLine(name, f"{block_descriptor} = {block_ptr_line}")
if indexing.can_lift: if indexing.can_lift:
self.prologue.writeline(line_body) self.prologue.writeline(line_body)
# Cache the descriptor for epilogue subtiling # Cache the descriptor for epilogue subtiling