Revert "[inductor][custom ops] Add tag to custom ops to preserve stride orders in inductor (#117298)"

This reverts commit 1967394690f144a7ba1717eccec977286cafe2da.

Reverted https://github.com/pytorch/pytorch/pull/117298 on behalf of https://github.com/huydhn due to Sorry for reverting you change but it is failing in MacOS 1967394690, may be due to a landrace ([comment](https://github.com/pytorch/pytorch/pull/117298#issuecomment-1901594120))
This commit is contained in:
PyTorch MergeBot
2024-01-20 02:14:58 +00:00
parent 94f0472579
commit 10923f8720
5 changed files with 13 additions and 141 deletions

View File

@ -42,10 +42,6 @@
desc: |
This tag indicates if an operator doesn't guarantee bitwise equivalence
across different runs of an operator with identical inputs.
- tag: needs_fixed_stride_order
desc: |
This tag indicates that the operator should be passed Tensors following
the same stride permutation as observed in eager when compiled in inductor.
# NOTE [Core ATen Ops]
- tag: core

View File

@ -116,20 +116,7 @@ skip_if_x86_mac = functools.partial(
)
vec_dtypes = [torch.float, torch.bfloat16, torch.float16]
libtest = torch.library.Library("test", "FRAGMENT")
ids = set()
def define_custom_op_for_test(id_, fn_cpu, fn_cuda, fn_meta, tags=()):
global libtest
global ids
if id_ not in ids:
libtest.define(f"{id_}(Tensor self) -> Tensor", tags=tags)
libtest.impl(id_, fn_cpu, "CPU")
libtest.impl(id_, fn_cuda, "CUDA")
libtest.impl(id_, fn_meta, "Meta")
ids.add(id_)
libfoo = None
f32 = torch.float32
@ -8002,108 +7989,22 @@ class CommonTemplate:
def foo_meta(x):
return torch.empty_like(x)
define_custom_op_for_test("foo", foo_cpu, foo_cuda, foo_meta)
global libfoo
if libfoo is None:
libfoo = torch.library.Library("foo", "DEF")
libfoo.define("custom(Tensor self) -> Tensor")
libfoo.impl("custom", foo_cpu, "CPU")
libfoo.impl("custom", foo_cuda, "CUDA")
libfoo.impl("custom", foo_meta, "Meta")
def fn(x):
a = torch.nn.functional.relu(x)
b = torch.ops.test.foo(a)
b = torch.ops.foo.custom(a)
c = torch.cos(b)
return c
self.common(fn, (torch.randn((16, 32)),), check_lowp=False)
@requires_cuda()
@torch._inductor.config.patch("layout_optimization", True)
@torch._inductor.config.patch("keep_output_stride", False)
@config.patch(implicit_fallbacks=True)
def test_custom_op_fixed_layout_sequential(self):
import torch.library
mod = nn.Conv2d(3, 128, 1, stride=1, bias=False).cuda()
inp = torch.rand(2, 3, 128, 128, device="cuda")
expected_stride = mod(inp).stride()
def bar_cpu(x):
self.assertEqual(x.stride(), expected_stride)
return x.clone()
def bar_cuda(x):
self.assertEqual(x.stride(), expected_stride)
return x.clone()
def bar_meta(x):
return torch.empty_like(x)
define_custom_op_for_test(
"bar",
bar_cpu,
bar_cuda,
bar_meta,
tags=[torch._C.Tag.needs_fixed_stride_order],
)
def fn(x):
z = mod(x)
output = torch.ops.test.bar(z)
return output
with torch.no_grad():
# With keep_output_stride False, inductor would normally have different layout from eager execution
# But because our custom op needs fixed layout, the assertions in the custom op will pass
self.common(fn, (inp,), check_lowp=False)
@requires_cuda()
@config.patch(implicit_fallbacks=True)
def test_custom_op_fixed_layout_channels_last(self):
class Block(nn.Module):
def __init__(
self,
):
super().__init__()
self.in_layers = nn.Sequential(
nn.Dropout(p=0.1),
)
def helper(self, x):
out = F.gelu(x)
out = self.in_layers(out)
return out
def forward(self, x):
out = self.helper(x)
out = torch.ops.test.baz(out)
return out
model = Block()
model = model.to("cuda").to(memory_format=torch.channels_last)
input_t = torch.randn([1, 320, 128, 128], dtype=torch.float32, device="cuda")
input_t = input_t.to(memory_format=torch.channels_last)
expected_strides = model.helper(input_t).stride()
def baz_cpu(x):
self.assertEqual(expected_strides, x.stride())
return x.clone()
def baz_cuda(x):
self.assertEqual(expected_strides, x.stride())
return x.clone()
def baz_meta(x):
return torch.empty_like(x)
define_custom_op_for_test(
"baz",
baz_cpu,
baz_cuda,
baz_meta,
tags=[torch._C.Tag.needs_fixed_stride_order],
)
with torch.no_grad():
net = torch.compile(model)
out = net(input_t)
def test_buffer_use_after_remove(self):
# https://github.com/pytorch/pytorch/issues/102857

