mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[inductor] Expand use of generic benchmark function (#164938)"
This reverts commit 5c583e2573f29243742e00b9fa36b266c5c78bb3.
Reverted https://github.com/pytorch/pytorch/pull/164938 on behalf of https://github.com/clee2000 due to I think this broke test/inductor/test_cuda_repro.py::CudaReproTests::test_epilogue_fusion_with_view? [GH job link](https://github.com/pytorch/pytorch/actions/runs/18529735968/job/52813191763) [HUD commit link](f58f301313
) on both rocm and the slow grad check for linux. It did run successfully on cuda workflow on trunk, I wonder if this a gpu capability thing? no clue though ([comment](https://github.com/pytorch/pytorch/pull/164938#issuecomment-3407600224))
This commit is contained in:
@ -8,7 +8,6 @@ from typing import Any, Optional, Union
|
|||||||
|
|
||||||
from torch._inductor.ir import MultiTemplateBuffer
|
from torch._inductor.ir import MultiTemplateBuffer
|
||||||
from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
|
from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
|
||||||
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
|
|
||||||
from torch.utils._ordered_set import OrderedSet
|
from torch.utils._ordered_set import OrderedSet
|
||||||
|
|
||||||
from .. import config
|
from .. import config
|
||||||
@ -370,20 +369,16 @@ class MultiKernelCall:
|
|||||||
be picked.
|
be picked.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_args_kwargs(kernel, index) -> tuple[tuple, dict[str, Any]]: # type: ignore[type-arg]
|
def wrap_fn(kernel, index):
|
||||||
filtered_args = self._get_filtered_args(args, index)
|
def inner():
|
||||||
args_clone, kwargs_clone = kernel.clone_args(*filtered_args, **kwargs)
|
filtered_args = self._get_filtered_args(args, index)
|
||||||
return args_clone, kwargs_clone
|
args_clone, kwargs_clone = kernel.clone_args(*filtered_args, **kwargs)
|
||||||
|
return kernel.run(*args_clone, **kwargs_clone)
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
return [
|
return [
|
||||||
benchmarker.benchmark(
|
benchmarker.benchmark_gpu(wrap_fn(kernel, index), rep=40)
|
||||||
kernel.run,
|
|
||||||
*get_args_kwargs(kernel, index),
|
|
||||||
device=kernel.device_props.type
|
|
||||||
if isinstance(kernel, CachingAutotuner)
|
|
||||||
else None,
|
|
||||||
rep=40,
|
|
||||||
)
|
|
||||||
for index, kernel in enumerate(self.kernels)
|
for index, kernel in enumerate(self.kernels)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -109,10 +109,7 @@ class SubgraphChoiceCaller(ir.ChoiceCaller):
|
|||||||
bm_func([*sym_inputs, *args])
|
bm_func([*sym_inputs, *args])
|
||||||
if config.profile_bandwidth_with_do_bench_using_profiling:
|
if config.profile_bandwidth_with_do_bench_using_profiling:
|
||||||
return do_bench_using_profiling(lambda: bm_func([*sym_inputs, *args]))
|
return do_bench_using_profiling(lambda: bm_func([*sym_inputs, *args]))
|
||||||
return benchmarker.benchmark(
|
return benchmarker.benchmark_gpu(lambda: bm_func([*sym_inputs, *args]))
|
||||||
bm_func,
|
|
||||||
fn_args=([*sym_inputs, *args],),
|
|
||||||
)
|
|
||||||
|
|
||||||
def hash_key(self) -> str:
|
def hash_key(self) -> str:
|
||||||
return "-".join(
|
return "-".join(
|
||||||
|
@ -4682,7 +4682,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
|||||||
|
|
||||||
result.writeline("args = get_args()")
|
result.writeline("args = get_args()")
|
||||||
result.writeline(
|
result.writeline(
|
||||||
f"ms = benchmarker.benchmark(lambda: call(args), device={V.graph.get_current_device_or_throw().type}, rep=40)" # noqa: B950 line too long
|
"ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40)"
|
||||||
)
|
)
|
||||||
result.writeline(f"num_gb = {num_gb}")
|
result.writeline(f"num_gb = {num_gb}")
|
||||||
result.writeline("gb_per_s = num_gb / (ms / 1e3)")
|
result.writeline("gb_per_s = num_gb / (ms / 1e3)")
|
||||||
@ -5624,21 +5624,18 @@ class TritonScheduling(SIMDScheduling):
|
|||||||
# skip benchmarking the kernel if there are register spills
|
# skip benchmarking the kernel if there are register spills
|
||||||
ms = float("inf")
|
ms = float("inf")
|
||||||
else:
|
else:
|
||||||
device = V.graph.get_current_device_or_throw()
|
|
||||||
# We have to clone the inplace updated arguments to avoid earlier calls
|
# We have to clone the inplace updated arguments to avoid earlier calls
|
||||||
# generating out of range indices for later calls.
|
# generating out of range indices for later calls.
|
||||||
ms = benchmarker.benchmark(
|
ms = benchmarker.benchmark_gpu(
|
||||||
lambda: call(wrapped_jit_function.clone_args(*args)[0]),
|
lambda: call(wrapped_jit_function.clone_args(*args)[0])
|
||||||
device=device,
|
|
||||||
)
|
)
|
||||||
# overhead of cloning args gives bias for fusing the kernel
|
# overhead of cloning args gives bias for fusing the kernel
|
||||||
# in the case of mutating/in-placeable second fusion
|
# in the case of mutating/in-placeable second fusion
|
||||||
# TODO - would be better as a hook in triton do_bench that reset
|
# TODO - would be better as a hook in triton do_bench that reset
|
||||||
# the input values between benchmarking
|
# the input values between benchmarking
|
||||||
if len(wrapped_jit_function.mutated_arg_names) > 0:
|
if len(wrapped_jit_function.mutated_arg_names) > 0:
|
||||||
ms = ms - benchmarker.benchmark(
|
ms = ms - benchmarker.benchmark_gpu(
|
||||||
lambda: wrapped_jit_function.clone_args(*args),
|
lambda: wrapped_jit_function.clone_args(*args)
|
||||||
device=str(device),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
log.debug(
|
log.debug(
|
||||||
@ -5807,16 +5804,13 @@ class TritonScheduling(SIMDScheduling):
|
|||||||
# skip benchmarking the kernel if there are register spills
|
# skip benchmarking the kernel if there are register spills
|
||||||
ms = ms_clone = float("inf")
|
ms = ms_clone = float("inf")
|
||||||
else:
|
else:
|
||||||
device = V.graph.get_current_device_or_throw()
|
|
||||||
# We have to clone the inplace updated arguments to avoid earlier calls
|
# We have to clone the inplace updated arguments to avoid earlier calls
|
||||||
# generating out of range indices for later calls.
|
# generating out of range indices for later calls.
|
||||||
ms = benchmarker.benchmark(
|
ms = benchmarker.benchmark_gpu(
|
||||||
lambda: call(wrapped_jit_function.clone_args(*args)[0]),
|
lambda: call(wrapped_jit_function.clone_args(*args)[0])
|
||||||
device=device,
|
|
||||||
)
|
)
|
||||||
ms_clone = benchmarker.benchmark(
|
ms_clone = benchmarker.benchmark_gpu(
|
||||||
lambda: wrapped_jit_function.clone_args(*args)[0],
|
lambda: wrapped_jit_function.clone_args(*args)[0]
|
||||||
device=device,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
log.debug(
|
log.debug(
|
||||||
|
@ -889,7 +889,6 @@ class ComboKernel(Kernel):
|
|||||||
result.writeline(f"return {', '.join(var_names)},")
|
result.writeline(f"return {', '.join(var_names)},")
|
||||||
|
|
||||||
result.writelines(["\n", "\n", "def call(args):"])
|
result.writelines(["\n", "\n", "def call(args):"])
|
||||||
device = V.graph.get_current_device_or_throw()
|
|
||||||
index = V.graph.get_current_device_or_throw().index
|
index = V.graph.get_current_device_or_throw().index
|
||||||
with result.indent():
|
with result.indent():
|
||||||
result.writeline(f"with {V.graph.device_ops.device_guard(index)}:")
|
result.writeline(f"with {V.graph.device_ops.device_guard(index)}:")
|
||||||
@ -924,7 +923,7 @@ class ComboKernel(Kernel):
|
|||||||
|
|
||||||
result.writeline("args = get_args()")
|
result.writeline("args = get_args()")
|
||||||
result.writeline(
|
result.writeline(
|
||||||
f"ms = benchmarker.benchmark(call, fn_args=(args,), device={device.type},rep=40)"
|
"ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40)"
|
||||||
)
|
)
|
||||||
result.writeline(f"num_gb = {num_gb}")
|
result.writeline(f"num_gb = {num_gb}")
|
||||||
result.writeline("gb_per_s = num_gb / (ms / 1e3)")
|
result.writeline("gb_per_s = num_gb / (ms / 1e3)")
|
||||||
|
@ -5050,9 +5050,7 @@ class ChoiceCaller:
|
|||||||
}
|
}
|
||||||
if config.profile_bandwidth_with_do_bench_using_profiling:
|
if config.profile_bandwidth_with_do_bench_using_profiling:
|
||||||
return do_bench_using_profiling(lambda: algo(*args), **benchmark_configs) # type: ignore[arg-type]
|
return do_bench_using_profiling(lambda: algo(*args), **benchmark_configs) # type: ignore[arg-type]
|
||||||
return benchmarker.benchmark(
|
return benchmarker.benchmark(algo, args, {"out": out}, **benchmark_configs)
|
||||||
algo, args, {"out": out}, device=None, **benchmark_configs
|
|
||||||
)
|
|
||||||
|
|
||||||
def call_name(self) -> str:
|
def call_name(self) -> str:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -92,11 +92,6 @@ def time_and_count(
|
|||||||
|
|
||||||
|
|
||||||
class Benchmarker:
|
class Benchmarker:
|
||||||
"""
|
|
||||||
A device-agnostic benchmarking utility for measuring the runtime of
|
|
||||||
inductor generated callables.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self: Self) -> None:
|
def __init__(self: Self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -104,9 +99,8 @@ class Benchmarker:
|
|||||||
def benchmark(
|
def benchmark(
|
||||||
self: Self,
|
self: Self,
|
||||||
fn: Callable[..., Any],
|
fn: Callable[..., Any],
|
||||||
fn_args: Optional[tuple[Any, ...]] = None,
|
fn_args: tuple[Any, ...],
|
||||||
fn_kwargs: Optional[dict[str, Any]] = None,
|
fn_kwargs: dict[str, Any],
|
||||||
device: Optional[Union[str, torch.device]] = None,
|
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> float:
|
) -> float:
|
||||||
"""Benchmark `fn(*fn_args, *fn_kwargs)` and return the runtime, in milliseconds (the
|
"""Benchmark `fn(*fn_args, *fn_kwargs)` and return the runtime, in milliseconds (the
|
||||||
@ -115,8 +109,7 @@ class Benchmarker:
|
|||||||
device-specific implementations, like `benchmark_cpu` and `benchmark_gpu`. Raises
|
device-specific implementations, like `benchmark_cpu` and `benchmark_gpu`. Raises
|
||||||
`ValueError(...)` if we can't safely infer the device type of `fn`; for example,
|
`ValueError(...)` if we can't safely infer the device type of `fn`; for example,
|
||||||
if multiple device types are found in `fn_args` and `fn_kwargs`, or if no device
|
if multiple device types are found in `fn_args` and `fn_kwargs`, or if no device
|
||||||
types are found. To bypass device inference, provide the device to the `device`
|
types are found.
|
||||||
parameter.
|
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
- fn: The function to benchmark.
|
- fn: The function to benchmark.
|
||||||
@ -124,52 +117,26 @@ class Benchmarker:
|
|||||||
- fn_kwargs: The function's kwargs.
|
- fn_kwargs: The function's kwargs.
|
||||||
|
|
||||||
Keyword Arguments:
|
Keyword Arguments:
|
||||||
- device: Which device to use for benchmarking. If not provided the device will be attempted
|
|
||||||
to be inferred from `fn_args` and `fn_kwargs`.
|
|
||||||
- **kwargs: The benchmarking implementation's kwargs.
|
- **kwargs: The benchmarking implementation's kwargs.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- The runtime of `fn(*fn_args, **fn_kwargs)`, in milliseconds.
|
- The runtime of `fn(*fn_args, **fn_kwargs)`, in milliseconds.
|
||||||
"""
|
"""
|
||||||
inferred_device: Optional[torch.device] = None
|
inferred_device = None
|
||||||
if device is not None:
|
for arg_or_kwarg in chain(fn_args, fn_kwargs.values()):
|
||||||
inferred_device = (
|
if not isinstance(arg_or_kwarg, torch.Tensor):
|
||||||
torch.device(device) if isinstance(device, str) else device
|
continue
|
||||||
)
|
if inferred_device is None:
|
||||||
else:
|
inferred_device = arg_or_kwarg.device
|
||||||
if fn_args is None and fn_kwargs is None:
|
elif arg_or_kwarg.device != inferred_device:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"`fn_args` and `fn_kwargs` cannot both be None if `device` is not provided."
|
"Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`!"
|
||||||
)
|
)
|
||||||
|
|
||||||
fn_args = fn_args or tuple()
|
|
||||||
fn_kwargs = fn_kwargs or {}
|
|
||||||
for arg_or_kwarg in chain(fn_args, fn_kwargs.values()):
|
|
||||||
if not isinstance(arg_or_kwarg, torch.Tensor):
|
|
||||||
continue
|
|
||||||
if inferred_device is None:
|
|
||||||
inferred_device = arg_or_kwarg.device
|
|
||||||
elif arg_or_kwarg.device != inferred_device:
|
|
||||||
raise ValueError(
|
|
||||||
"Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`!"
|
|
||||||
)
|
|
||||||
|
|
||||||
if inferred_device is None:
|
if inferred_device is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Can't safely infer the device type of `fn` with no device types"
|
"Can't safely infer the device type of `fn` with no device types in `fn_args` or `fn_kwargs`! You should be calling `.benchmark_cpu` or `.benchmark_gpu` directly." # noqa: B950
|
||||||
" in `fn_args` or `fn_kwargs` and `device` not explicitly provided!"
|
|
||||||
" You should be calling `.benchmark_cpu` or `.benchmark_gpu` directly."
|
|
||||||
)
|
)
|
||||||
|
_callable = lambda: fn(*fn_args, **fn_kwargs) # noqa: E731
|
||||||
fn_args = fn_args or tuple()
|
|
||||||
fn_kwargs = fn_kwargs or {}
|
|
||||||
|
|
||||||
# No need to wrap if the callable takes no arguments
|
|
||||||
if len(fn_args) == 0 and len(fn_kwargs) == 0:
|
|
||||||
_callable = fn
|
|
||||||
else:
|
|
||||||
_callable = lambda: fn(*fn_args, **fn_kwargs) # noqa: E731
|
|
||||||
|
|
||||||
if inferred_device == torch.device("cpu"):
|
if inferred_device == torch.device("cpu"):
|
||||||
return self.benchmark_cpu(_callable, **kwargs)
|
return self.benchmark_cpu(_callable, **kwargs)
|
||||||
# TODO(nmacchioni): For non-CPU functions we default to using the GPU-specific benchmarking
|
# TODO(nmacchioni): For non-CPU functions we default to using the GPU-specific benchmarking
|
||||||
|
@ -927,11 +927,11 @@ class CachingAutotuner(KernelInterface):
|
|||||||
|
|
||||||
return do_bench_using_profiling(kernel_call, warmup=10, rep=40)
|
return do_bench_using_profiling(kernel_call, warmup=10, rep=40)
|
||||||
|
|
||||||
benchmark_kwargs = {"rep": 40} if self.device_props.type == "cuda" else {}
|
if self.device_props.type == "cpu":
|
||||||
return benchmarker.benchmark(
|
return benchmarker.benchmark_cpu(kernel_call)
|
||||||
fn=kernel_call,
|
|
||||||
device=self.device_props.type,
|
return benchmarker.benchmark_gpu(
|
||||||
**benchmark_kwargs, # type: ignore[arg-type]
|
kernel_call, rep=40, is_vetted_benchmarking=True
|
||||||
)
|
)
|
||||||
|
|
||||||
def copy_args_to_cpu_if_needed(self, *args, **kwargs):
|
def copy_args_to_cpu_if_needed(self, *args, **kwargs):
|
||||||
|
@ -3269,8 +3269,8 @@ class Scheduler:
|
|||||||
device = node_list_1[0].get_device()
|
device = node_list_1[0].get_device()
|
||||||
assert device
|
assert device
|
||||||
|
|
||||||
# don't support benchmark fusion for CPU C++ backend right now.
|
# don't support benchmark fusion for CPU right now.
|
||||||
if device.type == "cpu" and config.cpu_backend != "triton":
|
if device.type == "cpu":
|
||||||
return True
|
return True
|
||||||
|
|
||||||
node_list_2 = node2.get_nodes()
|
node_list_2 = node2.get_nodes()
|
||||||
@ -5569,8 +5569,8 @@ class Scheduler:
|
|||||||
subkernel_nodes = nodes
|
subkernel_nodes = nodes
|
||||||
device = subkernel_nodes[0].get_device()
|
device = subkernel_nodes[0].get_device()
|
||||||
|
|
||||||
# don't support benchmark fusion for CPU C++ backend right now.
|
# don't support benchmark fusion for CPU right now.
|
||||||
if device is None or (device.type == "cpu" and config.cpu_backend != "triton"):
|
if device is None or device.type == "cpu":
|
||||||
return True
|
return True
|
||||||
|
|
||||||
from triton.compiler.errors import CompilationError
|
from triton.compiler.errors import CompilationError
|
||||||
|
@ -2671,10 +2671,8 @@ class AlgorithmSelectorCache(PersistentCache):
|
|||||||
|
|
||||||
# Templates selected with input_gen_fns require specific input data to avoid IMA
|
# Templates selected with input_gen_fns require specific input data to avoid IMA
|
||||||
# Passing custom input gen fns to benchmark_fusion NYI, so skip deferred template selection
|
# Passing custom input gen fns to benchmark_fusion NYI, so skip deferred template selection
|
||||||
# TODO(jgong5): support multi-template on CPU C++ backend
|
# TODO(jgong5): support multi-template on CPU
|
||||||
if input_gen_fns is not None or (
|
if input_gen_fns is not None or layout.device.type == "cpu":
|
||||||
layout.device.type == "cpu" and config.cpu_backend != "triton"
|
|
||||||
):
|
|
||||||
return_multi_template = False
|
return_multi_template = False
|
||||||
|
|
||||||
# TODO - assert that we have not mutating kernels here
|
# TODO - assert that we have not mutating kernels here
|
||||||
|
@ -93,7 +93,6 @@ def benchmark_all_kernels(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
triton_kernel = get_triton_kernel(kernel_mod)
|
triton_kernel = get_triton_kernel(kernel_mod)
|
||||||
device_type = triton_kernel.device_props.type
|
|
||||||
kernel_category = get_kernel_category(kernel_mod)
|
kernel_category = get_kernel_category(kernel_mod)
|
||||||
args = kernel_mod.get_args()
|
args = kernel_mod.get_args()
|
||||||
num_in_out_ptrs = len(
|
num_in_out_ptrs = len(
|
||||||
@ -138,12 +137,7 @@ def benchmark_all_kernels(
|
|||||||
f" {get_info_str(ms, launcher.n_regs, launcher.n_spills, launcher.shared)} @ {launcher.config}"
|
f" {get_info_str(ms, launcher.n_regs, launcher.n_spills, launcher.shared)} @ {launcher.config}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
ms = benchmarker.benchmark(
|
ms = benchmarker.benchmark_gpu(lambda: kernel_mod.call(args), rep=40)
|
||||||
kernel_mod.call,
|
|
||||||
fn_args=(args,),
|
|
||||||
device=device_type,
|
|
||||||
rep=40,
|
|
||||||
)
|
|
||||||
assert len(triton_kernel.launchers) == 1, (
|
assert len(triton_kernel.launchers) == 1, (
|
||||||
"Autotuner should have selected the best config"
|
"Autotuner should have selected the best config"
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user