Revert "multi-kernel matmuls based on varying hint sizes (#156628)"

This reverts commit 6c795306378c47341d58109da03371bba2bec46e.

Reverted https://github.com/pytorch/pytorch/pull/156628 on behalf of https://github.com/huydhn due to Sorry for reverting your change but some ROCM jobs went crazy after this lands, so I try to see if reverting helps ([comment](https://github.com/pytorch/pytorch/pull/156628#issuecomment-3064617123))
This commit is contained in:
PyTorch MergeBot
2025-07-12 03:48:39 +00:00
parent 2eff14c445
commit 9c189ed29a
14 changed files with 139 additions and 635 deletions

View File

@ -265,7 +265,6 @@ class PersistentCache(CacheBase):
op: str,
inputs: str,
benchmark: Optional[Callable[[Any], dict[ChoiceCaller, float]]],
hint_override: Optional[int] = None,
) -> dict[ChoiceCaller, float]:
"""
Check to see if we have benchmarked the given choice callers. For each
@ -278,7 +277,6 @@ class PersistentCache(CacheBase):
b. `max_autotune_gemm=False`: don't benchmark the choice, return nothing.
"""
precision = torch.get_float32_matmul_precision()
cache_key = f"{inputs}_{hint_override}" if hint_override is not None else inputs
timings = {}
@ -287,11 +285,9 @@ class PersistentCache(CacheBase):
hit = True
for choice in choices:
choice_hash = choice.hash_key()
if choice_hash in cache.get(op, {}).get(cache_key, {}).get(
precision, {}
):
if choice_hash in cache.get(op, {}).get(inputs, {}).get(precision, {}):
# cache hit
timings[choice] = cache[op][cache_key][precision][choice_hash]
timings[choice] = cache[op][inputs][precision][choice_hash]
else:
# cache miss
hit = False
@ -304,9 +300,9 @@ class PersistentCache(CacheBase):
timings = benchmark(choices)
assert all(choice in timings for choice in choices)
local_cache.setdefault(op, {})
local_cache[op].setdefault(cache_key, {}).setdefault(precision, {})
local_cache[op].setdefault(inputs, {}).setdefault(precision, {})
for choice, timing in timings.items():
local_cache[op][cache_key][precision][choice.hash_key()] = timing
local_cache[op][inputs][precision][choice.hash_key()] = timing
self.update_local_cache(local_cache)