[attempt 2] Compute contiguity symbolically to avoid dde, and introduce c++ sym_is_contiguous (#157472)

Summary:
When we compute contiguity for a tensor with dynamic shapes we first:
1) Try to compute it without guarding.
2) If all shapes hinted, compute it with potentially adding guards.
3) if any input is not hinted, compute it symbolically.

sym_is_contiguous return a SymBool that is then either evaluated or guard_or_false can be called
on it to avoid data dependent errors.

ex:
 bool is_contiguous = input.sym_is_contiguous().guard_or_false(__FILE__, __LINE__);
is_contiguous_or_false is a helper function that does that.

In this PR I only handle default contiguity, will follow up with changes for other formats like  channel_last .
We use this patter in this PR for several locations to avoid DDEs.

Test Plan:
contbuild & OSS CI,

Rollback Plan:

Reviewed By: malfet

Differential Revision: D77639021

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157472
Approved by: https://github.com/aorenste
This commit is contained in:
Laith Sakka
2025-07-02 23:12:29 +00:00
committed by PyTorch MergeBot
parent d40aaa42ee
commit 7cfd054075
34 changed files with 390 additions and 114 deletions

View File

@ -1 +1 @@
55a75404c9b75cd5fd62ab5d4deafc8c506b3af2
926700d7832caa552ba2e1fc8302f6a2f4d2f6d8

View File

