mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add Inductor config for default stride behavior (#135238)
By default, Inductor is allowed to manipulate the layout (strides+storage offset) of input tensors to custom operators. We want to change it so that the default is that Inductor should respect the stride order of input tensors to custom operators. This PR adds a config to toggle the behavior, in the next PR up we'll change the default. We also make the following changes: - We add a new operator Tag (flexible_layout), which means that inductor is allowed to manipulate the layout. When we flip the default, users can specify they want the old behavior by using this tag. This is a reland of https://github.com/pytorch/pytorch/pull/126986, which was previously reverted due to silent incorrectness. We've since fixed the silent incorrectness (https://github.com/pytorch/pytorch/pull/133639) Test Plan: - new test Pull Request resolved: https://github.com/pytorch/pytorch/pull/135238 Approved by: https://github.com/albanD
This commit is contained in:
@ -46,6 +46,15 @@
|
||||
desc: |
|
||||
This tag indicates that the operator should be passed Tensors following
|
||||
the same stride permutation as observed in eager when compiled in inductor.
|
||||
Only one of {needs_fixed_stride_order, flexible_layout} can apply; if
|
||||
multiple are assigned then we assume the most restrictive one.
|
||||
- tag: flexible_layout
|
||||
desc: |
|
||||
This tag indicates that the custom operator can accept inputs with varying
|
||||
strides/storage_offset and that when compiled, Inductor is allowed to change
|
||||
the strides/storage_offset of inputs to the custom operator.
|
||||
Only one of {needs_fixed_stride_order, flexible_layout} can apply; if
|
||||
multiple are assigned then we assume the most restrictive one.
|
||||
|
||||
# NOTE [Core ATen Ops]
|
||||
- tag: core
|
||||
|
||||
@ -3428,6 +3428,26 @@ Please use `add.register_fake` to add an fake impl.""",
|
||||
self.assertTrue(called)
|
||||
self.assertEqual(result, w * 2 * 3 * 42)
|
||||
|
||||
def test_layout_constraint_tags(self):
|
||||
needs_fixed_stride_order = torch._C.Tag.needs_fixed_stride_order
|
||||
flexible_layout = torch._C.Tag.flexible_layout
|
||||
# (tags, the result of the tag inference)
|
||||
tests = [
|
||||
({needs_fixed_stride_order}, needs_fixed_stride_order),
|
||||
({flexible_layout}, flexible_layout),
|
||||
# If no tags are provided, then the following is the default
|
||||
(set(), flexible_layout),
|
||||
# If multiple tags are provided, then we use the most constrained tag.
|
||||
({flexible_layout, needs_fixed_stride_order}, needs_fixed_stride_order),
|
||||
]
|
||||
from torch._inductor.lowering import get_layout_constraint_tag
|
||||
|
||||
for tags, expected in tests:
|
||||
with torch.library._scoped_library("mylib", "FRAGMENT") as m:
|
||||
m.define("foobar(Tensor x) -> Tensor", tags=tags)
|
||||
result = get_layout_constraint_tag(torch.ops.mylib.foobar.default)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
|
||||
def test_library_register_vmap(self):
|
||||
for mode in ["function", "qualname", "opoverload", "c_opdef"]:
|
||||
|
||||
@ -65,6 +65,13 @@ force_disable_caches = os.environ.get("TORCHINDUCTOR_FORCE_DISABLE_CACHES") == "
|
||||
# sleep in inductor for testing
|
||||
sleep_sec_TESTING_ONLY: Optional[int] = None
|
||||
|
||||
# The default layout constraint for custom operators.
|
||||
# This must be the name of one of the layout constraint tags
|
||||
# (that is, one of {"needs_fixed_stride_order", "flexible_layout"}),
|
||||
# If the custom op does not have a layout constraint tag already
|
||||
# then we assume the following applies.
|
||||
custom_op_default_layout_constraint = "flexible_layout"
|
||||
|
||||
# use cpp wrapper instead of python wrapper
|
||||
cpp_wrapper = os.environ.get("TORCHINDUCTOR_CPP_WRAPPER", "0") == "1"
|
||||
|
||||
|
||||
@ -101,11 +101,30 @@ def maybe_layout_constraints(fn: Callable[..., Any]) -> Optional[Callable[..., A
|
||||
_maybe_layout_constraints[fn] = None
|
||||
return None
|
||||
# We lazily register tag-based layout constraints.
|
||||
if torch._C.Tag.needs_fixed_stride_order in fn.tags:
|
||||
_maybe_layout_constraints[fn] = constrain_to_fx_strides
|
||||
return _maybe_layout_constraints[fn]
|
||||
_maybe_layout_constraints[fn] = None
|
||||
return None
|
||||
|
||||
def handle_layout_constraint_tag(tag):
|
||||
if tag is torch._C.Tag.needs_fixed_stride_order:
|
||||
_maybe_layout_constraints[fn] = constrain_to_fx_strides
|
||||
return _maybe_layout_constraints[fn]
|
||||
elif tag is torch._C.Tag.flexible_layout:
|
||||
_maybe_layout_constraints[fn] = None
|
||||
return None
|
||||
else:
|
||||
raise AssertionError(f"Unknown layout constraint tag: {tag}")
|
||||
|
||||
tag = get_layout_constraint_tag(fn)
|
||||
return handle_layout_constraint_tag(tag)
|
||||
|
||||
|
||||
def get_layout_constraint_tag(fn):
|
||||
tags_by_priority = [
|
||||
torch._C.Tag.needs_fixed_stride_order,
|
||||
torch._C.Tag.flexible_layout,
|
||||
]
|
||||
for tag in tags_by_priority:
|
||||
if tag in fn.tags:
|
||||
return tag
|
||||
return getattr(torch._C.Tag, config.custom_op_default_layout_constraint)
|
||||
|
||||
|
||||
def assert_nyi(cond, msg):
|
||||
|
||||
Reference in New Issue
Block a user