mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
refactor benchmarking to use dynamo_timed (#144315)
use dynamo_timed for all our wrapped calls, instead of our custom timer Pull Request resolved: https://github.com/pytorch/pytorch/pull/144315 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
17c3a10cbd
commit
ee3e89190a
@ -81,10 +81,6 @@ class TestBenchmarker(TestCase):
|
||||
timing = benchmarker.benchmark_gpu(_callable)
|
||||
self.assertGreater(timing, 0)
|
||||
self.assertEqual(self.get_counter_value(benchmarker_cls, "benchmark_gpu"), 1)
|
||||
if benchmarker_cls is TritonBenchmarker:
|
||||
self.assertEqual(
|
||||
self.get_counter_value(benchmarker_cls, "triton_do_bench"), 1
|
||||
)
|
||||
|
||||
@unittest.skipIf(not HAS_CPU and not HAS_GPU, "requires CPU or GPU")
|
||||
@unittest.expectedFailure
|
||||
|
@ -2,7 +2,7 @@ import time
|
||||
from functools import cached_property, wraps
|
||||
from itertools import chain
|
||||
from statistics import median
|
||||
from typing import Any, Callable, Dict, List
|
||||
from typing import Any, Callable, Dict, List, Tuple
|
||||
from typing_extensions import Concatenate, ParamSpec, Self, TypeVar
|
||||
|
||||
import torch
|
||||
@ -18,52 +18,20 @@ P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def maybe_time(
|
||||
def time_and_count(
|
||||
fn: Callable[Concatenate[Any, P], T]
|
||||
) -> Callable[Concatenate[Any, P], T]:
|
||||
"""Wrapper that logs the duration of `fn`, in milliseconds, along with a representation
|
||||
of the function's args and kwargs, if logging is enabled. It is expected that `fn` is
|
||||
a method of `Benchmarker` or one of its subclasses; typing limitations prevent us from
|
||||
declaring this directly. If logging is disabled, this becomes a no-op.
|
||||
"""
|
||||
|
||||
# no-op if benchmarking-specific logging is disabled
|
||||
if not torch._logging._internal.log_state.is_artifact_enabled("benchmarking"):
|
||||
return fn
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> T:
|
||||
start_t = time.perf_counter()
|
||||
result = fn(self, *args, **kwargs)
|
||||
logger.debug(
|
||||
"Call `benchmarking.%s.%s(*args=%r, **kwargs=%r)` took %f milliseconds.",
|
||||
self.__class__.__name__,
|
||||
fn.__name__,
|
||||
args,
|
||||
kwargs,
|
||||
(time.perf_counter() - start_t) * MILLISECONDS_PER_SECOND,
|
||||
)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def count(fn: Callable[Concatenate[Any, P], T]) -> Callable[Concatenate[Any, P], T]:
|
||||
"""Wrapper that increments relevant dynamo counters on `fn` call. It is expected that
|
||||
`fn` is a method of `Benchmarker` or one of its subclass; typing limitations prevent
|
||||
us from declaring this directly. The counter incrementation follows the formula,
|
||||
|
||||
`counters["inductor"]["benchmarking.Foo.bar] += 1`
|
||||
|
||||
where `Foo` is the class whose' instance called the function, and `bar` is the function name.
|
||||
"""Wraps `fn` with `dynamo_timed` context, and increments the appropriate dynamo
|
||||
counters. It is expected that `fn` is a method of `Benchmarker` or one of its
|
||||
subclasses; typing limitations prevent us from declaring this directly.
|
||||
"""
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> T:
|
||||
counters["inductor"][
|
||||
"benchmarking." + self.__class__.__name__ + "." + fn.__name__
|
||||
] += 1
|
||||
return fn(self, *args, **kwargs)
|
||||
fn_qual_name = f"{self.__class__.__name__}.{fn.__name__}"
|
||||
counters["inductor"][f"benchmarking.{fn_qual_name}"] += 1
|
||||
with dynamo_timed(fn_qual_name, log_pt2_compile_event=True):
|
||||
return fn(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@ -72,12 +40,11 @@ class Benchmarker:
|
||||
def __init__(self: Self) -> None:
|
||||
pass
|
||||
|
||||
@maybe_time
|
||||
@count
|
||||
@time_and_count
|
||||
def benchmark(
|
||||
self: Self,
|
||||
fn: Callable[..., Any],
|
||||
fn_args: tuple[Any, ...],
|
||||
fn_args: Tuple[Any, ...],
|
||||
fn_kwargs: Dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> float:
|
||||
@ -100,31 +67,29 @@ class Benchmarker:
|
||||
Returns:
|
||||
- The runtime of `fn(*fn_args, **fn_kwargs)`, in milliseconds.
|
||||
"""
|
||||
with dynamo_timed("Benchmarker.benchmark", log_pt2_compile_event=True):
|
||||
inferred_device = None
|
||||
for arg_or_kwarg in chain(fn_args, fn_kwargs.values()):
|
||||
if not isinstance(arg_or_kwarg, torch.Tensor):
|
||||
continue
|
||||
if inferred_device is None:
|
||||
inferred_device = arg_or_kwarg.device
|
||||
elif arg_or_kwarg.device != inferred_device:
|
||||
raise ValueError(
|
||||
"Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`!"
|
||||
)
|
||||
inferred_device = None
|
||||
for arg_or_kwarg in chain(fn_args, fn_kwargs.values()):
|
||||
if not isinstance(arg_or_kwarg, torch.Tensor):
|
||||
continue
|
||||
if inferred_device is None:
|
||||
inferred_device = arg_or_kwarg.device
|
||||
elif arg_or_kwarg.device != inferred_device:
|
||||
raise ValueError(
|
||||
"Can't safely infer the device type of `fn` with no device types in `fn_args` or `fn_kwargs`! You should be calling `.benchmark_cpu` or `.benchmark_gpu` directly." # noqa: B950
|
||||
"Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`!"
|
||||
)
|
||||
_callable = lambda: fn(*fn_args, **fn_kwargs) # noqa: E731
|
||||
if inferred_device == torch.device("cpu"):
|
||||
return self.benchmark_cpu(_callable, **kwargs)
|
||||
# TODO(nmacchioni): For non-CPU functions we default to using the GPU-specific benchmarking
|
||||
# implementation which was written specifically with CUDA devices in mind, we may want to
|
||||
# explore alternate implementations for other device types.
|
||||
return self.benchmark_gpu(_callable, **kwargs)
|
||||
if inferred_device is None:
|
||||
raise ValueError(
|
||||
"Can't safely infer the device type of `fn` with no device types in `fn_args` or `fn_kwargs`! You should be calling `.benchmark_cpu` or `.benchmark_gpu` directly." # noqa: B950
|
||||
)
|
||||
_callable = lambda: fn(*fn_args, **fn_kwargs) # noqa: E731
|
||||
if inferred_device == torch.device("cpu"):
|
||||
return self.benchmark_cpu(_callable, **kwargs)
|
||||
# TODO(nmacchioni): For non-CPU functions we default to using the GPU-specific benchmarking
|
||||
# implementation which was written specifically with CUDA devices in mind, we may want to
|
||||
# explore alternate implementations for other device types.
|
||||
return self.benchmark_gpu(_callable, **kwargs)
|
||||
|
||||
@maybe_time
|
||||
@count
|
||||
@time_and_count
|
||||
def benchmark_cpu(
|
||||
self: Self, _callable: Callable[[], Any], warmup: int = 20, rep: int = 100
|
||||
) -> float:
|
||||
@ -159,15 +124,13 @@ class Benchmarker:
|
||||
run_for(warmup)
|
||||
return median(run_for(rep))
|
||||
|
||||
@count
|
||||
@time_and_count
|
||||
def benchmark_gpu(self: Self, *args: Any, **kwargs: Any) -> float:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TritonBenchmarker(Benchmarker):
|
||||
@cached_property
|
||||
@maybe_time
|
||||
@count
|
||||
def triton_do_bench(self: Self) -> Callable[..., Any]:
|
||||
"""Lazily import Triton's `do_bench`."""
|
||||
try:
|
||||
@ -176,8 +139,7 @@ class TritonBenchmarker(Benchmarker):
|
||||
raise NotImplementedError("requires Triton") from e
|
||||
return do_bench
|
||||
|
||||
@maybe_time
|
||||
@count
|
||||
@time_and_count
|
||||
def benchmark_gpu(self: Self, _callable: Callable[[], Any], **kwargs: Any) -> float:
|
||||
"""Benchmark the GPU callable, `_callable`, and return the runtime, in milliseconds.
|
||||
|
||||
|
Reference in New Issue
Block a user