Remove guard_size_oblivious from default contiguity python check, and add aten.sym_is_contiguous. [attempt2] (#160869)

[relanding again after fixing internal build]
Summary:
This might cause some new DDEs on call sites that do not use is_contiguous_or_false() or sym_is_contiguous()
but want to find those call sites to handle this properly by calling  is_contiguous_or_false() and not is_contiguous() explitly when appropriate.
I had to fix one issue after removing the implicit size oblivious reasoning. here is context

we defined in this https://github.com/pytorch/pytorch/pull/157472 sym_is_contiguous to be the function computing contiguity for dynamic shapes in c++. It returns a symbolic expression that represents contiguity and guaranteed not to throw a DDE.

when people call is_contiguous we do sym_is_contiguous().guard_bool()
when people call is_contiguous_or_false we do sym_is_contiguous().guard_or_false()

one issue not handled well was this path
```
c10::SymBool TensorImpl::sym_is_contiguous_custom(
    at::MemoryFormat memory_format) const {
  if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
    return pyobj_slot_.load_pyobj_interpreter()->is_contiguous(
        this, memory_format);
  }

  return sym_is_contiguous_default(memory_format);
}
```
namely if we call sym_is_contiguous_custom but we have matches_python_custom(SizesStridesPolicy::CustomStrides) return true , then we used to call is_contiguous(this, memory_format);

This used to go through the load_pyobj_interpreter and end up calling the python is_contiguous call which used implicit size oblivious reasoning.
once we removed that implicit size oblivious reasoning, the right thing we want is to call
return pyobj_slot_.load_pyobj_interpreter()->sym_is_contiguous(this, memory_format);
otherwise we would get DDE even if the caller is doing sym_is_contiguous.

so I had to define it for pyinterpreter, and then I had to override it for nested tensors.

Approved by: https://github.com/ezyang

Test Plan:
contbuild & OSS CI, see e444cd24d4

Rollback Plan:

Differential Revision: D80435179

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160869
Approved by: https://github.com/ezyang
This commit is contained in:
Laith Sakka
2025-09-08 22:59:13 +00:00
committed by PyTorch MergeBot
parent 5fd6b6a2db
commit 189a054cfb
21 changed files with 142 additions and 34 deletions

View File

@ -234,14 +234,25 @@ class NestedTensor(torch.Tensor):
mt = self._min_seqlen_tensor
return None if mt is None else _load_val_from_tensor(mt)
def _is_contiguous_or_false(self):
if self.lengths() is not None:
return False
from torch._prims_common import is_contiguous_for_memory_format_or_false
return is_contiguous_for_memory_format_or_false(
self._values, memory_format=torch.contiguous_format
)
def __repr__(self): # type: ignore[override]
# We should implement this in torch/_tensor_str.py instead
grad_fn_str = (
f", requires_grad={self.requires_grad}" if self.requires_grad else ""
)
if self.grad_fn:
grad_fn_str = f", grad_fn={self.grad_fn}"
return f"NestedTensor(size={self._size}, offsets={self._offsets}{grad_fn_str}, contiguous={self.is_contiguous()})"
return f"NestedTensor(size={self._size}, offsets={self._offsets}{grad_fn_str}, contiguous={self._is_contiguous_or_false()})"
# TODO: Remove this in favor of the default tensor subclass serialization logic.
# We don't do this today because of https://github.com/pytorch/pytorch/issues/125622.

View File

@ -516,6 +516,29 @@ register_jagged_func(
)(is_contiguous_general)
@register_jagged_func(
torch.ops.aten.sym_is_contiguous.default, "self: jt_all, memory_format: any?"
)
def sym_is_contiguous_general(func, *args, **kwargs):
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
inp = new_kwargs.pop("input")
# If created from narrow() check for lengths
if inp.lengths() is not None:
return False
new_kwargs["memory_format"] = new_kwargs.get(
"memory_format", torch.contiguous_format
)
if new_kwargs["memory_format"] == torch.preserve_format:
return True
return torch.ops.aten.sym_is_contiguous.default(inp._values, **new_kwargs)
@register_jagged_func(
torch.ops.aten.clone.default, "input: jt_all, memory_format: any?"
)