mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "multi-kernel matmuls based on varying hint sizes (#156628)"
This reverts commit 6c795306378c47341d58109da03371bba2bec46e. Reverted https://github.com/pytorch/pytorch/pull/156628 on behalf of https://github.com/huydhn due to Sorry for reverting your change but some ROCM jobs went crazy after this lands, so I try to see if reverting helps ([comment](https://github.com/pytorch/pytorch/pull/156628#issuecomment-3064617123))
This commit is contained in:
@ -265,7 +265,6 @@ class PersistentCache(CacheBase):
|
||||
op: str,
|
||||
inputs: str,
|
||||
benchmark: Optional[Callable[[Any], dict[ChoiceCaller, float]]],
|
||||
hint_override: Optional[int] = None,
|
||||
) -> dict[ChoiceCaller, float]:
|
||||
"""
|
||||
Check to see if we have benchmarked the given choice callers. For each
|
||||
@ -278,7 +277,6 @@ class PersistentCache(CacheBase):
|
||||
b. `max_autotune_gemm=False`: don't benchmark the choice, return nothing.
|
||||
"""
|
||||
precision = torch.get_float32_matmul_precision()
|
||||
cache_key = f"{inputs}_{hint_override}" if hint_override is not None else inputs
|
||||
|
||||
timings = {}
|
||||
|
||||
@ -287,11 +285,9 @@ class PersistentCache(CacheBase):
|
||||
hit = True
|
||||
for choice in choices:
|
||||
choice_hash = choice.hash_key()
|
||||
if choice_hash in cache.get(op, {}).get(cache_key, {}).get(
|
||||
precision, {}
|
||||
):
|
||||
if choice_hash in cache.get(op, {}).get(inputs, {}).get(precision, {}):
|
||||
# cache hit
|
||||
timings[choice] = cache[op][cache_key][precision][choice_hash]
|
||||
timings[choice] = cache[op][inputs][precision][choice_hash]
|
||||
else:
|
||||
# cache miss
|
||||
hit = False
|
||||
@ -304,9 +300,9 @@ class PersistentCache(CacheBase):
|
||||
timings = benchmark(choices)
|
||||
assert all(choice in timings for choice in choices)
|
||||
local_cache.setdefault(op, {})
|
||||
local_cache[op].setdefault(cache_key, {}).setdefault(precision, {})
|
||||
local_cache[op].setdefault(inputs, {}).setdefault(precision, {})
|
||||
for choice, timing in timings.items():
|
||||
local_cache[op][cache_key][precision][choice.hash_key()] = timing
|
||||
local_cache[op][inputs][precision][choice.hash_key()] = timing
|
||||
|
||||
self.update_local_cache(local_cache)
|
||||
|
||||
|
Reference in New Issue
Block a user