mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[PyTorch] Add c10::MaybeOwned and Tensor::expect_contiguous (#53317)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53317 This seems like it might help in cases where we have to call `Tensor::contiguous`, but we expect that the tensor in question will be contiguous a good portion of the time. ghstack-source-id: 123203771 Test Plan: Profiled AdIndexer on inline_cvr; time spent in clip_ranges_gather_sigrid_hash_each_feature<int> was cut in half from 1.37% to 0.66% Reviewed By: smessmer Differential Revision: D26738036 fbshipit-source-id: b5db10783ccd103dae0ab3e79338a83b5e507ebb
This commit is contained in:
committed by
Facebook GitHub Bot
parent
8acb74c405
commit
0606057af3
@ -15,6 +15,7 @@
|
||||
#include <c10/core/WrapDimMinimal.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Deprecated.h>
|
||||
#include <c10/util/MaybeOwned.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
#include <ATen/core/DeprecatedTypePropertiesRegistry.h>
|
||||
@ -125,6 +126,24 @@ class TORCH_API Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Should be used if *this can reasonably be expected to be contiguous and
|
||||
/// performance is important.
|
||||
/// Compared to contiguous, it saves a reference count
|
||||
/// increment/decrement if *this is already contiguous, at the cost
|
||||
/// in all cases of an extra pointer of stack usage, an extra branch
|
||||
/// to access, and an extra branch at destruction time.
|
||||
c10::MaybeOwned<Tensor> expect_contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const & {
|
||||
if (is_contiguous(memory_format)) {
|
||||
return c10::MaybeOwned<Tensor>::borrowed(*this);
|
||||
} else {
|
||||
return c10::MaybeOwned<Tensor>::owned(__dispatch_contiguous(memory_format));
|
||||
}
|
||||
}
|
||||
|
||||
// Use .contiguous() instead. Trying to borrow from a prvalue Tensor
|
||||
// will only lead to trouble and dangling references.
|
||||
c10::MaybeOwned<Tensor> expect_contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) && = delete;
|
||||
|
||||
bool is_complex() const {
|
||||
return at::isComplexType(this->scalar_type());
|
||||
}
|
||||
|
94
c10/test/util/MaybeOwned_test.cpp
Normal file
94
c10/test/util/MaybeOwned_test.cpp
Normal file
@ -0,0 +1,94 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <c10/util/MaybeOwned.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
template<typename T>
|
||||
using MaybeOwned = c10::MaybeOwned<T>;
|
||||
|
||||
TEST(MaybeOwnedTest, SimpleDereferencingInt) {
|
||||
int x = 123;
|
||||
auto borrowed = MaybeOwned<int>::borrowed(x);
|
||||
auto owned = MaybeOwned<int>::owned(c10::in_place, x);
|
||||
EXPECT_EQ(*borrowed, x);
|
||||
EXPECT_EQ(*owned, x);
|
||||
EXPECT_EQ(&*borrowed, &x);
|
||||
EXPECT_NE(&*owned, &x);
|
||||
}
|
||||
|
||||
TEST(MaybeOwnedTest, SimpleDereferencingString) {
|
||||
std::string x = "hello";
|
||||
std::string y = x;
|
||||
auto borrowed = MaybeOwned<std::string>::borrowed(x);
|
||||
auto owned = MaybeOwned<std::string>::owned(c10::in_place, x);
|
||||
auto owned2 = MaybeOwned<std::string>::owned(std::move(y));
|
||||
EXPECT_EQ(*borrowed, x);
|
||||
EXPECT_EQ(*owned, x);
|
||||
EXPECT_EQ(*owned2, x);
|
||||
EXPECT_EQ(&*borrowed, &x);
|
||||
EXPECT_NE(&*owned, &x);
|
||||
EXPECT_NE(&*owned2, &x);
|
||||
|
||||
EXPECT_EQ(borrowed->size(), x.size());
|
||||
EXPECT_EQ(owned->size(), x.size());
|
||||
EXPECT_EQ(owned2->size(), x.size());
|
||||
}
|
||||
|
||||
TEST(MaybeOwnedTest, MoveConstructor) {
|
||||
std::string x = "hello";
|
||||
auto borrowed = MaybeOwned<std::string>::borrowed(x);
|
||||
auto owned = MaybeOwned<std::string>::owned(c10::in_place, x);
|
||||
auto owned2 = MaybeOwned<std::string>::owned(std::string(x));
|
||||
|
||||
auto movedBorrowed(std::move(borrowed));
|
||||
auto movedOwned(std::move(owned));
|
||||
auto movedOwned2(std::move(owned2));
|
||||
|
||||
for (auto *mo : {&movedBorrowed, &movedOwned, &movedOwned2}) {
|
||||
EXPECT_EQ(**mo, x);
|
||||
EXPECT_EQ((*mo)->size(), x.size());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(MaybeOwnedTest, MoveAssignmentIntoOwned) {
|
||||
std::string x = "hello";
|
||||
auto borrowed = MaybeOwned<std::string>::borrowed(x);
|
||||
auto owned = MaybeOwned<std::string>::owned(c10::in_place, x);
|
||||
auto owned2 = MaybeOwned<std::string>::owned(std::string(x));
|
||||
|
||||
auto movedBorrowed = MaybeOwned<std::string>::owned(c10::in_place, "");
|
||||
auto movedOwned = MaybeOwned<std::string>::owned(c10::in_place, "");
|
||||
auto movedOwned2 = MaybeOwned<std::string>::owned(c10::in_place, "");
|
||||
|
||||
movedBorrowed = std::move(borrowed);
|
||||
movedOwned = std::move(owned);
|
||||
movedOwned2 = std::move(owned2);
|
||||
|
||||
for (auto *mo : {&movedBorrowed, &movedOwned, &movedOwned2}) {
|
||||
EXPECT_EQ(**mo, x);
|
||||
EXPECT_EQ((*mo)->size(), x.size());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TEST(MaybeOwnedTest, MoveAssignmentIntoBorrowed) {
|
||||
std::string x = "hello";
|
||||
auto borrowed = MaybeOwned<std::string>::borrowed(x);
|
||||
auto owned = MaybeOwned<std::string>::owned(c10::in_place, x);
|
||||
auto owned2 = MaybeOwned<std::string>::owned(std::string(x));
|
||||
|
||||
std::string y = "goodbye";
|
||||
auto movedBorrowed = MaybeOwned<std::string>::borrowed(y);
|
||||
auto movedOwned = MaybeOwned<std::string>::borrowed(y);
|
||||
auto movedOwned2 = MaybeOwned<std::string>::borrowed(y);
|
||||
|
||||
movedBorrowed = std::move(borrowed);
|
||||
movedOwned = std::move(owned);
|
||||
movedOwned2 = std::move(owned2);
|
||||
|
||||
for (auto *mo : {&movedBorrowed, &movedOwned, &movedOwned2}) {
|
||||
EXPECT_EQ(**mo, x);
|
||||
EXPECT_EQ((*mo)->size(), x.size());
|
||||
}
|
||||
}
|
103
c10/util/MaybeOwned.h
Normal file
103
c10/util/MaybeOwned.h
Normal file
@ -0,0 +1,103 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/in_place.h>
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
/// A smart pointer around either a borrowed or owned T. Maintains an
|
||||
/// internal raw pointer when constructed with borrowed(), with all
|
||||
/// the attendant lifetime concerns. Compare to Rust's
|
||||
/// std::borrow::Cow
|
||||
/// (https://doc.rust-lang.org/std/borrow/enum.Cow.html), but note
|
||||
/// that it is probably not suitable for general use because C++ has
|
||||
/// no borrow checking. Included here to support
|
||||
/// Tensor::expect_contiguous.
|
||||
template <typename T>
|
||||
class MaybeOwned final {
|
||||
bool isBorrowed_;
|
||||
union {
|
||||
const T *borrow_;
|
||||
T own_;
|
||||
};
|
||||
|
||||
/// Don't use this; use borrowed() instead.
|
||||
explicit MaybeOwned(const T& t) : isBorrowed_(true), borrow_(&t) {}
|
||||
|
||||
/// Don't use this; use owned() instead.
|
||||
explicit MaybeOwned(T&& t) noexcept(std::is_nothrow_move_constructible<T>::value)
|
||||
: isBorrowed_(false), own_(std::move(t)) {}
|
||||
|
||||
/// Don't use this; use owned() instead.
|
||||
template <class... Args>
|
||||
explicit MaybeOwned(in_place_t, Args&&... args)
|
||||
: isBorrowed_(false)
|
||||
, own_(std::forward<Args>(args)...) {}
|
||||
|
||||
public:
|
||||
|
||||
MaybeOwned(const MaybeOwned&) = delete;
|
||||
MaybeOwned& operator=(const MaybeOwned&) = delete;
|
||||
|
||||
MaybeOwned(MaybeOwned&& rhs) noexcept(std::is_nothrow_move_constructible<T>::value)
|
||||
: isBorrowed_(rhs.isBorrowed_) {
|
||||
if (rhs.isBorrowed_) {
|
||||
borrow_ = rhs.borrow_;
|
||||
} else {
|
||||
new (&own_) T(std::move(rhs.own_));
|
||||
}
|
||||
}
|
||||
|
||||
MaybeOwned& operator=(MaybeOwned&& rhs) noexcept(std::is_nothrow_move_assignable<T>::value) {
|
||||
if (!isBorrowed_) {
|
||||
if (rhs.isBorrowed_) {
|
||||
own_.~T();
|
||||
borrow_ = rhs.borrow_;
|
||||
isBorrowed_ = true;
|
||||
} else {
|
||||
own_ = std::move(rhs.own_);
|
||||
}
|
||||
} else {
|
||||
if (rhs.isBorrowed_) {
|
||||
borrow_ = rhs.borrow_;
|
||||
} else {
|
||||
new (&own_) T(std::move(rhs.own_));
|
||||
isBorrowed_ = false;
|
||||
}
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isBorrowed_ == rhs.isBorrowed_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
static MaybeOwned borrowed(const T& t) {
|
||||
return MaybeOwned(t);
|
||||
}
|
||||
|
||||
static MaybeOwned owned(T&& t) noexcept(std::is_nothrow_move_constructible<T>::value) {
|
||||
return MaybeOwned(std::move(t));
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
static MaybeOwned owned(in_place_t, Args&&... args) {
|
||||
return MaybeOwned(in_place, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
~MaybeOwned() {
|
||||
if (!isBorrowed_) {
|
||||
own_.~T();
|
||||
}
|
||||
}
|
||||
|
||||
const T& operator*() const {
|
||||
return isBorrowed_ ? *borrow_ : own_;
|
||||
}
|
||||
|
||||
const T* operator->() const {
|
||||
return isBorrowed_ ? borrow_ : &own_;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
} // namespace c10
|
Reference in New Issue
Block a user