mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96458 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
cf3d3a583e
commit
cc699c56dc
@ -1205,8 +1205,9 @@ class TritonKernel(Kernel):
|
||||
result.writelines(["\n", "\n", "def call(args):"])
|
||||
grid = []
|
||||
extra_args = []
|
||||
extra_args_str = None
|
||||
index = V.graph.scheduler.current_device.index
|
||||
with result.indent():
|
||||
index = V.graph.scheduler.current_device.index
|
||||
result.writeline(f"with torch.cuda._DeviceGuard({index}):")
|
||||
with result.indent():
|
||||
result.writeline(
|
||||
@ -1226,6 +1227,18 @@ class TritonKernel(Kernel):
|
||||
f"triton_.run(*args, {extra_args_str}grid=grid({', '.join(grid)}), stream={stream_name})"
|
||||
)
|
||||
|
||||
# benchmark all configs
|
||||
result.writelines(["\n", "\n", "def benchmark_all_configs(args):"])
|
||||
with result.indent():
|
||||
result.writeline(f"with torch.cuda._DeviceGuard({index}):")
|
||||
with result.indent():
|
||||
result.writeline(
|
||||
f"torch.cuda.set_device({index})"
|
||||
) # no-op to ensure context
|
||||
result.writeline(
|
||||
f"return triton_.benchmark_all_configs(*args, {extra_args_str}grid=grid({', '.join(grid)}))"
|
||||
)
|
||||
|
||||
result.writelines(["\n", "\n", "if __name__ == '__main__':"])
|
||||
with result.indent():
|
||||
result.writeline("from torch._inductor.utils import get_num_bytes")
|
||||
|
@ -614,13 +614,16 @@ class WrapperCodeGen(CodeGen):
|
||||
"",
|
||||
"parser = argparse.ArgumentParser()",
|
||||
'parser.add_argument("--benchmark-kernels", "-k", action="store_true", help="Whether to benchmark each individual kernels")', # noqa: B950, line too long
|
||||
'parser.add_argument("--benchmark-all-configs", "-c", action="store_true", help="Whether to benchmark each individual config for a kernel")', # noqa: B950, line too long
|
||||
"args = parser.parse_args()",
|
||||
"",
|
||||
"if args.benchmark_kernels:",
|
||||
]
|
||||
)
|
||||
with output.indent():
|
||||
output.writeline(f"benchmark_all_kernels('{get_benchmark_name()}')")
|
||||
output.writeline(
|
||||
f"benchmark_all_kernels('{get_benchmark_name()}', args.benchmark_all_configs)"
|
||||
)
|
||||
output.writeline("else:")
|
||||
with output.indent():
|
||||
output.writeline("benchmark_compiled_module()")
|
||||
|
@ -90,16 +90,29 @@ def is_fbcode():
|
||||
# warnings intended for PyTorch developers, disable for point releases
|
||||
developer_warnings = is_fbcode() or "+" in torch.__version__
|
||||
|
||||
compile_threads = (
|
||||
1
|
||||
if sys.platform == "win32" or is_fbcode()
|
||||
else min(
|
||||
32,
|
||||
len(os.sched_getaffinity(0))
|
||||
if hasattr(os, "sched_getaffinity")
|
||||
else os.cpu_count(),
|
||||
)
|
||||
)
|
||||
|
||||
def decide_compile_threads():
|
||||
"""
|
||||
Here are the precedence to decide compile_threads
|
||||
1. User can override it by TORCHINDUCTOR_COMPILE_THREADS. One may want to disable async compiling by
|
||||
setting this to 1 to make pdb happy.
|
||||
2. Set to 1 if it's win32 platform or it's a fbcode build
|
||||
3. decide by the number of CPU cores
|
||||
"""
|
||||
if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ:
|
||||
return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"])
|
||||
elif sys.platform == "win32" or is_fbcode():
|
||||
return 1
|
||||
else:
|
||||
return min(
|
||||
32,
|
||||
len(os.sched_getaffinity(0))
|
||||
if hasattr(os, "sched_getaffinity")
|
||||
else os.cpu_count(),
|
||||
)
|
||||
|
||||
|
||||
compile_threads = decide_compile_threads()
|
||||
|
||||
# autotuning global cache path
|
||||
if is_fbcode():
|
||||
|
@ -155,8 +155,7 @@ class CachingAutotuner(KernelInterface):
|
||||
return do_bench(kernel_call, rep=40, fast_flush=True)
|
||||
|
||||
@dynamo_timed
|
||||
def autotune_to_one_config(self, *args, **kwargs):
|
||||
"""Do the actual autotuning"""
|
||||
def benchmark_all_configs(self, *args, **kwargs):
|
||||
from ..compile_fx import clone_preserve_strides
|
||||
|
||||
# clone inplace buffers to avoid autotune contaminating them if
|
||||
@ -171,9 +170,14 @@ class CachingAutotuner(KernelInterface):
|
||||
cloned_args.append(arg)
|
||||
|
||||
timings = {
|
||||
launcher: self.bench(launcher, *cloned_args, **kwargs)
|
||||
launcher: self.bench(launcher, *cloned_args, **kwargs)[0]
|
||||
for launcher in self.launchers
|
||||
}
|
||||
return timings
|
||||
|
||||
def autotune_to_one_config(self, *args, **kwargs):
|
||||
"""Do the actual autotuning"""
|
||||
timings = self.benchmark_all_configs(*args, **kwargs)
|
||||
self.launchers = [builtins.min(timings, key=timings.get)]
|
||||
if self.save_cache_hook:
|
||||
self.save_cache_hook(self.launchers[0].config)
|
||||
@ -313,8 +317,13 @@ def cached_autotune(
|
||||
configs = unique_configs(configs)
|
||||
assert len(configs) == 1 or filename
|
||||
|
||||
# The autotune cache will simply replace the list of candidate configs with
|
||||
# the best config cached. We don't want that when we benchmark triton kernels.
|
||||
# We need the perf for each of the candidate config instead.
|
||||
cache_autotune_result = not config.benchmark_kernel
|
||||
|
||||
# on disk caching logic
|
||||
if filename is not None and len(configs) > 1:
|
||||
if cache_autotune_result and filename is not None and len(configs) > 1:
|
||||
cache_filename = os.path.splitext(filename)[0] + ".best_config"
|
||||
configs_hash = hash_configs(configs)
|
||||
best_config = load_cached_autotuning(cache_filename, configs_hash, configs)
|
||||
|
@ -625,7 +625,7 @@ def get_benchmark_name():
|
||||
return arg[len("--only=") :]
|
||||
|
||||
|
||||
def benchmark_all_kernels(benchmark_name):
|
||||
def benchmark_all_kernels(benchmark_name, benchmark_all_configs):
|
||||
"""
|
||||
An experimental API used only when config.benchmark_kernel is true.
|
||||
|
||||
@ -642,18 +642,34 @@ def benchmark_all_kernels(benchmark_name):
|
||||
if not hasattr(kernel_mod, "get_args") or not hasattr(kernel_mod, "call"):
|
||||
continue
|
||||
args = kernel_mod.get_args()
|
||||
ms = do_bench(lambda: kernel_mod.call(args), rep=40, fast_flush=True)[0]
|
||||
num_gb = get_num_bytes(*args) / 1e9
|
||||
gb_per_s = num_gb / (ms / 1e3)
|
||||
|
||||
# follow what we do in DebugAutotuner
|
||||
info_str = f"{benchmark_name:20} {kernel_key[:10]} {ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s"
|
||||
import colorama
|
||||
def get_info_str(ms, prefix=""):
|
||||
gb_per_s = num_gb / (ms / 1e3)
|
||||
# follow what we do in DebugAutotuner
|
||||
info_str = f"{prefix}{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s"
|
||||
import colorama
|
||||
|
||||
if ms > 0.012 and gb_per_s < 650:
|
||||
print(colorama.Fore.RED + info_str + colorama.Fore.RESET)
|
||||
if ms > 0.012 and gb_per_s < 650:
|
||||
info_str = colorama.Fore.RED + info_str + colorama.Fore.RESET
|
||||
return info_str
|
||||
|
||||
bench_result = []
|
||||
if benchmark_all_configs:
|
||||
assert hasattr(kernel_mod, "benchmark_all_configs")
|
||||
bench_result = kernel_mod.benchmark_all_configs(args)
|
||||
bench_result = [
|
||||
(launcher.config, ms) for launcher, ms in bench_result.items()
|
||||
]
|
||||
print(f"{benchmark_name:20} {kernel_key[:10]}")
|
||||
for cfg, ms in bench_result:
|
||||
print(f" {get_info_str(ms)} @ {cfg}")
|
||||
else:
|
||||
print(info_str)
|
||||
ms = do_bench(lambda: kernel_mod.call(args), rep=40, fast_flush=True)[0]
|
||||
assert (
|
||||
len(kernel_mod.triton_.launchers) == 1
|
||||
), "Autotuner should have selected the best config"
|
||||
print(get_info_str(ms, prefix=f"{benchmark_name:20} {kernel_key[:10]} "))
|
||||
|
||||
nfound += 1
|
||||
if nfound == 0:
|
||||
|
Reference in New Issue
Block a user