Back out "Revert D81959389" (#163905)

Summary:
Original commit changeset: 06888d7ebff0

Original Phabricator Diff: D82932788

Restricted the test to SM90 for scaled_grouped_mm

Test Plan: TBD (will share the linux CI results)

Differential Revision: D83283991

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163905
Approved by: https://github.com/angelayi
This commit is contained in:
Yavuz Yetim
2025-09-30 07:05:09 +00:00
committed by PyTorch MergeBot
parent bbf6816f35
commit 7afcb030d8
5 changed files with 77 additions and 2 deletions

View File

@ -42,6 +42,7 @@ from torch.testing import FileCheck
from torch.testing._internal import common_utils
from torch.testing._internal.common_cuda import (
_get_torch_cuda_version,
IS_SM90,
PLATFORM_SUPPORTS_FLASH_ATTENTION,
PLATFORM_SUPPORTS_FP8,
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
@ -1238,6 +1239,72 @@ class AOTInductorTestsTemplate:
dynamic_shapes=dynamic_shapes,
)
@unittest.skipIf(
TEST_WITH_ROCM or not IS_SM90,
"scaled_grouped_mm is only supported on SM90",
)
@skipIfXpu
def test_scaled_grouped_mm(self):
# Test torch._scaled_grouped_mm AOTI lowering
# cuda only
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, weight, scale_a, scale_b, offsets):
# x: [num_groups, batch, in_features] - FP8 inputs
# weight: [total_out_features, in_features] - FP8 weights (transposed)
# scale_a: [num_groups] - input scales
# scale_b: [num_groups] - weight scales
# offsets: [num_groups] - cumulative output sizes
output = torch._scaled_grouped_mm(
x,
weight.t(),
scale_a=scale_a,
scale_b=scale_b,
offs=offsets,
use_fast_accum=True,
)
return output.half()
dtype = torch.float16
num_groups = 3
batch_size = 64
in_features = 128
out_features_list = [64, 128, 256] # Different output sizes for each group
device = GPU_TYPE
# Calculate offsets (cumulative output sizes)
offsets = torch.cumsum(torch.tensor(out_features_list), dim=0).to(
device, dtype=torch.int32
)
total_out_features = sum(out_features_list)
# Create FP8 input tensors - stacked for all groups
x_fp16 = torch.randn(
num_groups, batch_size, in_features, dtype=dtype, device=device
)
x_fp8 = x_fp16.to(torch.float8_e4m3fn)
# Create FP8 weight tensor - concatenated and transposed
weight_fp16 = torch.randn(
total_out_features, in_features, dtype=dtype, device=device
)
weight_fp8 = weight_fp16.to(torch.float8_e4m3fn)
# Create scales
scale_a = torch.ones(num_groups, batch_size, device=device, dtype=torch.float32)
scale_b = torch.ones(total_out_features, device=device, dtype=torch.float32)
self.check_model(
Model(),
(x_fp8, weight_fp8, scale_a, scale_b, offsets),
)
@unittest.skipIf(
not PLATFORM_SUPPORTS_FP8,
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",

View File

@ -469,7 +469,7 @@ def grouped_mm_args(
aten__grouped_mm = ExternKernelChoice(
torch._grouped_mm,
"at::_grouped_mm",
op_overload=aten._grouped_mm,
op_overload=aten._grouped_mm.default,
has_out_variant=False,
)
@ -477,7 +477,7 @@ aten__grouped_mm = ExternKernelChoice(
aten__scaled_grouped_mm = ExternKernelChoice(
torch._scaled_grouped_mm,
"at::_scaled_grouped_mm",
op_overload=aten._scaled_grouped_mm,
op_overload=aten._scaled_grouped_mm.default,
has_out_variant=False,
)
@ -735,6 +735,9 @@ def tuned_scaled_grouped_mm(
) -> TensorBox:
"""Auto-tuning for _scaled_grouped_mm() operator."""
# matching _scaled_grouped_mm_cuda Blas.cpp implementation
out_dtype = out_dtype or torch.bfloat16
return _tuned_grouped_mm_common(
"aten._scaled_grouped_mm.default",
"scaled_grouped_mm",

View File

@ -7642,6 +7642,9 @@ def meta_scaled_grouped_mm(
out_dtype: Optional[torch.dtype] = None,
use_fast_accum: bool = False,
):
# matching _scaled_grouped_mm_cuda Blas.cpp implementation
out_dtype = out_dtype or torch.bfloat16
return _meta_grouped_mm_common(
mat_a,
mat_b,

View File

@ -44,6 +44,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_flash_atten
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_flash_attention_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_grouped_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* offs, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__segment_reduce_backward(AtenTensorHandle grad, AtenTensorHandle output, AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* offsets, int64_t axis, double* initial, AtenTensorHandle* ret0);

View File

@ -56,6 +56,7 @@ inductor_fallback_ops: dict[str, dict[str, list[str]]] = {
"aten._scaled_dot_product_fused_attention_overrideable_backward.default": {},
"aten._scaled_dot_product_fused_attention_overrideable.default": {},
"aten._scaled_mm.default": {},
"aten._scaled_grouped_mm.default": {},
"aten._scaled_mm.out": {},
"aten._segment_reduce_backward.default": {},
"aten._thnn_fused_lstm_cell.default": {},