Files
pytorch/test/export/testing.py
Zhengxu Chen 3ef2dfc1ba [export] Implement cpp deserializer. (#136398)
Differential Revision: D63206258

This diff introduces a mechanism to generate a json-compatible deserializer in cpp using nlohmann json (already being used by AOTI).

Why we need this? Because there will be a lot of cases where people don't want to use Python to load the graph (e.g. cpp runtime), and instead they can use this header to deserialize the JSON graph.

Every time we call update_schema.py to update the schema, the header will be auto generated and included into the source files.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136398
Approved by: https://github.com/angelayi
2024-11-14 16:34:59 +00:00

292 lines
8.0 KiB
Python

import functools
import unittest
from unittest.mock import patch
import torch
aten = torch.ops.aten
# This list is not meant to be comprehensive
_COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY = [
aten.arctan2.default,
aten.divide.Tensor,
aten.divide.Scalar,
aten.divide.Tensor_mode,
aten.divide.Scalar_mode,
aten.multiply.Tensor,
aten.multiply.Scalar,
aten.subtract.Tensor,
aten.subtract.Scalar,
aten.true_divide.Tensor,
aten.true_divide.Scalar,
aten.greater.Tensor,
aten.greater.Scalar,
aten.greater_equal.Tensor,
aten.greater_equal.Scalar,
aten.less_equal.Tensor,
aten.less_equal.Scalar,
aten.less.Tensor,
aten.less.Scalar,
aten.not_equal.Tensor,
aten.not_equal.Scalar,
aten.cat.names,
aten.sum.dim_DimnameList,
aten.mean.names_dim,
aten.prod.dim_Dimname,
aten.all.dimname,
aten.norm.names_ScalarOpt_dim,
aten.norm.names_ScalarOpt_dim_dtype,
aten.var.default,
aten.var.dim,
aten.var.names_dim,
aten.var.correction_names,
aten.std.default,
aten.std.dim,
aten.std.names_dim,
aten.std.correction_names,
aten.absolute.default,
aten.arccos.default,
aten.arccosh.default,
aten.arcsin.default,
aten.arcsinh.default,
aten.arctan.default,
aten.arctanh.default,
aten.clip.default,
aten.clip.Tensor,
aten.fix.default,
aten.negative.default,
aten.square.default,
aten.size.int,
aten.size.Dimname,
aten.stride.int,
aten.stride.Dimname,
aten.repeat_interleave.self_Tensor,
aten.repeat_interleave.self_int,
aten.sym_size.int,
aten.sym_stride.int,
aten.atleast_1d.Sequence,
aten.atleast_2d.Sequence,
aten.atleast_3d.Sequence,
aten.linear.default,
aten.conv2d.default,
aten.conv2d.padding,
aten.mish_backward.default,
aten.silu_backward.default,
aten.index_add.dimname,
aten.pad_sequence.default,
aten.index_copy.dimname,
aten.upsample_nearest1d.vec,
aten.upsample_nearest2d.vec,
aten.upsample_nearest3d.vec,
aten._upsample_nearest_exact1d.vec,
aten._upsample_nearest_exact2d.vec,
aten._upsample_nearest_exact3d.vec,
aten.rnn_tanh.input,
aten.rnn_tanh.data,
aten.rnn_relu.input,
aten.rnn_relu.data,
aten.lstm.input,
aten.lstm.data,
aten.gru.input,
aten.gru.data,
aten._upsample_bilinear2d_aa.vec,
aten._upsample_bicubic2d_aa.vec,
aten.upsample_bilinear2d.vec,
aten.upsample_trilinear3d.vec,
aten.upsample_linear1d.vec,
aten.matmul.default,
aten.upsample_bicubic2d.vec,
aten.__and__.Scalar,
aten.__and__.Tensor,
aten.__or__.Tensor,
aten.__or__.Scalar,
aten.__xor__.Tensor,
aten.__xor__.Scalar,
aten.scatter.dimname_src,
aten.scatter.dimname_value,
aten.scatter_add.dimname,
aten.is_complex.default,
aten.logsumexp.names,
aten.where.ScalarOther,
aten.where.ScalarSelf,
aten.where.Scalar,
aten.where.default,
aten.item.default,
aten.any.dimname,
aten.std_mean.default,
aten.std_mean.dim,
aten.std_mean.names_dim,
aten.std_mean.correction_names,
aten.var_mean.default,
aten.var_mean.dim,
aten.var_mean.names_dim,
aten.var_mean.correction_names,
aten.broadcast_tensors.default,
aten.stft.default,
aten.stft.center,
aten.istft.default,
aten.index_fill.Dimname_Scalar,
aten.index_fill.Dimname_Tensor,
aten.index_select.dimname,
aten.diag.default,
aten.cumsum.dimname,
aten.cumprod.dimname,
aten.meshgrid.default,
aten.meshgrid.indexing,
aten.fft_fft.default,
aten.fft_ifft.default,
aten.fft_rfft.default,
aten.fft_irfft.default,
aten.fft_hfft.default,
aten.fft_ihfft.default,
aten.fft_fftn.default,
aten.fft_ifftn.default,
aten.fft_rfftn.default,
aten.fft_ihfftn.default,
aten.fft_irfftn.default,
aten.fft_hfftn.default,
aten.fft_fft2.default,
aten.fft_ifft2.default,
aten.fft_rfft2.default,
aten.fft_irfft2.default,
aten.fft_hfft2.default,
aten.fft_ihfft2.default,
aten.fft_fftshift.default,
aten.fft_ifftshift.default,
aten.selu.default,
aten.margin_ranking_loss.default,
aten.hinge_embedding_loss.default,
aten.nll_loss.default,
aten.prelu.default,
aten.relu6.default,
aten.pairwise_distance.default,
aten.pdist.default,
aten.special_ndtr.default,
aten.cummax.dimname,
aten.cummin.dimname,
aten.logcumsumexp.dimname,
aten.max.other,
aten.max.names_dim,
aten.min.other,
aten.min.names_dim,
aten.linalg_eigvals.default,
aten.median.names_dim,
aten.nanmedian.names_dim,
aten.mode.dimname,
aten.gather.dimname,
aten.sort.dimname,
aten.sort.dimname_stable,
aten.argsort.default,
aten.argsort.dimname,
aten.rrelu.default,
aten.conv_transpose1d.default,
aten.conv_transpose2d.input,
aten.conv_transpose3d.input,
aten.conv1d.default,
aten.conv1d.padding,
aten.conv3d.default,
aten.conv3d.padding,
aten.float_power.Tensor_Tensor,
aten.float_power.Tensor_Scalar,
aten.float_power.Scalar,
aten.ldexp.Tensor,
aten._version.default,
]
def make_test_cls_with_mocked_export(
cls, cls_prefix, fn_suffix, mocked_export_fn, xfail_prop=None
):
MockedTestClass = type(f"{cls_prefix}{cls.__name__}", cls.__bases__, {})
MockedTestClass.__qualname__ = MockedTestClass.__name__
for name in dir(cls):
if name.startswith("test_"):
fn = getattr(cls, name)
if not callable(fn):
setattr(MockedTestClass, name, getattr(cls, name))
continue
new_name = f"{name}{fn_suffix}"
new_fn = _make_fn_with_mocked_export(fn, mocked_export_fn)
new_fn.__name__ = new_name
if xfail_prop is not None and hasattr(fn, xfail_prop):
new_fn = unittest.expectedFailure(new_fn)
setattr(MockedTestClass, new_name, new_fn)
# NB: Doesn't handle slots correctly, but whatever
elif not hasattr(MockedTestClass, name):
setattr(MockedTestClass, name, getattr(cls, name))
return MockedTestClass
def _make_fn_with_mocked_export(fn, mocked_export_fn):
@functools.wraps(fn)
def _fn(*args, **kwargs):
try:
from . import test_export
except ImportError:
import test_export # @manual=fbcode//caffe2/test:test_export-library
with patch(f"{test_export.__name__}.export", mocked_export_fn):
return fn(*args, **kwargs)
return _fn
# Controls tests generated in test/export/test_export_training_ir_to_run_decomp.py
def expectedFailureTrainingIRToRunDecomp(fn):
fn._expected_failure_training_ir_to_run_decomp = True
return fn
# Controls tests generated in test/export/test_export_training_ir_to_run_decomp.py
def expectedFailureTrainingIRToRunDecompNonStrict(fn):
fn._expected_failure_training_ir_to_run_decomp_non_strict = True
return fn
# Controls tests generated in test/export/test_export_nonstrict.py
def expectedFailureNonStrict(fn):
fn._expected_failure_non_strict = True
return fn
# Controls tests generated in test/export/test_retraceability.py
def expectedFailureRetraceability(fn):
fn._expected_failure_retrace = True
return fn
# Controls tests generated in test/export/test_retraceability.py
def expectedFailureRetraceabilityNonStrict(fn):
fn._expected_failure_retrace_non_strict = True
return fn
# Controls tests generated in test/export/test_serdes.py
def expectedFailureSerDer(fn):
fn._expected_failure_serdes = True
return fn
# Controls tests generated in test/export/test_serdes.py
def expectedFailureSerDerNonStrict(fn):
fn._expected_failure_serdes_non_strict = True
return fn
def expectedFailureSerDerPreDispatch(fn):
fn._expected_failure_serdes_pre_dispatch = True
return fn
def expectedFailurePreDispatchRunDecomp(fn):
fn._expected_failure_pre_dispatch = True
return fn
def expectedFailureCppSerDes(fn):
fn._expected_failure_cpp_serdes = True
return fn