refactor benchmarking to use dynamo_timed (#144315)

use dynamo_timed for all our wrapped calls, instead of our custom timer

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144315
Approved by: https://github.com/eellison
This commit is contained in:
Nicolas Macchioni
2025-01-17 15:14:54 -08:00
committed by PyTorch MergeBot
parent 17c3a10cbd
commit ee3e89190a
2 changed files with 32 additions and 74 deletions

View File

@ -81,10 +81,6 @@ class TestBenchmarker(TestCase):
timing = benchmarker.benchmark_gpu(_callable)
self.assertGreater(timing, 0)
self.assertEqual(self.get_counter_value(benchmarker_cls, "benchmark_gpu"), 1)
if benchmarker_cls is TritonBenchmarker:
self.assertEqual(
self.get_counter_value(benchmarker_cls, "triton_do_bench"), 1
)
@unittest.skipIf(not HAS_CPU and not HAS_GPU, "requires CPU or GPU")
@unittest.expectedFailure

View File

@ -2,7 +2,7 @@ import time
from functools import cached_property, wraps
from itertools import chain
from statistics import median
from typing import Any, Callable, Dict, List
from typing import Any, Callable, Dict, List, Tuple
from typing_extensions import Concatenate, ParamSpec, Self, TypeVar
import torch
@ -18,52 +18,20 @@ P = ParamSpec("P")
T = TypeVar("T")
def maybe_time(
def time_and_count(
fn: Callable[Concatenate[Any, P], T]
) -> Callable[Concatenate[Any, P], T]:
"""Wrapper that logs the duration of `fn`, in milliseconds, along with a representation
of the function's args and kwargs, if logging is enabled. It is expected that `fn` is
a method of `Benchmarker` or one of its subclasses; typing limitations prevent us from
declaring this directly. If logging is disabled, this becomes a no-op.
"""
# no-op if benchmarking-specific logging is disabled
if not torch._logging._internal.log_state.is_artifact_enabled("benchmarking"):
return fn
@wraps(fn)
def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> T:
start_t = time.perf_counter()
result = fn(self, *args, **kwargs)
logger.debug(
"Call `benchmarking.%s.%s(*args=%r, **kwargs=%r)` took %f milliseconds.",
self.__class__.__name__,
fn.__name__,
args,
kwargs,
(time.perf_counter() - start_t) * MILLISECONDS_PER_SECOND,
)
return result
return wrapper
def count(fn: Callable[Concatenate[Any, P], T]) -> Callable[Concatenate[Any, P], T]:
"""Wrapper that increments relevant dynamo counters on `fn` call. It is expected that
`fn` is a method of `Benchmarker` or one of its subclass; typing limitations prevent
us from declaring this directly. The counter incrementation follows the formula,
`counters["inductor"]["benchmarking.Foo.bar] += 1`
where `Foo` is the class whose' instance called the function, and `bar` is the function name.
"""Wraps `fn` with `dynamo_timed` context, and increments the appropriate dynamo
counters. It is expected that `fn` is a method of `Benchmarker` or one of its
subclasses; typing limitations prevent us from declaring this directly.
"""
@wraps(fn)
def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> T:
counters["inductor"][
"benchmarking." + self.__class__.__name__ + "." + fn.__name__
] += 1
return fn(self, *args, **kwargs)
fn_qual_name = f"{self.__class__.__name__}.{fn.__name__}"
counters["inductor"][f"benchmarking.{fn_qual_name}"] += 1
with dynamo_timed(fn_qual_name, log_pt2_compile_event=True):
return fn(self, *args, **kwargs)
return wrapper
@ -72,12 +40,11 @@ class Benchmarker:
def __init__(self: Self) -> None:
pass
@maybe_time
@count
@time_and_count
def benchmark(
self: Self,
fn: Callable[..., Any],
fn_args: tuple[Any, ...],
fn_args: Tuple[Any, ...],
fn_kwargs: Dict[str, Any],
**kwargs: Any,
) -> float:
@ -100,31 +67,29 @@ class Benchmarker:
Returns:
- The runtime of `fn(*fn_args, **fn_kwargs)`, in milliseconds.
"""
with dynamo_timed("Benchmarker.benchmark", log_pt2_compile_event=True):
inferred_device = None
for arg_or_kwarg in chain(fn_args, fn_kwargs.values()):
if not isinstance(arg_or_kwarg, torch.Tensor):
continue
if inferred_device is None:
inferred_device = arg_or_kwarg.device
elif arg_or_kwarg.device != inferred_device:
raise ValueError(
"Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`!"
)
inferred_device = None
for arg_or_kwarg in chain(fn_args, fn_kwargs.values()):
if not isinstance(arg_or_kwarg, torch.Tensor):
continue
if inferred_device is None:
inferred_device = arg_or_kwarg.device
elif arg_or_kwarg.device != inferred_device:
raise ValueError(
"Can't safely infer the device type of `fn` with no device types in `fn_args` or `fn_kwargs`! You should be calling `.benchmark_cpu` or `.benchmark_gpu` directly." # noqa: B950
"Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`!"
)
_callable = lambda: fn(*fn_args, **fn_kwargs) # noqa: E731
if inferred_device == torch.device("cpu"):
return self.benchmark_cpu(_callable, **kwargs)
# TODO(nmacchioni): For non-CPU functions we default to using the GPU-specific benchmarking
# implementation which was written specifically with CUDA devices in mind, we may want to
# explore alternate implementations for other device types.
return self.benchmark_gpu(_callable, **kwargs)
if inferred_device is None:
raise ValueError(
"Can't safely infer the device type of `fn` with no device types in `fn_args` or `fn_kwargs`! You should be calling `.benchmark_cpu` or `.benchmark_gpu` directly." # noqa: B950
)
_callable = lambda: fn(*fn_args, **fn_kwargs) # noqa: E731
if inferred_device == torch.device("cpu"):
return self.benchmark_cpu(_callable, **kwargs)
# TODO(nmacchioni): For non-CPU functions we default to using the GPU-specific benchmarking
# implementation which was written specifically with CUDA devices in mind, we may want to
# explore alternate implementations for other device types.
return self.benchmark_gpu(_callable, **kwargs)
@maybe_time
@count
@time_and_count
def benchmark_cpu(
self: Self, _callable: Callable[[], Any], warmup: int = 20, rep: int = 100
) -> float:
@ -159,15 +124,13 @@ class Benchmarker:
run_for(warmup)
return median(run_for(rep))
@count
@time_and_count
def benchmark_gpu(self: Self, *args: Any, **kwargs: Any) -> float:
raise NotImplementedError
class TritonBenchmarker(Benchmarker):
@cached_property
@maybe_time
@count
def triton_do_bench(self: Self) -> Callable[..., Any]:
"""Lazily import Triton's `do_bench`."""
try:
@ -176,8 +139,7 @@ class TritonBenchmarker(Benchmarker):
raise NotImplementedError("requires Triton") from e
return do_bench
@maybe_time
@count
@time_and_count
def benchmark_gpu(self: Self, _callable: Callable[[], Any], **kwargs: Any) -> float:
"""Benchmark the GPU callable, `_callable`, and return the runtime, in milliseconds.