Inductor logging + analysis of torch.profile (#149697)

Prereqs:
 - https://github.com/pytorch/pytorch/pull/152708

Features:
1. Adds inductor's estimate of flops and bandwidth to the json trace events that perfetto uses.
1. Only use the tflops estimation from triton if we don't have the info from the datasheet because Triton's estimates are inaccurate. I have a backlog item to fix triton flops estimation upstream. New `DeviceInfo` class, and new function `get_device_tflops`.
1. New helpers `countable_fx` and `count_flops_fx` helps get the flops of an `fx.Node`.
1. Extends Triton `torch.profiler` logging to `DebugAutotuner`.
1. New script `profile_analysis.py`: `--augment_trace` adds perf estimates to any perfetto json trace, `--analyze` creates a summary table of these perf estimates, and `--diff` will compare two traces side by side:
```python
Device(NVIDIA H100, 0):
 Kernel Name                              | resnet Kernel Count | resnet FLOPS       | resnet bw gbps        | resnet Dur (ms)    | resnet Achieved FLOPS % | resnet Achieved Bandwidth % | newresnet Kernel Count | newresnet FLOPS    | newresnet bw gbps     | newresnet Dur (ms) | newresnet Achieved FLOPS % | newresnet Achieved Bandwidth %
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
 triton_poi_fused__native_batch_norm_legi | 24                  | 0                  | 0.11395268248131513   | 2.5919166666666666 | 0                       | 0.003401572611382541        | 24                     | 0                  | 0.11395268248131513   | 2.5919166666666666 | 0                          | 0.003401572611382541
 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 142                 | 16932673552.422373 | 0.2585007824198784    | 12.441619718309857 | 0.08683422334575583     | 0.007716441266265022        | 142                    | 16932673552.422373 | 0.2585007824198784    | 12.441619718309857 | 0.08683422334575583        | 0.007716441266265022
 triton_red_fused__native_batch_norm_legi | 39                  | 0                  | 0.13990024992108846   | 5.752589743589743  | 0                       | 0.004176126863316074        | 39                     | 0                  | 0.13990024992108846   | 5.752589743589743  | 0                          | 0.004176126863316074
 triton_poi_fused__native_batch_norm_legi | 25                  | 0                  | 0.31824055917536503   | 2.5291999999999994 | 0                       | 0.009499718184339253        | 25                     | 0                  | 0.31824055917536503   | 2.5291999999999994 | 0                          | 0.009499718184339253
 void cutlass::Kernel2<cutlass_80_tensoro | 98                  | 16211056473.596165 | 0.42972434051025826   | 7.130408163265306  | 0.08313362294151874     | 0.012827592254037562        | 98                     | 16211056473.596165 | 0.42972434051025826   | 7.130408163265306  | 0.08313362294151874        | 0.012827592254037562
 triton_red_fused__native_batch_norm_legi | 73                  | 0                  | 0.3225381327611705    | 9.987068493150682  | 0                       | 0.009628003963020014        | 73                     | 0                  | 0.3225381327611705    | 9.987068493150682  | 0                          | 0.009628003963020014
 triton_poi_fused__native_batch_norm_legi | 15                  | 0                  | 1.4491211346487216    | 4.439333333333333  | 0                       | 0.043257347302946926        | 15                     | 0                  | 1.4491211346487216    | 4.439333333333333  | 0                          | 0.043257347302946926
 void cutlass::Kernel2<cutlass_80_tensoro | 186                 | 14501701145.337954 | 0.2667131401910989    | 7.873865591397849  | 0.07436769818122027     | 0.007961586274361157        | 186                    | 14501701145.337954 | 0.2667131401910989    | 7.873865591397849  | 0.07436769818122027        | 0.007961586274361157
 triton_poi_fused__native_batch_norm_legi | 33                  | 0                  | 1.4924556538193923    | 4.3101515151515155 | 0                       | 0.044550915039384846        | 33                     | 0                  | 1.4924556538193923    | 4.3101515151515155 | 0                          | 0.044550915039384846
 triton_red_fused__native_batch_norm_legi | 29                  | 0                  | 0.25562590522631107   | 6.296275862068965  | 0                       | 0.007630624036606301        | 29                     | 0                  | 0.25562590522631107   | 6.296275862068965  | 0                          | 0.007630624036606301
 triton_poi_fused__native_batch_norm_legi | 13                  | 0                  | 0.5870562174192726    | 2.7397692307692307 | 0                       | 0.01752406619162008         | 13                     | 0                  | 0.5870562174192726    | 2.7397692307692307 | 0                          | 0.01752406619162008
 triton_poi_fused__native_batch_norm_legi | 34                  | 0                  | 0.41409928846284      | 2.853588235294117  | 0                       | 0.012361172789935523        | 34                     | 0                  | 0.41409928846284      | 2.853588235294117  | 0                          | 0.012361172789935523
 triton_per_fused__native_batch_norm_legi | 34                  | 0                  | 0.11705315007018151   | 3.460647058823529  | 0                       | 0.0034941238826919864       | 34                     | 0                  | 0.11705315007018151   | 3.460647058823529  | 0                          | 0.0034941238826919864
 triton_poi_fused__native_batch_norm_legi | 16                  | 0                  | 0.17207853197124584   | 2.3459375000000002 | 0                       | 0.005136672596156592        | 16                     | 0                  | 0.17207853197124584   | 2.3459375000000002 | 0                          | 0.005136672596156592
 triton_per_fused__native_batch_norm_legi | 30                  | 0                  | 0.2639714322022256    | 6.131199999999999  | 0                       | 0.007879744244842555        | 30                     | 0                  | 0.2639714322022256    | 6.131199999999999  | 0                          | 0.007879744244842555
 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 100                 | 11875430356.891787 | 0.19494470869421385   | 16.36534           | 0.06089964285585531     | 0.005819245035648175        | 100                    | 11875430356.891787 | 0.19494470869421385   | 16.36534           | 0.06089964285585531        | 0.005819245035648175
 triton_poi_fused__native_batch_norm_legi | 8                   | 0                  | 0.9854096626224687    | 3.2757500000000004 | 0                       | 0.029415213809625928        | 8                      | 0                  | 0.9854096626224687    | 3.2757500000000004 | 0                          | 0.029415213809625928
 void cublasLt::splitKreduce_kernel<32, 1 | 56                  | 34377923395.147064 | 0.8310300045762317    | 3.4199999999999986 | 0.17629704305203628     | 0.024806865808245714        | 56                     | 34377923395.147064 | 0.8310300045762317    | 3.4199999999999986 | 0.17629704305203628        | 0.024806865808245714
 triton_poi_fused__native_batch_norm_legi | 23                  | 0                  | 0.9944002965861103    | 3.2431304347826084 | 0                       | 0.02968359094286896         | 23                     | 0                  | 0.9944002965861103    | 3.2431304347826084 | 0                          | 0.02968359094286896
 triton_per_fused__native_batch_norm_legi | 10                  | 0                  | 0.1826801058931057    | 4.428800000000001  | 0                       | 0.00545313748934644         | 10                     | 0                  | 0.1826801058931057    | 4.428800000000001  | 0                          | 0.00545313748934644
 triton_poi_fused__native_batch_norm_legi | 10                  | 0                  | 0.3168973585366449    | 2.5471999999999997 | 0                       | 0.009459622642884923        | 10                     | 0                  | 0.3168973585366449    | 2.5471999999999997 | 0                          | 0.009459622642884923
 triton_poi_fused__native_batch_norm_legi | 34                  | 0                  | 1.1463614897015777    | 4.124323529411764  | 0                       | 0.03421974596124114         | 34                     | 0                  | 1.1463614897015777    | 4.124323529411764  | 0                          | 0.03421974596124114
 void cask_plugin_cudnn::xmma_cudnn::init | 44                  | 44045510816.64277  | 2.0661232850348643    | 3.6887499999999993 | 0.22587441444432194     | 0.06167532194133924         | 44                     | 44045510816.64277  | 2.0661232850348643    | 3.6887499999999993 | 0.22587441444432194        | 0.06167532194133924
 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 95                  | 7876855400.165316  | 0.4694941555946739    | 18.224315789473682 | 0.04039413025725802     | 0.014014750913273854        | 95                     | 7876855400.165316  | 0.4694941555946739    | 18.224315789473682 | 0.04039413025725802        | 0.014014750913273854
 triton_per_fused__native_batch_norm_legi | 41                  | 0                  | 0.06825669875995298   | 3.0384146341463416 | 0                       | 0.002037513395819492        | 41                     | 0                  | 0.06825669875995298   | 3.0384146341463416 | 0                          | 0.002037513395819492
 triton_poi_fused__native_batch_norm_legi | 23                  | 0                  | 0.08808154712430301   | 2.3275652173913044 | 0                       | 0.0026292999141582997       | 23                     | 0                  | 0.08808154712430301   | 2.3275652173913044 | 0                          | 0.0026292999141582997
 triton_per_fused__native_batch_norm_legi | 40                  | 0                  | 0.18179321034952417   | 4.556825           | 0                       | 0.005426662995508183        | 40                     | 0                  | 0.18179321034952417   | 4.556825           | 0                          | 0.005426662995508183
 triton_poi_fused__native_batch_norm_legi | 15                  | 0                  | 0.5887415155454232    | 2.783866666666667  | 0                       | 0.017574373598370836        | 15                     | 0                  | 0.5887415155454232    | 2.783866666666667  | 0                          | 0.017574373598370836
 void cutlass::Kernel2<cutlass_80_tensoro | 38                  | 14242013806.264643 | 0.256592404353939     | 7.217631578947369  | 0.0730359682372546      | 0.007659474756834           | 38                     | 14242013806.264643 | 0.256592404353939     | 7.217631578947369  | 0.0730359682372546         | 0.007659474756834
 triton_poi_fused__native_batch_norm_legi | 21                  | 0                  | 0.5842860973430516    | 2.7779047619047623 | 0                       | 0.017441376040091088        | 21                     | 0                  | 0.5842860973430516    | 2.7779047619047623 | 0                          | 0.017441376040091088
 triton_per_fused__native_batch_norm_legi | 16                  | 0                  | 0.11509365173486417   | 3.5959375000000002 | 0                       | 0.0034356313950705724       | 16                     | 0                  | 0.11509365173486417   | 3.5959375000000002 | 0                          | 0.0034356313950705724
 triton_poi_fused__native_batch_norm_legi | 14                  | 0                  | 0.1704672000243914    | 2.4044285714285714 | 0                       | 0.00508857313505646         | 14                     | 0                  | 0.1704672000243914    | 2.4044285714285714 | 0                          | 0.00508857313505646
 triton_poi_fused__native_batch_norm_legi | 58                  | 0                  | 2.307520779930795     | 8.190706896551722  | 0                       | 0.06888121731136704         | 58                     | 0                  | 2.307520779930795     | 8.190706896551722  | 0                          | 0.06888121731136704
 triton_per_fused__native_batch_norm_legi | 29                  | 0                  | 0.037243248971881276  | 3.0277586206896556 | 0                       | 0.001111738775280038        | 29                     | 0                  | 0.037243248971881276  | 3.0277586206896556 | 0                          | 0.001111738775280038
 triton_poi_fused__native_batch_norm_legi | 20                  | 0                  | 0.04741699795428918   | 2.2911500000000005 | 0                       | 0.0014154327747549007       | 20                     | 0                  | 0.04741699795428918   | 2.2911500000000005 | 0                          | 0.0014154327747549007
 triton_per_fused__native_batch_norm_legi | 25                  | 0                  | 0.13357016893727824   | 3.37536            | 0                       | 0.003987169222008305        | 25                     | 0                  | 0.13357016893727824   | 3.37536            | 0                          | 0.003987169222008305
 triton_poi_fused__native_batch_norm_legi | 13                  | 0                  | 0.3089862268300253    | 2.8111538461538457 | 0                       | 0.009223469457612694        | 13                     | 0                  | 0.3089862268300253    | 2.8111538461538457 | 0                          | 0.009223469457612694
 triton_poi_fused__native_batch_norm_legi | 17                  | 0                  | 0.3129385387909844    | 2.673              | 0                       | 0.009341448919133863        | 17                     | 0                  | 0.3129385387909844    | 2.673              | 0                          | 0.009341448919133863
 triton_per_fused__native_batch_norm_legi | 19                  | 0                  | 0.2215568162533158    | 3.8837368421052636 | 0                       | 0.0066136363060691275       | 19                     | 0                  | 0.2215568162533158    | 3.8837368421052636 | 0                          | 0.0066136363060691275
 std::enable_if<!(false), void>::type int | 23                  | 504916805.19297093 | 1.0118296096314707    | 8.113913043478261  | 0.0025893169497075447   | 0.030203868944223014        | 23                     | 504916805.19297093 | 1.0118296096314707    | 8.113913043478261  | 0.0025893169497075447      | 0.030203868944223014
 triton_poi_fused_add_copy__38            | 56                  | 0                  | 0                     | 2.132482142857143  | 0                       | 0                           | 56                     | 0                  | 0                     | 2.132482142857143  | 0                          | 0
 triton_poi_fused_convolution_0           | 18                  | 0                  | 0.43458610794936897   | 2.773333333333334  | 0                       | 0.012972719640279667        | 18                     | 0                  | 0.43458610794936897   | 2.773333333333334  | 0                          | 0.012972719640279667
 triton_poi_fused_convolution_1           | 17                  | 0                  | 0.028816312469162712  | 2.6145882352941174 | 0                       | 0.0008601884319153051       | 17                     | 0                  | 0.028816312469162712  | 2.6145882352941174 | 0                          | 0.0008601884319153051
 void convolve_common_engine_float_NHWC<f | 44                  | 8641868995.31118   | 0.024730540008465626  | 25.87327272727273  | 0.04431727689903169     | 0.0007382250748795709       | 44                     | 8641868995.31118   | 0.024730540008465626  | 25.87327272727273  | 0.04431727689903169        | 0.0007382250748795709
 triton_per_fused__native_batch_norm_legi | 12                  | 0                  | 0.6809930918986744    | 4.82675            | 0                       | 0.020328151996975356        | 12                     | 0                  | 0.6809930918986744    | 4.82675            | 0                          | 0.020328151996975356
 triton_per_fused__native_batch_norm_legi | 14                  | 0                  | 0.02883030597936608   | 2.6651428571428575 | 0                       | 0.0008606061486377935       | 14                     | 0                  | 0.02883030597936608   | 2.6651428571428575 | 0                          | 0.0008606061486377935
 triton_per_fused__native_batch_norm_legi | 16                  | 0                  | 0.0014658988233201874 | 2.098              | 0                       | 4.375817383045335e-05       | 16                     | 0                  | 0.0014658988233201874 | 2.098              | 0                          | 4.375817383045335e-05
 triton_poi_fused__native_batch_norm_legi | 13                  | 0                  | 0.9926297180284697    | 3.2367692307692306 | 0                       | 0.02963073785159611         | 13                     | 0                  | 0.9926297180284697    | 3.2367692307692306 | 0                          | 0.02963073785159611
 triton_poi_fused__native_batch_norm_legi | 9                   | 0                  | 1.3008817095666507    | 3.0863333333333336 | 0                       | 0.03883228983781048         | 9                      | 0                  | 1.3008817095666507    | 3.0863333333333336 | 0                          | 0.03883228983781048
 void at::native::(anonymous namespace):: | 98                  | 0                  | 0.09174335613709389   | 4.408520408163265  | 0                       | 0.0027386076458833994       | 98                     | 0                  | 0.09174335613709389   | 4.408520408163265  | 0                          | 0.0027386076458833994
 void at::native::vectorized_elementwise_ | 7                   | 0                  | 0                     | 1.7278571428571428 | 0                       | 0                           | 7                      | 0                  | 0                     | 1.7278571428571428 | 0                          | 0
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149697
Approved by: https://github.com/eellison, https://github.com/shunting314
This commit is contained in:
Gabriel Ferns
2025-07-01 16:51:03 +00:00
committed by PyTorch MergeBot
parent 0f9c1b374f
commit 47f10d0ad0
19 changed files with 1814 additions and 104 deletions

View File

@ -0,0 +1,608 @@
# Owner(s): ["module: inductor"]
import json
import tempfile
import unittest
import uuid
from io import StringIO
from unittest.mock import patch
import torch
import torch.nn.functional as F
from torch._inductor.analysis.profile_analysis import (
_augment_trace_helper,
_create_extern_mapping,
JsonProfile,
main,
)
from torch._inductor.utils import fresh_inductor_cache, tabulate_2d, zip_dicts
from torch.testing._internal.common_cuda import SM80OrLater
from torch.testing._internal.common_device_type import (
dtypes,
instantiate_device_type_tests,
skipIf,
)
from torch.testing._internal.common_utils import parametrize, run_tests, TestCase
from torch.testing._internal.inductor_utils import IS_BIG_GPU
example_profile = """
{
"schemaVersion": 1,
"deviceProperties": [
{
"id": 0, "name": "NVIDIA H100", "totalGlobalMem": 101997215744,
"computeMajor": 9, "computeMinor": 0,
"maxThreadsPerBlock": 1024, "maxThreadsPerMultiprocessor": 2048,
"regsPerBlock": 65536, "warpSize": 32,
"sharedMemPerBlock": 49152, "numSms": 132
, "regsPerMultiprocessor": 65536, "sharedMemPerBlockOptin": 232448, "sharedMemPerMultiprocessor": 233472
}
],
"cupti_version": 24,
"cuda_runtime_version": 12060,
"with_flops": 1,
"record_shapes": 1,
"cuda_driver_version": 12040,
"profile_memory": 1,
"trace_id": "301995E163ED42048FBD783860E6E7DC",
"displayTimeUnit": "ms",
"baseTimeNanoseconds": 1743521598000000000,
"traceEvents": [
{
"ph": "X", "cat": "cpu_op", "name": "aten::convolution", "pid": 1147039, "tid": 1147039,
"ts": 198093488368.463, "dur": 425.453,
"args": {
"External id": 1340,"Sequence number": 0, "Fwd thread id": 0, "Record function id": 0, "Concrete Inputs": \
["", "", "", "[2, 2]", "[3, 3]", "[1, 1]", "False", "[0, 0]", "1"], "Input type": ["float", "float", "", \
"ScalarList", "ScalarList", "ScalarList", "Scalar", "ScalarList", "Scalar"], "Input Strides": [[150528, 1, 672, 3],\
[147, 1, 21, 3], [], [], [], [], [], [], []], "Input Dims": [[1, 3, 224, 224], [64, 3, 7, 7], [], [], [], [], [], \
[], []], "Ev Idx": 1339
}
},
{
"ph": "X", "cat": "cpu_op", "name": "aten::_convolution", "pid": 1147039, "tid": 1147039,
"ts": 198093488444.498, "dur": 341.867,
"args": {
"External id": 1341,"Record function id": 0, "Concrete Inputs": ["", "", "", "[2, 2]", "[3, 3]", "[1, 1]",\
"False", "[0, 0]", "1", "False", "False", "True", "True"], "Input type": ["float", "float", "", "ScalarList",\
"ScalarList", "ScalarList", "Scalar", "ScalarList", "Scalar", "Scalar", "Scalar", "Scalar", "Scalar"], "Input Strides": \
[[150528, 1, 672, 3], [147, 1, 21, 3], [], [], [], [], [], [], [], [], [], [], []], "Input Dims": [[1, 3, 224, 224], \
[64, 3, 7, 7], [], [], [], [], [], [], [], [], [], [], []], "Ev Idx": 1340
}
},
{
"ph": "X", "cat": "cpu_op", "name": "aten::addmm", "pid": 1147039, "tid": 1147039,
"ts": 198093513655.849, "dur": 251.130,
"args": {
"External id": 1619,"Sequence number": 0, "Fwd thread id": 0, "Record function id": 0, "Concrete Inputs": \
["", "", "", "1", "1", ""], "Input type": ["float", "float", "float", "Scalar", "Scalar", "float"], "Input Strides":\
[[1], [0, 1], [1, 2048], [], [], [1000, 1]], "Input Dims": [[1000], [1, 2048], [2048, 1000], [], [], [1, 1000]], \
"Ev Idx": 1618
}
},
{
"ph": "X", "cat": "kernel", "name": "void cutlass_addmm", "pid": 1147039, "tid": 1147039,
"ts": 198093513655.849, "dur": 251.130,
"args": {
"External id": 1619,"Sequence number": 0, "Fwd thread id": 0, "Record function id": 0, "Ev Idx": 1618
}
},
{
"ph": "X", "cat": "kernel", "name": "void convolution_kernel", "pid": 1147039, "tid": 1147039,
"ts": 198093513655.849, "dur": 200.130,
"args": {
"External id": 1342, "Sequence number": 0, "Fwd thread id": 0, "Record function id": 0, "Ev Idx": 1618
}
},
{
"ph": "X", "cat": "cpu_op", "name": "aten::convolution", "pid": 1147039, "tid": 1147039,
"ts": 198093488444.498, "dur": 341.867,
"args": {
"External id": 1342,"Record function id": 0, "Concrete Inputs": ["", "", "", "[2, 2]", "[3, 3]", "[1, 1]", \
"False", "[0, 0]", "1", "False", "False", "True", "True"], "Input type": ["float", "float", "", "ScalarList", \
"ScalarList", "ScalarList", "Scalar", "ScalarList", "Scalar", "Scalar", "Scalar", "Scalar", "Scalar"], "Input \
Strides": [[150528, 1, 672, 3], [147, 1, 21, 3], [], [], [], [], [], [], [], [], [], [], []], "Input Dims": \
[[1, 3, 224, 224], [64, 3, 7, 7], [], [], [], [], [], [], [], [], [], [], []], "Ev Idx": 1340
}
}
],
"traceName": "/tmp/compiled_module_profile.json"
}
"""
def verify_flops(self, expected_flops, out_profile):
j = 0
for i in range(len(out_profile["traceEvents"])):
if "kernel_flop" in out_profile["traceEvents"][i]["args"]:
self.assertEqual(
out_profile["traceEvents"][i]["args"]["kernel_flop"],
expected_flops[j],
)
j += 1
def random_tensor(size, dtype, **kwargs):
if dtype in [torch.half, torch.bfloat16, torch.float, torch.double]:
return torch.randn(size, dtype=dtype, **kwargs)
elif dtype in [torch.uint8, torch.int8, torch.short, torch.int, torch.long]:
return torch.randint(0, 100, size, dtype=dtype, **kwargs)
else:
raise ValueError("Unsupported data type")
def cT(device, dtype):
def T(*shape, requires_grad=False):
return random_tensor(
shape, requires_grad=requires_grad, device=device, dtype=dtype
)
return T
def FlopCounterMode(*args, **kwargs):
return torch.utils.flop_counter.FlopCounterMode(*args, **kwargs, display=False)
TMP_DIR = tempfile.mkdtemp()
def trace_files():
TRACE1 = f"{TMP_DIR}/trace1-{uuid.uuid4()}.json"
TRACE2 = f"{TMP_DIR}/trace2-{uuid.uuid4()}.json"
return TRACE1, TRACE2
def _test_model(device, dtype, compile=True, addmm=True, bmm=True):
T = cT(device, dtype)
def model():
input_conv = T(1, 3, 56, 56)
conv_weight = T(12, 3, 5, 5)
# Increased matrix sizes
B = 8
M = 256
N = 512
K = 768
mat1 = T(M, N)
mat2 = T(N, K)
batch_mat1 = T(B, M, N)
batch_mat2 = T(B, N, K)
conv_output = F.conv2d(input_conv, conv_weight)
conv_output = conv_output * 10
mm_output = torch.mm(mat1, mat2)
ret = [
conv_output.flatten(),
mm_output.flatten(),
]
if addmm:
addmm_output = torch.addmm(
torch.zeros(mm_output.shape, device=mat1.device, dtype=mat1.dtype),
mat1,
mat2,
)
ret.append(addmm_output.flatten())
if bmm:
bmm_output = torch.bmm(batch_mat1, batch_mat2)
ret.append(bmm_output.flatten())
if bmm and addmm:
baddbmm_output = torch.baddbmm(
torch.zeros(
1,
*mm_output.shape,
device=batch_mat1.device,
dtype=batch_mat1.dtype,
),
batch_mat1,
batch_mat2,
)
ret.append(baddbmm_output.flatten())
return torch.cat(ret)
if compile:
return torch.compile(
model, options={"benchmark_kernel": True, "profile_bandwidth": True}
)
return model
def _pointwise_test_model(device, dtype, compile=True):
T = cT(device, dtype)
def model():
M = 1024
N = 512
mat3 = T(M, N)
mat4 = T(M, N)
pointwise_output = torch.add(mat3, mat4).sin()
return pointwise_output
if compile:
return torch.compile(
model, options={"benchmark_kernel": True, "profile_bandwidth": True}
)
return model
prefix = ["profile.py"]
class TestUtils(TestCase):
def test_tabulate2d(self):
headers = ["Kernel", "Self H100 TIME (ms)", "Count", "Percent"]
rows = [
["aten::mm", 0.500, 7, 0.0],
["aten::bmm", 0.400, 6, 0.0],
["aten::baddbmm", 0.300, 5, 0.0],
["aten::convolution", 0.200, 4, 0.0],
["aten::cudnn_convolution", 0.100, 3, 0.0],
]
table = [
" Kernel | Self H100 TIME (ms) | Count | Percent ",
"-----------------------------------------------------------------",
" aten::mm | 0.5 | 7 | 0.0 ",
" aten::bmm | 0.4 | 6 | 0.0 ",
" aten::baddbmm | 0.3 | 5 | 0.0 ",
" aten::convolution | 0.2 | 4 | 0.0 ",
" aten::cudnn_convolution | 0.1 | 3 | 0.0 ",
]
res = tabulate_2d(rows, headers)
for r, t in zip(res.split("\n"), table):
self.assertEqual(r, t)
def test_zip_dicts(self):
d1 = {"a": 1, "b": 2}
d2 = {"a": 3, "c": 4}
res1 = zip_dicts(d1, d2, d1_default=32, d2_default=48)
self.assertEqual(set(res1), {("a", 1, 3), ("b", 2, 48), ("c", 32, 4)})
res2 = zip_dicts(d1, d2)
self.assertEqual(set(res2), {("a", 1, 3), ("b", 2, None), ("c", None, 4)})
class TestAnalysis(TestCase):
@skipIf(not SM80OrLater, "Requires SM80")
def test_noop(self):
with (
patch("sys.stdout", new_callable=StringIO) as mock_stdout,
patch("sys.argv", [*prefix]),
):
main()
self.assertEqual(mock_stdout.getvalue(), "")
@skipIf(not SM80OrLater, "Requires SM80")
@dtypes(torch.float, torch.double, torch.float16)
def test_diff(self, device, dtype):
"""
diff, testing out the nruns feature too.
"""
if device == "cpu":
# TODO cpu support
return
om = _test_model(device, dtype)
REPEAT = 5
trace1, trace2 = trace_files()
print(f"first trace {trace1}")
torch._dynamo.reset() # reset the cache
with fresh_inductor_cache():
with torch.profiler.profile(record_shapes=True) as p:
om()
p.export_chrome_trace(trace1)
print(f"second trace {trace2}")
torch._dynamo.reset() # reset the cache
with fresh_inductor_cache():
with torch.profiler.profile(record_shapes=True) as p:
for _ in range(REPEAT):
om()
p.export_chrome_trace(trace2)
print("diffing...")
with patch(
"sys.argv",
[
*prefix,
"--diff",
trace1,
"foo",
trace2,
"bar",
str(dtype).split(".")[-1],
"--name_limit",
"30",
],
):
main()
@skipIf(not SM80OrLater, "Requires SM80")
def test_augment_trace_helper_unit(self):
js = json.loads(example_profile)
out_profile = _augment_trace_helper(js)
expected_flops = [4096000, 4096000, 223552896, 223552896, 0, 0, 0]
verify_flops(self, expected_flops, out_profile)
@skipIf(not SM80OrLater, "Requires SM80")
@dtypes(torch.float, torch.double, torch.float16)
def test_augment_trace_helper_args(self, device, dtype):
if device == "cpu":
# cpu doesn't produce traces currently
return
om = _test_model(device, dtype)
torch._dynamo.reset() # reset the cache
with fresh_inductor_cache():
with torch.profiler.profile(record_shapes=True) as p:
om()
trace1, trace2 = trace_files()
p.export_chrome_trace(trace1)
print(f"first trace {trace1}")
with patch(
"sys.argv",
[*prefix, "--augment_trace", trace1, trace2, str(dtype).split(".")[-1]],
):
main()
profile = JsonProfile(
trace2, benchmark_name="foo", dtype=str(dtype).split(".")[-1]
)
rep = profile.report()
self.assertTrue(len(rep.split("\n")) > 3, f"Error, empty table:\n{rep}")
# If these fail, just update them. They could change over time
self.assertIn("Kernel Name", rep)
self.assertIn("Kernel Count", rep)
self.assertIn("FLOPS", rep)
self.assertIn("Kernel Reads", rep)
self.assertIn("Dur", rep)
self.assertIn("Achieved", rep)
self.assertIn("|", rep)
self.assertIn("-----", rep)
tables = profile._create_tables(profile._devices)
# check to make sure all % values are less than 100%
percents = []
for tab in tables.values():
header, rows = tab
for i, h in enumerate(header):
if "%" in h:
percents.append(i)
self.assertTrue(len(percents) > 0, "There are no headers with % in them")
for row in rows.values():
for p in percents:
idx = p - 1
self.assertTrue(
float(row[idx]) <= 100.0,
f"column values from column {idx} with header '{header[idx]}' is greater than 100%: {row[idx]}",
)
self.assertTrue(
float(row[idx]) >= 0.0,
f"column values from column {idx} with header '{header[idx]}' is less than 0%: {row[idx]}",
)
@skipIf(not SM80OrLater, "Requires SM80")
@dtypes(torch.float, torch.double, torch.float16)
@parametrize(
"maxat",
[
(True, "TRITON"),
],
)
@skipIf(not IS_BIG_GPU, "we can't use Triton only as a backend for max autotune")
def test_triton_has_metadata(self, device, dtype, maxat):
"""
make sure that the chrome trace of triton kernels contains certain values
"""
if device == "cpu":
return
T = cT(device, dtype)
input_conv = T(1, 3, 56, 56)
conv_weight = T(12, 3, 5, 5)
def om(i, w):
# Convolution operation
conv_output = F.conv2d(i, w)
return conv_output
max_autotune, backends = maxat
comp_omni = torch.compile(
om,
options={
"benchmark_kernel": True,
"max_autotune_gemm_backends": backends,
"force_disable_caches": True,
"max_autotune": max_autotune,
},
)
def verify_triton(comp):
torch._dynamo.reset() # reset the cache
with fresh_inductor_cache():
with torch.profiler.profile(record_shapes=True) as profile:
comp(input_conv, conv_weight)
trace1, _ = trace_files()
profile.export_chrome_trace(trace1)
with open(trace1) as f:
out_profile = json.load(f)
seen = False
for event in out_profile["traceEvents"]:
if "triton" in event["name"] and "conv" in event["name"]:
seen = True
self.assertTrue(seen, "no triton conv found")
verify_triton(comp_omni)
@skipIf(not SM80OrLater, "Requires SM80")
@dtypes(torch.float, torch.float16)
@parametrize(
"maxat",
[
(False, "ATEN,TRITON"),
(True, "ATEN,TRITON"),
(True, "ATEN"),
(True, "TRITON"),
],
)
@unittest.skipIf(
not IS_BIG_GPU, "we can't use Triton only as a backend for max autotune"
)
def test_augment_trace_against_flop_counter(self, device, dtype, maxat):
# this tests to see if we can only use a Triton backend for max autotune
max_autotune, backends = maxat
if device == "cpu":
return
om = _test_model(device, dtype, compile=False)
comp_omni = torch.compile(
om,
options={
"benchmark_kernel": True,
"max_autotune_gemm_backends": backends,
"force_disable_caches": True,
"max_autotune": max_autotune,
},
)
comp_omni()
torch._dynamo.reset() # reset the cache
with fresh_inductor_cache():
with torch.profiler.profile(record_shapes=True) as profile:
comp_omni()
torch._dynamo.reset() # reset the cache
with fresh_inductor_cache():
with FlopCounterMode() as mode:
comp_omni()
trace1, trace2 = trace_files()
profile.export_chrome_trace(trace1)
with patch(
"sys.argv",
[*prefix, "--augment_trace", trace1, trace2, str(dtype).split(".")[-1]],
):
main()
with open(trace2) as f:
out_profile = json.load(f)
flop_counts = mode.flop_counts
extern_mapping = _create_extern_mapping(out_profile)
seen_mm = False
seen_bmm = False
seen_baddbmm = False
seen_conv = False
for event in out_profile["traceEvents"]:
if (
"cat" not in event
or event["cat"] != "kernel"
or "args" not in event
or "External id" not in event["args"]
):
continue
external_op = extern_mapping[event["args"]["External id"]][0]
name: str = external_op["name"]
self.assertNotEqual(name, None)
self.assertEqual(type(name), str)
if name.startswith("aten::mm") or "_mm_" in name:
seen_mm = True
self.assertEqual(
event["args"]["kernel_flop"],
flop_counts["Global"][torch.ops.aten.mm],
)
if (
name.startswith(
(
"aten::cudnn_convolution",
"aten::convolution",
"aten::_convolution",
)
)
or "conv" in name
):
seen_conv = True
self.assertEqual(
event["args"]["kernel_flop"],
flop_counts["Global"][torch.ops.aten.convolution],
)
if name.startswith("aten::baddbmm") or "_baddbmm_" in name:
seen_baddbmm = True
self.assertEqual(
event["args"]["kernel_flop"],
flop_counts["Global"][torch.ops.aten.baddbmm],
)
if name.startswith("aten::bmm") or "_bmm_" in name:
seen_bmm = True
self.assertEqual(
event["args"]["kernel_flop"],
flop_counts["Global"][torch.ops.aten.bmm],
)
self.assertTrue(seen_mm)
self.assertTrue(seen_bmm)
self.assertTrue(seen_baddbmm)
self.assertTrue(seen_conv)
@skipIf(not SM80OrLater, "Requires SM80")
@dtypes(torch.float, torch.float16)
@parametrize(
"maxat",
[
(False, "ATEN,TRITON"),
(True, "ATEN,TRITON"),
(True, "ATEN"),
(True, "TRITON"),
],
)
@unittest.skipIf(
not IS_BIG_GPU, "we can't use Triton only as a backend for max autotune"
)
def test_pointwise_bandwidth(self, device, dtype, maxat):
# this tests to see if we can only use a Triton backend for max autotune
max_autotune, backends = maxat
if device == "cpu":
return
om = _pointwise_test_model(device, dtype, compile=False)
comp_omni = torch.compile(
om,
options={
"benchmark_kernel": True,
"max_autotune_gemm_backends": backends,
"force_disable_caches": True,
"max_autotune": max_autotune,
},
)
comp_omni()
with fresh_inductor_cache():
with torch.profiler.profile(record_shapes=True) as profile:
comp_omni()
trace1, _ = trace_files()
profile.export_chrome_trace(trace1)
with patch(
"sys.argv",
[*prefix, "--analysis", trace1, str(dtype).split(".")[-1]],
):
main()
with open(trace1) as f:
out_profile = json.load(f)
for event in out_profile["traceEvents"]:
if event["name"] == "triton_poi_fused_add_randn_sin_0":
event["args"]["kernel_num_gb"] = 0.002097168
instantiate_device_type_tests(TestAnalysis, globals())
if __name__ == "__main__":
run_tests()

View File

@ -1,11 +1,12 @@
# Owner(s): ["module: inductor"] # Owner(s): ["module: inductor"]
from unittest import skipIf
import torch import torch
import torch._inductor.metrics as metrics import torch._inductor.metrics as metrics
import torch.utils.flop_counter import torch.utils.flop_counter
from torch._dynamo.utils import counters from torch._dynamo.utils import counters
from torch._inductor.ir import FixedLayout from torch._inductor.utils import fresh_inductor_cache
from torch._inductor.utils import fresh_cache
from torch.testing._internal.common_cuda import SM70OrLater from torch.testing._internal.common_cuda import SM70OrLater
from torch.testing._internal.common_device_type import ( from torch.testing._internal.common_device_type import (
dtypes, dtypes,
@ -13,6 +14,7 @@ from torch.testing._internal.common_device_type import (
skipCUDAIf, skipCUDAIf,
) )
from torch.testing._internal.common_utils import parametrize, run_tests, TestCase from torch.testing._internal.common_utils import parametrize, run_tests, TestCase
from torch.testing._internal.inductor_utils import IS_BIG_GPU
def FlopCounterMode(*args, **kwargs): def FlopCounterMode(*args, **kwargs):
@ -77,7 +79,7 @@ class TestScheduler(TestCase):
for op, example_inputs, kwargs in tc: for op, example_inputs, kwargs in tc:
comp = torch.compile(op) comp = torch.compile(op)
torch._dynamo.reset() torch._dynamo.reset()
with fresh_cache(): with fresh_inductor_cache():
comp(*example_inputs, **kwargs) comp(*example_inputs, **kwargs)
self.assertEqual(metrics.num_bytes_accessed, 0) self.assertEqual(metrics.num_bytes_accessed, 0)
self.assertEqual(any(m[1] for m in metrics.node_runtimes), False) self.assertEqual(any(m[1] for m in metrics.node_runtimes), False)
@ -85,37 +87,6 @@ class TestScheduler(TestCase):
metrics.reset() metrics.reset()
torch._logging.set_logs() torch._logging.set_logs()
@dtypes(torch.float, torch.float16)
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
def test_get_estimated_runtime_logging(self, device, dtype):
if device == "cpu":
return
tc = _test_cases(device, dtype)
expected_metrics = [
# num_bytes_accessed, number of nonzero node_runtimes
(74 * dtype.itemsize, 1),
(60 * dtype.itemsize, 1),
(222 * dtype.itemsize, 4),
(77 * dtype.itemsize, 2),
]
tc_plus_metrics = zip(tc, expected_metrics)
metrics.reset()
torch._logging.set_logs(inductor_metrics=True)
for test_case, met in tc_plus_metrics:
op, example_inputs, kwargs = test_case
enba, enr = met
comp = torch.compile(op)
torch._dynamo.reset()
with fresh_cache():
comp(*example_inputs, **kwargs)
self.assertEqual(enba, metrics.num_bytes_accessed)
nonzero_node_runtimes = sum(1 for x in metrics.node_runtimes if x[1] != 0)
self.assertEqual(enr, nonzero_node_runtimes)
metrics.reset()
torch._logging.set_logs()
@dtypes(torch.float, torch.float16) @dtypes(torch.float, torch.float16)
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
@parametrize( @parametrize(
@ -133,17 +104,10 @@ class TestScheduler(TestCase):
}, },
], ],
) )
@skipIf(not IS_BIG_GPU, "we can't use Triton only as a backend for max autotune")
def test_flop_counter_op(self, device, dtype, options): def test_flop_counter_op(self, device, dtype, options):
if device == "cpu": if device == "cpu":
return return
if (
options["max_autotune_gemm_backends"] == "TRITON"
and torch.cuda.is_available()
and not torch._inductor.utils.use_triton_template(
FixedLayout(torch.device("cuda"), torch.float16, [400, 800])
)
):
return
tc = _test_cases(device, dtype) tc = _test_cases(device, dtype)
@ -152,7 +116,7 @@ class TestScheduler(TestCase):
comp = torch.compile(op, options=options) comp = torch.compile(op, options=options)
# next two lines are required, otherwise the flops will be cached from pervious runs of this function. # next two lines are required, otherwise the flops will be cached from pervious runs of this function.
torch._dynamo.reset() torch._dynamo.reset()
with fresh_cache(): with fresh_inductor_cache():
# actually run to set the counters # actually run to set the counters
comp(*example_inputs, **kwargs) comp(*example_inputs, **kwargs)
with FlopCounterMode() as mode: with FlopCounterMode() as mode:

View File

@ -1,12 +1,18 @@
# Owner(s): ["module: inductor"] # Owner(s): ["module: inductor"]
import unittest
from sympy import Symbol, sympify from sympy import Symbol, sympify
import torch import torch
from torch._inductor.fx_utils import count_flops_fx, countable_fx from torch._inductor.fx_utils import count_flops_fx, countable_fx
from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import get_device_tflops, sympy_str, sympy_subs
from torch._inductor.utils import sympy_str, sympy_subs
from torch._inductor.virtualized import V from torch._inductor.virtualized import V
from torch.testing._internal.common_device_type import (
dtypes,
instantiate_device_type_tests,
)
from torch.testing._internal.common_utils import run_tests, TestCase
class TestUtils(TestCase): class TestUtils(TestCase):
@ -188,6 +194,14 @@ class TestUtils(TestCase):
countable_fx(fx_node_2), f"Expected false {f}: {fx_node_2}" countable_fx(fx_node_2), f"Expected false {f}: {fx_node_2}"
) )
@unittest.skipIf(not torch.cuda.is_available(), "skip if no device")
@dtypes(torch.float16, torch.bfloat16, torch.float32)
def test_get_device_tflops(self, dtype):
ret = get_device_tflops(dtype)
self.assertTrue(type(ret) == float)
instantiate_device_type_tests(TestUtils, globals())
if __name__ == "__main__": if __name__ == "__main__":
run_tests() run_tests()

View File

@ -27,6 +27,7 @@ import torch.nn as nn
import torch.optim import torch.optim
import torch.utils.data import torch.utils.data
from torch._C._profiler import _ExperimentalConfig, _ExtraFields_PyCall from torch._C._profiler import _ExperimentalConfig, _ExtraFields_PyCall
from torch._inductor.utils import is_big_gpu
from torch.autograd.profiler import KinetoStepTracker, profile as _profile from torch.autograd.profiler import KinetoStepTracker, profile as _profile
from torch.autograd.profiler_legacy import profile as _profile_legacy from torch.autograd.profiler_legacy import profile as _profile_legacy
from torch.profiler import ( from torch.profiler import (
@ -3045,6 +3046,53 @@ aten::mm""",
assert "Overload Name" in key_averages.table() assert "Overload Name" in key_averages.table()
validate_json(prof) validate_json(prof)
@unittest.skipIf(not torch.cuda.is_available(), "requries CUDA")
def test_profiler_debug_autotuner(self):
"""
This test makes sure that profiling events will be present when the kernel is run using the DebugAutotuner.
"""
if not is_big_gpu():
raise unittest.SkipTest("requires large gpu to max-autotune")
in1 = torch.randn((256, 512), device="cuda", dtype=torch.float16)
in2 = torch.randn((512, 768), device="cuda", dtype=torch.float16)
def mm():
return torch.mm(in1, in2)
pb_mm = torch.compile(
mm,
options={
"benchmark_kernel": True,
"max_autotune": True,
"max_autotune_gemm_backends": "TRITON",
"profile_bandwidth": True,
},
)
comp_mm = torch.compile(
mm,
options={
"benchmark_kernel": True,
"max_autotune": True,
"max_autotune_gemm_backends": "TRITON",
},
)
with profile() as prof1:
pb_mm()
with profile() as prof2:
comp_mm()
def names(prof):
return {
ev.name
for ev in prof.events()
if "mm" in ev.name or "triton" in ev.name
}
n1 = names(prof1)
n2 = names(prof2)
self.assertEqual(n1, n2)
if __name__ == "__main__": if __name__ == "__main__":
run_tests() run_tests()

View File

@ -0,0 +1,41 @@
# `torch._inductor.analysis`
Contains scripts for inductor performance analysis.
## Analysis
This will analyze a chrome trace to create a table useful for performance work. We mainly care about . Currently, it will add the flops and the memory reads of a kernel via formula (it's not looking at program counters or anything.) These, combined with the kernel duration, can be use to calculate achieved flops, achieved memory bandwidth, and roofline calculations.
### Usage
```
python profile_analysis.py --analysis <input_json_profile> <default_dtype>
```
### Arguments
- `input_json_profile`: The json profile files generated by `torch.profile.export_chrome_trace()`.
- `default_dtype`: The default dtype of the model. Sometimes the dtypes of the kernel inputs are not available in the profile, so we use the default dtype to infer the dtypes of the inputs.
## Diff
This mode will diff two different profiles and output a table of the differences. It groups by kernel name, which can fail to properly match accross hardware vendors. More intelligent grouping coming soon.
### Usage
```
python profile_analysis.py --diff <json_profile_1> <profile_name_1> <json_profile_2> <profile_name_2> <default_dtype> --name_limit 50
```
### Arguments
- `json_profile_1` `json_profile_2`: The json profile files generated by `torch.profile.export_chrome_trace()`.
- `profile_name_1` `profile_name_2`: The name of the profile. This is used to identify the profile in the output table.
- `default_dtype`: The default dtype of the model. Sometimes the dtypes of the kernel inputs are not available in the profile, so we use the default dtype to infer the dtypes of the inputs.
- `name_limit`: The maximum number of characters in the kernel name (they can be quite lengthly and hard to read).
## Augment
This mode will add post-hoc analysis to a profile. Currently, it will add the flops and the memory reads of a kernel via formula (it's not looking at program counters or anything.) These, combined with the kernel duration, can be use to calculate achieved flops, achieved memory bandwidth, and roofline calculations.
### Usage
```
python profile_analysis.py --augment_trace <input_json_profile> <output_json_profile> <default_dtype>
```
### Arguments
- `input_json_profile`: The json profile files generated by `torch.profile.export_chrome_trace()`.
- `output_json_profile`: Where the augmented profile is written.
- `default_dtype`: The default dtype of the model. Sometimes the dtypes of the kernel inputs are not available in the profile, so we use the default dtype to infer the dtypes of the inputs.

View File

View File

@ -0,0 +1,193 @@
import logging
from dataclasses import dataclass
from typing import Optional, Union
import torch
log = logging.getLogger(__name__)
@dataclass(frozen=True)
class DeviceInfo:
"""
Theoretical Numbers from data sheet. If two numbers are given, Tensor/Matrix Core vs not,
then the higher number is reported. Sparsity is not considered.
Bandwidth numbers are tricky, because there are platform differences that may not show up in the profiler trace.
For example,
"""
tops: dict[Union[torch.dtype, str], float]
dram_bw_gbs: float
dram_gb: float
# Indexing is based on `torch.cuda.get_device_name()`
# TODO investigate profiler support for tf32 and allow device to report correct number when it's turned on.
_device_mapping: dict[str, DeviceInfo] = {
# Source:
# @lint-ignore https://www.nvidia.com/en-us/data-center/h100/
"NVIDIA H100": DeviceInfo(
tops={
torch.float64: 67.0,
torch.float32: 67.5,
"torch.tf32": 156.0,
torch.bfloat16: 1979.0,
torch.float16: 1979.0,
torch.float8_e8m0fnu: 3958.0,
torch.float8_e8m0fnu: 3958.0,
torch.float8_e4m3fnuz: 3958.0,
torch.float8_e5m2: 3958.0,
torch.float8_e5m2fnuz: 3958.0,
torch.float8_e8m0fnu: 3958.0,
torch.int8: 3958.0,
},
dram_bw_gbs=3350,
dram_gb=80,
),
# Source:
# @lint-ignore https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/
# nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
"NVIDIA A100": DeviceInfo(
tops={
torch.float64: 19.5,
torch.float32: 19.5,
torch.bfloat16: 312.5,
torch.float16: 312.5,
# Not in datasheet: float8
torch.int8: 624.0,
"torch.tf32": 156.0,
},
dram_bw_gbs=2039.0,
dram_gb=80.0,
),
# Source:
# @lint-ignore https://resources.nvidia.com/en-us-gpu-resources/l4-tensor-datasheet
"NVIDIA L4": DeviceInfo(
tops={
# This is a guess, not in datasheet
torch.float64: 15.1,
torch.float32: 30.3,
"torch.tf32": 120.0,
torch.bfloat16: 242.0,
torch.float16: 242.0,
torch.float8_e8m0fnu: 485.0,
torch.float8_e8m0fnu: 485.0,
torch.float8_e4m3fnuz: 485.0,
torch.float8_e5m2: 485.0,
torch.float8_e5m2fnuz: 485.0,
torch.float8_e8m0fnu: 485.0,
torch.int8: 485.0,
},
dram_bw_gbs=3350,
dram_gb=24,
),
# Source:
# @lint-ignore https://www.amd.com/content/dam/amd/en/documents\
# /instinct-tech-docs/data-sheets/amd-instinct-mi300a-data-sheet.pdf
"AMD MI300A": DeviceInfo(
tops={
torch.float64: 122.6,
torch.float32: 122.6,
"torch.tf32": 490.3,
torch.bfloat16: 980.6,
torch.float16: 980.6,
torch.float8_e8m0fnu: 1961.2,
torch.float8_e8m0fnu: 1961.2,
torch.float8_e4m3fnuz: 1961.2,
torch.float8_e5m2: 1961.2,
torch.float8_e5m2fnuz: 1961.2,
torch.float8_e8m0fnu: 1961.2,
torch.int8: 1961.2,
},
dram_bw_gbs=5300.0,
dram_gb=128.0,
),
# Source:
# @lint-ignore https://www.amd.com/content/dam/amd/en/documents/\
# instinct-tech-docs/data-sheets/amd-instinct-mi300x-data-sheet.pdf
"AMD MI300X": DeviceInfo(
tops={
torch.float64: 163.4,
torch.float32: 163.4,
"torch.tf32": 653.7,
torch.bfloat16: 1307.4,
torch.float16: 1307.4,
torch.float8_e8m0fnu: 2614.9,
torch.float8_e8m0fnu: 2614.9,
torch.float8_e4m3fnuz: 2614.9,
torch.float8_e5m2: 2614.9,
torch.float8_e5m2fnuz: 2614.9,
torch.float8_e8m0fnu: 2614.9,
torch.int8: 2614.9,
},
dram_bw_gbs=5300.0,
dram_gb=192.0,
),
# Source:
# @lint-ignore https://www.amd.com/content/dam/amd/\
# en/documents/instinct-business-docs/product-briefs/instinct-mi210-brochure.pdf
"AMD MI210X": DeviceInfo(
tops={
torch.float64: 45.3,
torch.float32: 45.3,
# not specified, fall back to float32 numbers
"torch.tf32": 45.3,
torch.bfloat16: 181.0,
torch.float16: 181.0,
# not specified, fall back to float16 numbers
torch.float8_e8m0fnu: 181.0,
torch.float8_e8m0fnu: 181.0,
torch.float8_e4m3fnuz: 181.0,
torch.float8_e5m2: 181.0,
torch.float8_e5m2fnuz: 181.0,
torch.float8_e8m0fnu: 181.0,
torch.int8: 181.0,
},
# pcie4.0x16
dram_bw_gbs=1600.0,
dram_gb=64.0,
),
}
_device_mapping["AMD INSTINCT MI300X"] = _device_mapping["AMD MI300X"]
_device_mapping["AMD INSTINCT MI210X"] = _device_mapping["AMD MI210X"]
def lookup_device_info(name: str) -> Optional[DeviceInfo]:
"""
Problem: when diffing profiles between amd and nvidia, we don't have access to the device information
of the other one. Also, since the analysis is static, we should be able to do it on another device unrelated
to the recorded device. Therefore, _device_mapping statically contains the information for lots of devices.
If one is missing, please run DeviceInfo.get_device_info() and add it to _device_mapping.
name (str): name of the device to lookup. Should map onto torch.cuda.get_device_name().
"""
return _device_mapping.get(name, None)
def datasheet_tops(dtype: torch.dtype, is_tf32: bool = False) -> Optional[float]:
"""
Get the theoretical TFLOPS of the device for a given dtype. This can throw an exception if the device
is not in the datasheet list above.
"""
name: Optional[str] = torch.cuda.get_device_name()
if name is None:
log.info("No device found, returning None")
return None
device_info = lookup_device_info(name)
if device_info is None:
log_str = f"Device {name} not in datasheet, returning None"
log.info(log_str)
return None
if dtype not in device_info.tops:
log.info(
"Device %s does not have a datasheet entry for %s, returning None",
name,
dtype,
)
return None
return device_info.tops[
"torch.tf32" if dtype == torch.float32 and is_tf32 else dtype
]

View File

@ -0,0 +1,717 @@
import json
import logging
import math
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Callable, Optional, Union
import torch
from torch._inductor.analysis.device_info import DeviceInfo, lookup_device_info
from torch._inductor.utils import tabulate_2d, zip_dicts
from torch.utils import _pytree as pytree
from torch.utils._ordered_set import OrderedSet
from torch.utils.flop_counter import flop_registry
log = logging.getLogger(__name__)
ATEN_PREFIX = "aten::"
@dataclass
class ProfileEvent:
category: str
key: str
self_device_time_ms: float
# the benchmark is run multiple times and we average the count across all the
# runs. It should be an integer but define a float just in case.
count: float
# adapters convert the json trace into a format that works with flops_counter
ArgsType = tuple[tuple[Any, ...], dict[Any, Any]]
AdapterType = Callable[[tuple[Any, ...], tuple[Any, ...]], ArgsType]
adapters_map: dict[str, AdapterType] = {}
def parse_list(lst: str) -> list[int]:
lst = lst.replace("[", "").replace("]", "")
substrings = lst.split(",")
return [int(substring.strip()) for substring in substrings]
def register_adapter(
aten: Union[str, list[str]],
) -> Callable[
[AdapterType],
AdapterType,
]:
def decorator(func: AdapterType) -> AdapterType:
global _adapters_map
if isinstance(aten, str):
adapters_map[aten] = func
else:
for at in aten:
adapters_map[at] = func
return func
return decorator
@register_adapter(["_slow_conv2d_forward"])
def _slow_conv2d_adapter(
shapes: tuple[Any, ...], concrete: tuple[Any, ...]
) -> tuple[tuple[Any], dict[Any, Any]]:
tmp = list(shapes)
tmp.append(False)
tmp2 = list(concrete)
tmp2[3] = tmp2[4]
return conv_adapter(tuple(tmp), tuple(tmp2))
@register_adapter(["convolution", "_convolution", "cudnn_convolution"])
def conv_adapter(
shapes: tuple[Any, ...], concrete: tuple[Any, ...]
) -> tuple[tuple[Any], dict[Any, Any]]:
tmp = list(shapes)
if len(tmp) == 4:
transposed = False
else:
transposed = bool(tmp[6])
tmp[6] = transposed
kwargs: dict[Any, Any] = {}
if not transposed:
# calculate output shape if not transposed.
def conv_out_dims(x: int, kernel: int, stride: int) -> int:
return (x - kernel) // stride + 1
stride = parse_list(concrete[3])
inp = shapes[0]
w = shapes[1]
out_x_y = [conv_out_dims(*args) for args in zip(inp[2:], w[2:], stride)]
out = [inp[0], w[0]] + out_x_y # we only need the xy values
kwargs["out_val"] = out
return tuple(tmp), kwargs
def default_adapter(
shapes: tuple[Any], concrete: tuple[Any]
) -> tuple[tuple[Any], dict[Any, Any]]:
return shapes, {}
@register_adapter("addmm")
def addmm_adapter(
shapes: tuple[Any], concrete: tuple[Any]
) -> tuple[tuple[Any], dict[Any, Any]]:
tmp = list(shapes)[:3]
return tuple(tmp), {}
@register_adapter("bmm")
def bmm_adapter(
shapes: tuple[Any], concrete: tuple[Any]
) -> tuple[tuple[Any], dict[Any, Any]]:
tmp = list(shapes)
return tuple(tmp[:2]), {}
@register_adapter("baddbmm")
def baddbmm_adapter(
shapes: tuple[Any], concrete: tuple[Any]
) -> tuple[tuple[Any], dict[Any, Any]]:
tmp = list(shapes)[:3]
return tuple(tmp), {}
@register_adapter("mm")
def mm_adapter(
shapes: tuple[Any], concrete: tuple[Any]
) -> tuple[tuple[Any], dict[Any, Any]]:
return shapes, {}
def _parse_kernel_name(name: str) -> Optional[str]:
"""
parse the name of the kernel from the event name.
"""
if name.startswith(ATEN_PREFIX):
return name[len(ATEN_PREFIX) :]
elif "conv" in name:
return "convolution"
elif "addmm" in name:
return "addmm"
elif "bmm" in name:
return "bmm"
elif "baddbmm" in name:
return "baddbmm"
elif "_mm" in name:
return "mm"
else:
return None
def _calculate_flops(event: dict[str, Any]) -> int:
"""
This function has to parse the kernel name, which is error prone. There doesn't seem to be another solution that
will support all the different backends that can generate kernels, so make sure to update this function when new
ops and backends are desired.
"""
name = event["name"]
if "kernel_flop" in event["args"] and event["args"]["kernel_flop"] != 0:
return event["args"]["kernel_flop"]
op_name = _parse_kernel_name(name)
if op_name is None:
return 0
op_obj = getattr(torch.ops.aten, op_name, None)
if op_obj is None or op_obj not in flop_registry:
return 0
flop_function = flop_registry[op_obj]
assert "Input Dims" in event["args"] and "Concrete Inputs" in event["args"]
input_shapes = event["args"]["Input Dims"]
concrete = event["args"]["Concrete Inputs"]
if op_name in adapters_map:
args, kwargs = adapters_map[op_name](input_shapes, concrete)
else:
args, kwargs = default_adapter(input_shapes, concrete)
return flop_function(*args, **kwargs)
def _get_size_from_string(type_string: str) -> int:
if not hasattr(torch, type_string):
return 1
else:
return getattr(torch, type_string).itemsize
def _default_estimate_gb(event: dict[str, Any]) -> float:
sizes_and_types = zip(event["args"]["Input Dims"], event["args"]["Input type"])
bw = 0
for size, typ in sizes_and_types:
isize = _get_size_from_string(typ)
bw += isize * math.prod(pytree.tree_flatten(size)[0])
return bw / 1e9
def _estimate_gb(event: dict[str, Any]) -> float:
"""
Our best effort to estimate the gb, should be refactored soon with MemoryCounter.
"""
name = event["name"]
if "kernel_num_gb" in event["args"] and event["args"]["kernel_num_gb"] != 0:
return event["args"]["kernel_num_gb"]
if "Input type" not in event["args"] or "Input Dims" not in event["args"]:
return 0
op_name = _parse_kernel_name(name)
if op_name is None:
return _default_estimate_gb(event)
op_obj = getattr(torch.ops.aten, op_name, None)
if op_obj is None:
return _default_estimate_gb(event)
assert "Input Dims" in event["args"] and "Concrete Inputs" in event["args"]
input_shapes = event["args"]["Input Dims"]
# NOTE these will be refactored into a similar object to FlopCounter soon
def mm_formula(M: int, N: int, K: int, size: int) -> int:
return 2 * (M * K + N * K + M * N) * size
if op_name == "addmm":
add_in_size = math.prod(pytree.tree_flatten(input_shapes[0])[0])
add_type_size = _get_size_from_string(event["args"]["Input type"][0])
M = input_shapes[1][0]
N = input_shapes[1][1]
assert input_shapes[1][1] == input_shapes[2][0]
K = input_shapes[2][1]
mul_type_size = _get_size_from_string(event["args"]["Input type"][1])
return (mm_formula(M, N, K, mul_type_size) + add_in_size * add_type_size) / 1e9
elif op_name == "mm":
M = input_shapes[0][0]
N = input_shapes[0][1]
assert input_shapes[0][1] == input_shapes[1][0]
K = input_shapes[1][1]
type_size = _get_size_from_string(event["args"]["Input type"][0])
return mm_formula(M, N, K, type_size) / 1e9
elif op_name == "baddbmm":
add_in_size = math.prod(pytree.tree_flatten(input_shapes[0])[0])
add_type_size = _get_size_from_string(event["args"]["Input type"][0])
B = input_shapes[0][0]
M = input_shapes[1][1]
N = input_shapes[1][2]
K = input_shapes[2][2]
mul_type_size = _get_size_from_string(event["args"]["Input type"][1])
return (
B * mm_formula(M, N, K, mul_type_size) + add_in_size * add_type_size
) / 1e9
elif op_name == "bmm":
add_in_size = math.prod(pytree.tree_flatten(input_shapes[0])[0])
add_type_size = _get_size_from_string(event["args"]["Input type"][0])
B = input_shapes[0][0]
M = input_shapes[0][1]
N = input_shapes[0][2]
K = input_shapes[1][2]
mul_type_size = _get_size_from_string(event["args"]["Input type"][1])
return (
B * mm_formula(M, N, K, mul_type_size) + add_in_size * add_type_size
) / 1e9
elif op_name in [
"convolution",
"_convolution",
"cudnn_convolution",
"_slow_conv2d_forward",
]:
concrete = event["args"]["Concrete Inputs"]
def conv_out_dim(x: int, kernel: int, stride: int) -> int:
return (x - kernel) // stride + 1
stride = parse_list(
concrete[3] if op_name != "_slow_conv2d_forward" else concrete[4]
)
inp = input_shapes[0]
w = input_shapes[1]
out_x_y = [conv_out_dim(*args) for args in zip(inp[2:], w[2:], stride)]
out = [inp[0], w[0]] + out_x_y
# each output element reads in * w * w chunk
input_reads = out[0] * out[1] * out[2] * out[3] * inp[1] * w[2] * w[3]
# Assume weights are in cache, so only read once
weight_reads = w[0] * w[1] * w[2] * w[3]
return (input_reads + weight_reads) / 1e9
return _default_estimate_gb(event)
def _create_extern_mapping(
data: dict[str, Any],
) -> defaultdict[int, list[dict[str, Any]]]:
"""
compute a mapping from exteral ids to non kernels, which contain the information we need to estimate flops etc
"""
extern_mapping: defaultdict[int, list[dict[str, Any]]] = defaultdict(list)
for event in data["traceEvents"]:
if (
"args" not in event
or "External id" not in event["args"]
or event["cat"] != "cpu_op"
):
continue
if len(extern_mapping[event["args"]["External id"]]) > 0:
raise ParseException("duplicate external id in event")
extern_mapping[event["args"]["External id"]].append(event)
return extern_mapping
def _augment_trace_helper(data: dict[str, Any]) -> dict[str, Any]:
extern_mapping = _create_extern_mapping(data)
for event in data["traceEvents"]:
if "cat" not in event or event["cat"] != "kernel":
continue
if "args" not in event:
raise ParseException(f"kernel has no args: {event}")
if "External id" not in event["args"]:
event_str = f"kernel has no External id: {event}"
log.info(event_str)
continue
external_op = extern_mapping[event["args"]["External id"]][0]
flops = _calculate_flops(external_op)
if flops == 0:
flops = _calculate_flops(event)
external_op["args"]["kernel_flop"] = flops
external_op["args"]["kernel_num_gb"] = _estimate_gb(external_op)
event["args"]["kernel_flop"] = external_op["args"]["kernel_flop"]
event["args"]["kernel_num_gb"] = external_op["args"]["kernel_num_gb"]
return data
_dtype_map = {
"float": torch.float,
"float32": torch.float,
"int": torch.int,
"int8": torch.int8,
"int16": torch.int16,
"int32": torch.int,
"long": torch.long,
"long int": torch.long,
"bfloat16": torch.bfloat16,
"float16": torch.float16,
"float64": torch.double,
}
@dataclass(frozen=True)
class KernelStats:
flops: int
bw: float
latency: float # us
achieved_flops: float
achieved_bandwidth: float
KernelNameMap = defaultdict[str, OrderedSet[KernelStats]]
@dataclass(frozen=False)
class Device:
name: str
index: int
info: Optional[DeviceInfo]
stats: KernelNameMap
def __repr__(self) -> str:
return f"Device({self.name}, {self.index}): {self.info}"
DeviceMap = dict[int, Device]
Table = tuple[list[str], dict[str, list[str]]]
class JsonProfile:
_devices: DeviceMap
def __init__(
self,
path: str,
benchmark_name: Optional[str] = None,
dtype: Optional[Union[torch.dtype, str]] = None,
):
"""
Convienence class for running common operations on chrome/perfetto json traces.
"""
self.path = path
with open(path) as f:
self.data = json.load(f)
self.events = self.data["traceEvents"]
self.benchmark_name = benchmark_name
if dtype is None:
self.dtype = None
elif isinstance(dtype, torch.dtype):
self.dtype = dtype
else:
if dtype in _dtype_map:
self.dtype = _dtype_map[dtype]
else:
self.dtype = None
self._create_devices()
def convert_dtype(self, event: dict[str, Any]) -> Optional[torch.dtype]:
"""
Each op has a list of dtypes for each input arg. We need to convert these into a single dtype for flop estimation.
Issues:
- converting the strings to concrete torch.dtypes
- What if we have float32, float, float16 all in the inputs? Our choice is to use the largest buffer dtype.
"""
if (
"Input Dims" not in event["args"]
or "Input type" not in event["args"]
or "Concrete Inputs" not in event["args"]
):
if "bfloat16" in event["name"]:
return torch.bfloat16
elif "float16" in event["name"]:
return torch.float16
else:
return None
input_sizes = event["args"]["Input Dims"]
input_types = event["args"]["Input type"]
concrete_inputs = event["args"]["Concrete Inputs"]
assert len(input_sizes) == len(input_types)
assert len(input_types) == len(concrete_inputs)
if len(input_sizes) == 0:
raise RuntimeError("Empty input_sizes and input_types")
biggest_size = 0
biggest_index = 0
for i in range(len(input_sizes)):
if concrete_inputs[i] != "":
# concrete inputs are usually small tensors, so we can just skip
continue
my_size = input_sizes[i]
total_size = sum(parse_list(my_size))
if total_size > biggest_size:
biggest_size = total_size
biggest_index = i
ret_type = input_types[biggest_index]
if ret_type in _dtype_map:
return _dtype_map[ret_type]
raise RuntimeError(f"Unknown type: {ret_type}. Please add to _dtype_map.")
def _create_devices(self) -> None:
self._devices = {}
for dev in self.data["deviceProperties"]:
name = dev["name"]
device_info = lookup_device_info(name)
if device_info is None:
log.info(
"Unsupported device in profile: %s, please consider contributing to _device_mapping.",
name,
)
self._devices[dev["id"]] = Device(
name, dev["id"], device_info, defaultdict(OrderedSet)
)
def calculate_flops(self, event: dict[str, Any]) -> int:
return _calculate_flops(event)
def estimate_gb(self, event: dict[str, Any]) -> float:
return _estimate_gb(event)
def augment_trace(self) -> None:
self.data = _augment_trace_helper(self.data)
def _compute_stats(self) -> None:
"""populates the name -> stats map"""
for event in self.events:
if "cat" not in event or "args" not in event or event["cat"] != "kernel":
continue
dev = self._devices[event["args"]["device"]]
dur = event["dur"] # us
if "kernel_flop" in event["args"]:
assert dur != 0
# 1,000,000us/s * flop / us
op_flops = event["args"]["kernel_flop"] / (dur / 1e6)
else:
op_flops = 0
if "kernel_num_gb" in event["args"]:
assert dur != 0
# 1,000,000us/s * gb = gb/s
op_gbps = event["args"]["kernel_num_gb"] / (dur / 1e6)
else:
op_gbps = 0
if dev.info is not None:
dtype = self.convert_dtype(event) or self.dtype
if dtype is None:
raise RuntimeError(
"dtype is not found on tensor and default dtype is not set"
)
achieved_flops = 100 * op_flops / (1e12 * dev.info.tops[dtype])
achieved_bandwidth = 100 * op_gbps / dev.info.dram_bw_gbs
else:
achieved_flops = 0
achieved_bandwidth = 0
dev.stats[event["name"]].add(
KernelStats(
flops=op_flops,
bw=op_gbps,
latency=dur,
achieved_bandwidth=achieved_bandwidth,
achieved_flops=achieved_flops,
)
)
def _create_single_table(self, dev: Device) -> Table:
"""Create a table with the devices mapped to indices."""
headers = [
"Kernel Name",
"Kernel Count",
"FLOPS",
"Kernel Reads (GB)",
"Dur (us)",
"Achieved FLOPS %",
"Achieved Bandwidth %",
]
rows: dict[str, list[str]] = {}
def safe_div_format(x: float, y: float) -> str:
if y == 0:
return "0.0"
return f"{x / y:.4f}"
for kernel_name, stats_set in dev.stats.items():
ker_count = 0
flops = 0
flops_count = 0
achieved_flops = 0.0
bw = 0.0
bw_count = 0
achieved_bandwidth = 0.0
latency = 0.0
for stats in stats_set:
if stats.flops != 0:
flops += stats.flops
achieved_flops += stats.achieved_flops
flops_count += 1
if stats.bw != 0:
bw += stats.bw
achieved_bandwidth += stats.achieved_bandwidth
bw_count += 1
latency += stats.latency
ker_count += 1
assert ker_count != 0
rows[kernel_name] = [
str(ker_count),
safe_div_format(flops, flops_count),
safe_div_format(bw, bw_count),
safe_div_format(latency, ker_count),
safe_div_format(achieved_flops, flops_count),
safe_div_format(achieved_bandwidth, bw_count),
]
return headers, rows
def _create_tables(self, devs: DeviceMap) -> dict[int, Table]:
return {idx: self._create_single_table(dev) for idx, dev in devs.items()}
def _combine_tables(
self, table1: Table, table1_name: str, table2: Table, table2_name: str
) -> Table:
new_headers = (
["Kernel Name"]
+ [f"{table1_name} {head}" for head in table1[0][1:]]
+ [f"{table2_name} {head}" for head in table2[0][1:]]
)
t1_length = len(table1[0][1:])
t2_length = len(table2[0][1:])
new_rows = {}
for key, row1, row2 in zip_dicts(
table1[1],
table2[1],
d1_default=["Empty"] * t1_length,
d2_default=["Empty"] * t2_length,
):
assert row1 is not None
assert row2 is not None
new_rows[key] = row1 + row2
return new_headers, new_rows
def report(
self, other: Optional["JsonProfile"] = None, name_limit: int = 40
) -> str:
def create_ret(
table_headers: list[str], table_rows: dict[str, list[str]]
) -> str:
table_flattened = [
[kernel_name[:name_limit], *kernel_vals]
for kernel_name, kernel_vals in table_rows.items()
]
return tabulate_2d(table_flattened, headers=table_headers)
if other is not None:
self._compute_stats()
other._compute_stats()
self_tables = self._create_tables(self._devices)
other_tables = self._create_tables(other._devices)
self_name = (
self.benchmark_name if self.benchmark_name is not None else "Table 1"
)
other_name = (
other.benchmark_name if other.benchmark_name is not None else "Table 2"
)
ret = []
assert self._devices.keys() == other._devices.keys()
for device_idx, t1, t2 in zip_dicts(
self_tables, other_tables, d1_default=None, d2_default=None
):
assert t1 is not None
assert t2 is not None
table_headers, table_rows = self._combine_tables(
t1, self_name, t2, other_name
)
tab_string = create_ret(table_headers, table_rows)
ret.append(f"{self._devices[device_idx]}:\n{tab_string}")
return "\n".join(ret)
self._compute_stats()
self_tables = self._create_tables(self._devices)
ret = []
for idx, table in self_tables.items():
table_headers, table_rows = table
tab_string = create_ret(table_headers, table_rows)
ret.append(f"{self._devices[idx]}:\n{tab_string}")
return "\n".join(ret)
def dump(self, out: str) -> None:
with open(out, "w") as f:
json.dump(self.data, f)
class ParseException(RuntimeError):
pass
def main() -> None:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--diff",
nargs=5,
metavar=(
"input_file1",
"name1",
"input_file2",
"name2",
"dtype",
),
help="Two json traces to compare with, specified as <file1> <name1> <file2> <name2> <dtype>",
)
parser.add_argument(
"--name_limit",
type=int,
help="the maximum name size in the final report",
)
parser.add_argument(
"--augment_trace",
"-a",
nargs=3,
metavar=("input_file", "output_file", "dtype"),
help="Augment a trace with inductor meta information. Provide input and output file paths.",
)
parser.add_argument(
"--analysis",
nargs=2,
metavar=("input_file", "dtype"),
help="Run analysis on a single trace, specified as <file> <dtype>",
)
args = parser.parse_args()
if args.diff:
p1 = JsonProfile(args.diff[0], args.diff[1], dtype=args.diff[4])
p1.augment_trace()
p2 = JsonProfile(args.diff[2], args.diff[3], dtype=args.diff[4])
p2.augment_trace()
if args.name_limit:
print(p1.report(p2, name_limit=args.name_limit))
else:
print(p1.report(p2))
if args.analysis:
p1 = JsonProfile(
args.analysis[0],
dtype=args.analysis[1],
)
p1.augment_trace()
if args.name_limit:
print(p1.report(name_limit=args.name_limit))
else:
print(p1.report())
if args.augment_trace:
p = JsonProfile(args.augment_trace[0], dtype=args.augment_trace[2])
p.augment_trace()
p.dump(args.augment_trace[1])
if __name__ == "__main__":
main()

View File

@ -970,6 +970,13 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
return tuple(map(fn, value)) return tuple(map(fn, value))
return fn(value) return fn(value)
def estimate_flops(self) -> Optional[int]:
flops = [
node.estimate_flops()
for node in NodeScheduleMarker.only_nodes(self.features.node_schedule)
]
return sum(filter(None, flops))
def estimate_kernel_num_bytes(self): def estimate_kernel_num_bytes(self):
""" """
Try the best to estimate the total size (in bytes) of the Try the best to estimate the total size (in bytes) of the
@ -1616,7 +1623,9 @@ class SIMDScheduling(BaseScheduling):
kernel.cse.invalidate(OrderedSet()) kernel.cse.invalidate(OrderedSet())
if not isinstance(partial_code, str): if not isinstance(partial_code, str):
partial_code.finalize_hook("<DEF_KERNEL>") # This is used to calculate flops in TritonTemplateKernels
with ir.IRNode.current_origins(template_node.node.origins):
partial_code.finalize_hook("<DEF_KERNEL>")
partial_code.finalize_hook("<ARGDEFS>", strict=False) partial_code.finalize_hook("<ARGDEFS>", strict=False)
# finalize must be called after adding epilogue above # finalize must be called after adding epilogue above

View File

@ -3686,7 +3686,12 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
num_gb = None num_gb = None
if config.benchmark_kernel or config.profile_bandwidth: if config.benchmark_kernel or config.profile_bandwidth:
num_gb = self.estimate_kernel_num_bytes() / 1e9 num_gb = self.estimate_kernel_num_bytes() / 1e9
inductor_meta["kernel_num_gb"] = num_gb if num_gb is not None:
inductor_meta["kernel_num_gb"] = num_gb
if config.benchmark_kernel:
flops = self.estimate_flops()
if flops is not None:
inductor_meta["kernel_flop"] = flops
triton_meta["configs"] = [config_of(signature)] triton_meta["configs"] = [config_of(signature)]

View File

@ -314,7 +314,7 @@ def is_node_realized(node: torch.fx.Node) -> bool:
def count_flops_fx(node: torch.fx.Node) -> Optional[int]: def count_flops_fx(node: torch.fx.Node) -> Optional[int]:
if isinstance(node.target, str): if not countable_fx(node) or isinstance(node.target, str):
return None return None
with FakeTensorMode(allow_non_fake_inputs=True): with FakeTensorMode(allow_non_fake_inputs=True):
success, args, kwargs = get_fake_args_kwargs(node) success, args, kwargs = get_fake_args_kwargs(node)

View File

@ -65,6 +65,7 @@ from .exc import (
MissingOperatorWithDecomp, MissingOperatorWithDecomp,
MissingOperatorWithoutDecomp, MissingOperatorWithoutDecomp,
) )
from .fx_utils import count_flops_fx
from .ir import ( from .ir import (
Constant, Constant,
DonatedBuffer, DonatedBuffer,
@ -659,32 +660,24 @@ class GraphLowering(torch.fx.Interpreter):
# only grouped convolutions benchmarked as slower in conv samples for inference only # only grouped convolutions benchmarked as slower in conv samples for inference only
if is_inference: if is_inference:
from torch.utils.flop_counter import FlopCounterMode
flop_counts: dict[str, float] = defaultdict(float) flop_counts: dict[str, float] = defaultdict(float)
for node in conv_nodes: for node in conv_nodes:
success, args, kwargs = torch._inductor.fx_utils.get_fake_args_kwargs( counted_flops = count_flops_fx(node)
node if counted_flops is None:
) continue
if success: if is_grouped(node):
with FlopCounterMode(display=False) as flop_counter_mode: node_type = "grouped"
with V.fake_mode: elif is_small_channel(node):
node.target(*args, **kwargs) node_type = "small"
elif is_in_out_channel(node):
counted_flops = flop_counter_mode.get_total_flops() node_type = "in_out"
if is_grouped(node):
node_type = "grouped"
elif is_small_channel(node):
node_type = "small"
elif is_in_out_channel(node):
node_type = "in_out"
else:
node_type = "default"
flop_counts[node_type] += counted_flops
else: else:
log.debug("Conv inputs meta not found") node_type = "default"
flop_counts[node_type] += counted_flops
else:
log.debug("Conv inputs meta not found")
# average benchmarked channels last speedup / slowdown, < 1 is speedup. # average benchmarked channels last speedup / slowdown, < 1 is speedup.
# taken from the set of convolution inputs in benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/ # taken from the set of convolution inputs in benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/

View File

@ -785,14 +785,29 @@ class CachingAutotuner(KernelInterface):
# reset to zero before evaluating any config # reset to zero before evaluating any config
self.reset_to_zero_args(*args, **kwargs) self.reset_to_zero_args(*args, **kwargs)
args_with_constexprs = self._get_args_with_constexprs(cloned_args, launcher) args_with_constexprs = self._get_args_with_constexprs(cloned_args, launcher)
launcher( if autograd_profiler._is_profiler_enabled:
*args_with_constexprs, profiler_kwargs = self.get_profiler_kwargs(stream, launcher)
**cloned_kwargs, with torch._C._profiler._RecordFunctionFast(
stream=stream, self.inductor_meta.get("kernel_name", "triton kernel"),
) args_with_constexprs,
profiler_kwargs,
):
launcher(
*args_with_constexprs,
**cloned_kwargs,
stream=stream,
)
else:
launcher(
*args_with_constexprs,
**cloned_kwargs,
stream=stream,
)
self.restore_args_from_cpu(cpu_copies) self.restore_args_from_cpu(cpu_copies)
if with_profiler: # only use profiler when not already in a profiler instance
if with_profiler and not autograd_profiler._is_profiler_enabled:
from torch._inductor.utils import do_bench_using_profiling from torch._inductor.utils import do_bench_using_profiling
return do_bench_using_profiling(kernel_call, warmup=10, rep=40) return do_bench_using_profiling(kernel_call, warmup=10, rep=40)
@ -1100,6 +1115,28 @@ class CachingAutotuner(KernelInterface):
).make_launcher() ).make_launcher()
return config2launcher[best_config] return config2launcher[best_config]
def get_profiler_kwargs(self, stream, launcher):
kernel_kwargs_str = ",".join(
f"{k}={v}" for (k, v) in launcher.config.kwargs.items()
)
ret = {
"kernel_file": (self.filename or ""),
"kernel_hash": self.kernel_hash,
"kernel_backend": "triton",
"stream": stream,
"num_warps": launcher.config.num_warps,
"num_stages": launcher.config.num_stages,
"kernel_kwargs": kernel_kwargs_str,
}
if "kernel_name" in self.inductor_meta:
ret["kernel_name"] = self.inductor_meta["kernel_name"]
if "kernel_flop" in self.inductor_meta:
ret["kernel_flop"] = self.inductor_meta["kernel_flop"]
if "kernel_num_gb" in self.inductor_meta:
ret["kernel_num_gb"] = self.inductor_meta["kernel_num_gb"]
return ret
def run( def run(
self, self,
*args, *args,
@ -1152,19 +1189,7 @@ class CachingAutotuner(KernelInterface):
# it is faster than entering and exiting a context manager, even if the context # it is faster than entering and exiting a context manager, even if the context
# manager is a nullcontext. # manager is a nullcontext.
if autograd_profiler._is_profiler_enabled: if autograd_profiler._is_profiler_enabled:
kernel_kwargs_str = ",".join( profiler_kwargs = self.get_profiler_kwargs(stream, launcher)
f"{k}={v}" for (k, v) in launcher.config.kwargs.items()
)
profiler_kwargs = {
"kernel_file": (self.filename or ""),
"kernel_hash": self.kernel_hash,
"kernel_backend": "triton",
"stream": stream,
"num_warps": launcher.config.num_warps,
"num_stages": launcher.config.num_stages,
"kernel_kwargs": kernel_kwargs_str,
}
with torch._C._profiler._RecordFunctionFast( with torch._C._profiler._RecordFunctionFast(
self.inductor_meta.get("kernel_name", "triton kernel"), self.inductor_meta.get("kernel_name", "triton kernel"),

View File

@ -40,7 +40,7 @@ from .codegen.common import BackendFeature, get_scheduling_for_device, Kernel
from .comm_analysis import estimate_nccl_collective_runtime from .comm_analysis import estimate_nccl_collective_runtime
from .dependencies import Dep, MemoryDep, StarDep, WeakDep from .dependencies import Dep, MemoryDep, StarDep, WeakDep
from .exc import GPUTooOldForTriton, TritonMissing from .exc import GPUTooOldForTriton, TritonMissing
from .fx_utils import count_flops_fx, countable_fx from .fx_utils import count_flops_fx
from .ir import ( from .ir import (
get_device_type, get_device_type,
GraphPartitionSignature, GraphPartitionSignature,
@ -793,12 +793,12 @@ class BaseSchedulerNode:
fx_node = self.node.get_origin_node() fx_node = self.node.get_origin_node()
if fx_node is None: if fx_node is None:
return None return None
if not countable_fx(fx_node):
return None
flops = count_flops_fx(fx_node) flops = count_flops_fx(fx_node)
if flops is None:
return None
resolved_flops = V.graph.sizevars.size_hints((flops,), fallback=0)[0] resolved_flops = V.graph.sizevars.size_hint(flops, fallback=0)
counters["inductor"]["flop_count"] += resolved_flops counters["inductor"]["flop_count"] += resolved_flops
return resolved_flops return resolved_flops

View File

@ -60,6 +60,7 @@ from .codegen.triton import (
from .codegen.triton_utils import config_of, equal_1_arg_indices, signature_to_meta from .codegen.triton_utils import config_of, equal_1_arg_indices, signature_to_meta
from .codegen.wrapper import pexpr from .codegen.wrapper import pexpr
from .exc import CUDACompileError from .exc import CUDACompileError
from .fx_utils import count_flops_fx
from .ir import ChoiceCaller, PrimitiveInfoType from .ir import ChoiceCaller, PrimitiveInfoType
from .ops_handler import StoreMode from .ops_handler import StoreMode
from .runtime.benchmarking import benchmarker from .runtime.benchmarking import benchmarker
@ -481,12 +482,20 @@ class TritonTemplateKernel(TritonKernel):
ninplace_args = len(unique(self.args.inplace_buffers.values())) ninplace_args = len(unique(self.args.inplace_buffers.values()))
num_bytes = [] num_bytes = []
for i, inp in enumerate(itertools.chain(self.input_nodes, (self.output_node,))): for i, inp in enumerate(itertools.chain(self.input_nodes, (self.output_node,))):
size = V.graph.sizevars.size_hints(inp.get_size()) size = V.graph.sizevars.size_hints(inp.get_size(), fallback=0)
numel = functools.reduce(operator.mul, size, 1) numel = functools.reduce(operator.mul, size, 1)
dtype_size = get_dtype_size(inp.get_dtype()) dtype_size = get_dtype_size(inp.get_dtype())
num_bytes.append(numel * dtype_size * (1 + int(i < ninplace_args))) num_bytes.append(numel * dtype_size * (1 + int(i < ninplace_args)))
return sum(num_bytes) return sum(num_bytes)
def estimate_flops(self) -> int:
for node in self.input_nodes:
for fx_node in node._current_origins:
f = count_flops_fx(fx_node)
if f is not None:
return V.graph.sizevars.size_hint(f, fallback=0)
return 0
def jit_lines(self): def jit_lines(self):
if self.use_jit: if self.use_jit:
return "@triton.jit" return "@triton.jit"
@ -525,6 +534,9 @@ class TritonTemplateKernel(TritonKernel):
if config.profile_bandwidth or config.benchmark_kernel: if config.profile_bandwidth or config.benchmark_kernel:
num_gb = self.estimate_kernel_num_bytes() / 1e9 num_gb = self.estimate_kernel_num_bytes() / 1e9
inductor_meta["kernel_num_gb"] = num_gb inductor_meta["kernel_num_gb"] = num_gb
if config.benchmark_kernel:
flops = self.estimate_flops()
inductor_meta["kernel_flop"] = flops
template_args = f""" template_args = f"""
num_stages={self.num_stages}, num_stages={self.num_stages},

View File

@ -22,7 +22,14 @@ import tempfile
import textwrap import textwrap
import time import time
import unittest import unittest
from collections.abc import Collection, Iterator, Mapping, MutableMapping, MutableSet from collections.abc import (
Collection,
Generator,
Iterator,
Mapping,
MutableMapping,
MutableSet,
)
from datetime import datetime from datetime import datetime
from io import StringIO from io import StringIO
from typing import ( from typing import (
@ -51,6 +58,7 @@ from unittest import mock
import sympy import sympy
import torch import torch
from torch._inductor.analysis.device_info import datasheet_tops
from torch._inductor.runtime.hints import DeviceProperties from torch._inductor.runtime.hints import DeviceProperties
from torch.utils._ordered_set import OrderedSet from torch.utils._ordered_set import OrderedSet
from torch.utils._pytree import tree_map_only from torch.utils._pytree import tree_map_only
@ -2156,9 +2164,19 @@ def get_backend_num_stages() -> int:
@functools.cache @functools.cache
def get_device_tflops(dtype: torch.dtype) -> int: def get_device_tflops(dtype: torch.dtype) -> float:
"""
We don't want to throw errors in this function. First check to see if the device is in device_info.py,
then fall back to the inaccurate triton estimation.
"""
ds_tops = datasheet_tops(dtype, is_tf32=torch.backends.cuda.matmul.allow_tf32)
if ds_tops is not None:
return ds_tops
from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops
from torch.testing._internal.common_cuda import SM80OrLater
assert dtype in (torch.float16, torch.bfloat16, torch.float32) assert dtype in (torch.float16, torch.bfloat16, torch.float32)
if inspect.signature(get_max_simd_tflops).parameters.get("clock_rate"): if inspect.signature(get_max_simd_tflops).parameters.get("clock_rate"):
@ -2166,7 +2184,7 @@ def get_device_tflops(dtype: torch.dtype) -> int:
from torch._utils_internal import max_clock_rate from torch._utils_internal import max_clock_rate
sm_clock = max_clock_rate() sm_clock = max_clock_rate()
if dtype in (torch.float16, torch.bfloat16): if dtype in (torch.float16, torch.bfloat16) and SM80OrLater:
return get_max_tensorcore_tflops(dtype, sm_clock) return get_max_tensorcore_tflops(dtype, sm_clock)
if torch.backends.cuda.matmul.allow_tf32: if torch.backends.cuda.matmul.allow_tf32:
@ -2174,7 +2192,7 @@ def get_device_tflops(dtype: torch.dtype) -> int:
else: else:
return get_max_simd_tflops(torch.float32, sm_clock) return get_max_simd_tflops(torch.float32, sm_clock)
else: else:
if dtype in (torch.float16, torch.bfloat16): if dtype in (torch.float16, torch.bfloat16) and SM80OrLater:
return get_max_tensorcore_tflops(dtype) return get_max_tensorcore_tflops(dtype)
if torch.backends.cuda.matmul.allow_tf32: if torch.backends.cuda.matmul.allow_tf32:
@ -3204,3 +3222,54 @@ def is_mkldnn_fp16_supported(device_type: str) -> bool:
# match "xpu", "xpu:0", "xpu:1", etc. # match "xpu", "xpu:0", "xpu:1", etc.
return True return True
return False return False
def tabulate_2d(elements: Sequence[Sequence[T]], headers: Sequence[T]) -> str:
widths = [len(str(e)) for e in headers]
for row in elements:
assert len(row) == len(headers)
for i, e in enumerate(row):
widths[i] = max(widths[i], len(str(e)))
lines = []
lines.append("|".join(f" {h:{w}} " for h, w in zip(headers, widths)))
# widths whitespace horizontal separators
total_width = sum(widths) + (len(widths) * 2) + (len(widths) - 1)
lines.append("-" * total_width)
for row in elements:
lines.append("|".join(f" {e:{w}} " for e, w in zip(row, widths)))
return "\n".join(lines)
def zip_dicts(
dict1: Mapping[KeyType, ValType],
dict2: Mapping[KeyType, ValType],
d1_default: ValType | None = None,
d2_default: ValType | None = None,
) -> Generator[tuple[KeyType, ValType | None, ValType | None], None, None]:
"""
Zip two dictionaries together, replacing missing keys with default values.
Args:
dict1 (dict): The first dictionary.
dict2 (dict): The second dictionary.
d1_default (Any): the default value for the first dictionary
d2_default (Any): the default value for the second dictionary
Yields:
tuple: A tuple containing the key, the value from dict1 (or d1_default if missing),
and the value from dict2 (or d2_default if missing).
"""
# Find the union of all keys
all_keys = OrderedSet(dict1.keys()) | OrderedSet(dict2.keys())
# Iterate over all keys
for key in all_keys:
# Get the values from both dictionaries, or default if missing
value1 = dict1.get(key)
value2 = dict2.get(key)
yield (
key,
value1 if value1 is not None else d1_default,
value2 if value2 is not None else d2_default,
)

View File

@ -1,8 +1,8 @@
import argparse import argparse
import dataclasses
import datetime import datetime
import tempfile import tempfile
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass
from types import ModuleType from types import ModuleType
from typing import Any, Optional, Protocol from typing import Any, Optional, Protocol
@ -159,7 +159,7 @@ def benchmark_all_kernels(
) )
@dataclasses.dataclass @dataclass
class ProfileEvent: class ProfileEvent:
category: str category: str
key: str key: str
@ -176,6 +176,10 @@ def parse_profile_event_list(
nruns: int, nruns: int,
device_name: str, device_name: str,
) -> None: ) -> None:
"""
Parse and generate a report for an event_list.
"""
def get_self_device_time( def get_self_device_time(
ev: torch.autograd.profiler_util.EventList, ev: torch.autograd.profiler_util.EventList,
) -> float: ) -> float:
@ -295,6 +299,10 @@ def parse_profile_event_list(
report() report()
PROFILE_DIR = tempfile.gettempdir()
PROFILE_PATH = f"{PROFILE_DIR}/compiled_module_profile.json"
def perf_profile( def perf_profile(
wall_time_ms: float, wall_time_ms: float,
times: int, times: int,
@ -305,14 +313,14 @@ def perf_profile(
with torch.profiler.profile(record_shapes=True) as p: with torch.profiler.profile(record_shapes=True) as p:
benchmark_compiled_module_fn(times=times, repeat=repeat) benchmark_compiled_module_fn(times=times, repeat=repeat)
path = f"{tempfile.gettempdir()}/compiled_module_profile.json" path = PROFILE_PATH
p.export_chrome_trace(path) p.export_chrome_trace(path)
print(f"Profiling result for a compiled module of benchmark {benchmark_name}:") print(f"Profiling result for a compiled module of benchmark {benchmark_name}:")
print(f"Chrome trace for the profile is written to {path}") print(f"Chrome trace for the profile is written to {path}")
event_list = p.key_averages(group_by_input_shape=True) event_list = p.key_averages(group_by_input_shape=True)
print(event_list.table(sort_by="self_device_time_total", row_limit=10)) print(event_list.table(sort_by="self_device_time_total", row_limit=10))
parse_profile_event_list( parse_profile_event_list(
benchmark_name, event_list, wall_time_ms, times * repeat, p.use_device benchmark_name, event_list, wall_time_ms, times * repeat, p.use_device or ""
) )

View File

@ -213,6 +213,9 @@ def is_fb_unit_test() -> bool:
@functools.cache @functools.cache
def max_clock_rate(): def max_clock_rate():
"""
unit: MHz
"""
if not torch.version.hip: if not torch.version.hip:
from triton.testing import nvsmi from triton.testing import nvsmi

View File

@ -127,7 +127,6 @@ def conv_flop_count(
Returns: Returns:
int: the number of flops int: the number of flops
""" """
batch_size = x_shape[0] batch_size = x_shape[0]
conv_shape = (x_shape if transposed else out_shape)[2:] conv_shape = (x_shape if transposed else out_shape)[2:]
c_out, c_in, *filter_size = w_shape c_out, c_in, *filter_size = w_shape
@ -146,7 +145,7 @@ def conv_flop_count(
flop = prod(conv_shape) * prod(filter_size) * batch_size * c_out * c_in * 2 flop = prod(conv_shape) * prod(filter_size) * batch_size * c_out * c_in * 2
return flop return flop
@register_flop_formula([aten.convolution, aten._convolution]) @register_flop_formula([aten.convolution, aten._convolution, aten.cudnn_convolution, aten._slow_conv2d_forward])
def conv_flop(x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, *args, out_shape=None, **kwargs) -> int: def conv_flop(x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, *args, out_shape=None, **kwargs) -> int:
"""Count flops for convolution.""" """Count flops for convolution."""
return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed) return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
@ -561,6 +560,8 @@ flop_registry = {
aten._scaled_mm: _scaled_mm_flop, aten._scaled_mm: _scaled_mm_flop,
aten.convolution: conv_flop, aten.convolution: conv_flop,
aten._convolution: conv_flop, aten._convolution: conv_flop,
aten.cudnn_convolution: conv_flop,
aten._slow_conv2d_forward: conv_flop,
aten.convolution_backward: conv_backward_flop, aten.convolution_backward: conv_backward_flop,
aten._scaled_dot_product_efficient_attention: sdpa_flop, aten._scaled_dot_product_efficient_attention: sdpa_flop,
aten._scaled_dot_product_flash_attention: sdpa_flop, aten._scaled_dot_product_flash_attention: sdpa_flop,