Add SymFloat, support SymInt to SymFloat conversion (#84284)

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84284
Approved by: https://github.com/albanD
This commit is contained in:
Edward Z. Yang
2022-09-02 08:53:59 -07:00
committed by PyTorch MergeBot
parent 7f5da70ef0
commit 2a332afbf4
22 changed files with 548 additions and 138 deletions

View File

@ -99,6 +99,7 @@
#include <torch/csrc/jit/tensorexpr/tensorexpr_init.h>
#include <torch/csrc/utils/cpp_stacktraces.h>
#include <c10/core/SymFloat.h>
#include <c10/macros/Export.h>
#include <c10/util/irange.h>
#include <c10/util/signal_handler.h>
@ -125,6 +126,8 @@ using c10::Argument;
using c10::FunctionSchema;
using c10::SchemaArgType;
using c10::SchemaArgument;
using c10::SymFloat;
using c10::SymFloatNode;
using c10::SymIntNode;
using caffe2::serialize::PyTorchStreamReader;
using caffe2::serialize::PyTorchStreamWriter;
@ -158,6 +161,9 @@ class PythonSymIntNodeImpl : public c10::SymIntNodeImpl {
return getPyObj().attr("__int__")().cast<int64_t>();
}
// TODO: virtualize
SymFloat sym_float();
virtual std::string str() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("__str__")().cast<std::string>();
@ -223,6 +229,45 @@ class PythonSymIntNodeImpl : public c10::SymIntNodeImpl {
std::shared_ptr<c10::SafePyObject> pyobj_ = nullptr;
};
class PythonSymFloatNodeImpl : public c10::SymFloatNodeImpl {
public:
PythonSymFloatNodeImpl(py::object pyobj) : c10::SymFloatNodeImpl() {
pyobj_ = std::make_shared<c10::SafePyObject>(
pyobj.release().ptr(), getPyInterpreter());
};
virtual SymFloatNode wrap(double num) override {
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr("wrap")(num);
return c10::make_intrusive<PythonSymFloatNodeImpl>(r);
}
virtual std::string str() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("__str__")().cast<std::string>();
}
SymFloatNode dispatch_common_(const char* fname, const SymFloatNode& other) {
auto pother = dynamic_cast<PythonSymFloatNodeImpl*>(other.get());
TORCH_CHECK(pother);
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr(fname)(pother->getPyObj());
return c10::make_intrusive<PythonSymFloatNodeImpl>(r);
}
py::handle getPyObj() {
return py::handle(pyobj_.get()->ptr(getPyInterpreter()));
}
std::shared_ptr<c10::SafePyObject> pyobj_ = nullptr;
};
SymFloat PythonSymIntNodeImpl::sym_float() {
py::gil_scoped_acquire acquire;
return c10::make_intrusive<PythonSymFloatNodeImpl>(
getPyObj().attr("__sym_float__")())
->toSymFloat();
}
namespace {
using autograd::variable_list;
@ -1232,114 +1277,139 @@ void initJITBindings(PyObject* module) {
}
});
py::class_<c10::SymIntNodeImpl, c10::SymIntNode>(m, "SymIntNode")
auto symint_class =
py::class_<c10::SymIntNodeImpl, c10::SymIntNode>(m, "SymIntNode")
.def_static(
"new_symint",
[](py::object obj) -> c10::SymIntNode {
return c10::make_intrusive<PythonSymIntNodeImpl>(obj);
})
.def(
"get_pyobj",
[](c10::SymIntNode a) -> py::object {
if (auto* psn = dynamic_cast<PythonSymIntNodeImpl*>(a.get())) {
return py::reinterpret_borrow<py::object>(psn->getPyObj());
}
return py::none();
})
.def(
"__add__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->add(snb);
})
.def(
"__radd__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->add(snb);
})
.def(
"__sub__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->sub(snb);
})
.def(
"__mul__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->mul(snb);
})
.def(
"__rmul__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->mul(snb);
})
.def(
"__truediv__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->truediv(snb);
})
.def(
"__rtruediv__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return snb->truediv(a);
})
.def(
"__floordiv__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->floordiv(snb);
})
.def(
"__rfloordiv__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return snb->floordiv(a);
})
.def(
"__mod__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->mod(snb);
})
.def(
"__eq__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->eq(snb);
})
.def(
"__gt__",
[](c10::SymIntNode a, py::object b) {
auto snb = toSymIntNode(a, b);
return a->gt(snb);
})
.def(
"__lt__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->lt(snb);
})
.def(
"__le__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->le(snb);
})
.def(
"__ge__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->ge(snb);
})
.def("__bool__", [](c10::SymIntNode a) { return a->bool_(); })
.def("__int__", [](c10::SymIntNode a) { return a->int_(); })
.def(
"__sym_float__",
[](c10::SymIntNode a) {
// TODO: remove dynamic cast when sym_float is in base class
auto* psn = dynamic_cast<PythonSymIntNodeImpl*>(a.get());
TORCH_INTERNAL_ASSERT(psn);
return psn->sym_float();
})
.def("__str__", [](c10::SymIntNode a) { return a->str(); })
.def("__repr__", [](c10::SymIntNode a) { return a->str(); });
py::class_<c10::SymFloatNodeImpl, c10::SymFloatNode>(m, "SymFloatNode")
.def_static(
"new_symint",
[](py::object obj) -> c10::SymIntNode {
return c10::make_intrusive<PythonSymIntNodeImpl>(obj);
"new_symfloat",
[](py::object obj) -> c10::SymFloatNode {
return c10::make_intrusive<PythonSymFloatNodeImpl>(obj);
})
.def(
"get_pyobj",
[](c10::SymIntNode a) -> py::object {
if (auto* psn = dynamic_cast<PythonSymIntNodeImpl*>(a.get())) {
[](c10::SymFloatNode a) -> py::object {
if (auto* psn = dynamic_cast<PythonSymFloatNodeImpl*>(a.get())) {
return py::reinterpret_borrow<py::object>(psn->getPyObj());
}
return py::none();
})
.def(
"__add__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->add(snb);
})
.def(
"__radd__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->add(snb);
})
.def(
"__sub__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->sub(snb);
})
.def(
"__mul__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->mul(snb);
})
.def(
"__rmul__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->mul(snb);
})
.def(
"__truediv__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->truediv(snb);
})
.def(
"__rtruediv__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return snb->truediv(a);
})
.def(
"__floordiv__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->floordiv(snb);
})
.def(
"__rfloordiv__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return snb->floordiv(a);
})
.def(
"__mod__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->mod(snb);
})
.def(
"__eq__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->eq(snb);
})
.def(
"__gt__",
[](c10::SymIntNode a, py::object b) {
auto snb = toSymIntNode(a, b);
return a->gt(snb);
})
.def(
"__lt__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->lt(snb);
})
.def(
"__le__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->le(snb);
})
.def(
"__ge__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->ge(snb);
})
.def("__bool__", [](c10::SymIntNode a) { return a->bool_(); })
.def("__int__", [](c10::SymIntNode a) { return a->int_(); })
.def("__str__", [](c10::SymIntNode a) { return a->str(); })
.def("__repr__", [](c10::SymIntNode a) { return a->str(); });
.def("__str__", [](c10::SymFloatNode a) { return a->str(); });
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<CompleteArgumentSpec>(m, "CompleteArgumentSpec")