Eliminate unnecessary multiplications by 1 in addmm with sparse compressed tensor operand (#114026)

This PR:
- updates `torch/sparse/_triton_ops_meta.py` for the API change in `triton.testing.do_bench`
- force `num_stages` to be 1 when blocksize is 128x128 to avoid out of resources exception when `bsr_dense_mm` is called from `nn.linear`.
- as in the title. The performance of `nn.linear` on BSR tensor weights (dtypes `float16` and `bfloat16`) is increased as follows (`NVIDIA A100-SXM4-80GB`):
  - for blocksize 16x16, the average/maximum speed up is about 11/20 %
  - for blocksize 32x32, the average/maximum speed up is about 15/24 %
  - for blocksize 64x64, the average/maximum speed up is about 18/26 %
  - for blocksize 128x128, the average/maximum speed up is about 15/28 %

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114026
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Pearu Peterson
2023-11-18 21:19:21 +00:00
committed by PyTorch MergeBot
parent 826ab0e32d
commit 12f95df0e9
2 changed files with 99 additions and 79 deletions

View File

@ -218,27 +218,42 @@ Tensor& _compressed_row_strided_addmm_out(
const Scalar& beta,
const Scalar& alpha,
Tensor& result) {
auto alpha_val = alpha.toComplexDouble();
auto beta_val = beta.toComplexDouble();
// If result is not the same as self, it could always be used as out argument to mm.
if (!result.is_same(self)) {
_compressed_row_strided_mm_out(mat1, mat2, result).mul_(alpha);
_compressed_row_strided_mm_out(mat1, mat2, result);
if (alpha_val != 1.) {
result.mul_(alpha);
}
// Process beta
if (beta.toComplexDouble() != 0.) {
if (beta_val != 0.) {
if (beta_val == 1.) {
result.add_(self);
} else {
result.add_(self.mul(beta));
}
}
}
// Otherwise we need to allocate external memory for mm if beta != 0.
else {
// Process beta
if (beta.toComplexDouble() != 0.) {
if (beta_val != 0.) {
if (beta_val != 1.) {
result.mul_(beta);
}
auto mm = at::empty_like(result);
_compressed_row_strided_mm_out(mat1, mat2, mm);
if (alpha_val != 1.) {
mm.mul_(alpha);
}
result.add_(mm);
}
else {
_compressed_row_strided_mm_out(mat1, mat2, result).mul_(alpha);
_compressed_row_strided_mm_out(mat1, mat2, result);
if (alpha_val != 1.) {
result.mul_(alpha);
}
}
}

View File

@ -315,11 +315,11 @@ def optimize_scatter_mm(
def test_func():
return bsr_scatter_mm(bsr, dense, indices_data=indices_data)
ms, ms_min, ms_max = triton.testing.do_bench(
ms_min = triton.testing.do_bench(
test_func, warmup=500, rep=100, fast_flush=False
)
return ms
return ms_min
def step_meta_parameter(name, value, direction, meta, m=m, n=n, k=k, bm=bm, bk=bk):
# return next value in positive or negative direction, or
@ -364,8 +364,7 @@ def optimize_scatter_mm(
meta, speedup, timing = minimize(
bench, initial_meta, reference_meta, step_meta_parameter
)
print(f"{meta=} {speedup=:.1f} % {timing=:.3f} ms")
if initial_meta is not reference_meta and initial_meta == meta:
if initial_meta is not reference_meta and initial_meta == meta and not force:
return
print(f"{meta=} {speedup=:.1f} % {timing=:.3f} ms")
if speedup < 0:
@ -411,11 +410,11 @@ def optimize_bsr_dense_mm(
def test_func():
return bsr_dense_mm(bsr, dense, meta=meta)
ms, ms_min, ms_max = triton.testing.do_bench(
ms_min = triton.testing.do_bench(
test_func, warmup=500, rep=100, fast_flush=False
)
return ms
return ms_min
def step_meta_parameter(name, value, direction, meta, m=m, n=n, k=k, bm=bm, bk=bk):
# return next value in positive or negative direction, or
@ -424,6 +423,11 @@ def optimize_bsr_dense_mm(
is_log = name in {"num_warps"}
min_value = dict(num_warps=1, num_stages=1, GROUP_SIZE_ROW=1)[name]
max_value = dict().get(name)
if (bm, bk) == (128, 128) and name == "num_stages":
# For some reason, when bsr_dense_mm is called through
# nn.linear, the call will fail with out of resource
# error. So, we'll set a hard limit for such cases:
max_value = 1
value_step = dict(num_warps=2, num_stages=1, GROUP_SIZE_ROW=1)[name]
if is_log:
next_value = (
@ -442,7 +446,8 @@ def optimize_bsr_dense_mm(
meta, speedup, timing = minimize(
bench, initial_meta, reference_meta, step_meta_parameter, max_step=2
)
if initial_meta is not reference_meta and initial_meta == meta:
if initial_meta is not reference_meta and initial_meta == meta and not force:
return
print(f"{meta=} {speedup=:.1f} % {timing=:.3f} ms")
if speedup < 0:
@ -540,11 +545,11 @@ def main(op="scatter_mm", force=False, dtype=torch.float16):
else:
raise NotImplementedError(op)
ms, ms_min, ms_max = triton.testing.do_bench(
ms_min = triton.testing.do_bench(
test_func, warmup=500, rep=100, fast_flush=False
)
return ms
return ms_min
meta_lst.append(
(bench(meta), sparsity, tuple(meta[k] for k in sorted(meta)))
@ -1157,215 +1162,215 @@ _operation_device_version_data: Dict[Any, Dict] = {
(256, 256, 256, 16, 16): (8, 6, 1),
(256, 256, 256, 32, 32): (4, 5, 2),
(256, 256, 256, 64, 64): (3, 3, 4),
(256, 256, 256, 128, 128): (4, 2, 8),
(256, 256, 256, 128, 128): (3, 1, 8),
(256, 256, 512, 16, 16): (4, 1, 4),
(256, 256, 512, 32, 32): (4, 3, 4),
(256, 256, 512, 64, 64): (6, 3, 4),
(256, 256, 512, 128, 128): (2, 2, 8),
(256, 256, 512, 128, 128): (2, 1, 8),
(256, 256, 1024, 16, 16): (4, 1, 2),
(256, 256, 1024, 32, 32): (1, 3, 2),
(256, 256, 1024, 64, 64): (4, 3, 4),
(256, 256, 1024, 128, 128): (4, 2, 8),
(256, 256, 1024, 128, 128): (5, 1, 8),
(256, 256, 2048, 16, 16): (2, 3, 1),
(256, 256, 2048, 32, 32): (5, 4, 1),
(256, 256, 2048, 64, 64): (3, 3, 4),
(256, 256, 2048, 128, 128): (4, 2, 8),
(256, 256, 2048, 128, 128): (4, 1, 8),
(256, 256, 4096, 16, 16): (4, 3, 1),
(256, 256, 4096, 32, 32): (4, 1, 4),
(256, 256, 4096, 64, 64): (4, 3, 4),
(256, 256, 4096, 128, 128): (4, 2, 8),
(256, 256, 4096, 128, 128): (2, 1, 8),
(256, 256, 8192, 16, 16): (5, 3, 1),
(256, 256, 8192, 32, 32): (4, 3, 1),
(256, 256, 8192, 64, 64): (4, 2, 4),
(256, 256, 8192, 128, 128): (2, 1, 4),
(256, 256, 8192, 128, 128): (3, 1, 4),
(256, 256, 16384, 16, 16): (7, 3, 1),
(256, 256, 16384, 32, 32): (3, 3, 1),
(256, 256, 16384, 64, 64): (4, 2, 4),
(256, 256, 16384, 128, 128): (4, 1, 4),
(256, 256, 16384, 128, 128): (3, 1, 4),
(256, 256, 32768, 16, 16): (7, 3, 1),
(256, 256, 32768, 32, 32): (1, 3, 1),
(256, 256, 32768, 64, 64): (4, 2, 4),
(256, 256, 32768, 128, 128): (4, 1, 4),
(256, 256, 32768, 128, 128): (1, 1, 4),
(256, 256, 65536, 16, 16): (5, 3, 1),
(256, 256, 65536, 32, 32): (4, 3, 1),
(256, 256, 65536, 64, 64): (6, 2, 4),
(256, 256, 65536, 128, 128): (1, 1, 4),
(256, 256, 131072, 16, 16): (4, 1, 2),
(256, 256, 65536, 128, 128): (2, 1, 4),
(256, 256, 131072, 16, 16): (1, 1, 4),
(256, 256, 131072, 32, 32): (4, 2, 2),
(256, 256, 131072, 64, 64): (4, 2, 4),
(256, 256, 131072, 128, 128): (4, 1, 4),
(256, 256, 131072, 128, 128): (2, 1, 4),
(512, 512, 256, 16, 16): (4, 5, 1),
(512, 512, 256, 32, 32): (4, 5, 2),
(512, 512, 256, 64, 64): (4, 3, 4),
(512, 512, 256, 128, 128): (4, 2, 8),
(512, 512, 256, 128, 128): (4, 1, 8),
(512, 512, 512, 16, 16): (2, 4, 1),
(512, 512, 512, 32, 32): (4, 4, 4),
(512, 512, 512, 64, 64): (4, 5, 4),
(512, 512, 512, 128, 128): (4, 2, 8),
(512, 512, 512, 128, 128): (4, 1, 8),
(512, 512, 1024, 16, 16): (1, 4, 1),
(512, 512, 1024, 32, 32): (2, 3, 1),
(512, 512, 1024, 64, 64): (4, 4, 4),
(512, 512, 1024, 128, 128): (4, 2, 8),
(512, 512, 1024, 128, 128): (6, 1, 8),
(512, 512, 2048, 16, 16): (5, 3, 1),
(512, 512, 2048, 32, 32): (2, 3, 2),
(512, 512, 2048, 64, 64): (1, 3, 2),
(512, 512, 2048, 128, 128): (4, 2, 8),
(512, 512, 2048, 128, 128): (5, 1, 8),
(512, 512, 4096, 16, 16): (4, 3, 1),
(512, 512, 4096, 32, 32): (4, 4, 2),
(512, 512, 4096, 64, 64): (5, 3, 4),
(512, 512, 4096, 128, 128): (2, 2, 8),
(512, 512, 4096, 128, 128): (3, 1, 8),
(512, 512, 8192, 16, 16): (2, 3, 1),
(512, 512, 8192, 32, 32): (1, 3, 1),
(512, 512, 8192, 64, 64): (5, 3, 2),
(512, 512, 8192, 128, 128): (4, 1, 16),
(512, 512, 8192, 128, 128): (6, 1, 8),
(512, 512, 16384, 16, 16): (4, 3, 1),
(512, 512, 16384, 32, 32): (4, 3, 1),
(512, 512, 16384, 64, 64): (4, 3, 4),
(512, 512, 16384, 128, 128): (4, 1, 4),
(512, 512, 16384, 128, 128): (1, 1, 4),
(512, 512, 32768, 16, 16): (1, 2, 1),
(512, 512, 32768, 32, 32): (5, 3, 1),
(512, 512, 32768, 64, 64): (4, 3, 2),
(512, 512, 32768, 128, 128): (5, 1, 4),
(512, 512, 32768, 128, 128): (1, 1, 4),
(512, 512, 65536, 16, 16): (4, 3, 1),
(512, 512, 65536, 32, 32): (1, 3, 1),
(512, 512, 65536, 64, 64): (4, 3, 2),
(512, 512, 65536, 128, 128): (5, 1, 4),
(512, 512, 131072, 16, 16): (4, 1, 4),
(512, 512, 65536, 128, 128): (1, 1, 4),
(512, 512, 131072, 16, 16): (1, 1, 4),
(512, 512, 131072, 32, 32): (4, 2, 2),
(512, 512, 131072, 64, 64): (4, 3, 2),
(512, 512, 131072, 128, 128): (4, 1, 4),
(512, 512, 131072, 128, 128): (1, 1, 4),
(1024, 1024, 256, 16, 16): (4, 4, 1),
(1024, 1024, 256, 32, 32): (2, 4, 2),
(1024, 1024, 256, 64, 64): (4, 4, 4),
(1024, 1024, 256, 128, 128): (4, 2, 8),
(1024, 1024, 256, 128, 128): (2, 1, 8),
(1024, 1024, 512, 16, 16): (3, 3, 1),
(1024, 1024, 512, 32, 32): (5, 4, 2),
(1024, 1024, 512, 64, 64): (4, 3, 4),
(1024, 1024, 512, 128, 128): (3, 2, 8),
(1024, 1024, 512, 128, 128): (1, 1, 16),
(1024, 1024, 1024, 16, 16): (7, 3, 1),
(1024, 1024, 1024, 32, 32): (2, 3, 1),
(1024, 1024, 1024, 64, 64): (1, 3, 2),
(1024, 1024, 1024, 128, 128): (1, 2, 8),
(1024, 1024, 1024, 128, 128): (4, 1, 8),
(1024, 1024, 2048, 16, 16): (2, 3, 1),
(1024, 1024, 2048, 32, 32): (1, 4, 1),
(1024, 1024, 2048, 64, 64): (3, 3, 4),
(1024, 1024, 2048, 128, 128): (4, 2, 8),
(1024, 1024, 2048, 128, 128): (4, 1, 16),
(1024, 1024, 4096, 16, 16): (4, 3, 1),
(1024, 1024, 4096, 32, 32): (4, 4, 1),
(1024, 1024, 4096, 64, 64): (4, 3, 4),
(1024, 1024, 4096, 128, 128): (4, 2, 8),
(1024, 1024, 4096, 128, 128): (4, 1, 8),
(1024, 1024, 8192, 16, 16): (2, 3, 1),
(1024, 1024, 8192, 32, 32): (4, 3, 1),
(1024, 1024, 8192, 64, 64): (4, 3, 2),
(1024, 1024, 8192, 128, 128): (4, 1, 4),
(1024, 1024, 8192, 128, 128): (1, 1, 4),
(1024, 1024, 16384, 16, 16): (4, 2, 1),
(1024, 1024, 16384, 32, 32): (4, 3, 1),
(1024, 1024, 16384, 64, 64): (4, 3, 2),
(1024, 1024, 16384, 128, 128): (4, 1, 4),
(1024, 1024, 16384, 128, 128): (1, 1, 4),
(1024, 1024, 32768, 16, 16): (4, 2, 1),
(1024, 1024, 32768, 32, 32): (8, 3, 1),
(1024, 1024, 32768, 64, 64): (8, 3, 2),
(1024, 1024, 32768, 128, 128): (4, 1, 4),
(1024, 1024, 32768, 128, 128): (1, 1, 4),
(1024, 1024, 65536, 16, 16): (4, 4, 1),
(1024, 1024, 65536, 32, 32): (7, 3, 1),
(1024, 1024, 65536, 64, 64): (7, 3, 2),
(1024, 1024, 65536, 128, 128): (4, 1, 4),
(1024, 1024, 65536, 128, 128): (1, 1, 4),
(1024, 1024, 131072, 16, 16): (4, 1, 4),
(1024, 1024, 131072, 32, 32): (4, 2, 1),
(1024, 1024, 131072, 64, 64): (5, 3, 2),
(1024, 1024, 131072, 128, 128): (4, 1, 4),
(1024, 1024, 131072, 128, 128): (1, 1, 4),
(2048, 2048, 256, 16, 16): (5, 4, 1),
(2048, 2048, 256, 32, 32): (4, 4, 2),
(2048, 2048, 256, 64, 64): (2, 3, 4),
(2048, 2048, 256, 128, 128): (4, 2, 8),
(2048, 2048, 256, 128, 128): (7, 1, 8),
(2048, 2048, 512, 16, 16): (4, 4, 1),
(2048, 2048, 512, 32, 32): (5, 3, 2),
(2048, 2048, 512, 64, 64): (8, 3, 4),
(2048, 2048, 512, 128, 128): (4, 2, 8),
(2048, 2048, 512, 128, 128): (8, 1, 16),
(2048, 2048, 1024, 16, 16): (4, 3, 1),
(2048, 2048, 1024, 32, 32): (2, 4, 1),
(2048, 2048, 1024, 64, 64): (5, 3, 4),
(2048, 2048, 1024, 128, 128): (6, 2, 8),
(2048, 2048, 1024, 128, 128): (4, 1, 4),
(2048, 2048, 2048, 16, 16): (3, 3, 1),
(2048, 2048, 2048, 32, 32): (2, 4, 1),
(2048, 2048, 2048, 64, 64): (3, 3, 2),
(2048, 2048, 2048, 128, 128): (4, 2, 8),
(2048, 2048, 2048, 128, 128): (6, 1, 16),
(2048, 2048, 4096, 16, 16): (4, 3, 1),
(2048, 2048, 4096, 32, 32): (4, 4, 2),
(2048, 2048, 4096, 64, 64): (6, 3, 2),
(2048, 2048, 4096, 128, 128): (2, 1, 4),
(2048, 2048, 4096, 128, 128): (1, 1, 4),
(2048, 2048, 8192, 16, 16): (6, 2, 1),
(2048, 2048, 8192, 32, 32): (6, 4, 2),
(2048, 2048, 8192, 64, 64): (6, 3, 2),
(2048, 2048, 8192, 128, 128): (2, 1, 4),
(2048, 2048, 8192, 128, 128): (1, 1, 4),
(2048, 2048, 16384, 16, 16): (4, 2, 1),
(2048, 2048, 16384, 32, 32): (4, 4, 1),
(2048, 2048, 16384, 64, 64): (4, 3, 2),
(2048, 2048, 16384, 128, 128): (2, 1, 4),
(2048, 2048, 16384, 128, 128): (1, 1, 4),
(2048, 2048, 32768, 16, 16): (8, 2, 1),
(2048, 2048, 32768, 32, 32): (7, 4, 1),
(2048, 2048, 32768, 64, 64): (9, 3, 2),
(2048, 2048, 32768, 128, 128): (4, 1, 4),
(2048, 2048, 32768, 128, 128): (1, 1, 4),
(2048, 2048, 65536, 16, 16): (3, 2, 1),
(2048, 2048, 65536, 32, 32): (9, 3, 1),
(2048, 2048, 65536, 64, 64): (4, 3, 2),
(2048, 2048, 65536, 128, 128): (2, 1, 4),
(2048, 2048, 131072, 16, 16): (4, 1, 4),
(2048, 2048, 65536, 128, 128): (1, 1, 4),
(2048, 2048, 131072, 16, 16): (1, 1, 1),
(2048, 2048, 131072, 32, 32): (4, 1, 1),
(2048, 2048, 131072, 64, 64): (4, 3, 2),
(2048, 2048, 131072, 128, 128): (4, 1, 4),
(2048, 2048, 131072, 128, 128): (1, 1, 4),
(4096, 4096, 256, 16, 16): (4, 4, 1),
(4096, 4096, 256, 32, 32): (1, 3, 2),
(4096, 4096, 256, 64, 64): (3, 3, 4),
(4096, 4096, 256, 128, 128): (3, 2, 8),
(4096, 4096, 256, 128, 128): (8, 1, 16),
(4096, 4096, 512, 16, 16): (1, 3, 1),
(4096, 4096, 512, 32, 32): (1, 3, 4),
(4096, 4096, 512, 64, 64): (6, 3, 4),
(4096, 4096, 512, 128, 128): (2, 2, 8),
(4096, 4096, 512, 128, 128): (1, 1, 4),
(4096, 4096, 1024, 16, 16): (1, 3, 1),
(4096, 4096, 1024, 32, 32): (4, 4, 2),
(4096, 4096, 1024, 64, 64): (4, 4, 4),
(4096, 4096, 1024, 128, 128): (2, 2, 8),
(4096, 4096, 1024, 128, 128): (1, 1, 4),
(4096, 4096, 2048, 16, 16): (1, 3, 1),
(4096, 4096, 2048, 32, 32): (3, 4, 2),
(4096, 4096, 2048, 64, 64): (4, 3, 2),
(4096, 4096, 2048, 128, 128): (4, 1, 4),
(4096, 4096, 2048, 128, 128): (1, 1, 4),
(4096, 4096, 4096, 16, 16): (2, 3, 1),
(4096, 4096, 4096, 32, 32): (2, 4, 2),
(4096, 4096, 4096, 64, 64): (1, 3, 2),
(4096, 4096, 4096, 128, 128): (4, 1, 4),
(4096, 4096, 4096, 128, 128): (1, 1, 4),
(4096, 4096, 8192, 16, 16): (8, 2, 1),
(4096, 4096, 8192, 32, 32): (2, 4, 2),
(4096, 4096, 8192, 64, 64): (4, 3, 2),
(4096, 4096, 8192, 128, 128): (4, 1, 4),
(4096, 4096, 8192, 128, 128): (1, 1, 4),
(4096, 4096, 16384, 16, 16): (1, 1, 1),
(4096, 4096, 16384, 32, 32): (4, 4, 1),
(4096, 4096, 16384, 64, 64): (4, 3, 2),
(4096, 4096, 16384, 128, 128): (4, 1, 4),
(4096, 4096, 16384, 128, 128): (1, 1, 4),
(4096, 4096, 32768, 16, 16): (4, 2, 1),
(4096, 4096, 32768, 32, 32): (5, 3, 1),
(4096, 4096, 32768, 64, 64): (3, 3, 2),
(4096, 4096, 32768, 128, 128): (4, 1, 4),
(4096, 4096, 32768, 128, 128): (1, 1, 4),
(4096, 4096, 65536, 16, 16): (4, 2, 1),
(4096, 4096, 65536, 32, 32): (2, 4, 1),
(4096, 4096, 65536, 64, 64): (3, 3, 2),
(4096, 4096, 65536, 128, 128): (4, 1, 4),
(4096, 4096, 131072, 16, 16): (1, 1, 4),
(4096, 4096, 65536, 128, 128): (1, 1, 4),
(4096, 4096, 131072, 16, 16): (1, 1, 1),
(4096, 4096, 131072, 32, 32): (4, 2, 1),
(4096, 4096, 131072, 64, 64): (7, 3, 2),
(4096, 4096, 131072, 128, 128): (4, 1, 4),
(4096, 4096, 131072, 128, 128): (1, 1, 4),
(8192, 8192, 256, 16, 16): (4, 4, 1),
(8192, 8192, 256, 32, 32): (4, 5, 2),
(8192, 8192, 256, 64, 64): (4, 3, 4),
(8192, 8192, 256, 128, 128): (5, 1, 4),
(8192, 8192, 256, 128, 128): (1, 1, 4),
(8192, 8192, 512, 16, 16): (4, 5, 1),
(8192, 8192, 512, 32, 32): (4, 4, 2),
(8192, 8192, 512, 64, 64): (4, 4, 4),
(8192, 8192, 512, 128, 128): (4, 2, 8),
(8192, 8192, 512, 128, 128): (4, 1, 16),
(8192, 8192, 1024, 16, 16): (4, 5, 1),
(8192, 8192, 1024, 32, 32): (4, 4, 2),
(8192, 8192, 1024, 64, 64): (4, 4, 4),
(8192, 8192, 1024, 128, 128): (4, 1, 4),
(8192, 8192, 1024, 128, 128): (1, 1, 4),
(8192, 8192, 2048, 16, 16): (4, 5, 1),
(8192, 8192, 2048, 32, 32): (4, 4, 2),
(8192, 8192, 2048, 64, 64): (4, 3, 2),
@ -1390,22 +1395,22 @@ _operation_device_version_data: Dict[Any, Dict] = {
(8192, 8192, 65536, 32, 32): (4, 4, 1),
(8192, 8192, 65536, 64, 64): (4, 4, 2),
(8192, 8192, 65536, 128, 128): (4, 1, 4),
(8192, 8192, 131072, 16, 16): (4, 1, 4),
(8192, 8192, 131072, 16, 16): (2, 1, 16),
(8192, 8192, 131072, 32, 32): (4, 2, 1),
(8192, 8192, 131072, 64, 64): (4, 3, 2),
(8192, 8192, 131072, 128, 128): (4, 1, 4),
(16384, 16384, 256, 16, 16): (4, 7, 1),
(16384, 16384, 256, 32, 32): (4, 4, 2),
(16384, 16384, 256, 64, 64): (4, 4, 4),
(16384, 16384, 256, 128, 128): (6, 2, 8),
(16384, 16384, 256, 128, 128): (1, 1, 4),
(16384, 16384, 512, 16, 16): (4, 7, 1),
(16384, 16384, 512, 32, 32): (4, 5, 2),
(16384, 16384, 512, 64, 64): (4, 3, 2),
(16384, 16384, 512, 128, 128): (4, 2, 8),
(16384, 16384, 512, 128, 128): (4, 1, 16),
(16384, 16384, 1024, 16, 16): (4, 9, 1),
(16384, 16384, 1024, 32, 32): (4, 4, 1),
(16384, 16384, 1024, 64, 64): (4, 4, 4),
(16384, 16384, 1024, 128, 128): (4, 1, 4),
(16384, 16384, 1024, 128, 128): (1, 1, 4),
(16384, 16384, 2048, 16, 16): (4, 9, 1),
(16384, 16384, 2048, 32, 32): (4, 4, 1),
(16384, 16384, 2048, 64, 64): (4, 4, 4),
@ -2478,7 +2483,7 @@ _operation_device_version_data: Dict[Any, Dict] = {
(4096, 4096, 65536, 32, 32): (5, 1, 32, 128, 1, 4),
(4096, 4096, 65536, 64, 64): (1, 1, 64, 64, 3, 4),
(4096, 4096, 65536, 128, 128): (3, 16, 128, 64, 2, 4),
(4096, 4096, 131072, 16, 16): (5, 1, 16, 128, 1, 2),
(4096, 4096, 131072, 16, 16): (3, 1, 16, 128, 1, 2),
(4096, 4096, 131072, 32, 32): (3, 1, 32, 128, 3, 2),
(4096, 4096, 131072, 64, 64): (2, 1, 64, 64, 3, 4),
(4096, 4096, 131072, 128, 128): (1, 1, 128, 64, 1, 4),