mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fix https://github.com/pytorch/pytorch/issues/151589 It's interesting that the Q4_K dequantization example in the referred GH issue does not crash even if Inductor pass triton the wrong alignment information. I dig this a bit. The main reason is, there are 2 things in triton that decides the vectorization size 1. alignement 2. max number of contiguous elements a thread need to process Here is the triton code that decides vectorization size [link](c5fed8e1ca/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp (L147-L157)
), and here is the triton code that considers contiguity for vectorization [link](c5fed8e1ca/lib/Analysis/AxisInfo.cpp (L1250-L1269)
) When Inductor wrongly tell triton that a unaligned tensor is aligned, Triton may not do vectorization (or not do full vectorization) because of the second restriction. Check this test: ``` @parametrize( "size", ( 128, 1024, 1024 * 1024, ), ) def test_slice_view_dtype(self, size): offset = 1 def f(x): return x[2:].view(dtype=torch.float32) + 1 x = torch.randn((size + offset) * 2, dtype=torch.bfloat16, device=self.device) self.common(f, (x,), reference_in_float=False) ``` Before the fix, Inductor would tell Triton that the output of aten.view.dtype tensor is aligned even though it's not. That tensor will be passed to the triton kernel for the aten.add. Triton may do different vectorization decision depending on the tensor size 1. when size = 128, triton pick ld.global.b32 to load data from global memory 2. when size = 1024, triton uses ld.global.v2.b32 4. when size = 1024 * 1024, triton uses ld.global.v4.b32 So whether wrong alignment metadata causes issue depends on if triton picks the vectorized instructions. The latter depends on the triton config (block size) decided by inductor and triton internal logic (how they assign elements to each thread). We'd better to make sure Inductor always generate correct metadata to make sure such hidden issues does not turn into crash later. Pull Request resolved: https://github.com/pytorch/pytorch/pull/151859 Approved by: https://github.com/jansel, https://github.com/eellison ghstack dependencies: #151841
249 lines
7.9 KiB
Python
249 lines
7.9 KiB
Python
# Owner(s): ["module: inductor"]
|
|
import contextlib
|
|
import sys
|
|
import unittest
|
|
|
|
import torch
|
|
from torch._inductor import config
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
MACOS_VERSION,
|
|
parametrize,
|
|
)
|
|
from torch.testing._internal.inductor_utils import GPU_TYPE, RUN_CPU, RUN_GPU
|
|
|
|
|
|
try:
|
|
try:
|
|
from . import test_torchinductor
|
|
except ImportError:
|
|
import test_torchinductor # @manual=fbcode//caffe2/test/inductor:test_inductor-library
|
|
except unittest.SkipTest:
|
|
if __name__ == "__main__":
|
|
sys.exit(0)
|
|
raise
|
|
|
|
TestCase = test_torchinductor.TestCase
|
|
check_model = test_torchinductor.check_model
|
|
check_model_gpu = test_torchinductor.check_model_gpu
|
|
skip_if_cpp_wrapper = test_torchinductor.skip_if_cpp_wrapper
|
|
copy_tests = test_torchinductor.copy_tests
|
|
define_custom_op_for_test = test_torchinductor.define_custom_op_for_test
|
|
|
|
|
|
@instantiate_parametrized_tests
|
|
class CommonTemplate:
|
|
def test_unaligned_input(self):
|
|
def fn(x):
|
|
return torch.nn.functional.relu(x)
|
|
|
|
x = torch.randn(1024 + 16, device=self.device)[1:-15]
|
|
# TODO (malfet): Investigate failures on MacOS-14
|
|
with (
|
|
contextlib.nullcontext()
|
|
if self.device != "mps" or MACOS_VERSION >= 15.0
|
|
else self.assertRaises(AssertionError)
|
|
):
|
|
self.common(fn, (x,), check_lowp=False)
|
|
|
|
def test_unaligned_input_2d(self):
|
|
def fn(x):
|
|
return torch.nn.functional.relu(x)
|
|
|
|
x = torch.randn(1024, 1024 + 16, device=self.device)[:, 1:-15]
|
|
self.common(fn, (x,), check_lowp=False)
|
|
|
|
def test_alignment_without_custom_op(self):
|
|
def fn(x):
|
|
a = torch.nn.functional.relu(x)
|
|
b = (3 * a)[1:-15]
|
|
c = torch.cos(b)
|
|
return c
|
|
|
|
x = torch.randn(1024 + 16, device=self.device)
|
|
self.common(fn, (x,), check_lowp=False)
|
|
|
|
@config.patch(implicit_fallbacks=True)
|
|
def test_no_align_for_custom_op(self):
|
|
def slice1d(x):
|
|
return (3 * x)[1:-15]
|
|
|
|
def slice1d_meta(x):
|
|
return torch.empty_like(x)[1:-15]
|
|
|
|
define_custom_op_for_test("slice1d", slice1d, slice1d_meta)
|
|
|
|
def fn(x):
|
|
a = torch.nn.functional.relu(x)
|
|
b = torch.ops.test.slice1d(a)
|
|
c = torch.cos(b)
|
|
return c
|
|
|
|
x = torch.randn(1024 + 16, device=self.device)
|
|
self.common(fn, (x,), check_lowp=False)
|
|
|
|
@config.patch(implicit_fallbacks=True)
|
|
def test_no_align_for_custom_op_2d(self):
|
|
def slice2d(x):
|
|
return (3 * x)[..., 1:-15]
|
|
|
|
def slice2d_meta(x):
|
|
return torch.empty_like(x)[..., 1:-15]
|
|
|
|
define_custom_op_for_test("slice2d", slice2d, slice2d_meta)
|
|
|
|
def fn(x):
|
|
a = torch.nn.functional.relu(x)
|
|
b = torch.ops.test.slice2d(a)
|
|
c = torch.cos(b)
|
|
return c
|
|
|
|
x = torch.randn(1024, 1024 + 16, device=self.device)
|
|
self.common(fn, (x,), check_lowp=False)
|
|
|
|
@config.patch(implicit_fallbacks=True, alignment_asserts=True)
|
|
@skip_if_cpp_wrapper(
|
|
"Inductor does not generate alignment assertion for cpp_wrapper right now"
|
|
)
|
|
def test_incorrect_meta_for_custom_op_2d(self):
|
|
def slice2d(x):
|
|
return (3 * x)[..., 1:-15]
|
|
|
|
def slice2d_meta(x):
|
|
return torch.empty_like(x)[..., 0:-16]
|
|
|
|
define_custom_op_for_test("slice2d_incorrect_meta", slice2d, slice2d_meta)
|
|
|
|
def fn(x):
|
|
a = torch.nn.functional.relu(x)
|
|
b = torch.ops.test.slice2d_incorrect_meta(a)
|
|
c = torch.cos(b)
|
|
return c
|
|
|
|
x = torch.randn(1024, 1024 + 16, device=self.device)
|
|
|
|
expected_error = "Expect the tensor to be 16 bytes aligned. Fail due to storage_offset=1 itemsize=4"
|
|
with self.assertRaisesRegex(AssertionError, expected_error):
|
|
self.common(fn, (x,), check_lowp=False)
|
|
|
|
def test_slice(self):
|
|
def f(x):
|
|
return x[1:] + 1
|
|
|
|
x = torch.randn(1025, device=self.device)
|
|
self.common(f, (x,))
|
|
|
|
def test_view_dtype_slice(self):
|
|
def f(x):
|
|
return x.view(dtype=torch.float32)[1:] + 1
|
|
|
|
x = torch.randn(1025 * 2, dtype=torch.bfloat16, device=self.device)
|
|
self.common(f, (x,), reference_in_float=False)
|
|
|
|
@parametrize(
|
|
"size",
|
|
(
|
|
# wrapper for size = 128: https://gist.github.com/shunting314/88f1e72957b9fc5e9826aaa346a0e652
|
|
# ptx: https://gist.github.com/shunting314/eb657ee8821eef9f0685b7b91e2ad5c2
|
|
# the ptx file uses ld.global.b32 to load input buffer
|
|
128,
|
|
# wrapper for size = 1024: https://gist.github.com/shunting314/d7f64e1f52f6b1e2ec25e1a51052ce43
|
|
# ptx: https://gist.github.com/shunting314/a24ff7563bb6b04523d11b119ab0f2b2
|
|
# the ptx file uses ld.global.v2.b32 to load input buffer
|
|
1024,
|
|
# wrapper for size = 1024 * 1024: https://gist.github.com/shunting314/016b95cf0b6e9a75c25f5c9d5ed0a2ba
|
|
# ptx: https://gist.github.com/shunting314/360112a4893c759b114c12fc99958297
|
|
# the ptx file uses ld.global.v4.b32 to load input buffer
|
|
1024 * 1024,
|
|
),
|
|
)
|
|
def test_slice_view_dtype(self, size):
|
|
offset = 1
|
|
|
|
def f(x):
|
|
return x[2:].view(dtype=torch.float32) + 1
|
|
|
|
x = torch.randn((size + offset) * 2, dtype=torch.bfloat16, device=self.device)
|
|
self.common(f, (x,), reference_in_float=False)
|
|
|
|
def test_Q4_K_dequantization(self):
|
|
"""
|
|
Test the alignment issue for Q4_K dequantization.
|
|
"""
|
|
|
|
QK_K = 256
|
|
K_SCALE_SIZE = 12
|
|
|
|
def get_scale_min(scales):
|
|
n_blocks = scales.shape[0]
|
|
scales = scales.view(torch.uint8)
|
|
scales = scales.reshape((n_blocks, 3, 4))
|
|
|
|
d, m, m_d = torch.split(scales, scales.shape[-2] // 3, dim=-2)
|
|
|
|
sc = torch.cat([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], dim=-1)
|
|
min = torch.cat([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], dim=-1)
|
|
|
|
return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8)))
|
|
|
|
def split_block_dims(blocks, *args):
|
|
n_max = blocks.shape[1]
|
|
dims = list(args) + [n_max - sum(args)]
|
|
return torch.split(blocks, dims, dim=1)
|
|
|
|
def dequantize_blocks_Q4_K(blocks, block_size, type_size):
|
|
n_blocks = blocks.shape[0]
|
|
|
|
d, dmin, scales, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE)
|
|
d = d.view(torch.float16)
|
|
dmin = dmin.view(torch.float16)
|
|
|
|
sc, m = get_scale_min(scales)
|
|
|
|
d = (d * sc).reshape((n_blocks, -1, 1))
|
|
dm = (dmin * m).reshape((n_blocks, -1, 1))
|
|
|
|
qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor(
|
|
[0, 4], device=d.device, dtype=torch.uint8
|
|
).reshape((1, 1, 2, 1))
|
|
qs = (qs & 0x0F).reshape((n_blocks, -1, 32))
|
|
|
|
return (d * qs - dm).reshape((n_blocks, QK_K))
|
|
|
|
data = torch.randint(
|
|
0, 16, (18432, 1728), device=self.device, dtype=torch.uint8
|
|
)
|
|
|
|
def dequantize(data):
|
|
block_size, type_size = 256, 144
|
|
rows = data.reshape((-1, data.shape[-1])).view(torch.uint8)
|
|
n_blocks = rows.numel() // type_size
|
|
blocks = rows.reshape((n_blocks, type_size))
|
|
blocks = dequantize_blocks_Q4_K(blocks, block_size, type_size)
|
|
return blocks.reshape(18432, 3072)
|
|
|
|
self.common(dequantize, (data,), check_lowp=False, atol=1e-3, rtol=1e-3)
|
|
|
|
|
|
if RUN_CPU:
|
|
|
|
class CpuTests(TestCase):
|
|
common = check_model
|
|
device = "cpu"
|
|
|
|
copy_tests(CommonTemplate, CpuTests, "cpu")
|
|
|
|
if RUN_GPU:
|
|
|
|
class GPUTests(TestCase):
|
|
common = check_model_gpu
|
|
device = GPU_TYPE
|
|
|
|
copy_tests(CommonTemplate, GPUTests, GPU_TYPE)
|
|
|
|
if __name__ == "__main__":
|
|
from torch._inductor.test_case import run_tests
|
|
|
|
if RUN_CPU or RUN_GPU:
|
|
run_tests()
|