mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Add torch._C.Tag.needs_contiguous_strides (#152859)
this forces inductor to force the inputs to be contiguous. Test Plan: - new test Pull Request resolved: https://github.com/pytorch/pytorch/pull/152859 Approved by: https://github.com/eellison
This commit is contained in:
@ -46,20 +46,24 @@
|
|||||||
desc: |
|
desc: |
|
||||||
This tag indicates that the operator should be passed Tensors following
|
This tag indicates that the operator should be passed Tensors following
|
||||||
the same strides as observed in eager when compiled in inductor.
|
the same strides as observed in eager when compiled in inductor.
|
||||||
Only one of {needs_exact_strides, needs_fixed_stride_order, flexible_layout}
|
Only one of {needs_exact_strides, needs_contiguous_strides, needs_fixed_stride_order, flexible_layout}
|
||||||
can apply; if multiple are assigned then we assume the most restrictive one.
|
can apply; if multiple are assigned then we assume the most restrictive one.
|
||||||
|
- tag: needs_contiguous_strides
|
||||||
|
desc: |
|
||||||
|
This tag indicates that the operator should be passed contiguous Tensors.
|
||||||
|
Failure to do so will result in undefined behavior.
|
||||||
- tag: needs_fixed_stride_order
|
- tag: needs_fixed_stride_order
|
||||||
desc: |
|
desc: |
|
||||||
This tag indicates that the operator should be passed Tensors following
|
This tag indicates that the operator should be passed Tensors following
|
||||||
the same stride permutation as observed in eager when compiled in inductor.
|
the same stride permutation as observed in eager when compiled in inductor.
|
||||||
Only one of {needs_exact_strides, needs_fixed_stride_order, flexible_layout}
|
Only one of {needs_exact_strides, needs_contiguous_strides, needs_fixed_stride_order, flexible_layout}
|
||||||
can apply; if multiple are assigned then we assume the most restrictive one.
|
can apply; if multiple are assigned then we assume the most restrictive one.
|
||||||
- tag: flexible_layout
|
- tag: flexible_layout
|
||||||
desc: |
|
desc: |
|
||||||
This tag indicates that the custom operator can accept inputs with varying
|
This tag indicates that the custom operator can accept inputs with varying
|
||||||
strides/storage_offset and that when compiled, Inductor is allowed to change
|
strides/storage_offset and that when compiled, Inductor is allowed to change
|
||||||
the strides/storage_offset of inputs to the custom operator.
|
the strides/storage_offset of inputs to the custom operator.
|
||||||
Only one of {needs_exact_strides, needs_fixed_stride_order, flexible_layout}
|
Only one of {needs_exact_strides, needs_contiguous_strides, needs_fixed_stride_order, flexible_layout}
|
||||||
can apply; if multiple are assigned then we assume the most restrictive one.
|
can apply; if multiple are assigned then we assume the most restrictive one.
|
||||||
|
|
||||||
# NOTE [Core ATen Ops]
|
# NOTE [Core ATen Ops]
|
||||||
|
@ -8823,6 +8823,39 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
|||||||
self.assertTrue((d >= 0).all())
|
self.assertTrue((d >= 0).all())
|
||||||
self.assertTrue((d < 1).all())
|
self.assertTrue((d < 1).all())
|
||||||
|
|
||||||
|
@config.patch(implicit_fallbacks=True)
|
||||||
|
def test_needs_contiguous_strides(self):
|
||||||
|
# Construct a custom op whose output strides are not contiguous
|
||||||
|
@torch.library.custom_op("mylib::myop", mutates_args={})
|
||||||
|
def myop(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return torch.zeros(2, 2).t()
|
||||||
|
|
||||||
|
@myop.register_fake
|
||||||
|
def _(x):
|
||||||
|
return torch.zeros(2, 2).t()
|
||||||
|
|
||||||
|
# custom op that needs contiguous inputs
|
||||||
|
@torch.library.custom_op(
|
||||||
|
"mylib::second_op",
|
||||||
|
mutates_args={},
|
||||||
|
tags=[torch._C.Tag.needs_contiguous_strides],
|
||||||
|
)
|
||||||
|
def second_op(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
assert x.is_contiguous()
|
||||||
|
return torch.ones(2, 2)
|
||||||
|
|
||||||
|
@second_op.register_fake
|
||||||
|
def _(x):
|
||||||
|
return torch.ones(2, 2)
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
y = myop(x)
|
||||||
|
return second_op(y)
|
||||||
|
|
||||||
|
# Check that the x.is_contiguous() assertion never gets triggered
|
||||||
|
x = torch.randn(2, 2)
|
||||||
|
_ = torch.compile(f, backend="inductor", fullgraph=True)(x)
|
||||||
|
|
||||||
@config.patch(implicit_fallbacks=True)
|
@config.patch(implicit_fallbacks=True)
|
||||||
def test_fallback_mutable_op_basic(self):
|
def test_fallback_mutable_op_basic(self):
|
||||||
with torch.library._scoped_library("mylib", "FRAGMENT") as m:
|
with torch.library._scoped_library("mylib", "FRAGMENT") as m:
|
||||||
|
@ -5591,6 +5591,14 @@ class ExternKernel(InputsKernel):
|
|||||||
def require_contiguous(cls, x): # type: ignore[no-untyped-def]
|
def require_contiguous(cls, x): # type: ignore[no-untyped-def]
|
||||||
return cls.require_stride_order(x, list(reversed(range(len(x.get_size())))))
|
return cls.require_stride_order(x, list(reversed(range(len(x.get_size())))))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def require_contiguous_strides(cls, x): # type: ignore[no-untyped-def]
|
||||||
|
# TODO: combine this with require_contiguous after
|
||||||
|
# https://github.com/pytorch/pytorch/pull/148235 lands.
|
||||||
|
return cls.require_exact_strides(
|
||||||
|
x, FlexibleLayout.contiguous_strides(x.get_size())
|
||||||
|
)
|
||||||
|
|
||||||
def apply_constraint(self) -> None:
|
def apply_constraint(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -164,6 +164,8 @@ def maybe_layout_constraints(fn: Callable[..., Any]) -> Optional[Callable[..., A
|
|||||||
def tag_to_layout_constraint(tag):
|
def tag_to_layout_constraint(tag):
|
||||||
if tag == torch._C.Tag.needs_exact_strides:
|
if tag == torch._C.Tag.needs_exact_strides:
|
||||||
return constrain_to_fake_tensors
|
return constrain_to_fake_tensors
|
||||||
|
if tag == torch._C.Tag.needs_contiguous_strides:
|
||||||
|
return require_contiguous_strides
|
||||||
if tag == torch._C.Tag.needs_fixed_stride_order:
|
if tag == torch._C.Tag.needs_fixed_stride_order:
|
||||||
return constrain_to_fx_strides
|
return constrain_to_fx_strides
|
||||||
if tag == torch._C.Tag.flexible_layout:
|
if tag == torch._C.Tag.flexible_layout:
|
||||||
@ -2413,6 +2415,15 @@ def require_contiguous(_, *args, **kwargs):
|
|||||||
return args, kwargs
|
return args, kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def require_contiguous_strides(_, *args, **kwargs):
|
||||||
|
# TODO: combine this with require_contiguous after
|
||||||
|
# https://github.com/pytorch/pytorch/pull/148235 lands.
|
||||||
|
args, kwargs = pytree.tree_map_only(
|
||||||
|
ir.IRNode, ir.ExternKernel.require_contiguous_strides, (args, kwargs)
|
||||||
|
)
|
||||||
|
return args, kwargs
|
||||||
|
|
||||||
|
|
||||||
def require_channels_last(_, *args, **kwargs):
|
def require_channels_last(_, *args, **kwargs):
|
||||||
args, kwargs = pytree.tree_map_only(
|
args, kwargs = pytree.tree_map_only(
|
||||||
ir.IRNode, ir.ExternKernel.require_channels_last, (args, kwargs)
|
ir.IRNode, ir.ExternKernel.require_channels_last, (args, kwargs)
|
||||||
|
@ -505,6 +505,7 @@ def mutated_args_kwargs(schema: _C.FunctionSchema) -> tuple[list[int], list[str]
|
|||||||
|
|
||||||
tags_by_priority = [
|
tags_by_priority = [
|
||||||
_C.Tag.needs_exact_strides,
|
_C.Tag.needs_exact_strides,
|
||||||
|
_C.Tag.needs_contiguous_strides,
|
||||||
_C.Tag.needs_fixed_stride_order,
|
_C.Tag.needs_fixed_stride_order,
|
||||||
_C.Tag.flexible_layout,
|
_C.Tag.flexible_layout,
|
||||||
]
|
]
|
||||||
|
Reference in New Issue
Block a user