Fix and improvements to toward 3.13t (#136319)

Small part of https://github.com/pytorch/pytorch/pull/130689
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136319
Approved by: https://github.com/malfet, https://github.com/Skylion007
This commit is contained in:
albanD
2024-09-20 04:22:18 +00:00
committed by PyTorch MergeBot
parent e3ea5429f2
commit cf31724db7
5 changed files with 31 additions and 15 deletions

View File

@ -739,7 +739,7 @@ public:
static mpy::obj<Tensor> create() {
if (!TensorType) {
TensorType = (PyTypeObject*) mpy::import("functorch.dim").attr("Tensor").ptr();
TensorType = (PyTypeObject*) mpy::import("functorch.dim").attr("Tensor").release();
}
return Tensor::alloc(TensorType);
}

View File

@ -195,7 +195,9 @@ static bool THPStorage_tryPreserve(THPStorage* self) {
TORCH_INTERNAL_ASSERT(!storage_impl->pyobj_slot()->owns_pyobj());
storage_impl->pyobj_slot()->set_owns_pyobj(true);
Py_INCREF(self);
// When resurrecting, we MUST use _Py_NewReference and not Py_INCREF to
// ensure the PyObject is in a valid state
_Py_NewReference((PyObject*)self);
self->cdata = c10::MaybeOwned<c10::Storage>::borrowed(storage);
return true;

View File

@ -391,10 +391,10 @@ static bool THPVariable_tryResurrect(THPVariable* self) {
// can't assume that some other code has taken care of it.
// NB: this will overreport _Py_RefTotal but based on inspection of object.c
// there is no way to avoid this
#ifdef Py_TRACE_REFS
_Py_AddToAllObjects(reinterpret_cast<PyObject*>(self), 1);
#endif
Py_INCREF(self);
// When resurrecting, we MUST use _Py_NewReference and not Py_INCREF to
// ensure the PyObject is in a valid state
_Py_NewReference((PyObject*)self);
// Flip THPVariable to be non-owning
// (near use-after-free miss here: fresh MaybeOwned is created breaking

View File

@ -33,6 +33,17 @@ ToIValueAllowNumbersAsTensors::~ToIValueAllowNumbersAsTensors() {
// C++->Python. We need this because otherwise we may get the old Python object
// if C++ creates a new object at the memory location of the deleted object.
void clear_registered_instances(void* ptr) {
#if IS_PYBIND_2_13_PLUS
py::detail::with_instance_map(
ptr, [&](py::detail::instance_map& registered_instances) {
auto range = registered_instances.equal_range(ptr);
for (auto it = range.first; it != range.second; ++it) {
auto vh = it->second->get_value_and_holder();
vh.set_instance_registered(false);
}
registered_instances.erase(ptr);
});
#else
auto& registered_instances =
pybind11::detail::get_internals().registered_instances;
auto range = registered_instances.equal_range(ptr);
@ -41,6 +52,7 @@ void clear_registered_instances(void* ptr) {
vh.set_instance_registered(false);
}
registered_instances.erase(ptr);
#endif
}
// WARNING: Precondition for this function is that, e.g., you have tested if a

View File

@ -19,6 +19,8 @@
namespace py = pybind11;
#define IS_PYBIND_2_13_PLUS PYBIND11_VERSION_HEX >= 0x020D0000
// This makes intrusive_ptr to be available as a custom pybind11 holder type,
// see
// https://pybind11.readthedocs.io/en/stable/advanced/smart_ptrs.html#custom-smart-pointers