mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 21:49:24 +08:00
introduce definitely_contiguous and use it for reshape and tensor meta data computation. (#153432)
when a tensor has unbacked symbols it can be general enough to represent both contiguous and non contiguous tensors. in that case we cant really evaluate is_contiguous. In many places in the code base, we check for is_contiguous to take a fast path. but the general path usually works for both contiguous and not contiguous in that case we probably want to use definitely _contiguous API. This is appleid for reshape in this PR and also to tensor meta data computation, the meta data now will have an attribute that says that its contiguous when its always contiguous. We would store that only if definitely _contiguous is true now. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153432 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
54f1f29fed
commit
39df901b2a
@ -24,6 +24,7 @@
|
||||
#include <ATen/native/cpu/SerialStackImpl.h>
|
||||
#include <ATen/native/cpu/StackKernel.h>
|
||||
#include <ATen/quantized/QTensorImpl.h>
|
||||
#include <c10/core/Contiguity.h>
|
||||
#include <c10/core/GradMode.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/SmallVector.h>
|
||||
@ -1993,11 +1994,15 @@ Tensor reshape_symint(const Tensor& self, c10::SymIntArrayRef proposed_shape) {
|
||||
TORCH_CHECK(false, "reshape is not implemented for sparse tensors");
|
||||
}
|
||||
|
||||
if (self.is_contiguous() && !self.is_mkldnn()) {
|
||||
auto sym_sizes = self.sym_sizes();
|
||||
auto sym_strides = self.sym_strides();
|
||||
auto sym_numel = self.sym_numel();
|
||||
if (definitely_contiguous(sym_sizes, sym_strides, sym_numel) &&
|
||||
!self.is_mkldnn()) {
|
||||
return self.view_symint(proposed_shape);
|
||||
}
|
||||
|
||||
c10::SymDimVector shape = infer_size_dv(proposed_shape, self.sym_numel());
|
||||
c10::SymDimVector shape = infer_size_dv(proposed_shape, sym_numel);
|
||||
|
||||
if (self.is_mkldnn()) {
|
||||
return at::_mkldnn_reshape(self, C10_AS_INTARRAYREF_SLOW(shape));
|
||||
@ -2005,8 +2010,7 @@ Tensor reshape_symint(const Tensor& self, c10::SymIntArrayRef proposed_shape) {
|
||||
|
||||
// `computeStride` returns the proper strides to use if this
|
||||
// `reshape` can be just a view.
|
||||
auto stride =
|
||||
at::detail::computeStride(self.sym_sizes(), self.sym_strides(), shape);
|
||||
auto stride = at::detail::computeStride(sym_sizes, sym_strides, shape);
|
||||
|
||||
// NB: Even though we have viewable geometry and the target strides here,
|
||||
// we do not just call `as_strided` on `self` because the backward
|
||||
|
@ -2,7 +2,7 @@ add_loop_eager,compile_time_instruction_count,2953000000,0.015
|
||||
|
||||
|
||||
|
||||
add_loop_eager_dynamic,compile_time_instruction_count,5808000000,0.025
|
||||
add_loop_eager_dynamic,compile_time_instruction_count,5738000000,0.025
|
||||
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@ add_loop_inductor,compile_time_instruction_count,29370000000,0.015
|
||||
|
||||
|
||||
|
||||
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44010000000,0.025
|
||||
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44490000000,0.025
|
||||
|
||||
|
||||
|
||||
@ -22,11 +22,11 @@ basic_modules_ListOfLinears_eager,compile_time_instruction_count,939900000,0.015
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18140000000,0.015
|
||||
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18270000000,0.015
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16220000000,0.015
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16310000000,0.015
|
||||
|
||||
|
||||
|
||||
@ -34,11 +34,11 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10370000
|
||||
|
||||
|
||||
|
||||
update_hint_regression,compile_time_instruction_count,1681000000,0.02
|
||||
update_hint_regression,compile_time_instruction_count,1700000000,0.02
|
||||
|
||||
|
||||
|
||||
float_args,compile_time_instruction_count,449800000,0.015
|
||||
float_args,compile_time_instruction_count,452500000,0.015
|
||||
|
||||
|
||||
|
||||
@ -54,24 +54,24 @@ symint_sum_loop,compile_time_instruction_count,4262000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2091000000,0.015
|
||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2112000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5981000000,0.015
|
||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6022000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8585000000,0.015
|
||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8672000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1900000000,0.015
|
||||
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1917000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3818000000,0.015
|
||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3859000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10350000000,0.015
|
||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10420000000,0.015
|
||||
|
|
@ -12,24 +12,49 @@ namespace c10 {
|
||||
|
||||
template <typename T>
|
||||
bool _compute_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides, T numel) {
|
||||
bool is_contiguous = true;
|
||||
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(numel, 0))) {
|
||||
return is_contiguous;
|
||||
return true;
|
||||
}
|
||||
T z = 1;
|
||||
|
||||
T expected_stride = 1;
|
||||
// NB: make sure we do signed arithmetic
|
||||
for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) {
|
||||
const auto& size_d = sizes[d];
|
||||
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(size_d, 1))) {
|
||||
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(strides[d], z))) {
|
||||
z *= size_d;
|
||||
} else {
|
||||
is_contiguous = false;
|
||||
break;
|
||||
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(size_d, 1))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(strides[d], expected_stride))) {
|
||||
return false;
|
||||
}
|
||||
expected_stride *= size_d;
|
||||
}
|
||||
return is_contiguous;
|
||||
return true;
|
||||
}
|
||||
|
||||
// This function will return True if the tensor is contiguous, and False if the
|
||||
// its not or if we can't determine if it is contiguous due to unbacked symbols
|
||||
// (it could be either in that case based on the actual runtime data).
|
||||
template <typename T>
|
||||
bool definitely_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides, T numel) {
|
||||
if (TORCH_GUARD_OR_FALSE(sym_eq(numel, 0))) {
|
||||
return true;
|
||||
}
|
||||
|
||||
T expected_stride = 1;
|
||||
// NB: make sure we do signed arithmetic
|
||||
for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) {
|
||||
const auto& size_d = sizes[d];
|
||||
if (TORCH_GUARD_OR_FALSE(sym_eq(size_d, 1))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (TORCH_GUARD_OR_TRUE(sym_ne(strides[d], expected_stride))) {
|
||||
return false;
|
||||
}
|
||||
expected_stride *= size_d;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -2647,7 +2647,7 @@ graph():
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"Received user-specified .* \[None, 5\], conflicting with the inferred .*"
|
||||
r"\[6, int_oo\],.* for inputs\['xs'\]\['data'\]\[0\]\[0\]\.shape\[0\]",
|
||||
r"\[8, int_oo\],.* for inputs\['xs'\]\['data'\]\[0\]\[0\]\.shape\[0\]",
|
||||
):
|
||||
export(Foo(), ({"data": [[x, y]]},), dynamic_shapes=shapes)
|
||||
|
||||
|
@ -3281,6 +3281,39 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
|
||||
self.assertEqual(result_compiled, result_eager)
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
|
||||
# Pass a contiguous tensor. A recompilation will happen due to 0/1 speciialization on stride.
|
||||
log_stream, ctx = logs_to_string(
|
||||
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
|
||||
)
|
||||
with ctx():
|
||||
# This used to hit could guard on data-dependent expression Eq(10, u3) x.stride[0]==10. and x.size()=[u2, u3].
|
||||
# but not anymore since we use definitely_contiguous .
|
||||
# We need a way to mark strides unbacked to avoid the recompilation here.
|
||||
x = torch.randn(10, 10)
|
||||
torch._dynamo.decorators.mark_unbacked(x, 0)
|
||||
torch._dynamo.decorators.mark_unbacked(x, 1)
|
||||
|
||||
aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
|
||||
self.assertExpectedInline(
|
||||
aot_graphs,
|
||||
"""""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
)
|
||||
|
||||
result_compiled = compiled_func(x, torch.tensor([2, 50]))
|
||||
result_eager = func(x, torch.tensor([2, 50]))
|
||||
|
||||
self.assertEqual(result_compiled, result_eager)
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
x = torch.randn(4, 4)
|
||||
|
||||
result_eager = func(x, torch.tensor([2, 8]))
|
||||
result_compiled = compiled_func(x, torch.tensor([2, 8]))
|
||||
self.assertEqual(result_compiled, result_eager)
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
@unittest.skip("this test fails due to inductor/autograd issue #153041")
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_unbacked_non_contigious_reshape_failing(self):
|
||||
|
@ -1370,8 +1370,8 @@ def forward(self, crop_camera_1, mask_1):
|
||||
view_1 = torch.ops.aten.view.default(expand_1, [sym_size_int, sym_size_int_1, sym_size_int_2]); expand_1 = sym_size_int_1 = sym_size_int_2 = None
|
||||
bmm = torch.ops.aten.bmm.default(view, view_1); view = view_1 = None
|
||||
view_2 = torch.ops.aten.view.default(bmm, [sym_size_int, 3, 3]); bmm = None
|
||||
mul_4 = sym_size_int * 3
|
||||
view_3 = torch.ops.aten.view.default(view_2, [mul_4, 3]); view_2 = mul_4 = None
|
||||
mul_6 = sym_size_int * 3
|
||||
view_3 = torch.ops.aten.view.default(view_2, [mul_6, 3]); view_2 = mul_6 = None
|
||||
mm = torch.ops.aten.mm.default(view_3, eye); view_3 = eye = None
|
||||
_unsafe_view = torch.ops.aten._unsafe_view.default(mm, [sym_size_int, 3, 3]); mm = sym_size_int = None
|
||||
index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], _unsafe_view); crop_camera_1 = mask_1 = _unsafe_view = index_put_ = None
|
||||
|
@ -693,6 +693,7 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
|
||||
)
|
||||
)
|
||||
else:
|
||||
if current_group < len(remaining):
|
||||
return_getters.append(
|
||||
operator.itemgetter(add_range(current_group, size))
|
||||
)
|
||||
|
@ -259,47 +259,64 @@ def check_all_strides(
|
||||
|
||||
|
||||
# This function is equivalent to compute_contiguous() from TensorImpl.cpp
|
||||
def is_contiguous(a: TensorLikeType) -> bool:
|
||||
def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool:
|
||||
"""
|
||||
Tests whether a tensor is contiguous or not.
|
||||
|
||||
Tensors are contiguous when they have no elements,
|
||||
one element, or when they have "nested" strides.
|
||||
"""
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
guard_or_false,
|
||||
guard_or_true,
|
||||
guard_size_oblivious,
|
||||
)
|
||||
|
||||
if guard_size_oblivious(a.numel() < 2):
|
||||
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
|
||||
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
|
||||
|
||||
if maybe_guard_or_false(a.numel() < 2):
|
||||
return True
|
||||
|
||||
expected_stride = 1
|
||||
for x, y in reversed(tuple(zip(a.shape, a.stride()))):
|
||||
# Skips checking strides when a dimension has length 1
|
||||
if guard_size_oblivious(x == 1):
|
||||
if maybe_guard_or_false(x == 1):
|
||||
continue
|
||||
|
||||
if guard_size_oblivious(y != expected_stride):
|
||||
if maybe_guard_or_true(y != expected_stride):
|
||||
return False
|
||||
expected_stride = expected_stride * x
|
||||
|
||||
# if x is 0 then a is contiguous anyway. So in the check above for non-contiguity condition we can
|
||||
# can assume x is not 0 in expected_stride equation. This is also consistent with make_contiguous_strides_for.
|
||||
expected_stride = expected_stride * sym_max(x, 1)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# This function is equivalent to compute_channels_last_contiguous_2d() in TensorImpl.cpp
|
||||
def is_channels_last_contiguous_2d(a: Tensor) -> bool:
|
||||
def is_channels_last_contiguous_2d(a: Tensor, false_if_dde=False) -> bool:
|
||||
# NHWC or not channels last 2D contiguous
|
||||
if a.ndim != 4:
|
||||
return False
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
guard_or_false,
|
||||
guard_or_true,
|
||||
guard_size_oblivious,
|
||||
)
|
||||
|
||||
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
|
||||
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
|
||||
|
||||
expected_stride = 1
|
||||
for idx in (1, 3, 2, 0):
|
||||
length = a.shape[idx]
|
||||
if guard_size_oblivious(length == 1):
|
||||
if maybe_guard_or_false(length == 1):
|
||||
continue
|
||||
|
||||
stride = a.stride()[idx]
|
||||
if guard_size_oblivious(stride != expected_stride):
|
||||
if maybe_guard_or_true(stride != expected_stride):
|
||||
return False
|
||||
|
||||
expected_stride *= length
|
||||
@ -307,21 +324,28 @@ def is_channels_last_contiguous_2d(a: Tensor) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def is_channels_last_contiguous_3d(a: Tensor) -> bool:
|
||||
def is_channels_last_contiguous_3d(a: Tensor, false_if_dde=False) -> bool:
|
||||
# NDHWC or not channels last 3D contiguous
|
||||
if a.ndim != 5:
|
||||
return False
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
guard_or_false,
|
||||
guard_or_true,
|
||||
guard_size_oblivious,
|
||||
)
|
||||
|
||||
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
|
||||
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
|
||||
|
||||
expected_stride = 1
|
||||
for idx in (1, 4, 3, 2, 0):
|
||||
length = a.shape[idx]
|
||||
if guard_size_oblivious(length == 1):
|
||||
if maybe_guard_or_false(length == 1):
|
||||
continue
|
||||
|
||||
stride = a.stride()[idx]
|
||||
if guard_size_oblivious(stride != expected_stride):
|
||||
if maybe_guard_or_true(stride != expected_stride):
|
||||
return False
|
||||
|
||||
expected_stride *= length
|
||||
@ -345,16 +369,16 @@ def validate_memory_format(memory_format: torch.memory_format):
|
||||
|
||||
|
||||
def is_contiguous_for_memory_format( # type: ignore[return]
|
||||
a: Tensor, *, memory_format: torch.memory_format
|
||||
a: Tensor, *, memory_format: torch.memory_format, false_if_dde=False
|
||||
) -> bool:
|
||||
validate_memory_format(memory_format)
|
||||
|
||||
if memory_format == torch.contiguous_format:
|
||||
return is_contiguous(a)
|
||||
return is_contiguous(a, false_if_dde)
|
||||
if memory_format == torch.channels_last:
|
||||
return is_channels_last_contiguous_2d(a)
|
||||
return is_channels_last_contiguous_2d(a, false_if_dde)
|
||||
if memory_format == torch.channels_last_3d:
|
||||
return is_channels_last_contiguous_3d(a)
|
||||
return is_channels_last_contiguous_3d(a, false_if_dde)
|
||||
|
||||
torch._check(
|
||||
False,
|
||||
@ -362,6 +386,29 @@ def is_contiguous_for_memory_format( # type: ignore[return]
|
||||
)
|
||||
|
||||
|
||||
def definitely_contiguous(a: TensorLikeType) -> bool:
|
||||
return is_contiguous(a, false_if_dde=True)
|
||||
|
||||
|
||||
# similar to is_channels_last_contiguous_2d but return false on data dependency.
|
||||
def is_known_channels_last_contiguous_2d(a: Tensor) -> bool:
|
||||
return is_channels_last_contiguous_2d(a, false_if_dde=True)
|
||||
|
||||
|
||||
# similar to is_channels_last_contiguous_3d but return false on data dependency.
|
||||
def is_known_channels_last_contiguous_3d(a: Tensor) -> bool:
|
||||
return is_channels_last_contiguous_3d(a, false_if_dde=True)
|
||||
|
||||
|
||||
# similar to is_contiguous_for_memory_format but return false on data dependency.
|
||||
def definitely_contiguous_for_memory_format( # type: ignore[return]
|
||||
a: Tensor, *, memory_format: torch.memory_format
|
||||
) -> bool:
|
||||
return is_contiguous_for_memory_format(
|
||||
a, memory_format=memory_format, false_if_dde=True
|
||||
)
|
||||
|
||||
|
||||
# NOTE: that tensors with no elements and channels last is ???
|
||||
def is_channels_last_contiguous(a: Tensor) -> bool:
|
||||
"""
|
||||
@ -379,6 +426,13 @@ def is_channels_last_contiguous(a: Tensor) -> bool:
|
||||
return is_channels_last_contiguous_2d(a) or is_channels_last_contiguous_3d(a)
|
||||
|
||||
|
||||
# similar to is_channels_last_contiguous but return false on data dependency.
|
||||
def is_known_channels_last_contiguous(a: Tensor) -> bool:
|
||||
return is_known_channels_last_contiguous_2d(
|
||||
a
|
||||
) or is_known_channels_last_contiguous_3d(a)
|
||||
|
||||
|
||||
def is_non_overlapping_and_dense(a: Tensor) -> bool:
|
||||
"""
|
||||
True when a tensor is non-overlapping and dense.
|
||||
|
@ -19,6 +19,7 @@ import torch.utils._pytree as pytree
|
||||
from torch import sym_float, sym_int
|
||||
from torch._prims_common import (
|
||||
BoolLike,
|
||||
definitely_contiguous,
|
||||
DeviceLikeType,
|
||||
Dim,
|
||||
DimsSequenceType,
|
||||
@ -3824,7 +3825,7 @@ def _view_simple(a: TensorLikeType, shape, data_dependent_error) -> TensorLikeTy
|
||||
if new_strides is not None:
|
||||
return a.as_strided(shape, new_strides)
|
||||
|
||||
if a.is_contiguous():
|
||||
if definitely_contiguous(a):
|
||||
return a.as_strided(shape, utils.make_contiguous_strides_for(shape))
|
||||
|
||||
raise data_dependent_error
|
||||
|
@ -7,6 +7,7 @@ import torch
|
||||
import torch.fx
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._guards import detect_fake_mode
|
||||
from torch._prims_common import definitely_contiguous_for_memory_format
|
||||
from torch._subclasses.meta_utils import is_sparse_any
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch.fx.node import map_aggregate, Node
|
||||
@ -32,6 +33,10 @@ class TensorMetadata(NamedTuple):
|
||||
qparams: dict[str, Any]
|
||||
|
||||
|
||||
# When include_contiguity is True, we will set contiguity when its always true for the tensor.
|
||||
# Some tensors can represent both contiguous and non-contiguous tensors. e.g: (u0, u1) with (u2, u3).
|
||||
# In such situation contiguity is not set. We could also make it a tri-state i.e: (definitely_contiguous,
|
||||
# contiguous, and unknown).
|
||||
def _extract_tensor_metadata(
|
||||
result: torch.Tensor, include_contiguity=True
|
||||
) -> TensorMetadata:
|
||||
@ -52,7 +57,9 @@ def _extract_tensor_metadata(
|
||||
torch.channels_last_3d,
|
||||
}
|
||||
for query_format in memory_formats:
|
||||
if result.is_contiguous(memory_format=query_format):
|
||||
if definitely_contiguous_for_memory_format(
|
||||
result, memory_format=query_format
|
||||
):
|
||||
memory_format = query_format
|
||||
break
|
||||
|
||||
|
Reference in New Issue
Block a user