mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Inductor logging + analysis of torch.profile (#149697)
Prereqs: - https://github.com/pytorch/pytorch/pull/152708 Features: 1. Adds inductor's estimate of flops and bandwidth to the json trace events that perfetto uses. 1. Only use the tflops estimation from triton if we don't have the info from the datasheet because Triton's estimates are inaccurate. I have a backlog item to fix triton flops estimation upstream. New `DeviceInfo` class, and new function `get_device_tflops`. 1. New helpers `countable_fx` and `count_flops_fx` helps get the flops of an `fx.Node`. 1. Extends Triton `torch.profiler` logging to `DebugAutotuner`. 1. New script `profile_analysis.py`: `--augment_trace` adds perf estimates to any perfetto json trace, `--analyze` creates a summary table of these perf estimates, and `--diff` will compare two traces side by side: ```python Device(NVIDIA H100, 0): Kernel Name | resnet Kernel Count | resnet FLOPS | resnet bw gbps | resnet Dur (ms) | resnet Achieved FLOPS % | resnet Achieved Bandwidth % | newresnet Kernel Count | newresnet FLOPS | newresnet bw gbps | newresnet Dur (ms) | newresnet Achieved FLOPS % | newresnet Achieved Bandwidth % --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- triton_poi_fused__native_batch_norm_legi | 24 | 0 | 0.11395268248131513 | 2.5919166666666666 | 0 | 0.003401572611382541 | 24 | 0 | 0.11395268248131513 | 2.5919166666666666 | 0 | 0.003401572611382541 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 142 | 16932673552.422373 | 0.2585007824198784 | 12.441619718309857 | 0.08683422334575583 | 0.007716441266265022 | 142 | 16932673552.422373 | 0.2585007824198784 | 12.441619718309857 | 0.08683422334575583 | 0.007716441266265022 triton_red_fused__native_batch_norm_legi | 39 | 0 | 0.13990024992108846 | 5.752589743589743 | 0 | 0.004176126863316074 | 39 | 0 | 0.13990024992108846 | 5.752589743589743 | 0 | 0.004176126863316074 triton_poi_fused__native_batch_norm_legi | 25 | 0 | 0.31824055917536503 | 2.5291999999999994 | 0 | 0.009499718184339253 | 25 | 0 | 0.31824055917536503 | 2.5291999999999994 | 0 | 0.009499718184339253 void cutlass::Kernel2<cutlass_80_tensoro | 98 | 16211056473.596165 | 0.42972434051025826 | 7.130408163265306 | 0.08313362294151874 | 0.012827592254037562 | 98 | 16211056473.596165 | 0.42972434051025826 | 7.130408163265306 | 0.08313362294151874 | 0.012827592254037562 triton_red_fused__native_batch_norm_legi | 73 | 0 | 0.3225381327611705 | 9.987068493150682 | 0 | 0.009628003963020014 | 73 | 0 | 0.3225381327611705 | 9.987068493150682 | 0 | 0.009628003963020014 triton_poi_fused__native_batch_norm_legi | 15 | 0 | 1.4491211346487216 | 4.439333333333333 | 0 | 0.043257347302946926 | 15 | 0 | 1.4491211346487216 | 4.439333333333333 | 0 | 0.043257347302946926 void cutlass::Kernel2<cutlass_80_tensoro | 186 | 14501701145.337954 | 0.2667131401910989 | 7.873865591397849 | 0.07436769818122027 | 0.007961586274361157 | 186 | 14501701145.337954 | 0.2667131401910989 | 7.873865591397849 | 0.07436769818122027 | 0.007961586274361157 triton_poi_fused__native_batch_norm_legi | 33 | 0 | 1.4924556538193923 | 4.3101515151515155 | 0 | 0.044550915039384846 | 33 | 0 | 1.4924556538193923 | 4.3101515151515155 | 0 | 0.044550915039384846 triton_red_fused__native_batch_norm_legi | 29 | 0 | 0.25562590522631107 | 6.296275862068965 | 0 | 0.007630624036606301 | 29 | 0 | 0.25562590522631107 | 6.296275862068965 | 0 | 0.007630624036606301 triton_poi_fused__native_batch_norm_legi | 13 | 0 | 0.5870562174192726 | 2.7397692307692307 | 0 | 0.01752406619162008 | 13 | 0 | 0.5870562174192726 | 2.7397692307692307 | 0 | 0.01752406619162008 triton_poi_fused__native_batch_norm_legi | 34 | 0 | 0.41409928846284 | 2.853588235294117 | 0 | 0.012361172789935523 | 34 | 0 | 0.41409928846284 | 2.853588235294117 | 0 | 0.012361172789935523 triton_per_fused__native_batch_norm_legi | 34 | 0 | 0.11705315007018151 | 3.460647058823529 | 0 | 0.0034941238826919864 | 34 | 0 | 0.11705315007018151 | 3.460647058823529 | 0 | 0.0034941238826919864 triton_poi_fused__native_batch_norm_legi | 16 | 0 | 0.17207853197124584 | 2.3459375000000002 | 0 | 0.005136672596156592 | 16 | 0 | 0.17207853197124584 | 2.3459375000000002 | 0 | 0.005136672596156592 triton_per_fused__native_batch_norm_legi | 30 | 0 | 0.2639714322022256 | 6.131199999999999 | 0 | 0.007879744244842555 | 30 | 0 | 0.2639714322022256 | 6.131199999999999 | 0 | 0.007879744244842555 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 100 | 11875430356.891787 | 0.19494470869421385 | 16.36534 | 0.06089964285585531 | 0.005819245035648175 | 100 | 11875430356.891787 | 0.19494470869421385 | 16.36534 | 0.06089964285585531 | 0.005819245035648175 triton_poi_fused__native_batch_norm_legi | 8 | 0 | 0.9854096626224687 | 3.2757500000000004 | 0 | 0.029415213809625928 | 8 | 0 | 0.9854096626224687 | 3.2757500000000004 | 0 | 0.029415213809625928 void cublasLt::splitKreduce_kernel<32, 1 | 56 | 34377923395.147064 | 0.8310300045762317 | 3.4199999999999986 | 0.17629704305203628 | 0.024806865808245714 | 56 | 34377923395.147064 | 0.8310300045762317 | 3.4199999999999986 | 0.17629704305203628 | 0.024806865808245714 triton_poi_fused__native_batch_norm_legi | 23 | 0 | 0.9944002965861103 | 3.2431304347826084 | 0 | 0.02968359094286896 | 23 | 0 | 0.9944002965861103 | 3.2431304347826084 | 0 | 0.02968359094286896 triton_per_fused__native_batch_norm_legi | 10 | 0 | 0.1826801058931057 | 4.428800000000001 | 0 | 0.00545313748934644 | 10 | 0 | 0.1826801058931057 | 4.428800000000001 | 0 | 0.00545313748934644 triton_poi_fused__native_batch_norm_legi | 10 | 0 | 0.3168973585366449 | 2.5471999999999997 | 0 | 0.009459622642884923 | 10 | 0 | 0.3168973585366449 | 2.5471999999999997 | 0 | 0.009459622642884923 triton_poi_fused__native_batch_norm_legi | 34 | 0 | 1.1463614897015777 | 4.124323529411764 | 0 | 0.03421974596124114 | 34 | 0 | 1.1463614897015777 | 4.124323529411764 | 0 | 0.03421974596124114 void cask_plugin_cudnn::xmma_cudnn::init | 44 | 44045510816.64277 | 2.0661232850348643 | 3.6887499999999993 | 0.22587441444432194 | 0.06167532194133924 | 44 | 44045510816.64277 | 2.0661232850348643 | 3.6887499999999993 | 0.22587441444432194 | 0.06167532194133924 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 95 | 7876855400.165316 | 0.4694941555946739 | 18.224315789473682 | 0.04039413025725802 | 0.014014750913273854 | 95 | 7876855400.165316 | 0.4694941555946739 | 18.224315789473682 | 0.04039413025725802 | 0.014014750913273854 triton_per_fused__native_batch_norm_legi | 41 | 0 | 0.06825669875995298 | 3.0384146341463416 | 0 | 0.002037513395819492 | 41 | 0 | 0.06825669875995298 | 3.0384146341463416 | 0 | 0.002037513395819492 triton_poi_fused__native_batch_norm_legi | 23 | 0 | 0.08808154712430301 | 2.3275652173913044 | 0 | 0.0026292999141582997 | 23 | 0 | 0.08808154712430301 | 2.3275652173913044 | 0 | 0.0026292999141582997 triton_per_fused__native_batch_norm_legi | 40 | 0 | 0.18179321034952417 | 4.556825 | 0 | 0.005426662995508183 | 40 | 0 | 0.18179321034952417 | 4.556825 | 0 | 0.005426662995508183 triton_poi_fused__native_batch_norm_legi | 15 | 0 | 0.5887415155454232 | 2.783866666666667 | 0 | 0.017574373598370836 | 15 | 0 | 0.5887415155454232 | 2.783866666666667 | 0 | 0.017574373598370836 void cutlass::Kernel2<cutlass_80_tensoro | 38 | 14242013806.264643 | 0.256592404353939 | 7.217631578947369 | 0.0730359682372546 | 0.007659474756834 | 38 | 14242013806.264643 | 0.256592404353939 | 7.217631578947369 | 0.0730359682372546 | 0.007659474756834 triton_poi_fused__native_batch_norm_legi | 21 | 0 | 0.5842860973430516 | 2.7779047619047623 | 0 | 0.017441376040091088 | 21 | 0 | 0.5842860973430516 | 2.7779047619047623 | 0 | 0.017441376040091088 triton_per_fused__native_batch_norm_legi | 16 | 0 | 0.11509365173486417 | 3.5959375000000002 | 0 | 0.0034356313950705724 | 16 | 0 | 0.11509365173486417 | 3.5959375000000002 | 0 | 0.0034356313950705724 triton_poi_fused__native_batch_norm_legi | 14 | 0 | 0.1704672000243914 | 2.4044285714285714 | 0 | 0.00508857313505646 | 14 | 0 | 0.1704672000243914 | 2.4044285714285714 | 0 | 0.00508857313505646 triton_poi_fused__native_batch_norm_legi | 58 | 0 | 2.307520779930795 | 8.190706896551722 | 0 | 0.06888121731136704 | 58 | 0 | 2.307520779930795 | 8.190706896551722 | 0 | 0.06888121731136704 triton_per_fused__native_batch_norm_legi | 29 | 0 | 0.037243248971881276 | 3.0277586206896556 | 0 | 0.001111738775280038 | 29 | 0 | 0.037243248971881276 | 3.0277586206896556 | 0 | 0.001111738775280038 triton_poi_fused__native_batch_norm_legi | 20 | 0 | 0.04741699795428918 | 2.2911500000000005 | 0 | 0.0014154327747549007 | 20 | 0 | 0.04741699795428918 | 2.2911500000000005 | 0 | 0.0014154327747549007 triton_per_fused__native_batch_norm_legi | 25 | 0 | 0.13357016893727824 | 3.37536 | 0 | 0.003987169222008305 | 25 | 0 | 0.13357016893727824 | 3.37536 | 0 | 0.003987169222008305 triton_poi_fused__native_batch_norm_legi | 13 | 0 | 0.3089862268300253 | 2.8111538461538457 | 0 | 0.009223469457612694 | 13 | 0 | 0.3089862268300253 | 2.8111538461538457 | 0 | 0.009223469457612694 triton_poi_fused__native_batch_norm_legi | 17 | 0 | 0.3129385387909844 | 2.673 | 0 | 0.009341448919133863 | 17 | 0 | 0.3129385387909844 | 2.673 | 0 | 0.009341448919133863 triton_per_fused__native_batch_norm_legi | 19 | 0 | 0.2215568162533158 | 3.8837368421052636 | 0 | 0.0066136363060691275 | 19 | 0 | 0.2215568162533158 | 3.8837368421052636 | 0 | 0.0066136363060691275 std::enable_if<!(false), void>::type int | 23 | 504916805.19297093 | 1.0118296096314707 | 8.113913043478261 | 0.0025893169497075447 | 0.030203868944223014 | 23 | 504916805.19297093 | 1.0118296096314707 | 8.113913043478261 | 0.0025893169497075447 | 0.030203868944223014 triton_poi_fused_add_copy__38 | 56 | 0 | 0 | 2.132482142857143 | 0 | 0 | 56 | 0 | 0 | 2.132482142857143 | 0 | 0 triton_poi_fused_convolution_0 | 18 | 0 | 0.43458610794936897 | 2.773333333333334 | 0 | 0.012972719640279667 | 18 | 0 | 0.43458610794936897 | 2.773333333333334 | 0 | 0.012972719640279667 triton_poi_fused_convolution_1 | 17 | 0 | 0.028816312469162712 | 2.6145882352941174 | 0 | 0.0008601884319153051 | 17 | 0 | 0.028816312469162712 | 2.6145882352941174 | 0 | 0.0008601884319153051 void convolve_common_engine_float_NHWC<f | 44 | 8641868995.31118 | 0.024730540008465626 | 25.87327272727273 | 0.04431727689903169 | 0.0007382250748795709 | 44 | 8641868995.31118 | 0.024730540008465626 | 25.87327272727273 | 0.04431727689903169 | 0.0007382250748795709 triton_per_fused__native_batch_norm_legi | 12 | 0 | 0.6809930918986744 | 4.82675 | 0 | 0.020328151996975356 | 12 | 0 | 0.6809930918986744 | 4.82675 | 0 | 0.020328151996975356 triton_per_fused__native_batch_norm_legi | 14 | 0 | 0.02883030597936608 | 2.6651428571428575 | 0 | 0.0008606061486377935 | 14 | 0 | 0.02883030597936608 | 2.6651428571428575 | 0 | 0.0008606061486377935 triton_per_fused__native_batch_norm_legi | 16 | 0 | 0.0014658988233201874 | 2.098 | 0 | 4.375817383045335e-05 | 16 | 0 | 0.0014658988233201874 | 2.098 | 0 | 4.375817383045335e-05 triton_poi_fused__native_batch_norm_legi | 13 | 0 | 0.9926297180284697 | 3.2367692307692306 | 0 | 0.02963073785159611 | 13 | 0 | 0.9926297180284697 | 3.2367692307692306 | 0 | 0.02963073785159611 triton_poi_fused__native_batch_norm_legi | 9 | 0 | 1.3008817095666507 | 3.0863333333333336 | 0 | 0.03883228983781048 | 9 | 0 | 1.3008817095666507 | 3.0863333333333336 | 0 | 0.03883228983781048 void at::native::(anonymous namespace):: | 98 | 0 | 0.09174335613709389 | 4.408520408163265 | 0 | 0.0027386076458833994 | 98 | 0 | 0.09174335613709389 | 4.408520408163265 | 0 | 0.0027386076458833994 void at::native::vectorized_elementwise_ | 7 | 0 | 0 | 1.7278571428571428 | 0 | 0 | 7 | 0 | 0 | 1.7278571428571428 | 0 | 0 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/149697 Approved by: https://github.com/eellison, https://github.com/shunting314
This commit is contained in:
committed by
PyTorch MergeBot
parent
0f9c1b374f
commit
47f10d0ad0
608
test/inductor/test_analysis.py
Normal file
608
test/inductor/test_analysis.py
Normal file
@ -0,0 +1,608 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
import unittest
|
||||
import uuid
|
||||
from io import StringIO
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch._inductor.analysis.profile_analysis import (
|
||||
_augment_trace_helper,
|
||||
_create_extern_mapping,
|
||||
JsonProfile,
|
||||
main,
|
||||
)
|
||||
from torch._inductor.utils import fresh_inductor_cache, tabulate_2d, zip_dicts
|
||||
from torch.testing._internal.common_cuda import SM80OrLater
|
||||
from torch.testing._internal.common_device_type import (
|
||||
dtypes,
|
||||
instantiate_device_type_tests,
|
||||
skipIf,
|
||||
)
|
||||
from torch.testing._internal.common_utils import parametrize, run_tests, TestCase
|
||||
from torch.testing._internal.inductor_utils import IS_BIG_GPU
|
||||
|
||||
|
||||
example_profile = """
|
||||
{
|
||||
"schemaVersion": 1,
|
||||
"deviceProperties": [
|
||||
{
|
||||
"id": 0, "name": "NVIDIA H100", "totalGlobalMem": 101997215744,
|
||||
"computeMajor": 9, "computeMinor": 0,
|
||||
"maxThreadsPerBlock": 1024, "maxThreadsPerMultiprocessor": 2048,
|
||||
"regsPerBlock": 65536, "warpSize": 32,
|
||||
"sharedMemPerBlock": 49152, "numSms": 132
|
||||
, "regsPerMultiprocessor": 65536, "sharedMemPerBlockOptin": 232448, "sharedMemPerMultiprocessor": 233472
|
||||
}
|
||||
],
|
||||
"cupti_version": 24,
|
||||
"cuda_runtime_version": 12060,
|
||||
"with_flops": 1,
|
||||
"record_shapes": 1,
|
||||
"cuda_driver_version": 12040,
|
||||
"profile_memory": 1,
|
||||
"trace_id": "301995E163ED42048FBD783860E6E7DC",
|
||||
"displayTimeUnit": "ms",
|
||||
"baseTimeNanoseconds": 1743521598000000000,
|
||||
"traceEvents": [
|
||||
{
|
||||
"ph": "X", "cat": "cpu_op", "name": "aten::convolution", "pid": 1147039, "tid": 1147039,
|
||||
"ts": 198093488368.463, "dur": 425.453,
|
||||
"args": {
|
||||
"External id": 1340,"Sequence number": 0, "Fwd thread id": 0, "Record function id": 0, "Concrete Inputs": \
|
||||
["", "", "", "[2, 2]", "[3, 3]", "[1, 1]", "False", "[0, 0]", "1"], "Input type": ["float", "float", "", \
|
||||
"ScalarList", "ScalarList", "ScalarList", "Scalar", "ScalarList", "Scalar"], "Input Strides": [[150528, 1, 672, 3],\
|
||||
[147, 1, 21, 3], [], [], [], [], [], [], []], "Input Dims": [[1, 3, 224, 224], [64, 3, 7, 7], [], [], [], [], [], \
|
||||
[], []], "Ev Idx": 1339
|
||||
}
|
||||
},
|
||||
{
|
||||
"ph": "X", "cat": "cpu_op", "name": "aten::_convolution", "pid": 1147039, "tid": 1147039,
|
||||
"ts": 198093488444.498, "dur": 341.867,
|
||||
"args": {
|
||||
"External id": 1341,"Record function id": 0, "Concrete Inputs": ["", "", "", "[2, 2]", "[3, 3]", "[1, 1]",\
|
||||
"False", "[0, 0]", "1", "False", "False", "True", "True"], "Input type": ["float", "float", "", "ScalarList",\
|
||||
"ScalarList", "ScalarList", "Scalar", "ScalarList", "Scalar", "Scalar", "Scalar", "Scalar", "Scalar"], "Input Strides": \
|
||||
[[150528, 1, 672, 3], [147, 1, 21, 3], [], [], [], [], [], [], [], [], [], [], []], "Input Dims": [[1, 3, 224, 224], \
|
||||
[64, 3, 7, 7], [], [], [], [], [], [], [], [], [], [], []], "Ev Idx": 1340
|
||||
}
|
||||
},
|
||||
{
|
||||
"ph": "X", "cat": "cpu_op", "name": "aten::addmm", "pid": 1147039, "tid": 1147039,
|
||||
"ts": 198093513655.849, "dur": 251.130,
|
||||
"args": {
|
||||
"External id": 1619,"Sequence number": 0, "Fwd thread id": 0, "Record function id": 0, "Concrete Inputs": \
|
||||
["", "", "", "1", "1", ""], "Input type": ["float", "float", "float", "Scalar", "Scalar", "float"], "Input Strides":\
|
||||
[[1], [0, 1], [1, 2048], [], [], [1000, 1]], "Input Dims": [[1000], [1, 2048], [2048, 1000], [], [], [1, 1000]], \
|
||||
"Ev Idx": 1618
|
||||
}
|
||||
},
|
||||
{
|
||||
"ph": "X", "cat": "kernel", "name": "void cutlass_addmm", "pid": 1147039, "tid": 1147039,
|
||||
"ts": 198093513655.849, "dur": 251.130,
|
||||
"args": {
|
||||
"External id": 1619,"Sequence number": 0, "Fwd thread id": 0, "Record function id": 0, "Ev Idx": 1618
|
||||
}
|
||||
},
|
||||
{
|
||||
"ph": "X", "cat": "kernel", "name": "void convolution_kernel", "pid": 1147039, "tid": 1147039,
|
||||
"ts": 198093513655.849, "dur": 200.130,
|
||||
"args": {
|
||||
"External id": 1342, "Sequence number": 0, "Fwd thread id": 0, "Record function id": 0, "Ev Idx": 1618
|
||||
}
|
||||
},
|
||||
{
|
||||
"ph": "X", "cat": "cpu_op", "name": "aten::convolution", "pid": 1147039, "tid": 1147039,
|
||||
"ts": 198093488444.498, "dur": 341.867,
|
||||
"args": {
|
||||
"External id": 1342,"Record function id": 0, "Concrete Inputs": ["", "", "", "[2, 2]", "[3, 3]", "[1, 1]", \
|
||||
"False", "[0, 0]", "1", "False", "False", "True", "True"], "Input type": ["float", "float", "", "ScalarList", \
|
||||
"ScalarList", "ScalarList", "Scalar", "ScalarList", "Scalar", "Scalar", "Scalar", "Scalar", "Scalar"], "Input \
|
||||
Strides": [[150528, 1, 672, 3], [147, 1, 21, 3], [], [], [], [], [], [], [], [], [], [], []], "Input Dims": \
|
||||
[[1, 3, 224, 224], [64, 3, 7, 7], [], [], [], [], [], [], [], [], [], [], []], "Ev Idx": 1340
|
||||
}
|
||||
}
|
||||
],
|
||||
"traceName": "/tmp/compiled_module_profile.json"
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def verify_flops(self, expected_flops, out_profile):
|
||||
j = 0
|
||||
for i in range(len(out_profile["traceEvents"])):
|
||||
if "kernel_flop" in out_profile["traceEvents"][i]["args"]:
|
||||
self.assertEqual(
|
||||
out_profile["traceEvents"][i]["args"]["kernel_flop"],
|
||||
expected_flops[j],
|
||||
)
|
||||
j += 1
|
||||
|
||||
|
||||
def random_tensor(size, dtype, **kwargs):
|
||||
if dtype in [torch.half, torch.bfloat16, torch.float, torch.double]:
|
||||
return torch.randn(size, dtype=dtype, **kwargs)
|
||||
elif dtype in [torch.uint8, torch.int8, torch.short, torch.int, torch.long]:
|
||||
return torch.randint(0, 100, size, dtype=dtype, **kwargs)
|
||||
else:
|
||||
raise ValueError("Unsupported data type")
|
||||
|
||||
|
||||
def cT(device, dtype):
|
||||
def T(*shape, requires_grad=False):
|
||||
return random_tensor(
|
||||
shape, requires_grad=requires_grad, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
return T
|
||||
|
||||
|
||||
def FlopCounterMode(*args, **kwargs):
|
||||
return torch.utils.flop_counter.FlopCounterMode(*args, **kwargs, display=False)
|
||||
|
||||
|
||||
TMP_DIR = tempfile.mkdtemp()
|
||||
|
||||
|
||||
def trace_files():
|
||||
TRACE1 = f"{TMP_DIR}/trace1-{uuid.uuid4()}.json"
|
||||
TRACE2 = f"{TMP_DIR}/trace2-{uuid.uuid4()}.json"
|
||||
return TRACE1, TRACE2
|
||||
|
||||
|
||||
def _test_model(device, dtype, compile=True, addmm=True, bmm=True):
|
||||
T = cT(device, dtype)
|
||||
|
||||
def model():
|
||||
input_conv = T(1, 3, 56, 56)
|
||||
conv_weight = T(12, 3, 5, 5)
|
||||
|
||||
# Increased matrix sizes
|
||||
B = 8
|
||||
M = 256
|
||||
N = 512
|
||||
K = 768
|
||||
mat1 = T(M, N)
|
||||
mat2 = T(N, K)
|
||||
|
||||
batch_mat1 = T(B, M, N)
|
||||
batch_mat2 = T(B, N, K)
|
||||
|
||||
conv_output = F.conv2d(input_conv, conv_weight)
|
||||
|
||||
conv_output = conv_output * 10
|
||||
mm_output = torch.mm(mat1, mat2)
|
||||
ret = [
|
||||
conv_output.flatten(),
|
||||
mm_output.flatten(),
|
||||
]
|
||||
|
||||
if addmm:
|
||||
addmm_output = torch.addmm(
|
||||
torch.zeros(mm_output.shape, device=mat1.device, dtype=mat1.dtype),
|
||||
mat1,
|
||||
mat2,
|
||||
)
|
||||
ret.append(addmm_output.flatten())
|
||||
|
||||
if bmm:
|
||||
bmm_output = torch.bmm(batch_mat1, batch_mat2)
|
||||
ret.append(bmm_output.flatten())
|
||||
|
||||
if bmm and addmm:
|
||||
baddbmm_output = torch.baddbmm(
|
||||
torch.zeros(
|
||||
1,
|
||||
*mm_output.shape,
|
||||
device=batch_mat1.device,
|
||||
dtype=batch_mat1.dtype,
|
||||
),
|
||||
batch_mat1,
|
||||
batch_mat2,
|
||||
)
|
||||
ret.append(baddbmm_output.flatten())
|
||||
|
||||
return torch.cat(ret)
|
||||
|
||||
if compile:
|
||||
return torch.compile(
|
||||
model, options={"benchmark_kernel": True, "profile_bandwidth": True}
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def _pointwise_test_model(device, dtype, compile=True):
|
||||
T = cT(device, dtype)
|
||||
|
||||
def model():
|
||||
M = 1024
|
||||
N = 512
|
||||
mat3 = T(M, N)
|
||||
mat4 = T(M, N)
|
||||
pointwise_output = torch.add(mat3, mat4).sin()
|
||||
return pointwise_output
|
||||
|
||||
if compile:
|
||||
return torch.compile(
|
||||
model, options={"benchmark_kernel": True, "profile_bandwidth": True}
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
prefix = ["profile.py"]
|
||||
|
||||
|
||||
class TestUtils(TestCase):
|
||||
def test_tabulate2d(self):
|
||||
headers = ["Kernel", "Self H100 TIME (ms)", "Count", "Percent"]
|
||||
rows = [
|
||||
["aten::mm", 0.500, 7, 0.0],
|
||||
["aten::bmm", 0.400, 6, 0.0],
|
||||
["aten::baddbmm", 0.300, 5, 0.0],
|
||||
["aten::convolution", 0.200, 4, 0.0],
|
||||
["aten::cudnn_convolution", 0.100, 3, 0.0],
|
||||
]
|
||||
table = [
|
||||
" Kernel | Self H100 TIME (ms) | Count | Percent ",
|
||||
"-----------------------------------------------------------------",
|
||||
" aten::mm | 0.5 | 7 | 0.0 ",
|
||||
" aten::bmm | 0.4 | 6 | 0.0 ",
|
||||
" aten::baddbmm | 0.3 | 5 | 0.0 ",
|
||||
" aten::convolution | 0.2 | 4 | 0.0 ",
|
||||
" aten::cudnn_convolution | 0.1 | 3 | 0.0 ",
|
||||
]
|
||||
res = tabulate_2d(rows, headers)
|
||||
for r, t in zip(res.split("\n"), table):
|
||||
self.assertEqual(r, t)
|
||||
|
||||
def test_zip_dicts(self):
|
||||
d1 = {"a": 1, "b": 2}
|
||||
d2 = {"a": 3, "c": 4}
|
||||
res1 = zip_dicts(d1, d2, d1_default=32, d2_default=48)
|
||||
self.assertEqual(set(res1), {("a", 1, 3), ("b", 2, 48), ("c", 32, 4)})
|
||||
res2 = zip_dicts(d1, d2)
|
||||
self.assertEqual(set(res2), {("a", 1, 3), ("b", 2, None), ("c", None, 4)})
|
||||
|
||||
|
||||
class TestAnalysis(TestCase):
|
||||
@skipIf(not SM80OrLater, "Requires SM80")
|
||||
def test_noop(self):
|
||||
with (
|
||||
patch("sys.stdout", new_callable=StringIO) as mock_stdout,
|
||||
patch("sys.argv", [*prefix]),
|
||||
):
|
||||
main()
|
||||
self.assertEqual(mock_stdout.getvalue(), "")
|
||||
|
||||
@skipIf(not SM80OrLater, "Requires SM80")
|
||||
@dtypes(torch.float, torch.double, torch.float16)
|
||||
def test_diff(self, device, dtype):
|
||||
"""
|
||||
diff, testing out the nruns feature too.
|
||||
"""
|
||||
if device == "cpu":
|
||||
# TODO cpu support
|
||||
return
|
||||
om = _test_model(device, dtype)
|
||||
REPEAT = 5
|
||||
trace1, trace2 = trace_files()
|
||||
print(f"first trace {trace1}")
|
||||
torch._dynamo.reset() # reset the cache
|
||||
with fresh_inductor_cache():
|
||||
with torch.profiler.profile(record_shapes=True) as p:
|
||||
om()
|
||||
p.export_chrome_trace(trace1)
|
||||
|
||||
print(f"second trace {trace2}")
|
||||
torch._dynamo.reset() # reset the cache
|
||||
with fresh_inductor_cache():
|
||||
with torch.profiler.profile(record_shapes=True) as p:
|
||||
for _ in range(REPEAT):
|
||||
om()
|
||||
p.export_chrome_trace(trace2)
|
||||
|
||||
print("diffing...")
|
||||
with patch(
|
||||
"sys.argv",
|
||||
[
|
||||
*prefix,
|
||||
"--diff",
|
||||
trace1,
|
||||
"foo",
|
||||
trace2,
|
||||
"bar",
|
||||
str(dtype).split(".")[-1],
|
||||
"--name_limit",
|
||||
"30",
|
||||
],
|
||||
):
|
||||
main()
|
||||
|
||||
@skipIf(not SM80OrLater, "Requires SM80")
|
||||
def test_augment_trace_helper_unit(self):
|
||||
js = json.loads(example_profile)
|
||||
out_profile = _augment_trace_helper(js)
|
||||
expected_flops = [4096000, 4096000, 223552896, 223552896, 0, 0, 0]
|
||||
verify_flops(self, expected_flops, out_profile)
|
||||
|
||||
@skipIf(not SM80OrLater, "Requires SM80")
|
||||
@dtypes(torch.float, torch.double, torch.float16)
|
||||
def test_augment_trace_helper_args(self, device, dtype):
|
||||
if device == "cpu":
|
||||
# cpu doesn't produce traces currently
|
||||
return
|
||||
om = _test_model(device, dtype)
|
||||
torch._dynamo.reset() # reset the cache
|
||||
with fresh_inductor_cache():
|
||||
with torch.profiler.profile(record_shapes=True) as p:
|
||||
om()
|
||||
trace1, trace2 = trace_files()
|
||||
p.export_chrome_trace(trace1)
|
||||
print(f"first trace {trace1}")
|
||||
|
||||
with patch(
|
||||
"sys.argv",
|
||||
[*prefix, "--augment_trace", trace1, trace2, str(dtype).split(".")[-1]],
|
||||
):
|
||||
main()
|
||||
profile = JsonProfile(
|
||||
trace2, benchmark_name="foo", dtype=str(dtype).split(".")[-1]
|
||||
)
|
||||
rep = profile.report()
|
||||
self.assertTrue(len(rep.split("\n")) > 3, f"Error, empty table:\n{rep}")
|
||||
# If these fail, just update them. They could change over time
|
||||
self.assertIn("Kernel Name", rep)
|
||||
self.assertIn("Kernel Count", rep)
|
||||
self.assertIn("FLOPS", rep)
|
||||
self.assertIn("Kernel Reads", rep)
|
||||
self.assertIn("Dur", rep)
|
||||
self.assertIn("Achieved", rep)
|
||||
self.assertIn("|", rep)
|
||||
self.assertIn("-----", rep)
|
||||
|
||||
tables = profile._create_tables(profile._devices)
|
||||
|
||||
# check to make sure all % values are less than 100%
|
||||
percents = []
|
||||
for tab in tables.values():
|
||||
header, rows = tab
|
||||
for i, h in enumerate(header):
|
||||
if "%" in h:
|
||||
percents.append(i)
|
||||
self.assertTrue(len(percents) > 0, "There are no headers with % in them")
|
||||
for row in rows.values():
|
||||
for p in percents:
|
||||
idx = p - 1
|
||||
self.assertTrue(
|
||||
float(row[idx]) <= 100.0,
|
||||
f"column values from column {idx} with header '{header[idx]}' is greater than 100%: {row[idx]}",
|
||||
)
|
||||
self.assertTrue(
|
||||
float(row[idx]) >= 0.0,
|
||||
f"column values from column {idx} with header '{header[idx]}' is less than 0%: {row[idx]}",
|
||||
)
|
||||
|
||||
@skipIf(not SM80OrLater, "Requires SM80")
|
||||
@dtypes(torch.float, torch.double, torch.float16)
|
||||
@parametrize(
|
||||
"maxat",
|
||||
[
|
||||
(True, "TRITON"),
|
||||
],
|
||||
)
|
||||
@skipIf(not IS_BIG_GPU, "we can't use Triton only as a backend for max autotune")
|
||||
def test_triton_has_metadata(self, device, dtype, maxat):
|
||||
"""
|
||||
make sure that the chrome trace of triton kernels contains certain values
|
||||
"""
|
||||
if device == "cpu":
|
||||
return
|
||||
|
||||
T = cT(device, dtype)
|
||||
input_conv = T(1, 3, 56, 56)
|
||||
conv_weight = T(12, 3, 5, 5)
|
||||
|
||||
def om(i, w):
|
||||
# Convolution operation
|
||||
conv_output = F.conv2d(i, w)
|
||||
return conv_output
|
||||
|
||||
max_autotune, backends = maxat
|
||||
comp_omni = torch.compile(
|
||||
om,
|
||||
options={
|
||||
"benchmark_kernel": True,
|
||||
"max_autotune_gemm_backends": backends,
|
||||
"force_disable_caches": True,
|
||||
"max_autotune": max_autotune,
|
||||
},
|
||||
)
|
||||
|
||||
def verify_triton(comp):
|
||||
torch._dynamo.reset() # reset the cache
|
||||
with fresh_inductor_cache():
|
||||
with torch.profiler.profile(record_shapes=True) as profile:
|
||||
comp(input_conv, conv_weight)
|
||||
|
||||
trace1, _ = trace_files()
|
||||
profile.export_chrome_trace(trace1)
|
||||
with open(trace1) as f:
|
||||
out_profile = json.load(f)
|
||||
seen = False
|
||||
for event in out_profile["traceEvents"]:
|
||||
if "triton" in event["name"] and "conv" in event["name"]:
|
||||
seen = True
|
||||
self.assertTrue(seen, "no triton conv found")
|
||||
|
||||
verify_triton(comp_omni)
|
||||
|
||||
@skipIf(not SM80OrLater, "Requires SM80")
|
||||
@dtypes(torch.float, torch.float16)
|
||||
@parametrize(
|
||||
"maxat",
|
||||
[
|
||||
(False, "ATEN,TRITON"),
|
||||
(True, "ATEN,TRITON"),
|
||||
(True, "ATEN"),
|
||||
(True, "TRITON"),
|
||||
],
|
||||
)
|
||||
@unittest.skipIf(
|
||||
not IS_BIG_GPU, "we can't use Triton only as a backend for max autotune"
|
||||
)
|
||||
def test_augment_trace_against_flop_counter(self, device, dtype, maxat):
|
||||
# this tests to see if we can only use a Triton backend for max autotune
|
||||
max_autotune, backends = maxat
|
||||
if device == "cpu":
|
||||
return
|
||||
om = _test_model(device, dtype, compile=False)
|
||||
|
||||
comp_omni = torch.compile(
|
||||
om,
|
||||
options={
|
||||
"benchmark_kernel": True,
|
||||
"max_autotune_gemm_backends": backends,
|
||||
"force_disable_caches": True,
|
||||
"max_autotune": max_autotune,
|
||||
},
|
||||
)
|
||||
comp_omni()
|
||||
|
||||
torch._dynamo.reset() # reset the cache
|
||||
with fresh_inductor_cache():
|
||||
with torch.profiler.profile(record_shapes=True) as profile:
|
||||
comp_omni()
|
||||
|
||||
torch._dynamo.reset() # reset the cache
|
||||
with fresh_inductor_cache():
|
||||
with FlopCounterMode() as mode:
|
||||
comp_omni()
|
||||
|
||||
trace1, trace2 = trace_files()
|
||||
profile.export_chrome_trace(trace1)
|
||||
with patch(
|
||||
"sys.argv",
|
||||
[*prefix, "--augment_trace", trace1, trace2, str(dtype).split(".")[-1]],
|
||||
):
|
||||
main()
|
||||
|
||||
with open(trace2) as f:
|
||||
out_profile = json.load(f)
|
||||
|
||||
flop_counts = mode.flop_counts
|
||||
extern_mapping = _create_extern_mapping(out_profile)
|
||||
|
||||
seen_mm = False
|
||||
seen_bmm = False
|
||||
seen_baddbmm = False
|
||||
seen_conv = False
|
||||
for event in out_profile["traceEvents"]:
|
||||
if (
|
||||
"cat" not in event
|
||||
or event["cat"] != "kernel"
|
||||
or "args" not in event
|
||||
or "External id" not in event["args"]
|
||||
):
|
||||
continue
|
||||
|
||||
external_op = extern_mapping[event["args"]["External id"]][0]
|
||||
name: str = external_op["name"]
|
||||
self.assertNotEqual(name, None)
|
||||
self.assertEqual(type(name), str)
|
||||
if name.startswith("aten::mm") or "_mm_" in name:
|
||||
seen_mm = True
|
||||
self.assertEqual(
|
||||
event["args"]["kernel_flop"],
|
||||
flop_counts["Global"][torch.ops.aten.mm],
|
||||
)
|
||||
if (
|
||||
name.startswith(
|
||||
(
|
||||
"aten::cudnn_convolution",
|
||||
"aten::convolution",
|
||||
"aten::_convolution",
|
||||
)
|
||||
)
|
||||
or "conv" in name
|
||||
):
|
||||
seen_conv = True
|
||||
self.assertEqual(
|
||||
event["args"]["kernel_flop"],
|
||||
flop_counts["Global"][torch.ops.aten.convolution],
|
||||
)
|
||||
if name.startswith("aten::baddbmm") or "_baddbmm_" in name:
|
||||
seen_baddbmm = True
|
||||
self.assertEqual(
|
||||
event["args"]["kernel_flop"],
|
||||
flop_counts["Global"][torch.ops.aten.baddbmm],
|
||||
)
|
||||
if name.startswith("aten::bmm") or "_bmm_" in name:
|
||||
seen_bmm = True
|
||||
self.assertEqual(
|
||||
event["args"]["kernel_flop"],
|
||||
flop_counts["Global"][torch.ops.aten.bmm],
|
||||
)
|
||||
self.assertTrue(seen_mm)
|
||||
self.assertTrue(seen_bmm)
|
||||
self.assertTrue(seen_baddbmm)
|
||||
self.assertTrue(seen_conv)
|
||||
|
||||
@skipIf(not SM80OrLater, "Requires SM80")
|
||||
@dtypes(torch.float, torch.float16)
|
||||
@parametrize(
|
||||
"maxat",
|
||||
[
|
||||
(False, "ATEN,TRITON"),
|
||||
(True, "ATEN,TRITON"),
|
||||
(True, "ATEN"),
|
||||
(True, "TRITON"),
|
||||
],
|
||||
)
|
||||
@unittest.skipIf(
|
||||
not IS_BIG_GPU, "we can't use Triton only as a backend for max autotune"
|
||||
)
|
||||
def test_pointwise_bandwidth(self, device, dtype, maxat):
|
||||
# this tests to see if we can only use a Triton backend for max autotune
|
||||
max_autotune, backends = maxat
|
||||
if device == "cpu":
|
||||
return
|
||||
om = _pointwise_test_model(device, dtype, compile=False)
|
||||
comp_omni = torch.compile(
|
||||
om,
|
||||
options={
|
||||
"benchmark_kernel": True,
|
||||
"max_autotune_gemm_backends": backends,
|
||||
"force_disable_caches": True,
|
||||
"max_autotune": max_autotune,
|
||||
},
|
||||
)
|
||||
comp_omni()
|
||||
|
||||
with fresh_inductor_cache():
|
||||
with torch.profiler.profile(record_shapes=True) as profile:
|
||||
comp_omni()
|
||||
trace1, _ = trace_files()
|
||||
profile.export_chrome_trace(trace1)
|
||||
|
||||
with patch(
|
||||
"sys.argv",
|
||||
[*prefix, "--analysis", trace1, str(dtype).split(".")[-1]],
|
||||
):
|
||||
main()
|
||||
|
||||
with open(trace1) as f:
|
||||
out_profile = json.load(f)
|
||||
|
||||
for event in out_profile["traceEvents"]:
|
||||
if event["name"] == "triton_poi_fused_add_randn_sin_0":
|
||||
event["args"]["kernel_num_gb"] = 0.002097168
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestAnalysis, globals())
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -1,11 +1,12 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
|
||||
from unittest import skipIf
|
||||
|
||||
import torch
|
||||
import torch._inductor.metrics as metrics
|
||||
import torch.utils.flop_counter
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor.ir import FixedLayout
|
||||
from torch._inductor.utils import fresh_cache
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
from torch.testing._internal.common_cuda import SM70OrLater
|
||||
from torch.testing._internal.common_device_type import (
|
||||
dtypes,
|
||||
@ -13,6 +14,7 @@ from torch.testing._internal.common_device_type import (
|
||||
skipCUDAIf,
|
||||
)
|
||||
from torch.testing._internal.common_utils import parametrize, run_tests, TestCase
|
||||
from torch.testing._internal.inductor_utils import IS_BIG_GPU
|
||||
|
||||
|
||||
def FlopCounterMode(*args, **kwargs):
|
||||
@ -77,7 +79,7 @@ class TestScheduler(TestCase):
|
||||
for op, example_inputs, kwargs in tc:
|
||||
comp = torch.compile(op)
|
||||
torch._dynamo.reset()
|
||||
with fresh_cache():
|
||||
with fresh_inductor_cache():
|
||||
comp(*example_inputs, **kwargs)
|
||||
self.assertEqual(metrics.num_bytes_accessed, 0)
|
||||
self.assertEqual(any(m[1] for m in metrics.node_runtimes), False)
|
||||
@ -85,37 +87,6 @@ class TestScheduler(TestCase):
|
||||
metrics.reset()
|
||||
torch._logging.set_logs()
|
||||
|
||||
@dtypes(torch.float, torch.float16)
|
||||
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
|
||||
def test_get_estimated_runtime_logging(self, device, dtype):
|
||||
if device == "cpu":
|
||||
return
|
||||
tc = _test_cases(device, dtype)
|
||||
expected_metrics = [
|
||||
# num_bytes_accessed, number of nonzero node_runtimes
|
||||
(74 * dtype.itemsize, 1),
|
||||
(60 * dtype.itemsize, 1),
|
||||
(222 * dtype.itemsize, 4),
|
||||
(77 * dtype.itemsize, 2),
|
||||
]
|
||||
tc_plus_metrics = zip(tc, expected_metrics)
|
||||
|
||||
metrics.reset()
|
||||
torch._logging.set_logs(inductor_metrics=True)
|
||||
for test_case, met in tc_plus_metrics:
|
||||
op, example_inputs, kwargs = test_case
|
||||
enba, enr = met
|
||||
|
||||
comp = torch.compile(op)
|
||||
torch._dynamo.reset()
|
||||
with fresh_cache():
|
||||
comp(*example_inputs, **kwargs)
|
||||
self.assertEqual(enba, metrics.num_bytes_accessed)
|
||||
nonzero_node_runtimes = sum(1 for x in metrics.node_runtimes if x[1] != 0)
|
||||
self.assertEqual(enr, nonzero_node_runtimes)
|
||||
metrics.reset()
|
||||
torch._logging.set_logs()
|
||||
|
||||
@dtypes(torch.float, torch.float16)
|
||||
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
|
||||
@parametrize(
|
||||
@ -133,17 +104,10 @@ class TestScheduler(TestCase):
|
||||
},
|
||||
],
|
||||
)
|
||||
@skipIf(not IS_BIG_GPU, "we can't use Triton only as a backend for max autotune")
|
||||
def test_flop_counter_op(self, device, dtype, options):
|
||||
if device == "cpu":
|
||||
return
|
||||
if (
|
||||
options["max_autotune_gemm_backends"] == "TRITON"
|
||||
and torch.cuda.is_available()
|
||||
and not torch._inductor.utils.use_triton_template(
|
||||
FixedLayout(torch.device("cuda"), torch.float16, [400, 800])
|
||||
)
|
||||
):
|
||||
return
|
||||
|
||||
tc = _test_cases(device, dtype)
|
||||
|
||||
@ -152,7 +116,7 @@ class TestScheduler(TestCase):
|
||||
comp = torch.compile(op, options=options)
|
||||
# next two lines are required, otherwise the flops will be cached from pervious runs of this function.
|
||||
torch._dynamo.reset()
|
||||
with fresh_cache():
|
||||
with fresh_inductor_cache():
|
||||
# actually run to set the counters
|
||||
comp(*example_inputs, **kwargs)
|
||||
with FlopCounterMode() as mode:
|
||||
|
@ -1,12 +1,18 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
|
||||
import unittest
|
||||
|
||||
from sympy import Symbol, sympify
|
||||
|
||||
import torch
|
||||
from torch._inductor.fx_utils import count_flops_fx, countable_fx
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch._inductor.utils import sympy_str, sympy_subs
|
||||
from torch._inductor.utils import get_device_tflops, sympy_str, sympy_subs
|
||||
from torch._inductor.virtualized import V
|
||||
from torch.testing._internal.common_device_type import (
|
||||
dtypes,
|
||||
instantiate_device_type_tests,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
class TestUtils(TestCase):
|
||||
@ -188,6 +194,14 @@ class TestUtils(TestCase):
|
||||
countable_fx(fx_node_2), f"Expected false {f}: {fx_node_2}"
|
||||
)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "skip if no device")
|
||||
@dtypes(torch.float16, torch.bfloat16, torch.float32)
|
||||
def test_get_device_tflops(self, dtype):
|
||||
ret = get_device_tflops(dtype)
|
||||
self.assertTrue(type(ret) == float)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestUtils, globals())
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -27,6 +27,7 @@ import torch.nn as nn
|
||||
import torch.optim
|
||||
import torch.utils.data
|
||||
from torch._C._profiler import _ExperimentalConfig, _ExtraFields_PyCall
|
||||
from torch._inductor.utils import is_big_gpu
|
||||
from torch.autograd.profiler import KinetoStepTracker, profile as _profile
|
||||
from torch.autograd.profiler_legacy import profile as _profile_legacy
|
||||
from torch.profiler import (
|
||||
@ -3045,6 +3046,53 @@ aten::mm""",
|
||||
assert "Overload Name" in key_averages.table()
|
||||
validate_json(prof)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "requries CUDA")
|
||||
def test_profiler_debug_autotuner(self):
|
||||
"""
|
||||
This test makes sure that profiling events will be present when the kernel is run using the DebugAutotuner.
|
||||
"""
|
||||
if not is_big_gpu():
|
||||
raise unittest.SkipTest("requires large gpu to max-autotune")
|
||||
in1 = torch.randn((256, 512), device="cuda", dtype=torch.float16)
|
||||
in2 = torch.randn((512, 768), device="cuda", dtype=torch.float16)
|
||||
|
||||
def mm():
|
||||
return torch.mm(in1, in2)
|
||||
|
||||
pb_mm = torch.compile(
|
||||
mm,
|
||||
options={
|
||||
"benchmark_kernel": True,
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "TRITON",
|
||||
"profile_bandwidth": True,
|
||||
},
|
||||
)
|
||||
comp_mm = torch.compile(
|
||||
mm,
|
||||
options={
|
||||
"benchmark_kernel": True,
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "TRITON",
|
||||
},
|
||||
)
|
||||
|
||||
with profile() as prof1:
|
||||
pb_mm()
|
||||
with profile() as prof2:
|
||||
comp_mm()
|
||||
|
||||
def names(prof):
|
||||
return {
|
||||
ev.name
|
||||
for ev in prof.events()
|
||||
if "mm" in ev.name or "triton" in ev.name
|
||||
}
|
||||
|
||||
n1 = names(prof1)
|
||||
n2 = names(prof2)
|
||||
self.assertEqual(n1, n2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
41
torch/_inductor/analysis/README.md
Normal file
41
torch/_inductor/analysis/README.md
Normal file
@ -0,0 +1,41 @@
|
||||
# `torch._inductor.analysis`
|
||||
Contains scripts for inductor performance analysis.
|
||||
|
||||
## Analysis
|
||||
This will analyze a chrome trace to create a table useful for performance work. We mainly care about . Currently, it will add the flops and the memory reads of a kernel via formula (it's not looking at program counters or anything.) These, combined with the kernel duration, can be use to calculate achieved flops, achieved memory bandwidth, and roofline calculations.
|
||||
|
||||
### Usage
|
||||
```
|
||||
python profile_analysis.py --analysis <input_json_profile> <default_dtype>
|
||||
```
|
||||
|
||||
### Arguments
|
||||
- `input_json_profile`: The json profile files generated by `torch.profile.export_chrome_trace()`.
|
||||
- `default_dtype`: The default dtype of the model. Sometimes the dtypes of the kernel inputs are not available in the profile, so we use the default dtype to infer the dtypes of the inputs.
|
||||
|
||||
## Diff
|
||||
This mode will diff two different profiles and output a table of the differences. It groups by kernel name, which can fail to properly match accross hardware vendors. More intelligent grouping coming soon.
|
||||
|
||||
### Usage
|
||||
```
|
||||
python profile_analysis.py --diff <json_profile_1> <profile_name_1> <json_profile_2> <profile_name_2> <default_dtype> --name_limit 50
|
||||
```
|
||||
|
||||
### Arguments
|
||||
- `json_profile_1` `json_profile_2`: The json profile files generated by `torch.profile.export_chrome_trace()`.
|
||||
- `profile_name_1` `profile_name_2`: The name of the profile. This is used to identify the profile in the output table.
|
||||
- `default_dtype`: The default dtype of the model. Sometimes the dtypes of the kernel inputs are not available in the profile, so we use the default dtype to infer the dtypes of the inputs.
|
||||
- `name_limit`: The maximum number of characters in the kernel name (they can be quite lengthly and hard to read).
|
||||
|
||||
## Augment
|
||||
This mode will add post-hoc analysis to a profile. Currently, it will add the flops and the memory reads of a kernel via formula (it's not looking at program counters or anything.) These, combined with the kernel duration, can be use to calculate achieved flops, achieved memory bandwidth, and roofline calculations.
|
||||
|
||||
### Usage
|
||||
```
|
||||
python profile_analysis.py --augment_trace <input_json_profile> <output_json_profile> <default_dtype>
|
||||
```
|
||||
|
||||
### Arguments
|
||||
- `input_json_profile`: The json profile files generated by `torch.profile.export_chrome_trace()`.
|
||||
- `output_json_profile`: Where the augmented profile is written.
|
||||
- `default_dtype`: The default dtype of the model. Sometimes the dtypes of the kernel inputs are not available in the profile, so we use the default dtype to infer the dtypes of the inputs.
|
0
torch/_inductor/analysis/__init__.py
Normal file
0
torch/_inductor/analysis/__init__.py
Normal file
193
torch/_inductor/analysis/device_info.py
Normal file
193
torch/_inductor/analysis/device_info.py
Normal file
@ -0,0 +1,193 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DeviceInfo:
|
||||
"""
|
||||
Theoretical Numbers from data sheet. If two numbers are given, Tensor/Matrix Core vs not,
|
||||
then the higher number is reported. Sparsity is not considered.
|
||||
|
||||
|
||||
Bandwidth numbers are tricky, because there are platform differences that may not show up in the profiler trace.
|
||||
For example,
|
||||
"""
|
||||
|
||||
tops: dict[Union[torch.dtype, str], float]
|
||||
dram_bw_gbs: float
|
||||
dram_gb: float
|
||||
|
||||
|
||||
# Indexing is based on `torch.cuda.get_device_name()`
|
||||
# TODO investigate profiler support for tf32 and allow device to report correct number when it's turned on.
|
||||
_device_mapping: dict[str, DeviceInfo] = {
|
||||
# Source:
|
||||
# @lint-ignore https://www.nvidia.com/en-us/data-center/h100/
|
||||
"NVIDIA H100": DeviceInfo(
|
||||
tops={
|
||||
torch.float64: 67.0,
|
||||
torch.float32: 67.5,
|
||||
"torch.tf32": 156.0,
|
||||
torch.bfloat16: 1979.0,
|
||||
torch.float16: 1979.0,
|
||||
torch.float8_e8m0fnu: 3958.0,
|
||||
torch.float8_e8m0fnu: 3958.0,
|
||||
torch.float8_e4m3fnuz: 3958.0,
|
||||
torch.float8_e5m2: 3958.0,
|
||||
torch.float8_e5m2fnuz: 3958.0,
|
||||
torch.float8_e8m0fnu: 3958.0,
|
||||
torch.int8: 3958.0,
|
||||
},
|
||||
dram_bw_gbs=3350,
|
||||
dram_gb=80,
|
||||
),
|
||||
# Source:
|
||||
# @lint-ignore https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/
|
||||
# nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
|
||||
"NVIDIA A100": DeviceInfo(
|
||||
tops={
|
||||
torch.float64: 19.5,
|
||||
torch.float32: 19.5,
|
||||
torch.bfloat16: 312.5,
|
||||
torch.float16: 312.5,
|
||||
# Not in datasheet: float8
|
||||
torch.int8: 624.0,
|
||||
"torch.tf32": 156.0,
|
||||
},
|
||||
dram_bw_gbs=2039.0,
|
||||
dram_gb=80.0,
|
||||
),
|
||||
# Source:
|
||||
# @lint-ignore https://resources.nvidia.com/en-us-gpu-resources/l4-tensor-datasheet
|
||||
"NVIDIA L4": DeviceInfo(
|
||||
tops={
|
||||
# This is a guess, not in datasheet
|
||||
torch.float64: 15.1,
|
||||
torch.float32: 30.3,
|
||||
"torch.tf32": 120.0,
|
||||
torch.bfloat16: 242.0,
|
||||
torch.float16: 242.0,
|
||||
torch.float8_e8m0fnu: 485.0,
|
||||
torch.float8_e8m0fnu: 485.0,
|
||||
torch.float8_e4m3fnuz: 485.0,
|
||||
torch.float8_e5m2: 485.0,
|
||||
torch.float8_e5m2fnuz: 485.0,
|
||||
torch.float8_e8m0fnu: 485.0,
|
||||
torch.int8: 485.0,
|
||||
},
|
||||
dram_bw_gbs=3350,
|
||||
dram_gb=24,
|
||||
),
|
||||
# Source:
|
||||
# @lint-ignore https://www.amd.com/content/dam/amd/en/documents\
|
||||
# /instinct-tech-docs/data-sheets/amd-instinct-mi300a-data-sheet.pdf
|
||||
"AMD MI300A": DeviceInfo(
|
||||
tops={
|
||||
torch.float64: 122.6,
|
||||
torch.float32: 122.6,
|
||||
"torch.tf32": 490.3,
|
||||
torch.bfloat16: 980.6,
|
||||
torch.float16: 980.6,
|
||||
torch.float8_e8m0fnu: 1961.2,
|
||||
torch.float8_e8m0fnu: 1961.2,
|
||||
torch.float8_e4m3fnuz: 1961.2,
|
||||
torch.float8_e5m2: 1961.2,
|
||||
torch.float8_e5m2fnuz: 1961.2,
|
||||
torch.float8_e8m0fnu: 1961.2,
|
||||
torch.int8: 1961.2,
|
||||
},
|
||||
dram_bw_gbs=5300.0,
|
||||
dram_gb=128.0,
|
||||
),
|
||||
# Source:
|
||||
# @lint-ignore https://www.amd.com/content/dam/amd/en/documents/\
|
||||
# instinct-tech-docs/data-sheets/amd-instinct-mi300x-data-sheet.pdf
|
||||
"AMD MI300X": DeviceInfo(
|
||||
tops={
|
||||
torch.float64: 163.4,
|
||||
torch.float32: 163.4,
|
||||
"torch.tf32": 653.7,
|
||||
torch.bfloat16: 1307.4,
|
||||
torch.float16: 1307.4,
|
||||
torch.float8_e8m0fnu: 2614.9,
|
||||
torch.float8_e8m0fnu: 2614.9,
|
||||
torch.float8_e4m3fnuz: 2614.9,
|
||||
torch.float8_e5m2: 2614.9,
|
||||
torch.float8_e5m2fnuz: 2614.9,
|
||||
torch.float8_e8m0fnu: 2614.9,
|
||||
torch.int8: 2614.9,
|
||||
},
|
||||
dram_bw_gbs=5300.0,
|
||||
dram_gb=192.0,
|
||||
),
|
||||
# Source:
|
||||
# @lint-ignore https://www.amd.com/content/dam/amd/\
|
||||
# en/documents/instinct-business-docs/product-briefs/instinct-mi210-brochure.pdf
|
||||
"AMD MI210X": DeviceInfo(
|
||||
tops={
|
||||
torch.float64: 45.3,
|
||||
torch.float32: 45.3,
|
||||
# not specified, fall back to float32 numbers
|
||||
"torch.tf32": 45.3,
|
||||
torch.bfloat16: 181.0,
|
||||
torch.float16: 181.0,
|
||||
# not specified, fall back to float16 numbers
|
||||
torch.float8_e8m0fnu: 181.0,
|
||||
torch.float8_e8m0fnu: 181.0,
|
||||
torch.float8_e4m3fnuz: 181.0,
|
||||
torch.float8_e5m2: 181.0,
|
||||
torch.float8_e5m2fnuz: 181.0,
|
||||
torch.float8_e8m0fnu: 181.0,
|
||||
torch.int8: 181.0,
|
||||
},
|
||||
# pcie4.0x16
|
||||
dram_bw_gbs=1600.0,
|
||||
dram_gb=64.0,
|
||||
),
|
||||
}
|
||||
_device_mapping["AMD INSTINCT MI300X"] = _device_mapping["AMD MI300X"]
|
||||
_device_mapping["AMD INSTINCT MI210X"] = _device_mapping["AMD MI210X"]
|
||||
|
||||
|
||||
def lookup_device_info(name: str) -> Optional[DeviceInfo]:
|
||||
"""
|
||||
Problem: when diffing profiles between amd and nvidia, we don't have access to the device information
|
||||
of the other one. Also, since the analysis is static, we should be able to do it on another device unrelated
|
||||
to the recorded device. Therefore, _device_mapping statically contains the information for lots of devices.
|
||||
If one is missing, please run DeviceInfo.get_device_info() and add it to _device_mapping.
|
||||
name (str): name of the device to lookup. Should map onto torch.cuda.get_device_name().
|
||||
"""
|
||||
return _device_mapping.get(name, None)
|
||||
|
||||
|
||||
def datasheet_tops(dtype: torch.dtype, is_tf32: bool = False) -> Optional[float]:
|
||||
"""
|
||||
Get the theoretical TFLOPS of the device for a given dtype. This can throw an exception if the device
|
||||
is not in the datasheet list above.
|
||||
"""
|
||||
name: Optional[str] = torch.cuda.get_device_name()
|
||||
if name is None:
|
||||
log.info("No device found, returning None")
|
||||
return None
|
||||
device_info = lookup_device_info(name)
|
||||
if device_info is None:
|
||||
log_str = f"Device {name} not in datasheet, returning None"
|
||||
log.info(log_str)
|
||||
return None
|
||||
if dtype not in device_info.tops:
|
||||
log.info(
|
||||
"Device %s does not have a datasheet entry for %s, returning None",
|
||||
name,
|
||||
dtype,
|
||||
)
|
||||
return None
|
||||
|
||||
return device_info.tops[
|
||||
"torch.tf32" if dtype == torch.float32 and is_tf32 else dtype
|
||||
]
|
717
torch/_inductor/analysis/profile_analysis.py
Normal file
717
torch/_inductor/analysis/profile_analysis.py
Normal file
@ -0,0 +1,717 @@
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch._inductor.analysis.device_info import DeviceInfo, lookup_device_info
|
||||
from torch._inductor.utils import tabulate_2d, zip_dicts
|
||||
from torch.utils import _pytree as pytree
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils.flop_counter import flop_registry
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ATEN_PREFIX = "aten::"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProfileEvent:
|
||||
category: str
|
||||
key: str
|
||||
self_device_time_ms: float
|
||||
# the benchmark is run multiple times and we average the count across all the
|
||||
# runs. It should be an integer but define a float just in case.
|
||||
count: float
|
||||
|
||||
|
||||
# adapters convert the json trace into a format that works with flops_counter
|
||||
ArgsType = tuple[tuple[Any, ...], dict[Any, Any]]
|
||||
AdapterType = Callable[[tuple[Any, ...], tuple[Any, ...]], ArgsType]
|
||||
adapters_map: dict[str, AdapterType] = {}
|
||||
|
||||
|
||||
def parse_list(lst: str) -> list[int]:
|
||||
lst = lst.replace("[", "").replace("]", "")
|
||||
substrings = lst.split(",")
|
||||
|
||||
return [int(substring.strip()) for substring in substrings]
|
||||
|
||||
|
||||
def register_adapter(
|
||||
aten: Union[str, list[str]],
|
||||
) -> Callable[
|
||||
[AdapterType],
|
||||
AdapterType,
|
||||
]:
|
||||
def decorator(func: AdapterType) -> AdapterType:
|
||||
global _adapters_map
|
||||
|
||||
if isinstance(aten, str):
|
||||
adapters_map[aten] = func
|
||||
else:
|
||||
for at in aten:
|
||||
adapters_map[at] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@register_adapter(["_slow_conv2d_forward"])
|
||||
def _slow_conv2d_adapter(
|
||||
shapes: tuple[Any, ...], concrete: tuple[Any, ...]
|
||||
) -> tuple[tuple[Any], dict[Any, Any]]:
|
||||
tmp = list(shapes)
|
||||
tmp.append(False)
|
||||
tmp2 = list(concrete)
|
||||
tmp2[3] = tmp2[4]
|
||||
return conv_adapter(tuple(tmp), tuple(tmp2))
|
||||
|
||||
|
||||
@register_adapter(["convolution", "_convolution", "cudnn_convolution"])
|
||||
def conv_adapter(
|
||||
shapes: tuple[Any, ...], concrete: tuple[Any, ...]
|
||||
) -> tuple[tuple[Any], dict[Any, Any]]:
|
||||
tmp = list(shapes)
|
||||
if len(tmp) == 4:
|
||||
transposed = False
|
||||
else:
|
||||
transposed = bool(tmp[6])
|
||||
tmp[6] = transposed
|
||||
|
||||
kwargs: dict[Any, Any] = {}
|
||||
if not transposed:
|
||||
# calculate output shape if not transposed.
|
||||
def conv_out_dims(x: int, kernel: int, stride: int) -> int:
|
||||
return (x - kernel) // stride + 1
|
||||
|
||||
stride = parse_list(concrete[3])
|
||||
inp = shapes[0]
|
||||
w = shapes[1]
|
||||
out_x_y = [conv_out_dims(*args) for args in zip(inp[2:], w[2:], stride)]
|
||||
out = [inp[0], w[0]] + out_x_y # we only need the xy values
|
||||
kwargs["out_val"] = out
|
||||
|
||||
return tuple(tmp), kwargs
|
||||
|
||||
|
||||
def default_adapter(
|
||||
shapes: tuple[Any], concrete: tuple[Any]
|
||||
) -> tuple[tuple[Any], dict[Any, Any]]:
|
||||
return shapes, {}
|
||||
|
||||
|
||||
@register_adapter("addmm")
|
||||
def addmm_adapter(
|
||||
shapes: tuple[Any], concrete: tuple[Any]
|
||||
) -> tuple[tuple[Any], dict[Any, Any]]:
|
||||
tmp = list(shapes)[:3]
|
||||
return tuple(tmp), {}
|
||||
|
||||
|
||||
@register_adapter("bmm")
|
||||
def bmm_adapter(
|
||||
shapes: tuple[Any], concrete: tuple[Any]
|
||||
) -> tuple[tuple[Any], dict[Any, Any]]:
|
||||
tmp = list(shapes)
|
||||
return tuple(tmp[:2]), {}
|
||||
|
||||
|
||||
@register_adapter("baddbmm")
|
||||
def baddbmm_adapter(
|
||||
shapes: tuple[Any], concrete: tuple[Any]
|
||||
) -> tuple[tuple[Any], dict[Any, Any]]:
|
||||
tmp = list(shapes)[:3]
|
||||
return tuple(tmp), {}
|
||||
|
||||
|
||||
@register_adapter("mm")
|
||||
def mm_adapter(
|
||||
shapes: tuple[Any], concrete: tuple[Any]
|
||||
) -> tuple[tuple[Any], dict[Any, Any]]:
|
||||
return shapes, {}
|
||||
|
||||
|
||||
def _parse_kernel_name(name: str) -> Optional[str]:
|
||||
"""
|
||||
parse the name of the kernel from the event name.
|
||||
"""
|
||||
if name.startswith(ATEN_PREFIX):
|
||||
return name[len(ATEN_PREFIX) :]
|
||||
elif "conv" in name:
|
||||
return "convolution"
|
||||
elif "addmm" in name:
|
||||
return "addmm"
|
||||
elif "bmm" in name:
|
||||
return "bmm"
|
||||
elif "baddbmm" in name:
|
||||
return "baddbmm"
|
||||
elif "_mm" in name:
|
||||
return "mm"
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def _calculate_flops(event: dict[str, Any]) -> int:
|
||||
"""
|
||||
This function has to parse the kernel name, which is error prone. There doesn't seem to be another solution that
|
||||
will support all the different backends that can generate kernels, so make sure to update this function when new
|
||||
ops and backends are desired.
|
||||
"""
|
||||
name = event["name"]
|
||||
if "kernel_flop" in event["args"] and event["args"]["kernel_flop"] != 0:
|
||||
return event["args"]["kernel_flop"]
|
||||
op_name = _parse_kernel_name(name)
|
||||
if op_name is None:
|
||||
return 0
|
||||
|
||||
op_obj = getattr(torch.ops.aten, op_name, None)
|
||||
if op_obj is None or op_obj not in flop_registry:
|
||||
return 0
|
||||
|
||||
flop_function = flop_registry[op_obj]
|
||||
|
||||
assert "Input Dims" in event["args"] and "Concrete Inputs" in event["args"]
|
||||
input_shapes = event["args"]["Input Dims"]
|
||||
concrete = event["args"]["Concrete Inputs"]
|
||||
if op_name in adapters_map:
|
||||
args, kwargs = adapters_map[op_name](input_shapes, concrete)
|
||||
else:
|
||||
args, kwargs = default_adapter(input_shapes, concrete)
|
||||
return flop_function(*args, **kwargs)
|
||||
|
||||
|
||||
def _get_size_from_string(type_string: str) -> int:
|
||||
if not hasattr(torch, type_string):
|
||||
return 1
|
||||
else:
|
||||
return getattr(torch, type_string).itemsize
|
||||
|
||||
|
||||
def _default_estimate_gb(event: dict[str, Any]) -> float:
|
||||
sizes_and_types = zip(event["args"]["Input Dims"], event["args"]["Input type"])
|
||||
bw = 0
|
||||
for size, typ in sizes_and_types:
|
||||
isize = _get_size_from_string(typ)
|
||||
bw += isize * math.prod(pytree.tree_flatten(size)[0])
|
||||
return bw / 1e9
|
||||
|
||||
|
||||
def _estimate_gb(event: dict[str, Any]) -> float:
|
||||
"""
|
||||
Our best effort to estimate the gb, should be refactored soon with MemoryCounter.
|
||||
"""
|
||||
name = event["name"]
|
||||
if "kernel_num_gb" in event["args"] and event["args"]["kernel_num_gb"] != 0:
|
||||
return event["args"]["kernel_num_gb"]
|
||||
if "Input type" not in event["args"] or "Input Dims" not in event["args"]:
|
||||
return 0
|
||||
op_name = _parse_kernel_name(name)
|
||||
if op_name is None:
|
||||
return _default_estimate_gb(event)
|
||||
|
||||
op_obj = getattr(torch.ops.aten, op_name, None)
|
||||
if op_obj is None:
|
||||
return _default_estimate_gb(event)
|
||||
|
||||
assert "Input Dims" in event["args"] and "Concrete Inputs" in event["args"]
|
||||
input_shapes = event["args"]["Input Dims"]
|
||||
|
||||
# NOTE these will be refactored into a similar object to FlopCounter soon
|
||||
def mm_formula(M: int, N: int, K: int, size: int) -> int:
|
||||
return 2 * (M * K + N * K + M * N) * size
|
||||
|
||||
if op_name == "addmm":
|
||||
add_in_size = math.prod(pytree.tree_flatten(input_shapes[0])[0])
|
||||
add_type_size = _get_size_from_string(event["args"]["Input type"][0])
|
||||
M = input_shapes[1][0]
|
||||
N = input_shapes[1][1]
|
||||
assert input_shapes[1][1] == input_shapes[2][0]
|
||||
K = input_shapes[2][1]
|
||||
mul_type_size = _get_size_from_string(event["args"]["Input type"][1])
|
||||
return (mm_formula(M, N, K, mul_type_size) + add_in_size * add_type_size) / 1e9
|
||||
elif op_name == "mm":
|
||||
M = input_shapes[0][0]
|
||||
N = input_shapes[0][1]
|
||||
assert input_shapes[0][1] == input_shapes[1][0]
|
||||
K = input_shapes[1][1]
|
||||
type_size = _get_size_from_string(event["args"]["Input type"][0])
|
||||
return mm_formula(M, N, K, type_size) / 1e9
|
||||
elif op_name == "baddbmm":
|
||||
add_in_size = math.prod(pytree.tree_flatten(input_shapes[0])[0])
|
||||
add_type_size = _get_size_from_string(event["args"]["Input type"][0])
|
||||
B = input_shapes[0][0]
|
||||
M = input_shapes[1][1]
|
||||
N = input_shapes[1][2]
|
||||
K = input_shapes[2][2]
|
||||
mul_type_size = _get_size_from_string(event["args"]["Input type"][1])
|
||||
return (
|
||||
B * mm_formula(M, N, K, mul_type_size) + add_in_size * add_type_size
|
||||
) / 1e9
|
||||
elif op_name == "bmm":
|
||||
add_in_size = math.prod(pytree.tree_flatten(input_shapes[0])[0])
|
||||
add_type_size = _get_size_from_string(event["args"]["Input type"][0])
|
||||
B = input_shapes[0][0]
|
||||
M = input_shapes[0][1]
|
||||
N = input_shapes[0][2]
|
||||
K = input_shapes[1][2]
|
||||
mul_type_size = _get_size_from_string(event["args"]["Input type"][1])
|
||||
return (
|
||||
B * mm_formula(M, N, K, mul_type_size) + add_in_size * add_type_size
|
||||
) / 1e9
|
||||
elif op_name in [
|
||||
"convolution",
|
||||
"_convolution",
|
||||
"cudnn_convolution",
|
||||
"_slow_conv2d_forward",
|
||||
]:
|
||||
concrete = event["args"]["Concrete Inputs"]
|
||||
|
||||
def conv_out_dim(x: int, kernel: int, stride: int) -> int:
|
||||
return (x - kernel) // stride + 1
|
||||
|
||||
stride = parse_list(
|
||||
concrete[3] if op_name != "_slow_conv2d_forward" else concrete[4]
|
||||
)
|
||||
inp = input_shapes[0]
|
||||
w = input_shapes[1]
|
||||
out_x_y = [conv_out_dim(*args) for args in zip(inp[2:], w[2:], stride)]
|
||||
out = [inp[0], w[0]] + out_x_y
|
||||
# each output element reads in * w * w chunk
|
||||
input_reads = out[0] * out[1] * out[2] * out[3] * inp[1] * w[2] * w[3]
|
||||
# Assume weights are in cache, so only read once
|
||||
weight_reads = w[0] * w[1] * w[2] * w[3]
|
||||
return (input_reads + weight_reads) / 1e9
|
||||
|
||||
return _default_estimate_gb(event)
|
||||
|
||||
|
||||
def _create_extern_mapping(
|
||||
data: dict[str, Any],
|
||||
) -> defaultdict[int, list[dict[str, Any]]]:
|
||||
"""
|
||||
compute a mapping from exteral ids to non kernels, which contain the information we need to estimate flops etc
|
||||
"""
|
||||
extern_mapping: defaultdict[int, list[dict[str, Any]]] = defaultdict(list)
|
||||
for event in data["traceEvents"]:
|
||||
if (
|
||||
"args" not in event
|
||||
or "External id" not in event["args"]
|
||||
or event["cat"] != "cpu_op"
|
||||
):
|
||||
continue
|
||||
if len(extern_mapping[event["args"]["External id"]]) > 0:
|
||||
raise ParseException("duplicate external id in event")
|
||||
extern_mapping[event["args"]["External id"]].append(event)
|
||||
return extern_mapping
|
||||
|
||||
|
||||
def _augment_trace_helper(data: dict[str, Any]) -> dict[str, Any]:
|
||||
extern_mapping = _create_extern_mapping(data)
|
||||
|
||||
for event in data["traceEvents"]:
|
||||
if "cat" not in event or event["cat"] != "kernel":
|
||||
continue
|
||||
if "args" not in event:
|
||||
raise ParseException(f"kernel has no args: {event}")
|
||||
if "External id" not in event["args"]:
|
||||
event_str = f"kernel has no External id: {event}"
|
||||
log.info(event_str)
|
||||
continue
|
||||
|
||||
external_op = extern_mapping[event["args"]["External id"]][0]
|
||||
flops = _calculate_flops(external_op)
|
||||
if flops == 0:
|
||||
flops = _calculate_flops(event)
|
||||
external_op["args"]["kernel_flop"] = flops
|
||||
external_op["args"]["kernel_num_gb"] = _estimate_gb(external_op)
|
||||
event["args"]["kernel_flop"] = external_op["args"]["kernel_flop"]
|
||||
event["args"]["kernel_num_gb"] = external_op["args"]["kernel_num_gb"]
|
||||
return data
|
||||
|
||||
|
||||
_dtype_map = {
|
||||
"float": torch.float,
|
||||
"float32": torch.float,
|
||||
"int": torch.int,
|
||||
"int8": torch.int8,
|
||||
"int16": torch.int16,
|
||||
"int32": torch.int,
|
||||
"long": torch.long,
|
||||
"long int": torch.long,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float16": torch.float16,
|
||||
"float64": torch.double,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class KernelStats:
|
||||
flops: int
|
||||
bw: float
|
||||
latency: float # us
|
||||
achieved_flops: float
|
||||
achieved_bandwidth: float
|
||||
|
||||
|
||||
KernelNameMap = defaultdict[str, OrderedSet[KernelStats]]
|
||||
|
||||
|
||||
@dataclass(frozen=False)
|
||||
class Device:
|
||||
name: str
|
||||
index: int
|
||||
info: Optional[DeviceInfo]
|
||||
stats: KernelNameMap
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Device({self.name}, {self.index}): {self.info}"
|
||||
|
||||
|
||||
DeviceMap = dict[int, Device]
|
||||
Table = tuple[list[str], dict[str, list[str]]]
|
||||
|
||||
|
||||
class JsonProfile:
|
||||
_devices: DeviceMap
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
benchmark_name: Optional[str] = None,
|
||||
dtype: Optional[Union[torch.dtype, str]] = None,
|
||||
):
|
||||
"""
|
||||
Convienence class for running common operations on chrome/perfetto json traces.
|
||||
"""
|
||||
self.path = path
|
||||
with open(path) as f:
|
||||
self.data = json.load(f)
|
||||
self.events = self.data["traceEvents"]
|
||||
self.benchmark_name = benchmark_name
|
||||
if dtype is None:
|
||||
self.dtype = None
|
||||
elif isinstance(dtype, torch.dtype):
|
||||
self.dtype = dtype
|
||||
else:
|
||||
if dtype in _dtype_map:
|
||||
self.dtype = _dtype_map[dtype]
|
||||
else:
|
||||
self.dtype = None
|
||||
self._create_devices()
|
||||
|
||||
def convert_dtype(self, event: dict[str, Any]) -> Optional[torch.dtype]:
|
||||
"""
|
||||
Each op has a list of dtypes for each input arg. We need to convert these into a single dtype for flop estimation.
|
||||
Issues:
|
||||
- converting the strings to concrete torch.dtypes
|
||||
- What if we have float32, float, float16 all in the inputs? Our choice is to use the largest buffer dtype.
|
||||
"""
|
||||
|
||||
if (
|
||||
"Input Dims" not in event["args"]
|
||||
or "Input type" not in event["args"]
|
||||
or "Concrete Inputs" not in event["args"]
|
||||
):
|
||||
if "bfloat16" in event["name"]:
|
||||
return torch.bfloat16
|
||||
elif "float16" in event["name"]:
|
||||
return torch.float16
|
||||
else:
|
||||
return None
|
||||
|
||||
input_sizes = event["args"]["Input Dims"]
|
||||
input_types = event["args"]["Input type"]
|
||||
concrete_inputs = event["args"]["Concrete Inputs"]
|
||||
assert len(input_sizes) == len(input_types)
|
||||
assert len(input_types) == len(concrete_inputs)
|
||||
|
||||
if len(input_sizes) == 0:
|
||||
raise RuntimeError("Empty input_sizes and input_types")
|
||||
|
||||
biggest_size = 0
|
||||
biggest_index = 0
|
||||
for i in range(len(input_sizes)):
|
||||
if concrete_inputs[i] != "":
|
||||
# concrete inputs are usually small tensors, so we can just skip
|
||||
continue
|
||||
my_size = input_sizes[i]
|
||||
total_size = sum(parse_list(my_size))
|
||||
if total_size > biggest_size:
|
||||
biggest_size = total_size
|
||||
biggest_index = i
|
||||
ret_type = input_types[biggest_index]
|
||||
if ret_type in _dtype_map:
|
||||
return _dtype_map[ret_type]
|
||||
raise RuntimeError(f"Unknown type: {ret_type}. Please add to _dtype_map.")
|
||||
|
||||
def _create_devices(self) -> None:
|
||||
self._devices = {}
|
||||
for dev in self.data["deviceProperties"]:
|
||||
name = dev["name"]
|
||||
device_info = lookup_device_info(name)
|
||||
|
||||
if device_info is None:
|
||||
log.info(
|
||||
"Unsupported device in profile: %s, please consider contributing to _device_mapping.",
|
||||
name,
|
||||
)
|
||||
self._devices[dev["id"]] = Device(
|
||||
name, dev["id"], device_info, defaultdict(OrderedSet)
|
||||
)
|
||||
|
||||
def calculate_flops(self, event: dict[str, Any]) -> int:
|
||||
return _calculate_flops(event)
|
||||
|
||||
def estimate_gb(self, event: dict[str, Any]) -> float:
|
||||
return _estimate_gb(event)
|
||||
|
||||
def augment_trace(self) -> None:
|
||||
self.data = _augment_trace_helper(self.data)
|
||||
|
||||
def _compute_stats(self) -> None:
|
||||
"""populates the name -> stats map"""
|
||||
for event in self.events:
|
||||
if "cat" not in event or "args" not in event or event["cat"] != "kernel":
|
||||
continue
|
||||
dev = self._devices[event["args"]["device"]]
|
||||
|
||||
dur = event["dur"] # us
|
||||
if "kernel_flop" in event["args"]:
|
||||
assert dur != 0
|
||||
# 1,000,000us/s * flop / us
|
||||
op_flops = event["args"]["kernel_flop"] / (dur / 1e6)
|
||||
else:
|
||||
op_flops = 0
|
||||
|
||||
if "kernel_num_gb" in event["args"]:
|
||||
assert dur != 0
|
||||
# 1,000,000us/s * gb = gb/s
|
||||
op_gbps = event["args"]["kernel_num_gb"] / (dur / 1e6)
|
||||
else:
|
||||
op_gbps = 0
|
||||
|
||||
if dev.info is not None:
|
||||
dtype = self.convert_dtype(event) or self.dtype
|
||||
if dtype is None:
|
||||
raise RuntimeError(
|
||||
"dtype is not found on tensor and default dtype is not set"
|
||||
)
|
||||
achieved_flops = 100 * op_flops / (1e12 * dev.info.tops[dtype])
|
||||
achieved_bandwidth = 100 * op_gbps / dev.info.dram_bw_gbs
|
||||
else:
|
||||
achieved_flops = 0
|
||||
achieved_bandwidth = 0
|
||||
|
||||
dev.stats[event["name"]].add(
|
||||
KernelStats(
|
||||
flops=op_flops,
|
||||
bw=op_gbps,
|
||||
latency=dur,
|
||||
achieved_bandwidth=achieved_bandwidth,
|
||||
achieved_flops=achieved_flops,
|
||||
)
|
||||
)
|
||||
|
||||
def _create_single_table(self, dev: Device) -> Table:
|
||||
"""Create a table with the devices mapped to indices."""
|
||||
headers = [
|
||||
"Kernel Name",
|
||||
"Kernel Count",
|
||||
"FLOPS",
|
||||
"Kernel Reads (GB)",
|
||||
"Dur (us)",
|
||||
"Achieved FLOPS %",
|
||||
"Achieved Bandwidth %",
|
||||
]
|
||||
rows: dict[str, list[str]] = {}
|
||||
|
||||
def safe_div_format(x: float, y: float) -> str:
|
||||
if y == 0:
|
||||
return "0.0"
|
||||
return f"{x / y:.4f}"
|
||||
|
||||
for kernel_name, stats_set in dev.stats.items():
|
||||
ker_count = 0
|
||||
flops = 0
|
||||
flops_count = 0
|
||||
achieved_flops = 0.0
|
||||
bw = 0.0
|
||||
bw_count = 0
|
||||
achieved_bandwidth = 0.0
|
||||
latency = 0.0
|
||||
for stats in stats_set:
|
||||
if stats.flops != 0:
|
||||
flops += stats.flops
|
||||
achieved_flops += stats.achieved_flops
|
||||
flops_count += 1
|
||||
if stats.bw != 0:
|
||||
bw += stats.bw
|
||||
achieved_bandwidth += stats.achieved_bandwidth
|
||||
bw_count += 1
|
||||
latency += stats.latency
|
||||
ker_count += 1
|
||||
assert ker_count != 0
|
||||
rows[kernel_name] = [
|
||||
str(ker_count),
|
||||
safe_div_format(flops, flops_count),
|
||||
safe_div_format(bw, bw_count),
|
||||
safe_div_format(latency, ker_count),
|
||||
safe_div_format(achieved_flops, flops_count),
|
||||
safe_div_format(achieved_bandwidth, bw_count),
|
||||
]
|
||||
|
||||
return headers, rows
|
||||
|
||||
def _create_tables(self, devs: DeviceMap) -> dict[int, Table]:
|
||||
return {idx: self._create_single_table(dev) for idx, dev in devs.items()}
|
||||
|
||||
def _combine_tables(
|
||||
self, table1: Table, table1_name: str, table2: Table, table2_name: str
|
||||
) -> Table:
|
||||
new_headers = (
|
||||
["Kernel Name"]
|
||||
+ [f"{table1_name} {head}" for head in table1[0][1:]]
|
||||
+ [f"{table2_name} {head}" for head in table2[0][1:]]
|
||||
)
|
||||
t1_length = len(table1[0][1:])
|
||||
t2_length = len(table2[0][1:])
|
||||
new_rows = {}
|
||||
|
||||
for key, row1, row2 in zip_dicts(
|
||||
table1[1],
|
||||
table2[1],
|
||||
d1_default=["Empty"] * t1_length,
|
||||
d2_default=["Empty"] * t2_length,
|
||||
):
|
||||
assert row1 is not None
|
||||
assert row2 is not None
|
||||
new_rows[key] = row1 + row2
|
||||
return new_headers, new_rows
|
||||
|
||||
def report(
|
||||
self, other: Optional["JsonProfile"] = None, name_limit: int = 40
|
||||
) -> str:
|
||||
def create_ret(
|
||||
table_headers: list[str], table_rows: dict[str, list[str]]
|
||||
) -> str:
|
||||
table_flattened = [
|
||||
[kernel_name[:name_limit], *kernel_vals]
|
||||
for kernel_name, kernel_vals in table_rows.items()
|
||||
]
|
||||
return tabulate_2d(table_flattened, headers=table_headers)
|
||||
|
||||
if other is not None:
|
||||
self._compute_stats()
|
||||
other._compute_stats()
|
||||
|
||||
self_tables = self._create_tables(self._devices)
|
||||
other_tables = self._create_tables(other._devices)
|
||||
|
||||
self_name = (
|
||||
self.benchmark_name if self.benchmark_name is not None else "Table 1"
|
||||
)
|
||||
other_name = (
|
||||
other.benchmark_name if other.benchmark_name is not None else "Table 2"
|
||||
)
|
||||
|
||||
ret = []
|
||||
assert self._devices.keys() == other._devices.keys()
|
||||
for device_idx, t1, t2 in zip_dicts(
|
||||
self_tables, other_tables, d1_default=None, d2_default=None
|
||||
):
|
||||
assert t1 is not None
|
||||
assert t2 is not None
|
||||
table_headers, table_rows = self._combine_tables(
|
||||
t1, self_name, t2, other_name
|
||||
)
|
||||
tab_string = create_ret(table_headers, table_rows)
|
||||
ret.append(f"{self._devices[device_idx]}:\n{tab_string}")
|
||||
return "\n".join(ret)
|
||||
self._compute_stats()
|
||||
|
||||
self_tables = self._create_tables(self._devices)
|
||||
|
||||
ret = []
|
||||
for idx, table in self_tables.items():
|
||||
table_headers, table_rows = table
|
||||
tab_string = create_ret(table_headers, table_rows)
|
||||
ret.append(f"{self._devices[idx]}:\n{tab_string}")
|
||||
return "\n".join(ret)
|
||||
|
||||
def dump(self, out: str) -> None:
|
||||
with open(out, "w") as f:
|
||||
json.dump(self.data, f)
|
||||
|
||||
|
||||
class ParseException(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def main() -> None:
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--diff",
|
||||
nargs=5,
|
||||
metavar=(
|
||||
"input_file1",
|
||||
"name1",
|
||||
"input_file2",
|
||||
"name2",
|
||||
"dtype",
|
||||
),
|
||||
help="Two json traces to compare with, specified as <file1> <name1> <file2> <name2> <dtype>",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--name_limit",
|
||||
type=int,
|
||||
help="the maximum name size in the final report",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--augment_trace",
|
||||
"-a",
|
||||
nargs=3,
|
||||
metavar=("input_file", "output_file", "dtype"),
|
||||
help="Augment a trace with inductor meta information. Provide input and output file paths.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--analysis",
|
||||
nargs=2,
|
||||
metavar=("input_file", "dtype"),
|
||||
help="Run analysis on a single trace, specified as <file> <dtype>",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.diff:
|
||||
p1 = JsonProfile(args.diff[0], args.diff[1], dtype=args.diff[4])
|
||||
p1.augment_trace()
|
||||
p2 = JsonProfile(args.diff[2], args.diff[3], dtype=args.diff[4])
|
||||
p2.augment_trace()
|
||||
if args.name_limit:
|
||||
print(p1.report(p2, name_limit=args.name_limit))
|
||||
else:
|
||||
print(p1.report(p2))
|
||||
if args.analysis:
|
||||
p1 = JsonProfile(
|
||||
args.analysis[0],
|
||||
dtype=args.analysis[1],
|
||||
)
|
||||
p1.augment_trace()
|
||||
if args.name_limit:
|
||||
print(p1.report(name_limit=args.name_limit))
|
||||
else:
|
||||
print(p1.report())
|
||||
if args.augment_trace:
|
||||
p = JsonProfile(args.augment_trace[0], dtype=args.augment_trace[2])
|
||||
p.augment_trace()
|
||||
p.dump(args.augment_trace[1])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -970,6 +970,13 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
|
||||
return tuple(map(fn, value))
|
||||
return fn(value)
|
||||
|
||||
def estimate_flops(self) -> Optional[int]:
|
||||
flops = [
|
||||
node.estimate_flops()
|
||||
for node in NodeScheduleMarker.only_nodes(self.features.node_schedule)
|
||||
]
|
||||
return sum(filter(None, flops))
|
||||
|
||||
def estimate_kernel_num_bytes(self):
|
||||
"""
|
||||
Try the best to estimate the total size (in bytes) of the
|
||||
@ -1616,6 +1623,8 @@ class SIMDScheduling(BaseScheduling):
|
||||
kernel.cse.invalidate(OrderedSet())
|
||||
|
||||
if not isinstance(partial_code, str):
|
||||
# This is used to calculate flops in TritonTemplateKernels
|
||||
with ir.IRNode.current_origins(template_node.node.origins):
|
||||
partial_code.finalize_hook("<DEF_KERNEL>")
|
||||
partial_code.finalize_hook("<ARGDEFS>", strict=False)
|
||||
# finalize must be called after adding epilogue above
|
||||
|
@ -3686,7 +3686,12 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
num_gb = None
|
||||
if config.benchmark_kernel or config.profile_bandwidth:
|
||||
num_gb = self.estimate_kernel_num_bytes() / 1e9
|
||||
if num_gb is not None:
|
||||
inductor_meta["kernel_num_gb"] = num_gb
|
||||
if config.benchmark_kernel:
|
||||
flops = self.estimate_flops()
|
||||
if flops is not None:
|
||||
inductor_meta["kernel_flop"] = flops
|
||||
|
||||
triton_meta["configs"] = [config_of(signature)]
|
||||
|
||||
|
@ -314,7 +314,7 @@ def is_node_realized(node: torch.fx.Node) -> bool:
|
||||
|
||||
|
||||
def count_flops_fx(node: torch.fx.Node) -> Optional[int]:
|
||||
if isinstance(node.target, str):
|
||||
if not countable_fx(node) or isinstance(node.target, str):
|
||||
return None
|
||||
with FakeTensorMode(allow_non_fake_inputs=True):
|
||||
success, args, kwargs = get_fake_args_kwargs(node)
|
||||
|
@ -65,6 +65,7 @@ from .exc import (
|
||||
MissingOperatorWithDecomp,
|
||||
MissingOperatorWithoutDecomp,
|
||||
)
|
||||
from .fx_utils import count_flops_fx
|
||||
from .ir import (
|
||||
Constant,
|
||||
DonatedBuffer,
|
||||
@ -659,20 +660,12 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
|
||||
# only grouped convolutions benchmarked as slower in conv samples for inference only
|
||||
if is_inference:
|
||||
from torch.utils.flop_counter import FlopCounterMode
|
||||
|
||||
flop_counts: dict[str, float] = defaultdict(float)
|
||||
for node in conv_nodes:
|
||||
success, args, kwargs = torch._inductor.fx_utils.get_fake_args_kwargs(
|
||||
node
|
||||
)
|
||||
counted_flops = count_flops_fx(node)
|
||||
if counted_flops is None:
|
||||
continue
|
||||
|
||||
if success:
|
||||
with FlopCounterMode(display=False) as flop_counter_mode:
|
||||
with V.fake_mode:
|
||||
node.target(*args, **kwargs)
|
||||
|
||||
counted_flops = flop_counter_mode.get_total_flops()
|
||||
if is_grouped(node):
|
||||
node_type = "grouped"
|
||||
elif is_small_channel(node):
|
||||
|
@ -785,6 +785,20 @@ class CachingAutotuner(KernelInterface):
|
||||
# reset to zero before evaluating any config
|
||||
self.reset_to_zero_args(*args, **kwargs)
|
||||
args_with_constexprs = self._get_args_with_constexprs(cloned_args, launcher)
|
||||
if autograd_profiler._is_profiler_enabled:
|
||||
profiler_kwargs = self.get_profiler_kwargs(stream, launcher)
|
||||
with torch._C._profiler._RecordFunctionFast(
|
||||
self.inductor_meta.get("kernel_name", "triton kernel"),
|
||||
args_with_constexprs,
|
||||
profiler_kwargs,
|
||||
):
|
||||
launcher(
|
||||
*args_with_constexprs,
|
||||
**cloned_kwargs,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
else:
|
||||
launcher(
|
||||
*args_with_constexprs,
|
||||
**cloned_kwargs,
|
||||
@ -792,7 +806,8 @@ class CachingAutotuner(KernelInterface):
|
||||
)
|
||||
self.restore_args_from_cpu(cpu_copies)
|
||||
|
||||
if with_profiler:
|
||||
# only use profiler when not already in a profiler instance
|
||||
if with_profiler and not autograd_profiler._is_profiler_enabled:
|
||||
from torch._inductor.utils import do_bench_using_profiling
|
||||
|
||||
return do_bench_using_profiling(kernel_call, warmup=10, rep=40)
|
||||
@ -1100,6 +1115,28 @@ class CachingAutotuner(KernelInterface):
|
||||
).make_launcher()
|
||||
return config2launcher[best_config]
|
||||
|
||||
def get_profiler_kwargs(self, stream, launcher):
|
||||
kernel_kwargs_str = ",".join(
|
||||
f"{k}={v}" for (k, v) in launcher.config.kwargs.items()
|
||||
)
|
||||
|
||||
ret = {
|
||||
"kernel_file": (self.filename or ""),
|
||||
"kernel_hash": self.kernel_hash,
|
||||
"kernel_backend": "triton",
|
||||
"stream": stream,
|
||||
"num_warps": launcher.config.num_warps,
|
||||
"num_stages": launcher.config.num_stages,
|
||||
"kernel_kwargs": kernel_kwargs_str,
|
||||
}
|
||||
if "kernel_name" in self.inductor_meta:
|
||||
ret["kernel_name"] = self.inductor_meta["kernel_name"]
|
||||
if "kernel_flop" in self.inductor_meta:
|
||||
ret["kernel_flop"] = self.inductor_meta["kernel_flop"]
|
||||
if "kernel_num_gb" in self.inductor_meta:
|
||||
ret["kernel_num_gb"] = self.inductor_meta["kernel_num_gb"]
|
||||
return ret
|
||||
|
||||
def run(
|
||||
self,
|
||||
*args,
|
||||
@ -1152,19 +1189,7 @@ class CachingAutotuner(KernelInterface):
|
||||
# it is faster than entering and exiting a context manager, even if the context
|
||||
# manager is a nullcontext.
|
||||
if autograd_profiler._is_profiler_enabled:
|
||||
kernel_kwargs_str = ",".join(
|
||||
f"{k}={v}" for (k, v) in launcher.config.kwargs.items()
|
||||
)
|
||||
|
||||
profiler_kwargs = {
|
||||
"kernel_file": (self.filename or ""),
|
||||
"kernel_hash": self.kernel_hash,
|
||||
"kernel_backend": "triton",
|
||||
"stream": stream,
|
||||
"num_warps": launcher.config.num_warps,
|
||||
"num_stages": launcher.config.num_stages,
|
||||
"kernel_kwargs": kernel_kwargs_str,
|
||||
}
|
||||
profiler_kwargs = self.get_profiler_kwargs(stream, launcher)
|
||||
|
||||
with torch._C._profiler._RecordFunctionFast(
|
||||
self.inductor_meta.get("kernel_name", "triton kernel"),
|
||||
|
@ -40,7 +40,7 @@ from .codegen.common import BackendFeature, get_scheduling_for_device, Kernel
|
||||
from .comm_analysis import estimate_nccl_collective_runtime
|
||||
from .dependencies import Dep, MemoryDep, StarDep, WeakDep
|
||||
from .exc import GPUTooOldForTriton, TritonMissing
|
||||
from .fx_utils import count_flops_fx, countable_fx
|
||||
from .fx_utils import count_flops_fx
|
||||
from .ir import (
|
||||
get_device_type,
|
||||
GraphPartitionSignature,
|
||||
@ -793,12 +793,12 @@ class BaseSchedulerNode:
|
||||
fx_node = self.node.get_origin_node()
|
||||
if fx_node is None:
|
||||
return None
|
||||
if not countable_fx(fx_node):
|
||||
return None
|
||||
|
||||
flops = count_flops_fx(fx_node)
|
||||
if flops is None:
|
||||
return None
|
||||
|
||||
resolved_flops = V.graph.sizevars.size_hints((flops,), fallback=0)[0]
|
||||
resolved_flops = V.graph.sizevars.size_hint(flops, fallback=0)
|
||||
counters["inductor"]["flop_count"] += resolved_flops
|
||||
return resolved_flops
|
||||
|
||||
|
@ -60,6 +60,7 @@ from .codegen.triton import (
|
||||
from .codegen.triton_utils import config_of, equal_1_arg_indices, signature_to_meta
|
||||
from .codegen.wrapper import pexpr
|
||||
from .exc import CUDACompileError
|
||||
from .fx_utils import count_flops_fx
|
||||
from .ir import ChoiceCaller, PrimitiveInfoType
|
||||
from .ops_handler import StoreMode
|
||||
from .runtime.benchmarking import benchmarker
|
||||
@ -481,12 +482,20 @@ class TritonTemplateKernel(TritonKernel):
|
||||
ninplace_args = len(unique(self.args.inplace_buffers.values()))
|
||||
num_bytes = []
|
||||
for i, inp in enumerate(itertools.chain(self.input_nodes, (self.output_node,))):
|
||||
size = V.graph.sizevars.size_hints(inp.get_size())
|
||||
size = V.graph.sizevars.size_hints(inp.get_size(), fallback=0)
|
||||
numel = functools.reduce(operator.mul, size, 1)
|
||||
dtype_size = get_dtype_size(inp.get_dtype())
|
||||
num_bytes.append(numel * dtype_size * (1 + int(i < ninplace_args)))
|
||||
return sum(num_bytes)
|
||||
|
||||
def estimate_flops(self) -> int:
|
||||
for node in self.input_nodes:
|
||||
for fx_node in node._current_origins:
|
||||
f = count_flops_fx(fx_node)
|
||||
if f is not None:
|
||||
return V.graph.sizevars.size_hint(f, fallback=0)
|
||||
return 0
|
||||
|
||||
def jit_lines(self):
|
||||
if self.use_jit:
|
||||
return "@triton.jit"
|
||||
@ -525,6 +534,9 @@ class TritonTemplateKernel(TritonKernel):
|
||||
if config.profile_bandwidth or config.benchmark_kernel:
|
||||
num_gb = self.estimate_kernel_num_bytes() / 1e9
|
||||
inductor_meta["kernel_num_gb"] = num_gb
|
||||
if config.benchmark_kernel:
|
||||
flops = self.estimate_flops()
|
||||
inductor_meta["kernel_flop"] = flops
|
||||
|
||||
template_args = f"""
|
||||
num_stages={self.num_stages},
|
||||
|
@ -22,7 +22,14 @@ import tempfile
|
||||
import textwrap
|
||||
import time
|
||||
import unittest
|
||||
from collections.abc import Collection, Iterator, Mapping, MutableMapping, MutableSet
|
||||
from collections.abc import (
|
||||
Collection,
|
||||
Generator,
|
||||
Iterator,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
MutableSet,
|
||||
)
|
||||
from datetime import datetime
|
||||
from io import StringIO
|
||||
from typing import (
|
||||
@ -51,6 +58,7 @@ from unittest import mock
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from torch._inductor.analysis.device_info import datasheet_tops
|
||||
from torch._inductor.runtime.hints import DeviceProperties
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._pytree import tree_map_only
|
||||
@ -2156,9 +2164,19 @@ def get_backend_num_stages() -> int:
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_device_tflops(dtype: torch.dtype) -> int:
|
||||
def get_device_tflops(dtype: torch.dtype) -> float:
|
||||
"""
|
||||
We don't want to throw errors in this function. First check to see if the device is in device_info.py,
|
||||
then fall back to the inaccurate triton estimation.
|
||||
"""
|
||||
ds_tops = datasheet_tops(dtype, is_tf32=torch.backends.cuda.matmul.allow_tf32)
|
||||
if ds_tops is not None:
|
||||
return ds_tops
|
||||
|
||||
from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops
|
||||
|
||||
from torch.testing._internal.common_cuda import SM80OrLater
|
||||
|
||||
assert dtype in (torch.float16, torch.bfloat16, torch.float32)
|
||||
|
||||
if inspect.signature(get_max_simd_tflops).parameters.get("clock_rate"):
|
||||
@ -2166,7 +2184,7 @@ def get_device_tflops(dtype: torch.dtype) -> int:
|
||||
from torch._utils_internal import max_clock_rate
|
||||
|
||||
sm_clock = max_clock_rate()
|
||||
if dtype in (torch.float16, torch.bfloat16):
|
||||
if dtype in (torch.float16, torch.bfloat16) and SM80OrLater:
|
||||
return get_max_tensorcore_tflops(dtype, sm_clock)
|
||||
|
||||
if torch.backends.cuda.matmul.allow_tf32:
|
||||
@ -2174,7 +2192,7 @@ def get_device_tflops(dtype: torch.dtype) -> int:
|
||||
else:
|
||||
return get_max_simd_tflops(torch.float32, sm_clock)
|
||||
else:
|
||||
if dtype in (torch.float16, torch.bfloat16):
|
||||
if dtype in (torch.float16, torch.bfloat16) and SM80OrLater:
|
||||
return get_max_tensorcore_tflops(dtype)
|
||||
|
||||
if torch.backends.cuda.matmul.allow_tf32:
|
||||
@ -3204,3 +3222,54 @@ def is_mkldnn_fp16_supported(device_type: str) -> bool:
|
||||
# match "xpu", "xpu:0", "xpu:1", etc.
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def tabulate_2d(elements: Sequence[Sequence[T]], headers: Sequence[T]) -> str:
|
||||
widths = [len(str(e)) for e in headers]
|
||||
for row in elements:
|
||||
assert len(row) == len(headers)
|
||||
for i, e in enumerate(row):
|
||||
widths[i] = max(widths[i], len(str(e)))
|
||||
lines = []
|
||||
lines.append("|".join(f" {h:{w}} " for h, w in zip(headers, widths)))
|
||||
# widths whitespace horizontal separators
|
||||
total_width = sum(widths) + (len(widths) * 2) + (len(widths) - 1)
|
||||
lines.append("-" * total_width)
|
||||
for row in elements:
|
||||
lines.append("|".join(f" {e:{w}} " for e, w in zip(row, widths)))
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def zip_dicts(
|
||||
dict1: Mapping[KeyType, ValType],
|
||||
dict2: Mapping[KeyType, ValType],
|
||||
d1_default: ValType | None = None,
|
||||
d2_default: ValType | None = None,
|
||||
) -> Generator[tuple[KeyType, ValType | None, ValType | None], None, None]:
|
||||
"""
|
||||
Zip two dictionaries together, replacing missing keys with default values.
|
||||
|
||||
Args:
|
||||
dict1 (dict): The first dictionary.
|
||||
dict2 (dict): The second dictionary.
|
||||
d1_default (Any): the default value for the first dictionary
|
||||
d2_default (Any): the default value for the second dictionary
|
||||
|
||||
Yields:
|
||||
tuple: A tuple containing the key, the value from dict1 (or d1_default if missing),
|
||||
and the value from dict2 (or d2_default if missing).
|
||||
"""
|
||||
# Find the union of all keys
|
||||
all_keys = OrderedSet(dict1.keys()) | OrderedSet(dict2.keys())
|
||||
|
||||
# Iterate over all keys
|
||||
for key in all_keys:
|
||||
# Get the values from both dictionaries, or default if missing
|
||||
value1 = dict1.get(key)
|
||||
value2 = dict2.get(key)
|
||||
|
||||
yield (
|
||||
key,
|
||||
value1 if value1 is not None else d1_default,
|
||||
value2 if value2 is not None else d2_default,
|
||||
)
|
||||
|
@ -1,8 +1,8 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
import datetime
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from types import ModuleType
|
||||
from typing import Any, Optional, Protocol
|
||||
|
||||
@ -159,7 +159,7 @@ def benchmark_all_kernels(
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@dataclass
|
||||
class ProfileEvent:
|
||||
category: str
|
||||
key: str
|
||||
@ -176,6 +176,10 @@ def parse_profile_event_list(
|
||||
nruns: int,
|
||||
device_name: str,
|
||||
) -> None:
|
||||
"""
|
||||
Parse and generate a report for an event_list.
|
||||
"""
|
||||
|
||||
def get_self_device_time(
|
||||
ev: torch.autograd.profiler_util.EventList,
|
||||
) -> float:
|
||||
@ -295,6 +299,10 @@ def parse_profile_event_list(
|
||||
report()
|
||||
|
||||
|
||||
PROFILE_DIR = tempfile.gettempdir()
|
||||
PROFILE_PATH = f"{PROFILE_DIR}/compiled_module_profile.json"
|
||||
|
||||
|
||||
def perf_profile(
|
||||
wall_time_ms: float,
|
||||
times: int,
|
||||
@ -305,14 +313,14 @@ def perf_profile(
|
||||
with torch.profiler.profile(record_shapes=True) as p:
|
||||
benchmark_compiled_module_fn(times=times, repeat=repeat)
|
||||
|
||||
path = f"{tempfile.gettempdir()}/compiled_module_profile.json"
|
||||
path = PROFILE_PATH
|
||||
p.export_chrome_trace(path)
|
||||
print(f"Profiling result for a compiled module of benchmark {benchmark_name}:")
|
||||
print(f"Chrome trace for the profile is written to {path}")
|
||||
event_list = p.key_averages(group_by_input_shape=True)
|
||||
print(event_list.table(sort_by="self_device_time_total", row_limit=10))
|
||||
parse_profile_event_list(
|
||||
benchmark_name, event_list, wall_time_ms, times * repeat, p.use_device
|
||||
benchmark_name, event_list, wall_time_ms, times * repeat, p.use_device or ""
|
||||
)
|
||||
|
||||
|
||||
|
@ -213,6 +213,9 @@ def is_fb_unit_test() -> bool:
|
||||
|
||||
@functools.cache
|
||||
def max_clock_rate():
|
||||
"""
|
||||
unit: MHz
|
||||
"""
|
||||
if not torch.version.hip:
|
||||
from triton.testing import nvsmi
|
||||
|
||||
|
@ -127,7 +127,6 @@ def conv_flop_count(
|
||||
Returns:
|
||||
int: the number of flops
|
||||
"""
|
||||
|
||||
batch_size = x_shape[0]
|
||||
conv_shape = (x_shape if transposed else out_shape)[2:]
|
||||
c_out, c_in, *filter_size = w_shape
|
||||
@ -146,7 +145,7 @@ def conv_flop_count(
|
||||
flop = prod(conv_shape) * prod(filter_size) * batch_size * c_out * c_in * 2
|
||||
return flop
|
||||
|
||||
@register_flop_formula([aten.convolution, aten._convolution])
|
||||
@register_flop_formula([aten.convolution, aten._convolution, aten.cudnn_convolution, aten._slow_conv2d_forward])
|
||||
def conv_flop(x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, *args, out_shape=None, **kwargs) -> int:
|
||||
"""Count flops for convolution."""
|
||||
return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
|
||||
@ -561,6 +560,8 @@ flop_registry = {
|
||||
aten._scaled_mm: _scaled_mm_flop,
|
||||
aten.convolution: conv_flop,
|
||||
aten._convolution: conv_flop,
|
||||
aten.cudnn_convolution: conv_flop,
|
||||
aten._slow_conv2d_forward: conv_flop,
|
||||
aten.convolution_backward: conv_backward_flop,
|
||||
aten._scaled_dot_product_efficient_attention: sdpa_flop,
|
||||
aten._scaled_dot_product_flash_attention: sdpa_flop,
|
||||
|
Reference in New Issue
Block a user