mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137119 Approved by: https://github.com/williamwen42, https://github.com/anijain2305 ghstack dependencies: #137114, #137115, #137116, #137117, #137120, #137227
969 lines
42 KiB
Python
969 lines
42 KiB
Python
# Owner(s): ["module: functionalization"]
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
import torch._dynamo.testing
|
|
import torch._inductor.config as inductor_config
|
|
import torch._inductor.test_case
|
|
import torch.utils._pytree as pytree
|
|
import torch.utils.cpp_extension
|
|
from torch import Tensor
|
|
from torch.testing._internal.logging_utils import logs_to_string
|
|
|
|
|
|
class AutoFunctionalizeTests(torch._inductor.test_case.TestCase):
|
|
def test_auto_functionalize_can_with_default(self):
|
|
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
|
torch.library.define(
|
|
"mylib::foo",
|
|
"(Tensor a, int b, Tensor(d!)? c=None, Tensor? d=None, int e=-1) -> ()",
|
|
tags=torch.Tag.pt2_compliant_tag,
|
|
lib=lib,
|
|
)
|
|
|
|
@torch.library.impl("mylib::foo", "cpu", lib=lib)
|
|
def foo_impl(a, b, c=None, d=None, e=-1):
|
|
a + b
|
|
return
|
|
|
|
def f(a, mode):
|
|
return torch.ops.mylib.foo(
|
|
a,
|
|
0,
|
|
)
|
|
|
|
a = torch.tensor([10, 10, 10], dtype=torch.int64)
|
|
|
|
torch.compile(f)(a, 0)
|
|
|
|
def test_auto_functionalize_can_with_none_return(self):
|
|
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
|
lib.define("foo(Tensor x, Tensor(a!) out) -> None")
|
|
|
|
def foo_impl(x, out):
|
|
out.copy_(x)
|
|
|
|
lib.impl("foo", foo_impl, "CompositeExplicitAutograd")
|
|
x = torch.randn(3)
|
|
out = torch.zeros(3)
|
|
|
|
@torch.compile
|
|
def f(x, out):
|
|
torch.ops.mylib.foo(x, out)
|
|
|
|
f(x, out)
|
|
|
|
def test_auto_functionalize_self_as_mutate_arg(self):
|
|
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
|
lib.define("foo(Tensor(a!) self) -> None")
|
|
|
|
def foo_impl(self: torch.Tensor) -> None:
|
|
self.sin_()
|
|
|
|
x = torch.randn(3)
|
|
lib.impl("foo", foo_impl, "CompositeExplicitAutograd")
|
|
|
|
@torch.compile(backend="inductor", fullgraph=True)
|
|
def f(x):
|
|
torch.ops.mylib.foo(x)
|
|
|
|
f(x)
|
|
|
|
def test_auto_functionalize_tensorlist(self):
|
|
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
|
torch.library.define(
|
|
"mylib::foo",
|
|
"(Tensor all_gather_output, SymInt[] all_gather_input_split_sizes, int dim, Tensor(a!)[] out) -> ()",
|
|
tags=torch.Tag.pt2_compliant_tag,
|
|
lib=lib,
|
|
)
|
|
|
|
@torch.library.impl("mylib::foo", "cpu", lib=lib)
|
|
@torch._dynamo.disable
|
|
def foo_impl(all_gather_output, all_gather_input_split_sizes, dim, out):
|
|
for o in out:
|
|
o.copy_(all_gather_output)
|
|
|
|
def f(all_gather_output, all_gather_input_split_sizes, dim, out):
|
|
torch.ops.mylib.foo(
|
|
all_gather_output, all_gather_input_split_sizes, dim, out
|
|
)
|
|
|
|
a = torch.ones(4)
|
|
b = [2, 3]
|
|
c = 0
|
|
d = [torch.empty(4) for _ in range(2)]
|
|
orig_args = (a, b, c, d)
|
|
|
|
compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
|
torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args)
|
|
|
|
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
|
f(*eager_args)
|
|
self.assertEqual(compiled_args, eager_args)
|
|
|
|
def test_can_auto_functionalize(self):
|
|
from torch._higher_order_ops.auto_functionalize import can_auto_functionalize
|
|
|
|
expected_true = [
|
|
"(Tensor(a!) x) -> ()",
|
|
"(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> ()",
|
|
"(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> ()",
|
|
"(Tensor(a!) x, Tensor y, Tensor(b!)[] z, SymInt w) -> ()",
|
|
"(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> Tensor",
|
|
"(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor)",
|
|
]
|
|
expected_false = [
|
|
"(Tensor x) -> ()",
|
|
"(Tensor(a) x) -> Tensor(a)",
|
|
"(Tensor(a!) x) -> Tensor(a!)",
|
|
"(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> Tensor(a)",
|
|
"(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor(a))",
|
|
"(Tensor(a) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor(a))",
|
|
"(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor[])",
|
|
]
|
|
for schema in expected_true:
|
|
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
|
torch.library.define("mylib::a", schema, lib=lib)
|
|
|
|
self.assertTrue(
|
|
can_auto_functionalize(torch.ops.mylib.a.default), msg=schema
|
|
)
|
|
self.assertFalse(can_auto_functionalize(torch.ops.mylib.a))
|
|
|
|
for schema in expected_false:
|
|
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
|
torch.library.define("mylib::a", schema, lib=lib)
|
|
self.assertFalse(
|
|
can_auto_functionalize(torch.ops.mylib.a.default), msg=schema
|
|
)
|
|
self.assertFalse(can_auto_functionalize(torch.ops.mylib.a))
|
|
|
|
@torch._inductor.config.patch(enable_auto_functionalized_v2=False)
|
|
def test_auto_functionalize_old(self):
|
|
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
|
torch.library.define(
|
|
"mylib::foo",
|
|
"(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> ()",
|
|
tags=torch.Tag.pt2_compliant_tag,
|
|
lib=lib,
|
|
)
|
|
|
|
@torch.library.impl("mylib::foo", "cpu", lib=lib)
|
|
@torch._dynamo.disable
|
|
def foo_impl(x, y, z, w, n):
|
|
x.add_(y[0] + w)
|
|
z.add_(y[1] + n)
|
|
|
|
def f(x, y, z, n):
|
|
torch.ops.mylib.foo(x, y, z, 2, n)
|
|
|
|
x = torch.randn(3)
|
|
y = (torch.randn(3), torch.randn(3))
|
|
z = torch.randn(3)
|
|
n = torch.randn(3)
|
|
orig_args = (x, y, z, n)
|
|
compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
|
log_stream, ctx = logs_to_string(
|
|
"torch._inductor.compile_fx", "post_grad_graphs"
|
|
)
|
|
with ctx():
|
|
torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args)
|
|
|
|
post_grad_graphs = "\n".join(
|
|
log_stream.getvalue().strip().split("\n")[3:]
|
|
).strip()
|
|
|
|
# Check the graph under static shapes
|
|
if torch._dynamo.config.assume_static_by_default:
|
|
self.assertExpectedInline(
|
|
post_grad_graphs,
|
|
"""\
|
|
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
|
|
# No stacktrace found for following nodes
|
|
foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg4_1 = arg1_1 = arg0_1 = foo_default = None
|
|
return ()""", # noqa: B950
|
|
)
|
|
|
|
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
|
f(*eager_args)
|
|
self.assertEqual(compiled_args, eager_args)
|
|
|
|
@torch._inductor.config.patch(enable_auto_functionalized_v2=False)
|
|
def test_auto_functionalize_with_returns_old(self):
|
|
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
|
torch.library.define(
|
|
"mylib::foo",
|
|
"(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> (Tensor, Tensor)",
|
|
tags=torch.Tag.pt2_compliant_tag,
|
|
lib=lib,
|
|
)
|
|
|
|
@torch.library.impl("mylib::foo", "cpu", lib=lib)
|
|
@torch._dynamo.disable
|
|
def foo_impl(x, y, z, w, n):
|
|
x.add_(y[0] + w)
|
|
z.add_(y[1] + n)
|
|
return y[0] + w, y[1] + n
|
|
|
|
@torch.library.impl_abstract("mylib::foo", lib=lib)
|
|
def foo_abstract(x, y, z, w, n):
|
|
return y[0] + w, y[1] + n
|
|
|
|
def f(x, y, z, n):
|
|
return torch.ops.mylib.foo(x, y, z, 2, n)
|
|
|
|
x = torch.randn(3)
|
|
y = (torch.randn(3), torch.randn(3))
|
|
z = torch.randn(3)
|
|
n = torch.randn(3)
|
|
orig_args = (x, y, z, n)
|
|
|
|
compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
|
log_stream, ctx = logs_to_string(
|
|
"torch._inductor.compile_fx", "post_grad_graphs"
|
|
)
|
|
with ctx():
|
|
compiled_out = torch.compile(f, backend="inductor", fullgraph=True)(
|
|
*compiled_args
|
|
)
|
|
|
|
if torch._dynamo.config.assume_static_by_default:
|
|
post_grad_graphs = "\n".join(
|
|
log_stream.getvalue().strip().split("\n")[3:]
|
|
).strip()
|
|
self.assertExpectedInline(
|
|
post_grad_graphs,
|
|
"""\
|
|
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
|
|
foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg4_1 = arg1_1 = arg0_1 = None
|
|
getitem_4: "f32[3][1]cpu" = foo_default[0]
|
|
getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None
|
|
return (getitem_4, getitem_5)""", # noqa: B950
|
|
ignore_comments=True,
|
|
)
|
|
|
|
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
|
eager_out = f(*eager_args)
|
|
self.assertEqual(compiled_args, eager_args)
|
|
self.assertEqual(compiled_out, eager_out)
|
|
|
|
def test_auto_functionalize_on_view(self):
|
|
for value in [True, False]:
|
|
with torch.library._scoped_library(
|
|
"mylib", "FRAGMENT"
|
|
) as lib, inductor_config.patch({"enable_auto_functionalized_v2": value}):
|
|
torch.library.define(
|
|
"mylib::foo",
|
|
"(Tensor(a!) x) -> ()",
|
|
tags=torch.Tag.pt2_compliant_tag,
|
|
lib=lib,
|
|
)
|
|
|
|
@torch.library.impl("mylib::foo", "cpu", lib=lib)
|
|
@torch._dynamo.disable
|
|
def foo_impl(x):
|
|
x_np = x.detach().numpy() # view
|
|
np.sin(x_np, out=x_np)
|
|
return
|
|
|
|
x = torch.randn(3)
|
|
expected = x.sin()
|
|
torch.ops.mylib.foo(x)
|
|
assert torch.allclose(x, expected)
|
|
|
|
@torch.compile(backend="aot_eager_decomp_partition", fullgraph=True)
|
|
def f(x):
|
|
x = x.clone()
|
|
y = x[:]
|
|
torch.ops.mylib.foo(y)
|
|
return x
|
|
|
|
y = f(x)
|
|
self.assertEqual(y, x.sin())
|
|
|
|
@torch._inductor.config.patch(enable_auto_functionalized_v2=False)
|
|
def test_auto_functionalize_optional_old(self):
|
|
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
|
torch.library.define(
|
|
"mylib::foo",
|
|
"(Tensor(a!)? x, Tensor[] y, Tensor(b!)? z, SymInt w, Tensor n) -> ()",
|
|
tags=torch.Tag.pt2_compliant_tag,
|
|
lib=lib,
|
|
)
|
|
|
|
@torch.library.impl("mylib::foo", "cpu", lib=lib)
|
|
@torch._dynamo.disable
|
|
def foo_impl(x, y, z, w, n):
|
|
if x is not None:
|
|
x.add_(y[0] + w)
|
|
if z is not None:
|
|
z.add_(y[1] + n)
|
|
|
|
def f(x, y, z, n):
|
|
torch.ops.mylib.foo(x, y, z, 2, n)
|
|
|
|
x = None
|
|
y = (torch.randn(3), torch.randn(3))
|
|
z = torch.randn(3)
|
|
n = torch.randn(3)
|
|
orig_args = (x, y, z, n)
|
|
compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
|
log_stream, ctx = logs_to_string(
|
|
"torch._inductor.compile_fx", "post_grad_graphs"
|
|
)
|
|
with ctx():
|
|
torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args)
|
|
if torch._dynamo.config.assume_static_by_default:
|
|
post_grad_graphs = "\n".join(
|
|
log_stream.getvalue().strip().split("\n")[3:]
|
|
).strip()
|
|
self.assertExpectedInline(
|
|
post_grad_graphs,
|
|
"""\
|
|
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"):
|
|
# No stacktrace found for following nodes
|
|
foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1); \
|
|
arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None
|
|
return ()""",
|
|
)
|
|
|
|
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
|
f(*eager_args)
|
|
self.assertEqual(compiled_args, eager_args)
|
|
|
|
@torch._dynamo.config.patch(
|
|
capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
|
|
)
|
|
def test_unbacked_auto_functionalize_op(self):
|
|
@torch.library.custom_op(
|
|
"mylib::mk_image", mutates_args=("decoder",), device_types=["cpu"]
|
|
)
|
|
def mk_image(decoder: Tensor) -> Tensor:
|
|
return torch.randn(2, 3, 4, 5)
|
|
|
|
@torch.library.register_fake("mylib::mk_image")
|
|
def _(decoder: Tensor) -> Tensor:
|
|
image_size = [torch.library.get_ctx().new_dynamic_size() for _ in range(4)]
|
|
return torch.empty(image_size)
|
|
|
|
@torch.compile(fullgraph=True)
|
|
def f(x):
|
|
return torch.ops.mylib.mk_image.default(x)
|
|
|
|
x = torch.zeros(100, dtype=torch.int64)
|
|
f(x)
|
|
|
|
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
|
|
def test_auto_functionalize_v2(self, _dynamic=False):
|
|
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
|
torch.library.define(
|
|
"mylib::foo",
|
|
"(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> ()",
|
|
tags=torch.Tag.pt2_compliant_tag,
|
|
lib=lib,
|
|
)
|
|
|
|
@torch.library.impl("mylib::foo", "cpu", lib=lib)
|
|
@torch._dynamo.disable
|
|
def foo_impl(x, y, z, w, n):
|
|
x.add_(y[0] + w)
|
|
z.add_(y[1] + n)
|
|
|
|
def f(x, y, z, n):
|
|
torch.ops.mylib.foo(x, y, z, 2, n)
|
|
|
|
x = torch.randn(3)
|
|
y = (torch.randn(3), torch.randn(3))
|
|
z = torch.randn(3)
|
|
n = torch.randn(3)
|
|
orig_args = (x, y, z, n)
|
|
|
|
compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
|
|
|
log_stream, ctx = logs_to_string(
|
|
"torch._inductor.compile_fx", "post_grad_graphs"
|
|
)
|
|
with ctx():
|
|
torch.compile(f, backend="inductor", dynamic=_dynamic, fullgraph=True)(
|
|
*compiled_args
|
|
)
|
|
|
|
post_grad_graphs = "\n".join(
|
|
log_stream.getvalue().strip().split("\n")[3:]
|
|
).strip()
|
|
|
|
if torch._dynamo.config.assume_static_by_default:
|
|
if _dynamic:
|
|
self.assertExpectedInline(
|
|
post_grad_graphs,
|
|
"""\
|
|
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu", arg3_1: "f32[s0][1]cpu", arg4_1: "f32[s0][1]cpu", arg5_1: "f32[s0][1]cpu"):
|
|
foo_default = torch.ops.mylib.foo.default(arg3_1, [arg4_1, arg5_1], arg2_1, 2, arg1_1); arg4_1 = arg5_1 = arg1_1 = foo_default = None
|
|
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy_ = None
|
|
copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg3_1, arg3_1); arg3_1 = copy__1 = None
|
|
return ()""", # noqa: B950
|
|
ignore_comments=True,
|
|
ignore_empty_lines=True,
|
|
)
|
|
else:
|
|
self.assertExpectedInline(
|
|
post_grad_graphs,
|
|
"""\
|
|
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
|
|
foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg3_1 = arg4_1 = arg0_1 = foo_default = None
|
|
copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
|
|
copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None
|
|
return ()""", # noqa: B950
|
|
ignore_comments=True,
|
|
ignore_empty_lines=True,
|
|
)
|
|
|
|
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
|
f(*eager_args)
|
|
self.assertEqual(compiled_args, eager_args)
|
|
|
|
def run_aot_eager(self, f, orig_args, _dynamic=False):
|
|
aot_eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
|
|
|
log_stream, ctx = logs_to_string(
|
|
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
|
|
)
|
|
|
|
result = None
|
|
with ctx():
|
|
result = torch.compile(
|
|
f, backend="aot_eager", fullgraph=True, dynamic=_dynamic
|
|
)(*aot_eager_args)
|
|
|
|
graph = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
|
|
return [aot_eager_args, result, graph]
|
|
|
|
def run_inductor(self, f, orig_args, _dynamic=False):
|
|
compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
|
|
|
log_stream, ctx = logs_to_string(
|
|
"torch._inductor.compile_fx", "post_grad_graphs"
|
|
)
|
|
result = None
|
|
with ctx():
|
|
result = torch.compile(
|
|
f, backend="inductor", fullgraph=True, dynamic=_dynamic
|
|
)(*compiled_args)
|
|
|
|
graph = "\n".join(log_stream.getvalue().strip().split("\n")[3:]).strip()
|
|
|
|
return [compiled_args, result, graph]
|
|
|
|
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
|
|
def test_auto_functionalize_with_returns_v2(self):
|
|
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
|
torch.library.define(
|
|
"mylib::foo",
|
|
"(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> (Tensor, Tensor)",
|
|
tags=torch.Tag.pt2_compliant_tag,
|
|
lib=lib,
|
|
)
|
|
|
|
@torch.library.impl("mylib::foo", "cpu", lib=lib)
|
|
@torch._dynamo.disable
|
|
def foo_impl(x, y, z, w, n):
|
|
x.add_(y[0] + w)
|
|
z.add_(y[1] + n)
|
|
return y[0] + w, y[1] + n
|
|
|
|
@torch.library.impl_abstract("mylib::foo", lib=lib)
|
|
def foo_abstract(x, y, z, w, n):
|
|
return y[0] + w, y[1] + n
|
|
|
|
def f(x, y, z, n):
|
|
return torch.ops.mylib.foo(x, y, z, 2, n)
|
|
|
|
x = torch.randn(3)
|
|
y = (torch.randn(3), torch.randn(3))
|
|
z = torch.randn(3)
|
|
n = torch.randn(3)
|
|
orig_args = (x, y, z, n)
|
|
compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
|
log_stream, ctx = logs_to_string(
|
|
"torch._inductor.compile_fx", "post_grad_graphs"
|
|
)
|
|
with ctx():
|
|
compiled_out = torch.compile(f, backend="inductor", fullgraph=True)(
|
|
*compiled_args
|
|
)
|
|
if torch._dynamo.config.assume_static_by_default:
|
|
post_grad_graphs = "\n".join(
|
|
log_stream.getvalue().strip().split("\n")[3:]
|
|
).strip()
|
|
self.assertExpectedInline(
|
|
post_grad_graphs,
|
|
"""\
|
|
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
|
|
foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg3_1 = arg4_1 = arg0_1 = None
|
|
getitem_4: "f32[3][1]cpu" = foo_default[0]
|
|
getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None
|
|
copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
|
|
copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None
|
|
return (getitem_4, getitem_5)""", # noqa: B950
|
|
ignore_comments=True,
|
|
ignore_empty_lines=True,
|
|
)
|
|
|
|
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
|
eager_out = f(*eager_args)
|
|
self.assertEqual(compiled_args, eager_args)
|
|
self.assertEqual(compiled_out, eager_out)
|
|
|
|
# foo takes two inputs that are not views.
|
|
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
|
|
def test_auto_functionalize_extra1(self, _dynamic=False):
|
|
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
|
torch.library.define(
|
|
"mylib::foo",
|
|
"(Tensor(a!) x, Tensor(b!) y) -> ()",
|
|
tags=torch.Tag.pt2_compliant_tag,
|
|
lib=lib,
|
|
)
|
|
|
|
@torch.library.impl("mylib::foo", "cpu", lib=lib)
|
|
@torch._dynamo.disable
|
|
def foo_impl(x, y):
|
|
x.sin_()
|
|
y.sin_()
|
|
|
|
def f(x, y):
|
|
torch.ops.mylib.foo(x, y)
|
|
return x + y
|
|
|
|
orig_args = (torch.randn(2), torch.randn(2))
|
|
|
|
[aot_eager_args, result1, graph_aot] = self.run_aot_eager(
|
|
f, orig_args, _dynamic
|
|
)
|
|
[inductor_args, result2, graph_inductor] = self.run_inductor(
|
|
f, orig_args, _dynamic
|
|
)
|
|
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
|
result3 = f(*eager_args)
|
|
|
|
self.assertEqual(inductor_args, eager_args)
|
|
self.assertEqual(inductor_args, aot_eager_args)
|
|
|
|
self.assertEqual(result3, result1)
|
|
self.assertEqual(result3, result2)
|
|
|
|
if torch._dynamo.config.assume_static_by_default:
|
|
if _dynamic:
|
|
self.assertExpectedInline(
|
|
graph_aot,
|
|
"""\
|
|
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu"):
|
|
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg2_1, arg1_1])
|
|
getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1]
|
|
getitem_2: "f32[s0][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
|
|
add: "f32[s0][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2)
|
|
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy_ = None
|
|
copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_1); arg2_1 = getitem_1 = copy__1 = None
|
|
return (add,)""", # noqa: B950
|
|
ignore_comments=True,
|
|
ignore_empty_lines=True,
|
|
)
|
|
else:
|
|
self.assertExpectedInline(
|
|
graph_aot,
|
|
"""\
|
|
def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"):
|
|
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg1_1, arg0_1])
|
|
getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1]
|
|
getitem_2: "f32[2][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
|
|
add: "f32[2][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2)
|
|
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_2); arg0_1 = getitem_2 = copy_ = None
|
|
copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = copy__1 = None
|
|
return (add,)""", # noqa: B950
|
|
ignore_comments=True,
|
|
ignore_empty_lines=True,
|
|
)
|
|
|
|
if torch._dynamo.config.assume_static_by_default:
|
|
if _dynamic:
|
|
self.assertExpectedInline(
|
|
graph_inductor,
|
|
"""\
|
|
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu"):
|
|
foo_default = torch.ops.mylib.foo.default(arg2_1, arg1_1); foo_default = None
|
|
add: "f32[s0][1]cpu" = torch.ops.aten.add.Tensor(arg2_1, arg1_1)
|
|
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
|
|
copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None
|
|
return (add,)""",
|
|
ignore_comments=True,
|
|
ignore_empty_lines=True,
|
|
)
|
|
else:
|
|
self.assertExpectedInline(
|
|
graph_inductor,
|
|
"""\
|
|
def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"):
|
|
foo_default = torch.ops.mylib.foo.default(arg1_1, arg0_1); foo_default = None
|
|
add: "f32[2][1]cpu" = torch.ops.aten.add.Tensor(arg1_1, arg0_1)
|
|
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None
|
|
copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy__1 = None
|
|
return (add,)""",
|
|
ignore_comments=True,
|
|
ignore_empty_lines=True,
|
|
)
|
|
|
|
# foo takes two views on the same input, function does not have return.
|
|
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
|
|
def test_auto_functionalize_extra2(self, _dynamic=False):
|
|
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
|
torch.library.define(
|
|
"mylib::foo",
|
|
"(Tensor(a!) x, Tensor(b!) y) -> ()",
|
|
tags=torch.Tag.pt2_compliant_tag,
|
|
lib=lib,
|
|
)
|
|
|
|
@torch.library.impl("mylib::foo", "cpu", lib=lib)
|
|
@torch._dynamo.disable
|
|
def foo_impl(x, y):
|
|
x.sin_()
|
|
y.sin_()
|
|
|
|
def f(x):
|
|
a = x[0]
|
|
b = x[1]
|
|
torch.ops.mylib.foo(a, b)
|
|
return
|
|
|
|
orig_args = [torch.randn(2)]
|
|
|
|
[aot_eager_args, result1, graph_aot] = self.run_aot_eager(
|
|
f, orig_args, _dynamic
|
|
)
|
|
[inductor_args, result2, graph_inductor] = self.run_inductor(
|
|
f, orig_args, _dynamic
|
|
)
|
|
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
|
result3 = f(*eager_args)
|
|
|
|
self.assertEqual(inductor_args, eager_args)
|
|
self.assertEqual(inductor_args, aot_eager_args)
|
|
|
|
self.assertEqual(result3, result1)
|
|
self.assertEqual(result3, result2)
|
|
|
|
if torch._dynamo.config.assume_static_by_default:
|
|
if _dynamic:
|
|
self.assertExpectedInline(
|
|
graph_aot,
|
|
"""\
|
|
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
|
|
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_base_index = 0, _y_size = (), _y_stride = (), _y_storage_offset = 1, _all_bases = [arg1_1])
|
|
getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None
|
|
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = copy_ = None
|
|
return ()""", # noqa: B950
|
|
ignore_comments=True,
|
|
ignore_empty_lines=True,
|
|
)
|
|
else:
|
|
self.assertExpectedInline(
|
|
graph_aot,
|
|
"""\
|
|
def forward(self, arg0_1: "f32[2][1]cpu"):
|
|
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_base_index = 0, _y_size = (), _y_stride = (), _y_storage_offset = 1, _all_bases = [arg0_1])
|
|
getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None
|
|
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = getitem_1 = copy_ = None
|
|
return ()""", # noqa: B950
|
|
ignore_comments=True,
|
|
ignore_empty_lines=True,
|
|
)
|
|
|
|
# 2. Run with inductor backend
|
|
|
|
if torch._dynamo.config.assume_static_by_default:
|
|
if _dynamic:
|
|
self.assertExpectedInline(
|
|
graph_inductor,
|
|
"""\
|
|
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
|
|
as_strided_default: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg1_1, [], [], 0)
|
|
as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg1_1, [], [], 1)
|
|
foo_default = torch.ops.mylib.foo.default(as_strided_default, as_strided_default_1); as_strided_default = as_strided_default_1 = foo_default = None
|
|
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
|
|
return ()""", # noqa: B950
|
|
ignore_comments=True,
|
|
ignore_empty_lines=True,
|
|
)
|
|
else:
|
|
self.assertExpectedInline(
|
|
graph_inductor,
|
|
"""\
|
|
def forward(self, arg0_1: "f32[2][1]cpu"):
|
|
as_strided_default: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 0)
|
|
as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 1)
|
|
foo_default = torch.ops.mylib.foo.default(as_strided_default, as_strided_default_1); as_strided_default = as_strided_default_1 = foo_default = None
|
|
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None
|
|
return ()""", # noqa: B950
|
|
ignore_comments=True,
|
|
ignore_empty_lines=True,
|
|
)
|
|
|
|
# foo takes two views on the same input, function returns both views and the input
|
|
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
|
|
def test_auto_functionalize_extra3(self):
|
|
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
|
torch.library.define(
|
|
"mylib::foo",
|
|
"(Tensor(a!) x, Tensor(b!) y) -> ()",
|
|
tags=torch.Tag.pt2_compliant_tag,
|
|
lib=lib,
|
|
)
|
|
|
|
@torch.library.impl("mylib::foo", "cpu", lib=lib)
|
|
@torch._dynamo.disable
|
|
def foo_impl(x, y):
|
|
x.sin_()
|
|
y.sin_()
|
|
|
|
def f(x):
|
|
a = x[0]
|
|
b = x[1]
|
|
torch.ops.mylib.foo(a, b)
|
|
return (a, b, x)
|
|
|
|
orig_args = [torch.randn(2)]
|
|
|
|
[aot_eager_args, result1, graph_aot] = self.run_aot_eager(f, orig_args)
|
|
[inductor_args, result2, graph_inductor] = self.run_inductor(f, orig_args)
|
|
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
|
result3 = f(*eager_args)
|
|
|
|
self.assertEqual(inductor_args, eager_args)
|
|
self.assertEqual(inductor_args, aot_eager_args)
|
|
|
|
self.assertEqual(result3, result1)
|
|
self.assertEqual(result3, result2)
|
|
|
|
if torch._dynamo.config.assume_static_by_default:
|
|
self.assertExpectedInline(
|
|
graph_aot,
|
|
"""\
|
|
def forward(self, arg0_1: "f32[2][1]cpu"):
|
|
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_base_index = 0, _y_size = (), _y_stride = (), _y_storage_offset = 1, _all_bases = [arg0_1])
|
|
getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None
|
|
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = copy_ = None
|
|
select_2: "f32[][]cpu" = torch.ops.aten.select.int(getitem_1, 0, 0)
|
|
select_3: "f32[][]cpu" = torch.ops.aten.select.int(getitem_1, 0, 1); getitem_1 = None
|
|
return (select_2, select_3)""", # noqa: B950
|
|
ignore_comments=True,
|
|
ignore_empty_lines=True,
|
|
)
|
|
|
|
# 2. Run with inductor backend
|
|
|
|
if torch._dynamo.config.assume_static_by_default:
|
|
self.assertExpectedInline(
|
|
graph_inductor,
|
|
"""\
|
|
def forward(self, arg0_1: "f32[2][1]cpu"):
|
|
as_strided_default: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 0)
|
|
as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 1)
|
|
foo_default = torch.ops.mylib.foo.default(as_strided_default, as_strided_default_1); as_strided_default = as_strided_default_1 = foo_default = None
|
|
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); copy_ = None
|
|
select_2: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0)
|
|
select_3: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1); arg0_1 = None
|
|
return (select_2, select_3)""", # noqa: B950
|
|
ignore_comments=True,
|
|
ignore_empty_lines=True,
|
|
)
|
|
|
|
# foo takes a mutable list with views in addition to other args.
|
|
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
|
|
def test_auto_functionalize_extra4(self):
|
|
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
|
torch.library.define(
|
|
"mylib::foo",
|
|
"(Tensor(a!) x, Tensor(b!)[] y) -> ()",
|
|
tags=torch.Tag.pt2_compliant_tag,
|
|
lib=lib,
|
|
)
|
|
|
|
@torch.library.impl("mylib::foo", "cpu", lib=lib)
|
|
@torch._dynamo.disable
|
|
def foo_impl(x, y):
|
|
x.sin_()
|
|
y[0].sin_()
|
|
|
|
def f(x, y, z):
|
|
a = x[0]
|
|
b = z[0]
|
|
torch.ops.mylib.foo(a, [b, y])
|
|
|
|
orig_args = [torch.randn(2), torch.randn(2), torch.randn(2)]
|
|
|
|
[aot_eager_args, result1, graph_aot] = self.run_aot_eager(f, orig_args)
|
|
[inductor_args, result2, graph_inductor] = self.run_inductor(f, orig_args)
|
|
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
|
result3 = f(*eager_args)
|
|
|
|
self.assertEqual(inductor_args[2], eager_args[2])
|
|
self.assertEqual(inductor_args, aot_eager_args)
|
|
|
|
self.assertEqual(result3, result1)
|
|
self.assertEqual(result3, result2)
|
|
|
|
if torch._dynamo.config.assume_static_by_default:
|
|
self.assertExpectedInline(
|
|
graph_aot,
|
|
"""\
|
|
def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu", arg2_1: "f32[2][1]cpu"):
|
|
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_length = 2, _y_0_base_index = 1, _y_0_size = (), _y_0_stride = (), _y_0_storage_offset = 0, _y_1_base_index = 2, _all_bases = [arg0_1, arg1_1, arg2_1])
|
|
getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1]
|
|
getitem_2: "f32[2][1]cpu" = auto_functionalized_v2[2]
|
|
getitem_3: "f32[2][1]cpu" = auto_functionalized_v2[3]; auto_functionalized_v2 = None
|
|
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = getitem_1 = copy_ = None
|
|
copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy__1 = None
|
|
copy__2: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_3); arg2_1 = getitem_3 = copy__2 = None
|
|
return ()""", # noqa: B950
|
|
ignore_comments=True,
|
|
ignore_empty_lines=True,
|
|
)
|
|
|
|
# 2. Run with inductor backend
|
|
|
|
if torch._dynamo.config.assume_static_by_default:
|
|
self.assertExpectedInline(
|
|
graph_inductor,
|
|
"""\
|
|
def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu", arg2_1: "f32[2][1]cpu"):
|
|
as_strided_default: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 0)
|
|
as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg1_1, [], [], 0)
|
|
foo_default = torch.ops.mylib.foo.default(as_strided_default, [as_strided_default_1, arg2_1]); as_strided_default = as_strided_default_1 = foo_default = None
|
|
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None
|
|
copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy__1 = None
|
|
copy__2: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__2 = None
|
|
return ()""", # noqa: B950
|
|
ignore_comments=True,
|
|
ignore_empty_lines=True,
|
|
)
|
|
|
|
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
|
|
def test_auto_functionalize_optional_v2(self):
|
|
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
|
torch.library.define(
|
|
"mylib::foo",
|
|
"(Tensor(a!)? x, Tensor[] y, Tensor(b!)? z, SymInt w, Tensor n) -> ()",
|
|
tags=torch.Tag.pt2_compliant_tag,
|
|
lib=lib,
|
|
)
|
|
|
|
@torch.library.impl("mylib::foo", "cpu", lib=lib)
|
|
@torch._dynamo.disable
|
|
def foo_impl(x, y, z, w, n):
|
|
if x is not None:
|
|
x.add_(y[0] + w)
|
|
if z is not None:
|
|
z.add_(y[1] + n)
|
|
|
|
def f(x, y, z, n):
|
|
torch.ops.mylib.foo(x, y, z, 2, n)
|
|
|
|
x = None
|
|
y = (torch.randn(3), torch.randn(3))
|
|
z = torch.randn(3)
|
|
n = torch.randn(3)
|
|
orig_args = (x, y, z, n)
|
|
|
|
compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
|
log_stream, ctx = logs_to_string(
|
|
"torch._inductor.compile_fx", "post_grad_graphs"
|
|
)
|
|
with ctx():
|
|
torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args)
|
|
|
|
if torch._dynamo.config.assume_static_by_default:
|
|
post_grad_graphs = "\n".join(
|
|
log_stream.getvalue().strip().split("\n")[3:]
|
|
).strip()
|
|
self.assertExpectedInline(
|
|
post_grad_graphs,
|
|
"""\
|
|
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"):
|
|
foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg0_1 = foo_default = None
|
|
copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
|
|
return ()""", # noqa: B950
|
|
ignore_comments=True,
|
|
ignore_empty_lines=True,
|
|
)
|
|
|
|
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
|
f(*eager_args)
|
|
self.assertEqual(compiled_args, eager_args)
|
|
|
|
@torch._inductor.config.patch(enable_auto_functionalized_v2=False)
|
|
def test_inference_mode1_v2(self):
|
|
with torch.inference_mode():
|
|
self.test_auto_functionalize_extra1()
|
|
|
|
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
|
|
def test_inference_mode2_v2(self):
|
|
with torch.inference_mode():
|
|
self.test_auto_functionalize_extra2()
|
|
|
|
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
|
|
def test_inference_mode3_v2(self):
|
|
with torch.inference_mode():
|
|
self.test_auto_functionalize_extra3()
|
|
|
|
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
|
|
def test_inference_mode4_v2(self):
|
|
with torch.inference_mode():
|
|
self.test_auto_functionalize_extra4()
|
|
|
|
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
|
|
def test_dynamic_v2(self):
|
|
self.test_auto_functionalize_v2(_dynamic=True)
|
|
|
|
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
|
|
def test_dynamic2_v2(self):
|
|
self.test_auto_functionalize_extra1(_dynamic=True)
|
|
|
|
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
|
|
def test_dynamic3_v2(self):
|
|
self.test_auto_functionalize_extra2(_dynamic=True)
|
|
|
|
# foo takes two views on the same input, function does not have return.
|
|
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
|
|
def test_graph_input_is_view(self):
|
|
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
|
torch.library.define(
|
|
"mylib::foo",
|
|
"(Tensor(a!) x) -> ()",
|
|
tags=torch.Tag.pt2_compliant_tag,
|
|
lib=lib,
|
|
)
|
|
|
|
@torch.library.impl("mylib::foo", "cpu", lib=lib)
|
|
@torch._dynamo.disable
|
|
def foo_impl(x):
|
|
pass
|
|
|
|
@torch.compile(fullgraph=True, dynamic=False, backend="aot_eager")
|
|
def f(x):
|
|
a = x[0]
|
|
torch.ops.mylib.foo(a)
|
|
return
|
|
|
|
x = torch.tensor([[1, 2], [3, 4]])
|
|
# This would fail if auto_functionalized_v2 uses clone and not clone_preserve_strides
|
|
# to clone not-inplaced args.
|
|
f(x[1])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._inductor.test_case import run_tests
|
|
|
|
run_tests()
|