Add torch._C.Tag.needs_contiguous_strides (#152859)

this forces inductor to force the inputs to be contiguous.

Test Plan:
- new test

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152859
Approved by: https://github.com/eellison
This commit is contained in:
rzou
2025-05-06 12:01:55 -07:00
committed by PyTorch MergeBot
parent 2d25e4d478
commit 94ca3a4666
5 changed files with 60 additions and 3 deletions

View File

@ -46,20 +46,24 @@
desc: |
This tag indicates that the operator should be passed Tensors following
the same strides as observed in eager when compiled in inductor.
Only one of {needs_exact_strides, needs_fixed_stride_order, flexible_layout}
Only one of {needs_exact_strides, needs_contiguous_strides, needs_fixed_stride_order, flexible_layout}
can apply; if multiple are assigned then we assume the most restrictive one.
- tag: needs_contiguous_strides
desc: |
This tag indicates that the operator should be passed contiguous Tensors.
Failure to do so will result in undefined behavior.
- tag: needs_fixed_stride_order
desc: |
This tag indicates that the operator should be passed Tensors following
the same stride permutation as observed in eager when compiled in inductor.
Only one of {needs_exact_strides, needs_fixed_stride_order, flexible_layout}
Only one of {needs_exact_strides, needs_contiguous_strides, needs_fixed_stride_order, flexible_layout}
can apply; if multiple are assigned then we assume the most restrictive one.
- tag: flexible_layout
desc: |
This tag indicates that the custom operator can accept inputs with varying
strides/storage_offset and that when compiled, Inductor is allowed to change
the strides/storage_offset of inputs to the custom operator.
Only one of {needs_exact_strides, needs_fixed_stride_order, flexible_layout}
Only one of {needs_exact_strides, needs_contiguous_strides, needs_fixed_stride_order, flexible_layout}
can apply; if multiple are assigned then we assume the most restrictive one.
# NOTE [Core ATen Ops]

View File

@ -8823,6 +8823,39 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
self.assertTrue((d >= 0).all())
self.assertTrue((d < 1).all())
@config.patch(implicit_fallbacks=True)
def test_needs_contiguous_strides(self):
# Construct a custom op whose output strides are not contiguous
@torch.library.custom_op("mylib::myop", mutates_args={})
def myop(x: torch.Tensor) -> torch.Tensor:
return torch.zeros(2, 2).t()
@myop.register_fake
def _(x):
return torch.zeros(2, 2).t()
# custom op that needs contiguous inputs
@torch.library.custom_op(
"mylib::second_op",
mutates_args={},
tags=[torch._C.Tag.needs_contiguous_strides],
)
def second_op(x: torch.Tensor) -> torch.Tensor:
assert x.is_contiguous()
return torch.ones(2, 2)
@second_op.register_fake
def _(x):
return torch.ones(2, 2)
def f(x):
y = myop(x)
return second_op(y)
# Check that the x.is_contiguous() assertion never gets triggered
x = torch.randn(2, 2)
_ = torch.compile(f, backend="inductor", fullgraph=True)(x)
@config.patch(implicit_fallbacks=True)
def test_fallback_mutable_op_basic(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as m:

View File

@ -5591,6 +5591,14 @@ class ExternKernel(InputsKernel):
def require_contiguous(cls, x): # type: ignore[no-untyped-def]
return cls.require_stride_order(x, list(reversed(range(len(x.get_size())))))
@classmethod
def require_contiguous_strides(cls, x): # type: ignore[no-untyped-def]
# TODO: combine this with require_contiguous after
# https://github.com/pytorch/pytorch/pull/148235 lands.
return cls.require_exact_strides(
x, FlexibleLayout.contiguous_strides(x.get_size())
)
def apply_constraint(self) -> None:
pass

View File

@ -164,6 +164,8 @@ def maybe_layout_constraints(fn: Callable[..., Any]) -> Optional[Callable[..., A
def tag_to_layout_constraint(tag):
if tag == torch._C.Tag.needs_exact_strides:
return constrain_to_fake_tensors
if tag == torch._C.Tag.needs_contiguous_strides:
return require_contiguous_strides
if tag == torch._C.Tag.needs_fixed_stride_order:
return constrain_to_fx_strides
if tag == torch._C.Tag.flexible_layout:
@ -2413,6 +2415,15 @@ def require_contiguous(_, *args, **kwargs):
return args, kwargs
def require_contiguous_strides(_, *args, **kwargs):
# TODO: combine this with require_contiguous after
# https://github.com/pytorch/pytorch/pull/148235 lands.
args, kwargs = pytree.tree_map_only(
ir.IRNode, ir.ExternKernel.require_contiguous_strides, (args, kwargs)
)
return args, kwargs
def require_channels_last(_, *args, **kwargs):
args, kwargs = pytree.tree_map_only(
ir.IRNode, ir.ExternKernel.require_channels_last, (args, kwargs)

View File

@ -505,6 +505,7 @@ def mutated_args_kwargs(schema: _C.FunctionSchema) -> tuple[list[int], list[str]
tags_by_priority = [
_C.Tag.needs_exact_strides,
_C.Tag.needs_contiguous_strides,
_C.Tag.needs_fixed_stride_order,
_C.Tag.flexible_layout,
]