mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add int8 support to bsr_dense_addmm and bsr_dense_mm Triton kernels (#133855)
As in the title. In addition, the PR introduces `_int_bsr_dense_addmm` that is equivalent to `bsr_dense_addmm` except for int8 inputs the operation result is int32 tensor (similar to existing `_int_mm`). Pull Request resolved: https://github.com/pytorch/pytorch/pull/133855 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
a3e1416c05
commit
345578afb4
@ -11,7 +11,8 @@ from torch.testing import make_tensor, FileCheck
|
||||
from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, TEST_CUSPARSE_GENERIC
|
||||
from torch.testing._internal.common_utils import \
|
||||
(TEST_WITH_TORCHINDUCTOR, TEST_WITH_ROCM, TEST_SCIPY, TEST_NUMPY, TEST_MKL, IS_WINDOWS, TestCase, run_tests,
|
||||
load_tests, coalescedonoff, parametrize, subtest, skipIfTorchDynamo, skipIfRocm, IS_FBCODE, IS_REMOTE_GPU)
|
||||
load_tests, coalescedonoff, parametrize, subtest, skipIfTorchDynamo, skipIfRocm, IS_FBCODE, IS_REMOTE_GPU,
|
||||
suppress_warnings)
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(ops, instantiate_device_type_tests, dtypes, OpDTypes, dtypesIfCUDA, onlyCPU, onlyCUDA, skipCUDAIfNoSparseGeneric,
|
||||
precisionOverride, skipMeta, skipCUDAIf, skipCPUIfNoMklSparse, skipCUDAIfRocmVersionLessThan,
|
||||
@ -3978,28 +3979,44 @@ class TestSparseCompressedTritonKernels(TestCase):
|
||||
# but key is still valid:
|
||||
self.assertEqual(d.get(key5), (key5, 567), **assertEqualOptions)
|
||||
|
||||
@parametrize("op", ['bsr_dense_addmm', 'bsr_dense_mm', 'bsr_dense_linear'])
|
||||
@suppress_warnings
|
||||
@parametrize("op", ['bsr_dense_addmm', 'bsr_dense_mm', 'bsr_dense_linear', '_int_bsr_dense_addmm'])
|
||||
@parametrize("blocksize", [16, '16x32', 32])
|
||||
@onlyCUDA
|
||||
@skipIfRocm
|
||||
@dtypes(torch.half, torch.bfloat16, torch.float)
|
||||
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
|
||||
@dtypes(torch.half, torch.bfloat16, torch.float, torch.int8)
|
||||
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float, torch.int8)
|
||||
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
|
||||
def test_triton_kernel(self, op, device, dtype, blocksize):
|
||||
from torch.sparse._triton_ops import bsr_dense_addmm, bsr_dense_mm
|
||||
from torch.sparse._triton_ops import bsr_dense_addmm, bsr_dense_mm, _int_bsr_dense_addmm
|
||||
from torch.sparse._triton_ops_meta import (create_blocked_tensor, get_meta,
|
||||
optimize_bsr_dense_addmm, dump)
|
||||
|
||||
def bsr_dense_linear(input, weights, bias=None):
|
||||
return torch.nn.functional.linear(input, weights, bias=bias).transpose(-1, -2)
|
||||
|
||||
operation = dict(bsr_dense_addmm=bsr_dense_addmm, bsr_dense_mm=bsr_dense_mm, bsr_dense_linear=bsr_dense_linear)[op]
|
||||
operation = dict(bsr_dense_addmm=bsr_dense_addmm, bsr_dense_mm=bsr_dense_mm, bsr_dense_linear=bsr_dense_linear,
|
||||
_int_bsr_dense_addmm=_int_bsr_dense_addmm)[op]
|
||||
|
||||
def reference(input, mat1, mat2, beta=1, alpha=1):
|
||||
def reference(input, mat1, mat2, beta=1, alpha=1, op=op):
|
||||
assert mat1.layout is torch.strided
|
||||
assert mat2.layout is torch.strided
|
||||
if dtype is torch.int8:
|
||||
if op == '_int_bsr_dense_addmm':
|
||||
return beta * input + alpha * torch._int_mm(mat1, mat2)
|
||||
# workaround RuntimeError: "addmm_cuda" not implemented for 'Char'
|
||||
return beta * input + alpha * torch._int_mm(mat1, mat2).to(torch.int8)
|
||||
return beta * input + alpha * (mat1 @ mat2)
|
||||
|
||||
if op == '_int_bsr_dense_addmm':
|
||||
# _int_bsr_dense_addmm is same as bsr_dense_addmm except
|
||||
# with int8 inputs, _int_bsr_dense_addmm returns int32
|
||||
# result. This is covered by operation and reference
|
||||
# definitions above and all other definitions below are
|
||||
# identical between _int_bsr_dense_addmm and
|
||||
# bsr_dense_addmm.
|
||||
op = 'bsr_dense_addmm'
|
||||
|
||||
def nc_copy(t, axes=(-1,)):
|
||||
"""Return a copy of input.
|
||||
|
||||
@ -4030,6 +4047,13 @@ class TestSparseCompressedTritonKernels(TestCase):
|
||||
# todo: eliminate this skip
|
||||
self.skipTest(f"{op} does not support non-square blocks")
|
||||
|
||||
if op in {"bsr_dense_linear"} and dtype is torch.int8:
|
||||
# todo: eliminate this skip
|
||||
self.skipTest(f"{op} does not support int8")
|
||||
|
||||
if dtype is torch.int8 and min(BM, BK) < 32:
|
||||
self.skipTest("triton kernel does not support support int8 blocks smaller than 32")
|
||||
|
||||
beta_lst = dict(bsr_dense_addmm=[0, 1, 2], bsr_dense_mm=[0], bsr_dense_linear=[1])[op]
|
||||
alpha_lst = dict(bsr_dense_addmm=[0, 1, 2], bsr_dense_mm=[1], bsr_dense_linear=[1])[op]
|
||||
sparsity_lst = [0, 0.5, 1]
|
||||
@ -4096,28 +4120,33 @@ class TestSparseCompressedTritonKernels(TestCase):
|
||||
result = operation(*args, **kwargs)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
@parametrize("op", ['bsr_dense_addmm'])
|
||||
@parametrize("op", ['bsr_dense_addmm', '_int_bsr_dense_addmm'])
|
||||
@onlyCUDA
|
||||
@skipIfRocm
|
||||
@dtypes(torch.half, torch.bfloat16, torch.float)
|
||||
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
|
||||
@dtypes(torch.half, torch.bfloat16, torch.float, torch.int8)
|
||||
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float, torch.int8)
|
||||
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
|
||||
def test_triton_tune(self, op, device, dtype):
|
||||
from torch.sparse._triton_ops import bsr_dense_addmm
|
||||
from torch.sparse._triton_ops_meta import (create_blocked_tensor, tune_bsr_dense_addmm, get_meta)
|
||||
from torch.sparse._triton_ops import bsr_dense_addmm, _int_bsr_dense_addmm
|
||||
from torch.sparse._triton_ops_meta import (create_blocked_tensor, tune_bsr_dense_addmm, tune__int_bsr_dense_addmm, get_meta)
|
||||
|
||||
operation = dict(bsr_dense_addmm=bsr_dense_addmm)[op]
|
||||
tuner = dict(bsr_dense_addmm=tune_bsr_dense_addmm)[op]
|
||||
operation = dict(bsr_dense_addmm=bsr_dense_addmm, _int_bsr_dense_addmm=_int_bsr_dense_addmm)[op]
|
||||
tuner = dict(bsr_dense_addmm=tune_bsr_dense_addmm,
|
||||
_int_bsr_dense_addmm=tune__int_bsr_dense_addmm)[op]
|
||||
|
||||
M, K, N = 16, 16, 32
|
||||
if op == '_int_bsr_dense_addmm':
|
||||
M, K, N = 32, 32, 32
|
||||
blocksize = (32, 32)
|
||||
else:
|
||||
M, K, N = 16, 16, 32
|
||||
blocksize = (16, 16)
|
||||
sparsity = 1.0
|
||||
blocksize = (16, 16)
|
||||
bsr = create_blocked_tensor(0, M, K, blocksize, sparsity, dtype, device).to_sparse_bsr(blocksize)
|
||||
sparsity = 1 - bsr._nnz() * blocksize[0] * blocksize[1] / (M * K)
|
||||
input = make_tensor(K, N, dtype=dtype, device=device)
|
||||
dense = make_tensor(K, N, dtype=dtype, device=device)
|
||||
|
||||
if op == 'bsr_dense_addmm':
|
||||
if op in {'bsr_dense_addmm', '_int_bsr_dense_addmm'}:
|
||||
args = (input, bsr, dense)
|
||||
|
||||
def get_current_meta():
|
||||
|
||||
Reference in New Issue
Block a user