Revert "Remove guard_size_oblivious from default contiguity python check, and add aten.sym_is_contiguous. (#159197)"

This reverts commit e444cd24d48b3a46f067974f2cc157f5ed27709f.

Reverted https://github.com/pytorch/pytorch/pull/159197 on behalf of https://github.com/laithsakka due to internal build failures ([comment](https://github.com/pytorch/pytorch/pull/159197#issuecomment-3195436668))
This commit is contained in:
PyTorch MergeBot
2025-08-18 07:22:13 +00:00
parent d8d589bd3a
commit b82aa3df20
20 changed files with 34 additions and 141 deletions

View File

@ -18,7 +18,6 @@
#include <ATen/ops/is_set_to_native.h>
#include <ATen/ops/size_native.h>
#include <ATen/ops/stride_native.h>
#include <ATen/ops/sym_is_contiguous_native.h>
#include <ATen/ops/sym_numel_native.h>
#include <ATen/ops/sym_size_native.h>
#include <ATen/ops/sym_storage_offset_native.h>
@ -58,12 +57,6 @@ c10::SymInt sym_size(const Tensor& self, int64_t dim) {
return self.sym_size(dim);
}
c10::SymBool sym_is_contiguous(
const Tensor& self,
c10::MemoryFormat memory_format) {
return self.sym_is_contiguous(memory_format);
}
c10::SymInt sym_stride(const Tensor& self, int64_t dim) {
return self.sym_stride(dim);
}

View File

@ -5509,13 +5509,6 @@
tags: core
manual_cpp_binding: True
- func: sym_is_contiguous(Tensor self, MemoryFormat memory_format=contiguous_format) -> SymBool
variants: function
device_check: NoCheck
device_guard: False
tags: core
manual_cpp_binding: True
- func: sym_numel(Tensor self) -> SymInt
variants: function
device_check: NoCheck

View File

