mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Reland 2 min/max support for SymInt/Floats, finish as_strided/scatter/squeeze() backward symint support (#86797)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86797 Approved by: https://github.com/bdhirsh
This commit is contained in:
@ -242,6 +242,13 @@ class PythonSymIntNodeImpl : public c10::SymIntNodeImpl {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual SymIntNode min(const SymIntNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
virtual SymIntNode max(const SymIntNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual SymIntNode ceil() override {
|
||||
return dispatch_common_(__FUNCTION__);
|
||||
}
|
||||
@ -1474,6 +1481,18 @@ void initJITBindings(PyObject* module) {
|
||||
.def(
|
||||
"__ceil__",
|
||||
[](c10::SymIntNode a) -> c10::SymIntNode { return a->ceil(); })
|
||||
.def(
|
||||
"__min__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->min(snb);
|
||||
})
|
||||
.def(
|
||||
"__max__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->max(snb);
|
||||
})
|
||||
.def("__bool__", [](c10::SymIntNode a) { return a->bool_(); })
|
||||
.def("__int__", [](c10::SymIntNode a) { return a->int_(); })
|
||||
// Intentionally don't set file line, as the Python backtrace matters
|
||||
|
Reference in New Issue
Block a user