introduce definitely_contiguous and use it for reshape and tensor meta data computation. (#153432)

when a tensor has unbacked symbols it can be general enough to represent both contiguous and non contiguous tensors.
in that case we cant really evaluate is_contiguous. In many places in the code base, we check for is_contiguous to take a fast path. but the general path usually works for both contiguous and not contiguous in that case we probably want
to use definitely _contiguous API.

This is appleid for reshape in this PR and also to  tensor meta data computation, the meta data now will have an attribute that says that its contiguous when its always contiguous. We would store that only if definitely _contiguous is true  now.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153432
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Laith Sakka
2025-05-27 13:24:57 -07:00
committed by PyTorch MergeBot
parent 54f1f29fed
commit 39df901b2a
10 changed files with 178 additions and 53 deletions

View File

@ -24,6 +24,7 @@
#include <ATen/native/cpu/SerialStackImpl.h> #include <ATen/native/cpu/SerialStackImpl.h>
#include <ATen/native/cpu/StackKernel.h> #include <ATen/native/cpu/StackKernel.h>
#include <ATen/quantized/QTensorImpl.h> #include <ATen/quantized/QTensorImpl.h>
#include <c10/core/Contiguity.h>
#include <c10/core/GradMode.h> #include <c10/core/GradMode.h>
#include <c10/util/Exception.h> #include <c10/util/Exception.h>
#include <c10/util/SmallVector.h> #include <c10/util/SmallVector.h>
@ -1993,11 +1994,15 @@ Tensor reshape_symint(const Tensor& self, c10::SymIntArrayRef proposed_shape) {
TORCH_CHECK(false, "reshape is not implemented for sparse tensors"); TORCH_CHECK(false, "reshape is not implemented for sparse tensors");
} }
if (self.is_contiguous() && !self.is_mkldnn()) { auto sym_sizes = self.sym_sizes();
auto sym_strides = self.sym_strides();
auto sym_numel = self.sym_numel();
if (definitely_contiguous(sym_sizes, sym_strides, sym_numel) &&
!self.is_mkldnn()) {
return self.view_symint(proposed_shape); return self.view_symint(proposed_shape);
} }
c10::SymDimVector shape = infer_size_dv(proposed_shape, self.sym_numel()); c10::SymDimVector shape = infer_size_dv(proposed_shape, sym_numel);
if (self.is_mkldnn()) { if (self.is_mkldnn()) {
return at::_mkldnn_reshape(self, C10_AS_INTARRAYREF_SLOW(shape)); return at::_mkldnn_reshape(self, C10_AS_INTARRAYREF_SLOW(shape));
@ -2005,8 +2010,7 @@ Tensor reshape_symint(const Tensor& self, c10::SymIntArrayRef proposed_shape) {
// `computeStride` returns the proper strides to use if this // `computeStride` returns the proper strides to use if this
// `reshape` can be just a view. // `reshape` can be just a view.
auto stride = auto stride = at::detail::computeStride(sym_sizes, sym_strides, shape);
at::detail::computeStride(self.sym_sizes(), self.sym_strides(), shape);
// NB: Even though we have viewable geometry and the target strides here, // NB: Even though we have viewable geometry and the target strides here,
// we do not just call `as_strided` on `self` because the backward // we do not just call `as_strided` on `self` because the backward

View File

@ -2,7 +2,7 @@ add_loop_eager,compile_time_instruction_count,2953000000,0.015
add_loop_eager_dynamic,compile_time_instruction_count,5808000000,0.025 add_loop_eager_dynamic,compile_time_instruction_count,5738000000,0.025
@ -10,7 +10,7 @@ add_loop_inductor,compile_time_instruction_count,29370000000,0.015
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44010000000,0.025 add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44490000000,0.025
@ -22,11 +22,11 @@ basic_modules_ListOfLinears_eager,compile_time_instruction_count,939900000,0.015
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18140000000,0.015 basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18270000000,0.015
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16220000000,0.015 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16310000000,0.015
@ -34,11 +34,11 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10370000
update_hint_regression,compile_time_instruction_count,1681000000,0.02 update_hint_regression,compile_time_instruction_count,1700000000,0.02
float_args,compile_time_instruction_count,449800000,0.015 float_args,compile_time_instruction_count,452500000,0.015
@ -54,24 +54,24 @@ symint_sum_loop,compile_time_instruction_count,4262000000,0.015
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2091000000,0.015 aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2112000000,0.015
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5981000000,0.015 aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6022000000,0.015
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8585000000,0.015 aotdispatcher_partitioner_cpu,compile_time_instruction_count,8672000000,0.015
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1900000000,0.015 aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1917000000,0.015
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3818000000,0.015 aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3859000000,0.015
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10350000000,0.015 aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10420000000,0.015

1 add_loop_eager compile_time_instruction_count 2953000000 0.015
2 add_loop_eager_dynamic compile_time_instruction_count 5808000000 5738000000 0.025
3 add_loop_inductor compile_time_instruction_count 29370000000 0.015
4 add_loop_inductor_dynamic_gpu compile_time_instruction_count 44010000000 44490000000 0.025
5 add_loop_inductor_gpu compile_time_instruction_count 25900000000 0.015
6 basic_modules_ListOfLinears_eager compile_time_instruction_count 939900000 0.015
7 basic_modules_ListOfLinears_inductor compile_time_instruction_count 18140000000 18270000000 0.015
8 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 16220000000 16310000000 0.015
10 update_hint_regression compile_time_instruction_count 1681000000 1700000000 0.02
11 float_args compile_time_instruction_count 449800000 452500000 0.015
12 sum_floordiv_regression compile_time_instruction_count 998600000 0.015
13 symint_sum compile_time_instruction_count 3252000000 0.015
14 symint_sum_loop compile_time_instruction_count 4262000000 0.015
15 aotdispatcher_inference_nosubclass_cpu compile_time_instruction_count 2091000000 2112000000 0.015
16 aotdispatcher_inference_subclass_cpu compile_time_instruction_count 5981000000 6022000000 0.015
22
23
24
25
26
27
28
29
30
31
32
34
35
36
37
38
39
40
41
42
43
44
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77

View File

@ -12,24 +12,49 @@ namespace c10 {
template <typename T> template <typename T>
bool _compute_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides, T numel) { bool _compute_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides, T numel) {
bool is_contiguous = true;
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(numel, 0))) { if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(numel, 0))) {
return is_contiguous; return true;
} }
T z = 1;
T expected_stride = 1;
// NB: make sure we do signed arithmetic // NB: make sure we do signed arithmetic
for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) { for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) {
const auto& size_d = sizes[d]; const auto& size_d = sizes[d];
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(size_d, 1))) { if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(size_d, 1))) {
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(strides[d], z))) { continue;
z *= size_d;
} else {
is_contiguous = false;
break;
}
} }
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(strides[d], expected_stride))) {
return false;
}
expected_stride *= size_d;
} }
return is_contiguous; return true;
}
// This function will return True if the tensor is contiguous, and False if the
// its not or if we can't determine if it is contiguous due to unbacked symbols
// (it could be either in that case based on the actual runtime data).
template <typename T>
bool definitely_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides, T numel) {
if (TORCH_GUARD_OR_FALSE(sym_eq(numel, 0))) {
return true;
}
T expected_stride = 1;
// NB: make sure we do signed arithmetic
for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) {
const auto& size_d = sizes[d];
if (TORCH_GUARD_OR_FALSE(sym_eq(size_d, 1))) {
continue;
}
if (TORCH_GUARD_OR_TRUE(sym_ne(strides[d], expected_stride))) {
return false;
}
expected_stride *= size_d;
}
return true;
} }
template <typename T> template <typename T>

