Compare commits

...

35 Commits

Author SHA1 Message Date
4f1639b7bf Merge branch 'main' into mwizak/expand-generic-benchmark-use 2025-10-21 14:38:05 +00:00
24733c61e3 Lint 2025-10-21 14:15:00 +00:00
5a4367979f Merge branch 'mwizak/expand-generic-benchmark-use' of github.com:graphcore/pytorch-fork into mwizak/expand-generic-benchmark-use 2025-10-21 10:54:46 +00:00
cfca9d7e04 Kernel benchmarks have args as type list 2025-10-21 10:30:06 +00:00
1641a44a8b Lint 2025-10-21 10:30:06 +00:00
a99440ed7f Add comment 2025-10-21 10:30:06 +00:00
b7ed0ee464 Add comment 2025-10-21 10:30:06 +00:00
130b853463 Preserve existing cuda options 2025-10-21 10:30:06 +00:00
6d6c218c53 Fix mutable benchmark fn 2025-10-21 10:30:06 +00:00
be75b493cc extract tensors from nested structures 2025-10-21 10:30:06 +00:00
220c2c9dd8 Revert comment change 2025-10-21 10:30:06 +00:00
2a865c5775 Make args kwargs optional 2025-10-21 10:30:06 +00:00
ba11b9f4b5 Remove unneccessary line change 2025-10-21 10:30:06 +00:00
b1bbc488f2 Allow benchmarking callables that take no arguments 2025-10-21 10:30:06 +00:00
8e7f35c158 Lint 2025-10-21 10:30:06 +00:00
3dcc865fcf Fix mypy lint error 2025-10-21 10:30:05 +00:00
4e2cf6a6b9 Lint 2025-10-21 10:30:05 +00:00
a563dbc6c0 Just use the standard CPU timer 2025-10-21 10:30:05 +00:00
3cac8e3e9d Expand generic benchmark use
Allow triton CPU to be benchmarked
2025-10-21 10:30:05 +00:00
d16293b29a Add types 2025-10-17 11:38:36 +00:00
ec87bf4f96 Lint 2025-10-17 11:23:36 +00:00
71b2a4981f Add comment 2025-10-17 11:22:41 +00:00
5275f5cecf Add comment 2025-10-17 11:22:41 +00:00
4b6b47eed6 Preserve existing cuda options 2025-10-17 11:22:41 +00:00
10c3a73663 Fix mutable benchmark fn 2025-10-17 11:22:39 +00:00
fc5d397564 extract tensors from nested structures 2025-10-17 11:22:07 +00:00
85cf6b6331 Revert comment change 2025-10-17 11:22:07 +00:00
3453e5ecca Make args kwargs optional 2025-10-17 11:22:04 +00:00
e833b190ae Remove unneccessary line change 2025-10-17 11:21:42 +00:00
3eedfc141b Allow benchmarking callables that take no arguments 2025-10-17 11:21:42 +00:00
e93ef10e81 Lint 2025-10-17 11:21:42 +00:00
11ac1bc54b Fix mypy lint error 2025-10-17 11:21:42 +00:00
4af85f3df6 Lint 2025-10-17 11:21:42 +00:00
d577ff11f1 Just use the standard CPU timer 2025-10-17 11:21:42 +00:00
e884eedb03 Expand generic benchmark use
Allow triton CPU to be benchmarked
2025-10-17 11:21:40 +00:00
10 changed files with 143 additions and 47 deletions

View File

