mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: - Flip the default value of strict argument in torch.export.export from True to False - Update test infra to cope with the change, some of them made the assumption of strict mode as default - Disabled some tests that fail in non-strict mode Test Plan: Sandcastle Differential Revision: D70228628 Pull Request resolved: https://github.com/pytorch/pytorch/pull/148790 Approved by: https://github.com/angelayi
324 lines
8.8 KiB
Python
324 lines
8.8 KiB
Python
import functools
|
|
import unittest
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
|
|
|
|
aten = torch.ops.aten
|
|
|
|
# This list is not meant to be comprehensive
|
|
_COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY = [
|
|
aten.arctan2.default,
|
|
aten.divide.Tensor,
|
|
aten.divide.Scalar,
|
|
aten.divide.Tensor_mode,
|
|
aten.divide.Scalar_mode,
|
|
aten.multiply.Tensor,
|
|
aten.multiply.Scalar,
|
|
aten.subtract.Tensor,
|
|
aten.subtract.Scalar,
|
|
aten.true_divide.Tensor,
|
|
aten.true_divide.Scalar,
|
|
aten.greater.Tensor,
|
|
aten.greater.Scalar,
|
|
aten.greater_equal.Tensor,
|
|
aten.greater_equal.Scalar,
|
|
aten.less_equal.Tensor,
|
|
aten.less_equal.Scalar,
|
|
aten.less.Tensor,
|
|
aten.less.Scalar,
|
|
aten.not_equal.Tensor,
|
|
aten.not_equal.Scalar,
|
|
aten.cat.names,
|
|
aten.sum.dim_DimnameList,
|
|
aten.mean.names_dim,
|
|
aten.prod.dim_Dimname,
|
|
aten.all.dimname,
|
|
aten.norm.names_ScalarOpt_dim,
|
|
aten.norm.names_ScalarOpt_dim_dtype,
|
|
aten.var.default,
|
|
aten.var.dim,
|
|
aten.var.names_dim,
|
|
aten.var.correction_names,
|
|
aten.std.default,
|
|
aten.std.dim,
|
|
aten.std.names_dim,
|
|
aten.std.correction_names,
|
|
aten.absolute.default,
|
|
aten.arccos.default,
|
|
aten.arccosh.default,
|
|
aten.arcsin.default,
|
|
aten.arcsinh.default,
|
|
aten.arctan.default,
|
|
aten.arctanh.default,
|
|
aten.clip.default,
|
|
aten.clip.Tensor,
|
|
aten.fix.default,
|
|
aten.negative.default,
|
|
aten.square.default,
|
|
aten.size.int,
|
|
aten.size.Dimname,
|
|
aten.stride.int,
|
|
aten.stride.Dimname,
|
|
aten.repeat_interleave.self_Tensor,
|
|
aten.repeat_interleave.self_int,
|
|
aten.sym_size.int,
|
|
aten.sym_stride.int,
|
|
aten.atleast_1d.Sequence,
|
|
aten.atleast_2d.Sequence,
|
|
aten.atleast_3d.Sequence,
|
|
aten.linear.default,
|
|
aten.conv2d.default,
|
|
aten.conv2d.padding,
|
|
aten.mish_backward.default,
|
|
aten.silu_backward.default,
|
|
aten.index_add.dimname,
|
|
aten.pad_sequence.default,
|
|
aten.index_copy.dimname,
|
|
aten.upsample_nearest1d.vec,
|
|
aten.upsample_nearest2d.vec,
|
|
aten.upsample_nearest3d.vec,
|
|
aten._upsample_nearest_exact1d.vec,
|
|
aten._upsample_nearest_exact2d.vec,
|
|
aten._upsample_nearest_exact3d.vec,
|
|
aten.rnn_tanh.input,
|
|
aten.rnn_tanh.data,
|
|
aten.rnn_relu.input,
|
|
aten.rnn_relu.data,
|
|
aten.lstm.input,
|
|
aten.lstm.data,
|
|
aten.gru.input,
|
|
aten.gru.data,
|
|
aten._upsample_bilinear2d_aa.vec,
|
|
aten._upsample_bicubic2d_aa.vec,
|
|
aten.upsample_bilinear2d.vec,
|
|
aten.upsample_trilinear3d.vec,
|
|
aten.upsample_linear1d.vec,
|
|
aten.matmul.default,
|
|
aten.upsample_bicubic2d.vec,
|
|
aten.__and__.Scalar,
|
|
aten.__and__.Tensor,
|
|
aten.__or__.Tensor,
|
|
aten.__or__.Scalar,
|
|
aten.__xor__.Tensor,
|
|
aten.__xor__.Scalar,
|
|
aten.scatter.dimname_src,
|
|
aten.scatter.dimname_value,
|
|
aten.scatter_add.dimname,
|
|
aten.is_complex.default,
|
|
aten.logsumexp.names,
|
|
aten.where.ScalarOther,
|
|
aten.where.ScalarSelf,
|
|
aten.where.Scalar,
|
|
aten.where.default,
|
|
aten.item.default,
|
|
aten.any.dimname,
|
|
aten.std_mean.default,
|
|
aten.std_mean.dim,
|
|
aten.std_mean.names_dim,
|
|
aten.std_mean.correction_names,
|
|
aten.var_mean.default,
|
|
aten.var_mean.dim,
|
|
aten.var_mean.names_dim,
|
|
aten.var_mean.correction_names,
|
|
aten.broadcast_tensors.default,
|
|
aten.stft.default,
|
|
aten.stft.center,
|
|
aten.istft.default,
|
|
aten.index_fill.Dimname_Scalar,
|
|
aten.index_fill.Dimname_Tensor,
|
|
aten.index_select.dimname,
|
|
aten.diag.default,
|
|
aten.cumsum.dimname,
|
|
aten.cumprod.dimname,
|
|
aten.meshgrid.default,
|
|
aten.meshgrid.indexing,
|
|
aten.fft_fft.default,
|
|
aten.fft_ifft.default,
|
|
aten.fft_rfft.default,
|
|
aten.fft_irfft.default,
|
|
aten.fft_hfft.default,
|
|
aten.fft_ihfft.default,
|
|
aten.fft_fftn.default,
|
|
aten.fft_ifftn.default,
|
|
aten.fft_rfftn.default,
|
|
aten.fft_ihfftn.default,
|
|
aten.fft_irfftn.default,
|
|
aten.fft_hfftn.default,
|
|
aten.fft_fft2.default,
|
|
aten.fft_ifft2.default,
|
|
aten.fft_rfft2.default,
|
|
aten.fft_irfft2.default,
|
|
aten.fft_hfft2.default,
|
|
aten.fft_ihfft2.default,
|
|
aten.fft_fftshift.default,
|
|
aten.fft_ifftshift.default,
|
|
aten.selu.default,
|
|
aten.margin_ranking_loss.default,
|
|
aten.hinge_embedding_loss.default,
|
|
aten.nll_loss.default,
|
|
aten.prelu.default,
|
|
aten.relu6.default,
|
|
aten.pairwise_distance.default,
|
|
aten.pdist.default,
|
|
aten.special_ndtr.default,
|
|
aten.cummax.dimname,
|
|
aten.cummin.dimname,
|
|
aten.logcumsumexp.dimname,
|
|
aten.max.other,
|
|
aten.max.names_dim,
|
|
aten.min.other,
|
|
aten.min.names_dim,
|
|
aten.linalg_eigvals.default,
|
|
aten.median.names_dim,
|
|
aten.nanmedian.names_dim,
|
|
aten.mode.dimname,
|
|
aten.gather.dimname,
|
|
aten.sort.dimname,
|
|
aten.sort.dimname_stable,
|
|
aten.argsort.default,
|
|
aten.argsort.dimname,
|
|
aten.rrelu.default,
|
|
aten.conv_transpose1d.default,
|
|
aten.conv_transpose2d.input,
|
|
aten.conv_transpose3d.input,
|
|
aten.conv1d.default,
|
|
aten.conv1d.padding,
|
|
aten.conv3d.default,
|
|
aten.conv3d.padding,
|
|
aten.float_power.Tensor_Tensor,
|
|
aten.float_power.Tensor_Scalar,
|
|
aten.float_power.Scalar,
|
|
aten.ldexp.Tensor,
|
|
aten._version.default,
|
|
]
|
|
|
|
|
|
def make_test_cls_with_mocked_export(
|
|
cls,
|
|
cls_prefix,
|
|
fn_suffix,
|
|
mocked_export_fn,
|
|
xfail_prop=None,
|
|
test_only_if_no_xfail=False,
|
|
):
|
|
MockedTestClass = type(f"{cls_prefix}{cls.__name__}", cls.__bases__, {})
|
|
MockedTestClass.__qualname__ = MockedTestClass.__name__
|
|
|
|
for name in dir(cls):
|
|
if name.startswith("test_"):
|
|
fn = getattr(cls, name)
|
|
if not callable(fn):
|
|
setattr(MockedTestClass, name, getattr(cls, name))
|
|
continue
|
|
new_name = f"{name}{fn_suffix}"
|
|
new_fn = _make_fn_with_mocked_export(fn, mocked_export_fn)
|
|
new_fn.__name__ = new_name
|
|
if xfail_prop is not None and hasattr(fn, xfail_prop):
|
|
new_fn = unittest.expectedFailure(new_fn)
|
|
elif test_only_if_no_xfail and any(
|
|
x.startswith("_expected_failure") for x in dir(fn)
|
|
):
|
|
new_fn = unittest.skip(
|
|
"Will only be tested if no other tests are failing"
|
|
)(new_fn)
|
|
setattr(MockedTestClass, new_name, new_fn)
|
|
# NB: Doesn't handle slots correctly, but whatever
|
|
elif not hasattr(MockedTestClass, name):
|
|
setattr(MockedTestClass, name, getattr(cls, name))
|
|
|
|
return MockedTestClass
|
|
|
|
|
|
def _make_fn_with_mocked_export(fn, mocked_export_fn):
|
|
@functools.wraps(fn)
|
|
def _fn(*args, **kwargs):
|
|
try:
|
|
from . import test_export
|
|
except ImportError:
|
|
import test_export # @manual=fbcode//caffe2/test:test_export-library
|
|
|
|
with patch(f"{test_export.__name__}.export", mocked_export_fn):
|
|
return fn(*args, **kwargs)
|
|
|
|
return _fn
|
|
|
|
|
|
# Controls tests generated in test/export/test_export_training_ir_to_run_decomp.py
|
|
def expectedFailureTrainingIRToRunDecomp(fn):
|
|
fn._expected_failure_training_ir_to_run_decomp = True
|
|
return fn
|
|
|
|
|
|
# Controls tests generated in test/export/test_export_training_ir_to_run_decomp.py
|
|
def expectedFailureTrainingIRToRunDecompNonStrict(fn):
|
|
fn._expected_failure_training_ir_to_run_decomp_non_strict = True
|
|
return fn
|
|
|
|
|
|
# Controls tests generated in test/export/test_export_strict.py
|
|
def expectedFailureStrict(fn):
|
|
fn._expected_failure_strict = True
|
|
return fn
|
|
|
|
|
|
# Controls tests generated in test/export/test_retraceability.py
|
|
def expectedFailureRetraceability(fn):
|
|
fn._expected_failure_retrace = True
|
|
return fn
|
|
|
|
|
|
# Controls tests generated in test/export/test_retraceability.py
|
|
def expectedFailureRetraceabilityNonStrict(fn):
|
|
fn._expected_failure_retrace_non_strict = True
|
|
return fn
|
|
|
|
|
|
# Controls tests generated in test/export/test_serdes.py
|
|
def expectedFailureSerDer(fn):
|
|
fn._expected_failure_serdes = True
|
|
return fn
|
|
|
|
|
|
# Controls tests generated in test/export/test_serdes.py
|
|
def expectedFailureSerDerNonStrict(fn):
|
|
fn._expected_failure_serdes_non_strict = True
|
|
return fn
|
|
|
|
|
|
def expectedFailureSerDerPreDispatch(fn):
|
|
fn._expected_failure_serdes_pre_dispatch = True
|
|
return fn
|
|
|
|
|
|
def expectedFailurePreDispatchRunDecomp(fn):
|
|
fn._expected_failure_pre_dispatch = True
|
|
return fn
|
|
|
|
|
|
def expectedFailureCppSerDes(fn):
|
|
fn._expected_failure_cpp_serdes = True
|
|
return fn
|
|
|
|
|
|
def expectedFailureCppRuntime(fn):
|
|
fn._expected_failure_cpp_runtime = True
|
|
return fn
|
|
|
|
|
|
def expectedFailureCppRuntimeNonStrict(fn):
|
|
fn._expected_failure_cpp_runtime_non_strict = True
|
|
return fn
|
|
|
|
|
|
# Controls tests generated in test/export/test_export_legacy.py
|
|
def expectedFailureLegacyExportStrict(fn):
|
|
fn._expected_failure_legacy_export = True
|
|
return fn
|
|
|
|
|
|
def expectedFailureLegacyExportNonStrict(fn):
|
|
fn._expected_failure_legacy_export_non_strict = True
|
|
return fn
|