Files
pytorch/test/inductor/test_analysis.py
Gabriel Ferns 47f10d0ad0 Inductor logging + analysis of torch.profile (#149697)
Prereqs:
 - https://github.com/pytorch/pytorch/pull/152708

Features:
1. Adds inductor's estimate of flops and bandwidth to the json trace events that perfetto uses.
1. Only use the tflops estimation from triton if we don't have the info from the datasheet because Triton's estimates are inaccurate. I have a backlog item to fix triton flops estimation upstream. New `DeviceInfo` class, and new function `get_device_tflops`.
1. New helpers `countable_fx` and `count_flops_fx` helps get the flops of an `fx.Node`.
1. Extends Triton `torch.profiler` logging to `DebugAutotuner`.
1. New script `profile_analysis.py`: `--augment_trace` adds perf estimates to any perfetto json trace, `--analyze` creates a summary table of these perf estimates, and `--diff` will compare two traces side by side:
```python
Device(NVIDIA H100, 0):
 Kernel Name                              | resnet Kernel Count | resnet FLOPS       | resnet bw gbps        | resnet Dur (ms)    | resnet Achieved FLOPS % | resnet Achieved Bandwidth % | newresnet Kernel Count | newresnet FLOPS    | newresnet bw gbps     | newresnet Dur (ms) | newresnet Achieved FLOPS % | newresnet Achieved Bandwidth %
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
 triton_poi_fused__native_batch_norm_legi | 24                  | 0                  | 0.11395268248131513   | 2.5919166666666666 | 0                       | 0.003401572611382541        | 24                     | 0                  | 0.11395268248131513   | 2.5919166666666666 | 0                          | 0.003401572611382541
 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 142                 | 16932673552.422373 | 0.2585007824198784    | 12.441619718309857 | 0.08683422334575583     | 0.007716441266265022        | 142                    | 16932673552.422373 | 0.2585007824198784    | 12.441619718309857 | 0.08683422334575583        | 0.007716441266265022
 triton_red_fused__native_batch_norm_legi | 39                  | 0                  | 0.13990024992108846   | 5.752589743589743  | 0                       | 0.004176126863316074        | 39                     | 0                  | 0.13990024992108846   | 5.752589743589743  | 0                          | 0.004176126863316074
 triton_poi_fused__native_batch_norm_legi | 25                  | 0                  | 0.31824055917536503   | 2.5291999999999994 | 0                       | 0.009499718184339253        | 25                     | 0                  | 0.31824055917536503   | 2.5291999999999994 | 0                          | 0.009499718184339253
 void cutlass::Kernel2<cutlass_80_tensoro | 98                  | 16211056473.596165 | 0.42972434051025826   | 7.130408163265306  | 0.08313362294151874     | 0.012827592254037562        | 98                     | 16211056473.596165 | 0.42972434051025826   | 7.130408163265306  | 0.08313362294151874        | 0.012827592254037562
 triton_red_fused__native_batch_norm_legi | 73                  | 0                  | 0.3225381327611705    | 9.987068493150682  | 0                       | 0.009628003963020014        | 73                     | 0                  | 0.3225381327611705    | 9.987068493150682  | 0                          | 0.009628003963020014
 triton_poi_fused__native_batch_norm_legi | 15                  | 0                  | 1.4491211346487216    | 4.439333333333333  | 0                       | 0.043257347302946926        | 15                     | 0                  | 1.4491211346487216    | 4.439333333333333  | 0                          | 0.043257347302946926
 void cutlass::Kernel2<cutlass_80_tensoro | 186                 | 14501701145.337954 | 0.2667131401910989    | 7.873865591397849  | 0.07436769818122027     | 0.007961586274361157        | 186                    | 14501701145.337954 | 0.2667131401910989    | 7.873865591397849  | 0.07436769818122027        | 0.007961586274361157
 triton_poi_fused__native_batch_norm_legi | 33                  | 0                  | 1.4924556538193923    | 4.3101515151515155 | 0                       | 0.044550915039384846        | 33                     | 0                  | 1.4924556538193923    | 4.3101515151515155 | 0                          | 0.044550915039384846
 triton_red_fused__native_batch_norm_legi | 29                  | 0                  | 0.25562590522631107   | 6.296275862068965  | 0                       | 0.007630624036606301        | 29                     | 0                  | 0.25562590522631107   | 6.296275862068965  | 0                          | 0.007630624036606301
 triton_poi_fused__native_batch_norm_legi | 13                  | 0                  | 0.5870562174192726    | 2.7397692307692307 | 0                       | 0.01752406619162008         | 13                     | 0                  | 0.5870562174192726    | 2.7397692307692307 | 0                          | 0.01752406619162008
 triton_poi_fused__native_batch_norm_legi | 34                  | 0                  | 0.41409928846284      | 2.853588235294117  | 0                       | 0.012361172789935523        | 34                     | 0                  | 0.41409928846284      | 2.853588235294117  | 0                          | 0.012361172789935523
 triton_per_fused__native_batch_norm_legi | 34                  | 0                  | 0.11705315007018151   | 3.460647058823529  | 0                       | 0.0034941238826919864       | 34                     | 0                  | 0.11705315007018151   | 3.460647058823529  | 0                          | 0.0034941238826919864
 triton_poi_fused__native_batch_norm_legi | 16                  | 0                  | 0.17207853197124584   | 2.3459375000000002 | 0                       | 0.005136672596156592        | 16                     | 0                  | 0.17207853197124584   | 2.3459375000000002 | 0                          | 0.005136672596156592
 triton_per_fused__native_batch_norm_legi | 30                  | 0                  | 0.2639714322022256    | 6.131199999999999  | 0                       | 0.007879744244842555        | 30                     | 0                  | 0.2639714322022256    | 6.131199999999999  | 0                          | 0.007879744244842555
 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 100                 | 11875430356.891787 | 0.19494470869421385   | 16.36534           | 0.06089964285585531     | 0.005819245035648175        | 100                    | 11875430356.891787 | 0.19494470869421385   | 16.36534           | 0.06089964285585531        | 0.005819245035648175
 triton_poi_fused__native_batch_norm_legi | 8                   | 0                  | 0.9854096626224687    | 3.2757500000000004 | 0                       | 0.029415213809625928        | 8                      | 0                  | 0.9854096626224687    | 3.2757500000000004 | 0                          | 0.029415213809625928
 void cublasLt::splitKreduce_kernel<32, 1 | 56                  | 34377923395.147064 | 0.8310300045762317    | 3.4199999999999986 | 0.17629704305203628     | 0.024806865808245714        | 56                     | 34377923395.147064 | 0.8310300045762317    | 3.4199999999999986 | 0.17629704305203628        | 0.024806865808245714
 triton_poi_fused__native_batch_norm_legi | 23                  | 0                  | 0.9944002965861103    | 3.2431304347826084 | 0                       | 0.02968359094286896         | 23                     | 0                  | 0.9944002965861103    | 3.2431304347826084 | 0                          | 0.02968359094286896
 triton_per_fused__native_batch_norm_legi | 10                  | 0                  | 0.1826801058931057    | 4.428800000000001  | 0                       | 0.00545313748934644         | 10                     | 0                  | 0.1826801058931057    | 4.428800000000001  | 0                          | 0.00545313748934644
 triton_poi_fused__native_batch_norm_legi | 10                  | 0                  | 0.3168973585366449    | 2.5471999999999997 | 0                       | 0.009459622642884923        | 10                     | 0                  | 0.3168973585366449    | 2.5471999999999997 | 0                          | 0.009459622642884923
 triton_poi_fused__native_batch_norm_legi | 34                  | 0                  | 1.1463614897015777    | 4.124323529411764  | 0                       | 0.03421974596124114         | 34                     | 0                  | 1.1463614897015777    | 4.124323529411764  | 0                          | 0.03421974596124114
 void cask_plugin_cudnn::xmma_cudnn::init | 44                  | 44045510816.64277  | 2.0661232850348643    | 3.6887499999999993 | 0.22587441444432194     | 0.06167532194133924         | 44                     | 44045510816.64277  | 2.0661232850348643    | 3.6887499999999993 | 0.22587441444432194        | 0.06167532194133924
 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 95                  | 7876855400.165316  | 0.4694941555946739    | 18.224315789473682 | 0.04039413025725802     | 0.014014750913273854        | 95                     | 7876855400.165316  | 0.4694941555946739    | 18.224315789473682 | 0.04039413025725802        | 0.014014750913273854
 triton_per_fused__native_batch_norm_legi | 41                  | 0                  | 0.06825669875995298   | 3.0384146341463416 | 0                       | 0.002037513395819492        | 41                     | 0                  | 0.06825669875995298   | 3.0384146341463416 | 0                          | 0.002037513395819492
 triton_poi_fused__native_batch_norm_legi | 23                  | 0                  | 0.08808154712430301   | 2.3275652173913044 | 0                       | 0.0026292999141582997       | 23                     | 0                  | 0.08808154712430301   | 2.3275652173913044 | 0                          | 0.0026292999141582997
 triton_per_fused__native_batch_norm_legi | 40                  | 0                  | 0.18179321034952417   | 4.556825           | 0                       | 0.005426662995508183        | 40                     | 0                  | 0.18179321034952417   | 4.556825           | 0                          | 0.005426662995508183
 triton_poi_fused__native_batch_norm_legi | 15                  | 0                  | 0.5887415155454232    | 2.783866666666667  | 0                       | 0.017574373598370836        | 15                     | 0                  | 0.5887415155454232    | 2.783866666666667  | 0                          | 0.017574373598370836
 void cutlass::Kernel2<cutlass_80_tensoro | 38                  | 14242013806.264643 | 0.256592404353939     | 7.217631578947369  | 0.0730359682372546      | 0.007659474756834           | 38                     | 14242013806.264643 | 0.256592404353939     | 7.217631578947369  | 0.0730359682372546         | 0.007659474756834
 triton_poi_fused__native_batch_norm_legi | 21                  | 0                  | 0.5842860973430516    | 2.7779047619047623 | 0                       | 0.017441376040091088        | 21                     | 0                  | 0.5842860973430516    | 2.7779047619047623 | 0                          | 0.017441376040091088
 triton_per_fused__native_batch_norm_legi | 16                  | 0                  | 0.11509365173486417   | 3.5959375000000002 | 0                       | 0.0034356313950705724       | 16                     | 0                  | 0.11509365173486417   | 3.5959375000000002 | 0                          | 0.0034356313950705724
 triton_poi_fused__native_batch_norm_legi | 14                  | 0                  | 0.1704672000243914    | 2.4044285714285714 | 0                       | 0.00508857313505646         | 14                     | 0                  | 0.1704672000243914    | 2.4044285714285714 | 0                          | 0.00508857313505646
 triton_poi_fused__native_batch_norm_legi | 58                  | 0                  | 2.307520779930795     | 8.190706896551722  | 0                       | 0.06888121731136704         | 58                     | 0                  | 2.307520779930795     | 8.190706896551722  | 0                          | 0.06888121731136704
 triton_per_fused__native_batch_norm_legi | 29                  | 0                  | 0.037243248971881276  | 3.0277586206896556 | 0                       | 0.001111738775280038        | 29                     | 0                  | 0.037243248971881276  | 3.0277586206896556 | 0                          | 0.001111738775280038
 triton_poi_fused__native_batch_norm_legi | 20                  | 0                  | 0.04741699795428918   | 2.2911500000000005 | 0                       | 0.0014154327747549007       | 20                     | 0                  | 0.04741699795428918   | 2.2911500000000005 | 0                          | 0.0014154327747549007
 triton_per_fused__native_batch_norm_legi | 25                  | 0                  | 0.13357016893727824   | 3.37536            | 0                       | 0.003987169222008305        | 25                     | 0                  | 0.13357016893727824   | 3.37536            | 0                          | 0.003987169222008305
 triton_poi_fused__native_batch_norm_legi | 13                  | 0                  | 0.3089862268300253    | 2.8111538461538457 | 0                       | 0.009223469457612694        | 13                     | 0                  | 0.3089862268300253    | 2.8111538461538457 | 0                          | 0.009223469457612694
 triton_poi_fused__native_batch_norm_legi | 17                  | 0                  | 0.3129385387909844    | 2.673              | 0                       | 0.009341448919133863        | 17                     | 0                  | 0.3129385387909844    | 2.673              | 0                          | 0.009341448919133863
 triton_per_fused__native_batch_norm_legi | 19                  | 0                  | 0.2215568162533158    | 3.8837368421052636 | 0                       | 0.0066136363060691275       | 19                     | 0                  | 0.2215568162533158    | 3.8837368421052636 | 0                          | 0.0066136363060691275
 std::enable_if<!(false), void>::type int | 23                  | 504916805.19297093 | 1.0118296096314707    | 8.113913043478261  | 0.0025893169497075447   | 0.030203868944223014        | 23                     | 504916805.19297093 | 1.0118296096314707    | 8.113913043478261  | 0.0025893169497075447      | 0.030203868944223014
 triton_poi_fused_add_copy__38            | 56                  | 0                  | 0                     | 2.132482142857143  | 0                       | 0                           | 56                     | 0                  | 0                     | 2.132482142857143  | 0                          | 0
 triton_poi_fused_convolution_0           | 18                  | 0                  | 0.43458610794936897   | 2.773333333333334  | 0                       | 0.012972719640279667        | 18                     | 0                  | 0.43458610794936897   | 2.773333333333334  | 0                          | 0.012972719640279667
 triton_poi_fused_convolution_1           | 17                  | 0                  | 0.028816312469162712  | 2.6145882352941174 | 0                       | 0.0008601884319153051       | 17                     | 0                  | 0.028816312469162712  | 2.6145882352941174 | 0                          | 0.0008601884319153051
 void convolve_common_engine_float_NHWC<f | 44                  | 8641868995.31118   | 0.024730540008465626  | 25.87327272727273  | 0.04431727689903169     | 0.0007382250748795709       | 44                     | 8641868995.31118   | 0.024730540008465626  | 25.87327272727273  | 0.04431727689903169        | 0.0007382250748795709
 triton_per_fused__native_batch_norm_legi | 12                  | 0                  | 0.6809930918986744    | 4.82675            | 0                       | 0.020328151996975356        | 12                     | 0                  | 0.6809930918986744    | 4.82675            | 0                          | 0.020328151996975356
 triton_per_fused__native_batch_norm_legi | 14                  | 0                  | 0.02883030597936608   | 2.6651428571428575 | 0                       | 0.0008606061486377935       | 14                     | 0                  | 0.02883030597936608   | 2.6651428571428575 | 0                          | 0.0008606061486377935
 triton_per_fused__native_batch_norm_legi | 16                  | 0                  | 0.0014658988233201874 | 2.098              | 0                       | 4.375817383045335e-05       | 16                     | 0                  | 0.0014658988233201874 | 2.098              | 0                          | 4.375817383045335e-05
 triton_poi_fused__native_batch_norm_legi | 13                  | 0                  | 0.9926297180284697    | 3.2367692307692306 | 0                       | 0.02963073785159611         | 13                     | 0                  | 0.9926297180284697    | 3.2367692307692306 | 0                          | 0.02963073785159611
 triton_poi_fused__native_batch_norm_legi | 9                   | 0                  | 1.3008817095666507    | 3.0863333333333336 | 0                       | 0.03883228983781048         | 9                      | 0                  | 1.3008817095666507    | 3.0863333333333336 | 0                          | 0.03883228983781048
 void at::native::(anonymous namespace):: | 98                  | 0                  | 0.09174335613709389   | 4.408520408163265  | 0                       | 0.0027386076458833994       | 98                     | 0                  | 0.09174335613709389   | 4.408520408163265  | 0                          | 0.0027386076458833994
 void at::native::vectorized_elementwise_ | 7                   | 0                  | 0                     | 1.7278571428571428 | 0                       | 0                           | 7                      | 0                  | 0                     | 1.7278571428571428 | 0                          | 0
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149697
Approved by: https://github.com/eellison, https://github.com/shunting314
2025-07-01 16:51:03 +00:00

609 lines
20 KiB
Python

# Owner(s): ["module: inductor"]
import json
import tempfile
import unittest
import uuid
from io import StringIO
from unittest.mock import patch
import torch
import torch.nn.functional as F
from torch._inductor.analysis.profile_analysis import (
_augment_trace_helper,
_create_extern_mapping,
JsonProfile,
main,
)
from torch._inductor.utils import fresh_inductor_cache, tabulate_2d, zip_dicts
from torch.testing._internal.common_cuda import SM80OrLater
from torch.testing._internal.common_device_type import (
dtypes,
instantiate_device_type_tests,
skipIf,
)
from torch.testing._internal.common_utils import parametrize, run_tests, TestCase
from torch.testing._internal.inductor_utils import IS_BIG_GPU
example_profile = """
{
"schemaVersion": 1,
"deviceProperties": [
{
"id": 0, "name": "NVIDIA H100", "totalGlobalMem": 101997215744,
"computeMajor": 9, "computeMinor": 0,
"maxThreadsPerBlock": 1024, "maxThreadsPerMultiprocessor": 2048,
"regsPerBlock": 65536, "warpSize": 32,
"sharedMemPerBlock": 49152, "numSms": 132
, "regsPerMultiprocessor": 65536, "sharedMemPerBlockOptin": 232448, "sharedMemPerMultiprocessor": 233472
}
],
"cupti_version": 24,
"cuda_runtime_version": 12060,
"with_flops": 1,
"record_shapes": 1,
"cuda_driver_version": 12040,
"profile_memory": 1,
"trace_id": "301995E163ED42048FBD783860E6E7DC",
"displayTimeUnit": "ms",
"baseTimeNanoseconds": 1743521598000000000,
"traceEvents": [
{
"ph": "X", "cat": "cpu_op", "name": "aten::convolution", "pid": 1147039, "tid": 1147039,
"ts": 198093488368.463, "dur": 425.453,
"args": {
"External id": 1340,"Sequence number": 0, "Fwd thread id": 0, "Record function id": 0, "Concrete Inputs": \
["", "", "", "[2, 2]", "[3, 3]", "[1, 1]", "False", "[0, 0]", "1"], "Input type": ["float", "float", "", \
"ScalarList", "ScalarList", "ScalarList", "Scalar", "ScalarList", "Scalar"], "Input Strides": [[150528, 1, 672, 3],\
[147, 1, 21, 3], [], [], [], [], [], [], []], "Input Dims": [[1, 3, 224, 224], [64, 3, 7, 7], [], [], [], [], [], \
[], []], "Ev Idx": 1339
}
},
{
"ph": "X", "cat": "cpu_op", "name": "aten::_convolution", "pid": 1147039, "tid": 1147039,
"ts": 198093488444.498, "dur": 341.867,
"args": {
"External id": 1341,"Record function id": 0, "Concrete Inputs": ["", "", "", "[2, 2]", "[3, 3]", "[1, 1]",\
"False", "[0, 0]", "1", "False", "False", "True", "True"], "Input type": ["float", "float", "", "ScalarList",\
"ScalarList", "ScalarList", "Scalar", "ScalarList", "Scalar", "Scalar", "Scalar", "Scalar", "Scalar"], "Input Strides": \
[[150528, 1, 672, 3], [147, 1, 21, 3], [], [], [], [], [], [], [], [], [], [], []], "Input Dims": [[1, 3, 224, 224], \
[64, 3, 7, 7], [], [], [], [], [], [], [], [], [], [], []], "Ev Idx": 1340
}
},
{
"ph": "X", "cat": "cpu_op", "name": "aten::addmm", "pid": 1147039, "tid": 1147039,
"ts": 198093513655.849, "dur": 251.130,
"args": {
"External id": 1619,"Sequence number": 0, "Fwd thread id": 0, "Record function id": 0, "Concrete Inputs": \
["", "", "", "1", "1", ""], "Input type": ["float", "float", "float", "Scalar", "Scalar", "float"], "Input Strides":\
[[1], [0, 1], [1, 2048], [], [], [1000, 1]], "Input Dims": [[1000], [1, 2048], [2048, 1000], [], [], [1, 1000]], \
"Ev Idx": 1618
}
},
{
"ph": "X", "cat": "kernel", "name": "void cutlass_addmm", "pid": 1147039, "tid": 1147039,
"ts": 198093513655.849, "dur": 251.130,
"args": {
"External id": 1619,"Sequence number": 0, "Fwd thread id": 0, "Record function id": 0, "Ev Idx": 1618
}
},
{
"ph": "X", "cat": "kernel", "name": "void convolution_kernel", "pid": 1147039, "tid": 1147039,
"ts": 198093513655.849, "dur": 200.130,
"args": {
"External id": 1342, "Sequence number": 0, "Fwd thread id": 0, "Record function id": 0, "Ev Idx": 1618
}
},
{
"ph": "X", "cat": "cpu_op", "name": "aten::convolution", "pid": 1147039, "tid": 1147039,
"ts": 198093488444.498, "dur": 341.867,
"args": {
"External id": 1342,"Record function id": 0, "Concrete Inputs": ["", "", "", "[2, 2]", "[3, 3]", "[1, 1]", \
"False", "[0, 0]", "1", "False", "False", "True", "True"], "Input type": ["float", "float", "", "ScalarList", \
"ScalarList", "ScalarList", "Scalar", "ScalarList", "Scalar", "Scalar", "Scalar", "Scalar", "Scalar"], "Input \
Strides": [[150528, 1, 672, 3], [147, 1, 21, 3], [], [], [], [], [], [], [], [], [], [], []], "Input Dims": \
[[1, 3, 224, 224], [64, 3, 7, 7], [], [], [], [], [], [], [], [], [], [], []], "Ev Idx": 1340
}
}
],
"traceName": "/tmp/compiled_module_profile.json"
}
"""
def verify_flops(self, expected_flops, out_profile):
j = 0
for i in range(len(out_profile["traceEvents"])):
if "kernel_flop" in out_profile["traceEvents"][i]["args"]:
self.assertEqual(
out_profile["traceEvents"][i]["args"]["kernel_flop"],
expected_flops[j],
)
j += 1
def random_tensor(size, dtype, **kwargs):
if dtype in [torch.half, torch.bfloat16, torch.float, torch.double]:
return torch.randn(size, dtype=dtype, **kwargs)
elif dtype in [torch.uint8, torch.int8, torch.short, torch.int, torch.long]:
return torch.randint(0, 100, size, dtype=dtype, **kwargs)
else:
raise ValueError("Unsupported data type")
def cT(device, dtype):
def T(*shape, requires_grad=False):
return random_tensor(
shape, requires_grad=requires_grad, device=device, dtype=dtype
)
return T
def FlopCounterMode(*args, **kwargs):
return torch.utils.flop_counter.FlopCounterMode(*args, **kwargs, display=False)
TMP_DIR = tempfile.mkdtemp()
def trace_files():
TRACE1 = f"{TMP_DIR}/trace1-{uuid.uuid4()}.json"
TRACE2 = f"{TMP_DIR}/trace2-{uuid.uuid4()}.json"
return TRACE1, TRACE2
def _test_model(device, dtype, compile=True, addmm=True, bmm=True):
T = cT(device, dtype)
def model():
input_conv = T(1, 3, 56, 56)
conv_weight = T(12, 3, 5, 5)
# Increased matrix sizes
B = 8
M = 256
N = 512
K = 768
mat1 = T(M, N)
mat2 = T(N, K)
batch_mat1 = T(B, M, N)
batch_mat2 = T(B, N, K)
conv_output = F.conv2d(input_conv, conv_weight)
conv_output = conv_output * 10
mm_output = torch.mm(mat1, mat2)
ret = [
conv_output.flatten(),
mm_output.flatten(),
]
if addmm:
addmm_output = torch.addmm(
torch.zeros(mm_output.shape, device=mat1.device, dtype=mat1.dtype),
mat1,
mat2,
)
ret.append(addmm_output.flatten())
if bmm:
bmm_output = torch.bmm(batch_mat1, batch_mat2)
ret.append(bmm_output.flatten())
if bmm and addmm:
baddbmm_output = torch.baddbmm(
torch.zeros(
1,
*mm_output.shape,
device=batch_mat1.device,
dtype=batch_mat1.dtype,
),
batch_mat1,
batch_mat2,
)
ret.append(baddbmm_output.flatten())
return torch.cat(ret)
if compile:
return torch.compile(
model, options={"benchmark_kernel": True, "profile_bandwidth": True}
)
return model
def _pointwise_test_model(device, dtype, compile=True):
T = cT(device, dtype)
def model():
M = 1024
N = 512
mat3 = T(M, N)
mat4 = T(M, N)
pointwise_output = torch.add(mat3, mat4).sin()
return pointwise_output
if compile:
return torch.compile(
model, options={"benchmark_kernel": True, "profile_bandwidth": True}
)
return model
prefix = ["profile.py"]
class TestUtils(TestCase):
def test_tabulate2d(self):
headers = ["Kernel", "Self H100 TIME (ms)", "Count", "Percent"]
rows = [
["aten::mm", 0.500, 7, 0.0],
["aten::bmm", 0.400, 6, 0.0],
["aten::baddbmm", 0.300, 5, 0.0],
["aten::convolution", 0.200, 4, 0.0],
["aten::cudnn_convolution", 0.100, 3, 0.0],
]
table = [
" Kernel | Self H100 TIME (ms) | Count | Percent ",
"-----------------------------------------------------------------",
" aten::mm | 0.5 | 7 | 0.0 ",
" aten::bmm | 0.4 | 6 | 0.0 ",
" aten::baddbmm | 0.3 | 5 | 0.0 ",
" aten::convolution | 0.2 | 4 | 0.0 ",
" aten::cudnn_convolution | 0.1 | 3 | 0.0 ",
]
res = tabulate_2d(rows, headers)
for r, t in zip(res.split("\n"), table):
self.assertEqual(r, t)
def test_zip_dicts(self):
d1 = {"a": 1, "b": 2}
d2 = {"a": 3, "c": 4}
res1 = zip_dicts(d1, d2, d1_default=32, d2_default=48)
self.assertEqual(set(res1), {("a", 1, 3), ("b", 2, 48), ("c", 32, 4)})
res2 = zip_dicts(d1, d2)
self.assertEqual(set(res2), {("a", 1, 3), ("b", 2, None), ("c", None, 4)})
class TestAnalysis(TestCase):
@skipIf(not SM80OrLater, "Requires SM80")
def test_noop(self):
with (
patch("sys.stdout", new_callable=StringIO) as mock_stdout,
patch("sys.argv", [*prefix]),
):
main()
self.assertEqual(mock_stdout.getvalue(), "")
@skipIf(not SM80OrLater, "Requires SM80")
@dtypes(torch.float, torch.double, torch.float16)
def test_diff(self, device, dtype):
"""
diff, testing out the nruns feature too.
"""
if device == "cpu":
# TODO cpu support
return
om = _test_model(device, dtype)
REPEAT = 5
trace1, trace2 = trace_files()
print(f"first trace {trace1}")
torch._dynamo.reset() # reset the cache
with fresh_inductor_cache():
with torch.profiler.profile(record_shapes=True) as p:
om()
p.export_chrome_trace(trace1)
print(f"second trace {trace2}")
torch._dynamo.reset() # reset the cache
with fresh_inductor_cache():
with torch.profiler.profile(record_shapes=True) as p:
for _ in range(REPEAT):
om()
p.export_chrome_trace(trace2)
print("diffing...")
with patch(
"sys.argv",
[
*prefix,
"--diff",
trace1,
"foo",
trace2,
"bar",
str(dtype).split(".")[-1],
"--name_limit",
"30",
],
):
main()
@skipIf(not SM80OrLater, "Requires SM80")
def test_augment_trace_helper_unit(self):
js = json.loads(example_profile)
out_profile = _augment_trace_helper(js)
expected_flops = [4096000, 4096000, 223552896, 223552896, 0, 0, 0]
verify_flops(self, expected_flops, out_profile)
@skipIf(not SM80OrLater, "Requires SM80")
@dtypes(torch.float, torch.double, torch.float16)
def test_augment_trace_helper_args(self, device, dtype):
if device == "cpu":
# cpu doesn't produce traces currently
return
om = _test_model(device, dtype)
torch._dynamo.reset() # reset the cache
with fresh_inductor_cache():
with torch.profiler.profile(record_shapes=True) as p:
om()
trace1, trace2 = trace_files()
p.export_chrome_trace(trace1)
print(f"first trace {trace1}")
with patch(
"sys.argv",
[*prefix, "--augment_trace", trace1, trace2, str(dtype).split(".")[-1]],
):
main()
profile = JsonProfile(
trace2, benchmark_name="foo", dtype=str(dtype).split(".")[-1]
)
rep = profile.report()
self.assertTrue(len(rep.split("\n")) > 3, f"Error, empty table:\n{rep}")
# If these fail, just update them. They could change over time
self.assertIn("Kernel Name", rep)
self.assertIn("Kernel Count", rep)
self.assertIn("FLOPS", rep)
self.assertIn("Kernel Reads", rep)
self.assertIn("Dur", rep)
self.assertIn("Achieved", rep)
self.assertIn("|", rep)
self.assertIn("-----", rep)
tables = profile._create_tables(profile._devices)
# check to make sure all % values are less than 100%
percents = []
for tab in tables.values():
header, rows = tab
for i, h in enumerate(header):
if "%" in h:
percents.append(i)
self.assertTrue(len(percents) > 0, "There are no headers with % in them")
for row in rows.values():
for p in percents:
idx = p - 1
self.assertTrue(
float(row[idx]) <= 100.0,
f"column values from column {idx} with header '{header[idx]}' is greater than 100%: {row[idx]}",
)
self.assertTrue(
float(row[idx]) >= 0.0,
f"column values from column {idx} with header '{header[idx]}' is less than 0%: {row[idx]}",
)
@skipIf(not SM80OrLater, "Requires SM80")
@dtypes(torch.float, torch.double, torch.float16)
@parametrize(
"maxat",
[
(True, "TRITON"),
],
)
@skipIf(not IS_BIG_GPU, "we can't use Triton only as a backend for max autotune")
def test_triton_has_metadata(self, device, dtype, maxat):
"""
make sure that the chrome trace of triton kernels contains certain values
"""
if device == "cpu":
return
T = cT(device, dtype)
input_conv = T(1, 3, 56, 56)
conv_weight = T(12, 3, 5, 5)
def om(i, w):
# Convolution operation
conv_output = F.conv2d(i, w)
return conv_output
max_autotune, backends = maxat
comp_omni = torch.compile(
om,
options={
"benchmark_kernel": True,
"max_autotune_gemm_backends": backends,
"force_disable_caches": True,
"max_autotune": max_autotune,
},
)
def verify_triton(comp):
torch._dynamo.reset() # reset the cache
with fresh_inductor_cache():
with torch.profiler.profile(record_shapes=True) as profile:
comp(input_conv, conv_weight)
trace1, _ = trace_files()
profile.export_chrome_trace(trace1)
with open(trace1) as f:
out_profile = json.load(f)
seen = False
for event in out_profile["traceEvents"]:
if "triton" in event["name"] and "conv" in event["name"]:
seen = True
self.assertTrue(seen, "no triton conv found")
verify_triton(comp_omni)
@skipIf(not SM80OrLater, "Requires SM80")
@dtypes(torch.float, torch.float16)
@parametrize(
"maxat",
[
(False, "ATEN,TRITON"),
(True, "ATEN,TRITON"),
(True, "ATEN"),
(True, "TRITON"),
],
)
@unittest.skipIf(
not IS_BIG_GPU, "we can't use Triton only as a backend for max autotune"
)
def test_augment_trace_against_flop_counter(self, device, dtype, maxat):
# this tests to see if we can only use a Triton backend for max autotune
max_autotune, backends = maxat
if device == "cpu":
return
om = _test_model(device, dtype, compile=False)
comp_omni = torch.compile(
om,
options={
"benchmark_kernel": True,
"max_autotune_gemm_backends": backends,
"force_disable_caches": True,
"max_autotune": max_autotune,
},
)
comp_omni()
torch._dynamo.reset() # reset the cache
with fresh_inductor_cache():
with torch.profiler.profile(record_shapes=True) as profile:
comp_omni()
torch._dynamo.reset() # reset the cache
with fresh_inductor_cache():
with FlopCounterMode() as mode:
comp_omni()
trace1, trace2 = trace_files()
profile.export_chrome_trace(trace1)
with patch(
"sys.argv",
[*prefix, "--augment_trace", trace1, trace2, str(dtype).split(".")[-1]],
):
main()
with open(trace2) as f:
out_profile = json.load(f)
flop_counts = mode.flop_counts
extern_mapping = _create_extern_mapping(out_profile)
seen_mm = False
seen_bmm = False
seen_baddbmm = False
seen_conv = False
for event in out_profile["traceEvents"]:
if (
"cat" not in event
or event["cat"] != "kernel"
or "args" not in event
or "External id" not in event["args"]
):
continue
external_op = extern_mapping[event["args"]["External id"]][0]
name: str = external_op["name"]
self.assertNotEqual(name, None)
self.assertEqual(type(name), str)
if name.startswith("aten::mm") or "_mm_" in name:
seen_mm = True
self.assertEqual(
event["args"]["kernel_flop"],
flop_counts["Global"][torch.ops.aten.mm],
)
if (
name.startswith(
(
"aten::cudnn_convolution",
"aten::convolution",
"aten::_convolution",
)
)
or "conv" in name
):
seen_conv = True
self.assertEqual(
event["args"]["kernel_flop"],
flop_counts["Global"][torch.ops.aten.convolution],
)
if name.startswith("aten::baddbmm") or "_baddbmm_" in name:
seen_baddbmm = True
self.assertEqual(
event["args"]["kernel_flop"],
flop_counts["Global"][torch.ops.aten.baddbmm],
)
if name.startswith("aten::bmm") or "_bmm_" in name:
seen_bmm = True
self.assertEqual(
event["args"]["kernel_flop"],
flop_counts["Global"][torch.ops.aten.bmm],
)
self.assertTrue(seen_mm)
self.assertTrue(seen_bmm)
self.assertTrue(seen_baddbmm)
self.assertTrue(seen_conv)
@skipIf(not SM80OrLater, "Requires SM80")
@dtypes(torch.float, torch.float16)
@parametrize(
"maxat",
[
(False, "ATEN,TRITON"),
(True, "ATEN,TRITON"),
(True, "ATEN"),
(True, "TRITON"),
],
)
@unittest.skipIf(
not IS_BIG_GPU, "we can't use Triton only as a backend for max autotune"
)
def test_pointwise_bandwidth(self, device, dtype, maxat):
# this tests to see if we can only use a Triton backend for max autotune
max_autotune, backends = maxat
if device == "cpu":
return
om = _pointwise_test_model(device, dtype, compile=False)
comp_omni = torch.compile(
om,
options={
"benchmark_kernel": True,
"max_autotune_gemm_backends": backends,
"force_disable_caches": True,
"max_autotune": max_autotune,
},
)
comp_omni()
with fresh_inductor_cache():
with torch.profiler.profile(record_shapes=True) as profile:
comp_omni()
trace1, _ = trace_files()
profile.export_chrome_trace(trace1)
with patch(
"sys.argv",
[*prefix, "--analysis", trace1, str(dtype).split(".")[-1]],
):
main()
with open(trace1) as f:
out_profile = json.load(f)
for event in out_profile["traceEvents"]:
if event["name"] == "triton_poi_fused_add_randn_sin_0":
event["args"]["kernel_num_gb"] = 0.002097168
instantiate_device_type_tests(TestAnalysis, globals())
if __name__ == "__main__":
run_tests()