Added a couple more symint magic methods + symbolic shape infra (#81086)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81086
Approved by: https://github.com/ezyang
This commit is contained in:
Horace He
2022-07-16 03:22:27 +00:00
committed by PyTorch MergeBot
parent 7ccf693cf6
commit 97938d872e
4 changed files with 166 additions and 98 deletions

View File

@ -213,6 +213,16 @@ class PythonSymbolicIntNode : public c10::SymbolicIntNode {
return dispatch_common_(__FUNCTION__, other);
}
virtual std::shared_ptr<SymbolicIntNode> le(
const std::shared_ptr<SymbolicIntNode>& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual std::shared_ptr<SymbolicIntNode> ge(
const std::shared_ptr<SymbolicIntNode>& other) override {
return dispatch_common_(__FUNCTION__, other);
}
py::handle getPyObj() {
return py::handle(pyobj_.get()->ptr(getPyInterpreter()));
}
@ -1257,6 +1267,20 @@ void initJITBindings(PyObject* module) {
auto snb = toSymIntNode(a, b);
return a->lt(snb);
})
.def(
"__le__",
[](std::shared_ptr<c10::SymbolicIntNode> a,
py::object b) -> std::shared_ptr<c10::SymbolicIntNode> {
auto snb = toSymIntNode(a, b);
return a->le(snb);
})
.def(
"__ge__",
[](std::shared_ptr<c10::SymbolicIntNode> a,
py::object b) -> std::shared_ptr<c10::SymbolicIntNode> {
auto snb = toSymIntNode(a, b);
return a->ge(snb);
})
.def(
"__bool__",
[](std::shared_ptr<c10::SymbolicIntNode> a) { return a->bool_(); })