mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Reland 2: Add PyObject preservation for UntypedStorage (#109039)
Relands #103907 after it was reverted. This PR makes the new `ignore_hermetic_tls` argument of `check_pyobj` optional to avoid causing a compilation error in torchdistx Part of #91395 Pull Request resolved: https://github.com/pytorch/pytorch/pull/109039 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
6dc56d3490
commit
4c5e43574c
@ -896,6 +896,7 @@ libtorch_python_core_sources = [
|
||||
"torch/csrc/utils/python_dispatch.cpp",
|
||||
"torch/csrc/utils/python_symnode.cpp",
|
||||
"torch/csrc/utils/pybind.cpp",
|
||||
"torch/csrc/utils/pyobject_preservation.cpp",
|
||||
"torch/csrc/utils/structseq.cpp",
|
||||
"torch/csrc/utils/tensor_apply.cpp",
|
||||
"torch/csrc/utils/tensor_dtypes.cpp",
|
||||
|
78
c10/core/RefcountedDeleter.cpp
Normal file
78
c10/core/RefcountedDeleter.cpp
Normal file
@ -0,0 +1,78 @@
|
||||
#include <c10/core/RefcountedDeleter.h>
|
||||
|
||||
#include <mutex>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
void refcounted_deleter(void* ctx_) {
|
||||
RefcountedDeleterContext& ctx =
|
||||
*reinterpret_cast<RefcountedDeleterContext*>(ctx_);
|
||||
ctx.refcount--;
|
||||
if (ctx.refcount == 0) {
|
||||
ctx.other_ctx = nullptr;
|
||||
delete &ctx;
|
||||
}
|
||||
}
|
||||
|
||||
std::mutex replace_data_ptr_mutex;
|
||||
|
||||
void maybeApplyRefcountedDeleter(const c10::Storage& storage) {
|
||||
std::lock_guard<std::mutex> guard(replace_data_ptr_mutex);
|
||||
c10::DataPtr& data_ptr = storage.mutable_data_ptr();
|
||||
|
||||
if ((void*)data_ptr.get_deleter() == (void*)&c10::refcounted_deleter) {
|
||||
// Data pointer is already shared
|
||||
return;
|
||||
}
|
||||
|
||||
void* data = data_ptr.get();
|
||||
void* other_ctx = data_ptr.get_context();
|
||||
c10::DeleterFnPtr other_deleter = data_ptr.get_deleter();
|
||||
c10::Device device = data_ptr.device();
|
||||
|
||||
// Release the context of the original DataPtr so that the data doesn't
|
||||
// get deleted when the original DataPtr is replaced
|
||||
data_ptr.release_context();
|
||||
|
||||
c10::RefcountedDeleterContext* refcount_ctx =
|
||||
new c10::RefcountedDeleterContext(other_ctx, other_deleter);
|
||||
|
||||
c10::DataPtr new_data_ptr(
|
||||
data,
|
||||
reinterpret_cast<void*>(refcount_ctx),
|
||||
&c10::refcounted_deleter,
|
||||
device);
|
||||
storage.set_data_ptr(std::move(new_data_ptr));
|
||||
}
|
||||
|
||||
c10::Storage newStorageImplFromRefcountedDataPtr(const c10::Storage& storage) {
|
||||
c10::maybeApplyRefcountedDeleter(storage);
|
||||
|
||||
c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl();
|
||||
|
||||
c10::DataPtr& data_ptr = storage.mutable_data_ptr();
|
||||
c10::DataPtr new_data_ptr(
|
||||
data_ptr.get(),
|
||||
data_ptr.get_context(),
|
||||
data_ptr.get_deleter(),
|
||||
data_ptr.device());
|
||||
|
||||
// NOTE: This refcount increment should always happen immediately after
|
||||
// `new_data_ptr` is created. No other lines of code should be added between
|
||||
// them in the future, unless there's a very good reason for it, because if
|
||||
// any errors are raised and `new_data_ptr` is deleted before the refcount is
|
||||
// incremented, the refcount will get decremented and end up being one less
|
||||
// than it should be.
|
||||
reinterpret_cast<c10::RefcountedDeleterContext*>(data_ptr.get_context())
|
||||
->refcount++;
|
||||
|
||||
c10::Storage new_storage = c10::make_intrusive<c10::StorageImpl>(
|
||||
c10::StorageImpl::use_byte_size_t(),
|
||||
storage_impl->nbytes(),
|
||||
std::move(new_data_ptr),
|
||||
storage_impl->allocator(),
|
||||
/*resizable=*/storage_impl->resizable());
|
||||
return new_storage;
|
||||
}
|
||||
|
||||
} // namespace c10
|
51
c10/core/RefcountedDeleter.h
Normal file
51
c10/core/RefcountedDeleter.h
Normal file
@ -0,0 +1,51 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/Storage.h>
|
||||
#include <c10/util/UniqueVoidPtr.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <memory>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
// A RefcountedDeleterContext object is used as the `ctx` argument for DataPtr
|
||||
// to implement a shared DataPtr. Normally, a DataPtr is unique, but we use
|
||||
// this custom context and the `refcounted_deleter` function below to make the
|
||||
// DataPtr act like a non-unique DataPtr. This context object holds onto an
|
||||
// inner context and deleter function which handle the actual deletion of the
|
||||
// data when the refcount reaches 0.
|
||||
//
|
||||
// This shared DataPtr feature is only used when storages are shared between
|
||||
// multiple Python interpreters in MultiPy. Before storages had PyObject
|
||||
// preservation, interpreters could just share the same StorageImpl instance.
|
||||
// But now a StorageImpl can only be associated with one interpreter in order
|
||||
// to properly manage a zombie PyObject. So we share storages across Python
|
||||
// interpreters by creating a different StorageImpl instance for each one, but
|
||||
// they all point to the same data.
|
||||
struct C10_API RefcountedDeleterContext {
|
||||
RefcountedDeleterContext(void* other_ctx, c10::DeleterFnPtr other_deleter)
|
||||
: other_ctx(other_ctx, other_deleter), refcount(1) {}
|
||||
|
||||
std::unique_ptr<void, c10::DeleterFnPtr> other_ctx;
|
||||
std::atomic_int refcount;
|
||||
};
|
||||
|
||||
// `refcounted_deleter` is used as the `ctx_deleter` for DataPtr to implement
|
||||
// a shared DataPtr.
|
||||
//
|
||||
// Warning: This should only be called on a pointer to
|
||||
// a RefcountedDeleterContext that was allocated on the heap with `new`,
|
||||
// because when the refcount reaches 0, the context is deleted with `delete`
|
||||
C10_API void refcounted_deleter(void* ctx_);
|
||||
|
||||
// If the storage's DataPtr does not use `refcounted_deleter`, replace it with
|
||||
// a DataPtr that does, so it can be shared between multiple StorageImpls
|
||||
C10_API void maybeApplyRefcountedDeleter(const c10::Storage& storage);
|
||||
|
||||
// Create a new StorageImpl that points to the same data. If the original
|
||||
// StorageImpl's DataPtr does not use `refcounted_deleter`, it will be replaced
|
||||
// with one that does
|
||||
C10_API c10::Storage newStorageImplFromRefcountedDataPtr(
|
||||
const c10::Storage& storage);
|
||||
|
||||
} // namespace c10
|
@ -33,7 +33,7 @@ struct C10_API SafePyObject {
|
||||
|
||||
~SafePyObject() {
|
||||
if (data_ != nullptr) {
|
||||
(*pyinterpreter_)->decref(data_, /*is_tensor*/ false);
|
||||
(*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,3 +1,18 @@
|
||||
#include <c10/core/RefcountedDeleter.h>
|
||||
#include <c10/core/Storage.h>
|
||||
|
||||
namespace c10 {} // namespace c10
|
||||
namespace c10 {
|
||||
|
||||
bool isSharedStorageAlias(const Storage& storage0, const Storage& storage1) {
|
||||
c10::DeleterFnPtr deleter_expected = &c10::refcounted_deleter;
|
||||
c10::DeleterFnPtr deleter0 = storage0.data_ptr().get_deleter();
|
||||
c10::DeleterFnPtr deleter1 = storage1.data_ptr().get_deleter();
|
||||
|
||||
if ((deleter0 != deleter_expected) || (deleter1 != deleter_expected)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return storage0.data_ptr().get_context() == storage1.data_ptr().get_context();
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
@ -1,12 +1,22 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/StorageImpl.h>
|
||||
#include <c10/util/ExclusivelyOwned.h>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
struct Storage;
|
||||
|
||||
C10_API bool isSharedStorageAlias(
|
||||
const Storage& storage0,
|
||||
const Storage& storage1);
|
||||
|
||||
struct C10_API Storage {
|
||||
public:
|
||||
struct use_byte_size_t {};
|
||||
struct unsafe_borrow_t {
|
||||
explicit unsafe_borrow_t() = default;
|
||||
};
|
||||
|
||||
Storage() = default;
|
||||
Storage(c10::intrusive_ptr<StorageImpl> ptr)
|
||||
@ -40,6 +50,14 @@ struct C10_API Storage {
|
||||
allocator,
|
||||
resizable)) {}
|
||||
|
||||
protected:
|
||||
explicit Storage(unsafe_borrow_t, const Storage& rhs)
|
||||
: storage_impl_(c10::intrusive_ptr<c10::StorageImpl>::reclaim(
|
||||
rhs.storage_impl_.get())) {}
|
||||
|
||||
friend MaybeOwnedTraits<Storage>;
|
||||
|
||||
public:
|
||||
// Legacy constructor for partially initialized (dtype or memory) storages
|
||||
// that can be temporarily created with Caffe2 APIs. See the note on top of
|
||||
// TensorImpl.h for details.
|
||||
@ -144,7 +162,9 @@ struct C10_API Storage {
|
||||
}
|
||||
|
||||
bool is_alias_of(const Storage& other) const {
|
||||
return storage_impl_ == other.storage_impl_;
|
||||
return (
|
||||
storage_impl_ == other.storage_impl_ ||
|
||||
isSharedStorageAlias(*this, other));
|
||||
}
|
||||
|
||||
void UniqueStorageShareExternalPointer(
|
||||
@ -175,4 +195,67 @@ struct C10_API Storage {
|
||||
c10::intrusive_ptr<StorageImpl> storage_impl_;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MaybeOwnedTraits<c10::Storage> {
|
||||
using owned_type = c10::Storage;
|
||||
using borrow_type = c10::Storage;
|
||||
|
||||
static borrow_type createBorrow(const owned_type& from) {
|
||||
return borrow_type(borrow_type::unsafe_borrow_t{}, from);
|
||||
}
|
||||
|
||||
static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) {
|
||||
lhs.unsafeReleaseStorageImpl();
|
||||
lhs = borrow_type(borrow_type::unsafe_borrow_t{}, rhs);
|
||||
}
|
||||
|
||||
static void destroyBorrow(borrow_type& toDestroy) {
|
||||
toDestroy.unsafeReleaseStorageImpl(); // "leak" it, but it was already +0.
|
||||
}
|
||||
|
||||
static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
|
||||
return borrow;
|
||||
}
|
||||
|
||||
static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
|
||||
return &borrow;
|
||||
}
|
||||
|
||||
static bool debugBorrowIsValid(const borrow_type& /*borrow*/) {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ExclusivelyOwnedTraits<c10::Storage> {
|
||||
using repr_type = c10::Storage;
|
||||
using pointer_type = c10::Storage*;
|
||||
using const_pointer_type = const c10::Storage*;
|
||||
|
||||
static repr_type nullRepr() {
|
||||
return c10::Storage();
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
static repr_type createInPlace(Args&&... args) {
|
||||
return c10::Storage(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
static repr_type moveToRepr(c10::Storage&& x) {
|
||||
return std::move(x);
|
||||
}
|
||||
|
||||
static c10::Storage take(c10::Storage& x) {
|
||||
return std::move(x);
|
||||
}
|
||||
|
||||
static pointer_type getImpl(repr_type& x) {
|
||||
return &x;
|
||||
}
|
||||
|
||||
static const_pointer_type getImpl(const repr_type& x) {
|
||||
return &x;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace c10
|
||||
|
@ -203,6 +203,14 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
|
||||
return received_cuda_;
|
||||
}
|
||||
|
||||
impl::PyObjectSlot* pyobj_slot() {
|
||||
return &pyobj_slot_;
|
||||
}
|
||||
|
||||
const impl::PyObjectSlot* pyobj_slot() const {
|
||||
return &pyobj_slot_;
|
||||
}
|
||||
|
||||
private:
|
||||
DataPtr data_ptr_;
|
||||
SymInt size_bytes_;
|
||||
|
@ -73,9 +73,7 @@ void TensorImpl::_set_fw_grad(
|
||||
autograd_meta_->set_fw_grad(new_grad, self, level, is_inplace_op);
|
||||
}
|
||||
|
||||
TensorImpl::~TensorImpl() {
|
||||
pyobj_slot_.destroy_pyobj_if_needed();
|
||||
}
|
||||
TensorImpl::~TensorImpl() = default;
|
||||
|
||||
TensorImpl::TensorImpl(
|
||||
Storage&& storage,
|
||||
@ -582,7 +580,7 @@ void TensorImpl::release_resources() {
|
||||
if (storage_) {
|
||||
storage_ = {};
|
||||
}
|
||||
pyobj_slot_.destroy_pyobj_if_needed();
|
||||
pyobj_slot_.maybe_destroy_pyobj();
|
||||
}
|
||||
|
||||
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
|
||||
|
@ -10,7 +10,8 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
|
||||
return "<unloaded interpreter>";
|
||||
}
|
||||
|
||||
void decref(PyObject* pyobj, bool is_tensor) const override {} // do nothing
|
||||
void decref(PyObject* pyobj, bool has_pyobj_slot) const override {
|
||||
} // do nothing
|
||||
|
||||
#define PANIC(m) \
|
||||
TORCH_INTERNAL_ASSERT( \
|
||||
|
@ -127,8 +127,8 @@ struct C10_API PyInterpreterVTable {
|
||||
virtual std::string name() const = 0;
|
||||
|
||||
// Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call
|
||||
// See NOTE [PyInterpreter::decref takes an `is_tensor` arg]
|
||||
virtual void decref(PyObject* pyobj, bool is_tensor) const = 0;
|
||||
// See NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg]
|
||||
virtual void decref(PyObject* pyobj, bool has_pyobj_slot) const = 0;
|
||||
|
||||
// Perform a detach by deferring to the __torch_dispatch__ implementation of
|
||||
// detach, which will also arrange for the PyObject to get copied in this
|
||||
|
@ -5,12 +5,16 @@ namespace impl {
|
||||
|
||||
PyObjectSlot::PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {}
|
||||
|
||||
void PyObjectSlot::destroy_pyobj_if_needed() {
|
||||
PyObjectSlot::~PyObjectSlot() {
|
||||
maybe_destroy_pyobj();
|
||||
}
|
||||
|
||||
void PyObjectSlot::maybe_destroy_pyobj() {
|
||||
if (owns_pyobj()) {
|
||||
TORCH_INTERNAL_ASSERT(pyobj_interpreter_ != nullptr);
|
||||
TORCH_INTERNAL_ASSERT(pyobj_ != nullptr);
|
||||
(*pyobj_interpreter_.load(std::memory_order_acquire))
|
||||
->decref(_unchecked_untagged_pyobj(), /*is_tensor*/ true);
|
||||
->decref(_unchecked_untagged_pyobj(), /*has_pyobj_slot*/ true);
|
||||
// NB: this destructor can only be entered when there are no
|
||||
// references to this C++ object (obviously), NOR any references
|
||||
// to the PyObject (if there are references to the PyObject,
|
||||
@ -47,6 +51,15 @@ PyInterpreter& PyObjectSlot::load_pyobj_interpreter() const {
|
||||
(*pyobj_interpreter_.load())->name());
|
||||
}
|
||||
|
||||
bool PyObjectSlot::check_interpreter(PyInterpreter* interpreter) {
|
||||
return interpreter == pyobj_interpreter();
|
||||
}
|
||||
|
||||
bool PyObjectSlot::has_pyobj_nonhermetic() {
|
||||
return check_pyobj(pyobj_interpreter(), /*ignore_hermetic_tls=*/true)
|
||||
.has_value();
|
||||
}
|
||||
|
||||
bool PyObjectSlot::owns_pyobj() {
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
return reinterpret_cast<uintptr_t>(pyobj_) & 1;
|
||||
|
@ -14,7 +14,9 @@ struct C10_API PyObjectSlot {
|
||||
public:
|
||||
PyObjectSlot();
|
||||
|
||||
void destroy_pyobj_if_needed();
|
||||
~PyObjectSlot();
|
||||
|
||||
void maybe_destroy_pyobj();
|
||||
|
||||
// Associate the TensorImpl with the specified PyObject, and, if necessary,
|
||||
// also tag the interpreter.
|
||||
@ -82,9 +84,20 @@ struct C10_API PyObjectSlot {
|
||||
// a non-nullopt (but possibly null) PyObject. If (possibly) untagged,
|
||||
// returns a nullopt. If it is definitely invalid, raises an error.
|
||||
//
|
||||
// If `ignore_hermetic_tls` is false and this function is called from a
|
||||
// hermetic context (ie, `HermeticPyObjectTLS::get_state()` is true), then
|
||||
// nullopt is returned. If `ignore_hermetic_tls` is true, then the hermetic
|
||||
// context is ignored, allowing you to check the interpreter tag of a
|
||||
// nonhermetic PyObject from within a hermetic context. This is necessary
|
||||
// because there are some cases where the deallocator function of a
|
||||
// nonhermetic PyObject is called from within a hermetic context, so it must
|
||||
// be properly treated as a nonhermetic PyObject.
|
||||
//
|
||||
// NB: this lives in header so that we can avoid actually creating the
|
||||
// c10::optional
|
||||
c10::optional<PyObject*> check_pyobj(PyInterpreter* self_interpreter) const {
|
||||
c10::optional<PyObject*> check_pyobj(
|
||||
PyInterpreter* self_interpreter,
|
||||
bool ignore_hermetic_tls = false) const {
|
||||
// Note [Memory ordering on Python interpreter tag]
|
||||
impl::PyInterpreter* interpreter =
|
||||
pyobj_interpreter_.load(std::memory_order_acquire);
|
||||
@ -97,7 +110,7 @@ struct C10_API PyObjectSlot {
|
||||
return c10::nullopt;
|
||||
} else if (interpreter == self_interpreter) {
|
||||
// NB: pyobj_ could still be null!
|
||||
if (c10::impl::HermeticPyObjectTLS::get_state()) {
|
||||
if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) {
|
||||
return c10::nullopt;
|
||||
} else {
|
||||
return c10::make_optional(_unchecked_untagged_pyobj());
|
||||
@ -118,6 +131,13 @@ struct C10_API PyObjectSlot {
|
||||
|
||||
PyInterpreter& load_pyobj_interpreter() const;
|
||||
|
||||
// Check if the PyObjectSlot's interpreter is the same as the specified
|
||||
// interpreter
|
||||
bool check_interpreter(PyInterpreter* interpreter);
|
||||
|
||||
// Check if the PyObjectSlot is holding a PyObject, owned or non-owned
|
||||
bool has_pyobj_nonhermetic();
|
||||
|
||||
bool owns_pyobj();
|
||||
|
||||
void set_owns_pyobj(bool b);
|
||||
|
@ -1118,7 +1118,7 @@ int64_t _Tensor_ndim(mpy::handle h) {
|
||||
mpy::handle handle_from_tensor(Arena& A, TensorRef t) {
|
||||
// fast case: tensor is live in python
|
||||
c10::optional<PyObject*> mb_obj =
|
||||
t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(getPyInterpreter());
|
||||
t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(getPyInterpreter(), /*ignore_hermetic_tls=*/false);
|
||||
if (mb_obj.has_value() && !t->unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) {
|
||||
return *mb_obj;
|
||||
}
|
||||
|
@ -8832,6 +8832,16 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
||||
|
||||
T()
|
||||
|
||||
def test_storage_base_init(self):
|
||||
# Direct construction not OK
|
||||
self.assertRaises(RuntimeError, lambda: torch._C.StorageBase())
|
||||
|
||||
# But construction of subclass is OK
|
||||
class T(torch._C.StorageBase):
|
||||
pass
|
||||
|
||||
T()
|
||||
|
||||
def test_tensor_base_new(self):
|
||||
|
||||
# OK to call super().__new__, see
|
||||
@ -8844,6 +8854,18 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
||||
x = torch.ones(5)
|
||||
test_tensor = TestTensor(x)
|
||||
|
||||
def test_storage_base_new(self):
|
||||
|
||||
# OK to call super().__new__, see
|
||||
# https://github.com/pytorch/pytorch/issues/57421
|
||||
class TestStorage(torch._C.StorageBase):
|
||||
@staticmethod
|
||||
def __new__(cls, x, *args, **kwargs):
|
||||
return super().__new__(cls, x, *args, **kwargs)
|
||||
|
||||
x = torch.UntypedStorage(5)
|
||||
test_storage = TestStorage(x)
|
||||
|
||||
def test_pyobj_preserved(self):
|
||||
x = torch.empty(2)
|
||||
x.foo = 2 # put something on __dict__
|
||||
@ -8868,6 +8890,160 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
||||
del z # it's dead again
|
||||
self.assertEqual(type(y.grad), MyTensor)
|
||||
|
||||
@skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
|
||||
def test_storage_dealloc(self):
|
||||
m, t = Tracker.make()
|
||||
s0 = torch.UntypedStorage(10)
|
||||
s1 = s0
|
||||
s0._tracker = t
|
||||
del t
|
||||
|
||||
self.assertFalse(m[0])
|
||||
del s0
|
||||
self.assertFalse(m[0])
|
||||
del s1
|
||||
self.assertTrue(m[0])
|
||||
|
||||
@skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
|
||||
def test_storage_from_tensor_dealloc(self):
|
||||
m, t = Tracker.make()
|
||||
a = torch.randn(10)
|
||||
s0 = a.untyped_storage()
|
||||
s0._tracker = t
|
||||
del t
|
||||
|
||||
s1 = a.untyped_storage()
|
||||
self.assertTrue(s0 is s1)
|
||||
self.assertTrue(hasattr(s1, '_tracker'))
|
||||
|
||||
del a
|
||||
|
||||
self.assertFalse(m[0])
|
||||
del s0
|
||||
self.assertFalse(m[0])
|
||||
del s1
|
||||
self.assertTrue(m[0])
|
||||
|
||||
@skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
|
||||
def test_storage_from_tensor_dealloc_zombie(self):
|
||||
m, t = Tracker.make()
|
||||
a = torch.randn(10)
|
||||
s0 = a.untyped_storage()
|
||||
s0._tracker = t
|
||||
del t
|
||||
|
||||
s1 = a.untyped_storage()
|
||||
self.assertTrue(s0 is s1)
|
||||
self.assertTrue(hasattr(s1, '_tracker'))
|
||||
|
||||
self.assertFalse(m[0])
|
||||
del s0
|
||||
self.assertFalse(m[0])
|
||||
del s1
|
||||
self.assertFalse(m[0])
|
||||
del a
|
||||
self.assertTrue(m[0])
|
||||
|
||||
@skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
|
||||
def test_storage_from_tensor_dealloc_resurrected(self):
|
||||
m, t = Tracker.make()
|
||||
a = torch.randn(10)
|
||||
s0 = a.untyped_storage()
|
||||
s0._tracker = t
|
||||
del t
|
||||
|
||||
s1 = a.untyped_storage()
|
||||
self.assertTrue(s0 is s1)
|
||||
self.assertTrue(hasattr(s1, '_tracker'))
|
||||
|
||||
self.assertFalse(m[0])
|
||||
del s0
|
||||
self.assertFalse(m[0])
|
||||
del s1
|
||||
self.assertFalse(m[0])
|
||||
|
||||
s0 = a.untyped_storage()
|
||||
self.assertTrue(isinstance(s0, torch.UntypedStorage))
|
||||
|
||||
del a
|
||||
self.assertFalse(m[0])
|
||||
del s0
|
||||
self.assertTrue(m[0])
|
||||
|
||||
@skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
|
||||
def test_storage_dealloc_resurrected(self):
|
||||
m, t = Tracker.make()
|
||||
s = torch.UntypedStorage(10)
|
||||
s._tracker = t
|
||||
del t
|
||||
|
||||
a = torch.tensor(s)
|
||||
self.assertFalse(m[0])
|
||||
del s
|
||||
|
||||
self.assertFalse(m[0])
|
||||
|
||||
s = a.untyped_storage()
|
||||
self.assertTrue(isinstance(s, torch.UntypedStorage))
|
||||
|
||||
del a
|
||||
self.assertFalse(m[0])
|
||||
del s
|
||||
self.assertTrue(m[0])
|
||||
|
||||
@skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
|
||||
def test_storage_dealloc_subclass_zombie(self):
|
||||
class MyStorage(torch.UntypedStorage):
|
||||
finalized_count = 0
|
||||
|
||||
def __del__(self):
|
||||
MyStorage.finalized_count += 1
|
||||
|
||||
m, t = Tracker.make()
|
||||
s = MyStorage(10)
|
||||
s._tracker = t
|
||||
del t
|
||||
|
||||
a = torch.tensor(s)
|
||||
self.assertFalse(m[0])
|
||||
del s
|
||||
|
||||
self.assertEqual(MyStorage.finalized_count, 0)
|
||||
self.assertFalse(m[0])
|
||||
|
||||
del a
|
||||
self.assertEqual(MyStorage.finalized_count, 1)
|
||||
self.assertTrue(m[0])
|
||||
|
||||
@skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
|
||||
def test_storage_dealloc_subclass_resurrected(self):
|
||||
class MyStorage(torch.UntypedStorage):
|
||||
finalized_count = 0
|
||||
|
||||
def __del__(self):
|
||||
MyStorage.finalized_count += 1
|
||||
|
||||
m, t = Tracker.make()
|
||||
s = MyStorage(10)
|
||||
s._tracker = t
|
||||
del t
|
||||
|
||||
a = torch.tensor(s)
|
||||
self.assertFalse(m[0])
|
||||
del s
|
||||
|
||||
self.assertEqual(MyStorage.finalized_count, 0)
|
||||
self.assertFalse(m[0])
|
||||
|
||||
s = a.untyped_storage()
|
||||
del a
|
||||
self.assertFalse(m[0])
|
||||
self.assertEqual(MyStorage.finalized_count, 0)
|
||||
self.assertTrue(isinstance(s, MyStorage))
|
||||
del s
|
||||
self.assertEqual(MyStorage.finalized_count, 1)
|
||||
self.assertTrue(m[0])
|
||||
|
||||
def test_tensor_slot_dealloc(self):
|
||||
|
||||
class SlotTensor1(torch._C._TensorBase):
|
||||
@ -8889,6 +9065,27 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
||||
self.assertTrue(m1[0])
|
||||
self.assertTrue(m2[0])
|
||||
|
||||
def test_storage_slot_dealloc(self):
|
||||
|
||||
class SlotStorage1(torch._C.StorageBase):
|
||||
__slots__ = ['slot1']
|
||||
|
||||
class SlotStorage2(SlotStorage1):
|
||||
__slots__ = ['slot2']
|
||||
|
||||
m1, t1 = Tracker.make()
|
||||
m2, t2 = Tracker.make()
|
||||
slot_storage = SlotStorage2(torch.UntypedStorage(2))
|
||||
slot_storage.slot1 = t1
|
||||
slot_storage.slot2 = t2
|
||||
del t1
|
||||
del t2
|
||||
self.assertFalse(m1[0])
|
||||
self.assertFalse(m2[0])
|
||||
del slot_storage
|
||||
self.assertTrue(m1[0])
|
||||
self.assertTrue(m2[0])
|
||||
|
||||
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
|
||||
def test_tensor_dict_dealloc(self):
|
||||
m, t = Tracker.make()
|
||||
@ -8899,6 +9096,16 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
||||
del x
|
||||
self.assertTrue(m[0])
|
||||
|
||||
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
|
||||
def test_storage_dict_dealloc(self):
|
||||
m, t = Tracker.make()
|
||||
x = torch.UntypedStorage(2)
|
||||
x.arf = t
|
||||
del t
|
||||
self.assertFalse(m[0])
|
||||
del x
|
||||
self.assertTrue(m[0])
|
||||
|
||||
def test_tensor_finalizer_dealloc(self):
|
||||
m = [False]
|
||||
|
||||
@ -8911,9 +9118,20 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
||||
del fin_tensor
|
||||
self.assertTrue(m[0])
|
||||
|
||||
def test_storage_finalizer_dealloc(self):
|
||||
m = [False]
|
||||
|
||||
class FinalizerStorage(torch._C.StorageBase):
|
||||
def __del__(self):
|
||||
m[0] = True
|
||||
|
||||
fin_storage = FinalizerStorage(torch.UntypedStorage(2))
|
||||
self.assertFalse(m[0])
|
||||
del fin_storage
|
||||
self.assertTrue(m[0])
|
||||
|
||||
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
|
||||
def test_tensor_weakref_dealloc(self):
|
||||
|
||||
x = torch.empty(2)
|
||||
m = [False]
|
||||
|
||||
@ -8925,6 +9143,20 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
||||
self.assertTrue(m[0])
|
||||
self.assertEqual(wref(), None)
|
||||
|
||||
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
|
||||
def test_storage_weakref_dealloc(self):
|
||||
|
||||
x = torch.UntypedStorage(2)
|
||||
m = [False]
|
||||
|
||||
def cb(r):
|
||||
m[0] = True
|
||||
|
||||
wref = weakref.ref(x, cb)
|
||||
del x
|
||||
self.assertTrue(m[0])
|
||||
self.assertEqual(wref(), None)
|
||||
|
||||
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
|
||||
def test_tensor_cycle_via_dict(self):
|
||||
m1, t1 = Tracker.make()
|
||||
@ -8968,6 +9200,49 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
||||
self.assertTrue(m1[0])
|
||||
self.assertTrue(m2[0])
|
||||
|
||||
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
|
||||
def test_storage_cycle_via_dict(self):
|
||||
m1, t1 = Tracker.make()
|
||||
x = torch.UntypedStorage(2)
|
||||
x._tracker = t1
|
||||
del t1
|
||||
|
||||
m2, t2 = Tracker.make()
|
||||
y = torch.UntypedStorage(2)
|
||||
y._tracker = t2
|
||||
del t2
|
||||
|
||||
x._loop = y
|
||||
y._loop = x
|
||||
|
||||
# C++ reference should keep the cycle live!
|
||||
# This exercise THPVariable_subtype_traverse
|
||||
# NB: Because z.grad is a reference done entirely in C++, cycles
|
||||
# involving it directly are NOT broken by Python GC; you've
|
||||
# set up a good old C++ reference cycle which we cannot safely
|
||||
# break (because C++ references are allowed to be accessed
|
||||
# multithreaded-ly) (TODO: except maybe if you can prove that
|
||||
# only Python has access to the C++ object, in which case you can
|
||||
# also prove that no multithreaded access occurs)
|
||||
z = torch.UntypedStorage(2)
|
||||
z.grad = x
|
||||
|
||||
del x
|
||||
del y
|
||||
|
||||
gc.collect()
|
||||
self.assertFalse(m1[0])
|
||||
self.assertFalse(m2[0])
|
||||
|
||||
with disable_gc():
|
||||
del z
|
||||
self.assertFalse(m1[0])
|
||||
self.assertFalse(m2[0])
|
||||
|
||||
gc.collect()
|
||||
self.assertTrue(m1[0])
|
||||
self.assertTrue(m2[0])
|
||||
|
||||
def test_tensor_cycle_via_slots(self):
|
||||
m1 = [False]
|
||||
m2 = [False]
|
||||
@ -9000,6 +9275,67 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
||||
self.assertTrue(m1[0])
|
||||
self.assertTrue(m2[0])
|
||||
|
||||
def test_storage_cycle_via_slots(self):
|
||||
m1 = [False]
|
||||
m2 = [False]
|
||||
|
||||
class SlotStorage1(torch._C.StorageBase):
|
||||
__slots__ = ['slot1']
|
||||
|
||||
def __del__(self):
|
||||
m1[0] = True
|
||||
|
||||
class SlotStorage2(SlotStorage1):
|
||||
__slots__ = ['slot2']
|
||||
|
||||
def __del__(self):
|
||||
m2[0] = True
|
||||
|
||||
x = SlotStorage1(torch.UntypedStorage(2))
|
||||
y = SlotStorage2(torch.UntypedStorage(2))
|
||||
|
||||
x.slot1 = y
|
||||
y.slot2 = x
|
||||
|
||||
del x
|
||||
with disable_gc():
|
||||
del y
|
||||
self.assertFalse(m1[0])
|
||||
self.assertFalse(m2[0])
|
||||
|
||||
gc.collect()
|
||||
self.assertTrue(m1[0])
|
||||
self.assertTrue(m2[0])
|
||||
|
||||
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
|
||||
def test_storage_preserve_nonhermetic_in_hermetic_context(self):
|
||||
from torch.library import Library, impl
|
||||
global _my_storage
|
||||
|
||||
my_lib = Library("my_lib", "DEF")
|
||||
my_lib.define('my_func() -> None')
|
||||
|
||||
a = torch.tensor([1.])
|
||||
_my_storage = a.untyped_storage()
|
||||
|
||||
m, t = Tracker.make()
|
||||
_my_storage._tracker = t
|
||||
del t
|
||||
|
||||
@impl(my_lib, 'my_func', '')
|
||||
def my_func():
|
||||
global _my_storage
|
||||
del _my_storage
|
||||
|
||||
self.assertFalse(m[0])
|
||||
torch.ops.my_lib.my_func()
|
||||
self.assertFalse(m[0])
|
||||
|
||||
s = a.untyped_storage()
|
||||
del a
|
||||
del s
|
||||
self.assertTrue(m[0])
|
||||
|
||||
# FIXME: move to test_autograd?
|
||||
@skipIfTorchDynamo("TorchDynamo does not work well with hooks")
|
||||
def test_backward_hooks_traverse(self):
|
||||
@ -9028,7 +9364,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
||||
self.assertTrue(m2[0])
|
||||
|
||||
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
|
||||
def test_dead_weak_ref(self):
|
||||
def test_tensor_dead_weak_ref(self):
|
||||
x = torch.empty(2)
|
||||
w_x = weakref.ref(x)
|
||||
y = torch.empty(2)
|
||||
@ -9044,7 +9380,24 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
||||
|
||||
self.assertRaises(RuntimeError, lambda: x.sigmoid())
|
||||
|
||||
def test_resurrected_weak_ref(self):
|
||||
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
|
||||
def test_storage_dead_weak_ref(self):
|
||||
x = torch.UntypedStorage(2)
|
||||
w_x = weakref.ref(x)
|
||||
y = torch.tensor(x)
|
||||
del x
|
||||
|
||||
x = w_x()
|
||||
# Ideally, x would keep the storage live. But CPython doesn't
|
||||
# provide enough hooks to do this. So it will go dead and x
|
||||
# will transmute into storage with null StorageImpl. Not great, but the
|
||||
# best we can do.
|
||||
del y
|
||||
|
||||
self.assertRaisesRegex(RuntimeError, "Got a null Storage", lambda: x[0])
|
||||
self.assertRaisesRegex(RuntimeError, "Got a null Storage", lambda: x.float())
|
||||
|
||||
def test_tensor_resurrected_weak_ref(self):
|
||||
x = torch.empty(2)
|
||||
w_x = weakref.ref(x)
|
||||
y = torch.empty(2)
|
||||
@ -9057,8 +9410,20 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
||||
del y
|
||||
x.sigmoid()
|
||||
|
||||
def test_storage_resurrected_weak_ref(self):
|
||||
x = torch.UntypedStorage(2)
|
||||
w_x = weakref.ref(x)
|
||||
y = torch.tensor(x)
|
||||
del x
|
||||
|
||||
x = w_x()
|
||||
# Use this to manually fix weak reference after dereferencing them
|
||||
x._fix_weakref()
|
||||
del y
|
||||
x.float()
|
||||
|
||||
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
|
||||
def test_fix_weakref_no_leak(self):
|
||||
def test_tensor_fix_weakref_no_leak(self):
|
||||
import weakref
|
||||
|
||||
called = False
|
||||
@ -9074,6 +9439,23 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
||||
|
||||
self.assertTrue(called)
|
||||
|
||||
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
|
||||
def test_storage_fix_weakref_no_leak(self):
|
||||
import weakref
|
||||
|
||||
called = False
|
||||
|
||||
a = torch.UntypedStorage(1)
|
||||
|
||||
def callback(w):
|
||||
nonlocal called
|
||||
called = True
|
||||
wa = weakref.ref(a, callback)
|
||||
a._fix_weakref()
|
||||
del a
|
||||
|
||||
self.assertTrue(called)
|
||||
|
||||
# FIXME: move to test_linalg
|
||||
@torch.inference_mode()
|
||||
def test_bmm_multithreaded(self):
|
||||
|
@ -30,29 +30,6 @@ std::array<THPDtype*, static_cast<int>(at::ScalarType::NumOptions)>
|
||||
std::array<THPLayout*, static_cast<int>(at::Layout::NumOptions)>
|
||||
layout_registry = {};
|
||||
|
||||
at::DeprecatedTypeProperties* get_type_properties(
|
||||
at::DeviceType device_type,
|
||||
at::ScalarType scalarType) {
|
||||
at::Backend backend = at::Backend::Undefined;
|
||||
if (device_type == at::kCPU) {
|
||||
backend = at::Backend::CPU;
|
||||
} else if (device_type == at::kCUDA) {
|
||||
backend = at::Backend::CUDA;
|
||||
} else if (device_type == at::kXPU) {
|
||||
backend = at::Backend::XPU;
|
||||
} else if (device_type == at::kHPU) {
|
||||
backend = at::Backend::HPU;
|
||||
} else if (device_type == at::kMPS) {
|
||||
backend = at::Backend::MPS;
|
||||
} else if (device_type == at::DeviceType::Meta) {
|
||||
backend = at::Backend::Undefined;
|
||||
} else if (device_type == at::DeviceType::PrivateUse1) {
|
||||
backend = at::Backend::PrivateUse1;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Invalid device for storage: ", device_type);
|
||||
}
|
||||
return &at::getDeprecatedTypeProperties(backend, scalarType);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void registerDtypeObject(THPDtype* dtype, at::ScalarType scalarType) {
|
||||
@ -88,13 +65,10 @@ PyObject* createPyObject(const at::Storage& storage) {
|
||||
// information about storages from python). However, any accesses to the
|
||||
// data_ptr is not allowed, through methods like
|
||||
// x.untyped_storage().data_ptr()
|
||||
PyTypeObject* type = reinterpret_cast<PyTypeObject*>(THPStorageClass);
|
||||
auto obj = THPObjectPtr(type->tp_alloc(type, 0));
|
||||
PyObject* obj = THPStorage_Wrap(storage);
|
||||
if (!obj)
|
||||
throw python_error();
|
||||
((THPStorage*)obj.get())->cdata =
|
||||
c10::MaybeOwned<at::Storage>::owned(at::Storage(/* copy */ storage));
|
||||
return obj.release();
|
||||
return obj;
|
||||
}
|
||||
|
||||
PyTypeObject* loadTypedStorageTypeObject() {
|
||||
@ -118,16 +92,13 @@ bool isStorage(PyObject* obj) {
|
||||
if (PyObject_TypeCheck(obj, getTypedStorageTypeObject())) {
|
||||
return true;
|
||||
}
|
||||
auto obj_type = Py_TYPE(obj);
|
||||
|
||||
return obj_type == reinterpret_cast<PyTypeObject*>(THPStorageClass);
|
||||
return THPStorage_Check(obj);
|
||||
}
|
||||
|
||||
at::Storage createStorageGetType(
|
||||
PyObject* obj,
|
||||
at::ScalarType& scalar_type,
|
||||
bool& is_typed_storage) {
|
||||
is_typed_storage = PyObject_TypeCheck(obj, getTypedStorageTypeObject());
|
||||
std::tuple<at::Storage, at::ScalarType, bool> createStorageGetType(
|
||||
PyObject* obj) {
|
||||
at::ScalarType scalar_type = at::ScalarType::Undefined;
|
||||
bool is_typed_storage = PyObject_TypeCheck(obj, getTypedStorageTypeObject());
|
||||
PyObject* untyped_storage_obj = nullptr;
|
||||
|
||||
if (is_typed_storage) {
|
||||
@ -136,10 +107,9 @@ at::Storage createStorageGetType(
|
||||
// stay nonzero since the `TypedStorage` maintains a reference.
|
||||
PyObject* dtype_obj = PyObject_GetAttrString(obj, "dtype");
|
||||
TORCH_INTERNAL_ASSERT(dtype_obj);
|
||||
Py_DECREF(dtype_obj);
|
||||
|
||||
TORCH_INTERNAL_ASSERT(THPDtype_Check(dtype_obj));
|
||||
scalar_type = reinterpret_cast<THPDtype*>(dtype_obj)->scalar_type;
|
||||
Py_DECREF(dtype_obj);
|
||||
|
||||
untyped_storage_obj = PyObject_GetAttrString(obj, "_untyped_storage");
|
||||
TORCH_INTERNAL_ASSERT(untyped_storage_obj);
|
||||
@ -150,22 +120,18 @@ at::Storage createStorageGetType(
|
||||
untyped_storage_obj = obj;
|
||||
}
|
||||
|
||||
if (Py_TYPE(untyped_storage_obj) !=
|
||||
reinterpret_cast<PyTypeObject*>(THPStorageClass)) {
|
||||
throw TypeError("not a storage '%s'", Py_TYPE(obj)->tp_name);
|
||||
}
|
||||
TORCH_CHECK(
|
||||
THPStorage_Check(untyped_storage_obj),
|
||||
"not a storage '",
|
||||
Py_TYPE(obj)->tp_name,
|
||||
"'");
|
||||
|
||||
const auto& storage = THPStorage_Unpack(untyped_storage_obj);
|
||||
c10::DeviceType device_type = storage.device().type();
|
||||
auto type_properties = get_type_properties(device_type, at::kByte);
|
||||
return type_properties->unsafeStorageFromTH(
|
||||
storage.unsafeGetStorageImpl(), true);
|
||||
auto storage = THPStorage_Unpack(untyped_storage_obj);
|
||||
return std::make_tuple(storage, scalar_type, is_typed_storage);
|
||||
}
|
||||
|
||||
at::Storage createStorage(PyObject* obj) {
|
||||
at::ScalarType scalar_type = at::ScalarType::Undefined;
|
||||
bool is_typed_storage = false;
|
||||
return createStorageGetType(obj, scalar_type, is_typed_storage);
|
||||
return std::get<0>(createStorageGetType(obj));
|
||||
}
|
||||
|
||||
} // namespace torch
|
||||
|
@ -27,10 +27,8 @@ void registerLayoutObject(THPLayout* thp_layout, at::Layout layout);
|
||||
|
||||
TORCH_PYTHON_API PyObject* createPyObject(const at::Storage& storage);
|
||||
at::Storage createStorage(PyObject* obj);
|
||||
at::Storage createStorageGetType(
|
||||
PyObject* obj,
|
||||
at::ScalarType& scalar_type,
|
||||
bool& is_typed_storage);
|
||||
std::tuple<at::Storage, at::ScalarType, bool> createStorageGetType(
|
||||
PyObject* obj);
|
||||
bool isStorage(PyObject* obj);
|
||||
|
||||
TORCH_PYTHON_API THPDtype* getTHPDtype(at::ScalarType scalarType);
|
||||
|
@ -35,7 +35,7 @@ struct ConcretePyInterpreterVTable final
|
||||
: public c10::impl::PyInterpreterVTable {
|
||||
std::string name() const override;
|
||||
|
||||
void decref(PyObject* pyobj, bool is_tensor) const override;
|
||||
void decref(PyObject* pyobj, bool has_pyobj_slot) const override;
|
||||
|
||||
// TODO: Need to make this work for StorageImpl too. I imagine I'll want to
|
||||
// operate upon a PyObjectSlot rather than a TensorImpl
|
||||
@ -189,15 +189,15 @@ py::object torchDispatchFromTensorImpl(
|
||||
TorchFunctionName::TorchDispatch));
|
||||
}
|
||||
|
||||
// NOTE [PyInterpreter::decref takes an `is_tensor` arg]
|
||||
// NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg]
|
||||
// Before calling PyInterpreter::decref, we must statically know if the
|
||||
// pyobj is a Tensor or not.
|
||||
// - If it is a tensor, we need to be careful about PyObject resurrection
|
||||
// - If it is not a tensor, we can freely decref
|
||||
// pyobj has a PyObjectSlot or not.
|
||||
// - If it has a PyObjectSlot, we need to be careful about PyObject resurrection
|
||||
// - If it does not have a PyObjectSlot, we can freely decref
|
||||
// One alternative to this is using PyObject_IsInstance
|
||||
// to get at this information. However, we don't want to risk an incorrect
|
||||
// `__instancecheck__` changing the semantics here.
|
||||
void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool is_tensor)
|
||||
void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool has_pyobj_slot)
|
||||
const {
|
||||
// Leak the pyobj if not initialized. This can happen if we are running
|
||||
// exit handlers that are destructing tensors with residual (owned)
|
||||
@ -207,23 +207,33 @@ void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool is_tensor)
|
||||
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
// Two possibilities:
|
||||
// 1. We are decref-ing a tensor. Then we must be careful about
|
||||
// PyObject resurrection (this only applies to Tensors, see
|
||||
// 1. We are decref-ing an object that has a PyObjectSlot, like a Tensor or
|
||||
// Storage. Then we must be careful about PyObject resurrection (see
|
||||
// THPVariable_clear).
|
||||
// 2. We are decref-ing some other Python object. We don't do
|
||||
// PyObject resurrection on non-Tensors, so we just carry on as usual
|
||||
if (is_tensor && Py_REFCNT(pyobj) > 1) {
|
||||
// It's still alive! This can happen if a weak ref resurrected
|
||||
// the PyObject without flipping ownership. At this point it is
|
||||
// too late to rescue the object, so just stub out the PyObject
|
||||
// so that it fails on subsequent uses. Don't raise an error here;
|
||||
// you're probably in a destructor.
|
||||
TORCH_WARN(
|
||||
"Deallocating Tensor that still has live PyObject references. "
|
||||
"This probably happened because you took out a weak reference to "
|
||||
"Tensor and didn't call _fix_weakref() after dereferencing it. "
|
||||
"Subsequent accesses to this tensor via the PyObject will now fail.");
|
||||
((THPVariable*)pyobj)->cdata = c10::MaybeOwned<torch::autograd::Variable>();
|
||||
if (has_pyobj_slot && Py_REFCNT(pyobj) > 1) {
|
||||
if (THPVariable_Check(pyobj)) {
|
||||
// It's still alive! This can happen if a weak ref resurrected
|
||||
// the PyObject without flipping ownership. At this point it is
|
||||
// too late to rescue the object, so just stub out the PyObject
|
||||
// so that it fails on subsequent uses. Don't raise an error here;
|
||||
// you're probably in a destructor.
|
||||
TORCH_WARN(
|
||||
"Deallocating Tensor that still has live PyObject references. "
|
||||
"This probably happened because you took out a weak reference to "
|
||||
"Tensor and didn't call _fix_weakref() after dereferencing it. "
|
||||
"Subsequent accesses to this tensor via the PyObject will now fail.");
|
||||
((THPVariable*)pyobj)->cdata =
|
||||
c10::MaybeOwned<torch::autograd::Variable>();
|
||||
} else if (THPStorage_Check(pyobj)) {
|
||||
TORCH_WARN(
|
||||
"Deallocating UntypedStorage that still has live PyObject references. "
|
||||
"This probably happened because you took out a weak reference to "
|
||||
"UntypedStorage and didn't call _fix_weakref() after dereferencing it. "
|
||||
"Subsequent accesses to this storage via the PyObject will now fail.");
|
||||
((THPStorage*)pyobj)->cdata = c10::MaybeOwned<c10::Storage>();
|
||||
}
|
||||
}
|
||||
Py_DECREF(pyobj);
|
||||
};
|
||||
@ -548,8 +558,8 @@ static void set_tensor_attr_with_capsule(
|
||||
const c10::TensorImpl* tensor,
|
||||
py::capsule& capsule,
|
||||
const char* attr_name) {
|
||||
c10::optional<PyObject*> mb_obj =
|
||||
tensor->pyobj_slot()->check_pyobj(getPyInterpreter());
|
||||
c10::optional<PyObject*> mb_obj = tensor->pyobj_slot()->check_pyobj(
|
||||
getPyInterpreter(), /*ignore_hermetic_tls=*/false);
|
||||
TORCH_CHECK(
|
||||
mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value");
|
||||
py::handle(mb_obj.value()).attr(attr_name) = capsule;
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
#include <ATen/mps/MPSDevice.h>
|
||||
#include <c10/core/CPUAllocator.h>
|
||||
#include <c10/core/RefcountedDeleter.h>
|
||||
#include <libshm.h>
|
||||
#include <torch/csrc/CudaIPCTypes.h>
|
||||
#include <torch/csrc/Device.h>
|
||||
@ -15,6 +16,7 @@
|
||||
#include <torch/csrc/THP.h>
|
||||
#include <torch/csrc/autograd/utils/wrap_outputs.h>
|
||||
#include <torch/csrc/copy_utils.h>
|
||||
#include <torch/csrc/utils/pyobject_preservation.h>
|
||||
#include <torch/csrc/utils/python_arg_parser.h>
|
||||
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
@ -27,28 +29,265 @@ void THPPointer<c10::StorageImpl>::free() {
|
||||
}
|
||||
}
|
||||
|
||||
PyObject* THPStorageClass = nullptr;
|
||||
PyTypeObject* THPStorageClass = nullptr;
|
||||
|
||||
PyObject* THPStorage_New(c10::Storage storage) {
|
||||
PyTypeObject* type = (PyTypeObject*)THPStorageClass;
|
||||
PyObject* obj = type->tp_alloc(type, 0);
|
||||
if (obj) {
|
||||
((THPStorage*)obj)->cdata =
|
||||
c10::MaybeOwned<c10::Storage>::owned(std::move(storage));
|
||||
PyObject* THPStorage_NewWithStorage(
|
||||
PyTypeObject* type,
|
||||
c10::Storage _storage,
|
||||
c10::impl::PyInterpreterStatus status,
|
||||
bool allow_preexisting_pyobj) {
|
||||
TORCH_CHECK(
|
||||
PyType_IsSubtype(type, &THPStorageType),
|
||||
"Creating a Storage subclass from a class that does not inherit from ",
|
||||
"Storage is not possible. Make sure your class inherits from Storage.");
|
||||
|
||||
auto maybe_pyobj = _storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj(
|
||||
getPyInterpreter(), /*ignore_hermetic_tls=*/false);
|
||||
if (maybe_pyobj.has_value() && maybe_pyobj.value()) {
|
||||
TORCH_CHECK(
|
||||
allow_preexisting_pyobj,
|
||||
"Creating a new Storage subclass ",
|
||||
type->tp_name,
|
||||
" but the raw Storage object is already associated to a python object ",
|
||||
"of type ",
|
||||
maybe_pyobj.value()->ob_type->tp_name);
|
||||
PyObject* obj = *maybe_pyobj;
|
||||
PyTypeObject* obj_type = Py_TYPE(obj);
|
||||
TORCH_CHECK(
|
||||
obj_type == type || PyType_IsSubtype(obj_type, type),
|
||||
"Creating a new Storage subclass ",
|
||||
type->tp_name,
|
||||
" but the raw Storage object is already associated to a python object ",
|
||||
"of type ",
|
||||
maybe_pyobj.value()->ob_type->tp_name,
|
||||
" which is not a subclass of the "
|
||||
"requested type");
|
||||
return THPStorage_Wrap(std::move(_storage));
|
||||
}
|
||||
|
||||
PyObject* obj = type->tp_alloc(type, 0);
|
||||
TORCH_CHECK(obj, "Failed to allocate a ", type->tp_name, " object");
|
||||
|
||||
auto s = (THPStorage*)obj;
|
||||
|
||||
new (&s->cdata) c10::MaybeOwned<c10::Storage>();
|
||||
|
||||
s->cdata = c10::MaybeOwned<c10::Storage>::owned(std::move(_storage));
|
||||
|
||||
if (!c10::impl::HermeticPyObjectTLS::get_state()) {
|
||||
s->is_hermetic = false;
|
||||
const auto& storage = THPStorage_Unpack(s);
|
||||
storage.unsafeGetStorageImpl()->pyobj_slot()->init_pyobj(
|
||||
getPyInterpreter(), obj, status);
|
||||
} else {
|
||||
s->is_hermetic = true;
|
||||
}
|
||||
|
||||
return obj;
|
||||
}
|
||||
|
||||
// Wraps the c10::Storage with a storage PyObject
|
||||
PyObject* THPStorage_Wrap(c10::Storage storage) {
|
||||
c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl();
|
||||
if (c10::impl::HermeticPyObjectTLS::get_state()) {
|
||||
return THPStorage_NewWithStorage(
|
||||
THPStorageClass,
|
||||
std::move(storage),
|
||||
c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
|
||||
}
|
||||
c10::impl::PyObjectSlot* pyobj_slot = storage_impl->pyobj_slot();
|
||||
|
||||
// If the StorageImpl has a PyObject that is managed by a different
|
||||
// interpreter than the current one, create a new StorageImpl that points to
|
||||
// the same data and then create the Python storage from that.
|
||||
// NOTE: This is only supposed to happen in MultiPy
|
||||
if (pyobj_slot->has_pyobj_nonhermetic() &&
|
||||
!pyobj_slot->check_interpreter(getPyInterpreter())) {
|
||||
return THPStorage_NewWithStorage(
|
||||
THPStorageClass,
|
||||
c10::newStorageImplFromRefcountedDataPtr(storage),
|
||||
c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
|
||||
}
|
||||
c10::optional<PyObject*> maybe_pyobj = pyobj_slot->check_pyobj(
|
||||
getPyInterpreter(), /*ignore_hermetic_tls=*/false);
|
||||
c10::impl::PyInterpreterStatus status =
|
||||
c10::impl::PyInterpreterStatus::TAGGED_BY_US;
|
||||
if (maybe_pyobj.has_value()) {
|
||||
auto obj = *maybe_pyobj;
|
||||
if (obj) {
|
||||
TORCH_CHECK(
|
||||
THPStorage_Check(obj),
|
||||
"Expected a storage type, but got ",
|
||||
Py_TYPE(obj)->tp_name);
|
||||
|
||||
if (pyobj_slot->owns_pyobj()) {
|
||||
pyobj_slot->set_owns_pyobj(false);
|
||||
reinterpret_cast<THPStorage*>(obj)->cdata =
|
||||
c10::MaybeOwned<c10::Storage>::owned(std::move(storage));
|
||||
return obj;
|
||||
} else {
|
||||
Py_INCREF(obj);
|
||||
return obj;
|
||||
}
|
||||
}
|
||||
status = c10::impl::PyInterpreterStatus::TAGGED_BY_US;
|
||||
} else {
|
||||
if (storage.use_count() <= 1) {
|
||||
status = c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED;
|
||||
} else {
|
||||
status = c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED;
|
||||
}
|
||||
}
|
||||
return THPStorage_NewWithStorage(THPStorageClass, std::move(storage), status);
|
||||
}
|
||||
|
||||
static bool THPStorage_isPreservable(THPStorage* self) {
|
||||
if (self->cdata.unsafeIsBorrowed()) {
|
||||
return false;
|
||||
}
|
||||
auto const& storage = THPStorage_Unpack(self);
|
||||
|
||||
if (self->is_hermetic) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj(
|
||||
getPyInterpreter(), /*ignore_hermetic_tls=*/true) !=
|
||||
c10::make_optional((PyObject*)self)) {
|
||||
return false;
|
||||
}
|
||||
if (storage.use_count() <= 1) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool THPStorage_tryPreserve(THPStorage* self) {
|
||||
if (!THPStorage_isPreservable(self)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto& storage = THPStorage_Unpack(self);
|
||||
c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl();
|
||||
|
||||
auto maybe_pyobj = storage_impl->pyobj_slot()->check_pyobj(
|
||||
getPyInterpreter(),
|
||||
/*ignore_hermetic_tls=*/true);
|
||||
// NOTE: It is possible to just set the PyObjectSlot here, but the point is
|
||||
// that we should have already set PyObjectSlot when the storage PyObject was
|
||||
// created.
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
maybe_pyobj.has_value(),
|
||||
"Trying to preserve a Python storage whose PyObjectSlot does not have a PyObject");
|
||||
|
||||
PyObject* pyobj = *maybe_pyobj;
|
||||
|
||||
TORCH_CHECK(
|
||||
THPStorage_Check(pyobj),
|
||||
"Expected a storage type, but got ",
|
||||
Py_TYPE(pyobj)->tp_name);
|
||||
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
(void*)pyobj == (void*)self,
|
||||
"Python storage and the PyObject in the internal PyObjectSlot are not at the same address");
|
||||
|
||||
TORCH_INTERNAL_ASSERT(!storage_impl->pyobj_slot()->owns_pyobj());
|
||||
|
||||
storage_impl->pyobj_slot()->set_owns_pyobj(true);
|
||||
Py_INCREF(self);
|
||||
|
||||
self->cdata = c10::MaybeOwned<c10::Storage>::borrowed(storage);
|
||||
return true;
|
||||
}
|
||||
|
||||
static void THPStorage_subclass_dealloc(PyObject* self) {
|
||||
THPStorage* _self = (THPStorage*)self;
|
||||
// Some subclass of StorageBase are GC-tracked objects even
|
||||
// though the base class is not.
|
||||
|
||||
if (THPStorage_tryPreserve(_self)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Some subclass of StorageBase could be GC-tracked objects even
|
||||
// though the base class is not
|
||||
auto* type = Py_TYPE(self);
|
||||
if (PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC) != 0) {
|
||||
PyObject_GC_UnTrack(self);
|
||||
}
|
||||
|
||||
bool has_finalizer = type->tp_finalize || type->tp_del;
|
||||
|
||||
if (type->tp_finalize) {
|
||||
PyObject_GC_Track(self);
|
||||
if (PyObject_CallFinalizerFromDealloc(self) < 0) {
|
||||
// The finalizer has resurrected the PyObject and there is a new Python
|
||||
// reference to it, so we can just stop deallocating. Read about
|
||||
// resurrection from `__del__` here:
|
||||
// https://docs.python.org/3/reference/datamodel.html#object.__del__
|
||||
return;
|
||||
}
|
||||
PyObject_GC_UnTrack(self);
|
||||
}
|
||||
|
||||
// base test is unnecessary as THPStorae does not set this
|
||||
if (type->tp_weaklistoffset) {
|
||||
PyObject_ClearWeakRefs(self);
|
||||
}
|
||||
|
||||
if (type->tp_del) {
|
||||
PyObject_GC_Track(self);
|
||||
type->tp_del(self);
|
||||
if (self->ob_refcnt > 0) {
|
||||
// Resurrected (see above comment about resurrection from `__del__`)
|
||||
return;
|
||||
}
|
||||
PyObject_GC_UnTrack(self);
|
||||
}
|
||||
|
||||
if (has_finalizer) {
|
||||
/* New weakrefs could be created during the finalizer call.
|
||||
If this occurs, clear them out without calling their
|
||||
finalizers since they might rely on part of the object
|
||||
being finalized that has already been destroyed. */
|
||||
if (type->tp_weaklistoffset) {
|
||||
/* Modeled after GET_WEAKREFS_LISTPTR() */
|
||||
PyWeakReference** list =
|
||||
(PyWeakReference**)PyObject_GET_WEAKREFS_LISTPTR(self);
|
||||
while (*list)
|
||||
_PyWeakref_ClearRef(*list);
|
||||
}
|
||||
}
|
||||
|
||||
// Clear slots
|
||||
{
|
||||
PyTypeObject* base = type;
|
||||
while (base != &THPStorageType) {
|
||||
if (Py_SIZE(base)) {
|
||||
clear_slots(base, self);
|
||||
}
|
||||
base = base->tp_base;
|
||||
TORCH_INTERNAL_ASSERT(base);
|
||||
}
|
||||
}
|
||||
|
||||
// Clear __dict__
|
||||
if (C10_LIKELY(type->tp_dictoffset)) {
|
||||
PyObject** dictptr = _PyObject_GetDictPtr(self);
|
||||
if (dictptr != nullptr) {
|
||||
PyObject* dict = *dictptr;
|
||||
if (dict != nullptr) {
|
||||
Py_DECREF(dict);
|
||||
*dictptr = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_INTERNAL_ASSERT(Py_TYPE(self) == type);
|
||||
|
||||
_self->cdata.~MaybeOwned<c10::Storage>();
|
||||
Py_TYPE(_self)->tp_free(self);
|
||||
|
||||
TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE);
|
||||
Py_DECREF(type);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<c10::StorageImpl> make_storage_impl(
|
||||
@ -151,32 +390,35 @@ static PyObject* THPStorage_pynew(
|
||||
"(): only one or neither of 'allocator' or 'device' can ",
|
||||
"be given, but not both");
|
||||
|
||||
THPStoragePtr self((THPStorage*)type->tp_alloc(type, 0));
|
||||
THPUtils_assert(self, "failed to allocate a " THPStorageStr " object");
|
||||
PyObject* self = nullptr;
|
||||
c10::Allocator* allocator = nullptr;
|
||||
|
||||
// torch.Storage(*, ...)
|
||||
if (r.idx == 0) {
|
||||
self->cdata = c10::MaybeOwned<c10::Storage>::owned(make_storage_impl(
|
||||
c10::StorageImpl::use_byte_size_t(),
|
||||
0,
|
||||
allocator,
|
||||
/*resizable=*/true,
|
||||
allocator_opt,
|
||||
device_opt));
|
||||
return (PyObject*)self.release();
|
||||
self = THPStorage_NewWithStorage(
|
||||
type,
|
||||
make_storage_impl(
|
||||
c10::StorageImpl::use_byte_size_t(),
|
||||
0,
|
||||
allocator,
|
||||
/*resizable=*/true,
|
||||
allocator_opt,
|
||||
device_opt),
|
||||
c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
|
||||
|
||||
// torch.Storage(size, *, ...)
|
||||
} else if (r.idx == 1) {
|
||||
int64_t size = r.toInt64(0);
|
||||
self->cdata = c10::MaybeOwned<c10::Storage>::owned(make_storage_impl(
|
||||
c10::StorageImpl::use_byte_size_t(),
|
||||
size,
|
||||
allocator,
|
||||
/*resizable=*/true,
|
||||
allocator_opt,
|
||||
device_opt));
|
||||
return (PyObject*)self.release();
|
||||
self = THPStorage_NewWithStorage(
|
||||
type,
|
||||
make_storage_impl(
|
||||
c10::StorageImpl::use_byte_size_t(),
|
||||
size,
|
||||
allocator,
|
||||
/*resizable=*/true,
|
||||
allocator_opt,
|
||||
device_opt),
|
||||
c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
|
||||
|
||||
// torch.Storage(sequence, *, ...)
|
||||
} else if (r.idx == 2) {
|
||||
@ -192,19 +434,22 @@ static PyObject* THPStorage_pynew(
|
||||
THPStorageStr,
|
||||
"(): Could not obtain the length of sequence of type ",
|
||||
THPUtils_typename(sequence));
|
||||
self->cdata = c10::MaybeOwned<c10::Storage>::owned(make_storage_impl(
|
||||
c10::StorageImpl::use_byte_size_t(),
|
||||
length,
|
||||
allocator,
|
||||
/*resizable=*/true,
|
||||
allocator_opt,
|
||||
device_opt));
|
||||
self = THPStorage_NewWithStorage(
|
||||
type,
|
||||
make_storage_impl(
|
||||
c10::StorageImpl::use_byte_size_t(),
|
||||
length,
|
||||
allocator,
|
||||
/*resizable=*/true,
|
||||
allocator_opt,
|
||||
device_opt),
|
||||
c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
|
||||
THPObjectPtr item;
|
||||
try {
|
||||
const auto& storage = THPStorage_Unpack(self);
|
||||
for (Py_ssize_t i = 0; i < length; i++) {
|
||||
item = PySequence_GetItem(sequence, i);
|
||||
uint8_t value = THPByteUtils_unpackReal(item.get());
|
||||
const auto& storage = THPStorage_Unpack(self);
|
||||
if (allocator == c10::GetDefaultCPUAllocator()) {
|
||||
static_cast<uint8_t*>(storage.mutable_data())[i] = value;
|
||||
} else {
|
||||
@ -221,20 +466,22 @@ static PyObject* THPStorage_pynew(
|
||||
THPUtils_typename(item.get()));
|
||||
return nullptr;
|
||||
}
|
||||
return (PyObject*)self.release();
|
||||
}
|
||||
return self;
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static Py_ssize_t THPStorage_length(THPStorage* self) {
|
||||
HANDLE_TH_ERRORS
|
||||
THPStorage_assertNotNull(self);
|
||||
return static_cast<Py_ssize_t>(THPStorage_Unpack(self).nbytes());
|
||||
END_HANDLE_TH_ERRORS_RET(-1)
|
||||
}
|
||||
|
||||
static PyObject* THPStorage_get(THPStorage* self, PyObject* index) {
|
||||
HANDLE_TH_ERRORS
|
||||
THPStorage_assertNotNull(self);
|
||||
const auto& storage = THPStorage_Unpack(self);
|
||||
int64_t len = static_cast<int64_t>(storage.nbytes());
|
||||
/* Integer index */
|
||||
@ -289,7 +536,11 @@ static PyObject* THPStorage_get(THPStorage* self, PyObject* index) {
|
||||
old_storage_impl->allocator(),
|
||||
/* resizable */ false);
|
||||
|
||||
PyObject* _ret = THPStorage_New(std::move(new_storage_impl));
|
||||
PyObject* _ret = THPStorage_NewWithStorage(
|
||||
Py_TYPE(self),
|
||||
std::move(new_storage_impl),
|
||||
c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
|
||||
|
||||
return _ret;
|
||||
}
|
||||
PyErr_Format(
|
||||
@ -302,6 +553,7 @@ static PyObject* THPStorage_get(THPStorage* self, PyObject* index) {
|
||||
|
||||
static int THPStorage_set(THPStorage* self, PyObject* index, PyObject* value) {
|
||||
HANDLE_TH_ERRORS
|
||||
THPStorage_assertNotNull(self);
|
||||
if (!THPByteUtils_checkReal(value)) {
|
||||
THPUtils_setError(
|
||||
"can only set storage content with a int types, but got "
|
||||
@ -451,6 +703,7 @@ int THPStorageMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs) {
|
||||
|
||||
static PyObject* THPStorage_device(THPStorage* self, void* unused) {
|
||||
HANDLE_TH_ERRORS
|
||||
THPStorage_assertNotNull(self);
|
||||
return THPDevice_New(THPStorage_Unpack(self).device());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
@ -490,7 +743,17 @@ bool THPStorage_init(PyObject* module) {
|
||||
}
|
||||
|
||||
void THPStorage_postInit(PyObject* module) {
|
||||
THPStorageClass = PyObject_GetAttrString(module, "UntypedStorage");
|
||||
THPStorageClass =
|
||||
(PyTypeObject*)PyObject_GetAttrString(module, "UntypedStorage");
|
||||
if (!THPStorageClass)
|
||||
throw python_error();
|
||||
}
|
||||
|
||||
void THPStorage_assertNotNull(THPStorage* storage) {
|
||||
TORCH_CHECK(
|
||||
THPStorage_Unpack(storage).unsafeGetStorageImpl(), "Got a null Storage");
|
||||
}
|
||||
|
||||
void THPStorage_assertNotNull(PyObject* obj) {
|
||||
THPStorage_assertNotNull((THPStorage*)obj);
|
||||
}
|
||||
|
@ -5,84 +5,44 @@
|
||||
|
||||
#define THPStorageStr "torch.UntypedStorage"
|
||||
|
||||
namespace c10 {
|
||||
|
||||
template <>
|
||||
struct MaybeOwnedTraits<c10::Storage> {
|
||||
using owned_type = c10::Storage;
|
||||
using borrow_type = c10::Storage;
|
||||
|
||||
static borrow_type createBorrow(const owned_type& from) {
|
||||
return borrow_type(from);
|
||||
}
|
||||
|
||||
static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) {
|
||||
lhs.unsafeReleaseStorageImpl();
|
||||
lhs = borrow_type(rhs);
|
||||
}
|
||||
|
||||
static void destroyBorrow(borrow_type& toDestroy) {
|
||||
toDestroy.unsafeReleaseStorageImpl(); // "leak" it, but it was already +0.
|
||||
}
|
||||
|
||||
static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
|
||||
return borrow;
|
||||
}
|
||||
|
||||
static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
|
||||
return &borrow;
|
||||
}
|
||||
|
||||
static bool debugBorrowIsValid(const borrow_type& /*borrow*/) {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ExclusivelyOwnedTraits<c10::Storage> {
|
||||
using repr_type = c10::Storage;
|
||||
using pointer_type = c10::Storage*;
|
||||
using const_pointer_type = const c10::Storage*;
|
||||
|
||||
static repr_type nullRepr() {
|
||||
return c10::Storage();
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
static repr_type createInPlace(Args&&... args) {
|
||||
return c10::Storage(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
static repr_type moveToRepr(c10::Storage&& x) {
|
||||
return std::move(x);
|
||||
}
|
||||
|
||||
static c10::Storage take(c10::Storage& x) {
|
||||
return std::move(x);
|
||||
}
|
||||
|
||||
static pointer_type getImpl(repr_type& x) {
|
||||
return &x;
|
||||
}
|
||||
|
||||
static const_pointer_type getImpl(const repr_type& x) {
|
||||
return &x;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace c10
|
||||
|
||||
struct THPStorage {
|
||||
PyObject_HEAD;
|
||||
c10::MaybeOwned<c10::Storage> cdata;
|
||||
bool is_hermetic;
|
||||
};
|
||||
|
||||
TORCH_PYTHON_API PyObject* THPStorage_New(c10::Storage storage);
|
||||
extern PyObject* THPStorageClass;
|
||||
TORCH_PYTHON_API PyObject* THPStorage_Wrap(c10::Storage storage);
|
||||
TORCH_PYTHON_API PyObject* THPStorage_NewWithStorage(
|
||||
PyTypeObject* type,
|
||||
c10::Storage _storage,
|
||||
c10::impl::PyInterpreterStatus status,
|
||||
bool allow_preexisting_pyobj = false);
|
||||
extern PyTypeObject* THPStorageClass;
|
||||
|
||||
static inline bool THPStorage_CheckTypeExact(PyTypeObject* tp) {
|
||||
return tp == THPStorageClass;
|
||||
}
|
||||
|
||||
static inline bool THPStorage_CheckExact(PyObject* obj) {
|
||||
return THPStorage_CheckTypeExact(Py_TYPE(obj));
|
||||
}
|
||||
|
||||
inline bool THPStorage_Check(PyObject* obj) {
|
||||
if (!THPStorageClass)
|
||||
return false;
|
||||
|
||||
const auto result = PyObject_IsInstance(obj, (PyObject*)THPStorageClass);
|
||||
if (result == -1)
|
||||
throw python_error();
|
||||
return result;
|
||||
}
|
||||
|
||||
bool THPStorage_init(PyObject* module);
|
||||
void THPStorage_postInit(PyObject* module);
|
||||
|
||||
void THPStorage_assertNotNull(THPStorage* storage);
|
||||
void THPStorage_assertNotNull(PyObject* obj);
|
||||
|
||||
extern PyTypeObject THPStorageType;
|
||||
|
||||
inline const c10::Storage& THPStorage_Unpack(THPStorage* storage) {
|
||||
|
@ -41,6 +41,7 @@
|
||||
|
||||
static PyObject* THPStorage_nbytes(PyObject* self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
THPStorage_assertNotNull(self);
|
||||
return py::cast(THPStorage_Unpack(self).sym_nbytes()).release().ptr();
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
@ -66,6 +67,7 @@ static PyObject* THPStorage_copy_(
|
||||
PyObject* args,
|
||||
PyObject* kwargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
THPStorage_assertNotNull(self);
|
||||
|
||||
at::Storage self_ = torch::createStorage(self);
|
||||
|
||||
@ -96,12 +98,14 @@ static PyObject* THPStorage_copy_(
|
||||
|
||||
static PyObject* THPStorage_elementSize(PyObject* _self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
THPStorage_assertNotNull(_self);
|
||||
return THPUtils_packInt64(sizeof(uint8_t));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THPStorage_new(PyObject* self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
THPStorage_assertNotNull(self);
|
||||
c10::Allocator* allocator = THPStorage_Unpack(self).allocator();
|
||||
auto new_storage = c10::make_intrusive<at::StorageImpl>(
|
||||
c10::StorageImpl::use_byte_size_t(),
|
||||
@ -109,12 +113,13 @@ static PyObject* THPStorage_new(PyObject* self, PyObject* noargs) {
|
||||
allocator,
|
||||
/*resizable=*/true);
|
||||
|
||||
return THPStorage_New(std::move(new_storage));
|
||||
return THPStorage_Wrap(std::move(new_storage));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THPStorage_resize_(PyObject* self, PyObject* number_arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
THPStorage_assertNotNull(self);
|
||||
const auto& storage = THPStorage_Unpack(self);
|
||||
// See Note [Invalid Python Storages]
|
||||
auto invalid = storage.data() == nullptr &&
|
||||
@ -185,6 +190,7 @@ static PyObject* THPStorage_resize_(PyObject* self, PyObject* number_arg) {
|
||||
|
||||
static PyObject* THPStorage_fill_(PyObject* self, PyObject* number_arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
THPStorage_assertNotNull(self);
|
||||
const auto& storage = THPStorage_Unpack(self);
|
||||
// See Note [Invalid Python Storages]
|
||||
auto invalid = storage.data() == nullptr &&
|
||||
@ -389,7 +395,7 @@ static PyObject* THPStorage_fromBuffer(
|
||||
}
|
||||
|
||||
PyBuffer_Release(&buffer);
|
||||
return (PyObject*)THPStorage_New(storage);
|
||||
return THPStorage_Wrap(storage);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
@ -429,12 +435,16 @@ static PyObject* THPStorage_fromFile(
|
||||
storage->set_nbytes(actual_nbytes);
|
||||
}
|
||||
|
||||
return (PyObject*)THPStorage_New(std::move(storage));
|
||||
return THPStorage_NewWithStorage(
|
||||
THPStorageClass,
|
||||
std::move(storage),
|
||||
c10::impl::PyInterpreterStatus::TAGGED_BY_US);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THPStorage_writeFile(PyObject* self, PyObject* args) {
|
||||
HANDLE_TH_ERRORS
|
||||
THPStorage_assertNotNull(self);
|
||||
const auto& storage = THPStorage_Unpack(self);
|
||||
// See Note [Invalid Python Storages]
|
||||
auto invalid = storage.data() == nullptr &&
|
||||
@ -486,12 +496,13 @@ PyObject* THPStorage_newWithFile(PyObject* _unused, PyObject* args) {
|
||||
auto storage = THPStorage_readFileRaw<int>(fd, {}, element_size);
|
||||
if (!storage.defined())
|
||||
return nullptr;
|
||||
return THPStorage_New(std::move(storage));
|
||||
return THPStorage_Wrap(std::move(storage));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THPStorage_setFromFile(PyObject* self, PyObject* args) {
|
||||
HANDLE_TH_ERRORS
|
||||
THPStorage_assertNotNull(self);
|
||||
const auto& storage = THPStorage_Unpack(self);
|
||||
PyObject* file = PyTuple_GET_ITEM(args, 0);
|
||||
PyObject* offset = PyTuple_GET_ITEM(args, 1);
|
||||
@ -617,6 +628,12 @@ PyObject* THPStorage_byteswap(PyObject* self, PyObject* args) {
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THPStorage_fix_weakref(PyObject* self, PyObject* noargs) {
|
||||
const auto& storage = THPStorage_Unpack(self);
|
||||
Py_DECREF(THPStorage_Wrap(storage));
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
|
||||
static PyMethodDef THPStorage_methods[] = {
|
||||
{"copy_",
|
||||
@ -645,6 +662,7 @@ static PyMethodDef THPStorage_methods[] = {
|
||||
nullptr},
|
||||
{"_set_cdata", THPStorage__setCdata, METH_O, nullptr},
|
||||
{"_byteswap", THPStorage_byteswap, METH_VARARGS, nullptr},
|
||||
{"_fix_weakref", THPStorage_fix_weakref, METH_NOARGS, nullptr},
|
||||
{nullptr}};
|
||||
|
||||
PyMethodDef* THPStorage_getMethods() {
|
||||
|
@ -33,6 +33,7 @@
|
||||
|
||||
static PyObject* THPStorage_sharedDecref(PyObject* self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
THPStorage_assertNotNull(self);
|
||||
const auto& storage = THPStorage_Unpack(self);
|
||||
c10::DeviceType device_type = storage.device_type();
|
||||
if (device_type == at::kCPU) {
|
||||
@ -49,6 +50,7 @@ static PyObject* THPStorage_sharedDecref(PyObject* self, PyObject* noargs) {
|
||||
|
||||
static PyObject* THPStorage_sharedIncref(PyObject* self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
THPStorage_assertNotNull(self);
|
||||
const auto& storage = THPStorage_Unpack(self);
|
||||
c10::DeviceType device_type = storage.device_type();
|
||||
if (device_type == at::kCPU) {
|
||||
@ -76,18 +78,22 @@ static PyObject* THPStorage_pyNewFilenameStorage(
|
||||
|
||||
int flags = at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_EXCLUSIVE;
|
||||
std::string handle = at::NewProcessWideShmHandle();
|
||||
return THPStorage_New(c10::make_intrusive<at::StorageImpl>(
|
||||
c10::StorageImpl::use_byte_size_t(),
|
||||
size,
|
||||
THManagedMapAllocator::makeDataPtr(
|
||||
"", handle.c_str(), flags, static_cast<size_t>(size)),
|
||||
/*allocator=*/nullptr,
|
||||
/*resizable=*/false));
|
||||
return THPStorage_NewWithStorage(
|
||||
THPStorageClass,
|
||||
c10::make_intrusive<at::StorageImpl>(
|
||||
c10::StorageImpl::use_byte_size_t(),
|
||||
size,
|
||||
THManagedMapAllocator::makeDataPtr(
|
||||
"", handle.c_str(), flags, static_cast<size_t>(size)),
|
||||
/*allocator=*/nullptr,
|
||||
/*resizable=*/false),
|
||||
c10::impl::PyInterpreterStatus::TAGGED_BY_US);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THPStorage_shareFilename(PyObject* self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
THPStorage_assertNotNull(self);
|
||||
const auto& storage = THPStorage_Unpack(self);
|
||||
TORCH_CHECK(
|
||||
storage.device_type() == at::kCPU,
|
||||
@ -168,13 +174,16 @@ static PyObject* THPStorage_newSharedFilename(
|
||||
const char* object_handle = PyBytes_AS_STRING(_object_handle);
|
||||
uint64_t size = THPUtils_unpackUInt64(_size);
|
||||
int flags = at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_NOCREATE;
|
||||
return THPStorage_New(c10::make_intrusive<at::StorageImpl>(
|
||||
c10::StorageImpl::use_byte_size_t(),
|
||||
size,
|
||||
THManagedMapAllocator::makeDataPtr(
|
||||
manager_handle, object_handle, flags, size),
|
||||
/*allocator=*/nullptr,
|
||||
/*resizable=*/false));
|
||||
return THPStorage_NewWithStorage(
|
||||
THPStorageClass,
|
||||
c10::make_intrusive<at::StorageImpl>(
|
||||
c10::StorageImpl::use_byte_size_t(),
|
||||
size,
|
||||
THManagedMapAllocator::makeDataPtr(
|
||||
manager_handle, object_handle, flags, size),
|
||||
/*allocator=*/nullptr,
|
||||
/*resizable=*/false),
|
||||
c10::impl::PyInterpreterStatus::TAGGED_BY_US);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
@ -187,12 +196,16 @@ static PyObject* THPStorage_pyNewFdStorage(PyObject* _unused, PyObject* args) {
|
||||
if (size < 0) {
|
||||
return nullptr;
|
||||
}
|
||||
return THPStorage_New(at::new_shm_fd_storage(size));
|
||||
return THPStorage_NewWithStorage(
|
||||
THPStorageClass,
|
||||
at::new_shm_fd_storage(size),
|
||||
c10::impl::PyInterpreterStatus::TAGGED_BY_US);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THPStorage_shareFd(PyObject* self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
THPStorage_assertNotNull(self);
|
||||
const auto& storage = THPStorage_Unpack(self);
|
||||
TORCH_CHECK(
|
||||
storage.device_type() == at::kCPU, "_share_fd_: only available on CPU");
|
||||
@ -257,17 +270,22 @@ static PyObject* THPStorage_newSharedFd(PyObject* _unused, PyObject* args) {
|
||||
|
||||
int flags = at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_NOCREATE |
|
||||
at::ALLOCATOR_MAPPED_KEEPFD | at::ALLOCATOR_MAPPED_FROMFD;
|
||||
return THPStorage_New(c10::make_intrusive<at::StorageImpl>(
|
||||
c10::StorageImpl::use_byte_size_t(),
|
||||
size,
|
||||
at::MapAllocator::makeDataPtr(at::WITH_FD, "", fd, flags, size, nullptr),
|
||||
/*allocator=*/nullptr,
|
||||
/*resizable=*/false));
|
||||
return THPStorage_NewWithStorage(
|
||||
THPStorageClass,
|
||||
c10::make_intrusive<at::StorageImpl>(
|
||||
c10::StorageImpl::use_byte_size_t(),
|
||||
size,
|
||||
at::MapAllocator::makeDataPtr(
|
||||
at::WITH_FD, "", fd, flags, size, nullptr),
|
||||
/*allocator=*/nullptr,
|
||||
/*resizable=*/false),
|
||||
c10::impl::PyInterpreterStatus::TAGGED_BY_US);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THPStorage_shareCuda(PyObject* self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
THPStorage_assertNotNull(self);
|
||||
#ifdef USE_CUDA
|
||||
const auto& storage = THPStorage_Unpack(self);
|
||||
TORCH_CHECK(
|
||||
@ -547,7 +565,10 @@ static PyObject* THPStorage_newSharedCuda(PyObject* _unused, PyObject* args) {
|
||||
base->set_resizable(false);
|
||||
base->set_received_cuda(true);
|
||||
|
||||
return THPStorage_New(std::move(base));
|
||||
return THPStorage_NewWithStorage(
|
||||
THPStorageClass,
|
||||
std::move(base),
|
||||
c10::impl::PyInterpreterStatus::TAGGED_BY_US);
|
||||
#else
|
||||
TORCH_CHECK(false, "CUDA is not available");
|
||||
#endif
|
||||
@ -572,7 +593,7 @@ PyObject* THPStorage_newWithWeakPtr(PyObject* _unused, PyObject* arg) {
|
||||
THPUtils_checkLong(arg), "_new_with_weak_ptr(): arg must be an 'int'");
|
||||
c10::StorageImpl* weak_storage = (c10::StorageImpl*)PyLong_AsVoidPtr(arg);
|
||||
if (auto* storage = c10::raw::weak_intrusive_ptr::lock(weak_storage)) {
|
||||
return THPStorage_New(
|
||||
return THPStorage_Wrap(
|
||||
c10::intrusive_ptr<c10::StorageImpl>::reclaim(storage));
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
@ -604,6 +625,7 @@ PyObject* THPStorage_expired(PyObject* _unused, PyObject* arg) {
|
||||
|
||||
PyObject* THPStorage_sharedFd(PyObject* self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
THPStorage_assertNotNull(self);
|
||||
at::MapAllocator* ctx = nullptr;
|
||||
const auto& storage = THPStorage_Unpack(self);
|
||||
if (storage.device_type() == at::kCPU) {
|
||||
|
@ -26,6 +26,7 @@
|
||||
#include <torch/csrc/tensor/python_tensor.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
#include <torch/csrc/utils/pycfunction_helpers.h>
|
||||
#include <torch/csrc/utils/pyobject_preservation.h>
|
||||
#include <torch/csrc/utils/python_arg_parser.h>
|
||||
#include <torch/csrc/utils/python_dispatch.h>
|
||||
#include <torch/csrc/utils/python_strings.h>
|
||||
@ -268,7 +269,8 @@ PyObject* THPVariable_Wrap(at::TensorBase var) {
|
||||
}
|
||||
|
||||
c10::optional<PyObject*> mb_obj =
|
||||
var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(getPyInterpreter());
|
||||
var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
|
||||
getPyInterpreter(), /*ignore_hermetic_tls=*/false);
|
||||
c10::impl::PyInterpreterStatus status;
|
||||
if (mb_obj.has_value()) {
|
||||
auto obj = *mb_obj;
|
||||
@ -345,7 +347,8 @@ bool isResurrectable(THPVariable* self) {
|
||||
}
|
||||
// Check if this is hermetic. If it is, no resurrection.
|
||||
if (tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
|
||||
getPyInterpreter()) != c10::make_optional((PyObject*)self)) {
|
||||
getPyInterpreter(), /*ignore_hermetic_tls=*/false) !=
|
||||
c10::make_optional((PyObject*)self)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
@ -369,7 +372,16 @@ static bool THPVariable_tryResurrect(THPVariable* self) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
!tensor.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj());
|
||||
|
||||
tensor.unsafeGetTensorImpl()->pyobj_slot()->set_owns_pyobj(true);
|
||||
c10::TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
|
||||
auto maybe_pyobj = tensor_impl->pyobj_slot()->check_pyobj(
|
||||
getPyInterpreter(),
|
||||
/*ignore_hermetic_tls=*/false);
|
||||
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
maybe_pyobj.has_value(),
|
||||
"Trying to preserve a Python tensor whose PyObjectSlot does not have a PyObject");
|
||||
|
||||
tensor_impl->pyobj_slot()->set_owns_pyobj(true);
|
||||
|
||||
// Resurrect the Python object. This is something CPython does
|
||||
// internally occasionally, see
|
||||
@ -443,7 +455,8 @@ static int THPVariable_clear(THPVariable* self) {
|
||||
|
||||
if (!self->cdata.unsafeIsBorrowed() &&
|
||||
tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
|
||||
getPyInterpreter()) == c10::make_optional((PyObject*)self)) {
|
||||
getPyInterpreter(), /*ignore_hermetic_tls=*/false) ==
|
||||
c10::make_optional((PyObject*)self)) {
|
||||
// TODO: empirically, on OS X this assert appears to be untrue
|
||||
// In test_py_tensors_multi_async_call - ProcessGroupRpcTestWithSpawn
|
||||
// distributed/rpc/test_process_group_agent.py
|
||||
@ -1747,26 +1760,6 @@ PyObject* THPVariable_pynew(
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static void clear_slots(PyTypeObject* type, PyObject* self) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
Py_ssize_t i, n;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
PyMemberDef* mp;
|
||||
|
||||
n = Py_SIZE(type);
|
||||
mp = type->tp_members;
|
||||
for (i = 0; i < n; i++, mp++) {
|
||||
if (mp->type == T_OBJECT_EX && !(mp->flags & READONLY)) {
|
||||
char* addr = (char*)self + mp->offset;
|
||||
PyObject* obj = *(PyObject**)addr;
|
||||
if (obj != nullptr) {
|
||||
*(PyObject**)addr = nullptr;
|
||||
Py_DECREF(obj);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NB: this is not the tp_dealloc on THPVariable; instead, its the dealloc
|
||||
// on subclasses. It's never valid to construct a THPVariable so it's not
|
||||
// necessary to implement the dealloc for that case
|
||||
@ -1886,8 +1879,8 @@ static PyObject* THPVariable_NewWithVar(
|
||||
|
||||
// This function overwrite the Tensor's pyobj field without extra checks
|
||||
// Make sure it is not set otherwise we would leak memory
|
||||
auto mb_obj =
|
||||
_var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(getPyInterpreter());
|
||||
auto mb_obj = _var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
|
||||
getPyInterpreter(), /*ignore_hermetic_tls=*/false);
|
||||
|
||||
// Under some circumstances, we may attempt to create a new Python
|
||||
// object for a variable that already has a Python object. The most common
|
||||
|
19
torch/csrc/utils/pyobject_preservation.cpp
Normal file
19
torch/csrc/utils/pyobject_preservation.cpp
Normal file
@ -0,0 +1,19 @@
|
||||
#include <torch/csrc/utils/pyobject_preservation.h>
|
||||
|
||||
#include <structmember.h>
|
||||
|
||||
void clear_slots(PyTypeObject* type, PyObject* self) {
|
||||
Py_ssize_t n = Py_SIZE(type);
|
||||
PyMemberDef* mp = type->tp_members;
|
||||
|
||||
for (Py_ssize_t i = 0; i < n; i++, mp++) {
|
||||
if (mp->type == T_OBJECT_EX && !(mp->flags & READONLY)) {
|
||||
char* addr = (char*)self + mp->offset;
|
||||
PyObject* obj = *(PyObject**)addr;
|
||||
if (obj != nullptr) {
|
||||
*(PyObject**)addr = nullptr;
|
||||
Py_DECREF(obj);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
7
torch/csrc/utils/pyobject_preservation.h
Normal file
7
torch/csrc/utils/pyobject_preservation.h
Normal file
@ -0,0 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/python_headers.h>
|
||||
|
||||
// This file contains utilities used for handling PyObject preservation
|
||||
|
||||
void clear_slots(PyTypeObject* type, PyObject* self);
|
@ -1106,8 +1106,8 @@ inline at::Storage PythonArgs::storage(
|
||||
is_typed_storage = false;
|
||||
storage_scalar_type = at::ScalarType::Undefined;
|
||||
} else {
|
||||
storage =
|
||||
createStorageGetType(args[i], storage_scalar_type, is_typed_storage);
|
||||
std::tie(storage, storage_scalar_type, is_typed_storage) =
|
||||
createStorageGetType(args[i]);
|
||||
}
|
||||
return storage;
|
||||
}
|
||||
|
@ -385,10 +385,12 @@ Tensor internal_new_from_data(
|
||||
at::tracer::impl::NoTracerDispatchMode tracer_guard;
|
||||
|
||||
if (isStorage(data)) {
|
||||
ScalarType storage_scalar_type{ScalarType::Undefined};
|
||||
bool is_typed_storage = false;
|
||||
Storage storage =
|
||||
createStorageGetType(data, storage_scalar_type, is_typed_storage);
|
||||
ScalarType storage_scalar_type{ScalarType::Undefined};
|
||||
Storage storage;
|
||||
std::tie(storage, storage_scalar_type, is_typed_storage) =
|
||||
createStorageGetType(data);
|
||||
|
||||
TORCH_CHECK(
|
||||
!is_typed_storage || storage_scalar_type == scalar_type,
|
||||
"Expected a Storage of type ",
|
||||
|
Reference in New Issue
Block a user