[Inductor] Test ND block pointers with dynamic shapes (#151646)

With ND tiling, we can get multi-dimensional block pointers with dynamic shapes. This is an important capability, but I couldn't find any CI tests for it. This PR adds a couple of tests checking that we get the expected block pointers with dynamic shapes, both for pointwise and reduction kernels.

Example kernels:
```
@triton.jit
def triton_poi_fused_div_0(in_ptr0, out_ptr0, ks0, ks1, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    yoffset = (tl.program_id(1) + tl.program_id(2) * tl.num_programs(1)) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = yindex < ynumel
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = xindex < xnumel
    x1 = xindex
    y0 = yindex
    tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[ks0, ks0], strides=[ks1, 1], block_shape=[YBLOCK, XBLOCK], order=[1, 0], offsets=[yoffset, xoffset]), boundary_check=[0, 1])
    tmp1 = (tmp0 / tmp0)
    tl.store(tl.make_block_ptr(out_ptr0, shape=[ks0, ks0], strides=[ks0, 1], block_shape=[YBLOCK, XBLOCK], order=[1, 0], offsets=[yoffset, xoffset]), tl.broadcast_to(tmp1, [YBLOCK, XBLOCK]).to(tl.float32), boundary_check=[0, 1])

@triton.jit
def triton_red_fused_prod_0(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, r1_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr, R1_BLOCK : tl.constexpr):
    xnumel = 1
    rnumel = r0_numel * r1_numel
    RBLOCK: tl.constexpr = R0_BLOCK*R1_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None, None]
    xmask = tl.full([XBLOCK, R0_BLOCK, R1_BLOCK], True, tl.int1)
    r0_base = tl.arange(0, R0_BLOCK)[None, :, None]
    r1_base = tl.arange(0, R1_BLOCK)[None, None, :]
    rbase = r1_base + r0_base*r1_numel
    block_ptr0 = tl.make_block_ptr(in_ptr0, shape=[ks0, ks0], strides=[ks1, 1], block_shape=[R0_BLOCK, R1_BLOCK], order=[1, 0], offsets=[0, 0])
    _tmp2 = tl.full([XBLOCK, R0_BLOCK, R1_BLOCK], 1, tl.float32)
    for r0_offset in range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        for r1_offset in range(0, r1_numel, R1_BLOCK):
            r1_index = r1_offset + r1_base
            r1_mask = r1_index < r1_numel
            roffset = r1_offset + r0_offset*r1_numel
            rindex = r1_index + r0_index*r1_numel
            r0_0 = r0_index
            r1_1 = r1_index
            tmp0 = tl.load(block_ptr0, boundary_check=[0, 1], padding_option='zero', eviction_policy='evict_first')[None, :, :]
            tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK, R1_BLOCK])
            tmp3 = _tmp2 * tmp1
            _tmp2 = tl.where(r0_mask & r1_mask, tmp3, _tmp2)
            block_ptr0 = tl.advance(block_ptr0, [0, R1_BLOCK])
        block_ptr0 = tl.advance(block_ptr0, [R0_BLOCK, (-1)*R1_BLOCK*(triton_helpers.div_floor_integer((-1) + ks0 + R1_BLOCK,  R1_BLOCK))])
    tmp4 = tl.reshape(_tmp2, [XBLOCK, RBLOCK])
    tmp2 = triton_helpers.prod(tmp4, 1)[:, None, None]
    tl.store(out_ptr0 + (tl.full([XBLOCK, 1, 1], 0, tl.int32)), tmp2, None)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151646
Approved by: https://github.com/jansel, https://github.com/eellison, https://github.com/shunting314
This commit is contained in:
Blaine Burton Rister
2025-04-23 06:20:00 +00:00
committed by PyTorch MergeBot
parent ee81fe40c1
commit 62b5649b76

View File

@ -510,22 +510,60 @@ class CommonTemplate:
# Expect 2 block pointers: input and output
run_and_compare(self, foo, view, expected_num_block_pointers=2)
def test_dynamic_shapes_generic(self):
@parametrize(
"nd_tiling,num_block_pointers",
[
(True, 2), # With tiling, the index is affine.
(False, 1), # We can't infer that the load is a power of 2.
],
)
def test_dynamic_shapes_pointwise(self, nd_tiling: bool, num_block_pointers: int):
"""
Test a generic strided block with dynamic shapes. Block pointers are not
expected. This only checks that the analysis doesn't break this case.
Test a pointwise kernel with dynamic shapes.
"""
device = torch.device(self.device)
full_size = (8, 8)
view_size = (4, 4)
full = torch.randn(full_size).to(device)
view = torch.as_strided(full, view_size, full.stride())
view = self._discontiguous_tensor(view_size, self.device)
run_and_compare(self, torch.div, view, view, compile_kwargs={"dynamic": True})
run_and_compare(
self,
torch.div,
view,
view,
expected_num_block_pointers=num_block_pointers,
config_patches={"triton.prefer_nd_tiling": nd_tiling},
compile_kwargs={"dynamic": True},
)
@parametrize(
"with_tiling,num_block_pointers",
[
(True, 1), # With tiling, the index is affine.
(False, 0), # We can't infer that the load is a power of 2.
],
)
def test_dynamic_shapes_reduction(self, with_tiling: bool, num_block_pointers: int):
"""
Test a reduction kernel with dynamic shapes.
"""
view_size = (4, 4)
view = self._discontiguous_tensor(view_size, self.device)
run_and_compare(
self,
torch.prod,
view,
expected_num_block_pointers=num_block_pointers,
config_patches={
"triton.prefer_nd_tiling": with_tiling,
"triton.tile_reductions": with_tiling,
},
compile_kwargs={"dynamic": True},
)
@unittest.skip(reason="Dynamo tracing error")
def test_dynamic_shapes_multiple_max_block(self):
def test_dynamic_shapes_pointwise_multiple_max_block(self):
"""
Test dynamic shapes, where we know the shape is a multiple of the max block
size. We should be able to generate a block pointer for this case.