Add flop formula for _scaled_mm (#144973)

This will make it work correctly with the partitioner's AutoAC
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144973
Approved by: https://github.com/jeffdaily
This commit is contained in:
Luca Wehrstedt
2025-01-16 17:33:33 +00:00
committed by PyTorch MergeBot
parent 96c0dbbe97
commit a0d2c09115
2 changed files with 36 additions and 0 deletions

View File

@ -9,6 +9,7 @@ import torch.utils.flop_counter
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FLASH_ATTENTION,
PLATFORM_SUPPORTS_FP8,
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
PLATFORM_SUPPORTS_CUDNN_ATTENTION
)
@ -835,5 +836,23 @@ class TestFlopCounter(TestCase):
]
self.assertEqual(layer1_conv_flops_standard, layer1_conv_flops_inference)
@unittest.skipIf(not HAS_CUDA, "CUDA not available")
@unittest.skipIf(
not PLATFORM_SUPPORTS_FP8,
"Does not support fp8 (pre-SM90 hardware on CUDA)",
)
def test_scaled_mm(self):
dtype = torch.float8_e4m3fnuz if torch.version.hip else torch.float8_e4m3fn
with FlopCounterMode() as mode:
torch._scaled_mm(
torch.randn((3 * 16, 5 * 16), device="cuda").to(dtype),
torch.randn((7 * 16, 5 * 16), device="cuda").to(dtype).t(),
scale_a=torch.ones((), device="cuda"),
scale_b=torch.ones((), device="cuda"),
out_dtype=torch.bfloat16,
)
self.assertExpectedInline(get_total_flops(mode), """860160""")
if __name__ == "__main__":
run_tests()

View File

@ -89,6 +89,22 @@ def baddbmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int:
# Inputs contains the shapes of three tensors.
return bmm_flop(a_shape, b_shape)
@register_flop_formula(aten._scaled_mm)
def _scaled_mm_flop(
a_shape,
b_shape,
scale_a_shape,
scale_b_shape,
bias_shape=None,
scale_result_shape=None,
out_dtype=None,
use_fast_accum=False,
out_shape=None,
**kwargs,
) -> int:
"""Count flops for _scaled_mm."""
return mm_flop(a_shape, b_shape)
def conv_flop_count(
x_shape: List[int],
@ -541,6 +557,7 @@ flop_registry = {
aten.addmm: addmm_flop,
aten.bmm: bmm_flop,
aten.baddbmm: baddbmm_flop,
aten._scaled_mm: _scaled_mm_flop,
aten.convolution: conv_flop,
aten._convolution: conv_flop,
aten.convolution_backward: conv_backward_flop,