mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] handle offset in ReinterpretView for alignment (#151859)
Fix https://github.com/pytorch/pytorch/issues/151589 It's interesting that the Q4_K dequantization example in the referred GH issue does not crash even if Inductor pass triton the wrong alignment information. I dig this a bit. The main reason is, there are 2 things in triton that decides the vectorization size 1. alignement 2. max number of contiguous elements a thread need to process Here is the triton code that decides vectorization size [link](c5fed8e1ca/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp (L147-L157)
), and here is the triton code that considers contiguity for vectorization [link](c5fed8e1ca/lib/Analysis/AxisInfo.cpp (L1250-L1269)
) When Inductor wrongly tell triton that a unaligned tensor is aligned, Triton may not do vectorization (or not do full vectorization) because of the second restriction. Check this test: ``` @parametrize( "size", ( 128, 1024, 1024 * 1024, ), ) def test_slice_view_dtype(self, size): offset = 1 def f(x): return x[2:].view(dtype=torch.float32) + 1 x = torch.randn((size + offset) * 2, dtype=torch.bfloat16, device=self.device) self.common(f, (x,), reference_in_float=False) ``` Before the fix, Inductor would tell Triton that the output of aten.view.dtype tensor is aligned even though it's not. That tensor will be passed to the triton kernel for the aten.add. Triton may do different vectorization decision depending on the tensor size 1. when size = 128, triton pick ld.global.b32 to load data from global memory 2. when size = 1024, triton uses ld.global.v2.b32 4. when size = 1024 * 1024, triton uses ld.global.v4.b32 So whether wrong alignment metadata causes issue depends on if triton picks the vectorized instructions. The latter depends on the triton config (block size) decided by inductor and triton internal logic (how they assign elements to each thread). We'd better to make sure Inductor always generate correct metadata to make sure such hidden issues does not turn into crash later. Pull Request resolved: https://github.com/pytorch/pytorch/pull/151859 Approved by: https://github.com/jansel, https://github.com/eellison ghstack dependencies: #151841
This commit is contained in:
committed by
PyTorch MergeBot
parent
68a7501dab
commit
74074fe8d8
@ -5,7 +5,11 @@ import unittest
|
||||
|
||||
import torch
|
||||
from torch._inductor import config
|
||||
from torch.testing._internal.common_utils import MACOS_VERSION
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
MACOS_VERSION,
|
||||
parametrize,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, RUN_CPU, RUN_GPU
|
||||
|
||||
|
||||
@ -27,6 +31,7 @@ copy_tests = test_torchinductor.copy_tests
|
||||
define_custom_op_for_test = test_torchinductor.define_custom_op_for_test
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
class CommonTemplate:
|
||||
def test_unaligned_input(self):
|
||||
def fn(x):
|
||||
@ -121,6 +126,104 @@ class CommonTemplate:
|
||||
with self.assertRaisesRegex(AssertionError, expected_error):
|
||||
self.common(fn, (x,), check_lowp=False)
|
||||
|
||||
def test_slice(self):
|
||||
def f(x):
|
||||
return x[1:] + 1
|
||||
|
||||
x = torch.randn(1025, device=self.device)
|
||||
self.common(f, (x,))
|
||||
|
||||
def test_view_dtype_slice(self):
|
||||
def f(x):
|
||||
return x.view(dtype=torch.float32)[1:] + 1
|
||||
|
||||
x = torch.randn(1025 * 2, dtype=torch.bfloat16, device=self.device)
|
||||
self.common(f, (x,), reference_in_float=False)
|
||||
|
||||
@parametrize(
|
||||
"size",
|
||||
(
|
||||
# wrapper for size = 128: https://gist.github.com/shunting314/88f1e72957b9fc5e9826aaa346a0e652
|
||||
# ptx: https://gist.github.com/shunting314/eb657ee8821eef9f0685b7b91e2ad5c2
|
||||
# the ptx file uses ld.global.b32 to load input buffer
|
||||
128,
|
||||
# wrapper for size = 1024: https://gist.github.com/shunting314/d7f64e1f52f6b1e2ec25e1a51052ce43
|
||||
# ptx: https://gist.github.com/shunting314/a24ff7563bb6b04523d11b119ab0f2b2
|
||||
# the ptx file uses ld.global.v2.b32 to load input buffer
|
||||
1024,
|
||||
# wrapper for size = 1024 * 1024: https://gist.github.com/shunting314/016b95cf0b6e9a75c25f5c9d5ed0a2ba
|
||||
# ptx: https://gist.github.com/shunting314/360112a4893c759b114c12fc99958297
|
||||
# the ptx file uses ld.global.v4.b32 to load input buffer
|
||||
1024 * 1024,
|
||||
),
|
||||
)
|
||||
def test_slice_view_dtype(self, size):
|
||||
offset = 1
|
||||
|
||||
def f(x):
|
||||
return x[2:].view(dtype=torch.float32) + 1
|
||||
|
||||
x = torch.randn((size + offset) * 2, dtype=torch.bfloat16, device=self.device)
|
||||
self.common(f, (x,), reference_in_float=False)
|
||||
|
||||
def test_Q4_K_dequantization(self):
|
||||
"""
|
||||
Test the alignment issue for Q4_K dequantization.
|
||||
"""
|
||||
|
||||
QK_K = 256
|
||||
K_SCALE_SIZE = 12
|
||||
|
||||
def get_scale_min(scales):
|
||||
n_blocks = scales.shape[0]
|
||||
scales = scales.view(torch.uint8)
|
||||
scales = scales.reshape((n_blocks, 3, 4))
|
||||
|
||||
d, m, m_d = torch.split(scales, scales.shape[-2] // 3, dim=-2)
|
||||
|
||||
sc = torch.cat([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], dim=-1)
|
||||
min = torch.cat([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], dim=-1)
|
||||
|
||||
return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8)))
|
||||
|
||||
def split_block_dims(blocks, *args):
|
||||
n_max = blocks.shape[1]
|
||||
dims = list(args) + [n_max - sum(args)]
|
||||
return torch.split(blocks, dims, dim=1)
|
||||
|
||||
def dequantize_blocks_Q4_K(blocks, block_size, type_size):
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, dmin, scales, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE)
|
||||
d = d.view(torch.float16)
|
||||
dmin = dmin.view(torch.float16)
|
||||
|
||||
sc, m = get_scale_min(scales)
|
||||
|
||||
d = (d * sc).reshape((n_blocks, -1, 1))
|
||||
dm = (dmin * m).reshape((n_blocks, -1, 1))
|
||||
|
||||
qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor(
|
||||
[0, 4], device=d.device, dtype=torch.uint8
|
||||
).reshape((1, 1, 2, 1))
|
||||
qs = (qs & 0x0F).reshape((n_blocks, -1, 32))
|
||||
|
||||
return (d * qs - dm).reshape((n_blocks, QK_K))
|
||||
|
||||
data = torch.randint(
|
||||
0, 16, (18432, 1728), device=self.device, dtype=torch.uint8
|
||||
)
|
||||
|
||||
def dequantize(data):
|
||||
block_size, type_size = 256, 144
|
||||
rows = data.reshape((-1, data.shape[-1])).view(torch.uint8)
|
||||
n_blocks = rows.numel() // type_size
|
||||
blocks = rows.reshape((n_blocks, type_size))
|
||||
blocks = dequantize_blocks_Q4_K(blocks, block_size, type_size)
|
||||
return blocks.reshape(18432, 3072)
|
||||
|
||||
self.common(dequantize, (data,), check_lowp=False, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
if RUN_CPU:
|
||||
|
||||
|
@ -55,6 +55,7 @@ from torch.fx.experimental.symbolic_shapes import (
|
||||
rebind_unbacked,
|
||||
resolve_unbacked_bindings,
|
||||
ShapeEnv,
|
||||
statically_known_true,
|
||||
SymTypes,
|
||||
)
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
@ -86,6 +87,7 @@ from .utils import (
|
||||
convert_shape_to_inductor,
|
||||
convert_shape_to_symint,
|
||||
developer_warning,
|
||||
get_dtype_size,
|
||||
get_kernel_metadata,
|
||||
GPU_ALIGN_BYTES,
|
||||
ir_dataclass,
|
||||
@ -2595,6 +2597,24 @@ def is_stride_order_storage_and_layout(
|
||||
return False
|
||||
|
||||
|
||||
def is_unaligned(node: IRNode) -> bool:
|
||||
if isinstance(node, (TensorBox, StorageBox)):
|
||||
return is_unaligned(node.data)
|
||||
|
||||
if isinstance(node, ReinterpretView):
|
||||
layout = node.layout
|
||||
has_unaligned_layout = not statically_known_true(
|
||||
layout.offset * get_dtype_size(layout.dtype) % GPU_ALIGN_BYTES == 0
|
||||
)
|
||||
return is_unaligned(node.data) or has_unaligned_layout
|
||||
|
||||
if isinstance(node, Buffer):
|
||||
return node.get_name() in V.graph.unaligned_buffers
|
||||
|
||||
# assume to be aligned otherwise
|
||||
return False
|
||||
|
||||
|
||||
@ir_dataclass
|
||||
class BaseView(IRNode):
|
||||
data: IRNode
|
||||
@ -6990,9 +7010,7 @@ class FallbackKernel(ExternKernelAlloc):
|
||||
|
||||
# We need this extra check for input alignment since the example
|
||||
# inputs we created are always aligned.
|
||||
has_unaligned_input = any(
|
||||
arg.get_name() in V.graph.unaligned_buffers for arg in tensor_args
|
||||
)
|
||||
has_unaligned_input = any(is_unaligned(arg) for arg in tensor_args)
|
||||
|
||||
device = cls.find_device(tensor_args, example_output)
|
||||
|
||||
|
Reference in New Issue
Block a user