mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[inductor]Let output or input_as_strided match exact strides (#130956)"
This reverts commit a63efee5cd422db0aabe5d02d2fe35fef9be7978. Reverted https://github.com/pytorch/pytorch/pull/130956 on behalf of https://github.com/ZainRizvi due to sorry but this seems to cause internal tests to fail. Please see D61771533 for details ([comment](https://github.com/pytorch/pytorch/pull/130956#issuecomment-2310490049))
This commit is contained in:
@ -7257,24 +7257,6 @@ class CommonTemplate:
|
||||
[torch.randn(8, 384, 20, 20).to(memory_format=torch.channels_last)],
|
||||
)
|
||||
|
||||
def test_exact_stride(self):
|
||||
full = torch.randn((16, 16), device=self.device)
|
||||
view = torch.as_strided(full, (16, 8), full.stride())
|
||||
|
||||
def fn(x):
|
||||
result = x + x
|
||||
result_strided = torch.empty_strided(
|
||||
x.size(), x.stride(), device=self.device
|
||||
)
|
||||
result_strided[:] = result
|
||||
return result_strided
|
||||
|
||||
self.common(fn, [view])
|
||||
reference_out = fn(view)
|
||||
compiled_fn = torch.compile(fn)
|
||||
actual_out = compiled_fn(view)
|
||||
self.assertEqual(reference_out.stride(), actual_out.stride())
|
||||
|
||||
def test_like_channels_last(self):
|
||||
def foo():
|
||||
randn = torch.randn((4, 3, 8, 8), device=self.device, dtype=torch.float32)
|
||||
|
@ -186,9 +186,7 @@ def getattr_recursive(
|
||||
return attr_itr
|
||||
|
||||
|
||||
def mark_nodes_dislike_padding(
|
||||
g: Graph, user_visible_outputs: Optional[Dict[str, None]]
|
||||
) -> None:
|
||||
def mark_nodes_dislike_padding(g: Graph) -> None:
|
||||
"""
|
||||
Nodes like convolution/convolution_backward want its input to be dense.
|
||||
If we pad their inputs, we result in extra calls to copy kernels! On the other hand, padding usually helps reduction.
|
||||
@ -235,9 +233,7 @@ def mark_nodes_dislike_padding(
|
||||
op = _get_overload_packet(cur)
|
||||
if not op:
|
||||
continue
|
||||
if op in ops_dislike_padding or (
|
||||
user_visible_outputs and cur.name in user_visible_outputs
|
||||
):
|
||||
if op in ops_dislike_padding:
|
||||
cur.meta["dislike_padding"] = True
|
||||
|
||||
if cur.meta.get("dislike_padding", False):
|
||||
@ -419,11 +415,11 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
self.nodes_prefer_channels_last = (
|
||||
self.find_nodes_prefer_channels_last() if self.layout_opt else OrderedSet()
|
||||
)
|
||||
mark_nodes_dislike_padding(gm.graph)
|
||||
self._warned_fallback = {"aten.convolution_backward"}
|
||||
self.user_visible_outputs = (
|
||||
user_visible_outputs if user_visible_outputs is not None else {}
|
||||
)
|
||||
mark_nodes_dislike_padding(gm.graph, user_visible_outputs)
|
||||
self.cache_key: str = "" # This is the cache key for the compiled artifact
|
||||
self.cache_path: str = "" # This is the path in the filesystem where the compiled artifact is stored
|
||||
self.cache_linemap: List[
|
||||
@ -1356,47 +1352,26 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
n.meta["val"], torch.Tensor
|
||||
):
|
||||
strides = n.meta["val"].stride()
|
||||
if len(strides):
|
||||
allow_padding = (
|
||||
n.name not in self.user_visible_outputs
|
||||
and not is_input_for_as_strided
|
||||
)
|
||||
dense = torch._prims_common.is_non_overlapping_and_dense(
|
||||
n.meta["val"]
|
||||
)
|
||||
unbacked_symbols_in_strides = (
|
||||
len(free_unbacked_symbols(strides)) > 0
|
||||
)
|
||||
dense = torch._prims_common.is_non_overlapping_and_dense(n.meta["val"])
|
||||
unbacked_symbols_in_strides = len(free_unbacked_symbols(strides)) > 0
|
||||
# requiring a stride order for a non-dense output wouldn't
|
||||
# recreate the same strides, and would fail with view, defer for now.
|
||||
if not unbacked_symbols_in_strides and dense and len(strides):
|
||||
stride_order = ir.get_stride_order(strides)
|
||||
if (
|
||||
not unbacked_symbols_in_strides
|
||||
and dense
|
||||
and len(result.get_size()) == 4
|
||||
len(result.get_size()) == 4
|
||||
and n in self.nodes_prefer_channels_last
|
||||
and n.name not in self.user_visible_outputs
|
||||
and not is_input_for_as_strided
|
||||
):
|
||||
strides = ir.FlexibleLayout.stride_ordered_for_memory_format(
|
||||
result.get_size(), torch.channels_last
|
||||
stride_order = ir.NHWC_STRIDE_ORDER
|
||||
|
||||
allow_padding = (
|
||||
n.name not in self.user_visible_outputs
|
||||
and not is_input_for_as_strided
|
||||
)
|
||||
if not unbacked_symbols_in_strides and len(strides):
|
||||
# To avoid converting possible view ops to a copy kernel, we use the previous
|
||||
# require_exact_strides to handle views. But ultimately it's better to require
|
||||
# the right strides at the tensor definition.
|
||||
if n.meta["val"]._is_view() or isinstance(
|
||||
result.data, ir.BaseView
|
||||
):
|
||||
result = ir.ExternKernel.require_stride_order(
|
||||
result,
|
||||
ir.get_stride_order(strides),
|
||||
allow_padding=allow_padding,
|
||||
)
|
||||
else:
|
||||
strides = [
|
||||
s.node.expr if isinstance(s, torch.SymInt) else s
|
||||
for s in strides
|
||||
]
|
||||
result = ir.ExternKernel.require_exact_strides(
|
||||
result, strides, allow_padding=allow_padding
|
||||
result, stride_order, allow_padding=allow_padding
|
||||
)
|
||||
|
||||
# Realize if (1) any user need inputs realized, or (2) there is
|
||||
|
@ -745,22 +745,6 @@ def get_reduction_combine_fn(
|
||||
raise NotImplementedError(f"unknown reduction_type={reduction_type}")
|
||||
|
||||
|
||||
def significant_strides_equal(
|
||||
strides1: Sequence[_IntLike], strides2: Sequence[_IntLike], size: Sequence[_IntLike]
|
||||
) -> bool:
|
||||
"""
|
||||
Returns true if the strides are equal, ignoring dimensions of size 1 .
|
||||
"""
|
||||
non_1_indices = [
|
||||
i
|
||||
for i, dim in enumerate(size)
|
||||
if V.graph.sizevars.size_hint(dim, fallback=2) != 1
|
||||
]
|
||||
strides1 = [V.graph.sizevars.size_hint(strides1[i]) for i in non_1_indices]
|
||||
strides2 = [V.graph.sizevars.size_hint(strides2[i]) for i in non_1_indices]
|
||||
return strides1 == strides2
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Reduction(Loops):
|
||||
reduction_ranges: List[Expr]
|
||||
@ -2101,7 +2085,6 @@ def as_storage_and_layout(
|
||||
want_contiguous: bool = False,
|
||||
stride_order: Optional[Sequence[Union[int, Integer]]] = None,
|
||||
allow_padding: bool = False,
|
||||
exact_strides: Optional[Sequence[Union[int, Integer]]] = None,
|
||||
) -> Tuple[StorageBox, Layout]:
|
||||
"""
|
||||
Try to simplify x into a StorageBox and a Layout.
|
||||
@ -2116,7 +2099,6 @@ def as_storage_and_layout(
|
||||
want_contiguous=want_contiguous,
|
||||
stride_order=stride_order,
|
||||
allow_padding=allow_padding,
|
||||
exact_strides=exact_strides,
|
||||
)
|
||||
if isinstance(x, StorageBox) and isinstance(x.data, Buffer):
|
||||
if freeze:
|
||||
@ -2127,10 +2109,6 @@ def as_storage_and_layout(
|
||||
x.data.freeze_layout_with_stride_order(
|
||||
stride_order, allow_padding=allow_padding
|
||||
)
|
||||
elif exact_strides is not None:
|
||||
x.data.freeze_layout_with_exact_strides(
|
||||
exact_strides, allow_padding=allow_padding
|
||||
)
|
||||
else:
|
||||
x.data.decide_layout()
|
||||
return x, x.data.layout
|
||||
@ -3224,19 +3202,6 @@ class FlexibleLayout(Layout):
|
||||
self.offset,
|
||||
)
|
||||
|
||||
def as_exact_strides(self, exact_strides, allow_padding=False):
|
||||
new_stride = exact_strides
|
||||
if self.should_pad_strides() and allow_padding:
|
||||
new_stride = self._pad_strides(new_stride, self.size, self.dtype)
|
||||
|
||||
return FixedLayout(
|
||||
self.device,
|
||||
self.dtype,
|
||||
self.size,
|
||||
new_stride,
|
||||
self.offset,
|
||||
)
|
||||
|
||||
def as_fill_order(self, order):
|
||||
new_stride = self.fill_ordered(self.size, order)
|
||||
if self.should_pad_strides():
|
||||
@ -3463,12 +3428,6 @@ class Buffer(IRNode):
|
||||
assert isinstance(self.layout, FlexibleLayout)
|
||||
self.layout = self.layout.as_same_order(stride)
|
||||
|
||||
def freeze_layout_with_exact_strides(self, exact_strides, allow_padding=False):
|
||||
assert isinstance(self.layout, FlexibleLayout)
|
||||
self.layout = self.layout.as_exact_strides(
|
||||
exact_strides, allow_padding=allow_padding
|
||||
)
|
||||
|
||||
def is_zero_elements(self):
|
||||
return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) # type: ignore[arg-type]
|
||||
|
||||
@ -4721,22 +4680,15 @@ class ExternKernel(InputsKernel):
|
||||
return cls.copy_input(x)
|
||||
|
||||
@classmethod
|
||||
def require_strides(
|
||||
cls,
|
||||
x,
|
||||
order: Optional[Sequence[int]] = None,
|
||||
exact_strides: Optional[Sequence[_IntLike]] = None,
|
||||
allow_padding=False,
|
||||
):
|
||||
assert order is not None or exact_strides is not None
|
||||
def require_stride_order(cls, x, order, allow_padding=False):
|
||||
if x.get_numel() == 0: # Layout doesn't matter
|
||||
return x
|
||||
# require x to have the layout
|
||||
|
||||
# require x to have the layout as strided_ordered as order
|
||||
if is_storage_and_layout(x):
|
||||
while isinstance(x.get_layout(), NonOwningLayout):
|
||||
x = x.get_layout().view
|
||||
if isinstance(x.get_layout(), FlexibleLayout):
|
||||
if order:
|
||||
# If the the FlexibleLayout already has the size and stride in the required order,
|
||||
# freeze it to a FixedLayout by using its current size and stride.
|
||||
# The behavior of using its current size and stride or the given order can be different
|
||||
@ -4759,55 +4711,22 @@ class ExternKernel(InputsKernel):
|
||||
allow_padding=allow_padding,
|
||||
)
|
||||
return x
|
||||
else:
|
||||
# If the exact_strides is given, freeze the FlexibleLayout to a FixedLayout with the exact_strides.
|
||||
as_storage_and_layout(
|
||||
x,
|
||||
freeze=True,
|
||||
want_contiguous=False,
|
||||
stride_order=None,
|
||||
allow_padding=allow_padding,
|
||||
exact_strides=exact_strides,
|
||||
)
|
||||
return x
|
||||
elif isinstance(x.get_layout(), FixedLayout) and (
|
||||
(order and x.get_layout().is_stride_ordered(order))
|
||||
or (
|
||||
exact_strides
|
||||
and significant_strides_equal(
|
||||
exact_strides, x.get_layout().stride, x.get_size()
|
||||
)
|
||||
)
|
||||
):
|
||||
elif isinstance(
|
||||
x.get_layout(), FixedLayout
|
||||
) and x.get_layout().is_stride_ordered(order):
|
||||
return x
|
||||
elif isinstance(x.get_layout(), MutationLayoutSHOULDREMOVE):
|
||||
if isinstance(x.get_layout().real_layout(), FlexibleLayout):
|
||||
raise AssertionError(
|
||||
"the MutationLayoutSHOULDREMOVE's real layout shouldn't be FlexibleLayout"
|
||||
)
|
||||
elif isinstance(x.get_layout().real_layout(), FixedLayout) and (
|
||||
(order and x.get_layout().real_layout().is_stride_ordered(order))
|
||||
or (
|
||||
exact_strides
|
||||
and significant_strides_equal(
|
||||
exact_strides,
|
||||
x.get_layout().real_layout().stride,
|
||||
x.get_size(),
|
||||
)
|
||||
)
|
||||
):
|
||||
elif isinstance(
|
||||
x.get_layout().real_layout(), FixedLayout
|
||||
) and x.get_layout().real_layout().is_stride_ordered(order):
|
||||
return x
|
||||
|
||||
# TODO - Storage to InputBuffer
|
||||
if isinstance(x, InputBuffer) and (
|
||||
(order and x.get_layout().is_stride_ordered(order))
|
||||
or (
|
||||
exact_strides
|
||||
and significant_strides_equal(
|
||||
exact_strides, x.get_layout().stride, x.get_size()
|
||||
)
|
||||
)
|
||||
):
|
||||
if isinstance(x, InputBuffer) and x.get_layout().is_stride_ordered(order):
|
||||
return x
|
||||
if (
|
||||
isinstance(x, TensorBox)
|
||||
@ -4818,14 +4737,7 @@ class ExternKernel(InputsKernel):
|
||||
):
|
||||
try:
|
||||
x.data = cls.convert_to_reinterpret_view(x.data)
|
||||
if order:
|
||||
return cls.require_stride_order(
|
||||
x, order, allow_padding=allow_padding
|
||||
)
|
||||
elif exact_strides:
|
||||
return cls.require_exact_strides(
|
||||
x, exact_strides, allow_padding=allow_padding
|
||||
)
|
||||
return cls.require_stride_order(x, order, allow_padding=allow_padding)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
# Although this is a clone, inductor is good about fusing clones into previous
|
||||
@ -4837,22 +4749,10 @@ class ExternKernel(InputsKernel):
|
||||
want_contiguous=False,
|
||||
stride_order=order,
|
||||
allow_padding=allow_padding,
|
||||
exact_strides=exact_strides,
|
||||
)
|
||||
if order:
|
||||
assert is_stride_order_storage_and_layout(x, order)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def require_exact_strides(cls, x, exact_strides, allow_padding=False):
|
||||
return cls.require_strides(
|
||||
x, exact_strides=exact_strides, allow_padding=allow_padding
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def require_stride_order(cls, x, order, allow_padding=False):
|
||||
return cls.require_strides(x, order=order, allow_padding=allow_padding)
|
||||
|
||||
@classmethod
|
||||
def require_channels_last(cls, x):
|
||||
return cls.require_stride_order(x, NHWC_STRIDE_ORDER)
|
||||
|
Reference in New Issue
Block a user