mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Reland 2: Add PyObject preservation for UntypedStorage (#109039)
Relands #103907 after it was reverted. This PR makes the new `ignore_hermetic_tls` argument of `check_pyobj` optional to avoid causing a compilation error in torchdistx Part of #91395 Pull Request resolved: https://github.com/pytorch/pytorch/pull/109039 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
6dc56d3490
commit
4c5e43574c
@ -896,6 +896,7 @@ libtorch_python_core_sources = [
|
|||||||
"torch/csrc/utils/python_dispatch.cpp",
|
"torch/csrc/utils/python_dispatch.cpp",
|
||||||
"torch/csrc/utils/python_symnode.cpp",
|
"torch/csrc/utils/python_symnode.cpp",
|
||||||
"torch/csrc/utils/pybind.cpp",
|
"torch/csrc/utils/pybind.cpp",
|
||||||
|
"torch/csrc/utils/pyobject_preservation.cpp",
|
||||||
"torch/csrc/utils/structseq.cpp",
|
"torch/csrc/utils/structseq.cpp",
|
||||||
"torch/csrc/utils/tensor_apply.cpp",
|
"torch/csrc/utils/tensor_apply.cpp",
|
||||||
"torch/csrc/utils/tensor_dtypes.cpp",
|
"torch/csrc/utils/tensor_dtypes.cpp",
|
||||||
|
78
c10/core/RefcountedDeleter.cpp
Normal file
78
c10/core/RefcountedDeleter.cpp
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
#include <c10/core/RefcountedDeleter.h>
|
||||||
|
|
||||||
|
#include <mutex>
|
||||||
|
|
||||||
|
namespace c10 {
|
||||||
|
|
||||||
|
void refcounted_deleter(void* ctx_) {
|
||||||
|
RefcountedDeleterContext& ctx =
|
||||||
|
*reinterpret_cast<RefcountedDeleterContext*>(ctx_);
|
||||||
|
ctx.refcount--;
|
||||||
|
if (ctx.refcount == 0) {
|
||||||
|
ctx.other_ctx = nullptr;
|
||||||
|
delete &ctx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::mutex replace_data_ptr_mutex;
|
||||||
|
|
||||||
|
void maybeApplyRefcountedDeleter(const c10::Storage& storage) {
|
||||||
|
std::lock_guard<std::mutex> guard(replace_data_ptr_mutex);
|
||||||
|
c10::DataPtr& data_ptr = storage.mutable_data_ptr();
|
||||||
|
|
||||||
|
if ((void*)data_ptr.get_deleter() == (void*)&c10::refcounted_deleter) {
|
||||||
|
// Data pointer is already shared
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void* data = data_ptr.get();
|
||||||
|
void* other_ctx = data_ptr.get_context();
|
||||||
|
c10::DeleterFnPtr other_deleter = data_ptr.get_deleter();
|
||||||
|
c10::Device device = data_ptr.device();
|
||||||
|
|
||||||
|
// Release the context of the original DataPtr so that the data doesn't
|
||||||
|
// get deleted when the original DataPtr is replaced
|
||||||
|
data_ptr.release_context();
|
||||||
|
|
||||||
|
c10::RefcountedDeleterContext* refcount_ctx =
|
||||||
|
new c10::RefcountedDeleterContext(other_ctx, other_deleter);
|
||||||
|
|
||||||
|
c10::DataPtr new_data_ptr(
|
||||||
|
data,
|
||||||
|
reinterpret_cast<void*>(refcount_ctx),
|
||||||
|
&c10::refcounted_deleter,
|
||||||
|
device);
|
||||||
|
storage.set_data_ptr(std::move(new_data_ptr));
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::Storage newStorageImplFromRefcountedDataPtr(const c10::Storage& storage) {
|
||||||
|
c10::maybeApplyRefcountedDeleter(storage);
|
||||||
|
|
||||||
|
c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl();
|
||||||
|
|
||||||
|
c10::DataPtr& data_ptr = storage.mutable_data_ptr();
|
||||||
|
c10::DataPtr new_data_ptr(
|
||||||
|
data_ptr.get(),
|
||||||
|
data_ptr.get_context(),
|
||||||
|
data_ptr.get_deleter(),
|
||||||
|
data_ptr.device());
|
||||||
|
|
||||||
|
// NOTE: This refcount increment should always happen immediately after
|
||||||
|
// `new_data_ptr` is created. No other lines of code should be added between
|
||||||
|
// them in the future, unless there's a very good reason for it, because if
|
||||||
|
// any errors are raised and `new_data_ptr` is deleted before the refcount is
|
||||||
|
// incremented, the refcount will get decremented and end up being one less
|
||||||
|
// than it should be.
|
||||||
|
reinterpret_cast<c10::RefcountedDeleterContext*>(data_ptr.get_context())
|
||||||
|
->refcount++;
|
||||||
|
|
||||||
|
c10::Storage new_storage = c10::make_intrusive<c10::StorageImpl>(
|
||||||
|
c10::StorageImpl::use_byte_size_t(),
|
||||||
|
storage_impl->nbytes(),
|
||||||
|
std::move(new_data_ptr),
|
||||||
|
storage_impl->allocator(),
|
||||||
|
/*resizable=*/storage_impl->resizable());
|
||||||
|
return new_storage;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace c10
|
51
c10/core/RefcountedDeleter.h
Normal file
51
c10/core/RefcountedDeleter.h
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <c10/core/Storage.h>
|
||||||
|
#include <c10/util/UniqueVoidPtr.h>
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace c10 {
|
||||||
|
|
||||||
|
// A RefcountedDeleterContext object is used as the `ctx` argument for DataPtr
|
||||||
|
// to implement a shared DataPtr. Normally, a DataPtr is unique, but we use
|
||||||
|
// this custom context and the `refcounted_deleter` function below to make the
|
||||||
|
// DataPtr act like a non-unique DataPtr. This context object holds onto an
|
||||||
|
// inner context and deleter function which handle the actual deletion of the
|
||||||
|
// data when the refcount reaches 0.
|
||||||
|
//
|
||||||
|
// This shared DataPtr feature is only used when storages are shared between
|
||||||
|
// multiple Python interpreters in MultiPy. Before storages had PyObject
|
||||||
|
// preservation, interpreters could just share the same StorageImpl instance.
|
||||||
|
// But now a StorageImpl can only be associated with one interpreter in order
|
||||||
|
// to properly manage a zombie PyObject. So we share storages across Python
|
||||||
|
// interpreters by creating a different StorageImpl instance for each one, but
|
||||||
|
// they all point to the same data.
|
||||||
|
struct C10_API RefcountedDeleterContext {
|
||||||
|
RefcountedDeleterContext(void* other_ctx, c10::DeleterFnPtr other_deleter)
|
||||||
|
: other_ctx(other_ctx, other_deleter), refcount(1) {}
|
||||||
|
|
||||||
|
std::unique_ptr<void, c10::DeleterFnPtr> other_ctx;
|
||||||
|
std::atomic_int refcount;
|
||||||
|
};
|
||||||
|
|
||||||
|
// `refcounted_deleter` is used as the `ctx_deleter` for DataPtr to implement
|
||||||
|
// a shared DataPtr.
|
||||||
|
//
|
||||||
|
// Warning: This should only be called on a pointer to
|
||||||
|
// a RefcountedDeleterContext that was allocated on the heap with `new`,
|
||||||
|
// because when the refcount reaches 0, the context is deleted with `delete`
|
||||||
|
C10_API void refcounted_deleter(void* ctx_);
|
||||||
|
|
||||||
|
// If the storage's DataPtr does not use `refcounted_deleter`, replace it with
|
||||||
|
// a DataPtr that does, so it can be shared between multiple StorageImpls
|
||||||
|
C10_API void maybeApplyRefcountedDeleter(const c10::Storage& storage);
|
||||||
|
|
||||||
|
// Create a new StorageImpl that points to the same data. If the original
|
||||||
|
// StorageImpl's DataPtr does not use `refcounted_deleter`, it will be replaced
|
||||||
|
// with one that does
|
||||||
|
C10_API c10::Storage newStorageImplFromRefcountedDataPtr(
|
||||||
|
const c10::Storage& storage);
|
||||||
|
|
||||||
|
} // namespace c10
|
@ -33,7 +33,7 @@ struct C10_API SafePyObject {
|
|||||||
|
|
||||||
~SafePyObject() {
|
~SafePyObject() {
|
||||||
if (data_ != nullptr) {
|
if (data_ != nullptr) {
|
||||||
(*pyinterpreter_)->decref(data_, /*is_tensor*/ false);
|
(*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,3 +1,18 @@
|
|||||||
|
#include <c10/core/RefcountedDeleter.h>
|
||||||
#include <c10/core/Storage.h>
|
#include <c10/core/Storage.h>
|
||||||
|
|
||||||
namespace c10 {} // namespace c10
|
namespace c10 {
|
||||||
|
|
||||||
|
bool isSharedStorageAlias(const Storage& storage0, const Storage& storage1) {
|
||||||
|
c10::DeleterFnPtr deleter_expected = &c10::refcounted_deleter;
|
||||||
|
c10::DeleterFnPtr deleter0 = storage0.data_ptr().get_deleter();
|
||||||
|
c10::DeleterFnPtr deleter1 = storage1.data_ptr().get_deleter();
|
||||||
|
|
||||||
|
if ((deleter0 != deleter_expected) || (deleter1 != deleter_expected)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return storage0.data_ptr().get_context() == storage1.data_ptr().get_context();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace c10
|
||||||
|
@ -1,12 +1,22 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <c10/core/StorageImpl.h>
|
#include <c10/core/StorageImpl.h>
|
||||||
|
#include <c10/util/ExclusivelyOwned.h>
|
||||||
|
|
||||||
namespace c10 {
|
namespace c10 {
|
||||||
|
|
||||||
|
struct Storage;
|
||||||
|
|
||||||
|
C10_API bool isSharedStorageAlias(
|
||||||
|
const Storage& storage0,
|
||||||
|
const Storage& storage1);
|
||||||
|
|
||||||
struct C10_API Storage {
|
struct C10_API Storage {
|
||||||
public:
|
public:
|
||||||
struct use_byte_size_t {};
|
struct use_byte_size_t {};
|
||||||
|
struct unsafe_borrow_t {
|
||||||
|
explicit unsafe_borrow_t() = default;
|
||||||
|
};
|
||||||
|
|
||||||
Storage() = default;
|
Storage() = default;
|
||||||
Storage(c10::intrusive_ptr<StorageImpl> ptr)
|
Storage(c10::intrusive_ptr<StorageImpl> ptr)
|
||||||
@ -40,6 +50,14 @@ struct C10_API Storage {
|
|||||||
allocator,
|
allocator,
|
||||||
resizable)) {}
|
resizable)) {}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
explicit Storage(unsafe_borrow_t, const Storage& rhs)
|
||||||
|
: storage_impl_(c10::intrusive_ptr<c10::StorageImpl>::reclaim(
|
||||||
|
rhs.storage_impl_.get())) {}
|
||||||
|
|
||||||
|
friend MaybeOwnedTraits<Storage>;
|
||||||
|
|
||||||
|
public:
|
||||||
// Legacy constructor for partially initialized (dtype or memory) storages
|
// Legacy constructor for partially initialized (dtype or memory) storages
|
||||||
// that can be temporarily created with Caffe2 APIs. See the note on top of
|
// that can be temporarily created with Caffe2 APIs. See the note on top of
|
||||||
// TensorImpl.h for details.
|
// TensorImpl.h for details.
|
||||||
@ -144,7 +162,9 @@ struct C10_API Storage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool is_alias_of(const Storage& other) const {
|
bool is_alias_of(const Storage& other) const {
|
||||||
return storage_impl_ == other.storage_impl_;
|
return (
|
||||||
|
storage_impl_ == other.storage_impl_ ||
|
||||||
|
isSharedStorageAlias(*this, other));
|
||||||
}
|
}
|
||||||
|
|
||||||
void UniqueStorageShareExternalPointer(
|
void UniqueStorageShareExternalPointer(
|
||||||
@ -175,4 +195,67 @@ struct C10_API Storage {
|
|||||||
c10::intrusive_ptr<StorageImpl> storage_impl_;
|
c10::intrusive_ptr<StorageImpl> storage_impl_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct MaybeOwnedTraits<c10::Storage> {
|
||||||
|
using owned_type = c10::Storage;
|
||||||
|
using borrow_type = c10::Storage;
|
||||||
|
|
||||||
|
static borrow_type createBorrow(const owned_type& from) {
|
||||||
|
return borrow_type(borrow_type::unsafe_borrow_t{}, from);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) {
|
||||||
|
lhs.unsafeReleaseStorageImpl();
|
||||||
|
lhs = borrow_type(borrow_type::unsafe_borrow_t{}, rhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void destroyBorrow(borrow_type& toDestroy) {
|
||||||
|
toDestroy.unsafeReleaseStorageImpl(); // "leak" it, but it was already +0.
|
||||||
|
}
|
||||||
|
|
||||||
|
static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
|
||||||
|
return borrow;
|
||||||
|
}
|
||||||
|
|
||||||
|
static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
|
||||||
|
return &borrow;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool debugBorrowIsValid(const borrow_type& /*borrow*/) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct ExclusivelyOwnedTraits<c10::Storage> {
|
||||||
|
using repr_type = c10::Storage;
|
||||||
|
using pointer_type = c10::Storage*;
|
||||||
|
using const_pointer_type = const c10::Storage*;
|
||||||
|
|
||||||
|
static repr_type nullRepr() {
|
||||||
|
return c10::Storage();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class... Args>
|
||||||
|
static repr_type createInPlace(Args&&... args) {
|
||||||
|
return c10::Storage(std::forward<Args>(args)...);
|
||||||
|
}
|
||||||
|
|
||||||
|
static repr_type moveToRepr(c10::Storage&& x) {
|
||||||
|
return std::move(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
static c10::Storage take(c10::Storage& x) {
|
||||||
|
return std::move(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
static pointer_type getImpl(repr_type& x) {
|
||||||
|
return &x;
|
||||||
|
}
|
||||||
|
|
||||||
|
static const_pointer_type getImpl(const repr_type& x) {
|
||||||
|
return &x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace c10
|
} // namespace c10
|
||||||
|
@ -203,6 +203,14 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
|
|||||||
return received_cuda_;
|
return received_cuda_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl::PyObjectSlot* pyobj_slot() {
|
||||||
|
return &pyobj_slot_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const impl::PyObjectSlot* pyobj_slot() const {
|
||||||
|
return &pyobj_slot_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DataPtr data_ptr_;
|
DataPtr data_ptr_;
|
||||||
SymInt size_bytes_;
|
SymInt size_bytes_;
|
||||||
|
@ -73,9 +73,7 @@ void TensorImpl::_set_fw_grad(
|
|||||||
autograd_meta_->set_fw_grad(new_grad, self, level, is_inplace_op);
|
autograd_meta_->set_fw_grad(new_grad, self, level, is_inplace_op);
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorImpl::~TensorImpl() {
|
TensorImpl::~TensorImpl() = default;
|
||||||
pyobj_slot_.destroy_pyobj_if_needed();
|
|
||||||
}
|
|
||||||
|
|
||||||
TensorImpl::TensorImpl(
|
TensorImpl::TensorImpl(
|
||||||
Storage&& storage,
|
Storage&& storage,
|
||||||
@ -582,7 +580,7 @@ void TensorImpl::release_resources() {
|
|||||||
if (storage_) {
|
if (storage_) {
|
||||||
storage_ = {};
|
storage_ = {};
|
||||||
}
|
}
|
||||||
pyobj_slot_.destroy_pyobj_if_needed();
|
pyobj_slot_.maybe_destroy_pyobj();
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
|
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
|
||||||
|
@ -10,7 +10,8 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
|
|||||||
return "<unloaded interpreter>";
|
return "<unloaded interpreter>";
|
||||||
}
|
}
|
||||||
|
|
||||||
void decref(PyObject* pyobj, bool is_tensor) const override {} // do nothing
|
void decref(PyObject* pyobj, bool has_pyobj_slot) const override {
|
||||||
|
} // do nothing
|
||||||
|
|
||||||
#define PANIC(m) \
|
#define PANIC(m) \
|
||||||
TORCH_INTERNAL_ASSERT( \
|
TORCH_INTERNAL_ASSERT( \
|
||||||
|
@ -127,8 +127,8 @@ struct C10_API PyInterpreterVTable {
|
|||||||
virtual std::string name() const = 0;
|
virtual std::string name() const = 0;
|
||||||
|
|
||||||
// Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call
|
// Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call
|
||||||
// See NOTE [PyInterpreter::decref takes an `is_tensor` arg]
|
// See NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg]
|
||||||
virtual void decref(PyObject* pyobj, bool is_tensor) const = 0;
|
virtual void decref(PyObject* pyobj, bool has_pyobj_slot) const = 0;
|
||||||
|
|
||||||
// Perform a detach by deferring to the __torch_dispatch__ implementation of
|
// Perform a detach by deferring to the __torch_dispatch__ implementation of
|
||||||
// detach, which will also arrange for the PyObject to get copied in this
|
// detach, which will also arrange for the PyObject to get copied in this
|
||||||
|
@ -5,12 +5,16 @@ namespace impl {
|
|||||||
|
|
||||||
PyObjectSlot::PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {}
|
PyObjectSlot::PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {}
|
||||||
|
|
||||||
void PyObjectSlot::destroy_pyobj_if_needed() {
|
PyObjectSlot::~PyObjectSlot() {
|
||||||
|
maybe_destroy_pyobj();
|
||||||
|
}
|
||||||
|
|
||||||
|
void PyObjectSlot::maybe_destroy_pyobj() {
|
||||||
if (owns_pyobj()) {
|
if (owns_pyobj()) {
|
||||||
TORCH_INTERNAL_ASSERT(pyobj_interpreter_ != nullptr);
|
TORCH_INTERNAL_ASSERT(pyobj_interpreter_ != nullptr);
|
||||||
TORCH_INTERNAL_ASSERT(pyobj_ != nullptr);
|
TORCH_INTERNAL_ASSERT(pyobj_ != nullptr);
|
||||||
(*pyobj_interpreter_.load(std::memory_order_acquire))
|
(*pyobj_interpreter_.load(std::memory_order_acquire))
|
||||||
->decref(_unchecked_untagged_pyobj(), /*is_tensor*/ true);
|
->decref(_unchecked_untagged_pyobj(), /*has_pyobj_slot*/ true);
|
||||||
// NB: this destructor can only be entered when there are no
|
// NB: this destructor can only be entered when there are no
|
||||||
// references to this C++ object (obviously), NOR any references
|
// references to this C++ object (obviously), NOR any references
|
||||||
// to the PyObject (if there are references to the PyObject,
|
// to the PyObject (if there are references to the PyObject,
|
||||||
@ -47,6 +51,15 @@ PyInterpreter& PyObjectSlot::load_pyobj_interpreter() const {
|
|||||||
(*pyobj_interpreter_.load())->name());
|
(*pyobj_interpreter_.load())->name());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool PyObjectSlot::check_interpreter(PyInterpreter* interpreter) {
|
||||||
|
return interpreter == pyobj_interpreter();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool PyObjectSlot::has_pyobj_nonhermetic() {
|
||||||
|
return check_pyobj(pyobj_interpreter(), /*ignore_hermetic_tls=*/true)
|
||||||
|
.has_value();
|
||||||
|
}
|
||||||
|
|
||||||
bool PyObjectSlot::owns_pyobj() {
|
bool PyObjectSlot::owns_pyobj() {
|
||||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||||
return reinterpret_cast<uintptr_t>(pyobj_) & 1;
|
return reinterpret_cast<uintptr_t>(pyobj_) & 1;
|
||||||
|
@ -14,7 +14,9 @@ struct C10_API PyObjectSlot {
|
|||||||
public:
|
public:
|
||||||
PyObjectSlot();
|
PyObjectSlot();
|
||||||
|
|
||||||
void destroy_pyobj_if_needed();
|
~PyObjectSlot();
|
||||||
|
|
||||||
|
void maybe_destroy_pyobj();
|
||||||
|
|
||||||
// Associate the TensorImpl with the specified PyObject, and, if necessary,
|
// Associate the TensorImpl with the specified PyObject, and, if necessary,
|
||||||
// also tag the interpreter.
|
// also tag the interpreter.
|
||||||
@ -82,9 +84,20 @@ struct C10_API PyObjectSlot {
|
|||||||
// a non-nullopt (but possibly null) PyObject. If (possibly) untagged,
|
// a non-nullopt (but possibly null) PyObject. If (possibly) untagged,
|
||||||
// returns a nullopt. If it is definitely invalid, raises an error.
|
// returns a nullopt. If it is definitely invalid, raises an error.
|
||||||
//
|
//
|
||||||
|
// If `ignore_hermetic_tls` is false and this function is called from a
|
||||||
|
// hermetic context (ie, `HermeticPyObjectTLS::get_state()` is true), then
|
||||||
|
// nullopt is returned. If `ignore_hermetic_tls` is true, then the hermetic
|
||||||
|
// context is ignored, allowing you to check the interpreter tag of a
|
||||||
|
// nonhermetic PyObject from within a hermetic context. This is necessary
|
||||||
|
// because there are some cases where the deallocator function of a
|
||||||
|
// nonhermetic PyObject is called from within a hermetic context, so it must
|
||||||
|
// be properly treated as a nonhermetic PyObject.
|
||||||
|
//
|
||||||
// NB: this lives in header so that we can avoid actually creating the
|
// NB: this lives in header so that we can avoid actually creating the
|
||||||
// c10::optional
|
// c10::optional
|
||||||
c10::optional<PyObject*> check_pyobj(PyInterpreter* self_interpreter) const {
|
c10::optional<PyObject*> check_pyobj(
|
||||||
|
PyInterpreter* self_interpreter,
|
||||||
|
bool ignore_hermetic_tls = false) const {
|
||||||
// Note [Memory ordering on Python interpreter tag]
|
// Note [Memory ordering on Python interpreter tag]
|
||||||
impl::PyInterpreter* interpreter =
|
impl::PyInterpreter* interpreter =
|
||||||
pyobj_interpreter_.load(std::memory_order_acquire);
|
pyobj_interpreter_.load(std::memory_order_acquire);
|
||||||
@ -97,7 +110,7 @@ struct C10_API PyObjectSlot {
|
|||||||
return c10::nullopt;
|
return c10::nullopt;
|
||||||
} else if (interpreter == self_interpreter) {
|
} else if (interpreter == self_interpreter) {
|
||||||
// NB: pyobj_ could still be null!
|
// NB: pyobj_ could still be null!
|
||||||
if (c10::impl::HermeticPyObjectTLS::get_state()) {
|
if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) {
|
||||||
return c10::nullopt;
|
return c10::nullopt;
|
||||||
} else {
|
} else {
|
||||||
return c10::make_optional(_unchecked_untagged_pyobj());
|
return c10::make_optional(_unchecked_untagged_pyobj());
|
||||||
@ -118,6 +131,13 @@ struct C10_API PyObjectSlot {
|
|||||||
|
|
||||||
PyInterpreter& load_pyobj_interpreter() const;
|
PyInterpreter& load_pyobj_interpreter() const;
|
||||||
|
|
||||||
|
// Check if the PyObjectSlot's interpreter is the same as the specified
|
||||||
|
// interpreter
|
||||||
|
bool check_interpreter(PyInterpreter* interpreter);
|
||||||
|
|
||||||
|
// Check if the PyObjectSlot is holding a PyObject, owned or non-owned
|
||||||
|
bool has_pyobj_nonhermetic();
|
||||||
|
|
||||||
bool owns_pyobj();
|
bool owns_pyobj();
|
||||||
|
|
||||||
void set_owns_pyobj(bool b);
|
void set_owns_pyobj(bool b);
|
||||||
|
@ -1118,7 +1118,7 @@ int64_t _Tensor_ndim(mpy::handle h) {
|
|||||||
mpy::handle handle_from_tensor(Arena& A, TensorRef t) {
|
mpy::handle handle_from_tensor(Arena& A, TensorRef t) {
|
||||||
// fast case: tensor is live in python
|
// fast case: tensor is live in python
|
||||||
c10::optional<PyObject*> mb_obj =
|
c10::optional<PyObject*> mb_obj =
|
||||||
t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(getPyInterpreter());
|
t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(getPyInterpreter(), /*ignore_hermetic_tls=*/false);
|
||||||
if (mb_obj.has_value() && !t->unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) {
|
if (mb_obj.has_value() && !t->unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) {
|
||||||
return *mb_obj;
|
return *mb_obj;
|
||||||
}
|
}
|
||||||
|
@ -8832,6 +8832,16 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||||||
|
|
||||||
T()
|
T()
|
||||||
|
|
||||||
|
def test_storage_base_init(self):
|
||||||
|
# Direct construction not OK
|
||||||
|
self.assertRaises(RuntimeError, lambda: torch._C.StorageBase())
|
||||||
|
|
||||||
|
# But construction of subclass is OK
|
||||||
|
class T(torch._C.StorageBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
T()
|
||||||
|
|
||||||
def test_tensor_base_new(self):
|
def test_tensor_base_new(self):
|
||||||
|
|
||||||
# OK to call super().__new__, see
|
# OK to call super().__new__, see
|
||||||
@ -8844,6 +8854,18 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||||||
x = torch.ones(5)
|
x = torch.ones(5)
|
||||||
test_tensor = TestTensor(x)
|
test_tensor = TestTensor(x)
|
||||||
|
|
||||||
|
def test_storage_base_new(self):
|
||||||
|
|
||||||
|
# OK to call super().__new__, see
|
||||||
|
# https://github.com/pytorch/pytorch/issues/57421
|
||||||
|
class TestStorage(torch._C.StorageBase):
|
||||||
|
@staticmethod
|
||||||
|
def __new__(cls, x, *args, **kwargs):
|
||||||
|
return super().__new__(cls, x, *args, **kwargs)
|
||||||
|
|
||||||
|
x = torch.UntypedStorage(5)
|
||||||
|
test_storage = TestStorage(x)
|
||||||
|
|
||||||
def test_pyobj_preserved(self):
|
def test_pyobj_preserved(self):
|
||||||
x = torch.empty(2)
|
x = torch.empty(2)
|
||||||
x.foo = 2 # put something on __dict__
|
x.foo = 2 # put something on __dict__
|
||||||
@ -8868,6 +8890,160 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||||||
del z # it's dead again
|
del z # it's dead again
|
||||||
self.assertEqual(type(y.grad), MyTensor)
|
self.assertEqual(type(y.grad), MyTensor)
|
||||||
|
|
||||||
|
@skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
|
||||||
|
def test_storage_dealloc(self):
|
||||||
|
m, t = Tracker.make()
|
||||||
|
s0 = torch.UntypedStorage(10)
|
||||||
|
s1 = s0
|
||||||
|
s0._tracker = t
|
||||||
|
del t
|
||||||
|
|
||||||
|
self.assertFalse(m[0])
|
||||||
|
del s0
|
||||||
|
self.assertFalse(m[0])
|
||||||
|
del s1
|
||||||
|
self.assertTrue(m[0])
|
||||||
|
|
||||||
|
@skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
|
||||||
|
def test_storage_from_tensor_dealloc(self):
|
||||||
|
m, t = Tracker.make()
|
||||||
|
a = torch.randn(10)
|
||||||
|
s0 = a.untyped_storage()
|
||||||
|
s0._tracker = t
|
||||||
|
del t
|
||||||
|
|
||||||
|
s1 = a.untyped_storage()
|
||||||
|
self.assertTrue(s0 is s1)
|
||||||
|
self.assertTrue(hasattr(s1, '_tracker'))
|
||||||
|
|
||||||
|
del a
|
||||||
|
|
||||||
|
self.assertFalse(m[0])
|
||||||
|
del s0
|
||||||
|
self.assertFalse(m[0])
|
||||||
|
del s1
|
||||||
|
self.assertTrue(m[0])
|
||||||
|
|
||||||
|
@skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
|
||||||
|
def test_storage_from_tensor_dealloc_zombie(self):
|
||||||
|
m, t = Tracker.make()
|
||||||
|
a = torch.randn(10)
|
||||||
|
s0 = a.untyped_storage()
|
||||||
|
s0._tracker = t
|
||||||
|
del t
|
||||||
|
|
||||||
|
s1 = a.untyped_storage()
|
||||||
|
self.assertTrue(s0 is s1)
|
||||||
|
self.assertTrue(hasattr(s1, '_tracker'))
|
||||||
|
|
||||||
|
self.assertFalse(m[0])
|
||||||
|
del s0
|
||||||
|
self.assertFalse(m[0])
|
||||||
|
del s1
|
||||||
|
self.assertFalse(m[0])
|
||||||
|
del a
|
||||||
|
self.assertTrue(m[0])
|
||||||
|
|
||||||
|
@skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
|
||||||
|
def test_storage_from_tensor_dealloc_resurrected(self):
|
||||||
|
m, t = Tracker.make()
|
||||||
|
a = torch.randn(10)
|
||||||
|
s0 = a.untyped_storage()
|
||||||
|
s0._tracker = t
|
||||||
|
del t
|
||||||
|
|
||||||
|
s1 = a.untyped_storage()
|
||||||
|
self.assertTrue(s0 is s1)
|
||||||
|
self.assertTrue(hasattr(s1, '_tracker'))
|
||||||
|
|
||||||
|
self.assertFalse(m[0])
|
||||||
|
del s0
|
||||||
|
self.assertFalse(m[0])
|
||||||
|
del s1
|
||||||
|
self.assertFalse(m[0])
|
||||||
|
|
||||||
|
s0 = a.untyped_storage()
|
||||||
|
self.assertTrue(isinstance(s0, torch.UntypedStorage))
|
||||||
|
|
||||||
|
del a
|
||||||
|
self.assertFalse(m[0])
|
||||||
|
del s0
|
||||||
|
self.assertTrue(m[0])
|
||||||
|
|
||||||
|
@skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
|
||||||
|
def test_storage_dealloc_resurrected(self):
|
||||||
|
m, t = Tracker.make()
|
||||||
|
s = torch.UntypedStorage(10)
|
||||||
|
s._tracker = t
|
||||||
|
del t
|
||||||
|
|
||||||
|
a = torch.tensor(s)
|
||||||
|
self.assertFalse(m[0])
|
||||||
|
del s
|
||||||
|
|
||||||
|
self.assertFalse(m[0])
|
||||||
|
|
||||||
|
s = a.untyped_storage()
|
||||||
|
self.assertTrue(isinstance(s, torch.UntypedStorage))
|
||||||
|
|
||||||
|
del a
|
||||||
|
self.assertFalse(m[0])
|
||||||
|
del s
|
||||||
|
self.assertTrue(m[0])
|
||||||
|
|
||||||
|
@skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
|
||||||
|
def test_storage_dealloc_subclass_zombie(self):
|
||||||
|
class MyStorage(torch.UntypedStorage):
|
||||||
|
finalized_count = 0
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
MyStorage.finalized_count += 1
|
||||||
|
|
||||||
|
m, t = Tracker.make()
|
||||||
|
s = MyStorage(10)
|
||||||
|
s._tracker = t
|
||||||
|
del t
|
||||||
|
|
||||||
|
a = torch.tensor(s)
|
||||||
|
self.assertFalse(m[0])
|
||||||
|
del s
|
||||||
|
|
||||||
|
self.assertEqual(MyStorage.finalized_count, 0)
|
||||||
|
self.assertFalse(m[0])
|
||||||
|
|
||||||
|
del a
|
||||||
|
self.assertEqual(MyStorage.finalized_count, 1)
|
||||||
|
self.assertTrue(m[0])
|
||||||
|
|
||||||
|
@skipIfTorchDynamo("Tracker hook does not work in TorchDynamo")
|
||||||
|
def test_storage_dealloc_subclass_resurrected(self):
|
||||||
|
class MyStorage(torch.UntypedStorage):
|
||||||
|
finalized_count = 0
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
MyStorage.finalized_count += 1
|
||||||
|
|
||||||
|
m, t = Tracker.make()
|
||||||
|
s = MyStorage(10)
|
||||||
|
s._tracker = t
|
||||||
|
del t
|
||||||
|
|
||||||
|
a = torch.tensor(s)
|
||||||
|
self.assertFalse(m[0])
|
||||||
|
del s
|
||||||
|
|
||||||
|
self.assertEqual(MyStorage.finalized_count, 0)
|
||||||
|
self.assertFalse(m[0])
|
||||||
|
|
||||||
|
s = a.untyped_storage()
|
||||||
|
del a
|
||||||
|
self.assertFalse(m[0])
|
||||||
|
self.assertEqual(MyStorage.finalized_count, 0)
|
||||||
|
self.assertTrue(isinstance(s, MyStorage))
|
||||||
|
del s
|
||||||
|
self.assertEqual(MyStorage.finalized_count, 1)
|
||||||
|
self.assertTrue(m[0])
|
||||||
|
|
||||||
def test_tensor_slot_dealloc(self):
|
def test_tensor_slot_dealloc(self):
|
||||||
|
|
||||||
class SlotTensor1(torch._C._TensorBase):
|
class SlotTensor1(torch._C._TensorBase):
|
||||||
@ -8889,6 +9065,27 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||||||
self.assertTrue(m1[0])
|
self.assertTrue(m1[0])
|
||||||
self.assertTrue(m2[0])
|
self.assertTrue(m2[0])
|
||||||
|
|
||||||
|
def test_storage_slot_dealloc(self):
|
||||||
|
|
||||||
|
class SlotStorage1(torch._C.StorageBase):
|
||||||
|
__slots__ = ['slot1']
|
||||||
|
|
||||||
|
class SlotStorage2(SlotStorage1):
|
||||||
|
__slots__ = ['slot2']
|
||||||
|
|
||||||
|
m1, t1 = Tracker.make()
|
||||||
|
m2, t2 = Tracker.make()
|
||||||
|
slot_storage = SlotStorage2(torch.UntypedStorage(2))
|
||||||
|
slot_storage.slot1 = t1
|
||||||
|
slot_storage.slot2 = t2
|
||||||
|
del t1
|
||||||
|
del t2
|
||||||
|
self.assertFalse(m1[0])
|
||||||
|
self.assertFalse(m2[0])
|
||||||
|
del slot_storage
|
||||||
|
self.assertTrue(m1[0])
|
||||||
|
self.assertTrue(m2[0])
|
||||||
|
|
||||||
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
|
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
|
||||||
def test_tensor_dict_dealloc(self):
|
def test_tensor_dict_dealloc(self):
|
||||||
m, t = Tracker.make()
|
m, t = Tracker.make()
|
||||||
@ -8899,6 +9096,16 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||||||
del x
|
del x
|
||||||
self.assertTrue(m[0])
|
self.assertTrue(m[0])
|
||||||
|
|
||||||
|
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
|
||||||
|
def test_storage_dict_dealloc(self):
|
||||||
|
m, t = Tracker.make()
|
||||||
|
x = torch.UntypedStorage(2)
|
||||||
|
x.arf = t
|
||||||
|
del t
|
||||||
|
self.assertFalse(m[0])
|
||||||
|
del x
|
||||||
|
self.assertTrue(m[0])
|
||||||
|
|
||||||
def test_tensor_finalizer_dealloc(self):
|
def test_tensor_finalizer_dealloc(self):
|
||||||
m = [False]
|
m = [False]
|
||||||
|
|
||||||
@ -8911,9 +9118,20 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||||||
del fin_tensor
|
del fin_tensor
|
||||||
self.assertTrue(m[0])
|
self.assertTrue(m[0])
|
||||||
|
|
||||||
|
def test_storage_finalizer_dealloc(self):
|
||||||
|
m = [False]
|
||||||
|
|
||||||
|
class FinalizerStorage(torch._C.StorageBase):
|
||||||
|
def __del__(self):
|
||||||
|
m[0] = True
|
||||||
|
|
||||||
|
fin_storage = FinalizerStorage(torch.UntypedStorage(2))
|
||||||
|
self.assertFalse(m[0])
|
||||||
|
del fin_storage
|
||||||
|
self.assertTrue(m[0])
|
||||||
|
|
||||||
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
|
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
|
||||||
def test_tensor_weakref_dealloc(self):
|
def test_tensor_weakref_dealloc(self):
|
||||||
|
|
||||||
x = torch.empty(2)
|
x = torch.empty(2)
|
||||||
m = [False]
|
m = [False]
|
||||||
|
|
||||||
@ -8925,6 +9143,20 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||||||
self.assertTrue(m[0])
|
self.assertTrue(m[0])
|
||||||
self.assertEqual(wref(), None)
|
self.assertEqual(wref(), None)
|
||||||
|
|
||||||
|
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
|
||||||
|
def test_storage_weakref_dealloc(self):
|
||||||
|
|
||||||
|
x = torch.UntypedStorage(2)
|
||||||
|
m = [False]
|
||||||
|
|
||||||
|
def cb(r):
|
||||||
|
m[0] = True
|
||||||
|
|
||||||
|
wref = weakref.ref(x, cb)
|
||||||
|
del x
|
||||||
|
self.assertTrue(m[0])
|
||||||
|
self.assertEqual(wref(), None)
|
||||||
|
|
||||||
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
|
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
|
||||||
def test_tensor_cycle_via_dict(self):
|
def test_tensor_cycle_via_dict(self):
|
||||||
m1, t1 = Tracker.make()
|
m1, t1 = Tracker.make()
|
||||||
@ -8968,6 +9200,49 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||||||
self.assertTrue(m1[0])
|
self.assertTrue(m1[0])
|
||||||
self.assertTrue(m2[0])
|
self.assertTrue(m2[0])
|
||||||
|
|
||||||
|
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
|
||||||
|
def test_storage_cycle_via_dict(self):
|
||||||
|
m1, t1 = Tracker.make()
|
||||||
|
x = torch.UntypedStorage(2)
|
||||||
|
x._tracker = t1
|
||||||
|
del t1
|
||||||
|
|
||||||
|
m2, t2 = Tracker.make()
|
||||||
|
y = torch.UntypedStorage(2)
|
||||||
|
y._tracker = t2
|
||||||
|
del t2
|
||||||
|
|
||||||
|
x._loop = y
|
||||||
|
y._loop = x
|
||||||
|
|
||||||
|
# C++ reference should keep the cycle live!
|
||||||
|
# This exercise THPVariable_subtype_traverse
|
||||||
|
# NB: Because z.grad is a reference done entirely in C++, cycles
|
||||||
|
# involving it directly are NOT broken by Python GC; you've
|
||||||
|
# set up a good old C++ reference cycle which we cannot safely
|
||||||
|
# break (because C++ references are allowed to be accessed
|
||||||
|
# multithreaded-ly) (TODO: except maybe if you can prove that
|
||||||
|
# only Python has access to the C++ object, in which case you can
|
||||||
|
# also prove that no multithreaded access occurs)
|
||||||
|
z = torch.UntypedStorage(2)
|
||||||
|
z.grad = x
|
||||||
|
|
||||||
|
del x
|
||||||
|
del y
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
self.assertFalse(m1[0])
|
||||||
|
self.assertFalse(m2[0])
|
||||||
|
|
||||||
|
with disable_gc():
|
||||||
|
del z
|
||||||
|
self.assertFalse(m1[0])
|
||||||
|
self.assertFalse(m2[0])
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
self.assertTrue(m1[0])
|
||||||
|
self.assertTrue(m2[0])
|
||||||
|
|
||||||
def test_tensor_cycle_via_slots(self):
|
def test_tensor_cycle_via_slots(self):
|
||||||
m1 = [False]
|
m1 = [False]
|
||||||
m2 = [False]
|
m2 = [False]
|
||||||
@ -9000,6 +9275,67 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||||||
self.assertTrue(m1[0])
|
self.assertTrue(m1[0])
|
||||||
self.assertTrue(m2[0])
|
self.assertTrue(m2[0])
|
||||||
|
|
||||||
|
def test_storage_cycle_via_slots(self):
|
||||||
|
m1 = [False]
|
||||||
|
m2 = [False]
|
||||||
|
|
||||||
|
class SlotStorage1(torch._C.StorageBase):
|
||||||
|
__slots__ = ['slot1']
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
m1[0] = True
|
||||||
|
|
||||||
|
class SlotStorage2(SlotStorage1):
|
||||||
|
__slots__ = ['slot2']
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
m2[0] = True
|
||||||
|
|
||||||
|
x = SlotStorage1(torch.UntypedStorage(2))
|
||||||
|
y = SlotStorage2(torch.UntypedStorage(2))
|
||||||
|
|
||||||
|
x.slot1 = y
|
||||||
|
y.slot2 = x
|
||||||
|
|
||||||
|
del x
|
||||||
|
with disable_gc():
|
||||||
|
del y
|
||||||
|
self.assertFalse(m1[0])
|
||||||
|
self.assertFalse(m2[0])
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
self.assertTrue(m1[0])
|
||||||
|
self.assertTrue(m2[0])
|
||||||
|
|
||||||
|
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
|
||||||
|
def test_storage_preserve_nonhermetic_in_hermetic_context(self):
|
||||||
|
from torch.library import Library, impl
|
||||||
|
global _my_storage
|
||||||
|
|
||||||
|
my_lib = Library("my_lib", "DEF")
|
||||||
|
my_lib.define('my_func() -> None')
|
||||||
|
|
||||||
|
a = torch.tensor([1.])
|
||||||
|
_my_storage = a.untyped_storage()
|
||||||
|
|
||||||
|
m, t = Tracker.make()
|
||||||
|
_my_storage._tracker = t
|
||||||
|
del t
|
||||||
|
|
||||||
|
@impl(my_lib, 'my_func', '')
|
||||||
|
def my_func():
|
||||||
|
global _my_storage
|
||||||
|
del _my_storage
|
||||||
|
|
||||||
|
self.assertFalse(m[0])
|
||||||
|
torch.ops.my_lib.my_func()
|
||||||
|
self.assertFalse(m[0])
|
||||||
|
|
||||||
|
s = a.untyped_storage()
|
||||||
|
del a
|
||||||
|
del s
|
||||||
|
self.assertTrue(m[0])
|
||||||
|
|
||||||
# FIXME: move to test_autograd?
|
# FIXME: move to test_autograd?
|
||||||
@skipIfTorchDynamo("TorchDynamo does not work well with hooks")
|
@skipIfTorchDynamo("TorchDynamo does not work well with hooks")
|
||||||
def test_backward_hooks_traverse(self):
|
def test_backward_hooks_traverse(self):
|
||||||
@ -9028,7 +9364,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||||||
self.assertTrue(m2[0])
|
self.assertTrue(m2[0])
|
||||||
|
|
||||||
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
|
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
|
||||||
def test_dead_weak_ref(self):
|
def test_tensor_dead_weak_ref(self):
|
||||||
x = torch.empty(2)
|
x = torch.empty(2)
|
||||||
w_x = weakref.ref(x)
|
w_x = weakref.ref(x)
|
||||||
y = torch.empty(2)
|
y = torch.empty(2)
|
||||||
@ -9044,7 +9380,24 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||||||
|
|
||||||
self.assertRaises(RuntimeError, lambda: x.sigmoid())
|
self.assertRaises(RuntimeError, lambda: x.sigmoid())
|
||||||
|
|
||||||
def test_resurrected_weak_ref(self):
|
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
|
||||||
|
def test_storage_dead_weak_ref(self):
|
||||||
|
x = torch.UntypedStorage(2)
|
||||||
|
w_x = weakref.ref(x)
|
||||||
|
y = torch.tensor(x)
|
||||||
|
del x
|
||||||
|
|
||||||
|
x = w_x()
|
||||||
|
# Ideally, x would keep the storage live. But CPython doesn't
|
||||||
|
# provide enough hooks to do this. So it will go dead and x
|
||||||
|
# will transmute into storage with null StorageImpl. Not great, but the
|
||||||
|
# best we can do.
|
||||||
|
del y
|
||||||
|
|
||||||
|
self.assertRaisesRegex(RuntimeError, "Got a null Storage", lambda: x[0])
|
||||||
|
self.assertRaisesRegex(RuntimeError, "Got a null Storage", lambda: x.float())
|
||||||
|
|
||||||
|
def test_tensor_resurrected_weak_ref(self):
|
||||||
x = torch.empty(2)
|
x = torch.empty(2)
|
||||||
w_x = weakref.ref(x)
|
w_x = weakref.ref(x)
|
||||||
y = torch.empty(2)
|
y = torch.empty(2)
|
||||||
@ -9057,8 +9410,20 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||||||
del y
|
del y
|
||||||
x.sigmoid()
|
x.sigmoid()
|
||||||
|
|
||||||
|
def test_storage_resurrected_weak_ref(self):
|
||||||
|
x = torch.UntypedStorage(2)
|
||||||
|
w_x = weakref.ref(x)
|
||||||
|
y = torch.tensor(x)
|
||||||
|
del x
|
||||||
|
|
||||||
|
x = w_x()
|
||||||
|
# Use this to manually fix weak reference after dereferencing them
|
||||||
|
x._fix_weakref()
|
||||||
|
del y
|
||||||
|
x.float()
|
||||||
|
|
||||||
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
|
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
|
||||||
def test_fix_weakref_no_leak(self):
|
def test_tensor_fix_weakref_no_leak(self):
|
||||||
import weakref
|
import weakref
|
||||||
|
|
||||||
called = False
|
called = False
|
||||||
@ -9074,6 +9439,23 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||||||
|
|
||||||
self.assertTrue(called)
|
self.assertTrue(called)
|
||||||
|
|
||||||
|
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
|
||||||
|
def test_storage_fix_weakref_no_leak(self):
|
||||||
|
import weakref
|
||||||
|
|
||||||
|
called = False
|
||||||
|
|
||||||
|
a = torch.UntypedStorage(1)
|
||||||
|
|
||||||
|
def callback(w):
|
||||||
|
nonlocal called
|
||||||
|
called = True
|
||||||
|
wa = weakref.ref(a, callback)
|
||||||
|
a._fix_weakref()
|
||||||
|
del a
|
||||||
|
|
||||||
|
self.assertTrue(called)
|
||||||
|
|
||||||
# FIXME: move to test_linalg
|
# FIXME: move to test_linalg
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_bmm_multithreaded(self):
|
def test_bmm_multithreaded(self):
|
||||||
|
@ -30,29 +30,6 @@ std::array<THPDtype*, static_cast<int>(at::ScalarType::NumOptions)>
|
|||||||
std::array<THPLayout*, static_cast<int>(at::Layout::NumOptions)>
|
std::array<THPLayout*, static_cast<int>(at::Layout::NumOptions)>
|
||||||
layout_registry = {};
|
layout_registry = {};
|
||||||
|
|
||||||
at::DeprecatedTypeProperties* get_type_properties(
|
|
||||||
at::DeviceType device_type,
|
|
||||||
at::ScalarType scalarType) {
|
|
||||||
at::Backend backend = at::Backend::Undefined;
|
|
||||||
if (device_type == at::kCPU) {
|
|
||||||
backend = at::Backend::CPU;
|
|
||||||
} else if (device_type == at::kCUDA) {
|
|
||||||
backend = at::Backend::CUDA;
|
|
||||||
} else if (device_type == at::kXPU) {
|
|
||||||
backend = at::Backend::XPU;
|
|
||||||
} else if (device_type == at::kHPU) {
|
|
||||||
backend = at::Backend::HPU;
|
|
||||||
} else if (device_type == at::kMPS) {
|
|
||||||
backend = at::Backend::MPS;
|
|
||||||
} else if (device_type == at::DeviceType::Meta) {
|
|
||||||
backend = at::Backend::Undefined;
|
|
||||||
} else if (device_type == at::DeviceType::PrivateUse1) {
|
|
||||||
backend = at::Backend::PrivateUse1;
|
|
||||||
} else {
|
|
||||||
TORCH_CHECK(false, "Invalid device for storage: ", device_type);
|
|
||||||
}
|
|
||||||
return &at::getDeprecatedTypeProperties(backend, scalarType);
|
|
||||||
}
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void registerDtypeObject(THPDtype* dtype, at::ScalarType scalarType) {
|
void registerDtypeObject(THPDtype* dtype, at::ScalarType scalarType) {
|
||||||
@ -88,13 +65,10 @@ PyObject* createPyObject(const at::Storage& storage) {
|
|||||||
// information about storages from python). However, any accesses to the
|
// information about storages from python). However, any accesses to the
|
||||||
// data_ptr is not allowed, through methods like
|
// data_ptr is not allowed, through methods like
|
||||||
// x.untyped_storage().data_ptr()
|
// x.untyped_storage().data_ptr()
|
||||||
PyTypeObject* type = reinterpret_cast<PyTypeObject*>(THPStorageClass);
|
PyObject* obj = THPStorage_Wrap(storage);
|
||||||
auto obj = THPObjectPtr(type->tp_alloc(type, 0));
|
|
||||||
if (!obj)
|
if (!obj)
|
||||||
throw python_error();
|
throw python_error();
|
||||||
((THPStorage*)obj.get())->cdata =
|
return obj;
|
||||||
c10::MaybeOwned<at::Storage>::owned(at::Storage(/* copy */ storage));
|
|
||||||
return obj.release();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
PyTypeObject* loadTypedStorageTypeObject() {
|
PyTypeObject* loadTypedStorageTypeObject() {
|
||||||
@ -118,16 +92,13 @@ bool isStorage(PyObject* obj) {
|
|||||||
if (PyObject_TypeCheck(obj, getTypedStorageTypeObject())) {
|
if (PyObject_TypeCheck(obj, getTypedStorageTypeObject())) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
auto obj_type = Py_TYPE(obj);
|
return THPStorage_Check(obj);
|
||||||
|
|
||||||
return obj_type == reinterpret_cast<PyTypeObject*>(THPStorageClass);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Storage createStorageGetType(
|
std::tuple<at::Storage, at::ScalarType, bool> createStorageGetType(
|
||||||
PyObject* obj,
|
PyObject* obj) {
|
||||||
at::ScalarType& scalar_type,
|
at::ScalarType scalar_type = at::ScalarType::Undefined;
|
||||||
bool& is_typed_storage) {
|
bool is_typed_storage = PyObject_TypeCheck(obj, getTypedStorageTypeObject());
|
||||||
is_typed_storage = PyObject_TypeCheck(obj, getTypedStorageTypeObject());
|
|
||||||
PyObject* untyped_storage_obj = nullptr;
|
PyObject* untyped_storage_obj = nullptr;
|
||||||
|
|
||||||
if (is_typed_storage) {
|
if (is_typed_storage) {
|
||||||
@ -136,10 +107,9 @@ at::Storage createStorageGetType(
|
|||||||
// stay nonzero since the `TypedStorage` maintains a reference.
|
// stay nonzero since the `TypedStorage` maintains a reference.
|
||||||
PyObject* dtype_obj = PyObject_GetAttrString(obj, "dtype");
|
PyObject* dtype_obj = PyObject_GetAttrString(obj, "dtype");
|
||||||
TORCH_INTERNAL_ASSERT(dtype_obj);
|
TORCH_INTERNAL_ASSERT(dtype_obj);
|
||||||
Py_DECREF(dtype_obj);
|
|
||||||
|
|
||||||
TORCH_INTERNAL_ASSERT(THPDtype_Check(dtype_obj));
|
TORCH_INTERNAL_ASSERT(THPDtype_Check(dtype_obj));
|
||||||
scalar_type = reinterpret_cast<THPDtype*>(dtype_obj)->scalar_type;
|
scalar_type = reinterpret_cast<THPDtype*>(dtype_obj)->scalar_type;
|
||||||
|
Py_DECREF(dtype_obj);
|
||||||
|
|
||||||
untyped_storage_obj = PyObject_GetAttrString(obj, "_untyped_storage");
|
untyped_storage_obj = PyObject_GetAttrString(obj, "_untyped_storage");
|
||||||
TORCH_INTERNAL_ASSERT(untyped_storage_obj);
|
TORCH_INTERNAL_ASSERT(untyped_storage_obj);
|
||||||
@ -150,22 +120,18 @@ at::Storage createStorageGetType(
|
|||||||
untyped_storage_obj = obj;
|
untyped_storage_obj = obj;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Py_TYPE(untyped_storage_obj) !=
|
TORCH_CHECK(
|
||||||
reinterpret_cast<PyTypeObject*>(THPStorageClass)) {
|
THPStorage_Check(untyped_storage_obj),
|
||||||
throw TypeError("not a storage '%s'", Py_TYPE(obj)->tp_name);
|
"not a storage '",
|
||||||
}
|
Py_TYPE(obj)->tp_name,
|
||||||
|
"'");
|
||||||
|
|
||||||
const auto& storage = THPStorage_Unpack(untyped_storage_obj);
|
auto storage = THPStorage_Unpack(untyped_storage_obj);
|
||||||
c10::DeviceType device_type = storage.device().type();
|
return std::make_tuple(storage, scalar_type, is_typed_storage);
|
||||||
auto type_properties = get_type_properties(device_type, at::kByte);
|
|
||||||
return type_properties->unsafeStorageFromTH(
|
|
||||||
storage.unsafeGetStorageImpl(), true);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Storage createStorage(PyObject* obj) {
|
at::Storage createStorage(PyObject* obj) {
|
||||||
at::ScalarType scalar_type = at::ScalarType::Undefined;
|
return std::get<0>(createStorageGetType(obj));
|
||||||
bool is_typed_storage = false;
|
|
||||||
return createStorageGetType(obj, scalar_type, is_typed_storage);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -27,10 +27,8 @@ void registerLayoutObject(THPLayout* thp_layout, at::Layout layout);
|
|||||||
|
|
||||||
TORCH_PYTHON_API PyObject* createPyObject(const at::Storage& storage);
|
TORCH_PYTHON_API PyObject* createPyObject(const at::Storage& storage);
|
||||||
at::Storage createStorage(PyObject* obj);
|
at::Storage createStorage(PyObject* obj);
|
||||||
at::Storage createStorageGetType(
|
std::tuple<at::Storage, at::ScalarType, bool> createStorageGetType(
|
||||||
PyObject* obj,
|
PyObject* obj);
|
||||||
at::ScalarType& scalar_type,
|
|
||||||
bool& is_typed_storage);
|
|
||||||
bool isStorage(PyObject* obj);
|
bool isStorage(PyObject* obj);
|
||||||
|
|
||||||
TORCH_PYTHON_API THPDtype* getTHPDtype(at::ScalarType scalarType);
|
TORCH_PYTHON_API THPDtype* getTHPDtype(at::ScalarType scalarType);
|
||||||
|
@ -35,7 +35,7 @@ struct ConcretePyInterpreterVTable final
|
|||||||
: public c10::impl::PyInterpreterVTable {
|
: public c10::impl::PyInterpreterVTable {
|
||||||
std::string name() const override;
|
std::string name() const override;
|
||||||
|
|
||||||
void decref(PyObject* pyobj, bool is_tensor) const override;
|
void decref(PyObject* pyobj, bool has_pyobj_slot) const override;
|
||||||
|
|
||||||
// TODO: Need to make this work for StorageImpl too. I imagine I'll want to
|
// TODO: Need to make this work for StorageImpl too. I imagine I'll want to
|
||||||
// operate upon a PyObjectSlot rather than a TensorImpl
|
// operate upon a PyObjectSlot rather than a TensorImpl
|
||||||
@ -189,15 +189,15 @@ py::object torchDispatchFromTensorImpl(
|
|||||||
TorchFunctionName::TorchDispatch));
|
TorchFunctionName::TorchDispatch));
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE [PyInterpreter::decref takes an `is_tensor` arg]
|
// NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg]
|
||||||
// Before calling PyInterpreter::decref, we must statically know if the
|
// Before calling PyInterpreter::decref, we must statically know if the
|
||||||
// pyobj is a Tensor or not.
|
// pyobj has a PyObjectSlot or not.
|
||||||
// - If it is a tensor, we need to be careful about PyObject resurrection
|
// - If it has a PyObjectSlot, we need to be careful about PyObject resurrection
|
||||||
// - If it is not a tensor, we can freely decref
|
// - If it does not have a PyObjectSlot, we can freely decref
|
||||||
// One alternative to this is using PyObject_IsInstance
|
// One alternative to this is using PyObject_IsInstance
|
||||||
// to get at this information. However, we don't want to risk an incorrect
|
// to get at this information. However, we don't want to risk an incorrect
|
||||||
// `__instancecheck__` changing the semantics here.
|
// `__instancecheck__` changing the semantics here.
|
||||||
void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool is_tensor)
|
void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool has_pyobj_slot)
|
||||||
const {
|
const {
|
||||||
// Leak the pyobj if not initialized. This can happen if we are running
|
// Leak the pyobj if not initialized. This can happen if we are running
|
||||||
// exit handlers that are destructing tensors with residual (owned)
|
// exit handlers that are destructing tensors with residual (owned)
|
||||||
@ -207,23 +207,33 @@ void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool is_tensor)
|
|||||||
|
|
||||||
pybind11::gil_scoped_acquire gil;
|
pybind11::gil_scoped_acquire gil;
|
||||||
// Two possibilities:
|
// Two possibilities:
|
||||||
// 1. We are decref-ing a tensor. Then we must be careful about
|
// 1. We are decref-ing an object that has a PyObjectSlot, like a Tensor or
|
||||||
// PyObject resurrection (this only applies to Tensors, see
|
// Storage. Then we must be careful about PyObject resurrection (see
|
||||||
// THPVariable_clear).
|
// THPVariable_clear).
|
||||||
// 2. We are decref-ing some other Python object. We don't do
|
// 2. We are decref-ing some other Python object. We don't do
|
||||||
// PyObject resurrection on non-Tensors, so we just carry on as usual
|
// PyObject resurrection on non-Tensors, so we just carry on as usual
|
||||||
if (is_tensor && Py_REFCNT(pyobj) > 1) {
|
if (has_pyobj_slot && Py_REFCNT(pyobj) > 1) {
|
||||||
// It's still alive! This can happen if a weak ref resurrected
|
if (THPVariable_Check(pyobj)) {
|
||||||
// the PyObject without flipping ownership. At this point it is
|
// It's still alive! This can happen if a weak ref resurrected
|
||||||
// too late to rescue the object, so just stub out the PyObject
|
// the PyObject without flipping ownership. At this point it is
|
||||||
// so that it fails on subsequent uses. Don't raise an error here;
|
// too late to rescue the object, so just stub out the PyObject
|
||||||
// you're probably in a destructor.
|
// so that it fails on subsequent uses. Don't raise an error here;
|
||||||
TORCH_WARN(
|
// you're probably in a destructor.
|
||||||
"Deallocating Tensor that still has live PyObject references. "
|
TORCH_WARN(
|
||||||
"This probably happened because you took out a weak reference to "
|
"Deallocating Tensor that still has live PyObject references. "
|
||||||
"Tensor and didn't call _fix_weakref() after dereferencing it. "
|
"This probably happened because you took out a weak reference to "
|
||||||
"Subsequent accesses to this tensor via the PyObject will now fail.");
|
"Tensor and didn't call _fix_weakref() after dereferencing it. "
|
||||||
((THPVariable*)pyobj)->cdata = c10::MaybeOwned<torch::autograd::Variable>();
|
"Subsequent accesses to this tensor via the PyObject will now fail.");
|
||||||
|
((THPVariable*)pyobj)->cdata =
|
||||||
|
c10::MaybeOwned<torch::autograd::Variable>();
|
||||||
|
} else if (THPStorage_Check(pyobj)) {
|
||||||
|
TORCH_WARN(
|
||||||
|
"Deallocating UntypedStorage that still has live PyObject references. "
|
||||||
|
"This probably happened because you took out a weak reference to "
|
||||||
|
"UntypedStorage and didn't call _fix_weakref() after dereferencing it. "
|
||||||
|
"Subsequent accesses to this storage via the PyObject will now fail.");
|
||||||
|
((THPStorage*)pyobj)->cdata = c10::MaybeOwned<c10::Storage>();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Py_DECREF(pyobj);
|
Py_DECREF(pyobj);
|
||||||
};
|
};
|
||||||
@ -548,8 +558,8 @@ static void set_tensor_attr_with_capsule(
|
|||||||
const c10::TensorImpl* tensor,
|
const c10::TensorImpl* tensor,
|
||||||
py::capsule& capsule,
|
py::capsule& capsule,
|
||||||
const char* attr_name) {
|
const char* attr_name) {
|
||||||
c10::optional<PyObject*> mb_obj =
|
c10::optional<PyObject*> mb_obj = tensor->pyobj_slot()->check_pyobj(
|
||||||
tensor->pyobj_slot()->check_pyobj(getPyInterpreter());
|
getPyInterpreter(), /*ignore_hermetic_tls=*/false);
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value");
|
mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value");
|
||||||
py::handle(mb_obj.value()).attr(attr_name) = capsule;
|
py::handle(mb_obj.value()).attr(attr_name) = capsule;
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
#include <ATen/mps/MPSDevice.h>
|
#include <ATen/mps/MPSDevice.h>
|
||||||
#include <c10/core/CPUAllocator.h>
|
#include <c10/core/CPUAllocator.h>
|
||||||
|
#include <c10/core/RefcountedDeleter.h>
|
||||||
#include <libshm.h>
|
#include <libshm.h>
|
||||||
#include <torch/csrc/CudaIPCTypes.h>
|
#include <torch/csrc/CudaIPCTypes.h>
|
||||||
#include <torch/csrc/Device.h>
|
#include <torch/csrc/Device.h>
|
||||||
@ -15,6 +16,7 @@
|
|||||||
#include <torch/csrc/THP.h>
|
#include <torch/csrc/THP.h>
|
||||||
#include <torch/csrc/autograd/utils/wrap_outputs.h>
|
#include <torch/csrc/autograd/utils/wrap_outputs.h>
|
||||||
#include <torch/csrc/copy_utils.h>
|
#include <torch/csrc/copy_utils.h>
|
||||||
|
#include <torch/csrc/utils/pyobject_preservation.h>
|
||||||
#include <torch/csrc/utils/python_arg_parser.h>
|
#include <torch/csrc/utils/python_arg_parser.h>
|
||||||
|
|
||||||
#include <c10/util/intrusive_ptr.h>
|
#include <c10/util/intrusive_ptr.h>
|
||||||
@ -27,28 +29,265 @@ void THPPointer<c10::StorageImpl>::free() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
PyObject* THPStorageClass = nullptr;
|
PyTypeObject* THPStorageClass = nullptr;
|
||||||
|
|
||||||
PyObject* THPStorage_New(c10::Storage storage) {
|
PyObject* THPStorage_NewWithStorage(
|
||||||
PyTypeObject* type = (PyTypeObject*)THPStorageClass;
|
PyTypeObject* type,
|
||||||
PyObject* obj = type->tp_alloc(type, 0);
|
c10::Storage _storage,
|
||||||
if (obj) {
|
c10::impl::PyInterpreterStatus status,
|
||||||
((THPStorage*)obj)->cdata =
|
bool allow_preexisting_pyobj) {
|
||||||
c10::MaybeOwned<c10::Storage>::owned(std::move(storage));
|
TORCH_CHECK(
|
||||||
|
PyType_IsSubtype(type, &THPStorageType),
|
||||||
|
"Creating a Storage subclass from a class that does not inherit from ",
|
||||||
|
"Storage is not possible. Make sure your class inherits from Storage.");
|
||||||
|
|
||||||
|
auto maybe_pyobj = _storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj(
|
||||||
|
getPyInterpreter(), /*ignore_hermetic_tls=*/false);
|
||||||
|
if (maybe_pyobj.has_value() && maybe_pyobj.value()) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
allow_preexisting_pyobj,
|
||||||
|
"Creating a new Storage subclass ",
|
||||||
|
type->tp_name,
|
||||||
|
" but the raw Storage object is already associated to a python object ",
|
||||||
|
"of type ",
|
||||||
|
maybe_pyobj.value()->ob_type->tp_name);
|
||||||
|
PyObject* obj = *maybe_pyobj;
|
||||||
|
PyTypeObject* obj_type = Py_TYPE(obj);
|
||||||
|
TORCH_CHECK(
|
||||||
|
obj_type == type || PyType_IsSubtype(obj_type, type),
|
||||||
|
"Creating a new Storage subclass ",
|
||||||
|
type->tp_name,
|
||||||
|
" but the raw Storage object is already associated to a python object ",
|
||||||
|
"of type ",
|
||||||
|
maybe_pyobj.value()->ob_type->tp_name,
|
||||||
|
" which is not a subclass of the "
|
||||||
|
"requested type");
|
||||||
|
return THPStorage_Wrap(std::move(_storage));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PyObject* obj = type->tp_alloc(type, 0);
|
||||||
|
TORCH_CHECK(obj, "Failed to allocate a ", type->tp_name, " object");
|
||||||
|
|
||||||
|
auto s = (THPStorage*)obj;
|
||||||
|
|
||||||
|
new (&s->cdata) c10::MaybeOwned<c10::Storage>();
|
||||||
|
|
||||||
|
s->cdata = c10::MaybeOwned<c10::Storage>::owned(std::move(_storage));
|
||||||
|
|
||||||
|
if (!c10::impl::HermeticPyObjectTLS::get_state()) {
|
||||||
|
s->is_hermetic = false;
|
||||||
|
const auto& storage = THPStorage_Unpack(s);
|
||||||
|
storage.unsafeGetStorageImpl()->pyobj_slot()->init_pyobj(
|
||||||
|
getPyInterpreter(), obj, status);
|
||||||
|
} else {
|
||||||
|
s->is_hermetic = true;
|
||||||
|
}
|
||||||
|
|
||||||
return obj;
|
return obj;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Wraps the c10::Storage with a storage PyObject
|
||||||
|
PyObject* THPStorage_Wrap(c10::Storage storage) {
|
||||||
|
c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl();
|
||||||
|
if (c10::impl::HermeticPyObjectTLS::get_state()) {
|
||||||
|
return THPStorage_NewWithStorage(
|
||||||
|
THPStorageClass,
|
||||||
|
std::move(storage),
|
||||||
|
c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
|
||||||
|
}
|
||||||
|
c10::impl::PyObjectSlot* pyobj_slot = storage_impl->pyobj_slot();
|
||||||
|
|
||||||
|
// If the StorageImpl has a PyObject that is managed by a different
|
||||||
|
// interpreter than the current one, create a new StorageImpl that points to
|
||||||
|
// the same data and then create the Python storage from that.
|
||||||
|
// NOTE: This is only supposed to happen in MultiPy
|
||||||
|
if (pyobj_slot->has_pyobj_nonhermetic() &&
|
||||||
|
!pyobj_slot->check_interpreter(getPyInterpreter())) {
|
||||||
|
return THPStorage_NewWithStorage(
|
||||||
|
THPStorageClass,
|
||||||
|
c10::newStorageImplFromRefcountedDataPtr(storage),
|
||||||
|
c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
|
||||||
|
}
|
||||||
|
c10::optional<PyObject*> maybe_pyobj = pyobj_slot->check_pyobj(
|
||||||
|
getPyInterpreter(), /*ignore_hermetic_tls=*/false);
|
||||||
|
c10::impl::PyInterpreterStatus status =
|
||||||
|
c10::impl::PyInterpreterStatus::TAGGED_BY_US;
|
||||||
|
if (maybe_pyobj.has_value()) {
|
||||||
|
auto obj = *maybe_pyobj;
|
||||||
|
if (obj) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
THPStorage_Check(obj),
|
||||||
|
"Expected a storage type, but got ",
|
||||||
|
Py_TYPE(obj)->tp_name);
|
||||||
|
|
||||||
|
if (pyobj_slot->owns_pyobj()) {
|
||||||
|
pyobj_slot->set_owns_pyobj(false);
|
||||||
|
reinterpret_cast<THPStorage*>(obj)->cdata =
|
||||||
|
c10::MaybeOwned<c10::Storage>::owned(std::move(storage));
|
||||||
|
return obj;
|
||||||
|
} else {
|
||||||
|
Py_INCREF(obj);
|
||||||
|
return obj;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
status = c10::impl::PyInterpreterStatus::TAGGED_BY_US;
|
||||||
|
} else {
|
||||||
|
if (storage.use_count() <= 1) {
|
||||||
|
status = c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED;
|
||||||
|
} else {
|
||||||
|
status = c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return THPStorage_NewWithStorage(THPStorageClass, std::move(storage), status);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool THPStorage_isPreservable(THPStorage* self) {
|
||||||
|
if (self->cdata.unsafeIsBorrowed()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto const& storage = THPStorage_Unpack(self);
|
||||||
|
|
||||||
|
if (self->is_hermetic) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj(
|
||||||
|
getPyInterpreter(), /*ignore_hermetic_tls=*/true) !=
|
||||||
|
c10::make_optional((PyObject*)self)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (storage.use_count() <= 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool THPStorage_tryPreserve(THPStorage* self) {
|
||||||
|
if (!THPStorage_isPreservable(self)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto& storage = THPStorage_Unpack(self);
|
||||||
|
c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl();
|
||||||
|
|
||||||
|
auto maybe_pyobj = storage_impl->pyobj_slot()->check_pyobj(
|
||||||
|
getPyInterpreter(),
|
||||||
|
/*ignore_hermetic_tls=*/true);
|
||||||
|
// NOTE: It is possible to just set the PyObjectSlot here, but the point is
|
||||||
|
// that we should have already set PyObjectSlot when the storage PyObject was
|
||||||
|
// created.
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
maybe_pyobj.has_value(),
|
||||||
|
"Trying to preserve a Python storage whose PyObjectSlot does not have a PyObject");
|
||||||
|
|
||||||
|
PyObject* pyobj = *maybe_pyobj;
|
||||||
|
|
||||||
|
TORCH_CHECK(
|
||||||
|
THPStorage_Check(pyobj),
|
||||||
|
"Expected a storage type, but got ",
|
||||||
|
Py_TYPE(pyobj)->tp_name);
|
||||||
|
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
(void*)pyobj == (void*)self,
|
||||||
|
"Python storage and the PyObject in the internal PyObjectSlot are not at the same address");
|
||||||
|
|
||||||
|
TORCH_INTERNAL_ASSERT(!storage_impl->pyobj_slot()->owns_pyobj());
|
||||||
|
|
||||||
|
storage_impl->pyobj_slot()->set_owns_pyobj(true);
|
||||||
|
Py_INCREF(self);
|
||||||
|
|
||||||
|
self->cdata = c10::MaybeOwned<c10::Storage>::borrowed(storage);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
static void THPStorage_subclass_dealloc(PyObject* self) {
|
static void THPStorage_subclass_dealloc(PyObject* self) {
|
||||||
THPStorage* _self = (THPStorage*)self;
|
THPStorage* _self = (THPStorage*)self;
|
||||||
// Some subclass of StorageBase are GC-tracked objects even
|
|
||||||
// though the base class is not.
|
if (THPStorage_tryPreserve(_self)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Some subclass of StorageBase could be GC-tracked objects even
|
||||||
|
// though the base class is not
|
||||||
auto* type = Py_TYPE(self);
|
auto* type = Py_TYPE(self);
|
||||||
if (PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC) != 0) {
|
if (PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC) != 0) {
|
||||||
PyObject_GC_UnTrack(self);
|
PyObject_GC_UnTrack(self);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool has_finalizer = type->tp_finalize || type->tp_del;
|
||||||
|
|
||||||
|
if (type->tp_finalize) {
|
||||||
|
PyObject_GC_Track(self);
|
||||||
|
if (PyObject_CallFinalizerFromDealloc(self) < 0) {
|
||||||
|
// The finalizer has resurrected the PyObject and there is a new Python
|
||||||
|
// reference to it, so we can just stop deallocating. Read about
|
||||||
|
// resurrection from `__del__` here:
|
||||||
|
// https://docs.python.org/3/reference/datamodel.html#object.__del__
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
PyObject_GC_UnTrack(self);
|
||||||
|
}
|
||||||
|
|
||||||
|
// base test is unnecessary as THPStorae does not set this
|
||||||
|
if (type->tp_weaklistoffset) {
|
||||||
|
PyObject_ClearWeakRefs(self);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (type->tp_del) {
|
||||||
|
PyObject_GC_Track(self);
|
||||||
|
type->tp_del(self);
|
||||||
|
if (self->ob_refcnt > 0) {
|
||||||
|
// Resurrected (see above comment about resurrection from `__del__`)
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
PyObject_GC_UnTrack(self);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (has_finalizer) {
|
||||||
|
/* New weakrefs could be created during the finalizer call.
|
||||||
|
If this occurs, clear them out without calling their
|
||||||
|
finalizers since they might rely on part of the object
|
||||||
|
being finalized that has already been destroyed. */
|
||||||
|
if (type->tp_weaklistoffset) {
|
||||||
|
/* Modeled after GET_WEAKREFS_LISTPTR() */
|
||||||
|
PyWeakReference** list =
|
||||||
|
(PyWeakReference**)PyObject_GET_WEAKREFS_LISTPTR(self);
|
||||||
|
while (*list)
|
||||||
|
_PyWeakref_ClearRef(*list);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear slots
|
||||||
|
{
|
||||||
|
PyTypeObject* base = type;
|
||||||
|
while (base != &THPStorageType) {
|
||||||
|
if (Py_SIZE(base)) {
|
||||||
|
clear_slots(base, self);
|
||||||
|
}
|
||||||
|
base = base->tp_base;
|
||||||
|
TORCH_INTERNAL_ASSERT(base);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear __dict__
|
||||||
|
if (C10_LIKELY(type->tp_dictoffset)) {
|
||||||
|
PyObject** dictptr = _PyObject_GetDictPtr(self);
|
||||||
|
if (dictptr != nullptr) {
|
||||||
|
PyObject* dict = *dictptr;
|
||||||
|
if (dict != nullptr) {
|
||||||
|
Py_DECREF(dict);
|
||||||
|
*dictptr = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_INTERNAL_ASSERT(Py_TYPE(self) == type);
|
||||||
|
|
||||||
_self->cdata.~MaybeOwned<c10::Storage>();
|
_self->cdata.~MaybeOwned<c10::Storage>();
|
||||||
Py_TYPE(_self)->tp_free(self);
|
Py_TYPE(_self)->tp_free(self);
|
||||||
|
|
||||||
|
TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE);
|
||||||
|
Py_DECREF(type);
|
||||||
}
|
}
|
||||||
|
|
||||||
c10::intrusive_ptr<c10::StorageImpl> make_storage_impl(
|
c10::intrusive_ptr<c10::StorageImpl> make_storage_impl(
|
||||||
@ -151,32 +390,35 @@ static PyObject* THPStorage_pynew(
|
|||||||
"(): only one or neither of 'allocator' or 'device' can ",
|
"(): only one or neither of 'allocator' or 'device' can ",
|
||||||
"be given, but not both");
|
"be given, but not both");
|
||||||
|
|
||||||
THPStoragePtr self((THPStorage*)type->tp_alloc(type, 0));
|
PyObject* self = nullptr;
|
||||||
THPUtils_assert(self, "failed to allocate a " THPStorageStr " object");
|
|
||||||
c10::Allocator* allocator = nullptr;
|
c10::Allocator* allocator = nullptr;
|
||||||
|
|
||||||
// torch.Storage(*, ...)
|
// torch.Storage(*, ...)
|
||||||
if (r.idx == 0) {
|
if (r.idx == 0) {
|
||||||
self->cdata = c10::MaybeOwned<c10::Storage>::owned(make_storage_impl(
|
self = THPStorage_NewWithStorage(
|
||||||
c10::StorageImpl::use_byte_size_t(),
|
type,
|
||||||
0,
|
make_storage_impl(
|
||||||
allocator,
|
c10::StorageImpl::use_byte_size_t(),
|
||||||
/*resizable=*/true,
|
0,
|
||||||
allocator_opt,
|
allocator,
|
||||||
device_opt));
|
/*resizable=*/true,
|
||||||
return (PyObject*)self.release();
|
allocator_opt,
|
||||||
|
device_opt),
|
||||||
|
c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
|
||||||
|
|
||||||
// torch.Storage(size, *, ...)
|
// torch.Storage(size, *, ...)
|
||||||
} else if (r.idx == 1) {
|
} else if (r.idx == 1) {
|
||||||
int64_t size = r.toInt64(0);
|
int64_t size = r.toInt64(0);
|
||||||
self->cdata = c10::MaybeOwned<c10::Storage>::owned(make_storage_impl(
|
self = THPStorage_NewWithStorage(
|
||||||
c10::StorageImpl::use_byte_size_t(),
|
type,
|
||||||
size,
|
make_storage_impl(
|
||||||
allocator,
|
c10::StorageImpl::use_byte_size_t(),
|
||||||
/*resizable=*/true,
|
size,
|
||||||
allocator_opt,
|
allocator,
|
||||||
device_opt));
|
/*resizable=*/true,
|
||||||
return (PyObject*)self.release();
|
allocator_opt,
|
||||||
|
device_opt),
|
||||||
|
c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
|
||||||
|
|
||||||
// torch.Storage(sequence, *, ...)
|
// torch.Storage(sequence, *, ...)
|
||||||
} else if (r.idx == 2) {
|
} else if (r.idx == 2) {
|
||||||
@ -192,19 +434,22 @@ static PyObject* THPStorage_pynew(
|
|||||||
THPStorageStr,
|
THPStorageStr,
|
||||||
"(): Could not obtain the length of sequence of type ",
|
"(): Could not obtain the length of sequence of type ",
|
||||||
THPUtils_typename(sequence));
|
THPUtils_typename(sequence));
|
||||||
self->cdata = c10::MaybeOwned<c10::Storage>::owned(make_storage_impl(
|
self = THPStorage_NewWithStorage(
|
||||||
c10::StorageImpl::use_byte_size_t(),
|
type,
|
||||||
length,
|
make_storage_impl(
|
||||||
allocator,
|
c10::StorageImpl::use_byte_size_t(),
|
||||||
/*resizable=*/true,
|
length,
|
||||||
allocator_opt,
|
allocator,
|
||||||
device_opt));
|
/*resizable=*/true,
|
||||||
|
allocator_opt,
|
||||||
|
device_opt),
|
||||||
|
c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
|
||||||
THPObjectPtr item;
|
THPObjectPtr item;
|
||||||
try {
|
try {
|
||||||
|
const auto& storage = THPStorage_Unpack(self);
|
||||||
for (Py_ssize_t i = 0; i < length; i++) {
|
for (Py_ssize_t i = 0; i < length; i++) {
|
||||||
item = PySequence_GetItem(sequence, i);
|
item = PySequence_GetItem(sequence, i);
|
||||||
uint8_t value = THPByteUtils_unpackReal(item.get());
|
uint8_t value = THPByteUtils_unpackReal(item.get());
|
||||||
const auto& storage = THPStorage_Unpack(self);
|
|
||||||
if (allocator == c10::GetDefaultCPUAllocator()) {
|
if (allocator == c10::GetDefaultCPUAllocator()) {
|
||||||
static_cast<uint8_t*>(storage.mutable_data())[i] = value;
|
static_cast<uint8_t*>(storage.mutable_data())[i] = value;
|
||||||
} else {
|
} else {
|
||||||
@ -221,20 +466,22 @@ static PyObject* THPStorage_pynew(
|
|||||||
THPUtils_typename(item.get()));
|
THPUtils_typename(item.get()));
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return (PyObject*)self.release();
|
|
||||||
}
|
}
|
||||||
|
return self;
|
||||||
Py_RETURN_NONE;
|
Py_RETURN_NONE;
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
|
|
||||||
static Py_ssize_t THPStorage_length(THPStorage* self) {
|
static Py_ssize_t THPStorage_length(THPStorage* self) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
|
THPStorage_assertNotNull(self);
|
||||||
return static_cast<Py_ssize_t>(THPStorage_Unpack(self).nbytes());
|
return static_cast<Py_ssize_t>(THPStorage_Unpack(self).nbytes());
|
||||||
END_HANDLE_TH_ERRORS_RET(-1)
|
END_HANDLE_TH_ERRORS_RET(-1)
|
||||||
}
|
}
|
||||||
|
|
||||||
static PyObject* THPStorage_get(THPStorage* self, PyObject* index) {
|
static PyObject* THPStorage_get(THPStorage* self, PyObject* index) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
|
THPStorage_assertNotNull(self);
|
||||||
const auto& storage = THPStorage_Unpack(self);
|
const auto& storage = THPStorage_Unpack(self);
|
||||||
int64_t len = static_cast<int64_t>(storage.nbytes());
|
int64_t len = static_cast<int64_t>(storage.nbytes());
|
||||||
/* Integer index */
|
/* Integer index */
|
||||||
@ -289,7 +536,11 @@ static PyObject* THPStorage_get(THPStorage* self, PyObject* index) {
|
|||||||
old_storage_impl->allocator(),
|
old_storage_impl->allocator(),
|
||||||
/* resizable */ false);
|
/* resizable */ false);
|
||||||
|
|
||||||
PyObject* _ret = THPStorage_New(std::move(new_storage_impl));
|
PyObject* _ret = THPStorage_NewWithStorage(
|
||||||
|
Py_TYPE(self),
|
||||||
|
std::move(new_storage_impl),
|
||||||
|
c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
|
||||||
|
|
||||||
return _ret;
|
return _ret;
|
||||||
}
|
}
|
||||||
PyErr_Format(
|
PyErr_Format(
|
||||||
@ -302,6 +553,7 @@ static PyObject* THPStorage_get(THPStorage* self, PyObject* index) {
|
|||||||
|
|
||||||
static int THPStorage_set(THPStorage* self, PyObject* index, PyObject* value) {
|
static int THPStorage_set(THPStorage* self, PyObject* index, PyObject* value) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
|
THPStorage_assertNotNull(self);
|
||||||
if (!THPByteUtils_checkReal(value)) {
|
if (!THPByteUtils_checkReal(value)) {
|
||||||
THPUtils_setError(
|
THPUtils_setError(
|
||||||
"can only set storage content with a int types, but got "
|
"can only set storage content with a int types, but got "
|
||||||
@ -451,6 +703,7 @@ int THPStorageMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs) {
|
|||||||
|
|
||||||
static PyObject* THPStorage_device(THPStorage* self, void* unused) {
|
static PyObject* THPStorage_device(THPStorage* self, void* unused) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
|
THPStorage_assertNotNull(self);
|
||||||
return THPDevice_New(THPStorage_Unpack(self).device());
|
return THPDevice_New(THPStorage_Unpack(self).device());
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
@ -490,7 +743,17 @@ bool THPStorage_init(PyObject* module) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void THPStorage_postInit(PyObject* module) {
|
void THPStorage_postInit(PyObject* module) {
|
||||||
THPStorageClass = PyObject_GetAttrString(module, "UntypedStorage");
|
THPStorageClass =
|
||||||
|
(PyTypeObject*)PyObject_GetAttrString(module, "UntypedStorage");
|
||||||
if (!THPStorageClass)
|
if (!THPStorageClass)
|
||||||
throw python_error();
|
throw python_error();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void THPStorage_assertNotNull(THPStorage* storage) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
THPStorage_Unpack(storage).unsafeGetStorageImpl(), "Got a null Storage");
|
||||||
|
}
|
||||||
|
|
||||||
|
void THPStorage_assertNotNull(PyObject* obj) {
|
||||||
|
THPStorage_assertNotNull((THPStorage*)obj);
|
||||||
|
}
|
||||||
|
@ -5,84 +5,44 @@
|
|||||||
|
|
||||||
#define THPStorageStr "torch.UntypedStorage"
|
#define THPStorageStr "torch.UntypedStorage"
|
||||||
|
|
||||||
namespace c10 {
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct MaybeOwnedTraits<c10::Storage> {
|
|
||||||
using owned_type = c10::Storage;
|
|
||||||
using borrow_type = c10::Storage;
|
|
||||||
|
|
||||||
static borrow_type createBorrow(const owned_type& from) {
|
|
||||||
return borrow_type(from);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) {
|
|
||||||
lhs.unsafeReleaseStorageImpl();
|
|
||||||
lhs = borrow_type(rhs);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void destroyBorrow(borrow_type& toDestroy) {
|
|
||||||
toDestroy.unsafeReleaseStorageImpl(); // "leak" it, but it was already +0.
|
|
||||||
}
|
|
||||||
|
|
||||||
static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
|
|
||||||
return borrow;
|
|
||||||
}
|
|
||||||
|
|
||||||
static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
|
|
||||||
return &borrow;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool debugBorrowIsValid(const borrow_type& /*borrow*/) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct ExclusivelyOwnedTraits<c10::Storage> {
|
|
||||||
using repr_type = c10::Storage;
|
|
||||||
using pointer_type = c10::Storage*;
|
|
||||||
using const_pointer_type = const c10::Storage*;
|
|
||||||
|
|
||||||
static repr_type nullRepr() {
|
|
||||||
return c10::Storage();
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class... Args>
|
|
||||||
static repr_type createInPlace(Args&&... args) {
|
|
||||||
return c10::Storage(std::forward<Args>(args)...);
|
|
||||||
}
|
|
||||||
|
|
||||||
static repr_type moveToRepr(c10::Storage&& x) {
|
|
||||||
return std::move(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
static c10::Storage take(c10::Storage& x) {
|
|
||||||
return std::move(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
static pointer_type getImpl(repr_type& x) {
|
|
||||||
return &x;
|
|
||||||
}
|
|
||||||
|
|
||||||
static const_pointer_type getImpl(const repr_type& x) {
|
|
||||||
return &x;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace c10
|
|
||||||
|
|
||||||
struct THPStorage {
|
struct THPStorage {
|
||||||
PyObject_HEAD;
|
PyObject_HEAD;
|
||||||
c10::MaybeOwned<c10::Storage> cdata;
|
c10::MaybeOwned<c10::Storage> cdata;
|
||||||
|
bool is_hermetic;
|
||||||
};
|
};
|
||||||
|
|
||||||
TORCH_PYTHON_API PyObject* THPStorage_New(c10::Storage storage);
|
TORCH_PYTHON_API PyObject* THPStorage_Wrap(c10::Storage storage);
|
||||||
extern PyObject* THPStorageClass;
|
TORCH_PYTHON_API PyObject* THPStorage_NewWithStorage(
|
||||||
|
PyTypeObject* type,
|
||||||
|
c10::Storage _storage,
|
||||||
|
c10::impl::PyInterpreterStatus status,
|
||||||
|
bool allow_preexisting_pyobj = false);
|
||||||
|
extern PyTypeObject* THPStorageClass;
|
||||||
|
|
||||||
|
static inline bool THPStorage_CheckTypeExact(PyTypeObject* tp) {
|
||||||
|
return tp == THPStorageClass;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline bool THPStorage_CheckExact(PyObject* obj) {
|
||||||
|
return THPStorage_CheckTypeExact(Py_TYPE(obj));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool THPStorage_Check(PyObject* obj) {
|
||||||
|
if (!THPStorageClass)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
const auto result = PyObject_IsInstance(obj, (PyObject*)THPStorageClass);
|
||||||
|
if (result == -1)
|
||||||
|
throw python_error();
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
bool THPStorage_init(PyObject* module);
|
bool THPStorage_init(PyObject* module);
|
||||||
void THPStorage_postInit(PyObject* module);
|
void THPStorage_postInit(PyObject* module);
|
||||||
|
|
||||||
|
void THPStorage_assertNotNull(THPStorage* storage);
|
||||||
|
void THPStorage_assertNotNull(PyObject* obj);
|
||||||
|
|
||||||
extern PyTypeObject THPStorageType;
|
extern PyTypeObject THPStorageType;
|
||||||
|
|
||||||
inline const c10::Storage& THPStorage_Unpack(THPStorage* storage) {
|
inline const c10::Storage& THPStorage_Unpack(THPStorage* storage) {
|
||||||
|
@ -41,6 +41,7 @@
|
|||||||
|
|
||||||
static PyObject* THPStorage_nbytes(PyObject* self, PyObject* noargs) {
|
static PyObject* THPStorage_nbytes(PyObject* self, PyObject* noargs) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
|
THPStorage_assertNotNull(self);
|
||||||
return py::cast(THPStorage_Unpack(self).sym_nbytes()).release().ptr();
|
return py::cast(THPStorage_Unpack(self).sym_nbytes()).release().ptr();
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
@ -66,6 +67,7 @@ static PyObject* THPStorage_copy_(
|
|||||||
PyObject* args,
|
PyObject* args,
|
||||||
PyObject* kwargs) {
|
PyObject* kwargs) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
|
THPStorage_assertNotNull(self);
|
||||||
|
|
||||||
at::Storage self_ = torch::createStorage(self);
|
at::Storage self_ = torch::createStorage(self);
|
||||||
|
|
||||||
@ -96,12 +98,14 @@ static PyObject* THPStorage_copy_(
|
|||||||
|
|
||||||
static PyObject* THPStorage_elementSize(PyObject* _self, PyObject* noargs) {
|
static PyObject* THPStorage_elementSize(PyObject* _self, PyObject* noargs) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
|
THPStorage_assertNotNull(_self);
|
||||||
return THPUtils_packInt64(sizeof(uint8_t));
|
return THPUtils_packInt64(sizeof(uint8_t));
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
|
|
||||||
static PyObject* THPStorage_new(PyObject* self, PyObject* noargs) {
|
static PyObject* THPStorage_new(PyObject* self, PyObject* noargs) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
|
THPStorage_assertNotNull(self);
|
||||||
c10::Allocator* allocator = THPStorage_Unpack(self).allocator();
|
c10::Allocator* allocator = THPStorage_Unpack(self).allocator();
|
||||||
auto new_storage = c10::make_intrusive<at::StorageImpl>(
|
auto new_storage = c10::make_intrusive<at::StorageImpl>(
|
||||||
c10::StorageImpl::use_byte_size_t(),
|
c10::StorageImpl::use_byte_size_t(),
|
||||||
@ -109,12 +113,13 @@ static PyObject* THPStorage_new(PyObject* self, PyObject* noargs) {
|
|||||||
allocator,
|
allocator,
|
||||||
/*resizable=*/true);
|
/*resizable=*/true);
|
||||||
|
|
||||||
return THPStorage_New(std::move(new_storage));
|
return THPStorage_Wrap(std::move(new_storage));
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
|
|
||||||
static PyObject* THPStorage_resize_(PyObject* self, PyObject* number_arg) {
|
static PyObject* THPStorage_resize_(PyObject* self, PyObject* number_arg) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
|
THPStorage_assertNotNull(self);
|
||||||
const auto& storage = THPStorage_Unpack(self);
|
const auto& storage = THPStorage_Unpack(self);
|
||||||
// See Note [Invalid Python Storages]
|
// See Note [Invalid Python Storages]
|
||||||
auto invalid = storage.data() == nullptr &&
|
auto invalid = storage.data() == nullptr &&
|
||||||
@ -185,6 +190,7 @@ static PyObject* THPStorage_resize_(PyObject* self, PyObject* number_arg) {
|
|||||||
|
|
||||||
static PyObject* THPStorage_fill_(PyObject* self, PyObject* number_arg) {
|
static PyObject* THPStorage_fill_(PyObject* self, PyObject* number_arg) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
|
THPStorage_assertNotNull(self);
|
||||||
const auto& storage = THPStorage_Unpack(self);
|
const auto& storage = THPStorage_Unpack(self);
|
||||||
// See Note [Invalid Python Storages]
|
// See Note [Invalid Python Storages]
|
||||||
auto invalid = storage.data() == nullptr &&
|
auto invalid = storage.data() == nullptr &&
|
||||||
@ -389,7 +395,7 @@ static PyObject* THPStorage_fromBuffer(
|
|||||||
}
|
}
|
||||||
|
|
||||||
PyBuffer_Release(&buffer);
|
PyBuffer_Release(&buffer);
|
||||||
return (PyObject*)THPStorage_New(storage);
|
return THPStorage_Wrap(storage);
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -429,12 +435,16 @@ static PyObject* THPStorage_fromFile(
|
|||||||
storage->set_nbytes(actual_nbytes);
|
storage->set_nbytes(actual_nbytes);
|
||||||
}
|
}
|
||||||
|
|
||||||
return (PyObject*)THPStorage_New(std::move(storage));
|
return THPStorage_NewWithStorage(
|
||||||
|
THPStorageClass,
|
||||||
|
std::move(storage),
|
||||||
|
c10::impl::PyInterpreterStatus::TAGGED_BY_US);
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
|
|
||||||
PyObject* THPStorage_writeFile(PyObject* self, PyObject* args) {
|
PyObject* THPStorage_writeFile(PyObject* self, PyObject* args) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
|
THPStorage_assertNotNull(self);
|
||||||
const auto& storage = THPStorage_Unpack(self);
|
const auto& storage = THPStorage_Unpack(self);
|
||||||
// See Note [Invalid Python Storages]
|
// See Note [Invalid Python Storages]
|
||||||
auto invalid = storage.data() == nullptr &&
|
auto invalid = storage.data() == nullptr &&
|
||||||
@ -486,12 +496,13 @@ PyObject* THPStorage_newWithFile(PyObject* _unused, PyObject* args) {
|
|||||||
auto storage = THPStorage_readFileRaw<int>(fd, {}, element_size);
|
auto storage = THPStorage_readFileRaw<int>(fd, {}, element_size);
|
||||||
if (!storage.defined())
|
if (!storage.defined())
|
||||||
return nullptr;
|
return nullptr;
|
||||||
return THPStorage_New(std::move(storage));
|
return THPStorage_Wrap(std::move(storage));
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
|
|
||||||
static PyObject* THPStorage_setFromFile(PyObject* self, PyObject* args) {
|
static PyObject* THPStorage_setFromFile(PyObject* self, PyObject* args) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
|
THPStorage_assertNotNull(self);
|
||||||
const auto& storage = THPStorage_Unpack(self);
|
const auto& storage = THPStorage_Unpack(self);
|
||||||
PyObject* file = PyTuple_GET_ITEM(args, 0);
|
PyObject* file = PyTuple_GET_ITEM(args, 0);
|
||||||
PyObject* offset = PyTuple_GET_ITEM(args, 1);
|
PyObject* offset = PyTuple_GET_ITEM(args, 1);
|
||||||
@ -617,6 +628,12 @@ PyObject* THPStorage_byteswap(PyObject* self, PyObject* args) {
|
|||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static PyObject* THPStorage_fix_weakref(PyObject* self, PyObject* noargs) {
|
||||||
|
const auto& storage = THPStorage_Unpack(self);
|
||||||
|
Py_DECREF(THPStorage_Wrap(storage));
|
||||||
|
Py_RETURN_NONE;
|
||||||
|
}
|
||||||
|
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
static PyMethodDef THPStorage_methods[] = {
|
static PyMethodDef THPStorage_methods[] = {
|
||||||
{"copy_",
|
{"copy_",
|
||||||
@ -645,6 +662,7 @@ static PyMethodDef THPStorage_methods[] = {
|
|||||||
nullptr},
|
nullptr},
|
||||||
{"_set_cdata", THPStorage__setCdata, METH_O, nullptr},
|
{"_set_cdata", THPStorage__setCdata, METH_O, nullptr},
|
||||||
{"_byteswap", THPStorage_byteswap, METH_VARARGS, nullptr},
|
{"_byteswap", THPStorage_byteswap, METH_VARARGS, nullptr},
|
||||||
|
{"_fix_weakref", THPStorage_fix_weakref, METH_NOARGS, nullptr},
|
||||||
{nullptr}};
|
{nullptr}};
|
||||||
|
|
||||||
PyMethodDef* THPStorage_getMethods() {
|
PyMethodDef* THPStorage_getMethods() {
|
||||||
|
@ -33,6 +33,7 @@
|
|||||||
|
|
||||||
static PyObject* THPStorage_sharedDecref(PyObject* self, PyObject* noargs) {
|
static PyObject* THPStorage_sharedDecref(PyObject* self, PyObject* noargs) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
|
THPStorage_assertNotNull(self);
|
||||||
const auto& storage = THPStorage_Unpack(self);
|
const auto& storage = THPStorage_Unpack(self);
|
||||||
c10::DeviceType device_type = storage.device_type();
|
c10::DeviceType device_type = storage.device_type();
|
||||||
if (device_type == at::kCPU) {
|
if (device_type == at::kCPU) {
|
||||||
@ -49,6 +50,7 @@ static PyObject* THPStorage_sharedDecref(PyObject* self, PyObject* noargs) {
|
|||||||
|
|
||||||
static PyObject* THPStorage_sharedIncref(PyObject* self, PyObject* noargs) {
|
static PyObject* THPStorage_sharedIncref(PyObject* self, PyObject* noargs) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
|
THPStorage_assertNotNull(self);
|
||||||
const auto& storage = THPStorage_Unpack(self);
|
const auto& storage = THPStorage_Unpack(self);
|
||||||
c10::DeviceType device_type = storage.device_type();
|
c10::DeviceType device_type = storage.device_type();
|
||||||
if (device_type == at::kCPU) {
|
if (device_type == at::kCPU) {
|
||||||
@ -76,18 +78,22 @@ static PyObject* THPStorage_pyNewFilenameStorage(
|
|||||||
|
|
||||||
int flags = at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_EXCLUSIVE;
|
int flags = at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_EXCLUSIVE;
|
||||||
std::string handle = at::NewProcessWideShmHandle();
|
std::string handle = at::NewProcessWideShmHandle();
|
||||||
return THPStorage_New(c10::make_intrusive<at::StorageImpl>(
|
return THPStorage_NewWithStorage(
|
||||||
c10::StorageImpl::use_byte_size_t(),
|
THPStorageClass,
|
||||||
size,
|
c10::make_intrusive<at::StorageImpl>(
|
||||||
THManagedMapAllocator::makeDataPtr(
|
c10::StorageImpl::use_byte_size_t(),
|
||||||
"", handle.c_str(), flags, static_cast<size_t>(size)),
|
size,
|
||||||
/*allocator=*/nullptr,
|
THManagedMapAllocator::makeDataPtr(
|
||||||
/*resizable=*/false));
|
"", handle.c_str(), flags, static_cast<size_t>(size)),
|
||||||
|
/*allocator=*/nullptr,
|
||||||
|
/*resizable=*/false),
|
||||||
|
c10::impl::PyInterpreterStatus::TAGGED_BY_US);
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
|
|
||||||
static PyObject* THPStorage_shareFilename(PyObject* self, PyObject* noargs) {
|
static PyObject* THPStorage_shareFilename(PyObject* self, PyObject* noargs) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
|
THPStorage_assertNotNull(self);
|
||||||
const auto& storage = THPStorage_Unpack(self);
|
const auto& storage = THPStorage_Unpack(self);
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
storage.device_type() == at::kCPU,
|
storage.device_type() == at::kCPU,
|
||||||
@ -168,13 +174,16 @@ static PyObject* THPStorage_newSharedFilename(
|
|||||||
const char* object_handle = PyBytes_AS_STRING(_object_handle);
|
const char* object_handle = PyBytes_AS_STRING(_object_handle);
|
||||||
uint64_t size = THPUtils_unpackUInt64(_size);
|
uint64_t size = THPUtils_unpackUInt64(_size);
|
||||||
int flags = at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_NOCREATE;
|
int flags = at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_NOCREATE;
|
||||||
return THPStorage_New(c10::make_intrusive<at::StorageImpl>(
|
return THPStorage_NewWithStorage(
|
||||||
c10::StorageImpl::use_byte_size_t(),
|
THPStorageClass,
|
||||||
size,
|
c10::make_intrusive<at::StorageImpl>(
|
||||||
THManagedMapAllocator::makeDataPtr(
|
c10::StorageImpl::use_byte_size_t(),
|
||||||
manager_handle, object_handle, flags, size),
|
size,
|
||||||
/*allocator=*/nullptr,
|
THManagedMapAllocator::makeDataPtr(
|
||||||
/*resizable=*/false));
|
manager_handle, object_handle, flags, size),
|
||||||
|
/*allocator=*/nullptr,
|
||||||
|
/*resizable=*/false),
|
||||||
|
c10::impl::PyInterpreterStatus::TAGGED_BY_US);
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -187,12 +196,16 @@ static PyObject* THPStorage_pyNewFdStorage(PyObject* _unused, PyObject* args) {
|
|||||||
if (size < 0) {
|
if (size < 0) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return THPStorage_New(at::new_shm_fd_storage(size));
|
return THPStorage_NewWithStorage(
|
||||||
|
THPStorageClass,
|
||||||
|
at::new_shm_fd_storage(size),
|
||||||
|
c10::impl::PyInterpreterStatus::TAGGED_BY_US);
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
|
|
||||||
static PyObject* THPStorage_shareFd(PyObject* self, PyObject* noargs) {
|
static PyObject* THPStorage_shareFd(PyObject* self, PyObject* noargs) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
|
THPStorage_assertNotNull(self);
|
||||||
const auto& storage = THPStorage_Unpack(self);
|
const auto& storage = THPStorage_Unpack(self);
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
storage.device_type() == at::kCPU, "_share_fd_: only available on CPU");
|
storage.device_type() == at::kCPU, "_share_fd_: only available on CPU");
|
||||||
@ -257,17 +270,22 @@ static PyObject* THPStorage_newSharedFd(PyObject* _unused, PyObject* args) {
|
|||||||
|
|
||||||
int flags = at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_NOCREATE |
|
int flags = at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_NOCREATE |
|
||||||
at::ALLOCATOR_MAPPED_KEEPFD | at::ALLOCATOR_MAPPED_FROMFD;
|
at::ALLOCATOR_MAPPED_KEEPFD | at::ALLOCATOR_MAPPED_FROMFD;
|
||||||
return THPStorage_New(c10::make_intrusive<at::StorageImpl>(
|
return THPStorage_NewWithStorage(
|
||||||
c10::StorageImpl::use_byte_size_t(),
|
THPStorageClass,
|
||||||
size,
|
c10::make_intrusive<at::StorageImpl>(
|
||||||
at::MapAllocator::makeDataPtr(at::WITH_FD, "", fd, flags, size, nullptr),
|
c10::StorageImpl::use_byte_size_t(),
|
||||||
/*allocator=*/nullptr,
|
size,
|
||||||
/*resizable=*/false));
|
at::MapAllocator::makeDataPtr(
|
||||||
|
at::WITH_FD, "", fd, flags, size, nullptr),
|
||||||
|
/*allocator=*/nullptr,
|
||||||
|
/*resizable=*/false),
|
||||||
|
c10::impl::PyInterpreterStatus::TAGGED_BY_US);
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
|
|
||||||
static PyObject* THPStorage_shareCuda(PyObject* self, PyObject* noargs) {
|
static PyObject* THPStorage_shareCuda(PyObject* self, PyObject* noargs) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
|
THPStorage_assertNotNull(self);
|
||||||
#ifdef USE_CUDA
|
#ifdef USE_CUDA
|
||||||
const auto& storage = THPStorage_Unpack(self);
|
const auto& storage = THPStorage_Unpack(self);
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
@ -547,7 +565,10 @@ static PyObject* THPStorage_newSharedCuda(PyObject* _unused, PyObject* args) {
|
|||||||
base->set_resizable(false);
|
base->set_resizable(false);
|
||||||
base->set_received_cuda(true);
|
base->set_received_cuda(true);
|
||||||
|
|
||||||
return THPStorage_New(std::move(base));
|
return THPStorage_NewWithStorage(
|
||||||
|
THPStorageClass,
|
||||||
|
std::move(base),
|
||||||
|
c10::impl::PyInterpreterStatus::TAGGED_BY_US);
|
||||||
#else
|
#else
|
||||||
TORCH_CHECK(false, "CUDA is not available");
|
TORCH_CHECK(false, "CUDA is not available");
|
||||||
#endif
|
#endif
|
||||||
@ -572,7 +593,7 @@ PyObject* THPStorage_newWithWeakPtr(PyObject* _unused, PyObject* arg) {
|
|||||||
THPUtils_checkLong(arg), "_new_with_weak_ptr(): arg must be an 'int'");
|
THPUtils_checkLong(arg), "_new_with_weak_ptr(): arg must be an 'int'");
|
||||||
c10::StorageImpl* weak_storage = (c10::StorageImpl*)PyLong_AsVoidPtr(arg);
|
c10::StorageImpl* weak_storage = (c10::StorageImpl*)PyLong_AsVoidPtr(arg);
|
||||||
if (auto* storage = c10::raw::weak_intrusive_ptr::lock(weak_storage)) {
|
if (auto* storage = c10::raw::weak_intrusive_ptr::lock(weak_storage)) {
|
||||||
return THPStorage_New(
|
return THPStorage_Wrap(
|
||||||
c10::intrusive_ptr<c10::StorageImpl>::reclaim(storage));
|
c10::intrusive_ptr<c10::StorageImpl>::reclaim(storage));
|
||||||
}
|
}
|
||||||
Py_RETURN_NONE;
|
Py_RETURN_NONE;
|
||||||
@ -604,6 +625,7 @@ PyObject* THPStorage_expired(PyObject* _unused, PyObject* arg) {
|
|||||||
|
|
||||||
PyObject* THPStorage_sharedFd(PyObject* self, PyObject* noargs) {
|
PyObject* THPStorage_sharedFd(PyObject* self, PyObject* noargs) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
|
THPStorage_assertNotNull(self);
|
||||||
at::MapAllocator* ctx = nullptr;
|
at::MapAllocator* ctx = nullptr;
|
||||||
const auto& storage = THPStorage_Unpack(self);
|
const auto& storage = THPStorage_Unpack(self);
|
||||||
if (storage.device_type() == at::kCPU) {
|
if (storage.device_type() == at::kCPU) {
|
||||||
|
@ -26,6 +26,7 @@
|
|||||||
#include <torch/csrc/tensor/python_tensor.h>
|
#include <torch/csrc/tensor/python_tensor.h>
|
||||||
#include <torch/csrc/utils/pybind.h>
|
#include <torch/csrc/utils/pybind.h>
|
||||||
#include <torch/csrc/utils/pycfunction_helpers.h>
|
#include <torch/csrc/utils/pycfunction_helpers.h>
|
||||||
|
#include <torch/csrc/utils/pyobject_preservation.h>
|
||||||
#include <torch/csrc/utils/python_arg_parser.h>
|
#include <torch/csrc/utils/python_arg_parser.h>
|
||||||
#include <torch/csrc/utils/python_dispatch.h>
|
#include <torch/csrc/utils/python_dispatch.h>
|
||||||
#include <torch/csrc/utils/python_strings.h>
|
#include <torch/csrc/utils/python_strings.h>
|
||||||
@ -268,7 +269,8 @@ PyObject* THPVariable_Wrap(at::TensorBase var) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
c10::optional<PyObject*> mb_obj =
|
c10::optional<PyObject*> mb_obj =
|
||||||
var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(getPyInterpreter());
|
var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
|
||||||
|
getPyInterpreter(), /*ignore_hermetic_tls=*/false);
|
||||||
c10::impl::PyInterpreterStatus status;
|
c10::impl::PyInterpreterStatus status;
|
||||||
if (mb_obj.has_value()) {
|
if (mb_obj.has_value()) {
|
||||||
auto obj = *mb_obj;
|
auto obj = *mb_obj;
|
||||||
@ -345,7 +347,8 @@ bool isResurrectable(THPVariable* self) {
|
|||||||
}
|
}
|
||||||
// Check if this is hermetic. If it is, no resurrection.
|
// Check if this is hermetic. If it is, no resurrection.
|
||||||
if (tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
|
if (tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
|
||||||
getPyInterpreter()) != c10::make_optional((PyObject*)self)) {
|
getPyInterpreter(), /*ignore_hermetic_tls=*/false) !=
|
||||||
|
c10::make_optional((PyObject*)self)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
@ -369,7 +372,16 @@ static bool THPVariable_tryResurrect(THPVariable* self) {
|
|||||||
TORCH_INTERNAL_ASSERT(
|
TORCH_INTERNAL_ASSERT(
|
||||||
!tensor.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj());
|
!tensor.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj());
|
||||||
|
|
||||||
tensor.unsafeGetTensorImpl()->pyobj_slot()->set_owns_pyobj(true);
|
c10::TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
|
||||||
|
auto maybe_pyobj = tensor_impl->pyobj_slot()->check_pyobj(
|
||||||
|
getPyInterpreter(),
|
||||||
|
/*ignore_hermetic_tls=*/false);
|
||||||
|
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
maybe_pyobj.has_value(),
|
||||||
|
"Trying to preserve a Python tensor whose PyObjectSlot does not have a PyObject");
|
||||||
|
|
||||||
|
tensor_impl->pyobj_slot()->set_owns_pyobj(true);
|
||||||
|
|
||||||
// Resurrect the Python object. This is something CPython does
|
// Resurrect the Python object. This is something CPython does
|
||||||
// internally occasionally, see
|
// internally occasionally, see
|
||||||
@ -443,7 +455,8 @@ static int THPVariable_clear(THPVariable* self) {
|
|||||||
|
|
||||||
if (!self->cdata.unsafeIsBorrowed() &&
|
if (!self->cdata.unsafeIsBorrowed() &&
|
||||||
tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
|
tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
|
||||||
getPyInterpreter()) == c10::make_optional((PyObject*)self)) {
|
getPyInterpreter(), /*ignore_hermetic_tls=*/false) ==
|
||||||
|
c10::make_optional((PyObject*)self)) {
|
||||||
// TODO: empirically, on OS X this assert appears to be untrue
|
// TODO: empirically, on OS X this assert appears to be untrue
|
||||||
// In test_py_tensors_multi_async_call - ProcessGroupRpcTestWithSpawn
|
// In test_py_tensors_multi_async_call - ProcessGroupRpcTestWithSpawn
|
||||||
// distributed/rpc/test_process_group_agent.py
|
// distributed/rpc/test_process_group_agent.py
|
||||||
@ -1747,26 +1760,6 @@ PyObject* THPVariable_pynew(
|
|||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
|
|
||||||
static void clear_slots(PyTypeObject* type, PyObject* self) {
|
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
||||||
Py_ssize_t i, n;
|
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
||||||
PyMemberDef* mp;
|
|
||||||
|
|
||||||
n = Py_SIZE(type);
|
|
||||||
mp = type->tp_members;
|
|
||||||
for (i = 0; i < n; i++, mp++) {
|
|
||||||
if (mp->type == T_OBJECT_EX && !(mp->flags & READONLY)) {
|
|
||||||
char* addr = (char*)self + mp->offset;
|
|
||||||
PyObject* obj = *(PyObject**)addr;
|
|
||||||
if (obj != nullptr) {
|
|
||||||
*(PyObject**)addr = nullptr;
|
|
||||||
Py_DECREF(obj);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NB: this is not the tp_dealloc on THPVariable; instead, its the dealloc
|
// NB: this is not the tp_dealloc on THPVariable; instead, its the dealloc
|
||||||
// on subclasses. It's never valid to construct a THPVariable so it's not
|
// on subclasses. It's never valid to construct a THPVariable so it's not
|
||||||
// necessary to implement the dealloc for that case
|
// necessary to implement the dealloc for that case
|
||||||
@ -1886,8 +1879,8 @@ static PyObject* THPVariable_NewWithVar(
|
|||||||
|
|
||||||
// This function overwrite the Tensor's pyobj field without extra checks
|
// This function overwrite the Tensor's pyobj field without extra checks
|
||||||
// Make sure it is not set otherwise we would leak memory
|
// Make sure it is not set otherwise we would leak memory
|
||||||
auto mb_obj =
|
auto mb_obj = _var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
|
||||||
_var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(getPyInterpreter());
|
getPyInterpreter(), /*ignore_hermetic_tls=*/false);
|
||||||
|
|
||||||
// Under some circumstances, we may attempt to create a new Python
|
// Under some circumstances, we may attempt to create a new Python
|
||||||
// object for a variable that already has a Python object. The most common
|
// object for a variable that already has a Python object. The most common
|
||||||
|
19
torch/csrc/utils/pyobject_preservation.cpp
Normal file
19
torch/csrc/utils/pyobject_preservation.cpp
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
#include <torch/csrc/utils/pyobject_preservation.h>
|
||||||
|
|
||||||
|
#include <structmember.h>
|
||||||
|
|
||||||
|
void clear_slots(PyTypeObject* type, PyObject* self) {
|
||||||
|
Py_ssize_t n = Py_SIZE(type);
|
||||||
|
PyMemberDef* mp = type->tp_members;
|
||||||
|
|
||||||
|
for (Py_ssize_t i = 0; i < n; i++, mp++) {
|
||||||
|
if (mp->type == T_OBJECT_EX && !(mp->flags & READONLY)) {
|
||||||
|
char* addr = (char*)self + mp->offset;
|
||||||
|
PyObject* obj = *(PyObject**)addr;
|
||||||
|
if (obj != nullptr) {
|
||||||
|
*(PyObject**)addr = nullptr;
|
||||||
|
Py_DECREF(obj);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
7
torch/csrc/utils/pyobject_preservation.h
Normal file
7
torch/csrc/utils/pyobject_preservation.h
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/csrc/python_headers.h>
|
||||||
|
|
||||||
|
// This file contains utilities used for handling PyObject preservation
|
||||||
|
|
||||||
|
void clear_slots(PyTypeObject* type, PyObject* self);
|
@ -1106,8 +1106,8 @@ inline at::Storage PythonArgs::storage(
|
|||||||
is_typed_storage = false;
|
is_typed_storage = false;
|
||||||
storage_scalar_type = at::ScalarType::Undefined;
|
storage_scalar_type = at::ScalarType::Undefined;
|
||||||
} else {
|
} else {
|
||||||
storage =
|
std::tie(storage, storage_scalar_type, is_typed_storage) =
|
||||||
createStorageGetType(args[i], storage_scalar_type, is_typed_storage);
|
createStorageGetType(args[i]);
|
||||||
}
|
}
|
||||||
return storage;
|
return storage;
|
||||||
}
|
}
|
||||||
|
@ -385,10 +385,12 @@ Tensor internal_new_from_data(
|
|||||||
at::tracer::impl::NoTracerDispatchMode tracer_guard;
|
at::tracer::impl::NoTracerDispatchMode tracer_guard;
|
||||||
|
|
||||||
if (isStorage(data)) {
|
if (isStorage(data)) {
|
||||||
ScalarType storage_scalar_type{ScalarType::Undefined};
|
|
||||||
bool is_typed_storage = false;
|
bool is_typed_storage = false;
|
||||||
Storage storage =
|
ScalarType storage_scalar_type{ScalarType::Undefined};
|
||||||
createStorageGetType(data, storage_scalar_type, is_typed_storage);
|
Storage storage;
|
||||||
|
std::tie(storage, storage_scalar_type, is_typed_storage) =
|
||||||
|
createStorageGetType(data);
|
||||||
|
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
!is_typed_storage || storage_scalar_type == scalar_type,
|
!is_typed_storage || storage_scalar_type == scalar_type,
|
||||||
"Expected a Storage of type ",
|
"Expected a Storage of type ",
|
||||||
|
Reference in New Issue
Block a user