mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
[cutlass backend] Fix prescreening non-deterministic problem (#156144)
Differential Revision: [D76642615](https://our.internmc.facebook.com/intern/diff/D76642615/) What do we expect to see when we run two identical matmul back to back? We expect to see the second one spending no time in precompilation, autotuning and prescreening. However, the introduction of prescreening bring some non-deterministics-ness. Basically, we have 1. prescreening of first matmul chooses a set of kernels to advance to autotuning 2. autotuning re-does the autotuning of the winners, potentially changing their timings a bit 3. second prescreening results in a slightly different set of kernels 4. since not all timings are present, an autotune is re-done. With this diff: ``` SingleProcess AUTOTUNE benchmarking takes 3.8633 seconds and 134.7364 seconds precompiling for 32 choices and 24.4472 seconds prescreening SingleProcess AUTOTUNE benchmarking takes 0.0003 seconds and 0.0027 seconds precompiling for 32 choices and 0.0006 seconds prescreening ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/156144 Approved by: https://github.com/mlazos
This commit is contained in:
committed by
PyTorch MergeBot
parent
cd66ff8030
commit
bb462a6237
@ -2122,11 +2122,14 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
self.precompile_cache: dict[str, Callable[[], None]] = {}
|
||||
# list of callbacks that are called after benchmarking
|
||||
self.feedback_saver_fns: list[FeedbackFunction] = []
|
||||
# cache for prescreening results to ensure deterministic candidate selection
|
||||
self.prescreening_cache: dict[str, OrderedSet[str]] = {}
|
||||
|
||||
clear_on_fresh_cache(self)
|
||||
|
||||
def cache_clear(self) -> None:
|
||||
self.precompile_cache.clear()
|
||||
self.prescreening_cache.clear()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -2239,7 +2242,9 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
precompile_elapse = time.time() - precompile_start_ts
|
||||
log.debug("Precompilation elapsed time: %.02fs", precompile_elapse)
|
||||
|
||||
candidates = self.prescreen_choices(choices)
|
||||
candidates = self.prescreen_choices(
|
||||
choices, name, inputs_key, self.prescreening_cache
|
||||
)
|
||||
prescreening_elapse: Optional[float] = None
|
||||
if candidates:
|
||||
prescreening_start_ts = time.time()
|
||||
@ -2249,7 +2254,9 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
inputs_key,
|
||||
autotune,
|
||||
)
|
||||
choices = self.prune_choices_postscreen(choices, timings)
|
||||
choices = self.prune_choices_postscreen(
|
||||
choices, timings, name, inputs_key, self.prescreening_cache
|
||||
)
|
||||
prescreening_elapse = time.time() - prescreening_start_ts
|
||||
log.debug("Prescreening elapsed time: %.02fs", prescreening_elapse)
|
||||
|
||||
@ -2758,11 +2765,39 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
@staticmethod
|
||||
def prescreen_choices(
|
||||
choices: list[ChoiceCaller],
|
||||
name: str,
|
||||
inputs_key: str,
|
||||
prescreen_cache: dict[str, OrderedSet[str]],
|
||||
) -> list[ChoiceCaller]:
|
||||
"""
|
||||
Add prescreening phase. Motivation is to reduce the number of autotuning needed,
|
||||
for example, when there are runtime params.
|
||||
Figure out what choices need to be prescreened before autotuning with runtime
|
||||
params.
|
||||
|
||||
Prescreening is a process of reducing the number of autotuning for choices with
|
||||
runtime params via a two stage autotuning process. First, we fix a set of runtime
|
||||
params (here we use swizzle=2) and run autotuning to get a set of candidates.
|
||||
Then, we run autotuning again with the candidates and the full set of runtime
|
||||
params.
|
||||
|
||||
Since have the concept of runtime params, we need to differentiate between
|
||||
choice's hash_key and choice's kernel_hash_key. The former includes information
|
||||
like runtime params, while the latter does not. prescreen_cache, if exists, stores
|
||||
the set of hash_key that should win the prescreening.
|
||||
|
||||
Right now, only CUTLASS choices have runtime params.
|
||||
"""
|
||||
# Create a cache key for prescreening results
|
||||
prescreen_key = f"{name}:{inputs_key}"
|
||||
|
||||
# Check if we have cached prescreening results (prescreen_winners)
|
||||
if prescreen_key in prescreen_cache:
|
||||
prescreen_winners = [
|
||||
choice
|
||||
for choice in choices
|
||||
if choice.hash_key() in prescreen_cache[prescreen_key]
|
||||
]
|
||||
return prescreen_winners
|
||||
|
||||
# prescreen cutlass
|
||||
from .codegen.cuda.cuda_kernel import CUDATemplateCaller
|
||||
|
||||
@ -2791,14 +2826,31 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
def prune_choices_postscreen(
|
||||
choices: list[ChoiceCaller],
|
||||
candidate_timings: dict[ChoiceCaller, float],
|
||||
name: str,
|
||||
inputs_key: str,
|
||||
prescreen_cache: dict[str, OrderedSet[str]],
|
||||
) -> list[ChoiceCaller]:
|
||||
"""
|
||||
Prune the choices after prescreening.
|
||||
"""
|
||||
from .codegen.cuda.cuda_kernel import CUDATemplateCaller
|
||||
|
||||
if len(candidate_timings) < 10:
|
||||
return []
|
||||
prescreen_key = f"{name}:{inputs_key}"
|
||||
|
||||
# Check if we have cached postscreen results
|
||||
if prescreen_key in prescreen_cache:
|
||||
# candidate_timings are from choices that have won prescreening already
|
||||
winner_kernel_hashes = [
|
||||
candidate.kernel_hash_key() for candidate in candidate_timings
|
||||
]
|
||||
|
||||
pruned_choices = [
|
||||
choice
|
||||
for choice in choices
|
||||
if not isinstance(choice, CUDATemplateCaller)
|
||||
or choice.kernel_hash_key() in winner_kernel_hashes
|
||||
]
|
||||
return pruned_choices
|
||||
|
||||
log.debug("Before pruning using prescreening timings, %d choices", len(choices))
|
||||
sorted_candidates = sorted(
|
||||
@ -2835,21 +2887,28 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
candidates_to_prune = OrderedSet(
|
||||
candidate.kernel_hash_key() for candidate in sorted_candidates[num_to_keep:]
|
||||
)
|
||||
winner_hashes: OrderedSet[str] = OrderedSet()
|
||||
for candidate in sorted_candidates[:num_to_keep]:
|
||||
if candidate_timings[candidate] == float("inf"):
|
||||
candidates_to_prune.add(candidate.kernel_hash_key())
|
||||
else:
|
||||
winner_hashes.add(candidate.hash_key())
|
||||
if isinstance(candidate, CUDATemplateCaller):
|
||||
candidate.bmreq.ensure_dll_loaded()
|
||||
|
||||
choices = [
|
||||
pruned_choices = [
|
||||
choice
|
||||
for choice in choices
|
||||
if choice.kernel_hash_key() not in candidates_to_prune # type: ignore[attr-defined]
|
||||
]
|
||||
|
||||
log.debug("After pruning using prescreening timings, %d choices", len(choices))
|
||||
return choices
|
||||
# Cache the hash_key of winners of prescreening
|
||||
prescreen_cache[prescreen_key] = winner_hashes
|
||||
|
||||
log.debug(
|
||||
"After pruning using prescreening timings, %d choices", len(pruned_choices)
|
||||
)
|
||||
return pruned_choices
|
||||
|
||||
@staticmethod
|
||||
def log_results(
|
||||
|
Reference in New Issue
Block a user