Safely infer device type + docstrings + tests (#133668)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133668
Approved by: https://github.com/eellison
This commit is contained in:
Nicolas Macchioni
2024-08-20 11:49:34 -07:00
committed by PyTorch MergeBot
parent b39ec7fbe9
commit af664882dd
2 changed files with 206 additions and 30 deletions

View File

@ -0,0 +1,113 @@
# Owner(s): ["module: inductor"]
import unittest
import torch
from torch._dynamo.utils import counters
from torch._inductor.runtime.benchmarking import Benchmarker, TritonBenchmarker
from torch._inductor.test_case import run_tests, TestCase
from torch.testing._internal.common_utils import (
decorateIf,
instantiate_parametrized_tests,
parametrize,
)
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
ALL_BENCHMARKER_CLASSES = (
Benchmarker,
TritonBenchmarker,
)
@instantiate_parametrized_tests
class TestBenchmarker(TestCase):
def setUp(self):
super().setUp()
torch.manual_seed(12345)
counters.clear()
@staticmethod
def get_counter_value(benchmarker_cls, fn_name):
return counters["inductor"][
f"benchmarking.{benchmarker_cls.__name__}.{fn_name}"
]
@staticmethod
def make_params(device, size=100):
fn, fn_args, fn_kwargs = torch.sum, (torch.randn(size, device=device),), {}
_callable = lambda: fn(*fn_args, **fn_kwargs) # noqa: E731
return (fn, fn_args, fn_kwargs), _callable
@unittest.skipIf(not HAS_CPU or not HAS_GPU, "requires CPU and GPU")
@decorateIf(
unittest.expectedFailure,
lambda params: params["benchmarker_cls"] is Benchmarker
and params["device"] == GPU_TYPE,
)
@parametrize("benchmarker_cls", ALL_BENCHMARKER_CLASSES)
@parametrize("device", (GPU_TYPE, "cpu"))
def test_benchmark_smoke(self, benchmarker_cls, device):
benchmarker = benchmarker_cls()
(fn, fn_args, fn_kwargs), _ = self.make_params(device)
timing = benchmarker.benchmark(fn, fn_args, fn_kwargs)
self.assertGreater(timing, 0)
self.assertEqual(self.get_counter_value(benchmarker_cls, "benchmark"), 1)
self.assertEqual(
self.get_counter_value(
benchmarker_cls, "benchmark_cpu" if device == "cpu" else "benchmark_gpu"
),
1,
)
@unittest.skipIf(not HAS_CPU, "requires CPU")
@parametrize("benchmarker_cls", ALL_BENCHMARKER_CLASSES)
def test_benchmark_cpu_smoke(self, benchmarker_cls, device="cpu"):
benchmarker = benchmarker_cls()
_, _callable = self.make_params(device)
timing = benchmarker.benchmark_cpu(_callable)
self.assertGreater(timing, 0)
self.assertEqual(self.get_counter_value(benchmarker_cls, "benchmark_cpu"), 1)
@unittest.skipIf(not HAS_GPU, "requires GPU")
@decorateIf(
unittest.expectedFailure,
lambda params: params["benchmarker_cls"] is Benchmarker,
)
@parametrize("benchmarker_cls", ALL_BENCHMARKER_CLASSES)
def test_benchmark_gpu_smoke(self, benchmarker_cls, device=GPU_TYPE):
benchmarker = benchmarker_cls()
_, _callable = self.make_params(device)
timing = benchmarker.benchmark_gpu(_callable)
self.assertGreater(timing, 0)
self.assertEqual(self.get_counter_value(benchmarker_cls, "benchmark_gpu"), 1)
if benchmarker_cls is TritonBenchmarker:
self.assertEqual(
self.get_counter_value(benchmarker_cls, "triton_do_bench"), 1
)
@unittest.skipIf(not HAS_CPU and not HAS_GPU, "requires CPU or GPU")
@unittest.expectedFailure
@parametrize("benchmarker_cls", ALL_BENCHMARKER_CLASSES)
def test_benchmark_safely_infers_device_no_devices(
self, benchmarker_cls, device="cpu" if HAS_CPU else GPU_TYPE
):
benchmarker = benchmarker_cls()
(fn, _, _), _ = self.make_params(device)
benchmarker.benchmark(fn, (), {})
@unittest.skipIf(not HAS_CPU or not HAS_GPU, "requires CPU and GPU")
@unittest.expectedFailure
@parametrize("benchmarker_cls", ALL_BENCHMARKER_CLASSES)
def test_benchmark_safely_infers_device_many_devices(self, benchmarker_cls):
benchmarker = benchmarker_cls()
(fn, cpu_args, cpu_kwargs), _ = self.make_sum("cpu")
(_, gpu_args, gpu_kwargs), _ = self.make_sum(GPU_TYPE)
many_devices_args = cpu_args + gpu_args
many_devices_kwargs = cpu_kwargs
many_devices_kwargs.update(gpu_kwargs)
benchmarker.benchmark(fn, many_devices_args, many_devices_kwargs)
if __name__ == "__main__":
run_tests()

View File

@ -1,14 +1,15 @@
import time
from functools import cached_property, wraps
from itertools import chain
from statistics import median
from typing import Any, Callable, Dict, List, Tuple
from typing_extensions import ParamSpec, Self, TypeVar
from typing_extensions import Concatenate, ParamSpec, Self, TypeVar
import torch
from torch._inductor.utils import is_cpu_device
from torch._dynamo.utils import counters
log = torch._logging.getArtifactLogger(__name__, "benchmarking")
logger = torch._logging.getArtifactLogger(__name__, "benchmarking")
MILLISECONDS_PER_SECOND = 1000
@ -17,31 +18,62 @@ P = ParamSpec("P")
T = TypeVar("T")
def maybe_time(fn: Callable[P, T]) -> Callable[P, T]:
def maybe_time(
fn: Callable[Concatenate[Any, P], T]
) -> Callable[Concatenate[Any, P], T]:
"""Wrapper that logs the duration of `fn`, in milliseconds, along with a representation
of the function's args and kwargs, if logging is enabled. It is expected that `fn` is
a method of `Benchmarker` or one of its subclasses; typing limitations prevent us from
declaring this directly. If logging is disabled, this becomes a no-op.
"""
# no-op if benchmarking-specific logging is disabled
if not torch._logging._internal.log_state.is_artifact_enabled("benchmarking"):
return fn
@wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> Any:
start_s = time.perf_counter()
def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> T:
start_t = time.perf_counter()
result = fn(*args, **kwargs)
log.debug(
"fn:%r args:[%r, %r] took %f seconds.",
logger.debug(
"Call `benchmarking.%s.%s(*args=%r, **kwargs=%r)` took %f milliseconds.",
self.__class__.__name__,
fn.__name__,
args,
kwargs,
time.perf_counter() - start_s,
(time.perf_counter() - start_t) * MILLISECONDS_PER_SECOND,
)
return result
return wrapper
def count(fn: Callable[Concatenate[Any, P], T]) -> Callable[Concatenate[Any, P], T]:
"""Wrapper that increments relevant dynamo counters on `fn` call. It is expected that
`fn` is a method of `Benchmarker` or one of its subclass; typing limitations prevent
us from declaring this directly. The counter incrementation follows the formula,
`counters["inductor"]["benchmarking.Foo.bar] += 1`
where `Foo` is the class whose' instance called the function, and `bar` is the function name.
"""
@wraps(fn)
def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> T:
counters["inductor"][
"benchmarking." + self.__class__.__name__ + "." + fn.__name__
] += 1
return fn(self, *args, **kwargs)
return wrapper
class Benchmarker:
def __init__(self: Self) -> None:
pass
@maybe_time
@count
def benchmark(
self: Self,
fn: Callable[..., Any],
@ -49,8 +81,13 @@ class Benchmarker:
fn_kwargs: Dict[str, Any],
**kwargs: Any,
) -> float:
"""Construct benchmarkable callable and dispatch benchmark request to the appropriate
benchmarking function depending on the device type of `fn_args` and `fn_kwargs`.
"""Benchmark `fn(*fn_args, *fn_kwargs)` and return the runtime, in milliseconds (the
actual runtime calculation is dictated by the benchmarking implementation, but may be
one of [mean, median, minimum, etc.]). Functions as a convenience wrapper around
device-specific implementations, like `benchmark_cpu` and `benchmark_gpu`. Raises
`ValueError(...)` if we can't safely infer the device type of `fn`; for example,
if multiple device types are found in `fn_args` and `fn_kwargs`, or if no device
types are found.
Arguments:
- fn: The function to benchmark.
@ -58,27 +95,49 @@ class Benchmarker:
- fn_kwargs: The function's kwargs.
Keyword Arguments:
- **kwargs: The benchmarker's keyword arguments.
- **kwargs: The benchmarking implementation's kwargs.
Returns:
- The runtime of `fn(*fn_args, **fn_kwargs)`, in milliseconds.
"""
if is_cpu_device(list(fn_args) + list(fn_kwargs.values())):
return self.benchmark_cpu(lambda: fn(*fn_args, **fn_kwargs), **kwargs)
return self.benchmark_gpu(lambda: fn(*fn_args, **fn_kwargs), **kwargs)
inferred_device = None
for arg_or_kwarg in chain(fn_args, fn_kwargs.values()):
if not isinstance(arg_or_kwarg, torch.Tensor):
continue
if inferred_device is None:
inferred_device = arg_or_kwarg.device
elif arg_or_kwarg.device != inferred_device:
raise ValueError(
"Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`!"
)
if inferred_device is None:
raise ValueError(
"Can't safely infer the device type of `fn` with no device types in `fn_args` or `fn_kwargs`! You should be calling `.benchmark_cpu` or `.benchmark_gpu` directly." # noqa: B950
)
_callable = lambda: fn(*fn_args, **fn_kwargs) # noqa: E731
if inferred_device == torch.device("cpu"):
return self.benchmark_cpu(_callable, **kwargs)
# TODO(nmacchioni): For non-CPU functions we default to using the GPU-specific benchmarking
# implementation which was written specifically with CUDA devices in mind, we may want to
# explore alternate implementations for other device types.
return self.benchmark_gpu(_callable, **kwargs)
@maybe_time
@count
def benchmark_cpu(
self: Self, _callable: Callable[[], Any], warmup: int = 20, rep: int = 100
) -> float:
"""Benchmark a CPU callable.
"""Benchmark the CPU callable, `_callable`, and return the median runtime,
in milliseconds.
Arguments:
- _callable: The callable to benchmark.
- _callable: The CPU callable to benchmark.
Keyword Arguments:
- warmup: Duration to run the callable before benchmarking, in milliseconds.
- rep: Duration to run the benchmarking, in milliseconds.
- warmup: Optionally, the duration, in milliseconds, to run `_callable`
before benchmarking starts.
- rep: Optionally, the duration, in milliseconds, to run `_callable`
during benchmarking.
Returns:
- The median runtime of `_callable`, in milliseconds.
@ -86,19 +145,20 @@ class Benchmarker:
def run_for(ms: int) -> List[float]:
timings = []
run_start_s = time.perf_counter()
run_start_t = time.perf_counter()
while True:
start_s = time.perf_counter()
start_t = time.perf_counter()
_callable()
end_s = time.perf_counter()
timings.append((end_s - start_s) * MILLISECONDS_PER_SECOND)
if ((end_s - run_start_s) * MILLISECONDS_PER_SECOND) > ms:
end_t = time.perf_counter()
timings.append((end_t - start_t) * MILLISECONDS_PER_SECOND)
if ((end_t - run_start_t) * MILLISECONDS_PER_SECOND) > ms:
break
return timings
run_for(warmup)
return median(run_for(rep))
@count
def benchmark_gpu(self: Self, *args: Any, **kwargs: Any) -> float:
raise NotImplementedError
@ -106,8 +166,9 @@ class Benchmarker:
class TritonBenchmarker(Benchmarker):
@cached_property
@maybe_time
@count
def triton_do_bench(self: Self) -> Callable[..., Any]:
"""Lazily import Triton's do_bench."""
"""Lazily import Triton's `do_bench`."""
try:
from triton.testing import do_bench
except ImportError as e:
@ -115,16 +176,18 @@ class TritonBenchmarker(Benchmarker):
return do_bench
@maybe_time
@count
def benchmark_gpu(self: Self, _callable: Callable[[], Any], **kwargs: Any) -> float:
"""Benchmark a GPU callable using Triton's do_bench.
"""Benchmark the GPU callable, `_callable`, and return the runtime, in milliseconds.
Arguments:
- _callable: The callable to benchmark.
- _callable: The GPU callable to benchmark.
Keyword Arguments:
- quantiles: A tuple of floats denoting the requested quantiles.
- return_mode: The requested return mode, one of "min", "max", "mean", or "median".
- **kwargs: Additional kwargs passed to triton.testing.do_bench.
- quantiles: Optionally, a tuple of floats denoting the requested quantiles.
- return_mode: Optionally, the requested return mode. Currently, Triton's
`do_bench` supports min, max, mean, and median return modes.
- **kwargs: Additional kwargs passed to Triton's `do_bench`.
Returns:
- The runtime of `callable`, in milliseconds. If `kwargs["quantiles"]` is specified,