Add scaling arguments to bsr_dense_addmm (#136104)

As in the title.

Tackles https://github.com/pytorch/ao/pull/821/files#r1759821413

The PR assumes that the existing tuning parameters are good also when using scaling arguments. This needs to be verified as a follow-up task.

Also, this PR redefines triton-contiguous tensors: the tensor must have strides not larger than 1. This will now allow zero strides that previously triggered `contiguous` call although the underlying memory buffer was contiguous.

Re: "a considerable slow-down occurs because tensor data is copied element-wise rather than chunk-wise" - this note should refer to a code (torch or triton?) that implements the element/chunk-wise copy so that we could verify that allowing zero strides indeed would not trigger element-wise copies. Atm, the performance increase in ViT-H benchmarks (that involve using 0 strides) is an evidence that allowing zero strides does not lead to slow-downs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136104
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Pearu Peterson
2024-09-16 14:51:47 +03:00
committed by PyTorch MergeBot
parent bfbcdf4967
commit b76d1b79e6
3 changed files with 173 additions and 42 deletions

View File

@ -4027,6 +4027,7 @@ class TestSparseCompressedTritonKernels(TestCase):
@skipIfRocm
@dtypes(torch.half, torch.bfloat16, torch.float, torch.int8)
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float, torch.int8)
@precisionOverride({torch.float16: 6e-1})
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
def test_triton_kernel(self, op, device, dtype, blocksize):
from torch.sparse._triton_ops import bsr_dense_addmm, bsr_dense_mm, _int_bsr_dense_addmm
@ -4039,15 +4040,24 @@ class TestSparseCompressedTritonKernels(TestCase):
operation = dict(bsr_dense_addmm=bsr_dense_addmm, bsr_dense_mm=bsr_dense_mm, bsr_dense_linear=bsr_dense_linear,
_int_bsr_dense_addmm=_int_bsr_dense_addmm)[op]
def reference(input, mat1, mat2, beta=1, alpha=1, op=op):
def reference(input, mat1, mat2, beta=1, alpha=1, left_alpha=None, right_alpha=None, op=op):
assert mat1.layout is torch.strided
assert mat2.layout is torch.strided
if dtype is torch.int8:
if op == '_int_bsr_dense_addmm':
return beta * input + alpha * torch._int_mm(mat1, mat2)
# workaround RuntimeError: "addmm_cuda" not implemented for 'Char'
return beta * input + alpha * torch._int_mm(mat1, mat2).to(torch.int8)
return beta * input + alpha * (mat1 @ mat2)
mat12 = torch._int_mm(mat1, mat2)
else:
# workaround RuntimeError: "addmm_cuda" not implemented for 'Char'
mat12 = torch._int_mm(mat1, mat2).to(torch.int8)
else:
mat12 = mat1 @ mat2
if alpha != 1:
mat12 *= alpha
if left_alpha is not None:
mat12 = left_alpha.reshape(*left_alpha.shape[:-1], -1, 1) * mat12
if right_alpha is not None:
mat12 = mat12 * right_alpha.reshape(*right_alpha.shape[:-1], 1, -1)
return beta * input + mat12
if op == '_int_bsr_dense_addmm':
# _int_bsr_dense_addmm is same as bsr_dense_addmm except
@ -4056,6 +4066,8 @@ class TestSparseCompressedTritonKernels(TestCase):
# definitions above and all other definitions below are
# identical between _int_bsr_dense_addmm and
# bsr_dense_addmm.
if dtype.is_floating_point or dtype.is_complex:
self.skipTest(f"Redundant test: {op} on {dtype} tensors")
op = 'bsr_dense_addmm'
def nc_copy(t, axes=(-1,)):
@ -4101,14 +4113,21 @@ class TestSparseCompressedTritonKernels(TestCase):
blocks_per_row_lst = [1, 2]
blocks_per_col_lst = [1, 2]
result_cols_lst = [16, 32, 64]
for beta, alpha, sparsity, blocks_per_row, blocks_per_col, N in itertools.product(
beta_lst, alpha_lst, sparsity_lst, blocks_per_row_lst, blocks_per_col_lst, result_cols_lst):
has_left_alpha_lst = dict(bsr_dense_addmm=[False, True], bsr_dense_mm=[False], bsr_dense_linear=[False])[op]
has_right_alpha_lst = dict(bsr_dense_addmm=[False, True], bsr_dense_mm=[False], bsr_dense_linear=[False])[op]
high = 1.5 + int(dtype is torch.int8)
for beta, alpha, sparsity, blocks_per_row, blocks_per_col, N, has_left_alpha, has_right_alpha in itertools.product(
beta_lst, alpha_lst, sparsity_lst, blocks_per_row_lst, blocks_per_col_lst, result_cols_lst,
has_left_alpha_lst, has_right_alpha_lst):
M = BM * blocks_per_row
K = BK * blocks_per_col
mat1 = create_blocked_tensor(0, M, K, (BM, BK), sparsity, dtype, device=device)
bsr = mat1.to_sparse_bsr((BM, BK))
mat2 = make_tensor(K, N, dtype=dtype, device=device, low=0.5, high=1.5)
input = make_tensor(M, N, dtype=dtype, device=device, low=0.5, high=1.5)
mat2 = make_tensor(K, N, dtype=dtype, device=device, low=0.5, high=high)
input = make_tensor(M, N, dtype=dtype, device=device, low=0.5, high=high)
left_alpha = make_tensor(M, dtype=dtype, device=device, low=0.5, high=high) if has_left_alpha else None
right_alpha = make_tensor(N, dtype=dtype, device=device, low=0.5, high=high) if has_right_alpha else None
if 0 and op == "bsr_dense_addmm":
# Find optimal kernel parameters, the speed-up is
@ -4121,12 +4140,12 @@ class TestSparseCompressedTritonKernels(TestCase):
meta = get_meta(op, key, version=(0, dtype, 0.5))
if meta is None:
optimize_bsr_dense_addmm(M, K, N, BM, BK, beta=beta, alpha=alpha, dtype=dtype, sparsity=0.5)
meta = get_meta(op, key, version=(0, dtype, 0.5))
assert meta is not None
dump() # this will update torch/sparse/_triton_ops_meta.py
expected = reference(input, mat1, mat2, beta=beta, alpha=alpha)
kwargs = dict(bsr_dense_addmm=dict(beta=beta, alpha=alpha), bsr_dense_mm={},
expected = reference(input, mat1, mat2, beta=beta, alpha=alpha, left_alpha=left_alpha, right_alpha=right_alpha)
kwargs = dict(bsr_dense_addmm=dict(beta=beta, alpha=alpha,
left_alpha=left_alpha, right_alpha=right_alpha), bsr_dense_mm={},
bsr_dense_linear=dict(bias=input.transpose(-1, -2)))[op]
args = dict(bsr_dense_addmm=(input, bsr, mat2), bsr_dense_mm=(bsr, mat2),
@ -4156,7 +4175,7 @@ class TestSparseCompressedTritonKernels(TestCase):
if op in {'bsr_dense_addmm', 'bsr_dense_linear'}:
args = dict(bsr_dense_addmm=(nc_input, bsr, nc_mat2),
bsr_dense_linear=(nc_mat2.transpose(-1, -2), bsr))[op]
kwargs = dict(bsr_dense_addmm=dict(beta=beta, alpha=alpha),
kwargs = dict(bsr_dense_addmm=dict(beta=beta, alpha=alpha, left_alpha=left_alpha, right_alpha=right_alpha),
bsr_dense_linear=dict(bias=nc_input.transpose(-1, -2)))[op]
result = operation(*args, **kwargs)
self.assertEqual(result, expected)

View File

@ -89,14 +89,14 @@ def make_triton_contiguous(t):
"""Return input as a triton-contiguous tensor.
A triton-contiguous tensor is defined as a tensor that has strides
with minimal value equal to 1.
with minimal value smaller than or equal to 1.
While triton kernels support triton-non-contiguous tensors (all
strides being greater than 1 or having 0 strides) arguments, a
considerable slow-down occurs because tensor data is copied
element-wise rather than chunk-wise.
strides being greater than 1) arguments, a considerable slow-down
occurs because tensor data is copied element-wise rather than
chunk-wise. Zero strides is assumed to not have this defect.
"""
if min(t.stride()) != 1:
if min(t.stride()) > 1:
# TODO: investigate if contiguity along other axes than the
# last one can be beneficial for performance
return t.contiguous()
@ -1097,6 +1097,8 @@ def _int_bsr_dense_addmm(
*,
beta=1,
alpha=1,
left_alpha: Optional[torch.Tensor] = None,
right_alpha: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
skip_checks: bool = False,
max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,
@ -1120,6 +1122,8 @@ def _int_bsr_dense_addmm(
dense,
beta=beta,
alpha=alpha,
left_alpha=left_alpha,
right_alpha=right_alpha,
out=out,
skip_checks=skip_checks,
max_grid=max_grid,
@ -1134,11 +1138,21 @@ def bsr_dense_addmm(
*,
beta=1,
alpha=1,
left_alpha: Optional[torch.Tensor] = None,
right_alpha: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
skip_checks: bool = False,
max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,
meta: Optional[dict] = None,
):
"""Compute
out = beta * input + left_alpha.reshape(-1, 1) * (alpha * (bsr @ dense)) * right_alpha.reshape(1, -1)
where left_alpha, right_alpha are (* + 1)-D tensors when
specified, otherwise, these are treated as tensors filled with
ones.
"""
f_name = "bsr_dense_addmm"
values = bsr.values()
crow_indices = bsr.crow_indices()
@ -1150,8 +1164,8 @@ def bsr_dense_addmm(
# todo: implement checks
original_batch_dims_broadcasted = broadcast_batch_dims(f_name, bsr, dense)
if out is None:
original_batch_dims_broadcasted = broadcast_batch_dims(f_name, bsr, dense)
out = dense.new_empty(original_batch_dims_broadcasted + (M, N))
if bsr._nnz() == 0 or alpha == 0 or N == 0 or M == 0 or K == 0:
@ -1163,6 +1177,30 @@ def bsr_dense_addmm(
out.mul_(beta)
return out
left_alpha_is_one = False
right_alpha_is_one = False
if left_alpha is None:
left_alpha_is_one = True
left_alpha = dense.new_empty(()).expand(
*original_batch_dims_broadcasted, M, N
) # not referenced
else:
left_alpha = left_alpha.view(*original_batch_dims_broadcasted, M, 1).expand(
*original_batch_dims_broadcasted, M, N
)
if right_alpha is None:
right_alpha_is_one = True
right_alpha = dense.new_empty(()).expand(
*original_batch_dims_broadcasted, M, N
) # not referenced
else:
right_alpha = right_alpha.view(*original_batch_dims_broadcasted, 1, N).expand(
*original_batch_dims_broadcasted, M, N
)
assert left_alpha.stride()[-1] == 0
assert right_alpha.stride()[-2] == 0
if meta is None:
sparsity = round(1 - bsr._nnz() * blocksize[0] * blocksize[1] / (M * K), 2)
meta = bsr_dense_addmm_meta(
@ -1178,9 +1216,16 @@ def bsr_dense_addmm(
)
out_backup = out
crow_indices, col_indices, values, input, dense, out = prepare_inputs(
bsr, input, dense, out
)
(
crow_indices,
col_indices,
values,
input,
dense,
left_alpha,
right_alpha,
out,
) = prepare_inputs(bsr, input, dense, left_alpha, right_alpha, out)
BM, BK = blocksize
SPLIT_N = meta.get("SPLIT_N", N // BM)
@ -1191,6 +1236,9 @@ def bsr_dense_addmm(
dense = tile_to_blocksize(dense, (BK, BN))
input = tile_to_blocksize(input, (BM, BN))
left_alpha = tile_to_blocksize(left_alpha, (BM, BN))
right_alpha = tile_to_blocksize(right_alpha, (BM, BN))
dot_out_dtype = {
torch.float16: tl.float32,
torch.bfloat16: tl.float32,
@ -1216,6 +1264,8 @@ def bsr_dense_addmm(
col_indices: (0, None, None),
input: (0, -3, -4),
dense: (0, -3, None),
left_alpha: (0, -3, -4),
right_alpha: (0, -3, -4),
out: (0, -3, -4),
}
@ -1229,6 +1279,8 @@ def bsr_dense_addmm(
beta_is_one=beta == 1,
beta_is_nonzero=beta != 0,
alpha_is_one=alpha == 1,
left_alpha_is_one=left_alpha_is_one,
right_alpha_is_one=right_alpha_is_one,
BLOCKSIZE_ROW=BM,
BLOCKSIZE_INNER=BK,
BLOCKSIZE_COL=BN,
@ -2278,6 +2330,22 @@ if has_triton():
dense_row_block_stride,
dense_col_block_stride,
# dense epilogue
# left_alpha prologue
left_alpha_ptr,
left_alpha_batch_stride,
left_alpha_tiled_row_stride,
left_alpha_tiled_col_stride: tl.constexpr,
left_alpha_row_block_stride,
left_alpha_col_block_stride: tl.constexpr,
# left_alpha epilogue
# right_alpha prologue
right_alpha_ptr,
right_alpha_batch_stride,
right_alpha_tiled_row_stride: tl.constexpr,
right_alpha_tiled_col_stride,
right_alpha_row_block_stride: tl.constexpr,
right_alpha_col_block_stride,
# right_alpha epilogue
# output prologue
output_ptr,
output_batch_stride,
@ -2291,6 +2359,8 @@ if has_triton():
beta_is_one: tl.constexpr,
beta_is_nonzero: tl.constexpr,
alpha_is_one: tl.constexpr,
left_alpha_is_one: tl.constexpr,
right_alpha_is_one: tl.constexpr,
BLOCKSIZE_ROW: tl.constexpr,
BLOCKSIZE_COL: tl.constexpr,
BLOCKSIZE_INNER: tl.constexpr,
@ -2299,6 +2369,12 @@ if has_triton():
GROUP_SIZE_ROW: tl.constexpr,
SPLIT_N: tl.constexpr,
):
# left/right_alpha tensors are originally (* + 1)-dimensional
assert left_alpha_tiled_col_stride == 0
assert left_alpha_col_block_stride == 0
assert right_alpha_tiled_row_stride == 0
assert right_alpha_row_block_stride == 0
batch_pid = tl.program_id(axis=2)
row_block_pid = tl.program_id(axis=0)
col_block_pid = tl.program_id(axis=1)
@ -2324,17 +2400,6 @@ if has_triton():
inner_block_arange = tl.arange(0, BLOCKSIZE_INNER)
col_block_arange = tl.arange(0, BLOCKSIZE_COL)
if beta_is_nonzero:
# Pointers are set to exact write-to locations
input_ptrs = (
input_ptr
+ input_batch_stride * batch_pid
+ input_tiled_row_stride * row_block_pid
+ input_tiled_col_stride * col_block_pid
+ input_row_block_stride * row_block_arange[:, None]
+ input_col_block_stride * col_block_arange[None, :]
)
# Pointers are set to the first block of the current row.
values_block_ptrs = (
values_ptr
@ -2371,14 +2436,7 @@ if has_triton():
+ col_indices_stride * nnz_offset
)
# alpha is never 0
if beta_is_nonzero:
output_acc_block = tl.load(input_ptrs).to(acc_dtype) # type: ignore[possibly-undefined]
if not (beta_is_one and alpha_is_one):
beta_alpha = beta / alpha
output_acc_block *= beta_alpha
else:
output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype)
output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype)
for _ in range(row_nnz):
values_block = tl.load(values_block_ptrs)
@ -2402,6 +2460,42 @@ if has_triton():
if not alpha_is_one:
output_acc_block *= alpha
if not left_alpha_is_one:
left_alpha_ptrs = (
left_alpha_ptr
+ left_alpha_batch_stride * batch_pid
+ left_alpha_tiled_row_stride * row_block_pid
+ left_alpha_tiled_col_stride * col_block_pid
+ left_alpha_row_block_stride * row_block_arange[:, None]
+ left_alpha_col_block_stride * col_block_arange[None, :]
)
output_acc_block *= tl.load(left_alpha_ptrs)
if not right_alpha_is_one:
right_alpha_ptrs = (
right_alpha_ptr
+ right_alpha_batch_stride * batch_pid
+ right_alpha_tiled_row_stride * row_block_pid
+ right_alpha_tiled_col_stride * col_block_pid
+ right_alpha_row_block_stride * row_block_arange[:, None]
+ right_alpha_col_block_stride * col_block_arange[None, :]
)
output_acc_block *= tl.load(right_alpha_ptrs)
if beta_is_nonzero:
input_ptrs = (
input_ptr
+ input_batch_stride * batch_pid
+ input_tiled_row_stride * row_block_pid
+ input_tiled_col_stride * col_block_pid
+ input_row_block_stride * row_block_arange[:, None]
+ input_col_block_stride * col_block_arange[None, :]
)
if beta_is_one:
output_acc_block += tl.load(input_ptrs)
else:
output_acc_block += beta * tl.load(input_ptrs)
# write back the result
tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty))

View File

@ -599,6 +599,8 @@ def tune_bsr_dense_addmm(
*,
beta=1,
alpha=1,
left_alpha=None,
right_alpha=None,
out=None,
store=False,
verbose=False,
@ -658,7 +660,15 @@ def tune_bsr_dense_addmm(
def bench(meta, input=input, bsr=bsr, dense=dense, alpha=alpha, out=out):
def test_func():
return bsr_dense_addmm(
input, bsr, dense, beta=beta, alpha=alpha, meta=meta, out=out
input,
bsr,
dense,
beta=beta,
alpha=alpha,
left_alpha=left_alpha,
right_alpha=right_alpha,
meta=meta,
out=out,
)
return triton.testing.do_bench(test_func, warmup=500, rep=100)
@ -723,6 +733,8 @@ def optimize_bsr_dense_addmm(
bk,
beta=1,
alpha=1,
use_left_alpha=False,
use_right_alpha=False,
dtype=torch.float16,
device="cuda",
sparsity=0.5,
@ -736,12 +748,18 @@ def optimize_bsr_dense_addmm(
).to_sparse_bsr((bm, bk))
dense = make_tensor(k, n, dtype=dtype, device=device)
input = make_tensor(m, n, dtype=dtype, device=device)
left_alpha = make_tensor(m, dtype=dtype, device=device) if use_left_alpha else None
right_alpha = (
make_tensor(n, dtype=dtype, device=device) if use_right_alpha else None
)
tune_bsr_dense_addmm(
input,
bsr,
dense,
beta=beta,
alpha=alpha,
left_alpha=left_alpha,
right_alpha=right_alpha,
store=True,
force=force,
verbose=verbose,