Change SymIntNode into an intrusive pointer (#82432)

This will make the pointer type a single word, which is important
for packing it into an int64_t

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82432
Approved by: https://github.com/albanD, https://github.com/Krovatkin
This commit is contained in:
Edward Z. Yang
2022-07-28 16:18:15 -07:00
committed by PyTorch MergeBot
parent 7eed83e016
commit 7be44f8158
10 changed files with 29 additions and 23 deletions

View File

@ -145,7 +145,7 @@ class PythonSymIntNodeImpl : public c10::SymIntNodeImpl {
virtual SymIntNode wrap(int64_t num) override {
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr("wrap")(num);
return std::make_shared<PythonSymIntNodeImpl>(r);
return c10::make_intrusive<PythonSymIntNodeImpl>(r);
}
virtual bool bool_() override {
@ -166,11 +166,11 @@ class PythonSymIntNodeImpl : public c10::SymIntNodeImpl {
virtual SymIntNode dispatch_common_(
const char* fname,
const SymIntNode& other) {
auto pother = std::dynamic_pointer_cast<PythonSymIntNodeImpl>(other);
auto* pother = dynamic_cast<PythonSymIntNodeImpl*>(other.get());
TORCH_CHECK(pother);
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr(fname)(pother->getPyObj());
return std::make_shared<PythonSymIntNodeImpl>(r);
return c10::make_intrusive<PythonSymIntNodeImpl>(r);
}
virtual SymIntNode add(const SymIntNode& other) override {
@ -1182,12 +1182,12 @@ void initJITBindings(PyObject* module) {
.def_static(
"new_symint",
[](py::object obj) -> c10::SymIntNode {
return std::make_shared<PythonSymIntNodeImpl>(obj);
return c10::make_intrusive<PythonSymIntNodeImpl>(obj);
})
.def(
"get_pyobj",
[](c10::SymIntNode a) -> py::object {
if (auto psn = std::dynamic_pointer_cast<PythonSymIntNodeImpl>(a)) {
if (auto* psn = dynamic_cast<PythonSymIntNodeImpl*>(a.get())) {
return py::reinterpret_borrow<py::object>(psn->getPyObj());
}
return py::none();