mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix a bug in retrieving approximate bsr_dense_addmm kernel meta data (#124371)
Fixes #124333 Pull Request resolved: https://github.com/pytorch/pytorch/pull/124371 Approved by: https://github.com/eqy, https://github.com/lezcano
This commit is contained in:
committed by
PyTorch MergeBot
parent
a47f4253ab
commit
49f0d127fb
@ -571,10 +571,12 @@ def bsr_dense_addmm_meta(M, K, N, Ms, Ks, beta, alpha,
|
||||
device_name, version=(0, dtype, 0.5))
|
||||
if meta is None:
|
||||
# find approximate meta such that N % SPLIT_N == 0.
|
||||
for mkey, meta_ in sorted(get_meta(
|
||||
'bsr_dense_addmm',
|
||||
(*key[:2], '*', *key[3:]),
|
||||
device_name, version=(0, dtype, 0.5)) or {}):
|
||||
matching_meta = get_meta(
|
||||
'bsr_dense_addmm',
|
||||
(*key[:2], '*', *key[3:]),
|
||||
device_name, version=(0, dtype, 0.5))
|
||||
for mkey in sorted(matching_meta or {}):
|
||||
meta_ = matching_meta[mkey]
|
||||
if N % meta_['SPLIT_N'] == 0 and mkey[2] <= N:
|
||||
meta = meta_
|
||||
if meta is not None:
|
||||
|
@ -160,9 +160,8 @@ def get_meta(op, key, device_name=None, version=(0, torch.float16, 0.5), exact=F
|
||||
values = op_data.get(key)
|
||||
if values is not None:
|
||||
matching_data[key] = values
|
||||
|
||||
matching_meta = {}
|
||||
for key, values in matching_data.items():
|
||||
for op_key, values in matching_data.items():
|
||||
if op == "scatter_mm":
|
||||
names = (
|
||||
"GROUP_SIZE",
|
||||
@ -182,7 +181,7 @@ def get_meta(op, key, device_name=None, version=(0, torch.float16, 0.5), exact=F
|
||||
if "*" not in key:
|
||||
return meta
|
||||
|
||||
matching_meta[key] = meta
|
||||
matching_meta[op_key] = meta
|
||||
|
||||
if "*" in key:
|
||||
return matching_meta
|
||||
|
Reference in New Issue
Block a user