Very limited pow support (#87042)

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87042
Approved by: https://github.com/ezyang
This commit is contained in:
albanD
2022-10-16 22:16:14 -04:00
committed by PyTorch MergeBot
parent 37e9e89afb
commit c21dcffc00
4 changed files with 58 additions and 7 deletions

View File

@ -139,8 +139,13 @@ static c10::SymIntNode toSymIntNode(c10::SymIntNode a, py::object b) {
}
static c10::SymFloatNode toSymFloatNode(c10::SymFloatNode a, py::object b) {
return torch::is_symfloat_node(b) ? b.cast<c10::SymFloatNode>()
: a->wrap(b.cast<double>());
if (torch::is_symfloat_node(b)) {
return b.cast<c10::SymFloatNode>();
} else if (torch::is_symint_node(b)) {
return b.cast<c10::SymIntNode>()->sym_float();
} else {
return a->wrap(b.cast<double>());
}
}
class PythonSymIntNodeImpl : public c10::SymIntNodeImpl {
@ -301,6 +306,10 @@ class PythonSymFloatNodeImpl : public c10::SymFloatNodeImpl {
return dispatch_common_(__FUNCTION__, other);
}
SymFloatNode pow(const SymFloatNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
SymFloatNode eq(const SymFloatNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
@ -1393,7 +1402,7 @@ void initJITBindings(PyObject* module) {
"__radd__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->add(snb);
return snb->add(a);
})
.def(
"__sub__",
@ -1417,7 +1426,7 @@ void initJITBindings(PyObject* module) {
"__rmul__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->mul(snb);
return snb->mul(a);
})
.def(
"__truediv__",
@ -1455,6 +1464,18 @@ void initJITBindings(PyObject* module) {
auto snb = toSymIntNode(a, b);
return snb->mod(a);
})
.def(
"__pow__",
[](c10::SymIntNode a, py::object b) -> py::object {
if (PyFloat_Check(b.ptr())) {
auto float_a = a->sym_float();
return py::cast(
float_a->pow(float_a->wrap(py::cast<double>(b))));
}
// TODO: integer pow
return py::reinterpret_borrow<py::object>(Py_NotImplemented);
})
// TODO: rpow
.def(
"__eq__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
@ -1534,7 +1555,7 @@ void initJITBindings(PyObject* module) {
"__radd__",
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
auto snb = toSymFloatNode(a, b);
return a->add(snb);
return snb->add(a);
})
.def(
"__sub__",
@ -1552,7 +1573,7 @@ void initJITBindings(PyObject* module) {
"__rmul__",
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
auto snb = toSymFloatNode(a, b);
return a->mul(snb);
return snb->mul(a);
})
.def(
"__truediv__",
@ -1596,6 +1617,18 @@ void initJITBindings(PyObject* module) {
auto snb = toSymFloatNode(a, b);
return a->ge(snb);
})
.def(
"__pow__",
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
auto snb = toSymFloatNode(a, b);
return a->pow(snb);
})
.def(
"__rpow__",
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
auto snb = toSymFloatNode(a, b);
return snb->pow(a);
})
.def(
"__ceil__",
[](c10::SymFloatNode a) -> c10::SymIntNode { return a->ceil(); })