mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 00:54:56 +08:00
Rename shared_ptr<SymIntNodeImpl> to SymIntNode (#82355)
Makes code a lot more compact! It also makes it possible to swap out the shared ptr implementation, which I am about to do next. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/82355 Approved by: https://github.com/Krovatkin
This commit is contained in:
committed by
PyTorch MergeBot
parent
fd5ac1e6b5
commit
34bdd46e6e
@ -125,16 +125,14 @@ using c10::Argument;
|
||||
using c10::FunctionSchema;
|
||||
using c10::SchemaArgType;
|
||||
using c10::SchemaArgument;
|
||||
using c10::SymIntNode;
|
||||
using caffe2::serialize::PyTorchStreamReader;
|
||||
using caffe2::serialize::PyTorchStreamWriter;
|
||||
using torch::utils::SchemaInfo;
|
||||
|
||||
static std::shared_ptr<c10::SymIntNodeImpl> toSymIntNode(
|
||||
std::shared_ptr<c10::SymIntNodeImpl> a,
|
||||
py::object b) {
|
||||
return torch::is_symint_node(b)
|
||||
? b.cast<std::shared_ptr<c10::SymIntNodeImpl>>()
|
||||
: a->wrap(b.cast<int64_t>());
|
||||
static c10::SymIntNode toSymIntNode(c10::SymIntNode a, py::object b) {
|
||||
return torch::is_symint_node(b) ? b.cast<c10::SymIntNode>()
|
||||
: a->wrap(b.cast<int64_t>());
|
||||
}
|
||||
|
||||
class PythonSymIntNodeImpl : public c10::SymIntNodeImpl {
|
||||
@ -144,7 +142,7 @@ class PythonSymIntNodeImpl : public c10::SymIntNodeImpl {
|
||||
pyobj.release().ptr(), getPyInterpreter());
|
||||
};
|
||||
|
||||
virtual std::shared_ptr<SymIntNodeImpl> wrap(int64_t num) override {
|
||||
virtual SymIntNode wrap(int64_t num) override {
|
||||
py::gil_scoped_acquire acquire;
|
||||
auto r = getPyObj().attr("wrap")(num);
|
||||
return std::make_shared<PythonSymIntNodeImpl>(r);
|
||||
@ -165,9 +163,9 @@ class PythonSymIntNodeImpl : public c10::SymIntNodeImpl {
|
||||
return getPyObj().attr("__str__")().cast<std::string>();
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<SymIntNodeImpl> dispatch_common_(
|
||||
virtual SymIntNode dispatch_common_(
|
||||
const char* fname,
|
||||
const std::shared_ptr<SymIntNodeImpl>& other) {
|
||||
const SymIntNode& other) {
|
||||
auto pother = std::dynamic_pointer_cast<PythonSymIntNodeImpl>(other);
|
||||
TORCH_CHECK(pother);
|
||||
py::gil_scoped_acquire acquire;
|
||||
@ -175,53 +173,43 @@ class PythonSymIntNodeImpl : public c10::SymIntNodeImpl {
|
||||
return std::make_shared<PythonSymIntNodeImpl>(r);
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<SymIntNodeImpl> add(
|
||||
const std::shared_ptr<SymIntNodeImpl>& other) override {
|
||||
virtual SymIntNode add(const SymIntNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<SymIntNodeImpl> sub(
|
||||
const std::shared_ptr<SymIntNodeImpl>& other) override {
|
||||
virtual SymIntNode sub(const SymIntNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<SymIntNodeImpl> mul(
|
||||
const std::shared_ptr<SymIntNodeImpl>& other) override {
|
||||
virtual SymIntNode mul(const SymIntNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<SymIntNodeImpl> div(
|
||||
const std::shared_ptr<SymIntNodeImpl>& other) override {
|
||||
virtual SymIntNode div(const SymIntNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<SymIntNodeImpl> mod(
|
||||
const std::shared_ptr<SymIntNodeImpl>& other) override {
|
||||
virtual SymIntNode mod(const SymIntNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<SymIntNodeImpl> eq(
|
||||
const std::shared_ptr<SymIntNodeImpl>& other) override {
|
||||
virtual SymIntNode eq(const SymIntNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<SymIntNodeImpl> gt(
|
||||
const std::shared_ptr<SymIntNodeImpl>& other) override {
|
||||
virtual SymIntNode gt(const SymIntNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<SymIntNodeImpl> lt(
|
||||
const std::shared_ptr<SymIntNodeImpl>& other) override {
|
||||
virtual SymIntNode lt(const SymIntNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<SymIntNodeImpl> le(
|
||||
const std::shared_ptr<SymIntNodeImpl>& other) override {
|
||||
virtual SymIntNode le(const SymIntNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<SymIntNodeImpl> ge(
|
||||
const std::shared_ptr<SymIntNodeImpl>& other) override {
|
||||
virtual SymIntNode ge(const SymIntNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
@ -1190,16 +1178,15 @@ void initJITBindings(PyObject* module) {
|
||||
}
|
||||
});
|
||||
|
||||
py::class_<c10::SymIntNodeImpl, std::shared_ptr<c10::SymIntNodeImpl>>(
|
||||
m, "SymIntNode")
|
||||
py::class_<c10::SymIntNodeImpl, c10::SymIntNode>(m, "SymIntNode")
|
||||
.def_static(
|
||||
"new_symint",
|
||||
[](py::object obj) -> std::shared_ptr<c10::SymIntNodeImpl> {
|
||||
[](py::object obj) -> c10::SymIntNode {
|
||||
return std::make_shared<PythonSymIntNodeImpl>(obj);
|
||||
})
|
||||
.def(
|
||||
"get_pyobj",
|
||||
[](std::shared_ptr<c10::SymIntNodeImpl> a) -> py::object {
|
||||
[](c10::SymIntNode a) -> py::object {
|
||||
if (auto psn = std::dynamic_pointer_cast<PythonSymIntNodeImpl>(a)) {
|
||||
return py::reinterpret_borrow<py::object>(psn->getPyObj());
|
||||
}
|
||||
@ -1207,96 +1194,79 @@ void initJITBindings(PyObject* module) {
|
||||
})
|
||||
.def(
|
||||
"__add__",
|
||||
[](std::shared_ptr<c10::SymIntNodeImpl> a,
|
||||
py::object b) -> std::shared_ptr<c10::SymIntNodeImpl> {
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->add(snb);
|
||||
})
|
||||
.def(
|
||||
"__radd__",
|
||||
[](std::shared_ptr<c10::SymIntNodeImpl> a,
|
||||
py::object b) -> std::shared_ptr<c10::SymIntNodeImpl> {
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->add(snb);
|
||||
})
|
||||
.def(
|
||||
"__sub__",
|
||||
[](std::shared_ptr<c10::SymIntNodeImpl> a,
|
||||
py::object b) -> std::shared_ptr<c10::SymIntNodeImpl> {
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->sub(snb);
|
||||
})
|
||||
.def(
|
||||
"__mul__",
|
||||
[](std::shared_ptr<c10::SymIntNodeImpl> a,
|
||||
py::object b) -> std::shared_ptr<c10::SymIntNodeImpl> {
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->mul(snb);
|
||||
})
|
||||
.def(
|
||||
"__rmul__",
|
||||
[](std::shared_ptr<c10::SymIntNodeImpl> a,
|
||||
py::object b) -> std::shared_ptr<c10::SymIntNodeImpl> {
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->mul(snb);
|
||||
})
|
||||
.def(
|
||||
"__div__",
|
||||
[](std::shared_ptr<c10::SymIntNodeImpl> a,
|
||||
py::object b) -> std::shared_ptr<c10::SymIntNodeImpl> {
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->div(snb);
|
||||
})
|
||||
.def(
|
||||
"__mod__",
|
||||
[](std::shared_ptr<c10::SymIntNodeImpl> a,
|
||||
py::object b) -> std::shared_ptr<c10::SymIntNodeImpl> {
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->mod(snb);
|
||||
})
|
||||
.def(
|
||||
"__eq__",
|
||||
[](std::shared_ptr<c10::SymIntNodeImpl> a,
|
||||
py::object b) -> std::shared_ptr<c10::SymIntNodeImpl> {
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->eq(snb);
|
||||
})
|
||||
.def(
|
||||
"__gt__",
|
||||
[](std::shared_ptr<c10::SymIntNodeImpl> a, py::object b) {
|
||||
[](c10::SymIntNode a, py::object b) {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->gt(snb);
|
||||
})
|
||||
.def(
|
||||
"__lt__",
|
||||
[](std::shared_ptr<c10::SymIntNodeImpl> a,
|
||||
py::object b) -> std::shared_ptr<c10::SymIntNodeImpl> {
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->lt(snb);
|
||||
})
|
||||
.def(
|
||||
"__le__",
|
||||
[](std::shared_ptr<c10::SymIntNodeImpl> a,
|
||||
py::object b) -> std::shared_ptr<c10::SymIntNodeImpl> {
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->le(snb);
|
||||
})
|
||||
.def(
|
||||
"__ge__",
|
||||
[](std::shared_ptr<c10::SymIntNodeImpl> a,
|
||||
py::object b) -> std::shared_ptr<c10::SymIntNodeImpl> {
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->ge(snb);
|
||||
})
|
||||
.def(
|
||||
"__bool__",
|
||||
[](std::shared_ptr<c10::SymIntNodeImpl> a) { return a->bool_(); })
|
||||
.def(
|
||||
"__int__",
|
||||
[](std::shared_ptr<c10::SymIntNodeImpl> a) { return a->int_(); })
|
||||
.def("__str__", [](std::shared_ptr<c10::SymIntNodeImpl> a) {
|
||||
return a->str();
|
||||
});
|
||||
.def("__bool__", [](c10::SymIntNode a) { return a->bool_(); })
|
||||
.def("__int__", [](c10::SymIntNode a) { return a->int_(); })
|
||||
.def("__str__", [](c10::SymIntNode a) { return a->str(); });
|
||||
|
||||
// NOLINTNEXTLINE(bugprone-unused-raii)
|
||||
py::class_<CompleteArgumentSpec>(m, "CompleteArgumentSpec")
|
||||
|
||||
Reference in New Issue
Block a user