@ -499,8 +499,8 @@ int64_t FunctionalTensorWrapper::dim_custom() const {
int64_t FunctionalTensorWrapper::numel_custom() const {
return value_.unsafeGetTensorImpl()->numel();
}
bool FunctionalTensorWrapper::is_contiguous_custom(at::MemoryFormat memory_format) const {
return value_.unsafeGetTensorImpl()->is_contiguous(memory_format);
c10::SymBool FunctionalTensorWrapper::sym_is_contiguous_custom(at::MemoryFormat memory_format) const {
return value_.unsafeGetTensorImpl()->sym_is_contiguous(memory_format);
}
c10::SymIntArrayRef FunctionalTensorWrapper::sym_sizes_custom() const {
return value_.unsafeGetTensorImpl()->sym_sizes();

View File

@ -236,7 +236,8 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
at::IntArrayRef strides_custom() const override;
int64_t dim_custom() const override;
int64_t numel_custom() const override;
bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
c10::SymBool sym_is_contiguous_custom(
at::MemoryFormat memory_format) const override;
c10::SymIntArrayRef sym_sizes_custom() const override;
c10::SymInt sym_size_custom(int64_t d) const override;
c10::SymIntArrayRef sym_strides_custom() const override;

View File

@ -320,11 +320,9 @@ static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymInt
auto stride = at::detail::computeStride(self.sym_sizes(), self.sym_strides(), inferred_size);
if (!stride.has_value()) {
// With unbacked symints, computeStride could fail even on contiguous
// tensors. In this case, we can use the strides of an empty tensor of
// inferred_size.
TORCH_CHECK(
self.is_contiguous(),
TORCH_SYM_CHECK(
self.sym_is_contiguous(),
"View is not valid from size:",
self.sym_sizes(),
" stride: ",
@ -333,6 +331,9 @@ static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymInt
inferred_size,
" in case of unbacked symbols consider adding torch.check to guide computing strides.");
// With unbacked symints, computeStride could fail even on contiguous
// tensors. In this case, we can use the strides of an empty tensor of
// inferred_size.
stride = at::detail::empty_symint_meta(
inferred_size,
std::nullopt,

View File

@ -84,7 +84,7 @@ IntArrayRef BatchedTensorImpl::strides_custom() const {
// TODO: implement proper contiguity on batched tensor, then put
// sizes_strides_policy back to Default
bool BatchedTensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const {
c10::SymBool BatchedTensorImpl::sym_is_contiguous_custom(at::MemoryFormat memory_format) const {
TORCH_CHECK(memory_format == MemoryFormat::Contiguous,
"NYI: querying is_contiguous inside of vmap for memory_format ",
"other than torch.contiguous_format");

View File

@ -82,7 +82,8 @@ struct TORCH_API BatchedTensorImpl : public c10::TensorImpl {
IntArrayRef strides_custom() const override;
// Override a bunch of methods inherited from TensorImpl to return error
// messages.
bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
c10::SymBool sym_is_contiguous_custom(
at::MemoryFormat memory_format) const override;
void set_size(int64_t dim, int64_t new_size) override;
void set_stride(int64_t dim, int64_t new_stride) override;
void set_storage_offset(int64_t storage_offset) override;

View File

@ -24,7 +24,7 @@ MemOverlap has_internal_overlap(TensorImpl* t) {
}
}
if (t->is_non_overlapping_and_dense()) {
if (t->is_non_overlapping_and_dense_or_false()) {
return MemOverlap::No;
}
@ -63,7 +63,7 @@ MemOverlapStatus get_overlap_status(const TensorImpl* a, const TensorImpl* b) {
if (a->numel() == 0 || b->numel() == 0) {
return MemOverlapStatus::No;
}
if (!a->is_non_overlapping_and_dense() || !b->is_non_overlapping_and_dense()) {
if (!a->is_non_overlapping_and_dense_or_false() || !b->is_non_overlapping_and_dense_or_false()) {
return MemOverlapStatus::TooHard;
}
// Test for storage equality, rather than pointer equality.

View File

@ -273,7 +273,7 @@ c10::SymInt NestedTensorImpl::sym_numel_custom() const {
return NestedTensorImpl::numel_custom();
}
bool NestedTensorImpl::is_contiguous_custom(MemoryFormat) const {
c10::SymBool NestedTensorImpl::sym_is_contiguous_custom(MemoryFormat) const {
return nested_tensor_impl_is_contiguous(this);
}
IntArrayRef NestedTensorImpl::sizes_custom() const {

View File

@ -115,7 +115,7 @@ struct TORCH_API NestedTensorImpl : public c10::TensorImpl {
// with real implementations
int64_t numel_custom() const override;
c10::SymInt sym_numel_custom() const override;
bool is_contiguous_custom(MemoryFormat) const override;
c10::SymBool sym_is_contiguous_custom(MemoryFormat) const override;
int64_t size_custom(int64_t d) const override {
return this->size(d);
}

View File

@ -252,8 +252,7 @@ void SparseCsrTensorImpl::set_stride(int64_t dim, int64_t new_stride) {
void SparseCsrTensorImpl::set_storage_offset(int64_t storage_offset) {
TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have set_storage_offset.");
}
bool SparseCsrTensorImpl::is_contiguous_custom(MemoryFormat) const {
c10::SymBool SparseCsrTensorImpl::sym_is_contiguous_custom(MemoryFormat) const {
TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have is_contiguous");
}
} // namespace at

View File

@ -86,7 +86,7 @@ struct TORCH_API SparseCsrTensorImpl : public TensorImpl {
protected:
IntArrayRef strides_custom() const override;
SymIntArrayRef sym_strides_custom() const override;
bool is_contiguous_custom(MemoryFormat) const override;
SymBool sym_is_contiguous_custom(MemoryFormat) const override;
public:
void set_size(int64_t dim, int64_t new_size) override;

View File

@ -124,7 +124,7 @@ class TORCH_API TensorBase {
}
TensorBase contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const {
if (is_contiguous(memory_format)) {
if (is_contiguous_or_false(memory_format)) {
return *this;
} else {
return __dispatch_contiguous(memory_format);
@ -265,6 +265,25 @@ class TORCH_API TensorBase {
return impl_->is_contiguous(memory_format);
}
// Like is_contiguous, but more dynamic shape-friendly. May return a symbolic representation of
// contiguity instead of SymTrue SymFalse, when results are data-dependent.
c10::SymBool sym_is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const {
if (impl_->has_symbolic_sizes_strides()) {
return impl_->sym_is_contiguous(memory_format);
}
return impl_->is_contiguous(memory_format);
}
// Like is_contiguous, but more dynamic shape-friendly. Can returns
// false instead of throwing data-dependent errors for tensors with unbacked
// sizes or strides.
bool is_contiguous_or_false(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const {
if (impl_->has_symbolic_sizes_strides()) {
return impl_->sym_is_contiguous(memory_format).guard_or_false(__FILE__, __LINE__);
}
return impl_->is_contiguous(memory_format);
}
bool is_non_overlapping_and_dense() const {
return impl_->is_non_overlapping_and_dense();
}

View File

@ -126,7 +126,7 @@ SymIntArrayRef BatchedTensorImpl::sym_strides_custom() const {
// TODO: implement proper contiguity on batched tensor, then put
// sizes_strides_policy back to Default
bool BatchedTensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const {
c10::SymBool BatchedTensorImpl::sym_is_contiguous_custom(at::MemoryFormat memory_format) const {
TORCH_CHECK(memory_format == MemoryFormat::Contiguous,
"NYI: querying is_contiguous inside of vmap for memory_format ",
"other than torch.contiguous_format");

View File

@ -69,7 +69,7 @@ struct TORCH_API BatchedTensorImpl : public c10::TensorImpl {
IntArrayRef strides_custom() const override;
SymIntArrayRef sym_strides_custom() const override;
// Override a bunch of methods inherited from TensorImpl to return error messages.
bool is_contiguous_custom(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const override;
c10::SymBool sym_is_contiguous_custom(at::MemoryFormat memory_format) const override;
void set_size(int64_t dim, int64_t new_size) override;
void set_stride(int64_t dim, int64_t new_stride) override;
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(

View File

@ -93,7 +93,7 @@ Tensor linear(const Tensor& input, const Tensor& weight, const std::optional<Ten
if (bias->defined() && !input.is_xla()) {
// Also hit the fused path for contiguous 3D input, if not using xla
// backend. Reshaping/flattening has some performance implications on xla.
bool is_contiguous = definitely_contiguous(input.sym_sizes(), input.sym_strides(), input.sym_numel());
bool is_contiguous = input.is_contiguous_or_false();
if (is_contiguous && input_dim == 3) {
return _flatten_nd_linear(input, weight, *bias);
} else if (is_contiguous && input.layout() == c10::kStrided && weight.layout() == c10::kStrided && bias->dim() == 1) {

View File

@ -113,7 +113,7 @@ Tensor& detach_(Tensor& self) {
}
Tensor contiguous(const Tensor& self, MemoryFormat memory_format) {
if (self.is_contiguous(memory_format)) {
if (self.is_contiguous_or_false(memory_format)) {
return self;
}
TORCH_CHECK(

View File

@ -1998,19 +1998,18 @@ Tensor reshape_symint(const Tensor& self, c10::SymIntArrayRef proposed_shape) {
TORCH_CHECK(false, "reshape is not implemented for sparse tensors");
}
auto sym_sizes = self.sym_sizes();
auto sym_strides = self.sym_strides();
auto sym_numel = self.sym_numel();
if (definitely_contiguous(sym_sizes, sym_strides, sym_numel) &&
!self.is_mkldnn()) {
if (self.is_contiguous_or_false() && !self.is_mkldnn()) {
return self.view_symint(proposed_shape);
}
auto sym_numel = self.sym_numel();
c10::SymDimVector shape = infer_size_dv(proposed_shape, sym_numel);
if (self.is_mkldnn()) {
return at::_mkldnn_reshape(self, C10_AS_INTARRAYREF_SLOW(shape));
}
auto sym_sizes = self.sym_sizes();
auto sym_strides = self.sym_strides();
// `computeStride` returns the proper strides to use if this
// `reshape` can be just a view.

View File

@ -35,7 +35,7 @@ struct TORCH_API MetalTensorImpl : public OpaqueTensorImpl<OpaqueHandle> {
return c10::fromIntArrayRefKnownNonNegative(strides_);
}
bool is_contiguous_custom(c10::MemoryFormat memory_format) const override {
c10::SymBool sym_is_contiguous_custom(c10::MemoryFormat memory_format) const override {
return true;
}

View File

@ -776,7 +776,7 @@ Tensor scaled_dot_product_attention(
#ifdef USE_MPS
const auto any_nested = query_.is_nested() || key.is_nested() || value.is_nested();
const bool any_inputs_require_grad = query_.requires_grad() || key.requires_grad() || value.requires_grad();
const auto all_contiguous = query_.is_contiguous() && key.is_contiguous() && value.is_contiguous();
const auto all_contiguous = query_.is_contiguous_or_false() && key.is_contiguous_or_false() && value.is_contiguous_or_false();
if (query_device_type == DeviceType::MPS && dropout_p == 0.0
&& !(GradMode::is_enabled() && any_inputs_require_grad)
&& (all_contiguous || mps::is_macos_13_or_newer(mps::MacOSVersion::MACOS_VER_15_0_PLUS))

View File

@ -33,7 +33,8 @@ struct VulkanOpaqueTensorImpl : public OpaqueTensorImpl<OpaqueHandle> {
return c10::fromIntArrayRefKnownNonNegative(strides_);
}
bool is_contiguous_custom(c10::MemoryFormat memory_format) const override {
c10::SymBool sym_is_contiguous_custom(
c10::MemoryFormat memory_format) const override {
(void)memory_format;
return true;
}

View File

@ -12,7 +12,7 @@ namespace c10 {
template <typename T>
bool _compute_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides, T numel) {
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(numel, 0))) {
if (numel == 0) {
return true;
}
@ -20,11 +20,11 @@ bool _compute_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides, T numel) {
// NB: make sure we do signed arithmetic
for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) {
const auto& size_d = sizes[d];
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(size_d, 1))) {
if (size_d == 1) {
continue;
}
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(strides[d], expected_stride))) {
if (strides[d] != expected_stride) {
return false;
}
expected_stride *= size_d;
@ -32,29 +32,66 @@ bool _compute_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides, T numel) {
return true;
}
// This function will return True if the tensor is contiguous, and False if the
// its not or if we can't determine if it is contiguous due to unbacked symbols
// (it could be either in that case based on the actual runtime data).
template <typename T>
bool definitely_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides, T numel) {
// Return a SymBool with underlying symbolic expression that represents
// contiguity. Guaranteed not to add guards.
inline static c10::SymBool _compute_contiguous_sym(
ArrayRef<c10::SymInt> sizes,
ArrayRef<c10::SymInt> strides,
const c10::SymInt& numel) {
// If this return true, the tensor is contiguous indeed. Otherwise it could be
// either.
auto is_contiguous_or_false = [&]() {
if (TORCH_GUARD_OR_FALSE(sym_eq(numel, 0))) {
return true;
}
T expected_stride = 1;
// When calculating the expected stride, we can choose to multiply
// with max(1, size[d]) or size[d]. Regardless, this is ok for this
// function. Why?
// (1) If size[d] == 0, then the tensor is contiguous and if
// we return true or false it won't break this function.
// (2) If size[d] is not 0, then max(1,size[d]) and size[d] are equal.
// Therefore, if we choose to use max(1, size[d]) or size[d] to
// calculate the expected stride, the result is the same.
//
// We symbolically check both paths to maximize the cases where this
// function returns true. This is because make_contiguous_strides_for adds
// the max symbolically, and in some other situations the max might not be
// there. And we want to ensure we return true in both cases.
c10::SymInt expected_stride = 1;
c10::SymInt expected_stride_max = 1;
// NB: make sure we do signed arithmetic
for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) {
const auto& size_d = sizes[d];
if (TORCH_GUARD_OR_FALSE(sym_eq(size_d, 1))) {
if (TORCH_GUARD_OR_FALSE(sym_eq(sizes[d], 1))) {
continue;
}
if (TORCH_GUARD_OR_TRUE(sym_ne(strides[d], expected_stride))) {
if (TORCH_GUARD_OR_TRUE(sym_ne(strides[d], expected_stride)) &&
TORCH_GUARD_OR_TRUE(sym_ne(strides[d], expected_stride_max))) {
return false;
}
expected_stride *= size_d;
expected_stride_max *= sizes[d].max(1);
expected_stride *= sizes[d];
}
return true;
};
if (is_contiguous_or_false()) {
return c10::SymBool(true);
}
// Build a single expression that represents contiguity and return it.
c10::SymBool is_empty = sym_eq(numel, 0);
c10::SymBool is_contiguous_cond = true;
c10::SymInt expected_stride = 1;
for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) {
const auto& size_d = sizes[d];
is_contiguous_cond = is_contiguous_cond.sym_and(
size_d.sym_eq(1).sym_or(sym_eq(strides[d], expected_stride)));
expected_stride = expected_stride * size_d;
}
return is_contiguous_cond.sym_or(is_empty);
}
template <typename T>

View File

@ -79,11 +79,44 @@ SymBool SymbolicShapeMeta::compute_contiguous() const {
}
c10::SymIntArrayRef sizes(sizes_);
c10::SymIntArrayRef strides(strides_);
return _compute_contiguous(sizes, strides, numel());
auto result = _compute_contiguous_sym(sizes, strides, numel());
// If the result is already determined without guarding, just return it.
auto maybe_as_bool = result.maybe_as_bool();
if (maybe_as_bool.has_value()) {
return maybe_as_bool.value();
}
auto all_hinted = true;
for (const auto& s : sizes) {
if (!s.has_hint()) {
all_hinted = false;
break;
}
}
if (all_hinted) {
for (const auto& s : strides) {
if (!s.has_hint()) {
all_hinted = false;
break;
}
}
}
if (all_hinted) {
// We avoid going through the slow path if everything is hinted,
// because evaluating a large SymPy expression can be expensive.
// TODO exclude backed_size_oblivious from this path.
return _compute_contiguous<SymInt>(sizes_, strides_, numel());
}
return result;
}
// The rest of them
#define DEFINE_EAGER_SYMBOOL_COMPUTE(name, nodeimpl, fallback) \
#define DEFINE_EAGER_SYMBOOL_COMPUTE(name, fallback) \
SymBool SymbolicShapeMeta::name() const { \
if (!strides_valid_) { \
return false; \
@ -110,11 +143,13 @@ SymBool SymbolicShapeMeta::compute_contiguous() const {
}
// clang-format off
DEFINE_EAGER_SYMBOOL_COMPUTE(compute_channels_last_contiguous_2d, is_channels_last_contiguous_2d, _compute_channels_last_contiguous_2d)
DEFINE_EAGER_SYMBOOL_COMPUTE(compute_channels_last_contiguous_3d, is_channels_last_contiguous_3d, _compute_channels_last_contiguous_3d)
DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_2d, is_channels_last_strides_2d, is_channels_last_strides_2d)
DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_3d, is_channels_last_strides_3d, is_channels_last_strides_3d)
DEFINE_EAGER_SYMBOOL_COMPUTE(compute_channels_last_contiguous_2d, _compute_channels_last_contiguous_2d)
DEFINE_EAGER_SYMBOOL_COMPUTE(compute_channels_last_contiguous_3d, _compute_channels_last_contiguous_3d)
DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_2d, is_channels_last_strides_2d)
DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_3d, is_channels_last_strides_3d)
DEFINE_SYMBOOL_COMPUTE(compute_non_overlapping_and_dense, is_non_overlapping_and_dense, _compute_non_overlapping_and_dense)
// clang-format on
#undef DEFINE_SYMBOOL_COMPUTE
@ -192,6 +227,7 @@ void SymbolicShapeMeta::set_numel(SymInt val) const {
numel_ = std::move(val);
available_.fetch_or(numel_avail);
}
void SymbolicShapeMeta::set_is_contiguous(SymBool val) const {
std::scoped_lock lock(mutables_);
if (has_is_contiguous()) {
@ -200,6 +236,7 @@ void SymbolicShapeMeta::set_is_contiguous(SymBool val) const {
is_contiguous_ = std::move(val);
available_.fetch_or(is_contiguous_avail);
}
void SymbolicShapeMeta::set_is_channels_last_contiguous(SymBool val) const {
std::scoped_lock lock(mutables_);
if (has_is_channels_last_contiguous()) {
@ -208,6 +245,7 @@ void SymbolicShapeMeta::set_is_channels_last_contiguous(SymBool val) const {
is_channels_last_contiguous_ = std::move(val);
available_.fetch_or(is_channels_last_contiguous_avail);
}
void SymbolicShapeMeta::set_is_channels_last_3d_contiguous(SymBool val) const {
std::scoped_lock lock(mutables_);
if (has_is_channels_last_3d_contiguous()) {
@ -216,6 +254,7 @@ void SymbolicShapeMeta::set_is_channels_last_3d_contiguous(SymBool val) const {
is_channels_last_3d_contiguous_ = std::move(val);
available_.fetch_or(is_channels_last_3d_contiguous_avail);
}
void SymbolicShapeMeta::set_is_channels_last(SymBool val) const {
std::scoped_lock lock(mutables_);
if (has_is_channels_last()) {
@ -224,6 +263,7 @@ void SymbolicShapeMeta::set_is_channels_last(SymBool val) const {
is_channels_last_ = std::move(val);
available_.fetch_or(is_channels_last_avail);
}
void SymbolicShapeMeta::set_is_channels_last_3d(SymBool val) const {
std::scoped_lock lock(mutables_);
if (has_is_channels_last_3d()) {

View File

@ -1,4 +1,5 @@
#pragma once
#include <c10/core/MemoryFormat.h>
#include <c10/core/SymBool.h>
#include <c10/core/SymInt.h>
#include <c10/macros/Export.h>
@ -82,6 +83,15 @@ class C10_API SymbolicShapeMeta {
return numel_;
}
const SymBool& is_contiguous(at::MemoryFormat memory_format) const {
if (memory_format == at::MemoryFormat::ChannelsLast) {
return this->is_channels_last_contiguous();
} else if (memory_format == at::MemoryFormat::ChannelsLast3d) {
return this->is_channels_last_3d_contiguous();
}
return this->is_contiguous();
}
const SymBool& is_contiguous() const {
if (C10_UNLIKELY(!has_is_contiguous())) {
init_is_contiguous();
@ -194,6 +204,7 @@ class C10_API SymbolicShapeMeta {
// Lazily initialized variables, with the corresponding available_ flag
// indicating whether the value has been initialized
mutable std::atomic<int> available_{0};
enum avail {
numel_avail = 1 << 0,
is_contiguous_avail = 1 << 1,

View File

@ -310,12 +310,14 @@ void TensorImpl::throw_data_ptr_access_error() const {
false, "Cannot access data pointer of Tensor that doesn't have storage");
}
bool TensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const {
c10::SymBool TensorImpl::sym_is_contiguous_custom(
at::MemoryFormat memory_format) const {
if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
return pyobj_slot_.load_pyobj_interpreter()->is_contiguous(
this, memory_format);
}
return is_contiguous_default(memory_format);
return sym_is_contiguous_default(memory_format);
}
bool TensorImpl::is_strides_like_custom(at::MemoryFormat memory_format) const {
@ -326,12 +328,12 @@ bool TensorImpl::is_strides_like_custom(at::MemoryFormat memory_format) const {
return is_strides_like_default(memory_format);
}
bool TensorImpl::is_non_overlapping_and_dense_custom() const {
c10::SymBool TensorImpl::sym_is_non_overlapping_and_dense_custom() const {
if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
return pyobj_slot_.load_pyobj_interpreter()->is_non_overlapping_and_dense(
this);
}
return is_non_overlapping_and_dense_default();
return sym_is_non_overlapping_and_dense_default();
}
IntArrayRef TensorImpl::sizes_custom() const {

View File

@ -812,6 +812,43 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
}
}
c10::SymBool sym_is_contiguous(
at::MemoryFormat memory_format = at::MemoryFormat::Contiguous) const {
if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) {
return sym_is_contiguous_custom(memory_format);
}
return sym_is_contiguous_default(memory_format);
}
template <typename T>
T is_contiguous_default_impl(at::MemoryFormat memory_format) const {
if (!has_symbolic_sizes_strides_) {
if (memory_format == at::MemoryFormat::ChannelsLast) {
return is_channels_last_contiguous_;
} else if (memory_format == at::MemoryFormat::ChannelsLast3d) {
return is_channels_last_3d_contiguous_;
}
return is_contiguous_;
}
// Handle dynamic shapes.
const auto& symbolic = symbolic_shape_meta().is_contiguous(memory_format);
if constexpr (std::is_same_v<T, bool>) {
return symbolic.guard_bool(__FILE__, __LINE__);
} else {
return symbolic;
}
}
bool is_contiguous_default(at::MemoryFormat memory_format) const {
return is_contiguous_default_impl<bool>(memory_format);
}
c10::SymBool sym_is_contiguous_default(at::MemoryFormat memory_format) const {
return is_contiguous_default_impl<c10::SymBool>(memory_format);
}
/**
* Whether or not a tensor is laid out in contiguous memory.
*
@ -827,30 +864,6 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return is_contiguous_default(memory_format);
}
// These are factored into separate functions in case subclasses
// want to use them
bool is_contiguous_default(at::MemoryFormat memory_format) const {
if (has_symbolic_sizes_strides_) {
if (memory_format == at::MemoryFormat::ChannelsLast) {
return symbolic_shape_meta().is_channels_last_contiguous().guard_bool(
__FILE__, __LINE__);
} else if (memory_format == at::MemoryFormat::ChannelsLast3d) {
return symbolic_shape_meta()
.is_channels_last_3d_contiguous()
.guard_bool(__FILE__, __LINE__);
}
return symbolic_shape_meta().is_contiguous().guard_bool(
__FILE__, __LINE__);
}
if (memory_format == at::MemoryFormat::ChannelsLast) {
return is_channels_last_contiguous_;
} else if (memory_format == at::MemoryFormat::ChannelsLast3d) {
return is_channels_last_3d_contiguous_;
}
return is_contiguous_;
}
bool is_strides_like_default(at::MemoryFormat memory_format) const {
if (has_symbolic_sizes_strides_) {
if (memory_format == at::MemoryFormat::ChannelsLast) {
@ -873,9 +886,17 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
}
}
SymBool sym_is_non_overlapping_and_dense_default() const {
if (has_symbolic_sizes_strides_) {
return symbolic_shape_meta().is_non_overlapping_and_dense();
} else {
return is_non_overlapping_and_dense_;
}
}
bool is_non_overlapping_and_dense_default() const {
if (has_symbolic_sizes_strides_) {
return symbolic_shape_meta().is_non_overlapping_and_dense().guard_bool(
return sym_is_non_overlapping_and_dense_default().guard_bool(
__FILE__, __LINE__);
} else {
return is_non_overlapping_and_dense_;
@ -968,9 +989,24 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
* for a tensor to have rank, but not well defined sizes.
*/
// sizes_strides_policy_ >= CustomStrides
virtual bool is_contiguous_custom(at::MemoryFormat memory_format) const;
virtual bool is_strides_like_custom(at::MemoryFormat memory_format) const;
virtual bool is_non_overlapping_and_dense_custom() const;
virtual c10::SymBool sym_is_non_overlapping_and_dense_custom() const;
bool is_non_overlapping_and_dense_custom() const {
return sym_is_non_overlapping_and_dense_custom().guard_bool(
__FILE__, __LINE__);
}
virtual c10::SymBool sym_is_contiguous_custom(
at::MemoryFormat memory_format) const;
bool is_contiguous_custom(at::MemoryFormat memory_format) const {
return sym_is_contiguous_custom(memory_format)
.guard_bool(__FILE__, __LINE__);
}
// sizes_strides_policy_ >= CustomSizes
// Currently this method only exists to be overwritten by subclasses such as
// NestedTensorImpl.
@ -1004,7 +1040,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
virtual c10::SymInt sym_storage_offset_custom() const;
public:
/**
/**
* True if this tensor has storage. See storage() for details.
*/
#ifdef DEBUG
@ -1016,11 +1052,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
#endif
bool
has_storage() const
// NOTE: we devirtualize this because it arguably shouldn't be an
// error just to ask subclasses if they have storage.
// This used to throw for most subclasses, but OpaqueTensorImpl
// wanted it to successfully return false, so we went ahead and made
// it a non-error.
// NOTE: we devirtualize this because it arguably shouldn't be an
// error just to ask subclasses if they have storage.
// This used to throw for most subclasses, but OpaqueTensorImpl
// wanted it to successfully return false, so we went ahead and made
// it a non-error.
#ifdef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
{
return storage_;
@ -2447,6 +2483,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return is_strides_like(at::MemoryFormat::ChannelsLast3d);
}
bool is_non_overlapping_and_dense_or_false() const {
return sym_is_non_overlapping_and_dense().guard_or_false(
__FILE__, __LINE__);
}
bool is_non_overlapping_and_dense() const {
if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) {
return is_non_overlapping_and_dense_custom();
@ -2454,6 +2495,13 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return is_non_overlapping_and_dense_default();
}
SymBool sym_is_non_overlapping_and_dense() const {
if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) {
return sym_is_non_overlapping_and_dense_custom();
}
return sym_is_non_overlapping_and_dense_default();
}
// if this returns true, then it is guaranteed that this tensor has symbolic
// sizes/strides
bool has_symbolic_sizes_strides() const {

View File

@ -12,7 +12,8 @@ UndefinedTensorImpl::UndefinedTensorImpl()
set_custom_sizes_strides(SizesStridesPolicy::CustomStrides);
}
bool UndefinedTensorImpl::is_contiguous_custom(MemoryFormat format) const {
c10::SymBool UndefinedTensorImpl::sym_is_contiguous_custom(
MemoryFormat format) const {
return is_contiguous_default(format);
}
IntArrayRef UndefinedTensorImpl::strides_custom() const {

View File

@ -32,7 +32,7 @@ struct C10_API UndefinedTensorImpl final : public TensorImpl {
void set_storage_offset(int64_t offset) override;
protected:
bool is_contiguous_custom(MemoryFormat format) const override;
c10::SymBool sym_is_contiguous_custom(MemoryFormat format) const override;
IntArrayRef strides_custom() const override;
SymIntArrayRef sym_strides_custom() const override;

View File

@ -15467,6 +15467,48 @@ class TestExportCustomClass(TorchTestCase):
MyModel(), inps, dynamic_shapes=spec, strict=True
).run_decompositions({})
def test_unbacked_contiguous(self):
class MyModel(torch.nn.Module):
def forward(self, x, mask):
masked_select = x.masked_select(mask)
view = masked_select.view(-1, 1548)
contig = view.contiguous()
return contig + 1
example_inputs = (
torch.randn((768, 1548), dtype=torch.bfloat16),
torch.randint(low=0, high=1, size=(768, 1), dtype=torch.bool),
)
spec = {
"x": [Dim.STATIC, Dim.STATIC],
"mask": [Dim.STATIC, Dim.STATIC],
}
traced = export(MyModel(), example_inputs, strict=True)
self.assertExpectedInline(
traced.graph_module.code,
"""\
def forward(self, x, mask):
masked_select = torch.ops.aten.masked_select.default(x, mask); x = mask = None
sym_size_int_1 = torch.ops.aten.sym_size.int(masked_select, 0)
sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_1); sym_constrain_range_for_size_default = None
ge = sym_size_int_1 >= 0
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
le = sym_size_int_1 <= 1188864
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 1188864 on node 'le'"); le = _assert_scalar_default_1 = None
mod = sym_size_int_1 % 1548
eq_2 = mod == 0; mod = None
_assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(eq_2, "Runtime assertion failed for expression Eq(Mod(u0, 1548), 0) on node 'eq_2'"); eq_2 = _assert_scalar_default_2 = None
floordiv = sym_size_int_1 // 1548
mul_2 = 1548 * floordiv; floordiv = None
eq_3 = sym_size_int_1 == mul_2; sym_size_int_1 = mul_2 = None
_assert_scalar_default_3 = torch.ops.aten._assert_scalar.default(eq_3, "Runtime assertion failed for expression Eq(u0, 1548*((u0//1548))) on node 'eq_3'"); eq_3 = _assert_scalar_default_3 = None
view = torch.ops.aten.view.default(masked_select, [-1, 1548]); masked_select = None
add = torch.ops.aten.add.Tensor(view, 1); view = None
return (add,)""",
ignore_empty_lines=True,
)
if __name__ == "__main__":
run_tests()

View File

@ -3336,8 +3336,8 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
_assert_scalar_4 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u2*u3, u0*u1) on node 'eq'"); eq = _assert_scalar_4 = None
clone: "f32[u2, u3][Max(1, u3), 1]cpu" = torch.ops.aten.clone.default(arg3_1, memory_format = torch.contiguous_format); arg3_1 = None
view: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.view.default(clone, [_local_scalar_dense, _local_scalar_dense_1]); clone = _local_scalar_dense = _local_scalar_dense_1 = None
mul_19: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None
return (mul_19,)""", # noqa: B950
mul_21: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None
return (mul_21,)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
@ -3460,6 +3460,75 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
func(torch.ones(5, 6, 9, 8))
self.assertEqual(cnt.frame_count, 3)
@skipIfTorchDynamo("not allowed to trace mark_unbacked")
@fresh_cache()
def test_unbacked_contiguous(self):
cnt = CompileCounterWithBackend("inductor")
def func(x):
contig = x.contiguous()
return (contig + 1) * 100
compiled_func = torch.compile(fullgraph=True, backend=cnt, dynamic=True)(func)
x = torch.randn(10, 10)
# make x not contiguous.
x = x.t_()
torch._dynamo.decorators.mark_unbacked(x, 0)
torch._dynamo.decorators.mark_unbacked(x, 1)
log_stream, ctx = logs_to_string(
"torch._inductor.compile_fx", "post_grad_graphs"
)
with ctx():
compiled_func(x)
self.assertEqual(compiled_func(x), func(x))
y = torch.rand(20, 20).t()
self.assertEqual(compiled_func(y), func(y))
self.assertEqual(cnt.frame_count, 1)
output = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
self.assertExpectedInline(
output,
"""\
ge_1: "Sym(u0 >= 0)" = arg0_1 >= 0; arg0_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge_3: "Sym(u1 >= 0)" = arg1_1 >= 0; arg1_1 = None
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
clone: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.clone.default(arg2_1, memory_format = torch.contiguous_format); arg2_1 = None
add_3: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.add.Tensor(clone, 1); clone = None
mul_6: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(add_3, 100); add_3 = None
return (mul_6,)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
log_stream, ctx = logs_to_string(
"torch._inductor.compile_fx", "post_grad_graphs"
)
with ctx():
# recompilation will happen due to stride specialization.
y = torch.rand(20, 20)
torch._dynamo.decorators.mark_unbacked(y, 0)
torch._dynamo.decorators.mark_unbacked(y, 1)
self.assertEqual(compiled_func(y), func(y))
self.assertEqual(cnt.frame_count, 2)
output = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
# No clone this time since input is contiguous.
self.assertExpectedInline(
output,
"""\
ge_1: "Sym(u0 >= 0)" = arg0_1 >= 0; arg0_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge_3: "Sym(u1 >= 0)" = arg1_1 >= 0; arg1_1 = None
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
add: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.add.Tensor(arg2_1, 1); arg2_1 = None
mul_5: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(add, 100); add = None
return (mul_5,)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
instantiate_parametrized_tests(TestUnbacked)

View File

@ -1370,8 +1370,8 @@ def forward(self, crop_camera_1, mask_1):
view_1 = torch.ops.aten.view.default(expand_1, [sym_size_int, sym_size_int_1, sym_size_int_2]); expand_1 = sym_size_int_1 = sym_size_int_2 = None
bmm = torch.ops.aten.bmm.default(view, view_1); view = view_1 = None
view_2 = torch.ops.aten.view.default(bmm, [sym_size_int, 3, 3]); bmm = None
mul_6 = sym_size_int * 3
view_3 = torch.ops.aten.view.default(view_2, [mul_6, 3]); view_2 = mul_6 = None
mul_9 = sym_size_int * 3
view_3 = torch.ops.aten.view.default(view_2, [mul_9, 3]); view_2 = mul_9 = None
mm = torch.ops.aten.mm.default(view_3, eye); view_3 = eye = None
_unsafe_view = torch.ops.aten._unsafe_view.default(mm, [sym_size_int, 3, 3]); mm = sym_size_int = None
index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], _unsafe_view); crop_camera_1 = mask_1 = _unsafe_view = index_put_ = None

View File

@ -264,7 +264,7 @@ static PyObject * THPVariable_contiguous(PyObject* self, PyObject* args, PyObjec
auto& self_ = THPVariable_Unpack(self);
auto memory_format = r.memoryformat(0);
// avoids touching the GIL or current device if self is already contiguous
if (self_.is_contiguous(memory_format)) {
if (self_.is_contiguous_or_false(memory_format)) {
// NOTE: this logic is duplicated from VariableType.cpp. Since we need to
// record this call to contiguous() in the trace regardless of whether
// we actually call contiguous here, we need to record this information

View File

@ -195,13 +195,14 @@ bool LTCTensorImpl::is_strides_like_custom(
return false;
}
bool LTCTensorImpl::is_non_overlapping_and_dense_custom() const {
c10::SymBool LTCTensorImpl::sym_is_non_overlapping_and_dense_custom() const {
// This should be true, but false as a temporary fix for a PyTorch core issue,
// according to https://github.com/pytorch/xla/pull/2682.
return false;
}
bool LTCTensorImpl::is_contiguous_custom(c10::MemoryFormat _unused) const {
c10::SymBool LTCTensorImpl::sym_is_contiguous_custom(
c10::MemoryFormat _unused) const {
// TODO(ezyang): I don't think this branch is actually necessary
// TODO(ezyang): I don't think this logic is right, shouldn't we pass on
// the memory format?

View File

@ -41,10 +41,11 @@ class TORCH_API LTCTensorImpl final : public c10::TensorImpl {
int64_t numel_custom() const override;
int64_t storage_offset_custom() const override;
int64_t dim_custom() const override;
bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
bool is_strides_like_custom(at::MemoryFormat memory_format) const override;
bool is_non_overlapping_and_dense_custom() const override;
c10::SymBool sym_is_non_overlapping_and_dense_custom() const override;
c10::SymBool sym_is_contiguous_custom(
at::MemoryFormat memory_format) const override;
c10::SymIntArrayRef sym_sizes_custom() const override;
c10::SymIntArrayRef sym_strides_custom() const override;
c10::SymInt sym_numel_custom() const override;

View File

@ -20,6 +20,9 @@ class ExprPrinter(StrPrinter):
def _print_Mul(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, "*", precedence(expr))
def _print_Not(self, expr: sympy.Expr) -> str:
return f"not ({self._print(expr.args[0])})"
def _print_Add(self, expr: sympy.Expr, order: Optional[str] = None) -> str:
return self.stringify(expr.args, " + ", precedence(expr))