mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Prereqs: - https://github.com/pytorch/pytorch/pull/152708 Features: 1. Adds inductor's estimate of flops and bandwidth to the json trace events that perfetto uses. 1. Only use the tflops estimation from triton if we don't have the info from the datasheet because Triton's estimates are inaccurate. I have a backlog item to fix triton flops estimation upstream. New `DeviceInfo` class, and new function `get_device_tflops`. 1. New helpers `countable_fx` and `count_flops_fx` helps get the flops of an `fx.Node`. 1. Extends Triton `torch.profiler` logging to `DebugAutotuner`. 1. New script `profile_analysis.py`: `--augment_trace` adds perf estimates to any perfetto json trace, `--analyze` creates a summary table of these perf estimates, and `--diff` will compare two traces side by side: ```python Device(NVIDIA H100, 0): Kernel Name | resnet Kernel Count | resnet FLOPS | resnet bw gbps | resnet Dur (ms) | resnet Achieved FLOPS % | resnet Achieved Bandwidth % | newresnet Kernel Count | newresnet FLOPS | newresnet bw gbps | newresnet Dur (ms) | newresnet Achieved FLOPS % | newresnet Achieved Bandwidth % --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- triton_poi_fused__native_batch_norm_legi | 24 | 0 | 0.11395268248131513 | 2.5919166666666666 | 0 | 0.003401572611382541 | 24 | 0 | 0.11395268248131513 | 2.5919166666666666 | 0 | 0.003401572611382541 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 142 | 16932673552.422373 | 0.2585007824198784 | 12.441619718309857 | 0.08683422334575583 | 0.007716441266265022 | 142 | 16932673552.422373 | 0.2585007824198784 | 12.441619718309857 | 0.08683422334575583 | 0.007716441266265022 triton_red_fused__native_batch_norm_legi | 39 | 0 | 0.13990024992108846 | 5.752589743589743 | 0 | 0.004176126863316074 | 39 | 0 | 0.13990024992108846 | 5.752589743589743 | 0 | 0.004176126863316074 triton_poi_fused__native_batch_norm_legi | 25 | 0 | 0.31824055917536503 | 2.5291999999999994 | 0 | 0.009499718184339253 | 25 | 0 | 0.31824055917536503 | 2.5291999999999994 | 0 | 0.009499718184339253 void cutlass::Kernel2<cutlass_80_tensoro | 98 | 16211056473.596165 | 0.42972434051025826 | 7.130408163265306 | 0.08313362294151874 | 0.012827592254037562 | 98 | 16211056473.596165 | 0.42972434051025826 | 7.130408163265306 | 0.08313362294151874 | 0.012827592254037562 triton_red_fused__native_batch_norm_legi | 73 | 0 | 0.3225381327611705 | 9.987068493150682 | 0 | 0.009628003963020014 | 73 | 0 | 0.3225381327611705 | 9.987068493150682 | 0 | 0.009628003963020014 triton_poi_fused__native_batch_norm_legi | 15 | 0 | 1.4491211346487216 | 4.439333333333333 | 0 | 0.043257347302946926 | 15 | 0 | 1.4491211346487216 | 4.439333333333333 | 0 | 0.043257347302946926 void cutlass::Kernel2<cutlass_80_tensoro | 186 | 14501701145.337954 | 0.2667131401910989 | 7.873865591397849 | 0.07436769818122027 | 0.007961586274361157 | 186 | 14501701145.337954 | 0.2667131401910989 | 7.873865591397849 | 0.07436769818122027 | 0.007961586274361157 triton_poi_fused__native_batch_norm_legi | 33 | 0 | 1.4924556538193923 | 4.3101515151515155 | 0 | 0.044550915039384846 | 33 | 0 | 1.4924556538193923 | 4.3101515151515155 | 0 | 0.044550915039384846 triton_red_fused__native_batch_norm_legi | 29 | 0 | 0.25562590522631107 | 6.296275862068965 | 0 | 0.007630624036606301 | 29 | 0 | 0.25562590522631107 | 6.296275862068965 | 0 | 0.007630624036606301 triton_poi_fused__native_batch_norm_legi | 13 | 0 | 0.5870562174192726 | 2.7397692307692307 | 0 | 0.01752406619162008 | 13 | 0 | 0.5870562174192726 | 2.7397692307692307 | 0 | 0.01752406619162008 triton_poi_fused__native_batch_norm_legi | 34 | 0 | 0.41409928846284 | 2.853588235294117 | 0 | 0.012361172789935523 | 34 | 0 | 0.41409928846284 | 2.853588235294117 | 0 | 0.012361172789935523 triton_per_fused__native_batch_norm_legi | 34 | 0 | 0.11705315007018151 | 3.460647058823529 | 0 | 0.0034941238826919864 | 34 | 0 | 0.11705315007018151 | 3.460647058823529 | 0 | 0.0034941238826919864 triton_poi_fused__native_batch_norm_legi | 16 | 0 | 0.17207853197124584 | 2.3459375000000002 | 0 | 0.005136672596156592 | 16 | 0 | 0.17207853197124584 | 2.3459375000000002 | 0 | 0.005136672596156592 triton_per_fused__native_batch_norm_legi | 30 | 0 | 0.2639714322022256 | 6.131199999999999 | 0 | 0.007879744244842555 | 30 | 0 | 0.2639714322022256 | 6.131199999999999 | 0 | 0.007879744244842555 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 100 | 11875430356.891787 | 0.19494470869421385 | 16.36534 | 0.06089964285585531 | 0.005819245035648175 | 100 | 11875430356.891787 | 0.19494470869421385 | 16.36534 | 0.06089964285585531 | 0.005819245035648175 triton_poi_fused__native_batch_norm_legi | 8 | 0 | 0.9854096626224687 | 3.2757500000000004 | 0 | 0.029415213809625928 | 8 | 0 | 0.9854096626224687 | 3.2757500000000004 | 0 | 0.029415213809625928 void cublasLt::splitKreduce_kernel<32, 1 | 56 | 34377923395.147064 | 0.8310300045762317 | 3.4199999999999986 | 0.17629704305203628 | 0.024806865808245714 | 56 | 34377923395.147064 | 0.8310300045762317 | 3.4199999999999986 | 0.17629704305203628 | 0.024806865808245714 triton_poi_fused__native_batch_norm_legi | 23 | 0 | 0.9944002965861103 | 3.2431304347826084 | 0 | 0.02968359094286896 | 23 | 0 | 0.9944002965861103 | 3.2431304347826084 | 0 | 0.02968359094286896 triton_per_fused__native_batch_norm_legi | 10 | 0 | 0.1826801058931057 | 4.428800000000001 | 0 | 0.00545313748934644 | 10 | 0 | 0.1826801058931057 | 4.428800000000001 | 0 | 0.00545313748934644 triton_poi_fused__native_batch_norm_legi | 10 | 0 | 0.3168973585366449 | 2.5471999999999997 | 0 | 0.009459622642884923 | 10 | 0 | 0.3168973585366449 | 2.5471999999999997 | 0 | 0.009459622642884923 triton_poi_fused__native_batch_norm_legi | 34 | 0 | 1.1463614897015777 | 4.124323529411764 | 0 | 0.03421974596124114 | 34 | 0 | 1.1463614897015777 | 4.124323529411764 | 0 | 0.03421974596124114 void cask_plugin_cudnn::xmma_cudnn::init | 44 | 44045510816.64277 | 2.0661232850348643 | 3.6887499999999993 | 0.22587441444432194 | 0.06167532194133924 | 44 | 44045510816.64277 | 2.0661232850348643 | 3.6887499999999993 | 0.22587441444432194 | 0.06167532194133924 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 95 | 7876855400.165316 | 0.4694941555946739 | 18.224315789473682 | 0.04039413025725802 | 0.014014750913273854 | 95 | 7876855400.165316 | 0.4694941555946739 | 18.224315789473682 | 0.04039413025725802 | 0.014014750913273854 triton_per_fused__native_batch_norm_legi | 41 | 0 | 0.06825669875995298 | 3.0384146341463416 | 0 | 0.002037513395819492 | 41 | 0 | 0.06825669875995298 | 3.0384146341463416 | 0 | 0.002037513395819492 triton_poi_fused__native_batch_norm_legi | 23 | 0 | 0.08808154712430301 | 2.3275652173913044 | 0 | 0.0026292999141582997 | 23 | 0 | 0.08808154712430301 | 2.3275652173913044 | 0 | 0.0026292999141582997 triton_per_fused__native_batch_norm_legi | 40 | 0 | 0.18179321034952417 | 4.556825 | 0 | 0.005426662995508183 | 40 | 0 | 0.18179321034952417 | 4.556825 | 0 | 0.005426662995508183 triton_poi_fused__native_batch_norm_legi | 15 | 0 | 0.5887415155454232 | 2.783866666666667 | 0 | 0.017574373598370836 | 15 | 0 | 0.5887415155454232 | 2.783866666666667 | 0 | 0.017574373598370836 void cutlass::Kernel2<cutlass_80_tensoro | 38 | 14242013806.264643 | 0.256592404353939 | 7.217631578947369 | 0.0730359682372546 | 0.007659474756834 | 38 | 14242013806.264643 | 0.256592404353939 | 7.217631578947369 | 0.0730359682372546 | 0.007659474756834 triton_poi_fused__native_batch_norm_legi | 21 | 0 | 0.5842860973430516 | 2.7779047619047623 | 0 | 0.017441376040091088 | 21 | 0 | 0.5842860973430516 | 2.7779047619047623 | 0 | 0.017441376040091088 triton_per_fused__native_batch_norm_legi | 16 | 0 | 0.11509365173486417 | 3.5959375000000002 | 0 | 0.0034356313950705724 | 16 | 0 | 0.11509365173486417 | 3.5959375000000002 | 0 | 0.0034356313950705724 triton_poi_fused__native_batch_norm_legi | 14 | 0 | 0.1704672000243914 | 2.4044285714285714 | 0 | 0.00508857313505646 | 14 | 0 | 0.1704672000243914 | 2.4044285714285714 | 0 | 0.00508857313505646 triton_poi_fused__native_batch_norm_legi | 58 | 0 | 2.307520779930795 | 8.190706896551722 | 0 | 0.06888121731136704 | 58 | 0 | 2.307520779930795 | 8.190706896551722 | 0 | 0.06888121731136704 triton_per_fused__native_batch_norm_legi | 29 | 0 | 0.037243248971881276 | 3.0277586206896556 | 0 | 0.001111738775280038 | 29 | 0 | 0.037243248971881276 | 3.0277586206896556 | 0 | 0.001111738775280038 triton_poi_fused__native_batch_norm_legi | 20 | 0 | 0.04741699795428918 | 2.2911500000000005 | 0 | 0.0014154327747549007 | 20 | 0 | 0.04741699795428918 | 2.2911500000000005 | 0 | 0.0014154327747549007 triton_per_fused__native_batch_norm_legi | 25 | 0 | 0.13357016893727824 | 3.37536 | 0 | 0.003987169222008305 | 25 | 0 | 0.13357016893727824 | 3.37536 | 0 | 0.003987169222008305 triton_poi_fused__native_batch_norm_legi | 13 | 0 | 0.3089862268300253 | 2.8111538461538457 | 0 | 0.009223469457612694 | 13 | 0 | 0.3089862268300253 | 2.8111538461538457 | 0 | 0.009223469457612694 triton_poi_fused__native_batch_norm_legi | 17 | 0 | 0.3129385387909844 | 2.673 | 0 | 0.009341448919133863 | 17 | 0 | 0.3129385387909844 | 2.673 | 0 | 0.009341448919133863 triton_per_fused__native_batch_norm_legi | 19 | 0 | 0.2215568162533158 | 3.8837368421052636 | 0 | 0.0066136363060691275 | 19 | 0 | 0.2215568162533158 | 3.8837368421052636 | 0 | 0.0066136363060691275 std::enable_if<!(false), void>::type int | 23 | 504916805.19297093 | 1.0118296096314707 | 8.113913043478261 | 0.0025893169497075447 | 0.030203868944223014 | 23 | 504916805.19297093 | 1.0118296096314707 | 8.113913043478261 | 0.0025893169497075447 | 0.030203868944223014 triton_poi_fused_add_copy__38 | 56 | 0 | 0 | 2.132482142857143 | 0 | 0 | 56 | 0 | 0 | 2.132482142857143 | 0 | 0 triton_poi_fused_convolution_0 | 18 | 0 | 0.43458610794936897 | 2.773333333333334 | 0 | 0.012972719640279667 | 18 | 0 | 0.43458610794936897 | 2.773333333333334 | 0 | 0.012972719640279667 triton_poi_fused_convolution_1 | 17 | 0 | 0.028816312469162712 | 2.6145882352941174 | 0 | 0.0008601884319153051 | 17 | 0 | 0.028816312469162712 | 2.6145882352941174 | 0 | 0.0008601884319153051 void convolve_common_engine_float_NHWC<f | 44 | 8641868995.31118 | 0.024730540008465626 | 25.87327272727273 | 0.04431727689903169 | 0.0007382250748795709 | 44 | 8641868995.31118 | 0.024730540008465626 | 25.87327272727273 | 0.04431727689903169 | 0.0007382250748795709 triton_per_fused__native_batch_norm_legi | 12 | 0 | 0.6809930918986744 | 4.82675 | 0 | 0.020328151996975356 | 12 | 0 | 0.6809930918986744 | 4.82675 | 0 | 0.020328151996975356 triton_per_fused__native_batch_norm_legi | 14 | 0 | 0.02883030597936608 | 2.6651428571428575 | 0 | 0.0008606061486377935 | 14 | 0 | 0.02883030597936608 | 2.6651428571428575 | 0 | 0.0008606061486377935 triton_per_fused__native_batch_norm_legi | 16 | 0 | 0.0014658988233201874 | 2.098 | 0 | 4.375817383045335e-05 | 16 | 0 | 0.0014658988233201874 | 2.098 | 0 | 4.375817383045335e-05 triton_poi_fused__native_batch_norm_legi | 13 | 0 | 0.9926297180284697 | 3.2367692307692306 | 0 | 0.02963073785159611 | 13 | 0 | 0.9926297180284697 | 3.2367692307692306 | 0 | 0.02963073785159611 triton_poi_fused__native_batch_norm_legi | 9 | 0 | 1.3008817095666507 | 3.0863333333333336 | 0 | 0.03883228983781048 | 9 | 0 | 1.3008817095666507 | 3.0863333333333336 | 0 | 0.03883228983781048 void at::native::(anonymous namespace):: | 98 | 0 | 0.09174335613709389 | 4.408520408163265 | 0 | 0.0027386076458833994 | 98 | 0 | 0.09174335613709389 | 4.408520408163265 | 0 | 0.0027386076458833994 void at::native::vectorized_elementwise_ | 7 | 0 | 0 | 1.7278571428571428 | 0 | 0 | 7 | 0 | 0 | 1.7278571428571428 | 0 | 0 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/149697 Approved by: https://github.com/eellison, https://github.com/shunting314
609 lines
20 KiB
Python
609 lines
20 KiB
Python
# Owner(s): ["module: inductor"]
|
|
|
|
import json
|
|
import tempfile
|
|
import unittest
|
|
import uuid
|
|
from io import StringIO
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch._inductor.analysis.profile_analysis import (
|
|
_augment_trace_helper,
|
|
_create_extern_mapping,
|
|
JsonProfile,
|
|
main,
|
|
)
|
|
from torch._inductor.utils import fresh_inductor_cache, tabulate_2d, zip_dicts
|
|
from torch.testing._internal.common_cuda import SM80OrLater
|
|
from torch.testing._internal.common_device_type import (
|
|
dtypes,
|
|
instantiate_device_type_tests,
|
|
skipIf,
|
|
)
|
|
from torch.testing._internal.common_utils import parametrize, run_tests, TestCase
|
|
from torch.testing._internal.inductor_utils import IS_BIG_GPU
|
|
|
|
|
|
example_profile = """
|
|
{
|
|
"schemaVersion": 1,
|
|
"deviceProperties": [
|
|
{
|
|
"id": 0, "name": "NVIDIA H100", "totalGlobalMem": 101997215744,
|
|
"computeMajor": 9, "computeMinor": 0,
|
|
"maxThreadsPerBlock": 1024, "maxThreadsPerMultiprocessor": 2048,
|
|
"regsPerBlock": 65536, "warpSize": 32,
|
|
"sharedMemPerBlock": 49152, "numSms": 132
|
|
, "regsPerMultiprocessor": 65536, "sharedMemPerBlockOptin": 232448, "sharedMemPerMultiprocessor": 233472
|
|
}
|
|
],
|
|
"cupti_version": 24,
|
|
"cuda_runtime_version": 12060,
|
|
"with_flops": 1,
|
|
"record_shapes": 1,
|
|
"cuda_driver_version": 12040,
|
|
"profile_memory": 1,
|
|
"trace_id": "301995E163ED42048FBD783860E6E7DC",
|
|
"displayTimeUnit": "ms",
|
|
"baseTimeNanoseconds": 1743521598000000000,
|
|
"traceEvents": [
|
|
{
|
|
"ph": "X", "cat": "cpu_op", "name": "aten::convolution", "pid": 1147039, "tid": 1147039,
|
|
"ts": 198093488368.463, "dur": 425.453,
|
|
"args": {
|
|
"External id": 1340,"Sequence number": 0, "Fwd thread id": 0, "Record function id": 0, "Concrete Inputs": \
|
|
["", "", "", "[2, 2]", "[3, 3]", "[1, 1]", "False", "[0, 0]", "1"], "Input type": ["float", "float", "", \
|
|
"ScalarList", "ScalarList", "ScalarList", "Scalar", "ScalarList", "Scalar"], "Input Strides": [[150528, 1, 672, 3],\
|
|
[147, 1, 21, 3], [], [], [], [], [], [], []], "Input Dims": [[1, 3, 224, 224], [64, 3, 7, 7], [], [], [], [], [], \
|
|
[], []], "Ev Idx": 1339
|
|
}
|
|
},
|
|
{
|
|
"ph": "X", "cat": "cpu_op", "name": "aten::_convolution", "pid": 1147039, "tid": 1147039,
|
|
"ts": 198093488444.498, "dur": 341.867,
|
|
"args": {
|
|
"External id": 1341,"Record function id": 0, "Concrete Inputs": ["", "", "", "[2, 2]", "[3, 3]", "[1, 1]",\
|
|
"False", "[0, 0]", "1", "False", "False", "True", "True"], "Input type": ["float", "float", "", "ScalarList",\
|
|
"ScalarList", "ScalarList", "Scalar", "ScalarList", "Scalar", "Scalar", "Scalar", "Scalar", "Scalar"], "Input Strides": \
|
|
[[150528, 1, 672, 3], [147, 1, 21, 3], [], [], [], [], [], [], [], [], [], [], []], "Input Dims": [[1, 3, 224, 224], \
|
|
[64, 3, 7, 7], [], [], [], [], [], [], [], [], [], [], []], "Ev Idx": 1340
|
|
}
|
|
},
|
|
{
|
|
"ph": "X", "cat": "cpu_op", "name": "aten::addmm", "pid": 1147039, "tid": 1147039,
|
|
"ts": 198093513655.849, "dur": 251.130,
|
|
"args": {
|
|
"External id": 1619,"Sequence number": 0, "Fwd thread id": 0, "Record function id": 0, "Concrete Inputs": \
|
|
["", "", "", "1", "1", ""], "Input type": ["float", "float", "float", "Scalar", "Scalar", "float"], "Input Strides":\
|
|
[[1], [0, 1], [1, 2048], [], [], [1000, 1]], "Input Dims": [[1000], [1, 2048], [2048, 1000], [], [], [1, 1000]], \
|
|
"Ev Idx": 1618
|
|
}
|
|
},
|
|
{
|
|
"ph": "X", "cat": "kernel", "name": "void cutlass_addmm", "pid": 1147039, "tid": 1147039,
|
|
"ts": 198093513655.849, "dur": 251.130,
|
|
"args": {
|
|
"External id": 1619,"Sequence number": 0, "Fwd thread id": 0, "Record function id": 0, "Ev Idx": 1618
|
|
}
|
|
},
|
|
{
|
|
"ph": "X", "cat": "kernel", "name": "void convolution_kernel", "pid": 1147039, "tid": 1147039,
|
|
"ts": 198093513655.849, "dur": 200.130,
|
|
"args": {
|
|
"External id": 1342, "Sequence number": 0, "Fwd thread id": 0, "Record function id": 0, "Ev Idx": 1618
|
|
}
|
|
},
|
|
{
|
|
"ph": "X", "cat": "cpu_op", "name": "aten::convolution", "pid": 1147039, "tid": 1147039,
|
|
"ts": 198093488444.498, "dur": 341.867,
|
|
"args": {
|
|
"External id": 1342,"Record function id": 0, "Concrete Inputs": ["", "", "", "[2, 2]", "[3, 3]", "[1, 1]", \
|
|
"False", "[0, 0]", "1", "False", "False", "True", "True"], "Input type": ["float", "float", "", "ScalarList", \
|
|
"ScalarList", "ScalarList", "Scalar", "ScalarList", "Scalar", "Scalar", "Scalar", "Scalar", "Scalar"], "Input \
|
|
Strides": [[150528, 1, 672, 3], [147, 1, 21, 3], [], [], [], [], [], [], [], [], [], [], []], "Input Dims": \
|
|
[[1, 3, 224, 224], [64, 3, 7, 7], [], [], [], [], [], [], [], [], [], [], []], "Ev Idx": 1340
|
|
}
|
|
}
|
|
],
|
|
"traceName": "/tmp/compiled_module_profile.json"
|
|
}
|
|
"""
|
|
|
|
|
|
def verify_flops(self, expected_flops, out_profile):
|
|
j = 0
|
|
for i in range(len(out_profile["traceEvents"])):
|
|
if "kernel_flop" in out_profile["traceEvents"][i]["args"]:
|
|
self.assertEqual(
|
|
out_profile["traceEvents"][i]["args"]["kernel_flop"],
|
|
expected_flops[j],
|
|
)
|
|
j += 1
|
|
|
|
|
|
def random_tensor(size, dtype, **kwargs):
|
|
if dtype in [torch.half, torch.bfloat16, torch.float, torch.double]:
|
|
return torch.randn(size, dtype=dtype, **kwargs)
|
|
elif dtype in [torch.uint8, torch.int8, torch.short, torch.int, torch.long]:
|
|
return torch.randint(0, 100, size, dtype=dtype, **kwargs)
|
|
else:
|
|
raise ValueError("Unsupported data type")
|
|
|
|
|
|
def cT(device, dtype):
|
|
def T(*shape, requires_grad=False):
|
|
return random_tensor(
|
|
shape, requires_grad=requires_grad, device=device, dtype=dtype
|
|
)
|
|
|
|
return T
|
|
|
|
|
|
def FlopCounterMode(*args, **kwargs):
|
|
return torch.utils.flop_counter.FlopCounterMode(*args, **kwargs, display=False)
|
|
|
|
|
|
TMP_DIR = tempfile.mkdtemp()
|
|
|
|
|
|
def trace_files():
|
|
TRACE1 = f"{TMP_DIR}/trace1-{uuid.uuid4()}.json"
|
|
TRACE2 = f"{TMP_DIR}/trace2-{uuid.uuid4()}.json"
|
|
return TRACE1, TRACE2
|
|
|
|
|
|
def _test_model(device, dtype, compile=True, addmm=True, bmm=True):
|
|
T = cT(device, dtype)
|
|
|
|
def model():
|
|
input_conv = T(1, 3, 56, 56)
|
|
conv_weight = T(12, 3, 5, 5)
|
|
|
|
# Increased matrix sizes
|
|
B = 8
|
|
M = 256
|
|
N = 512
|
|
K = 768
|
|
mat1 = T(M, N)
|
|
mat2 = T(N, K)
|
|
|
|
batch_mat1 = T(B, M, N)
|
|
batch_mat2 = T(B, N, K)
|
|
|
|
conv_output = F.conv2d(input_conv, conv_weight)
|
|
|
|
conv_output = conv_output * 10
|
|
mm_output = torch.mm(mat1, mat2)
|
|
ret = [
|
|
conv_output.flatten(),
|
|
mm_output.flatten(),
|
|
]
|
|
|
|
if addmm:
|
|
addmm_output = torch.addmm(
|
|
torch.zeros(mm_output.shape, device=mat1.device, dtype=mat1.dtype),
|
|
mat1,
|
|
mat2,
|
|
)
|
|
ret.append(addmm_output.flatten())
|
|
|
|
if bmm:
|
|
bmm_output = torch.bmm(batch_mat1, batch_mat2)
|
|
ret.append(bmm_output.flatten())
|
|
|
|
if bmm and addmm:
|
|
baddbmm_output = torch.baddbmm(
|
|
torch.zeros(
|
|
1,
|
|
*mm_output.shape,
|
|
device=batch_mat1.device,
|
|
dtype=batch_mat1.dtype,
|
|
),
|
|
batch_mat1,
|
|
batch_mat2,
|
|
)
|
|
ret.append(baddbmm_output.flatten())
|
|
|
|
return torch.cat(ret)
|
|
|
|
if compile:
|
|
return torch.compile(
|
|
model, options={"benchmark_kernel": True, "profile_bandwidth": True}
|
|
)
|
|
return model
|
|
|
|
|
|
def _pointwise_test_model(device, dtype, compile=True):
|
|
T = cT(device, dtype)
|
|
|
|
def model():
|
|
M = 1024
|
|
N = 512
|
|
mat3 = T(M, N)
|
|
mat4 = T(M, N)
|
|
pointwise_output = torch.add(mat3, mat4).sin()
|
|
return pointwise_output
|
|
|
|
if compile:
|
|
return torch.compile(
|
|
model, options={"benchmark_kernel": True, "profile_bandwidth": True}
|
|
)
|
|
return model
|
|
|
|
|
|
prefix = ["profile.py"]
|
|
|
|
|
|
class TestUtils(TestCase):
|
|
def test_tabulate2d(self):
|
|
headers = ["Kernel", "Self H100 TIME (ms)", "Count", "Percent"]
|
|
rows = [
|
|
["aten::mm", 0.500, 7, 0.0],
|
|
["aten::bmm", 0.400, 6, 0.0],
|
|
["aten::baddbmm", 0.300, 5, 0.0],
|
|
["aten::convolution", 0.200, 4, 0.0],
|
|
["aten::cudnn_convolution", 0.100, 3, 0.0],
|
|
]
|
|
table = [
|
|
" Kernel | Self H100 TIME (ms) | Count | Percent ",
|
|
"-----------------------------------------------------------------",
|
|
" aten::mm | 0.5 | 7 | 0.0 ",
|
|
" aten::bmm | 0.4 | 6 | 0.0 ",
|
|
" aten::baddbmm | 0.3 | 5 | 0.0 ",
|
|
" aten::convolution | 0.2 | 4 | 0.0 ",
|
|
" aten::cudnn_convolution | 0.1 | 3 | 0.0 ",
|
|
]
|
|
res = tabulate_2d(rows, headers)
|
|
for r, t in zip(res.split("\n"), table):
|
|
self.assertEqual(r, t)
|
|
|
|
def test_zip_dicts(self):
|
|
d1 = {"a": 1, "b": 2}
|
|
d2 = {"a": 3, "c": 4}
|
|
res1 = zip_dicts(d1, d2, d1_default=32, d2_default=48)
|
|
self.assertEqual(set(res1), {("a", 1, 3), ("b", 2, 48), ("c", 32, 4)})
|
|
res2 = zip_dicts(d1, d2)
|
|
self.assertEqual(set(res2), {("a", 1, 3), ("b", 2, None), ("c", None, 4)})
|
|
|
|
|
|
class TestAnalysis(TestCase):
|
|
@skipIf(not SM80OrLater, "Requires SM80")
|
|
def test_noop(self):
|
|
with (
|
|
patch("sys.stdout", new_callable=StringIO) as mock_stdout,
|
|
patch("sys.argv", [*prefix]),
|
|
):
|
|
main()
|
|
self.assertEqual(mock_stdout.getvalue(), "")
|
|
|
|
@skipIf(not SM80OrLater, "Requires SM80")
|
|
@dtypes(torch.float, torch.double, torch.float16)
|
|
def test_diff(self, device, dtype):
|
|
"""
|
|
diff, testing out the nruns feature too.
|
|
"""
|
|
if device == "cpu":
|
|
# TODO cpu support
|
|
return
|
|
om = _test_model(device, dtype)
|
|
REPEAT = 5
|
|
trace1, trace2 = trace_files()
|
|
print(f"first trace {trace1}")
|
|
torch._dynamo.reset() # reset the cache
|
|
with fresh_inductor_cache():
|
|
with torch.profiler.profile(record_shapes=True) as p:
|
|
om()
|
|
p.export_chrome_trace(trace1)
|
|
|
|
print(f"second trace {trace2}")
|
|
torch._dynamo.reset() # reset the cache
|
|
with fresh_inductor_cache():
|
|
with torch.profiler.profile(record_shapes=True) as p:
|
|
for _ in range(REPEAT):
|
|
om()
|
|
p.export_chrome_trace(trace2)
|
|
|
|
print("diffing...")
|
|
with patch(
|
|
"sys.argv",
|
|
[
|
|
*prefix,
|
|
"--diff",
|
|
trace1,
|
|
"foo",
|
|
trace2,
|
|
"bar",
|
|
str(dtype).split(".")[-1],
|
|
"--name_limit",
|
|
"30",
|
|
],
|
|
):
|
|
main()
|
|
|
|
@skipIf(not SM80OrLater, "Requires SM80")
|
|
def test_augment_trace_helper_unit(self):
|
|
js = json.loads(example_profile)
|
|
out_profile = _augment_trace_helper(js)
|
|
expected_flops = [4096000, 4096000, 223552896, 223552896, 0, 0, 0]
|
|
verify_flops(self, expected_flops, out_profile)
|
|
|
|
@skipIf(not SM80OrLater, "Requires SM80")
|
|
@dtypes(torch.float, torch.double, torch.float16)
|
|
def test_augment_trace_helper_args(self, device, dtype):
|
|
if device == "cpu":
|
|
# cpu doesn't produce traces currently
|
|
return
|
|
om = _test_model(device, dtype)
|
|
torch._dynamo.reset() # reset the cache
|
|
with fresh_inductor_cache():
|
|
with torch.profiler.profile(record_shapes=True) as p:
|
|
om()
|
|
trace1, trace2 = trace_files()
|
|
p.export_chrome_trace(trace1)
|
|
print(f"first trace {trace1}")
|
|
|
|
with patch(
|
|
"sys.argv",
|
|
[*prefix, "--augment_trace", trace1, trace2, str(dtype).split(".")[-1]],
|
|
):
|
|
main()
|
|
profile = JsonProfile(
|
|
trace2, benchmark_name="foo", dtype=str(dtype).split(".")[-1]
|
|
)
|
|
rep = profile.report()
|
|
self.assertTrue(len(rep.split("\n")) > 3, f"Error, empty table:\n{rep}")
|
|
# If these fail, just update them. They could change over time
|
|
self.assertIn("Kernel Name", rep)
|
|
self.assertIn("Kernel Count", rep)
|
|
self.assertIn("FLOPS", rep)
|
|
self.assertIn("Kernel Reads", rep)
|
|
self.assertIn("Dur", rep)
|
|
self.assertIn("Achieved", rep)
|
|
self.assertIn("|", rep)
|
|
self.assertIn("-----", rep)
|
|
|
|
tables = profile._create_tables(profile._devices)
|
|
|
|
# check to make sure all % values are less than 100%
|
|
percents = []
|
|
for tab in tables.values():
|
|
header, rows = tab
|
|
for i, h in enumerate(header):
|
|
if "%" in h:
|
|
percents.append(i)
|
|
self.assertTrue(len(percents) > 0, "There are no headers with % in them")
|
|
for row in rows.values():
|
|
for p in percents:
|
|
idx = p - 1
|
|
self.assertTrue(
|
|
float(row[idx]) <= 100.0,
|
|
f"column values from column {idx} with header '{header[idx]}' is greater than 100%: {row[idx]}",
|
|
)
|
|
self.assertTrue(
|
|
float(row[idx]) >= 0.0,
|
|
f"column values from column {idx} with header '{header[idx]}' is less than 0%: {row[idx]}",
|
|
)
|
|
|
|
@skipIf(not SM80OrLater, "Requires SM80")
|
|
@dtypes(torch.float, torch.double, torch.float16)
|
|
@parametrize(
|
|
"maxat",
|
|
[
|
|
(True, "TRITON"),
|
|
],
|
|
)
|
|
@skipIf(not IS_BIG_GPU, "we can't use Triton only as a backend for max autotune")
|
|
def test_triton_has_metadata(self, device, dtype, maxat):
|
|
"""
|
|
make sure that the chrome trace of triton kernels contains certain values
|
|
"""
|
|
if device == "cpu":
|
|
return
|
|
|
|
T = cT(device, dtype)
|
|
input_conv = T(1, 3, 56, 56)
|
|
conv_weight = T(12, 3, 5, 5)
|
|
|
|
def om(i, w):
|
|
# Convolution operation
|
|
conv_output = F.conv2d(i, w)
|
|
return conv_output
|
|
|
|
max_autotune, backends = maxat
|
|
comp_omni = torch.compile(
|
|
om,
|
|
options={
|
|
"benchmark_kernel": True,
|
|
"max_autotune_gemm_backends": backends,
|
|
"force_disable_caches": True,
|
|
"max_autotune": max_autotune,
|
|
},
|
|
)
|
|
|
|
def verify_triton(comp):
|
|
torch._dynamo.reset() # reset the cache
|
|
with fresh_inductor_cache():
|
|
with torch.profiler.profile(record_shapes=True) as profile:
|
|
comp(input_conv, conv_weight)
|
|
|
|
trace1, _ = trace_files()
|
|
profile.export_chrome_trace(trace1)
|
|
with open(trace1) as f:
|
|
out_profile = json.load(f)
|
|
seen = False
|
|
for event in out_profile["traceEvents"]:
|
|
if "triton" in event["name"] and "conv" in event["name"]:
|
|
seen = True
|
|
self.assertTrue(seen, "no triton conv found")
|
|
|
|
verify_triton(comp_omni)
|
|
|
|
@skipIf(not SM80OrLater, "Requires SM80")
|
|
@dtypes(torch.float, torch.float16)
|
|
@parametrize(
|
|
"maxat",
|
|
[
|
|
(False, "ATEN,TRITON"),
|
|
(True, "ATEN,TRITON"),
|
|
(True, "ATEN"),
|
|
(True, "TRITON"),
|
|
],
|
|
)
|
|
@unittest.skipIf(
|
|
not IS_BIG_GPU, "we can't use Triton only as a backend for max autotune"
|
|
)
|
|
def test_augment_trace_against_flop_counter(self, device, dtype, maxat):
|
|
# this tests to see if we can only use a Triton backend for max autotune
|
|
max_autotune, backends = maxat
|
|
if device == "cpu":
|
|
return
|
|
om = _test_model(device, dtype, compile=False)
|
|
|
|
comp_omni = torch.compile(
|
|
om,
|
|
options={
|
|
"benchmark_kernel": True,
|
|
"max_autotune_gemm_backends": backends,
|
|
"force_disable_caches": True,
|
|
"max_autotune": max_autotune,
|
|
},
|
|
)
|
|
comp_omni()
|
|
|
|
torch._dynamo.reset() # reset the cache
|
|
with fresh_inductor_cache():
|
|
with torch.profiler.profile(record_shapes=True) as profile:
|
|
comp_omni()
|
|
|
|
torch._dynamo.reset() # reset the cache
|
|
with fresh_inductor_cache():
|
|
with FlopCounterMode() as mode:
|
|
comp_omni()
|
|
|
|
trace1, trace2 = trace_files()
|
|
profile.export_chrome_trace(trace1)
|
|
with patch(
|
|
"sys.argv",
|
|
[*prefix, "--augment_trace", trace1, trace2, str(dtype).split(".")[-1]],
|
|
):
|
|
main()
|
|
|
|
with open(trace2) as f:
|
|
out_profile = json.load(f)
|
|
|
|
flop_counts = mode.flop_counts
|
|
extern_mapping = _create_extern_mapping(out_profile)
|
|
|
|
seen_mm = False
|
|
seen_bmm = False
|
|
seen_baddbmm = False
|
|
seen_conv = False
|
|
for event in out_profile["traceEvents"]:
|
|
if (
|
|
"cat" not in event
|
|
or event["cat"] != "kernel"
|
|
or "args" not in event
|
|
or "External id" not in event["args"]
|
|
):
|
|
continue
|
|
|
|
external_op = extern_mapping[event["args"]["External id"]][0]
|
|
name: str = external_op["name"]
|
|
self.assertNotEqual(name, None)
|
|
self.assertEqual(type(name), str)
|
|
if name.startswith("aten::mm") or "_mm_" in name:
|
|
seen_mm = True
|
|
self.assertEqual(
|
|
event["args"]["kernel_flop"],
|
|
flop_counts["Global"][torch.ops.aten.mm],
|
|
)
|
|
if (
|
|
name.startswith(
|
|
(
|
|
"aten::cudnn_convolution",
|
|
"aten::convolution",
|
|
"aten::_convolution",
|
|
)
|
|
)
|
|
or "conv" in name
|
|
):
|
|
seen_conv = True
|
|
self.assertEqual(
|
|
event["args"]["kernel_flop"],
|
|
flop_counts["Global"][torch.ops.aten.convolution],
|
|
)
|
|
if name.startswith("aten::baddbmm") or "_baddbmm_" in name:
|
|
seen_baddbmm = True
|
|
self.assertEqual(
|
|
event["args"]["kernel_flop"],
|
|
flop_counts["Global"][torch.ops.aten.baddbmm],
|
|
)
|
|
if name.startswith("aten::bmm") or "_bmm_" in name:
|
|
seen_bmm = True
|
|
self.assertEqual(
|
|
event["args"]["kernel_flop"],
|
|
flop_counts["Global"][torch.ops.aten.bmm],
|
|
)
|
|
self.assertTrue(seen_mm)
|
|
self.assertTrue(seen_bmm)
|
|
self.assertTrue(seen_baddbmm)
|
|
self.assertTrue(seen_conv)
|
|
|
|
@skipIf(not SM80OrLater, "Requires SM80")
|
|
@dtypes(torch.float, torch.float16)
|
|
@parametrize(
|
|
"maxat",
|
|
[
|
|
(False, "ATEN,TRITON"),
|
|
(True, "ATEN,TRITON"),
|
|
(True, "ATEN"),
|
|
(True, "TRITON"),
|
|
],
|
|
)
|
|
@unittest.skipIf(
|
|
not IS_BIG_GPU, "we can't use Triton only as a backend for max autotune"
|
|
)
|
|
def test_pointwise_bandwidth(self, device, dtype, maxat):
|
|
# this tests to see if we can only use a Triton backend for max autotune
|
|
max_autotune, backends = maxat
|
|
if device == "cpu":
|
|
return
|
|
om = _pointwise_test_model(device, dtype, compile=False)
|
|
comp_omni = torch.compile(
|
|
om,
|
|
options={
|
|
"benchmark_kernel": True,
|
|
"max_autotune_gemm_backends": backends,
|
|
"force_disable_caches": True,
|
|
"max_autotune": max_autotune,
|
|
},
|
|
)
|
|
comp_omni()
|
|
|
|
with fresh_inductor_cache():
|
|
with torch.profiler.profile(record_shapes=True) as profile:
|
|
comp_omni()
|
|
trace1, _ = trace_files()
|
|
profile.export_chrome_trace(trace1)
|
|
|
|
with patch(
|
|
"sys.argv",
|
|
[*prefix, "--analysis", trace1, str(dtype).split(".")[-1]],
|
|
):
|
|
main()
|
|
|
|
with open(trace1) as f:
|
|
out_profile = json.load(f)
|
|
|
|
for event in out_profile["traceEvents"]:
|
|
if event["name"] == "triton_poi_fused_add_randn_sin_0":
|
|
event["args"]["kernel_num_gb"] = 0.002097168
|
|
|
|
|
|
instantiate_device_type_tests(TestAnalysis, globals())
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|