Revert "Add torch._scaled_mm for CPU (#139975)"

This reverts commit f0bdc27f74f8b1d4ab6789156691ee0fd5cbb30f.

Reverted https://github.com/pytorch/pytorch/pull/139975 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but it looks like internal ideep version is too old to support this ([comment](https://github.com/pytorch/pytorch/pull/139975#issuecomment-2660008996))
This commit is contained in:
PyTorch MergeBot
2025-02-14 18:31:54 +00:00
parent 20a9938069
commit aac5d1a289
12 changed files with 586 additions and 915 deletions

View File

@ -22,7 +22,7 @@ from torch.testing import make_tensor
from torch.testing._internal.common_dtype import (
_dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types,
floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and,
empty_types, complex_types_and, integral_types, custom_types, all_types_complex_float8_and, float8_types,
empty_types, complex_types_and, integral_types, custom_types, all_types_complex_float8_and,
)
from torch.testing._internal.common_device_type import \
(onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver,
@ -8743,21 +8743,18 @@ def sample_inputs_scaled_mm(op_info, device, dtype, requires_grad, **kwargs):
scale1 = make_scale((1,))
scale2 = make_scale((1,))
samples.append(SampleInput(mat1, mat2, scale1, scale2))
# TODO: Will remove this after oneDNN v3.6
# now oneDNN v3.5.3 only supports mat1 * mat2 with the same data types.
if device != 'cpu':
# mat1 e4m3 mat2 e5m2
mat1 = make_mat_e4m3((M, K))
mat2 = make_mat_e5m2((K, N)).t().contiguous().t()
scale1 = make_scale((1,))
scale2 = make_scale((1,))
samples.append(SampleInput(mat1, mat2, scale1, scale2))
# mat1 e5m2 mat2 e4m3
mat1 = make_mat_e5m2((M, K))
mat2 = make_mat_e4m3((K, N)).t().contiguous().t()
scale1 = make_scale((1,))
scale2 = make_scale((1,))
samples.append(SampleInput(mat1, mat2, scale1, scale2))
# mat1 e4m3 mat2 e5m2
mat1 = make_mat_e4m3((M, K))
mat2 = make_mat_e5m2((K, N)).t().contiguous().t()
scale1 = make_scale((1,))
scale2 = make_scale((1,))
samples.append(SampleInput(mat1, mat2, scale1, scale2))
# mat1 e5m2 mat2 e4m3
mat1 = make_mat_e5m2((M, K))
mat2 = make_mat_e4m3((K, N)).t().contiguous().t()
scale1 = make_scale((1,))
scale2 = make_scale((1,))
samples.append(SampleInput(mat1, mat2, scale1, scale2))
yield from samples
@ -16217,7 +16214,7 @@ op_db: list[OpInfo] = [
OpInfo(
'torch._scaled_mm',
sample_inputs_func=sample_inputs_scaled_mm,
dtypes=float8_types(),
dtypes=empty_types(),
dtypesIfCUDA=empty_types() + (torch.float8_e4m3fn,),
supports_out=True,
supports_forward_ad=False,
@ -16225,20 +16222,12 @@ op_db: list[OpInfo] = [
decorators=[skipCUDAIf(not SM89OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 8.9')],
skips=(
# Sample inputs isn't really parametrized on dtype
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes'),
# "add_stub" not implemented for 'Float8_e4m3fn'
# "ufunc_add_CUDA" not implemented for 'Float8_e4m3fn'
# https://github.com/pytorch/pytorch/issues/107256
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes',
device_type='cuda'),
# "mul_cuda" not implemented for float8_e4m3fn
# "mul_cpu_reduced_float" not implemented for 'Float8_e4m3fn'
# https://github.com/pytorch/pytorch/issues/107256
DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness'),
# aten::_scaled_mm hit the vmap fallback which is currently disabled
DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"),
DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness',
dtypes=(torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz)),
DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness',
dtypes=(torch.float8_e4m3fn,)),
)
),
OpInfo(