mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Prereqs: - https://github.com/pytorch/pytorch/pull/152708 Features: 1. Adds inductor's estimate of flops and bandwidth to the json trace events that perfetto uses. 1. Only use the tflops estimation from triton if we don't have the info from the datasheet because Triton's estimates are inaccurate. I have a backlog item to fix triton flops estimation upstream. New `DeviceInfo` class, and new function `get_device_tflops`. 1. New helpers `countable_fx` and `count_flops_fx` helps get the flops of an `fx.Node`. 1. Extends Triton `torch.profiler` logging to `DebugAutotuner`. 1. New script `profile_analysis.py`: `--augment_trace` adds perf estimates to any perfetto json trace, `--analyze` creates a summary table of these perf estimates, and `--diff` will compare two traces side by side: ```python Device(NVIDIA H100, 0): Kernel Name | resnet Kernel Count | resnet FLOPS | resnet bw gbps | resnet Dur (ms) | resnet Achieved FLOPS % | resnet Achieved Bandwidth % | newresnet Kernel Count | newresnet FLOPS | newresnet bw gbps | newresnet Dur (ms) | newresnet Achieved FLOPS % | newresnet Achieved Bandwidth % --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- triton_poi_fused__native_batch_norm_legi | 24 | 0 | 0.11395268248131513 | 2.5919166666666666 | 0 | 0.003401572611382541 | 24 | 0 | 0.11395268248131513 | 2.5919166666666666 | 0 | 0.003401572611382541 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 142 | 16932673552.422373 | 0.2585007824198784 | 12.441619718309857 | 0.08683422334575583 | 0.007716441266265022 | 142 | 16932673552.422373 | 0.2585007824198784 | 12.441619718309857 | 0.08683422334575583 | 0.007716441266265022 triton_red_fused__native_batch_norm_legi | 39 | 0 | 0.13990024992108846 | 5.752589743589743 | 0 | 0.004176126863316074 | 39 | 0 | 0.13990024992108846 | 5.752589743589743 | 0 | 0.004176126863316074 triton_poi_fused__native_batch_norm_legi | 25 | 0 | 0.31824055917536503 | 2.5291999999999994 | 0 | 0.009499718184339253 | 25 | 0 | 0.31824055917536503 | 2.5291999999999994 | 0 | 0.009499718184339253 void cutlass::Kernel2<cutlass_80_tensoro | 98 | 16211056473.596165 | 0.42972434051025826 | 7.130408163265306 | 0.08313362294151874 | 0.012827592254037562 | 98 | 16211056473.596165 | 0.42972434051025826 | 7.130408163265306 | 0.08313362294151874 | 0.012827592254037562 triton_red_fused__native_batch_norm_legi | 73 | 0 | 0.3225381327611705 | 9.987068493150682 | 0 | 0.009628003963020014 | 73 | 0 | 0.3225381327611705 | 9.987068493150682 | 0 | 0.009628003963020014 triton_poi_fused__native_batch_norm_legi | 15 | 0 | 1.4491211346487216 | 4.439333333333333 | 0 | 0.043257347302946926 | 15 | 0 | 1.4491211346487216 | 4.439333333333333 | 0 | 0.043257347302946926 void cutlass::Kernel2<cutlass_80_tensoro | 186 | 14501701145.337954 | 0.2667131401910989 | 7.873865591397849 | 0.07436769818122027 | 0.007961586274361157 | 186 | 14501701145.337954 | 0.2667131401910989 | 7.873865591397849 | 0.07436769818122027 | 0.007961586274361157 triton_poi_fused__native_batch_norm_legi | 33 | 0 | 1.4924556538193923 | 4.3101515151515155 | 0 | 0.044550915039384846 | 33 | 0 | 1.4924556538193923 | 4.3101515151515155 | 0 | 0.044550915039384846 triton_red_fused__native_batch_norm_legi | 29 | 0 | 0.25562590522631107 | 6.296275862068965 | 0 | 0.007630624036606301 | 29 | 0 | 0.25562590522631107 | 6.296275862068965 | 0 | 0.007630624036606301 triton_poi_fused__native_batch_norm_legi | 13 | 0 | 0.5870562174192726 | 2.7397692307692307 | 0 | 0.01752406619162008 | 13 | 0 | 0.5870562174192726 | 2.7397692307692307 | 0 | 0.01752406619162008 triton_poi_fused__native_batch_norm_legi | 34 | 0 | 0.41409928846284 | 2.853588235294117 | 0 | 0.012361172789935523 | 34 | 0 | 0.41409928846284 | 2.853588235294117 | 0 | 0.012361172789935523 triton_per_fused__native_batch_norm_legi | 34 | 0 | 0.11705315007018151 | 3.460647058823529 | 0 | 0.0034941238826919864 | 34 | 0 | 0.11705315007018151 | 3.460647058823529 | 0 | 0.0034941238826919864 triton_poi_fused__native_batch_norm_legi | 16 | 0 | 0.17207853197124584 | 2.3459375000000002 | 0 | 0.005136672596156592 | 16 | 0 | 0.17207853197124584 | 2.3459375000000002 | 0 | 0.005136672596156592 triton_per_fused__native_batch_norm_legi | 30 | 0 | 0.2639714322022256 | 6.131199999999999 | 0 | 0.007879744244842555 | 30 | 0 | 0.2639714322022256 | 6.131199999999999 | 0 | 0.007879744244842555 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 100 | 11875430356.891787 | 0.19494470869421385 | 16.36534 | 0.06089964285585531 | 0.005819245035648175 | 100 | 11875430356.891787 | 0.19494470869421385 | 16.36534 | 0.06089964285585531 | 0.005819245035648175 triton_poi_fused__native_batch_norm_legi | 8 | 0 | 0.9854096626224687 | 3.2757500000000004 | 0 | 0.029415213809625928 | 8 | 0 | 0.9854096626224687 | 3.2757500000000004 | 0 | 0.029415213809625928 void cublasLt::splitKreduce_kernel<32, 1 | 56 | 34377923395.147064 | 0.8310300045762317 | 3.4199999999999986 | 0.17629704305203628 | 0.024806865808245714 | 56 | 34377923395.147064 | 0.8310300045762317 | 3.4199999999999986 | 0.17629704305203628 | 0.024806865808245714 triton_poi_fused__native_batch_norm_legi | 23 | 0 | 0.9944002965861103 | 3.2431304347826084 | 0 | 0.02968359094286896 | 23 | 0 | 0.9944002965861103 | 3.2431304347826084 | 0 | 0.02968359094286896 triton_per_fused__native_batch_norm_legi | 10 | 0 | 0.1826801058931057 | 4.428800000000001 | 0 | 0.00545313748934644 | 10 | 0 | 0.1826801058931057 | 4.428800000000001 | 0 | 0.00545313748934644 triton_poi_fused__native_batch_norm_legi | 10 | 0 | 0.3168973585366449 | 2.5471999999999997 | 0 | 0.009459622642884923 | 10 | 0 | 0.3168973585366449 | 2.5471999999999997 | 0 | 0.009459622642884923 triton_poi_fused__native_batch_norm_legi | 34 | 0 | 1.1463614897015777 | 4.124323529411764 | 0 | 0.03421974596124114 | 34 | 0 | 1.1463614897015777 | 4.124323529411764 | 0 | 0.03421974596124114 void cask_plugin_cudnn::xmma_cudnn::init | 44 | 44045510816.64277 | 2.0661232850348643 | 3.6887499999999993 | 0.22587441444432194 | 0.06167532194133924 | 44 | 44045510816.64277 | 2.0661232850348643 | 3.6887499999999993 | 0.22587441444432194 | 0.06167532194133924 sm90_xmma_fprop_implicit_gemm_f32f32_tf3 | 95 | 7876855400.165316 | 0.4694941555946739 | 18.224315789473682 | 0.04039413025725802 | 0.014014750913273854 | 95 | 7876855400.165316 | 0.4694941555946739 | 18.224315789473682 | 0.04039413025725802 | 0.014014750913273854 triton_per_fused__native_batch_norm_legi | 41 | 0 | 0.06825669875995298 | 3.0384146341463416 | 0 | 0.002037513395819492 | 41 | 0 | 0.06825669875995298 | 3.0384146341463416 | 0 | 0.002037513395819492 triton_poi_fused__native_batch_norm_legi | 23 | 0 | 0.08808154712430301 | 2.3275652173913044 | 0 | 0.0026292999141582997 | 23 | 0 | 0.08808154712430301 | 2.3275652173913044 | 0 | 0.0026292999141582997 triton_per_fused__native_batch_norm_legi | 40 | 0 | 0.18179321034952417 | 4.556825 | 0 | 0.005426662995508183 | 40 | 0 | 0.18179321034952417 | 4.556825 | 0 | 0.005426662995508183 triton_poi_fused__native_batch_norm_legi | 15 | 0 | 0.5887415155454232 | 2.783866666666667 | 0 | 0.017574373598370836 | 15 | 0 | 0.5887415155454232 | 2.783866666666667 | 0 | 0.017574373598370836 void cutlass::Kernel2<cutlass_80_tensoro | 38 | 14242013806.264643 | 0.256592404353939 | 7.217631578947369 | 0.0730359682372546 | 0.007659474756834 | 38 | 14242013806.264643 | 0.256592404353939 | 7.217631578947369 | 0.0730359682372546 | 0.007659474756834 triton_poi_fused__native_batch_norm_legi | 21 | 0 | 0.5842860973430516 | 2.7779047619047623 | 0 | 0.017441376040091088 | 21 | 0 | 0.5842860973430516 | 2.7779047619047623 | 0 | 0.017441376040091088 triton_per_fused__native_batch_norm_legi | 16 | 0 | 0.11509365173486417 | 3.5959375000000002 | 0 | 0.0034356313950705724 | 16 | 0 | 0.11509365173486417 | 3.5959375000000002 | 0 | 0.0034356313950705724 triton_poi_fused__native_batch_norm_legi | 14 | 0 | 0.1704672000243914 | 2.4044285714285714 | 0 | 0.00508857313505646 | 14 | 0 | 0.1704672000243914 | 2.4044285714285714 | 0 | 0.00508857313505646 triton_poi_fused__native_batch_norm_legi | 58 | 0 | 2.307520779930795 | 8.190706896551722 | 0 | 0.06888121731136704 | 58 | 0 | 2.307520779930795 | 8.190706896551722 | 0 | 0.06888121731136704 triton_per_fused__native_batch_norm_legi | 29 | 0 | 0.037243248971881276 | 3.0277586206896556 | 0 | 0.001111738775280038 | 29 | 0 | 0.037243248971881276 | 3.0277586206896556 | 0 | 0.001111738775280038 triton_poi_fused__native_batch_norm_legi | 20 | 0 | 0.04741699795428918 | 2.2911500000000005 | 0 | 0.0014154327747549007 | 20 | 0 | 0.04741699795428918 | 2.2911500000000005 | 0 | 0.0014154327747549007 triton_per_fused__native_batch_norm_legi | 25 | 0 | 0.13357016893727824 | 3.37536 | 0 | 0.003987169222008305 | 25 | 0 | 0.13357016893727824 | 3.37536 | 0 | 0.003987169222008305 triton_poi_fused__native_batch_norm_legi | 13 | 0 | 0.3089862268300253 | 2.8111538461538457 | 0 | 0.009223469457612694 | 13 | 0 | 0.3089862268300253 | 2.8111538461538457 | 0 | 0.009223469457612694 triton_poi_fused__native_batch_norm_legi | 17 | 0 | 0.3129385387909844 | 2.673 | 0 | 0.009341448919133863 | 17 | 0 | 0.3129385387909844 | 2.673 | 0 | 0.009341448919133863 triton_per_fused__native_batch_norm_legi | 19 | 0 | 0.2215568162533158 | 3.8837368421052636 | 0 | 0.0066136363060691275 | 19 | 0 | 0.2215568162533158 | 3.8837368421052636 | 0 | 0.0066136363060691275 std::enable_if<!(false), void>::type int | 23 | 504916805.19297093 | 1.0118296096314707 | 8.113913043478261 | 0.0025893169497075447 | 0.030203868944223014 | 23 | 504916805.19297093 | 1.0118296096314707 | 8.113913043478261 | 0.0025893169497075447 | 0.030203868944223014 triton_poi_fused_add_copy__38 | 56 | 0 | 0 | 2.132482142857143 | 0 | 0 | 56 | 0 | 0 | 2.132482142857143 | 0 | 0 triton_poi_fused_convolution_0 | 18 | 0 | 0.43458610794936897 | 2.773333333333334 | 0 | 0.012972719640279667 | 18 | 0 | 0.43458610794936897 | 2.773333333333334 | 0 | 0.012972719640279667 triton_poi_fused_convolution_1 | 17 | 0 | 0.028816312469162712 | 2.6145882352941174 | 0 | 0.0008601884319153051 | 17 | 0 | 0.028816312469162712 | 2.6145882352941174 | 0 | 0.0008601884319153051 void convolve_common_engine_float_NHWC<f | 44 | 8641868995.31118 | 0.024730540008465626 | 25.87327272727273 | 0.04431727689903169 | 0.0007382250748795709 | 44 | 8641868995.31118 | 0.024730540008465626 | 25.87327272727273 | 0.04431727689903169 | 0.0007382250748795709 triton_per_fused__native_batch_norm_legi | 12 | 0 | 0.6809930918986744 | 4.82675 | 0 | 0.020328151996975356 | 12 | 0 | 0.6809930918986744 | 4.82675 | 0 | 0.020328151996975356 triton_per_fused__native_batch_norm_legi | 14 | 0 | 0.02883030597936608 | 2.6651428571428575 | 0 | 0.0008606061486377935 | 14 | 0 | 0.02883030597936608 | 2.6651428571428575 | 0 | 0.0008606061486377935 triton_per_fused__native_batch_norm_legi | 16 | 0 | 0.0014658988233201874 | 2.098 | 0 | 4.375817383045335e-05 | 16 | 0 | 0.0014658988233201874 | 2.098 | 0 | 4.375817383045335e-05 triton_poi_fused__native_batch_norm_legi | 13 | 0 | 0.9926297180284697 | 3.2367692307692306 | 0 | 0.02963073785159611 | 13 | 0 | 0.9926297180284697 | 3.2367692307692306 | 0 | 0.02963073785159611 triton_poi_fused__native_batch_norm_legi | 9 | 0 | 1.3008817095666507 | 3.0863333333333336 | 0 | 0.03883228983781048 | 9 | 0 | 1.3008817095666507 | 3.0863333333333336 | 0 | 0.03883228983781048 void at::native::(anonymous namespace):: | 98 | 0 | 0.09174335613709389 | 4.408520408163265 | 0 | 0.0027386076458833994 | 98 | 0 | 0.09174335613709389 | 4.408520408163265 | 0 | 0.0027386076458833994 void at::native::vectorized_elementwise_ | 7 | 0 | 0 | 1.7278571428571428 | 0 | 0 | 7 | 0 | 0 | 1.7278571428571428 | 0 | 0 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/149697 Approved by: https://github.com/eellison, https://github.com/shunting314
794 lines
28 KiB
Python
794 lines
28 KiB
Python
# mypy: allow-untyped-defs
|
|
import torch
|
|
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
|
from .module_tracker import ModuleTracker
|
|
from typing import Any, Optional, Union, TypeVar, Callable
|
|
from collections.abc import Iterator
|
|
from typing_extensions import ParamSpec
|
|
from collections import defaultdict
|
|
from torch.utils._python_dispatch import TorchDispatchMode
|
|
from math import prod
|
|
from functools import wraps
|
|
import warnings
|
|
|
|
__all__ = ["FlopCounterMode", "register_flop_formula"]
|
|
|
|
_T = TypeVar("_T")
|
|
_P = ParamSpec("_P")
|
|
|
|
aten = torch.ops.aten
|
|
|
|
def get_shape(i):
|
|
if isinstance(i, torch.Tensor):
|
|
return i.shape
|
|
return i
|
|
|
|
flop_registry: dict[Any, Any] = {}
|
|
|
|
def shape_wrapper(f):
|
|
@wraps(f)
|
|
def nf(*args, out_val=None, **kwargs):
|
|
args, kwargs, out_shape = tree_map(get_shape, (args, kwargs, out_val))
|
|
return f(*args, out_shape=out_shape, **kwargs)
|
|
return nf
|
|
|
|
def register_flop_formula(targets, get_raw=False) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
|
def register_fun(flop_formula: Callable[_P, _T]) -> Callable[_P, _T]:
|
|
if not get_raw:
|
|
flop_formula = shape_wrapper(flop_formula)
|
|
|
|
def register(target):
|
|
if not isinstance(target, torch._ops.OpOverloadPacket):
|
|
raise ValueError(
|
|
f"register_flop_formula(targets): expected each target to be "
|
|
f"OpOverloadPacket (i.e. torch.ops.mylib.foo), got "
|
|
f"{target} which is of type {type(target)}")
|
|
if target in flop_registry:
|
|
raise RuntimeError(f"duplicate registrations for {target}")
|
|
flop_registry[target] = flop_formula
|
|
|
|
# To handle allowing multiple aten_ops at once
|
|
torch.utils._pytree.tree_map_(register, targets)
|
|
|
|
return flop_formula
|
|
|
|
return register_fun
|
|
|
|
@register_flop_formula(aten.mm)
|
|
def mm_flop(a_shape, b_shape, *args, out_shape=None, **kwargs) -> int:
|
|
"""Count flops for matmul."""
|
|
# Inputs should be a list of length 2.
|
|
# Inputs contains the shapes of two matrices.
|
|
m, k = a_shape
|
|
k2, n = b_shape
|
|
assert k == k2
|
|
# NB(chilli): Should be 2 * k - 1 technically for FLOPs.
|
|
return m * n * 2 * k
|
|
|
|
@register_flop_formula(aten.addmm)
|
|
def addmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int:
|
|
"""Count flops for addmm."""
|
|
return mm_flop(a_shape, b_shape)
|
|
|
|
@register_flop_formula(aten.bmm)
|
|
def bmm_flop(a_shape, b_shape, out_shape=None, **kwargs) -> int:
|
|
"""Count flops for the bmm operation."""
|
|
# Inputs should be a list of length 2.
|
|
# Inputs contains the shapes of two tensor.
|
|
b, m, k = a_shape
|
|
b2, k2, n = b_shape
|
|
assert b == b2
|
|
assert k == k2
|
|
# NB(chilli): Should be 2 * k - 1 technically for FLOPs.
|
|
flop = b * m * n * 2 * k
|
|
return flop
|
|
|
|
@register_flop_formula(aten.baddbmm)
|
|
def baddbmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int:
|
|
"""Count flops for the baddbmm operation."""
|
|
# Inputs should be a list of length 3.
|
|
# Inputs contains the shapes of three tensors.
|
|
return bmm_flop(a_shape, b_shape)
|
|
|
|
@register_flop_formula(aten._scaled_mm)
|
|
def _scaled_mm_flop(
|
|
a_shape,
|
|
b_shape,
|
|
scale_a_shape,
|
|
scale_b_shape,
|
|
bias_shape=None,
|
|
scale_result_shape=None,
|
|
out_dtype=None,
|
|
use_fast_accum=False,
|
|
out_shape=None,
|
|
**kwargs,
|
|
) -> int:
|
|
"""Count flops for _scaled_mm."""
|
|
return mm_flop(a_shape, b_shape)
|
|
|
|
|
|
def conv_flop_count(
|
|
x_shape: list[int],
|
|
w_shape: list[int],
|
|
out_shape: list[int],
|
|
transposed: bool = False,
|
|
) -> int:
|
|
"""Count flops for convolution.
|
|
|
|
Note only multiplication is
|
|
counted. Computation for bias are ignored.
|
|
Flops for a transposed convolution are calculated as
|
|
flops = (x_shape[2:] * prod(w_shape) * batch_size).
|
|
Args:
|
|
x_shape (list(int)): The input shape before convolution.
|
|
w_shape (list(int)): The filter shape.
|
|
out_shape (list(int)): The output shape after convolution.
|
|
transposed (bool): is the convolution transposed
|
|
Returns:
|
|
int: the number of flops
|
|
"""
|
|
batch_size = x_shape[0]
|
|
conv_shape = (x_shape if transposed else out_shape)[2:]
|
|
c_out, c_in, *filter_size = w_shape
|
|
|
|
"""
|
|
General idea here is that for a regular conv, for each point in the output
|
|
spatial dimension we convolve the filter with something (hence
|
|
`prod(conv_shape) * prod(filter_size)` ops). Then, this gets multiplied by
|
|
1. batch_size, 2. the cross product of input and weight channels.
|
|
|
|
For the transpose, it's not each point in the *output* spatial dimension but
|
|
each point in the *input* spatial dimension.
|
|
"""
|
|
# NB(chilli): I don't think this properly accounts for padding :think:
|
|
# NB(chilli): Should be 2 * c_in - 1 technically for FLOPs.
|
|
flop = prod(conv_shape) * prod(filter_size) * batch_size * c_out * c_in * 2
|
|
return flop
|
|
|
|
@register_flop_formula([aten.convolution, aten._convolution, aten.cudnn_convolution, aten._slow_conv2d_forward])
|
|
def conv_flop(x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, *args, out_shape=None, **kwargs) -> int:
|
|
"""Count flops for convolution."""
|
|
return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
|
|
|
|
|
|
@register_flop_formula(aten.convolution_backward)
|
|
def conv_backward_flop(
|
|
grad_out_shape,
|
|
x_shape,
|
|
w_shape,
|
|
_bias,
|
|
_stride,
|
|
_padding,
|
|
_dilation,
|
|
transposed,
|
|
_output_padding,
|
|
_groups,
|
|
output_mask,
|
|
out_shape) -> int:
|
|
|
|
def t(shape):
|
|
return [shape[1], shape[0]] + list(shape[2:])
|
|
flop_count = 0
|
|
|
|
"""
|
|
Let's say we have a regular 1D conv
|
|
{A, B, C} [inp]
|
|
{i, j} [weight]
|
|
=> (conv)
|
|
{Ai + Bj, Bi + Cj} [out]
|
|
|
|
And as a reminder, the transposed conv of the above is
|
|
=> {Ai, Aj + Bi, Bj + Ci, Cj} [transposed conv out]
|
|
|
|
For the backwards of conv, we now have
|
|
{D, E} [grad_out]
|
|
{A, B, C} [inp]
|
|
{i, j} [weight]
|
|
|
|
# grad_inp as conv_transpose(grad_out, weight)
|
|
Let's first compute grad_inp. To do so, we can simply look at all the
|
|
multiplications that each element of inp is involved in. For example, A is
|
|
only involved in the first element of the output (and thus only depends upon
|
|
D in grad_out), and C is only involved in the last element of the output
|
|
(and thus only depends upon E in grad_out)
|
|
|
|
{Di, Dj + Ei, Ej} [grad_inp]
|
|
|
|
Note that this corresponds to the below conv_transpose. This gives us the
|
|
output_mask[0] branch, which is grad_inp.
|
|
|
|
{D, E} [inp (grad_out)]
|
|
{i, j} [weight]
|
|
=> (conv_transpose)
|
|
{Di, Dj + Ei, Ej} [out (grad_inp)]
|
|
|
|
I leave the fact that grad_inp for a transposed conv is just conv(grad_out,
|
|
weight) as an exercise for the reader.
|
|
|
|
# grad_weight as conv(inp, grad_out)
|
|
To compute grad_weight, we again look at the terms in the output, which as
|
|
a reminder is:
|
|
=> {Ai + Bj, Bi + Cj} [out]
|
|
=> {D, E} [grad_out]
|
|
If we manually compute the gradient for the weights, we see it's
|
|
{AD + BE, BD + CE} [grad_weight]
|
|
|
|
This corresponds to the below conv
|
|
{A, B, C} [inp]
|
|
{D, E} [weight (grad_out)]
|
|
=> (conv)
|
|
{AD + BE, BD + CE} [out (grad_weight)]
|
|
|
|
# grad_weight of transposed conv as conv(grad_out, inp)
|
|
As a reminder, the terms of the output of a transposed conv are:
|
|
=> {Ai, Aj + Bi, Bj + Ci, Cj} [transposed conv out]
|
|
=> {D, E, F, G} [grad_out]
|
|
|
|
Manually computing the gradient for the weights, we see it's
|
|
{AD + BE + CF, AE + BF + CG} [grad_weight]
|
|
|
|
This corresponds to the below conv
|
|
{D, E, F, G} [inp (grad_out)]
|
|
{A, B, C} [weight (inp)]
|
|
=> (conv)
|
|
{AD + BE + CF, AE + BF + CG} [out (grad_weight)]
|
|
|
|
For the full backwards formula, there are also some details involving
|
|
transpose of the batch/channel dimensions and groups, but I skip those for
|
|
the sake of brevity (and they're pretty similar to matmul backwards)
|
|
|
|
Check [conv backwards decomposition as conv forwards]
|
|
"""
|
|
# grad_inp as conv_transpose(grad_out, weight)
|
|
if output_mask[0]:
|
|
grad_input_shape = get_shape(out_shape[0])
|
|
flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not transposed)
|
|
|
|
if output_mask[1]:
|
|
grad_weight_shape = get_shape(out_shape[1])
|
|
if transposed:
|
|
# grad_weight of transposed conv as conv(grad_out, inp)
|
|
flop_count += conv_flop_count(t(grad_out_shape), t(x_shape), t(grad_weight_shape), transposed=False)
|
|
else:
|
|
# grad_weight as conv(inp, grad_out)
|
|
flop_count += conv_flop_count(t(x_shape), t(grad_out_shape), t(grad_weight_shape), transposed=False)
|
|
|
|
return flop_count
|
|
|
|
def sdpa_flop_count(query_shape, key_shape, value_shape):
|
|
"""
|
|
Count flops for self-attention.
|
|
|
|
NB: We can assume that value_shape == key_shape
|
|
"""
|
|
b, h, s_q, d_q = query_shape
|
|
_b2, _h2, s_k, _d2 = key_shape
|
|
_b3, _h3, _s3, d_v = value_shape
|
|
assert b == _b2 == _b3 and h == _h2 == _h3 and d_q == _d2 and s_k == _s3 and d_q == _d2
|
|
total_flops = 0
|
|
# q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k]
|
|
total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_k))
|
|
# scores: [b, h, s_q, s_k] @ v: [b, h, s_k, d_v] -> out: [b, h, s_q, d_v]
|
|
total_flops += bmm_flop((b * h, s_q, s_k), (b * h, s_k, d_v))
|
|
return total_flops
|
|
|
|
|
|
@register_flop_formula([aten._scaled_dot_product_efficient_attention,
|
|
aten._scaled_dot_product_flash_attention,
|
|
aten._scaled_dot_product_cudnn_attention])
|
|
def sdpa_flop(query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> int:
|
|
"""Count flops for self-attention."""
|
|
# NB: We aren't accounting for causal attention here
|
|
return sdpa_flop_count(query_shape, key_shape, value_shape)
|
|
|
|
|
|
def _offsets_to_lengths(offsets, max_len):
|
|
"""
|
|
If the offsets tensor is fake, then we don't know the actual lengths.
|
|
In that case, we can just assume the worst case; each batch has max length.
|
|
"""
|
|
from torch._subclasses.fake_tensor import FakeTensor
|
|
from torch._subclasses.functional_tensor import FunctionalTensor
|
|
if not isinstance(offsets, (FakeTensor, FunctionalTensor)) and offsets.device.type != "meta":
|
|
return offsets.diff().tolist()
|
|
return [max_len] * (offsets.size(0) - 1)
|
|
|
|
|
|
def _unpack_flash_attention_nested_shapes(
|
|
*,
|
|
query,
|
|
key,
|
|
value,
|
|
grad_out=None,
|
|
cum_seq_q,
|
|
cum_seq_k,
|
|
max_q,
|
|
max_k,
|
|
) -> Iterator[tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], Optional[tuple[int, ...]]]]:
|
|
"""
|
|
Given inputs to a flash_attention_(forward|backward) kernel, this will handle behavior for
|
|
NestedTensor inputs by effectively unbinding the NestedTensor and yielding the shapes for
|
|
each batch element.
|
|
|
|
In the case that this isn't a NestedTensor kernel, then it just yields the original shapes.
|
|
"""
|
|
if cum_seq_q is not None:
|
|
# This means we should be dealing with a Nested Jagged Tensor query.
|
|
# The inputs will have shape (sum(sequence len), heads, dimension)
|
|
# In comparison, non-Nested inputs have shape (batch, heads, sequence len, dimension)
|
|
# To deal with this, we convert to a shape of (batch, heads, max_seq_len, dimension)
|
|
# So the flops calculation in this case is an overestimate of the actual flops.
|
|
assert len(key.shape) == 3
|
|
assert len(value.shape) == 3
|
|
assert grad_out is None or grad_out.shape == query.shape
|
|
_, h_q, d_q = query.shape
|
|
_, h_k, d_k = key.shape
|
|
_, h_v, d_v = value.shape
|
|
assert cum_seq_q is not None
|
|
assert cum_seq_k is not None
|
|
assert cum_seq_q.shape == cum_seq_k.shape
|
|
seq_q_lengths = _offsets_to_lengths(cum_seq_q, max_q)
|
|
seq_k_lengths = _offsets_to_lengths(cum_seq_k, max_k)
|
|
for (seq_q_len, seq_k_len) in zip(seq_q_lengths, seq_k_lengths):
|
|
new_query_shape = (1, h_q, seq_q_len, d_q)
|
|
new_key_shape = (1, h_k, seq_k_len, d_k)
|
|
new_value_shape = (1, h_v, seq_k_len, d_v)
|
|
new_grad_out_shape = new_query_shape if grad_out is not None else None
|
|
yield new_query_shape, new_key_shape, new_value_shape, new_grad_out_shape
|
|
return
|
|
|
|
yield query.shape, key.shape, value.shape, grad_out.shape if grad_out is not None else None
|
|
|
|
|
|
def _unpack_efficient_attention_nested_shapes(
|
|
*,
|
|
query,
|
|
key,
|
|
value,
|
|
grad_out=None,
|
|
cu_seqlens_q,
|
|
cu_seqlens_k,
|
|
max_seqlen_q,
|
|
max_seqlen_k,
|
|
) -> Iterator[tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], Optional[tuple[int, ...]]]]:
|
|
"""
|
|
Given inputs to a efficient_attention_(forward|backward) kernel, this will handle behavior for
|
|
NestedTensor inputs by effectively unbinding the NestedTensor and yielding the shapes for
|
|
each batch element.
|
|
|
|
In the case that this isn't a NestedTensor kernel, then it just yields the original shapes.
|
|
"""
|
|
if cu_seqlens_q is not None:
|
|
# Unlike flash_attention_forward, we get a 4D tensor instead of a 3D tensor for efficient attention.
|
|
#
|
|
# This means we should be dealing with a Nested Jagged Tensor query.
|
|
# The inputs will have shape (sum(sequence len), heads, dimension)
|
|
# In comparison, non-Nested inputs have shape (batch, heads, sequence len, dimension)
|
|
# To deal with this, we convert to a shape of (batch, heads, max_seq_len, dimension)
|
|
# So the flops calculation in this case is an overestimate of the actual flops.
|
|
assert len(key.shape) == 4
|
|
assert len(value.shape) == 4
|
|
assert grad_out is None or grad_out.shape == query.shape
|
|
_, _, h_q, d_q = query.shape
|
|
_, _, h_k, d_k = key.shape
|
|
_, _, h_v, d_v = value.shape
|
|
assert cu_seqlens_q is not None
|
|
assert cu_seqlens_k is not None
|
|
assert cu_seqlens_q.shape == cu_seqlens_k.shape
|
|
seqlens_q = _offsets_to_lengths(cu_seqlens_q, max_seqlen_q)
|
|
seqlens_k = _offsets_to_lengths(cu_seqlens_k, max_seqlen_k)
|
|
for len_q, len_k in zip(seqlens_q, seqlens_k):
|
|
new_query_shape = (1, h_q, len_q, d_q)
|
|
new_key_shape = (1, h_k, len_k, d_k)
|
|
new_value_shape = (1, h_v, len_k, d_v)
|
|
new_grad_out_shape = new_query_shape if grad_out is not None else None
|
|
yield new_query_shape, new_key_shape, new_value_shape, new_grad_out_shape
|
|
return
|
|
|
|
yield query.shape, key.shape, value.shape, grad_out.shape if grad_out is not None else None
|
|
|
|
|
|
@register_flop_formula(aten._flash_attention_forward, get_raw=True)
|
|
def _flash_attention_forward_flop(
|
|
query,
|
|
key,
|
|
value,
|
|
cum_seq_q,
|
|
cum_seq_k,
|
|
max_q,
|
|
max_k,
|
|
*args,
|
|
out_shape=None,
|
|
**kwargs
|
|
) -> int:
|
|
"""Count flops for self-attention."""
|
|
# NB: We aren't accounting for causal attention here
|
|
# in case this is a nested tensor, we unpack the individual batch elements
|
|
# and then sum the flops per batch element
|
|
sizes = _unpack_flash_attention_nested_shapes(
|
|
query=query,
|
|
key=key,
|
|
value=value,
|
|
cum_seq_q=cum_seq_q,
|
|
cum_seq_k=cum_seq_k,
|
|
max_q=max_q,
|
|
max_k=max_k,
|
|
)
|
|
return sum(
|
|
sdpa_flop_count(query_shape, key_shape, value_shape)
|
|
for query_shape, key_shape, value_shape, _ in sizes
|
|
)
|
|
|
|
|
|
@register_flop_formula(aten._efficient_attention_forward, get_raw=True)
|
|
def _efficient_attention_forward_flop(
|
|
query,
|
|
key,
|
|
value,
|
|
bias,
|
|
cu_seqlens_q,
|
|
cu_seqlens_k,
|
|
max_seqlen_q,
|
|
max_seqlen_k,
|
|
*args,
|
|
**kwargs
|
|
) -> int:
|
|
"""Count flops for self-attention."""
|
|
# NB: We aren't accounting for causal attention here
|
|
# in case this is a nested tensor, we unpack the individual batch elements
|
|
# and then sum the flops per batch element
|
|
sizes = _unpack_efficient_attention_nested_shapes(
|
|
query=query,
|
|
key=key,
|
|
value=value,
|
|
cu_seqlens_q=cu_seqlens_q,
|
|
cu_seqlens_k=cu_seqlens_k,
|
|
max_seqlen_q=max_seqlen_q,
|
|
max_seqlen_k=max_seqlen_k,
|
|
)
|
|
return sum(
|
|
sdpa_flop_count(query_shape, key_shape, value_shape)
|
|
for query_shape, key_shape, value_shape, _ in sizes
|
|
)
|
|
|
|
|
|
def sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape):
|
|
total_flops = 0
|
|
b, h, s_q, d_q = query_shape
|
|
_b2, _h2, s_k, _d2 = key_shape
|
|
_b3, _h3, _s3, d_v = value_shape
|
|
_b4, _h4, _s4, _d4 = grad_out_shape
|
|
assert b == _b2 == _b3 == _b4 and h == _h2 == _h3 == _h4 and d_q == _d2
|
|
assert d_v == _d4 and s_k == _s3 and s_q == _s4
|
|
total_flops = 0
|
|
# Step 1: We recompute the scores matrix.
|
|
# q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k]
|
|
total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_k))
|
|
|
|
# Step 2: We propagate the gradients through the score @ v operation.
|
|
# gradOut: [b, h, s_q, d_v] @ v: [b, h, d_v, s_k] -> gradScores: [b, h, s_q, s_k]
|
|
total_flops += bmm_flop((b * h, s_q, d_v), (b * h, d_v, s_k))
|
|
# scores: [b, h, s_k, s_q] @ gradOut: [b, h, s_q, d_v] -> gradV: [b, h, s_k, d_v]
|
|
total_flops += bmm_flop((b * h, s_k, s_q), (b * h, s_q, d_v))
|
|
|
|
# Step 3: We propagate th gradients through the k @ v operation
|
|
# gradScores: [b, h, s_q, s_k] @ k: [b, h, s_k, d_q] -> gradQ: [b, h, s_q, d_q]
|
|
total_flops += bmm_flop((b * h, s_q, s_k), (b * h, s_k, d_q))
|
|
# q: [b, h, d_q, s_q] @ gradScores: [b, h, s_q, s_k] -> gradK: [b, h, d_q, s_k]
|
|
total_flops += bmm_flop((b * h, d_q, s_q), (b * h, s_q, s_k))
|
|
return total_flops
|
|
|
|
|
|
@register_flop_formula([aten._scaled_dot_product_efficient_attention_backward,
|
|
aten._scaled_dot_product_flash_attention_backward,
|
|
aten._scaled_dot_product_cudnn_attention_backward])
|
|
def sdpa_backward_flop(grad_out_shape, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> int:
|
|
"""Count flops for self-attention backward."""
|
|
return sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape)
|
|
|
|
@register_flop_formula(aten._flash_attention_backward, get_raw=True)
|
|
def _flash_attention_backward_flop(
|
|
grad_out,
|
|
query,
|
|
key,
|
|
value,
|
|
out, # named _out_shape to avoid kwarg collision with out_shape created in wrapper
|
|
logsumexp,
|
|
cum_seq_q,
|
|
cum_seq_k,
|
|
max_q,
|
|
max_k,
|
|
*args,
|
|
**kwargs,
|
|
) -> int:
|
|
# in case this is a nested tensor, we unpack the individual batch elements
|
|
# and then sum the flops per batch element
|
|
shapes = _unpack_flash_attention_nested_shapes(
|
|
query=query,
|
|
key=key,
|
|
value=value,
|
|
grad_out=grad_out,
|
|
cum_seq_q=cum_seq_q,
|
|
cum_seq_k=cum_seq_k,
|
|
max_q=max_q,
|
|
max_k=max_k,
|
|
)
|
|
return sum(
|
|
sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape)
|
|
for query_shape, key_shape, value_shape, grad_out_shape in shapes
|
|
)
|
|
|
|
|
|
@register_flop_formula(aten._efficient_attention_backward, get_raw=True)
|
|
def _efficient_attention_backward_flop(
|
|
grad_out,
|
|
query,
|
|
key,
|
|
value,
|
|
bias,
|
|
out, # named _out to avoid kwarg collision with out created in wrapper
|
|
cu_seqlens_q,
|
|
cu_seqlens_k,
|
|
max_seqlen_q,
|
|
max_seqlen_k,
|
|
*args,
|
|
**kwargs,
|
|
) -> int:
|
|
# in case this is a nested tensor, we unpack the individual batch elements
|
|
# and then sum the flops per batch element
|
|
shapes = _unpack_efficient_attention_nested_shapes(
|
|
query=query,
|
|
key=key,
|
|
value=value,
|
|
grad_out=grad_out,
|
|
cu_seqlens_q=cu_seqlens_q,
|
|
cu_seqlens_k=cu_seqlens_k,
|
|
max_seqlen_q=max_seqlen_q,
|
|
max_seqlen_k=max_seqlen_k,
|
|
)
|
|
return sum(
|
|
sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape)
|
|
for query_shape, key_shape, value_shape, grad_out_shape in shapes
|
|
)
|
|
|
|
|
|
flop_registry = {
|
|
aten.mm: mm_flop,
|
|
aten.addmm: addmm_flop,
|
|
aten.bmm: bmm_flop,
|
|
aten.baddbmm: baddbmm_flop,
|
|
aten._scaled_mm: _scaled_mm_flop,
|
|
aten.convolution: conv_flop,
|
|
aten._convolution: conv_flop,
|
|
aten.cudnn_convolution: conv_flop,
|
|
aten._slow_conv2d_forward: conv_flop,
|
|
aten.convolution_backward: conv_backward_flop,
|
|
aten._scaled_dot_product_efficient_attention: sdpa_flop,
|
|
aten._scaled_dot_product_flash_attention: sdpa_flop,
|
|
aten._scaled_dot_product_cudnn_attention: sdpa_flop,
|
|
aten._scaled_dot_product_efficient_attention_backward: sdpa_backward_flop,
|
|
aten._scaled_dot_product_flash_attention_backward: sdpa_backward_flop,
|
|
aten._scaled_dot_product_cudnn_attention_backward: sdpa_backward_flop,
|
|
aten._flash_attention_forward: _flash_attention_forward_flop,
|
|
aten._efficient_attention_forward: _efficient_attention_forward_flop,
|
|
aten._flash_attention_backward: _flash_attention_backward_flop,
|
|
aten._efficient_attention_backward: _efficient_attention_backward_flop,
|
|
}
|
|
|
|
def normalize_tuple(x):
|
|
if not isinstance(x, tuple):
|
|
return (x,)
|
|
return x
|
|
|
|
|
|
# Define the suffixes for different orders of magnitude
|
|
suffixes = ["", "K", "M", "B", "T"]
|
|
# Thanks BingChat!
|
|
def get_suffix_str(number):
|
|
# Find the index of the appropriate suffix based on the number of digits
|
|
# with some additional overflow.
|
|
# i.e. 1.01B should be displayed as 1001M, not 1.001B
|
|
index = max(0, min(len(suffixes) - 1, (len(str(number)) - 2) // 3))
|
|
return suffixes[index]
|
|
|
|
def convert_num_with_suffix(number, suffix):
|
|
index = suffixes.index(suffix)
|
|
# Divide the number by 1000^index and format it to two decimal places
|
|
value = f"{number / 1000 ** index:.3f}"
|
|
# Return the value and the suffix as a string
|
|
return value + suffixes[index]
|
|
|
|
def convert_to_percent_str(num, denom):
|
|
if denom == 0:
|
|
return "0%"
|
|
return f"{num / denom:.2%}"
|
|
|
|
def _pytreeify_preserve_structure(f):
|
|
@wraps(f)
|
|
def nf(args):
|
|
flat_args, spec = tree_flatten(args)
|
|
out = f(*flat_args)
|
|
return tree_unflatten(out, spec)
|
|
|
|
return nf
|
|
|
|
|
|
class FlopCounterMode:
|
|
"""
|
|
``FlopCounterMode`` is a context manager that counts the number of flops within its context.
|
|
|
|
It does this using a ``TorchDispatchMode``.
|
|
|
|
It also supports hierarchical output by passing a module (or list of
|
|
modules) to FlopCounterMode on construction. If you do not need hierarchical
|
|
output, you do not need to use it with a module.
|
|
|
|
Example usage
|
|
|
|
.. code-block:: python
|
|
|
|
mod = ...
|
|
with FlopCounterMode(mod) as flop_counter:
|
|
mod.sum().backward()
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
mods: Optional[Union[torch.nn.Module, list[torch.nn.Module]]] = None,
|
|
depth: int = 2,
|
|
display: bool = True,
|
|
custom_mapping: Optional[dict[Any, Any]] = None):
|
|
super().__init__()
|
|
self.flop_counts: dict[str, dict[Any, int]] = defaultdict(lambda: defaultdict(int))
|
|
self.depth = depth
|
|
self.display = display
|
|
self.mode: Optional[_FlopCounterMode] = None
|
|
if custom_mapping is None:
|
|
custom_mapping = {}
|
|
if mods is not None:
|
|
warnings.warn("mods argument is not needed anymore, you can stop passing it", stacklevel=2)
|
|
self.flop_registry = {
|
|
**flop_registry,
|
|
**{k: v if getattr(v, "_get_raw", False) else shape_wrapper(v) for k, v in custom_mapping.items()}
|
|
}
|
|
self.mod_tracker = ModuleTracker()
|
|
|
|
def get_total_flops(self) -> int:
|
|
return sum(self.flop_counts['Global'].values())
|
|
|
|
def get_flop_counts(self) -> dict[str, dict[Any, int]]:
|
|
"""Return the flop counts as a dictionary of dictionaries.
|
|
|
|
The outer
|
|
dictionary is keyed by module name, and the inner dictionary is keyed by
|
|
operation name.
|
|
|
|
Returns:
|
|
Dict[str, Dict[Any, int]]: The flop counts as a dictionary.
|
|
"""
|
|
return {k: dict(v) for k, v in self.flop_counts.items()}
|
|
|
|
def get_table(self, depth=None):
|
|
if depth is None:
|
|
depth = self.depth
|
|
if depth is None:
|
|
depth = 999999
|
|
|
|
import tabulate
|
|
tabulate.PRESERVE_WHITESPACE = True
|
|
header = ["Module", "FLOP", "% Total"]
|
|
values = []
|
|
global_flops = self.get_total_flops()
|
|
global_suffix = get_suffix_str(global_flops)
|
|
is_global_subsumed = False
|
|
|
|
def process_mod(mod_name, depth):
|
|
nonlocal is_global_subsumed
|
|
|
|
total_flops = sum(self.flop_counts[mod_name].values())
|
|
|
|
is_global_subsumed |= total_flops >= global_flops
|
|
|
|
padding = " " * depth
|
|
values = []
|
|
values.append([
|
|
padding + mod_name,
|
|
convert_num_with_suffix(total_flops, global_suffix),
|
|
convert_to_percent_str(total_flops, global_flops)
|
|
])
|
|
for k, v in self.flop_counts[mod_name].items():
|
|
values.append([
|
|
padding + " - " + str(k),
|
|
convert_num_with_suffix(v, global_suffix),
|
|
convert_to_percent_str(v, global_flops)
|
|
])
|
|
return values
|
|
|
|
for mod in sorted(self.flop_counts.keys()):
|
|
if mod == 'Global':
|
|
continue
|
|
mod_depth = mod.count(".") + 1
|
|
if mod_depth > depth:
|
|
continue
|
|
|
|
cur_values = process_mod(mod, mod_depth - 1)
|
|
values.extend(cur_values)
|
|
|
|
# We do a bit of messing around here to only output the "Global" value
|
|
# if there are any FLOPs in there that aren't already fully contained by
|
|
# a module.
|
|
if 'Global' in self.flop_counts and not is_global_subsumed:
|
|
for value in values:
|
|
value[0] = " " + value[0]
|
|
|
|
values = process_mod('Global', 0) + values
|
|
|
|
if len(values) == 0:
|
|
values = [["Global", "0", "0%"]]
|
|
|
|
return tabulate.tabulate(values, headers=header, colalign=("left", "right", "right"))
|
|
|
|
# NB: This context manager is NOT reentrant
|
|
def __enter__(self):
|
|
self.flop_counts.clear()
|
|
self.mod_tracker.__enter__()
|
|
self.mode = _FlopCounterMode(self)
|
|
self.mode.__enter__()
|
|
return self
|
|
|
|
def __exit__(self, *args):
|
|
assert self.mode is not None
|
|
b = self.mode.__exit__(*args)
|
|
self.mode = None # break cycles
|
|
self.mod_tracker.__exit__()
|
|
if self.display:
|
|
print(self.get_table(self.depth))
|
|
return b
|
|
|
|
def _count_flops(self, func_packet, out, args, kwargs):
|
|
if func_packet in self.flop_registry:
|
|
flop_count_func = self.flop_registry[func_packet]
|
|
flop_count = flop_count_func(*args, **kwargs, out_val=out) # type: ignore[operator]
|
|
for par in set(self.mod_tracker.parents):
|
|
self.flop_counts[par][func_packet] += flop_count
|
|
|
|
return out
|
|
|
|
|
|
class _FlopCounterMode(TorchDispatchMode):
|
|
def __init__(self, counter: FlopCounterMode):
|
|
self.counter = counter
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
kwargs = kwargs if kwargs else {}
|
|
|
|
# Skip ops from non-standard dispatch_sizes_strides_policy such as NJT
|
|
if func in {torch.ops.aten.is_contiguous.default,
|
|
torch.ops.aten.is_contiguous.memory_format,
|
|
torch.ops.aten.is_strides_like_format.default,
|
|
torch.ops.aten.is_non_overlapping_and_dense.default,
|
|
torch.ops.aten.size.default,
|
|
torch.ops.aten.sym_size.default,
|
|
torch.ops.aten.stride.default,
|
|
torch.ops.aten.sym_stride.default,
|
|
torch.ops.aten.storage_offset.default,
|
|
torch.ops.aten.sym_storage_offset.default,
|
|
torch.ops.aten.numel.default,
|
|
torch.ops.aten.sym_numel.default,
|
|
torch.ops.aten.dim.default,
|
|
torch.ops.prim.layout.default}:
|
|
|
|
return NotImplemented
|
|
|
|
# If we don't have func in flop_registry, see if it can decompose
|
|
if func not in self.counter.flop_registry and func is not torch.ops.prim.device.default:
|
|
with self:
|
|
r = func.decompose(*args, **kwargs)
|
|
if r is not NotImplemented:
|
|
return r
|
|
|
|
# no further decomposition; execute & count flops
|
|
out = func(*args, **kwargs)
|
|
return self.counter._count_flops(func._overloadpacket, out, args, kwargs)
|