mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Consistent compute numel/contiguous strategy with SymInts (#85858)
Previously, our handling for contiguity was inconsistent in the following ways: - is_strides_like 2d/3d and is_non_overlapping_and_dense always were computed based on sizes_and_strides_, even if you had symbolic ints - Furthermore, even if you set custom policy for strides, these quantities were not overridable by subclasses - Furthermore, we didn't even store these fields on ExtraMeta - We duplicate implementations of compute_contiguous (plain, channels last, channels last 3d) - We inconsistently called refresh_numel()/refresh_contiguous(), versus recomputing it ourselves This factor makes a consistent strategy for all of the boolean fields, and for numel computation. After this refactor: - All layout boolean fields are interposable via strides policy and can be overridden from Python; you will never access a garbage field - All layout boolean fields are on ExtraMeta - You can always call refresh_numel/contiguous, no matter if your Tensor is contiguous or not - The numel/layout boolean fields are always populated consistently with the sizes strides fields (either on Tensor or ExtraMeta), even if you have custom policy - There is only one implementation of the actual computation logic Signed-off-by: Edward Z. Yang <ezyang@fb.com> Differential Revision: [D39907696](https://our.internmc.facebook.com/intern/diff/D39907696) Pull Request resolved: https://github.com/pytorch/pytorch/pull/85858 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
84a06d7193
commit
3b6588ab74
@ -284,7 +284,7 @@ at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::IntArrayRef s
|
||||
auto inferred_size = at::infer_size_dv(size, self.numel());
|
||||
auto stride = at::detail::computeStride(self.sizes(), self.strides(), inferred_size);
|
||||
TORCH_INTERNAL_ASSERT(stride.has_value());
|
||||
out.unsafeGetTensorImpl()->set_sizes_and_strides(size, stride.value());
|
||||
out.unsafeGetTensorImpl()->set_sizes_and_strides(inferred_size, stride.value());
|
||||
return out;
|
||||
}
|
||||
|
||||
|
@ -114,10 +114,11 @@ inline std::vector<int64_t> get_channels_last_strides_3d(IntArrayRef sizes) {
|
||||
// input
|
||||
// 3. All helper functions have similar comments, only 1st helper function is
|
||||
// commented here.
|
||||
template <typename T>
|
||||
inline bool is_channels_last_strides_2d_s4(
|
||||
const IntArrayRef sizes,
|
||||
const IntArrayRef strides) {
|
||||
int64_t min = 0;
|
||||
const ArrayRef<T> sizes,
|
||||
const ArrayRef<T> strides) {
|
||||
T min = 0;
|
||||
// special case for trivial C dimension. default to NCHW
|
||||
if (strides[1] == 0) {
|
||||
return false;
|
||||
@ -155,10 +156,11 @@ inline bool is_channels_last_strides_2d_s4(
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool is_channels_last_strides_3d_s5(
|
||||
const IntArrayRef sizes,
|
||||
const IntArrayRef strides) {
|
||||
int64_t min = 0;
|
||||
const ArrayRef<T> sizes,
|
||||
const ArrayRef<T> strides) {
|
||||
T min = 0;
|
||||
if (strides[1] == 0) {
|
||||
return false;
|
||||
}
|
||||
@ -230,9 +232,10 @@ inline bool is_channels_last_strides_3d_s5(
|
||||
// implementation. Please check the helper functions
|
||||
// (is_channels_last_strides_*d_s*) for more details.
|
||||
|
||||
template <typename T>
|
||||
inline bool is_channels_last_strides_2d(
|
||||
const IntArrayRef sizes,
|
||||
const IntArrayRef strides) {
|
||||
const ArrayRef<T> sizes,
|
||||
const ArrayRef<T> strides) {
|
||||
switch (sizes.size()) {
|
||||
case 4:
|
||||
return is_channels_last_strides_2d_s4(sizes, strides);
|
||||
@ -244,9 +247,10 @@ inline bool is_channels_last_strides_2d(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool is_channels_last_strides_3d(
|
||||
const IntArrayRef sizes,
|
||||
const IntArrayRef strides) {
|
||||
const ArrayRef<T> sizes,
|
||||
const ArrayRef<T> strides) {
|
||||
switch (sizes.size()) {
|
||||
case 5:
|
||||
return is_channels_last_strides_3d_s5(sizes, strides);
|
||||
@ -258,4 +262,16 @@ inline bool is_channels_last_strides_3d(
|
||||
}
|
||||
}
|
||||
|
||||
inline bool is_channels_last_strides_2d(
|
||||
const IntArrayRef sizes,
|
||||
const IntArrayRef strides) {
|
||||
return is_channels_last_strides_2d<int64_t>(sizes, strides);
|
||||
}
|
||||
|
||||
inline bool is_channels_last_strides_3d(
|
||||
const IntArrayRef sizes,
|
||||
const IntArrayRef strides) {
|
||||
return is_channels_last_strides_3d<int64_t>(sizes, strides);
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
@ -227,15 +227,20 @@ void TensorImpl::HandleResize() {
|
||||
}
|
||||
}
|
||||
|
||||
bool TensorImpl::compute_contiguous() const {
|
||||
template <typename T>
|
||||
bool_is_contiguous _compute_contiguous(
|
||||
ArrayRef<T> sizes,
|
||||
ArrayRef<T> strides,
|
||||
T numel) {
|
||||
bool is_contiguous = true;
|
||||
if (is_empty())
|
||||
return is_contiguous;
|
||||
int64_t z = 1;
|
||||
for (int64_t d = dim() - 1; d >= 0; d--) {
|
||||
const auto size_d = sizes_and_strides_.size_at_unchecked(d);
|
||||
if (numel == 0)
|
||||
return bool_is_contiguous(is_contiguous);
|
||||
T z = 1;
|
||||
// NB: make sure we do signed arithmetic
|
||||
for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) {
|
||||
const auto size_d = sizes[d];
|
||||
if (size_d != 1) {
|
||||
if (sizes_and_strides_.stride_at_unchecked(d) == z) {
|
||||
if (strides[d] == z) {
|
||||
z *= size_d;
|
||||
} else {
|
||||
is_contiguous = false;
|
||||
@ -243,103 +248,150 @@ bool TensorImpl::compute_contiguous() const {
|
||||
}
|
||||
}
|
||||
}
|
||||
return is_contiguous;
|
||||
return bool_is_contiguous(is_contiguous);
|
||||
}
|
||||
|
||||
bool TensorImpl::compute_channels_last_contiguous_2d() const {
|
||||
// NB: intentionally bypass the normal accessors; we always want to be
|
||||
// consistent with what is actually stored on the struct
|
||||
#define COMPUTE_WITH_SIZES_STRIDES_NUMEL(TEMPLATE) \
|
||||
(has_symbolic_sizes_strides_ \
|
||||
? TEMPLATE<c10::SymInt>( \
|
||||
extra_meta_->sizes_, extra_meta_->strides_, extra_meta_->numel_) \
|
||||
: TEMPLATE<int64_t>( \
|
||||
sizes_and_strides_.sizes_arrayref(), \
|
||||
sizes_and_strides_.strides_arrayref(), \
|
||||
numel_))
|
||||
|
||||
#define COMPUTE_WITH_SIZES_STRIDES(TEMPLATE) \
|
||||
(has_symbolic_sizes_strides_ \
|
||||
? TEMPLATE<c10::SymInt>(extra_meta_->sizes_, extra_meta_->strides_) \
|
||||
: TEMPLATE<int64_t>( \
|
||||
sizes_and_strides_.sizes_arrayref(), \
|
||||
sizes_and_strides_.strides_arrayref()))
|
||||
|
||||
bool_is_contiguous TensorImpl::compute_contiguous() const {
|
||||
return COMPUTE_WITH_SIZES_STRIDES_NUMEL(_compute_contiguous);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool_is_channels_last_contiguous _compute_channels_last_contiguous_2d(
|
||||
ArrayRef<T> sizes,
|
||||
ArrayRef<T> strides) {
|
||||
// Please don't combine these code, constant array is used here to let
|
||||
// compiler fully unroll the loop to get better performance
|
||||
switch (sizes_and_strides_.size()) {
|
||||
switch (sizes.size()) {
|
||||
case 4: {
|
||||
int64_t expected = 1;
|
||||
T expected = 1;
|
||||
for (auto& d : {1, 3, 2, 0}) {
|
||||
const auto size_d = sizes_and_strides_.size_at_unchecked(d);
|
||||
const auto size_d = sizes[d];
|
||||
if (size_d != 1) {
|
||||
if (sizes_and_strides_.stride_at_unchecked(d) != expected) {
|
||||
return false;
|
||||
if (strides[d] != expected) {
|
||||
return bool_is_channels_last_contiguous(false);
|
||||
}
|
||||
expected *= size_d;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
return bool_is_channels_last_contiguous(true);
|
||||
}
|
||||
// NOLINTNEXTLINE(bugprone-branch-clone)
|
||||
case 3:
|
||||
// TODO dim == 3 case will be enabled once it is fully tested
|
||||
return false;
|
||||
return bool_is_channels_last_contiguous(false);
|
||||
default:
|
||||
return false;
|
||||
return bool_is_channels_last_contiguous(false);
|
||||
}
|
||||
}
|
||||
|
||||
bool TensorImpl::compute_channels_last_contiguous_3d() const {
|
||||
bool_is_channels_last_contiguous TensorImpl::
|
||||
compute_channels_last_contiguous_2d() const {
|
||||
return COMPUTE_WITH_SIZES_STRIDES(_compute_channels_last_contiguous_2d);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool_is_channels_last_3d_contiguous _compute_channels_last_contiguous_3d(
|
||||
ArrayRef<T> sizes,
|
||||
ArrayRef<T> strides) {
|
||||
// Please don't combine these code, constant array is used here to let
|
||||
// compiler fully unroll the loop to get better performance
|
||||
switch (sizes_and_strides_.size()) {
|
||||
switch (sizes.size()) {
|
||||
case 5: {
|
||||
int64_t expected = 1;
|
||||
T expected = 1;
|
||||
for (auto& d : {1, 4, 3, 2, 0}) {
|
||||
const auto size_d = sizes_and_strides_.size_at_unchecked(d);
|
||||
const auto size_d = sizes[d];
|
||||
if (size_d != 1) {
|
||||
if (sizes_and_strides_.stride_at_unchecked(d) != expected) {
|
||||
return false;
|
||||
if (strides[d] != expected) {
|
||||
return bool_is_channels_last_3d_contiguous(false);
|
||||
}
|
||||
expected *= size_d;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
return bool_is_channels_last_3d_contiguous(true);
|
||||
}
|
||||
// NOLINTNEXTLINE(bugprone-branch-clone)
|
||||
case 4:
|
||||
// TODO dim == 4 case will be enabled once it is fully tested
|
||||
return false;
|
||||
return bool_is_channels_last_3d_contiguous(false);
|
||||
default:
|
||||
return false;
|
||||
return bool_is_channels_last_3d_contiguous(false);
|
||||
}
|
||||
}
|
||||
|
||||
bool TensorImpl::compute_strides_like_channels_last_2d() const {
|
||||
return is_channels_last_strides_2d(
|
||||
TensorImpl::sizes(), TensorImpl::strides());
|
||||
bool_is_channels_last_3d_contiguous TensorImpl::
|
||||
compute_channels_last_contiguous_3d() const {
|
||||
return COMPUTE_WITH_SIZES_STRIDES(_compute_channels_last_contiguous_3d);
|
||||
}
|
||||
|
||||
bool TensorImpl::compute_strides_like_channels_last_3d() const {
|
||||
return is_channels_last_strides_3d(
|
||||
TensorImpl::sizes(), TensorImpl::strides());
|
||||
bool_is_channels_last TensorImpl::compute_strides_like_channels_last_2d()
|
||||
const {
|
||||
return bool_is_channels_last(
|
||||
COMPUTE_WITH_SIZES_STRIDES(is_channels_last_strides_2d));
|
||||
}
|
||||
|
||||
bool TensorImpl::compute_non_overlapping_and_dense() const {
|
||||
if (dim() == 1) {
|
||||
return sizes_and_strides_.size_at_unchecked(0) < 2 ||
|
||||
sizes_and_strides_.stride_at_unchecked(0) == 1;
|
||||
bool_is_channels_last_3d TensorImpl::compute_strides_like_channels_last_3d()
|
||||
const {
|
||||
return bool_is_channels_last_3d(
|
||||
COMPUTE_WITH_SIZES_STRIDES(is_channels_last_strides_3d));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool_is_non_overlapping_and_dense _compute_non_overlapping_and_dense(
|
||||
ArrayRef<T> sizes,
|
||||
ArrayRef<T> strides) {
|
||||
auto dim = sizes.size();
|
||||
if (dim == 1) {
|
||||
return bool_is_non_overlapping_and_dense(sizes[0] < 2 || strides[0] == 1);
|
||||
}
|
||||
SmallVector<int64_t, 5> perm;
|
||||
perm.resize(dim());
|
||||
for (const auto i : c10::irange(dim())) {
|
||||
perm.resize(dim);
|
||||
for (const auto i : c10::irange(dim)) {
|
||||
perm[i] = i;
|
||||
}
|
||||
// Sort by strides, leaving 0 and 1 sized dims at the end of the array
|
||||
std::sort(perm.begin(), perm.end(), [&](int64_t a, int64_t b) {
|
||||
if (sizes_and_strides_.size_at_unchecked(a) < 2) {
|
||||
if (sizes[a] < 2) {
|
||||
return false;
|
||||
} else if (sizes_and_strides_.size_at_unchecked(b) < 2) {
|
||||
} else if (sizes[b] < 2) {
|
||||
return true;
|
||||
}
|
||||
return sizes_and_strides_.stride_at_unchecked(a) <
|
||||
sizes_and_strides_.stride_at_unchecked(b);
|
||||
return strides[a] < strides[b];
|
||||
});
|
||||
int64_t require_stride = 1;
|
||||
for (const auto i : c10::irange(dim())) {
|
||||
const auto size_perm_i = sizes_and_strides_.size_at_unchecked(perm[i]);
|
||||
T require_stride = 1;
|
||||
for (const auto i : c10::irange(dim)) {
|
||||
const auto size_perm_i = sizes[perm[i]];
|
||||
if (size_perm_i < 2) {
|
||||
return true;
|
||||
return bool_is_non_overlapping_and_dense(true);
|
||||
}
|
||||
if (sizes_and_strides_.stride_at_unchecked(perm[i]) != require_stride) {
|
||||
return false;
|
||||
if (strides[perm[i]] != require_stride) {
|
||||
return bool_is_non_overlapping_and_dense(false);
|
||||
}
|
||||
require_stride *= size_perm_i;
|
||||
}
|
||||
return true;
|
||||
return bool_is_non_overlapping_and_dense(true);
|
||||
}
|
||||
|
||||
bool_is_non_overlapping_and_dense TensorImpl::
|
||||
compute_non_overlapping_and_dense() const {
|
||||
return COMPUTE_WITH_SIZES_STRIDES(_compute_non_overlapping_and_dense);
|
||||
}
|
||||
|
||||
void TensorImpl::release_resources() {
|
||||
@ -390,12 +442,25 @@ impl::PyInterpreter& TensorImpl::load_pyobj_interpreter() const {
|
||||
|
||||
bool TensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const {
|
||||
if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
|
||||
// TODO: pass memory_format to is_contiguous call
|
||||
return load_pyobj_interpreter()->is_contiguous(this);
|
||||
return load_pyobj_interpreter()->is_contiguous(this, memory_format);
|
||||
}
|
||||
return is_contiguous_default(memory_format);
|
||||
}
|
||||
|
||||
bool TensorImpl::is_strides_like_custom(at::MemoryFormat memory_format) const {
|
||||
if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
|
||||
return load_pyobj_interpreter()->is_strides_like(this, memory_format);
|
||||
}
|
||||
return is_strides_like_default(memory_format);
|
||||
}
|
||||
|
||||
bool TensorImpl::is_non_overlapping_and_dense_custom() const {
|
||||
if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
|
||||
return load_pyobj_interpreter()->is_non_overlapping_and_dense(this);
|
||||
}
|
||||
return is_non_overlapping_and_dense_default();
|
||||
}
|
||||
|
||||
IntArrayRef TensorImpl::sizes_custom() const {
|
||||
if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomSizes))) {
|
||||
return load_pyobj_interpreter()->sizes(this);
|
||||
@ -575,12 +640,8 @@ c10::intrusive_ptr<TensorImpl> TensorImpl::shallow_copy_and_detach_core(
|
||||
/*version_counter=*/std::forward<VariableVersion>(version_counter),
|
||||
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
|
||||
|
||||
// We currently don't support refresh_numel() and refresh_contiguous(). It's
|
||||
// plausible that we could support it, but currently done to unblock.
|
||||
if (!has_symbolic_sizes_strides()) {
|
||||
impl->refresh_numel();
|
||||
impl->refresh_contiguous();
|
||||
}
|
||||
impl->refresh_numel();
|
||||
impl->refresh_contiguous();
|
||||
return impl;
|
||||
}
|
||||
|
||||
@ -634,8 +695,8 @@ void TensorImpl::copy_generic_tensor_metadata(
|
||||
dest_impl->extra_meta_ = src_impl->extra_meta_->clone();
|
||||
}
|
||||
|
||||
// NB: symbolic sizes and strides are copied, but custom policy is
|
||||
// NOT (you have no Python object to dispatch to!)
|
||||
// NB: symbolic sizes and strides are copied as is custom policy, but python
|
||||
// policy is NOT (you have no Python object to dispatch to!)
|
||||
// NB: subclass relevant policy doesn't have to be copied; the
|
||||
// constructor sets this up
|
||||
|
||||
@ -885,33 +946,6 @@ void TensorImpl::ShareExternalPointer(
|
||||
}
|
||||
}
|
||||
|
||||
bool _compute_contiguous(const ExtraMeta& extra_meta, std::vector<int> order) {
|
||||
if (order.size() != extra_meta.sizes_.size())
|
||||
return false;
|
||||
bool is_contiguous = true;
|
||||
if (extra_meta.numel_ == 0)
|
||||
return is_contiguous;
|
||||
SymInt z = 1;
|
||||
for (auto d : order) {
|
||||
const auto size_d = extra_meta.sizes_.at(d);
|
||||
if (size_d != 1) {
|
||||
if (extra_meta.strides_.at(d) == z) {
|
||||
z *= size_d;
|
||||
} else {
|
||||
is_contiguous = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return is_contiguous;
|
||||
}
|
||||
|
||||
bool _compute_contiguous(const ExtraMeta& extra_meta) {
|
||||
std::vector<int> order(extra_meta.sizes_.size());
|
||||
std::iota(order.rbegin(), order.rend(), 0);
|
||||
return _compute_contiguous(extra_meta, order);
|
||||
}
|
||||
|
||||
void clone_symvec(SymIntArrayRef src, SymDimVector& dst) {
|
||||
dst.clear();
|
||||
dst.reserve(src.size());
|
||||
@ -952,16 +986,8 @@ void TensorImpl::set_sizes_and_strides(
|
||||
if (storage_offset.has_value())
|
||||
extra_meta_->storage_offset_ = storage_offset->clone();
|
||||
|
||||
SymInt numel = 1;
|
||||
for (const auto& s : sizes) {
|
||||
numel *= s;
|
||||
}
|
||||
extra_meta_->numel_ = numel;
|
||||
extra_meta_->is_contiguous_ = _compute_contiguous(*extra_meta_);
|
||||
extra_meta_->is_channels_last_contiguous_ =
|
||||
_compute_contiguous(*extra_meta_, {1, 3, 2, 0});
|
||||
extra_meta_->is_channels_last_3d_contiguous_ =
|
||||
_compute_contiguous(*extra_meta_, {1, 4, 3, 2, 0});
|
||||
refresh_numel();
|
||||
refresh_contiguous();
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
|
@ -21,6 +21,7 @@
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10/util/python_stub.h>
|
||||
#include <c10/util/safe_numerics.h>
|
||||
#include <c10/util/strong_type.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
@ -225,16 +226,42 @@ struct C10_API NamedTensorMetaInterface {
|
||||
};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using strong_bool = strong::
|
||||
type<bool, T, strong::regular, strong::iostreamable, strong::boolean>;
|
||||
|
||||
// For ease of copy pasting
|
||||
#if 0
|
||||
is_contiguous
|
||||
is_channels_last_contiguous
|
||||
is_channels_last_3d_contiguous
|
||||
is_channels_last
|
||||
is_channels_last_3d
|
||||
is_non_overlapping_and_dense
|
||||
#endif
|
||||
|
||||
using bool_is_contiguous = strong_bool<struct bool_is_contiguous_>;
|
||||
using bool_is_channels_last_contiguous =
|
||||
strong_bool<struct bool_is_channels_last_contiguous_>;
|
||||
using bool_is_channels_last_3d_contiguous =
|
||||
strong_bool<struct bool_is_channels_last_3d_contiguous_>;
|
||||
using bool_is_channels_last = strong_bool<struct bool_is_channels_last_>;
|
||||
using bool_is_channels_last_3d = strong_bool<struct bool_is_channels_last_3d_>;
|
||||
using bool_is_non_overlapping_and_dense =
|
||||
strong_bool<struct bool_is_non_overlapping_and_dense_>;
|
||||
|
||||
struct C10_API ExtraMeta {
|
||||
SymDimVector sizes_ = {0};
|
||||
SymDimVector strides_ = {1};
|
||||
SymInt numel_ = 1;
|
||||
SymInt storage_offset_ = 0;
|
||||
bool is_contiguous_ = true;
|
||||
bool is_channels_last_contiguous_ = false;
|
||||
bool is_channels_last_3d_contiguous_ = false;
|
||||
// TODO:
|
||||
// SymBool is_contiguous_;
|
||||
// TODO: make these all SymBool
|
||||
bool_is_contiguous is_contiguous_{true};
|
||||
bool_is_channels_last_contiguous is_channels_last_contiguous_{false};
|
||||
bool_is_channels_last_3d_contiguous is_channels_last_3d_contiguous_{false};
|
||||
bool_is_channels_last is_channels_last_{false};
|
||||
bool_is_channels_last_3d is_channels_last_3d_{false};
|
||||
bool_is_non_overlapping_and_dense is_non_overlapping_and_dense_{true};
|
||||
std::unique_ptr<c10::NamedTensorMetaInterface> named_tensor_meta_ = nullptr;
|
||||
|
||||
ExtraMeta() {}
|
||||
@ -244,9 +271,12 @@ struct C10_API ExtraMeta {
|
||||
SymDimVector strides,
|
||||
SymInt numel,
|
||||
SymInt storage_offset,
|
||||
bool is_contiguous,
|
||||
bool is_channels_last_contiguous,
|
||||
bool is_channels_last_3d_contiguous,
|
||||
bool_is_contiguous is_contiguous,
|
||||
bool_is_channels_last_contiguous is_channels_last_contiguous,
|
||||
bool_is_channels_last_3d_contiguous is_channels_last_3d_contiguous,
|
||||
bool_is_channels_last is_channels_last,
|
||||
bool_is_channels_last_3d is_channels_last_3d,
|
||||
bool_is_non_overlapping_and_dense is_non_overlapping_and_dense,
|
||||
std::unique_ptr<c10::NamedTensorMetaInterface> named_tensor_meta)
|
||||
: sizes_(std::move(sizes)),
|
||||
strides_(std::move(strides)),
|
||||
@ -255,6 +285,9 @@ struct C10_API ExtraMeta {
|
||||
is_contiguous_(is_contiguous),
|
||||
is_channels_last_contiguous_(is_channels_last_contiguous),
|
||||
is_channels_last_3d_contiguous_(is_channels_last_3d_contiguous),
|
||||
is_channels_last_(is_channels_last),
|
||||
is_channels_last_3d_(is_channels_last_3d),
|
||||
is_non_overlapping_and_dense_(is_non_overlapping_and_dense),
|
||||
named_tensor_meta_(std::move(named_tensor_meta)) {}
|
||||
|
||||
std::unique_ptr<ExtraMeta> clone() const {
|
||||
@ -266,6 +299,9 @@ struct C10_API ExtraMeta {
|
||||
is_contiguous_,
|
||||
is_channels_last_contiguous_,
|
||||
is_channels_last_3d_contiguous_,
|
||||
is_channels_last_,
|
||||
is_channels_last_3d_,
|
||||
is_non_overlapping_and_dense_,
|
||||
named_tensor_meta_ ? named_tensor_meta_->clone() : nullptr);
|
||||
}
|
||||
};
|
||||
@ -782,11 +818,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
bool is_contiguous_default(at::MemoryFormat memory_format) const {
|
||||
if (has_symbolic_sizes_strides_) {
|
||||
if (memory_format == at::MemoryFormat::ChannelsLast) {
|
||||
return extra_meta_->is_channels_last_contiguous_;
|
||||
return bool(extra_meta_->is_channels_last_contiguous_);
|
||||
} else if (memory_format == at::MemoryFormat::ChannelsLast3d) {
|
||||
return extra_meta_->is_channels_last_3d_contiguous_;
|
||||
return bool(extra_meta_->is_channels_last_3d_contiguous_);
|
||||
}
|
||||
return extra_meta_->is_contiguous_;
|
||||
return bool(extra_meta_->is_contiguous_);
|
||||
}
|
||||
|
||||
if (memory_format == at::MemoryFormat::ChannelsLast) {
|
||||
@ -797,6 +833,34 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
return is_contiguous_;
|
||||
}
|
||||
|
||||
bool is_strides_like_default(at::MemoryFormat memory_format) const {
|
||||
if (has_symbolic_sizes_strides_) {
|
||||
if (memory_format == at::MemoryFormat::ChannelsLast) {
|
||||
return bool(extra_meta_->is_channels_last_);
|
||||
} else if (memory_format == at::MemoryFormat::ChannelsLast3d) {
|
||||
return bool(extra_meta_->is_channels_last_3d_);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (memory_format == at::MemoryFormat::ChannelsLast) {
|
||||
return is_channels_last_;
|
||||
} else if (memory_format == at::MemoryFormat::ChannelsLast3d) {
|
||||
return is_channels_last_3d_;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool is_non_overlapping_and_dense_default() const {
|
||||
if (has_symbolic_sizes_strides_) {
|
||||
return bool(extra_meta_->is_non_overlapping_and_dense_);
|
||||
} else {
|
||||
return is_non_overlapping_and_dense_;
|
||||
}
|
||||
}
|
||||
|
||||
// NB: these dim accessor functions don't have _default(), as you can use
|
||||
// sizes_default/strides_default
|
||||
/**
|
||||
@ -884,6 +948,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
*/
|
||||
// sizes_strides_policy_ >= CustomStrides
|
||||
virtual bool is_contiguous_custom(at::MemoryFormat memory_format) const;
|
||||
virtual bool is_strides_like_custom(at::MemoryFormat memory_format) const;
|
||||
virtual bool is_non_overlapping_and_dense_custom() const;
|
||||
// sizes_strides_policy_ >= CustomSizes
|
||||
// Currently this method only exists to be overwritten by subclasses such as
|
||||
// NestedTensorImpl.
|
||||
@ -1589,7 +1655,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
sizes_and_strides_.set_sizes(new_size);
|
||||
|
||||
refresh_numel();
|
||||
empty_tensor_restride(MemoryFormat::Contiguous);
|
||||
empty_tensor_restride(
|
||||
MemoryFormat::Contiguous); // calls refresh_contiguous()
|
||||
}
|
||||
|
||||
/**
|
||||
@ -2273,16 +2340,26 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
refresh_contiguous();
|
||||
}
|
||||
|
||||
bool is_strides_like(at::MemoryFormat memory_format) const {
|
||||
if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) {
|
||||
return is_strides_like_custom(memory_format);
|
||||
}
|
||||
return is_strides_like_default(memory_format);
|
||||
}
|
||||
|
||||
bool is_strides_like_channels_last() const {
|
||||
return is_channels_last_;
|
||||
return is_strides_like(at::MemoryFormat::ChannelsLast);
|
||||
}
|
||||
|
||||
bool is_strides_like_channels_last_3d() const {
|
||||
return is_channels_last_3d_;
|
||||
return is_strides_like(at::MemoryFormat::ChannelsLast3d);
|
||||
}
|
||||
|
||||
bool is_non_overlapping_and_dense() const {
|
||||
return is_non_overlapping_and_dense_;
|
||||
if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) {
|
||||
return is_non_overlapping_and_dense_custom();
|
||||
}
|
||||
return is_non_overlapping_and_dense_default();
|
||||
}
|
||||
|
||||
bool has_symbolic_sizes_strides() const {
|
||||
@ -2360,12 +2437,16 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
/**
|
||||
* Compute the number of elements based on the sizes of a tensor.
|
||||
*/
|
||||
// NB: This is ONLY called when sizes_and_strides_ is used directly; if
|
||||
// we are virtualizing, then numel calls are virtualized as well, and this
|
||||
// should never get called
|
||||
int64_t compute_numel() const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!has_symbolic_sizes_strides_);
|
||||
#if C10_HAS_BUILTIN_OVERFLOW() && !defined(C10_MOBILE)
|
||||
// Use overflow checks if supported by the compiler
|
||||
return safe_compute_numel();
|
||||
#else
|
||||
return c10::multiply_integers(sizes());
|
||||
return c10::multiply_integers(sizes_and_strides_.sizes_arrayref());
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -2375,8 +2456,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
* using a sparse layout has multiple dimensions with large sizes.
|
||||
*/
|
||||
int64_t safe_compute_numel() const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!has_symbolic_sizes_strides_);
|
||||
uint64_t n = 1;
|
||||
bool overflows = c10::safe_multiplies_u64(sizes(), &n);
|
||||
bool overflows =
|
||||
c10::safe_multiplies_u64(sizes_and_strides_.sizes_arrayref(), &n);
|
||||
constexpr auto numel_max = std::min(
|
||||
static_cast<uint64_t>(std::numeric_limits<int64_t>::max()),
|
||||
static_cast<uint64_t>(std::numeric_limits<size_t>::max()));
|
||||
@ -2386,21 +2469,31 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
return static_cast<int64_t>(n);
|
||||
}
|
||||
|
||||
SymInt compute_sym_numel() const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(has_symbolic_sizes_strides_);
|
||||
SymInt numel = 1;
|
||||
for (const auto& s : extra_meta_->sizes_) {
|
||||
numel *= s;
|
||||
}
|
||||
return numel;
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute whether or not a tensor is contiguous based on the sizes and
|
||||
* strides of a tensor.
|
||||
*/
|
||||
bool compute_contiguous() const;
|
||||
bool_is_contiguous compute_contiguous() const;
|
||||
|
||||
bool compute_channels_last_contiguous_2d() const;
|
||||
bool_is_channels_last_contiguous compute_channels_last_contiguous_2d() const;
|
||||
|
||||
bool compute_channels_last_contiguous_3d() const;
|
||||
bool_is_channels_last_3d_contiguous compute_channels_last_contiguous_3d()
|
||||
const;
|
||||
|
||||
bool compute_strides_like_channels_last_2d() const;
|
||||
bool_is_channels_last compute_strides_like_channels_last_2d() const;
|
||||
|
||||
bool compute_strides_like_channels_last_3d() const;
|
||||
bool_is_channels_last_3d compute_strides_like_channels_last_3d() const;
|
||||
|
||||
bool compute_non_overlapping_and_dense() const;
|
||||
bool_is_non_overlapping_and_dense compute_non_overlapping_and_dense() const;
|
||||
|
||||
protected:
|
||||
/**
|
||||
@ -2410,9 +2503,20 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
* For tensors with sparse layouts, use safe_refresh_numel() instead
|
||||
* because it will catch integer overflow that may occur for tensors
|
||||
* with sparse layouts and large dimensions.
|
||||
*
|
||||
* NB: We may uselessly recompute cached numel even in situations where
|
||||
* it is completely never used (e.g., if CustomSizes for Python). However,
|
||||
* we still must keep it up to date in case the Python overload
|
||||
* returns None (in which case we will consult the field here). This also
|
||||
* implies that sizes/strides will never be complete garbage; in the
|
||||
* very worst case scenario, it will reflect a 1-dim zero size tensor.
|
||||
*/
|
||||
void refresh_numel() {
|
||||
numel_ = compute_numel();
|
||||
if (has_symbolic_sizes_strides_) {
|
||||
extra_meta_->numel_ = compute_sym_numel();
|
||||
} else {
|
||||
numel_ = compute_numel();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@ -2422,7 +2526,13 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
* overflow when computing numel.
|
||||
*/
|
||||
void safe_refresh_numel() {
|
||||
numel_ = safe_compute_numel();
|
||||
if (has_symbolic_sizes_strides_) {
|
||||
// NB: sym numel is done with symbolic integers, which handle overflow
|
||||
// checking
|
||||
extra_meta_->numel_ = compute_sym_numel();
|
||||
} else {
|
||||
numel_ = safe_compute_numel();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@ -2430,49 +2540,94 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
* or strides.
|
||||
*/
|
||||
void refresh_contiguous() {
|
||||
TORCH_CHECK(
|
||||
!has_symbolic_sizes_strides_,
|
||||
"refresh_contiguous() called on tensor with symbolic shape")
|
||||
auto set_fields =
|
||||
[&](bool_is_contiguous is_contiguous,
|
||||
bool_is_channels_last_contiguous is_channels_last_contiguous,
|
||||
bool_is_channels_last_3d_contiguous is_channels_last_3d_contiguous,
|
||||
bool_is_channels_last is_channels_last,
|
||||
bool_is_channels_last_3d is_channels_last_3d,
|
||||
bool_is_non_overlapping_and_dense is_non_overlapping_and_dense) {
|
||||
if (has_symbolic_sizes_strides_) {
|
||||
extra_meta_->is_contiguous_ = is_contiguous;
|
||||
extra_meta_->is_channels_last_contiguous_ =
|
||||
is_channels_last_contiguous;
|
||||
extra_meta_->is_channels_last_3d_contiguous_ =
|
||||
is_channels_last_3d_contiguous;
|
||||
extra_meta_->is_channels_last_ = is_channels_last;
|
||||
extra_meta_->is_channels_last_3d_ = is_channels_last_3d;
|
||||
extra_meta_->is_non_overlapping_and_dense_ =
|
||||
is_non_overlapping_and_dense;
|
||||
} else {
|
||||
is_contiguous_ = bool(is_contiguous);
|
||||
is_channels_last_contiguous_ = bool(is_channels_last_contiguous);
|
||||
is_channels_last_3d_contiguous_ =
|
||||
bool(is_channels_last_3d_contiguous);
|
||||
is_channels_last_ = bool(is_channels_last);
|
||||
is_channels_last_3d_ = bool(is_channels_last_3d);
|
||||
is_non_overlapping_and_dense_ = bool(is_non_overlapping_and_dense);
|
||||
}
|
||||
};
|
||||
|
||||
is_contiguous_ = compute_contiguous();
|
||||
auto is_contiguous = compute_contiguous();
|
||||
// Note:
|
||||
// Dim 0, 1, 2 will never be a channels last 2d/3d format
|
||||
// Dim 3+ is possibly be a channels last 2d format (Dim 4 only at this
|
||||
// point) Dim 4+ is possibly be a channels last 3d format (Dim 5 only at
|
||||
// this point)
|
||||
switch (dim()) {
|
||||
case 4:
|
||||
is_channels_last_contiguous_ = compute_channels_last_contiguous_2d();
|
||||
is_channels_last_3d_contiguous_ = false;
|
||||
is_channels_last_ = compute_strides_like_channels_last_2d();
|
||||
is_channels_last_3d_ = false;
|
||||
is_non_overlapping_and_dense_ = is_contiguous_ ||
|
||||
is_channels_last_contiguous_ || compute_non_overlapping_and_dense();
|
||||
case 4: {
|
||||
auto is_channels_last_contiguous =
|
||||
compute_channels_last_contiguous_2d();
|
||||
set_fields(
|
||||
is_contiguous,
|
||||
is_channels_last_contiguous,
|
||||
bool_is_channels_last_3d_contiguous(false),
|
||||
compute_strides_like_channels_last_2d(),
|
||||
bool_is_channels_last_3d(false),
|
||||
bool_is_non_overlapping_and_dense(
|
||||
is_contiguous || is_channels_last_contiguous ||
|
||||
compute_non_overlapping_and_dense()));
|
||||
break;
|
||||
case 5:
|
||||
is_channels_last_contiguous_ = compute_channels_last_contiguous_2d();
|
||||
is_channels_last_3d_contiguous_ = !is_channels_last_contiguous_ &&
|
||||
compute_channels_last_contiguous_3d();
|
||||
is_channels_last_ = !is_channels_last_3d_contiguous_ &&
|
||||
compute_strides_like_channels_last_2d();
|
||||
is_channels_last_3d_ =
|
||||
!is_channels_last_ && compute_strides_like_channels_last_3d();
|
||||
is_non_overlapping_and_dense_ = is_contiguous_ ||
|
||||
is_channels_last_contiguous_ || is_channels_last_3d_contiguous_ ||
|
||||
compute_non_overlapping_and_dense();
|
||||
}
|
||||
case 5: {
|
||||
auto is_channels_last_contiguous =
|
||||
compute_channels_last_contiguous_2d();
|
||||
auto is_channels_last_3d_contiguous =
|
||||
bool_is_channels_last_3d_contiguous(
|
||||
!is_channels_last_contiguous &&
|
||||
compute_channels_last_contiguous_3d());
|
||||
auto is_channels_last = bool_is_channels_last(
|
||||
!is_channels_last_3d_contiguous &&
|
||||
compute_strides_like_channels_last_2d());
|
||||
auto is_channels_last_3d = bool_is_channels_last_3d(
|
||||
!is_channels_last && compute_strides_like_channels_last_3d());
|
||||
auto is_non_overlapping_and_dense = bool_is_non_overlapping_and_dense(
|
||||
is_contiguous || is_channels_last_contiguous ||
|
||||
is_channels_last_3d_contiguous ||
|
||||
compute_non_overlapping_and_dense());
|
||||
set_fields(
|
||||
is_contiguous,
|
||||
is_channels_last_contiguous,
|
||||
is_channels_last_3d_contiguous,
|
||||
is_channels_last,
|
||||
is_channels_last_3d,
|
||||
is_non_overlapping_and_dense);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
is_channels_last_contiguous_ = false;
|
||||
is_channels_last_3d_contiguous_ = false;
|
||||
// is_channels_last_ and is_channels_last_3d_ are suggested
|
||||
// memory_format. Being channels_last_contiguous doesn't necessarily
|
||||
// mean the tensor is strided like channels_last: for strides on channel
|
||||
// dimension could suggest desired memory_layout, but it doesn't affect
|
||||
// memory storage
|
||||
is_channels_last_ = false;
|
||||
is_channels_last_3d_ = false;
|
||||
is_non_overlapping_and_dense_ =
|
||||
is_contiguous_ || compute_non_overlapping_and_dense();
|
||||
set_fields(
|
||||
is_contiguous,
|
||||
bool_is_channels_last_contiguous(false),
|
||||
bool_is_channels_last_3d_contiguous(false),
|
||||
bool_is_channels_last(false),
|
||||
bool_is_channels_last_3d(false),
|
||||
bool_is_non_overlapping_and_dense(
|
||||
is_contiguous || compute_non_overlapping_and_dense()));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -34,9 +34,16 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
|
||||
PANIC(python_dispatcher);
|
||||
}
|
||||
|
||||
bool is_contiguous(const TensorImpl* self) const override {
|
||||
bool is_contiguous(const TensorImpl* self, at::MemoryFormat) const override {
|
||||
PANIC(is_contiguous);
|
||||
}
|
||||
bool is_strides_like(const TensorImpl* self, at::MemoryFormat)
|
||||
const override {
|
||||
PANIC(is_strides_like);
|
||||
}
|
||||
bool is_non_overlapping_and_dense(const TensorImpl* self) const override {
|
||||
PANIC(is_non_overlapping_and_dense);
|
||||
}
|
||||
c10::Device device(const TensorImpl* self) const override {
|
||||
PANIC(device);
|
||||
}
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/core/Layout.h>
|
||||
#include <c10/core/MemoryFormat.h>
|
||||
#include <c10/core/SymIntArrayRef.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/ArrayRef.h>
|
||||
@ -146,7 +147,11 @@ struct C10_API PyInterpreterVTable {
|
||||
c10::DispatchKeySet,
|
||||
torch::jit::Stack* stack) const = 0;
|
||||
|
||||
virtual bool is_contiguous(const TensorImpl* self) const = 0;
|
||||
virtual bool is_contiguous(const TensorImpl* self, at::MemoryFormat)
|
||||
const = 0;
|
||||
virtual bool is_strides_like(const TensorImpl* self, at::MemoryFormat)
|
||||
const = 0;
|
||||
virtual bool is_non_overlapping_and_dense(const TensorImpl* self) const = 0;
|
||||
virtual c10::Device device(const TensorImpl* self) const = 0;
|
||||
virtual int64_t dim(const TensorImpl* self) const = 0;
|
||||
virtual c10::IntArrayRef strides(const TensorImpl* self) const = 0;
|
||||
|
@ -33,20 +33,12 @@
|
||||
#endif
|
||||
|
||||
#ifndef STRONG_HAS_STD_FORMAT
|
||||
#if __has_include(<format>)
|
||||
#define STRONG_HAS_STD_FORMAT 1
|
||||
#else
|
||||
#define STRONG_HAS_STD_FORMAT 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef STRONG_HAS_FMT_FORMAT
|
||||
#if __has_include(<fmt/format.h>)
|
||||
#define STRONG_HAS_FMT_FORMAT 1
|
||||
#else
|
||||
#define STRONG_HAS_FMT_FORMAT 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if STRONG_HAS_STD_FORMAT
|
||||
#include <format>
|
||||
|
@ -630,7 +630,7 @@ class FakeTensorOperatorInvariants(TestCase):
|
||||
for schema in self.get_all_aten_schemas():
|
||||
if "_like" == schema.name[-5:]:
|
||||
op = self.get_aten_op(schema)
|
||||
self.assertTrue(op in torch._subclasses.fake_tensor._like_tensor_constructors)
|
||||
self.assertIn(op, torch._subclasses.fake_tensor._like_tensor_constructors)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -1402,6 +1402,38 @@ $0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memo
|
||||
with self.assertRaisesRegex(TypeError, err_msg):
|
||||
e.contiguous()
|
||||
|
||||
def test_fancy_strides(self):
|
||||
calls = []
|
||||
|
||||
class ExampleTensor(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, data):
|
||||
return TestPythonDispatch.subclass_helper(cls, data, False, dispatch_sizes_strides_policy="strides")
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||
if func in [
|
||||
torch.ops.aten.is_contiguous.default,
|
||||
torch.ops.aten.is_contiguous.memory_format,
|
||||
torch.ops.aten.is_strides_like_format.default,
|
||||
torch.ops.aten.is_non_overlapping_and_dense.default,
|
||||
torch.ops.aten.stride.default
|
||||
]:
|
||||
calls.append((func, list(args)[1:]))
|
||||
return None
|
||||
with no_dispatch():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
e = ExampleTensor(torch.randn(2, 2))
|
||||
self.assertFalse(e.is_contiguous(memory_format=torch.channels_last))
|
||||
self.assertEqual(calls, [(torch.ops.aten.is_contiguous.memory_format, [torch.channels_last])])
|
||||
calls.clear()
|
||||
self.assertFalse(torch.ops.aten.is_strides_like_format.default(e, torch.channels_last))
|
||||
self.assertEqual(calls, [(torch.ops.aten.is_strides_like_format.default, [torch.channels_last])])
|
||||
calls.clear()
|
||||
self.assertTrue(torch.ops.aten.is_non_overlapping_and_dense.default(e))
|
||||
self.assertEqual(calls, [(torch.ops.aten.is_non_overlapping_and_dense.default, [])])
|
||||
|
||||
def test_device_slowpath(self):
|
||||
for use_wrapper_subclass in [True]:
|
||||
class ExampleTensor1(torch.Tensor):
|
||||
|
@ -67,7 +67,7 @@ static PyObject * THPVariable__parse_to(PyObject* module, PyObject* args, PyObje
|
||||
}
|
||||
PyTuple_SET_ITEM(tuple.get(), 2, torch::autograd::utils::wrap(non_blocking));
|
||||
if (opt_memory_format.has_value()) {
|
||||
PyTuple_SET_ITEM(tuple.get(), 3, torch::utils::getTHPMemoryFormat(opt_memory_format.value()).release().ptr());
|
||||
PyTuple_SET_ITEM(tuple.get(), 3, torch::utils::getTHPMemoryFormat(opt_memory_format.value()));
|
||||
} else {
|
||||
Py_INCREF(Py_None);
|
||||
PyTuple_SET_ITEM(tuple.get(), 3, Py_None);
|
||||
|
@ -132,8 +132,7 @@ std::pair<py::object, py::dict> parseIValuesToPyArgsKwargs(
|
||||
return py::reinterpret_borrow<py::object>(
|
||||
reinterpret_cast<PyObject*>(obj));
|
||||
} else if (match(c10::MemoryFormatType::Kind)) {
|
||||
return torch::utils::getTHPMemoryFormat(
|
||||
static_cast<c10::MemoryFormat>(arguments[idx].toInt()));
|
||||
return py::cast(static_cast<c10::MemoryFormat>(arguments[idx].toInt()));
|
||||
} else {
|
||||
return torch::jit::toPyObject(arguments[idx]);
|
||||
}
|
||||
@ -241,7 +240,9 @@ struct ConcretePyInterpreterVTable final
|
||||
c10::DispatchKeySet,
|
||||
torch::jit::Stack* stack) const override;
|
||||
|
||||
bool is_contiguous(const TensorImpl* self) const override;
|
||||
bool is_contiguous(const TensorImpl* self, at::MemoryFormat) const override;
|
||||
bool is_strides_like(const TensorImpl* self, at::MemoryFormat) const override;
|
||||
bool is_non_overlapping_and_dense(const TensorImpl* self) const override;
|
||||
c10::Device device(const TensorImpl* self) const override;
|
||||
int64_t dim(const TensorImpl* self) const override;
|
||||
c10::IntArrayRef strides(const TensorImpl* self) const override;
|
||||
@ -2179,7 +2180,12 @@ py::object torchDispatchFromTensorImpl(
|
||||
const c10::TensorImpl* self,
|
||||
const char* func_name,
|
||||
PyObject* torch_api_function,
|
||||
const char* module_name) {
|
||||
const char* module_name,
|
||||
// WARNING: MUST NOT BE TENSOR ARGS
|
||||
c10::SmallVector<py::object, 1> extra_args = {}) {
|
||||
if (torch_api_function == nullptr) {
|
||||
throw python_error();
|
||||
}
|
||||
TORCH_CHECK(
|
||||
PyGILState_Check(),
|
||||
"GIL must be held before you call parseIValuesToPyArgsKwargs");
|
||||
@ -2194,8 +2200,16 @@ py::object torchDispatchFromTensorImpl(
|
||||
// NB: this may not be a python tensor if you got here from a mode!
|
||||
// TORCH_INTERNAL_ASSERT(isPythonTensor(self_t));
|
||||
append_overloaded_tensor(&overloaded_args, self_p.ptr());
|
||||
auto args = py::reinterpret_steal<py::object>(PyTuple_New(1));
|
||||
auto args =
|
||||
py::reinterpret_steal<py::object>(PyTuple_New(1 + extra_args.size()));
|
||||
PyTuple_SET_ITEM(args.ptr(), 0, self_p.release().ptr());
|
||||
int64_t i = 1;
|
||||
for (auto& a : extra_args) {
|
||||
if (a.ptr() == nullptr)
|
||||
throw python_error();
|
||||
PyTuple_SET_ITEM(args.ptr(), i, std::move(a).release().ptr());
|
||||
i++;
|
||||
}
|
||||
|
||||
py::dict kwargs;
|
||||
|
||||
@ -2320,7 +2334,7 @@ void ConcretePyInterpreterVTable::python_dispatcher(
|
||||
py::object obj = py::reinterpret_steal<py::object>(
|
||||
PyObject_Call(handler.ptr(), args.ptr(), kwargs.ptr()));
|
||||
|
||||
if (obj == nullptr) {
|
||||
if (obj.ptr() == nullptr) {
|
||||
throw python_error();
|
||||
}
|
||||
|
||||
@ -2353,24 +2367,107 @@ c10::intrusive_ptr<TensorImpl> ConcretePyInterpreterVTable::detach(
|
||||
}
|
||||
|
||||
bool ConcretePyInterpreterVTable::is_contiguous(
|
||||
const c10::TensorImpl* self,
|
||||
at::MemoryFormat memory_format) const {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
at::impl::MaybeSetTLSOnEntryGuard guard;
|
||||
|
||||
py::object out;
|
||||
if (memory_format == at::MemoryFormat::Contiguous) {
|
||||
// For backwards compatibility
|
||||
out = torchDispatchFromTensorImpl(
|
||||
self,
|
||||
"is_contiguous",
|
||||
py::module::import("torch")
|
||||
.attr("ops")
|
||||
.attr("aten")
|
||||
.attr("is_contiguous")
|
||||
.attr("default")
|
||||
.ptr(),
|
||||
"torch.ops.aten");
|
||||
} else {
|
||||
out = torchDispatchFromTensorImpl(
|
||||
self,
|
||||
"is_contiguous",
|
||||
py::module::import("torch")
|
||||
.attr("ops")
|
||||
.attr("aten")
|
||||
.attr("is_contiguous")
|
||||
.attr("memory_format")
|
||||
.ptr(),
|
||||
"torch.ops.aten",
|
||||
{py::cast(memory_format)});
|
||||
}
|
||||
|
||||
if (out.is(py::none())) {
|
||||
return self->is_contiguous_default(memory_format);
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
PyBool_Check(out.ptr()),
|
||||
"is_contiguous returned invalid type ",
|
||||
py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
|
||||
", expected bool");
|
||||
|
||||
return PyObject_IsTrue(out.ptr());
|
||||
}
|
||||
|
||||
bool ConcretePyInterpreterVTable::is_strides_like(
|
||||
const c10::TensorImpl* self,
|
||||
at::MemoryFormat memory_format) const {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
at::impl::MaybeSetTLSOnEntryGuard guard;
|
||||
|
||||
auto out = torchDispatchFromTensorImpl(
|
||||
self,
|
||||
"is_strides_like",
|
||||
py::module::import("torch")
|
||||
.attr("ops")
|
||||
.attr("aten")
|
||||
// NB: intentionally suffixed with _format to avoid
|
||||
// triggering matches against "_like" suffix
|
||||
.attr("is_strides_like_format")
|
||||
.attr("default")
|
||||
.ptr(),
|
||||
"torch.ops.aten",
|
||||
{py::cast(memory_format)});
|
||||
|
||||
if (out.is(py::none())) {
|
||||
return self->is_strides_like_default(memory_format);
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
PyBool_Check(out.ptr()),
|
||||
"is_strides_like_format returned invalid type ",
|
||||
py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
|
||||
", expected bool");
|
||||
|
||||
return PyObject_IsTrue(out.ptr());
|
||||
}
|
||||
|
||||
bool ConcretePyInterpreterVTable::is_non_overlapping_and_dense(
|
||||
const c10::TensorImpl* self) const {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
at::impl::MaybeSetTLSOnEntryGuard guard;
|
||||
|
||||
auto out = torchDispatchFromTensorImpl(
|
||||
self,
|
||||
"is_contiguous",
|
||||
"is_non_overlapping_and_dense",
|
||||
py::module::import("torch")
|
||||
.attr("ops")
|
||||
.attr("aten")
|
||||
.attr("is_contiguous")
|
||||
.attr("is_non_overlapping_and_dense")
|
||||
.attr("default")
|
||||
.ptr(),
|
||||
"torch.ops.aten");
|
||||
|
||||
if (out.is(py::none())) {
|
||||
return self->is_non_overlapping_and_dense_default();
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
PyBool_Check(out.ptr()),
|
||||
"is_contiguous returned invalid type ",
|
||||
"is_non_overlapping_and_dense returned invalid type ",
|
||||
py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
|
||||
", expected bool");
|
||||
|
||||
|
@ -561,6 +561,34 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
|
||||
pack(stack, result);
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
TORCH_SELECTIVE_SCHEMA(
|
||||
"aten::is_contiguous.memory_format(Tensor self, MemoryFormat memory_format) -> bool"),
|
||||
[](Stack& stack) {
|
||||
auto memory_format = pop(stack).toMemoryFormat();
|
||||
auto t = pop(stack).toTensor();
|
||||
push(stack, t.is_contiguous(memory_format));
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
// NB: intentionally suffixed with extra _format to prevent tests for
|
||||
// "_like" suffix from triggering on this
|
||||
TORCH_SELECTIVE_SCHEMA(
|
||||
"aten::is_strides_like_format(Tensor self, MemoryFormat memory_format) -> bool"),
|
||||
[](Stack& stack) {
|
||||
auto memory_format = pop(stack).toMemoryFormat();
|
||||
auto t = pop(stack).toTensor();
|
||||
push(stack, t.unsafeGetTensorImpl()->is_strides_like(memory_format));
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
TORCH_SELECTIVE_SCHEMA(
|
||||
"aten::is_non_overlapping_and_dense(Tensor self) -> bool"),
|
||||
[](Stack& stack) {
|
||||
auto t = pop(stack).toTensor();
|
||||
push(stack, t.unsafeGetTensorImpl()->is_non_overlapping_and_dense());
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
// these ops are generic over the list element type.
|
||||
// CREATING GENERIC_LIST_OPS
|
||||
OperatorGeneratorArgs(
|
||||
|
@ -85,9 +85,6 @@ LTCTensorImpl::LTCTensorImpl(LazyTensor&& tensor)
|
||||
c10::scalarTypeToTypeMeta(tensor.dtype()),
|
||||
backendDeviceToAtenDevice(tensor.GetDevice())),
|
||||
tensor_(c10::make_intrusive<LazyTensor>(std::move(tensor))) {
|
||||
// This is a temporary fix for a PyTorch core issue,
|
||||
// according to https://github.com/pytorch/xla/pull/2682.
|
||||
is_non_overlapping_and_dense_ = false;
|
||||
set_custom_sizes_strides(SizesStridesPolicy::CustomSizes);
|
||||
}
|
||||
|
||||
@ -185,12 +182,33 @@ int64_t LTCTensorImpl::numel_custom() const {
|
||||
return numel_default();
|
||||
}
|
||||
|
||||
int64_t LTCTensorImpl::storage_offset_custom() const {
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool LTCTensorImpl::is_strides_like_custom(
|
||||
c10::MemoryFormat memory_format) const {
|
||||
TORCH_INTERNAL_ASSERT(memory_format != at::MemoryFormat::Contiguous);
|
||||
return false;
|
||||
}
|
||||
|
||||
bool LTCTensorImpl::is_non_overlapping_and_dense_custom() const {
|
||||
// This should be true, but false as a temporary fix for a PyTorch core issue,
|
||||
// according to https://github.com/pytorch/xla/pull/2682.
|
||||
return false;
|
||||
}
|
||||
|
||||
bool LTCTensorImpl::is_contiguous_custom(c10::MemoryFormat _unused) const {
|
||||
// TODO(ezyang): I don't think this branch is actually necessary
|
||||
// TODO(ezyang): I don't think this logic is right, shouldn't we pass on
|
||||
// the memory format?
|
||||
if (tensor_->CurrentTensorData()) {
|
||||
return tensor_->CurrentTensorData()->is_contiguous();
|
||||
}
|
||||
// Only check that the storage is already contiguous.
|
||||
CHECK(is_contiguous_) << "Non-contiguous storage for lazy tensor";
|
||||
// TODO: I don't think logic is right, we should check the requested memory
|
||||
// format before returning true
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -39,12 +39,15 @@ class TORCH_API LTCTensorImpl final : public c10::TensorImpl {
|
||||
|
||||
at::IntArrayRef sizes_custom() const override;
|
||||
at::IntArrayRef strides_custom() const override;
|
||||
int64_t dim_custom() const override;
|
||||
int64_t numel_custom() const override;
|
||||
int64_t storage_offset_custom() const override;
|
||||
int64_t dim_custom() const override;
|
||||
bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
|
||||
bool is_strides_like_custom(at::MemoryFormat memory_format) const override;
|
||||
bool is_non_overlapping_and_dense_custom() const override;
|
||||
|
||||
virtual c10::SymIntArrayRef sym_sizes_custom() const override;
|
||||
virtual c10::SymIntArrayRef sym_strides_custom() const override;
|
||||
c10::SymIntArrayRef sym_sizes_custom() const override;
|
||||
c10::SymIntArrayRef sym_strides_custom() const override;
|
||||
|
||||
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
|
||||
const at::Storage& storage() const override {
|
||||
|
@ -12,6 +12,8 @@
|
||||
#include <torch/csrc/Device.h>
|
||||
#include <torch/csrc/DynamicTypes.h>
|
||||
#include <torch/csrc/Generator.h>
|
||||
#include <torch/csrc/MemoryFormat.h>
|
||||
#include <torch/csrc/utils/tensor_memoryformats.h>
|
||||
|
||||
#include <stdexcept>
|
||||
#include <utility>
|
||||
@ -107,6 +109,28 @@ struct TORCH_PYTHON_API type_caster<at::IntArrayRef> {
|
||||
std::vector<int64_t> v_value;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TORCH_PYTHON_API type_caster<at::MemoryFormat> {
|
||||
public:
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||
PYBIND11_TYPE_CASTER(at::MemoryFormat, _("at::MemoryFormat"));
|
||||
|
||||
bool load(handle src, bool) {
|
||||
PyObject* obj = src.ptr();
|
||||
if (THPMemoryFormat_Check(obj)) {
|
||||
value = reinterpret_cast<THPMemoryFormat*>(obj)->memory_format;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
static handle cast(
|
||||
at::MemoryFormat src,
|
||||
return_value_policy /* policy */,
|
||||
handle /* parent */) {
|
||||
return handle(torch::utils::getTHPMemoryFormat(src));
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_caster<at::Device> {
|
||||
public:
|
||||
|
@ -17,9 +17,11 @@ std::array<PyObject*, static_cast<int>(at::MemoryFormat::NumOptions)>
|
||||
memory_format_registry = {};
|
||||
} // anonymous namespace
|
||||
|
||||
py::object getTHPMemoryFormat(at::MemoryFormat memory_format) {
|
||||
PyObject* getTHPMemoryFormat(at::MemoryFormat memory_format) {
|
||||
return py::reinterpret_borrow<py::object>(
|
||||
memory_format_registry[static_cast<size_t>(memory_format)]);
|
||||
memory_format_registry[static_cast<size_t>(memory_format)])
|
||||
.release()
|
||||
.ptr();
|
||||
}
|
||||
|
||||
void initializeMemoryFormats() {
|
||||
|
@ -1,13 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/MemoryFormat.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
#include <torch/csrc/utils/python_stub.h>
|
||||
|
||||
namespace torch {
|
||||
namespace utils {
|
||||
|
||||
void initializeMemoryFormats();
|
||||
py::object getTHPMemoryFormat(c10::MemoryFormat);
|
||||
PyObject* getTHPMemoryFormat(c10::MemoryFormat);
|
||||
|
||||
} // namespace utils
|
||||
} // namespace torch
|
||||
|
@ -266,7 +266,27 @@ def is_contiguous(func, *args, **kwargs):
|
||||
raise ValueError(
|
||||
"MaskedTensors with sparse data do not have is_contiguous"
|
||||
)
|
||||
return data.is_contiguous()
|
||||
return func(data, *args[1:], **kwargs)
|
||||
|
||||
|
||||
@register_dispatch_func([torch.ops.aten.is_strides_like_format])
|
||||
def is_strides_like_format(func, *args, **kwargs):
|
||||
data = _get_data(args[0])
|
||||
if data.is_sparse:
|
||||
raise ValueError(
|
||||
"MaskedTensors with sparse data do not have is_strides_like_format"
|
||||
)
|
||||
return func(data, *args[1:], **kwargs)
|
||||
|
||||
|
||||
@register_dispatch_func([torch.ops.aten.is_non_overlapping_and_dense])
|
||||
def is_non_overlapping_and_dense(func, *args, **kwargs):
|
||||
data = _get_data(args[0])
|
||||
if data.is_sparse:
|
||||
raise ValueError(
|
||||
"MaskedTensors with sparse data do not have is_non_overlapping_and_dense"
|
||||
)
|
||||
return func(data, *args[1:], **kwargs)
|
||||
|
||||
|
||||
@register_dispatch_func([torch.ops.aten.contiguous])
|
||||
|
Reference in New Issue
Block a user