View File

@ -198,9 +198,6 @@ test_failures = {
("cpu", "cuda")
),
"test_zero_element_mutation_dynamic_shapes": TestFailure(("cpu", "cuda")),
"test_custom_op_fixed_layout_sequential_dynamic_shapes": TestFailure(
("cpu", "cuda")
),
"test_cat_uint8_dynamic_shapes": TestFailure(
("cpu",)
), # cat on uint8 input is using aten fallback on cpu

View File

@ -958,14 +958,14 @@ def get_abstract_impl(qualname):
return custom_op._get_impl("abstract").func
def _custom_op_with_schema(qualname, schema, needs_fixed_stride_order=True):
def _custom_op_with_schema(qualname, schema):
ns, name = qualname.split("::")
schema_str = f"{name}{schema}"
function_schema = FunctionSchema.parse(schema_str)
validate_schema(function_schema)
tags = [torch._C.Tag.needs_fixed_stride_order] if needs_fixed_stride_order else []
lib = library.Library(ns, "FRAGMENT")
lib.define(schema_str, tags=tags)
lib.define(schema_str)
ophandle = find_ophandle_or_throw(ns, function_schema.name)
result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True)
result._register_autograd_kernel_indirection()

View File

@ -47,7 +47,6 @@ from .ir import (
TensorBox,
)
from .lowering import (
constrain_to_fx_strides,
FALLBACK_ALLOW_LIST,
fallback_handler,
fallback_node_due_to_unsupported_type,
@ -672,23 +671,6 @@ class GraphLowering(torch.fx.Interpreter):
# passthrough lowerings from .pattern_matcher
return target(*args, **kwargs)
def get_custom_op_layout_constraints(target, args, kwargs):
# Custom operations that require preserving stride order
# which run through implicit fallback must constrain their
# arguments' fx strides
layout_constraint = None
if torch._C.Tag.needs_fixed_stride_order in target.tags:
# We have to set the current args because call_function will immediately
# evaluate this lowering after creating the fallback, without evaluating
# the layout constraint
args, kwargs = constrain_to_fx_strides(
self.current_node, *args, **kwargs
)
# Also register the layout constraint so when the fallback
# is used again, we can constrain the args to the same layout
layout_constraint = constrain_to_fx_strides
return layout_constraint, args, kwargs
if target not in lowerings:
assert isinstance(
target, torch._ops.OpOverload
@ -697,9 +679,6 @@ class GraphLowering(torch.fx.Interpreter):
if base_name in FALLBACK_ALLOW_LIST:
make_fallback(target)
elif config.implicit_fallbacks:
layout_constraint, args, kwargs = get_custom_op_layout_constraints(
target, args, kwargs
)
error = (
MissingOperatorWithDecomp
if get_decompositions([target])
@ -709,8 +688,7 @@ class GraphLowering(torch.fx.Interpreter):
"Creating implicit fallback for:\n%s",
error.operator_str(target, args, kwargs),
)
make_fallback(target, layout_constraint)
make_fallback(target)
elif get_decompositions([target]):
# There isn't a good way to dynamically patch this in
# since AOT Autograd already ran. The error message tells