mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add SymFloat, support SymInt to SymFloat conversion (#84284)
Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/84284 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
7f5da70ef0
commit
2a332afbf4
@ -99,6 +99,7 @@
|
||||
#include <torch/csrc/jit/tensorexpr/tensorexpr_init.h>
|
||||
#include <torch/csrc/utils/cpp_stacktraces.h>
|
||||
|
||||
#include <c10/core/SymFloat.h>
|
||||
#include <c10/macros/Export.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10/util/signal_handler.h>
|
||||
@ -125,6 +126,8 @@ using c10::Argument;
|
||||
using c10::FunctionSchema;
|
||||
using c10::SchemaArgType;
|
||||
using c10::SchemaArgument;
|
||||
using c10::SymFloat;
|
||||
using c10::SymFloatNode;
|
||||
using c10::SymIntNode;
|
||||
using caffe2::serialize::PyTorchStreamReader;
|
||||
using caffe2::serialize::PyTorchStreamWriter;
|
||||
@ -158,6 +161,9 @@ class PythonSymIntNodeImpl : public c10::SymIntNodeImpl {
|
||||
return getPyObj().attr("__int__")().cast<int64_t>();
|
||||
}
|
||||
|
||||
// TODO: virtualize
|
||||
SymFloat sym_float();
|
||||
|
||||
virtual std::string str() override {
|
||||
py::gil_scoped_acquire acquire;
|
||||
return getPyObj().attr("__str__")().cast<std::string>();
|
||||
@ -223,6 +229,45 @@ class PythonSymIntNodeImpl : public c10::SymIntNodeImpl {
|
||||
std::shared_ptr<c10::SafePyObject> pyobj_ = nullptr;
|
||||
};
|
||||
|
||||
class PythonSymFloatNodeImpl : public c10::SymFloatNodeImpl {
|
||||
public:
|
||||
PythonSymFloatNodeImpl(py::object pyobj) : c10::SymFloatNodeImpl() {
|
||||
pyobj_ = std::make_shared<c10::SafePyObject>(
|
||||
pyobj.release().ptr(), getPyInterpreter());
|
||||
};
|
||||
|
||||
virtual SymFloatNode wrap(double num) override {
|
||||
py::gil_scoped_acquire acquire;
|
||||
auto r = getPyObj().attr("wrap")(num);
|
||||
return c10::make_intrusive<PythonSymFloatNodeImpl>(r);
|
||||
}
|
||||
|
||||
virtual std::string str() override {
|
||||
py::gil_scoped_acquire acquire;
|
||||
return getPyObj().attr("__str__")().cast<std::string>();
|
||||
}
|
||||
|
||||
SymFloatNode dispatch_common_(const char* fname, const SymFloatNode& other) {
|
||||
auto pother = dynamic_cast<PythonSymFloatNodeImpl*>(other.get());
|
||||
TORCH_CHECK(pother);
|
||||
py::gil_scoped_acquire acquire;
|
||||
auto r = getPyObj().attr(fname)(pother->getPyObj());
|
||||
return c10::make_intrusive<PythonSymFloatNodeImpl>(r);
|
||||
}
|
||||
|
||||
py::handle getPyObj() {
|
||||
return py::handle(pyobj_.get()->ptr(getPyInterpreter()));
|
||||
}
|
||||
std::shared_ptr<c10::SafePyObject> pyobj_ = nullptr;
|
||||
};
|
||||
|
||||
SymFloat PythonSymIntNodeImpl::sym_float() {
|
||||
py::gil_scoped_acquire acquire;
|
||||
return c10::make_intrusive<PythonSymFloatNodeImpl>(
|
||||
getPyObj().attr("__sym_float__")())
|
||||
->toSymFloat();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
using autograd::variable_list;
|
||||
@ -1232,114 +1277,139 @@ void initJITBindings(PyObject* module) {
|
||||
}
|
||||
});
|
||||
|
||||
py::class_<c10::SymIntNodeImpl, c10::SymIntNode>(m, "SymIntNode")
|
||||
auto symint_class =
|
||||
py::class_<c10::SymIntNodeImpl, c10::SymIntNode>(m, "SymIntNode")
|
||||
.def_static(
|
||||
"new_symint",
|
||||
[](py::object obj) -> c10::SymIntNode {
|
||||
return c10::make_intrusive<PythonSymIntNodeImpl>(obj);
|
||||
})
|
||||
.def(
|
||||
"get_pyobj",
|
||||
[](c10::SymIntNode a) -> py::object {
|
||||
if (auto* psn = dynamic_cast<PythonSymIntNodeImpl*>(a.get())) {
|
||||
return py::reinterpret_borrow<py::object>(psn->getPyObj());
|
||||
}
|
||||
return py::none();
|
||||
})
|
||||
.def(
|
||||
"__add__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->add(snb);
|
||||
})
|
||||
.def(
|
||||
"__radd__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->add(snb);
|
||||
})
|
||||
.def(
|
||||
"__sub__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->sub(snb);
|
||||
})
|
||||
.def(
|
||||
"__mul__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->mul(snb);
|
||||
})
|
||||
.def(
|
||||
"__rmul__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->mul(snb);
|
||||
})
|
||||
.def(
|
||||
"__truediv__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->truediv(snb);
|
||||
})
|
||||
.def(
|
||||
"__rtruediv__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return snb->truediv(a);
|
||||
})
|
||||
.def(
|
||||
"__floordiv__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->floordiv(snb);
|
||||
})
|
||||
.def(
|
||||
"__rfloordiv__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return snb->floordiv(a);
|
||||
})
|
||||
.def(
|
||||
"__mod__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->mod(snb);
|
||||
})
|
||||
.def(
|
||||
"__eq__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->eq(snb);
|
||||
})
|
||||
.def(
|
||||
"__gt__",
|
||||
[](c10::SymIntNode a, py::object b) {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->gt(snb);
|
||||
})
|
||||
.def(
|
||||
"__lt__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->lt(snb);
|
||||
})
|
||||
.def(
|
||||
"__le__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->le(snb);
|
||||
})
|
||||
.def(
|
||||
"__ge__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->ge(snb);
|
||||
})
|
||||
.def("__bool__", [](c10::SymIntNode a) { return a->bool_(); })
|
||||
.def("__int__", [](c10::SymIntNode a) { return a->int_(); })
|
||||
.def(
|
||||
"__sym_float__",
|
||||
[](c10::SymIntNode a) {
|
||||
// TODO: remove dynamic cast when sym_float is in base class
|
||||
auto* psn = dynamic_cast<PythonSymIntNodeImpl*>(a.get());
|
||||
TORCH_INTERNAL_ASSERT(psn);
|
||||
return psn->sym_float();
|
||||
})
|
||||
.def("__str__", [](c10::SymIntNode a) { return a->str(); })
|
||||
.def("__repr__", [](c10::SymIntNode a) { return a->str(); });
|
||||
|
||||
py::class_<c10::SymFloatNodeImpl, c10::SymFloatNode>(m, "SymFloatNode")
|
||||
.def_static(
|
||||
"new_symint",
|
||||
[](py::object obj) -> c10::SymIntNode {
|
||||
return c10::make_intrusive<PythonSymIntNodeImpl>(obj);
|
||||
"new_symfloat",
|
||||
[](py::object obj) -> c10::SymFloatNode {
|
||||
return c10::make_intrusive<PythonSymFloatNodeImpl>(obj);
|
||||
})
|
||||
.def(
|
||||
"get_pyobj",
|
||||
[](c10::SymIntNode a) -> py::object {
|
||||
if (auto* psn = dynamic_cast<PythonSymIntNodeImpl*>(a.get())) {
|
||||
[](c10::SymFloatNode a) -> py::object {
|
||||
if (auto* psn = dynamic_cast<PythonSymFloatNodeImpl*>(a.get())) {
|
||||
return py::reinterpret_borrow<py::object>(psn->getPyObj());
|
||||
}
|
||||
return py::none();
|
||||
})
|
||||
.def(
|
||||
"__add__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->add(snb);
|
||||
})
|
||||
.def(
|
||||
"__radd__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->add(snb);
|
||||
})
|
||||
.def(
|
||||
"__sub__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->sub(snb);
|
||||
})
|
||||
.def(
|
||||
"__mul__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->mul(snb);
|
||||
})
|
||||
.def(
|
||||
"__rmul__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->mul(snb);
|
||||
})
|
||||
.def(
|
||||
"__truediv__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->truediv(snb);
|
||||
})
|
||||
.def(
|
||||
"__rtruediv__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return snb->truediv(a);
|
||||
})
|
||||
.def(
|
||||
"__floordiv__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->floordiv(snb);
|
||||
})
|
||||
.def(
|
||||
"__rfloordiv__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return snb->floordiv(a);
|
||||
})
|
||||
.def(
|
||||
"__mod__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->mod(snb);
|
||||
})
|
||||
.def(
|
||||
"__eq__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->eq(snb);
|
||||
})
|
||||
.def(
|
||||
"__gt__",
|
||||
[](c10::SymIntNode a, py::object b) {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->gt(snb);
|
||||
})
|
||||
.def(
|
||||
"__lt__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->lt(snb);
|
||||
})
|
||||
.def(
|
||||
"__le__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->le(snb);
|
||||
})
|
||||
.def(
|
||||
"__ge__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->ge(snb);
|
||||
})
|
||||
.def("__bool__", [](c10::SymIntNode a) { return a->bool_(); })
|
||||
.def("__int__", [](c10::SymIntNode a) { return a->int_(); })
|
||||
.def("__str__", [](c10::SymIntNode a) { return a->str(); })
|
||||
.def("__repr__", [](c10::SymIntNode a) { return a->str(); });
|
||||
.def("__str__", [](c10::SymFloatNode a) { return a->str(); });
|
||||
|
||||
// NOLINTNEXTLINE(bugprone-unused-raii)
|
||||
py::class_<CompleteArgumentSpec>(m, "CompleteArgumentSpec")
|
||||
|
Reference in New Issue
Block a user