implement a function to convert a storage to copy-on-write (#100819)

implement a function to convert a storage to copy-on-write

Summary:
This will be used in the _lazy_clone() operator as well as reshape().

Test Plan: 100% coverage of reachable lines.

---
Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/pytorch/pull/100819).
* #100821
* #100820
* __->__ #100819

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100819
Approved by: https://github.com/ezyang
This commit is contained in:
mikey dagitses
2023-05-12 17:45:04 +00:00
committed by PyTorch MergeBot
parent f0f700e8d2
commit aec11b8c80
7 changed files with 241 additions and 0 deletions

View File

@ -53,6 +53,8 @@ class C10_API DataPtr {
operator bool() const {
return static_cast<bool>(ptr_);
}
/// @see c10::UniqueVoidPtr::cast_context
template <typename T>
T* cast_context(DeleterFnPtr expected_deleter) const {
return ptr_.cast_context<T>(expected_deleter);

View File

@ -109,6 +109,18 @@ def define_targets(rules):
visibility = ["//c10/test:__pkg__"],
)
rules.cc_library(
name = "impl/cow/try_ensure",
srcs = ["impl/cow/try_ensure.cpp"],
hdrs = ["impl/cow/try_ensure.h"],
deps = [
":base",
":impl/cow/context",
"//c10/util:base",
],
visibility = ["//c10/test:__pkg__"],
)
rules.filegroup(
name = "headers",
srcs = rules.glob(

View File

@ -0,0 +1,96 @@
#include <c10/core/impl/cow/try_ensure.h>
#include <c10/core/Allocator.h>
#include <c10/core/StorageImpl.h>
#include <c10/core/impl/cow/context.h>
#include <c10/core/impl/cow/deleter.h>
#include <c10/util/Exception.h>
#include <c10/util/UniqueVoidPtr.h>
#include <memory>
#include <optional>
namespace c10::impl {
namespace {
// Wraps a DataPtr with a copy-on-write DataPtr.
auto make_data_ptr(at::DataPtr const& data_ptr, cow::Context& ctx)
-> at::DataPtr {
return at::DataPtr(
data_ptr.get(), &ctx, cow::delete_context, data_ptr.device());
}
/// Copies a copy-on-write DataPtr.
auto copy_data_ptr(at::DataPtr const& data_ptr) -> at::DataPtr {
auto* ctx = data_ptr.cast_context<cow::Context>(cow::delete_context);
TORCH_INTERNAL_ASSERT(ctx != nullptr);
ctx->increment_refcount();
return make_data_ptr(data_ptr, *ctx);
}
} // namespace
auto C10_API cow::try_ensure(StorageImpl& storage)
-> c10::intrusive_ptr<StorageImpl> {
at::DataPtr& data_ptr = storage.mutable_data_ptr();
// There are three possible circumstances:
//
// 1) the storage does not already have a copy on write context. In
// this case there can be no blind aliases to the storage impl:
// they all will be public aliases and the user is expected to
// synchronize manually.
//
// No locking is required in this case.
//
// 2) the storage has a context that is not the copy on write
// context. This is not supported, so we just return null.
//
// No locking is required in this case.
//
// 3) there is already a copy on write context on the storage. There
// is a potential race condition with a blind alias (i.e. an
// alias that the user is not required to synchronize
// with). Because our input storage is bound to a live reference
// to the data, we know that it isn't going away. A blind alias
// could be copying from it right now, but we will grab the
// context's mutex to protect us.
//
// We do not need to lock in this case either, because we're just
// wrapping a context that we know isn't going away.
std::optional<DataPtr> new_data_ptr; // must be set below
if (data_ptr.get() == data_ptr.get_context()) {
// Case 1) We have a simple data pointer: wrap it.
std::unique_ptr<void, DeleterFnPtr> original_ctx = data_ptr.move_context();
TORCH_INTERNAL_ASSERT(original_ctx.get() == data_ptr.get());
// Save this for the result.
new_data_ptr =
make_data_ptr(data_ptr, *new cow::Context(std::move(original_ctx)));
// Update this storage to the new copy on write context.
storage.set_data_ptr_noswap(copy_data_ptr(*new_data_ptr));
} else if (data_ptr.get_deleter() != cow::delete_context) {
// Case 2) There is a context and it's not copy-on-write. Nothing
// we can do here.
return nullptr;
} else {
// Case 3): there is already a copy on write context. Just return a
// new storage impl.
new_data_ptr = copy_data_ptr(data_ptr);
}
TORCH_INTERNAL_ASSERT(new_data_ptr.has_value());
return make_intrusive<StorageImpl>(
StorageImpl::use_byte_size_t(),
storage.sym_nbytes(),
*std::move(new_data_ptr),
storage.allocator(),
storage.resizable());
}
} // namespace c10::impl

View File

@ -0,0 +1,23 @@
#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/intrusive_ptr.h>
namespace c10 {
struct StorageImpl;
}; // namespace c10
namespace c10::impl::cow {
// Ensures storage is copy-on-write, returning a new StorageImpl
// sharing the data.
//
// The result is suitable for creating a new Tensor that is logically
// distinct but shares data still.
//
// This will try to convert the storage to use a copy-on-write context
// if it is not already. Returns null only if Storage does not have a
// copy-on-write context upon completion.
auto C10_API try_ensure(StorageImpl& storage) -> intrusive_ptr<StorageImpl>;
} // namespace c10::impl::cow

View File

@ -18,6 +18,17 @@ def define_targets(rules):
],
)
rules.cc_test(
name = "core/impl/cow/try_ensure_test",
srcs = ["core/impl/cow/try_ensure_test.cpp"],
deps = [
"//c10/core:CPUAllocator",
"//c10/core:impl/cow/context",
"//c10/core:impl/cow/try_ensure",
"@com_google_googletest//:gtest_main",
],
)
rules.cc_test(
name = "core_tests",
size = "small",

View File

@ -0,0 +1,92 @@
#include <c10/core/impl/cow/try_ensure.h>
#include <c10/core/CPUAllocator.h>
#include <c10/core/StorageImpl.h>
#include <c10/core/impl/cow/context.h>
#include <c10/core/impl/cow/deleter.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <cstddef>
#include <memory>
namespace c10::impl {
namespace {
MATCHER(is_copy_on_write, "") {
const c10::StorageImpl& storage = std::ref(arg);
return storage.data_ptr().get_deleter() == cow::delete_context;
}
TEST(try_ensure_test, no_context) {
StorageImpl original_storage(
{}, /*size_bytes=*/7, GetCPUAllocator(), /*resizable=*/false);
ASSERT_THAT(original_storage, testing::Not(is_copy_on_write()));
intrusive_ptr<StorageImpl> new_storage = cow::try_ensure(original_storage);
ASSERT_THAT(new_storage, testing::NotNull());
// The original storage was modified in-place to now hold a copy on
// write context.
ASSERT_THAT(original_storage, is_copy_on_write());
// The result is a different storage impl.
ASSERT_THAT(&*new_storage, testing::Ne(&original_storage));
// But it is also copy-on-write.
ASSERT_THAT(*new_storage, is_copy_on_write());
// But they share the same data!
ASSERT_THAT(new_storage->data(), testing::Eq(original_storage.data()));
}
class OpaqueContext {};
TEST(try_ensure_test, different_context) {
StorageImpl storage(
{},
/*size_bytes=*/5,
at::DataPtr(
/*data=*/new std::byte[5],
/*ctx=*/new OpaqueContext,
+[](void* opaque_ctx) {
delete static_cast<OpaqueContext*>(opaque_ctx);
},
Device(Device::Type::CPU)),
/*allocator=*/nullptr,
/*resizable=*/false);
// We can't handle an arbitrary context.
ASSERT_THAT(cow::try_ensure(storage), testing::IsNull());
}
TEST(try_ensure_test, already_copy_on_write) {
std::unique_ptr<void, DeleterFnPtr> data(
new std::byte[5],
+[](void* bytes) { delete[] static_cast<std::byte*>(bytes); });
void* data_ptr = data.get();
StorageImpl original_storage(
{},
/*size_bytes=*/5,
at::DataPtr(
/*data=*/data_ptr,
/*ctx=*/new cow::Context(std::move(data)),
cow::delete_context,
Device(Device::Type::CPU)),
/*allocator=*/nullptr,
/*resizable=*/false);
ASSERT_THAT(original_storage, is_copy_on_write());
intrusive_ptr<StorageImpl> new_storage = cow::try_ensure(original_storage);
ASSERT_THAT(new_storage, testing::NotNull());
// The result is a different storage.
ASSERT_THAT(&*new_storage, testing::Ne(&original_storage));
// But it is also copy-on-write.
ASSERT_THAT(*new_storage, is_copy_on_write());
// But they share the same data!
ASSERT_THAT(new_storage->data(), testing::Eq(original_storage.data()));
}
} // namespace
} // namespace c10::impl

View File

@ -75,6 +75,11 @@ class UniqueVoidPtr {
return true;
}
/// Casts the context to the requested type, contingent on the
/// deleter matching.
///
/// Returns null without attempting a cast if the deleter does not
/// match.
template <typename T>
T* cast_context(DeleterFnPtr expected_deleter) const {
if (get_deleter() != expected_deleter)