Reland 2 min/max support for SymInt/Floats, finish as_strided/scatter/squeeze() backward symint support (#86797)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86797
Approved by: https://github.com/bdhirsh
This commit is contained in:
albanD
2022-10-12 11:24:51 -04:00
committed by PyTorch MergeBot
parent 894c4218dd
commit 66cab5245f
12 changed files with 93 additions and 33 deletions

View File

@ -242,6 +242,13 @@ class PythonSymIntNodeImpl : public c10::SymIntNodeImpl {
return dispatch_common_(__FUNCTION__, other);
}
virtual SymIntNode min(const SymIntNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual SymIntNode max(const SymIntNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual SymIntNode ceil() override {
return dispatch_common_(__FUNCTION__);
}
@ -1474,6 +1481,18 @@ void initJITBindings(PyObject* module) {
.def(
"__ceil__",
[](c10::SymIntNode a) -> c10::SymIntNode { return a->ceil(); })
.def(
"__min__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->min(snb);
})
.def(
"__max__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->max(snb);
})
.def("__bool__", [](c10::SymIntNode a) { return a->bool_(); })
.def("__int__", [](c10::SymIntNode a) { return a->int_(); })
// Intentionally don't set file line, as the Python backtrace matters