mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-17 08:11:08 +08:00
Compare commits
22 Commits
pyobjectsl
...
ciflow/h10
| Author | SHA1 | Date | |
|---|---|---|---|
| 610f9b437d | |||
| 7d0eb9b4f6 | |||
| af6ae22dbd | |||
| e3afc32110 | |||
| 15aa7e01a9 | |||
| fc5133bacb | |||
| 1a49d0cda4 | |||
| e9a3814dea | |||
| 2ace9e465a | |||
| d990b72872 | |||
| a8243bd1d4 | |||
| 1ccc757cac | |||
| 2abf4ecf2f | |||
| ff3e2942b4 | |||
| a81e5177de | |||
| f02dba7893 | |||
| 09abf0ceff | |||
| b4d23566db | |||
| 20ca3c48de | |||
| d83d25dee4 | |||
| 528d3fc4ce | |||
| fd178b2e17 |
@ -245,9 +245,6 @@ class TORCH_API TensorBase {
|
|||||||
size_t weak_use_count() const noexcept {
|
size_t weak_use_count() const noexcept {
|
||||||
return impl_.weak_use_count();
|
return impl_.weak_use_count();
|
||||||
}
|
}
|
||||||
bool is_uniquely_owned() const noexcept {
|
|
||||||
return impl_.is_uniquely_owned();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string toString() const;
|
std::string toString() const;
|
||||||
|
|
||||||
|
|||||||
@ -223,62 +223,6 @@ CONVERT_FROM_BF16_TEMPLATE(double)
|
|||||||
CONVERT_FROM_BF16_TEMPLATE(float16_t)
|
CONVERT_FROM_BF16_TEMPLATE(float16_t)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef __ARM_FEATURE_BF16
|
|
||||||
|
|
||||||
// clang-[17, 20] crashes when autovectorizing static cast to bf16
|
|
||||||
// Below is a workaround to have some vectorization
|
|
||||||
// Works decently well for smaller int types
|
|
||||||
template <typename from_type>
|
|
||||||
inline void convertToBf16Impl(
|
|
||||||
const from_type* __restrict src,
|
|
||||||
c10::BFloat16* __restrict dst,
|
|
||||||
uint64_t n) {
|
|
||||||
bfloat16_t* dstPtr = reinterpret_cast<bfloat16_t*>(dst);
|
|
||||||
uint64_t loopBound = n - (n % 16);
|
|
||||||
uint64_t i = 0;
|
|
||||||
for (; i < loopBound; i += 16) {
|
|
||||||
float32x4_t a, b, c, d;
|
|
||||||
a[0] = static_cast<float>(src[i]);
|
|
||||||
a[1] = static_cast<float>(src[i + 1]);
|
|
||||||
a[2] = static_cast<float>(src[i + 2]);
|
|
||||||
a[3] = static_cast<float>(src[i + 3]);
|
|
||||||
b[0] = static_cast<float>(src[i + 4]);
|
|
||||||
b[1] = static_cast<float>(src[i + 5]);
|
|
||||||
b[2] = static_cast<float>(src[i + 6]);
|
|
||||||
b[3] = static_cast<float>(src[i + 7]);
|
|
||||||
c[0] = static_cast<float>(src[i + 8]);
|
|
||||||
c[1] = static_cast<float>(src[i + 9]);
|
|
||||||
c[2] = static_cast<float>(src[i + 10]);
|
|
||||||
c[3] = static_cast<float>(src[i + 11]);
|
|
||||||
d[0] = static_cast<float>(src[i + 12]);
|
|
||||||
d[1] = static_cast<float>(src[i + 13]);
|
|
||||||
d[2] = static_cast<float>(src[i + 14]);
|
|
||||||
d[3] = static_cast<float>(src[i + 15]);
|
|
||||||
|
|
||||||
vst1q_bf16(dstPtr + i, vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(a), b));
|
|
||||||
vst1q_bf16(dstPtr + i + 8, vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(c), d));
|
|
||||||
}
|
|
||||||
|
|
||||||
#pragma clang loop vectorize(disable) interleave(disable) unroll(disable)
|
|
||||||
for (; i < n; i++) {
|
|
||||||
float a = static_cast<float>(src[i]);
|
|
||||||
dstPtr[i] = vcvth_bf16_f32(a);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#define CONVERT_TO_BF16_TEMPLATE(from_type) \
|
|
||||||
template <> \
|
|
||||||
inline void convert(const from_type* src, c10::BFloat16* dst, int64_t n) { \
|
|
||||||
return convertToBf16Impl<from_type>(src, dst, n); \
|
|
||||||
}
|
|
||||||
|
|
||||||
CONVERT_TO_BF16_TEMPLATE(uint8_t)
|
|
||||||
CONVERT_TO_BF16_TEMPLATE(int8_t)
|
|
||||||
CONVERT_TO_BF16_TEMPLATE(int16_t)
|
|
||||||
CONVERT_TO_BF16_TEMPLATE(int32_t)
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
inline void convertBoolToBfloat16Impl(
|
inline void convertBoolToBfloat16Impl(
|
||||||
const bool* __restrict src,
|
const bool* __restrict src,
|
||||||
c10::BFloat16* __restrict dst,
|
c10::BFloat16* __restrict dst,
|
||||||
|
|||||||
@ -10,13 +10,6 @@
|
|||||||
...
|
...
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
|
||||||
ignore_empty_generic_uninitialised_conditional_jump
|
|
||||||
Memcheck:Cond
|
|
||||||
fun:_ZN2at6detail13empty_genericEN3c108ArrayRefIlEEPNS1_9AllocatorENS1_14DispatchKeySetENS1_10ScalarTypeESt8optionalINS1_12MemoryFormatEE
|
|
||||||
...
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
{
|
||||||
Cond_cuda
|
Cond_cuda
|
||||||
Memcheck:Cond
|
Memcheck:Cond
|
||||||
|
|||||||
@ -44,7 +44,7 @@ struct C10_API SafePyObject {
|
|||||||
(*other.pyinterpreter_)->incref(other.data_);
|
(*other.pyinterpreter_)->incref(other.data_);
|
||||||
}
|
}
|
||||||
if (data_ != nullptr) {
|
if (data_ != nullptr) {
|
||||||
(*pyinterpreter_)->decref(data_);
|
(*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
|
||||||
}
|
}
|
||||||
data_ = other.data_;
|
data_ = other.data_;
|
||||||
pyinterpreter_ = other.pyinterpreter_;
|
pyinterpreter_ = other.pyinterpreter_;
|
||||||
@ -53,7 +53,7 @@ struct C10_API SafePyObject {
|
|||||||
|
|
||||||
~SafePyObject() {
|
~SafePyObject() {
|
||||||
if (data_ != nullptr) {
|
if (data_ != nullptr) {
|
||||||
(*pyinterpreter_)->decref(data_);
|
(*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -48,30 +48,6 @@ void warnDeprecatedDataPtr() {
|
|||||||
TORCH_CHECK(false, "Cannot access data pointer of Storage that is invalid.");
|
TORCH_CHECK(false, "Cannot access data pointer of Storage that is invalid.");
|
||||||
}
|
}
|
||||||
|
|
||||||
void StorageImpl::incref_pyobject() const {
|
|
||||||
// Because intrusive_ptr incref uses relaxed memory order, we need to
|
|
||||||
// do an acquire fence to ensure that the kHasPyObject bit was
|
|
||||||
// observed before the load of the PyObject* below.
|
|
||||||
// NB: This is a no-op on x86/x86-64
|
|
||||||
std::atomic_thread_fence(std::memory_order_acquire);
|
|
||||||
|
|
||||||
PyObject* obj = pyobj_slot_.load_pyobj();
|
|
||||||
(*pyobj_slot_.pyobj_interpreter())->incref(obj);
|
|
||||||
}
|
|
||||||
|
|
||||||
void StorageImpl::decref_pyobject() const {
|
|
||||||
PyObject* obj = pyobj_slot_.load_pyobj();
|
|
||||||
(*pyobj_slot_.pyobj_interpreter())->decref(obj);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool StorageImpl::try_incref_pyobject() const {
|
|
||||||
c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter();
|
|
||||||
if (C10_UNLIKELY(!interp)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return (*interp)->try_incref(pyobj_slot_);
|
|
||||||
}
|
|
||||||
|
|
||||||
void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr) {
|
void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr) {
|
||||||
// Allowlist verification.
|
// Allowlist verification.
|
||||||
// Only if the devicetype is in the allowlist,
|
// Only if the devicetype is in the allowlist,
|
||||||
|
|||||||
@ -105,12 +105,6 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
|
|||||||
data_ptr_.clear();
|
data_ptr_.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
void incref_pyobject() const override final;
|
|
||||||
|
|
||||||
void decref_pyobject() const override final;
|
|
||||||
|
|
||||||
bool try_incref_pyobject() const override final;
|
|
||||||
|
|
||||||
size_t nbytes() const {
|
size_t nbytes() const {
|
||||||
// OK to do this instead of maybe_as_int as nbytes is guaranteed positive
|
// OK to do this instead of maybe_as_int as nbytes is guaranteed positive
|
||||||
TORCH_CHECK(!size_bytes_is_heap_allocated_);
|
TORCH_CHECK(!size_bytes_is_heap_allocated_);
|
||||||
@ -376,18 +370,4 @@ C10_API c10::intrusive_ptr<c10::StorageImpl> make_storage_impl(
|
|||||||
bool resizable,
|
bool resizable,
|
||||||
std::optional<at::Device> device_opt);
|
std::optional<at::Device> device_opt);
|
||||||
|
|
||||||
namespace detail {
|
|
||||||
|
|
||||||
#ifndef C10_MOBILE
|
|
||||||
template <class T>
|
|
||||||
struct TargetTraits<
|
|
||||||
T,
|
|
||||||
std::enable_if_t<
|
|
||||||
std::is_base_of_v<c10::StorageImpl, std::remove_cv_t<T>>>> {
|
|
||||||
static constexpr bool can_have_pyobject = true;
|
|
||||||
};
|
|
||||||
#endif
|
|
||||||
|
|
||||||
} // namespace detail
|
|
||||||
|
|
||||||
} // namespace c10
|
} // namespace c10
|
||||||
|
|||||||
@ -277,6 +277,7 @@ void TensorImpl::release_resources() {
|
|||||||
if (storage_) {
|
if (storage_) {
|
||||||
storage_ = {};
|
storage_ = {};
|
||||||
}
|
}
|
||||||
|
pyobj_slot_.maybe_destroy_pyobj();
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
|
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
|
||||||
@ -988,30 +989,6 @@ void TensorImpl::empty_tensor_restride_symint(MemoryFormat memory_format) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TensorImpl::incref_pyobject() const {
|
|
||||||
// Because intrusive_ptr incref uses relaxed memory order, we need to
|
|
||||||
// do an acquire fence to ensure that the kHasPyObject bit was
|
|
||||||
// observed before the load of the PyObject* below.
|
|
||||||
// NB: This is a no-op on x86/x86-64
|
|
||||||
std::atomic_thread_fence(std::memory_order_acquire);
|
|
||||||
|
|
||||||
PyObject* obj = pyobj_slot_.load_pyobj();
|
|
||||||
(*pyobj_slot_.pyobj_interpreter())->incref(obj);
|
|
||||||
}
|
|
||||||
|
|
||||||
void TensorImpl::decref_pyobject() const {
|
|
||||||
PyObject* obj = pyobj_slot_.load_pyobj();
|
|
||||||
(*pyobj_slot_.pyobj_interpreter())->decref(obj);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool TensorImpl::try_incref_pyobject() const {
|
|
||||||
c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter();
|
|
||||||
if (C10_UNLIKELY(!interp)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return (*interp)->try_incref(pyobj_slot_);
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace impl {
|
namespace impl {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|||||||
@ -2178,12 +2178,6 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||||||
return &pyobj_slot_;
|
return &pyobj_slot_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void incref_pyobject() const override final;
|
|
||||||
|
|
||||||
void decref_pyobject() const override final;
|
|
||||||
|
|
||||||
bool try_incref_pyobject() const override final;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// See NOTE [std::optional operator usage in CUDA]
|
// See NOTE [std::optional operator usage in CUDA]
|
||||||
// We probably don't want to expose this publicly until
|
// We probably don't want to expose this publicly until
|
||||||
@ -3085,19 +3079,6 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||||||
friend class C10_TensorImpl_Size_Check_Dummy_Class;
|
friend class C10_TensorImpl_Size_Check_Dummy_Class;
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace detail {
|
|
||||||
|
|
||||||
#ifndef C10_MOBILE
|
|
||||||
template <class T>
|
|
||||||
struct TargetTraits<
|
|
||||||
T,
|
|
||||||
std::enable_if_t<std::is_base_of_v<c10::TensorImpl, std::remove_cv_t<T>>>> {
|
|
||||||
static constexpr bool can_have_pyobject = true;
|
|
||||||
};
|
|
||||||
#endif
|
|
||||||
|
|
||||||
} // namespace detail
|
|
||||||
|
|
||||||
// Note [TensorImpl size constraints]
|
// Note [TensorImpl size constraints]
|
||||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
// Changed the size of TensorImpl? If the size went down, good for
|
// Changed the size of TensorImpl? If the size went down, good for
|
||||||
|
|||||||
@ -11,11 +11,8 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
|
|||||||
|
|
||||||
void incref(PyObject* pyobj) const override {} // do nothing
|
void incref(PyObject* pyobj) const override {} // do nothing
|
||||||
|
|
||||||
void decref(PyObject* pyobj) const override {} // do nothing
|
void decref(PyObject* pyobj, bool has_pyobj_slot) const override {
|
||||||
|
} // do nothing
|
||||||
bool try_incref(const c10::impl::PyObjectSlot& pyobj_slot) const override {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
#define PANIC(m) \
|
#define PANIC(m) \
|
||||||
TORCH_INTERNAL_ASSERT( \
|
TORCH_INTERNAL_ASSERT( \
|
||||||
@ -23,10 +20,6 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
|
|||||||
"attempted to call " #m \
|
"attempted to call " #m \
|
||||||
" on a Tensor with nontrivial PyObject after corresponding interpreter died")
|
" on a Tensor with nontrivial PyObject after corresponding interpreter died")
|
||||||
|
|
||||||
size_t refcnt(PyObject* pyobj) const override {
|
|
||||||
PANIC(refcnt);
|
|
||||||
}
|
|
||||||
|
|
||||||
c10::intrusive_ptr<TensorImpl> detach(const TensorImpl* self) const override {
|
c10::intrusive_ptr<TensorImpl> detach(const TensorImpl* self) const override {
|
||||||
PANIC(detach);
|
PANIC(detach);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -18,9 +18,6 @@ namespace c10 {
|
|||||||
struct IValue;
|
struct IValue;
|
||||||
class OperatorHandle;
|
class OperatorHandle;
|
||||||
struct TensorImpl;
|
struct TensorImpl;
|
||||||
namespace impl {
|
|
||||||
struct PyObjectSlot;
|
|
||||||
} // namespace impl
|
|
||||||
} // namespace c10
|
} // namespace c10
|
||||||
|
|
||||||
namespace torch::jit {
|
namespace torch::jit {
|
||||||
@ -129,12 +126,9 @@ struct C10_API PyInterpreterVTable {
|
|||||||
|
|
||||||
// Run Py_INCREF on a PyObject.
|
// Run Py_INCREF on a PyObject.
|
||||||
virtual void incref(PyObject* pyobj) const = 0;
|
virtual void incref(PyObject* pyobj) const = 0;
|
||||||
// Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call.
|
// Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call
|
||||||
virtual void decref(PyObject* pyobj) const = 0;
|
// See NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg]
|
||||||
// Run PyUnstable_TryIncRef on a PyObject if it's not NULL.
|
virtual void decref(PyObject* pyobj, bool has_pyobj_slot) const = 0;
|
||||||
virtual bool try_incref(const c10::impl::PyObjectSlot& pyobj_slot) const = 0;
|
|
||||||
// Run Py_REFCNT on a PyObject.
|
|
||||||
virtual size_t refcnt(PyObject* pyobj) const = 0;
|
|
||||||
|
|
||||||
// Perform a detach by deferring to the __torch_dispatch__ implementation of
|
// Perform a detach by deferring to the __torch_dispatch__ implementation of
|
||||||
// detach, which will also arrange for the PyObject to get copied in this
|
// detach, which will also arrange for the PyObject to get copied in this
|
||||||
|
|||||||
56
c10/core/impl/PyObjectSlot.cpp
Normal file
56
c10/core/impl/PyObjectSlot.cpp
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
#include <c10/core/impl/PyObjectSlot.h>
|
||||||
|
|
||||||
|
namespace c10::impl {
|
||||||
|
|
||||||
|
PyObjectSlot::PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {}
|
||||||
|
|
||||||
|
PyObjectSlot::~PyObjectSlot() {
|
||||||
|
maybe_destroy_pyobj();
|
||||||
|
}
|
||||||
|
|
||||||
|
void PyObjectSlot::maybe_destroy_pyobj() {
|
||||||
|
if (owns_pyobj()) {
|
||||||
|
TORCH_INTERNAL_ASSERT(pyobj_interpreter_ != nullptr);
|
||||||
|
TORCH_INTERNAL_ASSERT(pyobj_ != nullptr);
|
||||||
|
(*pyobj_interpreter_.load(std::memory_order_acquire))
|
||||||
|
->decref(_unchecked_untagged_pyobj(), /*has_pyobj_slot*/ true);
|
||||||
|
// NB: this destructor can only be entered when there are no
|
||||||
|
// references to this C++ object (obviously), NOR any references
|
||||||
|
// to the PyObject (if there are references to the PyObject,
|
||||||
|
// then the PyObject holds an owning reference to the tensor).
|
||||||
|
// So it is OK to clear pyobj_ here as it is impossible for it to
|
||||||
|
// be used again (modulo weak reference races)
|
||||||
|
pyobj_ = nullptr; // for safety
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
PyInterpreter* PyObjectSlot::pyobj_interpreter() {
|
||||||
|
return pyobj_interpreter_.load(std::memory_order_acquire);
|
||||||
|
}
|
||||||
|
|
||||||
|
PyObject* PyObjectSlot::_unchecked_untagged_pyobj() const {
|
||||||
|
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||||
|
return reinterpret_cast<PyObject*>(
|
||||||
|
reinterpret_cast<uintptr_t>(pyobj_) & ~0x1ULL);
|
||||||
|
}
|
||||||
|
|
||||||
|
PyInterpreter& PyObjectSlot::load_pyobj_interpreter() const {
|
||||||
|
auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire);
|
||||||
|
if (interpreter) {
|
||||||
|
return *interpreter;
|
||||||
|
}
|
||||||
|
TORCH_CHECK(false, "cannot access PyObject for Tensor - no interpreter set");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool PyObjectSlot::owns_pyobj() {
|
||||||
|
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||||
|
return reinterpret_cast<uintptr_t>(pyobj_) & 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
void PyObjectSlot::set_owns_pyobj(bool b) {
|
||||||
|
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||||
|
pyobj_ = reinterpret_cast<PyObject*>(
|
||||||
|
reinterpret_cast<uintptr_t>(_unchecked_untagged_pyobj()) | b);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace c10::impl
|
||||||
@ -8,58 +8,117 @@
|
|||||||
|
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
|
|
||||||
namespace torch::utils {
|
|
||||||
class PyObjectPreservation;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace c10::impl {
|
namespace c10::impl {
|
||||||
|
|
||||||
struct C10_API PyObjectSlot {
|
struct C10_API PyObjectSlot {
|
||||||
public:
|
public:
|
||||||
PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {}
|
PyObjectSlot();
|
||||||
|
|
||||||
|
~PyObjectSlot();
|
||||||
|
|
||||||
|
void maybe_destroy_pyobj();
|
||||||
|
|
||||||
|
// Associate the TensorImpl with the specified PyObject, and, if necessary,
|
||||||
|
// also tag the interpreter.
|
||||||
|
//
|
||||||
|
// NB: This lives in a header so that we can inline away the switch on status
|
||||||
|
//
|
||||||
|
// NB: THIS FUNCTION CAN RAISE AN EXCEPTION. Make sure to clean up after
|
||||||
|
// PyObject if necessary!
|
||||||
|
void init_pyobj(PyObject* pyobj) {
|
||||||
|
pyobj_interpreter_.store(
|
||||||
|
getGlobalPyInterpreter(), std::memory_order_relaxed);
|
||||||
|
pyobj_ = pyobj;
|
||||||
|
}
|
||||||
|
|
||||||
// Query the PyObject interpreter. This may return null if there is no
|
// Query the PyObject interpreter. This may return null if there is no
|
||||||
// interpreter.
|
// interpreter. This is racy!
|
||||||
PyInterpreter* pyobj_interpreter() const {
|
PyInterpreter* pyobj_interpreter();
|
||||||
return pyobj_interpreter_.load(std::memory_order_acquire);
|
|
||||||
|
PyObject* _unchecked_untagged_pyobj() const;
|
||||||
|
|
||||||
|
// Test the interpreter tag. If tagged for the current interpreter, return
|
||||||
|
// a non-nullopt (but possibly null) PyObject. If (possibly) untagged,
|
||||||
|
// returns a nullopt. If it is definitely invalid, raises an error.
|
||||||
|
//
|
||||||
|
// If `ignore_hermetic_tls` is false and this function is called from a
|
||||||
|
// hermetic context (ie, `HermeticPyObjectTLS::get_state()` is true), then
|
||||||
|
// nullopt is returned. If `ignore_hermetic_tls` is true, then the hermetic
|
||||||
|
// context is ignored, allowing you to check the interpreter tag of a
|
||||||
|
// nonhermetic PyObject from within a hermetic context. This is necessary
|
||||||
|
// because there are some cases where the deallocator function of a
|
||||||
|
// nonhermetic PyObject is called from within a hermetic context, so it must
|
||||||
|
// be properly treated as a nonhermetic PyObject.
|
||||||
|
//
|
||||||
|
// NB: this lives in header so that we can avoid actually creating the
|
||||||
|
// std::optional
|
||||||
|
|
||||||
|
// @todo alban: I'm not too sure what's going on here, we can probably delete
|
||||||
|
// it but it's worthwhile making sure
|
||||||
|
std::optional<PyObject*> check_pyobj(bool ignore_hermetic_tls = false) const {
|
||||||
|
impl::PyInterpreter* interpreter =
|
||||||
|
pyobj_interpreter_.load(std::memory_order_acquire);
|
||||||
|
if (interpreter == nullptr) {
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) {
|
||||||
|
return std::nullopt;
|
||||||
|
} else {
|
||||||
|
return _unchecked_untagged_pyobj();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
PyInterpreter& load_pyobj_interpreter() const {
|
PyInterpreter& load_pyobj_interpreter() const;
|
||||||
auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire);
|
|
||||||
TORCH_INTERNAL_ASSERT(
|
|
||||||
interpreter, "cannot access PyObject for Tensor - no interpreter set");
|
|
||||||
return *interpreter;
|
|
||||||
}
|
|
||||||
|
|
||||||
PyObject* load_pyobj() const {
|
bool owns_pyobj();
|
||||||
return pyobj_.load(std::memory_order_acquire);
|
|
||||||
}
|
|
||||||
|
|
||||||
void store_pyobj(PyObject* obj) {
|
void set_owns_pyobj(bool b);
|
||||||
pyobj_.store(obj, std::memory_order_release);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool has_unique_reference() const {
|
|
||||||
PyObject* pyobj = load_pyobj();
|
|
||||||
return pyobj != nullptr && load_pyobj_interpreter()->refcnt(pyobj) == 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
void clear() {
|
|
||||||
pyobj_.store(nullptr, std::memory_order_relaxed);
|
|
||||||
pyobj_interpreter_.store(nullptr, std::memory_order_relaxed);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// This is now always the global interpreter if the PyObject is set.
|
// This field contains the interpreter tag for this object. See
|
||||||
// Maybe we can remove this field some day...
|
// Note [Python interpreter tag] for general context
|
||||||
|
//
|
||||||
|
// Note [Memory ordering on Python interpreter tag]
|
||||||
|
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
// What memory_order do we need when accessing this atomic? We don't
|
||||||
|
// need a single total modification order (as provided by
|
||||||
|
// memory_order_seq_cst) as pyobj_interpreter_ is monotonic: it can only
|
||||||
|
// transition from -1 to some positive integer and never changes afterwards.
|
||||||
|
// Because there is only one modification, it trivially already has a total
|
||||||
|
// modification order (e.g., we don't need fences or locked instructions on
|
||||||
|
// x86)
|
||||||
|
//
|
||||||
|
// In fact, one could make a reasonable argument that relaxed reads are OK,
|
||||||
|
// due to the presence of external locking (GIL) to ensure that interactions
|
||||||
|
// with other data structures are still correctly synchronized, so that
|
||||||
|
// we fall in the "Single-Location Data Structures" case as described in
|
||||||
|
// http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf
|
||||||
|
// However, on x86, it doesn't matter if I use acquire or relaxed on the load
|
||||||
|
// as I get the same assembly in both cases. So I just use the more
|
||||||
|
// conservative acquire (which will impede compiler optimizations but I don't
|
||||||
|
// care)
|
||||||
std::atomic<PyInterpreter*> pyobj_interpreter_;
|
std::atomic<PyInterpreter*> pyobj_interpreter_;
|
||||||
|
|
||||||
// The PyObject representing this Tensor or nullptr. Ownership is managed
|
// This field contains a reference to a PyObject representing this Tensor.
|
||||||
// by intrusive_ptr. By the time the PyObjectSlot is destroyed, this
|
// If pyobj is nullptr, when we transfer Tensor to Python, we allocate a new
|
||||||
// reference is already dead.
|
// PyObject for it and set this field. This field does not have to be
|
||||||
std::atomic<PyObject*> pyobj_;
|
// protected by an atomic as it is only allowed to be accessed when you hold
|
||||||
|
// the GIL, or during destruction of the tensor.
|
||||||
friend class torch::utils::PyObjectPreservation;
|
//
|
||||||
|
// When a PyObject dies, you are obligated to clear this field
|
||||||
|
// (otherwise, you will try to use-after-free the pyobj); this currently
|
||||||
|
// occurs in THPVariable_clear in torch/csrc/autograd/python_variable.cpp
|
||||||
|
//
|
||||||
|
// NB: Ordinarily, this should not be a strong reference, as if the
|
||||||
|
// PyObject owns the Tensor, this would create a reference cycle.
|
||||||
|
// However, sometimes this ownership flips. To track who owns
|
||||||
|
// who, this has a single pointer tag indicating whether or not the
|
||||||
|
// C++ object owns the PyObject (the common case, zero, means PyObject
|
||||||
|
// owns the C++ object); see _unchecked_untagged_pyobj for raw access
|
||||||
|
// or check_pyobj for checked access. See references to PyObject
|
||||||
|
// resurrection in torch/csrc/autograd/python_variable.cpp
|
||||||
|
PyObject* pyobj_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace c10::impl
|
} // namespace c10::impl
|
||||||
|
|||||||
@ -12,10 +12,6 @@ template <typename, typename...>
|
|||||||
class class_;
|
class class_;
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace torch::utils {
|
|
||||||
class PyObjectPreservation;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace c10 {
|
namespace c10 {
|
||||||
class intrusive_ptr_target;
|
class intrusive_ptr_target;
|
||||||
namespace raw {
|
namespace raw {
|
||||||
@ -37,8 +33,6 @@ constexpr uint64_t kImpracticallyHugeWeakReferenceCount =
|
|||||||
constexpr uint64_t kReferenceCountOne = 1;
|
constexpr uint64_t kReferenceCountOne = 1;
|
||||||
constexpr uint64_t kWeakReferenceCountOne = (kReferenceCountOne << 32);
|
constexpr uint64_t kWeakReferenceCountOne = (kReferenceCountOne << 32);
|
||||||
constexpr uint64_t kUniqueRef = (kReferenceCountOne | kWeakReferenceCountOne);
|
constexpr uint64_t kUniqueRef = (kReferenceCountOne | kWeakReferenceCountOne);
|
||||||
// Indicates whether the object has a PyObject wrapper.
|
|
||||||
constexpr uint64_t kHasPyObject = (uint64_t(1) << 63);
|
|
||||||
|
|
||||||
template <class TTarget>
|
template <class TTarget>
|
||||||
struct intrusive_target_default_null_type final {
|
struct intrusive_target_default_null_type final {
|
||||||
@ -61,11 +55,7 @@ inline uint32_t refcount(uint64_t combined_refcount) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
inline uint32_t weakcount(uint64_t combined_refcount) {
|
inline uint32_t weakcount(uint64_t combined_refcount) {
|
||||||
return static_cast<uint32_t>((combined_refcount & ~kHasPyObject) >> 32);
|
return static_cast<uint32_t>(combined_refcount >> 32);
|
||||||
}
|
|
||||||
|
|
||||||
inline bool has_pyobject(uint64_t combined_refcount) {
|
|
||||||
return (combined_refcount & kHasPyObject) != 0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// The only requirement for refcount increment is that it happens-before
|
// The only requirement for refcount increment is that it happens-before
|
||||||
@ -76,6 +66,12 @@ inline uint64_t atomic_combined_refcount_increment(
|
|||||||
return combined_refcount.fetch_add(inc, std::memory_order_relaxed) + inc;
|
return combined_refcount.fetch_add(inc, std::memory_order_relaxed) + inc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline uint32_t atomic_refcount_increment(
|
||||||
|
std::atomic<uint64_t>& combined_refcount) {
|
||||||
|
return detail::refcount(atomic_combined_refcount_increment(
|
||||||
|
combined_refcount, kReferenceCountOne));
|
||||||
|
}
|
||||||
|
|
||||||
inline uint32_t atomic_weakcount_increment(
|
inline uint32_t atomic_weakcount_increment(
|
||||||
std::atomic<uint64_t>& combined_refcount) {
|
std::atomic<uint64_t>& combined_refcount) {
|
||||||
return detail::weakcount(atomic_combined_refcount_increment(
|
return detail::weakcount(atomic_combined_refcount_increment(
|
||||||
@ -103,11 +99,6 @@ inline uint32_t atomic_weakcount_decrement(
|
|||||||
combined_refcount, kWeakReferenceCountOne));
|
combined_refcount, kWeakReferenceCountOne));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class T, class = void>
|
|
||||||
struct TargetTraits {
|
|
||||||
static constexpr bool can_have_pyobject = false;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -164,23 +155,6 @@ class C10_API intrusive_ptr_target {
|
|||||||
// we can atomically operate on both at the same time for performance
|
// we can atomically operate on both at the same time for performance
|
||||||
// and defined behaviors.
|
// and defined behaviors.
|
||||||
//
|
//
|
||||||
// Note [PyObject preservation for Tensor and Storages]
|
|
||||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
||||||
// intrusive_ptr has special support for preserving PyObject wrappers
|
|
||||||
// for TensorImpl and StorageImpl. The most significant bit (kHasPyObject) of
|
|
||||||
// the combined_refcount_ is used to indicate whether the object has a
|
|
||||||
// PyObject wrapper.
|
|
||||||
//
|
|
||||||
// - The PyObject, if it exists, holds a strong reference to the
|
|
||||||
// intrusive_ptr_target.
|
|
||||||
//
|
|
||||||
// - When the refcount goes from 1 to 2, we incref the PyObject.
|
|
||||||
//
|
|
||||||
// - When the refcount goes from 2 to 1, we decref the PyObject.
|
|
||||||
//
|
|
||||||
// In other words, the intrusive_ptr keeps the PyObject alive as long as there
|
|
||||||
// are other C++ references to the intrusive_ptr_target.
|
|
||||||
|
|
||||||
mutable std::atomic<uint64_t> combined_refcount_;
|
mutable std::atomic<uint64_t> combined_refcount_;
|
||||||
static_assert(sizeof(std::atomic<uint64_t>) == 8);
|
static_assert(sizeof(std::atomic<uint64_t>) == 8);
|
||||||
static_assert(alignof(std::atomic<uint64_t>) == 8);
|
static_assert(alignof(std::atomic<uint64_t>) == 8);
|
||||||
@ -198,8 +172,6 @@ class C10_API intrusive_ptr_target {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
friend struct ExclusivelyOwnedTensorTraits;
|
friend struct ExclusivelyOwnedTensorTraits;
|
||||||
|
|
||||||
friend class torch::utils::PyObjectPreservation;
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// protected destructor. We never want to destruct intrusive_ptr_target*
|
// protected destructor. We never want to destruct intrusive_ptr_target*
|
||||||
// directly.
|
// directly.
|
||||||
@ -283,16 +255,6 @@ class C10_API intrusive_ptr_target {
|
|||||||
*/
|
*/
|
||||||
virtual void release_resources() {}
|
virtual void release_resources() {}
|
||||||
|
|
||||||
/**
|
|
||||||
* These two methods are called when the refcount transitions between one
|
|
||||||
* and two and the object has a PyObject wrapper.
|
|
||||||
*/
|
|
||||||
virtual void incref_pyobject() const {}
|
|
||||||
virtual void decref_pyobject() const {}
|
|
||||||
virtual bool try_incref_pyobject() const {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t refcount(std::memory_order order = std::memory_order_relaxed) const {
|
uint32_t refcount(std::memory_order order = std::memory_order_relaxed) const {
|
||||||
return detail::refcount(combined_refcount_.load(order));
|
return detail::refcount(combined_refcount_.load(order));
|
||||||
}
|
}
|
||||||
@ -303,19 +265,6 @@ class C10_API intrusive_ptr_target {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace detail {
|
|
||||||
|
|
||||||
#ifndef C10_MOBILE
|
|
||||||
template <>
|
|
||||||
struct TargetTraits<c10::intrusive_ptr_target> {
|
|
||||||
// A generic intrusive_ptr<intrusive_ptr_target> may actually be a TensorImpl
|
|
||||||
// or StorageImpl, so we have to allow for PyObject support.
|
|
||||||
static constexpr bool can_have_pyobject = true;
|
|
||||||
};
|
|
||||||
#endif
|
|
||||||
|
|
||||||
} // namespace detail
|
|
||||||
|
|
||||||
template <class TTarget, class NullType>
|
template <class TTarget, class NullType>
|
||||||
class weak_intrusive_ptr;
|
class weak_intrusive_ptr;
|
||||||
|
|
||||||
@ -365,34 +314,18 @@ class intrusive_ptr final {
|
|||||||
|
|
||||||
void retain_() {
|
void retain_() {
|
||||||
if (target_ != NullType::singleton()) {
|
if (target_ != NullType::singleton()) {
|
||||||
uint64_t combined = detail::atomic_combined_refcount_increment(
|
uint32_t new_refcount =
|
||||||
target_->combined_refcount_, detail::kReferenceCountOne);
|
detail::atomic_refcount_increment(target_->combined_refcount_);
|
||||||
uint32_t new_refcount = detail::refcount(combined);
|
|
||||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||||
new_refcount != 1,
|
new_refcount != 1,
|
||||||
"intrusive_ptr: Cannot increase refcount after it reached zero.");
|
"intrusive_ptr: Cannot increase refcount after it reached zero.");
|
||||||
|
|
||||||
if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
|
|
||||||
// If the refcount transitioned from 1 to 2, we need to incref the
|
|
||||||
// PyObject. In other words, we need to ensure that the PyObject stays
|
|
||||||
// alive now that we have a C++ reference to this object in addition to
|
|
||||||
// the PyObject itself.
|
|
||||||
if (C10_UNLIKELY(
|
|
||||||
detail::has_pyobject(combined) &&
|
|
||||||
detail::refcount(combined) == 2)) {
|
|
||||||
target_->incref_pyobject();
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
|
||||||
!detail::has_pyobject(combined),
|
|
||||||
"TargetTraits indicates that type cannot have PyObject, but refcount has PyObject bit set.");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void reset_() noexcept {
|
void reset_() noexcept {
|
||||||
if (target_ != NullType::singleton()) {
|
if (target_ != NullType::singleton()) {
|
||||||
if (is_uniquely_owned()) {
|
if (target_->combined_refcount_.load(std::memory_order_acquire) ==
|
||||||
|
detail::kUniqueRef) {
|
||||||
// Both counts are 1, so there are no weak references and
|
// Both counts are 1, so there are no weak references and
|
||||||
// we are releasing the last strong reference. No other
|
// we are releasing the last strong reference. No other
|
||||||
// threads can observe the effects of this target_ deletion
|
// threads can observe the effects of this target_ deletion
|
||||||
@ -404,10 +337,9 @@ class intrusive_ptr final {
|
|||||||
|
|
||||||
auto combined_refcount = detail::atomic_combined_refcount_decrement(
|
auto combined_refcount = detail::atomic_combined_refcount_decrement(
|
||||||
target_->combined_refcount_, detail::kReferenceCountOne);
|
target_->combined_refcount_, detail::kReferenceCountOne);
|
||||||
uint32_t new_refcount = detail::refcount(combined_refcount);
|
if (detail::refcount(combined_refcount) == 0) {
|
||||||
bool has_pyobject = detail::has_pyobject(combined_refcount);
|
bool should_delete =
|
||||||
if (new_refcount == 0) {
|
(combined_refcount == detail::kWeakReferenceCountOne);
|
||||||
bool should_delete = detail::weakcount(combined_refcount) == 1;
|
|
||||||
// See comment above about weakcount. As long as refcount>0,
|
// See comment above about weakcount. As long as refcount>0,
|
||||||
// weakcount is one larger than the actual number of weak references.
|
// weakcount is one larger than the actual number of weak references.
|
||||||
// So we need to decrement it here.
|
// So we need to decrement it here.
|
||||||
@ -424,18 +356,6 @@ class intrusive_ptr final {
|
|||||||
if (should_delete) {
|
if (should_delete) {
|
||||||
delete target_;
|
delete target_;
|
||||||
}
|
}
|
||||||
} else if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
|
|
||||||
// If the refcount transitioned from 2 to 1, we need to decref the
|
|
||||||
// PyObject. In other words, we don't want to keep the PyObject alive if
|
|
||||||
// there are no C++ references to this object other than the PyObject
|
|
||||||
// itself.
|
|
||||||
if (C10_UNLIKELY(has_pyobject && new_refcount == 1)) {
|
|
||||||
target_->decref_pyobject();
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
|
||||||
!has_pyobject,
|
|
||||||
"TargetTraits indicates that type cannot have PyObject, but refcount has PyObject bit set.");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -602,16 +522,6 @@ class intrusive_ptr final {
|
|||||||
return use_count() == 1;
|
return use_count() == 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Stronger than unique() in that it must not have any weakrefs as well.
|
|
||||||
*/
|
|
||||||
bool is_uniquely_owned() const noexcept {
|
|
||||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(target_ != NullType::singleton());
|
|
||||||
uint64_t combined =
|
|
||||||
target_->combined_refcount_.load(std::memory_order_acquire);
|
|
||||||
return (combined & ~detail::kHasPyObject) == detail::kUniqueRef;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns an owning (!) pointer to the underlying object and makes the
|
* Returns an owning (!) pointer to the underlying object and makes the
|
||||||
* intrusive_ptr instance invalid. That means the refcount is not decreased.
|
* intrusive_ptr instance invalid. That means the refcount is not decreased.
|
||||||
@ -1022,7 +932,6 @@ class weak_intrusive_ptr final {
|
|||||||
if (target_ == NullType::singleton()) {
|
if (target_ == NullType::singleton()) {
|
||||||
return intrusive_ptr<TTarget, NullType>();
|
return intrusive_ptr<TTarget, NullType>();
|
||||||
} else {
|
} else {
|
||||||
bool increfed = false;
|
|
||||||
auto combined_refcount =
|
auto combined_refcount =
|
||||||
target_->combined_refcount_.load(std::memory_order_relaxed);
|
target_->combined_refcount_.load(std::memory_order_relaxed);
|
||||||
do {
|
do {
|
||||||
@ -1031,31 +940,12 @@ class weak_intrusive_ptr final {
|
|||||||
// Return nullptr.
|
// Return nullptr.
|
||||||
return intrusive_ptr<TTarget, NullType>();
|
return intrusive_ptr<TTarget, NullType>();
|
||||||
}
|
}
|
||||||
if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
|
|
||||||
if (detail::has_pyobject(combined_refcount) &&
|
|
||||||
detail::refcount(combined_refcount) == 1 && !increfed) {
|
|
||||||
// Object has a python wrapper with no other C++ references.
|
|
||||||
// We need to to incref the Python object before we acquire a
|
|
||||||
// strong reference to the C++ object to avoid a situation
|
|
||||||
// where the Python object is deallocated concurrently.
|
|
||||||
if (!target_->try_incref_pyobject()) {
|
|
||||||
return intrusive_ptr<TTarget, NullType>();
|
|
||||||
}
|
|
||||||
increfed = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} while (!target_->combined_refcount_.compare_exchange_weak(
|
} while (!target_->combined_refcount_.compare_exchange_weak(
|
||||||
combined_refcount,
|
combined_refcount,
|
||||||
combined_refcount + detail::kReferenceCountOne,
|
combined_refcount + detail::kReferenceCountOne,
|
||||||
std::memory_order_acquire,
|
std::memory_order_acquire,
|
||||||
std::memory_order_relaxed));
|
std::memory_order_relaxed));
|
||||||
|
|
||||||
if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
|
|
||||||
if (increfed && detail::refcount(combined_refcount) != 1) {
|
|
||||||
target_->decref_pyobject();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return intrusive_ptr<TTarget, NullType>(
|
return intrusive_ptr<TTarget, NullType>(
|
||||||
target_, raw::DontIncreaseRefcount{});
|
target_, raw::DontIncreaseRefcount{});
|
||||||
}
|
}
|
||||||
@ -1170,18 +1060,7 @@ namespace intrusive_ptr {
|
|||||||
// NullType::singleton to this function
|
// NullType::singleton to this function
|
||||||
inline void incref(intrusive_ptr_target* self) {
|
inline void incref(intrusive_ptr_target* self) {
|
||||||
if (self) {
|
if (self) {
|
||||||
uint64_t combined = detail::atomic_combined_refcount_increment(
|
detail::atomic_refcount_increment(self->combined_refcount_);
|
||||||
self->combined_refcount_, detail::kReferenceCountOne);
|
|
||||||
|
|
||||||
#ifndef C10_MOBILE
|
|
||||||
if (C10_UNLIKELY(
|
|
||||||
detail::has_pyobject(combined) &&
|
|
||||||
detail::refcount(combined) == 2)) {
|
|
||||||
self->incref_pyobject();
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!detail::has_pyobject(combined));
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,113 +0,0 @@
|
|||||||
# Device Management
|
|
||||||
|
|
||||||
## Background
|
|
||||||
|
|
||||||
Device management handles basic operations like querying how many devices are available and switching between them. Accelerator backends need to wrap their device runtime's APIs and expose them to PyTorch.
|
|
||||||
|
|
||||||
The OpenReg implementation ([`OpenRegFunctions.h/cpp`][OpenReg Device Management]) shows how to wrap a third-party runtime. These functions are used throughout the backend - by streams, events, generators, and Python bindings.
|
|
||||||
|
|
||||||
## Design
|
|
||||||
|
|
||||||
Accelerator vendors need to implement these core functions:
|
|
||||||
|
|
||||||
| Function Name | Description | Application Scenarios |
|
|
||||||
| ------------------------- | ---------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------- |
|
|
||||||
| `device_count()` | Query the total number of available devices in the system | - Application initialization<br>- Multi-device workload distribution<br>- Validating device indices before use |
|
|
||||||
| `current_device()` | Get the currently active device for the calling thread | - Debugging and logging<br>- Determining tensor placement<br>- Guard implementations |
|
|
||||||
| `set_device()` | Change the active device for subsequent operations | - Switching context between devices<br>- Initializing specific device resources<br>- Multi-GPU training loops |
|
|
||||||
| `exchange_device()` | Atomically swap device and return the previous device | - Implementing device guards<br>- Temporarily switching device context<br>- RAII-based device management |
|
|
||||||
| `maybe_exchange_device()` | Conditionally exchange device only if the index is valid (-1 OK) | - Safe device switching with optional indices<br>- Guard implementations with nullable device values |
|
|
||||||
|
|
||||||
These functions are building blocks for more complex features like streams, events, and memory management. Make sure to validate inputs and handle errors properly.
|
|
||||||
|
|
||||||
## Implementation
|
|
||||||
|
|
||||||
This section shows how to implement device management using `set_device` as an example. The implementation requires:
|
|
||||||
1. C++ wrappers around the device runtime
|
|
||||||
2. Python bindings to expose the C++ functions
|
|
||||||
3. User-friendly Python APIs
|
|
||||||
|
|
||||||
### C++ Side
|
|
||||||
|
|
||||||
Wrap the device runtime's API and add error handling. The `SetDevice` function shows this pattern:
|
|
||||||
|
|
||||||
```{eval-rst}
|
|
||||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp
|
|
||||||
:language: c++
|
|
||||||
:start-after: LITERALINCLUDE START: OPENREG SetDevice FUNCTION
|
|
||||||
:end-before: LITERALINCLUDE END: OPENREG SetDevice FUNCTION
|
|
||||||
:linenos:
|
|
||||||
```
|
|
||||||
```{eval-rst}
|
|
||||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp
|
|
||||||
:language: c++
|
|
||||||
:start-after: LITERALINCLUDE START: OPENREG set_device FUNCTION
|
|
||||||
:end-before: LITERALINCLUDE END: OPENREG set_device FUNCTION
|
|
||||||
:linenos:
|
|
||||||
```
|
|
||||||
|
|
||||||
### Binding
|
|
||||||
|
|
||||||
Expose the C++ functions to Python using pybind11:
|
|
||||||
|
|
||||||
```{eval-rst}
|
|
||||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
|
|
||||||
:language: c++
|
|
||||||
:start-after: LITERALINCLUDE START: MODULE SET DEVICE HELPER
|
|
||||||
:end-before: LITERALINCLUDE END: MODULE SET DEVICE HELPER
|
|
||||||
:linenos:
|
|
||||||
```
|
|
||||||
```{eval-rst}
|
|
||||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
|
|
||||||
:language: c++
|
|
||||||
:start-after: LITERALINCLUDE START: OPENREG MODULE METHODS
|
|
||||||
:end-before: LITERALINCLUDE END: OPENREG MODULE METHODS
|
|
||||||
:linenos:
|
|
||||||
:emphasize-lines: 5
|
|
||||||
```
|
|
||||||
|
|
||||||
### Python Side
|
|
||||||
|
|
||||||
Wrap the C++ bindings with user-friendly Python functions:
|
|
||||||
|
|
||||||
```{eval-rst}
|
|
||||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/__init__.py
|
|
||||||
:language: python
|
|
||||||
:start-after: LITERALINCLUDE START: PYTHON SET DEVICE FUNCTION
|
|
||||||
:end-before: LITERALINCLUDE END: PYTHON SET DEVICE FUNCTION
|
|
||||||
:linenos:
|
|
||||||
```
|
|
||||||
|
|
||||||
Here's the complete mapping from C++ to Python:
|
|
||||||
|
|
||||||
| C++ Binding Function | C++ Binding API (pybind11) | Python User API | Description |
|
|
||||||
| -------------------- | ---------------------------------------- | -------------------------------- | -------------------------------------------- |
|
|
||||||
| `_getDeviceCount` | `torch_openreg._C._get_device_count()` | `torch.openreg.device_count()` | Returns the total number of devices |
|
|
||||||
| `_getDevice` | `torch_openreg._C._get_device()` | `torch.openreg.current_device()` | Returns the current active device index |
|
|
||||||
| `_setDevice` | `torch_openreg._C._set_device(idx)` | `torch.openreg.set_device(idx)` | Sets the active device |
|
|
||||||
| `_exchangeDevice` | `torch_openreg._C._exchange_device(idx)` | N/A (internal use only) | Atomically swaps device and returns previous |
|
|
||||||
|
|
||||||
## Guard
|
|
||||||
|
|
||||||
Device guards provide automatic device switching with exception safety. They're similar to lock guards in C++ - they switch device on construction and restore it on destruction.
|
|
||||||
|
|
||||||
Implement `DeviceGuardImplInterface` to integrate with PyTorch's guard system:
|
|
||||||
|
|
||||||
```{eval-rst}
|
|
||||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h
|
|
||||||
:language: c++
|
|
||||||
:start-after: LITERALINCLUDE START: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
|
|
||||||
:end-before: LITERALINCLUDE END: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
|
|
||||||
:linenos:
|
|
||||||
```
|
|
||||||
|
|
||||||
**What needs to be implemented:**
|
|
||||||
|
|
||||||
1. **exchangeDevice()**: Switch to a new device and return the old one (used by guard constructors)
|
|
||||||
2. **getDevice()**: Get the current device
|
|
||||||
3. **setDevice()**: Set the active device
|
|
||||||
4. **Type checking**: Validate that device type matches the backend
|
|
||||||
|
|
||||||
This makes the guard available to PyTorch for the `PrivateUse1` device type. Users can then use standard PyTorch device guards with the custom backend.
|
|
||||||
|
|
||||||
[OpenReg Device Management]: https://github.com/pytorch/pytorch/blob/main/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp "OpenReg Device Management"
|
|
||||||
@ -42,7 +42,6 @@ Next, we will delve into each chapter of this guide. Each chapter focuses on a k
|
|||||||
:glob:
|
:glob:
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
|
||||||
device
|
|
||||||
hooks
|
hooks
|
||||||
autoload
|
autoload
|
||||||
operators
|
operators
|
||||||
|
|||||||
@ -4,12 +4,17 @@
|
|||||||
|
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
|
|
||||||
void orCheckFail(const char* func, const char* file, uint32_t line, const char* msg = "");
|
void orCheckFail(
|
||||||
|
const char* func,
|
||||||
|
const char* file,
|
||||||
|
uint32_t line,
|
||||||
|
const char* msg = "");
|
||||||
|
|
||||||
#define OPENREG_CHECK(EXPR, ...) \
|
#define OPENREG_CHECK(EXPR, ...) \
|
||||||
do { \
|
do { \
|
||||||
const orError_t __err = EXPR; \
|
const orError_t __err = EXPR; \
|
||||||
if (C10_UNLIKELY(__err != orSuccess)) { \
|
if (__err != orSuccess) { \
|
||||||
orCheckFail(__func__, __FILE__, static_cast<uint32_t>(__LINE__), ##__VA_ARGS__); \
|
orCheckFail( \
|
||||||
} \
|
__func__, __FILE__, static_cast<uint32_t>(__LINE__), ##__VA_ARGS__); \
|
||||||
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
#include <c10/util/Exception.h>
|
|
||||||
#include <include/openreg.h>
|
#include <include/openreg.h>
|
||||||
|
|
||||||
#include "OpenRegException.h"
|
#include "OpenRegException.h"
|
||||||
@ -10,22 +9,21 @@ orError_t GetDeviceCount(int* dev_count) {
|
|||||||
return orGetDeviceCount(dev_count);
|
return orGetDeviceCount(dev_count);
|
||||||
}
|
}
|
||||||
|
|
||||||
orError_t GetDevice(DeviceIndex* device) {
|
orError_t GetDevice(c10::DeviceIndex* device) {
|
||||||
int tmp_device = -1;
|
int tmp_device = -1;
|
||||||
auto err = orGetDevice(&tmp_device);
|
auto err = orGetDevice(&tmp_device);
|
||||||
*device = static_cast<DeviceIndex>(tmp_device);
|
*device = static_cast<c10::DeviceIndex>(tmp_device);
|
||||||
return err;
|
return err;
|
||||||
}
|
}
|
||||||
// LITERALINCLUDE START: OPENREG SetDevice FUNCTION
|
|
||||||
orError_t SetDevice(DeviceIndex device) {
|
orError_t SetDevice(c10::DeviceIndex device) {
|
||||||
int cur_device = -1;
|
int cur_device = -1;
|
||||||
OPENREG_CHECK(orGetDevice(&cur_device));
|
orGetDevice(&cur_device);
|
||||||
if (device == cur_device) {
|
if (device == cur_device) {
|
||||||
return orSuccess;
|
return orSuccess;
|
||||||
}
|
}
|
||||||
return orSetDevice(device);
|
return orSetDevice(device);
|
||||||
}
|
}
|
||||||
// LITERALINCLUDE END: OPENREG SetDevice FUNCTION
|
|
||||||
|
|
||||||
int device_count_impl() {
|
int device_count_impl() {
|
||||||
int count = 0;
|
int count = 0;
|
||||||
@ -33,37 +31,34 @@ int device_count_impl() {
|
|||||||
return count;
|
return count;
|
||||||
}
|
}
|
||||||
|
|
||||||
OPENREG_EXPORT DeviceIndex device_count() noexcept {
|
OPENREG_EXPORT c10::DeviceIndex device_count() noexcept {
|
||||||
// initialize number of devices only once
|
// initialize number of devices only once
|
||||||
static int count = []() {
|
static int count = []() {
|
||||||
try {
|
try {
|
||||||
auto result = device_count_impl();
|
auto result = device_count_impl();
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
result <= std::numeric_limits<DeviceIndex>::max(),
|
result <= std::numeric_limits<c10::DeviceIndex>::max(),
|
||||||
"Too many devices, DeviceIndex overflowed");
|
"Too many devices, DeviceIndex overflowed");
|
||||||
return result;
|
return result;
|
||||||
} catch (const Error& ex) {
|
} catch (const c10::Error& ex) {
|
||||||
// We don't want to fail, but still log the warning
|
// We don't want to fail, but still log the warning
|
||||||
// msg() returns the message without the stack trace
|
// msg() returns the message without the stack trace
|
||||||
TORCH_WARN("Device initialization: ", ex.msg());
|
TORCH_WARN("Device initialization: ", ex.msg());
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
}();
|
}();
|
||||||
return static_cast<DeviceIndex>(count);
|
return static_cast<c10::DeviceIndex>(count);
|
||||||
}
|
}
|
||||||
|
|
||||||
OPENREG_EXPORT DeviceIndex current_device() {
|
OPENREG_EXPORT c10::DeviceIndex current_device() {
|
||||||
DeviceIndex cur_device = -1;
|
c10::DeviceIndex cur_device = -1;
|
||||||
OPENREG_CHECK(GetDevice(&cur_device));
|
GetDevice(&cur_device);
|
||||||
return cur_device;
|
return cur_device;
|
||||||
}
|
}
|
||||||
|
|
||||||
// LITERALINCLUDE START: OPENREG set_device FUNCTION
|
OPENREG_EXPORT void set_device(c10::DeviceIndex device) {
|
||||||
OPENREG_EXPORT void set_device(DeviceIndex device) {
|
SetDevice(device);
|
||||||
check_device_index(device);
|
|
||||||
OPENREG_CHECK(SetDevice(device));
|
|
||||||
}
|
}
|
||||||
// LITERALINCLUDE END: OPENREG set_device FUNCTION
|
|
||||||
|
|
||||||
OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device) {
|
OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device) {
|
||||||
int current_device = -1;
|
int current_device = -1;
|
||||||
@ -76,8 +71,4 @@ OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device) {
|
|||||||
return current_device;
|
return current_device;
|
||||||
}
|
}
|
||||||
|
|
||||||
OPENREG_EXPORT DeviceIndex maybe_exchange_device(DeviceIndex to_device) {
|
|
||||||
check_device_index(to_device);
|
|
||||||
return ExchangeDevice(to_device);
|
|
||||||
}
|
|
||||||
} // namespace c10::openreg
|
} // namespace c10::openreg
|
||||||
|
|||||||
@ -9,20 +9,10 @@
|
|||||||
|
|
||||||
namespace c10::openreg {
|
namespace c10::openreg {
|
||||||
|
|
||||||
OPENREG_EXPORT DeviceIndex device_count() noexcept;
|
OPENREG_EXPORT c10::DeviceIndex device_count() noexcept;
|
||||||
OPENREG_EXPORT DeviceIndex current_device();
|
OPENREG_EXPORT c10::DeviceIndex current_device();
|
||||||
OPENREG_EXPORT void set_device(DeviceIndex device);
|
OPENREG_EXPORT void set_device(c10::DeviceIndex device);
|
||||||
OPENREG_EXPORT DeviceIndex maybe_exchange_device(DeviceIndex to_device);
|
|
||||||
|
|
||||||
OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device);
|
OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device);
|
||||||
|
|
||||||
static inline void check_device_index(int64_t device) {
|
|
||||||
TORCH_CHECK(device >= 0 && device < c10::openreg::device_count(),
|
|
||||||
"The device index is out of range. It must be in [0, ",
|
|
||||||
static_cast<int>(c10::openreg::device_count()),
|
|
||||||
"), but got ",
|
|
||||||
static_cast<int>(device),
|
|
||||||
".");
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace c10::openreg
|
} // namespace c10::openreg
|
||||||
|
|||||||
@ -2,8 +2,6 @@
|
|||||||
|
|
||||||
namespace c10::openreg {
|
namespace c10::openreg {
|
||||||
|
|
||||||
// LITERALINCLUDE START: OPENREG GUARD REGISTRATION
|
|
||||||
C10_REGISTER_GUARD_IMPL(PrivateUse1, OpenRegGuardImpl);
|
C10_REGISTER_GUARD_IMPL(PrivateUse1, OpenRegGuardImpl);
|
||||||
// LITERALINCLUDE END: OPENREG GUARD REGISTRATION
|
|
||||||
|
|
||||||
} // namespace c10::openreg
|
} // namespace c10::openreg
|
||||||
|
|||||||
@ -11,7 +11,6 @@
|
|||||||
|
|
||||||
namespace c10::openreg {
|
namespace c10::openreg {
|
||||||
|
|
||||||
// LITERALINCLUDE START: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
|
|
||||||
struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
||||||
static constexpr DeviceType static_type = c10::DeviceType::PrivateUse1;
|
static constexpr DeviceType static_type = c10::DeviceType::PrivateUse1;
|
||||||
|
|
||||||
@ -59,7 +58,6 @@ struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
|||||||
|
|
||||||
set_device(d.index());
|
set_device(d.index());
|
||||||
}
|
}
|
||||||
// LITERALINCLUDE END: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set the current device to c10::Device, without checking for errors
|
* Set the current device to c10::Device, without checking for errors
|
||||||
|
|||||||
@ -27,10 +27,6 @@ class TestDevice(TestCase):
|
|||||||
self.assertEqual(torch.accelerator.current_device_index(), 1)
|
self.assertEqual(torch.accelerator.current_device_index(), 1)
|
||||||
self.assertEqual(torch.accelerator.current_device_index(), device)
|
self.assertEqual(torch.accelerator.current_device_index(), device)
|
||||||
|
|
||||||
def test_invalid_device_index(self):
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "The device index is out of range"):
|
|
||||||
torch.accelerator.set_device_index(2)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|||||||
@ -34,21 +34,18 @@ static PyObject* _getDefaultGenerator(PyObject* self, PyObject* arg) {
|
|||||||
}
|
}
|
||||||
// LITERALINCLUDE END: OPENREG GET DEFAULT GENERATOR
|
// LITERALINCLUDE END: OPENREG GET DEFAULT GENERATOR
|
||||||
|
|
||||||
// LITERALINCLUDE START: MODULE SET DEVICE HELPER
|
|
||||||
|
|
||||||
PyObject* _setDevice(PyObject* self, PyObject* arg) {
|
PyObject* _setDevice(PyObject* self, PyObject* arg) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to setDevice");
|
TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to setDevice");
|
||||||
auto device = THPUtils_unpackDeviceIndex(arg);
|
auto device = THPUtils_unpackLong(arg);
|
||||||
|
|
||||||
torch::utils::device_lazy_init(at::kPrivateUse1);
|
torch::utils::device_lazy_init(at::kPrivateUse1);
|
||||||
c10::openreg::set_device(device);
|
c10::openreg::set_device(static_cast<c10::DeviceIndex>(device));
|
||||||
|
|
||||||
Py_RETURN_NONE;
|
Py_RETURN_NONE;
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
|
|
||||||
// LITERALINCLUDE END: MODULE SET DEVICE HELPER
|
|
||||||
|
|
||||||
PyObject* _exchangeDevice(PyObject* self, PyObject* arg) {
|
PyObject* _exchangeDevice(PyObject* self, PyObject* arg) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to exchangeDevice");
|
TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to exchangeDevice");
|
||||||
|
|||||||
@ -41,13 +41,8 @@ def current_device():
|
|||||||
return torch_openreg._C._get_device()
|
return torch_openreg._C._get_device()
|
||||||
|
|
||||||
|
|
||||||
# LITERALINCLUDE START: PYTHON SET DEVICE FUNCTION
|
|
||||||
def set_device(device) -> None:
|
def set_device(device) -> None:
|
||||||
if device >= 0:
|
return torch_openreg._C._set_device(device)
|
||||||
torch_openreg._C._set_device(device)
|
|
||||||
|
|
||||||
|
|
||||||
# LITERALINCLUDE END: PYTHON SET DEVICE FUNCTION
|
|
||||||
|
|
||||||
|
|
||||||
def init():
|
def init():
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import torch.distributed._functional_collectives as funcol
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed.tensor import DeviceMesh, DTensor, Shard
|
from torch.distributed.tensor import DeviceMesh, DTensor, Shard
|
||||||
from torch.distributed.tensor.debug import CommDebugMode
|
from torch.distributed.tensor.debug import CommDebugMode
|
||||||
from torch.testing._internal.common_distributed import requires_nccl
|
from torch.testing._internal.common_distributed import requires_accelerator_dist_backend
|
||||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||||
from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule
|
from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule
|
||||||
from torch.testing._internal.distributed.fake_pg import FakeStore
|
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||||
@ -14,6 +14,9 @@ from torch.testing._internal.distributed.fake_pg import FakeStore
|
|||||||
|
|
||||||
c10d_functional = torch.ops.c10d_functional
|
c10d_functional = torch.ops.c10d_functional
|
||||||
c10d_ops = torch.ops.c10d
|
c10d_ops = torch.ops.c10d
|
||||||
|
device_type = (
|
||||||
|
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestCommMode(TestCase):
|
class TestCommMode(TestCase):
|
||||||
@ -28,7 +31,7 @@ class TestCommMode(TestCase):
|
|||||||
dist.init_process_group(
|
dist.init_process_group(
|
||||||
backend="fake", rank=1, world_size=self.world_size, store=store
|
backend="fake", rank=1, world_size=self.world_size, store=store
|
||||||
)
|
)
|
||||||
self.device_type = "cuda" if torch.cuda.is_available() else "cpu"
|
self.device_type = device_type
|
||||||
self.world_pg = dist.distributed_c10d._get_default_group()
|
self.world_pg = dist.distributed_c10d._get_default_group()
|
||||||
|
|
||||||
def checksAssert(self, comm_mode, key, expected_value, expected_total_value):
|
def checksAssert(self, comm_mode, key, expected_value, expected_total_value):
|
||||||
@ -111,12 +114,12 @@ class TestCommMode(TestCase):
|
|||||||
self.assertEqual(comm_counts[c10d_functional.all_gather_into_tensor], 1)
|
self.assertEqual(comm_counts[c10d_functional.all_gather_into_tensor], 1)
|
||||||
self.assertEqual(comm_counts[c10d_functional.reduce_scatter_tensor], 0)
|
self.assertEqual(comm_counts[c10d_functional.reduce_scatter_tensor], 0)
|
||||||
|
|
||||||
@requires_nccl()
|
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||||
def test_comm_mode_with_c10d(self):
|
def test_comm_mode_with_c10d(self):
|
||||||
if not torch.cuda.is_available():
|
if not torch.accelerator.is_available():
|
||||||
return
|
return
|
||||||
|
|
||||||
inp = torch.rand(2, 8, 16).cuda()
|
inp = torch.rand(2, 8, 16).to(device_type)
|
||||||
all_gather_out = inp.new_empty(self.world_size * 2, 8, 16)
|
all_gather_out = inp.new_empty(self.world_size * 2, 8, 16)
|
||||||
|
|
||||||
comm_mode = CommDebugMode()
|
comm_mode = CommDebugMode()
|
||||||
|
|||||||
@ -658,11 +658,11 @@ class DTensorMeshTest(DTensorTestBase):
|
|||||||
|
|
||||||
@with_comms
|
@with_comms
|
||||||
def test_dtensor_device_mesh_device_conversion(self):
|
def test_dtensor_device_mesh_device_conversion(self):
|
||||||
# construct a cuda device mesh
|
# construct a gpu device mesh
|
||||||
mesh = self.build_device_mesh()
|
mesh = self.build_device_mesh()
|
||||||
|
|
||||||
# construct from a cpu local tensor with cuda device mesh
|
# construct from a cpu local tensor with gpu device mesh
|
||||||
# should automatically convert the dist tensor to cuda
|
# should automatically convert the dist tensor to gpu
|
||||||
placements = [Shard(0)]
|
placements = [Shard(0)]
|
||||||
local_tensor = torch.randn(3, 3)
|
local_tensor = torch.randn(3, 3)
|
||||||
dist_tensor = DTensor.from_local(local_tensor, mesh, placements)
|
dist_tensor = DTensor.from_local(local_tensor, mesh, placements)
|
||||||
@ -711,7 +711,7 @@ class DTensorMeshTest(DTensorTestBase):
|
|||||||
@with_comms
|
@with_comms
|
||||||
def test_dtensor_2d_mesh(self):
|
def test_dtensor_2d_mesh(self):
|
||||||
mesh_tensor = torch.arange(self.world_size).reshape(2, 4)
|
mesh_tensor = torch.arange(self.world_size).reshape(2, 4)
|
||||||
# construct a cuda device mesh
|
# construct a gpu device mesh
|
||||||
mesh = DeviceMesh(self.device_type, mesh_tensor)
|
mesh = DeviceMesh(self.device_type, mesh_tensor)
|
||||||
|
|
||||||
# construct a dist tensor on 2d device mesh and test if works
|
# construct a dist tensor on 2d device mesh and test if works
|
||||||
@ -733,7 +733,7 @@ class DTensorMeshTest(DTensorTestBase):
|
|||||||
|
|
||||||
@with_comms
|
@with_comms
|
||||||
def test_device_mesh_nd(self):
|
def test_device_mesh_nd(self):
|
||||||
# construct a cuda device mesh
|
# construct a gpu device mesh
|
||||||
mesh_tensor = torch.arange(self.world_size).reshape(2, 2, 2)
|
mesh_tensor = torch.arange(self.world_size).reshape(2, 2, 2)
|
||||||
mesh = DeviceMesh(self.device_type, mesh_tensor)
|
mesh = DeviceMesh(self.device_type, mesh_tensor)
|
||||||
# construct a dist tensor on 3d device mesh and test if works
|
# construct a dist tensor on 3d device mesh and test if works
|
||||||
@ -1064,8 +1064,8 @@ class TestDTensorPlacementTypes(DTensorTestBase):
|
|||||||
# Keep everything deterministic.
|
# Keep everything deterministic.
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
tensor = torch.rand(size)
|
tensor = torch.rand(size)
|
||||||
if self.device_type == "cuda":
|
if self.device_type != "cpu":
|
||||||
return tensor.cuda()
|
return tensor.to(self.device_type)
|
||||||
else:
|
else:
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|||||||
@ -39,6 +39,7 @@ from torch.distributed.tensor.parallel import (
|
|||||||
RowwiseParallel,
|
RowwiseParallel,
|
||||||
)
|
)
|
||||||
from torch.distributed.tensor.placement_types import _StridedShard
|
from torch.distributed.tensor.placement_types import _StridedShard
|
||||||
|
from torch.testing._internal.common_device_type import skipXPUIf
|
||||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||||
from torch.testing._internal.common_fsdp import get_devtype
|
from torch.testing._internal.common_fsdp import get_devtype
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
@ -47,8 +48,6 @@ from torch.testing._internal.common_utils import (
|
|||||||
run_tests,
|
run_tests,
|
||||||
skipIfHpu,
|
skipIfHpu,
|
||||||
skipIfTorchDynamo,
|
skipIfTorchDynamo,
|
||||||
TEST_CUDA,
|
|
||||||
TEST_HPU,
|
|
||||||
)
|
)
|
||||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||||
DTensorTestBase,
|
DTensorTestBase,
|
||||||
@ -95,6 +94,10 @@ aot_eager_graph = aot_autograd(
|
|||||||
partition_fn=min_cut_rematerialization_partition,
|
partition_fn=min_cut_rematerialization_partition,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
device_type = (
|
||||||
|
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _apply_sharding(mod: nn.Module, shard_dim: int, device_mesh: DeviceMesh):
|
def _apply_sharding(mod: nn.Module, shard_dim: int, device_mesh: DeviceMesh):
|
||||||
"""
|
"""
|
||||||
@ -141,7 +144,7 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def device_type(self) -> str:
|
def device_type(self) -> str:
|
||||||
return "cuda" if TEST_CUDA else "hpu" if TEST_HPU else "cpu"
|
return device_type
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def world_size(self) -> int:
|
def world_size(self) -> int:
|
||||||
@ -160,9 +163,9 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
|
|||||||
res = fn(x)
|
res = fn(x)
|
||||||
res.to_local().sum().backward()
|
res.to_local().sum().backward()
|
||||||
|
|
||||||
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
|
@unittest.skipIf(not torch.accelerator.is_available(), "accelerator not available")
|
||||||
def test_dtensor_basic_export(self):
|
def test_dtensor_basic_export(self):
|
||||||
mesh = DeviceMesh("cuda", torch.arange(self.world_size))
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||||
|
|
||||||
param = torch.randn(4, 4)
|
param = torch.randn(4, 4)
|
||||||
param_x = DTensor.from_local(param, mesh, [Shard(0)], run_check=False)
|
param_x = DTensor.from_local(param, mesh, [Shard(0)], run_check=False)
|
||||||
@ -188,10 +191,10 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
str(ep.graph_module.code).strip(),
|
str(ep.graph_module.code).strip(),
|
||||||
"""\
|
f"""\
|
||||||
def forward(self, b_buffer, x):
|
def forward(self, b_buffer, x):
|
||||||
_assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(x, dtype = torch.float64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default = None
|
_assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(x, dtype = torch.float64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default = None
|
||||||
to = torch.ops.aten.to.dtype_layout(x, dtype = torch.float64, layout = torch.strided, device = device(type='cuda')); x = None
|
to = torch.ops.aten.to.dtype_layout(x, dtype = torch.float64, layout = torch.strided, device = device(type='{self.device_type}')); x = None
|
||||||
view_as = torch.ops.aten.view_as.default(to, to); to = None
|
view_as = torch.ops.aten.view_as.default(to, to); to = None
|
||||||
dtensor___init__0 = self.dtensor___init__0
|
dtensor___init__0 = self.dtensor___init__0
|
||||||
dtensor_const_func_spec0 = self.dtensor_const_func_spec0
|
dtensor_const_func_spec0 = self.dtensor_const_func_spec0
|
||||||
@ -206,10 +209,10 @@ def forward(self, b_buffer, x):
|
|||||||
# add is performed in _propagate_tensor_meta_non_cached, hence add_1 instead of add
|
# add is performed in _propagate_tensor_meta_non_cached, hence add_1 instead of add
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
str(ep.run_decompositions({}).graph_module.code).strip(),
|
str(ep.run_decompositions({}).graph_module.code).strip(),
|
||||||
"""\
|
f"""\
|
||||||
def forward(self, b_parametrizations_buffer_original0, x):
|
def forward(self, b_parametrizations_buffer_original0, x):
|
||||||
_assert_tensor_metadata = torch.ops.aten._assert_tensor_metadata.default(x, None, None, torch.float64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata = None
|
_assert_tensor_metadata = torch.ops.aten._assert_tensor_metadata.default(x, None, None, torch.float64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata = None
|
||||||
_to_copy = torch.ops.aten._to_copy.default(x, dtype = torch.float64, layout = torch.strided, device = device(type='cuda', index=0)); x = None
|
_to_copy = torch.ops.aten._to_copy.default(x, dtype = torch.float64, layout = torch.strided, device = device(type='{self.device_type}', index=0)); x = None
|
||||||
view = torch.ops.aten.view.default(_to_copy, [4, 4]); _to_copy = None
|
view = torch.ops.aten.view.default(_to_copy, [4, 4]); _to_copy = None
|
||||||
add = torch.ops.aten.add.Tensor(b_parametrizations_buffer_original0, view); b_parametrizations_buffer_original0 = view = None
|
add = torch.ops.aten.add.Tensor(b_parametrizations_buffer_original0, view); b_parametrizations_buffer_original0 = view = None
|
||||||
view_1 = torch.ops.aten.view.default(add, [4, 4]); add = None
|
view_1 = torch.ops.aten.view.default(add, [4, 4]); add = None
|
||||||
@ -377,6 +380,7 @@ def forward(self, b_parametrizations_buffer_original0, x):
|
|||||||
self.assertEqual(res, ref)
|
self.assertEqual(res, ref)
|
||||||
|
|
||||||
@skipIfHpu
|
@skipIfHpu
|
||||||
|
@skipXPUIf(True, "https://github.com/intel/torch-xpu-ops/issues/1981")
|
||||||
def test_dtensor_dynamic_loss_parallel_log_softmax(self):
|
def test_dtensor_dynamic_loss_parallel_log_softmax(self):
|
||||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||||
|
|
||||||
@ -815,13 +819,13 @@ def forward(self, b_parametrizations_buffer_original0, x):
|
|||||||
out = layer_norm.permute(0, 2, 1)
|
out = layer_norm.permute(0, 2, 1)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
x = torch.randn(4, 2, 4, requires_grad=True, device="cuda")
|
x = torch.randn(4, 2, 4, requires_grad=True, device=self.device_type)
|
||||||
x_dt = DTensor.from_local(x, mesh, [Shard(1)], run_check=False)
|
x_dt = DTensor.from_local(x, mesh, [Shard(1)], run_check=False)
|
||||||
|
|
||||||
y = torch.randn(4, requires_grad=True, device="cuda")
|
y = torch.randn(4, requires_grad=True, device=self.device_type)
|
||||||
y_dt = DTensor.from_local(y, mesh, [Replicate()], run_check=False)
|
y_dt = DTensor.from_local(y, mesh, [Replicate()], run_check=False)
|
||||||
|
|
||||||
z = torch.randn(4, requires_grad=True, device="cuda")
|
z = torch.randn(4, requires_grad=True, device=self.device_type)
|
||||||
z_dt = DTensor.from_local(z, mesh, [Replicate()], run_check=False)
|
z_dt = DTensor.from_local(z, mesh, [Replicate()], run_check=False)
|
||||||
|
|
||||||
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
|
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
|
||||||
@ -919,7 +923,7 @@ def forward(self, b_parametrizations_buffer_original0, x):
|
|||||||
# pass in tensor as inputs/outputs, create DTensor and run redistribute
|
# pass in tensor as inputs/outputs, create DTensor and run redistribute
|
||||||
# (allgather collective) inside the fn
|
# (allgather collective) inside the fn
|
||||||
def fn(x_dt):
|
def fn(x_dt):
|
||||||
if x_dt.device_mesh.device_type == "cuda":
|
if x_dt.device_mesh.device_type == f"{self.device_type}":
|
||||||
return x_dt + 1
|
return x_dt + 1
|
||||||
else:
|
else:
|
||||||
return x_dt + 2
|
return x_dt + 2
|
||||||
@ -1051,7 +1055,7 @@ def forward(self, primals_1):
|
|||||||
|
|
||||||
model = FakeTransformer().to(self.device_type)
|
model = FakeTransformer().to(self.device_type)
|
||||||
|
|
||||||
tp_mesh = init_device_mesh("cuda", (2,), mesh_dim_names=("tp",))
|
tp_mesh = init_device_mesh(self.device_type, (2,), mesh_dim_names=("tp",))
|
||||||
|
|
||||||
# apply sequence parallel
|
# apply sequence parallel
|
||||||
parallel_plan = {
|
parallel_plan = {
|
||||||
|
|||||||
@ -27,8 +27,6 @@ from torch.testing._internal.common_utils import (
|
|||||||
instantiate_parametrized_tests,
|
instantiate_parametrized_tests,
|
||||||
parametrize,
|
parametrize,
|
||||||
run_tests,
|
run_tests,
|
||||||
TEST_CUDA,
|
|
||||||
TEST_HPU,
|
|
||||||
)
|
)
|
||||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||||
create_local_tensor_test_class,
|
create_local_tensor_test_class,
|
||||||
@ -541,7 +539,7 @@ class RedistributeTest(DTensorTestBase):
|
|||||||
local_out_dt = out_dt.to_local()
|
local_out_dt = out_dt.to_local()
|
||||||
local_expected_dt = expected_dt.to_local()
|
local_expected_dt = expected_dt.to_local()
|
||||||
self.assertEqual(out_dt.to_local(), expected_dt.to_local())
|
self.assertEqual(out_dt.to_local(), expected_dt.to_local())
|
||||||
if TEST_HPU or TEST_CUDA:
|
if torch.accelerator.is_available():
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
comm_mode.get_comm_counts()[
|
comm_mode.get_comm_counts()[
|
||||||
torch.ops._dtensor.shard_dim_alltoall
|
torch.ops._dtensor.shard_dim_alltoall
|
||||||
|
|||||||
@ -296,8 +296,8 @@ class DistTensorOpsTest(DTensorTestBase):
|
|||||||
self.assertEqual(dist_tensor.dtype, torch.float32)
|
self.assertEqual(dist_tensor.dtype, torch.float32)
|
||||||
self.assertEqual(zeros_like_dt.dtype, torch.bfloat16)
|
self.assertEqual(zeros_like_dt.dtype, torch.bfloat16)
|
||||||
|
|
||||||
@with_comms
|
|
||||||
@skip_if_lt_x_gpu(4)
|
@skip_if_lt_x_gpu(4)
|
||||||
|
@with_comms
|
||||||
def test_stack(self):
|
def test_stack(self):
|
||||||
mesh_2d = DeviceMesh(
|
mesh_2d = DeviceMesh(
|
||||||
self.device_type, torch.arange(self.world_size).reshape(2, 2)
|
self.device_type, torch.arange(self.world_size).reshape(2, 2)
|
||||||
|
|||||||
@ -952,9 +952,7 @@ User code traceback:
|
|||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
|
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
|
||||||
"""\
|
"""\
|
||||||
Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.
|
Graph break: skip: from user code at:
|
||||||
torch.compile will skip tracing the frame fn (test_error_messages.py line N) and fall back to eager.
|
|
||||||
The graph break occurred in the following user code:
|
|
||||||
File "test_error_messages.py", line N, in fn
|
File "test_error_messages.py", line N, in fn
|
||||||
assert x is None
|
assert x is None
|
||||||
""",
|
""",
|
||||||
@ -1080,88 +1078,6 @@ from user code:
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch._dynamo.config.patch(verbose=True)
|
|
||||||
@make_logging_test(graph_breaks=True)
|
|
||||||
def test_skipped_frame_with_verbose_traceback(self, records):
|
|
||||||
def fn(x):
|
|
||||||
with GenericCtxMgr():
|
|
||||||
torch._dynamo.graph_break()
|
|
||||||
return x + 1
|
|
||||||
|
|
||||||
torch.compile(fn, backend="eager")(torch.randn(3))
|
|
||||||
self.assertEqual(len(records), 1)
|
|
||||||
self.assertExpectedInline(
|
|
||||||
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
|
|
||||||
"""\
|
|
||||||
Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.
|
|
||||||
torch.compile will skip tracing the frame fn (test_error_messages.py line N) and fall back to eager.
|
|
||||||
The graph break occurred in the following user code:
|
|
||||||
File "test_error_messages.py", line N, in fn
|
|
||||||
torch._dynamo.graph_break()
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
self.assertExpectedInline(
|
|
||||||
munge_exc(records[0].exc_info[1], suppress_suffix=True, skip=0),
|
|
||||||
"""\
|
|
||||||
Graph break under GenericContextWrappingVariable
|
|
||||||
Explanation: Attempted to graph break in an active context manager(s) that doesn't support graph breaking.
|
|
||||||
Hint: Move the offending context manager(s) to outside the compiled region.
|
|
||||||
Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one.
|
|
||||||
|
|
||||||
Developer debug context: Active generic context managers: [GenericContextWrappingVariable(GenericCtxMgr)]
|
|
||||||
|
|
||||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0066.html
|
|
||||||
|
|
||||||
from user code:
|
|
||||||
File "test_error_messages.py", line N, in fn
|
|
||||||
torch._dynamo.graph_break()
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
@make_logging_test(graph_breaks=True)
|
|
||||||
def test_skip_frame_in_loop_message(self, records):
|
|
||||||
def fn(x):
|
|
||||||
for i in range(2):
|
|
||||||
with GenericCtxMgr():
|
|
||||||
if x.sum() > 0:
|
|
||||||
x = x + 1
|
|
||||||
return x
|
|
||||||
|
|
||||||
torch.compile(fn, backend="eager")(torch.randn(3))
|
|
||||||
self.assertEqual(len(records), 1)
|
|
||||||
self.assertExpectedInline(
|
|
||||||
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
|
|
||||||
"""\
|
|
||||||
Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.
|
|
||||||
torch.compile will skip tracing the frame fn (test_error_messages.py line N) and fall back to eager.
|
|
||||||
The graph break occurred in the following user code:
|
|
||||||
File "test_error_messages.py", line N, in fn
|
|
||||||
if x.sum() > 0:
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
@make_logging_test(dynamo=logging.DEBUG)
|
|
||||||
def test_skip_frame_empty_function_message(self, records):
|
|
||||||
def empty_fn(x):
|
|
||||||
pass
|
|
||||||
|
|
||||||
torch.compile(empty_fn, backend="eager")(torch.randn(3))
|
|
||||||
skip_messages = [
|
|
||||||
r
|
|
||||||
for r in records
|
|
||||||
if "intentionally decided to skip the frame" in r.getMessage()
|
|
||||||
]
|
|
||||||
self.assertEqual(len(skip_messages), 1)
|
|
||||||
msg = munge_exc(skip_messages[0].getMessage(), suppress_suffix=True, skip=0)
|
|
||||||
msg = re.sub(r" (\d+)$", r" N", msg, flags=re.MULTILINE)
|
|
||||||
|
|
||||||
self.assertExpectedInline(
|
|
||||||
msg,
|
|
||||||
"""\
|
|
||||||
Skipping frame torch.compile intentionally decided to skip the frame empty_fn (test_error_messages.py line N) and fall back to eager.
|
|
||||||
Reason: no content in function call empty_fn test_error_messages.py N""",
|
|
||||||
)
|
|
||||||
|
|
||||||
@make_logging_test(graph_breaks=True)
|
@make_logging_test(graph_breaks=True)
|
||||||
def test_nested_compile_user_frames(self, records):
|
def test_nested_compile_user_frames(self, records):
|
||||||
def fn(x):
|
def fn(x):
|
||||||
@ -1708,110 +1624,6 @@ from user code:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class NestedGraphBreakLoggingTests(
|
|
||||||
LoggingTestCase, torch._dynamo.test_case.TestCaseWithNestedGraphBreaks
|
|
||||||
):
|
|
||||||
@make_logging_test(graph_breaks=True)
|
|
||||||
def test_skipped_frame_with_verbose_traceback_nested(self, records):
|
|
||||||
global f1, f2, f3
|
|
||||||
|
|
||||||
class GenericCtxMgr:
|
|
||||||
def __enter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def f1(x):
|
|
||||||
with GenericCtxMgr():
|
|
||||||
torch._dynamo.graph_break()
|
|
||||||
return x + 1
|
|
||||||
|
|
||||||
def f2(x):
|
|
||||||
return f1(x + 2)
|
|
||||||
|
|
||||||
def f3(x):
|
|
||||||
return f2(x + 3)
|
|
||||||
|
|
||||||
torch.compile(f3, backend="eager")(torch.randn(3))
|
|
||||||
self.assertEqual(len(records), 1)
|
|
||||||
self.assertExpectedInline(
|
|
||||||
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
|
|
||||||
"""\
|
|
||||||
Graph break in user code at test_error_messages.py:N
|
|
||||||
Graph Break Reason: Encountered graph break that we cannot resume from. Compiling up to the previous resumable state, then skipping the rest of the function. Graph break encountered:
|
|
||||||
Graph break under GenericContextWrappingVariable
|
|
||||||
Explanation: Attempted to graph break in an active context manager(s) that doesn't support graph breaking.
|
|
||||||
Hint: Move the offending context manager(s) to outside the compiled region.
|
|
||||||
Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one.
|
|
||||||
|
|
||||||
Developer debug context: Active generic context managers: [GenericContextWrappingVariable(GenericCtxMgr)]
|
|
||||||
|
|
||||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0066.html
|
|
||||||
User code traceback:
|
|
||||||
File "test_error_messages.py", line N, in test_skipped_frame_with_verbose_traceback_nested
|
|
||||||
torch.compile(f3, backend="eager")(torch.randn(3))
|
|
||||||
File "test_error_messages.py", line N, in f3
|
|
||||||
return f2(x + 3)
|
|
||||||
File "test_error_messages.py", line N, in f2
|
|
||||||
return f1(x + 2)
|
|
||||||
File "test_error_messages.py", line N, in f1
|
|
||||||
torch._dynamo.graph_break()
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
@make_logging_test(graph_breaks=True)
|
|
||||||
def test_skip_frame_in_loop_message_nested(self, records):
|
|
||||||
global f1, f2, f3
|
|
||||||
|
|
||||||
class GenericCtxMgr:
|
|
||||||
def __enter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def f1(x):
|
|
||||||
for i in range(2):
|
|
||||||
with GenericCtxMgr():
|
|
||||||
if x.sum() > 0:
|
|
||||||
x = x + 1
|
|
||||||
return x
|
|
||||||
|
|
||||||
def f2(x):
|
|
||||||
return f1(x + 4)
|
|
||||||
|
|
||||||
def f3(x):
|
|
||||||
return f2(x + 5)
|
|
||||||
|
|
||||||
result = torch.compile(f3, backend="eager")(torch.randn(3)) # noqa: F841
|
|
||||||
self.assertEqual(len(records), 1)
|
|
||||||
self.assertExpectedInline(
|
|
||||||
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
|
|
||||||
"""\
|
|
||||||
Graph break in user code at test_error_messages.py:N
|
|
||||||
Graph Break Reason: Encountered graph break that we cannot resume from. Compiling up to the previous resumable state, then skipping the rest of the function. Graph break encountered:
|
|
||||||
Data-dependent branching
|
|
||||||
Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow.
|
|
||||||
Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.
|
|
||||||
Hint: Use `torch.cond` to express dynamic control flow.
|
|
||||||
|
|
||||||
Developer debug context: attempted to jump with TensorVariable()
|
|
||||||
|
|
||||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0170.html
|
|
||||||
User code traceback:
|
|
||||||
File "test_error_messages.py", line N, in test_skip_frame_in_loop_message_nested
|
|
||||||
result = torch.compile(f3, backend="eager")(torch.randn(3)) # noqa: F841
|
|
||||||
File "test_error_messages.py", line N, in f3
|
|
||||||
return f2(x + 5)
|
|
||||||
File "test_error_messages.py", line N, in f2
|
|
||||||
return f1(x + 4)
|
|
||||||
File "test_error_messages.py", line N, in f1
|
|
||||||
if x.sum() > 0:
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from torch._dynamo.test_case import run_tests
|
from torch._dynamo.test_case import run_tests
|
||||||
|
|
||||||
|
|||||||
@ -14036,44 +14036,6 @@ class DynamoOpPromotionTests(torch._dynamo.test_case.TestCase):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.fail(f"torch.compile failed with error: {e}")
|
self.fail(f"torch.compile failed with error: {e}")
|
||||||
|
|
||||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
|
||||||
def test_tensorify_track_item_symint(self):
|
|
||||||
def _random_resize(image: torch.Tensor):
|
|
||||||
image_metanet = image
|
|
||||||
default_patch_size = 14
|
|
||||||
rand_cnn_resolution = (224, 256)
|
|
||||||
min_nump = rand_cnn_resolution[0] // default_patch_size
|
|
||||||
max_nump = rand_cnn_resolution[1] // default_patch_size
|
|
||||||
new_nump = torch.randint(min_nump, max_nump + 1, (1,)).item()
|
|
||||||
torch._check(new_nump > 0)
|
|
||||||
torch._check(new_nump * default_patch_size > 1)
|
|
||||||
|
|
||||||
image_metanet = F.interpolate(
|
|
||||||
image_metanet,
|
|
||||||
size=(new_nump * default_patch_size, new_nump * default_patch_size),
|
|
||||||
mode="bilinear",
|
|
||||||
align_corners=True,
|
|
||||||
)
|
|
||||||
img_h_new, img_w_new = image_metanet.shape[2:]
|
|
||||||
|
|
||||||
return (img_h_new, img_w_new), image_metanet
|
|
||||||
|
|
||||||
_random_resize_compiled = torch.compile(fullgraph=True)(_random_resize)
|
|
||||||
|
|
||||||
# Test the function
|
|
||||||
input_tensor = torch.rand(1, 3, 224, 224)
|
|
||||||
(h, w), output = _random_resize_compiled(input_tensor)
|
|
||||||
|
|
||||||
# Verify output properties
|
|
||||||
self.assertEqual(output.shape[0], 1)
|
|
||||||
self.assertEqual(output.shape[1], 3)
|
|
||||||
self.assertEqual(output.shape[2], h)
|
|
||||||
self.assertEqual(output.shape[3], w)
|
|
||||||
self.assertTrue(h % 14 == 0)
|
|
||||||
self.assertTrue(w % 14 == 0)
|
|
||||||
self.assertTrue(224 <= h <= 256)
|
|
||||||
self.assertTrue(224 <= w <= 256)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from torch._dynamo.test_case import run_tests
|
from torch._dynamo.test_case import run_tests
|
||||||
|
|||||||
@ -10895,34 +10895,6 @@ get_out().sum().backward()
|
|||||||
|
|
||||||
self.assertTrue(gradcheck(func, x, fast_mode=True))
|
self.assertTrue(gradcheck(func, x, fast_mode=True))
|
||||||
|
|
||||||
def test_grad_thread_safety(self):
|
|
||||||
import threading
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
|
|
||||||
NUM_ITERS = 10
|
|
||||||
NUM_THREADS = 4
|
|
||||||
|
|
||||||
# Concurrent calls to tensor.untyped_storage()
|
|
||||||
def access_grad(tensor, barrier):
|
|
||||||
barrier.wait()
|
|
||||||
return weakref.ref(tensor.grad)
|
|
||||||
|
|
||||||
for i in range(NUM_ITERS):
|
|
||||||
tensor = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
|
|
||||||
(tensor**2).sum().backward()
|
|
||||||
|
|
||||||
barrier = threading.Barrier(NUM_THREADS)
|
|
||||||
with ThreadPoolExecutor(max_workers=NUM_THREADS) as executor:
|
|
||||||
futures = [
|
|
||||||
executor.submit(access_grad, tensor, barrier)
|
|
||||||
for _ in range(NUM_THREADS)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Check that all the grad tensors returned were the same
|
|
||||||
for future in futures:
|
|
||||||
self.assertEqual(future.result()(), tensor.grad)
|
|
||||||
self.assertIsNotNone(tensor.grad)
|
|
||||||
|
|
||||||
|
|
||||||
def index_perm_variable(shape, max_indices):
|
def index_perm_variable(shape, max_indices):
|
||||||
if not isinstance(shape, tuple):
|
if not isinstance(shape, tuple):
|
||||||
|
|||||||
@ -259,8 +259,7 @@ class TestTorchDeviceType(TestCase):
|
|||||||
def test_storage_use_count(self, device):
|
def test_storage_use_count(self, device):
|
||||||
a = torch.randn(10, device=device)
|
a = torch.randn(10, device=device)
|
||||||
prev_cf = torch._C._storage_Use_Count(a.untyped_storage()._cdata)
|
prev_cf = torch._C._storage_Use_Count(a.untyped_storage()._cdata)
|
||||||
# Two references: 'a' and the wrapper returned by untyped_storage()
|
self.assertEqual(prev_cf, 1)
|
||||||
self.assertEqual(prev_cf, 2)
|
|
||||||
b = a.view(2, 5)
|
b = a.view(2, 5)
|
||||||
self.assertEqual(torch._C._storage_Use_Count(b.untyped_storage()._cdata), prev_cf + 1)
|
self.assertEqual(torch._C._storage_Use_Count(b.untyped_storage()._cdata), prev_cf + 1)
|
||||||
|
|
||||||
@ -9325,7 +9324,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||||||
member_var = object()
|
member_var = object()
|
||||||
|
|
||||||
err_msg = "Creating a Tensor subclass from a class that does not inherit from Tensor"
|
err_msg = "Creating a Tensor subclass from a class that does not inherit from Tensor"
|
||||||
with self.assertRaisesRegex(TypeError, err_msg):
|
with self.assertRaisesRegex(RuntimeError, err_msg):
|
||||||
s0 = t0.as_subclass(BadSubTensor)
|
s0 = t0.as_subclass(BadSubTensor)
|
||||||
|
|
||||||
# FIXME: Port to a test suite that better fits slicing
|
# FIXME: Port to a test suite that better fits slicing
|
||||||
@ -10325,21 +10324,20 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||||||
|
|
||||||
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
|
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
|
||||||
def test_tensor_dead_weak_ref(self):
|
def test_tensor_dead_weak_ref(self):
|
||||||
x = torch.ones(2)
|
x = torch.empty(2)
|
||||||
w_x = weakref.ref(x)
|
w_x = weakref.ref(x)
|
||||||
y = torch.ones(2)
|
y = torch.empty(2)
|
||||||
y.grad = x
|
y.grad = x
|
||||||
del x
|
del x
|
||||||
|
|
||||||
x = w_x()
|
x = w_x()
|
||||||
# x should keep the tensor live. This didn't happen in earlier PyTorch
|
# Ideally, x would keep the tensor live. But CPython doesn't
|
||||||
# versions.
|
# provide enough hooks to do this. So it will go dead and x
|
||||||
|
# will transmute into an undefined tensor. Not great, but the
|
||||||
|
# best we can do.
|
||||||
del y
|
del y
|
||||||
|
|
||||||
self.assertEqual(2, x.sum())
|
self.assertRaises(RuntimeError, lambda: x.sigmoid())
|
||||||
|
|
||||||
del x
|
|
||||||
self.assertIsNone(w_x())
|
|
||||||
|
|
||||||
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
|
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
|
||||||
def test_storage_dead_weak_ref(self):
|
def test_storage_dead_weak_ref(self):
|
||||||
@ -10347,9 +10345,16 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||||||
w_x = weakref.ref(x)
|
w_x = weakref.ref(x)
|
||||||
y = torch.tensor(x)
|
y = torch.tensor(x)
|
||||||
del x
|
del x
|
||||||
self.assertIsNotNone(w_x())
|
|
||||||
|
x = w_x()
|
||||||
|
# Ideally, x would keep the storage live. But CPython doesn't
|
||||||
|
# provide enough hooks to do this. So it will go dead and x
|
||||||
|
# will transmute into storage with null StorageImpl. Not great, but the
|
||||||
|
# best we can do.
|
||||||
del y
|
del y
|
||||||
self.assertIsNone(w_x())
|
|
||||||
|
self.assertRaisesRegex(RuntimeError, "Got a null Storage", lambda: x[0])
|
||||||
|
self.assertRaisesRegex(RuntimeError, "Got a null Storage", lambda: x.float())
|
||||||
|
|
||||||
def test_tensor_resurrected_weak_ref(self):
|
def test_tensor_resurrected_weak_ref(self):
|
||||||
x = torch.empty(2)
|
x = torch.empty(2)
|
||||||
@ -10410,31 +10415,6 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||||||
|
|
||||||
self.assertTrue(called)
|
self.assertTrue(called)
|
||||||
|
|
||||||
def test_storage_thread_safety(self):
|
|
||||||
import threading
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
|
|
||||||
NUM_ITERS = 10
|
|
||||||
NUM_THREADS = 4
|
|
||||||
|
|
||||||
# Concurrent calls to tensor.untyped_storage()
|
|
||||||
def access_untyped_storage(tensor, barrier):
|
|
||||||
barrier.wait()
|
|
||||||
return weakref.ref(tensor.untyped_storage())
|
|
||||||
|
|
||||||
for i in range(NUM_ITERS):
|
|
||||||
tensor = torch.tensor([1.0, 2.0, 3.0])
|
|
||||||
barrier = threading.Barrier(NUM_THREADS)
|
|
||||||
with ThreadPoolExecutor(max_workers=NUM_THREADS) as executor:
|
|
||||||
futures = [
|
|
||||||
executor.submit(access_untyped_storage, tensor, barrier)
|
|
||||||
for _ in range(NUM_THREADS)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Check that all the storages returned were the same
|
|
||||||
for future in futures:
|
|
||||||
self.assertEqual(future.result()(), tensor.untyped_storage())
|
|
||||||
|
|
||||||
# FIXME: move to test_linalg
|
# FIXME: move to test_linalg
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_bmm_multithreaded(self):
|
def test_bmm_multithreaded(self):
|
||||||
|
|||||||
@ -1870,7 +1870,7 @@ class ConvertFrame:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
soft_fail = isinstance(e, Unsupported)
|
soft_fail = isinstance(e, Unsupported)
|
||||||
code = frame.f_code
|
|
||||||
# This is a soft failure. In the sense, the code path reaches here
|
# This is a soft failure. In the sense, the code path reaches here
|
||||||
# when we do not support graph breaks on bytecodes like LOAD_ATTR,
|
# when we do not support graph breaks on bytecodes like LOAD_ATTR,
|
||||||
# BUILD_SET etc. In such case, we can fallback to eager without
|
# BUILD_SET etc. In such case, we can fallback to eager without
|
||||||
@ -1885,13 +1885,7 @@ class ConvertFrame:
|
|||||||
user_stack_formatted = "".join(
|
user_stack_formatted = "".join(
|
||||||
traceback.format_list(user_stack)
|
traceback.format_list(user_stack)
|
||||||
)
|
)
|
||||||
frame_info = exc.format_frame_info(code)
|
user_stack_trace = f"Graph break: skip: from user code at:\n{user_stack_formatted}"
|
||||||
user_stack_trace = (
|
|
||||||
"Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.\n"
|
|
||||||
f"torch.compile will skip tracing the frame {frame_info} and fall back to eager.\n"
|
|
||||||
"The graph break occurred in the following user code:\n"
|
|
||||||
f"{user_stack_formatted}"
|
|
||||||
)
|
|
||||||
torch._logging.trace_structured(
|
torch._logging.trace_structured(
|
||||||
"artifact",
|
"artifact",
|
||||||
metadata_fn=lambda: {
|
metadata_fn=lambda: {
|
||||||
@ -1903,7 +1897,6 @@ class ConvertFrame:
|
|||||||
graph_break_log.debug(
|
graph_break_log.debug(
|
||||||
user_stack_trace,
|
user_stack_trace,
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
stack_info=config.verbose,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not config.suppress_errors and not soft_fail:
|
if not config.suppress_errors and not soft_fail:
|
||||||
|
|||||||
@ -794,38 +794,6 @@ def format_error_msg_verbose(
|
|||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
|
||||||
def format_frame_info(code: types.CodeType) -> str:
|
|
||||||
return (
|
|
||||||
f"{getattr(code, 'co_name', '<unknown>')} "
|
|
||||||
f"({getattr(code, 'co_filename', '<unknown>')} "
|
|
||||||
f"line {getattr(code, 'co_firstlineno', 0)})"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def format_skip_frame_message(code: Optional[types.CodeType], reason: str) -> str:
|
|
||||||
if code is not None:
|
|
||||||
frame_info = format_frame_info(code)
|
|
||||||
return (
|
|
||||||
f"torch.compile intentionally decided to skip the frame {frame_info} and fall back to eager.\n"
|
|
||||||
f"Reason: {reason}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return (
|
|
||||||
f"torch.compile intentionally decided to skip the frame and fall back to eager.\n"
|
|
||||||
f"Reason: {reason}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def format_loop_skip_frame_message(code: types.CodeType, frame_summary: str) -> str:
|
|
||||||
frame_info = format_frame_info(code)
|
|
||||||
return (
|
|
||||||
"Skipping frame because there is a graph break in a for/while loop\n"
|
|
||||||
f"torch.compile intentionally decided to skip the frame {frame_info} and fall back to eager.\n"
|
|
||||||
f"Reason: Skipping frame because there is a graph break in a for/while loop.\n"
|
|
||||||
f"{frame_summary}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def format_error_msg(
|
def format_error_msg(
|
||||||
exc: Exception,
|
exc: Exception,
|
||||||
code: types.CodeType,
|
code: types.CodeType,
|
||||||
|
|||||||
@ -94,8 +94,6 @@ from .exc import (
|
|||||||
BackendCompilerFailed,
|
BackendCompilerFailed,
|
||||||
collapse_resume_frames,
|
collapse_resume_frames,
|
||||||
format_graph_break_message,
|
format_graph_break_message,
|
||||||
format_loop_skip_frame_message,
|
|
||||||
format_skip_frame_message,
|
|
||||||
get_stack_above_dynamo,
|
get_stack_above_dynamo,
|
||||||
ResumePrologueTracingError,
|
ResumePrologueTracingError,
|
||||||
StepUnsupported,
|
StepUnsupported,
|
||||||
@ -607,9 +605,9 @@ def generic_jump(
|
|||||||
)
|
)
|
||||||
# compile a partial subgraph prefix then jump into user code
|
# compile a partial subgraph prefix then jump into user code
|
||||||
if self.maybe_has_backedge():
|
if self.maybe_has_backedge():
|
||||||
msg = format_loop_skip_frame_message(
|
msg = (
|
||||||
self.f_code,
|
"Skipping frame because there is a graph break in a for/while loop\n"
|
||||||
"".join(traceback.format_list([self.frame_summary()])),
|
f"{self.frame_summary()}"
|
||||||
)
|
)
|
||||||
log.info(msg)
|
log.info(msg)
|
||||||
raise exc.SkipFrame(msg)
|
raise exc.SkipFrame(msg)
|
||||||
@ -885,9 +883,9 @@ def break_graph_if_unsupported(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.maybe_has_backedge():
|
if self.maybe_has_backedge():
|
||||||
msg = format_loop_skip_frame_message(
|
msg = (
|
||||||
self.f_code,
|
"Skipping frame because there is a graph break in a for/while loop\n"
|
||||||
"".join(traceback.format_list([self.frame_summary()])),
|
f"{self.frame_summary()}"
|
||||||
)
|
)
|
||||||
log.info(msg)
|
log.info(msg)
|
||||||
raise exc.SkipFrame(msg) from excp
|
raise exc.SkipFrame(msg) from excp
|
||||||
@ -4628,9 +4626,8 @@ class InstructionTranslator(InstructionTranslatorBase):
|
|||||||
and not self.error_on_graph_break
|
and not self.error_on_graph_break
|
||||||
and not self.is_tracing_resume_prologue
|
and not self.is_tracing_resume_prologue
|
||||||
):
|
):
|
||||||
raise exc.SkipFrame(
|
raise exc.SkipFrame("because no content in function call")
|
||||||
format_skip_frame_message(self.f_code, "no content in function call")
|
|
||||||
)
|
|
||||||
self.instruction_pointer = None
|
self.instruction_pointer = None
|
||||||
_step_logger()(
|
_step_logger()(
|
||||||
logging.INFO,
|
logging.INFO,
|
||||||
|
|||||||
@ -2248,15 +2248,12 @@ def skip_frame_if_in_functorch_mode(val: torch.Tensor) -> None:
|
|||||||
try:
|
try:
|
||||||
val.data_ptr() # will throw for functorch tensors
|
val.data_ptr() # will throw for functorch tensors
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
from .exc import format_skip_frame_message, SkipFrame
|
from .exc import SkipFrame
|
||||||
|
|
||||||
# This will be GradTrackingTensor/BatchedTensor/etc
|
# This will be GradTrackingTensor/BatchedTensor/etc
|
||||||
functorch_subclass_name = re.sub(r"\(.*", "", repr(val))
|
functorch_subclass_name = re.sub(r"\(.*", "", repr(val))
|
||||||
raise SkipFrame(
|
raise SkipFrame(
|
||||||
format_skip_frame_message(
|
f"torch.compile cannot be run in context: {functorch_subclass_name}"
|
||||||
None,
|
|
||||||
f"torch.compile cannot be run in context: {functorch_subclass_name}",
|
|
||||||
)
|
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -42,7 +42,6 @@ from torch._guards import Source
|
|||||||
from .. import config, graph_break_hints, polyfills, variables
|
from .. import config, graph_break_hints, polyfills, variables
|
||||||
from ..bytecode_transformation import create_call_function, create_rot_n, is_generator
|
from ..bytecode_transformation import create_call_function, create_rot_n, is_generator
|
||||||
from ..exc import (
|
from ..exc import (
|
||||||
format_skip_frame_message,
|
|
||||||
get_dynamo_observed_exception,
|
get_dynamo_observed_exception,
|
||||||
handle_observed_exception,
|
handle_observed_exception,
|
||||||
InfiniteGeneratorError,
|
InfiniteGeneratorError,
|
||||||
@ -1653,13 +1652,8 @@ class SkipFunctionVariable(VariableTracker):
|
|||||||
skip_frame_msg = kwargs.get("msg")
|
skip_frame_msg = kwargs.get("msg")
|
||||||
if skip_frame_msg:
|
if skip_frame_msg:
|
||||||
skip_frame_msg = skip_frame_msg.as_python_constant()
|
skip_frame_msg = skip_frame_msg.as_python_constant()
|
||||||
else:
|
|
||||||
skip_frame_msg = ""
|
|
||||||
raise SkipFrame(
|
raise SkipFrame(
|
||||||
format_skip_frame_message(
|
f"Skip frame due to `torch._dynamo.skip_frame()`. Message: {skip_frame_msg}"
|
||||||
tx.f_code,
|
|
||||||
f"Skip frame due to `torch._dynamo.skip_frame()`. Message: {skip_frame_msg}",
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
elif self.value is torch._dynamo.step_unsupported:
|
elif self.value is torch._dynamo.step_unsupported:
|
||||||
raise StepUnsupported
|
raise StepUnsupported
|
||||||
|
|||||||
@ -536,14 +536,9 @@ class StorageWeakRefWrapper:
|
|||||||
if self.extra_ref_check is not None and not self.extra_ref_check():
|
if self.extra_ref_check is not None and not self.extra_ref_check():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# if extra_ref_check is not None we expect an additional reference
|
||||||
stor_count = torch._C._storage_Use_Count(self.ref.cdata)
|
stor_count = torch._C._storage_Use_Count(self.ref.cdata)
|
||||||
if self.extra_ref_check is not None:
|
return (stor_count - (self.extra_ref_check is not None)) == 0
|
||||||
# if extra_ref_check is not None we expect two additional references:
|
|
||||||
# - one from the Python storage object
|
|
||||||
# - one from the cached Tensor
|
|
||||||
stor_count -= 2
|
|
||||||
assert stor_count >= 0
|
|
||||||
return stor_count == 0
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
if self.ref is None or self.ref.expired():
|
if self.ref is None or self.ref.expired():
|
||||||
@ -1444,15 +1439,7 @@ class CUDAGraphNode:
|
|||||||
self_loc = self_ref()
|
self_loc = self_ref()
|
||||||
if self_loc is None:
|
if self_loc is None:
|
||||||
return False
|
return False
|
||||||
refcount = self_loc.get_output_refcount(i)
|
return self_loc.get_output_refcount(i) == 2
|
||||||
# pyrefly: ignore
|
|
||||||
if self_loc.cached_tensor_outputs[i]._use_count() > 1:
|
|
||||||
# c10::Tensor may also holds one reference count
|
|
||||||
assert refcount >= 3
|
|
||||||
return refcount == 3
|
|
||||||
else:
|
|
||||||
assert refcount >= 2
|
|
||||||
return refcount == 2
|
|
||||||
|
|
||||||
check = functools.partial(check_refcount, i=i)
|
check = functools.partial(check_refcount, i=i)
|
||||||
|
|
||||||
|
|||||||
@ -891,14 +891,10 @@ class TorchLogsFormatter(logging.Formatter):
|
|||||||
# exception handling - copied from logging.Formatter.format
|
# exception handling - copied from logging.Formatter.format
|
||||||
s = record.message
|
s = record.message
|
||||||
if record.exc_info:
|
if record.exc_info:
|
||||||
from torch._dynamo import config
|
|
||||||
|
|
||||||
should_format_exc = config.verbose or artifact_name != "graph_breaks"
|
|
||||||
# Cache the traceback text to avoid converting it multiple times
|
# Cache the traceback text to avoid converting it multiple times
|
||||||
# (it's constant anyway)
|
# (it's constant anyway)
|
||||||
if should_format_exc:
|
if not record.exc_text:
|
||||||
if not record.exc_text:
|
record.exc_text = self.formatException(record.exc_info)
|
||||||
record.exc_text = self.formatException(record.exc_info)
|
|
||||||
if record.exc_text:
|
if record.exc_text:
|
||||||
if s[-1:] != "\n":
|
if s[-1:] != "\n":
|
||||||
s = s + "\n"
|
s = s + "\n"
|
||||||
|
|||||||
@ -398,27 +398,36 @@ static PyObject* THPModule_swap_tensor_impl(PyObject* _unused, PyObject* args) {
|
|||||||
|
|
||||||
// weak_use_count() adds 1 if use_count is non-zero
|
// weak_use_count() adds 1 if use_count is non-zero
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
a->cdata.weak_use_count() == 1,
|
a->cdata->weak_use_count() == 1,
|
||||||
"Expected no weakrefs to t1's Tensor object but got ",
|
"Expected no weakrefs to t1's Tensor object but got ",
|
||||||
a->cdata.weak_use_count() - 1);
|
a->cdata->weak_use_count() - 1);
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
b->cdata.weak_use_count() == 1,
|
b->cdata->weak_use_count() == 1,
|
||||||
"Expected no weakrefs to t2's Tensor object but got ",
|
"Expected no weakrefs to t2's Tensor object but got ",
|
||||||
b->cdata.weak_use_count() - 1);
|
b->cdata->weak_use_count() - 1);
|
||||||
|
|
||||||
// NB: Creating local copies of *both* Tensors here ensures that they each
|
|
||||||
// hold a strong reference to their PyObject. This avoids having to fix up
|
|
||||||
// reference counts when we swap the PyObject slots below.
|
|
||||||
at::Tensor tmp_a = a->cdata;
|
|
||||||
at::Tensor tmp_b = b->cdata;
|
|
||||||
|
|
||||||
// Swap the Tensor Impl
|
// Swap the Tensor Impl
|
||||||
a->cdata = tmp_b;
|
c10::MaybeOwned<at::Tensor> tmp = a->cdata;
|
||||||
b->cdata = tmp_a;
|
|
||||||
|
|
||||||
// Fix up the PyObjects associated with each TensorImpl
|
// The TensorImpls contain PyObjectSlots that have a reference to the PyObject
|
||||||
a->cdata.unsafeGetTensorImpl()->pyobj_slot()->store_pyobj(a_);
|
// associated with the TensorImpl. Swap this field as well.
|
||||||
b->cdata.unsafeGetTensorImpl()->pyobj_slot()->store_pyobj(b_);
|
std::optional<PyObject*> mb_obj_a =
|
||||||
|
a->cdata->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
|
||||||
|
/*ignore_hermetic_tls=*/false);
|
||||||
|
std::optional<PyObject*> mb_obj_b =
|
||||||
|
b->cdata->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
|
||||||
|
/*ignore_hermetic_tls=*/false);
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
mb_obj_a.has_value() && mb_obj_b.has_value(),
|
||||||
|
"Both tensors should have PyObjects tagged by the current python interpreter");
|
||||||
|
TORCH_CHECK(mb_obj_a.value() == a_);
|
||||||
|
TORCH_CHECK(mb_obj_b.value() == b_);
|
||||||
|
|
||||||
|
a->cdata = b->cdata;
|
||||||
|
b->cdata = tmp;
|
||||||
|
|
||||||
|
a->cdata->unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(a_);
|
||||||
|
b->cdata->unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(b_);
|
||||||
|
|
||||||
Py_RETURN_NONE;
|
Py_RETURN_NONE;
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
|
|||||||
@ -45,9 +45,7 @@ struct ConcretePyInterpreterVTable final
|
|||||||
std::string name() const override;
|
std::string name() const override;
|
||||||
|
|
||||||
void incref(PyObject* pyobj) const override;
|
void incref(PyObject* pyobj) const override;
|
||||||
void decref(PyObject* pyobj) const override;
|
void decref(PyObject* pyobj, bool has_pyobj_slot) const override;
|
||||||
bool try_incref(const c10::impl::PyObjectSlot& pyobj_slot) const override;
|
|
||||||
size_t refcnt(PyObject* pyobj) const override;
|
|
||||||
|
|
||||||
// TODO: Need to make this work for StorageImpl too. I imagine I'll want to
|
// TODO: Need to make this work for StorageImpl too. I imagine I'll want to
|
||||||
// operate upon a PyObjectSlot rather than a TensorImpl
|
// operate upon a PyObjectSlot rather than a TensorImpl
|
||||||
@ -237,13 +235,53 @@ py::object torchDispatchFromTensorImpl(
|
|||||||
TorchFunctionName::TorchDispatch));
|
TorchFunctionName::TorchDispatch));
|
||||||
}
|
}
|
||||||
|
|
||||||
void ConcretePyInterpreterVTable::decref(PyObject* pyobj) const {
|
// NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg]
|
||||||
|
// Before calling PyInterpreter::decref, we must statically know if the
|
||||||
|
// pyobj has a PyObjectSlot or not.
|
||||||
|
// - If it has a PyObjectSlot, we need to be careful about PyObject resurrection
|
||||||
|
// - If it does not have a PyObjectSlot, we can freely decref
|
||||||
|
// One alternative to this is using PyObject_IsInstance
|
||||||
|
// to get at this information. However, we don't want to risk an incorrect
|
||||||
|
// `__instancecheck__` changing the semantics here.
|
||||||
|
void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool has_pyobj_slot)
|
||||||
|
const {
|
||||||
// Leak the pyobj if not initialized. This can happen if we are running
|
// Leak the pyobj if not initialized. This can happen if we are running
|
||||||
// exit handlers that are destructing tensors with residual (owned)
|
// exit handlers that are destructing tensors with residual (owned)
|
||||||
// PyObjects stored in them.
|
// PyObjects stored in them.
|
||||||
if (!Py_IsInitialized())
|
if (!Py_IsInitialized())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
pybind11::gil_scoped_acquire gil;
|
pybind11::gil_scoped_acquire gil;
|
||||||
|
// Two possibilities:
|
||||||
|
// 1. We are decref-ing an object that has a PyObjectSlot, like a Tensor or
|
||||||
|
// Storage. Then we must be careful about PyObject resurrection (see
|
||||||
|
// THPVariable_clear).
|
||||||
|
// 2. We are decref-ing some other Python object. We don't do
|
||||||
|
// PyObject resurrection on non-Tensors, so we just carry on as usual
|
||||||
|
if (has_pyobj_slot && Py_REFCNT(pyobj) > 1) {
|
||||||
|
if (THPVariable_Check(pyobj)) {
|
||||||
|
// It's still alive! This can happen if a weak ref resurrected
|
||||||
|
// the PyObject without flipping ownership. At this point it is
|
||||||
|
// too late to rescue the object, so just stub out the PyObject
|
||||||
|
// so that it fails on subsequent uses. Don't raise an error here;
|
||||||
|
// you're probably in a destructor.
|
||||||
|
TORCH_WARN(
|
||||||
|
"Deallocating Tensor that still has live PyObject references. "
|
||||||
|
"This probably happened because you took out a weak reference to "
|
||||||
|
"Tensor and didn't call _fix_weakref() after dereferencing it. "
|
||||||
|
"Subsequent accesses to this tensor via the PyObject will now fail.");
|
||||||
|
(reinterpret_cast<THPVariable*>(pyobj))->cdata =
|
||||||
|
c10::MaybeOwned<torch::autograd::Variable>();
|
||||||
|
} else if (THPStorage_Check(pyobj)) {
|
||||||
|
TORCH_WARN(
|
||||||
|
"Deallocating UntypedStorage that still has live PyObject references. "
|
||||||
|
"This probably happened because you took out a weak reference to "
|
||||||
|
"UntypedStorage and didn't call _fix_weakref() after dereferencing it. "
|
||||||
|
"Subsequent accesses to this storage via the PyObject will now fail.");
|
||||||
|
(reinterpret_cast<THPStorage*>(pyobj))->cdata =
|
||||||
|
c10::MaybeOwned<c10::Storage>();
|
||||||
|
}
|
||||||
|
}
|
||||||
Py_DECREF(pyobj);
|
Py_DECREF(pyobj);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -254,25 +292,6 @@ void ConcretePyInterpreterVTable::incref(PyObject* pyobj) const {
|
|||||||
Py_INCREF(pyobj);
|
Py_INCREF(pyobj);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ConcretePyInterpreterVTable::try_incref(
|
|
||||||
const c10::impl::PyObjectSlot& pyobj_slot) const {
|
|
||||||
if (!Py_IsInitialized())
|
|
||||||
return false;
|
|
||||||
pybind11::gil_scoped_acquire gil;
|
|
||||||
PyObject* pyobj = pyobj_slot.load_pyobj();
|
|
||||||
if (!pyobj) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return PyUnstable_TryIncRef(pyobj);
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t ConcretePyInterpreterVTable::refcnt(PyObject* pyobj) const {
|
|
||||||
if (!Py_IsInitialized() || pyobj == nullptr)
|
|
||||||
return 0;
|
|
||||||
pybind11::gil_scoped_acquire gil;
|
|
||||||
return Py_REFCNT(pyobj);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool isPythonTensor(const at::Tensor& tensor) {
|
bool isPythonTensor(const at::Tensor& tensor) {
|
||||||
return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python);
|
return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python);
|
||||||
}
|
}
|
||||||
@ -601,7 +620,11 @@ static void set_tensor_attr_with_capsule(
|
|||||||
const c10::TensorImpl* tensor,
|
const c10::TensorImpl* tensor,
|
||||||
py::capsule& capsule,
|
py::capsule& capsule,
|
||||||
const char* attr_name) {
|
const char* attr_name) {
|
||||||
PyObject* obj = tensor->pyobj_slot()->load_pyobj();
|
std::optional<PyObject*> mb_obj = tensor->pyobj_slot()->check_pyobj(
|
||||||
|
/*ignore_hermetic_tls=*/false);
|
||||||
|
TORCH_CHECK(
|
||||||
|
mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value");
|
||||||
|
auto obj = mb_obj.value();
|
||||||
py::handle(obj).attr(attr_name) = capsule;
|
py::handle(obj).attr(attr_name) = capsule;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -625,7 +648,11 @@ static c10::ArrayRef<T> get_set_cached_attr(
|
|||||||
const c10::TensorImpl* tensor,
|
const c10::TensorImpl* tensor,
|
||||||
const char* base_attr_name,
|
const char* base_attr_name,
|
||||||
const py::object& obj) {
|
const py::object& obj) {
|
||||||
PyObject* tensor_obj = tensor->pyobj_slot()->load_pyobj();
|
std::optional<PyObject*> mb_obj =
|
||||||
|
tensor->pyobj_slot()->check_pyobj(getPyInterpreter());
|
||||||
|
TORCH_CHECK(
|
||||||
|
mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value");
|
||||||
|
auto tensor_obj = mb_obj.value();
|
||||||
auto buffer_len_attr_name = std::string(base_attr_name) + std::string("_len");
|
auto buffer_len_attr_name = std::string(base_attr_name) + std::string("_len");
|
||||||
|
|
||||||
bool is_buffer_allocated = false;
|
bool is_buffer_allocated = false;
|
||||||
|
|||||||
@ -23,8 +23,6 @@
|
|||||||
#include <c10/util/intrusive_ptr.h>
|
#include <c10/util/intrusive_ptr.h>
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
|
|
||||||
using torch::utils::PyObjectPreservation;
|
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
void THPPointer<c10::StorageImpl>::free() {
|
void THPPointer<c10::StorageImpl>::free() {
|
||||||
if (ptr) {
|
if (ptr) {
|
||||||
@ -34,72 +32,238 @@ void THPPointer<c10::StorageImpl>::free() {
|
|||||||
|
|
||||||
PyTypeObject* THPStorageClass = nullptr;
|
PyTypeObject* THPStorageClass = nullptr;
|
||||||
|
|
||||||
// Create a new Python Storage object, but don't set the pyobj slot on the
|
PyObject* THPStorage_NewWithStorage(
|
||||||
// c10::Storage object.
|
PyTypeObject* type,
|
||||||
static PyObject* THPStorage_New(PyTypeObject* type, c10::Storage _storage) {
|
c10::Storage _storage,
|
||||||
|
bool allow_preexisting_pyobj) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
PyType_IsSubtype(type, &THPStorageType),
|
||||||
|
"Creating a Storage subclass from a class that does not inherit from ",
|
||||||
|
"Storage is not possible. Make sure your class inherits from Storage.");
|
||||||
|
|
||||||
|
auto maybe_pyobj = _storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj(
|
||||||
|
/*ignore_hermetic_tls=*/false);
|
||||||
|
if (maybe_pyobj.has_value() && maybe_pyobj.value()) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
allow_preexisting_pyobj,
|
||||||
|
"Creating a new Storage subclass ",
|
||||||
|
type->tp_name,
|
||||||
|
" but the raw Storage object is already associated to a python object ",
|
||||||
|
"of type ",
|
||||||
|
maybe_pyobj.value()->ob_type->tp_name);
|
||||||
|
PyObject* obj = *maybe_pyobj;
|
||||||
|
PyTypeObject* obj_type = Py_TYPE(obj);
|
||||||
|
TORCH_CHECK(
|
||||||
|
obj_type == type || PyType_IsSubtype(obj_type, type),
|
||||||
|
"Creating a new Storage subclass ",
|
||||||
|
type->tp_name,
|
||||||
|
" but the raw Storage object is already associated to a python object ",
|
||||||
|
"of type ",
|
||||||
|
maybe_pyobj.value()->ob_type->tp_name,
|
||||||
|
" which is not a subclass of the "
|
||||||
|
"requested type");
|
||||||
|
return THPStorage_Wrap(std::move(_storage));
|
||||||
|
}
|
||||||
|
|
||||||
PyObject* obj = type->tp_alloc(type, 0);
|
PyObject* obj = type->tp_alloc(type, 0);
|
||||||
TORCH_CHECK(obj, "Failed to allocate a ", type->tp_name, " object");
|
TORCH_CHECK(obj, "Failed to allocate a ", type->tp_name, " object");
|
||||||
|
|
||||||
// Ensure that PyUnstable_TryIncref calls don't fail spuriously in
|
auto s = reinterpret_cast<THPStorage*>(obj);
|
||||||
// free-threaded Python.
|
|
||||||
PyUnstable_EnableTryIncRef(obj);
|
|
||||||
|
|
||||||
auto s = (THPStorage*)obj;
|
new (&s->cdata) c10::MaybeOwned<c10::Storage>();
|
||||||
new (&s->cdata) c10::Storage(std::move(_storage));
|
|
||||||
return obj;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new Python Storage object for a new c10::Storage, and set the
|
s->cdata = c10::MaybeOwned<c10::Storage>::owned(std::move(_storage));
|
||||||
// pyobj slot. The c10::Storage must not already have a pyobj set.
|
|
||||||
PyObject* THPStorage_NewWithStorage(PyTypeObject* type, c10::Storage _storage) {
|
|
||||||
TORCH_CHECK(
|
|
||||||
type == THPStorageClass || PyType_IsSubtype(type, &THPStorageType),
|
|
||||||
"Creating a Storage subclass from a class that does not inherit from ",
|
|
||||||
"Storage is not possible. Make sure your class inherits from Storage.");
|
|
||||||
TORCH_INTERNAL_ASSERT(_storage.use_count() == 1);
|
|
||||||
|
|
||||||
c10::StorageImpl* storage_impl = _storage.unsafeGetStorageImpl();
|
if (!c10::impl::HermeticPyObjectTLS::get_state()) {
|
||||||
PyObject* obj = THPStorage_New(type, std::move(_storage));
|
s->is_hermetic = false;
|
||||||
PyObjectPreservation::init_fresh_nonatomic(
|
const auto& storage = THPStorage_Unpack(s);
|
||||||
storage_impl, storage_impl->pyobj_slot(), obj);
|
storage.unsafeGetStorageImpl()->pyobj_slot()->init_pyobj(obj);
|
||||||
return obj;
|
} else {
|
||||||
}
|
s->is_hermetic = true;
|
||||||
|
|
||||||
// Returns a PyObject wrapper for the c10::Storage object. The existing
|
|
||||||
// wrapper is returned if it already exists.
|
|
||||||
PyObject* THPStorage_Wrap(c10::Storage storage) {
|
|
||||||
if (c10::impl::HermeticPyObjectTLS::get_state()) {
|
|
||||||
return THPStorage_New(THPStorageClass, std::move(storage));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return obj;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wraps the c10::Storage with a storage PyObject
|
||||||
|
PyObject* THPStorage_Wrap(c10::Storage storage) {
|
||||||
c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl();
|
c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl();
|
||||||
|
if (c10::impl::HermeticPyObjectTLS::get_state()) {
|
||||||
|
return THPStorage_NewWithStorage(THPStorageClass, std::move(storage));
|
||||||
|
}
|
||||||
c10::impl::PyObjectSlot* pyobj_slot = storage_impl->pyobj_slot();
|
c10::impl::PyObjectSlot* pyobj_slot = storage_impl->pyobj_slot();
|
||||||
|
|
||||||
PyObject* obj = pyobj_slot->load_pyobj();
|
std::optional<PyObject*> maybe_pyobj = pyobj_slot->check_pyobj(
|
||||||
if (obj) {
|
/*ignore_hermetic_tls=*/false);
|
||||||
return Py_NewRef(obj);
|
if (maybe_pyobj.has_value()) {
|
||||||
}
|
auto obj = *maybe_pyobj;
|
||||||
|
if (obj) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
THPStorage_Check(obj),
|
||||||
|
"Expected a storage type, but got ",
|
||||||
|
Py_TYPE(obj)->tp_name);
|
||||||
|
|
||||||
obj = THPStorage_New(THPStorageClass, std::move(storage));
|
if (pyobj_slot->owns_pyobj()) {
|
||||||
PyObject* wrapper =
|
pyobj_slot->set_owns_pyobj(false);
|
||||||
PyObjectPreservation::init_once(storage_impl, pyobj_slot, obj);
|
reinterpret_cast<THPStorage*>(obj)->cdata =
|
||||||
if (wrapper != obj) {
|
c10::MaybeOwned<c10::Storage>::owned(std::move(storage));
|
||||||
// Another thread beat us to it
|
return obj;
|
||||||
Py_DECREF(obj);
|
} else {
|
||||||
return Py_NewRef(wrapper);
|
Py_INCREF(obj);
|
||||||
|
return obj;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return obj;
|
return THPStorage_NewWithStorage(THPStorageClass, std::move(storage));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void THPStorage_dealloc(PyObject* self) {
|
static bool THPStorage_isPreservable(THPStorage* self) {
|
||||||
THPStorage* _self = reinterpret_cast<THPStorage*>(self);
|
if (self->cdata.unsafeIsBorrowed()) {
|
||||||
auto pyobj_slot = _self->cdata.unsafeGetStorageImpl()->pyobj_slot();
|
return false;
|
||||||
if (pyobj_slot->load_pyobj() == self) {
|
|
||||||
TORCH_INTERNAL_ASSERT(_self->cdata.use_count() == 1);
|
|
||||||
pyobj_slot->clear();
|
|
||||||
}
|
}
|
||||||
_self->cdata.~Storage();
|
auto const& storage = THPStorage_Unpack(self);
|
||||||
|
|
||||||
|
if (self->is_hermetic) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj(
|
||||||
|
/*ignore_hermetic_tls=*/true) != reinterpret_cast<PyObject*>(self)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (storage.use_count() <= 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool THPStorage_tryPreserve(THPStorage* self) {
|
||||||
|
if (!THPStorage_isPreservable(self)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto& storage = THPStorage_Unpack(self);
|
||||||
|
c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl();
|
||||||
|
|
||||||
|
auto maybe_pyobj = storage_impl->pyobj_slot()->check_pyobj(
|
||||||
|
/*ignore_hermetic_tls=*/true);
|
||||||
|
// NOTE: It is possible to just set the PyObjectSlot here, but the point is
|
||||||
|
// that we should have already set PyObjectSlot when the storage PyObject
|
||||||
|
// was created.
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
maybe_pyobj.has_value(),
|
||||||
|
"Trying to preserve a Python storage whose PyObjectSlot does not have a PyObject");
|
||||||
|
|
||||||
|
PyObject* pyobj = *maybe_pyobj;
|
||||||
|
|
||||||
|
TORCH_CHECK(
|
||||||
|
THPStorage_Check(pyobj),
|
||||||
|
"Expected a storage type, but got ",
|
||||||
|
Py_TYPE(pyobj)->tp_name);
|
||||||
|
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
(void*)pyobj == (void*)self,
|
||||||
|
"Python storage and the PyObject in the internal PyObjectSlot are not at the same address");
|
||||||
|
|
||||||
|
TORCH_INTERNAL_ASSERT(!storage_impl->pyobj_slot()->owns_pyobj());
|
||||||
|
|
||||||
|
storage_impl->pyobj_slot()->set_owns_pyobj(true);
|
||||||
|
// When resurrecting, we MUST use _Py_NewReference and not Py_INCREF to
|
||||||
|
// ensure the PyObject is in a valid state
|
||||||
|
_Py_NewReference(reinterpret_cast<PyObject*>(self));
|
||||||
|
|
||||||
|
self->cdata = c10::MaybeOwned<c10::Storage>::borrowed(storage);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void THPStorage_subclass_dealloc(PyObject* self) {
|
||||||
|
THPStorage* _self = reinterpret_cast<THPStorage*>(self);
|
||||||
|
|
||||||
|
if (THPStorage_tryPreserve(_self)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Some subclass of StorageBase could be GC-tracked objects even
|
||||||
|
// though the base class is not
|
||||||
|
auto* type = Py_TYPE(self);
|
||||||
|
if (PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC) != 0) {
|
||||||
|
PyObject_GC_UnTrack(self);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool has_finalizer = type->tp_finalize || type->tp_del;
|
||||||
|
|
||||||
|
if (type->tp_finalize) {
|
||||||
|
PyObject_GC_Track(self);
|
||||||
|
if (PyObject_CallFinalizerFromDealloc(self) < 0) {
|
||||||
|
// The finalizer has resurrected the PyObject and there is a new Python
|
||||||
|
// reference to it, so we can just stop deallocating. Read about
|
||||||
|
// resurrection from `__del__` here:
|
||||||
|
// https://docs.python.org/3/reference/datamodel.html#object.__del__
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
PyObject_GC_UnTrack(self);
|
||||||
|
}
|
||||||
|
|
||||||
|
// base test is unnecessary as THPStorae does not set this
|
||||||
|
if (type->tp_weaklistoffset) {
|
||||||
|
PyObject_ClearWeakRefs(self);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (type->tp_del) {
|
||||||
|
PyObject_GC_Track(self);
|
||||||
|
type->tp_del(self);
|
||||||
|
if (Py_REFCNT(self) > 0) {
|
||||||
|
// Resurrected (see above comment about resurrection from `__del__`)
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
PyObject_GC_UnTrack(self);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (has_finalizer) {
|
||||||
|
/* New weakrefs could be created during the finalizer call.
|
||||||
|
If this occurs, clear them out without calling their
|
||||||
|
finalizers since they might rely on part of the object
|
||||||
|
being finalized that has already been destroyed. */
|
||||||
|
if (type->tp_weaklistoffset) {
|
||||||
|
/* Modeled after GET_WEAKREFS_LISTPTR() */
|
||||||
|
PyWeakReference** list = reinterpret_cast<PyWeakReference**>(
|
||||||
|
PyObject_GET_WEAKREFS_LISTPTR(self));
|
||||||
|
while (*list)
|
||||||
|
_PyWeakref_ClearRef(*list);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear slots
|
||||||
|
{
|
||||||
|
PyTypeObject* base = type;
|
||||||
|
while (base != &THPStorageType) {
|
||||||
|
if (Py_SIZE(base)) {
|
||||||
|
clear_slots(base, self);
|
||||||
|
}
|
||||||
|
base = base->tp_base;
|
||||||
|
TORCH_INTERNAL_ASSERT(base);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear __dict__
|
||||||
|
if (C10_LIKELY(type->tp_dictoffset)) {
|
||||||
|
PyObject** dictptr = _PyObject_GetDictPtr(self);
|
||||||
|
if (dictptr != nullptr) {
|
||||||
|
PyObject* dict = *dictptr;
|
||||||
|
if (dict != nullptr) {
|
||||||
|
Py_DECREF(dict);
|
||||||
|
*dictptr = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_INTERNAL_ASSERT(Py_TYPE(self) == type);
|
||||||
|
|
||||||
|
_self->cdata.~MaybeOwned<c10::Storage>();
|
||||||
Py_TYPE(_self)->tp_free(self);
|
Py_TYPE(_self)->tp_free(self);
|
||||||
|
|
||||||
|
TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE);
|
||||||
|
Py_DECREF(type);
|
||||||
}
|
}
|
||||||
|
|
||||||
static PyObject* THPStorage_pynew(
|
static PyObject* THPStorage_pynew(
|
||||||
@ -389,13 +553,64 @@ static PyMappingMethods THPStorage_mappingmethods = {
|
|||||||
reinterpret_cast<binaryfunc>(THPStorage_get),
|
reinterpret_cast<binaryfunc>(THPStorage_get),
|
||||||
reinterpret_cast<objobjargproc>(THPStorage_set)};
|
reinterpret_cast<objobjargproc>(THPStorage_set)};
|
||||||
|
|
||||||
|
struct THPStorageMeta {
|
||||||
|
PyHeapTypeObject base;
|
||||||
|
};
|
||||||
|
|
||||||
|
static int THPStorageMetaType_init(
|
||||||
|
PyObject* cls,
|
||||||
|
PyObject* args,
|
||||||
|
PyObject* kwargs);
|
||||||
|
|
||||||
|
static PyTypeObject THPStorageMetaType = {
|
||||||
|
PyVarObject_HEAD_INIT(DEFERRED_ADDRESS(&PyType_Type), 0)
|
||||||
|
"torch._C._StorageMeta", /* tp_name */
|
||||||
|
sizeof(THPStorageMeta), /* tp_basicsize */
|
||||||
|
0, /* tp_itemsize */
|
||||||
|
nullptr, /* tp_dealloc */
|
||||||
|
0, /* tp_vectorcall_offset */
|
||||||
|
nullptr, /* tp_getattr */
|
||||||
|
nullptr, /* tp_setattr */
|
||||||
|
nullptr, /* tp_reserved */
|
||||||
|
nullptr, /* tp_repr */
|
||||||
|
nullptr, /* tp_as_number */
|
||||||
|
nullptr, /* tp_as_sequence */
|
||||||
|
nullptr, /* tp_as_mapping */
|
||||||
|
nullptr, /* tp_hash */
|
||||||
|
nullptr, /* tp_call */
|
||||||
|
nullptr, /* tp_str */
|
||||||
|
nullptr, /* tp_getattro */
|
||||||
|
nullptr, /* tp_setattro */
|
||||||
|
nullptr, /* tp_as_buffer */
|
||||||
|
// NOLINTNEXTLINE(misc-redundant-expression)
|
||||||
|
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
|
||||||
|
nullptr, /* tp_doc */
|
||||||
|
nullptr, /* tp_traverse */
|
||||||
|
nullptr, /* tp_clear */
|
||||||
|
nullptr, /* tp_richcompare */
|
||||||
|
0, /* tp_weaklistoffset */
|
||||||
|
nullptr, /* tp_iter */
|
||||||
|
nullptr, /* tp_iternext */
|
||||||
|
nullptr, /* tp_methods */
|
||||||
|
nullptr, /* tp_members */
|
||||||
|
nullptr, /* tp_getset */
|
||||||
|
DEFERRED_ADDRESS(&PyType_Type), /* tp_base */
|
||||||
|
nullptr, /* tp_dict */
|
||||||
|
nullptr, /* tp_descr_get */
|
||||||
|
nullptr, /* tp_descr_set */
|
||||||
|
0, /* tp_dictoffset */
|
||||||
|
THPStorageMetaType_init, /* tp_init */
|
||||||
|
nullptr, /* tp_alloc */
|
||||||
|
nullptr, /* tp_new */
|
||||||
|
};
|
||||||
|
|
||||||
// TODO: implement equality
|
// TODO: implement equality
|
||||||
PyTypeObject THPStorageType = {
|
PyTypeObject THPStorageType = {
|
||||||
PyVarObject_HEAD_INIT(DEFERRED_ADDRESS(&PyType_Type), 0)
|
PyVarObject_HEAD_INIT(&THPStorageMetaType, 0)
|
||||||
"torch._C.StorageBase", /* tp_name */
|
"torch._C.StorageBase", /* tp_name */
|
||||||
sizeof(THPStorage), /* tp_basicsize */
|
sizeof(THPStorage), /* tp_basicsize */
|
||||||
0, /* tp_itemsize */
|
0, /* tp_itemsize */
|
||||||
THPStorage_dealloc, /* tp_dealloc */
|
nullptr, /* tp_dealloc */
|
||||||
0, /* tp_vectorcall_offset */
|
0, /* tp_vectorcall_offset */
|
||||||
nullptr, /* tp_getattr */
|
nullptr, /* tp_getattr */
|
||||||
nullptr, /* tp_setattr */
|
nullptr, /* tp_setattr */
|
||||||
@ -434,6 +649,15 @@ PyTypeObject THPStorageType = {
|
|||||||
THPStorage_pynew, /* tp_new */
|
THPStorage_pynew, /* tp_new */
|
||||||
};
|
};
|
||||||
|
|
||||||
|
int THPStorageMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs) {
|
||||||
|
if (PyType_Type.tp_init(cls, args, kwargs) < 0) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
(reinterpret_cast<PyTypeObject*>(cls))->tp_dealloc =
|
||||||
|
static_cast<destructor>(THPStorage_subclass_dealloc);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
static PyObject* THPStorage_device(THPStorage* self, void* unused) {
|
static PyObject* THPStorage_device(THPStorage* self, void* unused) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
THPStorage_assertNotNull(self);
|
THPStorage_assertNotNull(self);
|
||||||
@ -468,6 +692,13 @@ bool THPStorage_init(PyObject* module) {
|
|||||||
THPUtils_addPyMethodDefs(methods, THPStorage_getMethods());
|
THPUtils_addPyMethodDefs(methods, THPStorage_getMethods());
|
||||||
THPUtils_addPyMethodDefs(methods, THPStorage_getSharingMethods());
|
THPUtils_addPyMethodDefs(methods, THPStorage_getSharingMethods());
|
||||||
|
|
||||||
|
THPStorageMetaType.tp_base = &PyType_Type;
|
||||||
|
if (PyType_Ready(&THPStorageMetaType) < 0)
|
||||||
|
return false;
|
||||||
|
Py_INCREF(&THPStorageMetaType);
|
||||||
|
PyModule_AddObject(
|
||||||
|
module, "_StorageMeta", reinterpret_cast<PyObject*>(&THPStorageMetaType));
|
||||||
|
|
||||||
THPStorageType.tp_methods = methods.data();
|
THPStorageType.tp_methods = methods.data();
|
||||||
THPStorageType.tp_getset = THPStorage_properties;
|
THPStorageType.tp_getset = THPStorage_properties;
|
||||||
if (PyType_Ready(&THPStorageType) < 0)
|
if (PyType_Ready(&THPStorageType) < 0)
|
||||||
|
|||||||
@ -11,13 +11,15 @@
|
|||||||
|
|
||||||
struct THPStorage {
|
struct THPStorage {
|
||||||
PyObject_HEAD
|
PyObject_HEAD
|
||||||
c10::Storage cdata;
|
c10::MaybeOwned<c10::Storage> cdata;
|
||||||
|
bool is_hermetic;
|
||||||
};
|
};
|
||||||
|
|
||||||
TORCH_PYTHON_API PyObject* THPStorage_Wrap(c10::Storage storage);
|
TORCH_PYTHON_API PyObject* THPStorage_Wrap(c10::Storage storage);
|
||||||
TORCH_PYTHON_API PyObject* THPStorage_NewWithStorage(
|
TORCH_PYTHON_API PyObject* THPStorage_NewWithStorage(
|
||||||
PyTypeObject* type,
|
PyTypeObject* type,
|
||||||
c10::Storage _storage);
|
c10::Storage _storage,
|
||||||
|
bool allow_preexisting_pyobj = false);
|
||||||
TORCH_PYTHON_API extern PyTypeObject* THPStorageClass;
|
TORCH_PYTHON_API extern PyTypeObject* THPStorageClass;
|
||||||
|
|
||||||
inline bool THPStorage_CheckTypeExact(PyTypeObject* tp) {
|
inline bool THPStorage_CheckTypeExact(PyTypeObject* tp) {
|
||||||
@ -47,7 +49,7 @@ TORCH_PYTHON_API void THPStorage_assertNotNull(PyObject* obj);
|
|||||||
TORCH_PYTHON_API extern PyTypeObject THPStorageType;
|
TORCH_PYTHON_API extern PyTypeObject THPStorageType;
|
||||||
|
|
||||||
inline const c10::Storage& THPStorage_Unpack(THPStorage* storage) {
|
inline const c10::Storage& THPStorage_Unpack(THPStorage* storage) {
|
||||||
return storage->cdata;
|
return *storage->cdata;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline const c10::Storage& THPStorage_Unpack(PyObject* obj) {
|
inline const c10::Storage& THPStorage_Unpack(PyObject* obj) {
|
||||||
|
|||||||
@ -529,8 +529,9 @@ static PyObject* THPStorage__setCdata(PyObject* _self, PyObject* new_cdata) {
|
|||||||
THPUtils_typename(new_cdata));
|
THPUtils_typename(new_cdata));
|
||||||
c10::StorageImpl* ptr =
|
c10::StorageImpl* ptr =
|
||||||
static_cast<c10::StorageImpl*>(PyLong_AsVoidPtr(new_cdata));
|
static_cast<c10::StorageImpl*>(PyLong_AsVoidPtr(new_cdata));
|
||||||
self->cdata =
|
self->cdata.~MaybeOwned<c10::Storage>();
|
||||||
c10::Storage(c10::intrusive_ptr<c10::StorageImpl>::reclaim_copy(ptr));
|
self->cdata = c10::MaybeOwned<c10::Storage>::owned(
|
||||||
|
c10::Storage(c10::intrusive_ptr<c10::StorageImpl>::reclaim_copy(ptr)));
|
||||||
Py_INCREF(self);
|
Py_INCREF(self);
|
||||||
return reinterpret_cast<PyObject*>(self);
|
return reinterpret_cast<PyObject*>(self);
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
|
|||||||
@ -180,9 +180,7 @@ struct TORCH_API AccumulateGrad : public Node {
|
|||||||
if (!GradMode::is_enabled() && !new_grad.is_sparse() &&
|
if (!GradMode::is_enabled() && !new_grad.is_sparse() &&
|
||||||
!new_grad.is_sparse_csr() &&
|
!new_grad.is_sparse_csr() &&
|
||||||
!(variable.is_sparse_csr() && new_grad.layout() == at::kStrided) &&
|
!(variable.is_sparse_csr() && new_grad.layout() == at::kStrided) &&
|
||||||
impl::is_tensor_stealable(
|
at::caching::adjusted_use_count(new_grad) <= num_expected_refs &&
|
||||||
new_grad,
|
|
||||||
num_expected_refs + at::caching::is_cached_tensor(new_grad)) &&
|
|
||||||
(new_grad.is_mkldnn() ||
|
(new_grad.is_mkldnn() ||
|
||||||
utils::obeys_layout_contract(new_grad, variable))) {
|
utils::obeys_layout_contract(new_grad, variable))) {
|
||||||
// See Case 1.1: Stealable dense new_grad
|
// See Case 1.1: Stealable dense new_grad
|
||||||
@ -195,7 +193,7 @@ struct TORCH_API AccumulateGrad : public Node {
|
|||||||
// SparseTensor should be the only one holding a reference to these.
|
// SparseTensor should be the only one holding a reference to these.
|
||||||
new_grad._indices().use_count() <= 1 &&
|
new_grad._indices().use_count() <= 1 &&
|
||||||
new_grad._values().use_count() <= 1 &&
|
new_grad._values().use_count() <= 1 &&
|
||||||
impl::is_tensor_stealable(new_grad, num_expected_refs)) {
|
new_grad.use_count() <= num_expected_refs) {
|
||||||
// Case 1.2: Stealable sparse new_grad
|
// Case 1.2: Stealable sparse new_grad
|
||||||
// No scenario where we expect this to be true currently
|
// No scenario where we expect this to be true currently
|
||||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||||
|
|||||||
@ -86,8 +86,8 @@ bool can_accumulate_inplace(const Variable& v) {
|
|||||||
v.is_non_overlapping_and_dense() &&
|
v.is_non_overlapping_and_dense() &&
|
||||||
|
|
||||||
// and we hold the last reference
|
// and we hold the last reference
|
||||||
impl::is_tensor_stealable(v, 1 + at::caching::is_cached_tensor(v)) &&
|
at::caching::adjusted_use_count(v) == 1 && v.has_storage() &&
|
||||||
v.has_storage() && v.storage().use_count() == 1);
|
v.storage().use_count() == 1);
|
||||||
}
|
}
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
|
|||||||
@ -54,7 +54,6 @@
|
|||||||
using namespace at;
|
using namespace at;
|
||||||
using namespace torch;
|
using namespace torch;
|
||||||
using namespace torch::autograd;
|
using namespace torch::autograd;
|
||||||
using torch::utils::PyObjectPreservation;
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class OperatorArgsKwargsView {
|
class OperatorArgsKwargsView {
|
||||||
@ -322,15 +321,20 @@ PyObject* THPVariableClass = nullptr;
|
|||||||
|
|
||||||
PyObject* ParameterClass = nullptr;
|
PyObject* ParameterClass = nullptr;
|
||||||
|
|
||||||
|
static PyObject* THPVariable_NewWithVar(
|
||||||
|
PyTypeObject* type,
|
||||||
|
const at::TensorBase& _var,
|
||||||
|
bool allow_preexisting_pyobj = false,
|
||||||
|
std::optional<bool> has_torch_dispatch_if_known = std::nullopt);
|
||||||
|
|
||||||
// clang-tidy gets confused by static const
|
// clang-tidy gets confused by static const
|
||||||
static constexpr const char* VOLATILE_WARNING =
|
static constexpr const char* VOLATILE_WARNING =
|
||||||
"volatile was removed and now has no effect. Use "
|
"volatile was removed and now has no effect. Use "
|
||||||
"`with torch.no_grad():` instead.";
|
"`with torch.no_grad():` instead.";
|
||||||
|
|
||||||
static void TORCH_CHECK_TENSOR_SUBTYPE(PyObject* cls);
|
|
||||||
|
|
||||||
static bool check_has_torch_dispatch(PyObject* obj) {
|
static bool check_has_torch_dispatch(PyObject* obj) {
|
||||||
if (THPVariable_CheckExact(obj)) {
|
PyTypeObject* tp = Py_TYPE(obj);
|
||||||
|
if (THPVariable_CheckTypeExact(tp)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
py::object attr = PyObject_FastGetAttrString(obj, "__torch_dispatch__");
|
py::object attr = PyObject_FastGetAttrString(obj, "__torch_dispatch__");
|
||||||
@ -366,86 +370,152 @@ void activateGPUTrace() {
|
|||||||
c10::impl::GPUTrace::set_trace(getPyInterpreter());
|
c10::impl::GPUTrace::set_trace(getPyInterpreter());
|
||||||
}
|
}
|
||||||
|
|
||||||
static void check_tensor_subclass(PyObject* obj, PyTypeObject* type) {
|
PyObject* THPVariable_Wrap(const at::TensorBase& var) {
|
||||||
TORCH_CHECK(
|
|
||||||
PyObject_TypeCheck(obj, type),
|
|
||||||
"Creating a new Tensor subclass ",
|
|
||||||
type->tp_name,
|
|
||||||
" but the raw Tensor object is already associated to a python object ",
|
|
||||||
"of type ",
|
|
||||||
Py_TYPE(obj)->tp_name,
|
|
||||||
" which is not a subclass of the requested type");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generic for const Tensor& or Tensor&&
|
|
||||||
template <typename T>
|
|
||||||
static PyObject* THPVariable_WrapWithType(
|
|
||||||
T&& var,
|
|
||||||
std::optional<PyTypeObject*> desired_type) {
|
|
||||||
if (!var.defined()) {
|
if (!var.defined()) {
|
||||||
Py_RETURN_NONE;
|
Py_RETURN_NONE;
|
||||||
}
|
}
|
||||||
|
|
||||||
c10::TensorImpl* tensor_impl = var.unsafeGetTensorImpl();
|
if (c10::impl::HermeticPyObjectTLS::get_state()) {
|
||||||
c10::impl::PyObjectSlot* pyobj_slot = tensor_impl->pyobj_slot();
|
return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var);
|
||||||
|
}
|
||||||
|
|
||||||
PyObject* obj = pyobj_slot->load_pyobj();
|
std::optional<PyObject*> mb_obj =
|
||||||
if (obj) {
|
var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
|
||||||
if (desired_type) {
|
/*ignore_hermetic_tls=*/false);
|
||||||
check_tensor_subclass(obj, *desired_type);
|
if (mb_obj.has_value()) {
|
||||||
|
auto obj = *mb_obj;
|
||||||
|
if (obj) {
|
||||||
|
if (var.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) {
|
||||||
|
// C++ owns the Python object; this implies there weren't any other
|
||||||
|
// owning references to the Python object. Since we're making the
|
||||||
|
// object "live" again on Python side, let's flip back the ownership
|
||||||
|
// (Python owns C++) as it would now be unsound to deallocate the C++
|
||||||
|
// object if all C++ references go to zero
|
||||||
|
var.unsafeGetTensorImpl()->pyobj_slot()->set_owns_pyobj(false);
|
||||||
|
reinterpret_cast<THPVariable*>(obj)->cdata =
|
||||||
|
MaybeOwned<Variable>::owned(Variable(var));
|
||||||
|
// NB: incref is not necessary, because we are "stealing" the previous
|
||||||
|
// ownership from the Variable to return it here for the wrap
|
||||||
|
return obj;
|
||||||
|
}
|
||||||
|
Py_INCREF(obj);
|
||||||
|
return obj;
|
||||||
}
|
}
|
||||||
return Py_NewRef(obj);
|
// TODO: a better invariant is that if we tagged, we MUST have a valid
|
||||||
|
// PyObject. That's PyObject preservation
|
||||||
|
// (https://github.com/pytorch/pytorch/pull/56017). Prior to this PR
|
||||||
|
// being a thing, the PyObject field will get cleared when all references
|
||||||
|
// to the Python object are removed.
|
||||||
}
|
}
|
||||||
|
|
||||||
PyTypeObject* type = reinterpret_cast<PyTypeObject*>(THPVariableClass);
|
if (C10_LIKELY(var.device().type() != c10::kXLA)) {
|
||||||
if (desired_type) {
|
return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var);
|
||||||
type = *desired_type;
|
|
||||||
} else if (C10_UNLIKELY(var.device().type() == c10::kXLA)) {
|
|
||||||
if (auto clazz = getPythonTensorClass(var.device())) {
|
|
||||||
type = reinterpret_cast<PyTypeObject*>(clazz);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
obj = type->tp_alloc(type, 0);
|
if (auto clazz = getPythonTensorClass(var.device())) {
|
||||||
TORCH_CHECK(obj, "Failed to allocate a ", type->tp_name, " object");
|
return THPVariable_NewWithVar((PyTypeObject*)clazz, var);
|
||||||
|
|
||||||
// Ensure that PyUnstable_TryIncref calls don't fail spuriously in
|
|
||||||
// free-threaded Python.
|
|
||||||
PyUnstable_EnableTryIncRef(obj);
|
|
||||||
|
|
||||||
auto v = reinterpret_cast<THPVariable*>(obj);
|
|
||||||
new (&v->cdata) Tensor(std::forward<T>(var));
|
|
||||||
|
|
||||||
if (THPVariable_Unpack(obj).is_uniquely_owned()) {
|
|
||||||
// We can use a faster non-atomic code path if we have the only reference to
|
|
||||||
// a fresh Tensor.
|
|
||||||
PyObjectPreservation::init_fresh_nonatomic(tensor_impl, pyobj_slot, obj);
|
|
||||||
return obj;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
PyObject* wrapper =
|
return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var);
|
||||||
PyObjectPreservation::init_once(tensor_impl, pyobj_slot, obj);
|
|
||||||
if (wrapper != obj) {
|
|
||||||
// Another thread beat us to it
|
|
||||||
Py_DECREF(obj);
|
|
||||||
if (desired_type) {
|
|
||||||
check_tensor_subclass(wrapper, *desired_type);
|
|
||||||
}
|
|
||||||
return Py_NewRef(wrapper);
|
|
||||||
}
|
|
||||||
return obj;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
PyObject* THPVariable_Wrap(at::TensorBase&& var) {
|
static bool isResurrectable(THPVariable* self) {
|
||||||
return THPVariable_WrapWithType(std::move(var), std::nullopt);
|
// We want to divide this check into 2 cases.
|
||||||
|
|
||||||
|
// 1. C++ owns PyObject (in this case, self->cdata.unsafeIsBorrowed() is
|
||||||
|
// true). You might think that in this case, it is impossible for tp_clear to
|
||||||
|
// be called: surely the C++ reference to the PyObject is keeping it live? And
|
||||||
|
// you'd be right! In fact, when C++ owns the PyObject, we have an invariant
|
||||||
|
// that the refcount on the PyObject should be precisely one (because if you
|
||||||
|
// take out another reference to the PyObject, we're supposed to flip the
|
||||||
|
// ownership pointer back). In reality, you can violate this invariant
|
||||||
|
// temporarily with weak references, so we don't test for it in asserts.
|
||||||
|
|
||||||
|
// 2. PyObject owns C++ (in this case, self->cdata.unsafeIsBorrowed() is
|
||||||
|
// false). In this case, tp_clear can get called if the PyObject is referenced
|
||||||
|
// from a dead cycle, and nowhere else. But if resurrection did not occur,
|
||||||
|
// then the reference to C++ from the PyObject must be the ONLY reference to
|
||||||
|
// the C++ object.
|
||||||
|
if (self->cdata.unsafeIsBorrowed()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto const& tensor = THPVariable_Unpack(self);
|
||||||
|
if (!tensor.defined() || tensor.use_count() <= 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// Check if this is hermetic. If it is, no resurrection.
|
||||||
|
if (tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
|
||||||
|
/*ignore_hermetic_tls=*/false) != (PyObject*)self) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
PyObject* THPVariable_Wrap(const at::TensorBase& var) {
|
// returns true if successfully rezzed; if so, cancel the
|
||||||
return THPVariable_WrapWithType(var, std::nullopt);
|
// rest of deallocation
|
||||||
|
static bool THPVariable_tryResurrect(THPVariable* self) {
|
||||||
|
const auto& tensor = THPVariable_Unpack(self);
|
||||||
|
|
||||||
|
if (!isResurrectable(self)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// At this point, we are definitely going to resurrect the tensor. So, the
|
||||||
|
// tensor better be defined :)
|
||||||
|
TORCH_INTERNAL_ASSERT(tensor.defined());
|
||||||
|
|
||||||
|
// There are other C++ owners of the tensor. Flip ownership
|
||||||
|
// so that C++ owns this Python object, and cancel deallocation.
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
!tensor.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj());
|
||||||
|
|
||||||
|
c10::TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
|
||||||
|
auto maybe_pyobj = tensor_impl->pyobj_slot()->check_pyobj(
|
||||||
|
/*ignore_hermetic_tls=*/false);
|
||||||
|
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
maybe_pyobj.has_value(),
|
||||||
|
"Trying to preserve a Python tensor whose PyObjectSlot does not have a PyObject");
|
||||||
|
|
||||||
|
tensor_impl->pyobj_slot()->set_owns_pyobj(true);
|
||||||
|
|
||||||
|
// Resurrect the Python object. This is something CPython does
|
||||||
|
// internally occasionally, see
|
||||||
|
// https://github.com/python/cpython/blob/b98eba5bc2ffbe7a0ed49d540ebc4f756ae61985/Objects/object.c#L248-L259
|
||||||
|
// so we just copy the pattern here. Note that we don't have to worry
|
||||||
|
// about saving and restoring the refcount (as the quoted code does)
|
||||||
|
// because we actually DO need to reset the refcount to one here, we
|
||||||
|
// can't assume that some other code has taken care of it.
|
||||||
|
// NB: this will overreport _Py_RefTotal but based on inspection of object.c
|
||||||
|
// there is no way to avoid this
|
||||||
|
|
||||||
|
// When resurrecting, we MUST use _Py_NewReference and not Py_INCREF to
|
||||||
|
// ensure the PyObject is in a valid state
|
||||||
|
_Py_NewReference((PyObject*)self);
|
||||||
|
|
||||||
|
// Flip THPVariable to be non-owning
|
||||||
|
// (near use-after-free miss here: fresh MaybeOwned is created breaking
|
||||||
|
// reference on Tensor in struct BEFORE we overwrite the old one)
|
||||||
|
TORCH_INTERNAL_ASSERT(!c10::impl::HermeticPyObjectTLS::get_state());
|
||||||
|
self->cdata = MaybeOwned<Variable>::borrowed(tensor);
|
||||||
|
|
||||||
|
// NB: At this point, tensor *could* be dead (e.g., some other C++ thread
|
||||||
|
// decrefed it.) At this point, it is probably waiting on the GIL to
|
||||||
|
// deallocate the Python object and will kill self, BUT NOT YET.
|
||||||
|
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
PyObject* THPVariable_Wrap(const at::TensorBase& var, PyTypeObject* type) {
|
static int THPFake_traverse(THPVariable* self, visitproc visit, void* arg) {
|
||||||
return THPVariable_WrapWithType(var, type);
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
false, "TensorBase tp_traverse function was not overridden properly");
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int THPFake_clear(THPVariable* self) {
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
false, "TensorBase tp_clear function was not overridden properly");
|
||||||
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
static PyObject* THPVariable_pynew(
|
static PyObject* THPVariable_pynew(
|
||||||
@ -607,16 +677,16 @@ static PyObject* THPVariable_as_subclass(
|
|||||||
ParsedArgs<1> parsed_args{};
|
ParsedArgs<1> parsed_args{};
|
||||||
auto r = parser.parse(_self, args, kwargs, parsed_args);
|
auto r = parser.parse(_self, args, kwargs, parsed_args);
|
||||||
PyObject* cls = r.pyobject(0);
|
PyObject* cls = r.pyobject(0);
|
||||||
TORCH_CHECK_TENSOR_SUBTYPE(cls);
|
TORCH_CHECK_TYPE(
|
||||||
|
PyType_Check(cls),
|
||||||
|
"cls must be a type (got ",
|
||||||
|
Py_TYPE(cls)->tp_name,
|
||||||
|
")");
|
||||||
// guard completely turns off torch dispatch modes, doesn't just pop off the
|
// guard completely turns off torch dispatch modes, doesn't just pop off the
|
||||||
// stack
|
// stack
|
||||||
torch_dispatch_mode::StashTorchDispatchStackGuard td_g;
|
torch_dispatch_mode::StashTorchDispatchStackGuard td_g;
|
||||||
c10::impl::DisablePythonDispatcher dpd_g;
|
c10::impl::DisablePythonDispatcher dpd_g;
|
||||||
PyObject* obj = THPVariable_WrapWithType(self.alias(), (PyTypeObject*)cls);
|
return THPVariable_NewWithVar((PyTypeObject*)cls, self.alias());
|
||||||
if (check_has_torch_dispatch(obj)) {
|
|
||||||
THPVariable_Unpack(obj).unsafeGetTensorImpl()->set_python_dispatch(true);
|
|
||||||
}
|
|
||||||
return obj;
|
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -631,7 +701,11 @@ static PyObject* THPVariable_make_subclass(
|
|||||||
ParsedArgs<7> parsed_args{};
|
ParsedArgs<7> parsed_args{};
|
||||||
auto r = parser.parse(args, kwargs, parsed_args);
|
auto r = parser.parse(args, kwargs, parsed_args);
|
||||||
PyObject* cls = r.pyobject(0);
|
PyObject* cls = r.pyobject(0);
|
||||||
TORCH_CHECK_TENSOR_SUBTYPE(cls);
|
TORCH_CHECK_TYPE(
|
||||||
|
PyType_Check(cls),
|
||||||
|
"cls must be a type (got ",
|
||||||
|
Py_TYPE(cls)->tp_name,
|
||||||
|
")");
|
||||||
// guard completely turns off torch dispatch modes, doesn't just pop off the
|
// guard completely turns off torch dispatch modes, doesn't just pop off the
|
||||||
// stack
|
// stack
|
||||||
torch_dispatch_mode::StashTorchDispatchStackGuard td_g;
|
torch_dispatch_mode::StashTorchDispatchStackGuard td_g;
|
||||||
@ -664,11 +738,7 @@ static PyObject* THPVariable_make_subclass(
|
|||||||
data.unsafeGetTensorImpl()->_change_backend_component_keys(r.device(6));
|
data.unsafeGetTensorImpl()->_change_backend_component_keys(r.device(6));
|
||||||
}
|
}
|
||||||
|
|
||||||
PyObject* obj = THPVariable_WrapWithType(data, (PyTypeObject*)cls);
|
return THPVariable_NewWithVar((PyTypeObject*)cls, data);
|
||||||
if (check_has_torch_dispatch(obj)) {
|
|
||||||
THPVariable_Unpack(obj).unsafeGetTensorImpl()->set_python_dispatch(true);
|
|
||||||
}
|
|
||||||
return obj;
|
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -765,7 +835,11 @@ static PyObject* THPVariable_make_wrapper_subclass(
|
|||||||
auto r = parser.parse(args, kwargs, parsed_args);
|
auto r = parser.parse(args, kwargs, parsed_args);
|
||||||
PyObject* cls = r.pyobject(0);
|
PyObject* cls = r.pyobject(0);
|
||||||
|
|
||||||
TORCH_CHECK_TENSOR_SUBTYPE(cls);
|
TORCH_CHECK_TYPE(
|
||||||
|
PyType_Check(cls),
|
||||||
|
"cls must be a type (got ",
|
||||||
|
Py_TYPE(cls)->tp_name,
|
||||||
|
")");
|
||||||
|
|
||||||
// This is an important safety check; without it, the default behavior will be
|
// This is an important safety check; without it, the default behavior will be
|
||||||
// to continue on to the underlying CPU/CUDA kernel advertised by the dispatch
|
// to continue on to the underlying CPU/CUDA kernel advertised by the dispatch
|
||||||
@ -803,8 +877,6 @@ static PyObject* THPVariable_make_wrapper_subclass(
|
|||||||
/*storage_size=*/r.toSymIntOptional(14),
|
/*storage_size=*/r.toSymIntOptional(14),
|
||||||
r.toDispatchKeySetOptional(13));
|
r.toDispatchKeySetOptional(13));
|
||||||
|
|
||||||
tensor.unsafeGetTensorImpl()->set_python_dispatch(true);
|
|
||||||
|
|
||||||
const auto sizes_strides_policy = r.stringViewOptional(10);
|
const auto sizes_strides_policy = r.stringViewOptional(10);
|
||||||
if (sizes_strides_policy.has_value()) {
|
if (sizes_strides_policy.has_value()) {
|
||||||
tensor.unsafeGetTensorImpl()->set_python_custom_sizes_strides(
|
tensor.unsafeGetTensorImpl()->set_python_custom_sizes_strides(
|
||||||
@ -820,7 +892,13 @@ static PyObject* THPVariable_make_wrapper_subclass(
|
|||||||
tensor.unsafeGetTensorImpl()->set_python_custom_layout(true);
|
tensor.unsafeGetTensorImpl()->set_python_custom_layout(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
return THPVariable_WrapWithType(std::move(tensor), (PyTypeObject*)cls);
|
return THPVariable_NewWithVar(
|
||||||
|
(PyTypeObject*)cls,
|
||||||
|
tensor,
|
||||||
|
// false is the default
|
||||||
|
/*allow_preexisting_pyobj=*/false,
|
||||||
|
// we checked __torch_dispatch__ above; avoid checking again.
|
||||||
|
/*has_torch_dispatch_if_known=*/true);
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1621,7 +1699,11 @@ static PyObject* THPVariable_dtensor_new(
|
|||||||
auto r = parser.parse(args, kwargs, parsed_args);
|
auto r = parser.parse(args, kwargs, parsed_args);
|
||||||
PyObject* cls = r.pyobject(0);
|
PyObject* cls = r.pyobject(0);
|
||||||
|
|
||||||
TORCH_CHECK_TENSOR_SUBTYPE(cls);
|
TORCH_CHECK_TYPE(
|
||||||
|
PyType_Check(cls),
|
||||||
|
"cls must be a type (got ",
|
||||||
|
Py_TYPE(cls)->tp_name,
|
||||||
|
")");
|
||||||
|
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
// This is specifically for making a DTensor, which we know defines
|
// This is specifically for making a DTensor, which we know defines
|
||||||
@ -1674,9 +1756,14 @@ static PyObject* THPVariable_dtensor_new(
|
|||||||
/*storage_size=*/std::nullopt,
|
/*storage_size=*/std::nullopt,
|
||||||
extra_dispatch_keys);
|
extra_dispatch_keys);
|
||||||
tensor.set_requires_grad(requires_grad);
|
tensor.set_requires_grad(requires_grad);
|
||||||
tensor.unsafeGetTensorImpl()->set_python_dispatch(true);
|
py::object py_tensor =
|
||||||
py::object py_tensor = py::reinterpret_steal<py::object>(
|
py::reinterpret_steal<py::object>(THPVariable_NewWithVar(
|
||||||
THPVariable_WrapWithType(std::move(tensor), (PyTypeObject*)cls));
|
(PyTypeObject*)cls,
|
||||||
|
tensor,
|
||||||
|
// false is the default
|
||||||
|
/*allow_preexisting_pyobj=*/false,
|
||||||
|
// we know DTensor has __torch_dispatch__; avoid checking again.
|
||||||
|
/*has_torch_dispatch_if_known=*/true));
|
||||||
py_tensor.attr(dtensor_interned_strings._spec) = spec;
|
py_tensor.attr(dtensor_interned_strings._spec) = spec;
|
||||||
py_tensor.attr(dtensor_interned_strings._local_tensor) = local_tensor;
|
py_tensor.attr(dtensor_interned_strings._local_tensor) = local_tensor;
|
||||||
return py_tensor.release().ptr();
|
return py_tensor.release().ptr();
|
||||||
@ -3353,16 +3440,15 @@ static PyTypeObject THPVariableMetaType = {
|
|||||||
nullptr, /* tp_new */
|
nullptr, /* tp_new */
|
||||||
};
|
};
|
||||||
|
|
||||||
static void THPVariable_dealloc(PyObject* self);
|
|
||||||
static int THPVariable_clear(THPVariable* self);
|
|
||||||
static int THPVariable_traverse(PyObject* self, visitproc visit, void* arg);
|
|
||||||
|
|
||||||
static PyTypeObject THPVariableType = {
|
static PyTypeObject THPVariableType = {
|
||||||
PyVarObject_HEAD_INIT(&THPVariableMetaType, 0)
|
PyVarObject_HEAD_INIT(&THPVariableMetaType, 0)
|
||||||
"torch._C.TensorBase", /* tp_name */
|
"torch._C.TensorBase", /* tp_name */
|
||||||
sizeof(THPVariable), /* tp_basicsize */
|
sizeof(THPVariable), /* tp_basicsize */
|
||||||
0, /* tp_itemsize */
|
0, /* tp_itemsize */
|
||||||
THPVariable_dealloc, /* tp_dealloc */
|
// This is unspecified, because it is illegal to create a THPVariableType
|
||||||
|
// directly. Subclasses will have their tp_dealloc set appropriately
|
||||||
|
// by the metaclass
|
||||||
|
nullptr, /* tp_dealloc */
|
||||||
0, /* tp_vectorcall_offset */
|
0, /* tp_vectorcall_offset */
|
||||||
nullptr, /* tp_getattr */
|
nullptr, /* tp_getattr */
|
||||||
nullptr, /* tp_setattr */
|
nullptr, /* tp_setattr */
|
||||||
@ -3381,8 +3467,9 @@ static PyTypeObject THPVariableType = {
|
|||||||
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE |
|
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE |
|
||||||
Py_TPFLAGS_HAVE_GC, /* tp_flags */
|
Py_TPFLAGS_HAVE_GC, /* tp_flags */
|
||||||
nullptr, /* tp_doc */
|
nullptr, /* tp_doc */
|
||||||
(traverseproc)THPVariable_traverse, /* tp_traverse */
|
// Also set by metaclass
|
||||||
(inquiry)THPVariable_clear, /* tp_clear */
|
(traverseproc)THPFake_traverse, /* tp_traverse */
|
||||||
|
(inquiry)THPFake_clear, /* tp_clear */
|
||||||
nullptr, /* tp_richcompare */
|
nullptr, /* tp_richcompare */
|
||||||
0, /* tp_weaklistoffset */
|
0, /* tp_weaklistoffset */
|
||||||
nullptr, /* tp_iter */
|
nullptr, /* tp_iter */
|
||||||
@ -3411,68 +3498,345 @@ PyObject* THPVariable_pynew(
|
|||||||
type != &THPVariableType,
|
type != &THPVariableType,
|
||||||
"Cannot directly construct TensorBase; subclass it and then construct that");
|
"Cannot directly construct TensorBase; subclass it and then construct that");
|
||||||
jit::tracer::warn("torch.Tensor", jit::tracer::WARN_CONSTRUCTOR);
|
jit::tracer::warn("torch.Tensor", jit::tracer::WARN_CONSTRUCTOR);
|
||||||
|
auto tensor = torch::utils::base_tensor_ctor(args, kwargs);
|
||||||
// WARNING: tensor is NOT guaranteed to be a fresh tensor; e.g., if it was
|
// WARNING: tensor is NOT guaranteed to be a fresh tensor; e.g., if it was
|
||||||
// given a raw pointer that will refcount bump
|
// given a raw pointer that will refcount bump
|
||||||
// NB: base_tensor_ctor can call into dispatched ATen functions (e.g.,
|
// NB: base_tensor_ctor can call into dispatched ATen functions (e.g.,
|
||||||
// alias(), lift_fresh()) which can return Tensor subclasses. We allow
|
// alias(), lift_fresh()) which can return Tensor subclasses. We allow
|
||||||
// these to be passed on directly.
|
// these to be passed on directly.
|
||||||
PyObject* obj = THPVariable_WrapWithType(
|
return THPVariable_NewWithVar(
|
||||||
torch::utils::base_tensor_ctor(args, kwargs), type);
|
type,
|
||||||
if (check_has_torch_dispatch(obj)) {
|
tensor,
|
||||||
THPVariable_Unpack(obj).unsafeGetTensorImpl()->set_python_dispatch(true);
|
/*allow_preexisting_pyobj=*/true);
|
||||||
}
|
|
||||||
return obj;
|
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
|
|
||||||
static int THPVariable_clear(THPVariable* self) {
|
static int THPVariable_subclass_clear(THPVariable* self) {
|
||||||
|
// Is it OK for an object to still be live after running
|
||||||
|
// tp_clear? Yes. When Python is breaking reference cycles, it can't assume
|
||||||
|
// that an object will dealloc after it's cleared. The source code explicitly
|
||||||
|
// handles this case:
|
||||||
|
// https://github.com/python/cpython/blob/4e661cd69164318c1f871faa476c68a04092ddc4/Modules/gcmodule.c#L1010-L1025
|
||||||
|
|
||||||
|
// Note that we don't need to actually resurrect here. There are 2 cases:
|
||||||
|
// 1. The PyObject is not part of a reference cycle. In this case, we don't
|
||||||
|
// need to do anything. The GC will move on to try and break the reference
|
||||||
|
// cycle on another object, which will eventually trigger tp_dealloc (and thus
|
||||||
|
// resurrection).
|
||||||
|
|
||||||
|
// 2. The PyObject is part of a reference cycle. This case should not actually
|
||||||
|
// be possible, due to the logic in our tp_traverse
|
||||||
|
// (THPVariable_subclass_traverse).
|
||||||
|
|
||||||
|
// In fact, resurrecting here breaks the invariant that "C++ owns Python only
|
||||||
|
// when PyObject's refcount would otherwise be 0". Most immediately, as we're
|
||||||
|
// merely breaking reference cycles here, there can be other references to the
|
||||||
|
// PyObject. *However*, if other objects in the refcycle resurrect, then we
|
||||||
|
// will be in a state where the PyObject has multiple Python references, yet
|
||||||
|
// C++ owns the PyObject.
|
||||||
|
|
||||||
|
// See https://github.com/pytorch/pytorch/pull/75933 for more discussion.
|
||||||
|
if (isResurrectable(self)) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
// First clear Tensor specific things
|
// First clear Tensor specific things
|
||||||
|
|
||||||
Py_CLEAR(self->backward_hooks);
|
Py_CLEAR(self->backward_hooks);
|
||||||
Py_CLEAR(self->post_accumulate_grad_hooks);
|
Py_CLEAR(self->post_accumulate_grad_hooks);
|
||||||
if (self->cdata.defined()) {
|
const auto& tensor = THPVariable_Unpack(self);
|
||||||
auto pyobj_slot = self->cdata.unsafeGetTensorImpl()->pyobj_slot();
|
if (tensor.defined()) {
|
||||||
// Typically the Tensor's pyobj_slot points back to this object. The only
|
// Two situations to consider:
|
||||||
// time that's not the case is if we had a race in THPVariable_Wrap and we
|
// PyObject -owns-> Tensor
|
||||||
// need to discard the Python object because some other thread beat us to
|
// unsafeIsBorrowed() is FALSE. We're obligated to look through
|
||||||
// setting the pyobj_slot.
|
// Tensor to break references. Clearing cdata must induce the
|
||||||
if (pyobj_slot->load_pyobj() == (PyObject*)self) {
|
// destruction of the C++ Tensor. If there were other references
|
||||||
// A Tensor's Python object should only be destroyed when the Tensor has
|
// to C++ tensor, the Python object would have been resurrected
|
||||||
// no other references too.
|
// by flipping the ownership.
|
||||||
TORCH_INTERNAL_ASSERT(self->cdata.use_count() == 1);
|
// Tensor -owns-> PyObject
|
||||||
|
// unsafeIsBorrowed() is TRUE. We're deallocating the PyObject
|
||||||
|
// because Tensor asked us to (it's already destructing).
|
||||||
|
|
||||||
// Clear the pyobj_slot so that a try_incref() call from
|
if (!self->cdata.unsafeIsBorrowed() &&
|
||||||
// weak_intrusive_ptr::lock() won't see a freed pointer.
|
tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
|
||||||
pyobj_slot->clear();
|
/*ignore_hermetic_tls=*/false) == (PyObject*)self) {
|
||||||
|
// TODO: empirically, on OS X this assert appears to be untrue
|
||||||
|
// In test_py_tensors_multi_async_call - ProcessGroupRpcTestWithSpawn
|
||||||
|
// distributed/rpc/test_process_group_agent.py
|
||||||
|
//
|
||||||
|
// libc++abi.dylib: terminating with uncaught exception of type
|
||||||
|
// c10::Error:
|
||||||
|
// !tensor.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()INTERNAL
|
||||||
|
// ASSERT FAILED at "../torch/csrc/autograd/python_variable.cpp":171,
|
||||||
|
// please report a bug to PyTorch. Exception raised from
|
||||||
|
// THPVariable_subclass_clear at
|
||||||
|
// ../torch/csrc/autograd/python_variable.cpp:171 (most recent call
|
||||||
|
// first): frame #0: c10::Error::Error(c10::SourceLocation,
|
||||||
|
// std::__1::basic_string<char, std::__1::char_traits<char>,
|
||||||
|
// std::__1::allocator<char> >) + 98 (0x1158a0442 in libc10.dylib) frame
|
||||||
|
// #1: c10::detail::torchCheckFail(char const*, char const*, unsigned
|
||||||
|
// int, char const*) + 205 (0x11589ed3d in libc10.dylib) frame #2:
|
||||||
|
// c10::detail::torchInternalAssertFail(char const*, char const*,
|
||||||
|
// unsigned int, char const*, c10::detail::CompileTimeEmptyString) + 9
|
||||||
|
// (0x1141e3f89 in libtorch_python.dylib) frame #3:
|
||||||
|
// THPVariable_subclass_clear(THPVariable*) + 412 (0x1148a547c in
|
||||||
|
// libtorch_python.dylib) frame #4:
|
||||||
|
// THPVariable_subclass_dealloc(_object*) + 453 (0x1148a5035 in
|
||||||
|
// libtorch_python.dylib) frame #5: (anonymous
|
||||||
|
// namespace)::concrete_decref_fn(c10::impl::PyInterpreter const*,
|
||||||
|
// _object*) + 53 (0x1148a5ea5 in libtorch_python.dylib) frame #6:
|
||||||
|
// c10::TensorImpl::release_resources() + 182 (0x11588c4a6 in
|
||||||
|
// libc10.dylib) frame #7:
|
||||||
|
// c10::MaybeOwned<at::Tensor>::operator=(c10::MaybeOwned<at::Tensor>&&)
|
||||||
|
// + 91 (0x11488c11b in libtorch_python.dylib) frame #8:
|
||||||
|
// THPVariable_subclass_dealloc(_object*) + 607 (0x1148a50cf in
|
||||||
|
// libtorch_python.dylib) <omitting python frames> frame #47: start + 1
|
||||||
|
// (0x7fff6ffc7cc9 in libdyld.dylib) frame #48: 0x0 + 4 (0x4 in ???)
|
||||||
|
// TORCH_INTERNAL_ASSERT(!tensor.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj());
|
||||||
|
if (auto grad_acc =
|
||||||
|
torch::autograd::impl::try_get_grad_accumulator(tensor)) {
|
||||||
|
grad_acc->pre_hooks().clear();
|
||||||
|
grad_acc->tensor_pre_hooks().clear();
|
||||||
|
grad_acc->retains_grad_hooks().clear();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
TORCH_INTERNAL_ASSERT(!isResurrectable(self));
|
||||||
{
|
{
|
||||||
// MapAllocator can take significant time to release large tensors;
|
// MapAllocator can take significant time to release large tensors;
|
||||||
// release the GIL here to avoid impacting main thread perf.
|
// release the GIL here to avoid impacting main thread perf.
|
||||||
pybind11::gil_scoped_release no_gil;
|
pybind11::gil_scoped_release no_gil;
|
||||||
self->cdata = Variable();
|
self->cdata = MaybeOwned<Variable>();
|
||||||
}
|
}
|
||||||
|
// Since we override the basic subtype_clear from CPython, we need a crappy
|
||||||
|
// version here just like for traverse and dealloc
|
||||||
|
|
||||||
|
// Clear all slots until we get to the base Tensor class
|
||||||
|
PyTypeObject* type = Py_TYPE((PyObject*)self);
|
||||||
|
PyTypeObject* base = type;
|
||||||
|
while (base != &THPVariableType) {
|
||||||
|
if (Py_SIZE(base))
|
||||||
|
clear_slots(base, (PyObject*)self);
|
||||||
|
base = base->tp_base;
|
||||||
|
TORCH_INTERNAL_ASSERT(base);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assume we never have managed dict for Tensors as we don't set the flag on
|
||||||
|
// the base class
|
||||||
|
if (C10_LIKELY(type->tp_dictoffset)) {
|
||||||
|
PyObject** dictptr = _PyObject_GetDictPtr((PyObject*)self);
|
||||||
|
if (dictptr && *dictptr)
|
||||||
|
Py_CLEAR(*dictptr);
|
||||||
|
}
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void THPVariable_dealloc(PyObject* self) {
|
// NB: this is not the tp_dealloc on THPVariable; instead, its the dealloc
|
||||||
|
// on subclasses. It's never valid to construct a THPVariable so it's not
|
||||||
|
// necessary to implement the dealloc for that case
|
||||||
|
static void THPVariable_subclass_dealloc(PyObject* self) {
|
||||||
|
if (THPVariable_tryResurrect((THPVariable*)self))
|
||||||
|
return;
|
||||||
|
|
||||||
|
// This is like a crappy version of subtype_dealloc.
|
||||||
|
// Unfortunately, we cannot directly delegate to
|
||||||
|
// subtype_dealloc as it will start walking the parent
|
||||||
|
// chain *starting with* the type of self, which will cause
|
||||||
|
// us to go back to our custom dealloc.
|
||||||
|
//
|
||||||
|
// We have to replicate the subtype_dealloc logic to ensure
|
||||||
|
// that finalizers are handled correctly
|
||||||
|
PyTypeObject* type = Py_TYPE(self);
|
||||||
|
TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE);
|
||||||
|
TORCH_INTERNAL_ASSERT(PyType_IS_GC(type), "GC types not implemented");
|
||||||
|
|
||||||
PyObject_GC_UnTrack(self);
|
PyObject_GC_UnTrack(self);
|
||||||
THPVariable_clear((THPVariable*)self);
|
// TODO: consider using trash can
|
||||||
((THPVariable*)self)->cdata.~Variable();
|
|
||||||
|
bool has_finalizer = type->tp_finalize || type->tp_del;
|
||||||
|
|
||||||
|
if (type->tp_finalize) {
|
||||||
|
PyObject_GC_Track(self);
|
||||||
|
if (PyObject_CallFinalizerFromDealloc(self) < 0) {
|
||||||
|
/* Resurrected */
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
PyObject_GC_UnTrack(self);
|
||||||
|
}
|
||||||
|
|
||||||
|
// base test is unnecessary as THPVariable does not set this
|
||||||
|
if (type->tp_weaklistoffset) {
|
||||||
|
PyObject_ClearWeakRefs(self);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (type->tp_del) {
|
||||||
|
PyObject_GC_Track(self);
|
||||||
|
type->tp_del(self);
|
||||||
|
if (Py_REFCNT(self) > 0) {
|
||||||
|
/* Resurrected */
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
PyObject_GC_UnTrack(self);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (has_finalizer) {
|
||||||
|
/* New weakrefs could be created during the finalizer call.
|
||||||
|
If this occurs, clear them out without calling their
|
||||||
|
finalizers since they might rely on part of the object
|
||||||
|
being finalized that has already been destroyed. */
|
||||||
|
if (type->tp_weaklistoffset) {
|
||||||
|
/* Modeled after GET_WEAKREFS_LISTPTR() */
|
||||||
|
PyWeakReference** list =
|
||||||
|
(PyWeakReference**)PyObject_GET_WEAKREFS_LISTPTR(self);
|
||||||
|
while (*list)
|
||||||
|
_PyWeakref_ClearRef(*list);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear all slots until we get to base class THPVariableType
|
||||||
|
{
|
||||||
|
PyTypeObject* base = type;
|
||||||
|
while (base != &THPVariableType) {
|
||||||
|
if (Py_SIZE(base)) {
|
||||||
|
clear_slots(base, self);
|
||||||
|
}
|
||||||
|
base = base->tp_base;
|
||||||
|
TORCH_INTERNAL_ASSERT(base);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// All Python defined classes have __dict__
|
||||||
|
if (C10_LIKELY(type->tp_dictoffset)) {
|
||||||
|
PyObject** dictptr = _PyObject_GetDictPtr(self);
|
||||||
|
if (dictptr != nullptr) {
|
||||||
|
PyObject* dict = *dictptr;
|
||||||
|
if (dict != nullptr) {
|
||||||
|
Py_DECREF(dict);
|
||||||
|
*dictptr = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// subtype_dealloc allows for this but we don't
|
||||||
|
TORCH_INTERNAL_ASSERT(Py_TYPE(self) == type);
|
||||||
|
|
||||||
|
// Finally clear out the base THPVariable
|
||||||
|
THPVariable_subclass_clear((THPVariable*)self);
|
||||||
|
((THPVariable*)self)->cdata.~MaybeOwned<Variable>();
|
||||||
Py_TYPE(self)->tp_free(self);
|
Py_TYPE(self)->tp_free(self);
|
||||||
|
|
||||||
|
// Python defined subclasses should always be on the heap
|
||||||
|
TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE);
|
||||||
|
Py_DECREF(type);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void TORCH_CHECK_TENSOR_SUBTYPE(PyObject* cls) {
|
// Creates a new Python object for a Variable.
|
||||||
TORCH_CHECK_TYPE(
|
static PyObject* THPVariable_NewWithVar(
|
||||||
PyType_Check(cls),
|
PyTypeObject* type,
|
||||||
"cls must be a type (got ",
|
const at::TensorBase& _var,
|
||||||
Py_TYPE(cls)->tp_name,
|
bool allow_preexisting_pyobj,
|
||||||
")");
|
std::optional<bool> has_torch_dispatch_if_known) {
|
||||||
PyTypeObject* type = reinterpret_cast<PyTypeObject*>(cls);
|
// Make sure that the reinterpret into a THPVariable* will be valid
|
||||||
TORCH_CHECK_TYPE(
|
TORCH_CHECK(
|
||||||
type == &THPVariableType || cls == THPVariableClass ||
|
type == &THPVariableType || PyType_IsSubtype(type, &THPVariableType),
|
||||||
PyType_IsSubtype(type, &THPVariableType),
|
"Creating a Tensor subclass from a class ",
|
||||||
"Creating a Tensor subclass from a class that does not inherit from "
|
"that does not inherit from Tensor is not possible. Make sure your class inherits from Tensor.");
|
||||||
"Tensor is not possible. Make sure your class inherits from Tensor.");
|
|
||||||
|
// This function overwrite the Tensor's pyobj field without extra checks
|
||||||
|
// Make sure it is not set otherwise we would leak memory
|
||||||
|
auto mb_obj = _var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
|
||||||
|
/*ignore_hermetic_tls=*/false);
|
||||||
|
|
||||||
|
// Under some circumstances, we may attempt to create a new Python
|
||||||
|
// object for a variable that already has a Python object. The most common
|
||||||
|
// situation this can occur is if you have a TorchDispatchMode active that
|
||||||
|
// is returning a subclass from lift_fresh (which is invoked to
|
||||||
|
// appropriately "wrap" a constant tensor into whatever ambient modes are
|
||||||
|
// active.)
|
||||||
|
//
|
||||||
|
// In general, it is impossible to handle this case compositionally.
|
||||||
|
// Suppose you have a user call ATensor([1, 2, 3]) when a mode is active
|
||||||
|
// that is transforming all ops (including the internal lift_fresh call that
|
||||||
|
// transforms [1, 2, 3] into a torch.tensor([1., 2., 3.])) to output
|
||||||
|
// BTensor, where ATensor and BTensor are completely unrelated subclasses
|
||||||
|
// and there is no way to compose them. There is no way to satisfy the user
|
||||||
|
// request here: in particular, you can't just try to re-invoke the ATensor
|
||||||
|
// constructor on the returned BTensor, because (1) this could cause an
|
||||||
|
// infinite loop--we are already in ATensor.__new__ and (2) there isn't any
|
||||||
|
// guarantee that ATensor.__new__ supports a single element constructor
|
||||||
|
// anyway.
|
||||||
|
//
|
||||||
|
// However, a more common case is a user just called torch.Tensor([1, 2, 3]),
|
||||||
|
// and a fake tensor mode is active. Really, all you want is to get back
|
||||||
|
// a FakeTensor, in the same way torch.tensor([1, 2, 3]) or torch.arange(3)
|
||||||
|
// would have returned a fake tensor (concretely, the way this happens
|
||||||
|
// is we create a *real* tensor torch.tensor([1., 2., 3.]), and then it
|
||||||
|
// turns into a FakeTensor when we call lift_fresh on this real tensor).
|
||||||
|
// This case is compositional because FakeTensor is a subclass of Tensor, so
|
||||||
|
// it's valid for us to return it in place of a Tensor. So this is what we
|
||||||
|
// do.
|
||||||
|
|
||||||
|
if (mb_obj.has_value() && mb_obj.value()) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
allow_preexisting_pyobj,
|
||||||
|
"Creating a new Tensor subclass ",
|
||||||
|
type->tp_name,
|
||||||
|
" but the raw Tensor object is already associated to a python object ",
|
||||||
|
"of type ",
|
||||||
|
mb_obj.value()->ob_type->tp_name);
|
||||||
|
// Even if we allow pre-existing PyObject, we don't allow completely
|
||||||
|
// ignoring the requested type. Check that we fulfilled a subtype
|
||||||
|
// relation here. In the common case the requested type is Tensor and
|
||||||
|
// this always succeeds.
|
||||||
|
PyObject* obj = *mb_obj;
|
||||||
|
// Check if it's OK to just directly return the Python object without
|
||||||
|
// allocating a new variable. We just check that the existing Python
|
||||||
|
// object is a subclass of the requested type.
|
||||||
|
PyTypeObject* obj_type = Py_TYPE(obj);
|
||||||
|
TORCH_CHECK(
|
||||||
|
obj_type == type || PyType_IsSubtype(obj_type, type),
|
||||||
|
"Creating a new Tensor subclass ",
|
||||||
|
type->tp_name,
|
||||||
|
" but the raw Tensor object is already associated to a python object ",
|
||||||
|
"of type ",
|
||||||
|
mb_obj.value()->ob_type->tp_name,
|
||||||
|
" which is not a subclass of the "
|
||||||
|
"requested type");
|
||||||
|
// We may (in fact, we typically will) need to resurrect this
|
||||||
|
return THPVariable_Wrap(_var);
|
||||||
|
}
|
||||||
|
|
||||||
|
PyObject* obj = type->tp_alloc(type, 0);
|
||||||
|
if (obj) {
|
||||||
|
auto v = (THPVariable*)obj;
|
||||||
|
// TODO: named constructor to avoid default initialization
|
||||||
|
new (&v->cdata) MaybeOwned<Variable>();
|
||||||
|
if (c10::impl::HermeticPyObjectTLS::get_state()) {
|
||||||
|
// Do NOT initialize pyobj field on the tensor, you own the C++
|
||||||
|
v->cdata = MaybeOwned<Variable>::owned(Variable(_var));
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
!check_has_torch_dispatch(obj),
|
||||||
|
"While HermeticPyObject was enabled, we attempted to create a tensor "
|
||||||
|
"subclass with __torch_dispatch__. This violates the invariant that "
|
||||||
|
"operations in HermeticPyObject have equivalent C++ implementations. "
|
||||||
|
"If your operator registered from Python operator registration isn't "
|
||||||
|
"doing anything strange, there may be an internal PyTorch bug involving "
|
||||||
|
"not appropriately disabling TorchDispatchMode before executing "
|
||||||
|
"Python op registration.");
|
||||||
|
} else {
|
||||||
|
// Normal codepath
|
||||||
|
v->cdata = MaybeOwned<Variable>::owned(Variable(_var));
|
||||||
|
const auto& var = THPVariable_Unpack(v);
|
||||||
|
var.unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(obj);
|
||||||
|
if (has_torch_dispatch_if_known.has_value()
|
||||||
|
? *has_torch_dispatch_if_known
|
||||||
|
: check_has_torch_dispatch(obj)) {
|
||||||
|
var.unsafeGetTensorImpl()->set_python_dispatch(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return obj;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// NOTE [ PyObject Traversal ]
|
/// NOTE [ PyObject Traversal ]
|
||||||
@ -3491,7 +3855,7 @@ static void TORCH_CHECK_TENSOR_SUBTYPE(PyObject* cls) {
|
|||||||
/// into account these C++ ownership links.
|
/// into account these C++ ownership links.
|
||||||
///
|
///
|
||||||
/// The main danger here comes from the fact that, while all python-related code
|
/// The main danger here comes from the fact that, while all python-related code
|
||||||
/// is thread safe wrt the GC execution, other threads might
|
/// is thread safe wrt the GC execution (thanks to the GIL), other threads might
|
||||||
/// be using our C++ objects arbitrarily which can lead to shared_ptr ref count
|
/// be using our C++ objects arbitrarily which can lead to shared_ptr ref count
|
||||||
/// going up or down in between the different traverse/clear invocations. The
|
/// going up or down in between the different traverse/clear invocations. The
|
||||||
/// one constraint we add here that is not explicitly mentioned in the GC
|
/// one constraint we add here that is not explicitly mentioned in the GC
|
||||||
@ -3521,46 +3885,124 @@ static void TORCH_CHECK_TENSOR_SUBTYPE(PyObject* cls) {
|
|||||||
/// https://github.com/pytorch/pytorch/issues/7343
|
/// https://github.com/pytorch/pytorch/issues/7343
|
||||||
///
|
///
|
||||||
|
|
||||||
static int THPVariable_traverse(PyObject* self, visitproc visit, void* arg) {
|
static int traverse_slots(
|
||||||
|
PyTypeObject* type,
|
||||||
|
PyObject* self,
|
||||||
|
visitproc visit,
|
||||||
|
void* arg) {
|
||||||
|
auto n = Py_SIZE(type);
|
||||||
|
auto mp = type->tp_members;
|
||||||
|
for (Py_ssize_t i = 0; i < n; i++, mp++) {
|
||||||
|
if (mp->type == T_OBJECT_EX) {
|
||||||
|
char* addr = (char*)self + mp->offset;
|
||||||
|
PyObject* obj = *(PyObject**)addr;
|
||||||
|
if (obj != nullptr) {
|
||||||
|
int err = visit(obj, arg);
|
||||||
|
if (err)
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int THPVariable_subclass_traverse(
|
||||||
|
PyObject* self,
|
||||||
|
visitproc visit,
|
||||||
|
void* arg) {
|
||||||
|
// If the tensor is eligible to be resurrected, don't traverse it; instead
|
||||||
|
// treat all of its references as a root (as they WOULD be a root since we
|
||||||
|
// can treat the inbound C++ references as root owners).
|
||||||
|
//
|
||||||
|
// This works because unlike conventional GCs, Python's GC operates in two
|
||||||
|
// phases: first it uses traverse to discover roots, and then it uses traverse
|
||||||
|
// to do reachability. Bypassing traverse during root discovery forces Python
|
||||||
|
// to treat self as a root for everything it refers to. For a full
|
||||||
|
// explanation of the algorithm see
|
||||||
|
// https://devguide.python.org/garbage_collector/
|
||||||
|
//
|
||||||
|
// NB: if we don't hold an owning reference to the underlying Tensor, it is
|
||||||
|
// possible that the underlying Tensor has already gone dead. In that case,
|
||||||
|
// it's not safe to access it. But it's also safe to traverse, because if
|
||||||
|
// the underlying Tensor *is* live, then root discovery will determine that
|
||||||
|
// self is live, and nothing will get GC'ed anyway (resurrection cannot happen
|
||||||
|
// if the C++ objects owns the PyObject)
|
||||||
THPVariable* var = reinterpret_cast<THPVariable*>(self);
|
THPVariable* var = reinterpret_cast<THPVariable*>(self);
|
||||||
|
if (isResurrectable(var)) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Crappy version of subtype_traverse; same deal as
|
||||||
|
// THPVariable_subclass_dealloc
|
||||||
|
|
||||||
|
PyTypeObject* type = Py_TYPE(self);
|
||||||
|
// Traverse slots until we get to base class THPVariableType
|
||||||
|
{
|
||||||
|
PyTypeObject* base = type;
|
||||||
|
while (base != &THPVariableType) {
|
||||||
|
if (Py_SIZE(base)) {
|
||||||
|
int err = traverse_slots(base, self, visit, arg);
|
||||||
|
if (err)
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
base = base->tp_base;
|
||||||
|
TORCH_INTERNAL_ASSERT(base);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// All Python defined classes have __dict__
|
||||||
|
if (C10_LIKELY(type->tp_dictoffset)) {
|
||||||
|
PyObject** dictptr = _PyObject_GetDictPtr(self);
|
||||||
|
if (dictptr && *dictptr)
|
||||||
|
Py_VISIT(*dictptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE);
|
||||||
|
Py_VISIT(type);
|
||||||
|
|
||||||
|
// Finally traverse THPVariable special stuff
|
||||||
Py_VISIT(var->backward_hooks);
|
Py_VISIT(var->backward_hooks);
|
||||||
Py_VISIT(var->post_accumulate_grad_hooks);
|
Py_VISIT(var->post_accumulate_grad_hooks);
|
||||||
const auto& tensor = THPVariable_Unpack(var);
|
if (!var->cdata.unsafeIsBorrowed()) {
|
||||||
if (tensor.defined()) {
|
const auto& tensor = THPVariable_Unpack(var);
|
||||||
// WARNING: The grad_fn traversal logic is very subtle, if you change
|
if (tensor.defined()) {
|
||||||
// this, be very careful not to re-introduce this bug:
|
// WARNING: The grad_fn traversal logic is very subtle, if you change
|
||||||
// https://gist.github.com/zou3519/7ac92b84dd7d206dcc6eae55fee8372c
|
// this, be very careful not to re-introduce this bug:
|
||||||
|
// https://gist.github.com/zou3519/7ac92b84dd7d206dcc6eae55fee8372c
|
||||||
|
|
||||||
// We ensure that we follow NOTE [ PyObject Traversal ] he by checking
|
// We ensure that we follow NOTE [ PyObject Traversal ] he by checking
|
||||||
// that this python object is the sole owner of the underlying Tensor and
|
// that this python object is the sole owner of the underlying Tensor and
|
||||||
// that this Tensor is the sole owner of its grad_fn. In this case, the
|
// that this Tensor is the sole owner of its grad_fn. In this case, the
|
||||||
// only way to get a new reference to the grad_fn is by using this python
|
// only way to get a new reference to the grad_fn is by using this python
|
||||||
// object, which requires the GIL to be accessed. Note that this is only
|
// object, which requires the GIL to be accessed. Note that this is only
|
||||||
// valid as long as user don't share non-owning references across
|
// valid as long as user don't share non-owning references across
|
||||||
// different threads (which is crazy and should never be done).
|
// different threads (which is crazy and should never be done).
|
||||||
auto autograd_meta = torch::autograd::impl::get_autograd_meta(tensor);
|
auto autograd_meta = torch::autograd::impl::get_autograd_meta(tensor);
|
||||||
if (tensor.use_count() == 1) {
|
if (tensor.use_count() == 1) {
|
||||||
|
if (autograd_meta) {
|
||||||
|
// Do NOT call grad_fn() here as that might trigger a recompute
|
||||||
|
const auto& grad_fn = autograd_meta->grad_fn_;
|
||||||
|
if (grad_fn && grad_fn.use_count() == 1) {
|
||||||
|
// All Node can have a pyobj (stored in "pyobj_")
|
||||||
|
Py_VISIT(grad_fn->pyobj());
|
||||||
|
// PyNode are special as they also have an "obj" field
|
||||||
|
if (auto py_node_fn = dynamic_cast<PyNode*>(grad_fn.get())) {
|
||||||
|
Py_VISIT(py_node_fn->obj);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
if (autograd_meta) {
|
if (autograd_meta) {
|
||||||
// Do NOT call grad_fn() here as that might trigger a recompute
|
for (const auto& hook : torch::autograd::impl::hooks(tensor)) {
|
||||||
const auto& grad_fn = autograd_meta->grad_fn_;
|
if (auto pyhook =
|
||||||
if (grad_fn && grad_fn.use_count() == 1) {
|
dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) {
|
||||||
// All Node can have a pyobj (stored in "pyobj_")
|
Py_VISIT(pyhook->dict);
|
||||||
Py_VISIT(grad_fn->pyobj());
|
|
||||||
// PyNode are special as they also have an "obj" field
|
|
||||||
if (auto py_node_fn = dynamic_cast<PyNode*>(grad_fn.get())) {
|
|
||||||
Py_VISIT(py_node_fn->obj);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (autograd_meta) {
|
|
||||||
for (const auto& hook : torch::autograd::impl::hooks(tensor)) {
|
|
||||||
if (auto pyhook = dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) {
|
|
||||||
Py_VISIT(pyhook->dict);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3568,6 +4010,17 @@ int THPVariableMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs) {
|
|||||||
if (PyType_Type.tp_init(cls, args, kwargs) < 0) {
|
if (PyType_Type.tp_init(cls, args, kwargs) < 0) {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
// It is important for all three of these to be overridden correctly for the
|
||||||
|
// resurrection checks to properly happen. In particular, an older version
|
||||||
|
// was not overriding tp_clear here. This lead to the default subtype_clear
|
||||||
|
// running on the Tensor object (as only TensorBase tp_clear was custom),
|
||||||
|
// clearing the __dict__ field, before the TensorBase custom clear was called
|
||||||
|
// and would properly detect the resurrect.
|
||||||
|
// See https://github.com/pytorch/pytorch/issues/136358 for the exact behavior
|
||||||
|
((PyTypeObject*)cls)->tp_dealloc = (destructor)THPVariable_subclass_dealloc;
|
||||||
|
((PyTypeObject*)cls)->tp_traverse =
|
||||||
|
(traverseproc)THPVariable_subclass_traverse;
|
||||||
|
((PyTypeObject*)cls)->tp_clear = (inquiry)THPVariable_subclass_clear;
|
||||||
|
|
||||||
// Don't do anything for the base Tensor class
|
// Don't do anything for the base Tensor class
|
||||||
if (!THPVariableClass) {
|
if (!THPVariableClass) {
|
||||||
|
|||||||
@ -17,7 +17,7 @@ namespace py = pybind11;
|
|||||||
struct THPVariable {
|
struct THPVariable {
|
||||||
PyObject_HEAD
|
PyObject_HEAD
|
||||||
// Payload
|
// Payload
|
||||||
at::Tensor cdata;
|
c10::MaybeOwned<at::Tensor> cdata;
|
||||||
// Hooks to be run on backwards pass (corresponds to Python attr
|
// Hooks to be run on backwards pass (corresponds to Python attr
|
||||||
// '_backwards_hooks', set by 'register_hook')
|
// '_backwards_hooks', set by 'register_hook')
|
||||||
PyObject* backward_hooks = nullptr;
|
PyObject* backward_hooks = nullptr;
|
||||||
@ -37,11 +37,7 @@ TORCH_PYTHON_API extern PyObject* THPVariableClass;
|
|||||||
TORCH_PYTHON_API extern PyObject* ParameterClass;
|
TORCH_PYTHON_API extern PyObject* ParameterClass;
|
||||||
|
|
||||||
bool THPVariable_initModule(PyObject* module);
|
bool THPVariable_initModule(PyObject* module);
|
||||||
TORCH_PYTHON_API PyObject* THPVariable_Wrap(at::TensorBase&& var);
|
|
||||||
TORCH_PYTHON_API PyObject* THPVariable_Wrap(const at::TensorBase& var);
|
TORCH_PYTHON_API PyObject* THPVariable_Wrap(const at::TensorBase& var);
|
||||||
TORCH_PYTHON_API PyObject* THPVariable_Wrap(
|
|
||||||
const at::TensorBase& var,
|
|
||||||
PyTypeObject* type);
|
|
||||||
|
|
||||||
inline bool THPVariable_CheckTypeExact(PyTypeObject* tp) {
|
inline bool THPVariable_CheckTypeExact(PyTypeObject* tp) {
|
||||||
// Check that a python object is a `Tensor`, but not a `Tensor` subclass.
|
// Check that a python object is a `Tensor`, but not a `Tensor` subclass.
|
||||||
@ -73,7 +69,7 @@ inline bool THPVariable_Check(PyObject* obj) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
inline const at::Tensor& THPVariable_Unpack(THPVariable* var) {
|
inline const at::Tensor& THPVariable_Unpack(THPVariable* var) {
|
||||||
return var->cdata;
|
return *var->cdata;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline const at::Tensor& THPVariable_Unpack(PyObject* obj) {
|
inline const at::Tensor& THPVariable_Unpack(PyObject* obj) {
|
||||||
|
|||||||
@ -65,9 +65,7 @@ inline at::Tensor clone_obey_contract(
|
|||||||
.new_empty_strided_symint(
|
.new_empty_strided_symint(
|
||||||
variable.sym_sizes(),
|
variable.sym_sizes(),
|
||||||
variable.sym_strides(),
|
variable.sym_strides(),
|
||||||
variable.options()
|
variable.options().memory_format(std::nullopt))
|
||||||
.memory_format(std::nullopt)
|
|
||||||
.dtype(new_grad.dtype()))
|
|
||||||
.copy_(new_grad));
|
.copy_(new_grad));
|
||||||
} else {
|
} else {
|
||||||
// (2)
|
// (2)
|
||||||
|
|||||||
@ -70,10 +70,6 @@ inline PyObject* wrap(const at::Tensor& tensor) {
|
|||||||
return THPVariable_Wrap(tensor);
|
return THPVariable_Wrap(tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline PyObject* wrap(at::Tensor&& tensor) {
|
|
||||||
return THPVariable_Wrap(std::move(tensor));
|
|
||||||
}
|
|
||||||
|
|
||||||
inline PyObject* wrap(const at::Scalar& scalar) {
|
inline PyObject* wrap(const at::Scalar& scalar) {
|
||||||
return wrap(scalar_to_tensor(scalar));
|
return wrap(scalar_to_tensor(scalar));
|
||||||
}
|
}
|
||||||
|
|||||||
@ -197,22 +197,6 @@ TORCH_API std::unique_ptr<PostAccumulateGradHook>& post_acc_grad_hooks(
|
|||||||
TORCH_API void create_cpp_hook(
|
TORCH_API void create_cpp_hook(
|
||||||
const at::TensorBase& /*self*/,
|
const at::TensorBase& /*self*/,
|
||||||
bool is_retains_grad_hooks = false);
|
bool is_retains_grad_hooks = false);
|
||||||
|
|
||||||
inline bool is_tensor_stealable(
|
|
||||||
const at::Tensor& new_grad,
|
|
||||||
size_t num_expected_refs = 1) {
|
|
||||||
size_t use_count = new_grad.use_count();
|
|
||||||
if (use_count <= num_expected_refs) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (use_count >= 2 &&
|
|
||||||
new_grad.unsafeGetTensorImpl()->pyobj_slot()->has_unique_reference()) {
|
|
||||||
// The Python wrapper, if it exists, also has a reference to the Tensor.
|
|
||||||
num_expected_refs++;
|
|
||||||
}
|
|
||||||
return use_count <= num_expected_refs;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace impl
|
} // namespace impl
|
||||||
|
|
||||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
@ -910,7 +894,7 @@ inline Variable make_variable(
|
|||||||
bool requires_grad = false,
|
bool requires_grad = false,
|
||||||
bool allow_tensor_metadata_change = true) {
|
bool allow_tensor_metadata_change = true) {
|
||||||
if (data.defined()) {
|
if (data.defined()) {
|
||||||
if (impl::is_tensor_stealable(data) &&
|
if (data.getIntrusivePtr().use_count() == 1 &&
|
||||||
data.getIntrusivePtr()->unique_version()) {
|
data.getIntrusivePtr()->unique_version()) {
|
||||||
auto data_impl = data.unsafeReleaseIntrusivePtr();
|
auto data_impl = data.unsafeReleaseIntrusivePtr();
|
||||||
data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
|
data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
|
||||||
|
|||||||
@ -1,67 +1,19 @@
|
|||||||
#include <torch/csrc/utils/pyobject_preservation.h>
|
#include <torch/csrc/utils/pyobject_preservation.h>
|
||||||
|
|
||||||
#include <c10/core/impl/PyObjectSlot.h>
|
#include <structmember.h>
|
||||||
#include <c10/util/intrusive_ptr.h>
|
|
||||||
|
|
||||||
namespace torch::utils {
|
void clear_slots(PyTypeObject* type, PyObject* self) {
|
||||||
|
Py_ssize_t n = Py_SIZE(type);
|
||||||
|
PyMemberDef* mp = type->tp_members;
|
||||||
|
|
||||||
using c10::intrusive_ptr_target;
|
for (Py_ssize_t i = 0; i < n; i++, mp++) {
|
||||||
using c10::impl::PyObjectSlot;
|
if (mp->type == T_OBJECT_EX && !(mp->flags & READONLY)) {
|
||||||
|
char* addr = (char*)self + mp->offset;
|
||||||
void PyObjectPreservation::init_fresh_nonatomic(
|
PyObject* obj = *(PyObject**)addr;
|
||||||
intrusive_ptr_target* target,
|
if (obj != nullptr) {
|
||||||
PyObjectSlot* slot,
|
*(PyObject**)addr = nullptr;
|
||||||
PyObject* pyobj) {
|
Py_DECREF(obj);
|
||||||
TORCH_INTERNAL_ASSERT(slot->load_pyobj() == nullptr);
|
}
|
||||||
TORCH_INTERNAL_ASSERT(
|
|
||||||
target->combined_refcount_.load(std::memory_order_relaxed) ==
|
|
||||||
c10::detail::kUniqueRef);
|
|
||||||
|
|
||||||
slot->pyobj_.store(pyobj, std::memory_order_relaxed);
|
|
||||||
slot->pyobj_interpreter_.store(
|
|
||||||
c10::impl::getGlobalPyInterpreter(), std::memory_order_relaxed);
|
|
||||||
target->combined_refcount_.store(
|
|
||||||
c10::detail::kHasPyObject | c10::detail::kUniqueRef,
|
|
||||||
std::memory_order_relaxed);
|
|
||||||
}
|
|
||||||
|
|
||||||
PyObject* PyObjectPreservation::init_once(
|
|
||||||
intrusive_ptr_target* target,
|
|
||||||
PyObjectSlot* slot,
|
|
||||||
PyObject* pyobj) {
|
|
||||||
PyObject* expected = nullptr;
|
|
||||||
if (!slot->pyobj_.compare_exchange_strong(
|
|
||||||
expected, pyobj, std::memory_order_acq_rel)) {
|
|
||||||
TORCH_INTERNAL_ASSERT(expected != nullptr);
|
|
||||||
return expected;
|
|
||||||
}
|
|
||||||
|
|
||||||
slot->pyobj_interpreter_.store(
|
|
||||||
c10::impl::getGlobalPyInterpreter(), std::memory_order_release);
|
|
||||||
|
|
||||||
bool increfed = false;
|
|
||||||
auto combined = target->combined_refcount_.load(std::memory_order_relaxed);
|
|
||||||
do {
|
|
||||||
TORCH_INTERNAL_ASSERT(!c10::detail::has_pyobject(combined));
|
|
||||||
if (c10::detail::refcount(combined) > 1 && !increfed) {
|
|
||||||
// We need to incref the object to preserve the invariant that
|
|
||||||
// if refcount > 1, the c10 object holds a reference to the PyObject.
|
|
||||||
// This must happen before we set the kHasPyObject bit.
|
|
||||||
Py_INCREF(pyobj);
|
|
||||||
increfed = true;
|
|
||||||
}
|
}
|
||||||
} while (!target->combined_refcount_.compare_exchange_weak(
|
|
||||||
combined,
|
|
||||||
combined | c10::detail::kHasPyObject,
|
|
||||||
std::memory_order_acq_rel,
|
|
||||||
std::memory_order_relaxed));
|
|
||||||
|
|
||||||
if (increfed && c10::detail::refcount(combined) == 1) {
|
|
||||||
// Fix up if refcount if we did the incref in a failed compare-exchange
|
|
||||||
Py_DECREF(pyobj);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return pyobj;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace torch::utils
|
|
||||||
|
|||||||
@ -4,28 +4,4 @@
|
|||||||
|
|
||||||
// This file contains utilities used for handling PyObject preservation
|
// This file contains utilities used for handling PyObject preservation
|
||||||
|
|
||||||
namespace c10 {
|
void clear_slots(PyTypeObject* type, PyObject* self);
|
||||||
class intrusive_ptr_target;
|
|
||||||
namespace impl {
|
|
||||||
struct PyObjectSlot;
|
|
||||||
} // namespace impl
|
|
||||||
} // namespace c10
|
|
||||||
|
|
||||||
namespace torch::utils {
|
|
||||||
|
|
||||||
class PyObjectPreservation {
|
|
||||||
public:
|
|
||||||
// Store a PyObject wrapper on a fresh c10 wrapper. The caller must hold
|
|
||||||
// a unique reference to `target`.
|
|
||||||
static void init_fresh_nonatomic(
|
|
||||||
c10::intrusive_ptr_target* target,
|
|
||||||
c10::impl::PyObjectSlot* slot,
|
|
||||||
PyObject* pyobj);
|
|
||||||
|
|
||||||
static PyObject* init_once(
|
|
||||||
c10::intrusive_ptr_target* target,
|
|
||||||
c10::impl::PyObjectSlot* slot,
|
|
||||||
PyObject* pyobj);
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace torch::utils
|
|
||||||
|
|||||||
@ -207,19 +207,12 @@ def tensorify_python_scalars(
|
|||||||
and node.target is torch.ops.aten._local_scalar_dense.default
|
and node.target is torch.ops.aten._local_scalar_dense.default
|
||||||
):
|
):
|
||||||
dtype = node.args[0].meta["val"].dtype
|
dtype = node.args[0].meta["val"].dtype
|
||||||
|
if not dtype.is_floating_point:
|
||||||
|
continue
|
||||||
|
|
||||||
assert isinstance(node.args[0], fx.Node), node.args[0]
|
assert isinstance(node.args[0], fx.Node), node.args[0]
|
||||||
|
|
||||||
s = node.meta["val"].node.expr
|
s = node.meta["val"].node.expr
|
||||||
|
|
||||||
expr_to_sym_proxy[s] = MetaProxy(
|
|
||||||
node, tracer=tracer, fake_mode=fake_mode
|
|
||||||
)
|
|
||||||
|
|
||||||
# only tensorify if the dtype is floating point
|
|
||||||
if not dtype.is_floating_point:
|
|
||||||
continue
|
|
||||||
|
|
||||||
expr_to_tensor_proxy[s] = MetaProxy(
|
expr_to_tensor_proxy[s] = MetaProxy(
|
||||||
node.args[0], tracer=tracer, fake_mode=fake_mode
|
node.args[0], tracer=tracer, fake_mode=fake_mode
|
||||||
)
|
)
|
||||||
@ -227,7 +220,9 @@ def tensorify_python_scalars(
|
|||||||
expr_to_tensor_proxy[s] = torch.ops.prims.convert_element_type.default(
|
expr_to_tensor_proxy[s] = torch.ops.prims.convert_element_type.default(
|
||||||
expr_to_tensor_proxy[s], torch.float64
|
expr_to_tensor_proxy[s], torch.float64
|
||||||
)
|
)
|
||||||
|
expr_to_sym_proxy[s] = MetaProxy(
|
||||||
|
node, tracer=tracer, fake_mode=fake_mode
|
||||||
|
)
|
||||||
# pyrefly: ignore [bad-argument-type]
|
# pyrefly: ignore [bad-argument-type]
|
||||||
elif (sym_expr := _get_sym_val(node)) is not None:
|
elif (sym_expr := _get_sym_val(node)) is not None:
|
||||||
if sym_expr not in expr_to_sym_proxy and not isinstance(
|
if sym_expr not in expr_to_sym_proxy and not isinstance(
|
||||||
|
|||||||
@ -387,7 +387,7 @@ class DTensorTestBase(MultiProcessTestCase):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def backend(self) -> str:
|
def backend(self) -> str:
|
||||||
backend = dist.get_default_backend_for_device(DEVICE_TYPE)
|
backend = dist.get_default_backend_for_device(self.device_type)
|
||||||
return backend
|
return backend
|
||||||
|
|
||||||
def init_manual_seed_for_rank(self) -> None:
|
def init_manual_seed_for_rank(self) -> None:
|
||||||
|
|||||||
Reference in New Issue
Block a user