Files
pytorch/test/export/testing.py
Tugsbayasgalan Manlaibaatar d8b6622bb6 testing infra and some fixes (#162183)
This PR is quite large in that it covers most of rough edges in the new strict export flow:

1. Handle nn_module_stack correctly now that we are tracing wrapper module
2. module_call_spec needs to get queried from source directly because we are not running the bytecode anymore.
3. Correct input and output handling.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162183
Approved by: https://github.com/zhxchen17
ghstack dependencies: #162167
2025-09-09 02:42:11 +00:00

330 lines
8.9 KiB
Python

import functools
import unittest
from unittest.mock import patch
import torch
aten = torch.ops.aten
# This list is not meant to be comprehensive
_COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY = [
aten.arctan2.default,
aten.divide.Tensor,
aten.divide.Scalar,
aten.divide.Tensor_mode,
aten.divide.Scalar_mode,
aten.multiply.Tensor,
aten.multiply.Scalar,
aten.subtract.Tensor,
aten.subtract.Scalar,
aten.true_divide.Tensor,
aten.true_divide.Scalar,
aten.greater.Tensor,
aten.greater.Scalar,
aten.greater_equal.Tensor,
aten.greater_equal.Scalar,
aten.less_equal.Tensor,
aten.less_equal.Scalar,
aten.less.Tensor,
aten.less.Scalar,
aten.not_equal.Tensor,
aten.not_equal.Scalar,
aten.cat.names,
aten.sum.dim_DimnameList,
aten.mean.names_dim,
aten.prod.dim_Dimname,
aten.all.dimname,
aten.norm.names_ScalarOpt_dim,
aten.norm.names_ScalarOpt_dim_dtype,
aten.var.default,
aten.var.dim,
aten.var.names_dim,
aten.var.correction_names,
aten.std.default,
aten.std.dim,
aten.std.names_dim,
aten.std.correction_names,
aten.absolute.default,
aten.arccos.default,
aten.arccosh.default,
aten.arcsin.default,
aten.arcsinh.default,
aten.arctan.default,
aten.arctanh.default,
aten.clip.default,
aten.clip.Tensor,
aten.fix.default,
aten.negative.default,
aten.square.default,
aten.size.int,
aten.size.Dimname,
aten.stride.int,
aten.stride.Dimname,
aten.repeat_interleave.self_Tensor,
aten.repeat_interleave.self_int,
aten.sym_size.int,
aten.sym_stride.int,
aten.atleast_1d.Sequence,
aten.atleast_2d.Sequence,
aten.atleast_3d.Sequence,
aten.linear.default,
aten.conv2d.default,
aten.conv2d.padding,
aten.mish_backward.default,
aten.silu_backward.default,
aten.index_add.dimname,
aten.pad_sequence.default,
aten.index_copy.dimname,
aten.upsample_nearest1d.vec,
aten.upsample_nearest2d.vec,
aten.upsample_nearest3d.vec,
aten._upsample_nearest_exact1d.vec,
aten._upsample_nearest_exact2d.vec,
aten._upsample_nearest_exact3d.vec,
aten.rnn_tanh.input,
aten.rnn_tanh.data,
aten.rnn_relu.input,
aten.rnn_relu.data,
aten.lstm.input,
aten.lstm.data,
aten.gru.input,
aten.gru.data,
aten._upsample_bilinear2d_aa.vec,
aten._upsample_bicubic2d_aa.vec,
aten.upsample_bilinear2d.vec,
aten.upsample_trilinear3d.vec,
aten.upsample_linear1d.vec,
aten.matmul.default,
aten.upsample_bicubic2d.vec,
aten.__and__.Scalar,
aten.__and__.Tensor,
aten.__or__.Tensor,
aten.__or__.Scalar,
aten.__xor__.Tensor,
aten.__xor__.Scalar,
aten.scatter.dimname_src,
aten.scatter.dimname_value,
aten.scatter_add.dimname,
aten.is_complex.default,
aten.logsumexp.names,
aten.where.ScalarOther,
aten.where.ScalarSelf,
aten.where.Scalar,
aten.where.default,
aten.item.default,
aten.any.dimname,
aten.std_mean.default,
aten.std_mean.dim,
aten.std_mean.names_dim,
aten.std_mean.correction_names,
aten.var_mean.default,
aten.var_mean.dim,
aten.var_mean.names_dim,
aten.var_mean.correction_names,
aten.broadcast_tensors.default,
aten.stft.default,
aten.stft.center,
aten.istft.default,
aten.index_fill.Dimname_Scalar,
aten.index_fill.Dimname_Tensor,
aten.index_select.dimname,
aten.diag.default,
aten.cumsum.dimname,
aten.cumprod.dimname,
aten.meshgrid.default,
aten.meshgrid.indexing,
aten.fft_fft.default,
aten.fft_ifft.default,
aten.fft_rfft.default,
aten.fft_irfft.default,
aten.fft_hfft.default,
aten.fft_ihfft.default,
aten.fft_fftn.default,
aten.fft_ifftn.default,
aten.fft_rfftn.default,
aten.fft_ihfftn.default,
aten.fft_irfftn.default,
aten.fft_hfftn.default,
aten.fft_fft2.default,
aten.fft_ifft2.default,
aten.fft_rfft2.default,
aten.fft_irfft2.default,
aten.fft_hfft2.default,
aten.fft_ihfft2.default,
aten.fft_fftshift.default,
aten.fft_ifftshift.default,
aten.selu.default,
aten.margin_ranking_loss.default,
aten.hinge_embedding_loss.default,
aten.nll_loss.default,
aten.prelu.default,
aten.relu6.default,
aten.pairwise_distance.default,
aten.pdist.default,
aten.special_ndtr.default,
aten.cummax.dimname,
aten.cummin.dimname,
aten.logcumsumexp.dimname,
aten.max.other,
aten.max.names_dim,
aten.min.other,
aten.min.names_dim,
aten.linalg_eigvals.default,
aten.median.names_dim,
aten.nanmedian.names_dim,
aten.mode.dimname,
aten.gather.dimname,
aten.sort.dimname,
aten.sort.dimname_stable,
aten.argsort.default,
aten.argsort.dimname,
aten.rrelu.default,
aten.conv_transpose1d.default,
aten.conv_transpose2d.input,
aten.conv_transpose3d.input,
aten.conv1d.default,
aten.conv1d.padding,
aten.conv3d.default,
aten.conv3d.padding,
aten.float_power.Tensor_Tensor,
aten.float_power.Tensor_Scalar,
aten.float_power.Scalar,
aten.ldexp.Tensor,
aten._version.default,
]
def make_test_cls_with_mocked_export(
cls,
cls_prefix,
fn_suffix,
mocked_export_fn,
xfail_prop=None,
test_only_if_no_xfail=False,
):
MockedTestClass = type(f"{cls_prefix}{cls.__name__}", cls.__bases__, {})
MockedTestClass.__qualname__ = MockedTestClass.__name__
for name in dir(cls):
if name.startswith("test_"):
fn = getattr(cls, name)
if not callable(fn):
setattr(MockedTestClass, name, getattr(cls, name))
continue
new_name = f"{name}{fn_suffix}"
new_fn = _make_fn_with_mocked_export(fn, mocked_export_fn)
new_fn.__name__ = new_name
if xfail_prop is not None and hasattr(fn, xfail_prop):
new_fn = unittest.expectedFailure(new_fn)
elif test_only_if_no_xfail and any(
x.startswith("_expected_failure") for x in dir(fn)
):
new_fn = unittest.skip(
"Will only be tested if no other tests are failing"
)(new_fn)
setattr(MockedTestClass, new_name, new_fn)
# NB: Doesn't handle slots correctly, but whatever
elif not hasattr(MockedTestClass, name):
setattr(MockedTestClass, name, getattr(cls, name))
return MockedTestClass
def _make_fn_with_mocked_export(fn, mocked_export_fn):
@functools.wraps(fn)
def _fn(*args, **kwargs):
try:
from . import test_export
except ImportError:
import test_export # @manual=fbcode//caffe2/test:test_export-library
with patch(f"{test_export.__name__}.export", mocked_export_fn):
return fn(*args, **kwargs)
return _fn
# Controls tests generated in test/export/test_export_training_ir_to_run_decomp.py
def expectedFailureTrainingIRToRunDecomp(fn):
fn._expected_failure_training_ir_to_run_decomp = True
return fn
# Controls tests generated in test/export/test_export_training_ir_to_run_decomp.py
def expectedFailureTrainingIRToRunDecompNonStrict(fn):
fn._expected_failure_training_ir_to_run_decomp_non_strict = True
return fn
# Controls tests generated in test/export/test_export_strict_v2.py
def expectedFailureStrictV2(fn):
fn._expected_failure_strict_v2 = True
return fn
# Controls tests generated in test/export/test_export_strict.py
def expectedFailureStrict(fn):
fn._expected_failure_strict = True
return fn
# Controls tests generated in test/export/test_retraceability.py
def expectedFailureRetraceability(fn):
fn._expected_failure_retrace = True
return fn
# Controls tests generated in test/export/test_retraceability.py
def expectedFailureRetraceabilityNonStrict(fn):
fn._expected_failure_retrace_non_strict = True
return fn
# Controls tests generated in test/export/test_serdes.py
def expectedFailureSerDer(fn):
fn._expected_failure_serdes = True
return fn
# Controls tests generated in test/export/test_serdes.py
def expectedFailureSerDerNonStrict(fn):
fn._expected_failure_serdes_non_strict = True
return fn
def expectedFailureSerDerPreDispatch(fn):
fn._expected_failure_serdes_pre_dispatch = True
return fn
def expectedFailurePreDispatchRunDecomp(fn):
fn._expected_failure_pre_dispatch = True
return fn
def expectedFailureCppSerDes(fn):
fn._expected_failure_cpp_serdes = True
return fn
def expectedFailureCppRuntime(fn):
fn._expected_failure_cpp_runtime = True
return fn
def expectedFailureCppRuntimeNonStrict(fn):
fn._expected_failure_cpp_runtime_non_strict = True
return fn
# Controls tests generated in test/export/test_export_legacy.py
def expectedFailureLegacyExportStrict(fn):
fn._expected_failure_legacy_export = True
return fn
def expectedFailureLegacyExportNonStrict(fn):
fn._expected_failure_legacy_export_non_strict = True
return fn