mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BC breaking] move benchmarking + prefer inductor path (#132827)
move benchmarking out of `torch._inductor.runtime.runtime_utils` and into `torch._inductor.runtime.benchmarking`, and prefer this path over directly accessing Triton's benchmarking Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/132827 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
a9036e1cf8
commit
5cb05a82b4
@ -6,6 +6,7 @@ from prettytable import PrettyTable
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch._inductor.config
|
||||
from torch._inductor.runtime.benchmarking import benchmarker
|
||||
|
||||
|
||||
# torch._inductor.config.debug = True
|
||||
@ -74,12 +75,12 @@ def bench(shape, layer_id, p, fusion_types=[""]):
|
||||
return fn_mm(*args)
|
||||
|
||||
torch._inductor.config.triton.mm = "aten"
|
||||
torch_mm_ms, _, _ = triton.testing.do_bench(fn)
|
||||
torch_mm_ms, _, _ = benchmarker.benchmark_gpu(fn)
|
||||
torch._inductor.config.triton.mm = "triton"
|
||||
# reset to force code gen new python code
|
||||
torch._dynamo.reset()
|
||||
torch._inductor.metrics.reset()
|
||||
triton_mm_ms, _, _ = triton.testing.do_bench(fn)
|
||||
triton_mm_ms, _, _ = benchmarker.benchmark_gpu(fn)
|
||||
assert (
|
||||
torch._inductor.metrics.generated_kernel_count == 1
|
||||
), "codegen #kernel != 1"
|
||||
|
@ -5,6 +5,7 @@ import torch
|
||||
import torch._dynamo
|
||||
import torch._dynamo.config
|
||||
import torch._inductor.config as config
|
||||
from torch._inductor.runtime.benchmarking import benchmarker
|
||||
|
||||
|
||||
# The flag below controls whether to allow TF32 on matmul. This flag defaults to True.
|
||||
@ -74,10 +75,12 @@ def test_GPU_time(shapes):
|
||||
config.triton.mm = "triton"
|
||||
inductor_triton_mm(a, b)
|
||||
|
||||
torch_ms, _, _ = triton.testing.do_bench(lambda: torch_mm(a, b))
|
||||
triton_ms, _, _ = triton.testing.do_bench(lambda: triton_mm(a, b))
|
||||
ind_aten_ms, _, _ = triton.testing.do_bench(lambda: inductor_aten_mm(a, b))
|
||||
ind_triton_ms, _, _ = triton.testing.do_bench(lambda: inductor_triton_mm(a, b))
|
||||
torch_ms, _, _ = benchmarker.benchmark_gpu(lambda: torch_mm(a, b))
|
||||
triton_ms, _, _ = benchmarker.benchmark_gpu(lambda: triton_mm(a, b))
|
||||
ind_aten_ms, _, _ = benchmarker.benchmark_gpu(lambda: inductor_aten_mm(a, b))
|
||||
ind_triton_ms, _, _ = benchmarker.benchmark_gpu(
|
||||
lambda: inductor_triton_mm(a, b)
|
||||
)
|
||||
print(torch_ms, triton_ms, ind_aten_ms, ind_triton_ms, sep="; ")
|
||||
|
||||
torch._dynamo.reset()
|
||||
|
@ -10,6 +10,7 @@ from torch._dynamo.testing import same
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
from torch._inductor.decomposition import decompositions
|
||||
from torch._inductor.lowering import lowerings
|
||||
from torch._inductor.runtime.benchmarking import benchmarker
|
||||
from torch._inductor.utils import gen_gm_and_inputs
|
||||
from torch.utils._pytree import tree_map_only
|
||||
|
||||
@ -37,13 +38,11 @@ def compute_speedups(
|
||||
# interleave the runs to handle frequency scaling and load changes
|
||||
for m, model in enumerate(models):
|
||||
if device == "cuda":
|
||||
import triton
|
||||
|
||||
model(*example_inputs)
|
||||
|
||||
# do_bench() clears L2 cache to hide the latency of CPU launch time
|
||||
# benchmarker.benchmark_gpu() clears L2 cache to hide the latency of CPU launch time
|
||||
# along with cuda synchronization
|
||||
timings[rep, m] = triton.testing.do_bench(
|
||||
timings[rep, m] = benchmarker.benchmark_gpu(
|
||||
lambda: model(*example_inputs)
|
||||
)
|
||||
else:
|
||||
|
@ -1,6 +1,6 @@
|
||||
import torch
|
||||
from torch._inductor import ir
|
||||
from torch._inductor.runtime.runtime_utils import do_bench
|
||||
from torch._inductor.runtime.benchmarking import benchmarker
|
||||
|
||||
|
||||
def to_channels_last(x):
|
||||
@ -54,8 +54,8 @@ def bench_conv(with_stack=True):
|
||||
test_out[0][0][0][:32],
|
||||
)
|
||||
|
||||
baseline_ms = do_bench(baseline_fn, rep=40)
|
||||
test_ms = do_bench(test_fn, rep=40)
|
||||
baseline_ms = benchmarker.benchmark_gpu(baseline_fn, rep=40)
|
||||
test_ms = benchmarker.benchmark_gpu(test_fn, rep=40)
|
||||
print(f"baseline {baseline_ms} test {test_ms} speedup {baseline_ms / test_ms:.3f}x")
|
||||
|
||||
|
||||
|
@ -4,10 +4,10 @@ import dataclasses
|
||||
import os
|
||||
|
||||
from generate import run_llama2_7b_bf16, run_llama2_7b_int8, run_mixtral_8x7b_int8
|
||||
from triton.testing import do_bench
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch._inductor.runtime.benchmarking import benchmarker
|
||||
from torch.utils.flop_counter import FlopCounterMode
|
||||
|
||||
|
||||
@ -71,7 +71,7 @@ def run_mlp_layer_norm_gelu(device: str = "cuda"):
|
||||
for _ in range(WARMUP_ITER):
|
||||
compiled_mod(x)
|
||||
|
||||
us_per_iter = do_bench(lambda: compiled_mod(x)) * 1000
|
||||
us_per_iter = benchmarker.benchmark_gpu(lambda: compiled_mod(x)) * 1000
|
||||
flops_utilization += us_per_iter * flops / 1e9 / A100_40G_BF16_TFLOPS
|
||||
|
||||
flops_utilization = flops_utilization / len(input_shapes)
|
||||
@ -108,7 +108,7 @@ def run_layer_norm(device: str = "cuda"):
|
||||
for _ in range(WARMUP_ITER):
|
||||
compiled_mod(x)
|
||||
|
||||
us_per_iter = do_bench(lambda: compiled_mod(x)) * 1000
|
||||
us_per_iter = benchmarker.benchmark_gpu(lambda: compiled_mod(x)) * 1000
|
||||
memory_bandwidth += (1e6 / us_per_iter) * 2 * BS * D * dtype.itemsize / 1e9
|
||||
|
||||
memory_bandwidth = memory_bandwidth / len(input_shapes)
|
||||
@ -151,7 +151,9 @@ def run_gather_gemv(device: str = "cuda"):
|
||||
for _ in range(WARMUP_ITER):
|
||||
compiled_fn(W, score_idxs, x)
|
||||
|
||||
us_per_iter = do_bench(lambda: compiled_fn(W, score_idxs, x)) * 1000
|
||||
us_per_iter = (
|
||||
benchmarker.benchmark_gpu(lambda: compiled_fn(W, score_idxs, x)) * 1000
|
||||
)
|
||||
memory_bandwidth += (1e6 / us_per_iter) * 2 * D * D * dtype.itemsize / 1e9
|
||||
|
||||
memory_bandwidth = memory_bandwidth / len(input_shapes)
|
||||
@ -192,7 +194,7 @@ def run_gemv(device: str = "cuda"):
|
||||
for _ in range(WARMUP_ITER):
|
||||
compiled_fn(W, x)
|
||||
|
||||
us_per_iter = do_bench(lambda: compiled_fn(W, x)) * 1000
|
||||
us_per_iter = benchmarker.benchmark_gpu(lambda: compiled_fn(W, x)) * 1000
|
||||
memory_bandwidth += (1e6 / us_per_iter) * D * D * dtype.itemsize / 1e9
|
||||
|
||||
memory_bandwidth = memory_bandwidth / len(input_shapes)
|
||||
|
@ -1,4 +1,5 @@
|
||||
import torch
|
||||
from torch._inductor.runtime.benchmarking import benchmarker
|
||||
|
||||
|
||||
def create_blocked_tensor(B, M, N, blocksize, sparsity, dtype, device):
|
||||
@ -27,9 +28,7 @@ def create_blocked_tensor(B, M, N, blocksize, sparsity, dtype, device):
|
||||
|
||||
|
||||
def _test_worker(test_func):
|
||||
import triton
|
||||
|
||||
ms, ms_min, ms_max = triton.testing.do_bench(
|
||||
ms, ms_min, ms_max = benchmarker.benchmark_gpu(
|
||||
test_func, warmup=500, rep=100, fast_flush=False
|
||||
)
|
||||
|
||||
|
@ -24,14 +24,14 @@ torch._dynamo.config.automatic_dynamic_shapes = False
|
||||
torch._dynamo.config.cache_size_limit = 1000
|
||||
|
||||
|
||||
from triton.testing import do_bench
|
||||
from torch._inductor.runtime.benchmarking import benchmarker
|
||||
|
||||
|
||||
def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
|
||||
# warmup
|
||||
for _ in range(5):
|
||||
func(*args, **kwargs)
|
||||
return do_bench(lambda: func(*args, **kwargs)) * 1e3
|
||||
return benchmarker.benchmark_gpu(lambda: func(*args, **kwargs)) * 1e3
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
@ -721,6 +721,7 @@ exclusions = {
|
||||
"export",
|
||||
"trace_shape_events",
|
||||
"cudagraph_static_inputs",
|
||||
"benchmarking",
|
||||
}
|
||||
for name in torch._logging._internal.log_registry.artifact_names:
|
||||
if name not in exclusions:
|
||||
|
@ -3,7 +3,7 @@ import os
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch._inductor.runtime.runtime_utils import do_bench
|
||||
from torch._inductor.runtime.benchmarking import benchmarker
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||
@ -174,7 +174,7 @@ class B2BGEMMTest(TestCase):
|
||||
return torch.mm(torch.mm(m1, m2), m3)
|
||||
|
||||
f_opt = torch.compile(f, dynamic=False)
|
||||
return do_bench(f_opt, (m1, m2, m3), {}, warmup=100, rep=500)
|
||||
return benchmarker.benchmark(f_opt, (m1, m2, m3), {}, warmup=100, rep=500)
|
||||
|
||||
@torch._inductor.config.patch(b2b_gemm_pass=True)
|
||||
def run_with_b2b_gemm_on(
|
||||
@ -184,7 +184,7 @@ class B2BGEMMTest(TestCase):
|
||||
return torch.mm(torch.mm(m1, m2), m3)
|
||||
|
||||
f_opt = torch.compile(f, dynamic=False)
|
||||
return do_bench(f_opt, (m1, m2, m3), {}, warmup=100, rep=500)
|
||||
return benchmarker.benchmark(f_opt, (m1, m2, m3), {}, warmup=100, rep=500)
|
||||
|
||||
Ms = [128, 256, 300, 400, 512]
|
||||
Ns = [16, 20, 32, 40, 50, 64]
|
||||
@ -230,7 +230,7 @@ class B2BGEMMTest(TestCase):
|
||||
return torch.mm(g(torch.mm(m1, m2)), m3)
|
||||
|
||||
f_opt = torch.compile(f, dynamic=False)
|
||||
return do_bench(f_opt, (m1, m2, m3), {}, warmup=100, rep=500)
|
||||
return benchmarker.benchmark(f_opt, (m1, m2, m3), {}, warmup=100, rep=500)
|
||||
|
||||
@torch._inductor.config.patch(b2b_gemm_pass=True)
|
||||
def run_with_b2b_gemm_on(
|
||||
@ -241,7 +241,7 @@ class B2BGEMMTest(TestCase):
|
||||
return torch.mm(g(torch.mm(m1, m2)), m3)
|
||||
|
||||
f_opt = torch.compile(f, dynamic=False)
|
||||
return do_bench(f_opt, (m1, m2, m3), {}, warmup=100, rep=500)
|
||||
return benchmarker.benchmark(f_opt, (m1, m2, m3), {}, warmup=100, rep=500)
|
||||
|
||||
Ms = [128, 256, 300, 400, 512]
|
||||
Ns = [16, 20, 32, 40, 50, 64]
|
||||
@ -287,7 +287,7 @@ class B2BGEMMTest(TestCase):
|
||||
return torch.mm(g(torch.mm(m1, m2)), m3)
|
||||
|
||||
f_opt = torch.compile(f, dynamic=False)
|
||||
return do_bench(f_opt, (m1, m2, m3), {}, warmup=100, rep=500)
|
||||
return benchmarker.benchmark(f_opt, (m1, m2, m3), {}, warmup=100, rep=500)
|
||||
|
||||
@torch._inductor.config.patch(b2b_gemm_pass=True)
|
||||
def run_with_b2b_gemm_on(
|
||||
@ -298,7 +298,7 @@ class B2BGEMMTest(TestCase):
|
||||
return torch.mm(g(torch.mm(m1, m2)), m3)
|
||||
|
||||
f_opt = torch.compile(f, dynamic=False)
|
||||
return do_bench(f_opt, (m1, m2, m3), {}, warmup=100, rep=500)
|
||||
return benchmarker.benchmark(f_opt, (m1, m2, m3), {}, warmup=100, rep=500)
|
||||
|
||||
Ms = [128, 256, 300, 400, 512]
|
||||
Ns = [16, 20, 32, 40, 50, 64]
|
||||
|
@ -8,7 +8,7 @@ import torch
|
||||
from torch._inductor.codegen.cpp import cexpr
|
||||
from torch._inductor.codegen.triton import texpr
|
||||
from torch._inductor.codegen.wrapper import pexpr
|
||||
from torch._inductor.runtime.runtime_utils import do_bench_gpu
|
||||
from torch._inductor.runtime.benchmarking import benchmarker
|
||||
from torch._inductor.sizevars import SizeVarAllocator
|
||||
from torch._inductor.test_case import TestCase as InductorTestCase
|
||||
from torch._inductor.utils import run_and_get_triton_code
|
||||
@ -237,7 +237,7 @@ class TestIndexingSimplification(InductorTestCase):
|
||||
# tl.load(in_ptr0 + ((5504*x1) + (x0 // 2)),
|
||||
self.assertEqual(2, triton_code.count("tl.load(in_ptr0 + ((x2 // 2)),"))
|
||||
if DO_PERF_TEST:
|
||||
ms = do_bench_gpu(lambda: f(x))
|
||||
ms = benchmarker.benchmark_gpu(lambda: f(x))
|
||||
print(f"{ms=:.03f}")
|
||||
|
||||
|
||||
|
@ -4,7 +4,7 @@ import functools
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch._inductor.runtime.runtime_utils import do_bench
|
||||
from torch._inductor.runtime.benchmarking import benchmarker
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch._inductor.utils import do_bench_using_profiling
|
||||
|
||||
@ -20,8 +20,8 @@ class TestBench(TestCase):
|
||||
w = torch.rand(512, 10).cuda().half()
|
||||
cls._bench_fn = functools.partial(torch.nn.functional.linear, x, w)
|
||||
|
||||
def test_do_bench(self):
|
||||
res = do_bench(self._bench_fn, (), {})
|
||||
def test_benchmarker(self):
|
||||
res = benchmarker.benchmark(self._bench_fn, (), {})
|
||||
log.warning("do_bench result: %s", res)
|
||||
self.assertGreater(res, 0)
|
||||
|
||||
|
@ -11,7 +11,7 @@ from torch._dynamo.test_case import run_tests, TestCase
|
||||
from torch._dynamo.testing import rand_strided, reduce_to_scalar_loss
|
||||
from torch._inductor import config, ir, metrics
|
||||
from torch._inductor.fx_passes import pad_mm as pad_mm_pass
|
||||
from torch._inductor.runtime.runtime_utils import do_bench
|
||||
from torch._inductor.runtime.benchmarking import benchmarker
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
from torch.testing._internal.common_utils import requires_cuda, serialTest
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||
@ -169,10 +169,10 @@ class PerfTestBetweenGoodAndBadShape(TestCaseBase):
|
||||
m_bad_shape_opt = torch.compile(m_bad_shape)
|
||||
m_good_shape_opt = torch.compile(m_good_shape)
|
||||
|
||||
latency_good_shape = do_bench(
|
||||
latency_good_shape = benchmarker.benchmark_gpu(
|
||||
lambda: forward_and_backward_pass(m_good_shape_opt, inputs_good_shape)
|
||||
)
|
||||
latency_bad_shape = do_bench(
|
||||
latency_bad_shape = benchmarker.benchmark_gpu(
|
||||
lambda: forward_and_backward_pass(m_bad_shape_opt, inptus_bad_shape)
|
||||
)
|
||||
print(
|
||||
@ -213,9 +213,13 @@ class PerfTestBetweenGoodAndBadShape(TestCaseBase):
|
||||
f_bad_shape, inputs_bad_shape = create_model(30522)
|
||||
|
||||
print("benchmark for good shape")
|
||||
latency_good_shape = do_bench(lambda: f_good_shape(**inputs_good_shape))
|
||||
latency_good_shape = benchmarker.benchmark_gpu(
|
||||
lambda: f_good_shape(**inputs_good_shape)
|
||||
)
|
||||
print("benchmark for bad shape")
|
||||
latency_bad_shape = do_bench(lambda: f_bad_shape(**inputs_bad_shape))
|
||||
latency_bad_shape = benchmarker.benchmark_gpu(
|
||||
lambda: f_bad_shape(**inputs_bad_shape)
|
||||
)
|
||||
print(
|
||||
f"Latency with good and bad shape: {latency_good_shape:.3f} v.s. {latency_bad_shape:.3f}"
|
||||
)
|
||||
@ -285,7 +289,7 @@ class PerfTestWithAndWithoutPadding(TestCaseBase):
|
||||
opt_f_with_padding = torch.compile(
|
||||
get_f(m_copy_with_padding, optim_with_padding)
|
||||
)
|
||||
latency_with_padding = do_bench(
|
||||
latency_with_padding = benchmarker.benchmark_gpu(
|
||||
lambda: opt_f_with_padding(*perf_args, **perf_kwargs)
|
||||
)
|
||||
latency_without_padding = None
|
||||
@ -296,7 +300,7 @@ class PerfTestWithAndWithoutPadding(TestCaseBase):
|
||||
opt_f_without_padding = torch.compile(
|
||||
get_f(m_copy_without_padding, optim_without_padding)
|
||||
)
|
||||
latency_without_padding = do_bench(
|
||||
latency_without_padding = benchmarker.benchmark_gpu(
|
||||
lambda: opt_f_without_padding(*perf_args, **perf_kwargs)
|
||||
)
|
||||
print(
|
||||
@ -387,7 +391,7 @@ class PaddingTest(TestCaseBase):
|
||||
):
|
||||
a = torch.randn(M, K)
|
||||
b = torch.randn(K, N)
|
||||
ms = do_bench(lambda: f(a, b))
|
||||
ms = benchmarker.benchmark_gpu(lambda: f(a, b))
|
||||
print(f"MxKxN {M}x{K}x{N} {f.__name__}: {ms:.3f}ms")
|
||||
|
||||
@unittest.skipIf(not DO_PERF_TEST, "Perf test not enabled")
|
||||
@ -413,8 +417,8 @@ class PaddingTest(TestCaseBase):
|
||||
mat2 = pad_dim(mat2, 6, 0)
|
||||
return torch.ops.aten.mm(mat1, mat2)
|
||||
|
||||
ori_time = do_bench(f)
|
||||
pad_time = do_bench(g)
|
||||
ori_time = benchmarker.benchmark_gpu(f)
|
||||
pad_time = benchmarker.benchmark_gpu(g)
|
||||
|
||||
print(
|
||||
f"Latency between origional matmul and padded matmul: {ori_time:.3f} v.s. {pad_time:.3f}"
|
||||
@ -447,8 +451,8 @@ class PaddingTest(TestCaseBase):
|
||||
f2 = torch.compile(
|
||||
functools.partial(f, x_bad_shape, weight_bad_shape, out_bad_shape)
|
||||
)
|
||||
latency_good_shape = do_bench(f1)
|
||||
latency_bad_shape = do_bench(f2)
|
||||
latency_good_shape = benchmarker.benchmark_gpu(f1)
|
||||
latency_bad_shape = benchmarker.benchmark_gpu(f2)
|
||||
print(
|
||||
f"Latency with good and bad shapes: {latency_good_shape:.3f} v.s. {latency_bad_shape:.3f}"
|
||||
)
|
||||
@ -480,7 +484,7 @@ class PaddingTest(TestCaseBase):
|
||||
)
|
||||
|
||||
if DO_PERF_TEST:
|
||||
latency = do_bench(
|
||||
latency = benchmarker.benchmark_gpu(
|
||||
lambda: forward_and_backward_pass(m_bad_shape_opt, inputs_bad_shape)
|
||||
)
|
||||
print(f"latency: {latency:.3f}ms")
|
||||
@ -603,8 +607,8 @@ class PaddingTest(TestCaseBase):
|
||||
act = fun(x2, weight)
|
||||
self.check_close(ref, act)
|
||||
if DO_PERF_TEST:
|
||||
latency_with_padding = do_bench(lambda: fun(x2, weight))
|
||||
latency_without_padding = do_bench(lambda: fun(x1, weight))
|
||||
latency_with_padding = benchmarker.benchmark_gpu(lambda: fun(x2, weight))
|
||||
latency_without_padding = benchmarker.benchmark_gpu(lambda: fun(x1, weight))
|
||||
print(
|
||||
f"Latency with and without padding: {latency_with_padding:.3f} v.s. {latency_without_padding:.3f}"
|
||||
)
|
||||
@ -632,8 +636,8 @@ class PaddingTest(TestCaseBase):
|
||||
with config.patch("triton.cudagraphs", False):
|
||||
opt_f = torch.compile(f)
|
||||
opt_f(x)
|
||||
eager_time = do_bench(lambda: f(x))
|
||||
opt_time = do_bench(lambda: opt_f(x))
|
||||
eager_time = benchmarker.benchmark_gpu(lambda: f(x))
|
||||
opt_time = benchmarker.benchmark_gpu(lambda: opt_f(x))
|
||||
print(
|
||||
f"Latency between eager and compiled: {eager_time:.3f} v.s. {opt_time:.3f}"
|
||||
)
|
||||
|
@ -8,7 +8,7 @@ import torch
|
||||
from torch import nn
|
||||
from torch._dynamo.utils import counters, same
|
||||
from torch._inductor import metrics
|
||||
from torch._inductor.runtime.runtime_utils import do_bench_gpu as do_bench
|
||||
from torch._inductor.runtime.benchmarking import benchmarker
|
||||
from torch._inductor.test_case import TestCase
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
||||
|
||||
@ -191,7 +191,7 @@ class TestScatterOpt(TestCase):
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
for _ in range(3):
|
||||
opt_f(opt_model, x, label)
|
||||
ms = do_bench(lambda: opt_f(opt_model, x, label))
|
||||
ms = benchmarker.benchmark_gpu(lambda: opt_f(opt_model, x, label))
|
||||
peak_mem = torch.cuda.max_memory_allocated() / 10**9
|
||||
print(f"{ms=:.3f}, {peak_mem=:.3f} GB")
|
||||
|
||||
|
@ -1476,11 +1476,11 @@ def estimate_runtime(node):
|
||||
return 1
|
||||
|
||||
elif RUNTIME_MODE == "profile":
|
||||
from triton.testing import do_bench
|
||||
|
||||
with no_dispatch():
|
||||
from torch._inductor.runtime.benchmarking import benchmarker
|
||||
|
||||
args, kwargs = pytree.tree_map(materialize_arg, (node.args, node.kwargs))
|
||||
ms = do_bench(lambda: node.target(*args, **kwargs))
|
||||
ms = benchmarker.benchmark_gpu(lambda: node.target(*args, **kwargs))
|
||||
return ms
|
||||
|
||||
elif RUNTIME_MODE == "flops":
|
||||
|
@ -46,7 +46,7 @@ if TYPE_CHECKING:
|
||||
from torch._inductor.select_algorithm import TritonTemplateCaller
|
||||
|
||||
from . import config
|
||||
from .runtime.runtime_utils import do_bench_cpu, do_bench_gpu
|
||||
from .runtime.benchmarking import benchmarker
|
||||
from .virtualized import V
|
||||
|
||||
|
||||
@ -595,7 +595,7 @@ class GPUDeviceBenchmarkRequest(BenchmarkRequest):
|
||||
device_idx = torch.cuda.current_device()
|
||||
|
||||
with torch.cuda.device(device_idx):
|
||||
out = do_bench_gpu(fn)
|
||||
out = benchmarker.benchmark_gpu(fn)
|
||||
torch.cuda.synchronize() # shake out any CUDA errors
|
||||
|
||||
return out
|
||||
@ -801,7 +801,7 @@ class CPUDeviceBenchmarkRequest(BenchmarkRequest):
|
||||
*input_tensors: torch.Tensor,
|
||||
output_tensor: Optional[torch.Tensor] = None,
|
||||
) -> float:
|
||||
return do_bench_cpu(fn)
|
||||
return benchmarker.benchmark_cpu(fn)
|
||||
|
||||
|
||||
class CppBenchmarkRequest(CPUDeviceBenchmarkRequest):
|
||||
|
@ -9,7 +9,7 @@ from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from .. import config
|
||||
from ..codecache import get_path, TritonFuture
|
||||
from ..runtime.runtime_utils import do_bench_gpu
|
||||
from ..runtime.benchmarking import benchmarker
|
||||
from ..utils import cache_on_self, IndentedBuffer
|
||||
from ..virtualized import V
|
||||
from .common import TensorArg
|
||||
@ -323,7 +323,7 @@ class MultiKernelCall:
|
||||
return inner
|
||||
|
||||
return [
|
||||
do_bench_gpu(wrap_fn(kernel), rep=40, fast_flush=True)
|
||||
benchmarker.benchmark_gpu(wrap_fn(kernel), rep=40, fast_flush=True)
|
||||
for kernel in self.kernels
|
||||
]
|
||||
|
||||
|
@ -37,8 +37,9 @@ from ...utils._sympy.value_ranges import ValueRanges
|
||||
from .. import config, ir
|
||||
from ..codecache import code_hash, get_path, PyCodeCache
|
||||
from ..metrics import is_metric_table_enabled, log_kernel_metadata
|
||||
from ..runtime.benchmarking import benchmarker
|
||||
from ..runtime.hints import ReductionHint, TRITON_MAX_BLOCK
|
||||
from ..runtime.runtime_utils import do_bench_gpu, get_max_y_grid, next_power_of_2
|
||||
from ..runtime.runtime_utils import get_max_y_grid, next_power_of_2
|
||||
from ..utils import (
|
||||
cache_on_self,
|
||||
get_bounds_index_expr,
|
||||
@ -2472,12 +2473,14 @@ class TritonKernel(SIMDKernel):
|
||||
|
||||
result.writelines(["\n", "\n", "if __name__ == '__main__':"])
|
||||
with result.indent():
|
||||
result.writeline("from triton.testing import do_bench")
|
||||
result.writeline(
|
||||
"from torch._inductor.runtime.benchmarking import benchmarker"
|
||||
)
|
||||
result.writeline("")
|
||||
|
||||
result.writeline("args = get_args()")
|
||||
result.writeline(
|
||||
"ms = do_bench(lambda: call(args), rep=40, fast_flush=True)"
|
||||
"ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40, fast_flush=True)"
|
||||
)
|
||||
result.writeline(f"num_gb = {num_gb}")
|
||||
result.writeline("gb_per_s = num_gb / (ms / 1e3)")
|
||||
@ -3080,7 +3083,7 @@ class TritonScheduling(SIMDScheduling):
|
||||
else:
|
||||
# We have to clone the inplace updated arguments to avoid earlier calls
|
||||
# generating out of range indices for later calls.
|
||||
ms = do_bench_gpu(
|
||||
ms = benchmarker.benchmark_gpu(
|
||||
lambda: call(wrapped_jit_function.clone_args(*args)[0])
|
||||
)
|
||||
|
||||
@ -3088,7 +3091,9 @@ class TritonScheduling(SIMDScheduling):
|
||||
# in the case of mutating/in-placeable second fusion
|
||||
# TODO - would be better as a hook in triton do_bench that reset
|
||||
# the input values between benchmarking
|
||||
ms = ms - do_bench_gpu(lambda: wrapped_jit_function.clone_args(*args))
|
||||
ms = ms - benchmarker.benchmark_gpu(
|
||||
lambda: wrapped_jit_function.clone_args(*args)
|
||||
)
|
||||
|
||||
log.debug(
|
||||
"The fused kernel for %s took %.3f ms to run",
|
||||
|
@ -765,12 +765,14 @@ class ComboKernel(Kernel):
|
||||
|
||||
result.writelines(["\n", "\n", "if __name__ == '__main__':"])
|
||||
with result.indent():
|
||||
result.writeline("from triton.testing import do_bench")
|
||||
result.writeline(
|
||||
"from torch._inductor.runtime.benchmarking import benchmarker"
|
||||
)
|
||||
result.writeline("")
|
||||
|
||||
result.writeline("args = get_args()")
|
||||
result.writeline(
|
||||
"ms = do_bench(lambda: call(args), rep=40, fast_flush=True)"
|
||||
"ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40, fast_flush=True)"
|
||||
)
|
||||
result.writeline(f"num_gb = {num_gb}")
|
||||
result.writeline("gb_per_s = num_gb / (ms / 1e3)")
|
||||
|
@ -367,7 +367,7 @@ def should_pad_bench(
|
||||
match, mat1: Tensor, mat2: Tensor, op, input: Optional[Tensor] = None
|
||||
) -> bool:
|
||||
do_bench = functools.partial(
|
||||
torch._inductor.runtime.runtime_utils.do_bench_gpu,
|
||||
torch._inductor.runtime.benchmarking.benchmarker.benchmark_gpu,
|
||||
warmup=5,
|
||||
)
|
||||
m_padded_length = 0
|
||||
|
@ -73,8 +73,8 @@ from .dependencies import (
|
||||
var_builder,
|
||||
)
|
||||
from .ops_handler import OpCounterCSE
|
||||
from .runtime.benchmarking import benchmarker
|
||||
from .runtime.hints import ReductionHint
|
||||
from .runtime.runtime_utils import do_bench
|
||||
from .utils import (
|
||||
argsort,
|
||||
cache_on_self,
|
||||
@ -3951,7 +3951,7 @@ class ChoiceCaller:
|
||||
|
||||
def benchmark(self, *args, out) -> float:
|
||||
algo = self.to_callable()
|
||||
return do_bench(algo, args, {"out": out})
|
||||
return benchmarker.benchmark(algo, args, {"out": out})
|
||||
|
||||
def call_name(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
141
torch/_inductor/runtime/benchmarking.py
Normal file
141
torch/_inductor/runtime/benchmarking.py
Normal file
@ -0,0 +1,141 @@
|
||||
import time
|
||||
from functools import cached_property, wraps
|
||||
from statistics import median
|
||||
from typing import Any, Callable, Dict, List, Tuple
|
||||
from typing_extensions import ParamSpec, Self, TypeVar
|
||||
|
||||
import torch
|
||||
from torch._inductor.utils import is_cpu_device
|
||||
|
||||
|
||||
log = torch._logging.getArtifactLogger(__name__, "benchmarking")
|
||||
|
||||
|
||||
MILLISECONDS_PER_SECOND = 1000
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def maybe_time(fn: Callable[P, T]) -> Callable[P, T]:
|
||||
if not torch._logging._internal.log_state.is_artifact_enabled("benchmarking"):
|
||||
return fn
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
start_s = time.perf_counter()
|
||||
result = fn(*args, **kwargs)
|
||||
log.debug(
|
||||
"fn:%r args:[%r, %r] took %f seconds.",
|
||||
fn.__name__,
|
||||
args,
|
||||
kwargs,
|
||||
time.perf_counter() - start_s,
|
||||
)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class Benchmarker:
|
||||
def __init__(self: Self) -> None:
|
||||
pass
|
||||
|
||||
@maybe_time
|
||||
def benchmark(
|
||||
self: Self,
|
||||
fn: Callable[..., Any],
|
||||
fn_args: Tuple[Any],
|
||||
fn_kwargs: Dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> float:
|
||||
"""Construct benchmarkable callable and dispatch benchmark request to the appropriate
|
||||
benchmarking function depending on the device type of `fn_args` and `fn_kwargs`.
|
||||
|
||||
Arguments:
|
||||
- fn: The function to benchmark.
|
||||
- fn_args: The function's arguments.
|
||||
- fn_kwargs: The function's kwargs.
|
||||
|
||||
Keyword Arguments:
|
||||
- **kwargs: The benchmarker's keyword arguments.
|
||||
|
||||
Returns:
|
||||
- The runtime of `fn(*fn_args, **fn_kwargs)`, in milliseconds.
|
||||
"""
|
||||
if is_cpu_device(list(fn_args) + list(fn_kwargs.values())):
|
||||
return self.benchmark_cpu(lambda: fn(*fn_args, **fn_kwargs), **kwargs)
|
||||
return self.benchmark_gpu(lambda: fn(*fn_args, **fn_kwargs), **kwargs)
|
||||
|
||||
@maybe_time
|
||||
def benchmark_cpu(
|
||||
self: Self, _callable: Callable[[], Any], warmup: int = 20, rep: int = 100
|
||||
) -> float:
|
||||
"""Benchmark a CPU callable.
|
||||
|
||||
Arguments:
|
||||
- _callable: The callable to benchmark.
|
||||
|
||||
Keyword Arguments:
|
||||
- warmup: Duration to run the callable before benchmarking, in milliseconds.
|
||||
- rep: Duration to run the benchmarking, in milliseconds.
|
||||
|
||||
Returns:
|
||||
- The median runtime of `_callable`, in milliseconds.
|
||||
"""
|
||||
|
||||
def run_for(ms: int) -> List[float]:
|
||||
timings = []
|
||||
run_start_s = time.perf_counter()
|
||||
while True:
|
||||
start_s = time.perf_counter()
|
||||
_callable()
|
||||
end_s = time.perf_counter()
|
||||
timings.append((end_s - start_s) * MILLISECONDS_PER_SECOND)
|
||||
if ((end_s - run_start_s) * MILLISECONDS_PER_SECOND) > ms:
|
||||
break
|
||||
return timings
|
||||
|
||||
run_for(warmup)
|
||||
return median(run_for(rep))
|
||||
|
||||
def benchmark_gpu(self: Self, *args: Any, **kwargs: Any) -> float:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TritonBenchmarker(Benchmarker):
|
||||
@cached_property
|
||||
@maybe_time
|
||||
def triton_do_bench(self: Self) -> Callable[..., Any]:
|
||||
"""Lazily import Triton's do_bench."""
|
||||
try:
|
||||
from triton.testing import do_bench
|
||||
except ImportError as e:
|
||||
raise NotImplementedError("requires Triton") from e
|
||||
return do_bench
|
||||
|
||||
@maybe_time
|
||||
def benchmark_gpu(self: Self, _callable: Callable[[], Any], **kwargs: Any) -> float:
|
||||
"""Benchmark a GPU callable using Triton's do_bench.
|
||||
|
||||
Arguments:
|
||||
- _callable: The callable to benchmark.
|
||||
|
||||
Keyword Arguments:
|
||||
- quantiles: A tuple of floats denoting the requested quantiles.
|
||||
- return_mode: The requested return mode, one of "min", "max", "mean", or "median".
|
||||
- **kwargs: Additional kwargs passed to triton.testing.do_bench.
|
||||
|
||||
Returns:
|
||||
- The runtime of `callable`, in milliseconds. If `kwargs["quantiles"]` is specified,
|
||||
this is the first requested quantile. Else, if `kwargs["return_mode"]` is specified,
|
||||
this is the requested return mode. Otherwise, this is the median.
|
||||
"""
|
||||
if "quantiles" in kwargs:
|
||||
return self.triton_do_bench(_callable, **kwargs)[0]
|
||||
elif "return_mode" in kwargs:
|
||||
return self.triton_do_bench(_callable, **kwargs)
|
||||
return self.triton_do_bench(_callable, **kwargs, return_mode="median")
|
||||
|
||||
|
||||
benchmarker = TritonBenchmarker()
|
@ -4,12 +4,10 @@ from __future__ import annotations
|
||||
import contextlib
|
||||
import functools
|
||||
import getpass
|
||||
import inspect
|
||||
import operator
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
@ -77,84 +75,6 @@ def get_max_y_grid():
|
||||
return 65535
|
||||
|
||||
|
||||
def do_bench(fn, fn_args, fn_kwargs, **kwargs):
|
||||
from torch._inductor.utils import is_cpu_device
|
||||
|
||||
args = list(fn_args)
|
||||
args.extend(fn_kwargs.values())
|
||||
if is_cpu_device(args):
|
||||
return do_bench_cpu(lambda: fn(*fn_args, **fn_kwargs), **kwargs)
|
||||
else:
|
||||
return do_bench_gpu(lambda: fn(*fn_args, **fn_kwargs), **kwargs)
|
||||
|
||||
|
||||
def do_bench_gpu(*args, **kwargs):
|
||||
@functools.lru_cache(None)
|
||||
def load_triton():
|
||||
try:
|
||||
# NB: Lazily load triton, as importing triton is slow
|
||||
# see https://github.com/openai/triton/issues/1599
|
||||
from triton.testing import do_bench as triton_do_bench
|
||||
except ImportError as exc:
|
||||
raise NotImplementedError("requires Triton") from exc
|
||||
|
||||
# triton PR https://github.com/openai/triton/pull/1513 change the
|
||||
# quantile fields name from 'percentiles' to 'quantiles'
|
||||
# and change the default value from (0.5, 0.2, 0.8) to None.
|
||||
# This may break inductor since a caller expects a tuple may get a item.
|
||||
#
|
||||
# Add a wrapper to maintain the same behavior for inductor.
|
||||
# Maybe we should have own implementation of this function?
|
||||
return triton_do_bench, (
|
||||
"quantiles"
|
||||
if inspect.signature(triton_do_bench).parameters.get("quantiles")
|
||||
is not None
|
||||
else "percentiles"
|
||||
)
|
||||
|
||||
triton_do_bench, quantile_field_name = load_triton()
|
||||
|
||||
if quantile_field_name not in kwargs:
|
||||
kwargs[quantile_field_name] = (0.5, 0.2, 0.8)
|
||||
return triton_do_bench(*args, **kwargs)[0]
|
||||
|
||||
|
||||
def do_bench_cpu(fn, warmup=20, rep=100):
|
||||
"""
|
||||
Benchmark a function on the CPU.
|
||||
|
||||
Parameters:
|
||||
- fn: The function to be benchmarked.
|
||||
- warmup: The number of milliseconds to run the function before starting the benchmark.
|
||||
- rep: The number of milliseconds to run the function for the benchmark.
|
||||
|
||||
Returns:
|
||||
- The median time (in milliseconds) taken by the function.
|
||||
|
||||
"""
|
||||
start = time.perf_counter()
|
||||
while True:
|
||||
fn()
|
||||
if (time.perf_counter() - start) * 1000 > warmup:
|
||||
break
|
||||
durations = []
|
||||
start = time.perf_counter()
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
fn()
|
||||
t1 = time.perf_counter()
|
||||
durations.append((t1 - t0) * 1000)
|
||||
if (t1 - start) * 1000 > rep:
|
||||
break
|
||||
# return the median time
|
||||
sorted_durations = sorted(durations)
|
||||
times = len(durations)
|
||||
if times % 2 == 0:
|
||||
return (sorted_durations[times // 2 - 1] + sorted_durations[times // 2]) / 2
|
||||
else:
|
||||
return sorted_durations[times // 2]
|
||||
|
||||
|
||||
def cache_dir() -> str:
|
||||
cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR")
|
||||
if cache_dir is None:
|
||||
|
@ -18,6 +18,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from .benchmarking import benchmarker
|
||||
from .coordinate_descent_tuner import CoordescTuner
|
||||
from .hints import (
|
||||
_NUM_THREADS_PER_WARP,
|
||||
@ -33,7 +34,6 @@ from .runtime_utils import (
|
||||
ceildiv,
|
||||
conditional_product,
|
||||
create_bandwidth_info_str,
|
||||
do_bench_gpu,
|
||||
dynamo_timed,
|
||||
get_first_attr,
|
||||
get_max_y_grid,
|
||||
@ -664,7 +664,7 @@ class CachingAutotuner(KernelInterface):
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
return do_bench_gpu(kernel_call, rep=40, fast_flush=True)
|
||||
return benchmarker.benchmark_gpu(kernel_call, rep=40, fast_flush=True)
|
||||
|
||||
def clone_args(self, *args, **kwargs) -> Tuple[List[Any], Dict[str, Any]]:
|
||||
from ..compile_fx import clone_preserve_strides
|
||||
|
@ -40,8 +40,8 @@ from .codegen.triton import (
|
||||
from .codegen.triton_utils import config_of, signature_to_meta
|
||||
from .exc import CUDACompileError
|
||||
from .ir import ChoiceCaller, PrimitiveInfoType
|
||||
from .runtime.benchmarking import benchmarker
|
||||
from .runtime.hints import DeviceProperties
|
||||
from .runtime.runtime_utils import do_bench
|
||||
from .utils import (
|
||||
FakeIndentedBuffer,
|
||||
get_dtype_size,
|
||||
@ -952,7 +952,7 @@ class ExternKernelCaller(ChoiceCaller):
|
||||
out_new, tuple(out.size()), tuple(out.stride())
|
||||
)
|
||||
out.copy_(out_new) # for correctness checking
|
||||
return do_bench(algo, args, {})
|
||||
return benchmarker.benchmark(algo, args, {})
|
||||
|
||||
def to_callable(self):
|
||||
fn = self.choice.to_callable()
|
||||
|
@ -6,11 +6,8 @@ from collections import defaultdict
|
||||
import torch
|
||||
from torch.autograd import DeviceType
|
||||
|
||||
from .runtime.runtime_utils import (
|
||||
create_bandwidth_info_str,
|
||||
do_bench_gpu,
|
||||
get_num_bytes,
|
||||
)
|
||||
from .runtime.benchmarking import benchmarker
|
||||
from .runtime.runtime_utils import create_bandwidth_info_str, get_num_bytes
|
||||
|
||||
|
||||
_kernel_category_choices = [
|
||||
@ -123,7 +120,9 @@ def benchmark_all_kernels(benchmark_name, benchmark_all_configs):
|
||||
f" {get_info_str(ms, launcher.n_regs, launcher.n_spills, launcher.shared)} @ {launcher.config}"
|
||||
)
|
||||
else:
|
||||
ms = do_bench_gpu(lambda: kernel_mod.call(args), rep=40, fast_flush=True)
|
||||
ms = benchmarker.benchmark_gpu(
|
||||
lambda: kernel_mod.call(args), rep=40, fast_flush=True
|
||||
)
|
||||
assert (
|
||||
len(triton_kernel.launchers) == 1
|
||||
), "Autotuner should have selected the best config"
|
||||
|
@ -232,6 +232,7 @@ def set_logs(
|
||||
sym_node: bool = False,
|
||||
compiled_autograd_verbose: bool = False,
|
||||
cudagraph_static_inputs: bool = False,
|
||||
benchmarking: bool = False,
|
||||
):
|
||||
"""
|
||||
Sets the log level for individual components and toggles individual log
|
||||
@ -395,6 +396,9 @@ def set_logs(
|
||||
export (:class:`Optional[int]`):
|
||||
The log level for export. Default: ``logging.WARN``
|
||||
|
||||
benchmarking (:class:`bool`):
|
||||
Whether to emit detailed Inductor benchmarking information. Default: ``False``
|
||||
|
||||
modules (dict):
|
||||
This argument provides an alternate way to specify the above log
|
||||
component and artifact settings, in the format of a keyword args
|
||||
@ -504,6 +508,7 @@ def set_logs(
|
||||
cudagraphs=cudagraphs,
|
||||
compiled_autograd_verbose=compiled_autograd_verbose,
|
||||
cudagraph_static_inputs=cudagraph_static_inputs,
|
||||
benchmarking=benchmarking,
|
||||
)
|
||||
|
||||
|
||||
|
@ -160,11 +160,15 @@ register_artifact(
|
||||
"Logs traces for every ShapeEnv operation that we record for replay",
|
||||
off_by_default=True,
|
||||
)
|
||||
|
||||
register_artifact(
|
||||
"cudagraph_static_inputs",
|
||||
"Logs static inputs handling in dynamo, AOT, and cudagraphs",
|
||||
off_by_default=True,
|
||||
)
|
||||
register_artifact(
|
||||
"benchmarking",
|
||||
"Detailed Inductor benchmarking information.",
|
||||
off_by_default=True,
|
||||
)
|
||||
|
||||
register_artifact("custom_format_test_artifact", "Testing only", log_format="")
|
||||
|
Reference in New Issue
Block a user