Fix a bug in retrieving approximate bsr_dense_addmm kernel meta data (#124371)

Fixes #124333

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124371
Approved by: https://github.com/eqy, https://github.com/lezcano
This commit is contained in:
Pearu Peterson
2024-04-18 11:32:39 +03:00
committed by PyTorch MergeBot
parent a47f4253ab
commit 49f0d127fb
2 changed files with 8 additions and 7 deletions

View File

@ -571,10 +571,12 @@ def bsr_dense_addmm_meta(M, K, N, Ms, Ks, beta, alpha,
device_name, version=(0, dtype, 0.5))
if meta is None:
# find approximate meta such that N % SPLIT_N == 0.
for mkey, meta_ in sorted(get_meta(
'bsr_dense_addmm',
(*key[:2], '*', *key[3:]),
device_name, version=(0, dtype, 0.5)) or {}):
matching_meta = get_meta(
'bsr_dense_addmm',
(*key[:2], '*', *key[3:]),
device_name, version=(0, dtype, 0.5))
for mkey in sorted(matching_meta or {}):
meta_ = matching_meta[mkey]
if N % meta_['SPLIT_N'] == 0 and mkey[2] <= N:
meta = meta_
if meta is not None:

View File

@ -160,9 +160,8 @@ def get_meta(op, key, device_name=None, version=(0, torch.float16, 0.5), exact=F
values = op_data.get(key)
if values is not None:
matching_data[key] = values
matching_meta = {}
for key, values in matching_data.items():
for op_key, values in matching_data.items():
if op == "scatter_mm":
names = (
"GROUP_SIZE",
@ -182,7 +181,7 @@ def get_meta(op, key, device_name=None, version=(0, torch.float16, 0.5), exact=F
if "*" not in key:
return meta
matching_meta[key] = meta
matching_meta[op_key] = meta
if "*" in key:
return matching_meta