mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96458 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
cf3d3a583e
commit
cc699c56dc
@ -1205,8 +1205,9 @@ class TritonKernel(Kernel):
|
|||||||
result.writelines(["\n", "\n", "def call(args):"])
|
result.writelines(["\n", "\n", "def call(args):"])
|
||||||
grid = []
|
grid = []
|
||||||
extra_args = []
|
extra_args = []
|
||||||
|
extra_args_str = None
|
||||||
|
index = V.graph.scheduler.current_device.index
|
||||||
with result.indent():
|
with result.indent():
|
||||||
index = V.graph.scheduler.current_device.index
|
|
||||||
result.writeline(f"with torch.cuda._DeviceGuard({index}):")
|
result.writeline(f"with torch.cuda._DeviceGuard({index}):")
|
||||||
with result.indent():
|
with result.indent():
|
||||||
result.writeline(
|
result.writeline(
|
||||||
@ -1226,6 +1227,18 @@ class TritonKernel(Kernel):
|
|||||||
f"triton_.run(*args, {extra_args_str}grid=grid({', '.join(grid)}), stream={stream_name})"
|
f"triton_.run(*args, {extra_args_str}grid=grid({', '.join(grid)}), stream={stream_name})"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# benchmark all configs
|
||||||
|
result.writelines(["\n", "\n", "def benchmark_all_configs(args):"])
|
||||||
|
with result.indent():
|
||||||
|
result.writeline(f"with torch.cuda._DeviceGuard({index}):")
|
||||||
|
with result.indent():
|
||||||
|
result.writeline(
|
||||||
|
f"torch.cuda.set_device({index})"
|
||||||
|
) # no-op to ensure context
|
||||||
|
result.writeline(
|
||||||
|
f"return triton_.benchmark_all_configs(*args, {extra_args_str}grid=grid({', '.join(grid)}))"
|
||||||
|
)
|
||||||
|
|
||||||
result.writelines(["\n", "\n", "if __name__ == '__main__':"])
|
result.writelines(["\n", "\n", "if __name__ == '__main__':"])
|
||||||
with result.indent():
|
with result.indent():
|
||||||
result.writeline("from torch._inductor.utils import get_num_bytes")
|
result.writeline("from torch._inductor.utils import get_num_bytes")
|
||||||
|
@ -614,13 +614,16 @@ class WrapperCodeGen(CodeGen):
|
|||||||
"",
|
"",
|
||||||
"parser = argparse.ArgumentParser()",
|
"parser = argparse.ArgumentParser()",
|
||||||
'parser.add_argument("--benchmark-kernels", "-k", action="store_true", help="Whether to benchmark each individual kernels")', # noqa: B950, line too long
|
'parser.add_argument("--benchmark-kernels", "-k", action="store_true", help="Whether to benchmark each individual kernels")', # noqa: B950, line too long
|
||||||
|
'parser.add_argument("--benchmark-all-configs", "-c", action="store_true", help="Whether to benchmark each individual config for a kernel")', # noqa: B950, line too long
|
||||||
"args = parser.parse_args()",
|
"args = parser.parse_args()",
|
||||||
"",
|
"",
|
||||||
"if args.benchmark_kernels:",
|
"if args.benchmark_kernels:",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
with output.indent():
|
with output.indent():
|
||||||
output.writeline(f"benchmark_all_kernels('{get_benchmark_name()}')")
|
output.writeline(
|
||||||
|
f"benchmark_all_kernels('{get_benchmark_name()}', args.benchmark_all_configs)"
|
||||||
|
)
|
||||||
output.writeline("else:")
|
output.writeline("else:")
|
||||||
with output.indent():
|
with output.indent():
|
||||||
output.writeline("benchmark_compiled_module()")
|
output.writeline("benchmark_compiled_module()")
|
||||||
|
@ -90,16 +90,29 @@ def is_fbcode():
|
|||||||
# warnings intended for PyTorch developers, disable for point releases
|
# warnings intended for PyTorch developers, disable for point releases
|
||||||
developer_warnings = is_fbcode() or "+" in torch.__version__
|
developer_warnings = is_fbcode() or "+" in torch.__version__
|
||||||
|
|
||||||
compile_threads = (
|
|
||||||
1
|
def decide_compile_threads():
|
||||||
if sys.platform == "win32" or is_fbcode()
|
"""
|
||||||
else min(
|
Here are the precedence to decide compile_threads
|
||||||
32,
|
1. User can override it by TORCHINDUCTOR_COMPILE_THREADS. One may want to disable async compiling by
|
||||||
len(os.sched_getaffinity(0))
|
setting this to 1 to make pdb happy.
|
||||||
if hasattr(os, "sched_getaffinity")
|
2. Set to 1 if it's win32 platform or it's a fbcode build
|
||||||
else os.cpu_count(),
|
3. decide by the number of CPU cores
|
||||||
)
|
"""
|
||||||
)
|
if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ:
|
||||||
|
return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"])
|
||||||
|
elif sys.platform == "win32" or is_fbcode():
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
return min(
|
||||||
|
32,
|
||||||
|
len(os.sched_getaffinity(0))
|
||||||
|
if hasattr(os, "sched_getaffinity")
|
||||||
|
else os.cpu_count(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
compile_threads = decide_compile_threads()
|
||||||
|
|
||||||
# autotuning global cache path
|
# autotuning global cache path
|
||||||
if is_fbcode():
|
if is_fbcode():
|
||||||
|
@ -155,8 +155,7 @@ class CachingAutotuner(KernelInterface):
|
|||||||
return do_bench(kernel_call, rep=40, fast_flush=True)
|
return do_bench(kernel_call, rep=40, fast_flush=True)
|
||||||
|
|
||||||
@dynamo_timed
|
@dynamo_timed
|
||||||
def autotune_to_one_config(self, *args, **kwargs):
|
def benchmark_all_configs(self, *args, **kwargs):
|
||||||
"""Do the actual autotuning"""
|
|
||||||
from ..compile_fx import clone_preserve_strides
|
from ..compile_fx import clone_preserve_strides
|
||||||
|
|
||||||
# clone inplace buffers to avoid autotune contaminating them if
|
# clone inplace buffers to avoid autotune contaminating them if
|
||||||
@ -171,9 +170,14 @@ class CachingAutotuner(KernelInterface):
|
|||||||
cloned_args.append(arg)
|
cloned_args.append(arg)
|
||||||
|
|
||||||
timings = {
|
timings = {
|
||||||
launcher: self.bench(launcher, *cloned_args, **kwargs)
|
launcher: self.bench(launcher, *cloned_args, **kwargs)[0]
|
||||||
for launcher in self.launchers
|
for launcher in self.launchers
|
||||||
}
|
}
|
||||||
|
return timings
|
||||||
|
|
||||||
|
def autotune_to_one_config(self, *args, **kwargs):
|
||||||
|
"""Do the actual autotuning"""
|
||||||
|
timings = self.benchmark_all_configs(*args, **kwargs)
|
||||||
self.launchers = [builtins.min(timings, key=timings.get)]
|
self.launchers = [builtins.min(timings, key=timings.get)]
|
||||||
if self.save_cache_hook:
|
if self.save_cache_hook:
|
||||||
self.save_cache_hook(self.launchers[0].config)
|
self.save_cache_hook(self.launchers[0].config)
|
||||||
@ -313,8 +317,13 @@ def cached_autotune(
|
|||||||
configs = unique_configs(configs)
|
configs = unique_configs(configs)
|
||||||
assert len(configs) == 1 or filename
|
assert len(configs) == 1 or filename
|
||||||
|
|
||||||
|
# The autotune cache will simply replace the list of candidate configs with
|
||||||
|
# the best config cached. We don't want that when we benchmark triton kernels.
|
||||||
|
# We need the perf for each of the candidate config instead.
|
||||||
|
cache_autotune_result = not config.benchmark_kernel
|
||||||
|
|
||||||
# on disk caching logic
|
# on disk caching logic
|
||||||
if filename is not None and len(configs) > 1:
|
if cache_autotune_result and filename is not None and len(configs) > 1:
|
||||||
cache_filename = os.path.splitext(filename)[0] + ".best_config"
|
cache_filename = os.path.splitext(filename)[0] + ".best_config"
|
||||||
configs_hash = hash_configs(configs)
|
configs_hash = hash_configs(configs)
|
||||||
best_config = load_cached_autotuning(cache_filename, configs_hash, configs)
|
best_config = load_cached_autotuning(cache_filename, configs_hash, configs)
|
||||||
|
@ -625,7 +625,7 @@ def get_benchmark_name():
|
|||||||
return arg[len("--only=") :]
|
return arg[len("--only=") :]
|
||||||
|
|
||||||
|
|
||||||
def benchmark_all_kernels(benchmark_name):
|
def benchmark_all_kernels(benchmark_name, benchmark_all_configs):
|
||||||
"""
|
"""
|
||||||
An experimental API used only when config.benchmark_kernel is true.
|
An experimental API used only when config.benchmark_kernel is true.
|
||||||
|
|
||||||
@ -642,18 +642,34 @@ def benchmark_all_kernels(benchmark_name):
|
|||||||
if not hasattr(kernel_mod, "get_args") or not hasattr(kernel_mod, "call"):
|
if not hasattr(kernel_mod, "get_args") or not hasattr(kernel_mod, "call"):
|
||||||
continue
|
continue
|
||||||
args = kernel_mod.get_args()
|
args = kernel_mod.get_args()
|
||||||
ms = do_bench(lambda: kernel_mod.call(args), rep=40, fast_flush=True)[0]
|
|
||||||
num_gb = get_num_bytes(*args) / 1e9
|
num_gb = get_num_bytes(*args) / 1e9
|
||||||
gb_per_s = num_gb / (ms / 1e3)
|
|
||||||
|
|
||||||
# follow what we do in DebugAutotuner
|
def get_info_str(ms, prefix=""):
|
||||||
info_str = f"{benchmark_name:20} {kernel_key[:10]} {ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s"
|
gb_per_s = num_gb / (ms / 1e3)
|
||||||
import colorama
|
# follow what we do in DebugAutotuner
|
||||||
|
info_str = f"{prefix}{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s"
|
||||||
|
import colorama
|
||||||
|
|
||||||
if ms > 0.012 and gb_per_s < 650:
|
if ms > 0.012 and gb_per_s < 650:
|
||||||
print(colorama.Fore.RED + info_str + colorama.Fore.RESET)
|
info_str = colorama.Fore.RED + info_str + colorama.Fore.RESET
|
||||||
|
return info_str
|
||||||
|
|
||||||
|
bench_result = []
|
||||||
|
if benchmark_all_configs:
|
||||||
|
assert hasattr(kernel_mod, "benchmark_all_configs")
|
||||||
|
bench_result = kernel_mod.benchmark_all_configs(args)
|
||||||
|
bench_result = [
|
||||||
|
(launcher.config, ms) for launcher, ms in bench_result.items()
|
||||||
|
]
|
||||||
|
print(f"{benchmark_name:20} {kernel_key[:10]}")
|
||||||
|
for cfg, ms in bench_result:
|
||||||
|
print(f" {get_info_str(ms)} @ {cfg}")
|
||||||
else:
|
else:
|
||||||
print(info_str)
|
ms = do_bench(lambda: kernel_mod.call(args), rep=40, fast_flush=True)[0]
|
||||||
|
assert (
|
||||||
|
len(kernel_mod.triton_.launchers) == 1
|
||||||
|
), "Autotuner should have selected the best config"
|
||||||
|
print(get_info_str(ms, prefix=f"{benchmark_name:20} {kernel_key[:10]} "))
|
||||||
|
|
||||||
nfound += 1
|
nfound += 1
|
||||||
if nfound == 0:
|
if nfound == 0:
|
||||||
|
Reference in New Issue
Block a user