mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Wconstab/reland pysymint (#79795)
rebased https://github.com/pytorch/pytorch/pull/79617/ to see if issues are reproducible. Pull Request resolved: https://github.com/pytorch/pytorch/pull/79795 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
a6b783e714
commit
f7ee061638
@ -1,3 +1,4 @@
|
||||
#include <pybind11/pytypes.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
#include <torch/csrc/utils/python_arg_parser.h>
|
||||
|
||||
@ -11,6 +12,7 @@
|
||||
#if (!defined(FBCODE_CAFFE2) && defined(BUILD_ONEDNN_GRAPH))
|
||||
#include <torch/csrc/jit/codegen/onednn/interface.h>
|
||||
#endif
|
||||
#include <c10/core/SymbolicIntNode.h>
|
||||
#include <torch/csrc/jit/frontend/ir_emitter.h>
|
||||
#include <torch/csrc/jit/frontend/tracer.h>
|
||||
#include <torch/csrc/jit/ir/irparser.h>
|
||||
@ -103,6 +105,7 @@
|
||||
|
||||
#include <ATen/core/function_schema.h>
|
||||
|
||||
#include <pybind11/cast.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/iostream.h>
|
||||
#include <pybind11/operators.h>
|
||||
@ -124,6 +127,98 @@ using ::c10::FunctionSchema;
|
||||
using caffe2::serialize::PyTorchStreamReader;
|
||||
using caffe2::serialize::PyTorchStreamWriter;
|
||||
|
||||
static std::shared_ptr<c10::SymbolicIntNode> toSymIntNode(
|
||||
std::shared_ptr<c10::SymbolicIntNode> a,
|
||||
py::object b) {
|
||||
return torch::is_symint_node(b)
|
||||
? b.cast<std::shared_ptr<c10::SymbolicIntNode>>()
|
||||
: a->wrap(b.cast<int64_t>());
|
||||
}
|
||||
|
||||
class PythonSymbolicIntNode : public c10::SymbolicIntNode {
|
||||
public:
|
||||
PythonSymbolicIntNode(py::object pyobj) : c10::SymbolicIntNode() {
|
||||
pyobj_ = std::make_shared<c10::SafePyObject>(
|
||||
pyobj.release().ptr(), getPyInterpreter());
|
||||
};
|
||||
|
||||
virtual std::shared_ptr<SymbolicIntNode> wrap(int64_t num) override {
|
||||
py::gil_scoped_acquire acquire;
|
||||
auto r = getPyObj().attr("wrap")(num);
|
||||
return std::make_shared<PythonSymbolicIntNode>(r);
|
||||
}
|
||||
|
||||
virtual bool bool_() override {
|
||||
py::gil_scoped_acquire acquire;
|
||||
return getPyObj().attr("__bool__")().is(py::handle(Py_True));
|
||||
}
|
||||
|
||||
virtual int64_t int_() override {
|
||||
py::gil_scoped_acquire acquire;
|
||||
return getPyObj().attr("__int__")().cast<int64_t>();
|
||||
}
|
||||
|
||||
virtual std::string str() override {
|
||||
py::gil_scoped_acquire acquire;
|
||||
return getPyObj().attr("__str__")().cast<std::string>();
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<SymbolicIntNode> dispatch_common_(
|
||||
const char* fname,
|
||||
const std::shared_ptr<SymbolicIntNode>& other) {
|
||||
auto pother = std::dynamic_pointer_cast<PythonSymbolicIntNode>(other);
|
||||
TORCH_CHECK(pother);
|
||||
py::gil_scoped_acquire acquire;
|
||||
auto r = getPyObj().attr(fname)(pother->getPyObj());
|
||||
return std::make_shared<PythonSymbolicIntNode>(r);
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<SymbolicIntNode> add(
|
||||
const std::shared_ptr<SymbolicIntNode>& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<SymbolicIntNode> sub(
|
||||
const std::shared_ptr<SymbolicIntNode>& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<SymbolicIntNode> mul(
|
||||
const std::shared_ptr<SymbolicIntNode>& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<SymbolicIntNode> div(
|
||||
const std::shared_ptr<SymbolicIntNode>& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<SymbolicIntNode> mod(
|
||||
const std::shared_ptr<SymbolicIntNode>& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<SymbolicIntNode> eq(
|
||||
const std::shared_ptr<SymbolicIntNode>& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<SymbolicIntNode> gt(
|
||||
const std::shared_ptr<SymbolicIntNode>& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<SymbolicIntNode> lt(
|
||||
const std::shared_ptr<SymbolicIntNode>& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
py::handle getPyObj() {
|
||||
return py::handle(pyobj_.get()->ptr(getPyInterpreter()));
|
||||
}
|
||||
std::shared_ptr<c10::SafePyObject> pyobj_ = nullptr;
|
||||
};
|
||||
|
||||
namespace {
|
||||
|
||||
using autograd::variable_list;
|
||||
@ -1077,6 +1172,101 @@ void initJITBindings(PyObject* module) {
|
||||
}
|
||||
});
|
||||
|
||||
py::class_<c10::SymbolicIntNode, std::shared_ptr<c10::SymbolicIntNode>>(
|
||||
m, "SymbolicIntNode")
|
||||
.def_static(
|
||||
"new_symint",
|
||||
[](py::object obj) -> std::shared_ptr<c10::SymbolicIntNode> {
|
||||
return std::make_shared<PythonSymbolicIntNode>(obj);
|
||||
})
|
||||
.def(
|
||||
"get_pyobj",
|
||||
[](std::shared_ptr<c10::SymbolicIntNode> a) -> py::object {
|
||||
if (auto psn =
|
||||
std::dynamic_pointer_cast<PythonSymbolicIntNode>(a)) {
|
||||
return py::reinterpret_borrow<py::object>(psn->getPyObj());
|
||||
}
|
||||
return py::none();
|
||||
})
|
||||
.def(
|
||||
"__add__",
|
||||
[](std::shared_ptr<c10::SymbolicIntNode> a,
|
||||
py::object b) -> std::shared_ptr<c10::SymbolicIntNode> {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->add(snb);
|
||||
})
|
||||
.def(
|
||||
"__radd__",
|
||||
[](std::shared_ptr<c10::SymbolicIntNode> a,
|
||||
py::object b) -> std::shared_ptr<c10::SymbolicIntNode> {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->add(snb);
|
||||
})
|
||||
.def(
|
||||
"__sub__",
|
||||
[](std::shared_ptr<c10::SymbolicIntNode> a,
|
||||
py::object b) -> std::shared_ptr<c10::SymbolicIntNode> {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->sub(snb);
|
||||
})
|
||||
.def(
|
||||
"__mul__",
|
||||
[](std::shared_ptr<c10::SymbolicIntNode> a,
|
||||
py::object b) -> std::shared_ptr<c10::SymbolicIntNode> {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->mul(snb);
|
||||
})
|
||||
.def(
|
||||
"__rmul__",
|
||||
[](std::shared_ptr<c10::SymbolicIntNode> a,
|
||||
py::object b) -> std::shared_ptr<c10::SymbolicIntNode> {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->mul(snb);
|
||||
})
|
||||
.def(
|
||||
"__div__",
|
||||
[](std::shared_ptr<c10::SymbolicIntNode> a,
|
||||
py::object b) -> std::shared_ptr<c10::SymbolicIntNode> {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->div(snb);
|
||||
})
|
||||
.def(
|
||||
"__mod__",
|
||||
[](std::shared_ptr<c10::SymbolicIntNode> a,
|
||||
py::object b) -> std::shared_ptr<c10::SymbolicIntNode> {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->mod(snb);
|
||||
})
|
||||
.def(
|
||||
"__eq__",
|
||||
[](std::shared_ptr<c10::SymbolicIntNode> a,
|
||||
py::object b) -> std::shared_ptr<c10::SymbolicIntNode> {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->eq(snb);
|
||||
})
|
||||
.def(
|
||||
"__gt__",
|
||||
[](std::shared_ptr<c10::SymbolicIntNode> a, py::object b) {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->gt(snb);
|
||||
})
|
||||
.def(
|
||||
"__lt__",
|
||||
[](std::shared_ptr<c10::SymbolicIntNode> a,
|
||||
py::object b) -> std::shared_ptr<c10::SymbolicIntNode> {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->lt(snb);
|
||||
})
|
||||
.def(
|
||||
"__bool__",
|
||||
[](std::shared_ptr<c10::SymbolicIntNode> a) { return a->bool_(); })
|
||||
.def(
|
||||
"__int__",
|
||||
[](std::shared_ptr<c10::SymbolicIntNode> a) { return a->int_(); })
|
||||
.def("__str__", [](std::shared_ptr<c10::SymbolicIntNode> a) {
|
||||
return a->str();
|
||||
});
|
||||
|
||||
// NOLINTNEXTLINE(bugprone-unused-raii)
|
||||
py::class_<CompleteArgumentSpec>(m, "CompleteArgumentSpec")
|
||||
.def("__repr__", [](CompleteArgumentSpec& self) {
|
||||
|
Reference in New Issue
Block a user