mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Differential Revision: D63206258 This diff introduces a mechanism to generate a json-compatible deserializer in cpp using nlohmann json (already being used by AOTI). Why we need this? Because there will be a lot of cases where people don't want to use Python to load the graph (e.g. cpp runtime), and instead they can use this header to deserialize the JSON graph. Every time we call update_schema.py to update the schema, the header will be auto generated and included into the source files. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136398 Approved by: https://github.com/angelayi
292 lines
8.0 KiB
Python
292 lines
8.0 KiB
Python
import functools
|
|
import unittest
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
|
|
|
|
aten = torch.ops.aten
|
|
|
|
# This list is not meant to be comprehensive
|
|
_COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY = [
|
|
aten.arctan2.default,
|
|
aten.divide.Tensor,
|
|
aten.divide.Scalar,
|
|
aten.divide.Tensor_mode,
|
|
aten.divide.Scalar_mode,
|
|
aten.multiply.Tensor,
|
|
aten.multiply.Scalar,
|
|
aten.subtract.Tensor,
|
|
aten.subtract.Scalar,
|
|
aten.true_divide.Tensor,
|
|
aten.true_divide.Scalar,
|
|
aten.greater.Tensor,
|
|
aten.greater.Scalar,
|
|
aten.greater_equal.Tensor,
|
|
aten.greater_equal.Scalar,
|
|
aten.less_equal.Tensor,
|
|
aten.less_equal.Scalar,
|
|
aten.less.Tensor,
|
|
aten.less.Scalar,
|
|
aten.not_equal.Tensor,
|
|
aten.not_equal.Scalar,
|
|
aten.cat.names,
|
|
aten.sum.dim_DimnameList,
|
|
aten.mean.names_dim,
|
|
aten.prod.dim_Dimname,
|
|
aten.all.dimname,
|
|
aten.norm.names_ScalarOpt_dim,
|
|
aten.norm.names_ScalarOpt_dim_dtype,
|
|
aten.var.default,
|
|
aten.var.dim,
|
|
aten.var.names_dim,
|
|
aten.var.correction_names,
|
|
aten.std.default,
|
|
aten.std.dim,
|
|
aten.std.names_dim,
|
|
aten.std.correction_names,
|
|
aten.absolute.default,
|
|
aten.arccos.default,
|
|
aten.arccosh.default,
|
|
aten.arcsin.default,
|
|
aten.arcsinh.default,
|
|
aten.arctan.default,
|
|
aten.arctanh.default,
|
|
aten.clip.default,
|
|
aten.clip.Tensor,
|
|
aten.fix.default,
|
|
aten.negative.default,
|
|
aten.square.default,
|
|
aten.size.int,
|
|
aten.size.Dimname,
|
|
aten.stride.int,
|
|
aten.stride.Dimname,
|
|
aten.repeat_interleave.self_Tensor,
|
|
aten.repeat_interleave.self_int,
|
|
aten.sym_size.int,
|
|
aten.sym_stride.int,
|
|
aten.atleast_1d.Sequence,
|
|
aten.atleast_2d.Sequence,
|
|
aten.atleast_3d.Sequence,
|
|
aten.linear.default,
|
|
aten.conv2d.default,
|
|
aten.conv2d.padding,
|
|
aten.mish_backward.default,
|
|
aten.silu_backward.default,
|
|
aten.index_add.dimname,
|
|
aten.pad_sequence.default,
|
|
aten.index_copy.dimname,
|
|
aten.upsample_nearest1d.vec,
|
|
aten.upsample_nearest2d.vec,
|
|
aten.upsample_nearest3d.vec,
|
|
aten._upsample_nearest_exact1d.vec,
|
|
aten._upsample_nearest_exact2d.vec,
|
|
aten._upsample_nearest_exact3d.vec,
|
|
aten.rnn_tanh.input,
|
|
aten.rnn_tanh.data,
|
|
aten.rnn_relu.input,
|
|
aten.rnn_relu.data,
|
|
aten.lstm.input,
|
|
aten.lstm.data,
|
|
aten.gru.input,
|
|
aten.gru.data,
|
|
aten._upsample_bilinear2d_aa.vec,
|
|
aten._upsample_bicubic2d_aa.vec,
|
|
aten.upsample_bilinear2d.vec,
|
|
aten.upsample_trilinear3d.vec,
|
|
aten.upsample_linear1d.vec,
|
|
aten.matmul.default,
|
|
aten.upsample_bicubic2d.vec,
|
|
aten.__and__.Scalar,
|
|
aten.__and__.Tensor,
|
|
aten.__or__.Tensor,
|
|
aten.__or__.Scalar,
|
|
aten.__xor__.Tensor,
|
|
aten.__xor__.Scalar,
|
|
aten.scatter.dimname_src,
|
|
aten.scatter.dimname_value,
|
|
aten.scatter_add.dimname,
|
|
aten.is_complex.default,
|
|
aten.logsumexp.names,
|
|
aten.where.ScalarOther,
|
|
aten.where.ScalarSelf,
|
|
aten.where.Scalar,
|
|
aten.where.default,
|
|
aten.item.default,
|
|
aten.any.dimname,
|
|
aten.std_mean.default,
|
|
aten.std_mean.dim,
|
|
aten.std_mean.names_dim,
|
|
aten.std_mean.correction_names,
|
|
aten.var_mean.default,
|
|
aten.var_mean.dim,
|
|
aten.var_mean.names_dim,
|
|
aten.var_mean.correction_names,
|
|
aten.broadcast_tensors.default,
|
|
aten.stft.default,
|
|
aten.stft.center,
|
|
aten.istft.default,
|
|
aten.index_fill.Dimname_Scalar,
|
|
aten.index_fill.Dimname_Tensor,
|
|
aten.index_select.dimname,
|
|
aten.diag.default,
|
|
aten.cumsum.dimname,
|
|
aten.cumprod.dimname,
|
|
aten.meshgrid.default,
|
|
aten.meshgrid.indexing,
|
|
aten.fft_fft.default,
|
|
aten.fft_ifft.default,
|
|
aten.fft_rfft.default,
|
|
aten.fft_irfft.default,
|
|
aten.fft_hfft.default,
|
|
aten.fft_ihfft.default,
|
|
aten.fft_fftn.default,
|
|
aten.fft_ifftn.default,
|
|
aten.fft_rfftn.default,
|
|
aten.fft_ihfftn.default,
|
|
aten.fft_irfftn.default,
|
|
aten.fft_hfftn.default,
|
|
aten.fft_fft2.default,
|
|
aten.fft_ifft2.default,
|
|
aten.fft_rfft2.default,
|
|
aten.fft_irfft2.default,
|
|
aten.fft_hfft2.default,
|
|
aten.fft_ihfft2.default,
|
|
aten.fft_fftshift.default,
|
|
aten.fft_ifftshift.default,
|
|
aten.selu.default,
|
|
aten.margin_ranking_loss.default,
|
|
aten.hinge_embedding_loss.default,
|
|
aten.nll_loss.default,
|
|
aten.prelu.default,
|
|
aten.relu6.default,
|
|
aten.pairwise_distance.default,
|
|
aten.pdist.default,
|
|
aten.special_ndtr.default,
|
|
aten.cummax.dimname,
|
|
aten.cummin.dimname,
|
|
aten.logcumsumexp.dimname,
|
|
aten.max.other,
|
|
aten.max.names_dim,
|
|
aten.min.other,
|
|
aten.min.names_dim,
|
|
aten.linalg_eigvals.default,
|
|
aten.median.names_dim,
|
|
aten.nanmedian.names_dim,
|
|
aten.mode.dimname,
|
|
aten.gather.dimname,
|
|
aten.sort.dimname,
|
|
aten.sort.dimname_stable,
|
|
aten.argsort.default,
|
|
aten.argsort.dimname,
|
|
aten.rrelu.default,
|
|
aten.conv_transpose1d.default,
|
|
aten.conv_transpose2d.input,
|
|
aten.conv_transpose3d.input,
|
|
aten.conv1d.default,
|
|
aten.conv1d.padding,
|
|
aten.conv3d.default,
|
|
aten.conv3d.padding,
|
|
aten.float_power.Tensor_Tensor,
|
|
aten.float_power.Tensor_Scalar,
|
|
aten.float_power.Scalar,
|
|
aten.ldexp.Tensor,
|
|
aten._version.default,
|
|
]
|
|
|
|
|
|
def make_test_cls_with_mocked_export(
|
|
cls, cls_prefix, fn_suffix, mocked_export_fn, xfail_prop=None
|
|
):
|
|
MockedTestClass = type(f"{cls_prefix}{cls.__name__}", cls.__bases__, {})
|
|
MockedTestClass.__qualname__ = MockedTestClass.__name__
|
|
|
|
for name in dir(cls):
|
|
if name.startswith("test_"):
|
|
fn = getattr(cls, name)
|
|
if not callable(fn):
|
|
setattr(MockedTestClass, name, getattr(cls, name))
|
|
continue
|
|
new_name = f"{name}{fn_suffix}"
|
|
new_fn = _make_fn_with_mocked_export(fn, mocked_export_fn)
|
|
new_fn.__name__ = new_name
|
|
if xfail_prop is not None and hasattr(fn, xfail_prop):
|
|
new_fn = unittest.expectedFailure(new_fn)
|
|
setattr(MockedTestClass, new_name, new_fn)
|
|
# NB: Doesn't handle slots correctly, but whatever
|
|
elif not hasattr(MockedTestClass, name):
|
|
setattr(MockedTestClass, name, getattr(cls, name))
|
|
|
|
return MockedTestClass
|
|
|
|
|
|
def _make_fn_with_mocked_export(fn, mocked_export_fn):
|
|
@functools.wraps(fn)
|
|
def _fn(*args, **kwargs):
|
|
try:
|
|
from . import test_export
|
|
except ImportError:
|
|
import test_export # @manual=fbcode//caffe2/test:test_export-library
|
|
|
|
with patch(f"{test_export.__name__}.export", mocked_export_fn):
|
|
return fn(*args, **kwargs)
|
|
|
|
return _fn
|
|
|
|
|
|
# Controls tests generated in test/export/test_export_training_ir_to_run_decomp.py
|
|
def expectedFailureTrainingIRToRunDecomp(fn):
|
|
fn._expected_failure_training_ir_to_run_decomp = True
|
|
return fn
|
|
|
|
|
|
# Controls tests generated in test/export/test_export_training_ir_to_run_decomp.py
|
|
def expectedFailureTrainingIRToRunDecompNonStrict(fn):
|
|
fn._expected_failure_training_ir_to_run_decomp_non_strict = True
|
|
return fn
|
|
|
|
|
|
# Controls tests generated in test/export/test_export_nonstrict.py
|
|
def expectedFailureNonStrict(fn):
|
|
fn._expected_failure_non_strict = True
|
|
return fn
|
|
|
|
|
|
# Controls tests generated in test/export/test_retraceability.py
|
|
def expectedFailureRetraceability(fn):
|
|
fn._expected_failure_retrace = True
|
|
return fn
|
|
|
|
|
|
# Controls tests generated in test/export/test_retraceability.py
|
|
def expectedFailureRetraceabilityNonStrict(fn):
|
|
fn._expected_failure_retrace_non_strict = True
|
|
return fn
|
|
|
|
|
|
# Controls tests generated in test/export/test_serdes.py
|
|
def expectedFailureSerDer(fn):
|
|
fn._expected_failure_serdes = True
|
|
return fn
|
|
|
|
|
|
# Controls tests generated in test/export/test_serdes.py
|
|
def expectedFailureSerDerNonStrict(fn):
|
|
fn._expected_failure_serdes_non_strict = True
|
|
return fn
|
|
|
|
|
|
def expectedFailureSerDerPreDispatch(fn):
|
|
fn._expected_failure_serdes_pre_dispatch = True
|
|
return fn
|
|
|
|
|
|
def expectedFailurePreDispatchRunDecomp(fn):
|
|
fn._expected_failure_pre_dispatch = True
|
|
return fn
|
|
|
|
|
|
def expectedFailureCppSerDes(fn):
|
|
fn._expected_failure_cpp_serdes = True
|
|
return fn
|