[inductor] handle offset in ReinterpretView for alignment (#151859)

Fix https://github.com/pytorch/pytorch/issues/151589

It's interesting that the Q4_K dequantization example in the referred GH issue does not crash even if Inductor pass triton the wrong alignment information. I dig this a bit. The main reason is, there are 2 things in triton that decides the vectorization size
1. alignement
2. max number of contiguous elements a thread need to process

Here is the triton code that decides vectorization size [link](c5fed8e1ca/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp (L147-L157)), and here is the triton code that considers contiguity for vectorization [link](c5fed8e1ca/lib/Analysis/AxisInfo.cpp (L1250-L1269))

When Inductor wrongly tell triton that a unaligned tensor is aligned, Triton may not do vectorization (or not do full vectorization) because of the second restriction.

Check this test:
```
    @parametrize(
        "size",
        (
            128,
            1024,
            1024 * 1024,
        ),
    )
    def test_slice_view_dtype(self, size):
        offset = 1

        def f(x):
            return x[2:].view(dtype=torch.float32) + 1

        x = torch.randn((size + offset) * 2, dtype=torch.bfloat16, device=self.device)
        self.common(f, (x,), reference_in_float=False)
```

Before the fix, Inductor would tell Triton that the output of aten.view.dtype tensor is aligned even though it's not. That tensor will be passed to the triton kernel for the aten.add. Triton may do different vectorization decision depending on the tensor size
1. when size = 128, triton pick ld.global.b32 to load data from global memory
2. when size = 1024, triton uses ld.global.v2.b32
4. when size = 1024 * 1024, triton uses ld.global.v4.b32

So whether wrong alignment metadata causes issue depends on if triton picks the vectorized instructions. The latter depends on the triton config (block size) decided by inductor and triton internal logic (how they assign elements to each thread). We'd better to make sure Inductor always generate correct metadata to make sure such hidden issues does not turn into crash later.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151859
Approved by: https://github.com/jansel, https://github.com/eellison
ghstack dependencies: #151841
This commit is contained in:
Shunting Zhang
2025-04-22 10:50:14 -07:00
committed by PyTorch MergeBot
parent 68a7501dab
commit 74074fe8d8
2 changed files with 125 additions and 4 deletions

View File

@ -5,7 +5,11 @@ import unittest
import torch import torch
from torch._inductor import config from torch._inductor import config
from torch.testing._internal.common_utils import MACOS_VERSION from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
MACOS_VERSION,
parametrize,
)
from torch.testing._internal.inductor_utils import GPU_TYPE, RUN_CPU, RUN_GPU from torch.testing._internal.inductor_utils import GPU_TYPE, RUN_CPU, RUN_GPU
@ -27,6 +31,7 @@ copy_tests = test_torchinductor.copy_tests
define_custom_op_for_test = test_torchinductor.define_custom_op_for_test define_custom_op_for_test = test_torchinductor.define_custom_op_for_test
@instantiate_parametrized_tests
class CommonTemplate: class CommonTemplate:
def test_unaligned_input(self): def test_unaligned_input(self):
def fn(x): def fn(x):
@ -121,6 +126,104 @@ class CommonTemplate:
with self.assertRaisesRegex(AssertionError, expected_error): with self.assertRaisesRegex(AssertionError, expected_error):
self.common(fn, (x,), check_lowp=False) self.common(fn, (x,), check_lowp=False)
def test_slice(self):
def f(x):
return x[1:] + 1
x = torch.randn(1025, device=self.device)
self.common(f, (x,))
def test_view_dtype_slice(self):
def f(x):
return x.view(dtype=torch.float32)[1:] + 1
x = torch.randn(1025 * 2, dtype=torch.bfloat16, device=self.device)
self.common(f, (x,), reference_in_float=False)
@parametrize(
"size",
(
# wrapper for size = 128: https://gist.github.com/shunting314/88f1e72957b9fc5e9826aaa346a0e652
# ptx: https://gist.github.com/shunting314/eb657ee8821eef9f0685b7b91e2ad5c2
# the ptx file uses ld.global.b32 to load input buffer
128,
# wrapper for size = 1024: https://gist.github.com/shunting314/d7f64e1f52f6b1e2ec25e1a51052ce43
# ptx: https://gist.github.com/shunting314/a24ff7563bb6b04523d11b119ab0f2b2
# the ptx file uses ld.global.v2.b32 to load input buffer
1024,
# wrapper for size = 1024 * 1024: https://gist.github.com/shunting314/016b95cf0b6e9a75c25f5c9d5ed0a2ba
# ptx: https://gist.github.com/shunting314/360112a4893c759b114c12fc99958297
# the ptx file uses ld.global.v4.b32 to load input buffer
1024 * 1024,
),
)
def test_slice_view_dtype(self, size):
offset = 1
def f(x):
return x[2:].view(dtype=torch.float32) + 1
x = torch.randn((size + offset) * 2, dtype=torch.bfloat16, device=self.device)
self.common(f, (x,), reference_in_float=False)
def test_Q4_K_dequantization(self):
"""
Test the alignment issue for Q4_K dequantization.
"""
QK_K = 256
K_SCALE_SIZE = 12
def get_scale_min(scales):
n_blocks = scales.shape[0]
scales = scales.view(torch.uint8)
scales = scales.reshape((n_blocks, 3, 4))
d, m, m_d = torch.split(scales, scales.shape[-2] // 3, dim=-2)
sc = torch.cat([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], dim=-1)
min = torch.cat([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], dim=-1)
return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8)))
def split_block_dims(blocks, *args):
n_max = blocks.shape[1]
dims = list(args) + [n_max - sum(args)]
return torch.split(blocks, dims, dim=1)
def dequantize_blocks_Q4_K(blocks, block_size, type_size):
n_blocks = blocks.shape[0]
d, dmin, scales, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE)
d = d.view(torch.float16)
dmin = dmin.view(torch.float16)
sc, m = get_scale_min(scales)
d = (d * sc).reshape((n_blocks, -1, 1))
dm = (dmin * m).reshape((n_blocks, -1, 1))
qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor(
[0, 4], device=d.device, dtype=torch.uint8
).reshape((1, 1, 2, 1))
qs = (qs & 0x0F).reshape((n_blocks, -1, 32))
return (d * qs - dm).reshape((n_blocks, QK_K))
data = torch.randint(
0, 16, (18432, 1728), device=self.device, dtype=torch.uint8
)
def dequantize(data):
block_size, type_size = 256, 144
rows = data.reshape((-1, data.shape[-1])).view(torch.uint8)
n_blocks = rows.numel() // type_size
blocks = rows.reshape((n_blocks, type_size))
blocks = dequantize_blocks_Q4_K(blocks, block_size, type_size)
return blocks.reshape(18432, 3072)
self.common(dequantize, (data,), check_lowp=False, atol=1e-3, rtol=1e-3)
if RUN_CPU: if RUN_CPU:

