[JIT] Add SchemaCheckMode OpInfo test (#82442)

- Move test_schema_check to torch/test directory.
- Add opInfo test for SchemaCheckMode to check all operator schemas
- Add various changes (using isClose instead of equals, skipping complex number cases for certain ops, etc...) in order to have test_schema_check pass.

Differential Revision: [D38437946](https://our.internmc.facebook.com/intern/diff/D38437946)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82442
Approved by: https://github.com/davidberard98
This commit is contained in:
goldenxuett
2022-08-08 10:46:15 -07:00
committed by PyTorch MergeBot
parent a0b3854548
commit 2b6905413e
5 changed files with 210 additions and 56 deletions

View File

@ -237,10 +237,14 @@ bool loadPythonClasses() {
return true;
}
bool isEmptyContainer(const py::handle self) {
bool is_empty_list =
PySequence_Check(self.ptr()) && !PySequence_Size(self.ptr());
return is_empty_list;
c10::optional<IValue> toTypeInferredIValueOptional(py::handle input) {
// Errors need to be caught here because toTypeInferredIValue errors out
// on various object types, but we want it to work with all types.
try {
return toTypeInferredIValue(input);
} catch (const c10::Error& e) {
return c10::nullopt;
}
}
} // anonymous namespace
@ -1712,38 +1716,39 @@ void initJITBindings(PyObject* module) {
[](SchemaInfo& self,
const std::string& name,
const py::object& value) {
if (isEmptyContainer(value)) {
return;
}
// For normalization purposes there is an inconsistency within
// torch.fx that turns all arguments named "self" into "input". Thus
// this check ensures that those arguments are checked correctly.
if (name == "input" && !self.hasInputArgumentNamed("input")) {
self.addArgumentValue("self", toTypeInferredIValue(value));
} else {
self.addArgumentValue(name, toTypeInferredIValue(value));
c10::optional<IValue> i_value = toTypeInferredIValueOptional(value);
if (i_value) {
// For normalization purposes there is an inconsistency within
// torch.fx that turns all arguments named "self" into "input".
// Thus this check ensures that those arguments are checked
// correctly.
if (name == "input" && !self.hasInputArgumentNamed("input")) {
self.addArgumentValue("self", *i_value);
} else {
self.addArgumentValue(name, *i_value);
}
}
})
.def("add_argument_values", [](SchemaInfo& self, const py::dict& values) {
std::unordered_map<std::string, IValue> value_map;
for (const auto& key_pair : values) {
IValue key = toTypeInferredIValue(key_pair.first);
if (isEmptyContainer(key_pair.second)) {
continue;
}
IValue value = toTypeInferredIValue(key_pair.second);
TORCH_INTERNAL_ASSERT(
key.isString(),
"Add argument value keys types should be strings.");
// For normalization purposes there is an inconsistency within
// torch.fx that
// turns all arguments named "self" into "input". Thus this check
// ensures that those arguments are checked correctly.
if (key.toStringRef() == "input" &&
!self.hasInputArgumentNamed("input")) {
self.addArgumentValue("self", value);
} else {
value_map[key.toStringRef()] = value;
c10::optional<IValue> value =
toTypeInferredIValueOptional(key_pair.second);
if (value) {
// For normalization purposes there is an inconsistency within
// torch.fx that
// turns all arguments named "self" into "input". Thus this check
// ensures that those arguments are checked correctly.
if (key.toStringRef() == "input" &&
!self.hasInputArgumentNamed("input")) {
self.addArgumentValue("self", *value);
} else {
value_map[key.toStringRef()] = *value;
}
}
}
self.addArgumentValues(value_map);
@ -1915,16 +1920,24 @@ void initJITBindings(PyObject* module) {
}),
py::call_guard<py::gil_scoped_release>());
m.def("_is_alias_of", [](const py::object& self, const py::object& other) {
if (isEmptyContainer(self) || isEmptyContainer(other)) {
c10::optional<IValue> self_value = toTypeInferredIValueOptional(self);
c10::optional<IValue> other_value = toTypeInferredIValueOptional(other);
// Only return true if we are certain that self and other are aliasing.
if (!self_value || !other_value) {
return false;
}
return toTypeInferredIValue(self).isAliasOf(toTypeInferredIValue(other));
return self_value->isAliasOf(*other_value);
});
m.def("_overlaps", [](const py::object& self, const py::object& other) {
if (isEmptyContainer(self) || isEmptyContainer(other)) {
return true;
c10::optional<IValue> self_value = toTypeInferredIValueOptional(self);
c10::optional<IValue> other_value = toTypeInferredIValueOptional(other);
// Only return true if we are certain that self and other are overlapping.
if (!self_value || !other_value) {
return false;
}
return toTypeInferredIValue(self).overlaps(toTypeInferredIValue(other));
return self_value->overlaps(*other_value);
});
m.def("fork", [](const py::args& args, const py::kwargs& kwargs) {
AT_ASSERT(args.size() >= 1);