mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Prereqs: - https://github.com/pytorch/pytorch/pull/152708 Features: 1. Adds inductor's estimate of flops and bandwidth to the json trace events that perfetto uses. 1. Only use the tflops estimation from triton if we don't have the info from the datasheet because Triton's estimates are inaccurate. I have a backlog item to fix triton flops estimation upstream. New `DeviceInfo` class, and new function `get_device_tflops`. 1. New helpers `countable_fx` and `count_flops_fx` helps get the flops of an `fx.Node`. 1. Extends Triton `torch.profiler` logging to `DebugAutotuner`. 1. New script `profile_analysis.py`: `--augment_trace` adds perf estimates to any perfetto json trace, `--analyze` creates a summary table of these perf estimates, and `--diff` will compare two traces side by side: ```python Device(NVIDIA H100, 0): Kernel Name | resnet Kernel Count | resnet FLOPS | resnet bw gbps | resnet Dur (ms) | resnet Achieved FLOPS % | resnet Achieved Bandwidth % | newresnet Kernel Count | newresnet FLOPS | newresnet bw gbps | newresnet Dur (ms) | newresnet Achieved FLOPS % | newresnet Achieved Bandwidth % --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- triton_poi_fused__native_batch_norm_legi | 24 | 0 | 0.11395268248131513 | 2.5919166666666666 | 0 | 0.003401572611382541 | 24 | 0 | 0.11395268248131513 | 2.5919166666666666 | 0 | 0.003401572611382541 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 142 | 16932673552.422373 | 0.2585007824198784 | 12.441619718309857 | 0.08683422334575583 | 0.007716441266265022 | 142 | 16932673552.422373 | 0.2585007824198784 | 12.441619718309857 | 0.08683422334575583 | 0.007716441266265022 triton_red_fused__native_batch_norm_legi | 39 | 0 | 0.13990024992108846 | 5.752589743589743 | 0 | 0.004176126863316074 | 39 | 0 | 0.13990024992108846 | 5.752589743589743 | 0 | 0.004176126863316074 triton_poi_fused__native_batch_norm_legi | 25 | 0 | 0.31824055917536503 | 2.5291999999999994 | 0 | 0.009499718184339253 | 25 | 0 | 0.31824055917536503 | 2.5291999999999994 | 0 | 0.009499718184339253 void cutlass::Kernel2<cutlass_80_tensoro | 98 | 16211056473.596165 | 0.42972434051025826 | 7.130408163265306 | 0.08313362294151874 | 0.012827592254037562 | 98 | 16211056473.596165 | 0.42972434051025826 | 7.130408163265306 | 0.08313362294151874 | 0.012827592254037562 triton_red_fused__native_batch_norm_legi | 73 | 0 | 0.3225381327611705 | 9.987068493150682 | 0 | 0.009628003963020014 | 73 | 0 | 0.3225381327611705 | 9.987068493150682 | 0 | 0.009628003963020014 triton_poi_fused__native_batch_norm_legi | 15 | 0 | 1.4491211346487216 | 4.439333333333333 | 0 | 0.043257347302946926 | 15 | 0 | 1.4491211346487216 | 4.439333333333333 | 0 | 0.043257347302946926 void cutlass::Kernel2<cutlass_80_tensoro | 186 | 14501701145.337954 | 0.2667131401910989 | 7.873865591397849 | 0.07436769818122027 | 0.007961586274361157 | 186 | 14501701145.337954 | 0.2667131401910989 | 7.873865591397849 | 0.07436769818122027 | 0.007961586274361157 triton_poi_fused__native_batch_norm_legi | 33 | 0 | 1.4924556538193923 | 4.3101515151515155 | 0 | 0.044550915039384846 | 33 | 0 | 1.4924556538193923 | 4.3101515151515155 | 0 | 0.044550915039384846 triton_red_fused__native_batch_norm_legi | 29 | 0 | 0.25562590522631107 | 6.296275862068965 | 0 | 0.007630624036606301 | 29 | 0 | 0.25562590522631107 | 6.296275862068965 | 0 | 0.007630624036606301 triton_poi_fused__native_batch_norm_legi | 13 | 0 | 0.5870562174192726 | 2.7397692307692307 | 0 | 0.01752406619162008 | 13 | 0 | 0.5870562174192726 | 2.7397692307692307 | 0 | 0.01752406619162008 triton_poi_fused__native_batch_norm_legi | 34 | 0 | 0.41409928846284 | 2.853588235294117 | 0 | 0.012361172789935523 | 34 | 0 | 0.41409928846284 | 2.853588235294117 | 0 | 0.012361172789935523 triton_per_fused__native_batch_norm_legi | 34 | 0 | 0.11705315007018151 | 3.460647058823529 | 0 | 0.0034941238826919864 | 34 | 0 | 0.11705315007018151 | 3.460647058823529 | 0 | 0.0034941238826919864 triton_poi_fused__native_batch_norm_legi | 16 | 0 | 0.17207853197124584 | 2.3459375000000002 | 0 | 0.005136672596156592 | 16 | 0 | 0.17207853197124584 | 2.3459375000000002 | 0 | 0.005136672596156592 triton_per_fused__native_batch_norm_legi | 30 | 0 | 0.2639714322022256 | 6.131199999999999 | 0 | 0.007879744244842555 | 30 | 0 | 0.2639714322022256 | 6.131199999999999 | 0 | 0.007879744244842555 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 100 | 11875430356.891787 | 0.19494470869421385 | 16.36534 | 0.06089964285585531 | 0.005819245035648175 | 100 | 11875430356.891787 | 0.19494470869421385 | 16.36534 | 0.06089964285585531 | 0.005819245035648175 triton_poi_fused__native_batch_norm_legi | 8 | 0 | 0.9854096626224687 | 3.2757500000000004 | 0 | 0.029415213809625928 | 8 | 0 | 0.9854096626224687 | 3.2757500000000004 | 0 | 0.029415213809625928 void cublasLt::splitKreduce_kernel<32, 1 | 56 | 34377923395.147064 | 0.8310300045762317 | 3.4199999999999986 | 0.17629704305203628 | 0.024806865808245714 | 56 | 34377923395.147064 | 0.8310300045762317 | 3.4199999999999986 | 0.17629704305203628 | 0.024806865808245714 triton_poi_fused__native_batch_norm_legi | 23 | 0 | 0.9944002965861103 | 3.2431304347826084 | 0 | 0.02968359094286896 | 23 | 0 | 0.9944002965861103 | 3.2431304347826084 | 0 | 0.02968359094286896 triton_per_fused__native_batch_norm_legi | 10 | 0 | 0.1826801058931057 | 4.428800000000001 | 0 | 0.00545313748934644 | 10 | 0 | 0.1826801058931057 | 4.428800000000001 | 0 | 0.00545313748934644 triton_poi_fused__native_batch_norm_legi | 10 | 0 | 0.3168973585366449 | 2.5471999999999997 | 0 | 0.009459622642884923 | 10 | 0 | 0.3168973585366449 | 2.5471999999999997 | 0 | 0.009459622642884923 triton_poi_fused__native_batch_norm_legi | 34 | 0 | 1.1463614897015777 | 4.124323529411764 | 0 | 0.03421974596124114 | 34 | 0 | 1.1463614897015777 | 4.124323529411764 | 0 | 0.03421974596124114 void cask_plugin_cudnn::xmma_cudnn::init | 44 | 44045510816.64277 | 2.0661232850348643 | 3.6887499999999993 | 0.22587441444432194 | 0.06167532194133924 | 44 | 44045510816.64277 | 2.0661232850348643 | 3.6887499999999993 | 0.22587441444432194 | 0.06167532194133924 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 95 | 7876855400.165316 | 0.4694941555946739 | 18.224315789473682 | 0.04039413025725802 | 0.014014750913273854 | 95 | 7876855400.165316 | 0.4694941555946739 | 18.224315789473682 | 0.04039413025725802 | 0.014014750913273854 triton_per_fused__native_batch_norm_legi | 41 | 0 | 0.06825669875995298 | 3.0384146341463416 | 0 | 0.002037513395819492 | 41 | 0 | 0.06825669875995298 | 3.0384146341463416 | 0 | 0.002037513395819492 triton_poi_fused__native_batch_norm_legi | 23 | 0 | 0.08808154712430301 | 2.3275652173913044 | 0 | 0.0026292999141582997 | 23 | 0 | 0.08808154712430301 | 2.3275652173913044 | 0 | 0.0026292999141582997 triton_per_fused__native_batch_norm_legi | 40 | 0 | 0.18179321034952417 | 4.556825 | 0 | 0.005426662995508183 | 40 | 0 | 0.18179321034952417 | 4.556825 | 0 | 0.005426662995508183 triton_poi_fused__native_batch_norm_legi | 15 | 0 | 0.5887415155454232 | 2.783866666666667 | 0 | 0.017574373598370836 | 15 | 0 | 0.5887415155454232 | 2.783866666666667 | 0 | 0.017574373598370836 void cutlass::Kernel2<cutlass_80_tensoro | 38 | 14242013806.264643 | 0.256592404353939 | 7.217631578947369 | 0.0730359682372546 | 0.007659474756834 | 38 | 14242013806.264643 | 0.256592404353939 | 7.217631578947369 | 0.0730359682372546 | 0.007659474756834 triton_poi_fused__native_batch_norm_legi | 21 | 0 | 0.5842860973430516 | 2.7779047619047623 | 0 | 0.017441376040091088 | 21 | 0 | 0.5842860973430516 | 2.7779047619047623 | 0 | 0.017441376040091088 triton_per_fused__native_batch_norm_legi | 16 | 0 | 0.11509365173486417 | 3.5959375000000002 | 0 | 0.0034356313950705724 | 16 | 0 | 0.11509365173486417 | 3.5959375000000002 | 0 | 0.0034356313950705724 triton_poi_fused__native_batch_norm_legi | 14 | 0 | 0.1704672000243914 | 2.4044285714285714 | 0 | 0.00508857313505646 | 14 | 0 | 0.1704672000243914 | 2.4044285714285714 | 0 | 0.00508857313505646 triton_poi_fused__native_batch_norm_legi | 58 | 0 | 2.307520779930795 | 8.190706896551722 | 0 | 0.06888121731136704 | 58 | 0 | 2.307520779930795 | 8.190706896551722 | 0 | 0.06888121731136704 triton_per_fused__native_batch_norm_legi | 29 | 0 | 0.037243248971881276 | 3.0277586206896556 | 0 | 0.001111738775280038 | 29 | 0 | 0.037243248971881276 | 3.0277586206896556 | 0 | 0.001111738775280038 triton_poi_fused__native_batch_norm_legi | 20 | 0 | 0.04741699795428918 | 2.2911500000000005 | 0 | 0.0014154327747549007 | 20 | 0 | 0.04741699795428918 | 2.2911500000000005 | 0 | 0.0014154327747549007 triton_per_fused__native_batch_norm_legi | 25 | 0 | 0.13357016893727824 | 3.37536 | 0 | 0.003987169222008305 | 25 | 0 | 0.13357016893727824 | 3.37536 | 0 | 0.003987169222008305 triton_poi_fused__native_batch_norm_legi | 13 | 0 | 0.3089862268300253 | 2.8111538461538457 | 0 | 0.009223469457612694 | 13 | 0 | 0.3089862268300253 | 2.8111538461538457 | 0 | 0.009223469457612694 triton_poi_fused__native_batch_norm_legi | 17 | 0 | 0.3129385387909844 | 2.673 | 0 | 0.009341448919133863 | 17 | 0 | 0.3129385387909844 | 2.673 | 0 | 0.009341448919133863 triton_per_fused__native_batch_norm_legi | 19 | 0 | 0.2215568162533158 | 3.8837368421052636 | 0 | 0.0066136363060691275 | 19 | 0 | 0.2215568162533158 | 3.8837368421052636 | 0 | 0.0066136363060691275 std::enable_if<!(false), void>::type int | 23 | 504916805.19297093 | 1.0118296096314707 | 8.113913043478261 | 0.0025893169497075447 | 0.030203868944223014 | 23 | 504916805.19297093 | 1.0118296096314707 | 8.113913043478261 | 0.0025893169497075447 | 0.030203868944223014 triton_poi_fused_add_copy__38 | 56 | 0 | 0 | 2.132482142857143 | 0 | 0 | 56 | 0 | 0 | 2.132482142857143 | 0 | 0 triton_poi_fused_convolution_0 | 18 | 0 | 0.43458610794936897 | 2.773333333333334 | 0 | 0.012972719640279667 | 18 | 0 | 0.43458610794936897 | 2.773333333333334 | 0 | 0.012972719640279667 triton_poi_fused_convolution_1 | 17 | 0 | 0.028816312469162712 | 2.6145882352941174 | 0 | 0.0008601884319153051 | 17 | 0 | 0.028816312469162712 | 2.6145882352941174 | 0 | 0.0008601884319153051 void convolve_common_engine_float_NHWC<f | 44 | 8641868995.31118 | 0.024730540008465626 | 25.87327272727273 | 0.04431727689903169 | 0.0007382250748795709 | 44 | 8641868995.31118 | 0.024730540008465626 | 25.87327272727273 | 0.04431727689903169 | 0.0007382250748795709 triton_per_fused__native_batch_norm_legi | 12 | 0 | 0.6809930918986744 | 4.82675 | 0 | 0.020328151996975356 | 12 | 0 | 0.6809930918986744 | 4.82675 | 0 | 0.020328151996975356 triton_per_fused__native_batch_norm_legi | 14 | 0 | 0.02883030597936608 | 2.6651428571428575 | 0 | 0.0008606061486377935 | 14 | 0 | 0.02883030597936608 | 2.6651428571428575 | 0 | 0.0008606061486377935 triton_per_fused__native_batch_norm_legi | 16 | 0 | 0.0014658988233201874 | 2.098 | 0 | 4.375817383045335e-05 | 16 | 0 | 0.0014658988233201874 | 2.098 | 0 | 4.375817383045335e-05 triton_poi_fused__native_batch_norm_legi | 13 | 0 | 0.9926297180284697 | 3.2367692307692306 | 0 | 0.02963073785159611 | 13 | 0 | 0.9926297180284697 | 3.2367692307692306 | 0 | 0.02963073785159611 triton_poi_fused__native_batch_norm_legi | 9 | 0 | 1.3008817095666507 | 3.0863333333333336 | 0 | 0.03883228983781048 | 9 | 0 | 1.3008817095666507 | 3.0863333333333336 | 0 | 0.03883228983781048 void at::native::(anonymous namespace):: | 98 | 0 | 0.09174335613709389 | 4.408520408163265 | 0 | 0.0027386076458833994 | 98 | 0 | 0.09174335613709389 | 4.408520408163265 | 0 | 0.0027386076458833994 void at::native::vectorized_elementwise_ | 7 | 0 | 0 | 1.7278571428571428 | 0 | 0 | 7 | 0 | 0 | 1.7278571428571428 | 0 | 0 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/149697 Approved by: https://github.com/eellison, https://github.com/shunting314
718 lines
23 KiB
Python
718 lines
23 KiB
Python
import json
|
|
import logging
|
|
import math
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from typing import Any, Callable, Optional, Union
|
|
|
|
import torch
|
|
from torch._inductor.analysis.device_info import DeviceInfo, lookup_device_info
|
|
from torch._inductor.utils import tabulate_2d, zip_dicts
|
|
from torch.utils import _pytree as pytree
|
|
from torch.utils._ordered_set import OrderedSet
|
|
from torch.utils.flop_counter import flop_registry
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
ATEN_PREFIX = "aten::"
|
|
|
|
|
|
@dataclass
|
|
class ProfileEvent:
|
|
category: str
|
|
key: str
|
|
self_device_time_ms: float
|
|
# the benchmark is run multiple times and we average the count across all the
|
|
# runs. It should be an integer but define a float just in case.
|
|
count: float
|
|
|
|
|
|
# adapters convert the json trace into a format that works with flops_counter
|
|
ArgsType = tuple[tuple[Any, ...], dict[Any, Any]]
|
|
AdapterType = Callable[[tuple[Any, ...], tuple[Any, ...]], ArgsType]
|
|
adapters_map: dict[str, AdapterType] = {}
|
|
|
|
|
|
def parse_list(lst: str) -> list[int]:
|
|
lst = lst.replace("[", "").replace("]", "")
|
|
substrings = lst.split(",")
|
|
|
|
return [int(substring.strip()) for substring in substrings]
|
|
|
|
|
|
def register_adapter(
|
|
aten: Union[str, list[str]],
|
|
) -> Callable[
|
|
[AdapterType],
|
|
AdapterType,
|
|
]:
|
|
def decorator(func: AdapterType) -> AdapterType:
|
|
global _adapters_map
|
|
|
|
if isinstance(aten, str):
|
|
adapters_map[aten] = func
|
|
else:
|
|
for at in aten:
|
|
adapters_map[at] = func
|
|
return func
|
|
|
|
return decorator
|
|
|
|
|
|
@register_adapter(["_slow_conv2d_forward"])
|
|
def _slow_conv2d_adapter(
|
|
shapes: tuple[Any, ...], concrete: tuple[Any, ...]
|
|
) -> tuple[tuple[Any], dict[Any, Any]]:
|
|
tmp = list(shapes)
|
|
tmp.append(False)
|
|
tmp2 = list(concrete)
|
|
tmp2[3] = tmp2[4]
|
|
return conv_adapter(tuple(tmp), tuple(tmp2))
|
|
|
|
|
|
@register_adapter(["convolution", "_convolution", "cudnn_convolution"])
|
|
def conv_adapter(
|
|
shapes: tuple[Any, ...], concrete: tuple[Any, ...]
|
|
) -> tuple[tuple[Any], dict[Any, Any]]:
|
|
tmp = list(shapes)
|
|
if len(tmp) == 4:
|
|
transposed = False
|
|
else:
|
|
transposed = bool(tmp[6])
|
|
tmp[6] = transposed
|
|
|
|
kwargs: dict[Any, Any] = {}
|
|
if not transposed:
|
|
# calculate output shape if not transposed.
|
|
def conv_out_dims(x: int, kernel: int, stride: int) -> int:
|
|
return (x - kernel) // stride + 1
|
|
|
|
stride = parse_list(concrete[3])
|
|
inp = shapes[0]
|
|
w = shapes[1]
|
|
out_x_y = [conv_out_dims(*args) for args in zip(inp[2:], w[2:], stride)]
|
|
out = [inp[0], w[0]] + out_x_y # we only need the xy values
|
|
kwargs["out_val"] = out
|
|
|
|
return tuple(tmp), kwargs
|
|
|
|
|
|
def default_adapter(
|
|
shapes: tuple[Any], concrete: tuple[Any]
|
|
) -> tuple[tuple[Any], dict[Any, Any]]:
|
|
return shapes, {}
|
|
|
|
|
|
@register_adapter("addmm")
|
|
def addmm_adapter(
|
|
shapes: tuple[Any], concrete: tuple[Any]
|
|
) -> tuple[tuple[Any], dict[Any, Any]]:
|
|
tmp = list(shapes)[:3]
|
|
return tuple(tmp), {}
|
|
|
|
|
|
@register_adapter("bmm")
|
|
def bmm_adapter(
|
|
shapes: tuple[Any], concrete: tuple[Any]
|
|
) -> tuple[tuple[Any], dict[Any, Any]]:
|
|
tmp = list(shapes)
|
|
return tuple(tmp[:2]), {}
|
|
|
|
|
|
@register_adapter("baddbmm")
|
|
def baddbmm_adapter(
|
|
shapes: tuple[Any], concrete: tuple[Any]
|
|
) -> tuple[tuple[Any], dict[Any, Any]]:
|
|
tmp = list(shapes)[:3]
|
|
return tuple(tmp), {}
|
|
|
|
|
|
@register_adapter("mm")
|
|
def mm_adapter(
|
|
shapes: tuple[Any], concrete: tuple[Any]
|
|
) -> tuple[tuple[Any], dict[Any, Any]]:
|
|
return shapes, {}
|
|
|
|
|
|
def _parse_kernel_name(name: str) -> Optional[str]:
|
|
"""
|
|
parse the name of the kernel from the event name.
|
|
"""
|
|
if name.startswith(ATEN_PREFIX):
|
|
return name[len(ATEN_PREFIX) :]
|
|
elif "conv" in name:
|
|
return "convolution"
|
|
elif "addmm" in name:
|
|
return "addmm"
|
|
elif "bmm" in name:
|
|
return "bmm"
|
|
elif "baddbmm" in name:
|
|
return "baddbmm"
|
|
elif "_mm" in name:
|
|
return "mm"
|
|
else:
|
|
return None
|
|
|
|
|
|
def _calculate_flops(event: dict[str, Any]) -> int:
|
|
"""
|
|
This function has to parse the kernel name, which is error prone. There doesn't seem to be another solution that
|
|
will support all the different backends that can generate kernels, so make sure to update this function when new
|
|
ops and backends are desired.
|
|
"""
|
|
name = event["name"]
|
|
if "kernel_flop" in event["args"] and event["args"]["kernel_flop"] != 0:
|
|
return event["args"]["kernel_flop"]
|
|
op_name = _parse_kernel_name(name)
|
|
if op_name is None:
|
|
return 0
|
|
|
|
op_obj = getattr(torch.ops.aten, op_name, None)
|
|
if op_obj is None or op_obj not in flop_registry:
|
|
return 0
|
|
|
|
flop_function = flop_registry[op_obj]
|
|
|
|
assert "Input Dims" in event["args"] and "Concrete Inputs" in event["args"]
|
|
input_shapes = event["args"]["Input Dims"]
|
|
concrete = event["args"]["Concrete Inputs"]
|
|
if op_name in adapters_map:
|
|
args, kwargs = adapters_map[op_name](input_shapes, concrete)
|
|
else:
|
|
args, kwargs = default_adapter(input_shapes, concrete)
|
|
return flop_function(*args, **kwargs)
|
|
|
|
|
|
def _get_size_from_string(type_string: str) -> int:
|
|
if not hasattr(torch, type_string):
|
|
return 1
|
|
else:
|
|
return getattr(torch, type_string).itemsize
|
|
|
|
|
|
def _default_estimate_gb(event: dict[str, Any]) -> float:
|
|
sizes_and_types = zip(event["args"]["Input Dims"], event["args"]["Input type"])
|
|
bw = 0
|
|
for size, typ in sizes_and_types:
|
|
isize = _get_size_from_string(typ)
|
|
bw += isize * math.prod(pytree.tree_flatten(size)[0])
|
|
return bw / 1e9
|
|
|
|
|
|
def _estimate_gb(event: dict[str, Any]) -> float:
|
|
"""
|
|
Our best effort to estimate the gb, should be refactored soon with MemoryCounter.
|
|
"""
|
|
name = event["name"]
|
|
if "kernel_num_gb" in event["args"] and event["args"]["kernel_num_gb"] != 0:
|
|
return event["args"]["kernel_num_gb"]
|
|
if "Input type" not in event["args"] or "Input Dims" not in event["args"]:
|
|
return 0
|
|
op_name = _parse_kernel_name(name)
|
|
if op_name is None:
|
|
return _default_estimate_gb(event)
|
|
|
|
op_obj = getattr(torch.ops.aten, op_name, None)
|
|
if op_obj is None:
|
|
return _default_estimate_gb(event)
|
|
|
|
assert "Input Dims" in event["args"] and "Concrete Inputs" in event["args"]
|
|
input_shapes = event["args"]["Input Dims"]
|
|
|
|
# NOTE these will be refactored into a similar object to FlopCounter soon
|
|
def mm_formula(M: int, N: int, K: int, size: int) -> int:
|
|
return 2 * (M * K + N * K + M * N) * size
|
|
|
|
if op_name == "addmm":
|
|
add_in_size = math.prod(pytree.tree_flatten(input_shapes[0])[0])
|
|
add_type_size = _get_size_from_string(event["args"]["Input type"][0])
|
|
M = input_shapes[1][0]
|
|
N = input_shapes[1][1]
|
|
assert input_shapes[1][1] == input_shapes[2][0]
|
|
K = input_shapes[2][1]
|
|
mul_type_size = _get_size_from_string(event["args"]["Input type"][1])
|
|
return (mm_formula(M, N, K, mul_type_size) + add_in_size * add_type_size) / 1e9
|
|
elif op_name == "mm":
|
|
M = input_shapes[0][0]
|
|
N = input_shapes[0][1]
|
|
assert input_shapes[0][1] == input_shapes[1][0]
|
|
K = input_shapes[1][1]
|
|
type_size = _get_size_from_string(event["args"]["Input type"][0])
|
|
return mm_formula(M, N, K, type_size) / 1e9
|
|
elif op_name == "baddbmm":
|
|
add_in_size = math.prod(pytree.tree_flatten(input_shapes[0])[0])
|
|
add_type_size = _get_size_from_string(event["args"]["Input type"][0])
|
|
B = input_shapes[0][0]
|
|
M = input_shapes[1][1]
|
|
N = input_shapes[1][2]
|
|
K = input_shapes[2][2]
|
|
mul_type_size = _get_size_from_string(event["args"]["Input type"][1])
|
|
return (
|
|
B * mm_formula(M, N, K, mul_type_size) + add_in_size * add_type_size
|
|
) / 1e9
|
|
elif op_name == "bmm":
|
|
add_in_size = math.prod(pytree.tree_flatten(input_shapes[0])[0])
|
|
add_type_size = _get_size_from_string(event["args"]["Input type"][0])
|
|
B = input_shapes[0][0]
|
|
M = input_shapes[0][1]
|
|
N = input_shapes[0][2]
|
|
K = input_shapes[1][2]
|
|
mul_type_size = _get_size_from_string(event["args"]["Input type"][1])
|
|
return (
|
|
B * mm_formula(M, N, K, mul_type_size) + add_in_size * add_type_size
|
|
) / 1e9
|
|
elif op_name in [
|
|
"convolution",
|
|
"_convolution",
|
|
"cudnn_convolution",
|
|
"_slow_conv2d_forward",
|
|
]:
|
|
concrete = event["args"]["Concrete Inputs"]
|
|
|
|
def conv_out_dim(x: int, kernel: int, stride: int) -> int:
|
|
return (x - kernel) // stride + 1
|
|
|
|
stride = parse_list(
|
|
concrete[3] if op_name != "_slow_conv2d_forward" else concrete[4]
|
|
)
|
|
inp = input_shapes[0]
|
|
w = input_shapes[1]
|
|
out_x_y = [conv_out_dim(*args) for args in zip(inp[2:], w[2:], stride)]
|
|
out = [inp[0], w[0]] + out_x_y
|
|
# each output element reads in * w * w chunk
|
|
input_reads = out[0] * out[1] * out[2] * out[3] * inp[1] * w[2] * w[3]
|
|
# Assume weights are in cache, so only read once
|
|
weight_reads = w[0] * w[1] * w[2] * w[3]
|
|
return (input_reads + weight_reads) / 1e9
|
|
|
|
return _default_estimate_gb(event)
|
|
|
|
|
|
def _create_extern_mapping(
|
|
data: dict[str, Any],
|
|
) -> defaultdict[int, list[dict[str, Any]]]:
|
|
"""
|
|
compute a mapping from exteral ids to non kernels, which contain the information we need to estimate flops etc
|
|
"""
|
|
extern_mapping: defaultdict[int, list[dict[str, Any]]] = defaultdict(list)
|
|
for event in data["traceEvents"]:
|
|
if (
|
|
"args" not in event
|
|
or "External id" not in event["args"]
|
|
or event["cat"] != "cpu_op"
|
|
):
|
|
continue
|
|
if len(extern_mapping[event["args"]["External id"]]) > 0:
|
|
raise ParseException("duplicate external id in event")
|
|
extern_mapping[event["args"]["External id"]].append(event)
|
|
return extern_mapping
|
|
|
|
|
|
def _augment_trace_helper(data: dict[str, Any]) -> dict[str, Any]:
|
|
extern_mapping = _create_extern_mapping(data)
|
|
|
|
for event in data["traceEvents"]:
|
|
if "cat" not in event or event["cat"] != "kernel":
|
|
continue
|
|
if "args" not in event:
|
|
raise ParseException(f"kernel has no args: {event}")
|
|
if "External id" not in event["args"]:
|
|
event_str = f"kernel has no External id: {event}"
|
|
log.info(event_str)
|
|
continue
|
|
|
|
external_op = extern_mapping[event["args"]["External id"]][0]
|
|
flops = _calculate_flops(external_op)
|
|
if flops == 0:
|
|
flops = _calculate_flops(event)
|
|
external_op["args"]["kernel_flop"] = flops
|
|
external_op["args"]["kernel_num_gb"] = _estimate_gb(external_op)
|
|
event["args"]["kernel_flop"] = external_op["args"]["kernel_flop"]
|
|
event["args"]["kernel_num_gb"] = external_op["args"]["kernel_num_gb"]
|
|
return data
|
|
|
|
|
|
_dtype_map = {
|
|
"float": torch.float,
|
|
"float32": torch.float,
|
|
"int": torch.int,
|
|
"int8": torch.int8,
|
|
"int16": torch.int16,
|
|
"int32": torch.int,
|
|
"long": torch.long,
|
|
"long int": torch.long,
|
|
"bfloat16": torch.bfloat16,
|
|
"float16": torch.float16,
|
|
"float64": torch.double,
|
|
}
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class KernelStats:
|
|
flops: int
|
|
bw: float
|
|
latency: float # us
|
|
achieved_flops: float
|
|
achieved_bandwidth: float
|
|
|
|
|
|
KernelNameMap = defaultdict[str, OrderedSet[KernelStats]]
|
|
|
|
|
|
@dataclass(frozen=False)
|
|
class Device:
|
|
name: str
|
|
index: int
|
|
info: Optional[DeviceInfo]
|
|
stats: KernelNameMap
|
|
|
|
def __repr__(self) -> str:
|
|
return f"Device({self.name}, {self.index}): {self.info}"
|
|
|
|
|
|
DeviceMap = dict[int, Device]
|
|
Table = tuple[list[str], dict[str, list[str]]]
|
|
|
|
|
|
class JsonProfile:
|
|
_devices: DeviceMap
|
|
|
|
def __init__(
|
|
self,
|
|
path: str,
|
|
benchmark_name: Optional[str] = None,
|
|
dtype: Optional[Union[torch.dtype, str]] = None,
|
|
):
|
|
"""
|
|
Convienence class for running common operations on chrome/perfetto json traces.
|
|
"""
|
|
self.path = path
|
|
with open(path) as f:
|
|
self.data = json.load(f)
|
|
self.events = self.data["traceEvents"]
|
|
self.benchmark_name = benchmark_name
|
|
if dtype is None:
|
|
self.dtype = None
|
|
elif isinstance(dtype, torch.dtype):
|
|
self.dtype = dtype
|
|
else:
|
|
if dtype in _dtype_map:
|
|
self.dtype = _dtype_map[dtype]
|
|
else:
|
|
self.dtype = None
|
|
self._create_devices()
|
|
|
|
def convert_dtype(self, event: dict[str, Any]) -> Optional[torch.dtype]:
|
|
"""
|
|
Each op has a list of dtypes for each input arg. We need to convert these into a single dtype for flop estimation.
|
|
Issues:
|
|
- converting the strings to concrete torch.dtypes
|
|
- What if we have float32, float, float16 all in the inputs? Our choice is to use the largest buffer dtype.
|
|
"""
|
|
|
|
if (
|
|
"Input Dims" not in event["args"]
|
|
or "Input type" not in event["args"]
|
|
or "Concrete Inputs" not in event["args"]
|
|
):
|
|
if "bfloat16" in event["name"]:
|
|
return torch.bfloat16
|
|
elif "float16" in event["name"]:
|
|
return torch.float16
|
|
else:
|
|
return None
|
|
|
|
input_sizes = event["args"]["Input Dims"]
|
|
input_types = event["args"]["Input type"]
|
|
concrete_inputs = event["args"]["Concrete Inputs"]
|
|
assert len(input_sizes) == len(input_types)
|
|
assert len(input_types) == len(concrete_inputs)
|
|
|
|
if len(input_sizes) == 0:
|
|
raise RuntimeError("Empty input_sizes and input_types")
|
|
|
|
biggest_size = 0
|
|
biggest_index = 0
|
|
for i in range(len(input_sizes)):
|
|
if concrete_inputs[i] != "":
|
|
# concrete inputs are usually small tensors, so we can just skip
|
|
continue
|
|
my_size = input_sizes[i]
|
|
total_size = sum(parse_list(my_size))
|
|
if total_size > biggest_size:
|
|
biggest_size = total_size
|
|
biggest_index = i
|
|
ret_type = input_types[biggest_index]
|
|
if ret_type in _dtype_map:
|
|
return _dtype_map[ret_type]
|
|
raise RuntimeError(f"Unknown type: {ret_type}. Please add to _dtype_map.")
|
|
|
|
def _create_devices(self) -> None:
|
|
self._devices = {}
|
|
for dev in self.data["deviceProperties"]:
|
|
name = dev["name"]
|
|
device_info = lookup_device_info(name)
|
|
|
|
if device_info is None:
|
|
log.info(
|
|
"Unsupported device in profile: %s, please consider contributing to _device_mapping.",
|
|
name,
|
|
)
|
|
self._devices[dev["id"]] = Device(
|
|
name, dev["id"], device_info, defaultdict(OrderedSet)
|
|
)
|
|
|
|
def calculate_flops(self, event: dict[str, Any]) -> int:
|
|
return _calculate_flops(event)
|
|
|
|
def estimate_gb(self, event: dict[str, Any]) -> float:
|
|
return _estimate_gb(event)
|
|
|
|
def augment_trace(self) -> None:
|
|
self.data = _augment_trace_helper(self.data)
|
|
|
|
def _compute_stats(self) -> None:
|
|
"""populates the name -> stats map"""
|
|
for event in self.events:
|
|
if "cat" not in event or "args" not in event or event["cat"] != "kernel":
|
|
continue
|
|
dev = self._devices[event["args"]["device"]]
|
|
|
|
dur = event["dur"] # us
|
|
if "kernel_flop" in event["args"]:
|
|
assert dur != 0
|
|
# 1,000,000us/s * flop / us
|
|
op_flops = event["args"]["kernel_flop"] / (dur / 1e6)
|
|
else:
|
|
op_flops = 0
|
|
|
|
if "kernel_num_gb" in event["args"]:
|
|
assert dur != 0
|
|
# 1,000,000us/s * gb = gb/s
|
|
op_gbps = event["args"]["kernel_num_gb"] / (dur / 1e6)
|
|
else:
|
|
op_gbps = 0
|
|
|
|
if dev.info is not None:
|
|
dtype = self.convert_dtype(event) or self.dtype
|
|
if dtype is None:
|
|
raise RuntimeError(
|
|
"dtype is not found on tensor and default dtype is not set"
|
|
)
|
|
achieved_flops = 100 * op_flops / (1e12 * dev.info.tops[dtype])
|
|
achieved_bandwidth = 100 * op_gbps / dev.info.dram_bw_gbs
|
|
else:
|
|
achieved_flops = 0
|
|
achieved_bandwidth = 0
|
|
|
|
dev.stats[event["name"]].add(
|
|
KernelStats(
|
|
flops=op_flops,
|
|
bw=op_gbps,
|
|
latency=dur,
|
|
achieved_bandwidth=achieved_bandwidth,
|
|
achieved_flops=achieved_flops,
|
|
)
|
|
)
|
|
|
|
def _create_single_table(self, dev: Device) -> Table:
|
|
"""Create a table with the devices mapped to indices."""
|
|
headers = [
|
|
"Kernel Name",
|
|
"Kernel Count",
|
|
"FLOPS",
|
|
"Kernel Reads (GB)",
|
|
"Dur (us)",
|
|
"Achieved FLOPS %",
|
|
"Achieved Bandwidth %",
|
|
]
|
|
rows: dict[str, list[str]] = {}
|
|
|
|
def safe_div_format(x: float, y: float) -> str:
|
|
if y == 0:
|
|
return "0.0"
|
|
return f"{x / y:.4f}"
|
|
|
|
for kernel_name, stats_set in dev.stats.items():
|
|
ker_count = 0
|
|
flops = 0
|
|
flops_count = 0
|
|
achieved_flops = 0.0
|
|
bw = 0.0
|
|
bw_count = 0
|
|
achieved_bandwidth = 0.0
|
|
latency = 0.0
|
|
for stats in stats_set:
|
|
if stats.flops != 0:
|
|
flops += stats.flops
|
|
achieved_flops += stats.achieved_flops
|
|
flops_count += 1
|
|
if stats.bw != 0:
|
|
bw += stats.bw
|
|
achieved_bandwidth += stats.achieved_bandwidth
|
|
bw_count += 1
|
|
latency += stats.latency
|
|
ker_count += 1
|
|
assert ker_count != 0
|
|
rows[kernel_name] = [
|
|
str(ker_count),
|
|
safe_div_format(flops, flops_count),
|
|
safe_div_format(bw, bw_count),
|
|
safe_div_format(latency, ker_count),
|
|
safe_div_format(achieved_flops, flops_count),
|
|
safe_div_format(achieved_bandwidth, bw_count),
|
|
]
|
|
|
|
return headers, rows
|
|
|
|
def _create_tables(self, devs: DeviceMap) -> dict[int, Table]:
|
|
return {idx: self._create_single_table(dev) for idx, dev in devs.items()}
|
|
|
|
def _combine_tables(
|
|
self, table1: Table, table1_name: str, table2: Table, table2_name: str
|
|
) -> Table:
|
|
new_headers = (
|
|
["Kernel Name"]
|
|
+ [f"{table1_name} {head}" for head in table1[0][1:]]
|
|
+ [f"{table2_name} {head}" for head in table2[0][1:]]
|
|
)
|
|
t1_length = len(table1[0][1:])
|
|
t2_length = len(table2[0][1:])
|
|
new_rows = {}
|
|
|
|
for key, row1, row2 in zip_dicts(
|
|
table1[1],
|
|
table2[1],
|
|
d1_default=["Empty"] * t1_length,
|
|
d2_default=["Empty"] * t2_length,
|
|
):
|
|
assert row1 is not None
|
|
assert row2 is not None
|
|
new_rows[key] = row1 + row2
|
|
return new_headers, new_rows
|
|
|
|
def report(
|
|
self, other: Optional["JsonProfile"] = None, name_limit: int = 40
|
|
) -> str:
|
|
def create_ret(
|
|
table_headers: list[str], table_rows: dict[str, list[str]]
|
|
) -> str:
|
|
table_flattened = [
|
|
[kernel_name[:name_limit], *kernel_vals]
|
|
for kernel_name, kernel_vals in table_rows.items()
|
|
]
|
|
return tabulate_2d(table_flattened, headers=table_headers)
|
|
|
|
if other is not None:
|
|
self._compute_stats()
|
|
other._compute_stats()
|
|
|
|
self_tables = self._create_tables(self._devices)
|
|
other_tables = self._create_tables(other._devices)
|
|
|
|
self_name = (
|
|
self.benchmark_name if self.benchmark_name is not None else "Table 1"
|
|
)
|
|
other_name = (
|
|
other.benchmark_name if other.benchmark_name is not None else "Table 2"
|
|
)
|
|
|
|
ret = []
|
|
assert self._devices.keys() == other._devices.keys()
|
|
for device_idx, t1, t2 in zip_dicts(
|
|
self_tables, other_tables, d1_default=None, d2_default=None
|
|
):
|
|
assert t1 is not None
|
|
assert t2 is not None
|
|
table_headers, table_rows = self._combine_tables(
|
|
t1, self_name, t2, other_name
|
|
)
|
|
tab_string = create_ret(table_headers, table_rows)
|
|
ret.append(f"{self._devices[device_idx]}:\n{tab_string}")
|
|
return "\n".join(ret)
|
|
self._compute_stats()
|
|
|
|
self_tables = self._create_tables(self._devices)
|
|
|
|
ret = []
|
|
for idx, table in self_tables.items():
|
|
table_headers, table_rows = table
|
|
tab_string = create_ret(table_headers, table_rows)
|
|
ret.append(f"{self._devices[idx]}:\n{tab_string}")
|
|
return "\n".join(ret)
|
|
|
|
def dump(self, out: str) -> None:
|
|
with open(out, "w") as f:
|
|
json.dump(self.data, f)
|
|
|
|
|
|
class ParseException(RuntimeError):
|
|
pass
|
|
|
|
|
|
def main() -> None:
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--diff",
|
|
nargs=5,
|
|
metavar=(
|
|
"input_file1",
|
|
"name1",
|
|
"input_file2",
|
|
"name2",
|
|
"dtype",
|
|
),
|
|
help="Two json traces to compare with, specified as <file1> <name1> <file2> <name2> <dtype>",
|
|
)
|
|
parser.add_argument(
|
|
"--name_limit",
|
|
type=int,
|
|
help="the maximum name size in the final report",
|
|
)
|
|
parser.add_argument(
|
|
"--augment_trace",
|
|
"-a",
|
|
nargs=3,
|
|
metavar=("input_file", "output_file", "dtype"),
|
|
help="Augment a trace with inductor meta information. Provide input and output file paths.",
|
|
)
|
|
parser.add_argument(
|
|
"--analysis",
|
|
nargs=2,
|
|
metavar=("input_file", "dtype"),
|
|
help="Run analysis on a single trace, specified as <file> <dtype>",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
if args.diff:
|
|
p1 = JsonProfile(args.diff[0], args.diff[1], dtype=args.diff[4])
|
|
p1.augment_trace()
|
|
p2 = JsonProfile(args.diff[2], args.diff[3], dtype=args.diff[4])
|
|
p2.augment_trace()
|
|
if args.name_limit:
|
|
print(p1.report(p2, name_limit=args.name_limit))
|
|
else:
|
|
print(p1.report(p2))
|
|
if args.analysis:
|
|
p1 = JsonProfile(
|
|
args.analysis[0],
|
|
dtype=args.analysis[1],
|
|
)
|
|
p1.augment_trace()
|
|
if args.name_limit:
|
|
print(p1.report(name_limit=args.name_limit))
|
|
else:
|
|
print(p1.report())
|
|
if args.augment_trace:
|
|
p = JsonProfile(args.augment_trace[0], dtype=args.augment_trace[2])
|
|
p.augment_trace()
|
|
p.dump(args.augment_trace[1])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|