View File

@ -2647,7 +2647,7 @@ graph():
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, ValueError,
r"Received user-specified .* \[None, 5\], conflicting with the inferred .*" r"Received user-specified .* \[None, 5\], conflicting with the inferred .*"
r"\[6, int_oo\],.* for inputs\['xs'\]\['data'\]\[0\]\[0\]\.shape\[0\]", r"\[8, int_oo\],.* for inputs\['xs'\]\['data'\]\[0\]\[0\]\.shape\[0\]",
): ):
export(Foo(), ({"data": [[x, y]]},), dynamic_shapes=shapes) export(Foo(), ({"data": [[x, y]]},), dynamic_shapes=shapes)

View File

@ -3281,6 +3281,39 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
self.assertEqual(result_compiled, result_eager) self.assertEqual(result_compiled, result_eager)
self.assertEqual(cnt.frame_count, 1) self.assertEqual(cnt.frame_count, 1)
# Pass a contiguous tensor. A recompilation will happen due to 0/1 speciialization on stride.
log_stream, ctx = logs_to_string(
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
)
with ctx():
# This used to hit could guard on data-dependent expression Eq(10, u3) x.stride[0]==10. and x.size()=[u2, u3].
# but not anymore since we use definitely_contiguous .
# We need a way to mark strides unbacked to avoid the recompilation here.
x = torch.randn(10, 10)
torch._dynamo.decorators.mark_unbacked(x, 0)
torch._dynamo.decorators.mark_unbacked(x, 1)
aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
self.assertExpectedInline(
aot_graphs,
"""""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
result_compiled = compiled_func(x, torch.tensor([2, 50]))
result_eager = func(x, torch.tensor([2, 50]))
self.assertEqual(result_compiled, result_eager)
self.assertEqual(cnt.frame_count, 2)
x = torch.randn(4, 4)
result_eager = func(x, torch.tensor([2, 8]))
result_compiled = compiled_func(x, torch.tensor([2, 8]))
self.assertEqual(result_compiled, result_eager)
self.assertEqual(cnt.frame_count, 2)
@unittest.skip("this test fails due to inductor/autograd issue #153041") @unittest.skip("this test fails due to inductor/autograd issue #153041")
@torch._dynamo.config.patch("capture_scalar_outputs", True) @torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_unbacked_non_contigious_reshape_failing(self): def test_unbacked_non_contigious_reshape_failing(self):

View File

@ -1370,8 +1370,8 @@ def forward(self, crop_camera_1, mask_1):
view_1 = torch.ops.aten.view.default(expand_1, [sym_size_int, sym_size_int_1, sym_size_int_2]); expand_1 = sym_size_int_1 = sym_size_int_2 = None view_1 = torch.ops.aten.view.default(expand_1, [sym_size_int, sym_size_int_1, sym_size_int_2]); expand_1 = sym_size_int_1 = sym_size_int_2 = None
bmm = torch.ops.aten.bmm.default(view, view_1); view = view_1 = None bmm = torch.ops.aten.bmm.default(view, view_1); view = view_1 = None
view_2 = torch.ops.aten.view.default(bmm, [sym_size_int, 3, 3]); bmm = None view_2 = torch.ops.aten.view.default(bmm, [sym_size_int, 3, 3]); bmm = None
mul_4 = sym_size_int * 3 mul_6 = sym_size_int * 3
view_3 = torch.ops.aten.view.default(view_2, [mul_4, 3]); view_2 = mul_4 = None view_3 = torch.ops.aten.view.default(view_2, [mul_6, 3]); view_2 = mul_6 = None
mm = torch.ops.aten.mm.default(view_3, eye); view_3 = eye = None mm = torch.ops.aten.mm.default(view_3, eye); view_3 = eye = None
_unsafe_view = torch.ops.aten._unsafe_view.default(mm, [sym_size_int, 3, 3]); mm = sym_size_int = None _unsafe_view = torch.ops.aten._unsafe_view.default(mm, [sym_size_int, 3, 3]); mm = sym_size_int = None
index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], _unsafe_view); crop_camera_1 = mask_1 = _unsafe_view = index_put_ = None index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], _unsafe_view); crop_camera_1 = mask_1 = _unsafe_view = index_put_ = None

View File

@ -693,9 +693,10 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
) )
) )
else: else:
return_getters.append( if current_group < len(remaining):
operator.itemgetter(add_range(current_group, size)) return_getters.append(
) operator.itemgetter(add_range(current_group, size))
)
return_getters_groups.append(return_getters) return_getters_groups.append(return_getters)
assert all(V.graph.sizevars.size_hint(s) == 1 for s in remaining), ( assert all(V.graph.sizevars.size_hint(s) == 1 for s in remaining), (

View File

@ -259,47 +259,64 @@ def check_all_strides(
# This function is equivalent to compute_contiguous() from TensorImpl.cpp # This function is equivalent to compute_contiguous() from TensorImpl.cpp
def is_contiguous(a: TensorLikeType) -> bool: def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool:
""" """
Tests whether a tensor is contiguous or not. Tests whether a tensor is contiguous or not.
Tensors are contiguous when they have no elements, Tensors are contiguous when they have no elements,
one element, or when they have "nested" strides. one element, or when they have "nested" strides.
""" """
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious from torch.fx.experimental.symbolic_shapes import (
guard_or_false,
guard_or_true,
guard_size_oblivious,
)
if guard_size_oblivious(a.numel() < 2): maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
if maybe_guard_or_false(a.numel() < 2):
return True return True
expected_stride = 1 expected_stride = 1
for x, y in reversed(tuple(zip(a.shape, a.stride()))): for x, y in reversed(tuple(zip(a.shape, a.stride()))):
# Skips checking strides when a dimension has length 1 # Skips checking strides when a dimension has length 1
if guard_size_oblivious(x == 1): if maybe_guard_or_false(x == 1):
continue continue
if guard_size_oblivious(y != expected_stride): if maybe_guard_or_true(y != expected_stride):
return False return False
expected_stride = expected_stride * x
# if x is 0 then a is contiguous anyway. So in the check above for non-contiguity condition we can
# can assume x is not 0 in expected_stride equation. This is also consistent with make_contiguous_strides_for.
expected_stride = expected_stride * sym_max(x, 1)
return True return True
# This function is equivalent to compute_channels_last_contiguous_2d() in TensorImpl.cpp # This function is equivalent to compute_channels_last_contiguous_2d() in TensorImpl.cpp
def is_channels_last_contiguous_2d(a: Tensor) -> bool: def is_channels_last_contiguous_2d(a: Tensor, false_if_dde=False) -> bool:
# NHWC or not channels last 2D contiguous # NHWC or not channels last 2D contiguous
if a.ndim != 4: if a.ndim != 4:
return False return False
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious from torch.fx.experimental.symbolic_shapes import (
guard_or_false,
guard_or_true,
guard_size_oblivious,
)
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
expected_stride = 1 expected_stride = 1
for idx in (1, 3, 2, 0): for idx in (1, 3, 2, 0):
length = a.shape[idx] length = a.shape[idx]
if guard_size_oblivious(length == 1): if maybe_guard_or_false(length == 1):
continue continue
stride = a.stride()[idx] stride = a.stride()[idx]
if guard_size_oblivious(stride != expected_stride): if maybe_guard_or_true(stride != expected_stride):
return False return False
expected_stride *= length expected_stride *= length
@ -307,21 +324,28 @@ def is_channels_last_contiguous_2d(a: Tensor) -> bool:
return True return True
def is_channels_last_contiguous_3d(a: Tensor) -> bool: def is_channels_last_contiguous_3d(a: Tensor, false_if_dde=False) -> bool:
# NDHWC or not channels last 3D contiguous # NDHWC or not channels last 3D contiguous
if a.ndim != 5: if a.ndim != 5:
return False return False
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious from torch.fx.experimental.symbolic_shapes import (
guard_or_false,
guard_or_true,
guard_size_oblivious,
)
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
expected_stride = 1 expected_stride = 1
for idx in (1, 4, 3, 2, 0): for idx in (1, 4, 3, 2, 0):
length = a.shape[idx] length = a.shape[idx]
if guard_size_oblivious(length == 1): if maybe_guard_or_false(length == 1):
continue continue
stride = a.stride()[idx] stride = a.stride()[idx]
if guard_size_oblivious(stride != expected_stride): if maybe_guard_or_true(stride != expected_stride):
return False return False
expected_stride *= length expected_stride *= length
@ -345,16 +369,16 @@ def validate_memory_format(memory_format: torch.memory_format):
def is_contiguous_for_memory_format( # type: ignore[return] def is_contiguous_for_memory_format( # type: ignore[return]
a: Tensor, *, memory_format: torch.memory_format a: Tensor, *, memory_format: torch.memory_format, false_if_dde=False
) -> bool: ) -> bool:
validate_memory_format(memory_format) validate_memory_format(memory_format)
if memory_format == torch.contiguous_format: if memory_format == torch.contiguous_format:
return is_contiguous(a) return is_contiguous(a, false_if_dde)
if memory_format == torch.channels_last: if memory_format == torch.channels_last:
return is_channels_last_contiguous_2d(a) return is_channels_last_contiguous_2d(a, false_if_dde)
if memory_format == torch.channels_last_3d: if memory_format == torch.channels_last_3d:
return is_channels_last_contiguous_3d(a) return is_channels_last_contiguous_3d(a, false_if_dde)
torch._check( torch._check(
False, False,
@ -362,6 +386,29 @@ def is_contiguous_for_memory_format( # type: ignore[return]
) )
def definitely_contiguous(a: TensorLikeType) -> bool:
return is_contiguous(a, false_if_dde=True)
# similar to is_channels_last_contiguous_2d but return false on data dependency.
def is_known_channels_last_contiguous_2d(a: Tensor) -> bool:
return is_channels_last_contiguous_2d(a, false_if_dde=True)
# similar to is_channels_last_contiguous_3d but return false on data dependency.
def is_known_channels_last_contiguous_3d(a: Tensor) -> bool:
return is_channels_last_contiguous_3d(a, false_if_dde=True)
# similar to is_contiguous_for_memory_format but return false on data dependency.
def definitely_contiguous_for_memory_format( # type: ignore[return]
a: Tensor, *, memory_format: torch.memory_format
) -> bool:
return is_contiguous_for_memory_format(
a, memory_format=memory_format, false_if_dde=True
)
# NOTE: that tensors with no elements and channels last is ??? # NOTE: that tensors with no elements and channels last is ???
def is_channels_last_contiguous(a: Tensor) -> bool: def is_channels_last_contiguous(a: Tensor) -> bool:
""" """
@ -379,6 +426,13 @@ def is_channels_last_contiguous(a: Tensor) -> bool:
return is_channels_last_contiguous_2d(a) or is_channels_last_contiguous_3d(a) return is_channels_last_contiguous_2d(a) or is_channels_last_contiguous_3d(a)
# similar to is_channels_last_contiguous but return false on data dependency.
def is_known_channels_last_contiguous(a: Tensor) -> bool:
return is_known_channels_last_contiguous_2d(
a
) or is_known_channels_last_contiguous_3d(a)
def is_non_overlapping_and_dense(a: Tensor) -> bool: def is_non_overlapping_and_dense(a: Tensor) -> bool:
""" """
True when a tensor is non-overlapping and dense. True when a tensor is non-overlapping and dense.

View File

@ -19,6 +19,7 @@ import torch.utils._pytree as pytree
from torch import sym_float, sym_int from torch import sym_float, sym_int
from torch._prims_common import ( from torch._prims_common import (
BoolLike, BoolLike,
definitely_contiguous,
DeviceLikeType, DeviceLikeType,
Dim, Dim,
DimsSequenceType, DimsSequenceType,
@ -3824,7 +3825,7 @@ def _view_simple(a: TensorLikeType, shape, data_dependent_error) -> TensorLikeTy
if new_strides is not None: if new_strides is not None:
return a.as_strided(shape, new_strides) return a.as_strided(shape, new_strides)
if a.is_contiguous(): if definitely_contiguous(a):
return a.as_strided(shape, utils.make_contiguous_strides_for(shape)) return a.as_strided(shape, utils.make_contiguous_strides_for(shape))
raise data_dependent_error raise data_dependent_error

View File

@ -7,6 +7,7 @@ import torch
import torch.fx import torch.fx
from torch._dispatch.python import enable_python_dispatcher from torch._dispatch.python import enable_python_dispatcher
from torch._guards import detect_fake_mode from torch._guards import detect_fake_mode
from torch._prims_common import definitely_contiguous_for_memory_format
from torch._subclasses.meta_utils import is_sparse_any from torch._subclasses.meta_utils import is_sparse_any
from torch.fx._compatibility import compatibility from torch.fx._compatibility import compatibility
from torch.fx.node import map_aggregate, Node from torch.fx.node import map_aggregate, Node
@ -32,6 +33,10 @@ class TensorMetadata(NamedTuple):
qparams: dict[str, Any] qparams: dict[str, Any]
# When include_contiguity is True, we will set contiguity when its always true for the tensor.
# Some tensors can represent both contiguous and non-contiguous tensors. e.g: (u0, u1) with (u2, u3).
# In such situation contiguity is not set. We could also make it a tri-state i.e: (definitely_contiguous,
# contiguous, and unknown).
def _extract_tensor_metadata( def _extract_tensor_metadata(
result: torch.Tensor, include_contiguity=True result: torch.Tensor, include_contiguity=True
) -> TensorMetadata: ) -> TensorMetadata:
@ -52,7 +57,9 @@ def _extract_tensor_metadata(
torch.channels_last_3d, torch.channels_last_3d,
} }
for query_format in memory_formats: for query_format in memory_formats:
if result.is_contiguous(memory_format=query_format): if definitely_contiguous_for_memory_format(
result, memory_format=query_format
):
memory_format = query_format memory_format = query_format
break break