mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
refactor benchmarking to use dynamo_timed (#144315)
use dynamo_timed for all our wrapped calls, instead of our custom timer Pull Request resolved: https://github.com/pytorch/pytorch/pull/144315 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
17c3a10cbd
commit
ee3e89190a
@ -81,10 +81,6 @@ class TestBenchmarker(TestCase):
|
|||||||
timing = benchmarker.benchmark_gpu(_callable)
|
timing = benchmarker.benchmark_gpu(_callable)
|
||||||
self.assertGreater(timing, 0)
|
self.assertGreater(timing, 0)
|
||||||
self.assertEqual(self.get_counter_value(benchmarker_cls, "benchmark_gpu"), 1)
|
self.assertEqual(self.get_counter_value(benchmarker_cls, "benchmark_gpu"), 1)
|
||||||
if benchmarker_cls is TritonBenchmarker:
|
|
||||||
self.assertEqual(
|
|
||||||
self.get_counter_value(benchmarker_cls, "triton_do_bench"), 1
|
|
||||||
)
|
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_CPU and not HAS_GPU, "requires CPU or GPU")
|
@unittest.skipIf(not HAS_CPU and not HAS_GPU, "requires CPU or GPU")
|
||||||
@unittest.expectedFailure
|
@unittest.expectedFailure
|
||||||
|
@ -2,7 +2,7 @@ import time
|
|||||||
from functools import cached_property, wraps
|
from functools import cached_property, wraps
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from statistics import median
|
from statistics import median
|
||||||
from typing import Any, Callable, Dict, List
|
from typing import Any, Callable, Dict, List, Tuple
|
||||||
from typing_extensions import Concatenate, ParamSpec, Self, TypeVar
|
from typing_extensions import Concatenate, ParamSpec, Self, TypeVar
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -18,52 +18,20 @@ P = ParamSpec("P")
|
|||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
def maybe_time(
|
def time_and_count(
|
||||||
fn: Callable[Concatenate[Any, P], T]
|
fn: Callable[Concatenate[Any, P], T]
|
||||||
) -> Callable[Concatenate[Any, P], T]:
|
) -> Callable[Concatenate[Any, P], T]:
|
||||||
"""Wrapper that logs the duration of `fn`, in milliseconds, along with a representation
|
"""Wraps `fn` with `dynamo_timed` context, and increments the appropriate dynamo
|
||||||
of the function's args and kwargs, if logging is enabled. It is expected that `fn` is
|
counters. It is expected that `fn` is a method of `Benchmarker` or one of its
|
||||||
a method of `Benchmarker` or one of its subclasses; typing limitations prevent us from
|
subclasses; typing limitations prevent us from declaring this directly.
|
||||||
declaring this directly. If logging is disabled, this becomes a no-op.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# no-op if benchmarking-specific logging is disabled
|
|
||||||
if not torch._logging._internal.log_state.is_artifact_enabled("benchmarking"):
|
|
||||||
return fn
|
|
||||||
|
|
||||||
@wraps(fn)
|
|
||||||
def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> T:
|
|
||||||
start_t = time.perf_counter()
|
|
||||||
result = fn(self, *args, **kwargs)
|
|
||||||
logger.debug(
|
|
||||||
"Call `benchmarking.%s.%s(*args=%r, **kwargs=%r)` took %f milliseconds.",
|
|
||||||
self.__class__.__name__,
|
|
||||||
fn.__name__,
|
|
||||||
args,
|
|
||||||
kwargs,
|
|
||||||
(time.perf_counter() - start_t) * MILLISECONDS_PER_SECOND,
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
def count(fn: Callable[Concatenate[Any, P], T]) -> Callable[Concatenate[Any, P], T]:
|
|
||||||
"""Wrapper that increments relevant dynamo counters on `fn` call. It is expected that
|
|
||||||
`fn` is a method of `Benchmarker` or one of its subclass; typing limitations prevent
|
|
||||||
us from declaring this directly. The counter incrementation follows the formula,
|
|
||||||
|
|
||||||
`counters["inductor"]["benchmarking.Foo.bar] += 1`
|
|
||||||
|
|
||||||
where `Foo` is the class whose' instance called the function, and `bar` is the function name.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@wraps(fn)
|
@wraps(fn)
|
||||||
def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> T:
|
def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> T:
|
||||||
counters["inductor"][
|
fn_qual_name = f"{self.__class__.__name__}.{fn.__name__}"
|
||||||
"benchmarking." + self.__class__.__name__ + "." + fn.__name__
|
counters["inductor"][f"benchmarking.{fn_qual_name}"] += 1
|
||||||
] += 1
|
with dynamo_timed(fn_qual_name, log_pt2_compile_event=True):
|
||||||
return fn(self, *args, **kwargs)
|
return fn(self, *args, **kwargs)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
@ -72,12 +40,11 @@ class Benchmarker:
|
|||||||
def __init__(self: Self) -> None:
|
def __init__(self: Self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@maybe_time
|
@time_and_count
|
||||||
@count
|
|
||||||
def benchmark(
|
def benchmark(
|
||||||
self: Self,
|
self: Self,
|
||||||
fn: Callable[..., Any],
|
fn: Callable[..., Any],
|
||||||
fn_args: tuple[Any, ...],
|
fn_args: Tuple[Any, ...],
|
||||||
fn_kwargs: Dict[str, Any],
|
fn_kwargs: Dict[str, Any],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> float:
|
) -> float:
|
||||||
@ -100,31 +67,29 @@ class Benchmarker:
|
|||||||
Returns:
|
Returns:
|
||||||
- The runtime of `fn(*fn_args, **fn_kwargs)`, in milliseconds.
|
- The runtime of `fn(*fn_args, **fn_kwargs)`, in milliseconds.
|
||||||
"""
|
"""
|
||||||
with dynamo_timed("Benchmarker.benchmark", log_pt2_compile_event=True):
|
inferred_device = None
|
||||||
inferred_device = None
|
for arg_or_kwarg in chain(fn_args, fn_kwargs.values()):
|
||||||
for arg_or_kwarg in chain(fn_args, fn_kwargs.values()):
|
if not isinstance(arg_or_kwarg, torch.Tensor):
|
||||||
if not isinstance(arg_or_kwarg, torch.Tensor):
|
continue
|
||||||
continue
|
|
||||||
if inferred_device is None:
|
|
||||||
inferred_device = arg_or_kwarg.device
|
|
||||||
elif arg_or_kwarg.device != inferred_device:
|
|
||||||
raise ValueError(
|
|
||||||
"Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`!"
|
|
||||||
)
|
|
||||||
if inferred_device is None:
|
if inferred_device is None:
|
||||||
|
inferred_device = arg_or_kwarg.device
|
||||||
|
elif arg_or_kwarg.device != inferred_device:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Can't safely infer the device type of `fn` with no device types in `fn_args` or `fn_kwargs`! You should be calling `.benchmark_cpu` or `.benchmark_gpu` directly." # noqa: B950
|
"Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`!"
|
||||||
)
|
)
|
||||||
_callable = lambda: fn(*fn_args, **fn_kwargs) # noqa: E731
|
if inferred_device is None:
|
||||||
if inferred_device == torch.device("cpu"):
|
raise ValueError(
|
||||||
return self.benchmark_cpu(_callable, **kwargs)
|
"Can't safely infer the device type of `fn` with no device types in `fn_args` or `fn_kwargs`! You should be calling `.benchmark_cpu` or `.benchmark_gpu` directly." # noqa: B950
|
||||||
# TODO(nmacchioni): For non-CPU functions we default to using the GPU-specific benchmarking
|
)
|
||||||
# implementation which was written specifically with CUDA devices in mind, we may want to
|
_callable = lambda: fn(*fn_args, **fn_kwargs) # noqa: E731
|
||||||
# explore alternate implementations for other device types.
|
if inferred_device == torch.device("cpu"):
|
||||||
return self.benchmark_gpu(_callable, **kwargs)
|
return self.benchmark_cpu(_callable, **kwargs)
|
||||||
|
# TODO(nmacchioni): For non-CPU functions we default to using the GPU-specific benchmarking
|
||||||
|
# implementation which was written specifically with CUDA devices in mind, we may want to
|
||||||
|
# explore alternate implementations for other device types.
|
||||||
|
return self.benchmark_gpu(_callable, **kwargs)
|
||||||
|
|
||||||
@maybe_time
|
@time_and_count
|
||||||
@count
|
|
||||||
def benchmark_cpu(
|
def benchmark_cpu(
|
||||||
self: Self, _callable: Callable[[], Any], warmup: int = 20, rep: int = 100
|
self: Self, _callable: Callable[[], Any], warmup: int = 20, rep: int = 100
|
||||||
) -> float:
|
) -> float:
|
||||||
@ -159,15 +124,13 @@ class Benchmarker:
|
|||||||
run_for(warmup)
|
run_for(warmup)
|
||||||
return median(run_for(rep))
|
return median(run_for(rep))
|
||||||
|
|
||||||
@count
|
@time_and_count
|
||||||
def benchmark_gpu(self: Self, *args: Any, **kwargs: Any) -> float:
|
def benchmark_gpu(self: Self, *args: Any, **kwargs: Any) -> float:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class TritonBenchmarker(Benchmarker):
|
class TritonBenchmarker(Benchmarker):
|
||||||
@cached_property
|
@cached_property
|
||||||
@maybe_time
|
|
||||||
@count
|
|
||||||
def triton_do_bench(self: Self) -> Callable[..., Any]:
|
def triton_do_bench(self: Self) -> Callable[..., Any]:
|
||||||
"""Lazily import Triton's `do_bench`."""
|
"""Lazily import Triton's `do_bench`."""
|
||||||
try:
|
try:
|
||||||
@ -176,8 +139,7 @@ class TritonBenchmarker(Benchmarker):
|
|||||||
raise NotImplementedError("requires Triton") from e
|
raise NotImplementedError("requires Triton") from e
|
||||||
return do_bench
|
return do_bench
|
||||||
|
|
||||||
@maybe_time
|
@time_and_count
|
||||||
@count
|
|
||||||
def benchmark_gpu(self: Self, _callable: Callable[[], Any], **kwargs: Any) -> float:
|
def benchmark_gpu(self: Self, _callable: Callable[[], Any], **kwargs: Any) -> float:
|
||||||
"""Benchmark the GPU callable, `_callable`, and return the runtime, in milliseconds.
|
"""Benchmark the GPU callable, `_callable`, and return the runtime, in milliseconds.
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user