mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Remove guard_size_oblivious from default contiguity python check, and add aten.sym_is_contiguous. [attempt2] (#160869)
[relanding again after fixing internal build]
Summary:
This might cause some new DDEs on call sites that do not use is_contiguous_or_false() or sym_is_contiguous()
but want to find those call sites to handle this properly by calling is_contiguous_or_false() and not is_contiguous() explitly when appropriate.
I had to fix one issue after removing the implicit size oblivious reasoning. here is context
we defined in this https://github.com/pytorch/pytorch/pull/157472 sym_is_contiguous to be the function computing contiguity for dynamic shapes in c++. It returns a symbolic expression that represents contiguity and guaranteed not to throw a DDE.
when people call is_contiguous we do sym_is_contiguous().guard_bool()
when people call is_contiguous_or_false we do sym_is_contiguous().guard_or_false()
one issue not handled well was this path
```
c10::SymBool TensorImpl::sym_is_contiguous_custom(
at::MemoryFormat memory_format) const {
if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
return pyobj_slot_.load_pyobj_interpreter()->is_contiguous(
this, memory_format);
}
return sym_is_contiguous_default(memory_format);
}
```
namely if we call sym_is_contiguous_custom but we have matches_python_custom(SizesStridesPolicy::CustomStrides) return true , then we used to call is_contiguous(this, memory_format);
This used to go through the load_pyobj_interpreter and end up calling the python is_contiguous call which used implicit size oblivious reasoning.
once we removed that implicit size oblivious reasoning, the right thing we want is to call
return pyobj_slot_.load_pyobj_interpreter()->sym_is_contiguous(this, memory_format);
otherwise we would get DDE even if the caller is doing sym_is_contiguous.
so I had to define it for pyinterpreter, and then I had to override it for nested tensors.
Approved by: https://github.com/ezyang
Test Plan:
contbuild & OSS CI, see e444cd24d4
Rollback Plan:
Differential Revision: D80435179
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160869
Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
5fd6b6a2db
commit
189a054cfb
@ -234,14 +234,25 @@ class NestedTensor(torch.Tensor):
|
||||
mt = self._min_seqlen_tensor
|
||||
return None if mt is None else _load_val_from_tensor(mt)
|
||||
|
||||
def _is_contiguous_or_false(self):
|
||||
if self.lengths() is not None:
|
||||
return False
|
||||
from torch._prims_common import is_contiguous_for_memory_format_or_false
|
||||
|
||||
return is_contiguous_for_memory_format_or_false(
|
||||
self._values, memory_format=torch.contiguous_format
|
||||
)
|
||||
|
||||
def __repr__(self): # type: ignore[override]
|
||||
# We should implement this in torch/_tensor_str.py instead
|
||||
grad_fn_str = (
|
||||
f", requires_grad={self.requires_grad}" if self.requires_grad else ""
|
||||
)
|
||||
|
||||
if self.grad_fn:
|
||||
grad_fn_str = f", grad_fn={self.grad_fn}"
|
||||
return f"NestedTensor(size={self._size}, offsets={self._offsets}{grad_fn_str}, contiguous={self.is_contiguous()})"
|
||||
|
||||
return f"NestedTensor(size={self._size}, offsets={self._offsets}{grad_fn_str}, contiguous={self._is_contiguous_or_false()})"
|
||||
|
||||
# TODO: Remove this in favor of the default tensor subclass serialization logic.
|
||||
# We don't do this today because of https://github.com/pytorch/pytorch/issues/125622.
|
||||
|
@ -516,6 +516,29 @@ register_jagged_func(
|
||||
)(is_contiguous_general)
|
||||
|
||||
|
||||
@register_jagged_func(
|
||||
torch.ops.aten.sym_is_contiguous.default, "self: jt_all, memory_format: any?"
|
||||
)
|
||||
def sym_is_contiguous_general(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
inp = new_kwargs.pop("input")
|
||||
|
||||
# If created from narrow() check for lengths
|
||||
if inp.lengths() is not None:
|
||||
return False
|
||||
|
||||
new_kwargs["memory_format"] = new_kwargs.get(
|
||||
"memory_format", torch.contiguous_format
|
||||
)
|
||||
|
||||
if new_kwargs["memory_format"] == torch.preserve_format:
|
||||
return True
|
||||
|
||||
return torch.ops.aten.sym_is_contiguous.default(inp._values, **new_kwargs)
|
||||
|
||||
|
||||
@register_jagged_func(
|
||||
torch.ops.aten.clone.default, "input: jt_all, memory_format: any?"
|
||||
)
|
||||
|
Reference in New Issue
Block a user