@ -8,6 +8,7 @@ from typing import Any, Optional, Union
from torch._inductor.ir import MultiTemplateBuffer
from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
from torch.utils._ordered_set import OrderedSet
from .. import config
@ -372,16 +373,20 @@ class MultiKernelCall:
be picked.
"""
def wrap_fn(kernel, index):
def inner():
filtered_args = self._get_filtered_args(args, index)
args_clone, kwargs_clone = kernel.clone_args(*filtered_args, **kwargs)
return kernel.run(*args_clone, **kwargs_clone)
return inner
def get_args_kwargs(kernel, index) -> tuple[tuple, dict[str, Any]]: # type: ignore[type-arg]
filtered_args = self._get_filtered_args(args, index)
args_clone, kwargs_clone = kernel.clone_args(*filtered_args, **kwargs)
return args_clone, kwargs_clone
return [
benchmarker.benchmark_gpu(wrap_fn(kernel, index), rep=40)
benchmarker.benchmark(
kernel.run,
*get_args_kwargs(kernel, index),
device=kernel.device_props.type
if isinstance(kernel, CachingAutotuner)
else None,
rep=40,
)
for index, kernel in enumerate(self.kernels)
]

View File

@ -110,7 +110,11 @@ class SubgraphChoiceCaller(ir.ChoiceCaller):
bm_func([*sym_inputs, *args])
if config.profile_bandwidth_with_do_bench_using_profiling:
return do_bench_using_profiling(lambda: bm_func([*sym_inputs, *args]))
return benchmarker.benchmark_gpu(lambda: bm_func([*sym_inputs, *args]))
return benchmarker.benchmark(
# Shallow clone args since bm_func may clear args
lambda: bm_func([*sym_inputs, *args]),
device=benchmarker.infer_device(*sym_inputs, *args),
)
def hash_key(self) -> str:
return "-".join(

View File

@ -4726,7 +4726,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
result.writeline("args = get_args()")
result.writeline(
"ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40)"
f"ms = benchmarker.benchmark(lambda: call(args), device={V.graph.get_current_device_or_throw().type}, rep=40)" # noqa: B950 line too long
)
result.writeline(f"num_gb = {num_gb}")
result.writeline("gb_per_s = num_gb / (ms / 1e3)")
@ -5666,18 +5666,21 @@ class TritonScheduling(SIMDScheduling):
# skip benchmarking the kernel if there are register spills
ms = float("inf")
else:
device = V.graph.get_current_device_or_throw()
# We have to clone the inplace updated arguments to avoid earlier calls
# generating out of range indices for later calls.
ms = benchmarker.benchmark_gpu(
lambda: call(wrapped_jit_function.clone_args(*args)[0])
ms = benchmarker.benchmark(
lambda: call(wrapped_jit_function.clone_args(*args)[0]),
device=device,
)
# overhead of cloning args gives bias for fusing the kernel
# in the case of mutating/in-placeable second fusion
# TODO - would be better as a hook in triton do_bench that reset
# the input values between benchmarking
if len(wrapped_jit_function.mutated_arg_names) > 0:
ms = ms - benchmarker.benchmark_gpu(
lambda: wrapped_jit_function.clone_args(*args)
ms = ms - benchmarker.benchmark(
lambda: wrapped_jit_function.clone_args(*args),
device=str(device),
)
log.debug(
@ -5846,13 +5849,16 @@ class TritonScheduling(SIMDScheduling):
# skip benchmarking the kernel if there are register spills
ms = ms_clone = float("inf")
else:
device = V.graph.get_current_device_or_throw()
# We have to clone the inplace updated arguments to avoid earlier calls
# generating out of range indices for later calls.
ms = benchmarker.benchmark_gpu(
lambda: call(wrapped_jit_function.clone_args(*args)[0])
ms = benchmarker.benchmark(
lambda: call(wrapped_jit_function.clone_args(*args)[0]),
device=device,
)
ms_clone = benchmarker.benchmark_gpu(
lambda: wrapped_jit_function.clone_args(*args)[0]
ms_clone = benchmarker.benchmark(
lambda: wrapped_jit_function.clone_args(*args)[0],
device=device,
)
log.debug(

View File

@ -898,6 +898,7 @@ class ComboKernel(Kernel):
result.writeline(f"return {', '.join(var_names)},")
result.writelines(["\n", "\n", "def call(args):"])
device = V.graph.get_current_device_or_throw()
index = V.graph.get_current_device_or_throw().index
with result.indent():
result.writeline(f"with {V.graph.device_ops.device_guard(index)}:")
@ -932,7 +933,7 @@ class ComboKernel(Kernel):
result.writeline("args = get_args()")
result.writeline(
"ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40)"
f"ms = benchmarker.benchmark(call, fn_args=(args,), device={device.type},rep=40)"
)
result.writeline(f"num_gb = {num_gb}")
result.writeline("gb_per_s = num_gb / (ms / 1e3)")

View File

@ -5048,7 +5048,9 @@ class ChoiceCaller:
}
if config.profile_bandwidth_with_do_bench_using_profiling:
return do_bench_using_profiling(lambda: algo(*args), **benchmark_configs) # type: ignore[arg-type]
return benchmarker.benchmark(algo, args, {"out": out}, **benchmark_configs)
return benchmarker.benchmark(
algo, args, {"out": out}, device=None, **benchmark_configs
)
def call_name(self) -> str:
raise NotImplementedError

View File

@ -4,10 +4,11 @@ import time
from functools import cached_property, wraps
from itertools import chain
from statistics import median
from typing import Any, Callable
from typing import Any, Callable, Optional, Union
from typing_extensions import Concatenate, ParamSpec, Self, TypeVar
import torch
import torch.utils._pytree as pytree
from torch._dynamo.utils import counters, dynamo_timed
from torch._inductor.config import use_experimental_benchmarker
@ -92,15 +93,45 @@ def time_and_count(
class Benchmarker:
"""
A device-agnostic benchmarking utility for measuring the runtime of
inductor generated callables.
"""
def __init__(self: Self) -> None:
pass
def infer_device(self, *fn_args: Any, **fn_kwargs: Any) -> torch.device:
inferred_device: Optional[torch.device] = None
for arg_or_kwarg in chain(fn_args, fn_kwargs.values()):
# Some callables take nested structures as arguments so use the
# flattened form to find any tensors
for arg_or_kwarg_leaf in pytree.tree_leaves(arg_or_kwarg):
if not isinstance(arg_or_kwarg_leaf, torch.Tensor):
continue
if inferred_device is None:
inferred_device = arg_or_kwarg_leaf.device
elif arg_or_kwarg_leaf.device != inferred_device:
raise ValueError(
"Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`!"
)
if inferred_device is None:
raise ValueError(
"Can't safely infer the device type of `fn` with no device types"
" in `fn_args` or `fn_kwargs`. Use a direct benchmarking method instead e.g. "
"`Benchmarker.benchmark_cpu` or `Benchmarker.benchmark_gpu`."
)
return inferred_device
@time_and_count
def benchmark(
self: Self,
fn: Callable[..., Any],
fn_args: tuple[Any, ...],
fn_kwargs: dict[str, Any],
fn_args: Optional[tuple[Any, ...]] = None,
fn_kwargs: Optional[dict[str, Any]] = None,
device: Optional[Union[str, torch.device]] = None,
**kwargs: Any,
) -> float:
"""Benchmark `fn(*fn_args, *fn_kwargs)` and return the runtime, in milliseconds (the
@ -109,7 +140,14 @@ class Benchmarker:
device-specific implementations, like `benchmark_cpu` and `benchmark_gpu`. Raises
`ValueError(...)` if we can't safely infer the device type of `fn`; for example,
if multiple device types are found in `fn_args` and `fn_kwargs`, or if no device
types are found.
types are found. To bypass device inference, provide the device to the `device`
parameter.
WARNING: if `fn` mutates `fn_args` or `fn_kwargs`, benchmarking may fail unexpectedly.
For example, if `fn` clears a mutable object, subsequent invocations of `fn` during
benchmarking will fail. In such cases, `fn` should handle cloning its arguments internally.
If device inference is required, `Benchmarker.infer_device` can be used prior to calling
this method without any arguments for `fn_args` and `fn_kwargs`.
Arguments:
- fn: The function to benchmark.
@ -117,27 +155,56 @@ class Benchmarker:
- fn_kwargs: The function's kwargs.
Keyword Arguments:
- device: Which device to use for benchmarking. If not provided the device will be attempted
to be inferred from `fn_args` and `fn_kwargs`.
- **kwargs: The benchmarking implementation's kwargs.
Returns:
- The runtime of `fn(*fn_args, **fn_kwargs)`, in milliseconds.
"""
inferred_device = None
# pyrefly: ignore # bad-assignment
for arg_or_kwarg in chain(fn_args, fn_kwargs.values()):
if not isinstance(arg_or_kwarg, torch.Tensor):
continue
if inferred_device is None:
inferred_device = arg_or_kwarg.device
elif arg_or_kwarg.device != inferred_device:
inferred_device: Optional[torch.device] = None
if device is not None:
inferred_device = (
torch.device(device) if isinstance(device, str) else device
)
else:
if fn_args is None and fn_kwargs is None:
raise ValueError(
"Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`!"
"`fn_args` and `fn_kwargs` cannot both be None if `device` is not provided."
)
fn_args = fn_args or tuple()
fn_kwargs = fn_kwargs or {}
# pyrefly: ignore # bad-assignment
for arg_or_kwarg in chain(fn_args, fn_kwargs.values()):
# Some callables take nested structures as arguments so use the
# flattened form to find any tensors
for arg_or_kwarg_leaf in pytree.tree_leaves(arg_or_kwarg):
if not isinstance(arg_or_kwarg_leaf, torch.Tensor):
continue
if inferred_device is None:
inferred_device = arg_or_kwarg_leaf.device
elif arg_or_kwarg_leaf.device != inferred_device:
raise ValueError(
"Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`!"
)
if inferred_device is None:
raise ValueError(
"Can't safely infer the device type of `fn` with no device types in `fn_args` or `fn_kwargs`! You should be calling `.benchmark_cpu` or `.benchmark_gpu` directly." # noqa: B950
"Can't safely infer the device type of `fn` with no device types"
" in `fn_args` or `fn_kwargs` and `device` not explicitly provided!"
" You should be calling `.benchmark_cpu` or `.benchmark_gpu` directly."
)
_callable = lambda: fn(*fn_args, **fn_kwargs) # noqa: E731
fn_args = fn_args or tuple()
fn_kwargs = fn_kwargs or {}
# No need to wrap if the callable takes no arguments
if len(fn_args) == 0 and len(fn_kwargs) == 0:
_callable = fn
else:
_callable = lambda: fn(*fn_args, **fn_kwargs) # noqa: E731
if inferred_device == torch.device("cpu"):
return self.benchmark_cpu(_callable, **kwargs)
# TODO(nmacchioni): For non-CPU functions we default to using the GPU-specific benchmarking

View File

@ -916,11 +916,15 @@ class CachingAutotuner(KernelInterface):
return do_bench_using_profiling(kernel_call, warmup=10, rep=40)
if self.device_props.type == "cpu":
return benchmarker.benchmark_cpu(kernel_call)
return benchmarker.benchmark_gpu(
kernel_call, rep=40, is_vetted_benchmarking=True
benchmark_kwargs = (
{"rep": 40, "is_vetted_benchmarking": True}
if self.device_props.type == "cuda"
else {}
)
return benchmarker.benchmark(
fn=kernel_call,
device=self.device_props.type,
**benchmark_kwargs, # type: ignore[arg-type]
)
def copy_args_to_cpu_if_needed(self, *args, **kwargs):

View File

@ -3269,8 +3269,8 @@ class Scheduler:
device = node_list_1[0].get_device()
assert device
# don't support benchmark fusion for CPU right now.
if device.type == "cpu":
# don't support benchmark fusion for CPU C++ backend right now.
if device.type == "cpu" and config.cpu_backend != "triton":
return True
node_list_2 = node2.get_nodes()
@ -5585,8 +5585,8 @@ class Scheduler:
subkernel_nodes = nodes
device = subkernel_nodes[0].get_device()
# don't support benchmark fusion for CPU right now.
if device is None or device.type == "cpu":
# don't support benchmark fusion for CPU C++ backend right now.
if device is None or (device.type == "cpu" and config.cpu_backend != "triton"):
return True
from triton.compiler.errors import CompilationError

View File

@ -2671,8 +2671,10 @@ class AlgorithmSelectorCache(PersistentCache):
# Templates selected with input_gen_fns require specific input data to avoid IMA
# Passing custom input gen fns to benchmark_fusion NYI, so skip deferred template selection
# TODO(jgong5): support multi-template on CPU
if input_gen_fns is not None or layout.device.type == "cpu":
# TODO(jgong5): support multi-template on CPU C++ backend
if input_gen_fns is not None or (
layout.device.type == "cpu" and config.cpu_backend != "triton"
):
return_multi_template = False
# TODO - assert that we have not mutating kernels here

View File

@ -93,6 +93,7 @@ def benchmark_all_kernels(
continue
triton_kernel = get_triton_kernel(kernel_mod)
device_type = triton_kernel.device_props.type
kernel_category = get_kernel_category(kernel_mod)
args = kernel_mod.get_args()
num_in_out_ptrs = len(
@ -137,7 +138,11 @@ def benchmark_all_kernels(
f" {get_info_str(ms, launcher.n_regs, launcher.n_spills, launcher.shared)} @ {launcher.config}"
)
else:
ms = benchmarker.benchmark_gpu(lambda: kernel_mod.call(args), rep=40)
ms = benchmarker.benchmark(
lambda: kernel_mod.call(args),
device=device_type,
rep=40,
)
assert len(triton_kernel.launchers) == 1, (
"Autotuner should have selected the best config"
)