diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 6d7d22398989..6424bc6192ea 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -42,6 +42,7 @@ from torch.testing import FileCheck from torch.testing._internal import common_utils from torch.testing._internal.common_cuda import ( _get_torch_cuda_version, + IS_SM90, PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_FP8, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, @@ -1238,6 +1239,72 @@ class AOTInductorTestsTemplate: dynamic_shapes=dynamic_shapes, ) + @unittest.skipIf( + TEST_WITH_ROCM or not IS_SM90, + "scaled_grouped_mm is only supported on SM90", + ) + @skipIfXpu + def test_scaled_grouped_mm(self): + # Test torch._scaled_grouped_mm AOTI lowering + # cuda only + if self.device != "cuda": + raise unittest.SkipTest("requires CUDA") + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, weight, scale_a, scale_b, offsets): + # x: [num_groups, batch, in_features] - FP8 inputs + # weight: [total_out_features, in_features] - FP8 weights (transposed) + # scale_a: [num_groups] - input scales + # scale_b: [num_groups] - weight scales + # offsets: [num_groups] - cumulative output sizes + output = torch._scaled_grouped_mm( + x, + weight.t(), + scale_a=scale_a, + scale_b=scale_b, + offs=offsets, + use_fast_accum=True, + ) + return output.half() + + dtype = torch.float16 + num_groups = 3 + batch_size = 64 + in_features = 128 + out_features_list = [64, 128, 256] # Different output sizes for each group + + device = GPU_TYPE + + # Calculate offsets (cumulative output sizes) + offsets = torch.cumsum(torch.tensor(out_features_list), dim=0).to( + device, dtype=torch.int32 + ) + total_out_features = sum(out_features_list) + + # Create FP8 input tensors - stacked for all groups + x_fp16 = torch.randn( + num_groups, batch_size, in_features, dtype=dtype, device=device + ) + x_fp8 = x_fp16.to(torch.float8_e4m3fn) + + # Create FP8 weight tensor - concatenated and transposed + weight_fp16 = torch.randn( + total_out_features, in_features, dtype=dtype, device=device + ) + weight_fp8 = weight_fp16.to(torch.float8_e4m3fn) + + # Create scales + scale_a = torch.ones(num_groups, batch_size, device=device, dtype=torch.float32) + scale_b = torch.ones(total_out_features, device=device, dtype=torch.float32) + + self.check_model( + Model(), + (x_fp8, weight_fp8, scale_a, scale_b, offsets), + ) + @unittest.skipIf( not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+, SM 8.9 and MI300+ devices", diff --git a/torch/_inductor/kernel/mm_grouped.py b/torch/_inductor/kernel/mm_grouped.py index 1467224f6b9f..c082ef337714 100644 --- a/torch/_inductor/kernel/mm_grouped.py +++ b/torch/_inductor/kernel/mm_grouped.py @@ -469,7 +469,7 @@ def grouped_mm_args( aten__grouped_mm = ExternKernelChoice( torch._grouped_mm, "at::_grouped_mm", - op_overload=aten._grouped_mm, + op_overload=aten._grouped_mm.default, has_out_variant=False, ) @@ -477,7 +477,7 @@ aten__grouped_mm = ExternKernelChoice( aten__scaled_grouped_mm = ExternKernelChoice( torch._scaled_grouped_mm, "at::_scaled_grouped_mm", - op_overload=aten._scaled_grouped_mm, + op_overload=aten._scaled_grouped_mm.default, has_out_variant=False, ) @@ -735,6 +735,9 @@ def tuned_scaled_grouped_mm( ) -> TensorBox: """Auto-tuning for _scaled_grouped_mm() operator.""" + # matching _scaled_grouped_mm_cuda Blas.cpp implementation + out_dtype = out_dtype or torch.bfloat16 + return _tuned_grouped_mm_common( "aten._scaled_grouped_mm.default", "scaled_grouped_mm", diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 6c30f01d9fe7..b15d1eb04830 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -7642,6 +7642,9 @@ def meta_scaled_grouped_mm( out_dtype: Optional[torch.dtype] = None, use_fast_accum: bool = False, ): + # matching _scaled_grouped_mm_cuda Blas.cpp implementation + out_dtype = out_dtype or torch.bfloat16 + return _meta_grouped_mm_common( mat_a, mat_b, diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h index 470919cf389c..c41487ae6dd8 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h @@ -44,6 +44,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_flash_atten AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_flash_attention_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_grouped_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* offs, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__segment_reduce_backward(AtenTensorHandle grad, AtenTensorHandle output, AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* offsets, int64_t axis, double* initial, AtenTensorHandle* ret0); diff --git a/torchgen/aoti/fallback_ops.py b/torchgen/aoti/fallback_ops.py index 611400d271d9..a66151a31bb1 100644 --- a/torchgen/aoti/fallback_ops.py +++ b/torchgen/aoti/fallback_ops.py @@ -56,6 +56,7 @@ inductor_fallback_ops: dict[str, dict[str, list[str]]] = { "aten._scaled_dot_product_fused_attention_overrideable_backward.default": {}, "aten._scaled_dot_product_fused_attention_overrideable.default": {}, "aten._scaled_mm.default": {}, + "aten._scaled_grouped_mm.default": {}, "aten._scaled_mm.out": {}, "aten._segment_reduce_backward.default": {}, "aten._thnn_fused_lstm_cell.default": {},