mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add cse for make_block_ptr in Triton codegen (#163399)
Summary: per title Test Plan: added test cases Differential Revision: D82648215 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163399 Approved by: https://github.com/jansel, https://github.com/njriasan
This commit is contained in:
committed by
PyTorch MergeBot
parent
5d0b22008d
commit
00afa06800
@ -32,6 +32,8 @@ class CodegenInductorTest(InductorTestCase):
|
||||
*args,
|
||||
compile_kwargs: Optional[dict] = None,
|
||||
config_patches: Optional[dict] = None,
|
||||
atol: float | None = 1e-05,
|
||||
rtol: float | None = 1e-08,
|
||||
):
|
||||
"""
|
||||
Runs the module through Inductor, comparing to eager reference.
|
||||
@ -53,7 +55,7 @@ class CodegenInductorTest(InductorTestCase):
|
||||
ref_tensors = flatten_tensors(func(*args))
|
||||
actual_tensors = flatten_tensors(result)
|
||||
for ref, actual in zip(ref_tensors, actual_tensors):
|
||||
self.assertTrue(torch.allclose(ref, actual))
|
||||
self.assertTrue(torch.allclose(ref, actual, atol=atol, rtol=rtol))
|
||||
|
||||
return result, code
|
||||
|
||||
@ -89,6 +91,34 @@ class CodegenInductorTest(InductorTestCase):
|
||||
else:
|
||||
self.count_code(reinterpret_call, code, 2)
|
||||
|
||||
@requires_gpu()
|
||||
@skipIf(GPU_TYPE == "mps", "Triton is not available for MPS")
|
||||
def test_cse_make_block_ptr_reduction(self):
|
||||
def func(a, b):
|
||||
tmp0 = a * b
|
||||
tmp1 = a + b
|
||||
c = tmp0 + tmp1
|
||||
return c.sum(dim=0)
|
||||
|
||||
config_patches = {
|
||||
"triton.use_block_ptr": True,
|
||||
"triton.tile_reductions": True,
|
||||
"triton.prefer_nd_tiling": True,
|
||||
"triton.max_tiles": 3,
|
||||
"split_reductions": False,
|
||||
}
|
||||
a = torch.randn((512, 4096), device=torch.device(GPU_TYPE))
|
||||
b = torch.randn((512, 4096), device=torch.device(GPU_TYPE))
|
||||
_, code = self.run_and_compare(
|
||||
func,
|
||||
a,
|
||||
b,
|
||||
config_patches=config_patches,
|
||||
atol=1e-4,
|
||||
)
|
||||
self.count_code("= tl.make_block_ptr(in_ptr", code, 2)
|
||||
self.count_code("= tl.load(block_ptr", code, 2)
|
||||
|
||||
@requires_gpu()
|
||||
@skipIf(GPU_TYPE == "mps", "Triton is not available for MPS")
|
||||
def test_kernel_fusion_thresholds(self):
|
||||
|
@ -824,7 +824,7 @@ class CommonTemplate:
|
||||
[
|
||||
((8, 8), 1, 1, True), # Persistent Welford fallback
|
||||
subtest(
|
||||
((128, 128), 9, 2, False), decorators=[xfail_if_use_tensor_descriptor]
|
||||
((128, 128), 7, 2, False), decorators=[xfail_if_use_tensor_descriptor]
|
||||
), # Looped Welford reduction
|
||||
],
|
||||
)
|
||||
@ -924,7 +924,7 @@ class CommonTemplate:
|
||||
result, (code,) = self._run_and_compare(
|
||||
foo,
|
||||
view,
|
||||
expected_num_block_pointers=6,
|
||||
expected_num_block_pointers=5,
|
||||
expected_num_triton_kernels=2,
|
||||
config_patches={
|
||||
"triton.multi_kernel": True,
|
||||
|
@ -2865,6 +2865,24 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
indexing: Union[BlockPtrOptions, TensorDescriptorOptions],
|
||||
other="",
|
||||
) -> tuple[str, str]:
|
||||
"""Generate a block pointer or tensor descriptor for Triton kernel operations.
|
||||
|
||||
This method creates either a block pointer (for regular Triton operations) or
|
||||
a tensor descriptor (for TMA operations) based on the indexing type. It handles
|
||||
caching and reuse of descriptors for performance optimization.
|
||||
|
||||
Args:
|
||||
name: The name of the buffer/tensor being accessed
|
||||
var: The variable name for the pointer
|
||||
indexing: Block pointer options or tensor descriptor options containing
|
||||
indexing information and boundary check settings
|
||||
other: Additional parameters string (e.g., padding options)
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- block_descriptor: The generated block pointer or tensor descriptor variable name
|
||||
- other: Modified additional parameters string with boundary check options
|
||||
"""
|
||||
check = indexing.boundary_check()
|
||||
if isinstance(indexing, TensorDescriptorOptions):
|
||||
if check and other:
|
||||
@ -2892,14 +2910,24 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
# tensor descriptor.
|
||||
block_descriptor = self.prologue_cache[var]
|
||||
else:
|
||||
block_ptr_line = indexing.format(var, roffset=False)
|
||||
block_var = self.cse.try_get(block_ptr_line)
|
||||
|
||||
# Early return if block descriptor already exists
|
||||
if block_var:
|
||||
return str(block_var), other
|
||||
|
||||
block_descriptor_id = next(self.block_ptr_id)
|
||||
if isinstance(indexing, BlockPtrOptions):
|
||||
block_descriptor = f"block_ptr{block_descriptor_id}"
|
||||
else:
|
||||
block_descriptor = f"tma_descriptor{block_descriptor_id}"
|
||||
line_body = DeferredLine(
|
||||
name, f"{block_descriptor} = {indexing.format(var, roffset=False)}"
|
||||
named_var = self.cse.namedvar(
|
||||
block_descriptor, dtype=torch.uint64, shape=[]
|
||||
)
|
||||
self.cse.put(block_ptr_line, named_var)
|
||||
|
||||
line_body = DeferredLine(name, f"{block_descriptor} = {block_ptr_line}")
|
||||
if indexing.can_lift:
|
||||
self.prologue.writeline(line_body)
|
||||
# Cache the descriptor for epilogue subtiling
|
||||
|
Reference in New Issue
Block a user