[cutlass backend] Fix prescreening non-deterministic problem (#156144)

Differential Revision: [D76642615](https://our.internmc.facebook.com/intern/diff/D76642615/)

What do we expect to see when we run two identical matmul back to back? We expect to see the second one spending no time in precompilation, autotuning and prescreening.

However, the introduction of prescreening bring some non-deterministics-ness. Basically, we have
1. prescreening of first matmul chooses a set of kernels to advance to autotuning
2. autotuning re-does the autotuning of the winners, potentially changing their timings a bit
3. second prescreening results in a slightly different set of kernels
4. since not all timings are present, an autotune is re-done.

With this diff:
```
SingleProcess AUTOTUNE benchmarking takes 3.8633 seconds and 134.7364 seconds precompiling for 32 choices and 24.4472 seconds prescreening
SingleProcess AUTOTUNE benchmarking takes 0.0003 seconds and 0.0027 seconds precompiling for 32 choices and 0.0006 seconds prescreening
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156144
Approved by: https://github.com/mlazos
This commit is contained in:
henrylhtsang
2025-06-16 19:06:38 -07:00
committed by PyTorch MergeBot
parent cd66ff8030
commit bb462a6237

View File

@ -2122,11 +2122,14 @@ class AlgorithmSelectorCache(PersistentCache):
self.precompile_cache: dict[str, Callable[[], None]] = {}
# list of callbacks that are called after benchmarking
self.feedback_saver_fns: list[FeedbackFunction] = []
# cache for prescreening results to ensure deterministic candidate selection
self.prescreening_cache: dict[str, OrderedSet[str]] = {}
clear_on_fresh_cache(self)
def cache_clear(self) -> None:
self.precompile_cache.clear()
self.prescreening_cache.clear()
def __call__(
self,
@ -2239,7 +2242,9 @@ class AlgorithmSelectorCache(PersistentCache):
precompile_elapse = time.time() - precompile_start_ts
log.debug("Precompilation elapsed time: %.02fs", precompile_elapse)
candidates = self.prescreen_choices(choices)
candidates = self.prescreen_choices(
choices, name, inputs_key, self.prescreening_cache
)
prescreening_elapse: Optional[float] = None
if candidates:
prescreening_start_ts = time.time()
@ -2249,7 +2254,9 @@ class AlgorithmSelectorCache(PersistentCache):
inputs_key,
autotune,
)
choices = self.prune_choices_postscreen(choices, timings)
choices = self.prune_choices_postscreen(
choices, timings, name, inputs_key, self.prescreening_cache
)
prescreening_elapse = time.time() - prescreening_start_ts
log.debug("Prescreening elapsed time: %.02fs", prescreening_elapse)
@ -2758,11 +2765,39 @@ class AlgorithmSelectorCache(PersistentCache):
@staticmethod
def prescreen_choices(
choices: list[ChoiceCaller],
name: str,
inputs_key: str,
prescreen_cache: dict[str, OrderedSet[str]],
) -> list[ChoiceCaller]:
"""
Add prescreening phase. Motivation is to reduce the number of autotuning needed,
for example, when there are runtime params.
Figure out what choices need to be prescreened before autotuning with runtime
params.
Prescreening is a process of reducing the number of autotuning for choices with
runtime params via a two stage autotuning process. First, we fix a set of runtime
params (here we use swizzle=2) and run autotuning to get a set of candidates.
Then, we run autotuning again with the candidates and the full set of runtime
params.
Since have the concept of runtime params, we need to differentiate between
choice's hash_key and choice's kernel_hash_key. The former includes information
like runtime params, while the latter does not. prescreen_cache, if exists, stores
the set of hash_key that should win the prescreening.
Right now, only CUTLASS choices have runtime params.
"""
# Create a cache key for prescreening results
prescreen_key = f"{name}:{inputs_key}"
# Check if we have cached prescreening results (prescreen_winners)
if prescreen_key in prescreen_cache:
prescreen_winners = [
choice
for choice in choices
if choice.hash_key() in prescreen_cache[prescreen_key]
]
return prescreen_winners
# prescreen cutlass
from .codegen.cuda.cuda_kernel import CUDATemplateCaller
@ -2791,14 +2826,31 @@ class AlgorithmSelectorCache(PersistentCache):
def prune_choices_postscreen(
choices: list[ChoiceCaller],
candidate_timings: dict[ChoiceCaller, float],
name: str,
inputs_key: str,
prescreen_cache: dict[str, OrderedSet[str]],
) -> list[ChoiceCaller]:
"""
Prune the choices after prescreening.
"""
from .codegen.cuda.cuda_kernel import CUDATemplateCaller
if len(candidate_timings) < 10:
return []
prescreen_key = f"{name}:{inputs_key}"
# Check if we have cached postscreen results
if prescreen_key in prescreen_cache:
# candidate_timings are from choices that have won prescreening already
winner_kernel_hashes = [
candidate.kernel_hash_key() for candidate in candidate_timings
]
pruned_choices = [
choice
for choice in choices
if not isinstance(choice, CUDATemplateCaller)
or choice.kernel_hash_key() in winner_kernel_hashes
]
return pruned_choices
log.debug("Before pruning using prescreening timings, %d choices", len(choices))
sorted_candidates = sorted(
@ -2835,21 +2887,28 @@ class AlgorithmSelectorCache(PersistentCache):
candidates_to_prune = OrderedSet(
candidate.kernel_hash_key() for candidate in sorted_candidates[num_to_keep:]
)
winner_hashes: OrderedSet[str] = OrderedSet()
for candidate in sorted_candidates[:num_to_keep]:
if candidate_timings[candidate] == float("inf"):
candidates_to_prune.add(candidate.kernel_hash_key())
else:
winner_hashes.add(candidate.hash_key())
if isinstance(candidate, CUDATemplateCaller):
candidate.bmreq.ensure_dll_loaded()
choices = [
pruned_choices = [
choice
for choice in choices
if choice.kernel_hash_key() not in candidates_to_prune # type: ignore[attr-defined]
]
log.debug("After pruning using prescreening timings, %d choices", len(choices))
return choices
# Cache the hash_key of winners of prescreening
prescreen_cache[prescreen_key] = winner_hashes
log.debug(
"After pruning using prescreening timings, %d choices", len(pruned_choices)
)
return pruned_choices
@staticmethod
def log_results(