mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Back out "Revert D81959389" (#163905)
Summary: Original commit changeset: 06888d7ebff0 Original Phabricator Diff: D82932788 Restricted the test to SM90 for scaled_grouped_mm Test Plan: TBD (will share the linux CI results) Differential Revision: D83283991 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163905 Approved by: https://github.com/angelayi
This commit is contained in:
committed by
PyTorch MergeBot
parent
bbf6816f35
commit
7afcb030d8
@ -42,6 +42,7 @@ from torch.testing import FileCheck
|
||||
from torch.testing._internal import common_utils
|
||||
from torch.testing._internal.common_cuda import (
|
||||
_get_torch_cuda_version,
|
||||
IS_SM90,
|
||||
PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
||||
PLATFORM_SUPPORTS_FP8,
|
||||
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
|
||||
@ -1238,6 +1239,72 @@ class AOTInductorTestsTemplate:
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
)
|
||||
|
||||
@unittest.skipIf(
|
||||
TEST_WITH_ROCM or not IS_SM90,
|
||||
"scaled_grouped_mm is only supported on SM90",
|
||||
)
|
||||
@skipIfXpu
|
||||
def test_scaled_grouped_mm(self):
|
||||
# Test torch._scaled_grouped_mm AOTI lowering
|
||||
# cuda only
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("requires CUDA")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, weight, scale_a, scale_b, offsets):
|
||||
# x: [num_groups, batch, in_features] - FP8 inputs
|
||||
# weight: [total_out_features, in_features] - FP8 weights (transposed)
|
||||
# scale_a: [num_groups] - input scales
|
||||
# scale_b: [num_groups] - weight scales
|
||||
# offsets: [num_groups] - cumulative output sizes
|
||||
output = torch._scaled_grouped_mm(
|
||||
x,
|
||||
weight.t(),
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
offs=offsets,
|
||||
use_fast_accum=True,
|
||||
)
|
||||
return output.half()
|
||||
|
||||
dtype = torch.float16
|
||||
num_groups = 3
|
||||
batch_size = 64
|
||||
in_features = 128
|
||||
out_features_list = [64, 128, 256] # Different output sizes for each group
|
||||
|
||||
device = GPU_TYPE
|
||||
|
||||
# Calculate offsets (cumulative output sizes)
|
||||
offsets = torch.cumsum(torch.tensor(out_features_list), dim=0).to(
|
||||
device, dtype=torch.int32
|
||||
)
|
||||
total_out_features = sum(out_features_list)
|
||||
|
||||
# Create FP8 input tensors - stacked for all groups
|
||||
x_fp16 = torch.randn(
|
||||
num_groups, batch_size, in_features, dtype=dtype, device=device
|
||||
)
|
||||
x_fp8 = x_fp16.to(torch.float8_e4m3fn)
|
||||
|
||||
# Create FP8 weight tensor - concatenated and transposed
|
||||
weight_fp16 = torch.randn(
|
||||
total_out_features, in_features, dtype=dtype, device=device
|
||||
)
|
||||
weight_fp8 = weight_fp16.to(torch.float8_e4m3fn)
|
||||
|
||||
# Create scales
|
||||
scale_a = torch.ones(num_groups, batch_size, device=device, dtype=torch.float32)
|
||||
scale_b = torch.ones(total_out_features, device=device, dtype=torch.float32)
|
||||
|
||||
self.check_model(
|
||||
Model(),
|
||||
(x_fp8, weight_fp8, scale_a, scale_b, offsets),
|
||||
)
|
||||
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FP8,
|
||||
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
|
||||
|
@ -469,7 +469,7 @@ def grouped_mm_args(
|
||||
aten__grouped_mm = ExternKernelChoice(
|
||||
torch._grouped_mm,
|
||||
"at::_grouped_mm",
|
||||
op_overload=aten._grouped_mm,
|
||||
op_overload=aten._grouped_mm.default,
|
||||
has_out_variant=False,
|
||||
)
|
||||
|
||||
@ -477,7 +477,7 @@ aten__grouped_mm = ExternKernelChoice(
|
||||
aten__scaled_grouped_mm = ExternKernelChoice(
|
||||
torch._scaled_grouped_mm,
|
||||
"at::_scaled_grouped_mm",
|
||||
op_overload=aten._scaled_grouped_mm,
|
||||
op_overload=aten._scaled_grouped_mm.default,
|
||||
has_out_variant=False,
|
||||
)
|
||||
|
||||
@ -735,6 +735,9 @@ def tuned_scaled_grouped_mm(
|
||||
) -> TensorBox:
|
||||
"""Auto-tuning for _scaled_grouped_mm() operator."""
|
||||
|
||||
# matching _scaled_grouped_mm_cuda Blas.cpp implementation
|
||||
out_dtype = out_dtype or torch.bfloat16
|
||||
|
||||
return _tuned_grouped_mm_common(
|
||||
"aten._scaled_grouped_mm.default",
|
||||
"scaled_grouped_mm",
|
||||
|
@ -7642,6 +7642,9 @@ def meta_scaled_grouped_mm(
|
||||
out_dtype: Optional[torch.dtype] = None,
|
||||
use_fast_accum: bool = False,
|
||||
):
|
||||
# matching _scaled_grouped_mm_cuda Blas.cpp implementation
|
||||
out_dtype = out_dtype or torch.bfloat16
|
||||
|
||||
return _meta_grouped_mm_common(
|
||||
mat_a,
|
||||
mat_b,
|
||||
|
@ -44,6 +44,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_flash_atten
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_flash_attention_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_grouped_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* offs, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__segment_reduce_backward(AtenTensorHandle grad, AtenTensorHandle output, AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* offsets, int64_t axis, double* initial, AtenTensorHandle* ret0);
|
||||
|
@ -56,6 +56,7 @@ inductor_fallback_ops: dict[str, dict[str, list[str]]] = {
|
||||
"aten._scaled_dot_product_fused_attention_overrideable_backward.default": {},
|
||||
"aten._scaled_dot_product_fused_attention_overrideable.default": {},
|
||||
"aten._scaled_mm.default": {},
|
||||
"aten._scaled_grouped_mm.default": {},
|
||||
"aten._scaled_mm.out": {},
|
||||
"aten._segment_reduce_backward.default": {},
|
||||
"aten._thnn_fused_lstm_cell.default": {},
|
||||
|
Reference in New Issue
Block a user