mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: From Kurt Mohler, see https://github.com/pytorch/pytorch/pull/113396 (manually imported due to ghimport problems) Test Plan: sandcastle, OSS CI Differential Revision: D52610522 Pull Request resolved: https://github.com/pytorch/pytorch/pull/117053 Approved by: https://github.com/malfet, https://github.com/kurtamohler
43 lines
1.1 KiB
C++
43 lines
1.1 KiB
C++
#include <c10/core/impl/COWDeleter.h>
|
|
#include <c10/util/Exception.h>
|
|
#include <mutex>
|
|
|
|
namespace c10::impl {
|
|
|
|
void cow::cow_deleter(void* ctx) {
|
|
static_cast<cow::COWDeleterContext*>(ctx)->decrement_refcount();
|
|
}
|
|
|
|
cow::COWDeleterContext::COWDeleterContext(
|
|
std::unique_ptr<void, DeleterFnPtr> data)
|
|
: data_(std::move(data)) {
|
|
// We never wrap a COWDeleterContext.
|
|
TORCH_INTERNAL_ASSERT(data_.get_deleter() != cow::cow_deleter);
|
|
}
|
|
|
|
auto cow::COWDeleterContext::increment_refcount() -> void {
|
|
auto refcount = ++refcount_;
|
|
TORCH_INTERNAL_ASSERT(refcount > 1);
|
|
}
|
|
|
|
auto cow::COWDeleterContext::decrement_refcount()
|
|
-> std::variant<NotLastReference, LastReference> {
|
|
auto refcount = --refcount_;
|
|
TORCH_INTERNAL_ASSERT(refcount >= 0, refcount);
|
|
if (refcount == 0) {
|
|
std::unique_lock lock(mutex_);
|
|
auto result = std::move(data_);
|
|
lock.unlock();
|
|
delete this;
|
|
return {std::move(result)};
|
|
}
|
|
|
|
return std::shared_lock(mutex_);
|
|
}
|
|
|
|
cow::COWDeleterContext::~COWDeleterContext() {
|
|
TORCH_INTERNAL_ASSERT(refcount_ == 0);
|
|
}
|
|
|
|
} // namespace c10::impl
|