mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Remove guard_size_oblivious from default contiguity python check, and add aten.sym_is_contiguous. (#159197)
This might cause some new DDEs on call sites that do not use is_contiguous_or_false() or sym_is_contiguous() but want to find those call sites to handle this properly by calling is_contiguous_or_false() and not is_contiguous() explitly when appropriate. I had to fix one issue after removing the implicit size oblivious reasoning. here is context we defined in this https://github.com/pytorch/pytorch/pull/157472 sym_is_contiguous to be the function computing contiguity for dynamic shapes in c++. It returns a symbolic expression that represents contiguity and guaranteed not to throw a DDE. when people call is_contiguous we do sym_is_contiguous().guard_bool() when people call is_contiguous_or_false we do sym_is_contiguous().guard_or_false() one issue not handled well was this path ``` c10::SymBool TensorImpl::sym_is_contiguous_custom( at::MemoryFormat memory_format) const { if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) { return pyobj_slot_.load_pyobj_interpreter()->is_contiguous( this, memory_format); } return sym_is_contiguous_default(memory_format); } ``` namely if we call sym_is_contiguous_custom but we have matches_python_custom(SizesStridesPolicy::CustomStrides) return true , then we used to call is_contiguous(this, memory_format); This used to go through the load_pyobj_interpreter and end up calling the python is_contiguous call which used implicit size oblivious reasoning. once we removed that implicit size oblivious reasoning, the right thing we want is to call return pyobj_slot_.load_pyobj_interpreter()->sym_is_contiguous(this, memory_format); otherwise we would get DDE even if the caller is doing sym_is_contiguous. so I had to define it for pyinterpreter, and then I had to override it for nested tensors. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159197 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
a84541c73f
commit
e444cd24d4
@ -18,6 +18,7 @@
|
||||
#include <ATen/ops/is_set_to_native.h>
|
||||
#include <ATen/ops/size_native.h>
|
||||
#include <ATen/ops/stride_native.h>
|
||||
#include <ATen/ops/sym_is_contiguous_native.h>
|
||||
#include <ATen/ops/sym_numel_native.h>
|
||||
#include <ATen/ops/sym_size_native.h>
|
||||
#include <ATen/ops/sym_storage_offset_native.h>
|
||||
@ -57,6 +58,12 @@ c10::SymInt sym_size(const Tensor& self, int64_t dim) {
|
||||
return self.sym_size(dim);
|
||||
}
|
||||
|
||||
c10::SymBool sym_is_contiguous(
|
||||
const Tensor& self,
|
||||
c10::MemoryFormat memory_format) {
|
||||
return self.sym_is_contiguous(memory_format);
|
||||
}
|
||||
|
||||
c10::SymInt sym_stride(const Tensor& self, int64_t dim) {
|
||||
return self.sym_stride(dim);
|
||||
}
|
||||
|
@ -5509,6 +5509,13 @@
|
||||
tags: core
|
||||
manual_cpp_binding: True
|
||||
|
||||
- func: sym_is_contiguous(Tensor self, MemoryFormat memory_format=contiguous_format) -> SymBool
|
||||
variants: function
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
tags: core
|
||||
manual_cpp_binding: True
|
||||
|
||||
- func: sym_numel(Tensor self) -> SymInt
|
||||
variants: function
|
||||
device_check: NoCheck
|
||||
|
@ -313,8 +313,15 @@ void TensorImpl::throw_data_ptr_access_error() const {
|
||||
c10::SymBool TensorImpl::sym_is_contiguous_custom(
|
||||
at::MemoryFormat memory_format) const {
|
||||
if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
|
||||
return pyobj_slot_.load_pyobj_interpreter()->is_contiguous(
|
||||
this, memory_format);
|
||||
// TO reduce BC breaking and reduce having to introduce
|
||||
// sym_is_contiguous. call is_contiguous when tensor does not
|
||||
if (C10_UNLIKELY(has_symbolic_sizes_strides_)) {
|
||||
return pyobj_slot_.load_pyobj_interpreter()->sym_is_contiguous(
|
||||
this, memory_format);
|
||||
} else {
|
||||
return pyobj_slot_.load_pyobj_interpreter()->is_contiguous(
|
||||
this, memory_format);
|
||||
}
|
||||
}
|
||||
|
||||
return sym_is_contiguous_default(memory_format);
|
||||
|
@ -60,6 +60,10 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
|
||||
bool is_contiguous(const TensorImpl* self, at::MemoryFormat) const override {
|
||||
PANIC(is_contiguous);
|
||||
}
|
||||
c10::SymBool sym_is_contiguous(const TensorImpl* self, at::MemoryFormat)
|
||||
const override {
|
||||
PANIC(sym_is_contiguous);
|
||||
}
|
||||
bool is_strides_like(const TensorImpl* self, at::MemoryFormat)
|
||||
const override {
|
||||
PANIC(is_strides_like);
|
||||
|
@ -168,6 +168,9 @@ struct C10_API PyInterpreterVTable {
|
||||
|
||||
virtual bool is_contiguous(const TensorImpl* self, at::MemoryFormat)
|
||||
const = 0;
|
||||
virtual c10::SymBool sym_is_contiguous(
|
||||
const TensorImpl* self,
|
||||
at::MemoryFormat) const = 0;
|
||||
virtual bool is_strides_like(const TensorImpl* self, at::MemoryFormat)
|
||||
const = 0;
|
||||
virtual bool is_non_overlapping_and_dense(const TensorImpl* self) const = 0;
|
||||
|
@ -208,6 +208,7 @@ xfail_not_implemented = {
|
||||
"aten::subtract_.Scalar",
|
||||
"aten::subtract_.Tensor",
|
||||
"aten::svd.U",
|
||||
"aten::sym_is_contiguous",
|
||||
"aten::sym_size.int",
|
||||
"aten::sym_stride.int",
|
||||
"aten::sym_numel",
|
||||
|
@ -1958,6 +1958,8 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||
if func.overloadpacket == torch.ops.aten.is_contiguous:
|
||||
return contiguous_data.is_contiguous()
|
||||
if func.overloadpacket == torch.ops.aten.sym_is_contiguous:
|
||||
return torch.ops.aten.sym_is_contiguous(contiguous_data)
|
||||
return NotImplemented
|
||||
|
||||
class ExampleTensor3(torch.Tensor):
|
||||
@ -1971,6 +1973,8 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||
if func.overloadpacket == torch.ops.aten.is_contiguous:
|
||||
return not_contiguous_data.is_contiguous()
|
||||
if func.overloadpacket == torch.ops.aten.sym_is_contiguous:
|
||||
return torch.ops.aten.sym_is_contiguous(not_contiguous_data)
|
||||
return NotImplemented
|
||||
|
||||
err_msg = "Multiple dispatch failed for 'torch.ops.aten.is_contiguous'"
|
||||
@ -2003,6 +2007,7 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||
if func in [
|
||||
torch.ops.aten.sym_is_contiguous.default,
|
||||
torch.ops.aten.is_contiguous.default,
|
||||
torch.ops.aten.is_contiguous.memory_format,
|
||||
torch.ops.aten.is_strides_like_format.default,
|
||||
|
@ -97,6 +97,7 @@ _SKIP_PYTHON_BINDINGS = [
|
||||
"is_sparse_csr",
|
||||
"size",
|
||||
"stride",
|
||||
"sym_is_contiguous",
|
||||
"sym_size",
|
||||
"sym_stride",
|
||||
"sym_storage_offset",
|
||||
|
@ -1560,7 +1560,6 @@ class CatchErrorsWrapper:
|
||||
frame_state: dict[str, Union[int, FrameStateSizeEntry]],
|
||||
) -> ConvertFrameReturn:
|
||||
assert frame_state is not None
|
||||
|
||||
input_codes.add(frame.f_code)
|
||||
|
||||
is_skipfile = trace_rules.check(frame.f_code)
|
||||
|
@ -265,12 +265,14 @@ def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool:
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
guard_or_false,
|
||||
guard_or_true,
|
||||
guard_size_oblivious,
|
||||
is_nested_int,
|
||||
)
|
||||
|
||||
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
|
||||
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
|
||||
def eval_eager(x):
|
||||
return bool(x)
|
||||
|
||||
maybe_guard_or_false = guard_or_false if false_if_dde else eval_eager
|
||||
maybe_guard_or_true = guard_or_true if false_if_dde else eval_eager
|
||||
|
||||
if maybe_guard_or_false(a.numel() < 2):
|
||||
return True
|
||||
@ -305,14 +307,13 @@ def is_channels_last_contiguous_2d(a: Tensor, false_if_dde=False) -> bool:
|
||||
if a.ndim != 4:
|
||||
return False
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
guard_or_false,
|
||||
guard_or_true,
|
||||
guard_size_oblivious,
|
||||
)
|
||||
from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true
|
||||
|
||||
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
|
||||
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
|
||||
def eval_eager(x):
|
||||
return bool(x)
|
||||
|
||||
maybe_guard_or_false = guard_or_false if false_if_dde else eval_eager
|
||||
maybe_guard_or_true = guard_or_true if false_if_dde else eval_eager
|
||||
|
||||
expected_stride = 1
|
||||
for idx in (1, 3, 2, 0):
|
||||
@ -334,14 +335,13 @@ def is_channels_last_contiguous_3d(a: Tensor, false_if_dde=False) -> bool:
|
||||
if a.ndim != 5:
|
||||
return False
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
guard_or_false,
|
||||
guard_or_true,
|
||||
guard_size_oblivious,
|
||||
)
|
||||
from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true
|
||||
|
||||
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
|
||||
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
|
||||
def eval_eager(x):
|
||||
return bool(x)
|
||||
|
||||
maybe_guard_or_false = guard_or_false if false_if_dde else eval_eager
|
||||
maybe_guard_or_true = guard_or_true if false_if_dde else eval_eager
|
||||
|
||||
expected_stride = 1
|
||||
for idx in (1, 4, 3, 2, 0):
|
||||
@ -406,7 +406,7 @@ def is_channels_last_contiguous_or_false_3d(a: Tensor) -> bool:
|
||||
|
||||
|
||||
# similar to is_contiguous_for_memory_format but return false on data dependency.
|
||||
def contiguous_for_memory_format_or_false( # type: ignore[return]
|
||||
def is_contiguous_for_memory_format_or_false( # type: ignore[return]
|
||||
a: Tensor, *, memory_format: torch.memory_format
|
||||
) -> bool:
|
||||
return is_contiguous_for_memory_format(
|
||||
@ -550,11 +550,14 @@ def compute_elementwise_output_logical_to_physical_perm(
|
||||
is_contiguous = True
|
||||
is_channels_last = True
|
||||
for t in tensors:
|
||||
is_contiguous = is_contiguous and contiguous_for_memory_format_or_false(
|
||||
is_contiguous = is_contiguous and is_contiguous_for_memory_format_or_false(
|
||||
t, memory_format=torch.contiguous_format
|
||||
)
|
||||
is_channels_last = is_channels_last and contiguous_for_memory_format_or_false(
|
||||
t, memory_format=torch.channels_last
|
||||
is_channels_last = (
|
||||
is_channels_last
|
||||
and is_contiguous_for_memory_format_or_false(
|
||||
t, memory_format=torch.channels_last
|
||||
)
|
||||
)
|
||||
|
||||
if is_contiguous and not is_channels_last:
|
||||
|
@ -19,7 +19,6 @@ import torch.utils._pytree as pytree
|
||||
from torch import sym_float, sym_int
|
||||
from torch._prims_common import (
|
||||
BoolLike,
|
||||
contiguous_for_memory_format_or_false,
|
||||
DeviceLikeType,
|
||||
Dim,
|
||||
DimsSequenceType,
|
||||
@ -29,6 +28,7 @@ from torch._prims_common import (
|
||||
FloatLike,
|
||||
FloatWithoutSymFloat,
|
||||
IntLike,
|
||||
is_contiguous_for_memory_format_or_false,
|
||||
is_contiguous_or_false,
|
||||
is_weakly_lesser_type,
|
||||
Number,
|
||||
@ -3000,7 +3000,7 @@ def contiguous(
|
||||
)
|
||||
|
||||
# TODO: make logic consistent with aten contiguous
|
||||
if contiguous_for_memory_format_or_false(a, memory_format=memory_format):
|
||||
if is_contiguous_for_memory_format_or_false(a, memory_format=memory_format):
|
||||
return a
|
||||
|
||||
return torch.clone(a, memory_format=memory_format)
|
||||
|
@ -15,11 +15,11 @@ import torch._prims_common as utils
|
||||
from torch._dispatch.python import no_python_dispatcher
|
||||
from torch._ops import OpOverload
|
||||
from torch._prims_common import (
|
||||
contiguous_for_memory_format_or_false,
|
||||
elementwise_dtypes,
|
||||
ELEMENTWISE_TYPE_PROMOTION_KIND,
|
||||
is_boolean_dtype,
|
||||
is_contiguous,
|
||||
is_contiguous_for_memory_format_or_false,
|
||||
is_contiguous_or_false,
|
||||
is_float_dtype,
|
||||
is_integer_dtype,
|
||||
@ -1256,13 +1256,13 @@ def make_fast_binary_impl(
|
||||
continue
|
||||
definitely_contiguous = (
|
||||
definitely_contiguous
|
||||
and contiguous_for_memory_format_or_false(
|
||||
and is_contiguous_for_memory_format_or_false(
|
||||
op, memory_format=torch.contiguous_format
|
||||
)
|
||||
)
|
||||
definitely_channels_last = (
|
||||
definitely_channels_last
|
||||
and contiguous_for_memory_format_or_false(
|
||||
and is_contiguous_for_memory_format_or_false(
|
||||
op, memory_format=torch.channels_last
|
||||
)
|
||||
)
|
||||
|
@ -82,6 +82,8 @@ struct ConcretePyInterpreterVTable final
|
||||
|
||||
bool is_contiguous(const c10::TensorImpl* self, at::MemoryFormat)
|
||||
const override;
|
||||
c10::SymBool sym_is_contiguous(const c10::TensorImpl* self, at::MemoryFormat)
|
||||
const override;
|
||||
bool is_strides_like(const c10::TensorImpl* self, at::MemoryFormat)
|
||||
const override;
|
||||
bool is_non_overlapping_and_dense(const c10::TensorImpl* self) const override;
|
||||
@ -476,6 +478,33 @@ bool ConcretePyInterpreterVTable::is_contiguous(
|
||||
return PyObject_IsTrue(out.ptr());
|
||||
}
|
||||
|
||||
c10::SymBool ConcretePyInterpreterVTable::sym_is_contiguous(
|
||||
const c10::TensorImpl* self,
|
||||
at::MemoryFormat memory_format) const {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
at::impl::MaybeSetTLSOnEntryGuard guard;
|
||||
|
||||
py::object out;
|
||||
out = torchDispatchFromTensorImpl(
|
||||
self,
|
||||
"sym_is_contiguous",
|
||||
py::module::import("torch")
|
||||
.attr("ops")
|
||||
.attr("aten")
|
||||
.attr("sym_is_contiguous")
|
||||
.attr("default")
|
||||
.ptr(),
|
||||
"torch.ops.aten",
|
||||
{py::cast(memory_format)});
|
||||
|
||||
if (out.is_none()) {
|
||||
return self->sym_is_contiguous_default(memory_format);
|
||||
}
|
||||
|
||||
return torch::is_symbool(out) ? out.cast<c10::SymBool>()
|
||||
: c10::SymBool{py::cast<bool>(out)};
|
||||
}
|
||||
|
||||
bool ConcretePyInterpreterVTable::is_strides_like(
|
||||
const c10::TensorImpl* self,
|
||||
at::MemoryFormat memory_format) const {
|
||||
|
@ -33,6 +33,7 @@ using c10::StorageType;
|
||||
using c10::StreamObjType;
|
||||
using c10::StringType;
|
||||
using c10::Symbol;
|
||||
using c10::SymBoolType;
|
||||
using c10::SymIntType;
|
||||
using c10::TensorType;
|
||||
using c10::TupleType;
|
||||
@ -66,6 +67,7 @@ TypePtr SchemaTypeParser::parseBaseType() {
|
||||
{"int", c10::TypeFactory::get<IntType>()},
|
||||
{"SymInt", c10::TypeFactory::get<SymIntType>()},
|
||||
{"bool", c10::TypeFactory::get<BoolType>()},
|
||||
{"SymBool", c10::TypeFactory::get<SymBoolType>()},
|
||||
{"None", c10::TypeFactory::get<NoneType>()},
|
||||
{"NoneType", c10::TypeFactory::get<NoneType>()},
|
||||
{"Capsule", c10::TypeFactory::get<CapsuleType>()},
|
||||
|
@ -7,7 +7,7 @@ import torch
|
||||
import torch.fx
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._guards import detect_fake_mode
|
||||
from torch._prims_common import contiguous_for_memory_format_or_false
|
||||
from torch._prims_common import is_contiguous_for_memory_format_or_false
|
||||
from torch._subclasses.meta_utils import is_sparse_any
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch.fx.node import map_aggregate, Node
|
||||
@ -57,7 +57,7 @@ def _extract_tensor_metadata(
|
||||
torch.channels_last_3d,
|
||||
}
|
||||
for query_format in memory_formats:
|
||||
if contiguous_for_memory_format_or_false(
|
||||
if is_contiguous_for_memory_format_or_false(
|
||||
result, memory_format=query_format
|
||||
):
|
||||
memory_format = query_format
|
||||
|
@ -285,7 +285,9 @@ def layout(func, *args, **kwargs):
|
||||
return _get_data(args[0]).layout
|
||||
|
||||
|
||||
@register_dispatch_func([torch.ops.aten.is_contiguous])
|
||||
@register_dispatch_func(
|
||||
[torch.ops.aten.is_contiguous, torch.ops.aten.sym_is_contiguous]
|
||||
)
|
||||
def is_contiguous(func, *args, **kwargs):
|
||||
data = _get_data(args[0])
|
||||
if data.is_sparse:
|
||||
|
@ -234,14 +234,25 @@ class NestedTensor(torch.Tensor):
|
||||
mt = self._min_seqlen_tensor
|
||||
return None if mt is None else _load_val_from_tensor(mt)
|
||||
|
||||
def _is_contiguous_or_false(self):
|
||||
if self.lengths() is not None:
|
||||
return False
|
||||
from torch._prims_common import is_contiguous_for_memory_format_or_false
|
||||
|
||||
return is_contiguous_for_memory_format_or_false(
|
||||
self._values, memory_format=torch.contiguous_format
|
||||
)
|
||||
|
||||
def __repr__(self): # type: ignore[override]
|
||||
# We should implement this in torch/_tensor_str.py instead
|
||||
grad_fn_str = (
|
||||
f", requires_grad={self.requires_grad}" if self.requires_grad else ""
|
||||
)
|
||||
|
||||
if self.grad_fn:
|
||||
grad_fn_str = f", grad_fn={self.grad_fn}"
|
||||
return f"NestedTensor(size={self._size}, offsets={self._offsets}{grad_fn_str}, contiguous={self.is_contiguous()})"
|
||||
|
||||
return f"NestedTensor(size={self._size}, offsets={self._offsets}{grad_fn_str}, contiguous={self._is_contiguous_or_false()})"
|
||||
|
||||
# TODO: Remove this in favor of the default tensor subclass serialization logic.
|
||||
# We don't do this today because of https://github.com/pytorch/pytorch/issues/125622.
|
||||
|
@ -516,6 +516,29 @@ register_jagged_func(
|
||||
)(is_contiguous_general)
|
||||
|
||||
|
||||
@register_jagged_func(
|
||||
torch.ops.aten.sym_is_contiguous.default, "self: jt_all, memory_format: any?"
|
||||
)
|
||||
def sym_is_contiguous_general(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
inp = new_kwargs.pop("input")
|
||||
|
||||
# If created from narrow() check for lengths
|
||||
if inp.lengths() is not None:
|
||||
return False
|
||||
|
||||
new_kwargs["memory_format"] = new_kwargs.get(
|
||||
"memory_format", torch.contiguous_format
|
||||
)
|
||||
|
||||
if new_kwargs["memory_format"] == torch.preserve_format:
|
||||
return True
|
||||
|
||||
return torch.ops.aten.sym_is_contiguous.default(inp._values, **new_kwargs)
|
||||
|
||||
|
||||
@register_jagged_func(
|
||||
torch.ops.aten.clone.default, "input: jt_all, memory_format: any?"
|
||||
)
|
||||
|
@ -834,7 +834,8 @@ class _FlopCounterMode(TorchDispatchMode):
|
||||
kwargs = kwargs if kwargs else {}
|
||||
|
||||
# Skip ops from non-standard dispatch_sizes_strides_policy such as NJT
|
||||
if func in {torch.ops.aten.is_contiguous.default,
|
||||
if func in {torch.ops.aten.sym_is_contiguous.default,
|
||||
torch.ops.aten.is_contiguous.default,
|
||||
torch.ops.aten.is_contiguous.memory_format,
|
||||
torch.ops.aten.is_strides_like_format.default,
|
||||
torch.ops.aten.is_non_overlapping_and_dense.default,
|
||||
|
@ -79,6 +79,7 @@ tensorOptionsT = BaseCppType("at", "TensorOptions")
|
||||
typeAndSizeT = BaseCppType("torch::autograd::generated", "TypeAndSize")
|
||||
tensorGeometryT = BaseCppType("at", "TensorGeometry")
|
||||
SymIntT = BaseCppType("c10", "SymInt")
|
||||
SymBoolT = BaseCppType("c10", "SymBool")
|
||||
symIntArrayRefT = BaseCppType("c10", "SymIntArrayRef")
|
||||
|
||||
# Types representing template parameters. Technically, we probably shouldn't
|
||||
@ -125,6 +126,7 @@ BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = {
|
||||
BaseTy.Storage: storageT,
|
||||
BaseTy.Stream: streamT,
|
||||
BaseTy.SymInt: SymIntT,
|
||||
BaseTy.SymBool: SymBoolT,
|
||||
}
|
||||
|
||||
# CTypes encode C++ type structure as needed for translation.
|
||||
|
Reference in New Issue
Block a user