[JIT] Add basic aliasing checks for tensor inputs

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79474

Approved by: https://github.com/davidberard98
This commit is contained in:
goldenxuett
2022-06-16 17:48:58 -07:00
committed by PyTorch MergeBot
parent e8727994eb
commit 1432a3d6ac
3 changed files with 142 additions and 10 deletions

View File

@ -1499,9 +1499,34 @@ void initJITBindings(PyObject* module) {
const AliasInfo* aliasInfo = self.alias_info();
return aliasInfo && aliasInfo->isWrite();
})
.def_property_readonly(
"before_set",
[](Argument& self) {
const AliasInfo* aliasInfo = self.alias_info();
std::set<py::str> before_set_python;
if (aliasInfo) {
for (const auto& set : aliasInfo->beforeSets()) {
before_set_python.insert(py::str(set.toUnqualString()));
}
}
return before_set_python;
})
.def_property_readonly(
"after_set",
[](Argument& self) {
const AliasInfo* aliasInfo = self.alias_info();
std::set<py::str> after_set_python;
if (aliasInfo) {
for (const auto& set : aliasInfo->afterSets()) {
after_set_python.insert(py::str(set.toUnqualString()));
}
}
return after_set_python;
})
.def_property_readonly("kwarg_only", [](Argument& self) -> bool {
return self.kwarg_only();
});
m.def("_jit_get_all_schemas", []() {
const std::vector<std::shared_ptr<Operator>>& operations =
getAllOperators();
@ -1584,7 +1609,9 @@ void initJITBindings(PyObject* module) {
return nullptr;
}),
py::call_guard<py::gil_scoped_release>());
m.def("_is_alias_of", [](const at::Tensor& self, const at::Tensor& other) {
return self.is_alias_of(other);
});
m.def("fork", [](const py::args& args, const py::kwargs& kwargs) {
AT_ASSERT(args.size() >= 1);