Consistent compute numel/contiguous strategy with SymInts (#85858)

Previously, our handling for contiguity was inconsistent in the following ways:

- is_strides_like 2d/3d and is_non_overlapping_and_dense always were computed
  based on sizes_and_strides_, even if you had symbolic ints
- Furthermore, even if you set custom policy for strides, these quantities were
  not overridable by subclasses
- Furthermore, we didn't even store these fields on ExtraMeta
- We duplicate implementations of compute_contiguous (plain, channels last,
  channels last 3d)
- We inconsistently called refresh_numel()/refresh_contiguous(), versus
  recomputing it ourselves

This factor makes a consistent strategy for all of the boolean fields, and
for numel computation.  After this refactor:

- All layout boolean fields are interposable via strides policy
  and can be overridden from Python; you will never access a garbage field
- All layout boolean fields are on ExtraMeta
- You can always call refresh_numel/contiguous, no matter if your Tensor is
  contiguous or not
- The numel/layout boolean fields are always populated consistently with
  the sizes strides fields (either on Tensor or ExtraMeta), even if you
  have custom policy
- There is only one implementation of the actual computation logic

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Differential Revision: [D39907696](https://our.internmc.facebook.com/intern/diff/D39907696)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85858
Approved by: https://github.com/albanD
This commit is contained in:
Edward Z. Yang
2022-09-30 10:01:35 -07:00
committed by PyTorch MergeBot
parent 84a06d7193
commit 3b6588ab74
18 changed files with 617 additions and 192 deletions

View File

@ -284,7 +284,7 @@ at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::IntArrayRef s
auto inferred_size = at::infer_size_dv(size, self.numel());
auto stride = at::detail::computeStride(self.sizes(), self.strides(), inferred_size);
TORCH_INTERNAL_ASSERT(stride.has_value());
out.unsafeGetTensorImpl()->set_sizes_and_strides(size, stride.value());
out.unsafeGetTensorImpl()->set_sizes_and_strides(inferred_size, stride.value());
return out;
}

View File

@ -114,10 +114,11 @@ inline std::vector<int64_t> get_channels_last_strides_3d(IntArrayRef sizes) {
// input
// 3. All helper functions have similar comments, only 1st helper function is
// commented here.
template <typename T>
inline bool is_channels_last_strides_2d_s4(
const IntArrayRef sizes,
const IntArrayRef strides) {
int64_t min = 0;
const ArrayRef<T> sizes,
const ArrayRef<T> strides) {
T min = 0;
// special case for trivial C dimension. default to NCHW
if (strides[1] == 0) {
return false;
@ -155,10 +156,11 @@ inline bool is_channels_last_strides_2d_s4(
return true;
}
template <typename T>
inline bool is_channels_last_strides_3d_s5(
const IntArrayRef sizes,
const IntArrayRef strides) {
int64_t min = 0;
const ArrayRef<T> sizes,
const ArrayRef<T> strides) {
T min = 0;
if (strides[1] == 0) {
return false;
}
@ -230,9 +232,10 @@ inline bool is_channels_last_strides_3d_s5(
// implementation. Please check the helper functions
// (is_channels_last_strides_*d_s*) for more details.
template <typename T>
inline bool is_channels_last_strides_2d(
const IntArrayRef sizes,
const IntArrayRef strides) {
const ArrayRef<T> sizes,
const ArrayRef<T> strides) {
switch (sizes.size()) {
case 4:
return is_channels_last_strides_2d_s4(sizes, strides);
@ -244,9 +247,10 @@ inline bool is_channels_last_strides_2d(
}
}
template <typename T>
inline bool is_channels_last_strides_3d(
const IntArrayRef sizes,
const IntArrayRef strides) {
const ArrayRef<T> sizes,
const ArrayRef<T> strides) {
switch (sizes.size()) {
case 5:
return is_channels_last_strides_3d_s5(sizes, strides);
@ -258,4 +262,16 @@ inline bool is_channels_last_strides_3d(
}
}
inline bool is_channels_last_strides_2d(
const IntArrayRef sizes,
const IntArrayRef strides) {
return is_channels_last_strides_2d<int64_t>(sizes, strides);
}
inline bool is_channels_last_strides_3d(
const IntArrayRef sizes,
const IntArrayRef strides) {
return is_channels_last_strides_3d<int64_t>(sizes, strides);
}
} // namespace c10

View File

@ -227,15 +227,20 @@ void TensorImpl::HandleResize() {
}
}
bool TensorImpl::compute_contiguous() const {
template <typename T>
bool_is_contiguous _compute_contiguous(
ArrayRef<T> sizes,
ArrayRef<T> strides,
T numel) {
bool is_contiguous = true;
if (is_empty())
return is_contiguous;
int64_t z = 1;
for (int64_t d = dim() - 1; d >= 0; d--) {
const auto size_d = sizes_and_strides_.size_at_unchecked(d);
if (numel == 0)
return bool_is_contiguous(is_contiguous);
T z = 1;
// NB: make sure we do signed arithmetic
for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) {
const auto size_d = sizes[d];
if (size_d != 1) {
if (sizes_and_strides_.stride_at_unchecked(d) == z) {
if (strides[d] == z) {
z *= size_d;
} else {
is_contiguous = false;
@ -243,103 +248,150 @@ bool TensorImpl::compute_contiguous() const {
}
}
}
return is_contiguous;
return bool_is_contiguous(is_contiguous);
}
bool TensorImpl::compute_channels_last_contiguous_2d() const {
// NB: intentionally bypass the normal accessors; we always want to be
// consistent with what is actually stored on the struct
#define COMPUTE_WITH_SIZES_STRIDES_NUMEL(TEMPLATE) \
(has_symbolic_sizes_strides_ \
? TEMPLATE<c10::SymInt>( \
extra_meta_->sizes_, extra_meta_->strides_, extra_meta_->numel_) \
: TEMPLATE<int64_t>( \
sizes_and_strides_.sizes_arrayref(), \
sizes_and_strides_.strides_arrayref(), \
numel_))
#define COMPUTE_WITH_SIZES_STRIDES(TEMPLATE) \
(has_symbolic_sizes_strides_ \
? TEMPLATE<c10::SymInt>(extra_meta_->sizes_, extra_meta_->strides_) \
: TEMPLATE<int64_t>( \
sizes_and_strides_.sizes_arrayref(), \
sizes_and_strides_.strides_arrayref()))
bool_is_contiguous TensorImpl::compute_contiguous() const {
return COMPUTE_WITH_SIZES_STRIDES_NUMEL(_compute_contiguous);
}
template <typename T>
bool_is_channels_last_contiguous _compute_channels_last_contiguous_2d(
ArrayRef<T> sizes,
ArrayRef<T> strides) {
// Please don't combine these code, constant array is used here to let
// compiler fully unroll the loop to get better performance
switch (sizes_and_strides_.size()) {
switch (sizes.size()) {
case 4: {
int64_t expected = 1;
T expected = 1;
for (auto& d : {1, 3, 2, 0}) {
const auto size_d = sizes_and_strides_.size_at_unchecked(d);
const auto size_d = sizes[d];
if (size_d != 1) {
if (sizes_and_strides_.stride_at_unchecked(d) != expected) {
return false;
if (strides[d] != expected) {
return bool_is_channels_last_contiguous(false);
}
expected *= size_d;
}
}
return true;
return bool_is_channels_last_contiguous(true);
}
// NOLINTNEXTLINE(bugprone-branch-clone)
case 3:
// TODO dim == 3 case will be enabled once it is fully tested
return false;
return bool_is_channels_last_contiguous(false);
default:
return false;
return bool_is_channels_last_contiguous(false);
}
}
bool TensorImpl::compute_channels_last_contiguous_3d() const {
bool_is_channels_last_contiguous TensorImpl::
compute_channels_last_contiguous_2d() const {
return COMPUTE_WITH_SIZES_STRIDES(_compute_channels_last_contiguous_2d);
}
template <typename T>
bool_is_channels_last_3d_contiguous _compute_channels_last_contiguous_3d(
ArrayRef<T> sizes,
ArrayRef<T> strides) {
// Please don't combine these code, constant array is used here to let
// compiler fully unroll the loop to get better performance
switch (sizes_and_strides_.size()) {
switch (sizes.size()) {
case 5: {
int64_t expected = 1;
T expected = 1;
for (auto& d : {1, 4, 3, 2, 0}) {
const auto size_d = sizes_and_strides_.size_at_unchecked(d);
const auto size_d = sizes[d];
if (size_d != 1) {
if (sizes_and_strides_.stride_at_unchecked(d) != expected) {
return false;
if (strides[d] != expected) {
return bool_is_channels_last_3d_contiguous(false);
}
expected *= size_d;
}
}
return true;
return bool_is_channels_last_3d_contiguous(true);
}
// NOLINTNEXTLINE(bugprone-branch-clone)
case 4:
// TODO dim == 4 case will be enabled once it is fully tested
return false;
return bool_is_channels_last_3d_contiguous(false);
default:
return false;
return bool_is_channels_last_3d_contiguous(false);
}
}
bool TensorImpl::compute_strides_like_channels_last_2d() const {
return is_channels_last_strides_2d(
TensorImpl::sizes(), TensorImpl::strides());
bool_is_channels_last_3d_contiguous TensorImpl::
compute_channels_last_contiguous_3d() const {
return COMPUTE_WITH_SIZES_STRIDES(_compute_channels_last_contiguous_3d);
}
bool TensorImpl::compute_strides_like_channels_last_3d() const {
return is_channels_last_strides_3d(
TensorImpl::sizes(), TensorImpl::strides());
bool_is_channels_last TensorImpl::compute_strides_like_channels_last_2d()
const {
return bool_is_channels_last(
COMPUTE_WITH_SIZES_STRIDES(is_channels_last_strides_2d));
}
bool TensorImpl::compute_non_overlapping_and_dense() const {
if (dim() == 1) {
return sizes_and_strides_.size_at_unchecked(0) < 2 ||
sizes_and_strides_.stride_at_unchecked(0) == 1;
bool_is_channels_last_3d TensorImpl::compute_strides_like_channels_last_3d()
const {
return bool_is_channels_last_3d(
COMPUTE_WITH_SIZES_STRIDES(is_channels_last_strides_3d));
}
template <typename T>
bool_is_non_overlapping_and_dense _compute_non_overlapping_and_dense(
ArrayRef<T> sizes,
ArrayRef<T> strides) {
auto dim = sizes.size();
if (dim == 1) {
return bool_is_non_overlapping_and_dense(sizes[0] < 2 || strides[0] == 1);
}
SmallVector<int64_t, 5> perm;
perm.resize(dim());
for (const auto i : c10::irange(dim())) {
perm.resize(dim);
for (const auto i : c10::irange(dim)) {
perm[i] = i;
}
// Sort by strides, leaving 0 and 1 sized dims at the end of the array
std::sort(perm.begin(), perm.end(), [&](int64_t a, int64_t b) {
if (sizes_and_strides_.size_at_unchecked(a) < 2) {
if (sizes[a] < 2) {
return false;
} else if (sizes_and_strides_.size_at_unchecked(b) < 2) {
} else if (sizes[b] < 2) {
return true;
}
return sizes_and_strides_.stride_at_unchecked(a) <
sizes_and_strides_.stride_at_unchecked(b);
return strides[a] < strides[b];
});
int64_t require_stride = 1;
for (const auto i : c10::irange(dim())) {
const auto size_perm_i = sizes_and_strides_.size_at_unchecked(perm[i]);
T require_stride = 1;
for (const auto i : c10::irange(dim)) {
const auto size_perm_i = sizes[perm[i]];
if (size_perm_i < 2) {
return true;
return bool_is_non_overlapping_and_dense(true);
}
if (sizes_and_strides_.stride_at_unchecked(perm[i]) != require_stride) {
return false;
if (strides[perm[i]] != require_stride) {
return bool_is_non_overlapping_and_dense(false);
}
require_stride *= size_perm_i;
}
return true;
return bool_is_non_overlapping_and_dense(true);
}
bool_is_non_overlapping_and_dense TensorImpl::
compute_non_overlapping_and_dense() const {
return COMPUTE_WITH_SIZES_STRIDES(_compute_non_overlapping_and_dense);
}
void TensorImpl::release_resources() {
@ -390,12 +442,25 @@ impl::PyInterpreter& TensorImpl::load_pyobj_interpreter() const {
bool TensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const {
if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
// TODO: pass memory_format to is_contiguous call
return load_pyobj_interpreter()->is_contiguous(this);
return load_pyobj_interpreter()->is_contiguous(this, memory_format);
}
return is_contiguous_default(memory_format);
}
bool TensorImpl::is_strides_like_custom(at::MemoryFormat memory_format) const {
if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
return load_pyobj_interpreter()->is_strides_like(this, memory_format);
}
return is_strides_like_default(memory_format);
}
bool TensorImpl::is_non_overlapping_and_dense_custom() const {
if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
return load_pyobj_interpreter()->is_non_overlapping_and_dense(this);
}
return is_non_overlapping_and_dense_default();
}
IntArrayRef TensorImpl::sizes_custom() const {
if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomSizes))) {
return load_pyobj_interpreter()->sizes(this);
@ -575,12 +640,8 @@ c10::intrusive_ptr<TensorImpl> TensorImpl::shallow_copy_and_detach_core(
/*version_counter=*/std::forward<VariableVersion>(version_counter),
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
// We currently don't support refresh_numel() and refresh_contiguous(). It's
// plausible that we could support it, but currently done to unblock.
if (!has_symbolic_sizes_strides()) {
impl->refresh_numel();
impl->refresh_contiguous();
}
impl->refresh_numel();
impl->refresh_contiguous();
return impl;
}
@ -634,8 +695,8 @@ void TensorImpl::copy_generic_tensor_metadata(
dest_impl->extra_meta_ = src_impl->extra_meta_->clone();
}
// NB: symbolic sizes and strides are copied, but custom policy is
// NOT (you have no Python object to dispatch to!)
// NB: symbolic sizes and strides are copied as is custom policy, but python
// policy is NOT (you have no Python object to dispatch to!)
// NB: subclass relevant policy doesn't have to be copied; the
// constructor sets this up
@ -885,33 +946,6 @@ void TensorImpl::ShareExternalPointer(
}
}
bool _compute_contiguous(const ExtraMeta& extra_meta, std::vector<int> order) {
if (order.size() != extra_meta.sizes_.size())
return false;
bool is_contiguous = true;
if (extra_meta.numel_ == 0)
return is_contiguous;
SymInt z = 1;
for (auto d : order) {
const auto size_d = extra_meta.sizes_.at(d);
if (size_d != 1) {
if (extra_meta.strides_.at(d) == z) {
z *= size_d;
} else {
is_contiguous = false;
break;
}
}
}
return is_contiguous;
}
bool _compute_contiguous(const ExtraMeta& extra_meta) {
std::vector<int> order(extra_meta.sizes_.size());
std::iota(order.rbegin(), order.rend(), 0);
return _compute_contiguous(extra_meta, order);
}
void clone_symvec(SymIntArrayRef src, SymDimVector& dst) {
dst.clear();
dst.reserve(src.size());
@ -952,16 +986,8 @@ void TensorImpl::set_sizes_and_strides(
if (storage_offset.has_value())
extra_meta_->storage_offset_ = storage_offset->clone();
SymInt numel = 1;
for (const auto& s : sizes) {
numel *= s;
}
extra_meta_->numel_ = numel;
extra_meta_->is_contiguous_ = _compute_contiguous(*extra_meta_);
extra_meta_->is_channels_last_contiguous_ =
_compute_contiguous(*extra_meta_, {1, 3, 2, 0});
extra_meta_->is_channels_last_3d_contiguous_ =
_compute_contiguous(*extra_meta_, {1, 4, 3, 2, 0});
refresh_numel();
refresh_contiguous();
}
namespace impl {

View File

@ -21,6 +21,7 @@
#include <c10/util/irange.h>
#include <c10/util/python_stub.h>
#include <c10/util/safe_numerics.h>
#include <c10/util/strong_type.h>
#include <algorithm>
#include <atomic>
@ -225,16 +226,42 @@ struct C10_API NamedTensorMetaInterface {
};
};
template <typename T>
using strong_bool = strong::
type<bool, T, strong::regular, strong::iostreamable, strong::boolean>;
// For ease of copy pasting
#if 0
is_contiguous
is_channels_last_contiguous
is_channels_last_3d_contiguous
is_channels_last
is_channels_last_3d
is_non_overlapping_and_dense
#endif
using bool_is_contiguous = strong_bool<struct bool_is_contiguous_>;
using bool_is_channels_last_contiguous =
strong_bool<struct bool_is_channels_last_contiguous_>;
using bool_is_channels_last_3d_contiguous =
strong_bool<struct bool_is_channels_last_3d_contiguous_>;
using bool_is_channels_last = strong_bool<struct bool_is_channels_last_>;
using bool_is_channels_last_3d = strong_bool<struct bool_is_channels_last_3d_>;
using bool_is_non_overlapping_and_dense =
strong_bool<struct bool_is_non_overlapping_and_dense_>;
struct C10_API ExtraMeta {
SymDimVector sizes_ = {0};
SymDimVector strides_ = {1};
SymInt numel_ = 1;
SymInt storage_offset_ = 0;
bool is_contiguous_ = true;
bool is_channels_last_contiguous_ = false;
bool is_channels_last_3d_contiguous_ = false;
// TODO:
// SymBool is_contiguous_;
// TODO: make these all SymBool
bool_is_contiguous is_contiguous_{true};
bool_is_channels_last_contiguous is_channels_last_contiguous_{false};
bool_is_channels_last_3d_contiguous is_channels_last_3d_contiguous_{false};
bool_is_channels_last is_channels_last_{false};
bool_is_channels_last_3d is_channels_last_3d_{false};
bool_is_non_overlapping_and_dense is_non_overlapping_and_dense_{true};
std::unique_ptr<c10::NamedTensorMetaInterface> named_tensor_meta_ = nullptr;
ExtraMeta() {}
@ -244,9 +271,12 @@ struct C10_API ExtraMeta {
SymDimVector strides,
SymInt numel,
SymInt storage_offset,
bool is_contiguous,
bool is_channels_last_contiguous,
bool is_channels_last_3d_contiguous,
bool_is_contiguous is_contiguous,
bool_is_channels_last_contiguous is_channels_last_contiguous,
bool_is_channels_last_3d_contiguous is_channels_last_3d_contiguous,
bool_is_channels_last is_channels_last,
bool_is_channels_last_3d is_channels_last_3d,
bool_is_non_overlapping_and_dense is_non_overlapping_and_dense,
std::unique_ptr<c10::NamedTensorMetaInterface> named_tensor_meta)
: sizes_(std::move(sizes)),
strides_(std::move(strides)),
@ -255,6 +285,9 @@ struct C10_API ExtraMeta {
is_contiguous_(is_contiguous),
is_channels_last_contiguous_(is_channels_last_contiguous),
is_channels_last_3d_contiguous_(is_channels_last_3d_contiguous),
is_channels_last_(is_channels_last),
is_channels_last_3d_(is_channels_last_3d),
is_non_overlapping_and_dense_(is_non_overlapping_and_dense),
named_tensor_meta_(std::move(named_tensor_meta)) {}
std::unique_ptr<ExtraMeta> clone() const {
@ -266,6 +299,9 @@ struct C10_API ExtraMeta {
is_contiguous_,
is_channels_last_contiguous_,
is_channels_last_3d_contiguous_,
is_channels_last_,
is_channels_last_3d_,
is_non_overlapping_and_dense_,
named_tensor_meta_ ? named_tensor_meta_->clone() : nullptr);
}
};
@ -782,11 +818,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
bool is_contiguous_default(at::MemoryFormat memory_format) const {
if (has_symbolic_sizes_strides_) {
if (memory_format == at::MemoryFormat::ChannelsLast) {
return extra_meta_->is_channels_last_contiguous_;
return bool(extra_meta_->is_channels_last_contiguous_);
} else if (memory_format == at::MemoryFormat::ChannelsLast3d) {
return extra_meta_->is_channels_last_3d_contiguous_;
return bool(extra_meta_->is_channels_last_3d_contiguous_);
}
return extra_meta_->is_contiguous_;
return bool(extra_meta_->is_contiguous_);
}
if (memory_format == at::MemoryFormat::ChannelsLast) {
@ -797,6 +833,34 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return is_contiguous_;
}
bool is_strides_like_default(at::MemoryFormat memory_format) const {
if (has_symbolic_sizes_strides_) {
if (memory_format == at::MemoryFormat::ChannelsLast) {
return bool(extra_meta_->is_channels_last_);
} else if (memory_format == at::MemoryFormat::ChannelsLast3d) {
return bool(extra_meta_->is_channels_last_3d_);
} else {
return false;
}
}
if (memory_format == at::MemoryFormat::ChannelsLast) {
return is_channels_last_;
} else if (memory_format == at::MemoryFormat::ChannelsLast3d) {
return is_channels_last_3d_;
} else {
return false;
}
}
bool is_non_overlapping_and_dense_default() const {
if (has_symbolic_sizes_strides_) {
return bool(extra_meta_->is_non_overlapping_and_dense_);
} else {
return is_non_overlapping_and_dense_;
}
}
// NB: these dim accessor functions don't have _default(), as you can use
// sizes_default/strides_default
/**
@ -884,6 +948,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
*/
// sizes_strides_policy_ >= CustomStrides
virtual bool is_contiguous_custom(at::MemoryFormat memory_format) const;
virtual bool is_strides_like_custom(at::MemoryFormat memory_format) const;
virtual bool is_non_overlapping_and_dense_custom() const;
// sizes_strides_policy_ >= CustomSizes
// Currently this method only exists to be overwritten by subclasses such as
// NestedTensorImpl.
@ -1589,7 +1655,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
sizes_and_strides_.set_sizes(new_size);
refresh_numel();
empty_tensor_restride(MemoryFormat::Contiguous);
empty_tensor_restride(
MemoryFormat::Contiguous); // calls refresh_contiguous()
}
/**
@ -2273,16 +2340,26 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
refresh_contiguous();
}
bool is_strides_like(at::MemoryFormat memory_format) const {
if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) {
return is_strides_like_custom(memory_format);
}
return is_strides_like_default(memory_format);
}
bool is_strides_like_channels_last() const {
return is_channels_last_;
return is_strides_like(at::MemoryFormat::ChannelsLast);
}
bool is_strides_like_channels_last_3d() const {
return is_channels_last_3d_;
return is_strides_like(at::MemoryFormat::ChannelsLast3d);
}
bool is_non_overlapping_and_dense() const {
return is_non_overlapping_and_dense_;
if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) {
return is_non_overlapping_and_dense_custom();
}
return is_non_overlapping_and_dense_default();
}
bool has_symbolic_sizes_strides() const {
@ -2360,12 +2437,16 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
/**
* Compute the number of elements based on the sizes of a tensor.
*/
// NB: This is ONLY called when sizes_and_strides_ is used directly; if
// we are virtualizing, then numel calls are virtualized as well, and this
// should never get called
int64_t compute_numel() const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!has_symbolic_sizes_strides_);
#if C10_HAS_BUILTIN_OVERFLOW() && !defined(C10_MOBILE)
// Use overflow checks if supported by the compiler
return safe_compute_numel();
#else
return c10::multiply_integers(sizes());
return c10::multiply_integers(sizes_and_strides_.sizes_arrayref());
#endif
}
@ -2375,8 +2456,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
* using a sparse layout has multiple dimensions with large sizes.
*/
int64_t safe_compute_numel() const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!has_symbolic_sizes_strides_);
uint64_t n = 1;
bool overflows = c10::safe_multiplies_u64(sizes(), &n);
bool overflows =
c10::safe_multiplies_u64(sizes_and_strides_.sizes_arrayref(), &n);
constexpr auto numel_max = std::min(
static_cast<uint64_t>(std::numeric_limits<int64_t>::max()),
static_cast<uint64_t>(std::numeric_limits<size_t>::max()));
@ -2386,21 +2469,31 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return static_cast<int64_t>(n);
}
SymInt compute_sym_numel() const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(has_symbolic_sizes_strides_);
SymInt numel = 1;
for (const auto& s : extra_meta_->sizes_) {
numel *= s;
}
return numel;
}
/**
* Compute whether or not a tensor is contiguous based on the sizes and
* strides of a tensor.
*/
bool compute_contiguous() const;
bool_is_contiguous compute_contiguous() const;
bool compute_channels_last_contiguous_2d() const;
bool_is_channels_last_contiguous compute_channels_last_contiguous_2d() const;
bool compute_channels_last_contiguous_3d() const;
bool_is_channels_last_3d_contiguous compute_channels_last_contiguous_3d()
const;
bool compute_strides_like_channels_last_2d() const;
bool_is_channels_last compute_strides_like_channels_last_2d() const;
bool compute_strides_like_channels_last_3d() const;
bool_is_channels_last_3d compute_strides_like_channels_last_3d() const;
bool compute_non_overlapping_and_dense() const;
bool_is_non_overlapping_and_dense compute_non_overlapping_and_dense() const;
protected:
/**
@ -2410,9 +2503,20 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
* For tensors with sparse layouts, use safe_refresh_numel() instead
* because it will catch integer overflow that may occur for tensors
* with sparse layouts and large dimensions.
*
* NB: We may uselessly recompute cached numel even in situations where
* it is completely never used (e.g., if CustomSizes for Python). However,
* we still must keep it up to date in case the Python overload
* returns None (in which case we will consult the field here). This also
* implies that sizes/strides will never be complete garbage; in the
* very worst case scenario, it will reflect a 1-dim zero size tensor.
*/
void refresh_numel() {
numel_ = compute_numel();
if (has_symbolic_sizes_strides_) {
extra_meta_->numel_ = compute_sym_numel();
} else {
numel_ = compute_numel();
}
}
/**
@ -2422,7 +2526,13 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
* overflow when computing numel.
*/
void safe_refresh_numel() {
numel_ = safe_compute_numel();
if (has_symbolic_sizes_strides_) {
// NB: sym numel is done with symbolic integers, which handle overflow
// checking
extra_meta_->numel_ = compute_sym_numel();
} else {
numel_ = safe_compute_numel();
}
}
/**
@ -2430,49 +2540,94 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
* or strides.
*/
void refresh_contiguous() {
TORCH_CHECK(
!has_symbolic_sizes_strides_,
"refresh_contiguous() called on tensor with symbolic shape")
auto set_fields =
[&](bool_is_contiguous is_contiguous,
bool_is_channels_last_contiguous is_channels_last_contiguous,
bool_is_channels_last_3d_contiguous is_channels_last_3d_contiguous,
bool_is_channels_last is_channels_last,
bool_is_channels_last_3d is_channels_last_3d,
bool_is_non_overlapping_and_dense is_non_overlapping_and_dense) {
if (has_symbolic_sizes_strides_) {
extra_meta_->is_contiguous_ = is_contiguous;
extra_meta_->is_channels_last_contiguous_ =
is_channels_last_contiguous;
extra_meta_->is_channels_last_3d_contiguous_ =
is_channels_last_3d_contiguous;
extra_meta_->is_channels_last_ = is_channels_last;
extra_meta_->is_channels_last_3d_ = is_channels_last_3d;
extra_meta_->is_non_overlapping_and_dense_ =
is_non_overlapping_and_dense;
} else {
is_contiguous_ = bool(is_contiguous);
is_channels_last_contiguous_ = bool(is_channels_last_contiguous);
is_channels_last_3d_contiguous_ =
bool(is_channels_last_3d_contiguous);
is_channels_last_ = bool(is_channels_last);
is_channels_last_3d_ = bool(is_channels_last_3d);
is_non_overlapping_and_dense_ = bool(is_non_overlapping_and_dense);
}
};
is_contiguous_ = compute_contiguous();
auto is_contiguous = compute_contiguous();
// Note:
// Dim 0, 1, 2 will never be a channels last 2d/3d format
// Dim 3+ is possibly be a channels last 2d format (Dim 4 only at this
// point) Dim 4+ is possibly be a channels last 3d format (Dim 5 only at
// this point)
switch (dim()) {
case 4:
is_channels_last_contiguous_ = compute_channels_last_contiguous_2d();
is_channels_last_3d_contiguous_ = false;
is_channels_last_ = compute_strides_like_channels_last_2d();
is_channels_last_3d_ = false;
is_non_overlapping_and_dense_ = is_contiguous_ ||
is_channels_last_contiguous_ || compute_non_overlapping_and_dense();
case 4: {
auto is_channels_last_contiguous =
compute_channels_last_contiguous_2d();
set_fields(
is_contiguous,
is_channels_last_contiguous,
bool_is_channels_last_3d_contiguous(false),
compute_strides_like_channels_last_2d(),
bool_is_channels_last_3d(false),
bool_is_non_overlapping_and_dense(
is_contiguous || is_channels_last_contiguous ||
compute_non_overlapping_and_dense()));
break;
case 5:
is_channels_last_contiguous_ = compute_channels_last_contiguous_2d();
is_channels_last_3d_contiguous_ = !is_channels_last_contiguous_ &&
compute_channels_last_contiguous_3d();
is_channels_last_ = !is_channels_last_3d_contiguous_ &&
compute_strides_like_channels_last_2d();
is_channels_last_3d_ =
!is_channels_last_ && compute_strides_like_channels_last_3d();
is_non_overlapping_and_dense_ = is_contiguous_ ||
is_channels_last_contiguous_ || is_channels_last_3d_contiguous_ ||
compute_non_overlapping_and_dense();
}
case 5: {
auto is_channels_last_contiguous =
compute_channels_last_contiguous_2d();
auto is_channels_last_3d_contiguous =
bool_is_channels_last_3d_contiguous(
!is_channels_last_contiguous &&
compute_channels_last_contiguous_3d());
auto is_channels_last = bool_is_channels_last(
!is_channels_last_3d_contiguous &&
compute_strides_like_channels_last_2d());
auto is_channels_last_3d = bool_is_channels_last_3d(
!is_channels_last && compute_strides_like_channels_last_3d());
auto is_non_overlapping_and_dense = bool_is_non_overlapping_and_dense(
is_contiguous || is_channels_last_contiguous ||
is_channels_last_3d_contiguous ||
compute_non_overlapping_and_dense());
set_fields(
is_contiguous,
is_channels_last_contiguous,
is_channels_last_3d_contiguous,
is_channels_last,
is_channels_last_3d,
is_non_overlapping_and_dense);
break;
}
default:
is_channels_last_contiguous_ = false;
is_channels_last_3d_contiguous_ = false;
// is_channels_last_ and is_channels_last_3d_ are suggested
// memory_format. Being channels_last_contiguous doesn't necessarily
// mean the tensor is strided like channels_last: for strides on channel
// dimension could suggest desired memory_layout, but it doesn't affect
// memory storage
is_channels_last_ = false;
is_channels_last_3d_ = false;
is_non_overlapping_and_dense_ =
is_contiguous_ || compute_non_overlapping_and_dense();
set_fields(
is_contiguous,
bool_is_channels_last_contiguous(false),
bool_is_channels_last_3d_contiguous(false),
bool_is_channels_last(false),
bool_is_channels_last_3d(false),
bool_is_non_overlapping_and_dense(
is_contiguous || compute_non_overlapping_and_dense()));
}
}

View File

@ -34,9 +34,16 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
PANIC(python_dispatcher);
}
bool is_contiguous(const TensorImpl* self) const override {
bool is_contiguous(const TensorImpl* self, at::MemoryFormat) const override {
PANIC(is_contiguous);
}
bool is_strides_like(const TensorImpl* self, at::MemoryFormat)
const override {
PANIC(is_strides_like);
}
bool is_non_overlapping_and_dense(const TensorImpl* self) const override {
PANIC(is_non_overlapping_and_dense);
}
c10::Device device(const TensorImpl* self) const override {
PANIC(device);
}

View File

@ -2,6 +2,7 @@
#include <c10/core/Device.h>
#include <c10/core/Layout.h>
#include <c10/core/MemoryFormat.h>
#include <c10/core/SymIntArrayRef.h>
#include <c10/macros/Macros.h>
#include <c10/util/ArrayRef.h>
@ -146,7 +147,11 @@ struct C10_API PyInterpreterVTable {
c10::DispatchKeySet,
torch::jit::Stack* stack) const = 0;
virtual bool is_contiguous(const TensorImpl* self) const = 0;
virtual bool is_contiguous(const TensorImpl* self, at::MemoryFormat)
const = 0;
virtual bool is_strides_like(const TensorImpl* self, at::MemoryFormat)
const = 0;
virtual bool is_non_overlapping_and_dense(const TensorImpl* self) const = 0;
virtual c10::Device device(const TensorImpl* self) const = 0;
virtual int64_t dim(const TensorImpl* self) const = 0;
virtual c10::IntArrayRef strides(const TensorImpl* self) const = 0;

View File

@ -33,20 +33,12 @@
#endif
#ifndef STRONG_HAS_STD_FORMAT
#if __has_include(<format>)
#define STRONG_HAS_STD_FORMAT 1
#else
#define STRONG_HAS_STD_FORMAT 0
#endif
#endif
#ifndef STRONG_HAS_FMT_FORMAT
#if __has_include(<fmt/format.h>)
#define STRONG_HAS_FMT_FORMAT 1
#else
#define STRONG_HAS_FMT_FORMAT 0
#endif
#endif
#if STRONG_HAS_STD_FORMAT
#include <format>

View File

@ -630,7 +630,7 @@ class FakeTensorOperatorInvariants(TestCase):
for schema in self.get_all_aten_schemas():
if "_like" == schema.name[-5:]:
op = self.get_aten_op(schema)
self.assertTrue(op in torch._subclasses.fake_tensor._like_tensor_constructors)
self.assertIn(op, torch._subclasses.fake_tensor._like_tensor_constructors)
if __name__ == "__main__":
run_tests()

View File

@ -1402,6 +1402,38 @@ $0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memo
with self.assertRaisesRegex(TypeError, err_msg):
e.contiguous()
def test_fancy_strides(self):
calls = []
class ExampleTensor(torch.Tensor):
@staticmethod
def __new__(cls, data):
return TestPythonDispatch.subclass_helper(cls, data, False, dispatch_sizes_strides_policy="strides")
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func in [
torch.ops.aten.is_contiguous.default,
torch.ops.aten.is_contiguous.memory_format,
torch.ops.aten.is_strides_like_format.default,
torch.ops.aten.is_non_overlapping_and_dense.default,
torch.ops.aten.stride.default
]:
calls.append((func, list(args)[1:]))
return None
with no_dispatch():
return func(*args, **kwargs)
e = ExampleTensor(torch.randn(2, 2))
self.assertFalse(e.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(calls, [(torch.ops.aten.is_contiguous.memory_format, [torch.channels_last])])
calls.clear()
self.assertFalse(torch.ops.aten.is_strides_like_format.default(e, torch.channels_last))
self.assertEqual(calls, [(torch.ops.aten.is_strides_like_format.default, [torch.channels_last])])
calls.clear()
self.assertTrue(torch.ops.aten.is_non_overlapping_and_dense.default(e))
self.assertEqual(calls, [(torch.ops.aten.is_non_overlapping_and_dense.default, [])])
def test_device_slowpath(self):
for use_wrapper_subclass in [True]:
class ExampleTensor1(torch.Tensor):

View File

@ -67,7 +67,7 @@ static PyObject * THPVariable__parse_to(PyObject* module, PyObject* args, PyObje
}
PyTuple_SET_ITEM(tuple.get(), 2, torch::autograd::utils::wrap(non_blocking));
if (opt_memory_format.has_value()) {
PyTuple_SET_ITEM(tuple.get(), 3, torch::utils::getTHPMemoryFormat(opt_memory_format.value()).release().ptr());
PyTuple_SET_ITEM(tuple.get(), 3, torch::utils::getTHPMemoryFormat(opt_memory_format.value()));
} else {
Py_INCREF(Py_None);
PyTuple_SET_ITEM(tuple.get(), 3, Py_None);

View File

@ -132,8 +132,7 @@ std::pair<py::object, py::dict> parseIValuesToPyArgsKwargs(
return py::reinterpret_borrow<py::object>(
reinterpret_cast<PyObject*>(obj));
} else if (match(c10::MemoryFormatType::Kind)) {
return torch::utils::getTHPMemoryFormat(
static_cast<c10::MemoryFormat>(arguments[idx].toInt()));
return py::cast(static_cast<c10::MemoryFormat>(arguments[idx].toInt()));
} else {
return torch::jit::toPyObject(arguments[idx]);
}
@ -241,7 +240,9 @@ struct ConcretePyInterpreterVTable final
c10::DispatchKeySet,
torch::jit::Stack* stack) const override;
bool is_contiguous(const TensorImpl* self) const override;
bool is_contiguous(const TensorImpl* self, at::MemoryFormat) const override;
bool is_strides_like(const TensorImpl* self, at::MemoryFormat) const override;
bool is_non_overlapping_and_dense(const TensorImpl* self) const override;
c10::Device device(const TensorImpl* self) const override;
int64_t dim(const TensorImpl* self) const override;
c10::IntArrayRef strides(const TensorImpl* self) const override;
@ -2179,7 +2180,12 @@ py::object torchDispatchFromTensorImpl(
const c10::TensorImpl* self,
const char* func_name,
PyObject* torch_api_function,
const char* module_name) {
const char* module_name,
// WARNING: MUST NOT BE TENSOR ARGS
c10::SmallVector<py::object, 1> extra_args = {}) {
if (torch_api_function == nullptr) {
throw python_error();
}
TORCH_CHECK(
PyGILState_Check(),
"GIL must be held before you call parseIValuesToPyArgsKwargs");
@ -2194,8 +2200,16 @@ py::object torchDispatchFromTensorImpl(
// NB: this may not be a python tensor if you got here from a mode!
// TORCH_INTERNAL_ASSERT(isPythonTensor(self_t));
append_overloaded_tensor(&overloaded_args, self_p.ptr());
auto args = py::reinterpret_steal<py::object>(PyTuple_New(1));
auto args =
py::reinterpret_steal<py::object>(PyTuple_New(1 + extra_args.size()));
PyTuple_SET_ITEM(args.ptr(), 0, self_p.release().ptr());
int64_t i = 1;
for (auto& a : extra_args) {
if (a.ptr() == nullptr)
throw python_error();
PyTuple_SET_ITEM(args.ptr(), i, std::move(a).release().ptr());
i++;
}
py::dict kwargs;
@ -2320,7 +2334,7 @@ void ConcretePyInterpreterVTable::python_dispatcher(
py::object obj = py::reinterpret_steal<py::object>(
PyObject_Call(handler.ptr(), args.ptr(), kwargs.ptr()));
if (obj == nullptr) {
if (obj.ptr() == nullptr) {
throw python_error();
}
@ -2353,24 +2367,107 @@ c10::intrusive_ptr<TensorImpl> ConcretePyInterpreterVTable::detach(
}
bool ConcretePyInterpreterVTable::is_contiguous(
const c10::TensorImpl* self,
at::MemoryFormat memory_format) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
py::object out;
if (memory_format == at::MemoryFormat::Contiguous) {
// For backwards compatibility
out = torchDispatchFromTensorImpl(
self,
"is_contiguous",
py::module::import("torch")
.attr("ops")
.attr("aten")
.attr("is_contiguous")
.attr("default")
.ptr(),
"torch.ops.aten");
} else {
out = torchDispatchFromTensorImpl(
self,
"is_contiguous",
py::module::import("torch")
.attr("ops")
.attr("aten")
.attr("is_contiguous")
.attr("memory_format")
.ptr(),
"torch.ops.aten",
{py::cast(memory_format)});
}
if (out.is(py::none())) {
return self->is_contiguous_default(memory_format);
}
TORCH_CHECK(
PyBool_Check(out.ptr()),
"is_contiguous returned invalid type ",
py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
", expected bool");
return PyObject_IsTrue(out.ptr());
}
bool ConcretePyInterpreterVTable::is_strides_like(
const c10::TensorImpl* self,
at::MemoryFormat memory_format) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
auto out = torchDispatchFromTensorImpl(
self,
"is_strides_like",
py::module::import("torch")
.attr("ops")
.attr("aten")
// NB: intentionally suffixed with _format to avoid
// triggering matches against "_like" suffix
.attr("is_strides_like_format")
.attr("default")
.ptr(),
"torch.ops.aten",
{py::cast(memory_format)});
if (out.is(py::none())) {
return self->is_strides_like_default(memory_format);
}
TORCH_CHECK(
PyBool_Check(out.ptr()),
"is_strides_like_format returned invalid type ",
py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
", expected bool");
return PyObject_IsTrue(out.ptr());
}
bool ConcretePyInterpreterVTable::is_non_overlapping_and_dense(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
auto out = torchDispatchFromTensorImpl(
self,
"is_contiguous",
"is_non_overlapping_and_dense",
py::module::import("torch")
.attr("ops")
.attr("aten")
.attr("is_contiguous")
.attr("is_non_overlapping_and_dense")
.attr("default")
.ptr(),
"torch.ops.aten");
if (out.is(py::none())) {
return self->is_non_overlapping_and_dense_default();
}
TORCH_CHECK(
PyBool_Check(out.ptr()),
"is_contiguous returned invalid type ",
"is_non_overlapping_and_dense returned invalid type ",
py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
", expected bool");

View File

@ -561,6 +561,34 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
pack(stack, result);
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA(
"aten::is_contiguous.memory_format(Tensor self, MemoryFormat memory_format) -> bool"),
[](Stack& stack) {
auto memory_format = pop(stack).toMemoryFormat();
auto t = pop(stack).toTensor();
push(stack, t.is_contiguous(memory_format));
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
// NB: intentionally suffixed with extra _format to prevent tests for
// "_like" suffix from triggering on this
TORCH_SELECTIVE_SCHEMA(
"aten::is_strides_like_format(Tensor self, MemoryFormat memory_format) -> bool"),
[](Stack& stack) {
auto memory_format = pop(stack).toMemoryFormat();
auto t = pop(stack).toTensor();
push(stack, t.unsafeGetTensorImpl()->is_strides_like(memory_format));
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA(
"aten::is_non_overlapping_and_dense(Tensor self) -> bool"),
[](Stack& stack) {
auto t = pop(stack).toTensor();
push(stack, t.unsafeGetTensorImpl()->is_non_overlapping_and_dense());
},
aliasAnalysisFromSchema()),
// these ops are generic over the list element type.
// CREATING GENERIC_LIST_OPS
OperatorGeneratorArgs(

View File

@ -85,9 +85,6 @@ LTCTensorImpl::LTCTensorImpl(LazyTensor&& tensor)
c10::scalarTypeToTypeMeta(tensor.dtype()),
backendDeviceToAtenDevice(tensor.GetDevice())),
tensor_(c10::make_intrusive<LazyTensor>(std::move(tensor))) {
// This is a temporary fix for a PyTorch core issue,
// according to https://github.com/pytorch/xla/pull/2682.
is_non_overlapping_and_dense_ = false;
set_custom_sizes_strides(SizesStridesPolicy::CustomSizes);
}
@ -185,12 +182,33 @@ int64_t LTCTensorImpl::numel_custom() const {
return numel_default();
}
int64_t LTCTensorImpl::storage_offset_custom() const {
return 0;
}
bool LTCTensorImpl::is_strides_like_custom(
c10::MemoryFormat memory_format) const {
TORCH_INTERNAL_ASSERT(memory_format != at::MemoryFormat::Contiguous);
return false;
}
bool LTCTensorImpl::is_non_overlapping_and_dense_custom() const {
// This should be true, but false as a temporary fix for a PyTorch core issue,
// according to https://github.com/pytorch/xla/pull/2682.
return false;
}
bool LTCTensorImpl::is_contiguous_custom(c10::MemoryFormat _unused) const {
// TODO(ezyang): I don't think this branch is actually necessary
// TODO(ezyang): I don't think this logic is right, shouldn't we pass on
// the memory format?
if (tensor_->CurrentTensorData()) {
return tensor_->CurrentTensorData()->is_contiguous();
}
// Only check that the storage is already contiguous.
CHECK(is_contiguous_) << "Non-contiguous storage for lazy tensor";
// TODO: I don't think logic is right, we should check the requested memory
// format before returning true
return true;
}

View File

@ -39,12 +39,15 @@ class TORCH_API LTCTensorImpl final : public c10::TensorImpl {
at::IntArrayRef sizes_custom() const override;
at::IntArrayRef strides_custom() const override;
int64_t dim_custom() const override;
int64_t numel_custom() const override;
int64_t storage_offset_custom() const override;
int64_t dim_custom() const override;
bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
bool is_strides_like_custom(at::MemoryFormat memory_format) const override;
bool is_non_overlapping_and_dense_custom() const override;
virtual c10::SymIntArrayRef sym_sizes_custom() const override;
virtual c10::SymIntArrayRef sym_strides_custom() const override;
c10::SymIntArrayRef sym_sizes_custom() const override;
c10::SymIntArrayRef sym_strides_custom() const override;
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
const at::Storage& storage() const override {

View File

@ -12,6 +12,8 @@
#include <torch/csrc/Device.h>
#include <torch/csrc/DynamicTypes.h>
#include <torch/csrc/Generator.h>
#include <torch/csrc/MemoryFormat.h>
#include <torch/csrc/utils/tensor_memoryformats.h>
#include <stdexcept>
#include <utility>
@ -107,6 +109,28 @@ struct TORCH_PYTHON_API type_caster<at::IntArrayRef> {
std::vector<int64_t> v_value;
};
template <>
struct TORCH_PYTHON_API type_caster<at::MemoryFormat> {
public:
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
PYBIND11_TYPE_CASTER(at::MemoryFormat, _("at::MemoryFormat"));
bool load(handle src, bool) {
PyObject* obj = src.ptr();
if (THPMemoryFormat_Check(obj)) {
value = reinterpret_cast<THPMemoryFormat*>(obj)->memory_format;
return true;
}
return false;
}
static handle cast(
at::MemoryFormat src,
return_value_policy /* policy */,
handle /* parent */) {
return handle(torch::utils::getTHPMemoryFormat(src));
}
};
template <>
struct type_caster<at::Device> {
public:

View File

@ -17,9 +17,11 @@ std::array<PyObject*, static_cast<int>(at::MemoryFormat::NumOptions)>
memory_format_registry = {};
} // anonymous namespace
py::object getTHPMemoryFormat(at::MemoryFormat memory_format) {
PyObject* getTHPMemoryFormat(at::MemoryFormat memory_format) {
return py::reinterpret_borrow<py::object>(
memory_format_registry[static_cast<size_t>(memory_format)]);
memory_format_registry[static_cast<size_t>(memory_format)])
.release()
.ptr();
}
void initializeMemoryFormats() {

View File

@ -1,13 +1,13 @@
#pragma once
#include <c10/core/MemoryFormat.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_stub.h>
namespace torch {
namespace utils {
void initializeMemoryFormats();
py::object getTHPMemoryFormat(c10::MemoryFormat);
PyObject* getTHPMemoryFormat(c10::MemoryFormat);
} // namespace utils
} // namespace torch

View File

@ -266,7 +266,27 @@ def is_contiguous(func, *args, **kwargs):
raise ValueError(
"MaskedTensors with sparse data do not have is_contiguous"
)
return data.is_contiguous()
return func(data, *args[1:], **kwargs)
@register_dispatch_func([torch.ops.aten.is_strides_like_format])
def is_strides_like_format(func, *args, **kwargs):
data = _get_data(args[0])
if data.is_sparse:
raise ValueError(
"MaskedTensors with sparse data do not have is_strides_like_format"
)
return func(data, *args[1:], **kwargs)
@register_dispatch_func([torch.ops.aten.is_non_overlapping_and_dense])
def is_non_overlapping_and_dense(func, *args, **kwargs):
data = _get_data(args[0])
if data.is_sparse:
raise ValueError(
"MaskedTensors with sparse data do not have is_non_overlapping_and_dense"
)
return func(data, *args[1:], **kwargs)
@register_dispatch_func([torch.ops.aten.contiguous])