mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
[JIT] Add SchemaCheckMode OpInfo test (#82442)
- Move test_schema_check to torch/test directory. - Add opInfo test for SchemaCheckMode to check all operator schemas - Add various changes (using isClose instead of equals, skipping complex number cases for certain ops, etc...) in order to have test_schema_check pass. Differential Revision: [D38437946](https://our.internmc.facebook.com/intern/diff/D38437946) Pull Request resolved: https://github.com/pytorch/pytorch/pull/82442 Approved by: https://github.com/davidberard98
This commit is contained in:
committed by
PyTorch MergeBot
parent
a0b3854548
commit
2b6905413e
@ -237,10 +237,14 @@ bool loadPythonClasses() {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool isEmptyContainer(const py::handle self) {
|
||||
bool is_empty_list =
|
||||
PySequence_Check(self.ptr()) && !PySequence_Size(self.ptr());
|
||||
return is_empty_list;
|
||||
c10::optional<IValue> toTypeInferredIValueOptional(py::handle input) {
|
||||
// Errors need to be caught here because toTypeInferredIValue errors out
|
||||
// on various object types, but we want it to work with all types.
|
||||
try {
|
||||
return toTypeInferredIValue(input);
|
||||
} catch (const c10::Error& e) {
|
||||
return c10::nullopt;
|
||||
}
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
@ -1712,38 +1716,39 @@ void initJITBindings(PyObject* module) {
|
||||
[](SchemaInfo& self,
|
||||
const std::string& name,
|
||||
const py::object& value) {
|
||||
if (isEmptyContainer(value)) {
|
||||
return;
|
||||
}
|
||||
// For normalization purposes there is an inconsistency within
|
||||
// torch.fx that turns all arguments named "self" into "input". Thus
|
||||
// this check ensures that those arguments are checked correctly.
|
||||
if (name == "input" && !self.hasInputArgumentNamed("input")) {
|
||||
self.addArgumentValue("self", toTypeInferredIValue(value));
|
||||
} else {
|
||||
self.addArgumentValue(name, toTypeInferredIValue(value));
|
||||
c10::optional<IValue> i_value = toTypeInferredIValueOptional(value);
|
||||
if (i_value) {
|
||||
// For normalization purposes there is an inconsistency within
|
||||
// torch.fx that turns all arguments named "self" into "input".
|
||||
// Thus this check ensures that those arguments are checked
|
||||
// correctly.
|
||||
if (name == "input" && !self.hasInputArgumentNamed("input")) {
|
||||
self.addArgumentValue("self", *i_value);
|
||||
} else {
|
||||
self.addArgumentValue(name, *i_value);
|
||||
}
|
||||
}
|
||||
})
|
||||
.def("add_argument_values", [](SchemaInfo& self, const py::dict& values) {
|
||||
std::unordered_map<std::string, IValue> value_map;
|
||||
for (const auto& key_pair : values) {
|
||||
IValue key = toTypeInferredIValue(key_pair.first);
|
||||
if (isEmptyContainer(key_pair.second)) {
|
||||
continue;
|
||||
}
|
||||
IValue value = toTypeInferredIValue(key_pair.second);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
key.isString(),
|
||||
"Add argument value keys types should be strings.");
|
||||
// For normalization purposes there is an inconsistency within
|
||||
// torch.fx that
|
||||
// turns all arguments named "self" into "input". Thus this check
|
||||
// ensures that those arguments are checked correctly.
|
||||
if (key.toStringRef() == "input" &&
|
||||
!self.hasInputArgumentNamed("input")) {
|
||||
self.addArgumentValue("self", value);
|
||||
} else {
|
||||
value_map[key.toStringRef()] = value;
|
||||
c10::optional<IValue> value =
|
||||
toTypeInferredIValueOptional(key_pair.second);
|
||||
if (value) {
|
||||
// For normalization purposes there is an inconsistency within
|
||||
// torch.fx that
|
||||
// turns all arguments named "self" into "input". Thus this check
|
||||
// ensures that those arguments are checked correctly.
|
||||
if (key.toStringRef() == "input" &&
|
||||
!self.hasInputArgumentNamed("input")) {
|
||||
self.addArgumentValue("self", *value);
|
||||
} else {
|
||||
value_map[key.toStringRef()] = *value;
|
||||
}
|
||||
}
|
||||
}
|
||||
self.addArgumentValues(value_map);
|
||||
@ -1915,16 +1920,24 @@ void initJITBindings(PyObject* module) {
|
||||
}),
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
m.def("_is_alias_of", [](const py::object& self, const py::object& other) {
|
||||
if (isEmptyContainer(self) || isEmptyContainer(other)) {
|
||||
c10::optional<IValue> self_value = toTypeInferredIValueOptional(self);
|
||||
c10::optional<IValue> other_value = toTypeInferredIValueOptional(other);
|
||||
|
||||
// Only return true if we are certain that self and other are aliasing.
|
||||
if (!self_value || !other_value) {
|
||||
return false;
|
||||
}
|
||||
return toTypeInferredIValue(self).isAliasOf(toTypeInferredIValue(other));
|
||||
return self_value->isAliasOf(*other_value);
|
||||
});
|
||||
m.def("_overlaps", [](const py::object& self, const py::object& other) {
|
||||
if (isEmptyContainer(self) || isEmptyContainer(other)) {
|
||||
return true;
|
||||
c10::optional<IValue> self_value = toTypeInferredIValueOptional(self);
|
||||
c10::optional<IValue> other_value = toTypeInferredIValueOptional(other);
|
||||
|
||||
// Only return true if we are certain that self and other are overlapping.
|
||||
if (!self_value || !other_value) {
|
||||
return false;
|
||||
}
|
||||
return toTypeInferredIValue(self).overlaps(toTypeInferredIValue(other));
|
||||
return self_value->overlaps(*other_value);
|
||||
});
|
||||
m.def("fork", [](const py::args& args, const py::kwargs& kwargs) {
|
||||
AT_ASSERT(args.size() >= 1);
|
||||
|
Reference in New Issue
Block a user