refactor benchmarking to use dynamo_timed (#144315)

use dynamo_timed for all our wrapped calls, instead of our custom timer

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144315
Approved by: https://github.com/eellison
This commit is contained in:
Nicolas Macchioni
2025-01-17 15:14:54 -08:00
committed by PyTorch MergeBot
parent 17c3a10cbd
commit ee3e89190a
2 changed files with 32 additions and 74 deletions

View File

@ -81,10 +81,6 @@ class TestBenchmarker(TestCase):
timing = benchmarker.benchmark_gpu(_callable) timing = benchmarker.benchmark_gpu(_callable)
self.assertGreater(timing, 0) self.assertGreater(timing, 0)
self.assertEqual(self.get_counter_value(benchmarker_cls, "benchmark_gpu"), 1) self.assertEqual(self.get_counter_value(benchmarker_cls, "benchmark_gpu"), 1)
if benchmarker_cls is TritonBenchmarker:
self.assertEqual(
self.get_counter_value(benchmarker_cls, "triton_do_bench"), 1
)
@unittest.skipIf(not HAS_CPU and not HAS_GPU, "requires CPU or GPU") @unittest.skipIf(not HAS_CPU and not HAS_GPU, "requires CPU or GPU")
@unittest.expectedFailure @unittest.expectedFailure

View File

@ -2,7 +2,7 @@ import time
from functools import cached_property, wraps from functools import cached_property, wraps
from itertools import chain from itertools import chain
from statistics import median from statistics import median
from typing import Any, Callable, Dict, List from typing import Any, Callable, Dict, List, Tuple
from typing_extensions import Concatenate, ParamSpec, Self, TypeVar from typing_extensions import Concatenate, ParamSpec, Self, TypeVar
import torch import torch
@ -18,52 +18,20 @@ P = ParamSpec("P")
T = TypeVar("T") T = TypeVar("T")
def maybe_time( def time_and_count(
fn: Callable[Concatenate[Any, P], T] fn: Callable[Concatenate[Any, P], T]
) -> Callable[Concatenate[Any, P], T]: ) -> Callable[Concatenate[Any, P], T]:
"""Wrapper that logs the duration of `fn`, in milliseconds, along with a representation """Wraps `fn` with `dynamo_timed` context, and increments the appropriate dynamo
of the function's args and kwargs, if logging is enabled. It is expected that `fn` is counters. It is expected that `fn` is a method of `Benchmarker` or one of its
a method of `Benchmarker` or one of its subclasses; typing limitations prevent us from subclasses; typing limitations prevent us from declaring this directly.
declaring this directly. If logging is disabled, this becomes a no-op.
"""
# no-op if benchmarking-specific logging is disabled
if not torch._logging._internal.log_state.is_artifact_enabled("benchmarking"):
return fn
@wraps(fn)
def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> T:
start_t = time.perf_counter()
result = fn(self, *args, **kwargs)
logger.debug(
"Call `benchmarking.%s.%s(*args=%r, **kwargs=%r)` took %f milliseconds.",
self.__class__.__name__,
fn.__name__,
args,
kwargs,
(time.perf_counter() - start_t) * MILLISECONDS_PER_SECOND,
)
return result
return wrapper
def count(fn: Callable[Concatenate[Any, P], T]) -> Callable[Concatenate[Any, P], T]:
"""Wrapper that increments relevant dynamo counters on `fn` call. It is expected that
`fn` is a method of `Benchmarker` or one of its subclass; typing limitations prevent
us from declaring this directly. The counter incrementation follows the formula,
`counters["inductor"]["benchmarking.Foo.bar] += 1`
where `Foo` is the class whose' instance called the function, and `bar` is the function name.
""" """
@wraps(fn) @wraps(fn)
def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> T: def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> T:
counters["inductor"][ fn_qual_name = f"{self.__class__.__name__}.{fn.__name__}"
"benchmarking." + self.__class__.__name__ + "." + fn.__name__ counters["inductor"][f"benchmarking.{fn_qual_name}"] += 1
] += 1 with dynamo_timed(fn_qual_name, log_pt2_compile_event=True):
return fn(self, *args, **kwargs) return fn(self, *args, **kwargs)
return wrapper return wrapper
@ -72,12 +40,11 @@ class Benchmarker:
def __init__(self: Self) -> None: def __init__(self: Self) -> None:
pass pass
@maybe_time @time_and_count
@count
def benchmark( def benchmark(
self: Self, self: Self,
fn: Callable[..., Any], fn: Callable[..., Any],
fn_args: tuple[Any, ...], fn_args: Tuple[Any, ...],
fn_kwargs: Dict[str, Any], fn_kwargs: Dict[str, Any],
**kwargs: Any, **kwargs: Any,
) -> float: ) -> float:
@ -100,31 +67,29 @@ class Benchmarker:
Returns: Returns:
- The runtime of `fn(*fn_args, **fn_kwargs)`, in milliseconds. - The runtime of `fn(*fn_args, **fn_kwargs)`, in milliseconds.
""" """
with dynamo_timed("Benchmarker.benchmark", log_pt2_compile_event=True): inferred_device = None
inferred_device = None for arg_or_kwarg in chain(fn_args, fn_kwargs.values()):
for arg_or_kwarg in chain(fn_args, fn_kwargs.values()): if not isinstance(arg_or_kwarg, torch.Tensor):
if not isinstance(arg_or_kwarg, torch.Tensor): continue
continue
if inferred_device is None:
inferred_device = arg_or_kwarg.device
elif arg_or_kwarg.device != inferred_device:
raise ValueError(
"Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`!"
)
if inferred_device is None: if inferred_device is None:
inferred_device = arg_or_kwarg.device
elif arg_or_kwarg.device != inferred_device:
raise ValueError( raise ValueError(
"Can't safely infer the device type of `fn` with no device types in `fn_args` or `fn_kwargs`! You should be calling `.benchmark_cpu` or `.benchmark_gpu` directly." # noqa: B950 "Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`!"
) )
_callable = lambda: fn(*fn_args, **fn_kwargs) # noqa: E731 if inferred_device is None:
if inferred_device == torch.device("cpu"): raise ValueError(
return self.benchmark_cpu(_callable, **kwargs) "Can't safely infer the device type of `fn` with no device types in `fn_args` or `fn_kwargs`! You should be calling `.benchmark_cpu` or `.benchmark_gpu` directly." # noqa: B950
# TODO(nmacchioni): For non-CPU functions we default to using the GPU-specific benchmarking )
# implementation which was written specifically with CUDA devices in mind, we may want to _callable = lambda: fn(*fn_args, **fn_kwargs) # noqa: E731
# explore alternate implementations for other device types. if inferred_device == torch.device("cpu"):
return self.benchmark_gpu(_callable, **kwargs) return self.benchmark_cpu(_callable, **kwargs)
# TODO(nmacchioni): For non-CPU functions we default to using the GPU-specific benchmarking
# implementation which was written specifically with CUDA devices in mind, we may want to
# explore alternate implementations for other device types.
return self.benchmark_gpu(_callable, **kwargs)
@maybe_time @time_and_count
@count
def benchmark_cpu( def benchmark_cpu(
self: Self, _callable: Callable[[], Any], warmup: int = 20, rep: int = 100 self: Self, _callable: Callable[[], Any], warmup: int = 20, rep: int = 100
) -> float: ) -> float:
@ -159,15 +124,13 @@ class Benchmarker:
run_for(warmup) run_for(warmup)
return median(run_for(rep)) return median(run_for(rep))
@count @time_and_count
def benchmark_gpu(self: Self, *args: Any, **kwargs: Any) -> float: def benchmark_gpu(self: Self, *args: Any, **kwargs: Any) -> float:
raise NotImplementedError raise NotImplementedError
class TritonBenchmarker(Benchmarker): class TritonBenchmarker(Benchmarker):
@cached_property @cached_property
@maybe_time
@count
def triton_do_bench(self: Self) -> Callable[..., Any]: def triton_do_bench(self: Self) -> Callable[..., Any]:
"""Lazily import Triton's `do_bench`.""" """Lazily import Triton's `do_bench`."""
try: try:
@ -176,8 +139,7 @@ class TritonBenchmarker(Benchmarker):
raise NotImplementedError("requires Triton") from e raise NotImplementedError("requires Triton") from e
return do_bench return do_bench
@maybe_time @time_and_count
@count
def benchmark_gpu(self: Self, _callable: Callable[[], Any], **kwargs: Any) -> float: def benchmark_gpu(self: Self, _callable: Callable[[], Any], **kwargs: Any) -> float:
"""Benchmark the GPU callable, `_callable`, and return the runtime, in milliseconds. """Benchmark the GPU callable, `_callable`, and return the runtime, in milliseconds.