mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add torch._C.Tag.needs_contiguous_strides (#152859)
this forces inductor to force the inputs to be contiguous. Test Plan: - new test Pull Request resolved: https://github.com/pytorch/pytorch/pull/152859 Approved by: https://github.com/eellison
This commit is contained in:
@ -46,20 +46,24 @@
|
||||
desc: |
|
||||
This tag indicates that the operator should be passed Tensors following
|
||||
the same strides as observed in eager when compiled in inductor.
|
||||
Only one of {needs_exact_strides, needs_fixed_stride_order, flexible_layout}
|
||||
Only one of {needs_exact_strides, needs_contiguous_strides, needs_fixed_stride_order, flexible_layout}
|
||||
can apply; if multiple are assigned then we assume the most restrictive one.
|
||||
- tag: needs_contiguous_strides
|
||||
desc: |
|
||||
This tag indicates that the operator should be passed contiguous Tensors.
|
||||
Failure to do so will result in undefined behavior.
|
||||
- tag: needs_fixed_stride_order
|
||||
desc: |
|
||||
This tag indicates that the operator should be passed Tensors following
|
||||
the same stride permutation as observed in eager when compiled in inductor.
|
||||
Only one of {needs_exact_strides, needs_fixed_stride_order, flexible_layout}
|
||||
Only one of {needs_exact_strides, needs_contiguous_strides, needs_fixed_stride_order, flexible_layout}
|
||||
can apply; if multiple are assigned then we assume the most restrictive one.
|
||||
- tag: flexible_layout
|
||||
desc: |
|
||||
This tag indicates that the custom operator can accept inputs with varying
|
||||
strides/storage_offset and that when compiled, Inductor is allowed to change
|
||||
the strides/storage_offset of inputs to the custom operator.
|
||||
Only one of {needs_exact_strides, needs_fixed_stride_order, flexible_layout}
|
||||
Only one of {needs_exact_strides, needs_contiguous_strides, needs_fixed_stride_order, flexible_layout}
|
||||
can apply; if multiple are assigned then we assume the most restrictive one.
|
||||
|
||||
# NOTE [Core ATen Ops]
|
||||
|
@ -8823,6 +8823,39 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||
self.assertTrue((d >= 0).all())
|
||||
self.assertTrue((d < 1).all())
|
||||
|
||||
@config.patch(implicit_fallbacks=True)
|
||||
def test_needs_contiguous_strides(self):
|
||||
# Construct a custom op whose output strides are not contiguous
|
||||
@torch.library.custom_op("mylib::myop", mutates_args={})
|
||||
def myop(x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.zeros(2, 2).t()
|
||||
|
||||
@myop.register_fake
|
||||
def _(x):
|
||||
return torch.zeros(2, 2).t()
|
||||
|
||||
# custom op that needs contiguous inputs
|
||||
@torch.library.custom_op(
|
||||
"mylib::second_op",
|
||||
mutates_args={},
|
||||
tags=[torch._C.Tag.needs_contiguous_strides],
|
||||
)
|
||||
def second_op(x: torch.Tensor) -> torch.Tensor:
|
||||
assert x.is_contiguous()
|
||||
return torch.ones(2, 2)
|
||||
|
||||
@second_op.register_fake
|
||||
def _(x):
|
||||
return torch.ones(2, 2)
|
||||
|
||||
def f(x):
|
||||
y = myop(x)
|
||||
return second_op(y)
|
||||
|
||||
# Check that the x.is_contiguous() assertion never gets triggered
|
||||
x = torch.randn(2, 2)
|
||||
_ = torch.compile(f, backend="inductor", fullgraph=True)(x)
|
||||
|
||||
@config.patch(implicit_fallbacks=True)
|
||||
def test_fallback_mutable_op_basic(self):
|
||||
with torch.library._scoped_library("mylib", "FRAGMENT") as m:
|
||||
|
@ -5591,6 +5591,14 @@ class ExternKernel(InputsKernel):
|
||||
def require_contiguous(cls, x): # type: ignore[no-untyped-def]
|
||||
return cls.require_stride_order(x, list(reversed(range(len(x.get_size())))))
|
||||
|
||||
@classmethod
|
||||
def require_contiguous_strides(cls, x): # type: ignore[no-untyped-def]
|
||||
# TODO: combine this with require_contiguous after
|
||||
# https://github.com/pytorch/pytorch/pull/148235 lands.
|
||||
return cls.require_exact_strides(
|
||||
x, FlexibleLayout.contiguous_strides(x.get_size())
|
||||
)
|
||||
|
||||
def apply_constraint(self) -> None:
|
||||
pass
|
||||
|
||||
|
@ -164,6 +164,8 @@ def maybe_layout_constraints(fn: Callable[..., Any]) -> Optional[Callable[..., A
|
||||
def tag_to_layout_constraint(tag):
|
||||
if tag == torch._C.Tag.needs_exact_strides:
|
||||
return constrain_to_fake_tensors
|
||||
if tag == torch._C.Tag.needs_contiguous_strides:
|
||||
return require_contiguous_strides
|
||||
if tag == torch._C.Tag.needs_fixed_stride_order:
|
||||
return constrain_to_fx_strides
|
||||
if tag == torch._C.Tag.flexible_layout:
|
||||
@ -2413,6 +2415,15 @@ def require_contiguous(_, *args, **kwargs):
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def require_contiguous_strides(_, *args, **kwargs):
|
||||
# TODO: combine this with require_contiguous after
|
||||
# https://github.com/pytorch/pytorch/pull/148235 lands.
|
||||
args, kwargs = pytree.tree_map_only(
|
||||
ir.IRNode, ir.ExternKernel.require_contiguous_strides, (args, kwargs)
|
||||
)
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def require_channels_last(_, *args, **kwargs):
|
||||
args, kwargs = pytree.tree_map_only(
|
||||
ir.IRNode, ir.ExternKernel.require_channels_last, (args, kwargs)
|
||||
|
@ -505,6 +505,7 @@ def mutated_args_kwargs(schema: _C.FunctionSchema) -> tuple[list[int], list[str]
|
||||
|
||||
tags_by_priority = [
|
||||
_C.Tag.needs_exact_strides,
|
||||
_C.Tag.needs_contiguous_strides,
|
||||
_C.Tag.needs_fixed_stride_order,
|
||||
_C.Tag.flexible_layout,
|
||||
]
|
||||
|
Reference in New Issue
Block a user