[JIT] Add SchemaCheckMode OpInfo test (#82442)

- Move test_schema_check to torch/test directory.
- Add opInfo test for SchemaCheckMode to check all operator schemas
- Add various changes (using isClose instead of equals, skipping complex number cases for certain ops, etc...) in order to have test_schema_check pass.

Differential Revision: [D38437946](https://our.internmc.facebook.com/intern/diff/D38437946)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82442
Approved by: https://github.com/davidberard98
This commit is contained in:
goldenxuett
2022-08-08 10:46:15 -07:00
committed by PyTorch MergeBot
parent a0b3854548
commit 2b6905413e
5 changed files with 210 additions and 56 deletions

View File

@ -55,7 +55,6 @@ from jit.test_typing import TestTyping # noqa: F401
from jit.test_hash import TestHash # noqa: F401
from jit.test_complex import TestComplex # noqa: F401
from jit.test_jit_utils import TestJitUtils # noqa: F401
from jit.test_schema_check import TestSchemaCheck # noqa: F401
from jit.test_scriptmod_ann import TestScriptModuleInstanceAttributeTypeAnnotation # noqa: F401
from jit.test_types import TestTypesAndAnnotation # noqa: F401
from jit.test_misc import TestMisc # noqa: F401

View File

@ -5,19 +5,16 @@ import sys
import torch
from torch.utils._pytree import tree_map
from torch.testing._internal.common_utils import run_tests
from torch.fx.operator_schemas import normalize_function
from torch.testing._internal.schema_check_mode import SchemaCheckMode
from torch.utils._python_dispatch import enable_torch_dispatch_mode, TorchDispatchMode
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.common_device_type import ops, OpDTypes, instantiate_device_type_tests
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
# This TorchDispatchTensor Subclass is used to simulate an incorrect schema
# which is then used to test that SchemaCheckMode behaves as expected
@ -403,8 +400,8 @@ class TestSchemaCheck(JitTestCase):
def test_overlaps_empty_container(self):
x = []
y = [torch.rand((3, 3), requires_grad=True)]
# Anything overlaps nothing
self.assertTrue(torch._C._overlaps(y, x))
# Empty containers return false
self.assertFalse(torch._C._overlaps(y, x))
self.assertTrue(torch._C._overlaps(y, y))
# Tests that SchemaInfo Bindings work as expected
@ -443,3 +440,20 @@ class TestSchemaCheck(JitTestCase):
schemaInfoCheck = SchemaInfoBindTestMode(self)
with enable_torch_dispatch_mode(schemaInfoCheck):
x.add(x)
class TestSchemaCheckModeOpInfo(JitTestCase):
@ops(op_db, dtypes=OpDTypes.supported)
def test_schema_correctness(self, device, dtype, op):
# Currently torch.equal isn't supported with torch.complex32
# There's also errors with complex64 and complex128
if (dtype == torch.complex32):
return
for sample in op.sample_inputs(device, dtype, requires_grad=False):
with enable_torch_dispatch_mode(SchemaCheckMode()):
op(sample.input, *sample.args, **sample.kwargs)
instantiate_device_type_tests(TestSchemaCheckModeOpInfo, globals(), only_for=("cpu", "cuda"))
if __name__ == '__main__':
run_tests()

View File

@ -237,10 +237,14 @@ bool loadPythonClasses() {
return true;
}
bool isEmptyContainer(const py::handle self) {
bool is_empty_list =
PySequence_Check(self.ptr()) && !PySequence_Size(self.ptr());
return is_empty_list;
c10::optional<IValue> toTypeInferredIValueOptional(py::handle input) {
// Errors need to be caught here because toTypeInferredIValue errors out
// on various object types, but we want it to work with all types.
try {
return toTypeInferredIValue(input);
} catch (const c10::Error& e) {
return c10::nullopt;
}
}
} // anonymous namespace
@ -1712,38 +1716,39 @@ void initJITBindings(PyObject* module) {
[](SchemaInfo& self,
const std::string& name,
const py::object& value) {
if (isEmptyContainer(value)) {
return;
}
// For normalization purposes there is an inconsistency within
// torch.fx that turns all arguments named "self" into "input". Thus
// this check ensures that those arguments are checked correctly.
if (name == "input" && !self.hasInputArgumentNamed("input")) {
self.addArgumentValue("self", toTypeInferredIValue(value));
} else {
self.addArgumentValue(name, toTypeInferredIValue(value));
c10::optional<IValue> i_value = toTypeInferredIValueOptional(value);
if (i_value) {
// For normalization purposes there is an inconsistency within
// torch.fx that turns all arguments named "self" into "input".
// Thus this check ensures that those arguments are checked
// correctly.
if (name == "input" && !self.hasInputArgumentNamed("input")) {
self.addArgumentValue("self", *i_value);
} else {
self.addArgumentValue(name, *i_value);
}
}
})
.def("add_argument_values", [](SchemaInfo& self, const py::dict& values) {
std::unordered_map<std::string, IValue> value_map;
for (const auto& key_pair : values) {
IValue key = toTypeInferredIValue(key_pair.first);
if (isEmptyContainer(key_pair.second)) {
continue;
}
IValue value = toTypeInferredIValue(key_pair.second);
TORCH_INTERNAL_ASSERT(
key.isString(),
"Add argument value keys types should be strings.");
// For normalization purposes there is an inconsistency within
// torch.fx that
// turns all arguments named "self" into "input". Thus this check
// ensures that those arguments are checked correctly.
if (key.toStringRef() == "input" &&
!self.hasInputArgumentNamed("input")) {
self.addArgumentValue("self", value);
} else {
value_map[key.toStringRef()] = value;
c10::optional<IValue> value =
toTypeInferredIValueOptional(key_pair.second);
if (value) {
// For normalization purposes there is an inconsistency within
// torch.fx that
// turns all arguments named "self" into "input". Thus this check
// ensures that those arguments are checked correctly.
if (key.toStringRef() == "input" &&
!self.hasInputArgumentNamed("input")) {
self.addArgumentValue("self", *value);
} else {
value_map[key.toStringRef()] = *value;
}
}
}
self.addArgumentValues(value_map);
@ -1915,16 +1920,24 @@ void initJITBindings(PyObject* module) {
}),
py::call_guard<py::gil_scoped_release>());
m.def("_is_alias_of", [](const py::object& self, const py::object& other) {
if (isEmptyContainer(self) || isEmptyContainer(other)) {
c10::optional<IValue> self_value = toTypeInferredIValueOptional(self);
c10::optional<IValue> other_value = toTypeInferredIValueOptional(other);
// Only return true if we are certain that self and other are aliasing.
if (!self_value || !other_value) {
return false;
}
return toTypeInferredIValue(self).isAliasOf(toTypeInferredIValue(other));
return self_value->isAliasOf(*other_value);
});
m.def("_overlaps", [](const py::object& self, const py::object& other) {
if (isEmptyContainer(self) || isEmptyContainer(other)) {
return true;
c10::optional<IValue> self_value = toTypeInferredIValueOptional(self);
c10::optional<IValue> other_value = toTypeInferredIValueOptional(other);
// Only return true if we are certain that self and other are overlapping.
if (!self_value || !other_value) {
return false;
}
return toTypeInferredIValue(self).overlaps(toTypeInferredIValue(other));
return self_value->overlaps(*other_value);
});
m.def("fork", [](const py::args& args, const py::kwargs& kwargs) {
AT_ASSERT(args.size() >= 1);

View File

@ -8788,7 +8788,15 @@ op_db: List[OpInfo] = [
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
sample_inputs_func=sample_inputs_addmm),
sample_inputs_func=sample_inputs_addmm,
skips=(
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
DecorateInfo(
unittest.skip("Skipped!"),
'TestSchemaCheckModeOpInfo',
'test_schema_correctness',
dtypes=(torch.complex64, torch.complex128)),
)),
OpInfo('addmm',
# When alpha=beta=1 as compile-time constants, JIT will decompose addmm into mm and add.
variant_test_name='decomposed',
@ -8802,6 +8810,12 @@ op_db: List[OpInfo] = [
autodiff_nonfusible_nodes=['aten::add', 'aten::mm'],
sample_inputs_func=partial(sample_inputs_addmm, alpha=1, beta=1),
skips=(
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
DecorateInfo(
unittest.skip("Skipped!"),
'TestSchemaCheckModeOpInfo',
'test_schema_correctness',
dtypes=(torch.complex64, torch.complex128)),
# https://github.com/pytorch/pytorch/issues/71784
DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness',
device_type='cpu', dtypes=(torch.float16,)),
@ -8858,7 +8872,15 @@ op_db: List[OpInfo] = [
DecorateInfo(
toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}),
'TestMathBits', 'test_conj_view', device_type='cuda')],
sample_inputs_func=sample_inputs_baddbmm),
sample_inputs_func=sample_inputs_baddbmm,
skips=(
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
DecorateInfo(
unittest.skip("Skipped!"),
'TestSchemaCheckModeOpInfo',
'test_schema_correctness',
dtypes=(torch.complex64, torch.complex128)),
)),
OpInfo('dot',
dtypes=all_types_and_complex_and(torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
@ -8867,7 +8889,14 @@ op_db: List[OpInfo] = [
sample_inputs_func=sample_inputs_dot_vdot,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
),
skips=(
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
DecorateInfo(
unittest.skip("Skipped!"),
'TestSchemaCheckModeOpInfo',
'test_schema_correctness',
dtypes=(torch.complex64, torch.complex128)),
)),
OpInfo('vdot',
dtypes=all_types_and_complex_and(torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
@ -8875,7 +8904,14 @@ op_db: List[OpInfo] = [
sample_inputs_func=sample_inputs_dot_vdot,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
),
skips=(
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
DecorateInfo(
unittest.skip("Skipped!"),
'TestSchemaCheckModeOpInfo',
'test_schema_correctness',
dtypes=(torch.complex64, torch.complex128)),
)),
OpInfo('bmm',
dtypes=all_types_and_complex_and(torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
@ -9446,6 +9482,12 @@ op_db: List[OpInfo] = [
# See https://github.com/pytorch/pytorch/pull/78358
check_batched_forward_grad=False,
skips=(
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
DecorateInfo(
unittest.skip("Skipped!"),
'TestSchemaCheckModeOpInfo',
'test_schema_correctness',
dtypes=(torch.complex64, torch.complex128)),
# Pre-existing condition (calls .item); needs to be fixed
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'),
),
@ -9519,6 +9561,12 @@ op_db: List[OpInfo] = [
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
skips=(
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
DecorateInfo(
unittest.skip("Skipped!"),
'TestSchemaCheckModeOpInfo',
'test_schema_correctness',
dtypes=(torch.complex64, torch.complex128)),
# Pre-existing condition (calls .item); needs to be fixed
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'),
# Pre-existing condition (calls .item); needs to be fixed
@ -9914,8 +9962,7 @@ op_db: List[OpInfo] = [
# See https://github.com/pytorch/pytorch/pull/78358
check_batched_forward_grad=False,
decorators=[precisionOverride(
{torch.float: 1e-4, torch.cfloat: 1e-4})],
),
{torch.float: 1e-4, torch.cfloat: 1e-4})]),
SpectralFuncInfo('fft.hfft',
aten_name='fft_hfft',
decomp_aten_name='_fft_c2r',
@ -9932,7 +9979,16 @@ op_db: List[OpInfo] = [
supports_fwgrad_bwgrad=True,
# See https://github.com/pytorch/pytorch/pull/78358
check_batched_forward_grad=False,
check_batched_gradgrad=False),
check_batched_gradgrad=False,
skips=(
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
DecorateInfo(
unittest.skip("Skipped!"),
'TestSchemaCheckModeOpInfo',
'test_schema_correctness',
dtypes=(torch.complex64, torch.complex128)
),
)),
SpectralFuncInfo('fft.hfft2',
aten_name='fft_hfft2',
decomp_aten_name='_fft_c2r',
@ -9954,7 +10010,14 @@ op_db: List[OpInfo] = [
DecorateInfo(
precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}),
'TestFFT', 'test_reference_nd')],
),
skips=(
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
DecorateInfo(
unittest.skip("Skipped!"),
'TestSchemaCheckModeOpInfo',
'test_schema_correctness'
),
)),
SpectralFuncInfo('fft.hfftn',
aten_name='fft_hfftn',
decomp_aten_name='_fft_c2r',
@ -9976,7 +10039,14 @@ op_db: List[OpInfo] = [
DecorateInfo(
precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}),
'TestFFT', 'test_reference_nd'), ],
),
skips=(
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
DecorateInfo(
unittest.skip("Skipped!"),
'TestSchemaCheckModeOpInfo',
'test_schema_correctness'
),
)),
SpectralFuncInfo('fft.rfft',
aten_name='fft_rfft',
decomp_aten_name='_fft_r2c',
@ -10627,7 +10697,15 @@ op_db: List[OpInfo] = [
sample_inputs_func=sample_inputs_linalg_vecdot,
check_batched_forward_grad=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True),
supports_fwgrad_bwgrad=True,
skips=(
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
DecorateInfo(
unittest.skip("Skipped!"),
'TestSchemaCheckModeOpInfo',
'test_schema_correctness',
dtypes=(torch.complex64, torch.complex128)),
)),
OpInfo('linalg.cond',
aten_name='linalg_cond',
dtypes=floating_and_complex_types(),
@ -13373,7 +13451,15 @@ op_db: List[OpInfo] = [
assert_autodiffed=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_mm),
sample_inputs_func=sample_inputs_mm,
skips=(
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
DecorateInfo(
unittest.skip("Skipped!"),
'TestSchemaCheckModeOpInfo',
'test_schema_correctness',
dtypes=(torch.complex64, torch.complex128)),
)),
OpInfo('mode',
op=torch.mode,
dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
@ -14604,6 +14690,12 @@ op_db: List[OpInfo] = [
check_batched_gradgrad=False,
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
skips=(
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
DecorateInfo(
unittest.skip("Skipped!"),
'TestSchemaCheckModeOpInfo',
'test_schema_correctness',
dtypes=(torch.complex64, torch.complex128)),
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out',
device_type='mps', dtypes=[torch.float32]),
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager',
@ -16591,6 +16683,12 @@ op_db: List[OpInfo] = [
check_batched_forward_grad=False,
supports_fwgrad_bwgrad=True,
skips=(
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
DecorateInfo(
unittest.skip("Skipped!"),
'TestSchemaCheckModeOpInfo',
'test_schema_correctness',
dtypes=(torch.complex64, torch.complex128)),
# Expected RuntimeError when calling with input.device=cpu and out.device=cuda
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
# Arguments for call are not valid.
@ -17479,6 +17577,12 @@ op_db: List[OpInfo] = [
promotes_int_to_float=True,
dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
skips=(
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
DecorateInfo(
unittest.skip("Skipped!"),
'TestSchemaCheckModeOpInfo',
'test_schema_correctness',
dtypes=(torch.complex64, torch.complex128)),
DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
# FIXME: sum reduces all dimensions when dim=[]
DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'),
@ -17517,6 +17621,12 @@ op_db: List[OpInfo] = [
dtypes=all_types_and_complex_and(torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
skips=(
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
DecorateInfo(
unittest.skip("Skipped!"),
'TestSchemaCheckModeOpInfo',
'test_schema_correctness',
dtypes=(torch.complex64, torch.complex128)),
DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
# FIXME: sum reduces all dimensions when dim=[]
DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'),
@ -17831,6 +17941,13 @@ op_db: List[OpInfo] = [
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
skips=(
# https://github.com/pytorch/pytorch/issues/82235
DecorateInfo(
unittest.expectedFailure,
'TestSchemaCheckModeOpInfo',
'test_schema_correctness',
device_type='cuda',
),
DecorateInfo(
unittest.skip("Skipped!"),
"TestJit",
@ -17847,6 +17964,13 @@ op_db: List[OpInfo] = [
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
skips=(
# https://github.com/pytorch/pytorch/issues/82235
DecorateInfo(
unittest.expectedFailure,
'TestSchemaCheckModeOpInfo',
'test_schema_correctness',
device_type='cuda',
),
DecorateInfo(
unittest.skip("Skipped!"),
"TestJit",

View File

@ -41,9 +41,11 @@ class SchemaCheckMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
def has_mutated(before, after, md):
if type(before) == torch.Tensor and type(after) == torch.Tensor:
are_tensors = type(before) == torch.Tensor and type(after) == torch.Tensor
if are_tensors and before.layout != torch.sparse_csr and after.layout != torch.sparse_csr:
return not (
torch.equal(before, after) and
before.size() == after.size() and
torch.allclose(before, after, equal_nan=True) and
md[0] == after.stride() and
md[1] == after.storage()._cdata
)
@ -77,7 +79,8 @@ class SchemaCheckMode(TorchDispatchMode):
return (deepcopy(current.stride()), current.storage()._cdata)
except AttributeError as t:
return None
else:
# Sparse CSR tensors do not have strides or storage
elif (e.layout != torch.sparse_csr):
return (deepcopy(e.stride()), e.storage()._cdata)
return None
@ -112,7 +115,8 @@ class SchemaCheckMode(TorchDispatchMode):
md = cloned_metadata.get(name)
after = arguments.get(name)
for j in range(len(tuple_out)):
if has_aliased(tuple_out[j], after):
# aten::_unsafe_view is intended to have incorrect aliasing notation (hence unsafe)
if has_aliased(tuple_out[j], after) and func._schema.name != 'aten::_unsafe_view':
if not schema_info.may_contain_alias(
SchemaArgument(SchemaArgType.output, j),
SchemaArgument(SchemaArgType.input, i)):