Tune bsr_dense_addmm for int8 inputs on A100 (#136088)

As in the title. The tuning is done for dimensions 1280 and 5120 that are used in Vit-H.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136088
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Pearu Peterson
2024-09-24 20:33:34 +00:00
committed by PyTorch MergeBot
parent 9629835b1c
commit 8f2a4cc4b1
2 changed files with 349 additions and 5 deletions

View File

@ -793,7 +793,8 @@ def bsr_dense_addmm_meta(
# _triton_ops_meta.py for ways to avoid this warning
# message
warn_once(
f"bsr_dense_addmm uses non-optimal triton kernel parameters for {M=} {K=} {N=} {Ms=}, {Ks=} {beta=} {alpha=}"
"bsr_dense_addmm uses non-optimal triton kernel parameters"
f" for {M=} {K=} {N=} {Ms=}, {Ks=} {beta=} {alpha=} {dtype=}"
)
SPLIT_N = SPLIT_N or max(N // Ms, 1)
@ -1235,10 +1236,10 @@ def bsr_dense_addmm(
out = tile_to_blocksize(out, (BM, BN))
dense = tile_to_blocksize(dense, (BK, BN))
input = tile_to_blocksize(input, (BM, BN))
left_alpha = tile_to_blocksize(left_alpha, (BM, BN))
right_alpha = tile_to_blocksize(right_alpha, (BM, BN))
# tl.dot supports float16, float32, int32 as accumulator types.
dot_out_dtype = {
torch.float16: tl.float32,
torch.bfloat16: tl.float32,

View File

@ -190,8 +190,12 @@ def get_meta(op, key, device_name=None, version=(0, torch.float16, 0.5), exact=F
def update(op, device_name, version, key, value):
"""Update the db of op parameters."""
# avoid storing possible optimization failures:
assert value, (op, device_name, version, key, value)
# skip storing possible optimization failures:
if not value:
warnings.warn(
f"skipping empty value for {op}: {device_name=} {version=} {key=}"
)
return
if (op, device_name, version) in _operation_device_version_data:
if _operation_device_version_data[op, device_name, version].get(key) == value:
return
@ -782,10 +786,12 @@ def main(op="scatter_mm", force=False, dtype=torch.float16, verbose=True):
65536,
131072,
50432,
65792,
]
sizes3_lst = [3 * sz for sz in [64, 128] + sizes_lst if sz <= 2048]
shapes_lst = [(sz, sz) for sz in sizes_lst[:-4] + sizes3_lst]
shapes_lst.extend([(3072, 768), (768, 3072)])
shapes_lst.extend([(5120, 1280), (1280, 5120)])
if dtype is torch.int8:
# triton does not support smaller blocks than 32
blocksize_lst = [(32, 32), (64, 64), (128, 128), (256, 256)]
@ -1004,6 +1010,12 @@ _operation_device_version_data: Dict[Any, Dict] = {
(256, 256, 65536, 64, 64, True, False, True): (1, 512, 1, 4),
(256, 256, 65536, 128, 128, False, True, True): (2, 512, 1, 16),
(256, 256, 65536, 128, 128, True, False, True): (1, 512, 1, 4),
(256, 256, 65792, 32, 32, False, True, True): (1, 1028, 1, 8),
(256, 256, 65792, 32, 32, True, False, True): (1, 514, 1, 4),
(256, 256, 65792, 64, 64, False, True, True): (1, 1028, 1, 8),
(256, 256, 65792, 64, 64, True, False, True): (4, 257, 1, 4),
(256, 256, 65792, 128, 128, False, True, True): (2, 514, 1, 16),
(256, 256, 65792, 128, 128, True, False, True): (3, 514, 1, 4),
(256, 256, 131072, 32, 32, False, True, True): (1, 2048, 1, 8),
(256, 256, 131072, 32, 32, True, False, True): (2, 1024, 1, 4),
(256, 256, 131072, 64, 64, False, True, True): (1, 2048, 1, 8),
@ -1122,6 +1134,14 @@ _operation_device_version_data: Dict[Any, Dict] = {
(512, 512, 65536, 128, 128, True, False, True): (1, 512, 1, 4),
(512, 512, 65536, 256, 256, False, True, True): (1, 256, 1, 32),
(512, 512, 65536, 256, 256, True, False, True): (1, 256, 1, 32),
(512, 512, 65792, 32, 32, False, True, True): (1, 1028, 1, 8),
(512, 512, 65792, 32, 32, True, False, True): (1, 514, 3, 2),
(512, 512, 65792, 64, 64, False, True, True): (1, 1028, 1, 8),
(512, 512, 65792, 64, 64, True, False, True): (2, 257, 3, 4),
(512, 512, 65792, 128, 128, False, True, True): (4, 514, 1, 16),
(512, 512, 65792, 128, 128, True, False, True): (1, 514, 1, 4),
(512, 512, 65792, 256, 256, False, True, True): (1, 257, 1, 32),
(512, 512, 65792, 256, 256, True, False, True): (2, 257, 1, 32),
(512, 512, 131072, 32, 32, False, True, True): (1, 2048, 1, 8),
(512, 512, 131072, 32, 32, True, False, True): (1, 1024, 3, 2),
(512, 512, 131072, 64, 64, False, True, True): (1, 2048, 1, 8),
@ -1350,6 +1370,14 @@ _operation_device_version_data: Dict[Any, Dict] = {
(1024, 1024, 65536, 128, 128, True, False, True): (4, 512, 3, 4),
(1024, 1024, 65536, 256, 256, False, True, True): (1, 256, 1, 32),
(1024, 1024, 65536, 256, 256, True, False, True): (1, 256, 1, 32),
(1024, 1024, 65792, 32, 32, False, True, True): (1, 1028, 1, 8),
(1024, 1024, 65792, 32, 32, True, False, True): (1, 514, 3, 2),
(1024, 1024, 65792, 64, 64, False, True, True): (2, 514, 1, 4),
(1024, 1024, 65792, 64, 64, True, False, True): (4, 257, 3, 4),
(1024, 1024, 65792, 128, 128, False, True, True): (2, 514, 1, 16),
(1024, 1024, 65792, 128, 128, True, False, True): (2, 514, 2, 4),
(1024, 1024, 65792, 256, 256, False, True, True): (1, 257, 1, 32),
(1024, 1024, 65792, 256, 256, True, False, True): (1, 257, 1, 32),
(1024, 1024, 131072, 32, 32, False, True, True): (1, 2048, 1, 8),
(1024, 1024, 131072, 32, 32, True, False, True): (1, 1024, 3, 2),
(1024, 1024, 131072, 64, 64, False, True, True): (2, 1024, 1, 4),
@ -1358,6 +1386,14 @@ _operation_device_version_data: Dict[Any, Dict] = {
(1024, 1024, 131072, 128, 128, True, False, True): (1, 1024, 3, 4),
(1024, 1024, 131072, 256, 256, False, True, True): (1, 512, 1, 32),
(1024, 1024, 131072, 256, 256, True, False, True): (1, 512, 1, 32),
(1280, 5120, 65792, 32, 32, False, True, True): (1, 1028, 1, 8),
(1280, 5120, 65792, 32, 32, True, False, True): (1, 514, 3, 2),
(1280, 5120, 65792, 64, 64, False, True, True): (1, 1028, 1, 8),
(1280, 5120, 65792, 64, 64, True, False, True): (2, 257, 3, 4),
(1280, 5120, 65792, 128, 128, False, True, True): (2, 514, 1, 16),
(1280, 5120, 65792, 128, 128, True, False, True): (1, 514, 3, 4),
(1280, 5120, 65792, 256, 256, False, True, True): (1, 257, 1, 32),
(1280, 5120, 65792, 256, 256, True, False, True): (1, 257, 1, 32),
(1536, 1536, 256, 32, 32, False, True, True): (1, 8, 1, 4),
(1536, 1536, 256, 32, 32, True, False, True): (2, 8, 1, 8),
(1536, 1536, 256, 64, 64, False, True, True): (4, 4, 1, 16),
@ -1510,6 +1546,14 @@ _operation_device_version_data: Dict[Any, Dict] = {
(2048, 2048, 65536, 128, 128, True, False, True): (1, 512, 2, 4),
(2048, 2048, 65536, 256, 256, False, True, True): (1, 256, 1, 32),
(2048, 2048, 65536, 256, 256, True, False, True): (4, 256, 1, 32),
(2048, 2048, 65792, 32, 32, False, True, True): (1, 1028, 1, 8),
(2048, 2048, 65792, 32, 32, True, False, True): (1, 514, 3, 2),
(2048, 2048, 65792, 64, 64, False, True, True): (1, 514, 1, 4),
(2048, 2048, 65792, 64, 64, True, False, True): (2, 257, 3, 4),
(2048, 2048, 65792, 128, 128, False, True, True): (1, 514, 1, 8),
(2048, 2048, 65792, 128, 128, True, False, True): (1, 514, 2, 4),
(2048, 2048, 65792, 256, 256, False, True, True): (1, 257, 1, 32),
(2048, 2048, 65792, 256, 256, True, False, True): (1, 257, 1, 32),
(2048, 2048, 131072, 32, 32, False, True, True): (1, 2048, 1, 8),
(2048, 2048, 131072, 32, 32, True, False, True): (1, 1024, 3, 2),
(2048, 2048, 131072, 64, 64, False, True, True): (1, 1024, 1, 4),
@ -1758,6 +1802,14 @@ _operation_device_version_data: Dict[Any, Dict] = {
(4096, 4096, 65536, 128, 128, True, False, True): (1, 512, 3, 4),
(4096, 4096, 65536, 256, 256, False, True, True): (1, 256, 1, 32),
(4096, 4096, 65536, 256, 256, True, False, True): (4, 256, 1, 32),
(4096, 4096, 65792, 32, 32, False, True, True): (1, 1028, 1, 8),
(4096, 4096, 65792, 32, 32, True, False, True): (1, 514, 3, 2),
(4096, 4096, 65792, 64, 64, False, True, True): (1, 1028, 1, 8),
(4096, 4096, 65792, 64, 64, True, False, True): (1, 514, 3, 2),
(4096, 4096, 65792, 128, 128, False, True, True): (1, 514, 1, 8),
(4096, 4096, 65792, 128, 128, True, False, True): (1, 514, 2, 4),
(4096, 4096, 65792, 256, 256, False, True, True): (1, 257, 1, 32),
(4096, 4096, 65792, 256, 256, True, False, True): (1, 257, 1, 32),
(4096, 4096, 131072, 32, 32, False, True, True): (1, 2048, 1, 8),
(4096, 4096, 131072, 32, 32, True, False, True): (1, 1024, 3, 2),
(4096, 4096, 131072, 64, 64, False, True, True): (1, 2048, 1, 8),
@ -1766,6 +1818,14 @@ _operation_device_version_data: Dict[Any, Dict] = {
(4096, 4096, 131072, 128, 128, True, False, True): (1, 1024, 3, 4),
(4096, 4096, 131072, 256, 256, False, True, True): (1, 512, 1, 32),
(4096, 4096, 131072, 256, 256, True, False, True): (4, 512, 1, 32),
(5120, 1280, 65792, 32, 32, False, True, True): (1, 1028, 1, 8),
(5120, 1280, 65792, 32, 32, True, False, True): (1, 514, 1, 2),
(5120, 1280, 65792, 64, 64, False, True, True): (1, 514, 1, 4),
(5120, 1280, 65792, 64, 64, True, False, True): (1, 514, 2, 2),
(5120, 1280, 65792, 128, 128, False, True, True): (1, 514, 1, 8),
(5120, 1280, 65792, 128, 128, True, False, True): (1, 514, 2, 4),
(5120, 1280, 65792, 256, 256, False, True, True): (1, 257, 1, 32),
(5120, 1280, 65792, 256, 256, True, False, True): (1, 257, 1, 32),
(6144, 6144, 256, 32, 32, False, True, True): (2, 4, 1, 8),
(6144, 6144, 256, 32, 32, True, False, True): (2, 1, 4, 4),
(6144, 6144, 256, 64, 64, False, True, True): (1, 4, 1, 8),
@ -1918,6 +1978,14 @@ _operation_device_version_data: Dict[Any, Dict] = {
(8192, 8192, 65536, 128, 128, True, False, True): (4, 512, 3, 4),
(8192, 8192, 65536, 256, 256, False, True, True): (1, 256, 1, 32),
(8192, 8192, 65536, 256, 256, True, False, True): (4, 256, 1, 32),
(8192, 8192, 65792, 32, 32, False, True, True): (4, 1028, 1, 8),
(8192, 8192, 65792, 32, 32, True, False, True): (1, 514, 3, 2),
(8192, 8192, 65792, 64, 64, False, True, True): (4, 1028, 1, 8),
(8192, 8192, 65792, 64, 64, True, False, True): (2, 257, 3, 4),
(8192, 8192, 65792, 128, 128, False, True, True): (4, 514, 1, 16),
(8192, 8192, 65792, 128, 128, True, False, True): (2, 514, 3, 4),
(8192, 8192, 65792, 256, 256, False, True, True): (1, 257, 1, 32),
(8192, 8192, 65792, 256, 256, True, False, True): (1, 257, 1, 32),
(8192, 8192, 131072, 32, 32, False, True, True): (4, 2048, 1, 8),
(8192, 8192, 131072, 32, 32, True, False, True): (4, 1024, 3, 2),
(8192, 8192, 131072, 64, 64, False, True, True): (4, 1024, 1, 4),
@ -1998,6 +2066,14 @@ _operation_device_version_data: Dict[Any, Dict] = {
(16384, 16384, 65536, 128, 128, True, False, True): (4, 512, 3, 4),
(16384, 16384, 65536, 256, 256, False, True, True): (1, 256, 1, 32),
(16384, 16384, 65536, 256, 256, True, False, True): (4, 256, 1, 32),
(16384, 16384, 65792, 32, 32, False, True, True): (4, 1028, 1, 8),
(16384, 16384, 65792, 32, 32, True, False, True): (1, 514, 3, 2),
(16384, 16384, 65792, 64, 64, False, True, True): (2, 514, 1, 4),
(16384, 16384, 65792, 64, 64, True, False, True): (2, 257, 3, 4),
(16384, 16384, 65792, 128, 128, False, True, True): (2, 514, 1, 16),
(16384, 16384, 65792, 128, 128, True, False, True): (2, 514, 3, 4),
(16384, 16384, 65792, 256, 256, False, True, True): (1, 257, 1, 32),
(16384, 16384, 65792, 256, 256, True, False, True): (1, 257, 1, 32),
(16384, 16384, 131072, 32, 32, False, True, True): (4, 1024, 1, 8),
(16384, 16384, 131072, 32, 32, True, False, True): (4, 512, 3, 4),
(16384, 16384, 131072, 64, 64, False, True, True): (4, 1024, 1, 4),
@ -2006,6 +2082,78 @@ _operation_device_version_data: Dict[Any, Dict] = {
(16384, 16384, 131072, 128, 128, True, False, True): (4, 1024, 3, 4),
(16384, 16384, 131072, 256, 256, False, True, True): (4, 512, 1, 32),
(16384, 16384, 131072, 256, 256, True, False, True): (4, 512, 1, 32),
(32768, 32768, 256, 32, 32, False, True, True): (4, 4, 1, 8),
(32768, 32768, 256, 32, 32, True, False, True): (1, 2, 4, 2),
(32768, 32768, 256, 64, 64, False, True, True): (2, 2, 1, 4),
(32768, 32768, 256, 64, 64, True, False, True): (2, 1, 3, 4),
(32768, 32768, 256, 128, 128, False, True, True): (4, 2, 1, 8),
(32768, 32768, 256, 128, 128, True, False, True): (4, 2, 3, 4),
(32768, 32768, 256, 256, 256, False, True, True): (1, 1, 1, 32),
(32768, 32768, 256, 256, 256, True, False, True): (1, 1, 1, 32),
(32768, 32768, 512, 32, 32, False, True, True): (4, 8, 1, 8),
(32768, 32768, 512, 32, 32, True, False, True): (1, 4, 3, 2),
(32768, 32768, 512, 64, 64, False, True, True): (4, 4, 1, 4),
(32768, 32768, 512, 64, 64, True, False, True): (4, 2, 3, 4),
(32768, 32768, 512, 128, 128, False, True, True): (1, 2, 1, 8),
(32768, 32768, 512, 128, 128, True, False, True): (4, 4, 3, 4),
(32768, 32768, 512, 256, 256, False, True, True): (1, 2, 1, 32),
(32768, 32768, 512, 256, 256, True, False, True): (2, 2, 1, 32),
(32768, 32768, 1024, 32, 32, False, True, True): (4, 16, 1, 8),
(32768, 32768, 1024, 32, 32, True, False, True): (1, 8, 4, 2),
(32768, 32768, 1024, 64, 64, False, True, True): (4, 8, 1, 4),
(32768, 32768, 1024, 64, 64, True, False, True): (4, 4, 3, 4),
(32768, 32768, 1024, 128, 128, False, True, True): (1, 4, 1, 8),
(32768, 32768, 1024, 128, 128, True, False, True): (4, 8, 3, 4),
(32768, 32768, 1024, 256, 256, False, True, True): (1, 4, 1, 32),
(32768, 32768, 1024, 256, 256, True, False, True): (1, 4, 1, 32),
(32768, 32768, 2048, 32, 32, False, True, True): (2, 32, 1, 8),
(32768, 32768, 2048, 32, 32, True, False, True): (1, 16, 4, 2),
(32768, 32768, 2048, 64, 64, False, True, True): (2, 16, 1, 4),
(32768, 32768, 2048, 64, 64, True, False, True): (4, 8, 3, 4),
(32768, 32768, 2048, 128, 128, False, True, True): (1, 8, 1, 8),
(32768, 32768, 2048, 128, 128, True, False, True): (4, 16, 3, 4),
(32768, 32768, 2048, 256, 256, False, True, True): (1, 8, 1, 32),
(32768, 32768, 2048, 256, 256, True, False, True): (4, 8, 1, 32),
(32768, 32768, 4096, 32, 32, False, True, True): (2, 64, 1, 8),
(32768, 32768, 4096, 32, 32, True, False, True): (2, 32, 3, 2),
(32768, 32768, 4096, 64, 64, False, True, True): (2, 32, 1, 4),
(32768, 32768, 4096, 64, 64, True, False, True): (2, 16, 3, 4),
(32768, 32768, 4096, 128, 128, False, True, True): (1, 16, 1, 8),
(32768, 32768, 4096, 128, 128, True, False, True): (2, 32, 3, 4),
(32768, 32768, 4096, 256, 256, False, True, True): (1, 16, 1, 32),
(32768, 32768, 4096, 256, 256, True, False, True): (4, 16, 1, 32),
(32768, 32768, 8192, 32, 32, False, True, True): (2, 128, 1, 8),
(32768, 32768, 8192, 32, 32, True, False, True): (2, 64, 3, 2),
(32768, 32768, 8192, 64, 64, False, True, True): (2, 64, 1, 4),
(32768, 32768, 8192, 64, 64, True, False, True): (2, 32, 3, 4),
(32768, 32768, 8192, 128, 128, False, True, True): (1, 32, 1, 8),
(32768, 32768, 8192, 128, 128, True, False, True): (4, 64, 3, 4),
(32768, 32768, 8192, 256, 256, False, True, True): (1, 32, 1, 32),
(32768, 32768, 8192, 256, 256, True, False, True): (4, 32, 1, 32),
(32768, 32768, 16384, 32, 32, False, True, True): (2, 256, 1, 8),
(32768, 32768, 16384, 32, 32, True, False, True): (2, 128, 4, 2),
(32768, 32768, 16384, 64, 64, False, True, True): (2, 128, 1, 4),
(32768, 32768, 16384, 64, 64, True, False, True): (4, 64, 3, 4),
(32768, 32768, 16384, 128, 128, False, True, True): (1, 64, 1, 8),
(32768, 32768, 16384, 128, 128, True, False, True): (4, 128, 3, 4),
(32768, 32768, 16384, 256, 256, False, True, True): (1, 64, 1, 32),
(32768, 32768, 16384, 256, 256, True, False, True): (2, 64, 1, 32),
(32768, 32768, 32768, 32, 32, False, True, True): (2, 512, 1, 8),
(32768, 32768, 32768, 32, 32, True, False, True): (4, 256, 3, 2),
(32768, 32768, 32768, 64, 64, False, True, True): (1, 256, 1, 4),
(32768, 32768, 32768, 64, 64, True, False, True): (2, 128, 3, 4),
(32768, 32768, 32768, 128, 128, False, True, True): (1, 128, 1, 8),
(32768, 32768, 32768, 128, 128, True, False, True): (2, 256, 3, 4),
(32768, 32768, 32768, 256, 256, False, True, True): (1, 128, 1, 32),
(32768, 32768, 32768, 256, 256, True, False, True): (1, 128, 1, 32),
(32768, 32768, 65536, 32, 32, False, True, True): (2, 512, 1, 8),
(32768, 32768, 65536, 32, 32, True, False, True): (3, 512, 4, 2),
(32768, 32768, 65536, 64, 64, False, True, True): (1, 512, 1, 4),
(32768, 32768, 65536, 64, 64, True, False, True): (2, 512, 3, 2),
(32768, 32768, 65536, 128, 128, False, True, True): (1, 256, 1, 8),
(32768, 32768, 65536, 128, 128, True, False, True): (2, 512, 3, 4),
(32768, 32768, 65536, 256, 256, False, True, True): (1, 256, 1, 32),
(32768, 32768, 65536, 256, 256, True, False, True): (1, 256, 1, 32),
},
("_int_bsr_dense_addmm", "NVIDIA A100-SXM4-80GB", (0, torch.int8, 0.56)): {
(192, 192, 256, 64, 64, False, True, True): (3, 4, 3, 32),
@ -2088,6 +2236,8 @@ _operation_device_version_data: Dict[Any, Dict] = {
(256, 256, 32768, 256, 256, True, False, True): (1, 128, 1, 4),
(256, 256, 65536, 256, 256, False, True, True): (1, 4, 1, 1),
(256, 256, 65536, 256, 256, True, False, True): (1, 128, 1, 4),
(256, 256, 65792, 256, 256, False, True, True): (1, 128, 2, 16),
(256, 256, 65792, 256, 256, True, False, True): (1, 16, 3, 4),
(256, 256, 131072, 256, 256, False, True, True): (1, 512, 1, 4),
(256, 256, 131072, 256, 256, True, False, True): (1, 512, 1, 2),
},
@ -2816,6 +2966,14 @@ _operation_device_version_data: Dict[Any, Dict] = {
(1024, 1024, 131072, 64, 64, True, False, True): (2, 1024, 3, 4),
(1024, 1024, 131072, 128, 128, False, True, True): (4, 1024, 1, 4),
(1024, 1024, 131072, 128, 128, True, False, True): (4, 1024, 1, 4),
(1280, 5120, 65792, 16, 16, False, True, True): (1, 257, 1, 4),
(1280, 5120, 65792, 16, 16, True, False, True): (5, 257, 4, 1),
(1280, 5120, 65792, 32, 32, False, True, True): (1, 514, 1, 8),
(1280, 5120, 65792, 32, 32, True, False, True): (2, 257, 3, 4),
(1280, 5120, 65792, 64, 64, False, True, True): (1, 514, 3, 4),
(1280, 5120, 65792, 64, 64, True, False, True): (1, 257, 3, 4),
(1280, 5120, 65792, 128, 128, False, True, True): (1, 514, 3, 8),
(1280, 5120, 65792, 128, 128, True, False, True): (2, 514, 3, 8),
(1536, 1536, 256, 16, 16, False, True, True): (1, 4, 6, 2),
(1536, 1536, 256, 16, 16, True, False, True): (3, 4, 5, 2),
(1536, 1536, 256, 32, 32, False, True, True): (2, 4, 3, 4),
@ -3224,6 +3382,14 @@ _operation_device_version_data: Dict[Any, Dict] = {
(4096, 4096, 131072, 64, 64, True, False, True): (3, 1024, 3, 4),
(4096, 4096, 131072, 128, 128, False, True, True): (1, 1024, 1, 4),
(4096, 4096, 131072, 128, 128, True, False, True): (1, 1024, 1, 4),
(5120, 1280, 65792, 16, 16, False, True, True): (1, 257, 1, 4),
(5120, 1280, 65792, 16, 16, True, False, True): (11, 257, 4, 1),
(5120, 1280, 65792, 32, 32, False, True, True): (1, 257, 1, 4),
(5120, 1280, 65792, 32, 32, True, False, True): (5, 257, 3, 4),
(5120, 1280, 65792, 64, 64, False, True, True): (1, 514, 1, 4),
(5120, 1280, 65792, 64, 64, True, False, True): (5, 257, 2, 4),
(5120, 1280, 65792, 128, 128, False, True, True): (3, 514, 1, 4),
(5120, 1280, 65792, 128, 128, True, False, True): (7, 514, 2, 4),
(6144, 6144, 256, 16, 16, False, True, True): (1, 2, 1, 4),
(6144, 6144, 256, 16, 16, True, False, True): (3, 1, 4, 4),
(6144, 6144, 256, 32, 32, False, True, True): (3, 2, 1, 8),
@ -3844,6 +4010,14 @@ _operation_device_version_data: Dict[Any, Dict] = {
(256, 256, 65536, 64, 64, True, False, True): (5, 512, 1, 4),
(256, 256, 65536, 128, 128, False, True, True): (3, 512, 1, 4),
(256, 256, 65536, 128, 128, True, False, True): (1, 512, 1, 4),
(256, 256, 65792, 16, 16, False, True, True): (2, 257, 1, 4),
(256, 256, 65792, 16, 16, True, False, True): (1, 257, 3, 2),
(256, 256, 65792, 32, 32, False, True, True): (2, 257, 1, 4),
(256, 256, 65792, 32, 32, True, False, True): (1, 257, 3, 4),
(256, 256, 65792, 64, 64, False, True, True): (2, 514, 1, 4),
(256, 256, 65792, 64, 64, True, False, True): (2, 514, 2, 4),
(256, 256, 65792, 128, 128, False, True, True): (3, 514, 1, 4),
(256, 256, 65792, 128, 128, True, False, True): (1, 514, 2, 4),
(256, 256, 131072, 16, 16, False, True, True): (1, 512, 3, 1),
(256, 256, 131072, 16, 16, True, False, True): (1, 512, 3, 2),
(256, 256, 131072, 32, 32, False, True, True): (2, 1024, 3, 2),
@ -3992,6 +4166,14 @@ _operation_device_version_data: Dict[Any, Dict] = {
(512, 512, 65536, 64, 64, True, False, True): (1, 512, 3, 4),
(512, 512, 65536, 128, 128, False, True, True): (7, 512, 1, 4),
(512, 512, 65536, 128, 128, True, False, True): (5, 512, 1, 4),
(512, 512, 65792, 16, 16, False, True, True): (2, 257, 1, 4),
(512, 512, 65792, 16, 16, True, False, True): (1, 257, 3, 4),
(512, 512, 65792, 32, 32, False, True, True): (2, 257, 1, 4),
(512, 512, 65792, 32, 32, True, False, True): (1, 257, 3, 4),
(512, 512, 65792, 64, 64, False, True, True): (4, 514, 1, 4),
(512, 512, 65792, 64, 64, True, False, True): (4, 257, 2, 4),
(512, 512, 65792, 128, 128, False, True, True): (5, 514, 1, 4),
(512, 512, 65792, 128, 128, True, False, True): (4, 514, 2, 4),
(512, 512, 131072, 16, 16, False, True, True): (1, 512, 3, 1),
(512, 512, 131072, 16, 16, True, False, True): (1, 512, 3, 1),
(512, 512, 131072, 32, 32, False, True, True): (1, 1024, 3, 2),
@ -4248,6 +4430,14 @@ _operation_device_version_data: Dict[Any, Dict] = {
(1024, 1024, 65536, 64, 64, True, False, True): (1, 512, 3, 4),
(1024, 1024, 65536, 128, 128, False, True, True): (10, 512, 1, 4),
(1024, 1024, 65536, 128, 128, True, False, True): (4, 512, 1, 4),
(1024, 1024, 65792, 16, 16, False, True, True): (1, 257, 1, 4),
(1024, 1024, 65792, 16, 16, True, False, True): (10, 257, 4, 1),
(1024, 1024, 65792, 32, 32, False, True, True): (2, 257, 1, 4),
(1024, 1024, 65792, 32, 32, True, False, True): (1, 257, 3, 4),
(1024, 1024, 65792, 64, 64, False, True, True): (2, 514, 1, 4),
(1024, 1024, 65792, 64, 64, True, False, True): (2, 257, 2, 4),
(1024, 1024, 65792, 128, 128, False, True, True): (6, 514, 1, 4),
(1024, 1024, 65792, 128, 128, True, False, True): (2, 514, 2, 4),
(1024, 1024, 131072, 16, 16, False, True, True): (11, 512, 3, 2),
(1024, 1024, 131072, 16, 16, True, False, True): (11, 512, 3, 2),
(1024, 1024, 131072, 32, 32, False, True, True): (7, 1024, 3, 2),
@ -4256,6 +4446,14 @@ _operation_device_version_data: Dict[Any, Dict] = {
(1024, 1024, 131072, 64, 64, True, False, True): (4, 1024, 3, 4),
(1024, 1024, 131072, 128, 128, False, True, True): (12, 1024, 1, 4),
(1024, 1024, 131072, 128, 128, True, False, True): (4, 1024, 1, 4),
(1280, 5120, 65792, 16, 16, False, True, True): (1, 257, 1, 4),
(1280, 5120, 65792, 16, 16, True, False, True): (5, 257, 4, 1),
(1280, 5120, 65792, 32, 32, False, True, True): (2, 257, 1, 4),
(1280, 5120, 65792, 32, 32, True, False, True): (2, 257, 3, 4),
(1280, 5120, 65792, 64, 64, False, True, True): (1, 514, 3, 4),
(1280, 5120, 65792, 64, 64, True, False, True): (2, 257, 3, 4),
(1280, 5120, 65792, 128, 128, False, True, True): (1, 514, 3, 8),
(1280, 5120, 65792, 128, 128, True, False, True): (1, 514, 3, 8),
(1536, 1536, 256, 16, 16, False, True, True): (5, 4, 4, 2),
(1536, 1536, 256, 16, 16, True, False, True): (3, 4, 5, 2),
(1536, 1536, 256, 32, 32, False, True, True): (2, 4, 4, 4),
@ -4416,6 +4614,14 @@ _operation_device_version_data: Dict[Any, Dict] = {
(2048, 2048, 65536, 64, 64, True, False, True): (9, 512, 3, 4),
(2048, 2048, 65536, 128, 128, False, True, True): (5, 512, 1, 4),
(2048, 2048, 65536, 128, 128, True, False, True): (1, 512, 1, 4),
(2048, 2048, 65792, 16, 16, False, True, True): (1, 257, 1, 4),
(2048, 2048, 65792, 16, 16, True, False, True): (7, 257, 4, 1),
(2048, 2048, 65792, 32, 32, False, True, True): (2, 257, 1, 4),
(2048, 2048, 65792, 32, 32, True, False, True): (7, 257, 3, 4),
(2048, 2048, 65792, 64, 64, False, True, True): (1, 514, 3, 4),
(2048, 2048, 65792, 64, 64, True, False, True): (1, 257, 2, 4),
(2048, 2048, 65792, 128, 128, False, True, True): (3, 514, 1, 4),
(2048, 2048, 65792, 128, 128, True, False, True): (1, 514, 2, 4),
(2048, 2048, 131072, 16, 16, False, True, True): (9, 512, 3, 2),
(2048, 2048, 131072, 16, 16, True, False, True): (9, 512, 4, 4),
(2048, 2048, 131072, 32, 32, False, True, True): (7, 512, 3, 4),
@ -4672,6 +4878,14 @@ _operation_device_version_data: Dict[Any, Dict] = {
(4096, 4096, 65536, 64, 64, True, False, True): (1, 512, 3, 4),
(4096, 4096, 65536, 128, 128, False, True, True): (3, 512, 1, 4),
(4096, 4096, 65536, 128, 128, True, False, True): (1, 512, 1, 4),
(4096, 4096, 65792, 16, 16, False, True, True): (1, 257, 1, 4),
(4096, 4096, 65792, 16, 16, True, False, True): (5, 257, 4, 1),
(4096, 4096, 65792, 32, 32, False, True, True): (1, 257, 1, 4),
(4096, 4096, 65792, 32, 32, True, False, True): (1, 257, 3, 4),
(4096, 4096, 65792, 64, 64, False, True, True): (1, 514, 3, 4),
(4096, 4096, 65792, 64, 64, True, False, True): (1, 257, 2, 4),
(4096, 4096, 65792, 128, 128, False, True, True): (3, 514, 1, 4),
(4096, 4096, 65792, 128, 128, True, False, True): (1, 514, 2, 4),
(4096, 4096, 131072, 16, 16, False, True, True): (4, 512, 3, 4),
(4096, 4096, 131072, 16, 16, True, False, True): (5, 512, 4, 4),
(4096, 4096, 131072, 32, 32, False, True, True): (1, 512, 4, 8),
@ -4680,6 +4894,14 @@ _operation_device_version_data: Dict[Any, Dict] = {
(4096, 4096, 131072, 64, 64, True, False, True): (1, 512, 2, 4),
(4096, 4096, 131072, 128, 128, False, True, True): (3, 1024, 1, 4),
(4096, 4096, 131072, 128, 128, True, False, True): (1, 1024, 1, 4),
(5120, 1280, 65792, 16, 16, False, True, True): (1, 257, 1, 4),
(5120, 1280, 65792, 16, 16, True, False, True): (7, 257, 4, 1),
(5120, 1280, 65792, 32, 32, False, True, True): (2, 257, 1, 4),
(5120, 1280, 65792, 32, 32, True, False, True): (5, 257, 3, 4),
(5120, 1280, 65792, 64, 64, False, True, True): (1, 514, 1, 4),
(5120, 1280, 65792, 64, 64, True, False, True): (5, 257, 2, 4),
(5120, 1280, 65792, 128, 128, False, True, True): (3, 514, 1, 4),
(5120, 1280, 65792, 128, 128, True, False, True): (4, 514, 2, 4),
(6144, 6144, 256, 16, 16, False, True, True): (1, 2, 1, 4),
(6144, 6144, 256, 16, 16, True, False, True): (1, 1, 4, 4),
(6144, 6144, 256, 32, 32, False, True, True): (3, 2, 1, 8),
@ -4837,6 +5059,14 @@ _operation_device_version_data: Dict[Any, Dict] = {
(8192, 8192, 65536, 64, 64, True, False, True): (4, 256, 3, 8),
(8192, 8192, 65536, 128, 128, False, True, True): (6, 512, 1, 4),
(8192, 8192, 65536, 128, 128, True, False, True): (4, 512, 1, 4),
(8192, 8192, 65792, 16, 16, False, True, True): (1, 257, 1, 1),
(8192, 8192, 65792, 16, 16, True, False, True): (3, 257, 4, 1),
(8192, 8192, 65792, 32, 32, False, True, True): (2, 257, 1, 4),
(8192, 8192, 65792, 32, 32, True, False, True): (1, 257, 3, 4),
(8192, 8192, 65792, 64, 64, False, True, True): (2, 514, 3, 4),
(8192, 8192, 65792, 64, 64, True, False, True): (1, 257, 3, 4),
(8192, 8192, 65792, 128, 128, False, True, True): (2, 514, 1, 4),
(8192, 8192, 65792, 128, 128, True, False, True): (2, 514, 3, 8),
(8192, 8192, 131072, 16, 16, False, True, True): (4, 512, 4, 4),
(8192, 8192, 131072, 16, 16, True, False, True): (3, 512, 4, 4),
(8192, 8192, 131072, 32, 32, False, True, True): (2, 512, 4, 8),
@ -4997,6 +5227,14 @@ _operation_device_version_data: Dict[Any, Dict] = {
(16384, 16384, 65536, 64, 64, True, False, True): (1, 256, 3, 8),
(16384, 16384, 65536, 128, 128, False, True, True): (4, 512, 2, 8),
(16384, 16384, 65536, 128, 128, True, False, True): (4, 512, 1, 4),
(16384, 16384, 65792, 16, 16, False, True, True): (1, 257, 1, 1),
(16384, 16384, 65792, 16, 16, True, False, True): (1, 257, 4, 1),
(16384, 16384, 65792, 32, 32, False, True, True): (1, 257, 1, 4),
(16384, 16384, 65792, 32, 32, True, False, True): (1, 257, 3, 4),
(16384, 16384, 65792, 64, 64, False, True, True): (2, 514, 3, 4),
(16384, 16384, 65792, 64, 64, True, False, True): (1, 257, 3, 4),
(16384, 16384, 65792, 128, 128, False, True, True): (2, 514, 3, 8),
(16384, 16384, 65792, 128, 128, True, False, True): (2, 514, 3, 8),
(16384, 16384, 131072, 16, 16, False, True, True): (1, 512, 4, 4),
(16384, 16384, 131072, 16, 16, True, False, True): (1, 512, 3, 2),
(16384, 16384, 131072, 32, 32, False, True, True): (1, 512, 4, 8),
@ -5072,6 +5310,77 @@ _operation_device_version_data: Dict[Any, Dict] = {
(24576, 24576, 65536, 16, 16, False, True, True): (2, 512, 1, 2),
(24576, 24576, 65536, 16, 16, True, False, True): (1, 256, 4, 4),
(32768, 32768, 256, 16, 16, False, True, True): (4, 2, 1, 2),
(32768, 32768, 256, 16, 16, True, False, True): (2, 2, 5, 4),
(32768, 32768, 256, 32, 32, False, True, True): (4, 2, 4, 2),
(32768, 32768, 256, 32, 32, True, False, True): (1, 1, 4, 8),
(32768, 32768, 256, 64, 64, False, True, True): (2, 2, 3, 4),
(32768, 32768, 256, 64, 64, True, False, True): (1, 1, 3, 8),
(32768, 32768, 256, 128, 128, False, True, True): (2, 2, 3, 8),
(32768, 32768, 256, 128, 128, True, False, True): (2, 2, 3, 8),
(32768, 32768, 512, 16, 16, False, True, True): (2, 2, 1, 4),
(32768, 32768, 512, 16, 16, True, False, True): (2, 2, 4, 2),
(32768, 32768, 512, 32, 32, False, True, True): (1, 2, 3, 4),
(32768, 32768, 512, 32, 32, True, False, True): (1, 2, 4, 8),
(32768, 32768, 512, 64, 64, False, True, True): (4, 4, 3, 4),
(32768, 32768, 512, 64, 64, True, False, True): (1, 2, 3, 4),
(32768, 32768, 512, 128, 128, False, True, True): (4, 4, 3, 8),
(32768, 32768, 512, 128, 128, True, False, True): (4, 4, 3, 8),
(32768, 32768, 1024, 16, 16, False, True, True): (2, 4, 1, 1),
(32768, 32768, 1024, 16, 16, True, False, True): (1, 4, 4, 2),
(32768, 32768, 1024, 32, 32, False, True, True): (2, 4, 1, 4),
(32768, 32768, 1024, 32, 32, True, False, True): (1, 4, 3, 4),
(32768, 32768, 1024, 64, 64, False, True, True): (4, 8, 3, 4),
(32768, 32768, 1024, 64, 64, True, False, True): (1, 4, 3, 4),
(32768, 32768, 1024, 128, 128, False, True, True): (4, 8, 3, 8),
(32768, 32768, 1024, 128, 128, True, False, True): (4, 8, 3, 8),
(32768, 32768, 2048, 16, 16, False, True, True): (1, 8, 1, 4),
(32768, 32768, 2048, 16, 16, True, False, True): (1, 8, 4, 4),
(32768, 32768, 2048, 32, 32, False, True, True): (2, 8, 1, 4),
(32768, 32768, 2048, 32, 32, True, False, True): (1, 8, 3, 4),
(32768, 32768, 2048, 64, 64, False, True, True): (4, 16, 3, 4),
(32768, 32768, 2048, 64, 64, True, False, True): (1, 8, 3, 4),
(32768, 32768, 2048, 128, 128, False, True, True): (4, 16, 3, 8),
(32768, 32768, 2048, 128, 128, True, False, True): (2, 16, 3, 8),
(32768, 32768, 4096, 16, 16, False, True, True): (1, 16, 1, 4),
(32768, 32768, 4096, 16, 16, True, False, True): (1, 16, 4, 4),
(32768, 32768, 4096, 32, 32, False, True, True): (2, 16, 1, 4),
(32768, 32768, 4096, 32, 32, True, False, True): (1, 16, 3, 4),
(32768, 32768, 4096, 64, 64, False, True, True): (2, 32, 3, 4),
(32768, 32768, 4096, 64, 64, True, False, True): (1, 16, 3, 4),
(32768, 32768, 4096, 128, 128, False, True, True): (4, 32, 3, 8),
(32768, 32768, 4096, 128, 128, True, False, True): (4, 32, 3, 8),
(32768, 32768, 8192, 16, 16, False, True, True): (1, 32, 1, 4),
(32768, 32768, 8192, 16, 16, True, False, True): (2, 64, 4, 1),
(32768, 32768, 8192, 32, 32, False, True, True): (2, 32, 1, 4),
(32768, 32768, 8192, 32, 32, True, False, True): (1, 32, 3, 4),
(32768, 32768, 8192, 64, 64, False, True, True): (2, 64, 3, 4),
(32768, 32768, 8192, 64, 64, True, False, True): (1, 32, 3, 4),
(32768, 32768, 8192, 128, 128, False, True, True): (4, 64, 3, 8),
(32768, 32768, 8192, 128, 128, True, False, True): (2, 64, 3, 8),
(32768, 32768, 16384, 16, 16, False, True, True): (1, 64, 1, 4),
(32768, 32768, 16384, 16, 16, True, False, True): (1, 64, 4, 1),
(32768, 32768, 16384, 32, 32, False, True, True): (2, 64, 1, 4),
(32768, 32768, 16384, 32, 32, True, False, True): (1, 64, 3, 4),
(32768, 32768, 16384, 64, 64, False, True, True): (2, 128, 3, 4),
(32768, 32768, 16384, 64, 64, True, False, True): (1, 64, 3, 4),
(32768, 32768, 16384, 128, 128, False, True, True): (4, 128, 3, 8),
(32768, 32768, 16384, 128, 128, True, False, True): (2, 128, 3, 8),
(32768, 32768, 32768, 16, 16, False, True, True): (1, 128, 1, 4),
(32768, 32768, 32768, 16, 16, True, False, True): (1, 128, 4, 1),
(32768, 32768, 32768, 32, 32, False, True, True): (2, 128, 1, 4),
(32768, 32768, 32768, 32, 32, True, False, True): (1, 128, 3, 4),
(32768, 32768, 32768, 64, 64, False, True, True): (2, 256, 3, 4),
(32768, 32768, 32768, 64, 64, True, False, True): (1, 128, 3, 4),
(32768, 32768, 32768, 128, 128, False, True, True): (2, 256, 3, 8),
(32768, 32768, 32768, 128, 128, True, False, True): (4, 256, 3, 8),
(32768, 32768, 65536, 16, 16, False, True, True): (1, 256, 1, 4),
(32768, 32768, 65536, 16, 16, True, False, True): (1, 256, 4, 1),
(32768, 32768, 65536, 32, 32, False, True, True): (1, 256, 3, 4),
(32768, 32768, 65536, 32, 32, True, False, True): (1, 256, 3, 4),
(32768, 32768, 65536, 64, 64, False, True, True): (1, 512, 3, 4),
(32768, 32768, 65536, 64, 64, True, False, True): (1, 256, 3, 4),
(32768, 32768, 65536, 128, 128, False, True, True): (4, 512, 1, 4),
(32768, 32768, 65536, 128, 128, True, False, True): (2, 512, 3, 8),
},
("bsr_dense_addmm", "NVIDIA A100-SXM4-80GB", (0, torch.float16, 0.56)): {
(192, 192, 256, 64, 64, False, True, True): (1, 4, 3, 4),
@ -5844,6 +6153,14 @@ _operation_device_version_data: Dict[Any, Dict] = {
(1024, 1024, 131072, 64, 64, True, False, True): (1, 2048, 2, 4),
(1024, 1024, 131072, 128, 128, False, True, True): (1, 1024, 1, 32),
(1024, 1024, 131072, 128, 128, True, False, True): (1, 1024, 1, 32),
(1280, 5120, 65792, 16, 16, False, True, True): (1, 1028, 3, 1),
(1280, 5120, 65792, 16, 16, True, False, True): (1, 257, 3, 4),
(1280, 5120, 65792, 32, 32, False, True, True): (1, 514, 3, 4),
(1280, 5120, 65792, 32, 32, True, False, True): (1, 514, 3, 4),
(1280, 5120, 65792, 64, 64, False, True, True): (2, 1028, 3, 4),
(1280, 5120, 65792, 64, 64, True, False, True): (1, 1028, 3, 4),
(1280, 5120, 65792, 128, 128, False, True, True): (2, 514, 2, 32),
(1280, 5120, 65792, 128, 128, True, False, True): (1, 514, 2, 32),
(1536, 1536, 256, 16, 16, False, True, True): (5, 4, 3, 2),
(1536, 1536, 256, 16, 16, True, False, True): (2, 2, 3, 4),
(1536, 1536, 256, 32, 32, False, True, True): (1, 8, 2, 4),
@ -6252,6 +6569,14 @@ _operation_device_version_data: Dict[Any, Dict] = {
(4096, 4096, 131072, 64, 64, True, False, True): (2, 2048, 2, 4),
(4096, 4096, 131072, 128, 128, False, True, True): (4, 1024, 1, 32),
(4096, 4096, 131072, 128, 128, True, False, True): (4, 1024, 1, 32),
(5120, 1280, 65792, 16, 16, False, True, True): (2, 1028, 3, 1),
(5120, 1280, 65792, 16, 16, True, False, True): (1, 257, 3, 4),
(5120, 1280, 65792, 32, 32, False, True, True): (1, 514, 3, 4),
(5120, 1280, 65792, 32, 32, True, False, True): (1, 514, 3, 4),
(5120, 1280, 65792, 64, 64, False, True, True): (1, 1028, 3, 4),
(5120, 1280, 65792, 64, 64, True, False, True): (5, 1028, 3, 4),
(5120, 1280, 65792, 128, 128, False, True, True): (1, 514, 1, 32),
(5120, 1280, 65792, 128, 128, True, False, True): (4, 514, 2, 32),
(6144, 6144, 256, 16, 16, False, True, True): (2, 2, 3, 4),
(6144, 6144, 256, 16, 16, True, False, True): (2, 2, 3, 4),
(6144, 6144, 256, 32, 32, False, True, True): (2, 4, 3, 4),
@ -6535,6 +6860,24 @@ _operation_device_version_data: Dict[Any, Dict] = {
(384, 384, 131072, 128, 128, False, True, True): (1, 1024, 1, 32),
(384, 384, 131072, 128, 128, True, False, True): (3, 1024, 1, 32),
},
("bsr_dense_addmm", "NVIDIA A100-SXM4-80GB", (0, torch.int8, 0.5)): {
(1280, 5120, 65792, 32, 32, False, True, True): (1, 1028, 1, 8),
(1280, 5120, 65792, 32, 32, True, False, True): (1, 514, 3, 2),
(1280, 5120, 65792, 64, 64, False, True, True): (2, 514, 1, 4),
(1280, 5120, 65792, 64, 64, True, False, True): (1, 514, 3, 2),
(1280, 5120, 65792, 128, 128, False, True, True): (2, 514, 1, 8),
(1280, 5120, 65792, 128, 128, True, False, True): (1, 514, 2, 4),
(1280, 5120, 65792, 256, 256, False, True, True): (1, 257, 1, 32),
(1280, 5120, 65792, 256, 256, True, False, True): (1, 257, 1, 32),
(5120, 1280, 65792, 32, 32, False, True, True): (3, 1028, 1, 8),
(5120, 1280, 65792, 32, 32, True, False, True): (1, 514, 1, 2),
(5120, 1280, 65792, 64, 64, False, True, True): (1, 514, 1, 4),
(5120, 1280, 65792, 64, 64, True, False, True): (2, 514, 2, 2),
(5120, 1280, 65792, 128, 128, False, True, True): (2, 514, 1, 8),
(5120, 1280, 65792, 128, 128, True, False, True): (2, 514, 2, 4),
(5120, 1280, 65792, 256, 256, False, True, True): (1, 257, 1, 32),
(5120, 1280, 65792, 256, 256, True, False, True): (1, 257, 1, 32),
},
("scatter_mm", "NVIDIA A100-SXM4-80GB", (0, torch.bfloat16, 0.5)): {
(256, 256, 256, 16, 16): (1, 1, 16, 16, 1, 2),
(256, 256, 256, 32, 32): (1, 1, 16, 16, 1, 4),
@ -7396,6 +7739,6 @@ if __name__ == "__main__":
for dtype in [torch.int8]:
for op in ["_int_bsr_dense_addmm"]:
main(op=op, force=False, dtype=dtype)
for dtype in [torch.float16, torch.bfloat16, torch.float32]:
for dtype in [torch.float16, torch.bfloat16, torch.float32, torch.int8]:
for op in ["bsr_dense_addmm"]:
main(op=op, force=False, dtype=dtype)