mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add tests for bsr_dense_addmm and bsr_dense_mm triton kernels (#114800)
As in the title. In addition, - resolve https://github.com/pytorch/pytorch/pull/114757#discussion_r1409547917 re triton-contiguous inputs - support non-contiguous inputs and outputs in triton kernels - fix a couple of minor bugs Pull Request resolved: https://github.com/pytorch/pytorch/pull/114800 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
aafa8233a4
commit
4ba37e1804
@ -3878,6 +3878,130 @@ class TestSparseCompressedTritonKernels(TestCase):
|
||||
# but key is still valid:
|
||||
self.assertEqual(d.get(key5), (key5, 567), **assertEqualOptions)
|
||||
|
||||
@parametrize("op", ['bsr_dense_addmm', 'bsr_dense_mm', 'bsr_dense_linear'])
|
||||
@parametrize("blocksize", [16, '16x32', 32])
|
||||
@onlyCUDA
|
||||
@skipIfRocm
|
||||
@dtypes(torch.half, torch.bfloat16, torch.float)
|
||||
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
|
||||
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
|
||||
def test_triton_kernel(self, op, device, dtype, blocksize):
|
||||
from torch.sparse._triton_ops import bsr_dense_addmm, bsr_dense_mm
|
||||
from torch.sparse._triton_ops_meta import (create_blocked_tensor, get_meta,
|
||||
optimize_bsr_dense_addmm, optimize_bsr_dense_mm, dump)
|
||||
|
||||
def bsr_dense_linear(input, weights, bias=None):
|
||||
return torch.nn.functional.linear(input, weights, bias=bias).transpose(-1, -2)
|
||||
|
||||
operation = dict(bsr_dense_addmm=bsr_dense_addmm, bsr_dense_mm=bsr_dense_mm, bsr_dense_linear=bsr_dense_linear)[op]
|
||||
|
||||
def reference(input, mat1, mat2, beta=1, alpha=1):
|
||||
assert mat1.layout is torch.strided
|
||||
assert mat2.layout is torch.strided
|
||||
return beta * input + alpha * (mat1 @ mat2)
|
||||
|
||||
def nc_copy(t, axes=(-1,)):
|
||||
"""Return a copy of input.
|
||||
|
||||
The returned copy will be a non-contiguous tensor.
|
||||
"""
|
||||
if t.layout is torch.strided:
|
||||
shape = list(t.shape)
|
||||
for a in axes:
|
||||
shape[a] *= 2
|
||||
r = torch.empty(shape, dtype=t.dtype, device=t.device)
|
||||
s = r[tuple(slice(None, None, 2 if t.shape[i] != r.shape[i] else None) for i in range(t.ndim))]
|
||||
s.copy_(t)
|
||||
return s
|
||||
elif t.layout is torch.sparse_bsr:
|
||||
compressed_indices = t.crow_indices()
|
||||
plain_indices = t.col_indices()
|
||||
return torch.sparse_compressed_tensor(compressed_indices, plain_indices, nc_copy(t.values()),
|
||||
t.shape, layout=t.layout)
|
||||
else:
|
||||
raise NotImplementedError(t.layout)
|
||||
|
||||
if isinstance(blocksize, str):
|
||||
BM, BK = tuple(map(int, blocksize.split('x')))
|
||||
else:
|
||||
BM, BK = (blocksize,) * 2
|
||||
|
||||
if op in {"bsr_dense_mm", "bsr_dense_linear"} and BM != BK:
|
||||
# todo: eliminate this skip
|
||||
self.skipTest(f"{op} does not support non-square blocks")
|
||||
|
||||
beta_lst = dict(bsr_dense_addmm=[0, 1, 2], bsr_dense_mm=[0], bsr_dense_linear=[1])[op]
|
||||
alpha_lst = dict(bsr_dense_addmm=[0, 1, 2], bsr_dense_mm=[1], bsr_dense_linear=[1])[op]
|
||||
sparsity_lst = [0, 0.5, 1]
|
||||
# todo: eliminate `[:1]`
|
||||
blocks_per_row_lst = dict(bsr_dense_addmm=[1, 2], bsr_dense_mm=[1, 2][:1], bsr_dense_linear=[1, 2])[op]
|
||||
blocks_per_col_lst = [1, 2]
|
||||
# todo: eliminate `[1:]`
|
||||
result_cols_lst = dict(bsr_dense_addmm=[16, 32, 64], bsr_dense_mm=[16, 32, 64][1:],
|
||||
bsr_dense_linear=[16, 32, 64])[op]
|
||||
for beta, alpha, sparsity, blocks_per_row, blocks_per_col, N in itertools.product(
|
||||
beta_lst, alpha_lst, sparsity_lst, blocks_per_row_lst, blocks_per_col_lst, result_cols_lst):
|
||||
M = BM * blocks_per_row
|
||||
K = BK * blocks_per_col
|
||||
mat1 = create_blocked_tensor(0, M, K, (BM, BK), sparsity, dtype, device=device)
|
||||
bsr = mat1.to_sparse_bsr((BM, BK))
|
||||
mat2 = make_tensor(K, N, dtype=dtype, device=device, low=0.5, high=1.5)
|
||||
input = make_tensor(M, N, dtype=dtype, device=device, low=0.5, high=1.5)
|
||||
|
||||
if 0 and op != "bsr_dense_linear":
|
||||
# Find optimal kernel parameters, the speed-up is
|
||||
# about 10x for running this test.
|
||||
#
|
||||
# Enable this if-block when the test method is
|
||||
# updated, run the test, and finally, disable the
|
||||
# if-block.
|
||||
key = dict(bsr_dense_addmm=(M, K, N, BM, BK, beta == 0, beta == 1, alpha == 1),
|
||||
bsr_dense_mm=(M, K, N, BM, BK))[op]
|
||||
meta = get_meta(op, key, version=(0, dtype, 0.5))
|
||||
if meta is None:
|
||||
if op == 'bsr_dense_addmm':
|
||||
optimize_bsr_dense_addmm(M, K, N, BM, BK, beta=beta, alpha=alpha, dtype=dtype, sparsity=0.5)
|
||||
elif op == 'bsr_dense_mm':
|
||||
optimize_bsr_dense_mm(M, K, N, BM, BK, dtype=dtype, sparsity=0.5)
|
||||
else:
|
||||
raise NotImplementedError(op)
|
||||
meta = get_meta(op, key, version=(0, dtype, 0.5))
|
||||
assert meta is not None
|
||||
dump() # this will update torch/sparse/_triton_ops_meta.py
|
||||
|
||||
expected = reference(input, mat1, mat2, beta=beta, alpha=alpha)
|
||||
kwargs = dict(bsr_dense_addmm=dict(beta=beta, alpha=alpha), bsr_dense_mm=dict(),
|
||||
bsr_dense_linear=dict(bias=input.transpose(-1, -2)))[op]
|
||||
|
||||
args = dict(bsr_dense_addmm=(input, bsr, mat2), bsr_dense_mm=(bsr, mat2),
|
||||
bsr_dense_linear=(mat2.transpose(-1, -2), bsr))[op]
|
||||
result = operation(*args, **kwargs)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
# Test non-contiguous input tensors:
|
||||
nc_mat2 = nc_copy(mat2)
|
||||
nc_input = nc_copy(input)
|
||||
nc_bsr = nc_copy(bsr)
|
||||
|
||||
args = dict(bsr_dense_addmm=(input, bsr, nc_mat2), bsr_dense_mm=(bsr, nc_mat2),
|
||||
bsr_dense_linear=(nc_mat2.transpose(-1, -2), bsr))[op]
|
||||
result = operation(*args, **kwargs)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
# todo: add bsr_dense_linear to the set below
|
||||
if op in {'bsr_dense_addmm', 'bsr_dense_mm'}:
|
||||
args = dict(bsr_dense_addmm=(input, nc_bsr, mat2), bsr_dense_mm=(nc_bsr, mat2),
|
||||
bsr_dense_linear=(mat2.transpose(-1, -2), nc_bsr))[op]
|
||||
result = operation(*args, **kwargs)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
if op in {'bsr_dense_addmm', 'bsr_dense_linear'}:
|
||||
args = dict(bsr_dense_addmm=(nc_input, bsr, nc_mat2),
|
||||
bsr_dense_linear=(nc_mat2.transpose(-1, -2), bsr))[op]
|
||||
kwargs = dict(bsr_dense_addmm=dict(beta=beta, alpha=alpha),
|
||||
bsr_dense_linear=dict(bias=nc_input.transpose(-1, -2)))[op]
|
||||
result = operation(*args, **kwargs)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
# e.g., TestSparseCSRCPU and TestSparseCSRCUDA
|
||||
instantiate_device_type_tests(TestSparseCSR, globals())
|
||||
|
@ -77,12 +77,19 @@ def check_blocksize(f_name, blocksize):
|
||||
|
||||
|
||||
def make_triton_contiguous(t):
|
||||
# TODO: Why do we need "triton contiguity" that is not defined by
|
||||
# triton itself? It looks like it is required until
|
||||
# openai/triton#1291 fixed a bug in processing triton kernel
|
||||
# arguments. Unless triton comntiguity is required for
|
||||
# performance, remove this function.
|
||||
if (t.stride(-2) > 1 or t.dtype is torch.float32) and t.stride(-1) > 1:
|
||||
"""Return input as a triton-contiguous tensor.
|
||||
|
||||
A triton-contiguous tensor is defined as a tensor that has strides
|
||||
with minimal value equal to 1.
|
||||
|
||||
While triton kernels support triton-non-contiguous tensors (all
|
||||
strides being greater than 1 or having 0 strides) arguments, a
|
||||
considerable slow-down occurs because tensor data is copied
|
||||
element-wise rather than chunk-wise.
|
||||
"""
|
||||
if min(t.stride()) != 1:
|
||||
# TODO: investigate if contiguity along other axes than the
|
||||
# last one can be beneficial for performance
|
||||
return t.contiguous()
|
||||
else:
|
||||
return t
|
||||
@ -162,16 +169,12 @@ def launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks=None):
|
||||
kernel(grid, *sliced_tensors)
|
||||
|
||||
|
||||
def prepare_inputs(bsr, *dense_tensors, require_view=False):
|
||||
def prepare_inputs(bsr, *dense_tensors):
|
||||
# Introduce fake batch dimension if not present for convenience.
|
||||
crow_indices = bsr.crow_indices().unsqueeze(0)
|
||||
col_indices = bsr.col_indices().unsqueeze(0)
|
||||
if require_view:
|
||||
values = bsr.values().unsqueeze(0)
|
||||
tensors = [t.unsqueeze(0) for t in dense_tensors]
|
||||
else:
|
||||
values = make_triton_contiguous(bsr.values().unsqueeze(0))
|
||||
tensors = [make_triton_contiguous(t.unsqueeze(0)) for t in dense_tensors]
|
||||
values = make_triton_contiguous(bsr.values().unsqueeze(0))
|
||||
tensors = [make_triton_contiguous(t.unsqueeze(0)) for t in dense_tensors]
|
||||
|
||||
# Compute broadcasted batch dimension
|
||||
batch_dims_broadcasted = torch.broadcast_shapes(values.shape[:-3], *(t.shape[:-2] for t in tensors))
|
||||
@ -848,7 +851,7 @@ def bsr_dense_addmm(
|
||||
|
||||
out_backup = out
|
||||
|
||||
crow_indices, col_indices, values, input, dense, out = prepare_inputs(bsr, input, dense, out, require_view=True)
|
||||
crow_indices, col_indices, values, input, dense, out = prepare_inputs(bsr, input, dense, out)
|
||||
|
||||
BM, BK = blocksize
|
||||
SPLIT_N = meta.get('SPLIT_N', N // BM)
|
||||
@ -856,12 +859,9 @@ def bsr_dense_addmm(
|
||||
|
||||
dense = tile_to_blocksize(dense, (BK, BN))
|
||||
input = tile_to_blocksize(input, (BM, BN))
|
||||
out_untiled = out
|
||||
out = tile_to_blocksize(out, (BM, BN))
|
||||
|
||||
# out and out_backup may have different shapes/strides/offsets
|
||||
# but they must share storage:
|
||||
assert out.data_ptr() == out_backup.data_ptr()
|
||||
|
||||
dot_out_dtype = {torch.float16: tl.float32,
|
||||
torch.bfloat16: tl.float32,
|
||||
torch.float32: tl.float64,
|
||||
@ -904,6 +904,11 @@ def bsr_dense_addmm(
|
||||
|
||||
launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks)
|
||||
|
||||
if out.data_ptr() != out_backup.data_ptr():
|
||||
# prepare_inputs has made a copy of out, copy its content back
|
||||
# to out_backup:
|
||||
out_backup.copy_(out_untiled.view(out_backup.shape))
|
||||
|
||||
return out_backup
|
||||
|
||||
|
||||
@ -1408,7 +1413,7 @@ if has_triton():
|
||||
out_backup = out
|
||||
|
||||
# prepare inputs by reshaping them to be kernel-compatible.
|
||||
crow_indices, col_indices, values, dense, out = prepare_inputs(bsr, dense, out, require_view=True)
|
||||
crow_indices, col_indices, values, dense, out = prepare_inputs(bsr, dense, out)
|
||||
|
||||
# "Blockify" the row dimension of dense with blocksize[1]
|
||||
# since dense is on the rhs of matmul
|
||||
@ -1420,15 +1425,17 @@ if has_triton():
|
||||
# so it could be any value in [1, dense.shape[-1]).
|
||||
# We need to probably use the largest possible blocksize
|
||||
# so that it fits into SRAM.
|
||||
out_untiled = out
|
||||
out = tile_to_blocksize(out, (blocksize[0], blocksize[0]))
|
||||
|
||||
# out and out_backup may have different shapes/strides/offsets
|
||||
# but they must share storage:
|
||||
assert out.data_ptr() == out_backup.data_ptr()
|
||||
|
||||
# Launch kernel
|
||||
_run_dense_rowspace_kernel(blocksize, values, crow_indices, col_indices, dense, out, max_grid, meta)
|
||||
|
||||
if out.data_ptr() != out_backup.data_ptr():
|
||||
# prepare_inputs has made a copy of out, copy its content
|
||||
# back to out_backup:
|
||||
out_backup.copy_(out_untiled.view(out_backup.shape))
|
||||
|
||||
return out_backup
|
||||
|
||||
|
||||
|
@ -518,7 +518,9 @@ def optimize_bsr_dense_addmm(
|
||||
key = (m, k, n, bm, bk, beta == 0, beta == 1, alpha == 1)
|
||||
version = (0, dtype, sparsity)
|
||||
|
||||
reference_meta = dict(GROUP_SIZE_ROW=1, num_stages=1, num_warps=4, SPLIT_N=n // bm)
|
||||
reference_meta = dict(
|
||||
GROUP_SIZE_ROW=1, num_stages=1, num_warps=4, SPLIT_N=max(n // bm, 1)
|
||||
)
|
||||
|
||||
initial_meta = get_meta("bsr_dense_addmm", key, version=version, exact=True)
|
||||
|
||||
@ -556,7 +558,7 @@ def optimize_bsr_dense_addmm(
|
||||
# value. The input value is assumed to be valid.
|
||||
is_log = name in {"SPLIT_N", "num_warps"}
|
||||
min_value = dict(SPLIT_N=1, num_warps=1, num_stages=1, GROUP_SIZE_ROW=1)[name]
|
||||
max_value = dict(SPLIT_N=n // bm).get(name)
|
||||
max_value = dict(SPLIT_N=max(n // bm, 1)).get(name)
|
||||
value_step = dict(SPLIT_N=2, num_warps=2, num_stages=1, GROUP_SIZE_ROW=1)[name]
|
||||
if is_log:
|
||||
next_value = (
|
||||
@ -736,6 +738,222 @@ _operation_device_version_data: Dict[Any, Dict] = {
|
||||
#
|
||||
# BEGIN GENERATED DATA
|
||||
("bsr_dense_addmm", "NVIDIA A100-SXM4-80GB", (0, torch.bfloat16, 0.5)): {
|
||||
(16, 16, 16, 16, 16, False, False, False): (2, 1, 1, 2),
|
||||
(16, 16, 16, 16, 16, False, False, True): (1, 1, 1, 4),
|
||||
(16, 16, 16, 16, 16, False, True, False): (1, 1, 3, 16),
|
||||
(16, 16, 16, 16, 16, False, True, True): (1, 1, 1, 8),
|
||||
(16, 16, 16, 16, 16, True, False, False): (2, 1, 1, 8),
|
||||
(16, 16, 16, 16, 16, True, False, True): (1, 1, 1, 8),
|
||||
(16, 16, 32, 16, 16, False, False, False): (1, 2, 1, 8),
|
||||
(16, 16, 32, 16, 16, False, False, True): (1, 2, 2, 4),
|
||||
(16, 16, 32, 16, 16, False, True, False): (1, 1, 2, 4),
|
||||
(16, 16, 32, 16, 16, False, True, True): (1, 1, 2, 4),
|
||||
(16, 16, 32, 16, 16, True, False, False): (1, 1, 2, 4),
|
||||
(16, 16, 32, 16, 16, True, False, True): (2, 2, 1, 2),
|
||||
(16, 16, 64, 16, 16, False, False, False): (1, 4, 2, 4),
|
||||
(16, 16, 64, 16, 16, False, False, True): (1, 2, 1, 2),
|
||||
(16, 16, 64, 16, 16, False, True, False): (2, 1, 1, 2),
|
||||
(16, 16, 64, 16, 16, False, True, True): (1, 4, 1, 8),
|
||||
(16, 16, 64, 16, 16, True, False, False): (1, 4, 1, 1),
|
||||
(16, 16, 64, 16, 16, True, False, True): (1, 4, 2, 4),
|
||||
(16, 32, 16, 16, 16, False, False, False): (1, 1, 2, 2),
|
||||
(16, 32, 16, 16, 16, False, False, True): (1, 1, 1, 4),
|
||||
(16, 32, 16, 16, 16, False, True, False): (1, 1, 1, 2),
|
||||
(16, 32, 16, 16, 16, False, True, True): (1, 1, 1, 1),
|
||||
(16, 32, 16, 16, 16, True, False, False): (1, 1, 1, 2),
|
||||
(16, 32, 16, 16, 16, True, False, True): (2, 1, 1, 2),
|
||||
(16, 32, 16, 16, 32, False, False, False): (1, 1, 1, 4),
|
||||
(16, 32, 16, 16, 32, False, False, True): (1, 1, 1, 8),
|
||||
(16, 32, 16, 16, 32, False, True, False): (1, 1, 1, 8),
|
||||
(16, 32, 16, 16, 32, False, True, True): (1, 1, 2, 4),
|
||||
(16, 32, 16, 16, 32, True, False, False): (1, 1, 1, 2),
|
||||
(16, 32, 16, 16, 32, True, False, True): (1, 1, 1, 1),
|
||||
(16, 32, 32, 16, 16, False, False, False): (2, 2, 1, 4),
|
||||
(16, 32, 32, 16, 16, False, False, True): (2, 2, 1, 2),
|
||||
(16, 32, 32, 16, 16, False, True, False): (1, 1, 2, 8),
|
||||
(16, 32, 32, 16, 16, False, True, True): (1, 2, 1, 1),
|
||||
(16, 32, 32, 16, 16, True, False, False): (1, 1, 1, 8),
|
||||
(16, 32, 32, 16, 16, True, False, True): (1, 2, 1, 4),
|
||||
(16, 32, 32, 16, 32, False, False, False): (1, 1, 2, 8),
|
||||
(16, 32, 32, 16, 32, False, False, True): (2, 1, 1, 8),
|
||||
(16, 32, 32, 16, 32, False, True, False): (1, 1, 1, 4),
|
||||
(16, 32, 32, 16, 32, False, True, True): (1, 1, 1, 4),
|
||||
(16, 32, 32, 16, 32, True, False, False): (1, 2, 1, 8),
|
||||
(16, 32, 32, 16, 32, True, False, True): (1, 1, 1, 4),
|
||||
(16, 32, 64, 16, 16, False, False, False): (1, 4, 3, 8),
|
||||
(16, 32, 64, 16, 16, False, False, True): (1, 4, 1, 4),
|
||||
(16, 32, 64, 16, 16, False, True, False): (1, 4, 1, 4),
|
||||
(16, 32, 64, 16, 16, False, True, True): (2, 4, 1, 4),
|
||||
(16, 32, 64, 16, 16, True, False, False): (1, 2, 1, 4),
|
||||
(16, 32, 64, 16, 16, True, False, True): (1, 2, 1, 4),
|
||||
(16, 32, 64, 16, 32, False, False, False): (1, 4, 1, 8),
|
||||
(16, 32, 64, 16, 32, False, False, True): (1, 4, 1, 4),
|
||||
(16, 32, 64, 16, 32, False, True, False): (1, 4, 1, 2),
|
||||
(16, 32, 64, 16, 32, False, True, True): (1, 2, 1, 4),
|
||||
(16, 32, 64, 16, 32, True, False, False): (1, 2, 1, 4),
|
||||
(16, 32, 64, 16, 32, True, False, True): (1, 2, 1, 2),
|
||||
(16, 64, 16, 16, 32, False, False, False): (1, 1, 1, 2),
|
||||
(16, 64, 16, 16, 32, False, False, True): (1, 1, 2, 2),
|
||||
(16, 64, 16, 16, 32, False, True, False): (1, 1, 2, 8),
|
||||
(16, 64, 16, 16, 32, False, True, True): (1, 1, 1, 4),
|
||||
(16, 64, 16, 16, 32, True, False, False): (1, 1, 1, 8),
|
||||
(16, 64, 16, 16, 32, True, False, True): (1, 1, 1, 4),
|
||||
(16, 64, 32, 16, 32, False, False, False): (1, 2, 1, 2),
|
||||
(16, 64, 32, 16, 32, False, False, True): (1, 2, 1, 4),
|
||||
(16, 64, 32, 16, 32, False, True, False): (1, 2, 1, 4),
|
||||
(16, 64, 32, 16, 32, False, True, True): (2, 2, 1, 4),
|
||||
(16, 64, 32, 16, 32, True, False, False): (1, 2, 1, 4),
|
||||
(16, 64, 32, 16, 32, True, False, True): (1, 2, 1, 8),
|
||||
(16, 64, 64, 16, 32, False, False, False): (1, 2, 1, 4),
|
||||
(16, 64, 64, 16, 32, False, False, True): (1, 4, 2, 2),
|
||||
(16, 64, 64, 16, 32, False, True, False): (1, 1, 1, 4),
|
||||
(16, 64, 64, 16, 32, False, True, True): (1, 4, 1, 2),
|
||||
(16, 64, 64, 16, 32, True, False, False): (1, 2, 1, 4),
|
||||
(16, 64, 64, 16, 32, True, False, True): (1, 4, 1, 4),
|
||||
(32, 16, 16, 16, 16, False, False, False): (1, 1, 1, 8),
|
||||
(32, 16, 16, 16, 16, False, False, True): (1, 1, 2, 4),
|
||||
(32, 16, 16, 16, 16, False, True, False): (1, 1, 1, 4),
|
||||
(32, 16, 16, 16, 16, False, True, True): (1, 1, 2, 4),
|
||||
(32, 16, 16, 16, 16, True, False, False): (1, 1, 1, 2),
|
||||
(32, 16, 16, 16, 16, True, False, True): (1, 1, 1, 4),
|
||||
(32, 16, 32, 16, 16, False, False, False): (1, 1, 1, 4),
|
||||
(32, 16, 32, 16, 16, False, False, True): (2, 2, 1, 4),
|
||||
(32, 16, 32, 16, 16, False, True, False): (1, 2, 2, 2),
|
||||
(32, 16, 32, 16, 16, False, True, True): (2, 2, 1, 4),
|
||||
(32, 16, 32, 16, 16, True, False, False): (1, 2, 2, 8),
|
||||
(32, 16, 32, 16, 16, True, False, True): (1, 2, 1, 2),
|
||||
(32, 16, 64, 16, 16, False, False, False): (1, 4, 1, 4),
|
||||
(32, 16, 64, 16, 16, False, False, True): (1, 4, 2, 4),
|
||||
(32, 16, 64, 16, 16, False, True, False): (1, 2, 2, 2),
|
||||
(32, 16, 64, 16, 16, False, True, True): (3, 4, 1, 4),
|
||||
(32, 16, 64, 16, 16, True, False, False): (1, 2, 1, 2),
|
||||
(32, 16, 64, 16, 16, True, False, True): (1, 2, 1, 4),
|
||||
(32, 32, 16, 16, 16, False, False, False): (1, 1, 3, 4),
|
||||
(32, 32, 16, 16, 16, False, False, True): (1, 1, 1, 4),
|
||||
(32, 32, 16, 16, 16, False, True, False): (1, 1, 1, 2),
|
||||
(32, 32, 16, 16, 16, False, True, True): (1, 1, 1, 4),
|
||||
(32, 32, 16, 16, 16, True, False, False): (1, 1, 1, 4),
|
||||
(32, 32, 16, 16, 16, True, False, True): (1, 1, 2, 2),
|
||||
(32, 32, 16, 16, 32, False, False, False): (2, 1, 1, 4),
|
||||
(32, 32, 16, 16, 32, False, False, True): (1, 1, 1, 4),
|
||||
(32, 32, 16, 16, 32, False, True, False): (1, 1, 1, 4),
|
||||
(32, 32, 16, 16, 32, False, True, True): (3, 1, 2, 4),
|
||||
(32, 32, 16, 16, 32, True, False, False): (1, 1, 1, 4),
|
||||
(32, 32, 16, 16, 32, True, False, True): (1, 1, 1, 4),
|
||||
(32, 32, 16, 32, 32, False, False, False): (1, 1, 1, 8),
|
||||
(32, 32, 16, 32, 32, False, False, True): (1, 1, 1, 4),
|
||||
(32, 32, 16, 32, 32, False, True, False): (1, 1, 2, 1),
|
||||
(32, 32, 16, 32, 32, False, True, True): (2, 1, 2, 2),
|
||||
(32, 32, 16, 32, 32, True, False, False): (1, 1, 1, 8),
|
||||
(32, 32, 16, 32, 32, True, False, True): (2, 1, 3, 4),
|
||||
(32, 32, 32, 16, 16, False, False, False): (1, 2, 1, 4),
|
||||
(32, 32, 32, 16, 16, False, False, True): (2, 2, 1, 4),
|
||||
(32, 32, 32, 16, 16, False, True, False): (1, 1, 1, 8),
|
||||
(32, 32, 32, 16, 16, False, True, True): (2, 2, 1, 4),
|
||||
(32, 32, 32, 16, 16, True, False, False): (1, 1, 1, 4),
|
||||
(32, 32, 32, 16, 16, True, False, True): (2, 2, 2, 4),
|
||||
(32, 32, 32, 16, 32, False, False, False): (2, 2, 1, 8),
|
||||
(32, 32, 32, 16, 32, False, False, True): (1, 2, 1, 2),
|
||||
(32, 32, 32, 16, 32, False, True, False): (1, 2, 1, 4),
|
||||
(32, 32, 32, 16, 32, False, True, True): (1, 2, 1, 4),
|
||||
(32, 32, 32, 16, 32, True, False, False): (1, 2, 1, 4),
|
||||
(32, 32, 32, 16, 32, True, False, True): (1, 2, 1, 2),
|
||||
(32, 32, 32, 32, 32, False, False, False): (1, 1, 3, 8),
|
||||
(32, 32, 32, 32, 32, False, False, True): (1, 1, 1, 8),
|
||||
(32, 32, 32, 32, 32, False, True, False): (2, 1, 3, 4),
|
||||
(32, 32, 32, 32, 32, False, True, True): (2, 1, 1, 2),
|
||||
(32, 32, 32, 32, 32, True, False, False): (1, 1, 1, 2),
|
||||
(32, 32, 32, 32, 32, True, False, True): (4, 1, 1, 1),
|
||||
(32, 32, 64, 16, 16, False, False, False): (1, 4, 1, 4),
|
||||
(32, 32, 64, 16, 16, False, False, True): (1, 4, 1, 4),
|
||||
(32, 32, 64, 16, 16, False, True, False): (1, 2, 1, 8),
|
||||
(32, 32, 64, 16, 16, False, True, True): (1, 4, 1, 2),
|
||||
(32, 32, 64, 16, 16, True, False, False): (2, 4, 1, 2),
|
||||
(32, 32, 64, 16, 16, True, False, True): (1, 4, 1, 2),
|
||||
(32, 32, 64, 16, 32, False, False, False): (1, 2, 1, 8),
|
||||
(32, 32, 64, 16, 32, False, False, True): (1, 4, 2, 2),
|
||||
(32, 32, 64, 16, 32, False, True, False): (1, 2, 1, 4),
|
||||
(32, 32, 64, 16, 32, False, True, True): (1, 4, 1, 4),
|
||||
(32, 32, 64, 16, 32, True, False, False): (1, 4, 2, 2),
|
||||
(32, 32, 64, 16, 32, True, False, True): (3, 4, 2, 2),
|
||||
(32, 32, 64, 32, 32, False, False, False): (2, 2, 1, 4),
|
||||
(32, 32, 64, 32, 32, False, False, True): (1, 2, 1, 4),
|
||||
(32, 32, 64, 32, 32, False, True, False): (1, 1, 1, 8),
|
||||
(32, 32, 64, 32, 32, False, True, True): (1, 1, 1, 4),
|
||||
(32, 32, 64, 32, 32, True, False, False): (1, 2, 1, 2),
|
||||
(32, 32, 64, 32, 32, True, False, True): (3, 2, 1, 8),
|
||||
(32, 64, 16, 16, 32, False, False, False): (1, 1, 2, 2),
|
||||
(32, 64, 16, 16, 32, False, False, True): (1, 1, 1, 4),
|
||||
(32, 64, 16, 16, 32, False, True, False): (1, 1, 2, 4),
|
||||
(32, 64, 16, 16, 32, False, True, True): (1, 1, 1, 4),
|
||||
(32, 64, 16, 16, 32, True, False, False): (1, 1, 1, 2),
|
||||
(32, 64, 16, 16, 32, True, False, True): (2, 1, 2, 2),
|
||||
(32, 64, 16, 32, 32, False, False, False): (1, 1, 1, 1),
|
||||
(32, 64, 16, 32, 32, False, False, True): (2, 1, 1, 4),
|
||||
(32, 64, 16, 32, 32, False, True, False): (1, 1, 1, 1),
|
||||
(32, 64, 16, 32, 32, False, True, True): (1, 1, 2, 2),
|
||||
(32, 64, 16, 32, 32, True, False, False): (1, 1, 2, 4),
|
||||
(32, 64, 16, 32, 32, True, False, True): (1, 1, 1, 4),
|
||||
(32, 64, 32, 16, 32, False, False, False): (2, 2, 1, 4),
|
||||
(32, 64, 32, 16, 32, False, False, True): (1, 2, 1, 4),
|
||||
(32, 64, 32, 16, 32, False, True, False): (1, 1, 1, 4),
|
||||
(32, 64, 32, 16, 32, False, True, True): (2, 2, 3, 4),
|
||||
(32, 64, 32, 16, 32, True, False, False): (1, 1, 1, 2),
|
||||
(32, 64, 32, 16, 32, True, False, True): (1, 2, 1, 2),
|
||||
(32, 64, 32, 32, 32, False, False, False): (1, 1, 1, 2),
|
||||
(32, 64, 32, 32, 32, False, False, True): (2, 1, 1, 4),
|
||||
(32, 64, 32, 32, 32, False, True, False): (1, 1, 1, 8),
|
||||
(32, 64, 32, 32, 32, False, True, True): (1, 1, 2, 4),
|
||||
(32, 64, 32, 32, 32, True, False, False): (2, 1, 1, 4),
|
||||
(32, 64, 32, 32, 32, True, False, True): (1, 1, 2, 4),
|
||||
(32, 64, 64, 16, 32, False, False, False): (1, 4, 1, 4),
|
||||
(32, 64, 64, 16, 32, False, False, True): (1, 4, 2, 4),
|
||||
(32, 64, 64, 16, 32, False, True, False): (1, 4, 2, 2),
|
||||
(32, 64, 64, 16, 32, False, True, True): (1, 4, 1, 4),
|
||||
(32, 64, 64, 16, 32, True, False, False): (1, 4, 1, 8),
|
||||
(32, 64, 64, 16, 32, True, False, True): (1, 4, 2, 1),
|
||||
(32, 64, 64, 32, 32, False, False, False): (1, 1, 1, 4),
|
||||
(32, 64, 64, 32, 32, False, False, True): (2, 2, 1, 4),
|
||||
(32, 64, 64, 32, 32, False, True, False): (1, 1, 1, 4),
|
||||
(32, 64, 64, 32, 32, False, True, True): (2, 2, 1, 4),
|
||||
(32, 64, 64, 32, 32, True, False, False): (1, 2, 2, 4),
|
||||
(32, 64, 64, 32, 32, True, False, True): (2, 2, 3, 4),
|
||||
(64, 32, 16, 32, 32, False, False, False): (1, 1, 1, 4),
|
||||
(64, 32, 16, 32, 32, False, False, True): (1, 1, 1, 4),
|
||||
(64, 32, 16, 32, 32, False, True, False): (1, 1, 1, 8),
|
||||
(64, 32, 16, 32, 32, False, True, True): (1, 1, 1, 4),
|
||||
(64, 32, 16, 32, 32, True, False, False): (1, 1, 1, 16),
|
||||
(64, 32, 16, 32, 32, True, False, True): (2, 1, 1, 4),
|
||||
(64, 32, 32, 32, 32, False, False, False): (1, 1, 3, 4),
|
||||
(64, 32, 32, 32, 32, False, False, True): (2, 1, 1, 4),
|
||||
(64, 32, 32, 32, 32, False, True, False): (1, 1, 2, 4),
|
||||
(64, 32, 32, 32, 32, False, True, True): (2, 1, 1, 4),
|
||||
(64, 32, 32, 32, 32, True, False, False): (2, 1, 1, 16),
|
||||
(64, 32, 32, 32, 32, True, False, True): (2, 1, 1, 4),
|
||||
(64, 32, 64, 32, 32, False, False, False): (1, 2, 1, 4),
|
||||
(64, 32, 64, 32, 32, False, False, True): (2, 2, 1, 4),
|
||||
(64, 32, 64, 32, 32, False, True, False): (1, 1, 1, 4),
|
||||
(64, 32, 64, 32, 32, False, True, True): (2, 2, 1, 4),
|
||||
(64, 32, 64, 32, 32, True, False, False): (1, 2, 1, 8),
|
||||
(64, 32, 64, 32, 32, True, False, True): (2, 2, 3, 4),
|
||||
(64, 64, 16, 32, 32, False, False, False): (1, 1, 2, 16),
|
||||
(64, 64, 16, 32, 32, False, False, True): (1, 1, 3, 4),
|
||||
(64, 64, 16, 32, 32, False, True, False): (1, 1, 1, 2),
|
||||
(64, 64, 16, 32, 32, False, True, True): (2, 1, 1, 4),
|
||||
(64, 64, 16, 32, 32, True, False, False): (2, 1, 3, 2),
|
||||
(64, 64, 16, 32, 32, True, False, True): (1, 1, 2, 4),
|
||||
(64, 64, 32, 32, 32, False, False, False): (1, 1, 1, 8),
|
||||
(64, 64, 32, 32, 32, False, False, True): (2, 1, 2, 4),
|
||||
(64, 64, 32, 32, 32, False, True, False): (2, 1, 1, 4),
|
||||
(64, 64, 32, 32, 32, False, True, True): (1, 1, 2, 4),
|
||||
(64, 64, 32, 32, 32, True, False, False): (2, 1, 1, 4),
|
||||
(64, 64, 32, 32, 32, True, False, True): (1, 1, 2, 4),
|
||||
(64, 64, 64, 32, 32, False, False, False): (1, 2, 2, 4),
|
||||
(64, 64, 64, 32, 32, False, False, True): (1, 2, 2, 2),
|
||||
(64, 64, 64, 32, 32, False, True, False): (1, 2, 1, 2),
|
||||
(64, 64, 64, 32, 32, False, True, True): (1, 2, 1, 4),
|
||||
(64, 64, 64, 32, 32, True, False, False): (1, 2, 1, 4),
|
||||
(64, 64, 64, 32, 32, True, False, True): (1, 2, 1, 4),
|
||||
(256, 256, 256, 16, 16, False, True, True): (4, 8, 5, 2),
|
||||
(256, 256, 256, 32, 32, False, True, True): (1, 8, 5, 4),
|
||||
(256, 256, 256, 64, 64, False, True, True): (2, 4, 4, 4),
|
||||
@ -1018,6 +1236,222 @@ _operation_device_version_data: Dict[Any, Dict] = {
|
||||
(16384, 16384, 131072, 128, 128, False, True, True): (4, 1024, 1, 4),
|
||||
},
|
||||
("bsr_dense_addmm", "NVIDIA A100-SXM4-80GB", (0, torch.float16, 0.5)): {
|
||||
(16, 16, 16, 16, 16, False, False, False): (1, 1, 1, 1),
|
||||
(16, 16, 16, 16, 16, False, False, True): (1, 1, 2, 2),
|
||||
(16, 16, 16, 16, 16, False, True, False): (1, 1, 1, 1),
|
||||
(16, 16, 16, 16, 16, False, True, True): (1, 1, 1, 8),
|
||||
(16, 16, 16, 16, 16, True, False, False): (3, 1, 3, 4),
|
||||
(16, 16, 16, 16, 16, True, False, True): (1, 1, 2, 1),
|
||||
(16, 16, 32, 16, 16, False, False, False): (1, 2, 1, 8),
|
||||
(16, 16, 32, 16, 16, False, False, True): (1, 2, 1, 2),
|
||||
(16, 16, 32, 16, 16, False, True, False): (2, 1, 1, 4),
|
||||
(16, 16, 32, 16, 16, False, True, True): (1, 2, 1, 4),
|
||||
(16, 16, 32, 16, 16, True, False, False): (1, 1, 1, 4),
|
||||
(16, 16, 32, 16, 16, True, False, True): (1, 2, 1, 2),
|
||||
(16, 16, 64, 16, 16, False, False, False): (1, 4, 1, 1),
|
||||
(16, 16, 64, 16, 16, False, False, True): (1, 2, 2, 4),
|
||||
(16, 16, 64, 16, 16, False, True, False): (1, 4, 1, 4),
|
||||
(16, 16, 64, 16, 16, False, True, True): (1, 2, 1, 4),
|
||||
(16, 16, 64, 16, 16, True, False, False): (1, 4, 1, 2),
|
||||
(16, 16, 64, 16, 16, True, False, True): (1, 1, 1, 2),
|
||||
(16, 32, 16, 16, 16, False, False, False): (1, 1, 2, 4),
|
||||
(16, 32, 16, 16, 16, False, False, True): (1, 1, 1, 4),
|
||||
(16, 32, 16, 16, 16, False, True, False): (1, 1, 1, 2),
|
||||
(16, 32, 16, 16, 16, False, True, True): (1, 1, 1, 2),
|
||||
(16, 32, 16, 16, 16, True, False, False): (1, 1, 2, 16),
|
||||
(16, 32, 16, 16, 16, True, False, True): (1, 1, 1, 4),
|
||||
(16, 32, 16, 16, 32, False, False, False): (2, 1, 1, 8),
|
||||
(16, 32, 16, 16, 32, False, False, True): (2, 1, 1, 8),
|
||||
(16, 32, 16, 16, 32, False, True, False): (1, 1, 2, 1),
|
||||
(16, 32, 16, 16, 32, False, True, True): (1, 1, 1, 4),
|
||||
(16, 32, 16, 16, 32, True, False, False): (2, 1, 1, 8),
|
||||
(16, 32, 16, 16, 32, True, False, True): (1, 1, 2, 4),
|
||||
(16, 32, 32, 16, 16, False, False, False): (1, 1, 1, 16),
|
||||
(16, 32, 32, 16, 16, False, False, True): (1, 2, 1, 2),
|
||||
(16, 32, 32, 16, 16, False, True, False): (1, 2, 1, 8),
|
||||
(16, 32, 32, 16, 16, False, True, True): (3, 2, 1, 4),
|
||||
(16, 32, 32, 16, 16, True, False, False): (1, 2, 1, 4),
|
||||
(16, 32, 32, 16, 16, True, False, True): (1, 2, 1, 2),
|
||||
(16, 32, 32, 16, 32, False, False, False): (1, 2, 1, 2),
|
||||
(16, 32, 32, 16, 32, False, False, True): (1, 1, 1, 4),
|
||||
(16, 32, 32, 16, 32, False, True, False): (1, 1, 2, 4),
|
||||
(16, 32, 32, 16, 32, False, True, True): (1, 2, 1, 2),
|
||||
(16, 32, 32, 16, 32, True, False, False): (1, 2, 1, 2),
|
||||
(16, 32, 32, 16, 32, True, False, True): (1, 2, 1, 16),
|
||||
(16, 32, 64, 16, 16, False, False, False): (1, 4, 1, 4),
|
||||
(16, 32, 64, 16, 16, False, False, True): (2, 4, 1, 4),
|
||||
(16, 32, 64, 16, 16, False, True, False): (1, 4, 1, 4),
|
||||
(16, 32, 64, 16, 16, False, True, True): (1, 4, 1, 4),
|
||||
(16, 32, 64, 16, 16, True, False, False): (3, 4, 1, 2),
|
||||
(16, 32, 64, 16, 16, True, False, True): (1, 4, 1, 1),
|
||||
(16, 32, 64, 16, 32, False, False, False): (1, 4, 1, 16),
|
||||
(16, 32, 64, 16, 32, False, False, True): (1, 2, 1, 2),
|
||||
(16, 32, 64, 16, 32, False, True, False): (1, 4, 2, 2),
|
||||
(16, 32, 64, 16, 32, False, True, True): (1, 4, 1, 8),
|
||||
(16, 32, 64, 16, 32, True, False, False): (1, 4, 1, 8),
|
||||
(16, 32, 64, 16, 32, True, False, True): (1, 2, 1, 4),
|
||||
(16, 64, 16, 16, 32, False, False, False): (1, 1, 1, 2),
|
||||
(16, 64, 16, 16, 32, False, False, True): (1, 1, 1, 4),
|
||||
(16, 64, 16, 16, 32, False, True, False): (2, 1, 2, 4),
|
||||
(16, 64, 16, 16, 32, False, True, True): (1, 1, 1, 4),
|
||||
(16, 64, 16, 16, 32, True, False, False): (1, 1, 1, 4),
|
||||
(16, 64, 16, 16, 32, True, False, True): (1, 1, 1, 4),
|
||||
(16, 64, 32, 16, 32, False, False, False): (1, 2, 1, 2),
|
||||
(16, 64, 32, 16, 32, False, False, True): (1, 1, 1, 4),
|
||||
(16, 64, 32, 16, 32, False, True, False): (1, 1, 1, 4),
|
||||
(16, 64, 32, 16, 32, False, True, True): (1, 2, 3, 2),
|
||||
(16, 64, 32, 16, 32, True, False, False): (1, 1, 1, 4),
|
||||
(16, 64, 32, 16, 32, True, False, True): (1, 1, 2, 4),
|
||||
(16, 64, 64, 16, 32, False, False, False): (1, 4, 1, 8),
|
||||
(16, 64, 64, 16, 32, False, False, True): (1, 4, 1, 4),
|
||||
(16, 64, 64, 16, 32, False, True, False): (1, 4, 1, 1),
|
||||
(16, 64, 64, 16, 32, False, True, True): (2, 4, 1, 4),
|
||||
(16, 64, 64, 16, 32, True, False, False): (1, 4, 1, 4),
|
||||
(16, 64, 64, 16, 32, True, False, True): (1, 4, 1, 4),
|
||||
(32, 16, 16, 16, 16, False, False, False): (2, 1, 2, 4),
|
||||
(32, 16, 16, 16, 16, False, False, True): (2, 1, 1, 2),
|
||||
(32, 16, 16, 16, 16, False, True, False): (1, 1, 2, 4),
|
||||
(32, 16, 16, 16, 16, False, True, True): (1, 1, 1, 2),
|
||||
(32, 16, 16, 16, 16, True, False, False): (1, 1, 1, 4),
|
||||
(32, 16, 16, 16, 16, True, False, True): (2, 1, 1, 2),
|
||||
(32, 16, 32, 16, 16, False, False, False): (1, 1, 1, 4),
|
||||
(32, 16, 32, 16, 16, False, False, True): (1, 1, 1, 4),
|
||||
(32, 16, 32, 16, 16, False, True, False): (1, 2, 1, 4),
|
||||
(32, 16, 32, 16, 16, False, True, True): (2, 2, 1, 4),
|
||||
(32, 16, 32, 16, 16, True, False, False): (2, 1, 1, 4),
|
||||
(32, 16, 32, 16, 16, True, False, True): (2, 2, 1, 2),
|
||||
(32, 16, 64, 16, 16, False, False, False): (1, 4, 1, 2),
|
||||
(32, 16, 64, 16, 16, False, False, True): (1, 4, 1, 4),
|
||||
(32, 16, 64, 16, 16, False, True, False): (1, 2, 1, 4),
|
||||
(32, 16, 64, 16, 16, False, True, True): (1, 4, 1, 2),
|
||||
(32, 16, 64, 16, 16, True, False, False): (1, 4, 2, 8),
|
||||
(32, 16, 64, 16, 16, True, False, True): (1, 4, 1, 1),
|
||||
(32, 32, 16, 16, 16, False, False, False): (1, 1, 1, 4),
|
||||
(32, 32, 16, 16, 16, False, False, True): (2, 1, 1, 4),
|
||||
(32, 32, 16, 16, 16, False, True, False): (1, 1, 2, 4),
|
||||
(32, 32, 16, 16, 16, False, True, True): (1, 1, 2, 2),
|
||||
(32, 32, 16, 16, 16, True, False, False): (1, 1, 1, 8),
|
||||
(32, 32, 16, 16, 16, True, False, True): (1, 1, 1, 4),
|
||||
(32, 32, 16, 16, 32, False, False, False): (1, 1, 3, 2),
|
||||
(32, 32, 16, 16, 32, False, False, True): (2, 1, 1, 4),
|
||||
(32, 32, 16, 16, 32, False, True, False): (3, 1, 1, 4),
|
||||
(32, 32, 16, 16, 32, False, True, True): (1, 1, 1, 4),
|
||||
(32, 32, 16, 16, 32, True, False, False): (2, 1, 1, 8),
|
||||
(32, 32, 16, 16, 32, True, False, True): (1, 1, 3, 2),
|
||||
(32, 32, 16, 32, 32, False, False, False): (1, 1, 1, 2),
|
||||
(32, 32, 16, 32, 32, False, False, True): (2, 1, 1, 8),
|
||||
(32, 32, 16, 32, 32, False, True, False): (1, 1, 1, 2),
|
||||
(32, 32, 16, 32, 32, False, True, True): (1, 1, 1, 8),
|
||||
(32, 32, 16, 32, 32, True, False, False): (1, 1, 2, 4),
|
||||
(32, 32, 16, 32, 32, True, False, True): (1, 1, 1, 2),
|
||||
(32, 32, 32, 16, 16, False, False, False): (1, 1, 1, 4),
|
||||
(32, 32, 32, 16, 16, False, False, True): (1, 2, 1, 4),
|
||||
(32, 32, 32, 16, 16, False, True, False): (1, 2, 1, 4),
|
||||
(32, 32, 32, 16, 16, False, True, True): (1, 2, 1, 2),
|
||||
(32, 32, 32, 16, 16, True, False, False): (1, 2, 1, 4),
|
||||
(32, 32, 32, 16, 16, True, False, True): (1, 2, 1, 4),
|
||||
(32, 32, 32, 16, 32, False, False, False): (1, 2, 1, 4),
|
||||
(32, 32, 32, 16, 32, False, False, True): (1, 2, 1, 2),
|
||||
(32, 32, 32, 16, 32, False, True, False): (1, 2, 1, 4),
|
||||
(32, 32, 32, 16, 32, False, True, True): (1, 2, 1, 2),
|
||||
(32, 32, 32, 16, 32, True, False, False): (1, 2, 1, 1),
|
||||
(32, 32, 32, 16, 32, True, False, True): (1, 2, 1, 2),
|
||||
(32, 32, 32, 32, 32, False, False, False): (1, 1, 1, 4),
|
||||
(32, 32, 32, 32, 32, False, False, True): (2, 1, 1, 4),
|
||||
(32, 32, 32, 32, 32, False, True, False): (1, 1, 1, 8),
|
||||
(32, 32, 32, 32, 32, False, True, True): (1, 1, 1, 8),
|
||||
(32, 32, 32, 32, 32, True, False, False): (1, 1, 3, 4),
|
||||
(32, 32, 32, 32, 32, True, False, True): (1, 1, 1, 8),
|
||||
(32, 32, 64, 16, 16, False, False, False): (1, 4, 1, 4),
|
||||
(32, 32, 64, 16, 16, False, False, True): (1, 4, 1, 2),
|
||||
(32, 32, 64, 16, 16, False, True, False): (1, 1, 1, 4),
|
||||
(32, 32, 64, 16, 16, False, True, True): (1, 4, 1, 4),
|
||||
(32, 32, 64, 16, 16, True, False, False): (1, 4, 1, 8),
|
||||
(32, 32, 64, 16, 16, True, False, True): (1, 4, 1, 2),
|
||||
(32, 32, 64, 16, 32, False, False, False): (1, 1, 1, 4),
|
||||
(32, 32, 64, 16, 32, False, False, True): (1, 4, 1, 4),
|
||||
(32, 32, 64, 16, 32, False, True, False): (1, 1, 1, 4),
|
||||
(32, 32, 64, 16, 32, False, True, True): (1, 4, 1, 4),
|
||||
(32, 32, 64, 16, 32, True, False, False): (2, 2, 1, 8),
|
||||
(32, 32, 64, 16, 32, True, False, True): (1, 2, 1, 2),
|
||||
(32, 32, 64, 32, 32, False, False, False): (1, 2, 1, 4),
|
||||
(32, 32, 64, 32, 32, False, False, True): (1, 2, 1, 1),
|
||||
(32, 32, 64, 32, 32, False, True, False): (1, 2, 2, 8),
|
||||
(32, 32, 64, 32, 32, False, True, True): (1, 1, 1, 4),
|
||||
(32, 32, 64, 32, 32, True, False, False): (1, 2, 1, 4),
|
||||
(32, 32, 64, 32, 32, True, False, True): (2, 2, 1, 4),
|
||||
(32, 64, 16, 16, 32, False, False, False): (1, 1, 1, 8),
|
||||
(32, 64, 16, 16, 32, False, False, True): (1, 1, 1, 4),
|
||||
(32, 64, 16, 16, 32, False, True, False): (2, 1, 1, 4),
|
||||
(32, 64, 16, 16, 32, False, True, True): (1, 1, 1, 4),
|
||||
(32, 64, 16, 16, 32, True, False, False): (1, 1, 2, 4),
|
||||
(32, 64, 16, 16, 32, True, False, True): (1, 1, 2, 2),
|
||||
(32, 64, 16, 32, 32, False, False, False): (1, 1, 1, 8),
|
||||
(32, 64, 16, 32, 32, False, False, True): (2, 1, 1, 4),
|
||||
(32, 64, 16, 32, 32, False, True, False): (1, 1, 1, 4),
|
||||
(32, 64, 16, 32, 32, False, True, True): (1, 1, 2, 2),
|
||||
(32, 64, 16, 32, 32, True, False, False): (1, 1, 1, 2),
|
||||
(32, 64, 16, 32, 32, True, False, True): (2, 1, 2, 4),
|
||||
(32, 64, 32, 16, 32, False, False, False): (1, 1, 1, 4),
|
||||
(32, 64, 32, 16, 32, False, False, True): (1, 2, 1, 2),
|
||||
(32, 64, 32, 16, 32, False, True, False): (1, 2, 3, 4),
|
||||
(32, 64, 32, 16, 32, False, True, True): (2, 2, 1, 4),
|
||||
(32, 64, 32, 16, 32, True, False, False): (1, 1, 1, 4),
|
||||
(32, 64, 32, 16, 32, True, False, True): (1, 2, 2, 1),
|
||||
(32, 64, 32, 32, 32, False, False, False): (1, 1, 1, 8),
|
||||
(32, 64, 32, 32, 32, False, False, True): (1, 1, 1, 4),
|
||||
(32, 64, 32, 32, 32, False, True, False): (1, 1, 2, 4),
|
||||
(32, 64, 32, 32, 32, False, True, True): (1, 1, 1, 4),
|
||||
(32, 64, 32, 32, 32, True, False, False): (2, 1, 1, 2),
|
||||
(32, 64, 32, 32, 32, True, False, True): (1, 1, 1, 4),
|
||||
(32, 64, 64, 16, 32, False, False, False): (1, 4, 2, 1),
|
||||
(32, 64, 64, 16, 32, False, False, True): (3, 4, 1, 4),
|
||||
(32, 64, 64, 16, 32, False, True, False): (1, 1, 1, 8),
|
||||
(32, 64, 64, 16, 32, False, True, True): (1, 4, 1, 4),
|
||||
(32, 64, 64, 16, 32, True, False, False): (1, 4, 1, 4),
|
||||
(32, 64, 64, 16, 32, True, False, True): (2, 2, 3, 4),
|
||||
(32, 64, 64, 32, 32, False, False, False): (1, 2, 1, 4),
|
||||
(32, 64, 64, 32, 32, False, False, True): (1, 2, 1, 4),
|
||||
(32, 64, 64, 32, 32, False, True, False): (1, 2, 2, 8),
|
||||
(32, 64, 64, 32, 32, False, True, True): (1, 2, 1, 4),
|
||||
(32, 64, 64, 32, 32, True, False, False): (1, 2, 2, 4),
|
||||
(32, 64, 64, 32, 32, True, False, True): (1, 2, 1, 4),
|
||||
(64, 32, 16, 32, 32, False, False, False): (1, 1, 1, 1),
|
||||
(64, 32, 16, 32, 32, False, False, True): (1, 1, 2, 4),
|
||||
(64, 32, 16, 32, 32, False, True, False): (2, 1, 1, 8),
|
||||
(64, 32, 16, 32, 32, False, True, True): (1, 1, 1, 4),
|
||||
(64, 32, 16, 32, 32, True, False, False): (2, 1, 1, 2),
|
||||
(64, 32, 16, 32, 32, True, False, True): (1, 1, 1, 4),
|
||||
(64, 32, 32, 32, 32, False, False, False): (3, 1, 1, 4),
|
||||
(64, 32, 32, 32, 32, False, False, True): (1, 1, 1, 4),
|
||||
(64, 32, 32, 32, 32, False, True, False): (1, 1, 1, 8),
|
||||
(64, 32, 32, 32, 32, False, True, True): (1, 1, 1, 2),
|
||||
(64, 32, 32, 32, 32, True, False, False): (1, 1, 1, 2),
|
||||
(64, 32, 32, 32, 32, True, False, True): (1, 1, 1, 4),
|
||||
(64, 32, 64, 32, 32, False, False, False): (1, 2, 1, 2),
|
||||
(64, 32, 64, 32, 32, False, False, True): (3, 2, 1, 4),
|
||||
(64, 32, 64, 32, 32, False, True, False): (1, 1, 1, 1),
|
||||
(64, 32, 64, 32, 32, False, True, True): (1, 2, 1, 4),
|
||||
(64, 32, 64, 32, 32, True, False, False): (1, 1, 3, 4),
|
||||
(64, 32, 64, 32, 32, True, False, True): (1, 2, 2, 4),
|
||||
(64, 64, 16, 32, 32, False, False, False): (1, 1, 2, 2),
|
||||
(64, 64, 16, 32, 32, False, False, True): (1, 1, 3, 2),
|
||||
(64, 64, 16, 32, 32, False, True, False): (1, 1, 1, 8),
|
||||
(64, 64, 16, 32, 32, False, True, True): (1, 1, 2, 4),
|
||||
(64, 64, 16, 32, 32, True, False, False): (1, 1, 2, 4),
|
||||
(64, 64, 16, 32, 32, True, False, True): (2, 1, 2, 4),
|
||||
(64, 64, 32, 32, 32, False, False, False): (1, 1, 2, 8),
|
||||
(64, 64, 32, 32, 32, False, False, True): (1, 1, 2, 4),
|
||||
(64, 64, 32, 32, 32, False, True, False): (1, 1, 1, 4),
|
||||
(64, 64, 32, 32, 32, False, True, True): (1, 1, 1, 4),
|
||||
(64, 64, 32, 32, 32, True, False, False): (1, 1, 1, 4),
|
||||
(64, 64, 32, 32, 32, True, False, True): (2, 1, 2, 4),
|
||||
(64, 64, 64, 32, 32, False, False, False): (1, 2, 1, 4),
|
||||
(64, 64, 64, 32, 32, False, False, True): (1, 2, 1, 4),
|
||||
(64, 64, 64, 32, 32, False, True, False): (1, 2, 1, 4),
|
||||
(64, 64, 64, 32, 32, False, True, True): (3, 2, 1, 4),
|
||||
(64, 64, 64, 32, 32, True, False, False): (1, 2, 1, 8),
|
||||
(64, 64, 64, 32, 32, True, False, True): (1, 2, 3, 4),
|
||||
(256, 256, 256, 16, 16, False, True, True): (4, 8, 5, 2),
|
||||
(256, 256, 256, 32, 32, False, True, True): (1, 8, 5, 4),
|
||||
(256, 256, 256, 64, 64, False, True, True): (1, 4, 5, 4),
|
||||
@ -1300,6 +1734,222 @@ _operation_device_version_data: Dict[Any, Dict] = {
|
||||
(16384, 16384, 131072, 128, 128, False, True, True): (4, 1024, 1, 4),
|
||||
},
|
||||
("bsr_dense_addmm", "NVIDIA A100-SXM4-80GB", (0, torch.float32, 0.5)): {
|
||||
(16, 16, 16, 16, 16, False, False, False): (2, 1, 1, 16),
|
||||
(16, 16, 16, 16, 16, False, False, True): (1, 1, 2, 4),
|
||||
(16, 16, 16, 16, 16, False, True, False): (1, 1, 2, 16),
|
||||
(16, 16, 16, 16, 16, False, True, True): (2, 1, 2, 8),
|
||||
(16, 16, 16, 16, 16, True, False, False): (1, 1, 1, 2),
|
||||
(16, 16, 16, 16, 16, True, False, True): (2, 1, 1, 4),
|
||||
(16, 16, 32, 16, 16, False, False, False): (1, 1, 1, 2),
|
||||
(16, 16, 32, 16, 16, False, False, True): (1, 1, 2, 8),
|
||||
(16, 16, 32, 16, 16, False, True, False): (1, 2, 1, 4),
|
||||
(16, 16, 32, 16, 16, False, True, True): (1, 2, 2, 4),
|
||||
(16, 16, 32, 16, 16, True, False, False): (1, 1, 2, 4),
|
||||
(16, 16, 32, 16, 16, True, False, True): (1, 2, 2, 4),
|
||||
(16, 16, 64, 16, 16, False, False, False): (1, 4, 1, 4),
|
||||
(16, 16, 64, 16, 16, False, False, True): (2, 2, 1, 4),
|
||||
(16, 16, 64, 16, 16, False, True, False): (1, 4, 1, 4),
|
||||
(16, 16, 64, 16, 16, False, True, True): (1, 4, 1, 8),
|
||||
(16, 16, 64, 16, 16, True, False, False): (1, 2, 1, 4),
|
||||
(16, 16, 64, 16, 16, True, False, True): (1, 4, 2, 8),
|
||||
(16, 32, 16, 16, 16, False, False, False): (1, 1, 2, 8),
|
||||
(16, 32, 16, 16, 16, False, False, True): (2, 1, 1, 4),
|
||||
(16, 32, 16, 16, 16, False, True, False): (1, 1, 1, 4),
|
||||
(16, 32, 16, 16, 16, False, True, True): (1, 1, 1, 4),
|
||||
(16, 32, 16, 16, 16, True, False, False): (1, 1, 1, 4),
|
||||
(16, 32, 16, 16, 16, True, False, True): (1, 1, 2, 8),
|
||||
(16, 32, 16, 16, 32, False, False, False): (1, 1, 2, 4),
|
||||
(16, 32, 16, 16, 32, False, False, True): (2, 1, 2, 2),
|
||||
(16, 32, 16, 16, 32, False, True, False): (1, 1, 1, 8),
|
||||
(16, 32, 16, 16, 32, False, True, True): (1, 1, 1, 2),
|
||||
(16, 32, 16, 16, 32, True, False, False): (3, 1, 1, 4),
|
||||
(16, 32, 16, 16, 32, True, False, True): (1, 1, 1, 4),
|
||||
(16, 32, 32, 16, 16, False, False, False): (1, 2, 1, 4),
|
||||
(16, 32, 32, 16, 16, False, False, True): (2, 2, 1, 4),
|
||||
(16, 32, 32, 16, 16, False, True, False): (1, 2, 1, 2),
|
||||
(16, 32, 32, 16, 16, False, True, True): (1, 2, 1, 4),
|
||||
(16, 32, 32, 16, 16, True, False, False): (1, 2, 1, 4),
|
||||
(16, 32, 32, 16, 16, True, False, True): (1, 2, 1, 4),
|
||||
(16, 32, 32, 16, 32, False, False, False): (1, 1, 2, 4),
|
||||
(16, 32, 32, 16, 32, False, False, True): (1, 2, 1, 4),
|
||||
(16, 32, 32, 16, 32, False, True, False): (1, 2, 2, 8),
|
||||
(16, 32, 32, 16, 32, False, True, True): (1, 2, 1, 1),
|
||||
(16, 32, 32, 16, 32, True, False, False): (1, 2, 1, 2),
|
||||
(16, 32, 32, 16, 32, True, False, True): (1, 2, 1, 4),
|
||||
(16, 32, 64, 16, 16, False, False, False): (1, 2, 1, 4),
|
||||
(16, 32, 64, 16, 16, False, False, True): (2, 4, 1, 4),
|
||||
(16, 32, 64, 16, 16, False, True, False): (1, 4, 2, 4),
|
||||
(16, 32, 64, 16, 16, False, True, True): (1, 4, 1, 4),
|
||||
(16, 32, 64, 16, 16, True, False, False): (1, 2, 2, 8),
|
||||
(16, 32, 64, 16, 16, True, False, True): (1, 4, 1, 2),
|
||||
(16, 32, 64, 16, 32, False, False, False): (1, 4, 1, 4),
|
||||
(16, 32, 64, 16, 32, False, False, True): (1, 4, 3, 4),
|
||||
(16, 32, 64, 16, 32, False, True, False): (1, 2, 1, 4),
|
||||
(16, 32, 64, 16, 32, False, True, True): (1, 4, 1, 4),
|
||||
(16, 32, 64, 16, 32, True, False, False): (1, 2, 1, 8),
|
||||
(16, 32, 64, 16, 32, True, False, True): (1, 2, 1, 4),
|
||||
(16, 64, 16, 16, 32, False, False, False): (1, 1, 1, 2),
|
||||
(16, 64, 16, 16, 32, False, False, True): (1, 1, 1, 8),
|
||||
(16, 64, 16, 16, 32, False, True, False): (1, 1, 1, 8),
|
||||
(16, 64, 16, 16, 32, False, True, True): (1, 1, 1, 4),
|
||||
(16, 64, 16, 16, 32, True, False, False): (1, 1, 1, 8),
|
||||
(16, 64, 16, 16, 32, True, False, True): (1, 1, 1, 4),
|
||||
(16, 64, 32, 16, 32, False, False, False): (1, 2, 1, 4),
|
||||
(16, 64, 32, 16, 32, False, False, True): (1, 1, 1, 4),
|
||||
(16, 64, 32, 16, 32, False, True, False): (1, 2, 1, 1),
|
||||
(16, 64, 32, 16, 32, False, True, True): (1, 2, 1, 8),
|
||||
(16, 64, 32, 16, 32, True, False, False): (2, 2, 1, 4),
|
||||
(16, 64, 32, 16, 32, True, False, True): (2, 2, 1, 4),
|
||||
(16, 64, 64, 16, 32, False, False, False): (1, 2, 1, 4),
|
||||
(16, 64, 64, 16, 32, False, False, True): (1, 4, 1, 4),
|
||||
(16, 64, 64, 16, 32, False, True, False): (1, 4, 1, 4),
|
||||
(16, 64, 64, 16, 32, False, True, True): (1, 4, 1, 4),
|
||||
(16, 64, 64, 16, 32, True, False, False): (1, 4, 1, 2),
|
||||
(16, 64, 64, 16, 32, True, False, True): (3, 4, 1, 4),
|
||||
(32, 16, 16, 16, 16, False, False, False): (1, 1, 2, 4),
|
||||
(32, 16, 16, 16, 16, False, False, True): (1, 1, 1, 2),
|
||||
(32, 16, 16, 16, 16, False, True, False): (1, 1, 2, 4),
|
||||
(32, 16, 16, 16, 16, False, True, True): (1, 1, 2, 4),
|
||||
(32, 16, 16, 16, 16, True, False, False): (1, 1, 3, 8),
|
||||
(32, 16, 16, 16, 16, True, False, True): (1, 1, 2, 4),
|
||||
(32, 16, 32, 16, 16, False, False, False): (1, 2, 1, 4),
|
||||
(32, 16, 32, 16, 16, False, False, True): (1, 2, 3, 4),
|
||||
(32, 16, 32, 16, 16, False, True, False): (1, 1, 1, 8),
|
||||
(32, 16, 32, 16, 16, False, True, True): (1, 2, 1, 4),
|
||||
(32, 16, 32, 16, 16, True, False, False): (1, 1, 1, 2),
|
||||
(32, 16, 32, 16, 16, True, False, True): (1, 1, 1, 4),
|
||||
(32, 16, 64, 16, 16, False, False, False): (1, 4, 1, 4),
|
||||
(32, 16, 64, 16, 16, False, False, True): (3, 4, 1, 4),
|
||||
(32, 16, 64, 16, 16, False, True, False): (1, 4, 1, 1),
|
||||
(32, 16, 64, 16, 16, False, True, True): (1, 4, 1, 4),
|
||||
(32, 16, 64, 16, 16, True, False, False): (1, 4, 1, 4),
|
||||
(32, 16, 64, 16, 16, True, False, True): (1, 4, 1, 4),
|
||||
(32, 32, 16, 16, 16, False, False, False): (1, 1, 1, 2),
|
||||
(32, 32, 16, 16, 16, False, False, True): (2, 1, 1, 4),
|
||||
(32, 32, 16, 16, 16, False, True, False): (1, 1, 1, 2),
|
||||
(32, 32, 16, 16, 16, False, True, True): (2, 1, 1, 4),
|
||||
(32, 32, 16, 16, 16, True, False, False): (3, 1, 2, 4),
|
||||
(32, 32, 16, 16, 16, True, False, True): (1, 1, 2, 4),
|
||||
(32, 32, 16, 16, 32, False, False, False): (2, 1, 1, 2),
|
||||
(32, 32, 16, 16, 32, False, False, True): (1, 1, 1, 4),
|
||||
(32, 32, 16, 16, 32, False, True, False): (1, 1, 1, 4),
|
||||
(32, 32, 16, 16, 32, False, True, True): (1, 1, 1, 8),
|
||||
(32, 32, 16, 16, 32, True, False, False): (1, 1, 1, 8),
|
||||
(32, 32, 16, 16, 32, True, False, True): (1, 1, 1, 4),
|
||||
(32, 32, 16, 32, 32, False, False, False): (2, 1, 1, 4),
|
||||
(32, 32, 16, 32, 32, False, False, True): (1, 1, 2, 4),
|
||||
(32, 32, 16, 32, 32, False, True, False): (2, 1, 1, 1),
|
||||
(32, 32, 16, 32, 32, False, True, True): (2, 1, 2, 4),
|
||||
(32, 32, 16, 32, 32, True, False, False): (1, 1, 1, 8),
|
||||
(32, 32, 16, 32, 32, True, False, True): (1, 1, 1, 4),
|
||||
(32, 32, 32, 16, 16, False, False, False): (1, 1, 1, 4),
|
||||
(32, 32, 32, 16, 16, False, False, True): (1, 2, 1, 2),
|
||||
(32, 32, 32, 16, 16, False, True, False): (2, 2, 1, 4),
|
||||
(32, 32, 32, 16, 16, False, True, True): (1, 2, 2, 4),
|
||||
(32, 32, 32, 16, 16, True, False, False): (1, 2, 1, 4),
|
||||
(32, 32, 32, 16, 16, True, False, True): (2, 2, 1, 4),
|
||||
(32, 32, 32, 16, 32, False, False, False): (1, 2, 1, 4),
|
||||
(32, 32, 32, 16, 32, False, False, True): (1, 2, 1, 4),
|
||||
(32, 32, 32, 16, 32, False, True, False): (1, 2, 1, 4),
|
||||
(32, 32, 32, 16, 32, False, True, True): (1, 2, 1, 4),
|
||||
(32, 32, 32, 16, 32, True, False, False): (2, 1, 1, 2),
|
||||
(32, 32, 32, 16, 32, True, False, True): (2, 2, 2, 4),
|
||||
(32, 32, 32, 32, 32, False, False, False): (1, 1, 1, 4),
|
||||
(32, 32, 32, 32, 32, False, False, True): (1, 1, 1, 2),
|
||||
(32, 32, 32, 32, 32, False, True, False): (1, 1, 1, 4),
|
||||
(32, 32, 32, 32, 32, False, True, True): (1, 1, 2, 2),
|
||||
(32, 32, 32, 32, 32, True, False, False): (1, 1, 1, 2),
|
||||
(32, 32, 32, 32, 32, True, False, True): (1, 1, 2, 1),
|
||||
(32, 32, 64, 16, 16, False, False, False): (2, 4, 1, 4),
|
||||
(32, 32, 64, 16, 16, False, False, True): (1, 4, 2, 4),
|
||||
(32, 32, 64, 16, 16, False, True, False): (1, 4, 1, 4),
|
||||
(32, 32, 64, 16, 16, False, True, True): (1, 4, 1, 4),
|
||||
(32, 32, 64, 16, 16, True, False, False): (1, 2, 1, 4),
|
||||
(32, 32, 64, 16, 16, True, False, True): (2, 4, 1, 4),
|
||||
(32, 32, 64, 16, 32, False, False, False): (1, 4, 1, 8),
|
||||
(32, 32, 64, 16, 32, False, False, True): (1, 4, 1, 4),
|
||||
(32, 32, 64, 16, 32, False, True, False): (1, 4, 1, 4),
|
||||
(32, 32, 64, 16, 32, False, True, True): (2, 4, 1, 4),
|
||||
(32, 32, 64, 16, 32, True, False, False): (1, 2, 2, 4),
|
||||
(32, 32, 64, 16, 32, True, False, True): (2, 4, 1, 4),
|
||||
(32, 32, 64, 32, 32, False, False, False): (2, 2, 1, 4),
|
||||
(32, 32, 64, 32, 32, False, False, True): (1, 1, 1, 4),
|
||||
(32, 32, 64, 32, 32, False, True, False): (1, 1, 1, 8),
|
||||
(32, 32, 64, 32, 32, False, True, True): (2, 1, 1, 4),
|
||||
(32, 32, 64, 32, 32, True, False, False): (1, 1, 1, 4),
|
||||
(32, 32, 64, 32, 32, True, False, True): (1, 2, 1, 1),
|
||||
(32, 64, 16, 16, 32, False, False, False): (1, 1, 2, 2),
|
||||
(32, 64, 16, 16, 32, False, False, True): (2, 1, 1, 4),
|
||||
(32, 64, 16, 16, 32, False, True, False): (1, 1, 1, 8),
|
||||
(32, 64, 16, 16, 32, False, True, True): (1, 1, 3, 4),
|
||||
(32, 64, 16, 16, 32, True, False, False): (1, 1, 1, 2),
|
||||
(32, 64, 16, 16, 32, True, False, True): (1, 1, 2, 4),
|
||||
(32, 64, 16, 32, 32, False, False, False): (1, 1, 1, 2),
|
||||
(32, 64, 16, 32, 32, False, False, True): (1, 1, 3, 4),
|
||||
(32, 64, 16, 32, 32, False, True, False): (1, 1, 2, 4),
|
||||
(32, 64, 16, 32, 32, False, True, True): (1, 1, 1, 8),
|
||||
(32, 64, 16, 32, 32, True, False, False): (1, 1, 2, 4),
|
||||
(32, 64, 16, 32, 32, True, False, True): (1, 1, 1, 8),
|
||||
(32, 64, 32, 16, 32, False, False, False): (1, 2, 1, 4),
|
||||
(32, 64, 32, 16, 32, False, False, True): (1, 2, 3, 4),
|
||||
(32, 64, 32, 16, 32, False, True, False): (1, 2, 1, 8),
|
||||
(32, 64, 32, 16, 32, False, True, True): (3, 2, 1, 4),
|
||||
(32, 64, 32, 16, 32, True, False, False): (1, 1, 1, 8),
|
||||
(32, 64, 32, 16, 32, True, False, True): (1, 2, 1, 4),
|
||||
(32, 64, 32, 32, 32, False, False, False): (1, 1, 1, 1),
|
||||
(32, 64, 32, 32, 32, False, False, True): (1, 1, 1, 4),
|
||||
(32, 64, 32, 32, 32, False, True, False): (1, 1, 1, 4),
|
||||
(32, 64, 32, 32, 32, False, True, True): (1, 1, 1, 4),
|
||||
(32, 64, 32, 32, 32, True, False, False): (1, 1, 1, 4),
|
||||
(32, 64, 32, 32, 32, True, False, True): (1, 1, 2, 8),
|
||||
(32, 64, 64, 16, 32, False, False, False): (2, 4, 1, 4),
|
||||
(32, 64, 64, 16, 32, False, False, True): (1, 4, 1, 4),
|
||||
(32, 64, 64, 16, 32, False, True, False): (1, 4, 1, 4),
|
||||
(32, 64, 64, 16, 32, False, True, True): (2, 4, 1, 4),
|
||||
(32, 64, 64, 16, 32, True, False, False): (1, 4, 1, 4),
|
||||
(32, 64, 64, 16, 32, True, False, True): (1, 4, 1, 4),
|
||||
(32, 64, 64, 32, 32, False, False, False): (2, 2, 1, 4),
|
||||
(32, 64, 64, 32, 32, False, False, True): (1, 2, 1, 8),
|
||||
(32, 64, 64, 32, 32, False, True, False): (1, 2, 1, 4),
|
||||
(32, 64, 64, 32, 32, False, True, True): (1, 2, 1, 4),
|
||||
(32, 64, 64, 32, 32, True, False, False): (2, 2, 1, 4),
|
||||
(32, 64, 64, 32, 32, True, False, True): (1, 2, 3, 8),
|
||||
(64, 32, 16, 32, 32, False, False, False): (1, 1, 1, 4),
|
||||
(64, 32, 16, 32, 32, False, False, True): (3, 1, 2, 4),
|
||||
(64, 32, 16, 32, 32, False, True, False): (2, 1, 1, 2),
|
||||
(64, 32, 16, 32, 32, False, True, True): (1, 1, 1, 8),
|
||||
(64, 32, 16, 32, 32, True, False, False): (1, 1, 1, 2),
|
||||
(64, 32, 16, 32, 32, True, False, True): (1, 1, 1, 4),
|
||||
(64, 32, 32, 32, 32, False, False, False): (1, 1, 1, 4),
|
||||
(64, 32, 32, 32, 32, False, False, True): (1, 1, 2, 8),
|
||||
(64, 32, 32, 32, 32, False, True, False): (1, 1, 1, 8),
|
||||
(64, 32, 32, 32, 32, False, True, True): (1, 1, 1, 4),
|
||||
(64, 32, 32, 32, 32, True, False, False): (1, 1, 2, 4),
|
||||
(64, 32, 32, 32, 32, True, False, True): (1, 1, 3, 8),
|
||||
(64, 32, 64, 32, 32, False, False, False): (1, 2, 1, 4),
|
||||
(64, 32, 64, 32, 32, False, False, True): (2, 2, 1, 4),
|
||||
(64, 32, 64, 32, 32, False, True, False): (1, 1, 1, 4),
|
||||
(64, 32, 64, 32, 32, False, True, True): (1, 2, 1, 8),
|
||||
(64, 32, 64, 32, 32, True, False, False): (2, 2, 1, 4),
|
||||
(64, 32, 64, 32, 32, True, False, True): (1, 2, 1, 8),
|
||||
(64, 64, 16, 32, 32, False, False, False): (1, 1, 2, 8),
|
||||
(64, 64, 16, 32, 32, False, False, True): (2, 1, 2, 4),
|
||||
(64, 64, 16, 32, 32, False, True, False): (1, 1, 1, 2),
|
||||
(64, 64, 16, 32, 32, False, True, True): (1, 1, 2, 4),
|
||||
(64, 64, 16, 32, 32, True, False, False): (1, 1, 1, 2),
|
||||
(64, 64, 16, 32, 32, True, False, True): (1, 1, 2, 4),
|
||||
(64, 64, 32, 32, 32, False, False, False): (1, 1, 1, 4),
|
||||
(64, 64, 32, 32, 32, False, False, True): (2, 1, 1, 4),
|
||||
(64, 64, 32, 32, 32, False, True, False): (1, 1, 1, 8),
|
||||
(64, 64, 32, 32, 32, False, True, True): (2, 1, 1, 4),
|
||||
(64, 64, 32, 32, 32, True, False, False): (1, 1, 1, 4),
|
||||
(64, 64, 32, 32, 32, True, False, True): (1, 1, 1, 8),
|
||||
(64, 64, 64, 32, 32, False, False, False): (2, 2, 1, 4),
|
||||
(64, 64, 64, 32, 32, False, False, True): (1, 2, 1, 4),
|
||||
(64, 64, 64, 32, 32, False, True, False): (1, 2, 1, 4),
|
||||
(64, 64, 64, 32, 32, False, True, True): (2, 2, 1, 4),
|
||||
(64, 64, 64, 32, 32, True, False, False): (1, 1, 1, 8),
|
||||
(64, 64, 64, 32, 32, True, False, True): (1, 2, 2, 4),
|
||||
(256, 256, 256, 16, 16, False, True, True): (1, 16, 3, 4),
|
||||
(256, 256, 256, 32, 32, False, True, True): (1, 8, 5, 4),
|
||||
(256, 256, 256, 64, 64, False, True, True): (3, 4, 4, 8),
|
||||
@ -1582,6 +2232,14 @@ _operation_device_version_data: Dict[Any, Dict] = {
|
||||
(16384, 16384, 131072, 128, 128, False, True, True): (2, 1024, 1, 32),
|
||||
},
|
||||
("bsr_dense_mm", "NVIDIA A100-SXM4-80GB", (0, torch.bfloat16, 0.5)): {
|
||||
(16, 16, 32, 16, 16): (1, 1, 1),
|
||||
(16, 16, 64, 16, 16): (1, 2, 8),
|
||||
(16, 32, 32, 16, 16): (1, 1, 4),
|
||||
(16, 32, 64, 16, 16): (2, 1, 2),
|
||||
(32, 32, 32, 32, 32): (1, 2, 4),
|
||||
(32, 32, 64, 32, 32): (1, 1, 2),
|
||||
(32, 64, 32, 32, 32): (1, 2, 8),
|
||||
(32, 64, 64, 32, 32): (1, 1, 4),
|
||||
(256, 256, 256, 16, 16): (3, 3, 1),
|
||||
(256, 256, 256, 32, 32): (1, 5, 1),
|
||||
(256, 256, 256, 64, 64): (2, 3, 2),
|
||||
@ -2146,6 +2804,14 @@ _operation_device_version_data: Dict[Any, Dict] = {
|
||||
(16384, 16384, 131072, 128, 128): (4, 1, 4),
|
||||
},
|
||||
("bsr_dense_mm", "NVIDIA A100-SXM4-80GB", (0, torch.float16, 0.5)): {
|
||||
(16, 16, 32, 16, 16): (4, 3, 4),
|
||||
(16, 16, 64, 16, 16): (1, 2, 4),
|
||||
(16, 32, 32, 16, 16): (1, 1, 2),
|
||||
(16, 32, 64, 16, 16): (2, 1, 2),
|
||||
(32, 32, 32, 32, 32): (1, 1, 4),
|
||||
(32, 32, 64, 32, 32): (1, 1, 1),
|
||||
(32, 64, 32, 32, 32): (1, 1, 2),
|
||||
(32, 64, 64, 32, 32): (2, 1, 4),
|
||||
(256, 256, 256, 16, 16): (8, 6, 1),
|
||||
(256, 256, 256, 32, 32): (4, 5, 2),
|
||||
(256, 256, 256, 64, 64): (3, 3, 4),
|
||||
@ -2710,6 +3376,14 @@ _operation_device_version_data: Dict[Any, Dict] = {
|
||||
(16384, 16384, 131072, 128, 128): (4, 1, 4),
|
||||
},
|
||||
("bsr_dense_mm", "NVIDIA A100-SXM4-80GB", (0, torch.float32, 0.5)): {
|
||||
(16, 16, 32, 16, 16): (1, 1, 4),
|
||||
(16, 16, 64, 16, 16): (1, 1, 1),
|
||||
(16, 32, 32, 16, 16): (1, 2, 4),
|
||||
(16, 32, 64, 16, 16): (2, 1, 4),
|
||||
(32, 32, 32, 32, 32): (1, 1, 8),
|
||||
(32, 32, 64, 32, 32): (2, 1, 4),
|
||||
(32, 64, 32, 32, 32): (1, 1, 8),
|
||||
(32, 64, 64, 32, 32): (1, 2, 4),
|
||||
(256, 256, 256, 16, 16): (1, 1, 8),
|
||||
(256, 256, 256, 32, 32): (1, 3, 4),
|
||||
(256, 256, 256, 64, 64): (1, 1, 8),
|
||||
|
Reference in New Issue
Block a user