Add tests for bsr_dense_addmm and bsr_dense_mm triton kernels (#114800)

As in the title.

In addition,
- resolve https://github.com/pytorch/pytorch/pull/114757#discussion_r1409547917 re triton-contiguous inputs
- support non-contiguous inputs and outputs in triton kernels
- fix a couple of minor bugs

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114800
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Pearu Peterson
2023-12-03 12:55:38 +00:00
committed by PyTorch MergeBot
parent aafa8233a4
commit 4ba37e1804
3 changed files with 830 additions and 25 deletions

View File

@ -3878,6 +3878,130 @@ class TestSparseCompressedTritonKernels(TestCase):
# but key is still valid:
self.assertEqual(d.get(key5), (key5, 567), **assertEqualOptions)
@parametrize("op", ['bsr_dense_addmm', 'bsr_dense_mm', 'bsr_dense_linear'])
@parametrize("blocksize", [16, '16x32', 32])
@onlyCUDA
@skipIfRocm
@dtypes(torch.half, torch.bfloat16, torch.float)
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
def test_triton_kernel(self, op, device, dtype, blocksize):
from torch.sparse._triton_ops import bsr_dense_addmm, bsr_dense_mm
from torch.sparse._triton_ops_meta import (create_blocked_tensor, get_meta,
optimize_bsr_dense_addmm, optimize_bsr_dense_mm, dump)
def bsr_dense_linear(input, weights, bias=None):
return torch.nn.functional.linear(input, weights, bias=bias).transpose(-1, -2)
operation = dict(bsr_dense_addmm=bsr_dense_addmm, bsr_dense_mm=bsr_dense_mm, bsr_dense_linear=bsr_dense_linear)[op]
def reference(input, mat1, mat2, beta=1, alpha=1):
assert mat1.layout is torch.strided
assert mat2.layout is torch.strided
return beta * input + alpha * (mat1 @ mat2)
def nc_copy(t, axes=(-1,)):
"""Return a copy of input.
The returned copy will be a non-contiguous tensor.
"""
if t.layout is torch.strided:
shape = list(t.shape)
for a in axes:
shape[a] *= 2
r = torch.empty(shape, dtype=t.dtype, device=t.device)
s = r[tuple(slice(None, None, 2 if t.shape[i] != r.shape[i] else None) for i in range(t.ndim))]
s.copy_(t)
return s
elif t.layout is torch.sparse_bsr:
compressed_indices = t.crow_indices()
plain_indices = t.col_indices()
return torch.sparse_compressed_tensor(compressed_indices, plain_indices, nc_copy(t.values()),
t.shape, layout=t.layout)
else:
raise NotImplementedError(t.layout)
if isinstance(blocksize, str):
BM, BK = tuple(map(int, blocksize.split('x')))
else:
BM, BK = (blocksize,) * 2
if op in {"bsr_dense_mm", "bsr_dense_linear"} and BM != BK:
# todo: eliminate this skip
self.skipTest(f"{op} does not support non-square blocks")
beta_lst = dict(bsr_dense_addmm=[0, 1, 2], bsr_dense_mm=[0], bsr_dense_linear=[1])[op]
alpha_lst = dict(bsr_dense_addmm=[0, 1, 2], bsr_dense_mm=[1], bsr_dense_linear=[1])[op]
sparsity_lst = [0, 0.5, 1]
# todo: eliminate `[:1]`
blocks_per_row_lst = dict(bsr_dense_addmm=[1, 2], bsr_dense_mm=[1, 2][:1], bsr_dense_linear=[1, 2])[op]
blocks_per_col_lst = [1, 2]
# todo: eliminate `[1:]`
result_cols_lst = dict(bsr_dense_addmm=[16, 32, 64], bsr_dense_mm=[16, 32, 64][1:],
bsr_dense_linear=[16, 32, 64])[op]
for beta, alpha, sparsity, blocks_per_row, blocks_per_col, N in itertools.product(
beta_lst, alpha_lst, sparsity_lst, blocks_per_row_lst, blocks_per_col_lst, result_cols_lst):
M = BM * blocks_per_row
K = BK * blocks_per_col
mat1 = create_blocked_tensor(0, M, K, (BM, BK), sparsity, dtype, device=device)
bsr = mat1.to_sparse_bsr((BM, BK))
mat2 = make_tensor(K, N, dtype=dtype, device=device, low=0.5, high=1.5)
input = make_tensor(M, N, dtype=dtype, device=device, low=0.5, high=1.5)
if 0 and op != "bsr_dense_linear":
# Find optimal kernel parameters, the speed-up is
# about 10x for running this test.
#
# Enable this if-block when the test method is
# updated, run the test, and finally, disable the
# if-block.
key = dict(bsr_dense_addmm=(M, K, N, BM, BK, beta == 0, beta == 1, alpha == 1),
bsr_dense_mm=(M, K, N, BM, BK))[op]
meta = get_meta(op, key, version=(0, dtype, 0.5))
if meta is None:
if op == 'bsr_dense_addmm':
optimize_bsr_dense_addmm(M, K, N, BM, BK, beta=beta, alpha=alpha, dtype=dtype, sparsity=0.5)
elif op == 'bsr_dense_mm':
optimize_bsr_dense_mm(M, K, N, BM, BK, dtype=dtype, sparsity=0.5)
else:
raise NotImplementedError(op)
meta = get_meta(op, key, version=(0, dtype, 0.5))
assert meta is not None
dump() # this will update torch/sparse/_triton_ops_meta.py
expected = reference(input, mat1, mat2, beta=beta, alpha=alpha)
kwargs = dict(bsr_dense_addmm=dict(beta=beta, alpha=alpha), bsr_dense_mm=dict(),
bsr_dense_linear=dict(bias=input.transpose(-1, -2)))[op]
args = dict(bsr_dense_addmm=(input, bsr, mat2), bsr_dense_mm=(bsr, mat2),
bsr_dense_linear=(mat2.transpose(-1, -2), bsr))[op]
result = operation(*args, **kwargs)
self.assertEqual(result, expected)
# Test non-contiguous input tensors:
nc_mat2 = nc_copy(mat2)
nc_input = nc_copy(input)
nc_bsr = nc_copy(bsr)
args = dict(bsr_dense_addmm=(input, bsr, nc_mat2), bsr_dense_mm=(bsr, nc_mat2),
bsr_dense_linear=(nc_mat2.transpose(-1, -2), bsr))[op]
result = operation(*args, **kwargs)
self.assertEqual(result, expected)
# todo: add bsr_dense_linear to the set below
if op in {'bsr_dense_addmm', 'bsr_dense_mm'}:
args = dict(bsr_dense_addmm=(input, nc_bsr, mat2), bsr_dense_mm=(nc_bsr, mat2),
bsr_dense_linear=(mat2.transpose(-1, -2), nc_bsr))[op]
result = operation(*args, **kwargs)
self.assertEqual(result, expected)
if op in {'bsr_dense_addmm', 'bsr_dense_linear'}:
args = dict(bsr_dense_addmm=(nc_input, bsr, nc_mat2),
bsr_dense_linear=(nc_mat2.transpose(-1, -2), bsr))[op]
kwargs = dict(bsr_dense_addmm=dict(beta=beta, alpha=alpha),
bsr_dense_linear=dict(bias=nc_input.transpose(-1, -2)))[op]
result = operation(*args, **kwargs)
self.assertEqual(result, expected)
# e.g., TestSparseCSRCPU and TestSparseCSRCUDA
instantiate_device_type_tests(TestSparseCSR, globals())

View File

@ -77,12 +77,19 @@ def check_blocksize(f_name, blocksize):
def make_triton_contiguous(t):
# TODO: Why do we need "triton contiguity" that is not defined by
# triton itself? It looks like it is required until
# openai/triton#1291 fixed a bug in processing triton kernel
# arguments. Unless triton comntiguity is required for
# performance, remove this function.
if (t.stride(-2) > 1 or t.dtype is torch.float32) and t.stride(-1) > 1:
"""Return input as a triton-contiguous tensor.
A triton-contiguous tensor is defined as a tensor that has strides
with minimal value equal to 1.
While triton kernels support triton-non-contiguous tensors (all
strides being greater than 1 or having 0 strides) arguments, a
considerable slow-down occurs because tensor data is copied
element-wise rather than chunk-wise.
"""
if min(t.stride()) != 1:
# TODO: investigate if contiguity along other axes than the
# last one can be beneficial for performance
return t.contiguous()
else:
return t
@ -162,16 +169,12 @@ def launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks=None):
kernel(grid, *sliced_tensors)
def prepare_inputs(bsr, *dense_tensors, require_view=False):
def prepare_inputs(bsr, *dense_tensors):
# Introduce fake batch dimension if not present for convenience.
crow_indices = bsr.crow_indices().unsqueeze(0)
col_indices = bsr.col_indices().unsqueeze(0)
if require_view:
values = bsr.values().unsqueeze(0)
tensors = [t.unsqueeze(0) for t in dense_tensors]
else:
values = make_triton_contiguous(bsr.values().unsqueeze(0))
tensors = [make_triton_contiguous(t.unsqueeze(0)) for t in dense_tensors]
values = make_triton_contiguous(bsr.values().unsqueeze(0))
tensors = [make_triton_contiguous(t.unsqueeze(0)) for t in dense_tensors]
# Compute broadcasted batch dimension
batch_dims_broadcasted = torch.broadcast_shapes(values.shape[:-3], *(t.shape[:-2] for t in tensors))
@ -848,7 +851,7 @@ def bsr_dense_addmm(
out_backup = out
crow_indices, col_indices, values, input, dense, out = prepare_inputs(bsr, input, dense, out, require_view=True)
crow_indices, col_indices, values, input, dense, out = prepare_inputs(bsr, input, dense, out)
BM, BK = blocksize
SPLIT_N = meta.get('SPLIT_N', N // BM)
@ -856,12 +859,9 @@ def bsr_dense_addmm(
dense = tile_to_blocksize(dense, (BK, BN))
input = tile_to_blocksize(input, (BM, BN))
out_untiled = out
out = tile_to_blocksize(out, (BM, BN))
# out and out_backup may have different shapes/strides/offsets
# but they must share storage:
assert out.data_ptr() == out_backup.data_ptr()
dot_out_dtype = {torch.float16: tl.float32,
torch.bfloat16: tl.float32,
torch.float32: tl.float64,
@ -904,6 +904,11 @@ def bsr_dense_addmm(
launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks)
if out.data_ptr() != out_backup.data_ptr():
# prepare_inputs has made a copy of out, copy its content back
# to out_backup:
out_backup.copy_(out_untiled.view(out_backup.shape))
return out_backup
@ -1408,7 +1413,7 @@ if has_triton():
out_backup = out
# prepare inputs by reshaping them to be kernel-compatible.
crow_indices, col_indices, values, dense, out = prepare_inputs(bsr, dense, out, require_view=True)
crow_indices, col_indices, values, dense, out = prepare_inputs(bsr, dense, out)
# "Blockify" the row dimension of dense with blocksize[1]
# since dense is on the rhs of matmul
@ -1420,15 +1425,17 @@ if has_triton():
# so it could be any value in [1, dense.shape[-1]).
# We need to probably use the largest possible blocksize
# so that it fits into SRAM.
out_untiled = out
out = tile_to_blocksize(out, (blocksize[0], blocksize[0]))
# out and out_backup may have different shapes/strides/offsets
# but they must share storage:
assert out.data_ptr() == out_backup.data_ptr()
# Launch kernel
_run_dense_rowspace_kernel(blocksize, values, crow_indices, col_indices, dense, out, max_grid, meta)
if out.data_ptr() != out_backup.data_ptr():
# prepare_inputs has made a copy of out, copy its content
# back to out_backup:
out_backup.copy_(out_untiled.view(out_backup.shape))
return out_backup

View File

@ -518,7 +518,9 @@ def optimize_bsr_dense_addmm(
key = (m, k, n, bm, bk, beta == 0, beta == 1, alpha == 1)
version = (0, dtype, sparsity)
reference_meta = dict(GROUP_SIZE_ROW=1, num_stages=1, num_warps=4, SPLIT_N=n // bm)
reference_meta = dict(
GROUP_SIZE_ROW=1, num_stages=1, num_warps=4, SPLIT_N=max(n // bm, 1)
)
initial_meta = get_meta("bsr_dense_addmm", key, version=version, exact=True)
@ -556,7 +558,7 @@ def optimize_bsr_dense_addmm(
# value. The input value is assumed to be valid.
is_log = name in {"SPLIT_N", "num_warps"}
min_value = dict(SPLIT_N=1, num_warps=1, num_stages=1, GROUP_SIZE_ROW=1)[name]
max_value = dict(SPLIT_N=n // bm).get(name)
max_value = dict(SPLIT_N=max(n // bm, 1)).get(name)
value_step = dict(SPLIT_N=2, num_warps=2, num_stages=1, GROUP_SIZE_ROW=1)[name]
if is_log:
next_value = (
@ -736,6 +738,222 @@ _operation_device_version_data: Dict[Any, Dict] = {
#
# BEGIN GENERATED DATA
("bsr_dense_addmm", "NVIDIA A100-SXM4-80GB", (0, torch.bfloat16, 0.5)): {
(16, 16, 16, 16, 16, False, False, False): (2, 1, 1, 2),
(16, 16, 16, 16, 16, False, False, True): (1, 1, 1, 4),
(16, 16, 16, 16, 16, False, True, False): (1, 1, 3, 16),
(16, 16, 16, 16, 16, False, True, True): (1, 1, 1, 8),
(16, 16, 16, 16, 16, True, False, False): (2, 1, 1, 8),
(16, 16, 16, 16, 16, True, False, True): (1, 1, 1, 8),
(16, 16, 32, 16, 16, False, False, False): (1, 2, 1, 8),
(16, 16, 32, 16, 16, False, False, True): (1, 2, 2, 4),
(16, 16, 32, 16, 16, False, True, False): (1, 1, 2, 4),
(16, 16, 32, 16, 16, False, True, True): (1, 1, 2, 4),
(16, 16, 32, 16, 16, True, False, False): (1, 1, 2, 4),
(16, 16, 32, 16, 16, True, False, True): (2, 2, 1, 2),
(16, 16, 64, 16, 16, False, False, False): (1, 4, 2, 4),
(16, 16, 64, 16, 16, False, False, True): (1, 2, 1, 2),
(16, 16, 64, 16, 16, False, True, False): (2, 1, 1, 2),
(16, 16, 64, 16, 16, False, True, True): (1, 4, 1, 8),
(16, 16, 64, 16, 16, True, False, False): (1, 4, 1, 1),
(16, 16, 64, 16, 16, True, False, True): (1, 4, 2, 4),
(16, 32, 16, 16, 16, False, False, False): (1, 1, 2, 2),
(16, 32, 16, 16, 16, False, False, True): (1, 1, 1, 4),
(16, 32, 16, 16, 16, False, True, False): (1, 1, 1, 2),
(16, 32, 16, 16, 16, False, True, True): (1, 1, 1, 1),
(16, 32, 16, 16, 16, True, False, False): (1, 1, 1, 2),
(16, 32, 16, 16, 16, True, False, True): (2, 1, 1, 2),
(16, 32, 16, 16, 32, False, False, False): (1, 1, 1, 4),
(16, 32, 16, 16, 32, False, False, True): (1, 1, 1, 8),
(16, 32, 16, 16, 32, False, True, False): (1, 1, 1, 8),
(16, 32, 16, 16, 32, False, True, True): (1, 1, 2, 4),
(16, 32, 16, 16, 32, True, False, False): (1, 1, 1, 2),
(16, 32, 16, 16, 32, True, False, True): (1, 1, 1, 1),
(16, 32, 32, 16, 16, False, False, False): (2, 2, 1, 4),
(16, 32, 32, 16, 16, False, False, True): (2, 2, 1, 2),
(16, 32, 32, 16, 16, False, True, False): (1, 1, 2, 8),
(16, 32, 32, 16, 16, False, True, True): (1, 2, 1, 1),
(16, 32, 32, 16, 16, True, False, False): (1, 1, 1, 8),
(16, 32, 32, 16, 16, True, False, True): (1, 2, 1, 4),
(16, 32, 32, 16, 32, False, False, False): (1, 1, 2, 8),
(16, 32, 32, 16, 32, False, False, True): (2, 1, 1, 8),
(16, 32, 32, 16, 32, False, True, False): (1, 1, 1, 4),
(16, 32, 32, 16, 32, False, True, True): (1, 1, 1, 4),
(16, 32, 32, 16, 32, True, False, False): (1, 2, 1, 8),
(16, 32, 32, 16, 32, True, False, True): (1, 1, 1, 4),
(16, 32, 64, 16, 16, False, False, False): (1, 4, 3, 8),
(16, 32, 64, 16, 16, False, False, True): (1, 4, 1, 4),
(16, 32, 64, 16, 16, False, True, False): (1, 4, 1, 4),
(16, 32, 64, 16, 16, False, True, True): (2, 4, 1, 4),
(16, 32, 64, 16, 16, True, False, False): (1, 2, 1, 4),
(16, 32, 64, 16, 16, True, False, True): (1, 2, 1, 4),
(16, 32, 64, 16, 32, False, False, False): (1, 4, 1, 8),
(16, 32, 64, 16, 32, False, False, True): (1, 4, 1, 4),
(16, 32, 64, 16, 32, False, True, False): (1, 4, 1, 2),
(16, 32, 64, 16, 32, False, True, True): (1, 2, 1, 4),
(16, 32, 64, 16, 32, True, False, False): (1, 2, 1, 4),
(16, 32, 64, 16, 32, True, False, True): (1, 2, 1, 2),
(16, 64, 16, 16, 32, False, False, False): (1, 1, 1, 2),
(16, 64, 16, 16, 32, False, False, True): (1, 1, 2, 2),
(16, 64, 16, 16, 32, False, True, False): (1, 1, 2, 8),
(16, 64, 16, 16, 32, False, True, True): (1, 1, 1, 4),
(16, 64, 16, 16, 32, True, False, False): (1, 1, 1, 8),
(16, 64, 16, 16, 32, True, False, True): (1, 1, 1, 4),
(16, 64, 32, 16, 32, False, False, False): (1, 2, 1, 2),
(16, 64, 32, 16, 32, False, False, True): (1, 2, 1, 4),
(16, 64, 32, 16, 32, False, True, False): (1, 2, 1, 4),
(16, 64, 32, 16, 32, False, True, True): (2, 2, 1, 4),
(16, 64, 32, 16, 32, True, False, False): (1, 2, 1, 4),
(16, 64, 32, 16, 32, True, False, True): (1, 2, 1, 8),
(16, 64, 64, 16, 32, False, False, False): (1, 2, 1, 4),
(16, 64, 64, 16, 32, False, False, True): (1, 4, 2, 2),
(16, 64, 64, 16, 32, False, True, False): (1, 1, 1, 4),
(16, 64, 64, 16, 32, False, True, True): (1, 4, 1, 2),
(16, 64, 64, 16, 32, True, False, False): (1, 2, 1, 4),
(16, 64, 64, 16, 32, True, False, True): (1, 4, 1, 4),
(32, 16, 16, 16, 16, False, False, False): (1, 1, 1, 8),
(32, 16, 16, 16, 16, False, False, True): (1, 1, 2, 4),
(32, 16, 16, 16, 16, False, True, False): (1, 1, 1, 4),
(32, 16, 16, 16, 16, False, True, True): (1, 1, 2, 4),
(32, 16, 16, 16, 16, True, False, False): (1, 1, 1, 2),
(32, 16, 16, 16, 16, True, False, True): (1, 1, 1, 4),
(32, 16, 32, 16, 16, False, False, False): (1, 1, 1, 4),
(32, 16, 32, 16, 16, False, False, True): (2, 2, 1, 4),
(32, 16, 32, 16, 16, False, True, False): (1, 2, 2, 2),
(32, 16, 32, 16, 16, False, True, True): (2, 2, 1, 4),
(32, 16, 32, 16, 16, True, False, False): (1, 2, 2, 8),
(32, 16, 32, 16, 16, True, False, True): (1, 2, 1, 2),
(32, 16, 64, 16, 16, False, False, False): (1, 4, 1, 4),
(32, 16, 64, 16, 16, False, False, True): (1, 4, 2, 4),
(32, 16, 64, 16, 16, False, True, False): (1, 2, 2, 2),
(32, 16, 64, 16, 16, False, True, True): (3, 4, 1, 4),
(32, 16, 64, 16, 16, True, False, False): (1, 2, 1, 2),
(32, 16, 64, 16, 16, True, False, True): (1, 2, 1, 4),
(32, 32, 16, 16, 16, False, False, False): (1, 1, 3, 4),
(32, 32, 16, 16, 16, False, False, True): (1, 1, 1, 4),
(32, 32, 16, 16, 16, False, True, False): (1, 1, 1, 2),
(32, 32, 16, 16, 16, False, True, True): (1, 1, 1, 4),
(32, 32, 16, 16, 16, True, False, False): (1, 1, 1, 4),
(32, 32, 16, 16, 16, True, False, True): (1, 1, 2, 2),
(32, 32, 16, 16, 32, False, False, False): (2, 1, 1, 4),
(32, 32, 16, 16, 32, False, False, True): (1, 1, 1, 4),
(32, 32, 16, 16, 32, False, True, False): (1, 1, 1, 4),
(32, 32, 16, 16, 32, False, True, True): (3, 1, 2, 4),
(32, 32, 16, 16, 32, True, False, False): (1, 1, 1, 4),
(32, 32, 16, 16, 32, True, False, True): (1, 1, 1, 4),
(32, 32, 16, 32, 32, False, False, False): (1, 1, 1, 8),
(32, 32, 16, 32, 32, False, False, True): (1, 1, 1, 4),
(32, 32, 16, 32, 32, False, True, False): (1, 1, 2, 1),
(32, 32, 16, 32, 32, False, True, True): (2, 1, 2, 2),
(32, 32, 16, 32, 32, True, False, False): (1, 1, 1, 8),
(32, 32, 16, 32, 32, True, False, True): (2, 1, 3, 4),
(32, 32, 32, 16, 16, False, False, False): (1, 2, 1, 4),
(32, 32, 32, 16, 16, False, False, True): (2, 2, 1, 4),
(32, 32, 32, 16, 16, False, True, False): (1, 1, 1, 8),
(32, 32, 32, 16, 16, False, True, True): (2, 2, 1, 4),
(32, 32, 32, 16, 16, True, False, False): (1, 1, 1, 4),
(32, 32, 32, 16, 16, True, False, True): (2, 2, 2, 4),
(32, 32, 32, 16, 32, False, False, False): (2, 2, 1, 8),
(32, 32, 32, 16, 32, False, False, True): (1, 2, 1, 2),
(32, 32, 32, 16, 32, False, True, False): (1, 2, 1, 4),
(32, 32, 32, 16, 32, False, True, True): (1, 2, 1, 4),
(32, 32, 32, 16, 32, True, False, False): (1, 2, 1, 4),
(32, 32, 32, 16, 32, True, False, True): (1, 2, 1, 2),
(32, 32, 32, 32, 32, False, False, False): (1, 1, 3, 8),
(32, 32, 32, 32, 32, False, False, True): (1, 1, 1, 8),
(32, 32, 32, 32, 32, False, True, False): (2, 1, 3, 4),
(32, 32, 32, 32, 32, False, True, True): (2, 1, 1, 2),
(32, 32, 32, 32, 32, True, False, False): (1, 1, 1, 2),
(32, 32, 32, 32, 32, True, False, True): (4, 1, 1, 1),
(32, 32, 64, 16, 16, False, False, False): (1, 4, 1, 4),
(32, 32, 64, 16, 16, False, False, True): (1, 4, 1, 4),
(32, 32, 64, 16, 16, False, True, False): (1, 2, 1, 8),
(32, 32, 64, 16, 16, False, True, True): (1, 4, 1, 2),
(32, 32, 64, 16, 16, True, False, False): (2, 4, 1, 2),
(32, 32, 64, 16, 16, True, False, True): (1, 4, 1, 2),
(32, 32, 64, 16, 32, False, False, False): (1, 2, 1, 8),
(32, 32, 64, 16, 32, False, False, True): (1, 4, 2, 2),
(32, 32, 64, 16, 32, False, True, False): (1, 2, 1, 4),
(32, 32, 64, 16, 32, False, True, True): (1, 4, 1, 4),
(32, 32, 64, 16, 32, True, False, False): (1, 4, 2, 2),
(32, 32, 64, 16, 32, True, False, True): (3, 4, 2, 2),
(32, 32, 64, 32, 32, False, False, False): (2, 2, 1, 4),
(32, 32, 64, 32, 32, False, False, True): (1, 2, 1, 4),
(32, 32, 64, 32, 32, False, True, False): (1, 1, 1, 8),
(32, 32, 64, 32, 32, False, True, True): (1, 1, 1, 4),
(32, 32, 64, 32, 32, True, False, False): (1, 2, 1, 2),
(32, 32, 64, 32, 32, True, False, True): (3, 2, 1, 8),
(32, 64, 16, 16, 32, False, False, False): (1, 1, 2, 2),
(32, 64, 16, 16, 32, False, False, True): (1, 1, 1, 4),
(32, 64, 16, 16, 32, False, True, False): (1, 1, 2, 4),
(32, 64, 16, 16, 32, False, True, True): (1, 1, 1, 4),
(32, 64, 16, 16, 32, True, False, False): (1, 1, 1, 2),
(32, 64, 16, 16, 32, True, False, True): (2, 1, 2, 2),
(32, 64, 16, 32, 32, False, False, False): (1, 1, 1, 1),
(32, 64, 16, 32, 32, False, False, True): (2, 1, 1, 4),
(32, 64, 16, 32, 32, False, True, False): (1, 1, 1, 1),
(32, 64, 16, 32, 32, False, True, True): (1, 1, 2, 2),
(32, 64, 16, 32, 32, True, False, False): (1, 1, 2, 4),
(32, 64, 16, 32, 32, True, False, True): (1, 1, 1, 4),
(32, 64, 32, 16, 32, False, False, False): (2, 2, 1, 4),
(32, 64, 32, 16, 32, False, False, True): (1, 2, 1, 4),
(32, 64, 32, 16, 32, False, True, False): (1, 1, 1, 4),
(32, 64, 32, 16, 32, False, True, True): (2, 2, 3, 4),
(32, 64, 32, 16, 32, True, False, False): (1, 1, 1, 2),
(32, 64, 32, 16, 32, True, False, True): (1, 2, 1, 2),
(32, 64, 32, 32, 32, False, False, False): (1, 1, 1, 2),
(32, 64, 32, 32, 32, False, False, True): (2, 1, 1, 4),
(32, 64, 32, 32, 32, False, True, False): (1, 1, 1, 8),
(32, 64, 32, 32, 32, False, True, True): (1, 1, 2, 4),
(32, 64, 32, 32, 32, True, False, False): (2, 1, 1, 4),
(32, 64, 32, 32, 32, True, False, True): (1, 1, 2, 4),
(32, 64, 64, 16, 32, False, False, False): (1, 4, 1, 4),
(32, 64, 64, 16, 32, False, False, True): (1, 4, 2, 4),
(32, 64, 64, 16, 32, False, True, False): (1, 4, 2, 2),
(32, 64, 64, 16, 32, False, True, True): (1, 4, 1, 4),
(32, 64, 64, 16, 32, True, False, False): (1, 4, 1, 8),
(32, 64, 64, 16, 32, True, False, True): (1, 4, 2, 1),
(32, 64, 64, 32, 32, False, False, False): (1, 1, 1, 4),
(32, 64, 64, 32, 32, False, False, True): (2, 2, 1, 4),
(32, 64, 64, 32, 32, False, True, False): (1, 1, 1, 4),
(32, 64, 64, 32, 32, False, True, True): (2, 2, 1, 4),
(32, 64, 64, 32, 32, True, False, False): (1, 2, 2, 4),
(32, 64, 64, 32, 32, True, False, True): (2, 2, 3, 4),
(64, 32, 16, 32, 32, False, False, False): (1, 1, 1, 4),
(64, 32, 16, 32, 32, False, False, True): (1, 1, 1, 4),
(64, 32, 16, 32, 32, False, True, False): (1, 1, 1, 8),
(64, 32, 16, 32, 32, False, True, True): (1, 1, 1, 4),
(64, 32, 16, 32, 32, True, False, False): (1, 1, 1, 16),
(64, 32, 16, 32, 32, True, False, True): (2, 1, 1, 4),
(64, 32, 32, 32, 32, False, False, False): (1, 1, 3, 4),
(64, 32, 32, 32, 32, False, False, True): (2, 1, 1, 4),
(64, 32, 32, 32, 32, False, True, False): (1, 1, 2, 4),
(64, 32, 32, 32, 32, False, True, True): (2, 1, 1, 4),
(64, 32, 32, 32, 32, True, False, False): (2, 1, 1, 16),
(64, 32, 32, 32, 32, True, False, True): (2, 1, 1, 4),
(64, 32, 64, 32, 32, False, False, False): (1, 2, 1, 4),
(64, 32, 64, 32, 32, False, False, True): (2, 2, 1, 4),
(64, 32, 64, 32, 32, False, True, False): (1, 1, 1, 4),
(64, 32, 64, 32, 32, False, True, True): (2, 2, 1, 4),
(64, 32, 64, 32, 32, True, False, False): (1, 2, 1, 8),
(64, 32, 64, 32, 32, True, False, True): (2, 2, 3, 4),
(64, 64, 16, 32, 32, False, False, False): (1, 1, 2, 16),
(64, 64, 16, 32, 32, False, False, True): (1, 1, 3, 4),
(64, 64, 16, 32, 32, False, True, False): (1, 1, 1, 2),
(64, 64, 16, 32, 32, False, True, True): (2, 1, 1, 4),
(64, 64, 16, 32, 32, True, False, False): (2, 1, 3, 2),
(64, 64, 16, 32, 32, True, False, True): (1, 1, 2, 4),
(64, 64, 32, 32, 32, False, False, False): (1, 1, 1, 8),
(64, 64, 32, 32, 32, False, False, True): (2, 1, 2, 4),
(64, 64, 32, 32, 32, False, True, False): (2, 1, 1, 4),
(64, 64, 32, 32, 32, False, True, True): (1, 1, 2, 4),
(64, 64, 32, 32, 32, True, False, False): (2, 1, 1, 4),
(64, 64, 32, 32, 32, True, False, True): (1, 1, 2, 4),
(64, 64, 64, 32, 32, False, False, False): (1, 2, 2, 4),
(64, 64, 64, 32, 32, False, False, True): (1, 2, 2, 2),
(64, 64, 64, 32, 32, False, True, False): (1, 2, 1, 2),
(64, 64, 64, 32, 32, False, True, True): (1, 2, 1, 4),
(64, 64, 64, 32, 32, True, False, False): (1, 2, 1, 4),
(64, 64, 64, 32, 32, True, False, True): (1, 2, 1, 4),
(256, 256, 256, 16, 16, False, True, True): (4, 8, 5, 2),
(256, 256, 256, 32, 32, False, True, True): (1, 8, 5, 4),
(256, 256, 256, 64, 64, False, True, True): (2, 4, 4, 4),
@ -1018,6 +1236,222 @@ _operation_device_version_data: Dict[Any, Dict] = {
(16384, 16384, 131072, 128, 128, False, True, True): (4, 1024, 1, 4),
},
("bsr_dense_addmm", "NVIDIA A100-SXM4-80GB", (0, torch.float16, 0.5)): {
(16, 16, 16, 16, 16, False, False, False): (1, 1, 1, 1),
(16, 16, 16, 16, 16, False, False, True): (1, 1, 2, 2),
(16, 16, 16, 16, 16, False, True, False): (1, 1, 1, 1),
(16, 16, 16, 16, 16, False, True, True): (1, 1, 1, 8),
(16, 16, 16, 16, 16, True, False, False): (3, 1, 3, 4),
(16, 16, 16, 16, 16, True, False, True): (1, 1, 2, 1),
(16, 16, 32, 16, 16, False, False, False): (1, 2, 1, 8),
(16, 16, 32, 16, 16, False, False, True): (1, 2, 1, 2),
(16, 16, 32, 16, 16, False, True, False): (2, 1, 1, 4),
(16, 16, 32, 16, 16, False, True, True): (1, 2, 1, 4),
(16, 16, 32, 16, 16, True, False, False): (1, 1, 1, 4),
(16, 16, 32, 16, 16, True, False, True): (1, 2, 1, 2),
(16, 16, 64, 16, 16, False, False, False): (1, 4, 1, 1),
(16, 16, 64, 16, 16, False, False, True): (1, 2, 2, 4),
(16, 16, 64, 16, 16, False, True, False): (1, 4, 1, 4),
(16, 16, 64, 16, 16, False, True, True): (1, 2, 1, 4),
(16, 16, 64, 16, 16, True, False, False): (1, 4, 1, 2),
(16, 16, 64, 16, 16, True, False, True): (1, 1, 1, 2),
(16, 32, 16, 16, 16, False, False, False): (1, 1, 2, 4),
(16, 32, 16, 16, 16, False, False, True): (1, 1, 1, 4),
(16, 32, 16, 16, 16, False, True, False): (1, 1, 1, 2),
(16, 32, 16, 16, 16, False, True, True): (1, 1, 1, 2),
(16, 32, 16, 16, 16, True, False, False): (1, 1, 2, 16),
(16, 32, 16, 16, 16, True, False, True): (1, 1, 1, 4),
(16, 32, 16, 16, 32, False, False, False): (2, 1, 1, 8),
(16, 32, 16, 16, 32, False, False, True): (2, 1, 1, 8),
(16, 32, 16, 16, 32, False, True, False): (1, 1, 2, 1),
(16, 32, 16, 16, 32, False, True, True): (1, 1, 1, 4),
(16, 32, 16, 16, 32, True, False, False): (2, 1, 1, 8),
(16, 32, 16, 16, 32, True, False, True): (1, 1, 2, 4),
(16, 32, 32, 16, 16, False, False, False): (1, 1, 1, 16),
(16, 32, 32, 16, 16, False, False, True): (1, 2, 1, 2),
(16, 32, 32, 16, 16, False, True, False): (1, 2, 1, 8),
(16, 32, 32, 16, 16, False, True, True): (3, 2, 1, 4),
(16, 32, 32, 16, 16, True, False, False): (1, 2, 1, 4),
(16, 32, 32, 16, 16, True, False, True): (1, 2, 1, 2),
(16, 32, 32, 16, 32, False, False, False): (1, 2, 1, 2),
(16, 32, 32, 16, 32, False, False, True): (1, 1, 1, 4),
(16, 32, 32, 16, 32, False, True, False): (1, 1, 2, 4),
(16, 32, 32, 16, 32, False, True, True): (1, 2, 1, 2),
(16, 32, 32, 16, 32, True, False, False): (1, 2, 1, 2),
(16, 32, 32, 16, 32, True, False, True): (1, 2, 1, 16),
(16, 32, 64, 16, 16, False, False, False): (1, 4, 1, 4),
(16, 32, 64, 16, 16, False, False, True): (2, 4, 1, 4),
(16, 32, 64, 16, 16, False, True, False): (1, 4, 1, 4),
(16, 32, 64, 16, 16, False, True, True): (1, 4, 1, 4),
(16, 32, 64, 16, 16, True, False, False): (3, 4, 1, 2),
(16, 32, 64, 16, 16, True, False, True): (1, 4, 1, 1),
(16, 32, 64, 16, 32, False, False, False): (1, 4, 1, 16),
(16, 32, 64, 16, 32, False, False, True): (1, 2, 1, 2),
(16, 32, 64, 16, 32, False, True, False): (1, 4, 2, 2),
(16, 32, 64, 16, 32, False, True, True): (1, 4, 1, 8),
(16, 32, 64, 16, 32, True, False, False): (1, 4, 1, 8),
(16, 32, 64, 16, 32, True, False, True): (1, 2, 1, 4),
(16, 64, 16, 16, 32, False, False, False): (1, 1, 1, 2),
(16, 64, 16, 16, 32, False, False, True): (1, 1, 1, 4),
(16, 64, 16, 16, 32, False, True, False): (2, 1, 2, 4),
(16, 64, 16, 16, 32, False, True, True): (1, 1, 1, 4),
(16, 64, 16, 16, 32, True, False, False): (1, 1, 1, 4),
(16, 64, 16, 16, 32, True, False, True): (1, 1, 1, 4),
(16, 64, 32, 16, 32, False, False, False): (1, 2, 1, 2),
(16, 64, 32, 16, 32, False, False, True): (1, 1, 1, 4),
(16, 64, 32, 16, 32, False, True, False): (1, 1, 1, 4),
(16, 64, 32, 16, 32, False, True, True): (1, 2, 3, 2),
(16, 64, 32, 16, 32, True, False, False): (1, 1, 1, 4),
(16, 64, 32, 16, 32, True, False, True): (1, 1, 2, 4),
(16, 64, 64, 16, 32, False, False, False): (1, 4, 1, 8),
(16, 64, 64, 16, 32, False, False, True): (1, 4, 1, 4),
(16, 64, 64, 16, 32, False, True, False): (1, 4, 1, 1),
(16, 64, 64, 16, 32, False, True, True): (2, 4, 1, 4),
(16, 64, 64, 16, 32, True, False, False): (1, 4, 1, 4),
(16, 64, 64, 16, 32, True, False, True): (1, 4, 1, 4),
(32, 16, 16, 16, 16, False, False, False): (2, 1, 2, 4),
(32, 16, 16, 16, 16, False, False, True): (2, 1, 1, 2),
(32, 16, 16, 16, 16, False, True, False): (1, 1, 2, 4),
(32, 16, 16, 16, 16, False, True, True): (1, 1, 1, 2),
(32, 16, 16, 16, 16, True, False, False): (1, 1, 1, 4),
(32, 16, 16, 16, 16, True, False, True): (2, 1, 1, 2),
(32, 16, 32, 16, 16, False, False, False): (1, 1, 1, 4),
(32, 16, 32, 16, 16, False, False, True): (1, 1, 1, 4),
(32, 16, 32, 16, 16, False, True, False): (1, 2, 1, 4),
(32, 16, 32, 16, 16, False, True, True): (2, 2, 1, 4),
(32, 16, 32, 16, 16, True, False, False): (2, 1, 1, 4),
(32, 16, 32, 16, 16, True, False, True): (2, 2, 1, 2),
(32, 16, 64, 16, 16, False, False, False): (1, 4, 1, 2),
(32, 16, 64, 16, 16, False, False, True): (1, 4, 1, 4),
(32, 16, 64, 16, 16, False, True, False): (1, 2, 1, 4),
(32, 16, 64, 16, 16, False, True, True): (1, 4, 1, 2),
(32, 16, 64, 16, 16, True, False, False): (1, 4, 2, 8),
(32, 16, 64, 16, 16, True, False, True): (1, 4, 1, 1),
(32, 32, 16, 16, 16, False, False, False): (1, 1, 1, 4),
(32, 32, 16, 16, 16, False, False, True): (2, 1, 1, 4),
(32, 32, 16, 16, 16, False, True, False): (1, 1, 2, 4),
(32, 32, 16, 16, 16, False, True, True): (1, 1, 2, 2),
(32, 32, 16, 16, 16, True, False, False): (1, 1, 1, 8),
(32, 32, 16, 16, 16, True, False, True): (1, 1, 1, 4),
(32, 32, 16, 16, 32, False, False, False): (1, 1, 3, 2),
(32, 32, 16, 16, 32, False, False, True): (2, 1, 1, 4),
(32, 32, 16, 16, 32, False, True, False): (3, 1, 1, 4),
(32, 32, 16, 16, 32, False, True, True): (1, 1, 1, 4),
(32, 32, 16, 16, 32, True, False, False): (2, 1, 1, 8),
(32, 32, 16, 16, 32, True, False, True): (1, 1, 3, 2),
(32, 32, 16, 32, 32, False, False, False): (1, 1, 1, 2),
(32, 32, 16, 32, 32, False, False, True): (2, 1, 1, 8),
(32, 32, 16, 32, 32, False, True, False): (1, 1, 1, 2),
(32, 32, 16, 32, 32, False, True, True): (1, 1, 1, 8),
(32, 32, 16, 32, 32, True, False, False): (1, 1, 2, 4),
(32, 32, 16, 32, 32, True, False, True): (1, 1, 1, 2),
(32, 32, 32, 16, 16, False, False, False): (1, 1, 1, 4),
(32, 32, 32, 16, 16, False, False, True): (1, 2, 1, 4),
(32, 32, 32, 16, 16, False, True, False): (1, 2, 1, 4),
(32, 32, 32, 16, 16, False, True, True): (1, 2, 1, 2),
(32, 32, 32, 16, 16, True, False, False): (1, 2, 1, 4),
(32, 32, 32, 16, 16, True, False, True): (1, 2, 1, 4),
(32, 32, 32, 16, 32, False, False, False): (1, 2, 1, 4),
(32, 32, 32, 16, 32, False, False, True): (1, 2, 1, 2),
(32, 32, 32, 16, 32, False, True, False): (1, 2, 1, 4),
(32, 32, 32, 16, 32, False, True, True): (1, 2, 1, 2),
(32, 32, 32, 16, 32, True, False, False): (1, 2, 1, 1),
(32, 32, 32, 16, 32, True, False, True): (1, 2, 1, 2),
(32, 32, 32, 32, 32, False, False, False): (1, 1, 1, 4),
(32, 32, 32, 32, 32, False, False, True): (2, 1, 1, 4),
(32, 32, 32, 32, 32, False, True, False): (1, 1, 1, 8),
(32, 32, 32, 32, 32, False, True, True): (1, 1, 1, 8),
(32, 32, 32, 32, 32, True, False, False): (1, 1, 3, 4),
(32, 32, 32, 32, 32, True, False, True): (1, 1, 1, 8),
(32, 32, 64, 16, 16, False, False, False): (1, 4, 1, 4),
(32, 32, 64, 16, 16, False, False, True): (1, 4, 1, 2),
(32, 32, 64, 16, 16, False, True, False): (1, 1, 1, 4),
(32, 32, 64, 16, 16, False, True, True): (1, 4, 1, 4),
(32, 32, 64, 16, 16, True, False, False): (1, 4, 1, 8),
(32, 32, 64, 16, 16, True, False, True): (1, 4, 1, 2),
(32, 32, 64, 16, 32, False, False, False): (1, 1, 1, 4),
(32, 32, 64, 16, 32, False, False, True): (1, 4, 1, 4),
(32, 32, 64, 16, 32, False, True, False): (1, 1, 1, 4),
(32, 32, 64, 16, 32, False, True, True): (1, 4, 1, 4),
(32, 32, 64, 16, 32, True, False, False): (2, 2, 1, 8),
(32, 32, 64, 16, 32, True, False, True): (1, 2, 1, 2),
(32, 32, 64, 32, 32, False, False, False): (1, 2, 1, 4),
(32, 32, 64, 32, 32, False, False, True): (1, 2, 1, 1),
(32, 32, 64, 32, 32, False, True, False): (1, 2, 2, 8),
(32, 32, 64, 32, 32, False, True, True): (1, 1, 1, 4),
(32, 32, 64, 32, 32, True, False, False): (1, 2, 1, 4),
(32, 32, 64, 32, 32, True, False, True): (2, 2, 1, 4),
(32, 64, 16, 16, 32, False, False, False): (1, 1, 1, 8),
(32, 64, 16, 16, 32, False, False, True): (1, 1, 1, 4),
(32, 64, 16, 16, 32, False, True, False): (2, 1, 1, 4),
(32, 64, 16, 16, 32, False, True, True): (1, 1, 1, 4),
(32, 64, 16, 16, 32, True, False, False): (1, 1, 2, 4),
(32, 64, 16, 16, 32, True, False, True): (1, 1, 2, 2),
(32, 64, 16, 32, 32, False, False, False): (1, 1, 1, 8),
(32, 64, 16, 32, 32, False, False, True): (2, 1, 1, 4),
(32, 64, 16, 32, 32, False, True, False): (1, 1, 1, 4),
(32, 64, 16, 32, 32, False, True, True): (1, 1, 2, 2),
(32, 64, 16, 32, 32, True, False, False): (1, 1, 1, 2),
(32, 64, 16, 32, 32, True, False, True): (2, 1, 2, 4),
(32, 64, 32, 16, 32, False, False, False): (1, 1, 1, 4),
(32, 64, 32, 16, 32, False, False, True): (1, 2, 1, 2),
(32, 64, 32, 16, 32, False, True, False): (1, 2, 3, 4),
(32, 64, 32, 16, 32, False, True, True): (2, 2, 1, 4),
(32, 64, 32, 16, 32, True, False, False): (1, 1, 1, 4),
(32, 64, 32, 16, 32, True, False, True): (1, 2, 2, 1),
(32, 64, 32, 32, 32, False, False, False): (1, 1, 1, 8),
(32, 64, 32, 32, 32, False, False, True): (1, 1, 1, 4),
(32, 64, 32, 32, 32, False, True, False): (1, 1, 2, 4),
(32, 64, 32, 32, 32, False, True, True): (1, 1, 1, 4),
(32, 64, 32, 32, 32, True, False, False): (2, 1, 1, 2),
(32, 64, 32, 32, 32, True, False, True): (1, 1, 1, 4),
(32, 64, 64, 16, 32, False, False, False): (1, 4, 2, 1),
(32, 64, 64, 16, 32, False, False, True): (3, 4, 1, 4),
(32, 64, 64, 16, 32, False, True, False): (1, 1, 1, 8),
(32, 64, 64, 16, 32, False, True, True): (1, 4, 1, 4),
(32, 64, 64, 16, 32, True, False, False): (1, 4, 1, 4),
(32, 64, 64, 16, 32, True, False, True): (2, 2, 3, 4),
(32, 64, 64, 32, 32, False, False, False): (1, 2, 1, 4),
(32, 64, 64, 32, 32, False, False, True): (1, 2, 1, 4),
(32, 64, 64, 32, 32, False, True, False): (1, 2, 2, 8),
(32, 64, 64, 32, 32, False, True, True): (1, 2, 1, 4),
(32, 64, 64, 32, 32, True, False, False): (1, 2, 2, 4),
(32, 64, 64, 32, 32, True, False, True): (1, 2, 1, 4),
(64, 32, 16, 32, 32, False, False, False): (1, 1, 1, 1),
(64, 32, 16, 32, 32, False, False, True): (1, 1, 2, 4),
(64, 32, 16, 32, 32, False, True, False): (2, 1, 1, 8),
(64, 32, 16, 32, 32, False, True, True): (1, 1, 1, 4),
(64, 32, 16, 32, 32, True, False, False): (2, 1, 1, 2),
(64, 32, 16, 32, 32, True, False, True): (1, 1, 1, 4),
(64, 32, 32, 32, 32, False, False, False): (3, 1, 1, 4),
(64, 32, 32, 32, 32, False, False, True): (1, 1, 1, 4),
(64, 32, 32, 32, 32, False, True, False): (1, 1, 1, 8),
(64, 32, 32, 32, 32, False, True, True): (1, 1, 1, 2),
(64, 32, 32, 32, 32, True, False, False): (1, 1, 1, 2),
(64, 32, 32, 32, 32, True, False, True): (1, 1, 1, 4),
(64, 32, 64, 32, 32, False, False, False): (1, 2, 1, 2),
(64, 32, 64, 32, 32, False, False, True): (3, 2, 1, 4),
(64, 32, 64, 32, 32, False, True, False): (1, 1, 1, 1),
(64, 32, 64, 32, 32, False, True, True): (1, 2, 1, 4),
(64, 32, 64, 32, 32, True, False, False): (1, 1, 3, 4),
(64, 32, 64, 32, 32, True, False, True): (1, 2, 2, 4),
(64, 64, 16, 32, 32, False, False, False): (1, 1, 2, 2),
(64, 64, 16, 32, 32, False, False, True): (1, 1, 3, 2),
(64, 64, 16, 32, 32, False, True, False): (1, 1, 1, 8),
(64, 64, 16, 32, 32, False, True, True): (1, 1, 2, 4),
(64, 64, 16, 32, 32, True, False, False): (1, 1, 2, 4),
(64, 64, 16, 32, 32, True, False, True): (2, 1, 2, 4),
(64, 64, 32, 32, 32, False, False, False): (1, 1, 2, 8),
(64, 64, 32, 32, 32, False, False, True): (1, 1, 2, 4),
(64, 64, 32, 32, 32, False, True, False): (1, 1, 1, 4),
(64, 64, 32, 32, 32, False, True, True): (1, 1, 1, 4),
(64, 64, 32, 32, 32, True, False, False): (1, 1, 1, 4),
(64, 64, 32, 32, 32, True, False, True): (2, 1, 2, 4),
(64, 64, 64, 32, 32, False, False, False): (1, 2, 1, 4),
(64, 64, 64, 32, 32, False, False, True): (1, 2, 1, 4),
(64, 64, 64, 32, 32, False, True, False): (1, 2, 1, 4),
(64, 64, 64, 32, 32, False, True, True): (3, 2, 1, 4),
(64, 64, 64, 32, 32, True, False, False): (1, 2, 1, 8),
(64, 64, 64, 32, 32, True, False, True): (1, 2, 3, 4),
(256, 256, 256, 16, 16, False, True, True): (4, 8, 5, 2),
(256, 256, 256, 32, 32, False, True, True): (1, 8, 5, 4),
(256, 256, 256, 64, 64, False, True, True): (1, 4, 5, 4),
@ -1300,6 +1734,222 @@ _operation_device_version_data: Dict[Any, Dict] = {
(16384, 16384, 131072, 128, 128, False, True, True): (4, 1024, 1, 4),
},
("bsr_dense_addmm", "NVIDIA A100-SXM4-80GB", (0, torch.float32, 0.5)): {
(16, 16, 16, 16, 16, False, False, False): (2, 1, 1, 16),
(16, 16, 16, 16, 16, False, False, True): (1, 1, 2, 4),
(16, 16, 16, 16, 16, False, True, False): (1, 1, 2, 16),
(16, 16, 16, 16, 16, False, True, True): (2, 1, 2, 8),
(16, 16, 16, 16, 16, True, False, False): (1, 1, 1, 2),
(16, 16, 16, 16, 16, True, False, True): (2, 1, 1, 4),
(16, 16, 32, 16, 16, False, False, False): (1, 1, 1, 2),
(16, 16, 32, 16, 16, False, False, True): (1, 1, 2, 8),
(16, 16, 32, 16, 16, False, True, False): (1, 2, 1, 4),
(16, 16, 32, 16, 16, False, True, True): (1, 2, 2, 4),
(16, 16, 32, 16, 16, True, False, False): (1, 1, 2, 4),
(16, 16, 32, 16, 16, True, False, True): (1, 2, 2, 4),
(16, 16, 64, 16, 16, False, False, False): (1, 4, 1, 4),
(16, 16, 64, 16, 16, False, False, True): (2, 2, 1, 4),
(16, 16, 64, 16, 16, False, True, False): (1, 4, 1, 4),
(16, 16, 64, 16, 16, False, True, True): (1, 4, 1, 8),
(16, 16, 64, 16, 16, True, False, False): (1, 2, 1, 4),
(16, 16, 64, 16, 16, True, False, True): (1, 4, 2, 8),
(16, 32, 16, 16, 16, False, False, False): (1, 1, 2, 8),
(16, 32, 16, 16, 16, False, False, True): (2, 1, 1, 4),
(16, 32, 16, 16, 16, False, True, False): (1, 1, 1, 4),
(16, 32, 16, 16, 16, False, True, True): (1, 1, 1, 4),
(16, 32, 16, 16, 16, True, False, False): (1, 1, 1, 4),
(16, 32, 16, 16, 16, True, False, True): (1, 1, 2, 8),
(16, 32, 16, 16, 32, False, False, False): (1, 1, 2, 4),
(16, 32, 16, 16, 32, False, False, True): (2, 1, 2, 2),
(16, 32, 16, 16, 32, False, True, False): (1, 1, 1, 8),
(16, 32, 16, 16, 32, False, True, True): (1, 1, 1, 2),
(16, 32, 16, 16, 32, True, False, False): (3, 1, 1, 4),
(16, 32, 16, 16, 32, True, False, True): (1, 1, 1, 4),
(16, 32, 32, 16, 16, False, False, False): (1, 2, 1, 4),
(16, 32, 32, 16, 16, False, False, True): (2, 2, 1, 4),
(16, 32, 32, 16, 16, False, True, False): (1, 2, 1, 2),
(16, 32, 32, 16, 16, False, True, True): (1, 2, 1, 4),
(16, 32, 32, 16, 16, True, False, False): (1, 2, 1, 4),
(16, 32, 32, 16, 16, True, False, True): (1, 2, 1, 4),
(16, 32, 32, 16, 32, False, False, False): (1, 1, 2, 4),
(16, 32, 32, 16, 32, False, False, True): (1, 2, 1, 4),
(16, 32, 32, 16, 32, False, True, False): (1, 2, 2, 8),
(16, 32, 32, 16, 32, False, True, True): (1, 2, 1, 1),
(16, 32, 32, 16, 32, True, False, False): (1, 2, 1, 2),
(16, 32, 32, 16, 32, True, False, True): (1, 2, 1, 4),
(16, 32, 64, 16, 16, False, False, False): (1, 2, 1, 4),
(16, 32, 64, 16, 16, False, False, True): (2, 4, 1, 4),
(16, 32, 64, 16, 16, False, True, False): (1, 4, 2, 4),
(16, 32, 64, 16, 16, False, True, True): (1, 4, 1, 4),
(16, 32, 64, 16, 16, True, False, False): (1, 2, 2, 8),
(16, 32, 64, 16, 16, True, False, True): (1, 4, 1, 2),
(16, 32, 64, 16, 32, False, False, False): (1, 4, 1, 4),
(16, 32, 64, 16, 32, False, False, True): (1, 4, 3, 4),
(16, 32, 64, 16, 32, False, True, False): (1, 2, 1, 4),
(16, 32, 64, 16, 32, False, True, True): (1, 4, 1, 4),
(16, 32, 64, 16, 32, True, False, False): (1, 2, 1, 8),
(16, 32, 64, 16, 32, True, False, True): (1, 2, 1, 4),
(16, 64, 16, 16, 32, False, False, False): (1, 1, 1, 2),
(16, 64, 16, 16, 32, False, False, True): (1, 1, 1, 8),
(16, 64, 16, 16, 32, False, True, False): (1, 1, 1, 8),
(16, 64, 16, 16, 32, False, True, True): (1, 1, 1, 4),
(16, 64, 16, 16, 32, True, False, False): (1, 1, 1, 8),
(16, 64, 16, 16, 32, True, False, True): (1, 1, 1, 4),
(16, 64, 32, 16, 32, False, False, False): (1, 2, 1, 4),
(16, 64, 32, 16, 32, False, False, True): (1, 1, 1, 4),
(16, 64, 32, 16, 32, False, True, False): (1, 2, 1, 1),
(16, 64, 32, 16, 32, False, True, True): (1, 2, 1, 8),
(16, 64, 32, 16, 32, True, False, False): (2, 2, 1, 4),
(16, 64, 32, 16, 32, True, False, True): (2, 2, 1, 4),
(16, 64, 64, 16, 32, False, False, False): (1, 2, 1, 4),
(16, 64, 64, 16, 32, False, False, True): (1, 4, 1, 4),
(16, 64, 64, 16, 32, False, True, False): (1, 4, 1, 4),
(16, 64, 64, 16, 32, False, True, True): (1, 4, 1, 4),
(16, 64, 64, 16, 32, True, False, False): (1, 4, 1, 2),
(16, 64, 64, 16, 32, True, False, True): (3, 4, 1, 4),
(32, 16, 16, 16, 16, False, False, False): (1, 1, 2, 4),
(32, 16, 16, 16, 16, False, False, True): (1, 1, 1, 2),
(32, 16, 16, 16, 16, False, True, False): (1, 1, 2, 4),
(32, 16, 16, 16, 16, False, True, True): (1, 1, 2, 4),
(32, 16, 16, 16, 16, True, False, False): (1, 1, 3, 8),
(32, 16, 16, 16, 16, True, False, True): (1, 1, 2, 4),
(32, 16, 32, 16, 16, False, False, False): (1, 2, 1, 4),
(32, 16, 32, 16, 16, False, False, True): (1, 2, 3, 4),
(32, 16, 32, 16, 16, False, True, False): (1, 1, 1, 8),
(32, 16, 32, 16, 16, False, True, True): (1, 2, 1, 4),
(32, 16, 32, 16, 16, True, False, False): (1, 1, 1, 2),
(32, 16, 32, 16, 16, True, False, True): (1, 1, 1, 4),
(32, 16, 64, 16, 16, False, False, False): (1, 4, 1, 4),
(32, 16, 64, 16, 16, False, False, True): (3, 4, 1, 4),
(32, 16, 64, 16, 16, False, True, False): (1, 4, 1, 1),
(32, 16, 64, 16, 16, False, True, True): (1, 4, 1, 4),
(32, 16, 64, 16, 16, True, False, False): (1, 4, 1, 4),
(32, 16, 64, 16, 16, True, False, True): (1, 4, 1, 4),
(32, 32, 16, 16, 16, False, False, False): (1, 1, 1, 2),
(32, 32, 16, 16, 16, False, False, True): (2, 1, 1, 4),
(32, 32, 16, 16, 16, False, True, False): (1, 1, 1, 2),
(32, 32, 16, 16, 16, False, True, True): (2, 1, 1, 4),
(32, 32, 16, 16, 16, True, False, False): (3, 1, 2, 4),
(32, 32, 16, 16, 16, True, False, True): (1, 1, 2, 4),
(32, 32, 16, 16, 32, False, False, False): (2, 1, 1, 2),
(32, 32, 16, 16, 32, False, False, True): (1, 1, 1, 4),
(32, 32, 16, 16, 32, False, True, False): (1, 1, 1, 4),
(32, 32, 16, 16, 32, False, True, True): (1, 1, 1, 8),
(32, 32, 16, 16, 32, True, False, False): (1, 1, 1, 8),
(32, 32, 16, 16, 32, True, False, True): (1, 1, 1, 4),
(32, 32, 16, 32, 32, False, False, False): (2, 1, 1, 4),
(32, 32, 16, 32, 32, False, False, True): (1, 1, 2, 4),
(32, 32, 16, 32, 32, False, True, False): (2, 1, 1, 1),
(32, 32, 16, 32, 32, False, True, True): (2, 1, 2, 4),
(32, 32, 16, 32, 32, True, False, False): (1, 1, 1, 8),
(32, 32, 16, 32, 32, True, False, True): (1, 1, 1, 4),
(32, 32, 32, 16, 16, False, False, False): (1, 1, 1, 4),
(32, 32, 32, 16, 16, False, False, True): (1, 2, 1, 2),
(32, 32, 32, 16, 16, False, True, False): (2, 2, 1, 4),
(32, 32, 32, 16, 16, False, True, True): (1, 2, 2, 4),
(32, 32, 32, 16, 16, True, False, False): (1, 2, 1, 4),
(32, 32, 32, 16, 16, True, False, True): (2, 2, 1, 4),
(32, 32, 32, 16, 32, False, False, False): (1, 2, 1, 4),
(32, 32, 32, 16, 32, False, False, True): (1, 2, 1, 4),
(32, 32, 32, 16, 32, False, True, False): (1, 2, 1, 4),
(32, 32, 32, 16, 32, False, True, True): (1, 2, 1, 4),
(32, 32, 32, 16, 32, True, False, False): (2, 1, 1, 2),
(32, 32, 32, 16, 32, True, False, True): (2, 2, 2, 4),
(32, 32, 32, 32, 32, False, False, False): (1, 1, 1, 4),
(32, 32, 32, 32, 32, False, False, True): (1, 1, 1, 2),
(32, 32, 32, 32, 32, False, True, False): (1, 1, 1, 4),
(32, 32, 32, 32, 32, False, True, True): (1, 1, 2, 2),
(32, 32, 32, 32, 32, True, False, False): (1, 1, 1, 2),
(32, 32, 32, 32, 32, True, False, True): (1, 1, 2, 1),
(32, 32, 64, 16, 16, False, False, False): (2, 4, 1, 4),
(32, 32, 64, 16, 16, False, False, True): (1, 4, 2, 4),
(32, 32, 64, 16, 16, False, True, False): (1, 4, 1, 4),
(32, 32, 64, 16, 16, False, True, True): (1, 4, 1, 4),
(32, 32, 64, 16, 16, True, False, False): (1, 2, 1, 4),
(32, 32, 64, 16, 16, True, False, True): (2, 4, 1, 4),
(32, 32, 64, 16, 32, False, False, False): (1, 4, 1, 8),
(32, 32, 64, 16, 32, False, False, True): (1, 4, 1, 4),
(32, 32, 64, 16, 32, False, True, False): (1, 4, 1, 4),
(32, 32, 64, 16, 32, False, True, True): (2, 4, 1, 4),
(32, 32, 64, 16, 32, True, False, False): (1, 2, 2, 4),
(32, 32, 64, 16, 32, True, False, True): (2, 4, 1, 4),
(32, 32, 64, 32, 32, False, False, False): (2, 2, 1, 4),
(32, 32, 64, 32, 32, False, False, True): (1, 1, 1, 4),
(32, 32, 64, 32, 32, False, True, False): (1, 1, 1, 8),
(32, 32, 64, 32, 32, False, True, True): (2, 1, 1, 4),
(32, 32, 64, 32, 32, True, False, False): (1, 1, 1, 4),
(32, 32, 64, 32, 32, True, False, True): (1, 2, 1, 1),
(32, 64, 16, 16, 32, False, False, False): (1, 1, 2, 2),
(32, 64, 16, 16, 32, False, False, True): (2, 1, 1, 4),
(32, 64, 16, 16, 32, False, True, False): (1, 1, 1, 8),
(32, 64, 16, 16, 32, False, True, True): (1, 1, 3, 4),
(32, 64, 16, 16, 32, True, False, False): (1, 1, 1, 2),
(32, 64, 16, 16, 32, True, False, True): (1, 1, 2, 4),
(32, 64, 16, 32, 32, False, False, False): (1, 1, 1, 2),
(32, 64, 16, 32, 32, False, False, True): (1, 1, 3, 4),
(32, 64, 16, 32, 32, False, True, False): (1, 1, 2, 4),
(32, 64, 16, 32, 32, False, True, True): (1, 1, 1, 8),
(32, 64, 16, 32, 32, True, False, False): (1, 1, 2, 4),
(32, 64, 16, 32, 32, True, False, True): (1, 1, 1, 8),
(32, 64, 32, 16, 32, False, False, False): (1, 2, 1, 4),
(32, 64, 32, 16, 32, False, False, True): (1, 2, 3, 4),
(32, 64, 32, 16, 32, False, True, False): (1, 2, 1, 8),
(32, 64, 32, 16, 32, False, True, True): (3, 2, 1, 4),
(32, 64, 32, 16, 32, True, False, False): (1, 1, 1, 8),
(32, 64, 32, 16, 32, True, False, True): (1, 2, 1, 4),
(32, 64, 32, 32, 32, False, False, False): (1, 1, 1, 1),
(32, 64, 32, 32, 32, False, False, True): (1, 1, 1, 4),
(32, 64, 32, 32, 32, False, True, False): (1, 1, 1, 4),
(32, 64, 32, 32, 32, False, True, True): (1, 1, 1, 4),
(32, 64, 32, 32, 32, True, False, False): (1, 1, 1, 4),
(32, 64, 32, 32, 32, True, False, True): (1, 1, 2, 8),
(32, 64, 64, 16, 32, False, False, False): (2, 4, 1, 4),
(32, 64, 64, 16, 32, False, False, True): (1, 4, 1, 4),
(32, 64, 64, 16, 32, False, True, False): (1, 4, 1, 4),
(32, 64, 64, 16, 32, False, True, True): (2, 4, 1, 4),
(32, 64, 64, 16, 32, True, False, False): (1, 4, 1, 4),
(32, 64, 64, 16, 32, True, False, True): (1, 4, 1, 4),
(32, 64, 64, 32, 32, False, False, False): (2, 2, 1, 4),
(32, 64, 64, 32, 32, False, False, True): (1, 2, 1, 8),
(32, 64, 64, 32, 32, False, True, False): (1, 2, 1, 4),
(32, 64, 64, 32, 32, False, True, True): (1, 2, 1, 4),
(32, 64, 64, 32, 32, True, False, False): (2, 2, 1, 4),
(32, 64, 64, 32, 32, True, False, True): (1, 2, 3, 8),
(64, 32, 16, 32, 32, False, False, False): (1, 1, 1, 4),
(64, 32, 16, 32, 32, False, False, True): (3, 1, 2, 4),
(64, 32, 16, 32, 32, False, True, False): (2, 1, 1, 2),
(64, 32, 16, 32, 32, False, True, True): (1, 1, 1, 8),
(64, 32, 16, 32, 32, True, False, False): (1, 1, 1, 2),
(64, 32, 16, 32, 32, True, False, True): (1, 1, 1, 4),
(64, 32, 32, 32, 32, False, False, False): (1, 1, 1, 4),
(64, 32, 32, 32, 32, False, False, True): (1, 1, 2, 8),
(64, 32, 32, 32, 32, False, True, False): (1, 1, 1, 8),
(64, 32, 32, 32, 32, False, True, True): (1, 1, 1, 4),
(64, 32, 32, 32, 32, True, False, False): (1, 1, 2, 4),
(64, 32, 32, 32, 32, True, False, True): (1, 1, 3, 8),
(64, 32, 64, 32, 32, False, False, False): (1, 2, 1, 4),
(64, 32, 64, 32, 32, False, False, True): (2, 2, 1, 4),
(64, 32, 64, 32, 32, False, True, False): (1, 1, 1, 4),
(64, 32, 64, 32, 32, False, True, True): (1, 2, 1, 8),
(64, 32, 64, 32, 32, True, False, False): (2, 2, 1, 4),
(64, 32, 64, 32, 32, True, False, True): (1, 2, 1, 8),
(64, 64, 16, 32, 32, False, False, False): (1, 1, 2, 8),
(64, 64, 16, 32, 32, False, False, True): (2, 1, 2, 4),
(64, 64, 16, 32, 32, False, True, False): (1, 1, 1, 2),
(64, 64, 16, 32, 32, False, True, True): (1, 1, 2, 4),
(64, 64, 16, 32, 32, True, False, False): (1, 1, 1, 2),
(64, 64, 16, 32, 32, True, False, True): (1, 1, 2, 4),
(64, 64, 32, 32, 32, False, False, False): (1, 1, 1, 4),
(64, 64, 32, 32, 32, False, False, True): (2, 1, 1, 4),
(64, 64, 32, 32, 32, False, True, False): (1, 1, 1, 8),
(64, 64, 32, 32, 32, False, True, True): (2, 1, 1, 4),
(64, 64, 32, 32, 32, True, False, False): (1, 1, 1, 4),
(64, 64, 32, 32, 32, True, False, True): (1, 1, 1, 8),
(64, 64, 64, 32, 32, False, False, False): (2, 2, 1, 4),
(64, 64, 64, 32, 32, False, False, True): (1, 2, 1, 4),
(64, 64, 64, 32, 32, False, True, False): (1, 2, 1, 4),
(64, 64, 64, 32, 32, False, True, True): (2, 2, 1, 4),
(64, 64, 64, 32, 32, True, False, False): (1, 1, 1, 8),
(64, 64, 64, 32, 32, True, False, True): (1, 2, 2, 4),
(256, 256, 256, 16, 16, False, True, True): (1, 16, 3, 4),
(256, 256, 256, 32, 32, False, True, True): (1, 8, 5, 4),
(256, 256, 256, 64, 64, False, True, True): (3, 4, 4, 8),
@ -1582,6 +2232,14 @@ _operation_device_version_data: Dict[Any, Dict] = {
(16384, 16384, 131072, 128, 128, False, True, True): (2, 1024, 1, 32),
},
("bsr_dense_mm", "NVIDIA A100-SXM4-80GB", (0, torch.bfloat16, 0.5)): {
(16, 16, 32, 16, 16): (1, 1, 1),
(16, 16, 64, 16, 16): (1, 2, 8),
(16, 32, 32, 16, 16): (1, 1, 4),
(16, 32, 64, 16, 16): (2, 1, 2),
(32, 32, 32, 32, 32): (1, 2, 4),
(32, 32, 64, 32, 32): (1, 1, 2),
(32, 64, 32, 32, 32): (1, 2, 8),
(32, 64, 64, 32, 32): (1, 1, 4),
(256, 256, 256, 16, 16): (3, 3, 1),
(256, 256, 256, 32, 32): (1, 5, 1),
(256, 256, 256, 64, 64): (2, 3, 2),
@ -2146,6 +2804,14 @@ _operation_device_version_data: Dict[Any, Dict] = {
(16384, 16384, 131072, 128, 128): (4, 1, 4),
},
("bsr_dense_mm", "NVIDIA A100-SXM4-80GB", (0, torch.float16, 0.5)): {
(16, 16, 32, 16, 16): (4, 3, 4),
(16, 16, 64, 16, 16): (1, 2, 4),
(16, 32, 32, 16, 16): (1, 1, 2),
(16, 32, 64, 16, 16): (2, 1, 2),
(32, 32, 32, 32, 32): (1, 1, 4),
(32, 32, 64, 32, 32): (1, 1, 1),
(32, 64, 32, 32, 32): (1, 1, 2),
(32, 64, 64, 32, 32): (2, 1, 4),
(256, 256, 256, 16, 16): (8, 6, 1),
(256, 256, 256, 32, 32): (4, 5, 2),
(256, 256, 256, 64, 64): (3, 3, 4),
@ -2710,6 +3376,14 @@ _operation_device_version_data: Dict[Any, Dict] = {
(16384, 16384, 131072, 128, 128): (4, 1, 4),
},
("bsr_dense_mm", "NVIDIA A100-SXM4-80GB", (0, torch.float32, 0.5)): {
(16, 16, 32, 16, 16): (1, 1, 4),
(16, 16, 64, 16, 16): (1, 1, 1),
(16, 32, 32, 16, 16): (1, 2, 4),
(16, 32, 64, 16, 16): (2, 1, 4),
(32, 32, 32, 32, 32): (1, 1, 8),
(32, 32, 64, 32, 32): (2, 1, 4),
(32, 64, 32, 32, 32): (1, 1, 8),
(32, 64, 64, 32, 32): (1, 2, 4),
(256, 256, 256, 16, 16): (1, 1, 8),
(256, 256, 256, 32, 32): (1, 3, 4),
(256, 256, 256, 64, 64): (1, 1, 8),