mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor] Test ND block pointers with dynamic shapes (#151646)
With ND tiling, we can get multi-dimensional block pointers with dynamic shapes. This is an important capability, but I couldn't find any CI tests for it. This PR adds a couple of tests checking that we get the expected block pointers with dynamic shapes, both for pointwise and reduction kernels. Example kernels: ``` @triton.jit def triton_poi_fused_div_0(in_ptr0, out_ptr0, ks0, ks1, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr): yoffset = (tl.program_id(1) + tl.program_id(2) * tl.num_programs(1)) * YBLOCK yindex = yoffset + tl.arange(0, YBLOCK)[:, None] ymask = yindex < ynumel xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[None, :] xmask = xindex < xnumel x1 = xindex y0 = yindex tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[ks0, ks0], strides=[ks1, 1], block_shape=[YBLOCK, XBLOCK], order=[1, 0], offsets=[yoffset, xoffset]), boundary_check=[0, 1]) tmp1 = (tmp0 / tmp0) tl.store(tl.make_block_ptr(out_ptr0, shape=[ks0, ks0], strides=[ks0, 1], block_shape=[YBLOCK, XBLOCK], order=[1, 0], offsets=[yoffset, xoffset]), tl.broadcast_to(tmp1, [YBLOCK, XBLOCK]).to(tl.float32), boundary_check=[0, 1]) @triton.jit def triton_red_fused_prod_0(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, r1_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr, R1_BLOCK : tl.constexpr): xnumel = 1 rnumel = r0_numel * r1_numel RBLOCK: tl.constexpr = R0_BLOCK*R1_BLOCK xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None, None] xmask = tl.full([XBLOCK, R0_BLOCK, R1_BLOCK], True, tl.int1) r0_base = tl.arange(0, R0_BLOCK)[None, :, None] r1_base = tl.arange(0, R1_BLOCK)[None, None, :] rbase = r1_base + r0_base*r1_numel block_ptr0 = tl.make_block_ptr(in_ptr0, shape=[ks0, ks0], strides=[ks1, 1], block_shape=[R0_BLOCK, R1_BLOCK], order=[1, 0], offsets=[0, 0]) _tmp2 = tl.full([XBLOCK, R0_BLOCK, R1_BLOCK], 1, tl.float32) for r0_offset in range(0, r0_numel, R0_BLOCK): r0_index = r0_offset + r0_base r0_mask = r0_index < r0_numel for r1_offset in range(0, r1_numel, R1_BLOCK): r1_index = r1_offset + r1_base r1_mask = r1_index < r1_numel roffset = r1_offset + r0_offset*r1_numel rindex = r1_index + r0_index*r1_numel r0_0 = r0_index r1_1 = r1_index tmp0 = tl.load(block_ptr0, boundary_check=[0, 1], padding_option='zero', eviction_policy='evict_first')[None, :, :] tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK, R1_BLOCK]) tmp3 = _tmp2 * tmp1 _tmp2 = tl.where(r0_mask & r1_mask, tmp3, _tmp2) block_ptr0 = tl.advance(block_ptr0, [0, R1_BLOCK]) block_ptr0 = tl.advance(block_ptr0, [R0_BLOCK, (-1)*R1_BLOCK*(triton_helpers.div_floor_integer((-1) + ks0 + R1_BLOCK, R1_BLOCK))]) tmp4 = tl.reshape(_tmp2, [XBLOCK, RBLOCK]) tmp2 = triton_helpers.prod(tmp4, 1)[:, None, None] tl.store(out_ptr0 + (tl.full([XBLOCK, 1, 1], 0, tl.int32)), tmp2, None) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/151646 Approved by: https://github.com/jansel, https://github.com/eellison, https://github.com/shunting314
This commit is contained in:
committed by
PyTorch MergeBot
parent
ee81fe40c1
commit
62b5649b76
@ -510,22 +510,60 @@ class CommonTemplate:
|
||||
# Expect 2 block pointers: input and output
|
||||
run_and_compare(self, foo, view, expected_num_block_pointers=2)
|
||||
|
||||
def test_dynamic_shapes_generic(self):
|
||||
@parametrize(
|
||||
"nd_tiling,num_block_pointers",
|
||||
[
|
||||
(True, 2), # With tiling, the index is affine.
|
||||
(False, 1), # We can't infer that the load is a power of 2.
|
||||
],
|
||||
)
|
||||
def test_dynamic_shapes_pointwise(self, nd_tiling: bool, num_block_pointers: int):
|
||||
"""
|
||||
Test a generic strided block with dynamic shapes. Block pointers are not
|
||||
expected. This only checks that the analysis doesn't break this case.
|
||||
Test a pointwise kernel with dynamic shapes.
|
||||
"""
|
||||
|
||||
device = torch.device(self.device)
|
||||
full_size = (8, 8)
|
||||
view_size = (4, 4)
|
||||
full = torch.randn(full_size).to(device)
|
||||
view = torch.as_strided(full, view_size, full.stride())
|
||||
view = self._discontiguous_tensor(view_size, self.device)
|
||||
|
||||
run_and_compare(self, torch.div, view, view, compile_kwargs={"dynamic": True})
|
||||
run_and_compare(
|
||||
self,
|
||||
torch.div,
|
||||
view,
|
||||
view,
|
||||
expected_num_block_pointers=num_block_pointers,
|
||||
config_patches={"triton.prefer_nd_tiling": nd_tiling},
|
||||
compile_kwargs={"dynamic": True},
|
||||
)
|
||||
|
||||
@parametrize(
|
||||
"with_tiling,num_block_pointers",
|
||||
[
|
||||
(True, 1), # With tiling, the index is affine.
|
||||
(False, 0), # We can't infer that the load is a power of 2.
|
||||
],
|
||||
)
|
||||
def test_dynamic_shapes_reduction(self, with_tiling: bool, num_block_pointers: int):
|
||||
"""
|
||||
Test a reduction kernel with dynamic shapes.
|
||||
"""
|
||||
|
||||
view_size = (4, 4)
|
||||
view = self._discontiguous_tensor(view_size, self.device)
|
||||
|
||||
run_and_compare(
|
||||
self,
|
||||
torch.prod,
|
||||
view,
|
||||
expected_num_block_pointers=num_block_pointers,
|
||||
config_patches={
|
||||
"triton.prefer_nd_tiling": with_tiling,
|
||||
"triton.tile_reductions": with_tiling,
|
||||
},
|
||||
compile_kwargs={"dynamic": True},
|
||||
)
|
||||
|
||||
@unittest.skip(reason="Dynamo tracing error")
|
||||
def test_dynamic_shapes_multiple_max_block(self):
|
||||
def test_dynamic_shapes_pointwise_multiple_max_block(self):
|
||||
"""
|
||||
Test dynamic shapes, where we know the shape is a multiple of the max block
|
||||
size. We should be able to generate a block pointer for this case.
|
||||
|
Reference in New Issue
Block a user