mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add flop formula for _scaled_mm (#144973)
This will make it work correctly with the partitioner's AutoAC Pull Request resolved: https://github.com/pytorch/pytorch/pull/144973 Approved by: https://github.com/jeffdaily
This commit is contained in:
committed by
PyTorch MergeBot
parent
96c0dbbe97
commit
a0d2c09115
@ -9,6 +9,7 @@ import torch.utils.flop_counter
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.testing._internal.common_cuda import (
|
||||
PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
||||
PLATFORM_SUPPORTS_FP8,
|
||||
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
|
||||
PLATFORM_SUPPORTS_CUDNN_ATTENTION
|
||||
)
|
||||
@ -835,5 +836,23 @@ class TestFlopCounter(TestCase):
|
||||
]
|
||||
self.assertEqual(layer1_conv_flops_standard, layer1_conv_flops_inference)
|
||||
|
||||
@unittest.skipIf(not HAS_CUDA, "CUDA not available")
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FP8,
|
||||
"Does not support fp8 (pre-SM90 hardware on CUDA)",
|
||||
)
|
||||
def test_scaled_mm(self):
|
||||
dtype = torch.float8_e4m3fnuz if torch.version.hip else torch.float8_e4m3fn
|
||||
with FlopCounterMode() as mode:
|
||||
torch._scaled_mm(
|
||||
torch.randn((3 * 16, 5 * 16), device="cuda").to(dtype),
|
||||
torch.randn((7 * 16, 5 * 16), device="cuda").to(dtype).t(),
|
||||
scale_a=torch.ones((), device="cuda"),
|
||||
scale_b=torch.ones((), device="cuda"),
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
self.assertExpectedInline(get_total_flops(mode), """860160""")
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -89,6 +89,22 @@ def baddbmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int:
|
||||
# Inputs contains the shapes of three tensors.
|
||||
return bmm_flop(a_shape, b_shape)
|
||||
|
||||
@register_flop_formula(aten._scaled_mm)
|
||||
def _scaled_mm_flop(
|
||||
a_shape,
|
||||
b_shape,
|
||||
scale_a_shape,
|
||||
scale_b_shape,
|
||||
bias_shape=None,
|
||||
scale_result_shape=None,
|
||||
out_dtype=None,
|
||||
use_fast_accum=False,
|
||||
out_shape=None,
|
||||
**kwargs,
|
||||
) -> int:
|
||||
"""Count flops for _scaled_mm."""
|
||||
return mm_flop(a_shape, b_shape)
|
||||
|
||||
|
||||
def conv_flop_count(
|
||||
x_shape: List[int],
|
||||
@ -541,6 +557,7 @@ flop_registry = {
|
||||
aten.addmm: addmm_flop,
|
||||
aten.bmm: bmm_flop,
|
||||
aten.baddbmm: baddbmm_flop,
|
||||
aten._scaled_mm: _scaled_mm_flop,
|
||||
aten.convolution: conv_flop,
|
||||
aten._convolution: conv_flop,
|
||||
aten.convolution_backward: conv_backward_flop,
|
||||
|
Reference in New Issue
Block a user