[JIT] Bind AliasInfo to decrease differences in interfaces across languages

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79661

Approved by: https://github.com/davidberard98
This commit is contained in:
goldenxuett
2022-06-17 14:59:45 -07:00
committed by PyTorch MergeBot
parent ee715e0a65
commit f6d9a9a952
2 changed files with 36 additions and 40 deletions

View File

@ -1491,42 +1491,32 @@ void initJITBindings(PyObject* module) {
[](Argument& self) -> py::bool_ {
return self.default_value().has_value();
})
.def_property_readonly(
"alias_info", [](Argument& self) { return self.alias_info(); })
.def_property_readonly(
"is_out", [](Argument& self) { return self.is_out(); })
.def_property_readonly(
"is_mutable",
[](Argument& self) {
const AliasInfo* aliasInfo = self.alias_info();
return aliasInfo && aliasInfo->isWrite();
})
.def_property_readonly(
"before_set",
[](Argument& self) {
const AliasInfo* aliasInfo = self.alias_info();
std::set<py::str> before_set_python;
if (aliasInfo) {
for (const auto& set : aliasInfo->beforeSets()) {
before_set_python.insert(py::str(set.toUnqualString()));
}
}
return before_set_python;
})
.def_property_readonly(
"after_set",
[](Argument& self) {
const AliasInfo* aliasInfo = self.alias_info();
std::set<py::str> after_set_python;
if (aliasInfo) {
for (const auto& set : aliasInfo->afterSets()) {
after_set_python.insert(py::str(set.toUnqualString()));
}
}
return after_set_python;
})
.def_property_readonly("kwarg_only", [](Argument& self) -> bool {
return self.kwarg_only();
});
py::class_<AliasInfo>(m, "_AliasInfo")
.def_property_readonly(
"is_write", [](AliasInfo& self) { return self.isWrite(); })
.def_property_readonly(
"before_set",
[](AliasInfo& self) {
std::set<py::str> before_set_python;
for (const auto& set : self.beforeSets()) {
before_set_python.insert(py::str(set.toUnqualString()));
}
return before_set_python;
})
.def_property_readonly("after_set", [](AliasInfo& self) {
std::set<py::str> after_set_python;
for (const auto& set : self.afterSets()) {
after_set_python.insert(py::str(set.toUnqualString()));
}
return after_set_python;
});
m.def("_jit_get_all_schemas", []() {
const std::vector<std::shared_ptr<Operator>>& operations =
getAllOperators();