mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add cse for make_block_ptr in Triton codegen (#163399)
Summary: per title Test Plan: added test cases Differential Revision: D82648215 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163399 Approved by: https://github.com/jansel, https://github.com/njriasan
This commit is contained in:
committed by
PyTorch MergeBot
parent
5d0b22008d
commit
00afa06800
@ -32,6 +32,8 @@ class CodegenInductorTest(InductorTestCase):
|
|||||||
*args,
|
*args,
|
||||||
compile_kwargs: Optional[dict] = None,
|
compile_kwargs: Optional[dict] = None,
|
||||||
config_patches: Optional[dict] = None,
|
config_patches: Optional[dict] = None,
|
||||||
|
atol: float | None = 1e-05,
|
||||||
|
rtol: float | None = 1e-08,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Runs the module through Inductor, comparing to eager reference.
|
Runs the module through Inductor, comparing to eager reference.
|
||||||
@ -53,7 +55,7 @@ class CodegenInductorTest(InductorTestCase):
|
|||||||
ref_tensors = flatten_tensors(func(*args))
|
ref_tensors = flatten_tensors(func(*args))
|
||||||
actual_tensors = flatten_tensors(result)
|
actual_tensors = flatten_tensors(result)
|
||||||
for ref, actual in zip(ref_tensors, actual_tensors):
|
for ref, actual in zip(ref_tensors, actual_tensors):
|
||||||
self.assertTrue(torch.allclose(ref, actual))
|
self.assertTrue(torch.allclose(ref, actual, atol=atol, rtol=rtol))
|
||||||
|
|
||||||
return result, code
|
return result, code
|
||||||
|
|
||||||
@ -89,6 +91,34 @@ class CodegenInductorTest(InductorTestCase):
|
|||||||
else:
|
else:
|
||||||
self.count_code(reinterpret_call, code, 2)
|
self.count_code(reinterpret_call, code, 2)
|
||||||
|
|
||||||
|
@requires_gpu()
|
||||||
|
@skipIf(GPU_TYPE == "mps", "Triton is not available for MPS")
|
||||||
|
def test_cse_make_block_ptr_reduction(self):
|
||||||
|
def func(a, b):
|
||||||
|
tmp0 = a * b
|
||||||
|
tmp1 = a + b
|
||||||
|
c = tmp0 + tmp1
|
||||||
|
return c.sum(dim=0)
|
||||||
|
|
||||||
|
config_patches = {
|
||||||
|
"triton.use_block_ptr": True,
|
||||||
|
"triton.tile_reductions": True,
|
||||||
|
"triton.prefer_nd_tiling": True,
|
||||||
|
"triton.max_tiles": 3,
|
||||||
|
"split_reductions": False,
|
||||||
|
}
|
||||||
|
a = torch.randn((512, 4096), device=torch.device(GPU_TYPE))
|
||||||
|
b = torch.randn((512, 4096), device=torch.device(GPU_TYPE))
|
||||||
|
_, code = self.run_and_compare(
|
||||||
|
func,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
config_patches=config_patches,
|
||||||
|
atol=1e-4,
|
||||||
|
)
|
||||||
|
self.count_code("= tl.make_block_ptr(in_ptr", code, 2)
|
||||||
|
self.count_code("= tl.load(block_ptr", code, 2)
|
||||||
|
|
||||||
@requires_gpu()
|
@requires_gpu()
|
||||||
@skipIf(GPU_TYPE == "mps", "Triton is not available for MPS")
|
@skipIf(GPU_TYPE == "mps", "Triton is not available for MPS")
|
||||||
def test_kernel_fusion_thresholds(self):
|
def test_kernel_fusion_thresholds(self):
|
||||||
|
@ -824,7 +824,7 @@ class CommonTemplate:
|
|||||||
[
|
[
|
||||||
((8, 8), 1, 1, True), # Persistent Welford fallback
|
((8, 8), 1, 1, True), # Persistent Welford fallback
|
||||||
subtest(
|
subtest(
|
||||||
((128, 128), 9, 2, False), decorators=[xfail_if_use_tensor_descriptor]
|
((128, 128), 7, 2, False), decorators=[xfail_if_use_tensor_descriptor]
|
||||||
), # Looped Welford reduction
|
), # Looped Welford reduction
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -924,7 +924,7 @@ class CommonTemplate:
|
|||||||
result, (code,) = self._run_and_compare(
|
result, (code,) = self._run_and_compare(
|
||||||
foo,
|
foo,
|
||||||
view,
|
view,
|
||||||
expected_num_block_pointers=6,
|
expected_num_block_pointers=5,
|
||||||
expected_num_triton_kernels=2,
|
expected_num_triton_kernels=2,
|
||||||
config_patches={
|
config_patches={
|
||||||
"triton.multi_kernel": True,
|
"triton.multi_kernel": True,
|
||||||
|
@ -2865,6 +2865,24 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
|||||||
indexing: Union[BlockPtrOptions, TensorDescriptorOptions],
|
indexing: Union[BlockPtrOptions, TensorDescriptorOptions],
|
||||||
other="",
|
other="",
|
||||||
) -> tuple[str, str]:
|
) -> tuple[str, str]:
|
||||||
|
"""Generate a block pointer or tensor descriptor for Triton kernel operations.
|
||||||
|
|
||||||
|
This method creates either a block pointer (for regular Triton operations) or
|
||||||
|
a tensor descriptor (for TMA operations) based on the indexing type. It handles
|
||||||
|
caching and reuse of descriptors for performance optimization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name of the buffer/tensor being accessed
|
||||||
|
var: The variable name for the pointer
|
||||||
|
indexing: Block pointer options or tensor descriptor options containing
|
||||||
|
indexing information and boundary check settings
|
||||||
|
other: Additional parameters string (e.g., padding options)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing:
|
||||||
|
- block_descriptor: The generated block pointer or tensor descriptor variable name
|
||||||
|
- other: Modified additional parameters string with boundary check options
|
||||||
|
"""
|
||||||
check = indexing.boundary_check()
|
check = indexing.boundary_check()
|
||||||
if isinstance(indexing, TensorDescriptorOptions):
|
if isinstance(indexing, TensorDescriptorOptions):
|
||||||
if check and other:
|
if check and other:
|
||||||
@ -2892,14 +2910,24 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
|||||||
# tensor descriptor.
|
# tensor descriptor.
|
||||||
block_descriptor = self.prologue_cache[var]
|
block_descriptor = self.prologue_cache[var]
|
||||||
else:
|
else:
|
||||||
|
block_ptr_line = indexing.format(var, roffset=False)
|
||||||
|
block_var = self.cse.try_get(block_ptr_line)
|
||||||
|
|
||||||
|
# Early return if block descriptor already exists
|
||||||
|
if block_var:
|
||||||
|
return str(block_var), other
|
||||||
|
|
||||||
block_descriptor_id = next(self.block_ptr_id)
|
block_descriptor_id = next(self.block_ptr_id)
|
||||||
if isinstance(indexing, BlockPtrOptions):
|
if isinstance(indexing, BlockPtrOptions):
|
||||||
block_descriptor = f"block_ptr{block_descriptor_id}"
|
block_descriptor = f"block_ptr{block_descriptor_id}"
|
||||||
else:
|
else:
|
||||||
block_descriptor = f"tma_descriptor{block_descriptor_id}"
|
block_descriptor = f"tma_descriptor{block_descriptor_id}"
|
||||||
line_body = DeferredLine(
|
named_var = self.cse.namedvar(
|
||||||
name, f"{block_descriptor} = {indexing.format(var, roffset=False)}"
|
block_descriptor, dtype=torch.uint64, shape=[]
|
||||||
)
|
)
|
||||||
|
self.cse.put(block_ptr_line, named_var)
|
||||||
|
|
||||||
|
line_body = DeferredLine(name, f"{block_descriptor} = {block_ptr_line}")
|
||||||
if indexing.can_lift:
|
if indexing.can_lift:
|
||||||
self.prologue.writeline(line_body)
|
self.prologue.writeline(line_body)
|
||||||
# Cache the descriptor for epilogue subtiling
|
# Cache the descriptor for epilogue subtiling
|
||||||
|
Reference in New Issue
Block a user