@ -313,15 +313,8 @@ void TensorImpl::throw_data_ptr_access_error() const {
c10::SymBool TensorImpl::sym_is_contiguous_custom(
at::MemoryFormat memory_format) const {
if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
// TO reduce BC breaking and reduce having to introduce
// sym_is_contiguous. call is_contiguous when tensor does not
if (C10_UNLIKELY(has_symbolic_sizes_strides_)) {
return pyobj_slot_.load_pyobj_interpreter()->sym_is_contiguous(
this, memory_format);
} else {
return pyobj_slot_.load_pyobj_interpreter()->is_contiguous(
this, memory_format);
}
return pyobj_slot_.load_pyobj_interpreter()->is_contiguous(
this, memory_format);
}
return sym_is_contiguous_default(memory_format);

View File

@ -60,10 +60,6 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
bool is_contiguous(const TensorImpl* self, at::MemoryFormat) const override {
PANIC(is_contiguous);
}
c10::SymBool sym_is_contiguous(const TensorImpl* self, at::MemoryFormat)
const override {
PANIC(sym_is_contiguous);
}
bool is_strides_like(const TensorImpl* self, at::MemoryFormat)
const override {
PANIC(is_strides_like);

View File

@ -168,9 +168,6 @@ struct C10_API PyInterpreterVTable {
virtual bool is_contiguous(const TensorImpl* self, at::MemoryFormat)
const = 0;
virtual c10::SymBool sym_is_contiguous(
const TensorImpl* self,
at::MemoryFormat) const = 0;
virtual bool is_strides_like(const TensorImpl* self, at::MemoryFormat)
const = 0;
virtual bool is_non_overlapping_and_dense(const TensorImpl* self) const = 0;

View File

@ -208,7 +208,6 @@ xfail_not_implemented = {
"aten::subtract_.Scalar",
"aten::subtract_.Tensor",
"aten::svd.U",
"aten::sym_is_contiguous",
"aten::sym_size.int",
"aten::sym_stride.int",
"aten::sym_numel",

View File

@ -1958,8 +1958,6 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.aten.is_contiguous:
return contiguous_data.is_contiguous()
if func.overloadpacket == torch.ops.aten.sym_is_contiguous:
return torch.ops.aten.sym_is_contiguous(contiguous_data)
return NotImplemented
class ExampleTensor3(torch.Tensor):
@ -1973,8 +1971,6 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.aten.is_contiguous:
return not_contiguous_data.is_contiguous()
if func.overloadpacket == torch.ops.aten.sym_is_contiguous:
return torch.ops.aten.sym_is_contiguous(not_contiguous_data)
return NotImplemented
err_msg = "Multiple dispatch failed for 'torch.ops.aten.is_contiguous'"
@ -2007,7 +2003,6 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func in [
torch.ops.aten.sym_is_contiguous.default,
torch.ops.aten.is_contiguous.default,
torch.ops.aten.is_contiguous.memory_format,
torch.ops.aten.is_strides_like_format.default,

View File

@ -97,7 +97,6 @@ _SKIP_PYTHON_BINDINGS = [
"is_sparse_csr",
"size",
"stride",
"sym_is_contiguous",
"sym_size",
"sym_stride",
"sym_storage_offset",

View File

@ -1560,6 +1560,7 @@ class CatchErrorsWrapper:
frame_state: dict[str, Union[int, FrameStateSizeEntry]],
) -> ConvertFrameReturn:
assert frame_state is not None
input_codes.add(frame.f_code)
is_skipfile = trace_rules.check(frame.f_code)

View File

@ -265,14 +265,12 @@ def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool:
from torch.fx.experimental.symbolic_shapes import (
guard_or_false,
guard_or_true,
guard_size_oblivious,
is_nested_int,
)
def eval_eager(x):
return bool(x)
maybe_guard_or_false = guard_or_false if false_if_dde else eval_eager
maybe_guard_or_true = guard_or_true if false_if_dde else eval_eager
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
if maybe_guard_or_false(a.numel() < 2):
return True
@ -307,13 +305,14 @@ def is_channels_last_contiguous_2d(a: Tensor, false_if_dde=False) -> bool:
if a.ndim != 4:
return False
from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true
from torch.fx.experimental.symbolic_shapes import (
guard_or_false,
guard_or_true,
guard_size_oblivious,
)
def eval_eager(x):
return bool(x)
maybe_guard_or_false = guard_or_false if false_if_dde else eval_eager
maybe_guard_or_true = guard_or_true if false_if_dde else eval_eager
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
expected_stride = 1
for idx in (1, 3, 2, 0):
@ -335,13 +334,14 @@ def is_channels_last_contiguous_3d(a: Tensor, false_if_dde=False) -> bool:
if a.ndim != 5:
return False
from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true
from torch.fx.experimental.symbolic_shapes import (
guard_or_false,
guard_or_true,
guard_size_oblivious,
)
def eval_eager(x):
return bool(x)
maybe_guard_or_false = guard_or_false if false_if_dde else eval_eager
maybe_guard_or_true = guard_or_true if false_if_dde else eval_eager
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
expected_stride = 1
for idx in (1, 4, 3, 2, 0):
@ -406,7 +406,7 @@ def is_channels_last_contiguous_or_false_3d(a: Tensor) -> bool:
# similar to is_contiguous_for_memory_format but return false on data dependency.
def is_contiguous_for_memory_format_or_false( # type: ignore[return]
def contiguous_for_memory_format_or_false( # type: ignore[return]
a: Tensor, *, memory_format: torch.memory_format
) -> bool:
return is_contiguous_for_memory_format(
@ -550,14 +550,11 @@ def compute_elementwise_output_logical_to_physical_perm(
is_contiguous = True
is_channels_last = True
for t in tensors:
is_contiguous = is_contiguous and is_contiguous_for_memory_format_or_false(
is_contiguous = is_contiguous and contiguous_for_memory_format_or_false(
t, memory_format=torch.contiguous_format
)
is_channels_last = (
is_channels_last
and is_contiguous_for_memory_format_or_false(
t, memory_format=torch.channels_last
)
is_channels_last = is_channels_last and contiguous_for_memory_format_or_false(
t, memory_format=torch.channels_last
)
if is_contiguous and not is_channels_last:

View File

@ -19,6 +19,7 @@ import torch.utils._pytree as pytree
from torch import sym_float, sym_int
from torch._prims_common import (
BoolLike,
contiguous_for_memory_format_or_false,
DeviceLikeType,
Dim,
DimsSequenceType,
@ -28,7 +29,6 @@ from torch._prims_common import (
FloatLike,
FloatWithoutSymFloat,
IntLike,
is_contiguous_for_memory_format_or_false,
is_contiguous_or_false,
is_weakly_lesser_type,
Number,
@ -3000,7 +3000,7 @@ def contiguous(
)
# TODO: make logic consistent with aten contiguous
if is_contiguous_for_memory_format_or_false(a, memory_format=memory_format):
if contiguous_for_memory_format_or_false(a, memory_format=memory_format):
return a
return torch.clone(a, memory_format=memory_format)

View File

@ -15,11 +15,11 @@ import torch._prims_common as utils
from torch._dispatch.python import no_python_dispatcher
from torch._ops import OpOverload
from torch._prims_common import (
contiguous_for_memory_format_or_false,
elementwise_dtypes,
ELEMENTWISE_TYPE_PROMOTION_KIND,
is_boolean_dtype,
is_contiguous,
is_contiguous_for_memory_format_or_false,
is_contiguous_or_false,
is_float_dtype,
is_integer_dtype,
@ -1256,13 +1256,13 @@ def make_fast_binary_impl(
continue
definitely_contiguous = (
definitely_contiguous
and is_contiguous_for_memory_format_or_false(
and contiguous_for_memory_format_or_false(
op, memory_format=torch.contiguous_format
)
)
definitely_channels_last = (
definitely_channels_last
and is_contiguous_for_memory_format_or_false(
and contiguous_for_memory_format_or_false(
op, memory_format=torch.channels_last
)
)

View File

@ -82,8 +82,6 @@ struct ConcretePyInterpreterVTable final
bool is_contiguous(const c10::TensorImpl* self, at::MemoryFormat)
const override;
c10::SymBool sym_is_contiguous(const c10::TensorImpl* self, at::MemoryFormat)
const override;
bool is_strides_like(const c10::TensorImpl* self, at::MemoryFormat)
const override;
bool is_non_overlapping_and_dense(const c10::TensorImpl* self) const override;
@ -478,33 +476,6 @@ bool ConcretePyInterpreterVTable::is_contiguous(
return PyObject_IsTrue(out.ptr());
}
c10::SymBool ConcretePyInterpreterVTable::sym_is_contiguous(
const c10::TensorImpl* self,
at::MemoryFormat memory_format) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
py::object out;
out = torchDispatchFromTensorImpl(
self,
"sym_is_contiguous",
py::module::import("torch")
.attr("ops")
.attr("aten")
.attr("sym_is_contiguous")
.attr("default")
.ptr(),
"torch.ops.aten",
{py::cast(memory_format)});
if (out.is_none()) {
return self->sym_is_contiguous_default(memory_format);
}
return torch::is_symbool(out) ? out.cast<c10::SymBool>()
: c10::SymBool{py::cast<bool>(out)};
}
bool ConcretePyInterpreterVTable::is_strides_like(
const c10::TensorImpl* self,
at::MemoryFormat memory_format) const {

View File

@ -33,7 +33,6 @@ using c10::StorageType;
using c10::StreamObjType;
using c10::StringType;
using c10::Symbol;
using c10::SymBoolType;
using c10::SymIntType;
using c10::TensorType;
using c10::TupleType;
@ -67,7 +66,6 @@ TypePtr SchemaTypeParser::parseBaseType() {
{"int", c10::TypeFactory::get<IntType>()},
{"SymInt", c10::TypeFactory::get<SymIntType>()},
{"bool", c10::TypeFactory::get<BoolType>()},
{"SymBool", c10::TypeFactory::get<SymBoolType>()},
{"None", c10::TypeFactory::get<NoneType>()},
{"NoneType", c10::TypeFactory::get<NoneType>()},
{"Capsule", c10::TypeFactory::get<CapsuleType>()},

View File

@ -7,7 +7,7 @@ import torch
import torch.fx
from torch._dispatch.python import enable_python_dispatcher
from torch._guards import detect_fake_mode
from torch._prims_common import is_contiguous_for_memory_format_or_false
from torch._prims_common import contiguous_for_memory_format_or_false
from torch._subclasses.meta_utils import is_sparse_any
from torch.fx._compatibility import compatibility
from torch.fx.node import map_aggregate, Node
@ -57,7 +57,7 @@ def _extract_tensor_metadata(
torch.channels_last_3d,
}
for query_format in memory_formats:
if is_contiguous_for_memory_format_or_false(
if contiguous_for_memory_format_or_false(
result, memory_format=query_format
):
memory_format = query_format

View File

@ -285,9 +285,7 @@ def layout(func, *args, **kwargs):
return _get_data(args[0]).layout
@register_dispatch_func(
[torch.ops.aten.is_contiguous, torch.ops.aten.sym_is_contiguous]
)
@register_dispatch_func([torch.ops.aten.is_contiguous])
def is_contiguous(func, *args, **kwargs):
data = _get_data(args[0])
if data.is_sparse:

View File

@ -234,25 +234,14 @@ class NestedTensor(torch.Tensor):
mt = self._min_seqlen_tensor
return None if mt is None else _load_val_from_tensor(mt)
def _is_contiguous_or_false(self):
if self.lengths() is not None:
return False
from torch._prims_common import is_contiguous_for_memory_format_or_false
return is_contiguous_for_memory_format_or_false(
self._values, memory_format=torch.contiguous_format
)
def __repr__(self): # type: ignore[override]
# We should implement this in torch/_tensor_str.py instead
grad_fn_str = (
f", requires_grad={self.requires_grad}" if self.requires_grad else ""
)
if self.grad_fn:
grad_fn_str = f", grad_fn={self.grad_fn}"
return f"NestedTensor(size={self._size}, offsets={self._offsets}{grad_fn_str}, contiguous={self._is_contiguous_or_false()})"
return f"NestedTensor(size={self._size}, offsets={self._offsets}{grad_fn_str}, contiguous={self.is_contiguous()})"
# TODO: Remove this in favor of the default tensor subclass serialization logic.
# We don't do this today because of https://github.com/pytorch/pytorch/issues/125622.

View File

@ -516,29 +516,6 @@ register_jagged_func(
)(is_contiguous_general)
@register_jagged_func(
torch.ops.aten.sym_is_contiguous.default, "self: jt_all, memory_format: any?"
)
def sym_is_contiguous_general(func, *args, **kwargs):
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
inp = new_kwargs.pop("input")
# If created from narrow() check for lengths
if inp.lengths() is not None:
return False
new_kwargs["memory_format"] = new_kwargs.get(
"memory_format", torch.contiguous_format
)
if new_kwargs["memory_format"] == torch.preserve_format:
return True
return torch.ops.aten.sym_is_contiguous.default(inp._values, **new_kwargs)
@register_jagged_func(
torch.ops.aten.clone.default, "input: jt_all, memory_format: any?"
)

View File

@ -834,8 +834,7 @@ class _FlopCounterMode(TorchDispatchMode):
kwargs = kwargs if kwargs else {}
# Skip ops from non-standard dispatch_sizes_strides_policy such as NJT
if func in {torch.ops.aten.sym_is_contiguous.default,
torch.ops.aten.is_contiguous.default,
if func in {torch.ops.aten.is_contiguous.default,
torch.ops.aten.is_contiguous.memory_format,
torch.ops.aten.is_strides_like_format.default,
torch.ops.aten.is_non_overlapping_and_dense.default,

View File

@ -79,7 +79,6 @@ tensorOptionsT = BaseCppType("at", "TensorOptions")
typeAndSizeT = BaseCppType("torch::autograd::generated", "TypeAndSize")
tensorGeometryT = BaseCppType("at", "TensorGeometry")
SymIntT = BaseCppType("c10", "SymInt")
SymBoolT = BaseCppType("c10", "SymBool")
symIntArrayRefT = BaseCppType("c10", "SymIntArrayRef")
# Types representing template parameters. Technically, we probably shouldn't
@ -126,7 +125,6 @@ BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = {
BaseTy.Storage: storageT,
BaseTy.Stream: streamT,
BaseTy.SymInt: SymIntT,
BaseTy.SymBool: SymBoolT,
}
# CTypes encode C++ type structure as needed for translation.