mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add scaling arguments to bsr_dense_addmm (#136104)
As in the title. Tackles https://github.com/pytorch/ao/pull/821/files#r1759821413 The PR assumes that the existing tuning parameters are good also when using scaling arguments. This needs to be verified as a follow-up task. Also, this PR redefines triton-contiguous tensors: the tensor must have strides not larger than 1. This will now allow zero strides that previously triggered `contiguous` call although the underlying memory buffer was contiguous. Re: "a considerable slow-down occurs because tensor data is copied element-wise rather than chunk-wise" - this note should refer to a code (torch or triton?) that implements the element/chunk-wise copy so that we could verify that allowing zero strides indeed would not trigger element-wise copies. Atm, the performance increase in ViT-H benchmarks (that involve using 0 strides) is an evidence that allowing zero strides does not lead to slow-downs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136104 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
bfbcdf4967
commit
b76d1b79e6
@ -4027,6 +4027,7 @@ class TestSparseCompressedTritonKernels(TestCase):
|
||||
@skipIfRocm
|
||||
@dtypes(torch.half, torch.bfloat16, torch.float, torch.int8)
|
||||
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float, torch.int8)
|
||||
@precisionOverride({torch.float16: 6e-1})
|
||||
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
|
||||
def test_triton_kernel(self, op, device, dtype, blocksize):
|
||||
from torch.sparse._triton_ops import bsr_dense_addmm, bsr_dense_mm, _int_bsr_dense_addmm
|
||||
@ -4039,15 +4040,24 @@ class TestSparseCompressedTritonKernels(TestCase):
|
||||
operation = dict(bsr_dense_addmm=bsr_dense_addmm, bsr_dense_mm=bsr_dense_mm, bsr_dense_linear=bsr_dense_linear,
|
||||
_int_bsr_dense_addmm=_int_bsr_dense_addmm)[op]
|
||||
|
||||
def reference(input, mat1, mat2, beta=1, alpha=1, op=op):
|
||||
def reference(input, mat1, mat2, beta=1, alpha=1, left_alpha=None, right_alpha=None, op=op):
|
||||
assert mat1.layout is torch.strided
|
||||
assert mat2.layout is torch.strided
|
||||
if dtype is torch.int8:
|
||||
if op == '_int_bsr_dense_addmm':
|
||||
return beta * input + alpha * torch._int_mm(mat1, mat2)
|
||||
# workaround RuntimeError: "addmm_cuda" not implemented for 'Char'
|
||||
return beta * input + alpha * torch._int_mm(mat1, mat2).to(torch.int8)
|
||||
return beta * input + alpha * (mat1 @ mat2)
|
||||
mat12 = torch._int_mm(mat1, mat2)
|
||||
else:
|
||||
# workaround RuntimeError: "addmm_cuda" not implemented for 'Char'
|
||||
mat12 = torch._int_mm(mat1, mat2).to(torch.int8)
|
||||
else:
|
||||
mat12 = mat1 @ mat2
|
||||
if alpha != 1:
|
||||
mat12 *= alpha
|
||||
if left_alpha is not None:
|
||||
mat12 = left_alpha.reshape(*left_alpha.shape[:-1], -1, 1) * mat12
|
||||
if right_alpha is not None:
|
||||
mat12 = mat12 * right_alpha.reshape(*right_alpha.shape[:-1], 1, -1)
|
||||
return beta * input + mat12
|
||||
|
||||
if op == '_int_bsr_dense_addmm':
|
||||
# _int_bsr_dense_addmm is same as bsr_dense_addmm except
|
||||
@ -4056,6 +4066,8 @@ class TestSparseCompressedTritonKernels(TestCase):
|
||||
# definitions above and all other definitions below are
|
||||
# identical between _int_bsr_dense_addmm and
|
||||
# bsr_dense_addmm.
|
||||
if dtype.is_floating_point or dtype.is_complex:
|
||||
self.skipTest(f"Redundant test: {op} on {dtype} tensors")
|
||||
op = 'bsr_dense_addmm'
|
||||
|
||||
def nc_copy(t, axes=(-1,)):
|
||||
@ -4101,14 +4113,21 @@ class TestSparseCompressedTritonKernels(TestCase):
|
||||
blocks_per_row_lst = [1, 2]
|
||||
blocks_per_col_lst = [1, 2]
|
||||
result_cols_lst = [16, 32, 64]
|
||||
for beta, alpha, sparsity, blocks_per_row, blocks_per_col, N in itertools.product(
|
||||
beta_lst, alpha_lst, sparsity_lst, blocks_per_row_lst, blocks_per_col_lst, result_cols_lst):
|
||||
has_left_alpha_lst = dict(bsr_dense_addmm=[False, True], bsr_dense_mm=[False], bsr_dense_linear=[False])[op]
|
||||
has_right_alpha_lst = dict(bsr_dense_addmm=[False, True], bsr_dense_mm=[False], bsr_dense_linear=[False])[op]
|
||||
high = 1.5 + int(dtype is torch.int8)
|
||||
for beta, alpha, sparsity, blocks_per_row, blocks_per_col, N, has_left_alpha, has_right_alpha in itertools.product(
|
||||
beta_lst, alpha_lst, sparsity_lst, blocks_per_row_lst, blocks_per_col_lst, result_cols_lst,
|
||||
has_left_alpha_lst, has_right_alpha_lst):
|
||||
M = BM * blocks_per_row
|
||||
K = BK * blocks_per_col
|
||||
mat1 = create_blocked_tensor(0, M, K, (BM, BK), sparsity, dtype, device=device)
|
||||
bsr = mat1.to_sparse_bsr((BM, BK))
|
||||
mat2 = make_tensor(K, N, dtype=dtype, device=device, low=0.5, high=1.5)
|
||||
input = make_tensor(M, N, dtype=dtype, device=device, low=0.5, high=1.5)
|
||||
mat2 = make_tensor(K, N, dtype=dtype, device=device, low=0.5, high=high)
|
||||
input = make_tensor(M, N, dtype=dtype, device=device, low=0.5, high=high)
|
||||
|
||||
left_alpha = make_tensor(M, dtype=dtype, device=device, low=0.5, high=high) if has_left_alpha else None
|
||||
right_alpha = make_tensor(N, dtype=dtype, device=device, low=0.5, high=high) if has_right_alpha else None
|
||||
|
||||
if 0 and op == "bsr_dense_addmm":
|
||||
# Find optimal kernel parameters, the speed-up is
|
||||
@ -4121,12 +4140,12 @@ class TestSparseCompressedTritonKernels(TestCase):
|
||||
meta = get_meta(op, key, version=(0, dtype, 0.5))
|
||||
if meta is None:
|
||||
optimize_bsr_dense_addmm(M, K, N, BM, BK, beta=beta, alpha=alpha, dtype=dtype, sparsity=0.5)
|
||||
meta = get_meta(op, key, version=(0, dtype, 0.5))
|
||||
assert meta is not None
|
||||
dump() # this will update torch/sparse/_triton_ops_meta.py
|
||||
|
||||
expected = reference(input, mat1, mat2, beta=beta, alpha=alpha)
|
||||
kwargs = dict(bsr_dense_addmm=dict(beta=beta, alpha=alpha), bsr_dense_mm={},
|
||||
expected = reference(input, mat1, mat2, beta=beta, alpha=alpha, left_alpha=left_alpha, right_alpha=right_alpha)
|
||||
kwargs = dict(bsr_dense_addmm=dict(beta=beta, alpha=alpha,
|
||||
left_alpha=left_alpha, right_alpha=right_alpha), bsr_dense_mm={},
|
||||
bsr_dense_linear=dict(bias=input.transpose(-1, -2)))[op]
|
||||
|
||||
args = dict(bsr_dense_addmm=(input, bsr, mat2), bsr_dense_mm=(bsr, mat2),
|
||||
@ -4156,7 +4175,7 @@ class TestSparseCompressedTritonKernels(TestCase):
|
||||
if op in {'bsr_dense_addmm', 'bsr_dense_linear'}:
|
||||
args = dict(bsr_dense_addmm=(nc_input, bsr, nc_mat2),
|
||||
bsr_dense_linear=(nc_mat2.transpose(-1, -2), bsr))[op]
|
||||
kwargs = dict(bsr_dense_addmm=dict(beta=beta, alpha=alpha),
|
||||
kwargs = dict(bsr_dense_addmm=dict(beta=beta, alpha=alpha, left_alpha=left_alpha, right_alpha=right_alpha),
|
||||
bsr_dense_linear=dict(bias=nc_input.transpose(-1, -2)))[op]
|
||||
result = operation(*args, **kwargs)
|
||||
self.assertEqual(result, expected)
|
||||
|
@ -89,14 +89,14 @@ def make_triton_contiguous(t):
|
||||
"""Return input as a triton-contiguous tensor.
|
||||
|
||||
A triton-contiguous tensor is defined as a tensor that has strides
|
||||
with minimal value equal to 1.
|
||||
with minimal value smaller than or equal to 1.
|
||||
|
||||
While triton kernels support triton-non-contiguous tensors (all
|
||||
strides being greater than 1 or having 0 strides) arguments, a
|
||||
considerable slow-down occurs because tensor data is copied
|
||||
element-wise rather than chunk-wise.
|
||||
strides being greater than 1) arguments, a considerable slow-down
|
||||
occurs because tensor data is copied element-wise rather than
|
||||
chunk-wise. Zero strides is assumed to not have this defect.
|
||||
"""
|
||||
if min(t.stride()) != 1:
|
||||
if min(t.stride()) > 1:
|
||||
# TODO: investigate if contiguity along other axes than the
|
||||
# last one can be beneficial for performance
|
||||
return t.contiguous()
|
||||
@ -1097,6 +1097,8 @@ def _int_bsr_dense_addmm(
|
||||
*,
|
||||
beta=1,
|
||||
alpha=1,
|
||||
left_alpha: Optional[torch.Tensor] = None,
|
||||
right_alpha: Optional[torch.Tensor] = None,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
skip_checks: bool = False,
|
||||
max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,
|
||||
@ -1120,6 +1122,8 @@ def _int_bsr_dense_addmm(
|
||||
dense,
|
||||
beta=beta,
|
||||
alpha=alpha,
|
||||
left_alpha=left_alpha,
|
||||
right_alpha=right_alpha,
|
||||
out=out,
|
||||
skip_checks=skip_checks,
|
||||
max_grid=max_grid,
|
||||
@ -1134,11 +1138,21 @@ def bsr_dense_addmm(
|
||||
*,
|
||||
beta=1,
|
||||
alpha=1,
|
||||
left_alpha: Optional[torch.Tensor] = None,
|
||||
right_alpha: Optional[torch.Tensor] = None,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
skip_checks: bool = False,
|
||||
max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,
|
||||
meta: Optional[dict] = None,
|
||||
):
|
||||
"""Compute
|
||||
|
||||
out = beta * input + left_alpha.reshape(-1, 1) * (alpha * (bsr @ dense)) * right_alpha.reshape(1, -1)
|
||||
|
||||
where left_alpha, right_alpha are (* + 1)-D tensors when
|
||||
specified, otherwise, these are treated as tensors filled with
|
||||
ones.
|
||||
"""
|
||||
f_name = "bsr_dense_addmm"
|
||||
values = bsr.values()
|
||||
crow_indices = bsr.crow_indices()
|
||||
@ -1150,8 +1164,8 @@ def bsr_dense_addmm(
|
||||
|
||||
# todo: implement checks
|
||||
|
||||
original_batch_dims_broadcasted = broadcast_batch_dims(f_name, bsr, dense)
|
||||
if out is None:
|
||||
original_batch_dims_broadcasted = broadcast_batch_dims(f_name, bsr, dense)
|
||||
out = dense.new_empty(original_batch_dims_broadcasted + (M, N))
|
||||
|
||||
if bsr._nnz() == 0 or alpha == 0 or N == 0 or M == 0 or K == 0:
|
||||
@ -1163,6 +1177,30 @@ def bsr_dense_addmm(
|
||||
out.mul_(beta)
|
||||
return out
|
||||
|
||||
left_alpha_is_one = False
|
||||
right_alpha_is_one = False
|
||||
if left_alpha is None:
|
||||
left_alpha_is_one = True
|
||||
left_alpha = dense.new_empty(()).expand(
|
||||
*original_batch_dims_broadcasted, M, N
|
||||
) # not referenced
|
||||
else:
|
||||
left_alpha = left_alpha.view(*original_batch_dims_broadcasted, M, 1).expand(
|
||||
*original_batch_dims_broadcasted, M, N
|
||||
)
|
||||
|
||||
if right_alpha is None:
|
||||
right_alpha_is_one = True
|
||||
right_alpha = dense.new_empty(()).expand(
|
||||
*original_batch_dims_broadcasted, M, N
|
||||
) # not referenced
|
||||
else:
|
||||
right_alpha = right_alpha.view(*original_batch_dims_broadcasted, 1, N).expand(
|
||||
*original_batch_dims_broadcasted, M, N
|
||||
)
|
||||
assert left_alpha.stride()[-1] == 0
|
||||
assert right_alpha.stride()[-2] == 0
|
||||
|
||||
if meta is None:
|
||||
sparsity = round(1 - bsr._nnz() * blocksize[0] * blocksize[1] / (M * K), 2)
|
||||
meta = bsr_dense_addmm_meta(
|
||||
@ -1178,9 +1216,16 @@ def bsr_dense_addmm(
|
||||
)
|
||||
out_backup = out
|
||||
|
||||
crow_indices, col_indices, values, input, dense, out = prepare_inputs(
|
||||
bsr, input, dense, out
|
||||
)
|
||||
(
|
||||
crow_indices,
|
||||
col_indices,
|
||||
values,
|
||||
input,
|
||||
dense,
|
||||
left_alpha,
|
||||
right_alpha,
|
||||
out,
|
||||
) = prepare_inputs(bsr, input, dense, left_alpha, right_alpha, out)
|
||||
|
||||
BM, BK = blocksize
|
||||
SPLIT_N = meta.get("SPLIT_N", N // BM)
|
||||
@ -1191,6 +1236,9 @@ def bsr_dense_addmm(
|
||||
dense = tile_to_blocksize(dense, (BK, BN))
|
||||
input = tile_to_blocksize(input, (BM, BN))
|
||||
|
||||
left_alpha = tile_to_blocksize(left_alpha, (BM, BN))
|
||||
right_alpha = tile_to_blocksize(right_alpha, (BM, BN))
|
||||
|
||||
dot_out_dtype = {
|
||||
torch.float16: tl.float32,
|
||||
torch.bfloat16: tl.float32,
|
||||
@ -1216,6 +1264,8 @@ def bsr_dense_addmm(
|
||||
col_indices: (0, None, None),
|
||||
input: (0, -3, -4),
|
||||
dense: (0, -3, None),
|
||||
left_alpha: (0, -3, -4),
|
||||
right_alpha: (0, -3, -4),
|
||||
out: (0, -3, -4),
|
||||
}
|
||||
|
||||
@ -1229,6 +1279,8 @@ def bsr_dense_addmm(
|
||||
beta_is_one=beta == 1,
|
||||
beta_is_nonzero=beta != 0,
|
||||
alpha_is_one=alpha == 1,
|
||||
left_alpha_is_one=left_alpha_is_one,
|
||||
right_alpha_is_one=right_alpha_is_one,
|
||||
BLOCKSIZE_ROW=BM,
|
||||
BLOCKSIZE_INNER=BK,
|
||||
BLOCKSIZE_COL=BN,
|
||||
@ -2278,6 +2330,22 @@ if has_triton():
|
||||
dense_row_block_stride,
|
||||
dense_col_block_stride,
|
||||
# dense epilogue
|
||||
# left_alpha prologue
|
||||
left_alpha_ptr,
|
||||
left_alpha_batch_stride,
|
||||
left_alpha_tiled_row_stride,
|
||||
left_alpha_tiled_col_stride: tl.constexpr,
|
||||
left_alpha_row_block_stride,
|
||||
left_alpha_col_block_stride: tl.constexpr,
|
||||
# left_alpha epilogue
|
||||
# right_alpha prologue
|
||||
right_alpha_ptr,
|
||||
right_alpha_batch_stride,
|
||||
right_alpha_tiled_row_stride: tl.constexpr,
|
||||
right_alpha_tiled_col_stride,
|
||||
right_alpha_row_block_stride: tl.constexpr,
|
||||
right_alpha_col_block_stride,
|
||||
# right_alpha epilogue
|
||||
# output prologue
|
||||
output_ptr,
|
||||
output_batch_stride,
|
||||
@ -2291,6 +2359,8 @@ if has_triton():
|
||||
beta_is_one: tl.constexpr,
|
||||
beta_is_nonzero: tl.constexpr,
|
||||
alpha_is_one: tl.constexpr,
|
||||
left_alpha_is_one: tl.constexpr,
|
||||
right_alpha_is_one: tl.constexpr,
|
||||
BLOCKSIZE_ROW: tl.constexpr,
|
||||
BLOCKSIZE_COL: tl.constexpr,
|
||||
BLOCKSIZE_INNER: tl.constexpr,
|
||||
@ -2299,6 +2369,12 @@ if has_triton():
|
||||
GROUP_SIZE_ROW: tl.constexpr,
|
||||
SPLIT_N: tl.constexpr,
|
||||
):
|
||||
# left/right_alpha tensors are originally (* + 1)-dimensional
|
||||
assert left_alpha_tiled_col_stride == 0
|
||||
assert left_alpha_col_block_stride == 0
|
||||
assert right_alpha_tiled_row_stride == 0
|
||||
assert right_alpha_row_block_stride == 0
|
||||
|
||||
batch_pid = tl.program_id(axis=2)
|
||||
row_block_pid = tl.program_id(axis=0)
|
||||
col_block_pid = tl.program_id(axis=1)
|
||||
@ -2324,17 +2400,6 @@ if has_triton():
|
||||
inner_block_arange = tl.arange(0, BLOCKSIZE_INNER)
|
||||
col_block_arange = tl.arange(0, BLOCKSIZE_COL)
|
||||
|
||||
if beta_is_nonzero:
|
||||
# Pointers are set to exact write-to locations
|
||||
input_ptrs = (
|
||||
input_ptr
|
||||
+ input_batch_stride * batch_pid
|
||||
+ input_tiled_row_stride * row_block_pid
|
||||
+ input_tiled_col_stride * col_block_pid
|
||||
+ input_row_block_stride * row_block_arange[:, None]
|
||||
+ input_col_block_stride * col_block_arange[None, :]
|
||||
)
|
||||
|
||||
# Pointers are set to the first block of the current row.
|
||||
values_block_ptrs = (
|
||||
values_ptr
|
||||
@ -2371,14 +2436,7 @@ if has_triton():
|
||||
+ col_indices_stride * nnz_offset
|
||||
)
|
||||
|
||||
# alpha is never 0
|
||||
if beta_is_nonzero:
|
||||
output_acc_block = tl.load(input_ptrs).to(acc_dtype) # type: ignore[possibly-undefined]
|
||||
if not (beta_is_one and alpha_is_one):
|
||||
beta_alpha = beta / alpha
|
||||
output_acc_block *= beta_alpha
|
||||
else:
|
||||
output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype)
|
||||
output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype)
|
||||
|
||||
for _ in range(row_nnz):
|
||||
values_block = tl.load(values_block_ptrs)
|
||||
@ -2402,6 +2460,42 @@ if has_triton():
|
||||
if not alpha_is_one:
|
||||
output_acc_block *= alpha
|
||||
|
||||
if not left_alpha_is_one:
|
||||
left_alpha_ptrs = (
|
||||
left_alpha_ptr
|
||||
+ left_alpha_batch_stride * batch_pid
|
||||
+ left_alpha_tiled_row_stride * row_block_pid
|
||||
+ left_alpha_tiled_col_stride * col_block_pid
|
||||
+ left_alpha_row_block_stride * row_block_arange[:, None]
|
||||
+ left_alpha_col_block_stride * col_block_arange[None, :]
|
||||
)
|
||||
output_acc_block *= tl.load(left_alpha_ptrs)
|
||||
|
||||
if not right_alpha_is_one:
|
||||
right_alpha_ptrs = (
|
||||
right_alpha_ptr
|
||||
+ right_alpha_batch_stride * batch_pid
|
||||
+ right_alpha_tiled_row_stride * row_block_pid
|
||||
+ right_alpha_tiled_col_stride * col_block_pid
|
||||
+ right_alpha_row_block_stride * row_block_arange[:, None]
|
||||
+ right_alpha_col_block_stride * col_block_arange[None, :]
|
||||
)
|
||||
output_acc_block *= tl.load(right_alpha_ptrs)
|
||||
|
||||
if beta_is_nonzero:
|
||||
input_ptrs = (
|
||||
input_ptr
|
||||
+ input_batch_stride * batch_pid
|
||||
+ input_tiled_row_stride * row_block_pid
|
||||
+ input_tiled_col_stride * col_block_pid
|
||||
+ input_row_block_stride * row_block_arange[:, None]
|
||||
+ input_col_block_stride * col_block_arange[None, :]
|
||||
)
|
||||
if beta_is_one:
|
||||
output_acc_block += tl.load(input_ptrs)
|
||||
else:
|
||||
output_acc_block += beta * tl.load(input_ptrs)
|
||||
|
||||
# write back the result
|
||||
tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty))
|
||||
|
||||
|
@ -599,6 +599,8 @@ def tune_bsr_dense_addmm(
|
||||
*,
|
||||
beta=1,
|
||||
alpha=1,
|
||||
left_alpha=None,
|
||||
right_alpha=None,
|
||||
out=None,
|
||||
store=False,
|
||||
verbose=False,
|
||||
@ -658,7 +660,15 @@ def tune_bsr_dense_addmm(
|
||||
def bench(meta, input=input, bsr=bsr, dense=dense, alpha=alpha, out=out):
|
||||
def test_func():
|
||||
return bsr_dense_addmm(
|
||||
input, bsr, dense, beta=beta, alpha=alpha, meta=meta, out=out
|
||||
input,
|
||||
bsr,
|
||||
dense,
|
||||
beta=beta,
|
||||
alpha=alpha,
|
||||
left_alpha=left_alpha,
|
||||
right_alpha=right_alpha,
|
||||
meta=meta,
|
||||
out=out,
|
||||
)
|
||||
|
||||
return triton.testing.do_bench(test_func, warmup=500, rep=100)
|
||||
@ -723,6 +733,8 @@ def optimize_bsr_dense_addmm(
|
||||
bk,
|
||||
beta=1,
|
||||
alpha=1,
|
||||
use_left_alpha=False,
|
||||
use_right_alpha=False,
|
||||
dtype=torch.float16,
|
||||
device="cuda",
|
||||
sparsity=0.5,
|
||||
@ -736,12 +748,18 @@ def optimize_bsr_dense_addmm(
|
||||
).to_sparse_bsr((bm, bk))
|
||||
dense = make_tensor(k, n, dtype=dtype, device=device)
|
||||
input = make_tensor(m, n, dtype=dtype, device=device)
|
||||
left_alpha = make_tensor(m, dtype=dtype, device=device) if use_left_alpha else None
|
||||
right_alpha = (
|
||||
make_tensor(n, dtype=dtype, device=device) if use_right_alpha else None
|
||||
)
|
||||
tune_bsr_dense_addmm(
|
||||
input,
|
||||
bsr,
|
||||
dense,
|
||||
beta=beta,
|
||||
alpha=alpha,
|
||||
left_alpha=left_alpha,
|
||||
right_alpha=right_alpha,
|
||||
store=True,
|
||||
force=force,
|
||||
verbose=verbose,
|
||||
|
Reference in New Issue
Block a user