Reland 2: Add PyObject preservation for UntypedStorage (#109039)

Relands #103907 after it was reverted. This PR makes the new `ignore_hermetic_tls` argument of `check_pyobj` optional to avoid causing a compilation error in torchdistx

Part of #91395

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109039
Approved by: https://github.com/ezyang
This commit is contained in:
Kurt Mohler
2023-09-12 22:26:05 +00:00
committed by PyTorch MergeBot
parent 6dc56d3490
commit 4c5e43574c
26 changed files with 1169 additions and 261 deletions

View File

@ -896,6 +896,7 @@ libtorch_python_core_sources = [
"torch/csrc/utils/python_dispatch.cpp", "torch/csrc/utils/python_dispatch.cpp",
"torch/csrc/utils/python_symnode.cpp", "torch/csrc/utils/python_symnode.cpp",
"torch/csrc/utils/pybind.cpp", "torch/csrc/utils/pybind.cpp",
"torch/csrc/utils/pyobject_preservation.cpp",
"torch/csrc/utils/structseq.cpp", "torch/csrc/utils/structseq.cpp",
"torch/csrc/utils/tensor_apply.cpp", "torch/csrc/utils/tensor_apply.cpp",
"torch/csrc/utils/tensor_dtypes.cpp", "torch/csrc/utils/tensor_dtypes.cpp",

View File

@ -0,0 +1,78 @@
#include <c10/core/RefcountedDeleter.h>
#include <mutex>
namespace c10 {
void refcounted_deleter(void* ctx_) {
RefcountedDeleterContext& ctx =
*reinterpret_cast<RefcountedDeleterContext*>(ctx_);
ctx.refcount--;
if (ctx.refcount == 0) {
ctx.other_ctx = nullptr;
delete &ctx;
}
}
std::mutex replace_data_ptr_mutex;
void maybeApplyRefcountedDeleter(const c10::Storage& storage) {
std::lock_guard<std::mutex> guard(replace_data_ptr_mutex);
c10::DataPtr& data_ptr = storage.mutable_data_ptr();
if ((void*)data_ptr.get_deleter() == (void*)&c10::refcounted_deleter) {
// Data pointer is already shared
return;
}
void* data = data_ptr.get();
void* other_ctx = data_ptr.get_context();
c10::DeleterFnPtr other_deleter = data_ptr.get_deleter();
c10::Device device = data_ptr.device();
// Release the context of the original DataPtr so that the data doesn't
// get deleted when the original DataPtr is replaced
data_ptr.release_context();
c10::RefcountedDeleterContext* refcount_ctx =
new c10::RefcountedDeleterContext(other_ctx, other_deleter);
c10::DataPtr new_data_ptr(
data,
reinterpret_cast<void*>(refcount_ctx),
&c10::refcounted_deleter,
device);
storage.set_data_ptr(std::move(new_data_ptr));
}
c10::Storage newStorageImplFromRefcountedDataPtr(const c10::Storage& storage) {
c10::maybeApplyRefcountedDeleter(storage);
c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl();
c10::DataPtr& data_ptr = storage.mutable_data_ptr();
c10::DataPtr new_data_ptr(
data_ptr.get(),
data_ptr.get_context(),
data_ptr.get_deleter(),
data_ptr.device());
// NOTE: This refcount increment should always happen immediately after
// `new_data_ptr` is created. No other lines of code should be added between
// them in the future, unless there's a very good reason for it, because if
// any errors are raised and `new_data_ptr` is deleted before the refcount is
// incremented, the refcount will get decremented and end up being one less
// than it should be.
reinterpret_cast<c10::RefcountedDeleterContext*>(data_ptr.get_context())
->refcount++;
c10::Storage new_storage = c10::make_intrusive<c10::StorageImpl>(
c10::StorageImpl::use_byte_size_t(),
storage_impl->nbytes(),
std::move(new_data_ptr),
storage_impl->allocator(),
/*resizable=*/storage_impl->resizable());
return new_storage;
}
} // namespace c10

View File

@ -0,0 +1,51 @@
#pragma once
#include <c10/core/Storage.h>
#include <c10/util/UniqueVoidPtr.h>
#include <atomic>
#include <memory>
namespace c10 {
// A RefcountedDeleterContext object is used as the `ctx` argument for DataPtr
// to implement a shared DataPtr. Normally, a DataPtr is unique, but we use
// this custom context and the `refcounted_deleter` function below to make the
// DataPtr act like a non-unique DataPtr. This context object holds onto an
// inner context and deleter function which handle the actual deletion of the
// data when the refcount reaches 0.
//
// This shared DataPtr feature is only used when storages are shared between
// multiple Python interpreters in MultiPy. Before storages had PyObject
// preservation, interpreters could just share the same StorageImpl instance.
// But now a StorageImpl can only be associated with one interpreter in order
// to properly manage a zombie PyObject. So we share storages across Python
// interpreters by creating a different StorageImpl instance for each one, but
// they all point to the same data.
struct C10_API RefcountedDeleterContext {
RefcountedDeleterContext(void* other_ctx, c10::DeleterFnPtr other_deleter)
: other_ctx(other_ctx, other_deleter), refcount(1) {}
std::unique_ptr<void, c10::DeleterFnPtr> other_ctx;
std::atomic_int refcount;
};
// `refcounted_deleter` is used as the `ctx_deleter` for DataPtr to implement
// a shared DataPtr.
//
// Warning: This should only be called on a pointer to
// a RefcountedDeleterContext that was allocated on the heap with `new`,
// because when the refcount reaches 0, the context is deleted with `delete`
C10_API void refcounted_deleter(void* ctx_);
// If the storage's DataPtr does not use `refcounted_deleter`, replace it with
// a DataPtr that does, so it can be shared between multiple StorageImpls
C10_API void maybeApplyRefcountedDeleter(const c10::Storage& storage);
// Create a new StorageImpl that points to the same data. If the original
// StorageImpl's DataPtr does not use `refcounted_deleter`, it will be replaced
// with one that does
C10_API c10::Storage newStorageImplFromRefcountedDataPtr(
const c10::Storage& storage);
} // namespace c10

View File

@ -33,7 +33,7 @@ struct C10_API SafePyObject {
~SafePyObject() { ~SafePyObject() {
if (data_ != nullptr) { if (data_ != nullptr) {
(*pyinterpreter_)->decref(data_, /*is_tensor*/ false); (*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
} }
} }

View File

@ -1,3 +1,18 @@
#include <c10/core/RefcountedDeleter.h>
#include <c10/core/Storage.h> #include <c10/core/Storage.h>
namespace c10 {} // namespace c10 namespace c10 {
bool isSharedStorageAlias(const Storage& storage0, const Storage& storage1) {
c10::DeleterFnPtr deleter_expected = &c10::refcounted_deleter;
c10::DeleterFnPtr deleter0 = storage0.data_ptr().get_deleter();
c10::DeleterFnPtr deleter1 = storage1.data_ptr().get_deleter();
if ((deleter0 != deleter_expected) || (deleter1 != deleter_expected)) {
return false;
}
return storage0.data_ptr().get_context() == storage1.data_ptr().get_context();
}
} // namespace c10

View File

@ -1,12 +1,22 @@
#pragma once #pragma once
#include <c10/core/StorageImpl.h> #include <c10/core/StorageImpl.h>
#include <c10/util/ExclusivelyOwned.h>
namespace c10 { namespace c10 {
struct Storage;
C10_API bool isSharedStorageAlias(
const Storage& storage0,
const Storage& storage1);
struct C10_API Storage { struct C10_API Storage {
public: public:
struct use_byte_size_t {}; struct use_byte_size_t {};
struct unsafe_borrow_t {
explicit unsafe_borrow_t() = default;
};
Storage() = default; Storage() = default;
Storage(c10::intrusive_ptr<StorageImpl> ptr) Storage(c10::intrusive_ptr<StorageImpl> ptr)
@ -40,6 +50,14 @@ struct C10_API Storage {
allocator, allocator,
resizable)) {} resizable)) {}
protected:
explicit Storage(unsafe_borrow_t, const Storage& rhs)
: storage_impl_(c10::intrusive_ptr<c10::StorageImpl>::reclaim(
rhs.storage_impl_.get())) {}
friend MaybeOwnedTraits<Storage>;
public:
// Legacy constructor for partially initialized (dtype or memory) storages // Legacy constructor for partially initialized (dtype or memory) storages
// that can be temporarily created with Caffe2 APIs. See the note on top of // that can be temporarily created with Caffe2 APIs. See the note on top of
// TensorImpl.h for details. // TensorImpl.h for details.
@ -144,7 +162,9 @@ struct C10_API Storage {
} }
bool is_alias_of(const Storage& other) const { bool is_alias_of(const Storage& other) const {
return storage_impl_ == other.storage_impl_; return (
storage_impl_ == other.storage_impl_ ||
isSharedStorageAlias(*this, other));
} }
void UniqueStorageShareExternalPointer( void UniqueStorageShareExternalPointer(
@ -175,4 +195,67 @@ struct C10_API Storage {
c10::intrusive_ptr<StorageImpl> storage_impl_; c10::intrusive_ptr<StorageImpl> storage_impl_;
}; };
template <>
struct MaybeOwnedTraits<c10::Storage> {
using owned_type = c10::Storage;
using borrow_type = c10::Storage;
static borrow_type createBorrow(const owned_type& from) {
return borrow_type(borrow_type::unsafe_borrow_t{}, from);
}
static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) {
lhs.unsafeReleaseStorageImpl();
lhs = borrow_type(borrow_type::unsafe_borrow_t{}, rhs);
}
static void destroyBorrow(borrow_type& toDestroy) {
toDestroy.unsafeReleaseStorageImpl(); // "leak" it, but it was already +0.
}
static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
return borrow;
}
static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
return &borrow;
}
static bool debugBorrowIsValid(const borrow_type& /*borrow*/) {
return true;
}
};
template <>
struct ExclusivelyOwnedTraits<c10::Storage> {
using repr_type = c10::Storage;
using pointer_type = c10::Storage*;
using const_pointer_type = const c10::Storage*;
static repr_type nullRepr() {
return c10::Storage();
}
template <class... Args>
static repr_type createInPlace(Args&&... args) {
return c10::Storage(std::forward<Args>(args)...);
}
static repr_type moveToRepr(c10::Storage&& x) {
return std::move(x);
}
static c10::Storage take(c10::Storage& x) {
return std::move(x);
}
static pointer_type getImpl(repr_type& x) {
return &x;
}
static const_pointer_type getImpl(const repr_type& x) {
return &x;
}
};
} // namespace c10 } // namespace c10

View File

@ -203,6 +203,14 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
return received_cuda_; return received_cuda_;
} }
impl::PyObjectSlot* pyobj_slot() {
return &pyobj_slot_;
}
const impl::PyObjectSlot* pyobj_slot() const {
return &pyobj_slot_;
}
private: private:
DataPtr data_ptr_; DataPtr data_ptr_;
SymInt size_bytes_; SymInt size_bytes_;

