Files
pytorch/torch/_inductor/wrapper_benchmark.py
Gabriel Ferns 47f10d0ad0 Inductor logging + analysis of torch.profile (#149697)
Prereqs:
 - https://github.com/pytorch/pytorch/pull/152708

Features:
1. Adds inductor's estimate of flops and bandwidth to the json trace events that perfetto uses.
1. Only use the tflops estimation from triton if we don't have the info from the datasheet because Triton's estimates are inaccurate. I have a backlog item to fix triton flops estimation upstream. New `DeviceInfo` class, and new function `get_device_tflops`.
1. New helpers `countable_fx` and `count_flops_fx` helps get the flops of an `fx.Node`.
1. Extends Triton `torch.profiler` logging to `DebugAutotuner`.
1. New script `profile_analysis.py`: `--augment_trace` adds perf estimates to any perfetto json trace, `--analyze` creates a summary table of these perf estimates, and `--diff` will compare two traces side by side:
```python
Device(NVIDIA H100, 0):
 Kernel Name                              | resnet Kernel Count | resnet FLOPS       | resnet bw gbps        | resnet Dur (ms)    | resnet Achieved FLOPS % | resnet Achieved Bandwidth % | newresnet Kernel Count | newresnet FLOPS    | newresnet bw gbps     | newresnet Dur (ms) | newresnet Achieved FLOPS % | newresnet Achieved Bandwidth %
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
 triton_poi_fused__native_batch_norm_legi | 24                  | 0                  | 0.11395268248131513   | 2.5919166666666666 | 0                       | 0.003401572611382541        | 24                     | 0                  | 0.11395268248131513   | 2.5919166666666666 | 0                          | 0.003401572611382541
 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 142                 | 16932673552.422373 | 0.2585007824198784    | 12.441619718309857 | 0.08683422334575583     | 0.007716441266265022        | 142                    | 16932673552.422373 | 0.2585007824198784    | 12.441619718309857 | 0.08683422334575583        | 0.007716441266265022
 triton_red_fused__native_batch_norm_legi | 39                  | 0                  | 0.13990024992108846   | 5.752589743589743  | 0                       | 0.004176126863316074        | 39                     | 0                  | 0.13990024992108846   | 5.752589743589743  | 0                          | 0.004176126863316074
 triton_poi_fused__native_batch_norm_legi | 25                  | 0                  | 0.31824055917536503   | 2.5291999999999994 | 0                       | 0.009499718184339253        | 25                     | 0                  | 0.31824055917536503   | 2.5291999999999994 | 0                          | 0.009499718184339253
 void cutlass::Kernel2<cutlass_80_tensoro | 98                  | 16211056473.596165 | 0.42972434051025826   | 7.130408163265306  | 0.08313362294151874     | 0.012827592254037562        | 98                     | 16211056473.596165 | 0.42972434051025826   | 7.130408163265306  | 0.08313362294151874        | 0.012827592254037562
 triton_red_fused__native_batch_norm_legi | 73                  | 0                  | 0.3225381327611705    | 9.987068493150682  | 0                       | 0.009628003963020014        | 73                     | 0                  | 0.3225381327611705    | 9.987068493150682  | 0                          | 0.009628003963020014
 triton_poi_fused__native_batch_norm_legi | 15                  | 0                  | 1.4491211346487216    | 4.439333333333333  | 0                       | 0.043257347302946926        | 15                     | 0                  | 1.4491211346487216    | 4.439333333333333  | 0                          | 0.043257347302946926
 void cutlass::Kernel2<cutlass_80_tensoro | 186                 | 14501701145.337954 | 0.2667131401910989    | 7.873865591397849  | 0.07436769818122027     | 0.007961586274361157        | 186                    | 14501701145.337954 | 0.2667131401910989    | 7.873865591397849  | 0.07436769818122027        | 0.007961586274361157
 triton_poi_fused__native_batch_norm_legi | 33                  | 0                  | 1.4924556538193923    | 4.3101515151515155 | 0                       | 0.044550915039384846        | 33                     | 0                  | 1.4924556538193923    | 4.3101515151515155 | 0                          | 0.044550915039384846
 triton_red_fused__native_batch_norm_legi | 29                  | 0                  | 0.25562590522631107   | 6.296275862068965  | 0                       | 0.007630624036606301        | 29                     | 0                  | 0.25562590522631107   | 6.296275862068965  | 0                          | 0.007630624036606301
 triton_poi_fused__native_batch_norm_legi | 13                  | 0                  | 0.5870562174192726    | 2.7397692307692307 | 0                       | 0.01752406619162008         | 13                     | 0                  | 0.5870562174192726    | 2.7397692307692307 | 0                          | 0.01752406619162008
 triton_poi_fused__native_batch_norm_legi | 34                  | 0                  | 0.41409928846284      | 2.853588235294117  | 0                       | 0.012361172789935523        | 34                     | 0                  | 0.41409928846284      | 2.853588235294117  | 0                          | 0.012361172789935523
 triton_per_fused__native_batch_norm_legi | 34                  | 0                  | 0.11705315007018151   | 3.460647058823529  | 0                       | 0.0034941238826919864       | 34                     | 0                  | 0.11705315007018151   | 3.460647058823529  | 0                          | 0.0034941238826919864
 triton_poi_fused__native_batch_norm_legi | 16                  | 0                  | 0.17207853197124584   | 2.3459375000000002 | 0                       | 0.005136672596156592        | 16                     | 0                  | 0.17207853197124584   | 2.3459375000000002 | 0                          | 0.005136672596156592
 triton_per_fused__native_batch_norm_legi | 30                  | 0                  | 0.2639714322022256    | 6.131199999999999  | 0                       | 0.007879744244842555        | 30                     | 0                  | 0.2639714322022256    | 6.131199999999999  | 0                          | 0.007879744244842555
 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 100                 | 11875430356.891787 | 0.19494470869421385   | 16.36534           | 0.06089964285585531     | 0.005819245035648175        | 100                    | 11875430356.891787 | 0.19494470869421385   | 16.36534           | 0.06089964285585531        | 0.005819245035648175
 triton_poi_fused__native_batch_norm_legi | 8                   | 0                  | 0.9854096626224687    | 3.2757500000000004 | 0                       | 0.029415213809625928        | 8                      | 0                  | 0.9854096626224687    | 3.2757500000000004 | 0                          | 0.029415213809625928
 void cublasLt::splitKreduce_kernel<32, 1 | 56                  | 34377923395.147064 | 0.8310300045762317    | 3.4199999999999986 | 0.17629704305203628     | 0.024806865808245714        | 56                     | 34377923395.147064 | 0.8310300045762317    | 3.4199999999999986 | 0.17629704305203628        | 0.024806865808245714
 triton_poi_fused__native_batch_norm_legi | 23                  | 0                  | 0.9944002965861103    | 3.2431304347826084 | 0                       | 0.02968359094286896         | 23                     | 0                  | 0.9944002965861103    | 3.2431304347826084 | 0                          | 0.02968359094286896
 triton_per_fused__native_batch_norm_legi | 10                  | 0                  | 0.1826801058931057    | 4.428800000000001  | 0                       | 0.00545313748934644         | 10                     | 0                  | 0.1826801058931057    | 4.428800000000001  | 0                          | 0.00545313748934644
 triton_poi_fused__native_batch_norm_legi | 10                  | 0                  | 0.3168973585366449    | 2.5471999999999997 | 0                       | 0.009459622642884923        | 10                     | 0                  | 0.3168973585366449    | 2.5471999999999997 | 0                          | 0.009459622642884923
 triton_poi_fused__native_batch_norm_legi | 34                  | 0                  | 1.1463614897015777    | 4.124323529411764  | 0                       | 0.03421974596124114         | 34                     | 0                  | 1.1463614897015777    | 4.124323529411764  | 0                          | 0.03421974596124114
 void cask_plugin_cudnn::xmma_cudnn::init | 44                  | 44045510816.64277  | 2.0661232850348643    | 3.6887499999999993 | 0.22587441444432194     | 0.06167532194133924         | 44                     | 44045510816.64277  | 2.0661232850348643    | 3.6887499999999993 | 0.22587441444432194        | 0.06167532194133924
 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 95                  | 7876855400.165316  | 0.4694941555946739    | 18.224315789473682 | 0.04039413025725802     | 0.014014750913273854        | 95                     | 7876855400.165316  | 0.4694941555946739    | 18.224315789473682 | 0.04039413025725802        | 0.014014750913273854
 triton_per_fused__native_batch_norm_legi | 41                  | 0                  | 0.06825669875995298   | 3.0384146341463416 | 0                       | 0.002037513395819492        | 41                     | 0                  | 0.06825669875995298   | 3.0384146341463416 | 0                          | 0.002037513395819492
 triton_poi_fused__native_batch_norm_legi | 23                  | 0                  | 0.08808154712430301   | 2.3275652173913044 | 0                       | 0.0026292999141582997       | 23                     | 0                  | 0.08808154712430301   | 2.3275652173913044 | 0                          | 0.0026292999141582997
 triton_per_fused__native_batch_norm_legi | 40                  | 0                  | 0.18179321034952417   | 4.556825           | 0                       | 0.005426662995508183        | 40                     | 0                  | 0.18179321034952417   | 4.556825           | 0                          | 0.005426662995508183
 triton_poi_fused__native_batch_norm_legi | 15                  | 0                  | 0.5887415155454232    | 2.783866666666667  | 0                       | 0.017574373598370836        | 15                     | 0                  | 0.5887415155454232    | 2.783866666666667  | 0                          | 0.017574373598370836
 void cutlass::Kernel2<cutlass_80_tensoro | 38                  | 14242013806.264643 | 0.256592404353939     | 7.217631578947369  | 0.0730359682372546      | 0.007659474756834           | 38                     | 14242013806.264643 | 0.256592404353939     | 7.217631578947369  | 0.0730359682372546         | 0.007659474756834
 triton_poi_fused__native_batch_norm_legi | 21                  | 0                  | 0.5842860973430516    | 2.7779047619047623 | 0                       | 0.017441376040091088        | 21                     | 0                  | 0.5842860973430516    | 2.7779047619047623 | 0                          | 0.017441376040091088
 triton_per_fused__native_batch_norm_legi | 16                  | 0                  | 0.11509365173486417   | 3.5959375000000002 | 0                       | 0.0034356313950705724       | 16                     | 0                  | 0.11509365173486417   | 3.5959375000000002 | 0                          | 0.0034356313950705724
 triton_poi_fused__native_batch_norm_legi | 14                  | 0                  | 0.1704672000243914    | 2.4044285714285714 | 0                       | 0.00508857313505646         | 14                     | 0                  | 0.1704672000243914    | 2.4044285714285714 | 0                          | 0.00508857313505646
 triton_poi_fused__native_batch_norm_legi | 58                  | 0                  | 2.307520779930795     | 8.190706896551722  | 0                       | 0.06888121731136704         | 58                     | 0                  | 2.307520779930795     | 8.190706896551722  | 0                          | 0.06888121731136704
 triton_per_fused__native_batch_norm_legi | 29                  | 0                  | 0.037243248971881276  | 3.0277586206896556 | 0                       | 0.001111738775280038        | 29                     | 0                  | 0.037243248971881276  | 3.0277586206896556 | 0                          | 0.001111738775280038
 triton_poi_fused__native_batch_norm_legi | 20                  | 0                  | 0.04741699795428918   | 2.2911500000000005 | 0                       | 0.0014154327747549007       | 20                     | 0                  | 0.04741699795428918   | 2.2911500000000005 | 0                          | 0.0014154327747549007
 triton_per_fused__native_batch_norm_legi | 25                  | 0                  | 0.13357016893727824   | 3.37536            | 0                       | 0.003987169222008305        | 25                     | 0                  | 0.13357016893727824   | 3.37536            | 0                          | 0.003987169222008305
 triton_poi_fused__native_batch_norm_legi | 13                  | 0                  | 0.3089862268300253    | 2.8111538461538457 | 0                       | 0.009223469457612694        | 13                     | 0                  | 0.3089862268300253    | 2.8111538461538457 | 0                          | 0.009223469457612694
 triton_poi_fused__native_batch_norm_legi | 17                  | 0                  | 0.3129385387909844    | 2.673              | 0                       | 0.009341448919133863        | 17                     | 0                  | 0.3129385387909844    | 2.673              | 0                          | 0.009341448919133863
 triton_per_fused__native_batch_norm_legi | 19                  | 0                  | 0.2215568162533158    | 3.8837368421052636 | 0                       | 0.0066136363060691275       | 19                     | 0                  | 0.2215568162533158    | 3.8837368421052636 | 0                          | 0.0066136363060691275
 std::enable_if<!(false), void>::type int | 23                  | 504916805.19297093 | 1.0118296096314707    | 8.113913043478261  | 0.0025893169497075447   | 0.030203868944223014        | 23                     | 504916805.19297093 | 1.0118296096314707    | 8.113913043478261  | 0.0025893169497075447      | 0.030203868944223014
 triton_poi_fused_add_copy__38            | 56                  | 0                  | 0                     | 2.132482142857143  | 0                       | 0                           | 56                     | 0                  | 0                     | 2.132482142857143  | 0                          | 0
 triton_poi_fused_convolution_0           | 18                  | 0                  | 0.43458610794936897   | 2.773333333333334  | 0                       | 0.012972719640279667        | 18                     | 0                  | 0.43458610794936897   | 2.773333333333334  | 0                          | 0.012972719640279667
 triton_poi_fused_convolution_1           | 17                  | 0                  | 0.028816312469162712  | 2.6145882352941174 | 0                       | 0.0008601884319153051       | 17                     | 0                  | 0.028816312469162712  | 2.6145882352941174 | 0                          | 0.0008601884319153051
 void convolve_common_engine_float_NHWC<f | 44                  | 8641868995.31118   | 0.024730540008465626  | 25.87327272727273  | 0.04431727689903169     | 0.0007382250748795709       | 44                     | 8641868995.31118   | 0.024730540008465626  | 25.87327272727273  | 0.04431727689903169        | 0.0007382250748795709
 triton_per_fused__native_batch_norm_legi | 12                  | 0                  | 0.6809930918986744    | 4.82675            | 0                       | 0.020328151996975356        | 12                     | 0                  | 0.6809930918986744    | 4.82675            | 0                          | 0.020328151996975356
 triton_per_fused__native_batch_norm_legi | 14                  | 0                  | 0.02883030597936608   | 2.6651428571428575 | 0                       | 0.0008606061486377935       | 14                     | 0                  | 0.02883030597936608   | 2.6651428571428575 | 0                          | 0.0008606061486377935
 triton_per_fused__native_batch_norm_legi | 16                  | 0                  | 0.0014658988233201874 | 2.098              | 0                       | 4.375817383045335e-05       | 16                     | 0                  | 0.0014658988233201874 | 2.098              | 0                          | 4.375817383045335e-05
 triton_poi_fused__native_batch_norm_legi | 13                  | 0                  | 0.9926297180284697    | 3.2367692307692306 | 0                       | 0.02963073785159611         | 13                     | 0                  | 0.9926297180284697    | 3.2367692307692306 | 0                          | 0.02963073785159611
 triton_poi_fused__native_batch_norm_legi | 9                   | 0                  | 1.3008817095666507    | 3.0863333333333336 | 0                       | 0.03883228983781048         | 9                      | 0                  | 1.3008817095666507    | 3.0863333333333336 | 0                          | 0.03883228983781048
 void at::native::(anonymous namespace):: | 98                  | 0                  | 0.09174335613709389   | 4.408520408163265  | 0                       | 0.0027386076458833994       | 98                     | 0                  | 0.09174335613709389   | 4.408520408163265  | 0                          | 0.0027386076458833994
 void at::native::vectorized_elementwise_ | 7                   | 0                  | 0                     | 1.7278571428571428 | 0                       | 0                           | 7                      | 0                  | 0                     | 1.7278571428571428 | 0                          | 0
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149697
Approved by: https://github.com/eellison, https://github.com/shunting314
2025-07-01 16:51:03 +00:00

504 lines
16 KiB
Python

import argparse
import datetime
import tempfile
from collections import defaultdict
from dataclasses import dataclass
from types import ModuleType
from typing import Any, Optional, Protocol
import torch
from torch.autograd import DeviceType
from torch.utils._ordered_set import OrderedSet
from .runtime.benchmarking import benchmarker
from .runtime.runtime_utils import create_bandwidth_info_str, get_num_bytes
class BenchmarkCallableType(Protocol):
def __call__(self, times: int, repeat: int) -> float: ...
_kernel_category_choices = [
"foreach",
"persistent_reduction",
"pointwise",
"reduction",
"split_scan",
"template",
]
def get_kernel_category_by_source_code(src_code: str) -> str:
"""
Similar to get_kernel_category but use the source code. Call this API
if we have not compile the src_code to module yet.
"""
choices = [
ch for ch in _kernel_category_choices if f"@triton_heuristics.{ch}" in src_code
]
if len(choices) == 1:
return choices[0]
else:
return "unknown"
def get_kernel_category(kernel_mod: ModuleType) -> str:
"""
Given the module defining a triton kernel, return the category of the kernel.
Category can be one of:
- pointwise
- reduction
- persistent_reduction
Currently we simply decide the category depending on what decorator is imported
by the kernel.
"""
choices = [ch for ch in _kernel_category_choices if ch in kernel_mod.__dict__]
if len(choices) == 1:
return choices[0]
else:
return "unknown"
def get_triton_kernel(mod: ModuleType): # type: ignore[no-untyped-def]
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
cand_list = [
v
for k, v in mod.__dict__.items()
if k.startswith("triton_") and isinstance(v, CachingAutotuner)
]
assert len(cand_list) == 1
return cand_list[0]
def benchmark_all_kernels(
benchmark_name: str, benchmark_all_configs: Optional[dict[Any, Any]]
) -> None:
"""
An experimental API used only when config.benchmark_kernel is true.
Run the kernel benchmarks for all the kernels cached in PyCodeCache.
Used in the compiled modules.
Put this method here rather than codegen it for convenience since its implementation
does not change based on different graph modules being compiled.
"""
from torch._inductor.codecache import PyCodeCache
nfound = 0
for kernel_mod in PyCodeCache.modules:
kernel_key = kernel_mod.key
if not hasattr(kernel_mod, "get_args") or not hasattr(kernel_mod, "call"):
continue
triton_kernel = get_triton_kernel(kernel_mod)
kernel_category = get_kernel_category(kernel_mod)
args = kernel_mod.get_args()
num_in_out_ptrs = len(
[
arg_name
for arg_name in triton_kernel.fn.arg_names
if arg_name.startswith("in_out_ptr")
]
)
num_gb = triton_kernel.inductor_meta.get("kernel_num_gb", None)
if num_gb is None:
num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9
def get_info_str(
ms: float,
n_regs: Optional[Any],
n_spills: Optional[Any],
shared: Optional[Any],
prefix: str = "",
) -> str:
if not any(x is None for x in [n_regs, n_spills, shared]):
kernel_detail_str = (
f" {n_regs:3} regs {n_spills:3} spills {shared:8} shared mem"
)
else:
kernel_detail_str = ""
gb_per_s = num_gb / (ms / 1e3)
return create_bandwidth_info_str(
ms, num_gb, gb_per_s, prefix=prefix, suffix=kernel_detail_str
)
kernel_desc = (
f"{benchmark_name:20} {kernel_category[:3].upper()} {kernel_key[:10]}"
)
if benchmark_all_configs:
assert hasattr(kernel_mod, "benchmark_all_configs")
bench_result = kernel_mod.benchmark_all_configs(args)
print(kernel_desc)
for launcher, ms in bench_result.items():
print(
f" {get_info_str(ms, launcher.n_regs, launcher.n_spills, launcher.shared)} @ {launcher.config}"
)
else:
ms = benchmarker.benchmark_gpu(lambda: kernel_mod.call(args), rep=40)
assert len(triton_kernel.launchers) == 1, (
"Autotuner should have selected the best config"
)
launcher = triton_kernel.launchers[0]
print(
get_info_str(
ms,
launcher.n_regs,
launcher.n_spills,
launcher.shared,
prefix=f"{kernel_desc} ",
)
)
nfound += 1
if nfound == 0:
print(
"No kernel with benchmark functionality found. Make sure you run inductor with config.benchmark_kernel being True"
)
@dataclass
class ProfileEvent:
category: str
key: str
self_device_time_ms: float
# the benchmark is run multiple times and we average the count across all the
# runs. It should be an integer but define a float just in case.
count: float
def parse_profile_event_list(
benchmark_name: str,
event_list: torch.autograd.profiler_util.EventList,
wall_time_ms: float,
nruns: int,
device_name: str,
) -> None:
"""
Parse and generate a report for an event_list.
"""
def get_self_device_time(
ev: torch.autograd.profiler_util.EventList,
) -> float:
"""
ev.self_device_time_total is in microsecond. Convert to millisecond.
"""
return ev.self_device_time_total / 1000 / nruns # type: ignore[attr-defined]
all_events: dict[str, list[ProfileEvent]] = defaultdict(list)
def add_event(
ev: torch.autograd.profiler_util.EventList,
category: str,
) -> None:
profile_ev = ProfileEvent(
category=category,
key=ev.key, # type: ignore[attr-defined]
self_device_time_ms=get_self_device_time(ev),
count=ev.count / nruns, # type: ignore[operator] # average across all runs
)
all_events[category].append(profile_ev)
for ev in event_list:
assert not ev.is_legacy, "Don't support the legacy profiler"
if ev.device_type == DeviceType.CPU:
# ignore the event on CPU side
continue
category = "unknown"
if ev.key.startswith("triton_"):
if ev.key.startswith("triton_poi"):
category = "triton_pointwise"
elif ev.key.startswith("triton_red"):
category = "triton_reduction"
elif ev.key.startswith("triton_per"):
category = "triton_persistent_reduction"
else:
category = "triton_unknown"
add_event(ev, category)
def report_category(category: str, profile_events: list[ProfileEvent]) -> float:
if not device_name:
return 0.0
from tabulate import tabulate
profile_events.sort(key=lambda ev: ev.self_device_time_ms, reverse=True)
rows = []
total_time = 0.0
print(f"\n == {category} category kernels == ")
for ev in profile_events:
total_time += ev.self_device_time_ms
percent = f"{ev.self_device_time_ms / wall_time_ms * 100:.2f}%"
rows.append([ev.key[:120], ev.self_device_time_ms, ev.count, percent])
rows.append(
["Total", total_time, "", f"{total_time / wall_time_ms * 100:.2f}%"]
)
print(
tabulate(
rows,
headers=[
"Kernel",
f"Self {device_name.upper()} TIME (ms)",
"Count",
"Percent",
],
)
)
return total_time
def report() -> None:
category_list = [
"triton_pointwise",
"triton_reduction",
"triton_persistent_reduction",
"triton_unknown",
"unknown",
]
assert OrderedSet(all_events.keys()).issubset(OrderedSet(category_list)), (
f"{list(all_events.keys())}"
)
per_category_wall_time = {}
total_device_ms = 0.0
for category in category_list:
if category in all_events:
_time = report_category(category, all_events[category])
per_category_wall_time[category] = _time
total_device_ms += _time
device_busy_percent = f"{total_device_ms / wall_time_ms * 100:.2f}%"
if device_name:
print(
f"\nPercent of time when {device_name.upper()} is busy: {device_busy_percent}"
)
else:
print("No device detected")
print(f"Total wall time {wall_time_ms:.3f} ms")
# output such a line so we can gather such line from all compiled modules from all
# benchmarks and tabulate it!
# Columns: benchmark_name, pointwise_percent, reduction_percent, persistent_reduction_percent,
# unknown_category_percent, device_busy_percent, wall_time_ms
tabulate_line = f"Output for tabulate: {benchmark_name}"
for category in category_list:
percent = (
f"{per_category_wall_time.get(category, 0.0) / wall_time_ms * 100:.2f}%"
)
tabulate_line += f", {percent}"
tabulate_line += f", {device_busy_percent}, {wall_time_ms:.3f}ms"
print(tabulate_line)
report()
PROFILE_DIR = tempfile.gettempdir()
PROFILE_PATH = f"{PROFILE_DIR}/compiled_module_profile.json"
def perf_profile(
wall_time_ms: float,
times: int,
repeat: int,
benchmark_name: str,
benchmark_compiled_module_fn: BenchmarkCallableType,
) -> None:
with torch.profiler.profile(record_shapes=True) as p:
benchmark_compiled_module_fn(times=times, repeat=repeat)
path = PROFILE_PATH
p.export_chrome_trace(path)
print(f"Profiling result for a compiled module of benchmark {benchmark_name}:")
print(f"Chrome trace for the profile is written to {path}")
event_list = p.key_averages(group_by_input_shape=True)
print(event_list.table(sort_by="self_device_time_total", row_limit=10))
parse_profile_event_list(
benchmark_name, event_list, wall_time_ms, times * repeat, p.use_device or ""
)
def ncu_analyzer(
benchmark_name: str,
benchmark_compiled_module_fn: BenchmarkCallableType,
args: argparse.Namespace,
) -> None:
import inspect
import os
import subprocess
kernel_regex = args.ncu_kernel_regex
metrics = args.ncu_metrics
module_file = inspect.getfile(benchmark_compiled_module_fn)
module_dir = os.path.dirname(module_file)
module_name = os.path.splitext(os.path.basename(module_file))[0]
ncu_dir = tempfile.gettempdir()
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
ncu_output = os.path.join(ncu_dir, f"ncu_output_{timestamp}.ncu-rep")
python_cmd = (
f"""import sys; sys.path.insert(0, '{module_dir}'); """
f"""from {module_name} import benchmark_compiled_module; """
"""benchmark_compiled_module(times=1, repeat=1)"""
)
ncu_cmd = [
"ncu",
"--target-processes",
"all",
"--replay-mode",
"kernel",
"--kernel-name-base",
"function",
"--print-units",
"base",
"--import-source",
"yes",
"--force-overwrite",
"--export",
ncu_output,
]
if kernel_regex:
ncu_cmd.extend(["--kernel-name", f"regex:{kernel_regex}"])
if metrics:
ncu_cmd.extend(["--metrics", metrics])
else:
ncu_cmd.extend(["--set", "full"])
ncu_cmd.extend(
[
"python",
"-c",
python_cmd,
]
)
try:
subprocess.run(ncu_cmd, check=True)
print(f"\nNCU profiling results for benchmark {benchmark_name}:")
print(f"NCU report has been written to {ncu_output}")
except subprocess.CalledProcessError as e:
print(f"NCU profiling failed with error: {e}")
return
def collect_memory_snapshot(
benchmark_compiled_module_fn: BenchmarkCallableType,
) -> None:
assert torch.cuda.is_available()
torch.cuda.memory._record_memory_history(max_entries=100000)
benchmark_compiled_module_fn(times=10, repeat=1) # run 10 times
snapshot_path = f"{tempfile.gettempdir()}/memory_snapshot.pickle"
torch.cuda.memory._dump_snapshot(snapshot_path)
torch.cuda.memory._record_memory_history(enabled=None)
print(f"The collect memory snapshot has been written to {snapshot_path}")
# With AOTAutograd cache, we directly call the compiled module. So prevent
# Dynamo from reentering
@torch.compiler.disable # type: ignore[misc]
def compiled_module_main(
benchmark_name: str, benchmark_compiled_module_fn: BenchmarkCallableType
) -> None:
"""
This is the function called in __main__ block of a compiled module.
"""
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--benchmark-kernels",
"-k",
action="store_true",
help="Whether to benchmark each individual kernels",
)
parser.add_argument(
"--benchmark-all-configs",
"-c",
action="store_true",
help="Whether to benchmark each individual config for a kernel",
)
parser.add_argument(
"--profile",
"-p",
action="store_true",
help="Whether to profile the compiled module",
)
parser.add_argument(
"--cuda-memory-snapshot",
action="store_true",
help="""
Whether to collect CUDA memory snapshot. Refer to
"https://pytorch.org/blog/understanding-gpu-memory-1/
for details about how to visualize the collected snapshot
""",
)
parser.add_argument(
"--ncu",
action="store_true",
help="Whether to run ncu analysis",
)
parser.add_argument(
"--ncu-kernel-regex",
type=str,
default=None,
help=(
"Filter kernels profiled by NCU using a regex (e.g., '^triton_.*'). "
"Maps to '--kernel-name regex:<regex>'. "
"If None, NCU will profile all kernels."
),
)
parser.add_argument(
"--ncu-metrics",
type=str,
default=None,
help=(
"Comma-separated list of NCU metrics to collect (e.g., 'dram__bytes.sum.per_second'). "
"If None, NCU will use '--set full'."
),
)
args = parser.parse_args()
if args.benchmark_kernels:
benchmark_all_kernels(benchmark_name, args.benchmark_all_configs)
else:
times = 10
repeat = 10
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
wall_time_ms = benchmark_compiled_module_fn(times=times, repeat=repeat) * 1000
if torch.cuda.is_available():
peak_mem = torch.cuda.max_memory_allocated()
print(f"Peak GPU memory usage {peak_mem / 1e6:.3f} MB")
if torch.cuda.is_available() and args.cuda_memory_snapshot:
collect_memory_snapshot(benchmark_compiled_module_fn)
if args.profile:
perf_profile(
wall_time_ms,
times,
repeat,
benchmark_name,
benchmark_compiled_module_fn,
)
if args.ncu:
ncu_analyzer(
benchmark_name,
benchmark_compiled_module_fn,
args=args,
)