From cc699c56dc14af5b652a49e3044c0fe7c8ca70ee Mon Sep 17 00:00:00 2001 From: Shunting Zhang Date: Thu, 9 Mar 2023 22:23:41 +0000 Subject: [PATCH] reland #96248 [inductor] show performance for each autotune config for a kernel (#96458) Pull Request resolved: https://github.com/pytorch/pytorch/pull/96458 Approved by: https://github.com/ngimel --- torch/_inductor/codegen/triton.py | 15 +++++++++++- torch/_inductor/codegen/wrapper.py | 5 +++- torch/_inductor/config.py | 33 +++++++++++++++++-------- torch/_inductor/triton_ops/autotune.py | 17 ++++++++++--- torch/_inductor/utils.py | 34 +++++++++++++++++++------- 5 files changed, 79 insertions(+), 25 deletions(-) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index abaf523f9cfc..c4fbd4dfdb99 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1205,8 +1205,9 @@ class TritonKernel(Kernel): result.writelines(["\n", "\n", "def call(args):"]) grid = [] extra_args = [] + extra_args_str = None + index = V.graph.scheduler.current_device.index with result.indent(): - index = V.graph.scheduler.current_device.index result.writeline(f"with torch.cuda._DeviceGuard({index}):") with result.indent(): result.writeline( @@ -1226,6 +1227,18 @@ class TritonKernel(Kernel): f"triton_.run(*args, {extra_args_str}grid=grid({', '.join(grid)}), stream={stream_name})" ) + # benchmark all configs + result.writelines(["\n", "\n", "def benchmark_all_configs(args):"]) + with result.indent(): + result.writeline(f"with torch.cuda._DeviceGuard({index}):") + with result.indent(): + result.writeline( + f"torch.cuda.set_device({index})" + ) # no-op to ensure context + result.writeline( + f"return triton_.benchmark_all_configs(*args, {extra_args_str}grid=grid({', '.join(grid)}))" + ) + result.writelines(["\n", "\n", "if __name__ == '__main__':"]) with result.indent(): result.writeline("from torch._inductor.utils import get_num_bytes") diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 80351b2016e1..9de425065e8d 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -614,13 +614,16 @@ class WrapperCodeGen(CodeGen): "", "parser = argparse.ArgumentParser()", 'parser.add_argument("--benchmark-kernels", "-k", action="store_true", help="Whether to benchmark each individual kernels")', # noqa: B950, line too long + 'parser.add_argument("--benchmark-all-configs", "-c", action="store_true", help="Whether to benchmark each individual config for a kernel")', # noqa: B950, line too long "args = parser.parse_args()", "", "if args.benchmark_kernels:", ] ) with output.indent(): - output.writeline(f"benchmark_all_kernels('{get_benchmark_name()}')") + output.writeline( + f"benchmark_all_kernels('{get_benchmark_name()}', args.benchmark_all_configs)" + ) output.writeline("else:") with output.indent(): output.writeline("benchmark_compiled_module()") diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 238f15005a2a..4f9b8a59cfab 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -90,16 +90,29 @@ def is_fbcode(): # warnings intended for PyTorch developers, disable for point releases developer_warnings = is_fbcode() or "+" in torch.__version__ -compile_threads = ( - 1 - if sys.platform == "win32" or is_fbcode() - else min( - 32, - len(os.sched_getaffinity(0)) - if hasattr(os, "sched_getaffinity") - else os.cpu_count(), - ) -) + +def decide_compile_threads(): + """ + Here are the precedence to decide compile_threads + 1. User can override it by TORCHINDUCTOR_COMPILE_THREADS. One may want to disable async compiling by + setting this to 1 to make pdb happy. + 2. Set to 1 if it's win32 platform or it's a fbcode build + 3. decide by the number of CPU cores + """ + if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ: + return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"]) + elif sys.platform == "win32" or is_fbcode(): + return 1 + else: + return min( + 32, + len(os.sched_getaffinity(0)) + if hasattr(os, "sched_getaffinity") + else os.cpu_count(), + ) + + +compile_threads = decide_compile_threads() # autotuning global cache path if is_fbcode(): diff --git a/torch/_inductor/triton_ops/autotune.py b/torch/_inductor/triton_ops/autotune.py index 558a4240d8f5..edaa0889337c 100644 --- a/torch/_inductor/triton_ops/autotune.py +++ b/torch/_inductor/triton_ops/autotune.py @@ -155,8 +155,7 @@ class CachingAutotuner(KernelInterface): return do_bench(kernel_call, rep=40, fast_flush=True) @dynamo_timed - def autotune_to_one_config(self, *args, **kwargs): - """Do the actual autotuning""" + def benchmark_all_configs(self, *args, **kwargs): from ..compile_fx import clone_preserve_strides # clone inplace buffers to avoid autotune contaminating them if @@ -171,9 +170,14 @@ class CachingAutotuner(KernelInterface): cloned_args.append(arg) timings = { - launcher: self.bench(launcher, *cloned_args, **kwargs) + launcher: self.bench(launcher, *cloned_args, **kwargs)[0] for launcher in self.launchers } + return timings + + def autotune_to_one_config(self, *args, **kwargs): + """Do the actual autotuning""" + timings = self.benchmark_all_configs(*args, **kwargs) self.launchers = [builtins.min(timings, key=timings.get)] if self.save_cache_hook: self.save_cache_hook(self.launchers[0].config) @@ -313,8 +317,13 @@ def cached_autotune( configs = unique_configs(configs) assert len(configs) == 1 or filename + # The autotune cache will simply replace the list of candidate configs with + # the best config cached. We don't want that when we benchmark triton kernels. + # We need the perf for each of the candidate config instead. + cache_autotune_result = not config.benchmark_kernel + # on disk caching logic - if filename is not None and len(configs) > 1: + if cache_autotune_result and filename is not None and len(configs) > 1: cache_filename = os.path.splitext(filename)[0] + ".best_config" configs_hash = hash_configs(configs) best_config = load_cached_autotuning(cache_filename, configs_hash, configs) diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 9956efd605c8..1441f2ea2e33 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -625,7 +625,7 @@ def get_benchmark_name(): return arg[len("--only=") :] -def benchmark_all_kernels(benchmark_name): +def benchmark_all_kernels(benchmark_name, benchmark_all_configs): """ An experimental API used only when config.benchmark_kernel is true. @@ -642,18 +642,34 @@ def benchmark_all_kernels(benchmark_name): if not hasattr(kernel_mod, "get_args") or not hasattr(kernel_mod, "call"): continue args = kernel_mod.get_args() - ms = do_bench(lambda: kernel_mod.call(args), rep=40, fast_flush=True)[0] num_gb = get_num_bytes(*args) / 1e9 - gb_per_s = num_gb / (ms / 1e3) - # follow what we do in DebugAutotuner - info_str = f"{benchmark_name:20} {kernel_key[:10]} {ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s" - import colorama + def get_info_str(ms, prefix=""): + gb_per_s = num_gb / (ms / 1e3) + # follow what we do in DebugAutotuner + info_str = f"{prefix}{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s" + import colorama - if ms > 0.012 and gb_per_s < 650: - print(colorama.Fore.RED + info_str + colorama.Fore.RESET) + if ms > 0.012 and gb_per_s < 650: + info_str = colorama.Fore.RED + info_str + colorama.Fore.RESET + return info_str + + bench_result = [] + if benchmark_all_configs: + assert hasattr(kernel_mod, "benchmark_all_configs") + bench_result = kernel_mod.benchmark_all_configs(args) + bench_result = [ + (launcher.config, ms) for launcher, ms in bench_result.items() + ] + print(f"{benchmark_name:20} {kernel_key[:10]}") + for cfg, ms in bench_result: + print(f" {get_info_str(ms)} @ {cfg}") else: - print(info_str) + ms = do_bench(lambda: kernel_mod.call(args), rep=40, fast_flush=True)[0] + assert ( + len(kernel_mod.triton_.launchers) == 1 + ), "Autotuner should have selected the best config" + print(get_info_str(ms, prefix=f"{benchmark_name:20} {kernel_key[:10]} ")) nfound += 1 if nfound == 0: