Add int8 support to bsr_dense_addmm and bsr_dense_mm Triton kernels (#133855)

As in the title. In addition, the PR introduces `_int_bsr_dense_addmm` that is equivalent to `bsr_dense_addmm` except for int8 inputs the operation result is int32 tensor (similar to existing `_int_mm`).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133855
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Pearu Peterson
2024-08-20 01:12:29 +03:00
committed by PyTorch MergeBot
parent a3e1416c05
commit 345578afb4
3 changed files with 144 additions and 30 deletions

View File

@ -11,7 +11,8 @@ from torch.testing import make_tensor, FileCheck
from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, TEST_CUSPARSE_GENERIC
from torch.testing._internal.common_utils import \
(TEST_WITH_TORCHINDUCTOR, TEST_WITH_ROCM, TEST_SCIPY, TEST_NUMPY, TEST_MKL, IS_WINDOWS, TestCase, run_tests,
load_tests, coalescedonoff, parametrize, subtest, skipIfTorchDynamo, skipIfRocm, IS_FBCODE, IS_REMOTE_GPU)
load_tests, coalescedonoff, parametrize, subtest, skipIfTorchDynamo, skipIfRocm, IS_FBCODE, IS_REMOTE_GPU,
suppress_warnings)
from torch.testing._internal.common_device_type import \
(ops, instantiate_device_type_tests, dtypes, OpDTypes, dtypesIfCUDA, onlyCPU, onlyCUDA, skipCUDAIfNoSparseGeneric,
precisionOverride, skipMeta, skipCUDAIf, skipCPUIfNoMklSparse, skipCUDAIfRocmVersionLessThan,
@ -3978,28 +3979,44 @@ class TestSparseCompressedTritonKernels(TestCase):
# but key is still valid:
self.assertEqual(d.get(key5), (key5, 567), **assertEqualOptions)
@parametrize("op", ['bsr_dense_addmm', 'bsr_dense_mm', 'bsr_dense_linear'])
@suppress_warnings
@parametrize("op", ['bsr_dense_addmm', 'bsr_dense_mm', 'bsr_dense_linear', '_int_bsr_dense_addmm'])
@parametrize("blocksize", [16, '16x32', 32])
@onlyCUDA
@skipIfRocm
@dtypes(torch.half, torch.bfloat16, torch.float)
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
@dtypes(torch.half, torch.bfloat16, torch.float, torch.int8)
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float, torch.int8)
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
def test_triton_kernel(self, op, device, dtype, blocksize):
from torch.sparse._triton_ops import bsr_dense_addmm, bsr_dense_mm
from torch.sparse._triton_ops import bsr_dense_addmm, bsr_dense_mm, _int_bsr_dense_addmm
from torch.sparse._triton_ops_meta import (create_blocked_tensor, get_meta,
optimize_bsr_dense_addmm, dump)
def bsr_dense_linear(input, weights, bias=None):
return torch.nn.functional.linear(input, weights, bias=bias).transpose(-1, -2)
operation = dict(bsr_dense_addmm=bsr_dense_addmm, bsr_dense_mm=bsr_dense_mm, bsr_dense_linear=bsr_dense_linear)[op]
operation = dict(bsr_dense_addmm=bsr_dense_addmm, bsr_dense_mm=bsr_dense_mm, bsr_dense_linear=bsr_dense_linear,
_int_bsr_dense_addmm=_int_bsr_dense_addmm)[op]
def reference(input, mat1, mat2, beta=1, alpha=1):
def reference(input, mat1, mat2, beta=1, alpha=1, op=op):
assert mat1.layout is torch.strided
assert mat2.layout is torch.strided
if dtype is torch.int8:
if op == '_int_bsr_dense_addmm':
return beta * input + alpha * torch._int_mm(mat1, mat2)
# workaround RuntimeError: "addmm_cuda" not implemented for 'Char'
return beta * input + alpha * torch._int_mm(mat1, mat2).to(torch.int8)
return beta * input + alpha * (mat1 @ mat2)
if op == '_int_bsr_dense_addmm':
# _int_bsr_dense_addmm is same as bsr_dense_addmm except
# with int8 inputs, _int_bsr_dense_addmm returns int32
# result. This is covered by operation and reference
# definitions above and all other definitions below are
# identical between _int_bsr_dense_addmm and
# bsr_dense_addmm.
op = 'bsr_dense_addmm'
def nc_copy(t, axes=(-1,)):
"""Return a copy of input.
@ -4030,6 +4047,13 @@ class TestSparseCompressedTritonKernels(TestCase):
# todo: eliminate this skip
self.skipTest(f"{op} does not support non-square blocks")
if op in {"bsr_dense_linear"} and dtype is torch.int8:
# todo: eliminate this skip
self.skipTest(f"{op} does not support int8")
if dtype is torch.int8 and min(BM, BK) < 32:
self.skipTest("triton kernel does not support support int8 blocks smaller than 32")
beta_lst = dict(bsr_dense_addmm=[0, 1, 2], bsr_dense_mm=[0], bsr_dense_linear=[1])[op]
alpha_lst = dict(bsr_dense_addmm=[0, 1, 2], bsr_dense_mm=[1], bsr_dense_linear=[1])[op]
sparsity_lst = [0, 0.5, 1]
@ -4096,28 +4120,33 @@ class TestSparseCompressedTritonKernels(TestCase):
result = operation(*args, **kwargs)
self.assertEqual(result, expected)
@parametrize("op", ['bsr_dense_addmm'])
@parametrize("op", ['bsr_dense_addmm', '_int_bsr_dense_addmm'])
@onlyCUDA
@skipIfRocm
@dtypes(torch.half, torch.bfloat16, torch.float)
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
@dtypes(torch.half, torch.bfloat16, torch.float, torch.int8)
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float, torch.int8)
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
def test_triton_tune(self, op, device, dtype):
from torch.sparse._triton_ops import bsr_dense_addmm
from torch.sparse._triton_ops_meta import (create_blocked_tensor, tune_bsr_dense_addmm, get_meta)
from torch.sparse._triton_ops import bsr_dense_addmm, _int_bsr_dense_addmm
from torch.sparse._triton_ops_meta import (create_blocked_tensor, tune_bsr_dense_addmm, tune__int_bsr_dense_addmm, get_meta)
operation = dict(bsr_dense_addmm=bsr_dense_addmm)[op]
tuner = dict(bsr_dense_addmm=tune_bsr_dense_addmm)[op]
operation = dict(bsr_dense_addmm=bsr_dense_addmm, _int_bsr_dense_addmm=_int_bsr_dense_addmm)[op]
tuner = dict(bsr_dense_addmm=tune_bsr_dense_addmm,
_int_bsr_dense_addmm=tune__int_bsr_dense_addmm)[op]
M, K, N = 16, 16, 32
if op == '_int_bsr_dense_addmm':
M, K, N = 32, 32, 32
blocksize = (32, 32)
else:
M, K, N = 16, 16, 32
blocksize = (16, 16)
sparsity = 1.0
blocksize = (16, 16)
bsr = create_blocked_tensor(0, M, K, blocksize, sparsity, dtype, device).to_sparse_bsr(blocksize)
sparsity = 1 - bsr._nnz() * blocksize[0] * blocksize[1] / (M * K)
input = make_tensor(K, N, dtype=dtype, device=device)
dense = make_tensor(K, N, dtype=dtype, device=device)
if op == 'bsr_dense_addmm':
if op in {'bsr_dense_addmm', '_int_bsr_dense_addmm'}:
args = (input, bsr, dense)
def get_current_meta():