Remove guard_size_oblivious from default contiguity python check, and add aten.sym_is_contiguous. (#159197)

This might cause some new DDEs on call sites that do not use is_contiguous_or_false() or sym_is_contiguous()
but want to find those call sites to handle this properly by calling  is_contiguous_or_false() and not is_contiguous() explitly when appropriate.
I had to fix one issue after removing the implicit size oblivious reasoning. here is context

we defined in this https://github.com/pytorch/pytorch/pull/157472 sym_is_contiguous to be the function computing contiguity for dynamic shapes in c++. It returns a symbolic expression that represents contiguity and guaranteed not to throw a DDE.

when people call is_contiguous we do sym_is_contiguous().guard_bool()
when people call is_contiguous_or_false we do sym_is_contiguous().guard_or_false()

one issue not handled well was this path
```
c10::SymBool TensorImpl::sym_is_contiguous_custom(
    at::MemoryFormat memory_format) const {
  if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
    return pyobj_slot_.load_pyobj_interpreter()->is_contiguous(
        this, memory_format);
  }

  return sym_is_contiguous_default(memory_format);
}
```
namely if we call sym_is_contiguous_custom but we have matches_python_custom(SizesStridesPolicy::CustomStrides) return true , then we used to call is_contiguous(this, memory_format);

This used to go through the load_pyobj_interpreter and end up calling the python is_contiguous call which used implicit size oblivious reasoning.
once we removed that implicit size oblivious reasoning, the right thing we want is to call
return pyobj_slot_.load_pyobj_interpreter()->sym_is_contiguous(this, memory_format);
otherwise we would get DDE even if the caller is doing sym_is_contiguous.

so I had to define it for pyinterpreter, and then I had to override it for nested tensors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159197
Approved by: https://github.com/ezyang
This commit is contained in:
Laith Sakka
2025-08-15 14:34:05 -07:00
committed by PyTorch MergeBot
parent a84541c73f
commit e444cd24d4
20 changed files with 141 additions and 34 deletions

View File

@ -18,6 +18,7 @@
#include <ATen/ops/is_set_to_native.h>
#include <ATen/ops/size_native.h>
#include <ATen/ops/stride_native.h>
#include <ATen/ops/sym_is_contiguous_native.h>
#include <ATen/ops/sym_numel_native.h>
#include <ATen/ops/sym_size_native.h>
#include <ATen/ops/sym_storage_offset_native.h>
@ -57,6 +58,12 @@ c10::SymInt sym_size(const Tensor& self, int64_t dim) {
return self.sym_size(dim);
}
c10::SymBool sym_is_contiguous(
const Tensor& self,
c10::MemoryFormat memory_format) {
return self.sym_is_contiguous(memory_format);
}
c10::SymInt sym_stride(const Tensor& self, int64_t dim) {
return self.sym_stride(dim);
}

View File

@ -5509,6 +5509,13 @@
tags: core
manual_cpp_binding: True
- func: sym_is_contiguous(Tensor self, MemoryFormat memory_format=contiguous_format) -> SymBool
variants: function
device_check: NoCheck
device_guard: False
tags: core
manual_cpp_binding: True
- func: sym_numel(Tensor self) -> SymInt
variants: function
device_check: NoCheck

View File

@ -313,8 +313,15 @@ void TensorImpl::throw_data_ptr_access_error() const {
c10::SymBool TensorImpl::sym_is_contiguous_custom(
at::MemoryFormat memory_format) const {
if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
return pyobj_slot_.load_pyobj_interpreter()->is_contiguous(
this, memory_format);
// TO reduce BC breaking and reduce having to introduce
// sym_is_contiguous. call is_contiguous when tensor does not
if (C10_UNLIKELY(has_symbolic_sizes_strides_)) {
return pyobj_slot_.load_pyobj_interpreter()->sym_is_contiguous(
this, memory_format);
} else {
return pyobj_slot_.load_pyobj_interpreter()->is_contiguous(
this, memory_format);
}
}
return sym_is_contiguous_default(memory_format);

View File

@ -60,6 +60,10 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
bool is_contiguous(const TensorImpl* self, at::MemoryFormat) const override {
PANIC(is_contiguous);
}
c10::SymBool sym_is_contiguous(const TensorImpl* self, at::MemoryFormat)
const override {
PANIC(sym_is_contiguous);
}
bool is_strides_like(const TensorImpl* self, at::MemoryFormat)
const override {
PANIC(is_strides_like);

View File

@ -168,6 +168,9 @@ struct C10_API PyInterpreterVTable {
virtual bool is_contiguous(const TensorImpl* self, at::MemoryFormat)
const = 0;
virtual c10::SymBool sym_is_contiguous(
const TensorImpl* self,
at::MemoryFormat) const = 0;
virtual bool is_strides_like(const TensorImpl* self, at::MemoryFormat)
const = 0;
virtual bool is_non_overlapping_and_dense(const TensorImpl* self) const = 0;

View File

@ -208,6 +208,7 @@ xfail_not_implemented = {
"aten::subtract_.Scalar",
"aten::subtract_.Tensor",
"aten::svd.U",
"aten::sym_is_contiguous",
"aten::sym_size.int",
"aten::sym_stride.int",
"aten::sym_numel",

View File

@ -1958,6 +1958,8 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.aten.is_contiguous:
return contiguous_data.is_contiguous()
if func.overloadpacket == torch.ops.aten.sym_is_contiguous:
return torch.ops.aten.sym_is_contiguous(contiguous_data)
return NotImplemented
class ExampleTensor3(torch.Tensor):
@ -1971,6 +1973,8 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.aten.is_contiguous:
return not_contiguous_data.is_contiguous()
if func.overloadpacket == torch.ops.aten.sym_is_contiguous:
return torch.ops.aten.sym_is_contiguous(not_contiguous_data)
return NotImplemented
err_msg = "Multiple dispatch failed for 'torch.ops.aten.is_contiguous'"
@ -2003,6 +2007,7 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func in [
torch.ops.aten.sym_is_contiguous.default,
torch.ops.aten.is_contiguous.default,
torch.ops.aten.is_contiguous.memory_format,
torch.ops.aten.is_strides_like_format.default,

View File

@ -97,6 +97,7 @@ _SKIP_PYTHON_BINDINGS = [
"is_sparse_csr",
"size",
"stride",
"sym_is_contiguous",
"sym_size",
"sym_stride",
"sym_storage_offset",

View File

@ -1560,7 +1560,6 @@ class CatchErrorsWrapper:
frame_state: dict[str, Union[int, FrameStateSizeEntry]],
) -> ConvertFrameReturn:
assert frame_state is not None
input_codes.add(frame.f_code)
is_skipfile = trace_rules.check(frame.f_code)

View File

@ -265,12 +265,14 @@ def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool:
from torch.fx.experimental.symbolic_shapes import (
guard_or_false,
guard_or_true,
guard_size_oblivious,
is_nested_int,
)
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
def eval_eager(x):
return bool(x)
maybe_guard_or_false = guard_or_false if false_if_dde else eval_eager
maybe_guard_or_true = guard_or_true if false_if_dde else eval_eager
if maybe_guard_or_false(a.numel() < 2):
return True
@ -305,14 +307,13 @@ def is_channels_last_contiguous_2d(a: Tensor, false_if_dde=False) -> bool:
if a.ndim != 4:
return False
from torch.fx.experimental.symbolic_shapes import (
guard_or_false,
guard_or_true,
guard_size_oblivious,
)
from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
def eval_eager(x):
return bool(x)
maybe_guard_or_false = guard_or_false if false_if_dde else eval_eager
maybe_guard_or_true = guard_or_true if false_if_dde else eval_eager
expected_stride = 1
for idx in (1, 3, 2, 0):
@ -334,14 +335,13 @@ def is_channels_last_contiguous_3d(a: Tensor, false_if_dde=False) -> bool:
if a.ndim != 5:
return False
from torch.fx.experimental.symbolic_shapes import (
guard_or_false,
guard_or_true,
guard_size_oblivious,
)
from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
def eval_eager(x):
return bool(x)
maybe_guard_or_false = guard_or_false if false_if_dde else eval_eager
maybe_guard_or_true = guard_or_true if false_if_dde else eval_eager
expected_stride = 1
for idx in (1, 4, 3, 2, 0):
@ -406,7 +406,7 @@ def is_channels_last_contiguous_or_false_3d(a: Tensor) -> bool:
# similar to is_contiguous_for_memory_format but return false on data dependency.
def contiguous_for_memory_format_or_false( # type: ignore[return]
def is_contiguous_for_memory_format_or_false( # type: ignore[return]
a: Tensor, *, memory_format: torch.memory_format
) -> bool:
return is_contiguous_for_memory_format(
@ -550,11 +550,14 @@ def compute_elementwise_output_logical_to_physical_perm(
is_contiguous = True
is_channels_last = True
for t in tensors:
is_contiguous = is_contiguous and contiguous_for_memory_format_or_false(
is_contiguous = is_contiguous and is_contiguous_for_memory_format_or_false(
t, memory_format=torch.contiguous_format
)
is_channels_last = is_channels_last and contiguous_for_memory_format_or_false(
t, memory_format=torch.channels_last
is_channels_last = (
is_channels_last
and is_contiguous_for_memory_format_or_false(
t, memory_format=torch.channels_last
)
)
if is_contiguous and not is_channels_last:

View File

@ -19,7 +19,6 @@ import torch.utils._pytree as pytree
from torch import sym_float, sym_int
from torch._prims_common import (
BoolLike,
contiguous_for_memory_format_or_false,
DeviceLikeType,
Dim,
DimsSequenceType,
@ -29,6 +28,7 @@ from torch._prims_common import (
FloatLike,
FloatWithoutSymFloat,
IntLike,
is_contiguous_for_memory_format_or_false,
is_contiguous_or_false,
is_weakly_lesser_type,
Number,
@ -3000,7 +3000,7 @@ def contiguous(
)
# TODO: make logic consistent with aten contiguous
if contiguous_for_memory_format_or_false(a, memory_format=memory_format):
if is_contiguous_for_memory_format_or_false(a, memory_format=memory_format):
return a
return torch.clone(a, memory_format=memory_format)

View File

@ -15,11 +15,11 @@ import torch._prims_common as utils
from torch._dispatch.python import no_python_dispatcher
from torch._ops import OpOverload
from torch._prims_common import (
contiguous_for_memory_format_or_false,
elementwise_dtypes,
ELEMENTWISE_TYPE_PROMOTION_KIND,
is_boolean_dtype,
is_contiguous,
is_contiguous_for_memory_format_or_false,
is_contiguous_or_false,
is_float_dtype,
is_integer_dtype,
@ -1256,13 +1256,13 @@ def make_fast_binary_impl(
continue
definitely_contiguous = (
definitely_contiguous
and contiguous_for_memory_format_or_false(
and is_contiguous_for_memory_format_or_false(
op, memory_format=torch.contiguous_format
)
)
definitely_channels_last = (
definitely_channels_last
and contiguous_for_memory_format_or_false(
and is_contiguous_for_memory_format_or_false(
op, memory_format=torch.channels_last
)
)

View File

@ -82,6 +82,8 @@ struct ConcretePyInterpreterVTable final
bool is_contiguous(const c10::TensorImpl* self, at::MemoryFormat)
const override;
c10::SymBool sym_is_contiguous(const c10::TensorImpl* self, at::MemoryFormat)
const override;
bool is_strides_like(const c10::TensorImpl* self, at::MemoryFormat)
const override;
bool is_non_overlapping_and_dense(const c10::TensorImpl* self) const override;
@ -476,6 +478,33 @@ bool ConcretePyInterpreterVTable::is_contiguous(
return PyObject_IsTrue(out.ptr());
}
c10::SymBool ConcretePyInterpreterVTable::sym_is_contiguous(
const c10::TensorImpl* self,
at::MemoryFormat memory_format) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
py::object out;
out = torchDispatchFromTensorImpl(
self,
"sym_is_contiguous",
py::module::import("torch")
.attr("ops")
.attr("aten")
.attr("sym_is_contiguous")
.attr("default")
.ptr(),
"torch.ops.aten",
{py::cast(memory_format)});
if (out.is_none()) {
return self->sym_is_contiguous_default(memory_format);
}
return torch::is_symbool(out) ? out.cast<c10::SymBool>()
: c10::SymBool{py::cast<bool>(out)};
}
bool ConcretePyInterpreterVTable::is_strides_like(
const c10::TensorImpl* self,
at::MemoryFormat memory_format) const {

View File

@ -33,6 +33,7 @@ using c10::StorageType;
using c10::StreamObjType;
using c10::StringType;
using c10::Symbol;
using c10::SymBoolType;
using c10::SymIntType;
using c10::TensorType;
using c10::TupleType;
@ -66,6 +67,7 @@ TypePtr SchemaTypeParser::parseBaseType() {
{"int", c10::TypeFactory::get<IntType>()},
{"SymInt", c10::TypeFactory::get<SymIntType>()},
{"bool", c10::TypeFactory::get<BoolType>()},
{"SymBool", c10::TypeFactory::get<SymBoolType>()},
{"None", c10::TypeFactory::get<NoneType>()},
{"NoneType", c10::TypeFactory::get<NoneType>()},
{"Capsule", c10::TypeFactory::get<CapsuleType>()},

View File

@ -7,7 +7,7 @@ import torch
import torch.fx
from torch._dispatch.python import enable_python_dispatcher
from torch._guards import detect_fake_mode
from torch._prims_common import contiguous_for_memory_format_or_false
from torch._prims_common import is_contiguous_for_memory_format_or_false
from torch._subclasses.meta_utils import is_sparse_any
from torch.fx._compatibility import compatibility
from torch.fx.node import map_aggregate, Node
@ -57,7 +57,7 @@ def _extract_tensor_metadata(
torch.channels_last_3d,
}
for query_format in memory_formats:
if contiguous_for_memory_format_or_false(
if is_contiguous_for_memory_format_or_false(
result, memory_format=query_format
):
memory_format = query_format

View File

@ -285,7 +285,9 @@ def layout(func, *args, **kwargs):
return _get_data(args[0]).layout
@register_dispatch_func([torch.ops.aten.is_contiguous])
@register_dispatch_func(
[torch.ops.aten.is_contiguous, torch.ops.aten.sym_is_contiguous]
)
def is_contiguous(func, *args, **kwargs):
data = _get_data(args[0])
if data.is_sparse:

View File

@ -234,14 +234,25 @@ class NestedTensor(torch.Tensor):
mt = self._min_seqlen_tensor
return None if mt is None else _load_val_from_tensor(mt)
def _is_contiguous_or_false(self):
if self.lengths() is not None:
return False
from torch._prims_common import is_contiguous_for_memory_format_or_false
return is_contiguous_for_memory_format_or_false(
self._values, memory_format=torch.contiguous_format
)
def __repr__(self): # type: ignore[override]
# We should implement this in torch/_tensor_str.py instead
grad_fn_str = (
f", requires_grad={self.requires_grad}" if self.requires_grad else ""
)
if self.grad_fn:
grad_fn_str = f", grad_fn={self.grad_fn}"
return f"NestedTensor(size={self._size}, offsets={self._offsets}{grad_fn_str}, contiguous={self.is_contiguous()})"
return f"NestedTensor(size={self._size}, offsets={self._offsets}{grad_fn_str}, contiguous={self._is_contiguous_or_false()})"
# TODO: Remove this in favor of the default tensor subclass serialization logic.
# We don't do this today because of https://github.com/pytorch/pytorch/issues/125622.

View File

@ -516,6 +516,29 @@ register_jagged_func(
)(is_contiguous_general)
@register_jagged_func(
torch.ops.aten.sym_is_contiguous.default, "self: jt_all, memory_format: any?"
)
def sym_is_contiguous_general(func, *args, **kwargs):
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
inp = new_kwargs.pop("input")
# If created from narrow() check for lengths
if inp.lengths() is not None:
return False
new_kwargs["memory_format"] = new_kwargs.get(
"memory_format", torch.contiguous_format
)
if new_kwargs["memory_format"] == torch.preserve_format:
return True
return torch.ops.aten.sym_is_contiguous.default(inp._values, **new_kwargs)
@register_jagged_func(
torch.ops.aten.clone.default, "input: jt_all, memory_format: any?"
)

View File

@ -834,7 +834,8 @@ class _FlopCounterMode(TorchDispatchMode):
kwargs = kwargs if kwargs else {}
# Skip ops from non-standard dispatch_sizes_strides_policy such as NJT
if func in {torch.ops.aten.is_contiguous.default,
if func in {torch.ops.aten.sym_is_contiguous.default,
torch.ops.aten.is_contiguous.default,
torch.ops.aten.is_contiguous.memory_format,
torch.ops.aten.is_strides_like_format.default,
torch.ops.aten.is_non_overlapping_and_dense.default,

View File

@ -79,6 +79,7 @@ tensorOptionsT = BaseCppType("at", "TensorOptions")
typeAndSizeT = BaseCppType("torch::autograd::generated", "TypeAndSize")
tensorGeometryT = BaseCppType("at", "TensorGeometry")
SymIntT = BaseCppType("c10", "SymInt")
SymBoolT = BaseCppType("c10", "SymBool")
symIntArrayRefT = BaseCppType("c10", "SymIntArrayRef")
# Types representing template parameters. Technically, we probably shouldn't
@ -125,6 +126,7 @@ BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = {
BaseTy.Storage: storageT,
BaseTy.Stream: streamT,
BaseTy.SymInt: SymIntT,
BaseTy.SymBool: SymBoolT,
}
# CTypes encode C++ type structure as needed for translation.