mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
implement a function to convert a storage to copy-on-write (#100819)
implement a function to convert a storage to copy-on-write Summary: This will be used in the _lazy_clone() operator as well as reshape(). Test Plan: 100% coverage of reachable lines. --- Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/pytorch/pull/100819). * #100821 * #100820 * __->__ #100819 Pull Request resolved: https://github.com/pytorch/pytorch/pull/100819 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
f0f700e8d2
commit
aec11b8c80
@ -53,6 +53,8 @@ class C10_API DataPtr {
|
||||
operator bool() const {
|
||||
return static_cast<bool>(ptr_);
|
||||
}
|
||||
|
||||
/// @see c10::UniqueVoidPtr::cast_context
|
||||
template <typename T>
|
||||
T* cast_context(DeleterFnPtr expected_deleter) const {
|
||||
return ptr_.cast_context<T>(expected_deleter);
|
||||
|
@ -109,6 +109,18 @@ def define_targets(rules):
|
||||
visibility = ["//c10/test:__pkg__"],
|
||||
)
|
||||
|
||||
rules.cc_library(
|
||||
name = "impl/cow/try_ensure",
|
||||
srcs = ["impl/cow/try_ensure.cpp"],
|
||||
hdrs = ["impl/cow/try_ensure.h"],
|
||||
deps = [
|
||||
":base",
|
||||
":impl/cow/context",
|
||||
"//c10/util:base",
|
||||
],
|
||||
visibility = ["//c10/test:__pkg__"],
|
||||
)
|
||||
|
||||
rules.filegroup(
|
||||
name = "headers",
|
||||
srcs = rules.glob(
|
||||
|
96
c10/core/impl/cow/try_ensure.cpp
Normal file
96
c10/core/impl/cow/try_ensure.cpp
Normal file
@ -0,0 +1,96 @@
|
||||
#include <c10/core/impl/cow/try_ensure.h>
|
||||
|
||||
#include <c10/core/Allocator.h>
|
||||
#include <c10/core/StorageImpl.h>
|
||||
#include <c10/core/impl/cow/context.h>
|
||||
#include <c10/core/impl/cow/deleter.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/UniqueVoidPtr.h>
|
||||
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
|
||||
namespace c10::impl {
|
||||
|
||||
namespace {
|
||||
|
||||
// Wraps a DataPtr with a copy-on-write DataPtr.
|
||||
auto make_data_ptr(at::DataPtr const& data_ptr, cow::Context& ctx)
|
||||
-> at::DataPtr {
|
||||
return at::DataPtr(
|
||||
data_ptr.get(), &ctx, cow::delete_context, data_ptr.device());
|
||||
}
|
||||
|
||||
/// Copies a copy-on-write DataPtr.
|
||||
auto copy_data_ptr(at::DataPtr const& data_ptr) -> at::DataPtr {
|
||||
auto* ctx = data_ptr.cast_context<cow::Context>(cow::delete_context);
|
||||
TORCH_INTERNAL_ASSERT(ctx != nullptr);
|
||||
ctx->increment_refcount();
|
||||
return make_data_ptr(data_ptr, *ctx);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
auto C10_API cow::try_ensure(StorageImpl& storage)
|
||||
-> c10::intrusive_ptr<StorageImpl> {
|
||||
at::DataPtr& data_ptr = storage.mutable_data_ptr();
|
||||
|
||||
// There are three possible circumstances:
|
||||
//
|
||||
// 1) the storage does not already have a copy on write context. In
|
||||
// this case there can be no blind aliases to the storage impl:
|
||||
// they all will be public aliases and the user is expected to
|
||||
// synchronize manually.
|
||||
//
|
||||
// No locking is required in this case.
|
||||
//
|
||||
// 2) the storage has a context that is not the copy on write
|
||||
// context. This is not supported, so we just return null.
|
||||
//
|
||||
// No locking is required in this case.
|
||||
//
|
||||
// 3) there is already a copy on write context on the storage. There
|
||||
// is a potential race condition with a blind alias (i.e. an
|
||||
// alias that the user is not required to synchronize
|
||||
// with). Because our input storage is bound to a live reference
|
||||
// to the data, we know that it isn't going away. A blind alias
|
||||
// could be copying from it right now, but we will grab the
|
||||
// context's mutex to protect us.
|
||||
//
|
||||
// We do not need to lock in this case either, because we're just
|
||||
// wrapping a context that we know isn't going away.
|
||||
|
||||
std::optional<DataPtr> new_data_ptr; // must be set below
|
||||
|
||||
if (data_ptr.get() == data_ptr.get_context()) {
|
||||
// Case 1) We have a simple data pointer: wrap it.
|
||||
std::unique_ptr<void, DeleterFnPtr> original_ctx = data_ptr.move_context();
|
||||
TORCH_INTERNAL_ASSERT(original_ctx.get() == data_ptr.get());
|
||||
|
||||
// Save this for the result.
|
||||
new_data_ptr =
|
||||
make_data_ptr(data_ptr, *new cow::Context(std::move(original_ctx)));
|
||||
|
||||
// Update this storage to the new copy on write context.
|
||||
storage.set_data_ptr_noswap(copy_data_ptr(*new_data_ptr));
|
||||
} else if (data_ptr.get_deleter() != cow::delete_context) {
|
||||
// Case 2) There is a context and it's not copy-on-write. Nothing
|
||||
// we can do here.
|
||||
return nullptr;
|
||||
} else {
|
||||
// Case 3): there is already a copy on write context. Just return a
|
||||
// new storage impl.
|
||||
new_data_ptr = copy_data_ptr(data_ptr);
|
||||
}
|
||||
|
||||
TORCH_INTERNAL_ASSERT(new_data_ptr.has_value());
|
||||
|
||||
return make_intrusive<StorageImpl>(
|
||||
StorageImpl::use_byte_size_t(),
|
||||
storage.sym_nbytes(),
|
||||
*std::move(new_data_ptr),
|
||||
storage.allocator(),
|
||||
storage.resizable());
|
||||
}
|
||||
|
||||
} // namespace c10::impl
|
23
c10/core/impl/cow/try_ensure.h
Normal file
23
c10/core/impl/cow/try_ensure.h
Normal file
@ -0,0 +1,23 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
|
||||
namespace c10 {
|
||||
struct StorageImpl;
|
||||
}; // namespace c10
|
||||
|
||||
namespace c10::impl::cow {
|
||||
|
||||
// Ensures storage is copy-on-write, returning a new StorageImpl
|
||||
// sharing the data.
|
||||
//
|
||||
// The result is suitable for creating a new Tensor that is logically
|
||||
// distinct but shares data still.
|
||||
//
|
||||
// This will try to convert the storage to use a copy-on-write context
|
||||
// if it is not already. Returns null only if Storage does not have a
|
||||
// copy-on-write context upon completion.
|
||||
auto C10_API try_ensure(StorageImpl& storage) -> intrusive_ptr<StorageImpl>;
|
||||
|
||||
} // namespace c10::impl::cow
|
@ -18,6 +18,17 @@ def define_targets(rules):
|
||||
],
|
||||
)
|
||||
|
||||
rules.cc_test(
|
||||
name = "core/impl/cow/try_ensure_test",
|
||||
srcs = ["core/impl/cow/try_ensure_test.cpp"],
|
||||
deps = [
|
||||
"//c10/core:CPUAllocator",
|
||||
"//c10/core:impl/cow/context",
|
||||
"//c10/core:impl/cow/try_ensure",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
rules.cc_test(
|
||||
name = "core_tests",
|
||||
size = "small",
|
||||
|
92
c10/test/core/impl/cow/try_ensure_test.cpp
Normal file
92
c10/test/core/impl/cow/try_ensure_test.cpp
Normal file
@ -0,0 +1,92 @@
|
||||
#include <c10/core/impl/cow/try_ensure.h>
|
||||
|
||||
#include <c10/core/CPUAllocator.h>
|
||||
#include <c10/core/StorageImpl.h>
|
||||
#include <c10/core/impl/cow/context.h>
|
||||
#include <c10/core/impl/cow/deleter.h>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <memory>
|
||||
|
||||
namespace c10::impl {
|
||||
namespace {
|
||||
|
||||
MATCHER(is_copy_on_write, "") {
|
||||
const c10::StorageImpl& storage = std::ref(arg);
|
||||
return storage.data_ptr().get_deleter() == cow::delete_context;
|
||||
}
|
||||
|
||||
TEST(try_ensure_test, no_context) {
|
||||
StorageImpl original_storage(
|
||||
{}, /*size_bytes=*/7, GetCPUAllocator(), /*resizable=*/false);
|
||||
ASSERT_THAT(original_storage, testing::Not(is_copy_on_write()));
|
||||
|
||||
intrusive_ptr<StorageImpl> new_storage = cow::try_ensure(original_storage);
|
||||
ASSERT_THAT(new_storage, testing::NotNull());
|
||||
|
||||
// The original storage was modified in-place to now hold a copy on
|
||||
// write context.
|
||||
ASSERT_THAT(original_storage, is_copy_on_write());
|
||||
|
||||
// The result is a different storage impl.
|
||||
ASSERT_THAT(&*new_storage, testing::Ne(&original_storage));
|
||||
// But it is also copy-on-write.
|
||||
ASSERT_THAT(*new_storage, is_copy_on_write());
|
||||
// But they share the same data!
|
||||
ASSERT_THAT(new_storage->data(), testing::Eq(original_storage.data()));
|
||||
}
|
||||
|
||||
class OpaqueContext {};
|
||||
|
||||
TEST(try_ensure_test, different_context) {
|
||||
StorageImpl storage(
|
||||
{},
|
||||
/*size_bytes=*/5,
|
||||
at::DataPtr(
|
||||
/*data=*/new std::byte[5],
|
||||
/*ctx=*/new OpaqueContext,
|
||||
+[](void* opaque_ctx) {
|
||||
delete static_cast<OpaqueContext*>(opaque_ctx);
|
||||
},
|
||||
Device(Device::Type::CPU)),
|
||||
/*allocator=*/nullptr,
|
||||
/*resizable=*/false);
|
||||
|
||||
// We can't handle an arbitrary context.
|
||||
ASSERT_THAT(cow::try_ensure(storage), testing::IsNull());
|
||||
}
|
||||
|
||||
TEST(try_ensure_test, already_copy_on_write) {
|
||||
std::unique_ptr<void, DeleterFnPtr> data(
|
||||
new std::byte[5],
|
||||
+[](void* bytes) { delete[] static_cast<std::byte*>(bytes); });
|
||||
void* data_ptr = data.get();
|
||||
StorageImpl original_storage(
|
||||
{},
|
||||
/*size_bytes=*/5,
|
||||
at::DataPtr(
|
||||
/*data=*/data_ptr,
|
||||
/*ctx=*/new cow::Context(std::move(data)),
|
||||
cow::delete_context,
|
||||
Device(Device::Type::CPU)),
|
||||
/*allocator=*/nullptr,
|
||||
/*resizable=*/false);
|
||||
|
||||
ASSERT_THAT(original_storage, is_copy_on_write());
|
||||
|
||||
intrusive_ptr<StorageImpl> new_storage = cow::try_ensure(original_storage);
|
||||
ASSERT_THAT(new_storage, testing::NotNull());
|
||||
|
||||
// The result is a different storage.
|
||||
ASSERT_THAT(&*new_storage, testing::Ne(&original_storage));
|
||||
// But it is also copy-on-write.
|
||||
ASSERT_THAT(*new_storage, is_copy_on_write());
|
||||
// But they share the same data!
|
||||
ASSERT_THAT(new_storage->data(), testing::Eq(original_storage.data()));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace c10::impl
|
@ -75,6 +75,11 @@ class UniqueVoidPtr {
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Casts the context to the requested type, contingent on the
|
||||
/// deleter matching.
|
||||
///
|
||||
/// Returns null without attempting a cast if the deleter does not
|
||||
/// match.
|
||||
template <typename T>
|
||||
T* cast_context(DeleterFnPtr expected_deleter) const {
|
||||
if (get_deleter() != expected_deleter)
|
||||
|
Reference in New Issue
Block a user