Revert "[inductor]Let output or input_as_strided match exact strides (#130956)"

This reverts commit a63efee5cd422db0aabe5d02d2fe35fef9be7978.

Reverted https://github.com/pytorch/pytorch/pull/130956 on behalf of https://github.com/ZainRizvi due to sorry but this seems to cause internal tests to fail. Please see D61771533 for details ([comment](https://github.com/pytorch/pytorch/pull/130956#issuecomment-2310490049))
This commit is contained in:
PyTorch MergeBot
2024-08-26 15:31:23 +00:00
parent 1c4780e69a
commit 17e8a51ff2
3 changed files with 50 additions and 193 deletions

View File

@ -7257,24 +7257,6 @@ class CommonTemplate:
[torch.randn(8, 384, 20, 20).to(memory_format=torch.channels_last)],
)
def test_exact_stride(self):
full = torch.randn((16, 16), device=self.device)
view = torch.as_strided(full, (16, 8), full.stride())
def fn(x):
result = x + x
result_strided = torch.empty_strided(
x.size(), x.stride(), device=self.device
)
result_strided[:] = result
return result_strided
self.common(fn, [view])
reference_out = fn(view)
compiled_fn = torch.compile(fn)
actual_out = compiled_fn(view)
self.assertEqual(reference_out.stride(), actual_out.stride())
def test_like_channels_last(self):
def foo():
randn = torch.randn((4, 3, 8, 8), device=self.device, dtype=torch.float32)

View File

