Add torch._C.Tag.needs_contiguous_strides (#152859)

this forces inductor to force the inputs to be contiguous.

Test Plan:
- new test

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152859
Approved by: https://github.com/eellison
This commit is contained in:
rzou
2025-05-06 12:01:55 -07:00
committed by PyTorch MergeBot
parent 2d25e4d478
commit 94ca3a4666
5 changed files with 60 additions and 3 deletions

View File

@ -46,20 +46,24 @@
desc: | desc: |
This tag indicates that the operator should be passed Tensors following This tag indicates that the operator should be passed Tensors following
the same strides as observed in eager when compiled in inductor. the same strides as observed in eager when compiled in inductor.
Only one of {needs_exact_strides, needs_fixed_stride_order, flexible_layout} Only one of {needs_exact_strides, needs_contiguous_strides, needs_fixed_stride_order, flexible_layout}
can apply; if multiple are assigned then we assume the most restrictive one. can apply; if multiple are assigned then we assume the most restrictive one.
- tag: needs_contiguous_strides
desc: |
This tag indicates that the operator should be passed contiguous Tensors.
Failure to do so will result in undefined behavior.
- tag: needs_fixed_stride_order - tag: needs_fixed_stride_order
desc: | desc: |
This tag indicates that the operator should be passed Tensors following This tag indicates that the operator should be passed Tensors following
the same stride permutation as observed in eager when compiled in inductor. the same stride permutation as observed in eager when compiled in inductor.
Only one of {needs_exact_strides, needs_fixed_stride_order, flexible_layout} Only one of {needs_exact_strides, needs_contiguous_strides, needs_fixed_stride_order, flexible_layout}
can apply; if multiple are assigned then we assume the most restrictive one. can apply; if multiple are assigned then we assume the most restrictive one.
- tag: flexible_layout - tag: flexible_layout
desc: | desc: |
This tag indicates that the custom operator can accept inputs with varying This tag indicates that the custom operator can accept inputs with varying
strides/storage_offset and that when compiled, Inductor is allowed to change strides/storage_offset and that when compiled, Inductor is allowed to change
the strides/storage_offset of inputs to the custom operator. the strides/storage_offset of inputs to the custom operator.
Only one of {needs_exact_strides, needs_fixed_stride_order, flexible_layout} Only one of {needs_exact_strides, needs_contiguous_strides, needs_fixed_stride_order, flexible_layout}
can apply; if multiple are assigned then we assume the most restrictive one. can apply; if multiple are assigned then we assume the most restrictive one.
# NOTE [Core ATen Ops] # NOTE [Core ATen Ops]

View File

@ -8823,6 +8823,39 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
self.assertTrue((d >= 0).all()) self.assertTrue((d >= 0).all())
self.assertTrue((d < 1).all()) self.assertTrue((d < 1).all())
@config.patch(implicit_fallbacks=True)
def test_needs_contiguous_strides(self):
# Construct a custom op whose output strides are not contiguous
@torch.library.custom_op("mylib::myop", mutates_args={})
def myop(x: torch.Tensor) -> torch.Tensor:
return torch.zeros(2, 2).t()
@myop.register_fake
def _(x):
return torch.zeros(2, 2).t()
# custom op that needs contiguous inputs
@torch.library.custom_op(
"mylib::second_op",
mutates_args={},
tags=[torch._C.Tag.needs_contiguous_strides],
)
def second_op(x: torch.Tensor) -> torch.Tensor:
assert x.is_contiguous()
return torch.ones(2, 2)
@second_op.register_fake
def _(x):
return torch.ones(2, 2)
def f(x):
y = myop(x)
return second_op(y)
# Check that the x.is_contiguous() assertion never gets triggered
x = torch.randn(2, 2)
_ = torch.compile(f, backend="inductor", fullgraph=True)(x)
@config.patch(implicit_fallbacks=True) @config.patch(implicit_fallbacks=True)
def test_fallback_mutable_op_basic(self): def test_fallback_mutable_op_basic(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as m: with torch.library._scoped_library("mylib", "FRAGMENT") as m:

View File

@ -5591,6 +5591,14 @@ class ExternKernel(InputsKernel):
def require_contiguous(cls, x): # type: ignore[no-untyped-def] def require_contiguous(cls, x): # type: ignore[no-untyped-def]
return cls.require_stride_order(x, list(reversed(range(len(x.get_size()))))) return cls.require_stride_order(x, list(reversed(range(len(x.get_size())))))
@classmethod
def require_contiguous_strides(cls, x): # type: ignore[no-untyped-def]
# TODO: combine this with require_contiguous after
# https://github.com/pytorch/pytorch/pull/148235 lands.
return cls.require_exact_strides(
x, FlexibleLayout.contiguous_strides(x.get_size())
)
def apply_constraint(self) -> None: def apply_constraint(self) -> None:
pass pass

View File

@ -164,6 +164,8 @@ def maybe_layout_constraints(fn: Callable[..., Any]) -> Optional[Callable[..., A
def tag_to_layout_constraint(tag): def tag_to_layout_constraint(tag):
if tag == torch._C.Tag.needs_exact_strides: if tag == torch._C.Tag.needs_exact_strides:
return constrain_to_fake_tensors return constrain_to_fake_tensors
if tag == torch._C.Tag.needs_contiguous_strides:
return require_contiguous_strides
if tag == torch._C.Tag.needs_fixed_stride_order: if tag == torch._C.Tag.needs_fixed_stride_order:
return constrain_to_fx_strides return constrain_to_fx_strides
if tag == torch._C.Tag.flexible_layout: if tag == torch._C.Tag.flexible_layout:
@ -2413,6 +2415,15 @@ def require_contiguous(_, *args, **kwargs):
return args, kwargs return args, kwargs
def require_contiguous_strides(_, *args, **kwargs):
# TODO: combine this with require_contiguous after
# https://github.com/pytorch/pytorch/pull/148235 lands.
args, kwargs = pytree.tree_map_only(
ir.IRNode, ir.ExternKernel.require_contiguous_strides, (args, kwargs)
)
return args, kwargs
def require_channels_last(_, *args, **kwargs): def require_channels_last(_, *args, **kwargs):
args, kwargs = pytree.tree_map_only( args, kwargs = pytree.tree_map_only(
ir.IRNode, ir.ExternKernel.require_channels_last, (args, kwargs) ir.IRNode, ir.ExternKernel.require_channels_last, (args, kwargs)

View File

@ -505,6 +505,7 @@ def mutated_args_kwargs(schema: _C.FunctionSchema) -> tuple[list[int], list[str]
tags_by_priority = [ tags_by_priority = [
_C.Tag.needs_exact_strides, _C.Tag.needs_exact_strides,
_C.Tag.needs_contiguous_strides,
_C.Tag.needs_fixed_stride_order, _C.Tag.needs_fixed_stride_order,
_C.Tag.flexible_layout, _C.Tag.flexible_layout,
] ]