mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Add torch._scaled_mm for CPU (#139975)"
This reverts commit f0bdc27f74f8b1d4ab6789156691ee0fd5cbb30f. Reverted https://github.com/pytorch/pytorch/pull/139975 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but it looks like internal ideep version is too old to support this ([comment](https://github.com/pytorch/pytorch/pull/139975#issuecomment-2660008996))
This commit is contained in:
@ -22,7 +22,7 @@ from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_dtype import (
|
||||
_dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types,
|
||||
floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and,
|
||||
empty_types, complex_types_and, integral_types, custom_types, all_types_complex_float8_and, float8_types,
|
||||
empty_types, complex_types_and, integral_types, custom_types, all_types_complex_float8_and,
|
||||
)
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver,
|
||||
@ -8743,21 +8743,18 @@ def sample_inputs_scaled_mm(op_info, device, dtype, requires_grad, **kwargs):
|
||||
scale1 = make_scale((1,))
|
||||
scale2 = make_scale((1,))
|
||||
samples.append(SampleInput(mat1, mat2, scale1, scale2))
|
||||
# TODO: Will remove this after oneDNN v3.6
|
||||
# now oneDNN v3.5.3 only supports mat1 * mat2 with the same data types.
|
||||
if device != 'cpu':
|
||||
# mat1 e4m3 mat2 e5m2
|
||||
mat1 = make_mat_e4m3((M, K))
|
||||
mat2 = make_mat_e5m2((K, N)).t().contiguous().t()
|
||||
scale1 = make_scale((1,))
|
||||
scale2 = make_scale((1,))
|
||||
samples.append(SampleInput(mat1, mat2, scale1, scale2))
|
||||
# mat1 e5m2 mat2 e4m3
|
||||
mat1 = make_mat_e5m2((M, K))
|
||||
mat2 = make_mat_e4m3((K, N)).t().contiguous().t()
|
||||
scale1 = make_scale((1,))
|
||||
scale2 = make_scale((1,))
|
||||
samples.append(SampleInput(mat1, mat2, scale1, scale2))
|
||||
# mat1 e4m3 mat2 e5m2
|
||||
mat1 = make_mat_e4m3((M, K))
|
||||
mat2 = make_mat_e5m2((K, N)).t().contiguous().t()
|
||||
scale1 = make_scale((1,))
|
||||
scale2 = make_scale((1,))
|
||||
samples.append(SampleInput(mat1, mat2, scale1, scale2))
|
||||
# mat1 e5m2 mat2 e4m3
|
||||
mat1 = make_mat_e5m2((M, K))
|
||||
mat2 = make_mat_e4m3((K, N)).t().contiguous().t()
|
||||
scale1 = make_scale((1,))
|
||||
scale2 = make_scale((1,))
|
||||
samples.append(SampleInput(mat1, mat2, scale1, scale2))
|
||||
|
||||
yield from samples
|
||||
|
||||
@ -16217,7 +16214,7 @@ op_db: list[OpInfo] = [
|
||||
OpInfo(
|
||||
'torch._scaled_mm',
|
||||
sample_inputs_func=sample_inputs_scaled_mm,
|
||||
dtypes=float8_types(),
|
||||
dtypes=empty_types(),
|
||||
dtypesIfCUDA=empty_types() + (torch.float8_e4m3fn,),
|
||||
supports_out=True,
|
||||
supports_forward_ad=False,
|
||||
@ -16225,20 +16222,12 @@ op_db: list[OpInfo] = [
|
||||
decorators=[skipCUDAIf(not SM89OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 8.9')],
|
||||
skips=(
|
||||
# Sample inputs isn't really parametrized on dtype
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes'),
|
||||
# "add_stub" not implemented for 'Float8_e4m3fn'
|
||||
# "ufunc_add_CUDA" not implemented for 'Float8_e4m3fn'
|
||||
# https://github.com/pytorch/pytorch/issues/107256
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes',
|
||||
device_type='cuda'),
|
||||
# "mul_cuda" not implemented for float8_e4m3fn
|
||||
# "mul_cpu_reduced_float" not implemented for 'Float8_e4m3fn'
|
||||
# https://github.com/pytorch/pytorch/issues/107256
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness'),
|
||||
# aten::_scaled_mm hit the vmap fallback which is currently disabled
|
||||
DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
|
||||
DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness',
|
||||
dtypes=(torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz)),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness',
|
||||
dtypes=(torch.float8_e4m3fn,)),
|
||||
)
|
||||
),
|
||||
OpInfo(
|
||||
|
Reference in New Issue
Block a user