mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor] Test ND block pointers with dynamic shapes (#151646)
With ND tiling, we can get multi-dimensional block pointers with dynamic shapes. This is an important capability, but I couldn't find any CI tests for it. This PR adds a couple of tests checking that we get the expected block pointers with dynamic shapes, both for pointwise and reduction kernels.
Example kernels:
```
@triton.jit
def triton_poi_fused_div_0(in_ptr0, out_ptr0, ks0, ks1, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
yoffset = (tl.program_id(1) + tl.program_id(2) * tl.num_programs(1)) * YBLOCK
yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
ymask = yindex < ynumel
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
xmask = xindex < xnumel
x1 = xindex
y0 = yindex
tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[ks0, ks0], strides=[ks1, 1], block_shape=[YBLOCK, XBLOCK], order=[1, 0], offsets=[yoffset, xoffset]), boundary_check=[0, 1])
tmp1 = (tmp0 / tmp0)
tl.store(tl.make_block_ptr(out_ptr0, shape=[ks0, ks0], strides=[ks0, 1], block_shape=[YBLOCK, XBLOCK], order=[1, 0], offsets=[yoffset, xoffset]), tl.broadcast_to(tmp1, [YBLOCK, XBLOCK]).to(tl.float32), boundary_check=[0, 1])
@triton.jit
def triton_red_fused_prod_0(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, r1_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr, R1_BLOCK : tl.constexpr):
xnumel = 1
rnumel = r0_numel * r1_numel
RBLOCK: tl.constexpr = R0_BLOCK*R1_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None, None]
xmask = tl.full([XBLOCK, R0_BLOCK, R1_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :, None]
r1_base = tl.arange(0, R1_BLOCK)[None, None, :]
rbase = r1_base + r0_base*r1_numel
block_ptr0 = tl.make_block_ptr(in_ptr0, shape=[ks0, ks0], strides=[ks1, 1], block_shape=[R0_BLOCK, R1_BLOCK], order=[1, 0], offsets=[0, 0])
_tmp2 = tl.full([XBLOCK, R0_BLOCK, R1_BLOCK], 1, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
for r1_offset in range(0, r1_numel, R1_BLOCK):
r1_index = r1_offset + r1_base
r1_mask = r1_index < r1_numel
roffset = r1_offset + r0_offset*r1_numel
rindex = r1_index + r0_index*r1_numel
r0_0 = r0_index
r1_1 = r1_index
tmp0 = tl.load(block_ptr0, boundary_check=[0, 1], padding_option='zero', eviction_policy='evict_first')[None, :, :]
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK, R1_BLOCK])
tmp3 = _tmp2 * tmp1
_tmp2 = tl.where(r0_mask & r1_mask, tmp3, _tmp2)
block_ptr0 = tl.advance(block_ptr0, [0, R1_BLOCK])
block_ptr0 = tl.advance(block_ptr0, [R0_BLOCK, (-1)*R1_BLOCK*(triton_helpers.div_floor_integer((-1) + ks0 + R1_BLOCK, R1_BLOCK))])
tmp4 = tl.reshape(_tmp2, [XBLOCK, RBLOCK])
tmp2 = triton_helpers.prod(tmp4, 1)[:, None, None]
tl.store(out_ptr0 + (tl.full([XBLOCK, 1, 1], 0, tl.int32)), tmp2, None)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151646
Approved by: https://github.com/jansel, https://github.com/eellison, https://github.com/shunting314
This commit is contained in:
committed by
PyTorch MergeBot
parent
ee81fe40c1
commit
62b5649b76
@ -510,22 +510,60 @@ class CommonTemplate:
|
||||
# Expect 2 block pointers: input and output
|
||||
run_and_compare(self, foo, view, expected_num_block_pointers=2)
|
||||
|
||||
def test_dynamic_shapes_generic(self):
|
||||
@parametrize(
|
||||
"nd_tiling,num_block_pointers",
|
||||
[
|
||||
(True, 2), # With tiling, the index is affine.
|
||||
(False, 1), # We can't infer that the load is a power of 2.
|
||||
],
|
||||
)
|
||||
def test_dynamic_shapes_pointwise(self, nd_tiling: bool, num_block_pointers: int):
|
||||
"""
|
||||
Test a generic strided block with dynamic shapes. Block pointers are not
|
||||
expected. This only checks that the analysis doesn't break this case.
|
||||
Test a pointwise kernel with dynamic shapes.
|
||||
"""
|
||||
|
||||
device = torch.device(self.device)
|
||||
full_size = (8, 8)
|
||||
view_size = (4, 4)
|
||||
full = torch.randn(full_size).to(device)
|
||||
view = torch.as_strided(full, view_size, full.stride())
|
||||
view = self._discontiguous_tensor(view_size, self.device)
|
||||
|
||||
run_and_compare(self, torch.div, view, view, compile_kwargs={"dynamic": True})
|
||||
run_and_compare(
|
||||
self,
|
||||
torch.div,
|
||||
view,
|
||||
view,
|
||||
expected_num_block_pointers=num_block_pointers,
|
||||
config_patches={"triton.prefer_nd_tiling": nd_tiling},
|
||||
compile_kwargs={"dynamic": True},
|
||||
)
|
||||
|
||||
@parametrize(
|
||||
"with_tiling,num_block_pointers",
|
||||
[
|
||||
(True, 1), # With tiling, the index is affine.
|
||||
(False, 0), # We can't infer that the load is a power of 2.
|
||||
],
|
||||
)
|
||||
def test_dynamic_shapes_reduction(self, with_tiling: bool, num_block_pointers: int):
|
||||
"""
|
||||
Test a reduction kernel with dynamic shapes.
|
||||
"""
|
||||
|
||||
view_size = (4, 4)
|
||||
view = self._discontiguous_tensor(view_size, self.device)
|
||||
|
||||
run_and_compare(
|
||||
self,
|
||||
torch.prod,
|
||||
view,
|
||||
expected_num_block_pointers=num_block_pointers,
|
||||
config_patches={
|
||||
"triton.prefer_nd_tiling": with_tiling,
|
||||
"triton.tile_reductions": with_tiling,
|
||||
},
|
||||
compile_kwargs={"dynamic": True},
|
||||
)
|
||||
|
||||
@unittest.skip(reason="Dynamo tracing error")
|
||||
def test_dynamic_shapes_multiple_max_block(self):
|
||||
def test_dynamic_shapes_pointwise_multiple_max_block(self):
|
||||
"""
|
||||
Test dynamic shapes, where we know the shape is a multiple of the max block
|
||||
size. We should be able to generate a block pointer for this case.
|
||||
|
||||
Reference in New Issue
Block a user