diff --git a/torch/_inductor/codegen/multi_kernel.py b/torch/_inductor/codegen/multi_kernel.py index 01055f5cd6e5..e2cf718aa7e0 100644 --- a/torch/_inductor/codegen/multi_kernel.py +++ b/torch/_inductor/codegen/multi_kernel.py @@ -8,6 +8,7 @@ from typing import Any, Optional, Union from torch._inductor.ir import MultiTemplateBuffer from torch._inductor.metrics import get_metric_table, is_metric_table_enabled +from torch._inductor.runtime.triton_heuristics import CachingAutotuner from torch.utils._ordered_set import OrderedSet from .. import config @@ -369,16 +370,20 @@ class MultiKernelCall: be picked. """ - def wrap_fn(kernel, index): - def inner(): - filtered_args = self._get_filtered_args(args, index) - args_clone, kwargs_clone = kernel.clone_args(*filtered_args, **kwargs) - return kernel.run(*args_clone, **kwargs_clone) - - return inner + def get_args_kwargs(kernel, index) -> tuple[tuple, dict[str, Any]]: # type: ignore[type-arg] + filtered_args = self._get_filtered_args(args, index) + args_clone, kwargs_clone = kernel.clone_args(*filtered_args, **kwargs) + return args_clone, kwargs_clone return [ - benchmarker.benchmark_gpu(wrap_fn(kernel, index), rep=40) + benchmarker.benchmark( + kernel.run, + *get_args_kwargs(kernel, index), + device=kernel.device_props.type + if isinstance(kernel, CachingAutotuner) + else None, + rep=40, + ) for index, kernel in enumerate(self.kernels) ] diff --git a/torch/_inductor/codegen/subgraph.py b/torch/_inductor/codegen/subgraph.py index 1fbed50db91c..ac39d839591f 100644 --- a/torch/_inductor/codegen/subgraph.py +++ b/torch/_inductor/codegen/subgraph.py @@ -109,7 +109,10 @@ class SubgraphChoiceCaller(ir.ChoiceCaller): bm_func([*sym_inputs, *args]) if config.profile_bandwidth_with_do_bench_using_profiling: return do_bench_using_profiling(lambda: bm_func([*sym_inputs, *args])) - return benchmarker.benchmark_gpu(lambda: bm_func([*sym_inputs, *args])) + return benchmarker.benchmark( + bm_func, + fn_args=([*sym_inputs, *args],), + ) def hash_key(self) -> str: return "-".join( diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 166413e341d5..56211ec005c4 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -4682,7 +4682,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): result.writeline("args = get_args()") result.writeline( - "ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40)" + f"ms = benchmarker.benchmark(lambda: call(args), device={V.graph.get_current_device_or_throw().type}, rep=40)" # noqa: B950 line too long ) result.writeline(f"num_gb = {num_gb}") result.writeline("gb_per_s = num_gb / (ms / 1e3)") @@ -5624,18 +5624,21 @@ class TritonScheduling(SIMDScheduling): # skip benchmarking the kernel if there are register spills ms = float("inf") else: + device = V.graph.get_current_device_or_throw() # We have to clone the inplace updated arguments to avoid earlier calls # generating out of range indices for later calls. - ms = benchmarker.benchmark_gpu( - lambda: call(wrapped_jit_function.clone_args(*args)[0]) + ms = benchmarker.benchmark( + lambda: call(wrapped_jit_function.clone_args(*args)[0]), + device=device, ) # overhead of cloning args gives bias for fusing the kernel # in the case of mutating/in-placeable second fusion # TODO - would be better as a hook in triton do_bench that reset # the input values between benchmarking if len(wrapped_jit_function.mutated_arg_names) > 0: - ms = ms - benchmarker.benchmark_gpu( - lambda: wrapped_jit_function.clone_args(*args) + ms = ms - benchmarker.benchmark( + lambda: wrapped_jit_function.clone_args(*args), + device=str(device), ) log.debug( @@ -5804,13 +5807,16 @@ class TritonScheduling(SIMDScheduling): # skip benchmarking the kernel if there are register spills ms = ms_clone = float("inf") else: + device = V.graph.get_current_device_or_throw() # We have to clone the inplace updated arguments to avoid earlier calls # generating out of range indices for later calls. - ms = benchmarker.benchmark_gpu( - lambda: call(wrapped_jit_function.clone_args(*args)[0]) + ms = benchmarker.benchmark( + lambda: call(wrapped_jit_function.clone_args(*args)[0]), + device=device, ) - ms_clone = benchmarker.benchmark_gpu( - lambda: wrapped_jit_function.clone_args(*args)[0] + ms_clone = benchmarker.benchmark( + lambda: wrapped_jit_function.clone_args(*args)[0], + device=device, ) log.debug( diff --git a/torch/_inductor/codegen/triton_combo_kernel.py b/torch/_inductor/codegen/triton_combo_kernel.py index c28321923c5e..e3134935da0b 100644 --- a/torch/_inductor/codegen/triton_combo_kernel.py +++ b/torch/_inductor/codegen/triton_combo_kernel.py @@ -889,6 +889,7 @@ class ComboKernel(Kernel): result.writeline(f"return {', '.join(var_names)},") result.writelines(["\n", "\n", "def call(args):"]) + device = V.graph.get_current_device_or_throw() index = V.graph.get_current_device_or_throw().index with result.indent(): result.writeline(f"with {V.graph.device_ops.device_guard(index)}:") @@ -923,7 +924,7 @@ class ComboKernel(Kernel): result.writeline("args = get_args()") result.writeline( - "ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40)" + f"ms = benchmarker.benchmark(call, fn_args=(args,), device={device.type},rep=40)" ) result.writeline(f"num_gb = {num_gb}") result.writeline("gb_per_s = num_gb / (ms / 1e3)") diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 4952daee3095..5ce9cfa93c40 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -5050,7 +5050,9 @@ class ChoiceCaller: } if config.profile_bandwidth_with_do_bench_using_profiling: return do_bench_using_profiling(lambda: algo(*args), **benchmark_configs) # type: ignore[arg-type] - return benchmarker.benchmark(algo, args, {"out": out}, **benchmark_configs) + return benchmarker.benchmark( + algo, args, {"out": out}, device=None, **benchmark_configs + ) def call_name(self) -> str: raise NotImplementedError diff --git a/torch/_inductor/runtime/benchmarking.py b/torch/_inductor/runtime/benchmarking.py index 21ee339b7df6..6387299ba67e 100644 --- a/torch/_inductor/runtime/benchmarking.py +++ b/torch/_inductor/runtime/benchmarking.py @@ -92,6 +92,11 @@ def time_and_count( class Benchmarker: + """ + A device-agnostic benchmarking utility for measuring the runtime of + inductor generated callables. + """ + def __init__(self: Self) -> None: pass @@ -99,8 +104,9 @@ class Benchmarker: def benchmark( self: Self, fn: Callable[..., Any], - fn_args: tuple[Any, ...], - fn_kwargs: dict[str, Any], + fn_args: Optional[tuple[Any, ...]] = None, + fn_kwargs: Optional[dict[str, Any]] = None, + device: Optional[Union[str, torch.device]] = None, **kwargs: Any, ) -> float: """Benchmark `fn(*fn_args, *fn_kwargs)` and return the runtime, in milliseconds (the @@ -109,7 +115,8 @@ class Benchmarker: device-specific implementations, like `benchmark_cpu` and `benchmark_gpu`. Raises `ValueError(...)` if we can't safely infer the device type of `fn`; for example, if multiple device types are found in `fn_args` and `fn_kwargs`, or if no device - types are found. + types are found. To bypass device inference, provide the device to the `device` + parameter. Arguments: - fn: The function to benchmark. @@ -117,26 +124,52 @@ class Benchmarker: - fn_kwargs: The function's kwargs. Keyword Arguments: + - device: Which device to use for benchmarking. If not provided the device will be attempted + to be inferred from `fn_args` and `fn_kwargs`. - **kwargs: The benchmarking implementation's kwargs. Returns: - The runtime of `fn(*fn_args, **fn_kwargs)`, in milliseconds. """ - inferred_device = None - for arg_or_kwarg in chain(fn_args, fn_kwargs.values()): - if not isinstance(arg_or_kwarg, torch.Tensor): - continue - if inferred_device is None: - inferred_device = arg_or_kwarg.device - elif arg_or_kwarg.device != inferred_device: + inferred_device: Optional[torch.device] = None + if device is not None: + inferred_device = ( + torch.device(device) if isinstance(device, str) else device + ) + else: + if fn_args is None and fn_kwargs is None: raise ValueError( - "Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`!" + "`fn_args` and `fn_kwargs` cannot both be None if `device` is not provided." ) + + fn_args = fn_args or tuple() + fn_kwargs = fn_kwargs or {} + for arg_or_kwarg in chain(fn_args, fn_kwargs.values()): + if not isinstance(arg_or_kwarg, torch.Tensor): + continue + if inferred_device is None: + inferred_device = arg_or_kwarg.device + elif arg_or_kwarg.device != inferred_device: + raise ValueError( + "Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`!" + ) + if inferred_device is None: raise ValueError( - "Can't safely infer the device type of `fn` with no device types in `fn_args` or `fn_kwargs`! You should be calling `.benchmark_cpu` or `.benchmark_gpu` directly." # noqa: B950 + "Can't safely infer the device type of `fn` with no device types" + " in `fn_args` or `fn_kwargs` and `device` not explicitly provided!" + " You should be calling `.benchmark_cpu` or `.benchmark_gpu` directly." ) - _callable = lambda: fn(*fn_args, **fn_kwargs) # noqa: E731 + + fn_args = fn_args or tuple() + fn_kwargs = fn_kwargs or {} + + # No need to wrap if the callable takes no arguments + if len(fn_args) == 0 and len(fn_kwargs) == 0: + _callable = fn + else: + _callable = lambda: fn(*fn_args, **fn_kwargs) # noqa: E731 + if inferred_device == torch.device("cpu"): return self.benchmark_cpu(_callable, **kwargs) # TODO(nmacchioni): For non-CPU functions we default to using the GPU-specific benchmarking diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 709f0ec8b11a..edcc7d574dc0 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -927,11 +927,11 @@ class CachingAutotuner(KernelInterface): return do_bench_using_profiling(kernel_call, warmup=10, rep=40) - if self.device_props.type == "cpu": - return benchmarker.benchmark_cpu(kernel_call) - - return benchmarker.benchmark_gpu( - kernel_call, rep=40, is_vetted_benchmarking=True + benchmark_kwargs = {"rep": 40} if self.device_props.type == "cuda" else {} + return benchmarker.benchmark( + fn=kernel_call, + device=self.device_props.type, + **benchmark_kwargs, # type: ignore[arg-type] ) def copy_args_to_cpu_if_needed(self, *args, **kwargs): diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index f85b5c7e39d9..0c39408e13a9 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -3269,8 +3269,8 @@ class Scheduler: device = node_list_1[0].get_device() assert device - # don't support benchmark fusion for CPU right now. - if device.type == "cpu": + # don't support benchmark fusion for CPU C++ backend right now. + if device.type == "cpu" and config.cpu_backend != "triton": return True node_list_2 = node2.get_nodes() @@ -5569,8 +5569,8 @@ class Scheduler: subkernel_nodes = nodes device = subkernel_nodes[0].get_device() - # don't support benchmark fusion for CPU right now. - if device is None or device.type == "cpu": + # don't support benchmark fusion for CPU C++ backend right now. + if device is None or (device.type == "cpu" and config.cpu_backend != "triton"): return True from triton.compiler.errors import CompilationError diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index b0e81444ad84..f9badd8b39de 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -2671,8 +2671,10 @@ class AlgorithmSelectorCache(PersistentCache): # Templates selected with input_gen_fns require specific input data to avoid IMA # Passing custom input gen fns to benchmark_fusion NYI, so skip deferred template selection - # TODO(jgong5): support multi-template on CPU - if input_gen_fns is not None or layout.device.type == "cpu": + # TODO(jgong5): support multi-template on CPU C++ backend + if input_gen_fns is not None or ( + layout.device.type == "cpu" and config.cpu_backend != "triton" + ): return_multi_template = False # TODO - assert that we have not mutating kernels here diff --git a/torch/_inductor/wrapper_benchmark.py b/torch/_inductor/wrapper_benchmark.py index f8430064917e..a721393b2bfb 100644 --- a/torch/_inductor/wrapper_benchmark.py +++ b/torch/_inductor/wrapper_benchmark.py @@ -93,6 +93,7 @@ def benchmark_all_kernels( continue triton_kernel = get_triton_kernel(kernel_mod) + device_type = triton_kernel.device_props.type kernel_category = get_kernel_category(kernel_mod) args = kernel_mod.get_args() num_in_out_ptrs = len( @@ -137,7 +138,12 @@ def benchmark_all_kernels( f" {get_info_str(ms, launcher.n_regs, launcher.n_spills, launcher.shared)} @ {launcher.config}" ) else: - ms = benchmarker.benchmark_gpu(lambda: kernel_mod.call(args), rep=40) + ms = benchmarker.benchmark( + kernel_mod.call, + fn_args=(args,), + device=device_type, + rep=40, + ) assert len(triton_kernel.launchers) == 1, ( "Autotuner should have selected the best config" )