@ -186,9 +186,7 @@ def getattr_recursive(
return attr_itr
def mark_nodes_dislike_padding(
g: Graph, user_visible_outputs: Optional[Dict[str, None]]
) -> None:
def mark_nodes_dislike_padding(g: Graph) -> None:
"""
Nodes like convolution/convolution_backward want its input to be dense.
If we pad their inputs, we result in extra calls to copy kernels! On the other hand, padding usually helps reduction.
@ -235,9 +233,7 @@ def mark_nodes_dislike_padding(
op = _get_overload_packet(cur)
if not op:
continue
if op in ops_dislike_padding or (
user_visible_outputs and cur.name in user_visible_outputs
):
if op in ops_dislike_padding:
cur.meta["dislike_padding"] = True
if cur.meta.get("dislike_padding", False):
@ -419,11 +415,11 @@ class GraphLowering(torch.fx.Interpreter):
self.nodes_prefer_channels_last = (
self.find_nodes_prefer_channels_last() if self.layout_opt else OrderedSet()
)
mark_nodes_dislike_padding(gm.graph)
self._warned_fallback = {"aten.convolution_backward"}
self.user_visible_outputs = (
user_visible_outputs if user_visible_outputs is not None else {}
)
mark_nodes_dislike_padding(gm.graph, user_visible_outputs)
self.cache_key: str = "" # This is the cache key for the compiled artifact
self.cache_path: str = "" # This is the path in the filesystem where the compiled artifact is stored
self.cache_linemap: List[
@ -1356,47 +1352,26 @@ class GraphLowering(torch.fx.Interpreter):
n.meta["val"], torch.Tensor
):
strides = n.meta["val"].stride()
if len(strides):
allow_padding = (
n.name not in self.user_visible_outputs
and not is_input_for_as_strided
)
dense = torch._prims_common.is_non_overlapping_and_dense(
n.meta["val"]
)
unbacked_symbols_in_strides = (
len(free_unbacked_symbols(strides)) > 0
)
dense = torch._prims_common.is_non_overlapping_and_dense(n.meta["val"])
unbacked_symbols_in_strides = len(free_unbacked_symbols(strides)) > 0
# requiring a stride order for a non-dense output wouldn't
# recreate the same strides, and would fail with view, defer for now.
if not unbacked_symbols_in_strides and dense and len(strides):
stride_order = ir.get_stride_order(strides)
if (
not unbacked_symbols_in_strides
and dense
and len(result.get_size()) == 4
len(result.get_size()) == 4
and n in self.nodes_prefer_channels_last
and n.name not in self.user_visible_outputs
and not is_input_for_as_strided
):
strides = ir.FlexibleLayout.stride_ordered_for_memory_format(
result.get_size(), torch.channels_last
stride_order = ir.NHWC_STRIDE_ORDER
allow_padding = (
n.name not in self.user_visible_outputs
and not is_input_for_as_strided
)
if not unbacked_symbols_in_strides and len(strides):
# To avoid converting possible view ops to a copy kernel, we use the previous
# require_exact_strides to handle views. But ultimately it's better to require
# the right strides at the tensor definition.
if n.meta["val"]._is_view() or isinstance(
result.data, ir.BaseView
):
result = ir.ExternKernel.require_stride_order(
result,
ir.get_stride_order(strides),
allow_padding=allow_padding,
)
else:
strides = [
s.node.expr if isinstance(s, torch.SymInt) else s
for s in strides
]
result = ir.ExternKernel.require_exact_strides(
result, strides, allow_padding=allow_padding
result, stride_order, allow_padding=allow_padding
)
# Realize if (1) any user need inputs realized, or (2) there is

View File

@ -745,22 +745,6 @@ def get_reduction_combine_fn(
raise NotImplementedError(f"unknown reduction_type={reduction_type}")
def significant_strides_equal(
strides1: Sequence[_IntLike], strides2: Sequence[_IntLike], size: Sequence[_IntLike]
) -> bool:
"""
Returns true if the strides are equal, ignoring dimensions of size 1 .
"""
non_1_indices = [
i
for i, dim in enumerate(size)
if V.graph.sizevars.size_hint(dim, fallback=2) != 1
]
strides1 = [V.graph.sizevars.size_hint(strides1[i]) for i in non_1_indices]
strides2 = [V.graph.sizevars.size_hint(strides2[i]) for i in non_1_indices]
return strides1 == strides2
@dataclasses.dataclass
class Reduction(Loops):
reduction_ranges: List[Expr]
@ -2101,7 +2085,6 @@ def as_storage_and_layout(
want_contiguous: bool = False,
stride_order: Optional[Sequence[Union[int, Integer]]] = None,
allow_padding: bool = False,
exact_strides: Optional[Sequence[Union[int, Integer]]] = None,
) -> Tuple[StorageBox, Layout]:
"""
Try to simplify x into a StorageBox and a Layout.
@ -2116,7 +2099,6 @@ def as_storage_and_layout(
want_contiguous=want_contiguous,
stride_order=stride_order,
allow_padding=allow_padding,
exact_strides=exact_strides,
)
if isinstance(x, StorageBox) and isinstance(x.data, Buffer):
if freeze:
@ -2127,10 +2109,6 @@ def as_storage_and_layout(
x.data.freeze_layout_with_stride_order(
stride_order, allow_padding=allow_padding
)
elif exact_strides is not None:
x.data.freeze_layout_with_exact_strides(
exact_strides, allow_padding=allow_padding
)
else:
x.data.decide_layout()
return x, x.data.layout
@ -3224,19 +3202,6 @@ class FlexibleLayout(Layout):
self.offset,
)
def as_exact_strides(self, exact_strides, allow_padding=False):
new_stride = exact_strides
if self.should_pad_strides() and allow_padding:
new_stride = self._pad_strides(new_stride, self.size, self.dtype)
return FixedLayout(
self.device,
self.dtype,
self.size,
new_stride,
self.offset,
)
def as_fill_order(self, order):
new_stride = self.fill_ordered(self.size, order)
if self.should_pad_strides():
@ -3463,12 +3428,6 @@ class Buffer(IRNode):
assert isinstance(self.layout, FlexibleLayout)
self.layout = self.layout.as_same_order(stride)
def freeze_layout_with_exact_strides(self, exact_strides, allow_padding=False):
assert isinstance(self.layout, FlexibleLayout)
self.layout = self.layout.as_exact_strides(
exact_strides, allow_padding=allow_padding
)
def is_zero_elements(self):
return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) # type: ignore[arg-type]
@ -4721,22 +4680,15 @@ class ExternKernel(InputsKernel):
return cls.copy_input(x)
@classmethod
def require_strides(
cls,
x,
order: Optional[Sequence[int]] = None,
exact_strides: Optional[Sequence[_IntLike]] = None,
allow_padding=False,
):
assert order is not None or exact_strides is not None
def require_stride_order(cls, x, order, allow_padding=False):
if x.get_numel() == 0: # Layout doesn't matter
return x
# require x to have the layout
# require x to have the layout as strided_ordered as order
if is_storage_and_layout(x):
while isinstance(x.get_layout(), NonOwningLayout):
x = x.get_layout().view
if isinstance(x.get_layout(), FlexibleLayout):
if order:
# If the the FlexibleLayout already has the size and stride in the required order,
# freeze it to a FixedLayout by using its current size and stride.
# The behavior of using its current size and stride or the given order can be different
@ -4759,55 +4711,22 @@ class ExternKernel(InputsKernel):
allow_padding=allow_padding,
)
return x
else:
# If the exact_strides is given, freeze the FlexibleLayout to a FixedLayout with the exact_strides.
as_storage_and_layout(
x,
freeze=True,
want_contiguous=False,
stride_order=None,
allow_padding=allow_padding,
exact_strides=exact_strides,
)
return x
elif isinstance(x.get_layout(), FixedLayout) and (
(order and x.get_layout().is_stride_ordered(order))
or (
exact_strides
and significant_strides_equal(
exact_strides, x.get_layout().stride, x.get_size()
)
)
):
elif isinstance(
x.get_layout(), FixedLayout
) and x.get_layout().is_stride_ordered(order):
return x
elif isinstance(x.get_layout(), MutationLayoutSHOULDREMOVE):
if isinstance(x.get_layout().real_layout(), FlexibleLayout):
raise AssertionError(
"the MutationLayoutSHOULDREMOVE's real layout shouldn't be FlexibleLayout"
)
elif isinstance(x.get_layout().real_layout(), FixedLayout) and (
(order and x.get_layout().real_layout().is_stride_ordered(order))
or (
exact_strides
and significant_strides_equal(
exact_strides,
x.get_layout().real_layout().stride,
x.get_size(),
)
)
):
elif isinstance(
x.get_layout().real_layout(), FixedLayout
) and x.get_layout().real_layout().is_stride_ordered(order):
return x
# TODO - Storage to InputBuffer
if isinstance(x, InputBuffer) and (
(order and x.get_layout().is_stride_ordered(order))
or (
exact_strides
and significant_strides_equal(
exact_strides, x.get_layout().stride, x.get_size()
)
)
):
if isinstance(x, InputBuffer) and x.get_layout().is_stride_ordered(order):
return x
if (
isinstance(x, TensorBox)
@ -4818,14 +4737,7 @@ class ExternKernel(InputsKernel):
):
try:
x.data = cls.convert_to_reinterpret_view(x.data)
if order:
return cls.require_stride_order(
x, order, allow_padding=allow_padding
)
elif exact_strides:
return cls.require_exact_strides(
x, exact_strides, allow_padding=allow_padding
)
return cls.require_stride_order(x, order, allow_padding=allow_padding)
except NotImplementedError:
pass
# Although this is a clone, inductor is good about fusing clones into previous
@ -4837,22 +4749,10 @@ class ExternKernel(InputsKernel):
want_contiguous=False,
stride_order=order,
allow_padding=allow_padding,
exact_strides=exact_strides,
)
if order:
assert is_stride_order_storage_and_layout(x, order)
return x
@classmethod
def require_exact_strides(cls, x, exact_strides, allow_padding=False):
return cls.require_strides(
x, exact_strides=exact_strides, allow_padding=allow_padding
)
@classmethod
def require_stride_order(cls, x, order, allow_padding=False):
return cls.require_strides(x, order=order, allow_padding=allow_padding)
@classmethod
def require_channels_last(cls, x):
return cls.require_stride_order(x, NHWC_STRIDE_ORDER)