View File

@ -73,9 +73,7 @@ void TensorImpl::_set_fw_grad(
autograd_meta_->set_fw_grad(new_grad, self, level, is_inplace_op); autograd_meta_->set_fw_grad(new_grad, self, level, is_inplace_op);
} }
TensorImpl::~TensorImpl() { TensorImpl::~TensorImpl() = default;
pyobj_slot_.destroy_pyobj_if_needed();
}
TensorImpl::TensorImpl( TensorImpl::TensorImpl(
Storage&& storage, Storage&& storage,
@ -582,7 +580,7 @@ void TensorImpl::release_resources() {
if (storage_) { if (storage_) {
storage_ = {}; storage_ = {};
} }
pyobj_slot_.destroy_pyobj_if_needed(); pyobj_slot_.maybe_destroy_pyobj();
} }
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY #ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY

View File

@ -10,7 +10,8 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
return "<unloaded interpreter>"; return "<unloaded interpreter>";
} }
void decref(PyObject* pyobj, bool is_tensor) const override {} // do nothing void decref(PyObject* pyobj, bool has_pyobj_slot) const override {
} // do nothing
#define PANIC(m) \ #define PANIC(m) \
TORCH_INTERNAL_ASSERT( \ TORCH_INTERNAL_ASSERT( \

View File

@ -127,8 +127,8 @@ struct C10_API PyInterpreterVTable {
virtual std::string name() const = 0; virtual std::string name() const = 0;
// Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call // Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call
// See NOTE [PyInterpreter::decref takes an `is_tensor` arg] // See NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg]
virtual void decref(PyObject* pyobj, bool is_tensor) const = 0; virtual void decref(PyObject* pyobj, bool has_pyobj_slot) const = 0;
// Perform a detach by deferring to the __torch_dispatch__ implementation of // Perform a detach by deferring to the __torch_dispatch__ implementation of
// detach, which will also arrange for the PyObject to get copied in this // detach, which will also arrange for the PyObject to get copied in this

View File

@ -5,12 +5,16 @@ namespace impl {
PyObjectSlot::PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {} PyObjectSlot::PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {}
void PyObjectSlot::destroy_pyobj_if_needed() { PyObjectSlot::~PyObjectSlot() {
maybe_destroy_pyobj();
}
void PyObjectSlot::maybe_destroy_pyobj() {
if (owns_pyobj()) { if (owns_pyobj()) {
TORCH_INTERNAL_ASSERT(pyobj_interpreter_ != nullptr); TORCH_INTERNAL_ASSERT(pyobj_interpreter_ != nullptr);
TORCH_INTERNAL_ASSERT(pyobj_ != nullptr); TORCH_INTERNAL_ASSERT(pyobj_ != nullptr);
(*pyobj_interpreter_.load(std::memory_order_acquire)) (*pyobj_interpreter_.load(std::memory_order_acquire))
->decref(_unchecked_untagged_pyobj(), /*is_tensor*/ true); ->decref(_unchecked_untagged_pyobj(), /*has_pyobj_slot*/ true);
// NB: this destructor can only be entered when there are no // NB: this destructor can only be entered when there are no
// references to this C++ object (obviously), NOR any references // references to this C++ object (obviously), NOR any references
// to the PyObject (if there are references to the PyObject, // to the PyObject (if there are references to the PyObject,
@ -47,6 +51,15 @@ PyInterpreter& PyObjectSlot::load_pyobj_interpreter() const {
(*pyobj_interpreter_.load())->name()); (*pyobj_interpreter_.load())->name());
} }
bool PyObjectSlot::check_interpreter(PyInterpreter* interpreter) {
return interpreter == pyobj_interpreter();
}
bool PyObjectSlot::has_pyobj_nonhermetic() {
return check_pyobj(pyobj_interpreter(), /*ignore_hermetic_tls=*/true)
.has_value();
}
bool PyObjectSlot::owns_pyobj() { bool PyObjectSlot::owns_pyobj() {
// NOLINTNEXTLINE(performance-no-int-to-ptr) // NOLINTNEXTLINE(performance-no-int-to-ptr)
return reinterpret_cast<uintptr_t>(pyobj_) & 1; return reinterpret_cast<uintptr_t>(pyobj_) & 1;

View File

@ -14,7 +14,9 @@ struct C10_API PyObjectSlot {
public: public:
PyObjectSlot(); PyObjectSlot();
void destroy_pyobj_if_needed(); ~PyObjectSlot();
void maybe_destroy_pyobj();
// Associate the TensorImpl with the specified PyObject, and, if necessary, // Associate the TensorImpl with the specified PyObject, and, if necessary,
// also tag the interpreter. // also tag the interpreter.
@ -82,9 +84,20 @@ struct C10_API PyObjectSlot {
// a non-nullopt (but possibly null) PyObject. If (possibly) untagged, // a non-nullopt (but possibly null) PyObject. If (possibly) untagged,
// returns a nullopt. If it is definitely invalid, raises an error. // returns a nullopt. If it is definitely invalid, raises an error.
// //
// If `ignore_hermetic_tls` is false and this function is called from a
// hermetic context (ie, `HermeticPyObjectTLS::get_state()` is true), then
// nullopt is returned. If `ignore_hermetic_tls` is true, then the hermetic
// context is ignored, allowing you to check the interpreter tag of a
// nonhermetic PyObject from within a hermetic context. This is necessary
// because there are some cases where the deallocator function of a
// nonhermetic PyObject is called from within a hermetic context, so it must
// be properly treated as a nonhermetic PyObject.
//
// NB: this lives in header so that we can avoid actually creating the // NB: this lives in header so that we can avoid actually creating the
// c10::optional // c10::optional
c10::optional<PyObject*> check_pyobj(PyInterpreter* self_interpreter) const { c10::optional<PyObject*> check_pyobj(
PyInterpreter* self_interpreter,
bool ignore_hermetic_tls = false) const {
// Note [Memory ordering on Python interpreter tag] // Note [Memory ordering on Python interpreter tag]
impl::PyInterpreter* interpreter = impl::PyInterpreter* interpreter =
pyobj_interpreter_.load(std::memory_order_acquire); pyobj_interpreter_.load(std::memory_order_acquire);
@ -97,7 +110,7 @@ struct C10_API PyObjectSlot {
return c10::nullopt; return c10::nullopt;
} else if (interpreter == self_interpreter) { } else if (interpreter == self_interpreter) {
// NB: pyobj_ could still be null! // NB: pyobj_ could still be null!
if (c10::impl::HermeticPyObjectTLS::get_state()) { if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) {
return c10::nullopt; return c10::nullopt;
} else { } else {
return c10::make_optional(_unchecked_untagged_pyobj()); return c10::make_optional(_unchecked_untagged_pyobj());
@ -118,6 +131,13 @@ struct C10_API PyObjectSlot {
PyInterpreter& load_pyobj_interpreter() const; PyInterpreter& load_pyobj_interpreter() const;
// Check if the PyObjectSlot's interpreter is the same as the specified
// interpreter
bool check_interpreter(PyInterpreter* interpreter);
// Check if the PyObjectSlot is holding a PyObject, owned or non-owned
bool has_pyobj_nonhermetic();
bool owns_pyobj(); bool owns_pyobj();
void set_owns_pyobj(bool b); void set_owns_pyobj(bool b);

View File

@ -1118,7 +1118,7 @@ int64_t _Tensor_ndim(mpy::handle h) {
mpy::handle handle_from_tensor(Arena& A, TensorRef t) { mpy::handle handle_from_tensor(Arena& A, TensorRef t) {
// fast case: tensor is live in python // fast case: tensor is live in python
c10::optional<PyObject*> mb_obj = c10::optional<PyObject*> mb_obj =
t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(getPyInterpreter()); t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(getPyInterpreter(), /*ignore_hermetic_tls=*/false);
if (mb_obj.has_value() && !t->unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) { if (mb_obj.has_value() && !t->unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) {
return *mb_obj; return *mb_obj;
} }

View File

@ -8832,6 +8832,16 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
T() T()
def test_storage_base_init(self):
# Direct construction not OK
self.assertRaises(RuntimeError, lambda: torch._C.StorageBase())
# But construction of subclass is OK
class T(torch._C.StorageBase):
pass
T()
def test_tensor_base_new(self): def test_tensor_base_new(self):
# OK to call super().__new__, see # OK to call super().__new__, see
@ -8844,6 +8854,18 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
x = torch.ones(5) x = torch.ones(5)
test_tensor = TestTensor(x) test_tensor = TestTensor(x)
def test_storage_base_new(self):
# OK to call super().__new__, see
# https://github.com/pytorch/pytorch/issues/57421
class TestStorage(torch._C.StorageBase):
@staticmethod
def __new__(cls, x, *args, **kwargs):
return super().__new__(cls, x, *args, **kwargs)
x = torch.UntypedStorage(5)
test_storage = TestStorage(x)
def test_pyobj_preserved(self): def test_pyobj_preserved(self):
x = torch.empty(2) x = torch.empty(2)
x.foo = 2 # put something on __dict__ x.foo = 2 # put something on __dict__
@ -8868,6 +8890,160 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
del z # it's dead again del z # it's dead again
self.assertEqual(type(y.grad), MyTensor) self.assertEqual(type(y.grad), MyTensor)
@skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
def test_storage_dealloc(self):
m, t = Tracker.make()
s0 = torch.UntypedStorage(10)
s1 = s0
s0._tracker = t
del t
self.assertFalse(m[0])
del s0
self.assertFalse(m[0])
del s1
self.assertTrue(m[0])
@skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
def test_storage_from_tensor_dealloc(self):
m, t = Tracker.make()
a = torch.randn(10)
s0 = a.untyped_storage()
s0._tracker = t
del t
s1 = a.untyped_storage()
self.assertTrue(s0 is s1)
self.assertTrue(hasattr(s1, '_tracker'))
del a
self.assertFalse(m[0])
del s0
self.assertFalse(m[0])
del s1
self.assertTrue(m[0])
@skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
def test_storage_from_tensor_dealloc_zombie(self):
m, t = Tracker.make()
a = torch.randn(10)
s0 = a.untyped_storage()
s0._tracker = t
del t
s1 = a.untyped_storage()
self.assertTrue(s0 is s1)
self.assertTrue(hasattr(s1, '_tracker'))
self.assertFalse(m[0])
del s0
self.assertFalse(m[0])
del s1
self.assertFalse(m[0])
del a
self.assertTrue(m[0])
@skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
def test_storage_from_tensor_dealloc_resurrected(self):
m, t = Tracker.make()
a = torch.randn(10)
s0 = a.untyped_storage()
s0._tracker = t
del t
s1 = a.untyped_storage()
self.assertTrue(s0 is s1)
self.assertTrue(hasattr(s1, '_tracker'))
self.assertFalse(m[0])
del s0
self.assertFalse(m[0])
del s1
self.assertFalse(m[0])
s0 = a.untyped_storage()
self.assertTrue(isinstance(s0, torch.UntypedStorage))
del a
self.assertFalse(m[0])
del s0
self.assertTrue(m[0])
@skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
def test_storage_dealloc_resurrected(self):
m, t = Tracker.make()
s = torch.UntypedStorage(10)
s._tracker = t
del t
a = torch.tensor(s)
self.assertFalse(m[0])
del s
self.assertFalse(m[0])
s = a.untyped_storage()
self.assertTrue(isinstance(s, torch.UntypedStorage))
del a
self.assertFalse(m[0])
del s
self.assertTrue(m[0])
@skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
def test_storage_dealloc_subclass_zombie(self):
class MyStorage(torch.UntypedStorage):
finalized_count = 0
def __del__(self):
MyStorage.finalized_count += 1
m, t = Tracker.make()
s = MyStorage(10)
s._tracker = t
del t
a = torch.tensor(s)
self.assertFalse(m[0])
del s
self.assertEqual(MyStorage.finalized_count, 0)
self.assertFalse(m[0])
del a
self.assertEqual(MyStorage.finalized_count, 1)
self.assertTrue(m[0])
@skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
def test_storage_dealloc_subclass_resurrected(self):
class MyStorage(torch.UntypedStorage):
finalized_count = 0
def __del__(self):
MyStorage.finalized_count += 1
m, t = Tracker.make()
s = MyStorage(10)
s._tracker = t
del t
a = torch.tensor(s)
self.assertFalse(m[0])
del s
self.assertEqual(MyStorage.finalized_count, 0)
self.assertFalse(m[0])
s = a.untyped_storage()
del a
self.assertFalse(m[0])
self.assertEqual(MyStorage.finalized_count, 0)
self.assertTrue(isinstance(s, MyStorage))
del s
self.assertEqual(MyStorage.finalized_count, 1)
self.assertTrue(m[0])
def test_tensor_slot_dealloc(self): def test_tensor_slot_dealloc(self):
class SlotTensor1(torch._C._TensorBase): class SlotTensor1(torch._C._TensorBase):
@ -8889,6 +9065,27 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
self.assertTrue(m1[0]) self.assertTrue(m1[0])
self.assertTrue(m2[0]) self.assertTrue(m2[0])
def test_storage_slot_dealloc(self):
class SlotStorage1(torch._C.StorageBase):
__slots__ = ['slot1']
class SlotStorage2(SlotStorage1):
__slots__ = ['slot2']
m1, t1 = Tracker.make()
m2, t2 = Tracker.make()
slot_storage = SlotStorage2(torch.UntypedStorage(2))
slot_storage.slot1 = t1
slot_storage.slot2 = t2
del t1
del t2
self.assertFalse(m1[0])
self.assertFalse(m2[0])
del slot_storage
self.assertTrue(m1[0])
self.assertTrue(m2[0])
@skipIfTorchDynamo("Not a suitable test for TorchDynamo") @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
def test_tensor_dict_dealloc(self): def test_tensor_dict_dealloc(self):
m, t = Tracker.make() m, t = Tracker.make()
@ -8899,6 +9096,16 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
del x del x
self.assertTrue(m[0]) self.assertTrue(m[0])
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
def test_storage_dict_dealloc(self):
m, t = Tracker.make()
x = torch.UntypedStorage(2)
x.arf = t
del t
self.assertFalse(m[0])
del x
self.assertTrue(m[0])
def test_tensor_finalizer_dealloc(self): def test_tensor_finalizer_dealloc(self):
m = [False] m = [False]
@ -8911,9 +9118,20 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
del fin_tensor del fin_tensor
self.assertTrue(m[0]) self.assertTrue(m[0])
def test_storage_finalizer_dealloc(self):
m = [False]
class FinalizerStorage(torch._C.StorageBase):
def __del__(self):
m[0] = True
fin_storage = FinalizerStorage(torch.UntypedStorage(2))
self.assertFalse(m[0])
del fin_storage
self.assertTrue(m[0])
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993") @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
def test_tensor_weakref_dealloc(self): def test_tensor_weakref_dealloc(self):
x = torch.empty(2) x = torch.empty(2)
m = [False] m = [False]
@ -8925,6 +9143,20 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
self.assertTrue(m[0]) self.assertTrue(m[0])
self.assertEqual(wref(), None) self.assertEqual(wref(), None)
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
def test_storage_weakref_dealloc(self):
x = torch.UntypedStorage(2)
m = [False]
def cb(r):
m[0] = True
wref = weakref.ref(x, cb)
del x
self.assertTrue(m[0])
self.assertEqual(wref(), None)
@skipIfTorchDynamo("Not a suitable test for TorchDynamo") @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
def test_tensor_cycle_via_dict(self): def test_tensor_cycle_via_dict(self):
m1, t1 = Tracker.make() m1, t1 = Tracker.make()
@ -8968,6 +9200,49 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
self.assertTrue(m1[0]) self.assertTrue(m1[0])
self.assertTrue(m2[0]) self.assertTrue(m2[0])
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
def test_storage_cycle_via_dict(self):
m1, t1 = Tracker.make()
x = torch.UntypedStorage(2)
x._tracker = t1
del t1
m2, t2 = Tracker.make()
y = torch.UntypedStorage(2)
y._tracker = t2
del t2
x._loop = y
y._loop = x
# C++ reference should keep the cycle live!
# This exercise THPVariable_subtype_traverse
# NB: Because z.grad is a reference done entirely in C++, cycles
# involving it directly are NOT broken by Python GC; you've
# set up a good old C++ reference cycle which we cannot safely
# break (because C++ references are allowed to be accessed
# multithreaded-ly) (TODO: except maybe if you can prove that
# only Python has access to the C++ object, in which case you can
# also prove that no multithreaded access occurs)
z = torch.UntypedStorage(2)
z.grad = x
del x
del y
gc.collect()
self.assertFalse(m1[0])
self.assertFalse(m2[0])
with disable_gc():
del z
self.assertFalse(m1[0])
self.assertFalse(m2[0])
gc.collect()
self.assertTrue(m1[0])
self.assertTrue(m2[0])
def test_tensor_cycle_via_slots(self): def test_tensor_cycle_via_slots(self):
m1 = [False] m1 = [False]
m2 = [False] m2 = [False]
@ -9000,6 +9275,67 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
self.assertTrue(m1[0]) self.assertTrue(m1[0])
self.assertTrue(m2[0]) self.assertTrue(m2[0])
def test_storage_cycle_via_slots(self):
m1 = [False]
m2 = [False]
class SlotStorage1(torch._C.StorageBase):
__slots__ = ['slot1']
def __del__(self):
m1[0] = True
class SlotStorage2(SlotStorage1):
__slots__ = ['slot2']
def __del__(self):
m2[0] = True
x = SlotStorage1(torch.UntypedStorage(2))
y = SlotStorage2(torch.UntypedStorage(2))
x.slot1 = y
y.slot2 = x
del x
with disable_gc():
del y
self.assertFalse(m1[0])
self.assertFalse(m2[0])
gc.collect()
self.assertTrue(m1[0])
self.assertTrue(m2[0])
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
def test_storage_preserve_nonhermetic_in_hermetic_context(self):
from torch.library import Library, impl
global _my_storage
my_lib = Library("my_lib", "DEF")
my_lib.define('my_func() -> None')
a = torch.tensor([1.])
_my_storage = a.untyped_storage()
m, t = Tracker.make()
_my_storage._tracker = t
del t
@impl(my_lib, 'my_func', '')
def my_func():
global _my_storage
del _my_storage
self.assertFalse(m[0])
torch.ops.my_lib.my_func()
self.assertFalse(m[0])
s = a.untyped_storage()
del a
del s
self.assertTrue(m[0])
# FIXME: move to test_autograd? # FIXME: move to test_autograd?
@skipIfTorchDynamo("TorchDynamo does not work well with hooks") @skipIfTorchDynamo("TorchDynamo does not work well with hooks")
def test_backward_hooks_traverse(self): def test_backward_hooks_traverse(self):
@ -9028,7 +9364,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
self.assertTrue(m2[0]) self.assertTrue(m2[0])
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993") @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
def test_dead_weak_ref(self): def test_tensor_dead_weak_ref(self):
x = torch.empty(2) x = torch.empty(2)
w_x = weakref.ref(x) w_x = weakref.ref(x)
y = torch.empty(2) y = torch.empty(2)
@ -9044,7 +9380,24 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
self.assertRaises(RuntimeError, lambda: x.sigmoid()) self.assertRaises(RuntimeError, lambda: x.sigmoid())
def test_resurrected_weak_ref(self): @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
def test_storage_dead_weak_ref(self):
x = torch.UntypedStorage(2)
w_x = weakref.ref(x)
y = torch.tensor(x)
del x
x = w_x()
# Ideally, x would keep the storage live. But CPython doesn't
# provide enough hooks to do this. So it will go dead and x
# will transmute into storage with null StorageImpl. Not great, but the
# best we can do.
del y
self.assertRaisesRegex(RuntimeError, "Got a null Storage", lambda: x[0])
self.assertRaisesRegex(RuntimeError, "Got a null Storage", lambda: x.float())
def test_tensor_resurrected_weak_ref(self):
x = torch.empty(2) x = torch.empty(2)
w_x = weakref.ref(x) w_x = weakref.ref(x)
y = torch.empty(2) y = torch.empty(2)
@ -9057,8 +9410,20 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
del y del y
x.sigmoid() x.sigmoid()
def test_storage_resurrected_weak_ref(self):
x = torch.UntypedStorage(2)
w_x = weakref.ref(x)
y = torch.tensor(x)
del x
x = w_x()
# Use this to manually fix weak reference after dereferencing them
x._fix_weakref()
del y
x.float()
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993") @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
def test_fix_weakref_no_leak(self): def test_tensor_fix_weakref_no_leak(self):
import weakref import weakref
called = False called = False
@ -9074,6 +9439,23 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
self.assertTrue(called) self.assertTrue(called)
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
def test_storage_fix_weakref_no_leak(self):
import weakref
called = False
a = torch.UntypedStorage(1)
def callback(w):
nonlocal called
called = True
wa = weakref.ref(a, callback)
a._fix_weakref()
del a
self.assertTrue(called)
# FIXME: move to test_linalg # FIXME: move to test_linalg
@torch.inference_mode() @torch.inference_mode()
def test_bmm_multithreaded(self): def test_bmm_multithreaded(self):

View File

@ -30,29 +30,6 @@ std::array<THPDtype*, static_cast<int>(at::ScalarType::NumOptions)>
std::array<THPLayout*, static_cast<int>(at::Layout::NumOptions)> std::array<THPLayout*, static_cast<int>(at::Layout::NumOptions)>
layout_registry = {}; layout_registry = {};
at::DeprecatedTypeProperties* get_type_properties(
at::DeviceType device_type,
at::ScalarType scalarType) {
at::Backend backend = at::Backend::Undefined;
if (device_type == at::kCPU) {
backend = at::Backend::CPU;
} else if (device_type == at::kCUDA) {
backend = at::Backend::CUDA;
} else if (device_type == at::kXPU) {
backend = at::Backend::XPU;
} else if (device_type == at::kHPU) {
backend = at::Backend::HPU;
} else if (device_type == at::kMPS) {
backend = at::Backend::MPS;
} else if (device_type == at::DeviceType::Meta) {
backend = at::Backend::Undefined;
} else if (device_type == at::DeviceType::PrivateUse1) {
backend = at::Backend::PrivateUse1;
} else {
TORCH_CHECK(false, "Invalid device for storage: ", device_type);
}
return &at::getDeprecatedTypeProperties(backend, scalarType);
}
} // namespace } // namespace
void registerDtypeObject(THPDtype* dtype, at::ScalarType scalarType) { void registerDtypeObject(THPDtype* dtype, at::ScalarType scalarType) {
@ -88,13 +65,10 @@ PyObject* createPyObject(const at::Storage& storage) {
// information about storages from python). However, any accesses to the // information about storages from python). However, any accesses to the
// data_ptr is not allowed, through methods like // data_ptr is not allowed, through methods like
// x.untyped_storage().data_ptr() // x.untyped_storage().data_ptr()
PyTypeObject* type = reinterpret_cast<PyTypeObject*>(THPStorageClass); PyObject* obj = THPStorage_Wrap(storage);
auto obj = THPObjectPtr(type->tp_alloc(type, 0));
if (!obj) if (!obj)
throw python_error(); throw python_error();
((THPStorage*)obj.get())->cdata = return obj;
c10::MaybeOwned<at::Storage>::owned(at::Storage(/* copy */ storage));
return obj.release();
} }
PyTypeObject* loadTypedStorageTypeObject() { PyTypeObject* loadTypedStorageTypeObject() {
@ -118,16 +92,13 @@ bool isStorage(PyObject* obj) {
if (PyObject_TypeCheck(obj, getTypedStorageTypeObject())) { if (PyObject_TypeCheck(obj, getTypedStorageTypeObject())) {
return true; return true;
} }
auto obj_type = Py_TYPE(obj); return THPStorage_Check(obj);
return obj_type == reinterpret_cast<PyTypeObject*>(THPStorageClass);
} }
at::Storage createStorageGetType( std::tuple<at::Storage, at::ScalarType, bool> createStorageGetType(
PyObject* obj, PyObject* obj) {
at::ScalarType& scalar_type, at::ScalarType scalar_type = at::ScalarType::Undefined;
bool& is_typed_storage) { bool is_typed_storage = PyObject_TypeCheck(obj, getTypedStorageTypeObject());
is_typed_storage = PyObject_TypeCheck(obj, getTypedStorageTypeObject());
PyObject* untyped_storage_obj = nullptr; PyObject* untyped_storage_obj = nullptr;
if (is_typed_storage) { if (is_typed_storage) {
@ -136,10 +107,9 @@ at::Storage createStorageGetType(
// stay nonzero since the `TypedStorage` maintains a reference. // stay nonzero since the `TypedStorage` maintains a reference.
PyObject* dtype_obj = PyObject_GetAttrString(obj, "dtype"); PyObject* dtype_obj = PyObject_GetAttrString(obj, "dtype");
TORCH_INTERNAL_ASSERT(dtype_obj); TORCH_INTERNAL_ASSERT(dtype_obj);
Py_DECREF(dtype_obj);
TORCH_INTERNAL_ASSERT(THPDtype_Check(dtype_obj)); TORCH_INTERNAL_ASSERT(THPDtype_Check(dtype_obj));
scalar_type = reinterpret_cast<THPDtype*>(dtype_obj)->scalar_type; scalar_type = reinterpret_cast<THPDtype*>(dtype_obj)->scalar_type;
Py_DECREF(dtype_obj);
untyped_storage_obj = PyObject_GetAttrString(obj, "_untyped_storage"); untyped_storage_obj = PyObject_GetAttrString(obj, "_untyped_storage");
TORCH_INTERNAL_ASSERT(untyped_storage_obj); TORCH_INTERNAL_ASSERT(untyped_storage_obj);
@ -150,22 +120,18 @@ at::Storage createStorageGetType(
untyped_storage_obj = obj; untyped_storage_obj = obj;
} }
if (Py_TYPE(untyped_storage_obj) != TORCH_CHECK(
reinterpret_cast<PyTypeObject*>(THPStorageClass)) { THPStorage_Check(untyped_storage_obj),
throw TypeError("not a storage '%s'", Py_TYPE(obj)->tp_name); "not a storage '",
} Py_TYPE(obj)->tp_name,
"'");
const auto& storage = THPStorage_Unpack(untyped_storage_obj); auto storage = THPStorage_Unpack(untyped_storage_obj);
c10::DeviceType device_type = storage.device().type(); return std::make_tuple(storage, scalar_type, is_typed_storage);
auto type_properties = get_type_properties(device_type, at::kByte);
return type_properties->unsafeStorageFromTH(
storage.unsafeGetStorageImpl(), true);
} }
at::Storage createStorage(PyObject* obj) { at::Storage createStorage(PyObject* obj) {
at::ScalarType scalar_type = at::ScalarType::Undefined; return std::get<0>(createStorageGetType(obj));
bool is_typed_storage = false;
return createStorageGetType(obj, scalar_type, is_typed_storage);
} }
} // namespace torch } // namespace torch

View File

@ -27,10 +27,8 @@ void registerLayoutObject(THPLayout* thp_layout, at::Layout layout);
TORCH_PYTHON_API PyObject* createPyObject(const at::Storage& storage); TORCH_PYTHON_API PyObject* createPyObject(const at::Storage& storage);
at::Storage createStorage(PyObject* obj); at::Storage createStorage(PyObject* obj);
at::Storage createStorageGetType( std::tuple<at::Storage, at::ScalarType, bool> createStorageGetType(
PyObject* obj, PyObject* obj);
at::ScalarType& scalar_type,
bool& is_typed_storage);
bool isStorage(PyObject* obj); bool isStorage(PyObject* obj);
TORCH_PYTHON_API THPDtype* getTHPDtype(at::ScalarType scalarType); TORCH_PYTHON_API THPDtype* getTHPDtype(at::ScalarType scalarType);

View File

@ -35,7 +35,7 @@ struct ConcretePyInterpreterVTable final
: public c10::impl::PyInterpreterVTable { : public c10::impl::PyInterpreterVTable {
std::string name() const override; std::string name() const override;
void decref(PyObject* pyobj, bool is_tensor) const override; void decref(PyObject* pyobj, bool has_pyobj_slot) const override;
// TODO: Need to make this work for StorageImpl too. I imagine I'll want to // TODO: Need to make this work for StorageImpl too. I imagine I'll want to
// operate upon a PyObjectSlot rather than a TensorImpl // operate upon a PyObjectSlot rather than a TensorImpl
@ -189,15 +189,15 @@ py::object torchDispatchFromTensorImpl(
TorchFunctionName::TorchDispatch)); TorchFunctionName::TorchDispatch));
} }
// NOTE [PyInterpreter::decref takes an `is_tensor` arg] // NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg]
// Before calling PyInterpreter::decref, we must statically know if the // Before calling PyInterpreter::decref, we must statically know if the
// pyobj is a Tensor or not. // pyobj has a PyObjectSlot or not.
// - If it is a tensor, we need to be careful about PyObject resurrection // - If it has a PyObjectSlot, we need to be careful about PyObject resurrection
// - If it is not a tensor, we can freely decref // - If it does not have a PyObjectSlot, we can freely decref
// One alternative to this is using PyObject_IsInstance // One alternative to this is using PyObject_IsInstance
// to get at this information. However, we don't want to risk an incorrect // to get at this information. However, we don't want to risk an incorrect
// `__instancecheck__` changing the semantics here. // `__instancecheck__` changing the semantics here.
void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool is_tensor) void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool has_pyobj_slot)
const { const {
// Leak the pyobj if not initialized. This can happen if we are running // Leak the pyobj if not initialized. This can happen if we are running
// exit handlers that are destructing tensors with residual (owned) // exit handlers that are destructing tensors with residual (owned)
@ -207,23 +207,33 @@ void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool is_tensor)
pybind11::gil_scoped_acquire gil; pybind11::gil_scoped_acquire gil;
// Two possibilities: // Two possibilities:
// 1. We are decref-ing a tensor. Then we must be careful about // 1. We are decref-ing an object that has a PyObjectSlot, like a Tensor or
// PyObject resurrection (this only applies to Tensors, see // Storage. Then we must be careful about PyObject resurrection (see
// THPVariable_clear). // THPVariable_clear).
// 2. We are decref-ing some other Python object. We don't do // 2. We are decref-ing some other Python object. We don't do
// PyObject resurrection on non-Tensors, so we just carry on as usual // PyObject resurrection on non-Tensors, so we just carry on as usual
if (is_tensor && Py_REFCNT(pyobj) > 1) { if (has_pyobj_slot && Py_REFCNT(pyobj) > 1) {
// It's still alive! This can happen if a weak ref resurrected if (THPVariable_Check(pyobj)) {
// the PyObject without flipping ownership. At this point it is // It's still alive! This can happen if a weak ref resurrected
// too late to rescue the object, so just stub out the PyObject // the PyObject without flipping ownership. At this point it is
// so that it fails on subsequent uses. Don't raise an error here; // too late to rescue the object, so just stub out the PyObject
// you're probably in a destructor. // so that it fails on subsequent uses. Don't raise an error here;
TORCH_WARN( // you're probably in a destructor.
"Deallocating Tensor that still has live PyObject references. " TORCH_WARN(
"This probably happened because you took out a weak reference to " "Deallocating Tensor that still has live PyObject references. "
"Tensor and didn't call _fix_weakref() after dereferencing it. " "This probably happened because you took out a weak reference to "
"Subsequent accesses to this tensor via the PyObject will now fail."); "Tensor and didn't call _fix_weakref() after dereferencing it. "
((THPVariable*)pyobj)->cdata = c10::MaybeOwned<torch::autograd::Variable>(); "Subsequent accesses to this tensor via the PyObject will now fail.");
((THPVariable*)pyobj)->cdata =
c10::MaybeOwned<torch::autograd::Variable>();
} else if (THPStorage_Check(pyobj)) {
TORCH_WARN(
"Deallocating UntypedStorage that still has live PyObject references. "
"This probably happened because you took out a weak reference to "
"UntypedStorage and didn't call _fix_weakref() after dereferencing it. "
"Subsequent accesses to this storage via the PyObject will now fail.");
((THPStorage*)pyobj)->cdata = c10::MaybeOwned<c10::Storage>();
}
} }
Py_DECREF(pyobj); Py_DECREF(pyobj);
}; };
@ -548,8 +558,8 @@ static void set_tensor_attr_with_capsule(
const c10::TensorImpl* tensor, const c10::TensorImpl* tensor,
py::capsule& capsule, py::capsule& capsule,
const char* attr_name) { const char* attr_name) {
c10::optional<PyObject*> mb_obj = c10::optional<PyObject*> mb_obj = tensor->pyobj_slot()->check_pyobj(
tensor->pyobj_slot()->check_pyobj(getPyInterpreter()); getPyInterpreter(), /*ignore_hermetic_tls=*/false);
TORCH_CHECK( TORCH_CHECK(
mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value"); mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value");
py::handle(mb_obj.value()).attr(attr_name) = capsule; py::handle(mb_obj.value()).attr(attr_name) = capsule;

View File

@ -6,6 +6,7 @@
#include <ATen/mps/MPSDevice.h> #include <ATen/mps/MPSDevice.h>
#include <c10/core/CPUAllocator.h> #include <c10/core/CPUAllocator.h>
#include <c10/core/RefcountedDeleter.h>
#include <libshm.h> #include <libshm.h>
#include <torch/csrc/CudaIPCTypes.h> #include <torch/csrc/CudaIPCTypes.h>
#include <torch/csrc/Device.h> #include <torch/csrc/Device.h>
@ -15,6 +16,7 @@
#include <torch/csrc/THP.h> #include <torch/csrc/THP.h>
#include <torch/csrc/autograd/utils/wrap_outputs.h> #include <torch/csrc/autograd/utils/wrap_outputs.h>
#include <torch/csrc/copy_utils.h> #include <torch/csrc/copy_utils.h>
#include <torch/csrc/utils/pyobject_preservation.h>
#include <torch/csrc/utils/python_arg_parser.h> #include <torch/csrc/utils/python_arg_parser.h>
#include <c10/util/intrusive_ptr.h> #include <c10/util/intrusive_ptr.h>
@ -27,28 +29,265 @@ void THPPointer<c10::StorageImpl>::free() {
} }
} }
PyObject* THPStorageClass = nullptr; PyTypeObject* THPStorageClass = nullptr;
PyObject* THPStorage_New(c10::Storage storage) { PyObject* THPStorage_NewWithStorage(
PyTypeObject* type = (PyTypeObject*)THPStorageClass; PyTypeObject* type,
PyObject* obj = type->tp_alloc(type, 0); c10::Storage _storage,
if (obj) { c10::impl::PyInterpreterStatus status,
((THPStorage*)obj)->cdata = bool allow_preexisting_pyobj) {
c10::MaybeOwned<c10::Storage>::owned(std::move(storage)); TORCH_CHECK(
PyType_IsSubtype(type, &THPStorageType),
"Creating a Storage subclass from a class that does not inherit from ",
"Storage is not possible. Make sure your class inherits from Storage.");
auto maybe_pyobj = _storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj(
getPyInterpreter(), /*ignore_hermetic_tls=*/false);
if (maybe_pyobj.has_value() && maybe_pyobj.value()) {
TORCH_CHECK(
allow_preexisting_pyobj,
"Creating a new Storage subclass ",
type->tp_name,
" but the raw Storage object is already associated to a python object ",
"of type ",
maybe_pyobj.value()->ob_type->tp_name);
PyObject* obj = *maybe_pyobj;
PyTypeObject* obj_type = Py_TYPE(obj);
TORCH_CHECK(
obj_type == type || PyType_IsSubtype(obj_type, type),
"Creating a new Storage subclass ",
type->tp_name,
" but the raw Storage object is already associated to a python object ",
"of type ",
maybe_pyobj.value()->ob_type->tp_name,
" which is not a subclass of the "
"requested type");
return THPStorage_Wrap(std::move(_storage));
} }
PyObject* obj = type->tp_alloc(type, 0);
TORCH_CHECK(obj, "Failed to allocate a ", type->tp_name, " object");
auto s = (THPStorage*)obj;
new (&s->cdata) c10::MaybeOwned<c10::Storage>();
s->cdata = c10::MaybeOwned<c10::Storage>::owned(std::move(_storage));
if (!c10::impl::HermeticPyObjectTLS::get_state()) {
s->is_hermetic = false;
const auto& storage = THPStorage_Unpack(s);
storage.unsafeGetStorageImpl()->pyobj_slot()->init_pyobj(
getPyInterpreter(), obj, status);
} else {
s->is_hermetic = true;
}
return obj; return obj;
} }
// Wraps the c10::Storage with a storage PyObject
PyObject* THPStorage_Wrap(c10::Storage storage) {
c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl();
if (c10::impl::HermeticPyObjectTLS::get_state()) {
return THPStorage_NewWithStorage(
THPStorageClass,
std::move(storage),
c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
}
c10::impl::PyObjectSlot* pyobj_slot = storage_impl->pyobj_slot();
// If the StorageImpl has a PyObject that is managed by a different
// interpreter than the current one, create a new StorageImpl that points to
// the same data and then create the Python storage from that.
// NOTE: This is only supposed to happen in MultiPy
if (pyobj_slot->has_pyobj_nonhermetic() &&
!pyobj_slot->check_interpreter(getPyInterpreter())) {
return THPStorage_NewWithStorage(
THPStorageClass,
c10::newStorageImplFromRefcountedDataPtr(storage),
c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
}
c10::optional<PyObject*> maybe_pyobj = pyobj_slot->check_pyobj(
getPyInterpreter(), /*ignore_hermetic_tls=*/false);
c10::impl::PyInterpreterStatus status =
c10::impl::PyInterpreterStatus::TAGGED_BY_US;
if (maybe_pyobj.has_value()) {
auto obj = *maybe_pyobj;
if (obj) {
TORCH_CHECK(
THPStorage_Check(obj),
"Expected a storage type, but got ",
Py_TYPE(obj)->tp_name);
if (pyobj_slot->owns_pyobj()) {
pyobj_slot->set_owns_pyobj(false);
reinterpret_cast<THPStorage*>(obj)->cdata =
c10::MaybeOwned<c10::Storage>::owned(std::move(storage));
return obj;
} else {
Py_INCREF(obj);
return obj;
}
}
status = c10::impl::PyInterpreterStatus::TAGGED_BY_US;
} else {
if (storage.use_count() <= 1) {
status = c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED;
} else {
status = c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED;
}
}
return THPStorage_NewWithStorage(THPStorageClass, std::move(storage), status);
}
static bool THPStorage_isPreservable(THPStorage* self) {
if (self->cdata.unsafeIsBorrowed()) {
return false;
}
auto const& storage = THPStorage_Unpack(self);
if (self->is_hermetic) {
return false;
}
if (storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj(
getPyInterpreter(), /*ignore_hermetic_tls=*/true) !=
c10::make_optional((PyObject*)self)) {
return false;
}
if (storage.use_count() <= 1) {
return false;
}
return true;
}
static bool THPStorage_tryPreserve(THPStorage* self) {
if (!THPStorage_isPreservable(self)) {
return false;
}
const auto& storage = THPStorage_Unpack(self);
c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl();
auto maybe_pyobj = storage_impl->pyobj_slot()->check_pyobj(
getPyInterpreter(),
/*ignore_hermetic_tls=*/true);
// NOTE: It is possible to just set the PyObjectSlot here, but the point is
// that we should have already set PyObjectSlot when the storage PyObject was
// created.
TORCH_INTERNAL_ASSERT(
maybe_pyobj.has_value(),
"Trying to preserve a Python storage whose PyObjectSlot does not have a PyObject");
PyObject* pyobj = *maybe_pyobj;
TORCH_CHECK(
THPStorage_Check(pyobj),
"Expected a storage type, but got ",
Py_TYPE(pyobj)->tp_name);
TORCH_INTERNAL_ASSERT(
(void*)pyobj == (void*)self,
"Python storage and the PyObject in the internal PyObjectSlot are not at the same address");
TORCH_INTERNAL_ASSERT(!storage_impl->pyobj_slot()->owns_pyobj());
storage_impl->pyobj_slot()->set_owns_pyobj(true);
Py_INCREF(self);
self->cdata = c10::MaybeOwned<c10::Storage>::borrowed(storage);
return true;
}
static void THPStorage_subclass_dealloc(PyObject* self) { static void THPStorage_subclass_dealloc(PyObject* self) {
THPStorage* _self = (THPStorage*)self; THPStorage* _self = (THPStorage*)self;
// Some subclass of StorageBase are GC-tracked objects even
// though the base class is not. if (THPStorage_tryPreserve(_self)) {
return;
}
// Some subclass of StorageBase could be GC-tracked objects even
// though the base class is not
auto* type = Py_TYPE(self); auto* type = Py_TYPE(self);
if (PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC) != 0) { if (PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC) != 0) {
PyObject_GC_UnTrack(self); PyObject_GC_UnTrack(self);
} }
bool has_finalizer = type->tp_finalize || type->tp_del;
if (type->tp_finalize) {
PyObject_GC_Track(self);
if (PyObject_CallFinalizerFromDealloc(self) < 0) {
// The finalizer has resurrected the PyObject and there is a new Python
// reference to it, so we can just stop deallocating. Read about
// resurrection from `__del__` here:
// https://docs.python.org/3/reference/datamodel.html#object.__del__
return;
}
PyObject_GC_UnTrack(self);
}
// base test is unnecessary as THPStorae does not set this
if (type->tp_weaklistoffset) {
PyObject_ClearWeakRefs(self);
}
if (type->tp_del) {
PyObject_GC_Track(self);
type->tp_del(self);
if (self->ob_refcnt > 0) {
// Resurrected (see above comment about resurrection from `__del__`)
return;
}
PyObject_GC_UnTrack(self);
}
if (has_finalizer) {
/* New weakrefs could be created during the finalizer call.
If this occurs, clear them out without calling their
finalizers since they might rely on part of the object
being finalized that has already been destroyed. */
if (type->tp_weaklistoffset) {
/* Modeled after GET_WEAKREFS_LISTPTR() */
PyWeakReference** list =
(PyWeakReference**)PyObject_GET_WEAKREFS_LISTPTR(self);
while (*list)
_PyWeakref_ClearRef(*list);
}
}
// Clear slots
{
PyTypeObject* base = type;
while (base != &THPStorageType) {
if (Py_SIZE(base)) {
clear_slots(base, self);
}
base = base->tp_base;
TORCH_INTERNAL_ASSERT(base);
}
}
// Clear __dict__
if (C10_LIKELY(type->tp_dictoffset)) {
PyObject** dictptr = _PyObject_GetDictPtr(self);
if (dictptr != nullptr) {
PyObject* dict = *dictptr;
if (dict != nullptr) {
Py_DECREF(dict);
*dictptr = nullptr;
}
}
}
TORCH_INTERNAL_ASSERT(Py_TYPE(self) == type);
_self->cdata.~MaybeOwned<c10::Storage>(); _self->cdata.~MaybeOwned<c10::Storage>();
Py_TYPE(_self)->tp_free(self); Py_TYPE(_self)->tp_free(self);
TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE);
Py_DECREF(type);
} }
c10::intrusive_ptr<c10::StorageImpl> make_storage_impl( c10::intrusive_ptr<c10::StorageImpl> make_storage_impl(
@ -151,32 +390,35 @@ static PyObject* THPStorage_pynew(
"(): only one or neither of 'allocator' or 'device' can ", "(): only one or neither of 'allocator' or 'device' can ",
"be given, but not both"); "be given, but not both");
THPStoragePtr self((THPStorage*)type->tp_alloc(type, 0)); PyObject* self = nullptr;
THPUtils_assert(self, "failed to allocate a " THPStorageStr " object");
c10::Allocator* allocator = nullptr; c10::Allocator* allocator = nullptr;
// torch.Storage(*, ...) // torch.Storage(*, ...)
if (r.idx == 0) { if (r.idx == 0) {
self->cdata = c10::MaybeOwned<c10::Storage>::owned(make_storage_impl( self = THPStorage_NewWithStorage(
c10::StorageImpl::use_byte_size_t(), type,
0, make_storage_impl(
allocator, c10::StorageImpl::use_byte_size_t(),
/*resizable=*/true, 0,
allocator_opt, allocator,
device_opt)); /*resizable=*/true,
return (PyObject*)self.release(); allocator_opt,
device_opt),
c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
// torch.Storage(size, *, ...) // torch.Storage(size, *, ...)
} else if (r.idx == 1) { } else if (r.idx == 1) {
int64_t size = r.toInt64(0); int64_t size = r.toInt64(0);
self->cdata = c10::MaybeOwned<c10::Storage>::owned(make_storage_impl( self = THPStorage_NewWithStorage(
c10::StorageImpl::use_byte_size_t(), type,
size, make_storage_impl(
allocator, c10::StorageImpl::use_byte_size_t(),
/*resizable=*/true, size,
allocator_opt, allocator,
device_opt)); /*resizable=*/true,
return (PyObject*)self.release(); allocator_opt,
device_opt),
c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
// torch.Storage(sequence, *, ...) // torch.Storage(sequence, *, ...)
} else if (r.idx == 2) { } else if (r.idx == 2) {
@ -192,19 +434,22 @@ static PyObject* THPStorage_pynew(
THPStorageStr, THPStorageStr,
"(): Could not obtain the length of sequence of type ", "(): Could not obtain the length of sequence of type ",
THPUtils_typename(sequence)); THPUtils_typename(sequence));
self->cdata = c10::MaybeOwned<c10::Storage>::owned(make_storage_impl( self = THPStorage_NewWithStorage(
c10::StorageImpl::use_byte_size_t(), type,
length, make_storage_impl(
allocator, c10::StorageImpl::use_byte_size_t(),
/*resizable=*/true, length,
allocator_opt, allocator,
device_opt)); /*resizable=*/true,
allocator_opt,
device_opt),
c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
THPObjectPtr item; THPObjectPtr item;
try { try {
const auto& storage = THPStorage_Unpack(self);
for (Py_ssize_t i = 0; i < length; i++) { for (Py_ssize_t i = 0; i < length; i++) {
item = PySequence_GetItem(sequence, i); item = PySequence_GetItem(sequence, i);
uint8_t value = THPByteUtils_unpackReal(item.get()); uint8_t value = THPByteUtils_unpackReal(item.get());
const auto& storage = THPStorage_Unpack(self);
if (allocator == c10::GetDefaultCPUAllocator()) { if (allocator == c10::GetDefaultCPUAllocator()) {
static_cast<uint8_t*>(storage.mutable_data())[i] = value; static_cast<uint8_t*>(storage.mutable_data())[i] = value;
} else { } else {
@ -221,20 +466,22 @@ static PyObject* THPStorage_pynew(
THPUtils_typename(item.get())); THPUtils_typename(item.get()));
return nullptr; return nullptr;
} }
return (PyObject*)self.release();
} }
return self;
Py_RETURN_NONE; Py_RETURN_NONE;
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
static Py_ssize_t THPStorage_length(THPStorage* self) { static Py_ssize_t THPStorage_length(THPStorage* self) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
THPStorage_assertNotNull(self);
return static_cast<Py_ssize_t>(THPStorage_Unpack(self).nbytes()); return static_cast<Py_ssize_t>(THPStorage_Unpack(self).nbytes());
END_HANDLE_TH_ERRORS_RET(-1) END_HANDLE_TH_ERRORS_RET(-1)
} }
static PyObject* THPStorage_get(THPStorage* self, PyObject* index) { static PyObject* THPStorage_get(THPStorage* self, PyObject* index) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
THPStorage_assertNotNull(self);
const auto& storage = THPStorage_Unpack(self); const auto& storage = THPStorage_Unpack(self);
int64_t len = static_cast<int64_t>(storage.nbytes()); int64_t len = static_cast<int64_t>(storage.nbytes());
/* Integer index */ /* Integer index */
@ -289,7 +536,11 @@ static PyObject* THPStorage_get(THPStorage* self, PyObject* index) {
old_storage_impl->allocator(), old_storage_impl->allocator(),
/* resizable */ false); /* resizable */ false);
PyObject* _ret = THPStorage_New(std::move(new_storage_impl)); PyObject* _ret = THPStorage_NewWithStorage(
Py_TYPE(self),
std::move(new_storage_impl),
c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
return _ret; return _ret;
} }
PyErr_Format( PyErr_Format(
@ -302,6 +553,7 @@ static PyObject* THPStorage_get(THPStorage* self, PyObject* index) {
static int THPStorage_set(THPStorage* self, PyObject* index, PyObject* value) { static int THPStorage_set(THPStorage* self, PyObject* index, PyObject* value) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
THPStorage_assertNotNull(self);
if (!THPByteUtils_checkReal(value)) { if (!THPByteUtils_checkReal(value)) {
THPUtils_setError( THPUtils_setError(
"can only set storage content with a int types, but got " "can only set storage content with a int types, but got "
@ -451,6 +703,7 @@ int THPStorageMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs) {
static PyObject* THPStorage_device(THPStorage* self, void* unused) { static PyObject* THPStorage_device(THPStorage* self, void* unused) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
THPStorage_assertNotNull(self);
return THPDevice_New(THPStorage_Unpack(self).device()); return THPDevice_New(THPStorage_Unpack(self).device());
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
@ -490,7 +743,17 @@ bool THPStorage_init(PyObject* module) {
} }
void THPStorage_postInit(PyObject* module) { void THPStorage_postInit(PyObject* module) {
THPStorageClass = PyObject_GetAttrString(module, "UntypedStorage"); THPStorageClass =
(PyTypeObject*)PyObject_GetAttrString(module, "UntypedStorage");
if (!THPStorageClass) if (!THPStorageClass)
throw python_error(); throw python_error();
} }
void THPStorage_assertNotNull(THPStorage* storage) {
TORCH_CHECK(
THPStorage_Unpack(storage).unsafeGetStorageImpl(), "Got a null Storage");
}
void THPStorage_assertNotNull(PyObject* obj) {
THPStorage_assertNotNull((THPStorage*)obj);
}

View File

@ -5,84 +5,44 @@
#define THPStorageStr "torch.UntypedStorage" #define THPStorageStr "torch.UntypedStorage"
namespace c10 {
template <>
struct MaybeOwnedTraits<c10::Storage> {
using owned_type = c10::Storage;
using borrow_type = c10::Storage;
static borrow_type createBorrow(const owned_type& from) {
return borrow_type(from);
}
static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) {
lhs.unsafeReleaseStorageImpl();
lhs = borrow_type(rhs);
}
static void destroyBorrow(borrow_type& toDestroy) {
toDestroy.unsafeReleaseStorageImpl(); // "leak" it, but it was already +0.
}
static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
return borrow;
}
static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
return &borrow;
}
static bool debugBorrowIsValid(const borrow_type& /*borrow*/) {
return true;
}
};
template <>
struct ExclusivelyOwnedTraits<c10::Storage> {
using repr_type = c10::Storage;
using pointer_type = c10::Storage*;
using const_pointer_type = const c10::Storage*;
static repr_type nullRepr() {
return c10::Storage();
}
template <class... Args>
static repr_type createInPlace(Args&&... args) {
return c10::Storage(std::forward<Args>(args)...);
}
static repr_type moveToRepr(c10::Storage&& x) {
return std::move(x);
}
static c10::Storage take(c10::Storage& x) {
return std::move(x);
}
static pointer_type getImpl(repr_type& x) {
return &x;
}
static const_pointer_type getImpl(const repr_type& x) {
return &x;
}
};
} // namespace c10
struct THPStorage { struct THPStorage {
PyObject_HEAD; PyObject_HEAD;
c10::MaybeOwned<c10::Storage> cdata; c10::MaybeOwned<c10::Storage> cdata;
bool is_hermetic;
}; };
TORCH_PYTHON_API PyObject* THPStorage_New(c10::Storage storage); TORCH_PYTHON_API PyObject* THPStorage_Wrap(c10::Storage storage);
extern PyObject* THPStorageClass; TORCH_PYTHON_API PyObject* THPStorage_NewWithStorage(
PyTypeObject* type,
c10::Storage _storage,
c10::impl::PyInterpreterStatus status,
bool allow_preexisting_pyobj = false);
extern PyTypeObject* THPStorageClass;
static inline bool THPStorage_CheckTypeExact(PyTypeObject* tp) {
return tp == THPStorageClass;
}
static inline bool THPStorage_CheckExact(PyObject* obj) {
return THPStorage_CheckTypeExact(Py_TYPE(obj));
}
inline bool THPStorage_Check(PyObject* obj) {
if (!THPStorageClass)
return false;
const auto result = PyObject_IsInstance(obj, (PyObject*)THPStorageClass);
if (result == -1)
throw python_error();
return result;
}
bool THPStorage_init(PyObject* module); bool THPStorage_init(PyObject* module);
void THPStorage_postInit(PyObject* module); void THPStorage_postInit(PyObject* module);
void THPStorage_assertNotNull(THPStorage* storage);
void THPStorage_assertNotNull(PyObject* obj);
extern PyTypeObject THPStorageType; extern PyTypeObject THPStorageType;
inline const c10::Storage& THPStorage_Unpack(THPStorage* storage) { inline const c10::Storage& THPStorage_Unpack(THPStorage* storage) {

View File

@ -41,6 +41,7 @@
static PyObject* THPStorage_nbytes(PyObject* self, PyObject* noargs) { static PyObject* THPStorage_nbytes(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
THPStorage_assertNotNull(self);
return py::cast(THPStorage_Unpack(self).sym_nbytes()).release().ptr(); return py::cast(THPStorage_Unpack(self).sym_nbytes()).release().ptr();
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
@ -66,6 +67,7 @@ static PyObject* THPStorage_copy_(
PyObject* args, PyObject* args,
PyObject* kwargs) { PyObject* kwargs) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
THPStorage_assertNotNull(self);
at::Storage self_ = torch::createStorage(self); at::Storage self_ = torch::createStorage(self);
@ -96,12 +98,14 @@ static PyObject* THPStorage_copy_(
static PyObject* THPStorage_elementSize(PyObject* _self, PyObject* noargs) { static PyObject* THPStorage_elementSize(PyObject* _self, PyObject* noargs) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
THPStorage_assertNotNull(_self);
return THPUtils_packInt64(sizeof(uint8_t)); return THPUtils_packInt64(sizeof(uint8_t));
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
static PyObject* THPStorage_new(PyObject* self, PyObject* noargs) { static PyObject* THPStorage_new(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
THPStorage_assertNotNull(self);
c10::Allocator* allocator = THPStorage_Unpack(self).allocator(); c10::Allocator* allocator = THPStorage_Unpack(self).allocator();
auto new_storage = c10::make_intrusive<at::StorageImpl>( auto new_storage = c10::make_intrusive<at::StorageImpl>(
c10::StorageImpl::use_byte_size_t(), c10::StorageImpl::use_byte_size_t(),
@ -109,12 +113,13 @@ static PyObject* THPStorage_new(PyObject* self, PyObject* noargs) {
allocator, allocator,
/*resizable=*/true); /*resizable=*/true);
return THPStorage_New(std::move(new_storage)); return THPStorage_Wrap(std::move(new_storage));
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
static PyObject* THPStorage_resize_(PyObject* self, PyObject* number_arg) { static PyObject* THPStorage_resize_(PyObject* self, PyObject* number_arg) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
THPStorage_assertNotNull(self);
const auto& storage = THPStorage_Unpack(self); const auto& storage = THPStorage_Unpack(self);
// See Note [Invalid Python Storages] // See Note [Invalid Python Storages]
auto invalid = storage.data() == nullptr && auto invalid = storage.data() == nullptr &&
@ -185,6 +190,7 @@ static PyObject* THPStorage_resize_(PyObject* self, PyObject* number_arg) {
static PyObject* THPStorage_fill_(PyObject* self, PyObject* number_arg) { static PyObject* THPStorage_fill_(PyObject* self, PyObject* number_arg) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
THPStorage_assertNotNull(self);
const auto& storage = THPStorage_Unpack(self); const auto& storage = THPStorage_Unpack(self);
// See Note [Invalid Python Storages] // See Note [Invalid Python Storages]
auto invalid = storage.data() == nullptr && auto invalid = storage.data() == nullptr &&
@ -389,7 +395,7 @@ static PyObject* THPStorage_fromBuffer(
} }
PyBuffer_Release(&buffer); PyBuffer_Release(&buffer);
return (PyObject*)THPStorage_New(storage); return THPStorage_Wrap(storage);
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
@ -429,12 +435,16 @@ static PyObject* THPStorage_fromFile(
storage->set_nbytes(actual_nbytes); storage->set_nbytes(actual_nbytes);
} }
return (PyObject*)THPStorage_New(std::move(storage)); return THPStorage_NewWithStorage(
THPStorageClass,
std::move(storage),
c10::impl::PyInterpreterStatus::TAGGED_BY_US);
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
PyObject* THPStorage_writeFile(PyObject* self, PyObject* args) { PyObject* THPStorage_writeFile(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
THPStorage_assertNotNull(self);
const auto& storage = THPStorage_Unpack(self); const auto& storage = THPStorage_Unpack(self);
// See Note [Invalid Python Storages] // See Note [Invalid Python Storages]
auto invalid = storage.data() == nullptr && auto invalid = storage.data() == nullptr &&
@ -486,12 +496,13 @@ PyObject* THPStorage_newWithFile(PyObject* _unused, PyObject* args) {
auto storage = THPStorage_readFileRaw<int>(fd, {}, element_size); auto storage = THPStorage_readFileRaw<int>(fd, {}, element_size);
if (!storage.defined()) if (!storage.defined())
return nullptr; return nullptr;
return THPStorage_New(std::move(storage)); return THPStorage_Wrap(std::move(storage));
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
static PyObject* THPStorage_setFromFile(PyObject* self, PyObject* args) { static PyObject* THPStorage_setFromFile(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
THPStorage_assertNotNull(self);
const auto& storage = THPStorage_Unpack(self); const auto& storage = THPStorage_Unpack(self);
PyObject* file = PyTuple_GET_ITEM(args, 0); PyObject* file = PyTuple_GET_ITEM(args, 0);
PyObject* offset = PyTuple_GET_ITEM(args, 1); PyObject* offset = PyTuple_GET_ITEM(args, 1);
@ -617,6 +628,12 @@ PyObject* THPStorage_byteswap(PyObject* self, PyObject* args) {
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
static PyObject* THPStorage_fix_weakref(PyObject* self, PyObject* noargs) {
const auto& storage = THPStorage_Unpack(self);
Py_DECREF(THPStorage_Wrap(storage));
Py_RETURN_NONE;
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
static PyMethodDef THPStorage_methods[] = { static PyMethodDef THPStorage_methods[] = {
{"copy_", {"copy_",
@ -645,6 +662,7 @@ static PyMethodDef THPStorage_methods[] = {
nullptr}, nullptr},
{"_set_cdata", THPStorage__setCdata, METH_O, nullptr}, {"_set_cdata", THPStorage__setCdata, METH_O, nullptr},
{"_byteswap", THPStorage_byteswap, METH_VARARGS, nullptr}, {"_byteswap", THPStorage_byteswap, METH_VARARGS, nullptr},
{"_fix_weakref", THPStorage_fix_weakref, METH_NOARGS, nullptr},
{nullptr}}; {nullptr}};
PyMethodDef* THPStorage_getMethods() { PyMethodDef* THPStorage_getMethods() {

View File

@ -33,6 +33,7 @@
static PyObject* THPStorage_sharedDecref(PyObject* self, PyObject* noargs) { static PyObject* THPStorage_sharedDecref(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
THPStorage_assertNotNull(self);
const auto& storage = THPStorage_Unpack(self); const auto& storage = THPStorage_Unpack(self);
c10::DeviceType device_type = storage.device_type(); c10::DeviceType device_type = storage.device_type();
if (device_type == at::kCPU) { if (device_type == at::kCPU) {
@ -49,6 +50,7 @@ static PyObject* THPStorage_sharedDecref(PyObject* self, PyObject* noargs) {
static PyObject* THPStorage_sharedIncref(PyObject* self, PyObject* noargs) { static PyObject* THPStorage_sharedIncref(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
THPStorage_assertNotNull(self);
const auto& storage = THPStorage_Unpack(self); const auto& storage = THPStorage_Unpack(self);
c10::DeviceType device_type = storage.device_type(); c10::DeviceType device_type = storage.device_type();
if (device_type == at::kCPU) { if (device_type == at::kCPU) {
@ -76,18 +78,22 @@ static PyObject* THPStorage_pyNewFilenameStorage(
int flags = at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_EXCLUSIVE; int flags = at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_EXCLUSIVE;
std::string handle = at::NewProcessWideShmHandle(); std::string handle = at::NewProcessWideShmHandle();
return THPStorage_New(c10::make_intrusive<at::StorageImpl>( return THPStorage_NewWithStorage(
c10::StorageImpl::use_byte_size_t(), THPStorageClass,
size, c10::make_intrusive<at::StorageImpl>(
THManagedMapAllocator::makeDataPtr( c10::StorageImpl::use_byte_size_t(),
"", handle.c_str(), flags, static_cast<size_t>(size)), size,
/*allocator=*/nullptr, THManagedMapAllocator::makeDataPtr(
/*resizable=*/false)); "", handle.c_str(), flags, static_cast<size_t>(size)),
/*allocator=*/nullptr,
/*resizable=*/false),
c10::impl::PyInterpreterStatus::TAGGED_BY_US);
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
static PyObject* THPStorage_shareFilename(PyObject* self, PyObject* noargs) { static PyObject* THPStorage_shareFilename(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
THPStorage_assertNotNull(self);
const auto& storage = THPStorage_Unpack(self); const auto& storage = THPStorage_Unpack(self);
TORCH_CHECK( TORCH_CHECK(
storage.device_type() == at::kCPU, storage.device_type() == at::kCPU,
@ -168,13 +174,16 @@ static PyObject* THPStorage_newSharedFilename(
const char* object_handle = PyBytes_AS_STRING(_object_handle); const char* object_handle = PyBytes_AS_STRING(_object_handle);
uint64_t size = THPUtils_unpackUInt64(_size); uint64_t size = THPUtils_unpackUInt64(_size);
int flags = at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_NOCREATE; int flags = at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_NOCREATE;
return THPStorage_New(c10::make_intrusive<at::StorageImpl>( return THPStorage_NewWithStorage(
c10::StorageImpl::use_byte_size_t(), THPStorageClass,
size, c10::make_intrusive<at::StorageImpl>(
THManagedMapAllocator::makeDataPtr( c10::StorageImpl::use_byte_size_t(),
manager_handle, object_handle, flags, size), size,
/*allocator=*/nullptr, THManagedMapAllocator::makeDataPtr(
/*resizable=*/false)); manager_handle, object_handle, flags, size),
/*allocator=*/nullptr,
/*resizable=*/false),
c10::impl::PyInterpreterStatus::TAGGED_BY_US);
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
@ -187,12 +196,16 @@ static PyObject* THPStorage_pyNewFdStorage(PyObject* _unused, PyObject* args) {
if (size < 0) { if (size < 0) {
return nullptr; return nullptr;
} }
return THPStorage_New(at::new_shm_fd_storage(size)); return THPStorage_NewWithStorage(
THPStorageClass,
at::new_shm_fd_storage(size),
c10::impl::PyInterpreterStatus::TAGGED_BY_US);
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
static PyObject* THPStorage_shareFd(PyObject* self, PyObject* noargs) { static PyObject* THPStorage_shareFd(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
THPStorage_assertNotNull(self);
const auto& storage = THPStorage_Unpack(self); const auto& storage = THPStorage_Unpack(self);
TORCH_CHECK( TORCH_CHECK(
storage.device_type() == at::kCPU, "_share_fd_: only available on CPU"); storage.device_type() == at::kCPU, "_share_fd_: only available on CPU");
@ -257,17 +270,22 @@ static PyObject* THPStorage_newSharedFd(PyObject* _unused, PyObject* args) {
int flags = at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_NOCREATE | int flags = at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_NOCREATE |
at::ALLOCATOR_MAPPED_KEEPFD | at::ALLOCATOR_MAPPED_FROMFD; at::ALLOCATOR_MAPPED_KEEPFD | at::ALLOCATOR_MAPPED_FROMFD;
return THPStorage_New(c10::make_intrusive<at::StorageImpl>( return THPStorage_NewWithStorage(
c10::StorageImpl::use_byte_size_t(), THPStorageClass,
size, c10::make_intrusive<at::StorageImpl>(
at::MapAllocator::makeDataPtr(at::WITH_FD, "", fd, flags, size, nullptr), c10::StorageImpl::use_byte_size_t(),
/*allocator=*/nullptr, size,
/*resizable=*/false)); at::MapAllocator::makeDataPtr(
at::WITH_FD, "", fd, flags, size, nullptr),
/*allocator=*/nullptr,
/*resizable=*/false),
c10::impl::PyInterpreterStatus::TAGGED_BY_US);
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
static PyObject* THPStorage_shareCuda(PyObject* self, PyObject* noargs) { static PyObject* THPStorage_shareCuda(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
THPStorage_assertNotNull(self);
#ifdef USE_CUDA #ifdef USE_CUDA
const auto& storage = THPStorage_Unpack(self); const auto& storage = THPStorage_Unpack(self);
TORCH_CHECK( TORCH_CHECK(
@ -547,7 +565,10 @@ static PyObject* THPStorage_newSharedCuda(PyObject* _unused, PyObject* args) {
base->set_resizable(false); base->set_resizable(false);
base->set_received_cuda(true); base->set_received_cuda(true);
return THPStorage_New(std::move(base)); return THPStorage_NewWithStorage(
THPStorageClass,
std::move(base),
c10::impl::PyInterpreterStatus::TAGGED_BY_US);
#else #else
TORCH_CHECK(false, "CUDA is not available"); TORCH_CHECK(false, "CUDA is not available");
#endif #endif
@ -572,7 +593,7 @@ PyObject* THPStorage_newWithWeakPtr(PyObject* _unused, PyObject* arg) {
THPUtils_checkLong(arg), "_new_with_weak_ptr(): arg must be an 'int'"); THPUtils_checkLong(arg), "_new_with_weak_ptr(): arg must be an 'int'");
c10::StorageImpl* weak_storage = (c10::StorageImpl*)PyLong_AsVoidPtr(arg); c10::StorageImpl* weak_storage = (c10::StorageImpl*)PyLong_AsVoidPtr(arg);
if (auto* storage = c10::raw::weak_intrusive_ptr::lock(weak_storage)) { if (auto* storage = c10::raw::weak_intrusive_ptr::lock(weak_storage)) {
return THPStorage_New( return THPStorage_Wrap(
c10::intrusive_ptr<c10::StorageImpl>::reclaim(storage)); c10::intrusive_ptr<c10::StorageImpl>::reclaim(storage));
} }
Py_RETURN_NONE; Py_RETURN_NONE;
@ -604,6 +625,7 @@ PyObject* THPStorage_expired(PyObject* _unused, PyObject* arg) {
PyObject* THPStorage_sharedFd(PyObject* self, PyObject* noargs) { PyObject* THPStorage_sharedFd(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
THPStorage_assertNotNull(self);
at::MapAllocator* ctx = nullptr; at::MapAllocator* ctx = nullptr;
const auto& storage = THPStorage_Unpack(self); const auto& storage = THPStorage_Unpack(self);
if (storage.device_type() == at::kCPU) { if (storage.device_type() == at::kCPU) {

View File

@ -26,6 +26,7 @@
#include <torch/csrc/tensor/python_tensor.h> #include <torch/csrc/tensor/python_tensor.h>
#include <torch/csrc/utils/pybind.h> #include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/pycfunction_helpers.h> #include <torch/csrc/utils/pycfunction_helpers.h>
#include <torch/csrc/utils/pyobject_preservation.h>
#include <torch/csrc/utils/python_arg_parser.h> #include <torch/csrc/utils/python_arg_parser.h>
#include <torch/csrc/utils/python_dispatch.h> #include <torch/csrc/utils/python_dispatch.h>
#include <torch/csrc/utils/python_strings.h> #include <torch/csrc/utils/python_strings.h>
@ -268,7 +269,8 @@ PyObject* THPVariable_Wrap(at::TensorBase var) {
} }
c10::optional<PyObject*> mb_obj = c10::optional<PyObject*> mb_obj =
var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(getPyInterpreter()); var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
getPyInterpreter(), /*ignore_hermetic_tls=*/false);
c10::impl::PyInterpreterStatus status; c10::impl::PyInterpreterStatus status;
if (mb_obj.has_value()) { if (mb_obj.has_value()) {
auto obj = *mb_obj; auto obj = *mb_obj;
@ -345,7 +347,8 @@ bool isResurrectable(THPVariable* self) {
} }
// Check if this is hermetic. If it is, no resurrection. // Check if this is hermetic. If it is, no resurrection.
if (tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( if (tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
getPyInterpreter()) != c10::make_optional((PyObject*)self)) { getPyInterpreter(), /*ignore_hermetic_tls=*/false) !=
c10::make_optional((PyObject*)self)) {
return false; return false;
} }
return true; return true;
@ -369,7 +372,16 @@ static bool THPVariable_tryResurrect(THPVariable* self) {
TORCH_INTERNAL_ASSERT( TORCH_INTERNAL_ASSERT(
!tensor.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()); !tensor.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj());
tensor.unsafeGetTensorImpl()->pyobj_slot()->set_owns_pyobj(true); c10::TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
auto maybe_pyobj = tensor_impl->pyobj_slot()->check_pyobj(
getPyInterpreter(),
/*ignore_hermetic_tls=*/false);
TORCH_INTERNAL_ASSERT(
maybe_pyobj.has_value(),
"Trying to preserve a Python tensor whose PyObjectSlot does not have a PyObject");
tensor_impl->pyobj_slot()->set_owns_pyobj(true);
// Resurrect the Python object. This is something CPython does // Resurrect the Python object. This is something CPython does
// internally occasionally, see // internally occasionally, see
@ -443,7 +455,8 @@ static int THPVariable_clear(THPVariable* self) {
if (!self->cdata.unsafeIsBorrowed() && if (!self->cdata.unsafeIsBorrowed() &&
tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
getPyInterpreter()) == c10::make_optional((PyObject*)self)) { getPyInterpreter(), /*ignore_hermetic_tls=*/false) ==
c10::make_optional((PyObject*)self)) {
// TODO: empirically, on OS X this assert appears to be untrue // TODO: empirically, on OS X this assert appears to be untrue
// In test_py_tensors_multi_async_call - ProcessGroupRpcTestWithSpawn // In test_py_tensors_multi_async_call - ProcessGroupRpcTestWithSpawn
// distributed/rpc/test_process_group_agent.py // distributed/rpc/test_process_group_agent.py
@ -1747,26 +1760,6 @@ PyObject* THPVariable_pynew(
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
static void clear_slots(PyTypeObject* type, PyObject* self) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
Py_ssize_t i, n;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
PyMemberDef* mp;
n = Py_SIZE(type);
mp = type->tp_members;
for (i = 0; i < n; i++, mp++) {
if (mp->type == T_OBJECT_EX && !(mp->flags & READONLY)) {
char* addr = (char*)self + mp->offset;
PyObject* obj = *(PyObject**)addr;
if (obj != nullptr) {
*(PyObject**)addr = nullptr;
Py_DECREF(obj);
}
}
}
}
// NB: this is not the tp_dealloc on THPVariable; instead, its the dealloc // NB: this is not the tp_dealloc on THPVariable; instead, its the dealloc
// on subclasses. It's never valid to construct a THPVariable so it's not // on subclasses. It's never valid to construct a THPVariable so it's not
// necessary to implement the dealloc for that case // necessary to implement the dealloc for that case
@ -1886,8 +1879,8 @@ static PyObject* THPVariable_NewWithVar(
// This function overwrite the Tensor's pyobj field without extra checks // This function overwrite the Tensor's pyobj field without extra checks
// Make sure it is not set otherwise we would leak memory // Make sure it is not set otherwise we would leak memory
auto mb_obj = auto mb_obj = _var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
_var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(getPyInterpreter()); getPyInterpreter(), /*ignore_hermetic_tls=*/false);
// Under some circumstances, we may attempt to create a new Python // Under some circumstances, we may attempt to create a new Python
// object for a variable that already has a Python object. The most common // object for a variable that already has a Python object. The most common

View File

@ -0,0 +1,19 @@
#include <torch/csrc/utils/pyobject_preservation.h>
#include <structmember.h>
void clear_slots(PyTypeObject* type, PyObject* self) {
Py_ssize_t n = Py_SIZE(type);
PyMemberDef* mp = type->tp_members;
for (Py_ssize_t i = 0; i < n; i++, mp++) {
if (mp->type == T_OBJECT_EX && !(mp->flags & READONLY)) {
char* addr = (char*)self + mp->offset;
PyObject* obj = *(PyObject**)addr;
if (obj != nullptr) {
*(PyObject**)addr = nullptr;
Py_DECREF(obj);
}
}
}
}

View File

@ -0,0 +1,7 @@
#pragma once
#include <torch/csrc/python_headers.h>
// This file contains utilities used for handling PyObject preservation
void clear_slots(PyTypeObject* type, PyObject* self);

View File

@ -1106,8 +1106,8 @@ inline at::Storage PythonArgs::storage(
is_typed_storage = false; is_typed_storage = false;
storage_scalar_type = at::ScalarType::Undefined; storage_scalar_type = at::ScalarType::Undefined;
} else { } else {
storage = std::tie(storage, storage_scalar_type, is_typed_storage) =
createStorageGetType(args[i], storage_scalar_type, is_typed_storage); createStorageGetType(args[i]);
} }
return storage; return storage;
} }

View File

@ -385,10 +385,12 @@ Tensor internal_new_from_data(
at::tracer::impl::NoTracerDispatchMode tracer_guard; at::tracer::impl::NoTracerDispatchMode tracer_guard;
if (isStorage(data)) { if (isStorage(data)) {
ScalarType storage_scalar_type{ScalarType::Undefined};
bool is_typed_storage = false; bool is_typed_storage = false;
Storage storage = ScalarType storage_scalar_type{ScalarType::Undefined};
createStorageGetType(data, storage_scalar_type, is_typed_storage); Storage storage;
std::tie(storage, storage_scalar_type, is_typed_storage) =
createStorageGetType(data);
TORCH_CHECK( TORCH_CHECK(
!is_typed_storage || storage_scalar_type == scalar_type, !is_typed_storage || storage_scalar_type == scalar_type,
"Expected a Storage of type ", "Expected a Storage of type ",