mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Prereqs: - https://github.com/pytorch/pytorch/pull/152708 Features: 1. Adds inductor's estimate of flops and bandwidth to the json trace events that perfetto uses. 1. Only use the tflops estimation from triton if we don't have the info from the datasheet because Triton's estimates are inaccurate. I have a backlog item to fix triton flops estimation upstream. New `DeviceInfo` class, and new function `get_device_tflops`. 1. New helpers `countable_fx` and `count_flops_fx` helps get the flops of an `fx.Node`. 1. Extends Triton `torch.profiler` logging to `DebugAutotuner`. 1. New script `profile_analysis.py`: `--augment_trace` adds perf estimates to any perfetto json trace, `--analyze` creates a summary table of these perf estimates, and `--diff` will compare two traces side by side: ```python Device(NVIDIA H100, 0): Kernel Name | resnet Kernel Count | resnet FLOPS | resnet bw gbps | resnet Dur (ms) | resnet Achieved FLOPS % | resnet Achieved Bandwidth % | newresnet Kernel Count | newresnet FLOPS | newresnet bw gbps | newresnet Dur (ms) | newresnet Achieved FLOPS % | newresnet Achieved Bandwidth % --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- triton_poi_fused__native_batch_norm_legi | 24 | 0 | 0.11395268248131513 | 2.5919166666666666 | 0 | 0.003401572611382541 | 24 | 0 | 0.11395268248131513 | 2.5919166666666666 | 0 | 0.003401572611382541 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 142 | 16932673552.422373 | 0.2585007824198784 | 12.441619718309857 | 0.08683422334575583 | 0.007716441266265022 | 142 | 16932673552.422373 | 0.2585007824198784 | 12.441619718309857 | 0.08683422334575583 | 0.007716441266265022 triton_red_fused__native_batch_norm_legi | 39 | 0 | 0.13990024992108846 | 5.752589743589743 | 0 | 0.004176126863316074 | 39 | 0 | 0.13990024992108846 | 5.752589743589743 | 0 | 0.004176126863316074 triton_poi_fused__native_batch_norm_legi | 25 | 0 | 0.31824055917536503 | 2.5291999999999994 | 0 | 0.009499718184339253 | 25 | 0 | 0.31824055917536503 | 2.5291999999999994 | 0 | 0.009499718184339253 void cutlass::Kernel2<cutlass_80_tensoro | 98 | 16211056473.596165 | 0.42972434051025826 | 7.130408163265306 | 0.08313362294151874 | 0.012827592254037562 | 98 | 16211056473.596165 | 0.42972434051025826 | 7.130408163265306 | 0.08313362294151874 | 0.012827592254037562 triton_red_fused__native_batch_norm_legi | 73 | 0 | 0.3225381327611705 | 9.987068493150682 | 0 | 0.009628003963020014 | 73 | 0 | 0.3225381327611705 | 9.987068493150682 | 0 | 0.009628003963020014 triton_poi_fused__native_batch_norm_legi | 15 | 0 | 1.4491211346487216 | 4.439333333333333 | 0 | 0.043257347302946926 | 15 | 0 | 1.4491211346487216 | 4.439333333333333 | 0 | 0.043257347302946926 void cutlass::Kernel2<cutlass_80_tensoro | 186 | 14501701145.337954 | 0.2667131401910989 | 7.873865591397849 | 0.07436769818122027 | 0.007961586274361157 | 186 | 14501701145.337954 | 0.2667131401910989 | 7.873865591397849 | 0.07436769818122027 | 0.007961586274361157 triton_poi_fused__native_batch_norm_legi | 33 | 0 | 1.4924556538193923 | 4.3101515151515155 | 0 | 0.044550915039384846 | 33 | 0 | 1.4924556538193923 | 4.3101515151515155 | 0 | 0.044550915039384846 triton_red_fused__native_batch_norm_legi | 29 | 0 | 0.25562590522631107 | 6.296275862068965 | 0 | 0.007630624036606301 | 29 | 0 | 0.25562590522631107 | 6.296275862068965 | 0 | 0.007630624036606301 triton_poi_fused__native_batch_norm_legi | 13 | 0 | 0.5870562174192726 | 2.7397692307692307 | 0 | 0.01752406619162008 | 13 | 0 | 0.5870562174192726 | 2.7397692307692307 | 0 | 0.01752406619162008 triton_poi_fused__native_batch_norm_legi | 34 | 0 | 0.41409928846284 | 2.853588235294117 | 0 | 0.012361172789935523 | 34 | 0 | 0.41409928846284 | 2.853588235294117 | 0 | 0.012361172789935523 triton_per_fused__native_batch_norm_legi | 34 | 0 | 0.11705315007018151 | 3.460647058823529 | 0 | 0.0034941238826919864 | 34 | 0 | 0.11705315007018151 | 3.460647058823529 | 0 | 0.0034941238826919864 triton_poi_fused__native_batch_norm_legi | 16 | 0 | 0.17207853197124584 | 2.3459375000000002 | 0 | 0.005136672596156592 | 16 | 0 | 0.17207853197124584 | 2.3459375000000002 | 0 | 0.005136672596156592 triton_per_fused__native_batch_norm_legi | 30 | 0 | 0.2639714322022256 | 6.131199999999999 | 0 | 0.007879744244842555 | 30 | 0 | 0.2639714322022256 | 6.131199999999999 | 0 | 0.007879744244842555 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 100 | 11875430356.891787 | 0.19494470869421385 | 16.36534 | 0.06089964285585531 | 0.005819245035648175 | 100 | 11875430356.891787 | 0.19494470869421385 | 16.36534 | 0.06089964285585531 | 0.005819245035648175 triton_poi_fused__native_batch_norm_legi | 8 | 0 | 0.9854096626224687 | 3.2757500000000004 | 0 | 0.029415213809625928 | 8 | 0 | 0.9854096626224687 | 3.2757500000000004 | 0 | 0.029415213809625928 void cublasLt::splitKreduce_kernel<32, 1 | 56 | 34377923395.147064 | 0.8310300045762317 | 3.4199999999999986 | 0.17629704305203628 | 0.024806865808245714 | 56 | 34377923395.147064 | 0.8310300045762317 | 3.4199999999999986 | 0.17629704305203628 | 0.024806865808245714 triton_poi_fused__native_batch_norm_legi | 23 | 0 | 0.9944002965861103 | 3.2431304347826084 | 0 | 0.02968359094286896 | 23 | 0 | 0.9944002965861103 | 3.2431304347826084 | 0 | 0.02968359094286896 triton_per_fused__native_batch_norm_legi | 10 | 0 | 0.1826801058931057 | 4.428800000000001 | 0 | 0.00545313748934644 | 10 | 0 | 0.1826801058931057 | 4.428800000000001 | 0 | 0.00545313748934644 triton_poi_fused__native_batch_norm_legi | 10 | 0 | 0.3168973585366449 | 2.5471999999999997 | 0 | 0.009459622642884923 | 10 | 0 | 0.3168973585366449 | 2.5471999999999997 | 0 | 0.009459622642884923 triton_poi_fused__native_batch_norm_legi | 34 | 0 | 1.1463614897015777 | 4.124323529411764 | 0 | 0.03421974596124114 | 34 | 0 | 1.1463614897015777 | 4.124323529411764 | 0 | 0.03421974596124114 void cask_plugin_cudnn::xmma_cudnn::init | 44 | 44045510816.64277 | 2.0661232850348643 | 3.6887499999999993 | 0.22587441444432194 | 0.06167532194133924 | 44 | 44045510816.64277 | 2.0661232850348643 | 3.6887499999999993 | 0.22587441444432194 | 0.06167532194133924 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 95 | 7876855400.165316 | 0.4694941555946739 | 18.224315789473682 | 0.04039413025725802 | 0.014014750913273854 | 95 | 7876855400.165316 | 0.4694941555946739 | 18.224315789473682 | 0.04039413025725802 | 0.014014750913273854 triton_per_fused__native_batch_norm_legi | 41 | 0 | 0.06825669875995298 | 3.0384146341463416 | 0 | 0.002037513395819492 | 41 | 0 | 0.06825669875995298 | 3.0384146341463416 | 0 | 0.002037513395819492 triton_poi_fused__native_batch_norm_legi | 23 | 0 | 0.08808154712430301 | 2.3275652173913044 | 0 | 0.0026292999141582997 | 23 | 0 | 0.08808154712430301 | 2.3275652173913044 | 0 | 0.0026292999141582997 triton_per_fused__native_batch_norm_legi | 40 | 0 | 0.18179321034952417 | 4.556825 | 0 | 0.005426662995508183 | 40 | 0 | 0.18179321034952417 | 4.556825 | 0 | 0.005426662995508183 triton_poi_fused__native_batch_norm_legi | 15 | 0 | 0.5887415155454232 | 2.783866666666667 | 0 | 0.017574373598370836 | 15 | 0 | 0.5887415155454232 | 2.783866666666667 | 0 | 0.017574373598370836 void cutlass::Kernel2<cutlass_80_tensoro | 38 | 14242013806.264643 | 0.256592404353939 | 7.217631578947369 | 0.0730359682372546 | 0.007659474756834 | 38 | 14242013806.264643 | 0.256592404353939 | 7.217631578947369 | 0.0730359682372546 | 0.007659474756834 triton_poi_fused__native_batch_norm_legi | 21 | 0 | 0.5842860973430516 | 2.7779047619047623 | 0 | 0.017441376040091088 | 21 | 0 | 0.5842860973430516 | 2.7779047619047623 | 0 | 0.017441376040091088 triton_per_fused__native_batch_norm_legi | 16 | 0 | 0.11509365173486417 | 3.5959375000000002 | 0 | 0.0034356313950705724 | 16 | 0 | 0.11509365173486417 | 3.5959375000000002 | 0 | 0.0034356313950705724 triton_poi_fused__native_batch_norm_legi | 14 | 0 | 0.1704672000243914 | 2.4044285714285714 | 0 | 0.00508857313505646 | 14 | 0 | 0.1704672000243914 | 2.4044285714285714 | 0 | 0.00508857313505646 triton_poi_fused__native_batch_norm_legi | 58 | 0 | 2.307520779930795 | 8.190706896551722 | 0 | 0.06888121731136704 | 58 | 0 | 2.307520779930795 | 8.190706896551722 | 0 | 0.06888121731136704 triton_per_fused__native_batch_norm_legi | 29 | 0 | 0.037243248971881276 | 3.0277586206896556 | 0 | 0.001111738775280038 | 29 | 0 | 0.037243248971881276 | 3.0277586206896556 | 0 | 0.001111738775280038 triton_poi_fused__native_batch_norm_legi | 20 | 0 | 0.04741699795428918 | 2.2911500000000005 | 0 | 0.0014154327747549007 | 20 | 0 | 0.04741699795428918 | 2.2911500000000005 | 0 | 0.0014154327747549007 triton_per_fused__native_batch_norm_legi | 25 | 0 | 0.13357016893727824 | 3.37536 | 0 | 0.003987169222008305 | 25 | 0 | 0.13357016893727824 | 3.37536 | 0 | 0.003987169222008305 triton_poi_fused__native_batch_norm_legi | 13 | 0 | 0.3089862268300253 | 2.8111538461538457 | 0 | 0.009223469457612694 | 13 | 0 | 0.3089862268300253 | 2.8111538461538457 | 0 | 0.009223469457612694 triton_poi_fused__native_batch_norm_legi | 17 | 0 | 0.3129385387909844 | 2.673 | 0 | 0.009341448919133863 | 17 | 0 | 0.3129385387909844 | 2.673 | 0 | 0.009341448919133863 triton_per_fused__native_batch_norm_legi | 19 | 0 | 0.2215568162533158 | 3.8837368421052636 | 0 | 0.0066136363060691275 | 19 | 0 | 0.2215568162533158 | 3.8837368421052636 | 0 | 0.0066136363060691275 std::enable_if<!(false), void>::type int | 23 | 504916805.19297093 | 1.0118296096314707 | 8.113913043478261 | 0.0025893169497075447 | 0.030203868944223014 | 23 | 504916805.19297093 | 1.0118296096314707 | 8.113913043478261 | 0.0025893169497075447 | 0.030203868944223014 triton_poi_fused_add_copy__38 | 56 | 0 | 0 | 2.132482142857143 | 0 | 0 | 56 | 0 | 0 | 2.132482142857143 | 0 | 0 triton_poi_fused_convolution_0 | 18 | 0 | 0.43458610794936897 | 2.773333333333334 | 0 | 0.012972719640279667 | 18 | 0 | 0.43458610794936897 | 2.773333333333334 | 0 | 0.012972719640279667 triton_poi_fused_convolution_1 | 17 | 0 | 0.028816312469162712 | 2.6145882352941174 | 0 | 0.0008601884319153051 | 17 | 0 | 0.028816312469162712 | 2.6145882352941174 | 0 | 0.0008601884319153051 void convolve_common_engine_float_NHWC<f | 44 | 8641868995.31118 | 0.024730540008465626 | 25.87327272727273 | 0.04431727689903169 | 0.0007382250748795709 | 44 | 8641868995.31118 | 0.024730540008465626 | 25.87327272727273 | 0.04431727689903169 | 0.0007382250748795709 triton_per_fused__native_batch_norm_legi | 12 | 0 | 0.6809930918986744 | 4.82675 | 0 | 0.020328151996975356 | 12 | 0 | 0.6809930918986744 | 4.82675 | 0 | 0.020328151996975356 triton_per_fused__native_batch_norm_legi | 14 | 0 | 0.02883030597936608 | 2.6651428571428575 | 0 | 0.0008606061486377935 | 14 | 0 | 0.02883030597936608 | 2.6651428571428575 | 0 | 0.0008606061486377935 triton_per_fused__native_batch_norm_legi | 16 | 0 | 0.0014658988233201874 | 2.098 | 0 | 4.375817383045335e-05 | 16 | 0 | 0.0014658988233201874 | 2.098 | 0 | 4.375817383045335e-05 triton_poi_fused__native_batch_norm_legi | 13 | 0 | 0.9926297180284697 | 3.2367692307692306 | 0 | 0.02963073785159611 | 13 | 0 | 0.9926297180284697 | 3.2367692307692306 | 0 | 0.02963073785159611 triton_poi_fused__native_batch_norm_legi | 9 | 0 | 1.3008817095666507 | 3.0863333333333336 | 0 | 0.03883228983781048 | 9 | 0 | 1.3008817095666507 | 3.0863333333333336 | 0 | 0.03883228983781048 void at::native::(anonymous namespace):: | 98 | 0 | 0.09174335613709389 | 4.408520408163265 | 0 | 0.0027386076458833994 | 98 | 0 | 0.09174335613709389 | 4.408520408163265 | 0 | 0.0027386076458833994 void at::native::vectorized_elementwise_ | 7 | 0 | 0 | 1.7278571428571428 | 0 | 0 | 7 | 0 | 0 | 1.7278571428571428 | 0 | 0 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/149697 Approved by: https://github.com/eellison, https://github.com/shunting314
194 lines
6.5 KiB
Python
194 lines
6.5 KiB
Python
import logging
|
|
from dataclasses import dataclass
|
|
from typing import Optional, Union
|
|
|
|
import torch
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class DeviceInfo:
|
|
"""
|
|
Theoretical Numbers from data sheet. If two numbers are given, Tensor/Matrix Core vs not,
|
|
then the higher number is reported. Sparsity is not considered.
|
|
|
|
|
|
Bandwidth numbers are tricky, because there are platform differences that may not show up in the profiler trace.
|
|
For example,
|
|
"""
|
|
|
|
tops: dict[Union[torch.dtype, str], float]
|
|
dram_bw_gbs: float
|
|
dram_gb: float
|
|
|
|
|
|
# Indexing is based on `torch.cuda.get_device_name()`
|
|
# TODO investigate profiler support for tf32 and allow device to report correct number when it's turned on.
|
|
_device_mapping: dict[str, DeviceInfo] = {
|
|
# Source:
|
|
# @lint-ignore https://www.nvidia.com/en-us/data-center/h100/
|
|
"NVIDIA H100": DeviceInfo(
|
|
tops={
|
|
torch.float64: 67.0,
|
|
torch.float32: 67.5,
|
|
"torch.tf32": 156.0,
|
|
torch.bfloat16: 1979.0,
|
|
torch.float16: 1979.0,
|
|
torch.float8_e8m0fnu: 3958.0,
|
|
torch.float8_e8m0fnu: 3958.0,
|
|
torch.float8_e4m3fnuz: 3958.0,
|
|
torch.float8_e5m2: 3958.0,
|
|
torch.float8_e5m2fnuz: 3958.0,
|
|
torch.float8_e8m0fnu: 3958.0,
|
|
torch.int8: 3958.0,
|
|
},
|
|
dram_bw_gbs=3350,
|
|
dram_gb=80,
|
|
),
|
|
# Source:
|
|
# @lint-ignore https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/
|
|
# nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
|
|
"NVIDIA A100": DeviceInfo(
|
|
tops={
|
|
torch.float64: 19.5,
|
|
torch.float32: 19.5,
|
|
torch.bfloat16: 312.5,
|
|
torch.float16: 312.5,
|
|
# Not in datasheet: float8
|
|
torch.int8: 624.0,
|
|
"torch.tf32": 156.0,
|
|
},
|
|
dram_bw_gbs=2039.0,
|
|
dram_gb=80.0,
|
|
),
|
|
# Source:
|
|
# @lint-ignore https://resources.nvidia.com/en-us-gpu-resources/l4-tensor-datasheet
|
|
"NVIDIA L4": DeviceInfo(
|
|
tops={
|
|
# This is a guess, not in datasheet
|
|
torch.float64: 15.1,
|
|
torch.float32: 30.3,
|
|
"torch.tf32": 120.0,
|
|
torch.bfloat16: 242.0,
|
|
torch.float16: 242.0,
|
|
torch.float8_e8m0fnu: 485.0,
|
|
torch.float8_e8m0fnu: 485.0,
|
|
torch.float8_e4m3fnuz: 485.0,
|
|
torch.float8_e5m2: 485.0,
|
|
torch.float8_e5m2fnuz: 485.0,
|
|
torch.float8_e8m0fnu: 485.0,
|
|
torch.int8: 485.0,
|
|
},
|
|
dram_bw_gbs=3350,
|
|
dram_gb=24,
|
|
),
|
|
# Source:
|
|
# @lint-ignore https://www.amd.com/content/dam/amd/en/documents\
|
|
# /instinct-tech-docs/data-sheets/amd-instinct-mi300a-data-sheet.pdf
|
|
"AMD MI300A": DeviceInfo(
|
|
tops={
|
|
torch.float64: 122.6,
|
|
torch.float32: 122.6,
|
|
"torch.tf32": 490.3,
|
|
torch.bfloat16: 980.6,
|
|
torch.float16: 980.6,
|
|
torch.float8_e8m0fnu: 1961.2,
|
|
torch.float8_e8m0fnu: 1961.2,
|
|
torch.float8_e4m3fnuz: 1961.2,
|
|
torch.float8_e5m2: 1961.2,
|
|
torch.float8_e5m2fnuz: 1961.2,
|
|
torch.float8_e8m0fnu: 1961.2,
|
|
torch.int8: 1961.2,
|
|
},
|
|
dram_bw_gbs=5300.0,
|
|
dram_gb=128.0,
|
|
),
|
|
# Source:
|
|
# @lint-ignore https://www.amd.com/content/dam/amd/en/documents/\
|
|
# instinct-tech-docs/data-sheets/amd-instinct-mi300x-data-sheet.pdf
|
|
"AMD MI300X": DeviceInfo(
|
|
tops={
|
|
torch.float64: 163.4,
|
|
torch.float32: 163.4,
|
|
"torch.tf32": 653.7,
|
|
torch.bfloat16: 1307.4,
|
|
torch.float16: 1307.4,
|
|
torch.float8_e8m0fnu: 2614.9,
|
|
torch.float8_e8m0fnu: 2614.9,
|
|
torch.float8_e4m3fnuz: 2614.9,
|
|
torch.float8_e5m2: 2614.9,
|
|
torch.float8_e5m2fnuz: 2614.9,
|
|
torch.float8_e8m0fnu: 2614.9,
|
|
torch.int8: 2614.9,
|
|
},
|
|
dram_bw_gbs=5300.0,
|
|
dram_gb=192.0,
|
|
),
|
|
# Source:
|
|
# @lint-ignore https://www.amd.com/content/dam/amd/\
|
|
# en/documents/instinct-business-docs/product-briefs/instinct-mi210-brochure.pdf
|
|
"AMD MI210X": DeviceInfo(
|
|
tops={
|
|
torch.float64: 45.3,
|
|
torch.float32: 45.3,
|
|
# not specified, fall back to float32 numbers
|
|
"torch.tf32": 45.3,
|
|
torch.bfloat16: 181.0,
|
|
torch.float16: 181.0,
|
|
# not specified, fall back to float16 numbers
|
|
torch.float8_e8m0fnu: 181.0,
|
|
torch.float8_e8m0fnu: 181.0,
|
|
torch.float8_e4m3fnuz: 181.0,
|
|
torch.float8_e5m2: 181.0,
|
|
torch.float8_e5m2fnuz: 181.0,
|
|
torch.float8_e8m0fnu: 181.0,
|
|
torch.int8: 181.0,
|
|
},
|
|
# pcie4.0x16
|
|
dram_bw_gbs=1600.0,
|
|
dram_gb=64.0,
|
|
),
|
|
}
|
|
_device_mapping["AMD INSTINCT MI300X"] = _device_mapping["AMD MI300X"]
|
|
_device_mapping["AMD INSTINCT MI210X"] = _device_mapping["AMD MI210X"]
|
|
|
|
|
|
def lookup_device_info(name: str) -> Optional[DeviceInfo]:
|
|
"""
|
|
Problem: when diffing profiles between amd and nvidia, we don't have access to the device information
|
|
of the other one. Also, since the analysis is static, we should be able to do it on another device unrelated
|
|
to the recorded device. Therefore, _device_mapping statically contains the information for lots of devices.
|
|
If one is missing, please run DeviceInfo.get_device_info() and add it to _device_mapping.
|
|
name (str): name of the device to lookup. Should map onto torch.cuda.get_device_name().
|
|
"""
|
|
return _device_mapping.get(name, None)
|
|
|
|
|
|
def datasheet_tops(dtype: torch.dtype, is_tf32: bool = False) -> Optional[float]:
|
|
"""
|
|
Get the theoretical TFLOPS of the device for a given dtype. This can throw an exception if the device
|
|
is not in the datasheet list above.
|
|
"""
|
|
name: Optional[str] = torch.cuda.get_device_name()
|
|
if name is None:
|
|
log.info("No device found, returning None")
|
|
return None
|
|
device_info = lookup_device_info(name)
|
|
if device_info is None:
|
|
log_str = f"Device {name} not in datasheet, returning None"
|
|
log.info(log_str)
|
|
return None
|
|
if dtype not in device_info.tops:
|
|
log.info(
|
|
"Device %s does not have a datasheet entry for %s, returning None",
|
|
name,
|
|
dtype,
|
|
)
|
|
return None
|
|
|
|
return device_info.tops[
|
|
"torch.tf32" if dtype == torch.float32 and is_tf32 else dtype
|
|
]
|