View File

@ -55,6 +55,7 @@ from torch.fx.experimental.symbolic_shapes import (
rebind_unbacked, rebind_unbacked,
resolve_unbacked_bindings, resolve_unbacked_bindings,
ShapeEnv, ShapeEnv,
statically_known_true,
SymTypes, SymTypes,
) )
from torch.utils._ordered_set import OrderedSet from torch.utils._ordered_set import OrderedSet
@ -86,6 +87,7 @@ from .utils import (
convert_shape_to_inductor, convert_shape_to_inductor,
convert_shape_to_symint, convert_shape_to_symint,
developer_warning, developer_warning,
get_dtype_size,
get_kernel_metadata, get_kernel_metadata,
GPU_ALIGN_BYTES, GPU_ALIGN_BYTES,
ir_dataclass, ir_dataclass,
@ -2595,6 +2597,24 @@ def is_stride_order_storage_and_layout(
return False return False
def is_unaligned(node: IRNode) -> bool:
if isinstance(node, (TensorBox, StorageBox)):
return is_unaligned(node.data)
if isinstance(node, ReinterpretView):
layout = node.layout
has_unaligned_layout = not statically_known_true(
layout.offset * get_dtype_size(layout.dtype) % GPU_ALIGN_BYTES == 0
)
return is_unaligned(node.data) or has_unaligned_layout
if isinstance(node, Buffer):
return node.get_name() in V.graph.unaligned_buffers
# assume to be aligned otherwise
return False
@ir_dataclass @ir_dataclass
class BaseView(IRNode): class BaseView(IRNode):
data: IRNode data: IRNode
@ -6990,9 +7010,7 @@ class FallbackKernel(ExternKernelAlloc):
# We need this extra check for input alignment since the example # We need this extra check for input alignment since the example
# inputs we created are always aligned. # inputs we created are always aligned.
has_unaligned_input = any( has_unaligned_input = any(is_unaligned(arg) for arg in tensor_args)
arg.get_name() in V.graph.unaligned_buffers for arg in tensor_args
)
device = cls.find_device(tensor_args, example_output) device = cls.find_device(tensor_args, example_output)