mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-22 06:11:27 +08:00
Very limited pow support (#87042)
Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/87042 Approved by: https://github.com/ezyang
This commit is contained in:
@ -139,8 +139,13 @@ static c10::SymIntNode toSymIntNode(c10::SymIntNode a, py::object b) {
|
||||
}
|
||||
|
||||
static c10::SymFloatNode toSymFloatNode(c10::SymFloatNode a, py::object b) {
|
||||
return torch::is_symfloat_node(b) ? b.cast<c10::SymFloatNode>()
|
||||
: a->wrap(b.cast<double>());
|
||||
if (torch::is_symfloat_node(b)) {
|
||||
return b.cast<c10::SymFloatNode>();
|
||||
} else if (torch::is_symint_node(b)) {
|
||||
return b.cast<c10::SymIntNode>()->sym_float();
|
||||
} else {
|
||||
return a->wrap(b.cast<double>());
|
||||
}
|
||||
}
|
||||
|
||||
class PythonSymIntNodeImpl : public c10::SymIntNodeImpl {
|
||||
@ -301,6 +306,10 @@ class PythonSymFloatNodeImpl : public c10::SymFloatNodeImpl {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
SymFloatNode pow(const SymFloatNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
SymFloatNode eq(const SymFloatNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
@ -1393,7 +1402,7 @@ void initJITBindings(PyObject* module) {
|
||||
"__radd__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->add(snb);
|
||||
return snb->add(a);
|
||||
})
|
||||
.def(
|
||||
"__sub__",
|
||||
@ -1417,7 +1426,7 @@ void initJITBindings(PyObject* module) {
|
||||
"__rmul__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->mul(snb);
|
||||
return snb->mul(a);
|
||||
})
|
||||
.def(
|
||||
"__truediv__",
|
||||
@ -1455,6 +1464,18 @@ void initJITBindings(PyObject* module) {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return snb->mod(a);
|
||||
})
|
||||
.def(
|
||||
"__pow__",
|
||||
[](c10::SymIntNode a, py::object b) -> py::object {
|
||||
if (PyFloat_Check(b.ptr())) {
|
||||
auto float_a = a->sym_float();
|
||||
return py::cast(
|
||||
float_a->pow(float_a->wrap(py::cast<double>(b))));
|
||||
}
|
||||
// TODO: integer pow
|
||||
return py::reinterpret_borrow<py::object>(Py_NotImplemented);
|
||||
})
|
||||
// TODO: rpow
|
||||
.def(
|
||||
"__eq__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
@ -1534,7 +1555,7 @@ void initJITBindings(PyObject* module) {
|
||||
"__radd__",
|
||||
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
|
||||
auto snb = toSymFloatNode(a, b);
|
||||
return a->add(snb);
|
||||
return snb->add(a);
|
||||
})
|
||||
.def(
|
||||
"__sub__",
|
||||
@ -1552,7 +1573,7 @@ void initJITBindings(PyObject* module) {
|
||||
"__rmul__",
|
||||
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
|
||||
auto snb = toSymFloatNode(a, b);
|
||||
return a->mul(snb);
|
||||
return snb->mul(a);
|
||||
})
|
||||
.def(
|
||||
"__truediv__",
|
||||
@ -1596,6 +1617,18 @@ void initJITBindings(PyObject* module) {
|
||||
auto snb = toSymFloatNode(a, b);
|
||||
return a->ge(snb);
|
||||
})
|
||||
.def(
|
||||
"__pow__",
|
||||
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
|
||||
auto snb = toSymFloatNode(a, b);
|
||||
return a->pow(snb);
|
||||
})
|
||||
.def(
|
||||
"__rpow__",
|
||||
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
|
||||
auto snb = toSymFloatNode(a, b);
|
||||
return snb->pow(a);
|
||||
})
|
||||
.def(
|
||||
"__ceil__",
|
||||
[](c10::SymFloatNode a) -> c10::SymIntNode { return a->ceil(); })
|
||||
|
Reference in New Issue
Block a user