add integer divison for symints (#82791)

### Description
This PR brings integer division (floor) to symints + tests.

### Issue

https://github.com/orgs/pytorch/projects/17/views/2

### Testing
added two tests to TestPySymInts

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82791
Approved by: https://github.com/ezyang
This commit is contained in:
Nikolay Korovaiko
2022-08-04 20:00:47 +00:00
committed by PyTorch MergeBot
parent c08092fdf2
commit 8b20e47974
6 changed files with 67 additions and 4 deletions

View File

@ -185,7 +185,11 @@ class PythonSymIntNodeImpl : public c10::SymIntNodeImpl {
return dispatch_common_(__FUNCTION__, other);
}
virtual SymIntNode div(const SymIntNode& other) override {
virtual SymIntNode truediv(const SymIntNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual SymIntNode floordiv(const SymIntNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
@ -1223,10 +1227,28 @@ void initJITBindings(PyObject* module) {
return a->mul(snb);
})
.def(
"__div__",
"__truediv__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->div(snb);
return a->truediv(snb);
})
.def(
"__rtruediv__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return snb->truediv(a);
})
.def(
"__floordiv__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->floordiv(snb);
})
.def(
"__rfloordiv__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return snb->floordiv(a);
})
.def(
"__mod__",