mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Added a couple more symint magic methods + symbolic shape infra (#81086)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81086 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
7ccf693cf6
commit
97938d872e
@ -213,6 +213,16 @@ class PythonSymbolicIntNode : public c10::SymbolicIntNode {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<SymbolicIntNode> le(
|
||||
const std::shared_ptr<SymbolicIntNode>& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<SymbolicIntNode> ge(
|
||||
const std::shared_ptr<SymbolicIntNode>& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
py::handle getPyObj() {
|
||||
return py::handle(pyobj_.get()->ptr(getPyInterpreter()));
|
||||
}
|
||||
@ -1257,6 +1267,20 @@ void initJITBindings(PyObject* module) {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->lt(snb);
|
||||
})
|
||||
.def(
|
||||
"__le__",
|
||||
[](std::shared_ptr<c10::SymbolicIntNode> a,
|
||||
py::object b) -> std::shared_ptr<c10::SymbolicIntNode> {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->le(snb);
|
||||
})
|
||||
.def(
|
||||
"__ge__",
|
||||
[](std::shared_ptr<c10::SymbolicIntNode> a,
|
||||
py::object b) -> std::shared_ptr<c10::SymbolicIntNode> {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->ge(snb);
|
||||
})
|
||||
.def(
|
||||
"__bool__",
|
||||
[](std::shared_ptr<c10::SymbolicIntNode> a) { return a->bool_(); })
|
||||
|
Reference in New Issue
Block a user