Files
pytorch/benchmarks/dynamo/microbenchmarks/bench_mm_fusion.py
Gabriel Ferns 254293b777 Add flag _metrics_log_runtime to disable runtime metric logging by default (#153506)
https://github.com/pytorch/pytorch/pull/152708 expanded support of `get_estimated_runtime` to many more types of `SchedulerNodes`. This caused an increase in compile time because we're always calling `get_estimated_runtime` to populate the metrics table. This PR adds a flag for this logging, which reduces the instruction count by 8%. Long term, we should probably merge metrics.py with TORCH_LOGS/tlparse (suggestion from @xmfan).

Update: added support for TORCH_LOGS for the metrics logging.

Test Plan:
mm_loop.py and many existing tests cover.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153506
Approved by: https://github.com/eellison
2025-05-22 01:02:11 +00:00

127 lines
3.2 KiB
Python

# flake8: noqa: B902
from prettytable import PrettyTable
import torch
import torch._dynamo
import torch._inductor.config
from torch._inductor.runtime.benchmarking import benchmarker
# torch._inductor.config.debug = True
torch._inductor.config.triton.dense_indexing = True
torch.manual_seed(0)
# The flag below controls whether to allow TF32 on matmul.
torch.backends.cuda.matmul.allow_tf32 = True
class Func:
# mm
@torch._dynamo.optimize("inductor")
def mm(a, b, bias):
y = torch.mm(a, b)
return y
# mm+bias
@torch._dynamo.optimize("inductor")
def mm_add(a, b, bias):
y = torch.mm(a, b)
return y + bias
# relu(mm)
@torch._dynamo.optimize("inductor")
def mm_relu(a, b, bias):
y = torch.mm(a, b)
return torch.relu(y)
# relu(mm+bias)
@torch._dynamo.optimize("inductor")
def mm_add_relu(a, b, bias):
y = torch.mm(a, b)
y += bias
return torch.relu(y)
def bench(shape, layer_id, p, fusion_types=None):
torch._logging.set_logs(inductor_metrics=True)
if fusion_types is None:
fusion_types = [""]
dtype = torch.float16
M, K = shape[0]
_, N = shape[1]
torch.manual_seed(0)
# allocate inputs
a = torch.randn(shape[0], device="cuda", dtype=dtype)
b = torch.randn(shape[1], device="cuda", dtype=dtype)
def tflops(ms):
return M * K * N / ms * 1e-9
row = [layer_id]
for fusion_type in fusion_types:
if fusion_type == "":
fn_mm = Func.mm
else:
fn_mm = getattr(Func, f"mm_{fusion_type}")
if "add" in fusion_type:
bias = torch.randn((M, N), dtype=dtype, device="cuda")
else:
bias = None
args = (a, b, bias)
def fn():
return fn_mm(*args)
torch._inductor.config.triton.mm = "aten"
torch_mm_ms, _, _ = benchmarker.benchmark_gpu(fn)
torch._inductor.config.triton.mm = "triton"
# reset to force code gen new python code
torch._dynamo.reset()
torch._inductor.metrics.reset()
triton_mm_ms, _, _ = benchmarker.benchmark_gpu(fn)
assert torch._inductor.metrics.generated_kernel_count == 1, (
"codegen #kernel != 1"
)
row.extend([tflops(torch_mm_ms), tflops(triton_mm_ms)])
p.add_row(row)
torch._logging.set_logs()
fusion_types = ["", "add", "relu", "add_relu"]
shapes = [
# alexnet
([128, 9216], [9216, 4096]),
([128, 4096], [4096, 4096]),
([128, 4096], [4096, 1000]),
# BERT
([2048, 768], [768, 768]),
([2048, 768], [768, 3072]),
([2048, 3072], [3072, 768]),
# hf_GPT2
([1024, 768], [768, 768]),
([1024, 768], [768, 3072]),
([1024, 3072], [3072, 768]),
([1024, 768], [768, 2304]),
]
p = PrettyTable()
field_names = ["layer"]
for fusion_type in fusion_types:
if fusion_type == "":
field_names.append("torch mm")
field_names.append("triton mm")
else:
field_names.append(f"torch mm+{fusion_type}")
field_names.append(f"triton mm+{fusion_type}")
p.field_names = field_names
p.float_format = ".3"
for id, shape in enumerate(shapes):
bench(shape, id, p, fusion_types)
print(p)