mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[JIT] Add SchemaCheckMode OpInfo test (#82442)
- Move test_schema_check to torch/test directory. - Add opInfo test for SchemaCheckMode to check all operator schemas - Add various changes (using isClose instead of equals, skipping complex number cases for certain ops, etc...) in order to have test_schema_check pass. Differential Revision: [D38437946](https://our.internmc.facebook.com/intern/diff/D38437946) Pull Request resolved: https://github.com/pytorch/pytorch/pull/82442 Approved by: https://github.com/davidberard98
This commit is contained in:
committed by
PyTorch MergeBot
parent
a0b3854548
commit
2b6905413e
@ -55,7 +55,6 @@ from jit.test_typing import TestTyping # noqa: F401
|
||||
from jit.test_hash import TestHash # noqa: F401
|
||||
from jit.test_complex import TestComplex # noqa: F401
|
||||
from jit.test_jit_utils import TestJitUtils # noqa: F401
|
||||
from jit.test_schema_check import TestSchemaCheck # noqa: F401
|
||||
from jit.test_scriptmod_ann import TestScriptModuleInstanceAttributeTypeAnnotation # noqa: F401
|
||||
from jit.test_types import TestTypesAndAnnotation # noqa: F401
|
||||
from jit.test_misc import TestMisc # noqa: F401
|
||||
|
@ -5,19 +5,16 @@ import sys
|
||||
import torch
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.fx.operator_schemas import normalize_function
|
||||
from torch.testing._internal.schema_check_mode import SchemaCheckMode
|
||||
from torch.utils._python_dispatch import enable_torch_dispatch_mode, TorchDispatchMode
|
||||
from torch.testing._internal.common_methods_invocations import op_db
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
from torch.testing._internal.common_device_type import ops, OpDTypes, instantiate_device_type_tests
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
|
||||
# This TorchDispatchTensor Subclass is used to simulate an incorrect schema
|
||||
# which is then used to test that SchemaCheckMode behaves as expected
|
||||
|
||||
@ -403,8 +400,8 @@ class TestSchemaCheck(JitTestCase):
|
||||
def test_overlaps_empty_container(self):
|
||||
x = []
|
||||
y = [torch.rand((3, 3), requires_grad=True)]
|
||||
# Anything overlaps nothing
|
||||
self.assertTrue(torch._C._overlaps(y, x))
|
||||
# Empty containers return false
|
||||
self.assertFalse(torch._C._overlaps(y, x))
|
||||
self.assertTrue(torch._C._overlaps(y, y))
|
||||
|
||||
# Tests that SchemaInfo Bindings work as expected
|
||||
@ -443,3 +440,20 @@ class TestSchemaCheck(JitTestCase):
|
||||
schemaInfoCheck = SchemaInfoBindTestMode(self)
|
||||
with enable_torch_dispatch_mode(schemaInfoCheck):
|
||||
x.add(x)
|
||||
|
||||
|
||||
class TestSchemaCheckModeOpInfo(JitTestCase):
|
||||
@ops(op_db, dtypes=OpDTypes.supported)
|
||||
def test_schema_correctness(self, device, dtype, op):
|
||||
# Currently torch.equal isn't supported with torch.complex32
|
||||
# There's also errors with complex64 and complex128
|
||||
if (dtype == torch.complex32):
|
||||
return
|
||||
for sample in op.sample_inputs(device, dtype, requires_grad=False):
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
op(sample.input, *sample.args, **sample.kwargs)
|
||||
|
||||
instantiate_device_type_tests(TestSchemaCheckModeOpInfo, globals(), only_for=("cpu", "cuda"))
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
@ -237,10 +237,14 @@ bool loadPythonClasses() {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool isEmptyContainer(const py::handle self) {
|
||||
bool is_empty_list =
|
||||
PySequence_Check(self.ptr()) && !PySequence_Size(self.ptr());
|
||||
return is_empty_list;
|
||||
c10::optional<IValue> toTypeInferredIValueOptional(py::handle input) {
|
||||
// Errors need to be caught here because toTypeInferredIValue errors out
|
||||
// on various object types, but we want it to work with all types.
|
||||
try {
|
||||
return toTypeInferredIValue(input);
|
||||
} catch (const c10::Error& e) {
|
||||
return c10::nullopt;
|
||||
}
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
@ -1712,38 +1716,39 @@ void initJITBindings(PyObject* module) {
|
||||
[](SchemaInfo& self,
|
||||
const std::string& name,
|
||||
const py::object& value) {
|
||||
if (isEmptyContainer(value)) {
|
||||
return;
|
||||
}
|
||||
// For normalization purposes there is an inconsistency within
|
||||
// torch.fx that turns all arguments named "self" into "input". Thus
|
||||
// this check ensures that those arguments are checked correctly.
|
||||
if (name == "input" && !self.hasInputArgumentNamed("input")) {
|
||||
self.addArgumentValue("self", toTypeInferredIValue(value));
|
||||
} else {
|
||||
self.addArgumentValue(name, toTypeInferredIValue(value));
|
||||
c10::optional<IValue> i_value = toTypeInferredIValueOptional(value);
|
||||
if (i_value) {
|
||||
// For normalization purposes there is an inconsistency within
|
||||
// torch.fx that turns all arguments named "self" into "input".
|
||||
// Thus this check ensures that those arguments are checked
|
||||
// correctly.
|
||||
if (name == "input" && !self.hasInputArgumentNamed("input")) {
|
||||
self.addArgumentValue("self", *i_value);
|
||||
} else {
|
||||
self.addArgumentValue(name, *i_value);
|
||||
}
|
||||
}
|
||||
})
|
||||
.def("add_argument_values", [](SchemaInfo& self, const py::dict& values) {
|
||||
std::unordered_map<std::string, IValue> value_map;
|
||||
for (const auto& key_pair : values) {
|
||||
IValue key = toTypeInferredIValue(key_pair.first);
|
||||
if (isEmptyContainer(key_pair.second)) {
|
||||
continue;
|
||||
}
|
||||
IValue value = toTypeInferredIValue(key_pair.second);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
key.isString(),
|
||||
"Add argument value keys types should be strings.");
|
||||
// For normalization purposes there is an inconsistency within
|
||||
// torch.fx that
|
||||
// turns all arguments named "self" into "input". Thus this check
|
||||
// ensures that those arguments are checked correctly.
|
||||
if (key.toStringRef() == "input" &&
|
||||
!self.hasInputArgumentNamed("input")) {
|
||||
self.addArgumentValue("self", value);
|
||||
} else {
|
||||
value_map[key.toStringRef()] = value;
|
||||
c10::optional<IValue> value =
|
||||
toTypeInferredIValueOptional(key_pair.second);
|
||||
if (value) {
|
||||
// For normalization purposes there is an inconsistency within
|
||||
// torch.fx that
|
||||
// turns all arguments named "self" into "input". Thus this check
|
||||
// ensures that those arguments are checked correctly.
|
||||
if (key.toStringRef() == "input" &&
|
||||
!self.hasInputArgumentNamed("input")) {
|
||||
self.addArgumentValue("self", *value);
|
||||
} else {
|
||||
value_map[key.toStringRef()] = *value;
|
||||
}
|
||||
}
|
||||
}
|
||||
self.addArgumentValues(value_map);
|
||||
@ -1915,16 +1920,24 @@ void initJITBindings(PyObject* module) {
|
||||
}),
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
m.def("_is_alias_of", [](const py::object& self, const py::object& other) {
|
||||
if (isEmptyContainer(self) || isEmptyContainer(other)) {
|
||||
c10::optional<IValue> self_value = toTypeInferredIValueOptional(self);
|
||||
c10::optional<IValue> other_value = toTypeInferredIValueOptional(other);
|
||||
|
||||
// Only return true if we are certain that self and other are aliasing.
|
||||
if (!self_value || !other_value) {
|
||||
return false;
|
||||
}
|
||||
return toTypeInferredIValue(self).isAliasOf(toTypeInferredIValue(other));
|
||||
return self_value->isAliasOf(*other_value);
|
||||
});
|
||||
m.def("_overlaps", [](const py::object& self, const py::object& other) {
|
||||
if (isEmptyContainer(self) || isEmptyContainer(other)) {
|
||||
return true;
|
||||
c10::optional<IValue> self_value = toTypeInferredIValueOptional(self);
|
||||
c10::optional<IValue> other_value = toTypeInferredIValueOptional(other);
|
||||
|
||||
// Only return true if we are certain that self and other are overlapping.
|
||||
if (!self_value || !other_value) {
|
||||
return false;
|
||||
}
|
||||
return toTypeInferredIValue(self).overlaps(toTypeInferredIValue(other));
|
||||
return self_value->overlaps(*other_value);
|
||||
});
|
||||
m.def("fork", [](const py::args& args, const py::kwargs& kwargs) {
|
||||
AT_ASSERT(args.size() >= 1);
|
||||
|
@ -8788,7 +8788,15 @@ op_db: List[OpInfo] = [
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
|
||||
sample_inputs_func=sample_inputs_addmm),
|
||||
sample_inputs_func=sample_inputs_addmm,
|
||||
skips=(
|
||||
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
|
||||
DecorateInfo(
|
||||
unittest.skip("Skipped!"),
|
||||
'TestSchemaCheckModeOpInfo',
|
||||
'test_schema_correctness',
|
||||
dtypes=(torch.complex64, torch.complex128)),
|
||||
)),
|
||||
OpInfo('addmm',
|
||||
# When alpha=beta=1 as compile-time constants, JIT will decompose addmm into mm and add.
|
||||
variant_test_name='decomposed',
|
||||
@ -8802,6 +8810,12 @@ op_db: List[OpInfo] = [
|
||||
autodiff_nonfusible_nodes=['aten::add', 'aten::mm'],
|
||||
sample_inputs_func=partial(sample_inputs_addmm, alpha=1, beta=1),
|
||||
skips=(
|
||||
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
|
||||
DecorateInfo(
|
||||
unittest.skip("Skipped!"),
|
||||
'TestSchemaCheckModeOpInfo',
|
||||
'test_schema_correctness',
|
||||
dtypes=(torch.complex64, torch.complex128)),
|
||||
# https://github.com/pytorch/pytorch/issues/71784
|
||||
DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness',
|
||||
device_type='cpu', dtypes=(torch.float16,)),
|
||||
@ -8858,7 +8872,15 @@ op_db: List[OpInfo] = [
|
||||
DecorateInfo(
|
||||
toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}),
|
||||
'TestMathBits', 'test_conj_view', device_type='cuda')],
|
||||
sample_inputs_func=sample_inputs_baddbmm),
|
||||
sample_inputs_func=sample_inputs_baddbmm,
|
||||
skips=(
|
||||
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
|
||||
DecorateInfo(
|
||||
unittest.skip("Skipped!"),
|
||||
'TestSchemaCheckModeOpInfo',
|
||||
'test_schema_correctness',
|
||||
dtypes=(torch.complex64, torch.complex128)),
|
||||
)),
|
||||
OpInfo('dot',
|
||||
dtypes=all_types_and_complex_and(torch.bfloat16),
|
||||
dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
|
||||
@ -8867,7 +8889,14 @@ op_db: List[OpInfo] = [
|
||||
sample_inputs_func=sample_inputs_dot_vdot,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
),
|
||||
skips=(
|
||||
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
|
||||
DecorateInfo(
|
||||
unittest.skip("Skipped!"),
|
||||
'TestSchemaCheckModeOpInfo',
|
||||
'test_schema_correctness',
|
||||
dtypes=(torch.complex64, torch.complex128)),
|
||||
)),
|
||||
OpInfo('vdot',
|
||||
dtypes=all_types_and_complex_and(torch.bfloat16),
|
||||
dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
|
||||
@ -8875,7 +8904,14 @@ op_db: List[OpInfo] = [
|
||||
sample_inputs_func=sample_inputs_dot_vdot,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
),
|
||||
skips=(
|
||||
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
|
||||
DecorateInfo(
|
||||
unittest.skip("Skipped!"),
|
||||
'TestSchemaCheckModeOpInfo',
|
||||
'test_schema_correctness',
|
||||
dtypes=(torch.complex64, torch.complex128)),
|
||||
)),
|
||||
OpInfo('bmm',
|
||||
dtypes=all_types_and_complex_and(torch.bfloat16),
|
||||
dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
|
||||
@ -9446,6 +9482,12 @@ op_db: List[OpInfo] = [
|
||||
# See https://github.com/pytorch/pytorch/pull/78358
|
||||
check_batched_forward_grad=False,
|
||||
skips=(
|
||||
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
|
||||
DecorateInfo(
|
||||
unittest.skip("Skipped!"),
|
||||
'TestSchemaCheckModeOpInfo',
|
||||
'test_schema_correctness',
|
||||
dtypes=(torch.complex64, torch.complex128)),
|
||||
# Pre-existing condition (calls .item); needs to be fixed
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'),
|
||||
),
|
||||
@ -9519,6 +9561,12 @@ op_db: List[OpInfo] = [
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
skips=(
|
||||
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
|
||||
DecorateInfo(
|
||||
unittest.skip("Skipped!"),
|
||||
'TestSchemaCheckModeOpInfo',
|
||||
'test_schema_correctness',
|
||||
dtypes=(torch.complex64, torch.complex128)),
|
||||
# Pre-existing condition (calls .item); needs to be fixed
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'),
|
||||
# Pre-existing condition (calls .item); needs to be fixed
|
||||
@ -9914,8 +9962,7 @@ op_db: List[OpInfo] = [
|
||||
# See https://github.com/pytorch/pytorch/pull/78358
|
||||
check_batched_forward_grad=False,
|
||||
decorators=[precisionOverride(
|
||||
{torch.float: 1e-4, torch.cfloat: 1e-4})],
|
||||
),
|
||||
{torch.float: 1e-4, torch.cfloat: 1e-4})]),
|
||||
SpectralFuncInfo('fft.hfft',
|
||||
aten_name='fft_hfft',
|
||||
decomp_aten_name='_fft_c2r',
|
||||
@ -9932,7 +9979,16 @@ op_db: List[OpInfo] = [
|
||||
supports_fwgrad_bwgrad=True,
|
||||
# See https://github.com/pytorch/pytorch/pull/78358
|
||||
check_batched_forward_grad=False,
|
||||
check_batched_gradgrad=False),
|
||||
check_batched_gradgrad=False,
|
||||
skips=(
|
||||
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
|
||||
DecorateInfo(
|
||||
unittest.skip("Skipped!"),
|
||||
'TestSchemaCheckModeOpInfo',
|
||||
'test_schema_correctness',
|
||||
dtypes=(torch.complex64, torch.complex128)
|
||||
),
|
||||
)),
|
||||
SpectralFuncInfo('fft.hfft2',
|
||||
aten_name='fft_hfft2',
|
||||
decomp_aten_name='_fft_c2r',
|
||||
@ -9954,7 +10010,14 @@ op_db: List[OpInfo] = [
|
||||
DecorateInfo(
|
||||
precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}),
|
||||
'TestFFT', 'test_reference_nd')],
|
||||
),
|
||||
skips=(
|
||||
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
|
||||
DecorateInfo(
|
||||
unittest.skip("Skipped!"),
|
||||
'TestSchemaCheckModeOpInfo',
|
||||
'test_schema_correctness'
|
||||
),
|
||||
)),
|
||||
SpectralFuncInfo('fft.hfftn',
|
||||
aten_name='fft_hfftn',
|
||||
decomp_aten_name='_fft_c2r',
|
||||
@ -9976,7 +10039,14 @@ op_db: List[OpInfo] = [
|
||||
DecorateInfo(
|
||||
precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}),
|
||||
'TestFFT', 'test_reference_nd'), ],
|
||||
),
|
||||
skips=(
|
||||
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
|
||||
DecorateInfo(
|
||||
unittest.skip("Skipped!"),
|
||||
'TestSchemaCheckModeOpInfo',
|
||||
'test_schema_correctness'
|
||||
),
|
||||
)),
|
||||
SpectralFuncInfo('fft.rfft',
|
||||
aten_name='fft_rfft',
|
||||
decomp_aten_name='_fft_r2c',
|
||||
@ -10627,7 +10697,15 @@ op_db: List[OpInfo] = [
|
||||
sample_inputs_func=sample_inputs_linalg_vecdot,
|
||||
check_batched_forward_grad=False,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True),
|
||||
supports_fwgrad_bwgrad=True,
|
||||
skips=(
|
||||
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
|
||||
DecorateInfo(
|
||||
unittest.skip("Skipped!"),
|
||||
'TestSchemaCheckModeOpInfo',
|
||||
'test_schema_correctness',
|
||||
dtypes=(torch.complex64, torch.complex128)),
|
||||
)),
|
||||
OpInfo('linalg.cond',
|
||||
aten_name='linalg_cond',
|
||||
dtypes=floating_and_complex_types(),
|
||||
@ -13373,7 +13451,15 @@ op_db: List[OpInfo] = [
|
||||
assert_autodiffed=True,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
sample_inputs_func=sample_inputs_mm),
|
||||
sample_inputs_func=sample_inputs_mm,
|
||||
skips=(
|
||||
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
|
||||
DecorateInfo(
|
||||
unittest.skip("Skipped!"),
|
||||
'TestSchemaCheckModeOpInfo',
|
||||
'test_schema_correctness',
|
||||
dtypes=(torch.complex64, torch.complex128)),
|
||||
)),
|
||||
OpInfo('mode',
|
||||
op=torch.mode,
|
||||
dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
|
||||
@ -14604,6 +14690,12 @@ op_db: List[OpInfo] = [
|
||||
check_batched_gradgrad=False,
|
||||
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
|
||||
skips=(
|
||||
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
|
||||
DecorateInfo(
|
||||
unittest.skip("Skipped!"),
|
||||
'TestSchemaCheckModeOpInfo',
|
||||
'test_schema_correctness',
|
||||
dtypes=(torch.complex64, torch.complex128)),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out',
|
||||
device_type='mps', dtypes=[torch.float32]),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager',
|
||||
@ -16591,6 +16683,12 @@ op_db: List[OpInfo] = [
|
||||
check_batched_forward_grad=False,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
skips=(
|
||||
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
|
||||
DecorateInfo(
|
||||
unittest.skip("Skipped!"),
|
||||
'TestSchemaCheckModeOpInfo',
|
||||
'test_schema_correctness',
|
||||
dtypes=(torch.complex64, torch.complex128)),
|
||||
# Expected RuntimeError when calling with input.device=cpu and out.device=cuda
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
|
||||
# Arguments for call are not valid.
|
||||
@ -17479,6 +17577,12 @@ op_db: List[OpInfo] = [
|
||||
promotes_int_to_float=True,
|
||||
dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
|
||||
skips=(
|
||||
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
|
||||
DecorateInfo(
|
||||
unittest.skip("Skipped!"),
|
||||
'TestSchemaCheckModeOpInfo',
|
||||
'test_schema_correctness',
|
||||
dtypes=(torch.complex64, torch.complex128)),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
|
||||
# FIXME: sum reduces all dimensions when dim=[]
|
||||
DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'),
|
||||
@ -17517,6 +17621,12 @@ op_db: List[OpInfo] = [
|
||||
dtypes=all_types_and_complex_and(torch.bfloat16),
|
||||
dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
|
||||
skips=(
|
||||
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
|
||||
DecorateInfo(
|
||||
unittest.skip("Skipped!"),
|
||||
'TestSchemaCheckModeOpInfo',
|
||||
'test_schema_correctness',
|
||||
dtypes=(torch.complex64, torch.complex128)),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
|
||||
# FIXME: sum reduces all dimensions when dim=[]
|
||||
DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'),
|
||||
@ -17831,6 +17941,13 @@ op_db: List[OpInfo] = [
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
skips=(
|
||||
# https://github.com/pytorch/pytorch/issues/82235
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
'TestSchemaCheckModeOpInfo',
|
||||
'test_schema_correctness',
|
||||
device_type='cuda',
|
||||
),
|
||||
DecorateInfo(
|
||||
unittest.skip("Skipped!"),
|
||||
"TestJit",
|
||||
@ -17847,6 +17964,13 @@ op_db: List[OpInfo] = [
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
skips=(
|
||||
# https://github.com/pytorch/pytorch/issues/82235
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
'TestSchemaCheckModeOpInfo',
|
||||
'test_schema_correctness',
|
||||
device_type='cuda',
|
||||
),
|
||||
DecorateInfo(
|
||||
unittest.skip("Skipped!"),
|
||||
"TestJit",
|
||||
|
@ -41,9 +41,11 @@ class SchemaCheckMode(TorchDispatchMode):
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
def has_mutated(before, after, md):
|
||||
if type(before) == torch.Tensor and type(after) == torch.Tensor:
|
||||
are_tensors = type(before) == torch.Tensor and type(after) == torch.Tensor
|
||||
if are_tensors and before.layout != torch.sparse_csr and after.layout != torch.sparse_csr:
|
||||
return not (
|
||||
torch.equal(before, after) and
|
||||
before.size() == after.size() and
|
||||
torch.allclose(before, after, equal_nan=True) and
|
||||
md[0] == after.stride() and
|
||||
md[1] == after.storage()._cdata
|
||||
)
|
||||
@ -77,7 +79,8 @@ class SchemaCheckMode(TorchDispatchMode):
|
||||
return (deepcopy(current.stride()), current.storage()._cdata)
|
||||
except AttributeError as t:
|
||||
return None
|
||||
else:
|
||||
# Sparse CSR tensors do not have strides or storage
|
||||
elif (e.layout != torch.sparse_csr):
|
||||
return (deepcopy(e.stride()), e.storage()._cdata)
|
||||
return None
|
||||
|
||||
@ -112,7 +115,8 @@ class SchemaCheckMode(TorchDispatchMode):
|
||||
md = cloned_metadata.get(name)
|
||||
after = arguments.get(name)
|
||||
for j in range(len(tuple_out)):
|
||||
if has_aliased(tuple_out[j], after):
|
||||
# aten::_unsafe_view is intended to have incorrect aliasing notation (hence unsafe)
|
||||
if has_aliased(tuple_out[j], after) and func._schema.name != 'aten::_unsafe_view':
|
||||
if not schema_info.may_contain_alias(
|
||||
SchemaArgument(SchemaArgType.output, j),
|
||||
SchemaArgument(SchemaArgType.input, i)):
|
||||
|
Reference in New Issue
Block a user