mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Redesign custom op functionlaization for better re-inplace (#134409)
- The new implementation (auto_functionalized_v2) is enabled by default but can be disable using an inductor flag. - In export mode the old implementation is used. **Motiviation** Previous functionalization fails to re-inplace arguments when they are view over other tensors. see issue https://github.com/pytorch/pytorch/issues/131192 The new functionalization is easier to re-inplace for views. **A) Functionalizations pass** consider a program: ``` func(t) x = t[0] y = t[1] foo(x, y) # custom operator with x, y mutable return (x, y, t) ``` - To functionalize `foo` we generate a function that operates on the base tensors of the inputs; (x.base() and y.base()) and record how to regenerates the views out of the base for argument x by recording ```ViewInfo=(x.base(), x.size(), x.stride, x,storage_offset())``` - Due to some limitations on the torch.export arguments format, we have to generate alot of arguments, but this is something we can simplify in the future, for the example above we get the following function. ``` auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0 , _y_base_index = 0,_y_size = (), _y_stride = (), _y_storage_offset = 1 , _all_bases = [arg0_1]) ``` - In the code above: - _all_bases[t]: refers to a unique set of bases for all foo arguments. - for each argument x we have _x_base_index, _x_size, _x_stride, _x_storage_offset that can be used to (1) regenerate x from _all_bases[_x_base_index] or a copy of a the base. - the output of auto_functionalized is foo output , followed by x tensors one for each base in _all_bases, that is a copy of the base tensor after observing the mutations of the all the arguments that are views of that base. - for each use of a base in _all_bases or a view of it , that are after the call to foo, replace it with a view of the new output for the function above after functionalization we get : ``` def forward(self, arg0_1: "f32[2][1]cpu"): auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_base_index = 0, _y_size = (), _y_stride = (), _y_storage_offset = 1, _all_bases = [arg0_1]) getitem_1: "f32[2][1]cpu" = auto_functionalized[1]; auto_functionalized = None copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = copy_ = None # No stacktrace found for following nodes select_2: "f32[][]cpu" = torch.ops.aten.select.int(getitem_1, 0, 0) select_3: "f32[][]cpu" = torch.ops.aten.select.int(getitem_1, 0, 1); getitem_1 = None return (select_2, select_3) ``` **B) Semantics of auto_functionalize** The new semantics of auto_functionalize is as the following: 1. For each base in all_bases, copy the base and create all_bases copies. (if a base is inplaced we do not need to copy it) 2. For each arg, regenerate the arg from the copy of its base using the view information above. 3. return the original foo output followed by the new bases. **C) Re-inplace pass** since auto_functionalize not copy the bases, what we actually inplace is the bases. (run just like before but on the beses instead of args). 1. For each base b in _all_bases check if there is any use of base (or its aliases/views) after auto_functionalize (before its overwritten with a copy) if there is not any, then inplace it (avoid copying it in step 1 above). Pull Request resolved: https://github.com/pytorch/pytorch/pull/134409 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
195ac85fb6
commit
c8ab9b06a2
@ -4,6 +4,7 @@ import numpy as np
|
||||
|
||||
import torch
|
||||
import torch._dynamo.testing
|
||||
import torch._inductor.config as inductor_config
|
||||
import torch._inductor.test_case
|
||||
import torch.onnx.operators
|
||||
import torch.utils._pytree as pytree
|
||||
@ -126,6 +127,7 @@ class AutoFunctionalizeTests(torch._inductor.test_case.TestCase):
|
||||
for schema in expected_true:
|
||||
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
||||
torch.library.define("mylib::a", schema, lib=lib)
|
||||
|
||||
self.assertTrue(
|
||||
can_auto_functionalize(torch.ops.mylib.a.default), msg=schema
|
||||
)
|
||||
@ -139,7 +141,8 @@ class AutoFunctionalizeTests(torch._inductor.test_case.TestCase):
|
||||
)
|
||||
self.assertFalse(can_auto_functionalize(torch.ops.mylib.a))
|
||||
|
||||
def test_auto_functionalize(self):
|
||||
@torch._inductor.config.patch(enable_auto_functionalized_v2=False)
|
||||
def test_auto_functionalize_old(self):
|
||||
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
||||
torch.library.define(
|
||||
"mylib::foo",
|
||||
@ -162,9 +165,7 @@ class AutoFunctionalizeTests(torch._inductor.test_case.TestCase):
|
||||
z = torch.randn(3)
|
||||
n = torch.randn(3)
|
||||
orig_args = (x, y, z, n)
|
||||
|
||||
compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
||||
|
||||
log_stream, ctx = logs_to_string(
|
||||
"torch._inductor.compile_fx", "post_grad_graphs"
|
||||
)
|
||||
@ -192,7 +193,8 @@ arg3_1 = arg1_1 = arg0_1 = foo_default = None
|
||||
f(*eager_args)
|
||||
self.assertEqual(compiled_args, eager_args)
|
||||
|
||||
def test_auto_functionalize_with_returns(self):
|
||||
@torch._inductor.config.patch(enable_auto_functionalized_v2=False)
|
||||
def test_auto_functionalize_with_returns_old(self):
|
||||
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
||||
torch.library.define(
|
||||
"mylib::foo",
|
||||
@ -237,14 +239,12 @@ arg3_1 = arg1_1 = arg0_1 = foo_default = None
|
||||
self.assertExpectedInline(
|
||||
post_grad_graphs,
|
||||
"""\
|
||||
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", \
|
||||
arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
|
||||
# No stacktrace found for following nodes
|
||||
foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); \
|
||||
arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None
|
||||
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
|
||||
foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None
|
||||
getitem_4: "f32[3][1]cpu" = foo_default[0]
|
||||
getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None
|
||||
return (getitem_4, getitem_5)""",
|
||||
return (getitem_4, getitem_5)""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
)
|
||||
|
||||
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
||||
@ -253,37 +253,41 @@ arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None
|
||||
self.assertEqual(compiled_out, eager_out)
|
||||
|
||||
def test_auto_functionalize_on_view(self):
|
||||
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
||||
torch.library.define(
|
||||
"mylib::foo",
|
||||
"(Tensor(a!) x) -> ()",
|
||||
tags=torch.Tag.pt2_compliant_tag,
|
||||
lib=lib,
|
||||
)
|
||||
for value in [True, False]:
|
||||
with torch.library._scoped_library(
|
||||
"mylib", "FRAGMENT"
|
||||
) as lib, inductor_config.patch({"enable_auto_functionalized_v2": value}):
|
||||
torch.library.define(
|
||||
"mylib::foo",
|
||||
"(Tensor(a!) x) -> ()",
|
||||
tags=torch.Tag.pt2_compliant_tag,
|
||||
lib=lib,
|
||||
)
|
||||
|
||||
@torch.library.impl("mylib::foo", "cpu", lib=lib)
|
||||
@torch._dynamo.disable
|
||||
def foo_impl(x):
|
||||
x_np = x.detach().numpy() # view
|
||||
np.sin(x_np, out=x_np)
|
||||
return
|
||||
@torch.library.impl("mylib::foo", "cpu", lib=lib)
|
||||
@torch._dynamo.disable
|
||||
def foo_impl(x):
|
||||
x_np = x.detach().numpy() # view
|
||||
np.sin(x_np, out=x_np)
|
||||
return
|
||||
|
||||
x = torch.randn(3)
|
||||
expected = x.sin()
|
||||
torch.ops.mylib.foo(x)
|
||||
assert torch.allclose(x, expected)
|
||||
x = torch.randn(3)
|
||||
expected = x.sin()
|
||||
torch.ops.mylib.foo(x)
|
||||
assert torch.allclose(x, expected)
|
||||
|
||||
@torch.compile(backend="aot_eager_decomp_partition", fullgraph=True)
|
||||
def f(x):
|
||||
x = x.clone()
|
||||
y = x[:]
|
||||
torch.ops.mylib.foo(y)
|
||||
return x
|
||||
@torch.compile(backend="aot_eager_decomp_partition", fullgraph=True)
|
||||
def f(x):
|
||||
x = x.clone()
|
||||
y = x[:]
|
||||
torch.ops.mylib.foo(y)
|
||||
return x
|
||||
|
||||
y = f(x)
|
||||
self.assertEqual(y, x.sin())
|
||||
y = f(x)
|
||||
self.assertEqual(y, x.sin())
|
||||
|
||||
def test_auto_functionalize_optional(self):
|
||||
@torch._inductor.config.patch(enable_auto_functionalized_v2=False)
|
||||
def test_auto_functionalize_optional_old(self):
|
||||
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
||||
torch.library.define(
|
||||
"mylib::foo",
|
||||
@ -308,14 +312,12 @@ arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None
|
||||
z = torch.randn(3)
|
||||
n = torch.randn(3)
|
||||
orig_args = (x, y, z, n)
|
||||
|
||||
compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
||||
log_stream, ctx = logs_to_string(
|
||||
"torch._inductor.compile_fx", "post_grad_graphs"
|
||||
)
|
||||
with ctx():
|
||||
torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args)
|
||||
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
post_grad_graphs = "\n".join(
|
||||
log_stream.getvalue().strip().split("\n")[3:]
|
||||
@ -356,6 +358,647 @@ arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None
|
||||
x = torch.zeros(100, dtype=torch.int64)
|
||||
f(x)
|
||||
|
||||
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
|
||||
def test_auto_functionalize_v2(self, _dynamic=False):
|
||||
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
||||
torch.library.define(
|
||||
"mylib::foo",
|
||||
"(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> ()",
|
||||
tags=torch.Tag.pt2_compliant_tag,
|
||||
lib=lib,
|
||||
)
|
||||
|
||||
@torch.library.impl("mylib::foo", "cpu", lib=lib)
|
||||
@torch._dynamo.disable
|
||||
def foo_impl(x, y, z, w, n):
|
||||
x.add_(y[0] + w)
|
||||
z.add_(y[1] + n)
|
||||
|
||||
def f(x, y, z, n):
|
||||
torch.ops.mylib.foo(x, y, z, 2, n)
|
||||
|
||||
x = torch.randn(3)
|
||||
y = (torch.randn(3), torch.randn(3))
|
||||
z = torch.randn(3)
|
||||
n = torch.randn(3)
|
||||
orig_args = (x, y, z, n)
|
||||
|
||||
compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
||||
|
||||
log_stream, ctx = logs_to_string(
|
||||
"torch._inductor.compile_fx", "post_grad_graphs"
|
||||
)
|
||||
with ctx():
|
||||
torch.compile(f, backend="inductor", dynamic=_dynamic, fullgraph=True)(
|
||||
*compiled_args
|
||||
)
|
||||
|
||||
post_grad_graphs = "\n".join(
|
||||
log_stream.getvalue().strip().split("\n")[3:]
|
||||
).strip()
|
||||
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
if _dynamic:
|
||||
self.assertExpectedInline(
|
||||
post_grad_graphs,
|
||||
"""\
|
||||
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu", arg3_1: "f32[s0][1]cpu", arg4_1: "f32[s0][1]cpu", arg5_1: "f32[s0][1]cpu"):
|
||||
foo_default = torch.ops.mylib.foo.default(arg5_1, [arg3_1, arg4_1], arg2_1, 2, arg1_1); arg3_1 = arg4_1 = arg1_1 = foo_default = None
|
||||
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy_ = None
|
||||
copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg5_1, arg5_1); arg5_1 = copy__1 = None
|
||||
return ()""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
)
|
||||
else:
|
||||
self.assertExpectedInline(
|
||||
post_grad_graphs,
|
||||
"""\
|
||||
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
|
||||
foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg0_1 = foo_default = None
|
||||
copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
|
||||
copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg4_1, arg4_1); arg4_1 = copy__1 = None
|
||||
return ()""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
)
|
||||
|
||||
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
||||
f(*eager_args)
|
||||
self.assertEqual(compiled_args, eager_args)
|
||||
|
||||
def run_aot_eager(self, f, orig_args, _dynamic=False):
|
||||
aot_eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
||||
|
||||
log_stream, ctx = logs_to_string(
|
||||
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
|
||||
)
|
||||
|
||||
result = None
|
||||
with ctx():
|
||||
result = torch.compile(
|
||||
f, backend="aot_eager", fullgraph=True, dynamic=_dynamic
|
||||
)(*aot_eager_args)
|
||||
|
||||
graph = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
|
||||
return [aot_eager_args, result, graph]
|
||||
|
||||
def run_inductor(self, f, orig_args, _dynamic=False):
|
||||
compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
||||
|
||||
log_stream, ctx = logs_to_string(
|
||||
"torch._inductor.compile_fx", "post_grad_graphs"
|
||||
)
|
||||
result = None
|
||||
with ctx():
|
||||
result = torch.compile(
|
||||
f, backend="inductor", fullgraph=True, dynamic=_dynamic
|
||||
)(*compiled_args)
|
||||
|
||||
graph = "\n".join(log_stream.getvalue().strip().split("\n")[3:]).strip()
|
||||
|
||||
return [compiled_args, result, graph]
|
||||
|
||||
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
|
||||
def test_auto_functionalize_with_returns_v2(self):
|
||||
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
||||
torch.library.define(
|
||||
"mylib::foo",
|
||||
"(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> (Tensor, Tensor)",
|
||||
tags=torch.Tag.pt2_compliant_tag,
|
||||
lib=lib,
|
||||
)
|
||||
|
||||
@torch.library.impl("mylib::foo", "cpu", lib=lib)
|
||||
@torch._dynamo.disable
|
||||
def foo_impl(x, y, z, w, n):
|
||||
x.add_(y[0] + w)
|
||||
z.add_(y[1] + n)
|
||||
return y[0] + w, y[1] + n
|
||||
|
||||
@torch.library.impl_abstract("mylib::foo", lib=lib)
|
||||
def foo_abstract(x, y, z, w, n):
|
||||
return y[0] + w, y[1] + n
|
||||
|
||||
def f(x, y, z, n):
|
||||
return torch.ops.mylib.foo(x, y, z, 2, n)
|
||||
|
||||
x = torch.randn(3)
|
||||
y = (torch.randn(3), torch.randn(3))
|
||||
z = torch.randn(3)
|
||||
n = torch.randn(3)
|
||||
orig_args = (x, y, z, n)
|
||||
compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
||||
log_stream, ctx = logs_to_string(
|
||||
"torch._inductor.compile_fx", "post_grad_graphs"
|
||||
)
|
||||
with ctx():
|
||||
compiled_out = torch.compile(f, backend="inductor", fullgraph=True)(
|
||||
*compiled_args
|
||||
)
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
post_grad_graphs = "\n".join(
|
||||
log_stream.getvalue().strip().split("\n")[3:]
|
||||
).strip()
|
||||
self.assertExpectedInline(
|
||||
post_grad_graphs,
|
||||
"""\
|
||||
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
|
||||
foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg0_1 = None
|
||||
getitem_4: "f32[3][1]cpu" = foo_default[0]
|
||||
getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None
|
||||
|
||||
copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
|
||||
copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg4_1, arg4_1); arg4_1 = copy__1 = None
|
||||
return (getitem_4, getitem_5)""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
)
|
||||
|
||||
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
||||
eager_out = f(*eager_args)
|
||||
self.assertEqual(compiled_args, eager_args)
|
||||
self.assertEqual(compiled_out, eager_out)
|
||||
|
||||
# foo takes two inputs that are not views.
|
||||
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
|
||||
def test_auto_functionalize_extra1(self, _dynamic=False):
|
||||
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
||||
torch.library.define(
|
||||
"mylib::foo",
|
||||
"(Tensor(a!) x, Tensor(b!) y) -> ()",
|
||||
tags=torch.Tag.pt2_compliant_tag,
|
||||
lib=lib,
|
||||
)
|
||||
|
||||
@torch.library.impl("mylib::foo", "cpu", lib=lib)
|
||||
@torch._dynamo.disable
|
||||
def foo_impl(x, y):
|
||||
x.sin_()
|
||||
y.sin_()
|
||||
|
||||
def f(x, y):
|
||||
torch.ops.mylib.foo(x, y)
|
||||
return x + y
|
||||
|
||||
orig_args = (torch.randn(2), torch.randn(2))
|
||||
|
||||
[aot_eager_args, result1, graph_aot] = self.run_aot_eager(
|
||||
f, orig_args, _dynamic
|
||||
)
|
||||
[inductor_args, result2, graph_inductor] = self.run_inductor(
|
||||
f, orig_args, _dynamic
|
||||
)
|
||||
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
||||
result3 = f(*eager_args)
|
||||
|
||||
self.assertEqual(inductor_args, eager_args)
|
||||
self.assertEqual(inductor_args, aot_eager_args)
|
||||
|
||||
self.assertEqual(result3, result1)
|
||||
self.assertEqual(result3, result2)
|
||||
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
if _dynamic:
|
||||
self.assertExpectedInline(
|
||||
graph_aot,
|
||||
"""\
|
||||
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu"):
|
||||
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg2_1, arg1_1])
|
||||
getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1]
|
||||
getitem_2: "f32[s0][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
|
||||
add: "f32[s0][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2)
|
||||
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy_ = None
|
||||
copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_1); arg2_1 = getitem_1 = copy__1 = None
|
||||
return (add,)""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
)
|
||||
else:
|
||||
self.assertExpectedInline(
|
||||
graph_aot,
|
||||
"""\
|
||||
def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"):
|
||||
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg1_1, arg0_1])
|
||||
getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1]
|
||||
getitem_2: "f32[2][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
|
||||
add: "f32[2][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2)
|
||||
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_2); arg0_1 = getitem_2 = copy_ = None
|
||||
copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = copy__1 = None
|
||||
return (add,)""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
)
|
||||
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
if _dynamic:
|
||||
self.assertExpectedInline(
|
||||
graph_inductor,
|
||||
"""\
|
||||
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu"):
|
||||
foo_default = torch.ops.mylib.foo.default(arg2_1, arg1_1); foo_default = None
|
||||
add: "f32[s0][1]cpu" = torch.ops.aten.add.Tensor(arg2_1, arg1_1)
|
||||
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
|
||||
copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None
|
||||
return (add,)""",
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
)
|
||||
else:
|
||||
self.assertExpectedInline(
|
||||
graph_inductor,
|
||||
"""\
|
||||
def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"):
|
||||
foo_default = torch.ops.mylib.foo.default(arg1_1, arg0_1); foo_default = None
|
||||
add: "f32[2][1]cpu" = torch.ops.aten.add.Tensor(arg1_1, arg0_1)
|
||||
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None
|
||||
copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy__1 = None
|
||||
return (add,)""",
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
)
|
||||
|
||||
# foo takes two views on the same input, function does not have return.
|
||||
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
|
||||
def test_auto_functionalize_extra2(self, _dynamic=False):
|
||||
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
||||
torch.library.define(
|
||||
"mylib::foo",
|
||||
"(Tensor(a!) x, Tensor(b!) y) -> ()",
|
||||
tags=torch.Tag.pt2_compliant_tag,
|
||||
lib=lib,
|
||||
)
|
||||
|
||||
@torch.library.impl("mylib::foo", "cpu", lib=lib)
|
||||
@torch._dynamo.disable
|
||||
def foo_impl(x, y):
|
||||
x.sin_()
|
||||
y.sin_()
|
||||
|
||||
def f(x):
|
||||
a = x[0]
|
||||
b = x[1]
|
||||
torch.ops.mylib.foo(a, b)
|
||||
return
|
||||
|
||||
orig_args = [torch.randn(2)]
|
||||
|
||||
[aot_eager_args, result1, graph_aot] = self.run_aot_eager(
|
||||
f, orig_args, _dynamic
|
||||
)
|
||||
[inductor_args, result2, graph_inductor] = self.run_inductor(
|
||||
f, orig_args, _dynamic
|
||||
)
|
||||
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
||||
result3 = f(*eager_args)
|
||||
|
||||
self.assertEqual(inductor_args, eager_args)
|
||||
self.assertEqual(inductor_args, aot_eager_args)
|
||||
|
||||
self.assertEqual(result3, result1)
|
||||
self.assertEqual(result3, result2)
|
||||
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
if _dynamic:
|
||||
self.assertExpectedInline(
|
||||
graph_aot,
|
||||
"""\
|
||||
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
|
||||
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_base_index = 0, _y_size = (), _y_stride = (), _y_storage_offset = 1, _all_bases = [arg1_1])
|
||||
getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None
|
||||
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = copy_ = None
|
||||
return ()""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
)
|
||||
else:
|
||||
self.assertExpectedInline(
|
||||
graph_aot,
|
||||
"""\
|
||||
def forward(self, arg0_1: "f32[2][1]cpu"):
|
||||
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_base_index = 0, _y_size = (), _y_stride = (), _y_storage_offset = 1, _all_bases = [arg0_1])
|
||||
getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None
|
||||
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = getitem_1 = copy_ = None
|
||||
return ()""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
)
|
||||
|
||||
# 2. Run with inductor backend
|
||||
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
if _dynamic:
|
||||
self.assertExpectedInline(
|
||||
graph_inductor,
|
||||
"""\
|
||||
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
|
||||
as_strided_default: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg1_1, [], [], 0)
|
||||
as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg1_1, [], [], 1)
|
||||
foo_default = torch.ops.mylib.foo.default(as_strided_default, as_strided_default_1); as_strided_default = as_strided_default_1 = foo_default = None
|
||||
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
|
||||
return ()""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
)
|
||||
else:
|
||||
self.assertExpectedInline(
|
||||
graph_inductor,
|
||||
"""\
|
||||
def forward(self, arg0_1: "f32[2][1]cpu"):
|
||||
as_strided_default: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 0)
|
||||
as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 1)
|
||||
foo_default = torch.ops.mylib.foo.default(as_strided_default, as_strided_default_1); as_strided_default = as_strided_default_1 = foo_default = None
|
||||
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None
|
||||
return ()""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
)
|
||||
|
||||
# foo takes two views on the same input, function returns both views and the input
|
||||
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
|
||||
def test_auto_functionalize_extra3(self):
|
||||
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
||||
torch.library.define(
|
||||
"mylib::foo",
|
||||
"(Tensor(a!) x, Tensor(b!) y) -> ()",
|
||||
tags=torch.Tag.pt2_compliant_tag,
|
||||
lib=lib,
|
||||
)
|
||||
|
||||
@torch.library.impl("mylib::foo", "cpu", lib=lib)
|
||||
@torch._dynamo.disable
|
||||
def foo_impl(x, y):
|
||||
x.sin_()
|
||||
y.sin_()
|
||||
|
||||
def f(x):
|
||||
a = x[0]
|
||||
b = x[1]
|
||||
torch.ops.mylib.foo(a, b)
|
||||
return (a, b, x)
|
||||
|
||||
orig_args = [torch.randn(2)]
|
||||
|
||||
[aot_eager_args, result1, graph_aot] = self.run_aot_eager(f, orig_args)
|
||||
[inductor_args, result2, graph_inductor] = self.run_inductor(f, orig_args)
|
||||
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
||||
result3 = f(*eager_args)
|
||||
|
||||
self.assertEqual(inductor_args, eager_args)
|
||||
self.assertEqual(inductor_args, aot_eager_args)
|
||||
|
||||
self.assertEqual(result3, result1)
|
||||
self.assertEqual(result3, result2)
|
||||
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
self.assertExpectedInline(
|
||||
graph_aot,
|
||||
"""\
|
||||
def forward(self, arg0_1: "f32[2][1]cpu"):
|
||||
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_base_index = 0, _y_size = (), _y_stride = (), _y_storage_offset = 1, _all_bases = [arg0_1])
|
||||
getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None
|
||||
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = copy_ = None
|
||||
select_2: "f32[][]cpu" = torch.ops.aten.select.int(getitem_1, 0, 0)
|
||||
select_3: "f32[][]cpu" = torch.ops.aten.select.int(getitem_1, 0, 1); getitem_1 = None
|
||||
return (select_2, select_3)""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
)
|
||||
|
||||
# 2. Run with inductor backend
|
||||
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
self.assertExpectedInline(
|
||||
graph_inductor,
|
||||
"""\
|
||||
def forward(self, arg0_1: "f32[2][1]cpu"):
|
||||
as_strided_default: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 0)
|
||||
as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 1)
|
||||
foo_default = torch.ops.mylib.foo.default(as_strided_default, as_strided_default_1); as_strided_default = as_strided_default_1 = foo_default = None
|
||||
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); copy_ = None
|
||||
select_2: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0)
|
||||
select_3: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1); arg0_1 = None
|
||||
return (select_2, select_3)""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
)
|
||||
|
||||
# foo takes a mutable list with views in addition to other args.
|
||||
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
|
||||
def test_auto_functionalize_extra4(self):
|
||||
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
||||
torch.library.define(
|
||||
"mylib::foo",
|
||||
"(Tensor(a!) x, Tensor(b!)[] y) -> ()",
|
||||
tags=torch.Tag.pt2_compliant_tag,
|
||||
lib=lib,
|
||||
)
|
||||
|
||||
@torch.library.impl("mylib::foo", "cpu", lib=lib)
|
||||
@torch._dynamo.disable
|
||||
def foo_impl(x, y):
|
||||
x.sin_()
|
||||
y[0].sin_()
|
||||
|
||||
def f(x, y, z):
|
||||
a = x[0]
|
||||
b = z[0]
|
||||
torch.ops.mylib.foo(a, [b, y])
|
||||
|
||||
orig_args = [torch.randn(2), torch.randn(2), torch.randn(2)]
|
||||
|
||||
[aot_eager_args, result1, graph_aot] = self.run_aot_eager(f, orig_args)
|
||||
[inductor_args, result2, graph_inductor] = self.run_inductor(f, orig_args)
|
||||
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
||||
result3 = f(*eager_args)
|
||||
|
||||
self.assertEqual(inductor_args[2], eager_args[2])
|
||||
self.assertEqual(inductor_args, aot_eager_args)
|
||||
|
||||
self.assertEqual(result3, result1)
|
||||
self.assertEqual(result3, result2)
|
||||
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
self.assertExpectedInline(
|
||||
graph_aot,
|
||||
"""\
|
||||
def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu", arg2_1: "f32[2][1]cpu"):
|
||||
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_length = 2, _y_0_base_index = 1, _y_0_size = (), _y_0_stride = (), _y_0_storage_offset = 0, _y_1_base_index = 2, _all_bases = [arg0_1, arg1_1, arg2_1])
|
||||
getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1]
|
||||
getitem_2: "f32[2][1]cpu" = auto_functionalized_v2[2]
|
||||
getitem_3: "f32[2][1]cpu" = auto_functionalized_v2[3]; auto_functionalized_v2 = None
|
||||
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = getitem_1 = copy_ = None
|
||||
copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy__1 = None
|
||||
copy__2: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_3); arg2_1 = getitem_3 = copy__2 = None
|
||||
return ()""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
)
|
||||
|
||||
# 2. Run with inductor backend
|
||||
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
self.assertExpectedInline(
|
||||
graph_inductor,
|
||||
"""\
|
||||
def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu", arg2_1: "f32[2][1]cpu"):
|
||||
as_strided_default: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 0)
|
||||
as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg1_1, [], [], 0)
|
||||
foo_default = torch.ops.mylib.foo.default(as_strided_default, [as_strided_default_1, arg2_1]); as_strided_default = as_strided_default_1 = foo_default = None
|
||||
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None
|
||||
copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy__1 = None
|
||||
copy__2: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__2 = None
|
||||
return ()""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
)
|
||||
|
||||
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
|
||||
def test_auto_functionalize_optional_v2(self):
|
||||
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
||||
torch.library.define(
|
||||
"mylib::foo",
|
||||
"(Tensor(a!)? x, Tensor[] y, Tensor(b!)? z, SymInt w, Tensor n) -> ()",
|
||||
tags=torch.Tag.pt2_compliant_tag,
|
||||
lib=lib,
|
||||
)
|
||||
|
||||
@torch.library.impl("mylib::foo", "cpu", lib=lib)
|
||||
@torch._dynamo.disable
|
||||
def foo_impl(x, y, z, w, n):
|
||||
if x is not None:
|
||||
x.add_(y[0] + w)
|
||||
if z is not None:
|
||||
z.add_(y[1] + n)
|
||||
|
||||
def f(x, y, z, n):
|
||||
torch.ops.mylib.foo(x, y, z, 2, n)
|
||||
|
||||
x = None
|
||||
y = (torch.randn(3), torch.randn(3))
|
||||
z = torch.randn(3)
|
||||
n = torch.randn(3)
|
||||
orig_args = (x, y, z, n)
|
||||
|
||||
compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
||||
log_stream, ctx = logs_to_string(
|
||||
"torch._inductor.compile_fx", "post_grad_graphs"
|
||||
)
|
||||
with ctx():
|
||||
torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args)
|
||||
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
post_grad_graphs = "\n".join(
|
||||
log_stream.getvalue().strip().split("\n")[3:]
|
||||
).strip()
|
||||
self.assertExpectedInline(
|
||||
post_grad_graphs,
|
||||
"""\
|
||||
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"):
|
||||
foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg0_1 = foo_default = None
|
||||
copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
|
||||
return ()""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
)
|
||||
|
||||
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
||||
f(*eager_args)
|
||||
self.assertEqual(compiled_args, eager_args)
|
||||
|
||||
@torch._inductor.config.patch(enable_auto_functionalized_v2=False)
|
||||
def test_inference_mode1_v2(self):
|
||||
with torch.inference_mode():
|
||||
self.test_auto_functionalize_extra1()
|
||||
|
||||
# In inference mode we do not support inplacing views yet.
|
||||
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
|
||||
def test_inference_mode2_v2(self):
|
||||
with torch.inference_mode(), torch.library._scoped_library(
|
||||
"mylib", "FRAGMENT"
|
||||
) as lib:
|
||||
torch.library.define(
|
||||
"mylib::foo",
|
||||
"(Tensor(a!) x, Tensor(b!) y) -> ()",
|
||||
tags=torch.Tag.pt2_compliant_tag,
|
||||
lib=lib,
|
||||
)
|
||||
|
||||
@torch.library.impl("mylib::foo", "cpu", lib=lib)
|
||||
@torch._dynamo.disable
|
||||
def foo_impl(x, y):
|
||||
x.sin_()
|
||||
y.sin_()
|
||||
|
||||
def f(x):
|
||||
a = x[0]
|
||||
b = x[1]
|
||||
torch.ops.mylib.foo(a, b)
|
||||
return
|
||||
|
||||
orig_args = [torch.randn(2)]
|
||||
|
||||
[aot_eager_args, result1, graph_aot] = self.run_aot_eager(f, orig_args)
|
||||
[inductor_args, result2, graph_inductor] = self.run_inductor(f, orig_args)
|
||||
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
||||
result3 = f(*eager_args)
|
||||
|
||||
self.assertEqual(inductor_args, eager_args)
|
||||
self.assertEqual(inductor_args, aot_eager_args)
|
||||
|
||||
self.assertEqual(result3, result1)
|
||||
self.assertEqual(result3, result2)
|
||||
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
self.assertExpectedInline(
|
||||
graph_aot,
|
||||
"""\
|
||||
def forward(self, arg0_1: "f32[2][1]cpu"):
|
||||
select: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0)
|
||||
select_1: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1)
|
||||
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [select, select_1]); select = select_1 = None
|
||||
getitem_1: "f32[][]cpu" = auto_functionalized_v2[1]
|
||||
getitem_2: "f32[][]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
|
||||
select_scatter: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(arg0_1, getitem_1, 0, 0); getitem_1 = None
|
||||
select_scatter_1: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(select_scatter, getitem_2, 0, 1); select_scatter = getitem_2 = None
|
||||
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, select_scatter_1); arg0_1 = select_scatter_1 = copy_ = None
|
||||
return ()""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
)
|
||||
|
||||
# 2. Run with inductor backend
|
||||
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
self.assertExpectedInline(
|
||||
graph_inductor,
|
||||
"""\
|
||||
def forward(self, arg0_1: "f32[2][1]cpu"):
|
||||
select: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0)
|
||||
select_1: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1)
|
||||
clone_default: "f32[][]cpu" = torch.ops.aten.clone.default(select); select = None
|
||||
clone_default_1: "f32[][]cpu" = torch.ops.aten.clone.default(select_1); select_1 = None
|
||||
foo_default = torch.ops.mylib.foo.default(clone_default, clone_default_1); foo_default = None
|
||||
select_scatter_default: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(arg0_1, clone_default, 0, 0); clone_default = None
|
||||
select_scatter_default_1: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(select_scatter_default, clone_default_1, 0, 1); select_scatter_default = clone_default_1 = None
|
||||
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, select_scatter_default_1); arg0_1 = select_scatter_default_1 = copy_ = None
|
||||
return ()""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
)
|
||||
|
||||
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
|
||||
def test_dynamic_v2(self):
|
||||
self.test_auto_functionalize_v2(_dynamic=True)
|
||||
|
||||
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
|
||||
def test_dynamic2_v2(self):
|
||||
self.test_auto_functionalize_extra1(_dynamic=True)
|
||||
|
||||
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
|
||||
def test_dynamic3_v2(self):
|
||||
self.test_auto_functionalize_extra2(_dynamic=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
@ -3,10 +3,14 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch._inductor.config as inductor_config
|
||||
from functorch import make_fx
|
||||
from torch import Tensor
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._higher_order_ops.auto_functionalize import (
|
||||
auto_functionalized,
|
||||
auto_functionalized_v2,
|
||||
)
|
||||
from torch._inductor.fx_passes.reinplace import reinplace_inplaceable_ops_core
|
||||
from torch._inductor.test_case import run_tests, TestCase as InductorTestCase
|
||||
from torch.testing._internal.common_utils import (
|
||||
@ -70,6 +74,11 @@ else:
|
||||
return
|
||||
|
||||
|
||||
@torch.library.custom_op("test_view::boo", mutates_args={"x"})
|
||||
def boo(x: torch.Tensor) -> None:
|
||||
x.sin_()
|
||||
|
||||
|
||||
class TestReinplacingPassCorrectness(InductorTestCase):
|
||||
def setUp(self):
|
||||
counters.clear()
|
||||
@ -124,7 +133,7 @@ class TestReinplacingPassCorrectness(InductorTestCase):
|
||||
|
||||
self._test(f)
|
||||
|
||||
def test_counters(self):
|
||||
def test_counters_functionalize_old(self):
|
||||
counters.clear()
|
||||
|
||||
def f(x):
|
||||
@ -143,21 +152,176 @@ class TestReinplacingPassCorrectness(InductorTestCase):
|
||||
# IF THIS NUMBER GOES TO ZERO, PLEASE FIND ANOTHER EXAMPLE
|
||||
self.assertEqual(num_reinplacing_failures(), 1)
|
||||
|
||||
def test_counters_functionalize_v2(self):
|
||||
counters.clear()
|
||||
|
||||
def f(x):
|
||||
out = torch.empty_like(x)
|
||||
_, new_out = auto_functionalized_v2(
|
||||
sin._opoverload,
|
||||
x=x,
|
||||
_result_base_index=0,
|
||||
_result_size=(3,),
|
||||
_result_stride=(1,),
|
||||
_result_storage_offset=0,
|
||||
_all_bases=[out],
|
||||
)
|
||||
y = out * new_out
|
||||
return new_out, y
|
||||
|
||||
x = torch.randn(3, device=device)
|
||||
gm = make_fx(f, tracing_mode="fake")(x)
|
||||
reinplace_inplaceable_ops_core(gm.graph)
|
||||
|
||||
# We shouldn't have been able to reinplace `out` because it was used after
|
||||
# auto_functionalized. Note that this usually doesn't happen in practice;
|
||||
# we're artificially creating this example to test the counter.
|
||||
# IF THIS NUMBER GOES TO ZERO, PLEASE FIND ANOTHER EXAMPLE
|
||||
self.assertEqual(num_reinplacing_failures(), 1)
|
||||
|
||||
def get_not_inplaced_count(self, graph):
|
||||
counter = 0
|
||||
auto_functionalized_found = False
|
||||
for node in graph.nodes:
|
||||
if (node.target == torch.ops.higher_order.auto_functionalized) or (
|
||||
node.target == torch.ops.higher_order.auto_functionalized_v2
|
||||
):
|
||||
auto_functionalized_found = True
|
||||
counter += len(node.meta["only_clone_these_tensors"])
|
||||
assert auto_functionalized_found
|
||||
return counter
|
||||
|
||||
def test_view_inplaced_functionalize_v2(self):
|
||||
def f(arg0_1):
|
||||
select = torch.ops.aten.select.int(arg0_1, 0, 0)
|
||||
auto_functionalized = auto_functionalized_v2(
|
||||
torch.ops.test_view.boo.default,
|
||||
_x_base_index=0,
|
||||
_x_size=(3,),
|
||||
_x_stride=(1,),
|
||||
_x_storage_offset=0,
|
||||
_all_bases=[arg0_1],
|
||||
)
|
||||
getitem_1 = auto_functionalized[1]
|
||||
copy_ = torch.ops.aten.copy_.default(arg0_1, getitem_1)
|
||||
return ()
|
||||
|
||||
x1 = torch.randn(3, device=device)
|
||||
gm = make_fx(f, tracing_mode="fake")(x1)
|
||||
reinplace_inplaceable_ops_core(gm.graph)
|
||||
|
||||
self.assertEqual(self.get_not_inplaced_count(gm.graph), 0)
|
||||
|
||||
# introduce a view another_view that is used `after` the copy
|
||||
def test_view_inplaced2_functionalize_v2(self):
|
||||
def f(arg0_1):
|
||||
select = torch.ops.aten.select.int(arg0_1, 0, 0)
|
||||
another_view = arg0_1[2]
|
||||
auto_functionalized = auto_functionalized_v2(
|
||||
torch.ops.test_view.boo.default,
|
||||
_x_base_index=0,
|
||||
_x_size=(3,),
|
||||
_x_stride=(1,),
|
||||
_x_storage_offset=0,
|
||||
_all_bases=[arg0_1],
|
||||
)
|
||||
getitem_1 = auto_functionalized[1]
|
||||
copy_ = torch.ops.aten.copy_.default(arg0_1, getitem_1)
|
||||
return another_view
|
||||
|
||||
x1 = torch.randn(3, device=device)
|
||||
gm = make_fx(f, tracing_mode="fake")(x1)
|
||||
reinplace_inplaceable_ops_core(gm.graph)
|
||||
|
||||
self.assertEqual(self.get_not_inplaced_count(gm.graph), 0)
|
||||
|
||||
# introduce a view another_view that is used `before` the copy
|
||||
def test_views_not_inplaced_functionalize_v2(self):
|
||||
def f(arg0_1):
|
||||
select = torch.ops.aten.select.int(arg0_1, 0, 0)
|
||||
another_view = arg0_1[2]
|
||||
auto_functionalized = auto_functionalized_v2(
|
||||
torch.ops.test_view.boo.default,
|
||||
_x_base_index=0,
|
||||
_x_size=(3,),
|
||||
_x_stride=(1,),
|
||||
_x_storage_offset=0,
|
||||
_all_bases=[arg0_1],
|
||||
)
|
||||
getitem_1 = auto_functionalized[1]
|
||||
use_another_view = another_view * 10
|
||||
copy_ = torch.ops.aten.copy_.default(arg0_1, getitem_1)
|
||||
return use_another_view
|
||||
|
||||
x1 = torch.randn(3, device=device)
|
||||
gm = make_fx(f, tracing_mode="fake")(x1)
|
||||
reinplace_inplaceable_ops_core(gm.graph)
|
||||
|
||||
self.assertEqual(self.get_not_inplaced_count(gm.graph), 1)
|
||||
|
||||
# a view over input without copy node, inplace not allowed
|
||||
def test_views_not_inplaced2_functionalize_v2(self):
|
||||
def f(arg0_1):
|
||||
select = torch.ops.aten.select.int(arg0_1, 0, 0)
|
||||
another_view = arg0_1[2]
|
||||
auto_functionalized = auto_functionalized_v2(
|
||||
torch.ops.test_view.boo.default,
|
||||
_x_base_index=0,
|
||||
_x_size=(3,),
|
||||
_x_stride=(1,),
|
||||
_x_storage_offset=0,
|
||||
_all_bases=[arg0_1],
|
||||
)
|
||||
getitem_1 = auto_functionalized[1]
|
||||
return
|
||||
|
||||
x1 = torch.randn(3, device=device)
|
||||
gm = make_fx(f, tracing_mode="fake")(x1)
|
||||
reinplace_inplaceable_ops_core(gm.graph)
|
||||
|
||||
self.assertEqual(self.get_not_inplaced_count(gm.graph), 1)
|
||||
|
||||
# no copy nodes, view over local, with a use for another view
|
||||
def test_views_not_inplaced3_functionalize_v2(self):
|
||||
def f(arg0_1):
|
||||
a = torch.ones(10)
|
||||
another_view = a[2]
|
||||
auto_functionalized = auto_functionalized_v2(
|
||||
torch.ops.test_view.boo.default,
|
||||
_x_base_index=0,
|
||||
_x_size=(),
|
||||
_x_stride=(),
|
||||
_x_storage_offset=0,
|
||||
_all_bases=[a],
|
||||
)
|
||||
getitem_1 = auto_functionalized[1]
|
||||
return another_view
|
||||
|
||||
x1 = torch.randn(3, device=device)
|
||||
gm = make_fx(f, tracing_mode="fake")(x1)
|
||||
reinplace_inplaceable_ops_core(gm.graph)
|
||||
|
||||
self.assertEqual(self.get_not_inplaced_count(gm.graph), 1)
|
||||
|
||||
def test_multi_output_intermediate(self):
|
||||
for requires_grad in [False, True]:
|
||||
counters.clear()
|
||||
for enable_v2 in [False, True]:
|
||||
with inductor_config.patch(
|
||||
{"enable_auto_functionalized_v2": enable_v2}
|
||||
):
|
||||
counters.clear()
|
||||
|
||||
def f(x):
|
||||
out1 = torch.empty_like(x)
|
||||
out2 = torch.empty_like(x)
|
||||
sin_cos(x, out1, out2)
|
||||
return out1, out2, x**2
|
||||
def f(x):
|
||||
out1 = torch.empty_like(x)
|
||||
out2 = torch.empty_like(x)
|
||||
sin_cos(x, out1, out2)
|
||||
return out1, out2, x**2
|
||||
|
||||
x = torch.randn(3, device=device, requires_grad=requires_grad)
|
||||
res1, res2, _ = torch.compile(f)(x)
|
||||
self.assertEqual(res1, x.sin())
|
||||
self.assertEqual(res2, x.cos())
|
||||
self.assertEqual(num_reinplacing_failures(), 0)
|
||||
x = torch.randn(3, device=device, requires_grad=requires_grad)
|
||||
res1, res2, _ = torch.compile(f)(x)
|
||||
self.assertEqual(res1, x.sin())
|
||||
self.assertEqual(res2, x.cos())
|
||||
self.assertEqual(num_reinplacing_failures(), 0)
|
||||
|
||||
def test_multiple_mutations(self):
|
||||
counters.clear()
|
||||
@ -190,31 +354,59 @@ class TestReinplacingPassCorrectness(InductorTestCase):
|
||||
self.assertEqual(result, x.sin().sin().sin())
|
||||
self.assertEqual(num_reinplacing_failures(), 0)
|
||||
|
||||
def test_lists(self):
|
||||
@torch.library.custom_op("mylib::mutate_op", mutates_args={"y"})
|
||||
def mutate_op(y: List[Tensor]) -> None:
|
||||
y[0].add_(2)
|
||||
y[1].add_(3)
|
||||
def test_lists_functionalize_v2(self):
|
||||
with inductor_config.patch({"enable_auto_functionalized_v2": True}):
|
||||
|
||||
@torch.compile(fullgraph=True, dynamic=False, backend="inductor")
|
||||
def f(b):
|
||||
mutate_op([b[0], b[1]])
|
||||
@torch.library.custom_op("mylib::mutate_op", mutates_args={"y"})
|
||||
def mutate_op(y: List[Tensor]) -> None:
|
||||
y[0].add_(2)
|
||||
y[1].add_(3)
|
||||
|
||||
x1 = torch.tensor([0.3, 0.4], device=device)
|
||||
log_stream, ctx = logs_to_string(
|
||||
"torch._inductor.compile_fx", "post_grad_graphs"
|
||||
)
|
||||
with ctx():
|
||||
torch.compile(f, backend="inductor", fullgraph=True)(x1)
|
||||
post_grad_graphs = "\n".join(
|
||||
log_stream.getvalue().strip().split("\n")[3:]
|
||||
).strip()
|
||||
@torch.compile(fullgraph=True, dynamic=False, backend="inductor")
|
||||
def f(b):
|
||||
mutate_op([b[0], b[1]])
|
||||
|
||||
# Can't reinplace on views yet (1 for the "entire list" failing to reinplace)
|
||||
self.assertEqual(num_reinplacing_failures(), 1)
|
||||
x1 = torch.tensor([0.3, 0.4], device=device)
|
||||
log_stream, ctx = logs_to_string(
|
||||
"torch._inductor.compile_fx", "post_grad_graphs"
|
||||
)
|
||||
with ctx():
|
||||
torch.compile(f, backend="inductor", fullgraph=True)(x1)
|
||||
post_grad_graphs = "\n".join(
|
||||
log_stream.getvalue().strip().split("\n")[3:]
|
||||
).strip()
|
||||
|
||||
# Both list inputs failed to reinplace. So we should have emitted clones for them.
|
||||
self.assertEqual(post_grad_graphs.count("aten.clone"), 2)
|
||||
# We can inplace the base y. no clones emitted.
|
||||
self.assertEqual(num_reinplacing_failures(), 0)
|
||||
self.assertEqual(post_grad_graphs.count("aten.clone"), 0)
|
||||
|
||||
def test_lists_old_functionalize(self):
|
||||
with inductor_config.patch({"enable_auto_functionalized_v2": False}):
|
||||
|
||||
@torch.library.custom_op("mylib::mutate_op", mutates_args={"y"})
|
||||
def mutate_op(y: List[Tensor]) -> None:
|
||||
y[0].add_(2)
|
||||
y[1].add_(3)
|
||||
|
||||
@torch.compile(fullgraph=True, dynamic=False, backend="inductor")
|
||||
def f(b):
|
||||
mutate_op([b[0], b[1]])
|
||||
|
||||
x1 = torch.tensor([0.3, 0.4], device=device)
|
||||
log_stream, ctx = logs_to_string(
|
||||
"torch._inductor.compile_fx", "post_grad_graphs"
|
||||
)
|
||||
with ctx():
|
||||
torch.compile(f, backend="inductor", fullgraph=True)(x1)
|
||||
post_grad_graphs = "\n".join(
|
||||
log_stream.getvalue().strip().split("\n")[3:]
|
||||
).strip()
|
||||
|
||||
# Can't reinplace on views yet (1 for the "entire list" failing to reinplace)
|
||||
self.assertEqual(num_reinplacing_failures(), 1)
|
||||
|
||||
# Both list inputs failed to reinplace. So we should have emitted clones for them.
|
||||
self.assertEqual(post_grad_graphs.count("aten.clone"), 2)
|
||||
|
||||
@parametrize(
|
||||
"factory_op",
|
||||
|
@ -184,7 +184,6 @@ def aot_dispatch_base_graph(
|
||||
# As long as we opted to remove input mutations, then
|
||||
# there should be *NO* mutating ops in the graph at this point.
|
||||
copy_count = assert_functional_graph(fw_module.graph)
|
||||
|
||||
fw_module.graph.eliminate_dead_code()
|
||||
fw_module.recompile()
|
||||
|
||||
|
@ -804,16 +804,17 @@ def solve_min_cut(
|
||||
if node.op == "call_function" and hasattr(node.target, "_overloadpacket")
|
||||
}
|
||||
ops_ignored = joint_module_ops - {str(i) for i in op_types.recomputable_ops}
|
||||
print("Ops banned from rematerialization: ", ops_ignored)
|
||||
print("Ops banned from re-materialization: ", ops_ignored)
|
||||
print()
|
||||
|
||||
def can_fuse_into_auto_functionalized(a, b):
|
||||
if b.target != torch.ops.higher_order.auto_functionalized:
|
||||
return False
|
||||
mutable_op = b.args[0]
|
||||
mutable_arg_names = (
|
||||
torch._higher_order_ops.auto_functionalize.get_mutable_arg_names(mutable_op)
|
||||
)
|
||||
(
|
||||
mutable_arg_names,
|
||||
_,
|
||||
) = torch._higher_order_ops.auto_functionalize.get_mutable_args(mutable_op)
|
||||
for name in mutable_arg_names:
|
||||
arg = b.kwargs[name]
|
||||
if a is arg:
|
||||
|
@ -1,7 +1,8 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
@ -17,6 +18,142 @@ from torch.fx.experimental.proxy_tensor import (
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ViewInfo:
|
||||
base_index: int
|
||||
size: Optional[Sequence[Union[int, torch.SymInt]]] = None
|
||||
stride: Optional[Sequence[Union[int, torch.SymInt]]] = None
|
||||
storage_offset: Optional[int] = None
|
||||
# When is_view is false, the tensor is the base, and
|
||||
# size, stride and storage_offset are all None.
|
||||
is_view: bool = True
|
||||
|
||||
def regenerate_view(self, bases_list: List[Tensor]):
|
||||
if not self.is_view:
|
||||
return bases_list[self.base_index]
|
||||
|
||||
assert self.stride is not None
|
||||
assert self.size is not None
|
||||
assert self.storage_offset is not None
|
||||
|
||||
return torch.as_strided(
|
||||
bases_list[self.base_index],
|
||||
self.size,
|
||||
self.stride,
|
||||
self.storage_offset,
|
||||
)
|
||||
|
||||
|
||||
def write_view_information_to_args(
|
||||
mutable_arg_names: List[str],
|
||||
mutable_arg_types: List[torch.Type],
|
||||
kwargs: Dict[str, Any],
|
||||
arg_to_base_index: Dict[str, Any],
|
||||
):
|
||||
"""
|
||||
This function writes the view information into kwargs. It reads mutable_args from kwargs.
|
||||
and uses arg_to_base_index and tensor information to write ViewInfo into kwargs.
|
||||
mutable_arg_names: mutable custom operator arg names.
|
||||
mutable_arg_types: mutable custom operator arg types.
|
||||
kwargs: the original custom operator args.
|
||||
arg_to_base_index: maps mutable_arg_name to int | [int] that refers to the base tensor that
|
||||
corresponds to the input tensor
|
||||
"""
|
||||
|
||||
def write_single_view(prefix: str, tensor: Tensor, base_index: int):
|
||||
assert f"{prefix}_base_index" not in kwargs
|
||||
assert f"{prefix}_size" not in kwargs
|
||||
assert f"{prefix}_stride" not in kwargs
|
||||
assert f"{prefix}_storage_offset" not in kwargs
|
||||
|
||||
if tensor is None:
|
||||
kwargs[f"{prefix}_base_index"] = None
|
||||
elif tensor._base is None:
|
||||
# if the tensor is the base (not view), for simplicity we do not serialize view meta.
|
||||
kwargs[f"{prefix}_base_index"] = base_index
|
||||
else:
|
||||
kwargs[f"{prefix}_base_index"] = base_index
|
||||
kwargs[f"{prefix}_size"] = tensor.size()
|
||||
kwargs[f"{prefix}_stride"] = tensor.stride()
|
||||
kwargs[f"{prefix}_storage_offset"] = tensor.storage_offset()
|
||||
|
||||
for arg_name, arg_type in zip(mutable_arg_names, mutable_arg_types):
|
||||
arg = kwargs[arg_name]
|
||||
if isinstance(arg_type, torch.ListType):
|
||||
if arg is None:
|
||||
kwargs[f"_{arg_name}_length"] = None
|
||||
|
||||
kwargs[f"_{arg_name}_length"] = len(arg)
|
||||
for i, elem in enumerate(arg):
|
||||
write_single_view(
|
||||
f"_{arg_name}_{i}", elem, arg_to_base_index[arg_name][i]
|
||||
)
|
||||
|
||||
elif isinstance(arg_type, (torch.TensorType, torch.OptionalType)):
|
||||
write_single_view(
|
||||
f"_{arg_name}",
|
||||
kwargs[arg_name],
|
||||
arg_to_base_index.get(arg_name, None),
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported type {arg_type}")
|
||||
|
||||
|
||||
# Returns a dict of arg_name -> ViewInfo | [ViewInfo]
|
||||
def read_view_information_from_args(
|
||||
mutable_arg_names: List[str],
|
||||
mutable_arg_types: List[torch.Type],
|
||||
kwargs: Dict[str, Any],
|
||||
all_bases: List[Tensor],
|
||||
):
|
||||
"""
|
||||
This reads the view information added by `write_view_information_to_args` from kwargs, pop them,
|
||||
and returns a dict arg_name -> ViewInfo | [ViewInfo](if the input is list). that maps each mutable arg
|
||||
to its view information.
|
||||
mutable_arg_names: mutable custom operator arg names.
|
||||
mutable_arg_types: mutable custom operator arg types.
|
||||
kwargs : args of auto_functionalize(custom_op, kwargs)
|
||||
"""
|
||||
|
||||
def get_arg(name):
|
||||
return kwargs.pop(name)
|
||||
|
||||
def read_single_view(prefix):
|
||||
base_index = get_arg(f"{prefix}_base_index")
|
||||
if base_index is None:
|
||||
return None
|
||||
elif f"{prefix}_size" not in kwargs:
|
||||
assert f"{prefix}_stride" not in kwargs
|
||||
assert f"{prefix}_storage_offset" not in kwargs
|
||||
|
||||
# This means that the argument is the base tensor
|
||||
return ViewInfo(base_index, all_bases[base_index], is_view=False)
|
||||
|
||||
else:
|
||||
size = get_arg(f"{prefix}_size")
|
||||
stride = get_arg(f"{prefix}_stride")
|
||||
storage_offset = get_arg(f"{prefix}_storage_offset")
|
||||
return ViewInfo(base_index, size, stride, storage_offset, is_view=True)
|
||||
|
||||
args_view_info: Dict[str, Any] = {}
|
||||
for arg_name, arg_type in zip(mutable_arg_names, mutable_arg_types):
|
||||
if isinstance(arg_type, torch.ListType):
|
||||
length = get_arg(f"_{arg_name}_length")
|
||||
if length is None:
|
||||
# The whole list is None.
|
||||
args_view_info[arg_name] = None
|
||||
else:
|
||||
args_view_info[arg_name] = [
|
||||
read_single_view(f"_{arg_name}_{i}") for i in range(length)
|
||||
]
|
||||
|
||||
elif isinstance(arg_type, (torch.TensorType, torch.OptionalType)):
|
||||
args_view_info[arg_name] = read_single_view(f"_{arg_name}")
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported type {arg_type}")
|
||||
return args_view_info
|
||||
|
||||
|
||||
# NOTE: [auto-functionalizing custom ops]
|
||||
# Users may wish to torch.compile custom ops that mutate their inputs.
|
||||
# torch.compile will automatically support this op without anyone needing
|
||||
@ -34,6 +171,9 @@ from torch.fx.experimental.proxy_tensor import (
|
||||
# This HOP effectively runs the functional version of the op when
|
||||
# called: it clones inputs that will be mutated, runs the op, and
|
||||
# then returns (output, Tensors with the new values)
|
||||
#
|
||||
# auto_functionalize_v2 is an improved version of auto_functionalize that better handle
|
||||
# re-inplacing views.
|
||||
|
||||
|
||||
class AutoFunctionalized(HigherOrderOperator):
|
||||
@ -71,6 +211,38 @@ class AutoFunctionalized(HigherOrderOperator):
|
||||
auto_functionalized = AutoFunctionalized()
|
||||
auto_functionalized.__module__ = "torch.ops.higher_order"
|
||||
|
||||
auto_functionalized.fallthrough(DispatchKey.AutogradCPU)
|
||||
auto_functionalized.fallthrough(DispatchKey.AutogradCUDA)
|
||||
|
||||
|
||||
class AutoFunctionalizedV2(HigherOrderOperator):
|
||||
"""auto_functionalized_v2(_mutable_op, **kwargs)
|
||||
|
||||
This HOP runs a "functional" version of _mutable_op.
|
||||
Unlike AutoFunctionalized, this version is improved to better handle
|
||||
view tensors. This version is only used in non export mode.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__("auto_functionalized_v2")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
/,
|
||||
_mutable_op: OpOverload,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[Any, Tuple[Tensor, ...]]:
|
||||
assert can_auto_functionalize(_mutable_op)
|
||||
assert isinstance(kwargs, dict)
|
||||
return super().__call__(_mutable_op, **kwargs)
|
||||
|
||||
|
||||
auto_functionalized_v2 = AutoFunctionalizedV2()
|
||||
auto_functionalized_v2.__module__ = "torch.ops.higher_order"
|
||||
|
||||
auto_functionalized_v2.fallthrough(DispatchKey.AutogradCPU)
|
||||
auto_functionalized_v2.fallthrough(DispatchKey.AutogradCUDA)
|
||||
|
||||
|
||||
def can_auto_functionalize(op: OperatorBase) -> bool:
|
||||
if not isinstance(op, OpOverload):
|
||||
@ -120,6 +292,253 @@ def can_auto_functionalize(op: OperatorBase) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def get_mutable_args(op: OpOverload) -> Tuple[List[str], List[torch.Type]]:
|
||||
"""
|
||||
Returns the list of argument names that get mutated according to the
|
||||
schema and their types.
|
||||
"""
|
||||
mutable_args_names = [
|
||||
arg.name
|
||||
for arg in op._schema.arguments
|
||||
if arg.alias_info is not None and arg.alias_info.is_write
|
||||
]
|
||||
|
||||
mutable_args_types = [
|
||||
arg.type
|
||||
for arg in op._schema.arguments
|
||||
if arg.alias_info is not None and arg.alias_info.is_write
|
||||
]
|
||||
return mutable_args_names, mutable_args_types
|
||||
|
||||
|
||||
def do_auto_functionalize(
|
||||
op: OpOverload,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
) -> Any:
|
||||
"""Functionalizes a call to op(*args, **kwargs) by emitting a call to
|
||||
`outs = auto_functionalized(op, normalized_kwargs)`
|
||||
and replacing the mutated (args, kwargs) with the corresponding outputs.
|
||||
|
||||
The normalized_kwargs are just the (args, kwargs), but all in kwarg form.
|
||||
This makes handling easier for the auto_functionalized HOP.
|
||||
"""
|
||||
from torch._subclasses.functional_tensor import PythonFunctionalizeAPI
|
||||
|
||||
ctx = PythonFunctionalizeAPI()
|
||||
|
||||
# All of the (args, kwargs), but all as kwargs. The names for the
|
||||
# args come from the schema. This makes it easier for us to work with them.
|
||||
normalized_kwargs = {}
|
||||
schema = op._schema
|
||||
for idx, arg in enumerate(schema.arguments):
|
||||
# NB: torch_dispatch kwargs are the args defined as kwarg-only in the schema
|
||||
if arg.name in kwargs:
|
||||
normalized_kwargs[arg.name] = kwargs[arg.name]
|
||||
elif idx < len(args):
|
||||
# if its out of bounds we don't need to do anything
|
||||
# as it means the the optional arg was passed with its default
|
||||
# value
|
||||
normalized_kwargs[arg.name] = args[idx]
|
||||
else:
|
||||
normalized_kwargs[arg.name] = arg.default_value
|
||||
|
||||
unwrapped_kwargs = ctx.unwrap_tensors(normalized_kwargs) # type: ignore[arg-type]
|
||||
if "self" in unwrapped_kwargs or "self_" in unwrapped_kwargs:
|
||||
warnings.warn(
|
||||
"Using `self` or `self_` as an argument in the definition of custom ops may lead to ambiguous parsing. "
|
||||
"Please consider using a different name for this argument to avoid potential issues."
|
||||
)
|
||||
with ctx.redispatch_to_next():
|
||||
unwrapped_outs = auto_functionalized(
|
||||
op, **unwrapped_kwargs # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# List of the name of args that get mutated (according to the schema)
|
||||
mutable_args_names, _ = get_mutable_args(op)
|
||||
|
||||
unwrapped_actual_out: Union[Any, Tuple[Any]] = unwrapped_outs[
|
||||
: -len(mutable_args_names)
|
||||
]
|
||||
unwrapped_mutable_out = unwrapped_outs[-len(mutable_args_names) :]
|
||||
|
||||
if len(op._schema.returns) == 0:
|
||||
assert unwrapped_actual_out[0] is None
|
||||
unwrapped_actual_out = None
|
||||
elif len(op._schema.returns) == 1:
|
||||
assert len(unwrapped_actual_out) == 1
|
||||
unwrapped_actual_out = unwrapped_actual_out[0]
|
||||
else:
|
||||
assert len(unwrapped_actual_out) == len(op._schema.returns)
|
||||
|
||||
for name, unwrapped_out in zip(mutable_args_names, unwrapped_mutable_out):
|
||||
# Can be None if input was `Tensor(a!)?`
|
||||
if unwrapped_out is None:
|
||||
continue
|
||||
|
||||
# We only handle Tensor or List[Tensor] here for now.
|
||||
def sync_update(o, orig_arg):
|
||||
ctx.replace(orig_arg, o)
|
||||
ctx.commit_update(orig_arg)
|
||||
ctx.sync(orig_arg)
|
||||
|
||||
orig_arg = normalized_kwargs[name]
|
||||
|
||||
if isinstance(unwrapped_out, torch.Tensor):
|
||||
sync_update(unwrapped_out, orig_arg)
|
||||
elif isinstance(unwrapped_out, list) and all(
|
||||
isinstance(o, torch.Tensor) for o in unwrapped_out
|
||||
):
|
||||
assert len(orig_arg) == len(unwrapped_out)
|
||||
for orig_a, o in zip(orig_arg, unwrapped_out):
|
||||
sync_update(o, orig_a)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"unsupported type for auto-functionalization: {unwrapped_out}"
|
||||
)
|
||||
|
||||
return ctx.wrap_tensors(unwrapped_actual_out) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def do_auto_functionalize_v2(
|
||||
op: OpOverload,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
) -> Any:
|
||||
from torch._subclasses.functional_tensor import PythonFunctionalizeAPI
|
||||
|
||||
ctx = PythonFunctionalizeAPI()
|
||||
|
||||
# All of the (args, kwargs), but all as kwargs. The names for the
|
||||
# args come from the schema. This makes it easier for us to work with them.
|
||||
normalized_kwargs = {}
|
||||
|
||||
schema = op._schema
|
||||
for idx, arg in enumerate(schema.arguments):
|
||||
# NB: torch_dispatch kwargs are the args defined as kwarg-only in the schema
|
||||
if arg.name in kwargs:
|
||||
normalized_kwargs[arg.name] = kwargs[arg.name]
|
||||
elif idx < len(args):
|
||||
# if its out of bounds we don't need to do anything
|
||||
# as it means the the optional arg was passed with its default
|
||||
# value
|
||||
normalized_kwargs[arg.name] = args[idx]
|
||||
else:
|
||||
normalized_kwargs[arg.name] = arg.default_value
|
||||
|
||||
# List of the name of args that get mutated (according to the schema)
|
||||
mutable_args_names, mutable_args_types = get_mutable_args(op)
|
||||
|
||||
# A list of all bases of mutable args without duplication
|
||||
all_bases = []
|
||||
all_bases_addresses: list[int] = []
|
||||
|
||||
# Map arg_name to the index of its base in all_bases.
|
||||
arg_to_base_index: Dict[str, Any] = {}
|
||||
|
||||
def update_dict(tensor, arg_name, index=None):
|
||||
base = tensor if tensor._base is None else tensor._base
|
||||
|
||||
def set_result(base_index):
|
||||
if index is None:
|
||||
arg_to_base_index[arg_name] = base_index
|
||||
else:
|
||||
arg_to_base_index[arg_name][index] = base_index
|
||||
|
||||
if not all_bases_addresses.__contains__(base._cdata):
|
||||
all_bases_addresses.append(base._cdata)
|
||||
all_bases.append(base)
|
||||
set_result(len(all_bases) - 1)
|
||||
else:
|
||||
set_result(all_bases_addresses.index(base._cdata))
|
||||
|
||||
for arg_name in mutable_args_names:
|
||||
arg = normalized_kwargs[arg_name]
|
||||
if arg is None:
|
||||
continue
|
||||
|
||||
if isinstance(arg, list):
|
||||
arg_to_base_index[arg_name] = {}
|
||||
for i, tensor in enumerate(arg):
|
||||
if tensor is None:
|
||||
arg_to_base_index[arg_name].append(None)
|
||||
continue
|
||||
|
||||
update_dict(tensor, arg_name, i)
|
||||
|
||||
else:
|
||||
update_dict(arg, arg_name)
|
||||
|
||||
# add view_meta for each args into unwrapped_kwargs.
|
||||
write_view_information_to_args(
|
||||
mutable_args_names,
|
||||
mutable_args_types,
|
||||
normalized_kwargs,
|
||||
arg_to_base_index,
|
||||
)
|
||||
|
||||
# remove mutated args from the kwargs (its a function of _all_bases now)
|
||||
for arg_name in mutable_args_names:
|
||||
del normalized_kwargs[arg_name] # type: ignore[arg-type]
|
||||
|
||||
unwrapped_kwargs = ctx.unwrap_tensors(normalized_kwargs) # type: ignore[arg-type]
|
||||
if "self" in unwrapped_kwargs or "self_" in unwrapped_kwargs:
|
||||
warnings.warn(
|
||||
"Using `self` or `self_` as an argument in the definition of custom ops may lead to ambiguous parsing. "
|
||||
"Please consider using a different name for this argument to avoid potential issues."
|
||||
)
|
||||
all_basis_unwrapped = ctx.unwrap_tensors(all_bases)
|
||||
|
||||
with ctx.redispatch_to_next():
|
||||
unwrapped_outs = auto_functionalized_v2(
|
||||
op, **dict(unwrapped_kwargs, _all_bases=all_basis_unwrapped) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
unwrapped_actual_out: Union[Any, Tuple[Any]] = (
|
||||
unwrapped_outs if len(all_bases) == 0 else unwrapped_outs[: -len(all_bases)]
|
||||
)
|
||||
|
||||
unwrapped_mutable_out = (
|
||||
[] if len(all_bases) == 0 else unwrapped_outs[-len(all_bases) :]
|
||||
)
|
||||
|
||||
if len(op._schema.returns) == 0:
|
||||
assert unwrapped_actual_out[0] is None
|
||||
unwrapped_actual_out = None
|
||||
elif len(op._schema.returns) == 1:
|
||||
assert len(unwrapped_actual_out) == 1
|
||||
unwrapped_actual_out = unwrapped_actual_out[0]
|
||||
else:
|
||||
assert len(unwrapped_actual_out) == len(op._schema.returns)
|
||||
|
||||
for orig_arg, unwrapped_out in zip(all_bases, unwrapped_mutable_out):
|
||||
# Can be None if input was `Tensor(a!)?`
|
||||
if unwrapped_out is None:
|
||||
continue
|
||||
|
||||
# We only handle Tensor or List[Tensor] here for now.
|
||||
def sync_update(o, orig_arg):
|
||||
ctx.replace(orig_arg, o)
|
||||
ctx.commit_update(orig_arg)
|
||||
ctx.sync(orig_arg)
|
||||
|
||||
if isinstance(unwrapped_out, torch.Tensor):
|
||||
sync_update(unwrapped_out, orig_arg)
|
||||
elif isinstance(unwrapped_out, list) and all(
|
||||
isinstance(o, torch.Tensor) for o in unwrapped_out
|
||||
):
|
||||
assert len(orig_arg) == len(unwrapped_out)
|
||||
for orig_a, o in zip(orig_arg, unwrapped_out):
|
||||
sync_update(o, orig_a)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"unsupported type for auto-functionalization: {unwrapped_out}"
|
||||
)
|
||||
|
||||
return ctx.wrap_tensors(unwrapped_actual_out) # type: ignore[arg-type]
|
||||
|
||||
|
||||
# auto_functionalize functions
|
||||
@auto_functionalized.py_impl(DispatchKey.CompositeExplicitAutograd)
|
||||
def auto_functionalized_dense(
|
||||
_mutable_op: OpOverload,
|
||||
@ -129,7 +548,7 @@ def auto_functionalized_dense(
|
||||
new_kwargs = dict(**kwargs)
|
||||
result = []
|
||||
|
||||
_mutable_args_names = get_mutable_arg_names(_mutable_op)
|
||||
_mutable_args_names, _ = get_mutable_args(_mutable_op)
|
||||
for name in _mutable_args_names:
|
||||
if (
|
||||
_only_clone_these_tensors is not None
|
||||
@ -184,115 +603,104 @@ def auto_functionalized_proxy(
|
||||
return result
|
||||
|
||||
|
||||
auto_functionalized.fallthrough(DispatchKey.AutogradCPU)
|
||||
auto_functionalized.fallthrough(DispatchKey.AutogradCUDA)
|
||||
|
||||
|
||||
def get_mutable_arg_names(op: OpOverload) -> List[str]:
|
||||
"""
|
||||
Returns the list of argument names that get mutated according to the
|
||||
schema.
|
||||
"""
|
||||
mutable_args_names = [
|
||||
arg.name
|
||||
for arg in op._schema.arguments
|
||||
if arg.alias_info is not None and arg.alias_info.is_write
|
||||
]
|
||||
return mutable_args_names
|
||||
|
||||
|
||||
def do_auto_functionalize(
|
||||
op: OpOverload,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
) -> Any:
|
||||
"""Functionalizes a call to op(*args, **kwargs) by emitting a call to
|
||||
`outs = auto_functionalized(op, normalized_kwargs)`
|
||||
and replacing the mutated (args, kwargs) with the corresponding outputs.
|
||||
|
||||
The normalized_kwargs are just the (args, kwargs), but all in kwarg form.
|
||||
This makes handling easier for the auto_functionalized HOP.
|
||||
"""
|
||||
from torch._subclasses.functional_tensor import PythonFunctionalizeAPI
|
||||
|
||||
ctx = PythonFunctionalizeAPI()
|
||||
|
||||
# All of the (args, kwargs), but all as kwargs. The names for the
|
||||
# args come from the schema. This makes it easier for us to work with them.
|
||||
normalized_kwargs = {}
|
||||
schema = op._schema
|
||||
for idx, arg in enumerate(schema.arguments):
|
||||
# NB: torch_dispatch kwargs are the args defined as kwarg-only in the schema
|
||||
if arg.name in kwargs:
|
||||
normalized_kwargs[arg.name] = kwargs[arg.name]
|
||||
elif idx < len(args):
|
||||
# if its out of bounds we don't need to do anything
|
||||
# as it means the the optional arg was passed with its default
|
||||
# value
|
||||
normalized_kwargs[arg.name] = args[idx]
|
||||
else:
|
||||
normalized_kwargs[arg.name] = arg.default_value
|
||||
|
||||
unwrapped_kwargs = ctx.unwrap_tensors(normalized_kwargs) # type: ignore[arg-type]
|
||||
if "self" in unwrapped_kwargs or "self_" in unwrapped_kwargs:
|
||||
warnings.warn(
|
||||
"Using `self` or `self_` as an argument in the definition of custom ops may lead to ambiguous parsing. "
|
||||
"Please consider using a different name for this argument to avoid potential issues."
|
||||
)
|
||||
with ctx.redispatch_to_next():
|
||||
unwrapped_outs = auto_functionalized(
|
||||
op, **unwrapped_kwargs # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# List of the name of args that get mutated (according to the schema)
|
||||
mutable_args_names = get_mutable_arg_names(op)
|
||||
|
||||
unwrapped_actual_out: Union[Any, Tuple[Any]] = unwrapped_outs[
|
||||
: -len(mutable_args_names)
|
||||
]
|
||||
unwrapped_mutable_out = unwrapped_outs[-len(mutable_args_names) :]
|
||||
|
||||
if len(op._schema.returns) == 0:
|
||||
assert unwrapped_actual_out[0] is None
|
||||
unwrapped_actual_out = None
|
||||
elif len(op._schema.returns) == 1:
|
||||
assert len(unwrapped_actual_out) == 1
|
||||
unwrapped_actual_out = unwrapped_actual_out[0]
|
||||
else:
|
||||
assert len(unwrapped_actual_out) == len(op._schema.returns)
|
||||
|
||||
for name, unwrapped_out in zip(mutable_args_names, unwrapped_mutable_out):
|
||||
# Can be None if input was `Tensor(a!)?`
|
||||
if unwrapped_out is None:
|
||||
continue
|
||||
|
||||
# We only handle Tensor or List[Tensor] here for now.
|
||||
def sync_update(o, orig_arg):
|
||||
ctx.replace(orig_arg, o)
|
||||
ctx.commit_update(orig_arg)
|
||||
ctx.sync(orig_arg)
|
||||
|
||||
orig_arg = normalized_kwargs[name]
|
||||
|
||||
if isinstance(unwrapped_out, torch.Tensor):
|
||||
sync_update(unwrapped_out, orig_arg)
|
||||
elif isinstance(unwrapped_out, list) and all(
|
||||
isinstance(o, torch.Tensor) for o in unwrapped_out
|
||||
):
|
||||
assert len(orig_arg) == len(unwrapped_out)
|
||||
for orig_a, o in zip(orig_arg, unwrapped_out):
|
||||
sync_update(o, orig_a)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"unsupported type for auto-functionalization: {unwrapped_out}"
|
||||
)
|
||||
|
||||
return ctx.wrap_tensors(unwrapped_actual_out) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@auto_functionalized.py_functionalize_impl
|
||||
def auto_functionalized_func(ctx, _mutable_op, **kwargs):
|
||||
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
|
||||
with ctx.redispatch_to_next():
|
||||
result = auto_functionalized(_mutable_op, **unwrapped_kwargs)
|
||||
return ctx.wrap_tensors(result)
|
||||
|
||||
|
||||
# auto_functionalized_v2 functions
|
||||
@auto_functionalized_v2.py_impl(DispatchKey.CompositeExplicitAutograd)
|
||||
def auto_functionalized_v2_dense(
|
||||
_mutable_op: OpOverload,
|
||||
_only_clone_these_bases: Optional[Tuple[int, ...]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[Any, Tuple[Tensor, ...]]:
|
||||
all_bases: List[Tensor] = kwargs.pop("_all_bases", [])
|
||||
mutable_args_names, mutable_args_types = get_mutable_args(_mutable_op)
|
||||
args_view_info = read_view_information_from_args(
|
||||
mutable_args_names, mutable_args_types, kwargs, all_bases
|
||||
)
|
||||
|
||||
if _only_clone_these_bases is None:
|
||||
_only_clone_these_bases = tuple(range(len(all_bases)))
|
||||
|
||||
def maybe_copy(i, t):
|
||||
if t is None:
|
||||
return None
|
||||
if i in _only_clone_these_bases:
|
||||
return t.clone()
|
||||
else:
|
||||
return t
|
||||
|
||||
all_bases_new = [maybe_copy(i, t) for i, t in enumerate(all_bases)]
|
||||
|
||||
# create new args
|
||||
new_kwargs = dict(**kwargs)
|
||||
|
||||
# re-generate all inputs from all_bases_new using args_view_info and add them to new_kwargs.
|
||||
for arg_name in mutable_args_names:
|
||||
if args_view_info[arg_name] is None:
|
||||
new_kwargs[arg_name] = None
|
||||
elif isinstance(args_view_info[arg_name], list):
|
||||
new_kwargs[arg_name] = []
|
||||
for i, elem in enumerate(args_view_info[arg_name]):
|
||||
if elem is None:
|
||||
new_kwargs[arg_name].append(None)
|
||||
else:
|
||||
view_info = args_view_info[arg_name][i]
|
||||
new_kwargs[arg_name].append(
|
||||
view_info.regenerate_view(all_bases_new)
|
||||
)
|
||||
else:
|
||||
new_kwargs[arg_name] = args_view_info[arg_name].regenerate_view(
|
||||
all_bases_new
|
||||
)
|
||||
|
||||
out = _mutable_op(**new_kwargs)
|
||||
|
||||
if isinstance(out, tuple):
|
||||
return (*out, *all_bases_new) # type: ignore[return-value]
|
||||
else:
|
||||
return (out, *all_bases_new) # type: ignore[return-value]
|
||||
|
||||
|
||||
@auto_functionalized_v2.py_impl(FakeTensorMode)
|
||||
def auto_functionalized_v2_fake(
|
||||
mode,
|
||||
_mutable_op: OpOverload,
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> Tuple[Any, Tuple[Tensor, ...]]:
|
||||
with mode:
|
||||
result = auto_functionalized_v2_dense(_mutable_op, **kwargs)
|
||||
return result
|
||||
|
||||
|
||||
@auto_functionalized_v2.py_impl(ProxyTorchDispatchMode)
|
||||
def auto_functionalized_v2_proxy(
|
||||
mode,
|
||||
_mutable_op: OpOverload,
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> Tuple[Any, Tuple[Tensor, ...]]:
|
||||
with disable_proxy_modes_tracing():
|
||||
out = auto_functionalized_v2(_mutable_op, **kwargs)
|
||||
|
||||
proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
|
||||
out_proxy = mode.tracer.create_proxy(
|
||||
"call_function",
|
||||
auto_functionalized_v2,
|
||||
(_mutable_op,),
|
||||
proxy_kwargs,
|
||||
)
|
||||
result = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
|
||||
return result
|
||||
|
||||
|
||||
@auto_functionalized_v2.py_functionalize_impl
|
||||
def auto_functionalized_v2_func(ctx, _mutable_op, **kwargs):
|
||||
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
|
||||
with ctx.redispatch_to_next():
|
||||
result = auto_functionalized_v2(_mutable_op, **unwrapped_kwargs)
|
||||
return ctx.wrap_tensors(result)
|
||||
|
@ -25,6 +25,11 @@ def autotune_remote_cache_default() -> Optional[bool]:
|
||||
return None
|
||||
|
||||
|
||||
# Enable auto_functionalized_v2 (enabled by default)
|
||||
enable_auto_functionalized_v2 = (
|
||||
os.environ.get("TORCHDYNAMO_AUTO_FUNCTIONALIZED_V2", "0") == "1"
|
||||
)
|
||||
|
||||
# add some debug printouts
|
||||
debug = False
|
||||
|
||||
|
@ -811,7 +811,7 @@ def decompose_auto_functionalized(graph):
|
||||
CallFunctionVarArgs(torch.ops.higher_order.auto_functionalized),
|
||||
pass_dict=graph_pass,
|
||||
)
|
||||
def replacement(match: Match, *args, **kwargs):
|
||||
def _(match: Match, *args, **kwargs):
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized_dense
|
||||
|
||||
only_clone_these_tensors = tuple(
|
||||
@ -849,11 +849,42 @@ def decompose_auto_functionalized(graph):
|
||||
|
||||
match.replace_by_example(decomp, flat_args, run_functional_passes=False)
|
||||
|
||||
@register_graph_pattern(
|
||||
CallFunctionVarArgs(torch.ops.higher_order.auto_functionalized_v2),
|
||||
pass_dict=graph_pass,
|
||||
)
|
||||
def _(match: Match, *args, **kwargs):
|
||||
from torch._higher_order_ops.auto_functionalize import (
|
||||
auto_functionalized_v2_dense,
|
||||
)
|
||||
|
||||
only_clone_these_bases = tuple(
|
||||
match.nodes[0].meta.get("only_clone_these_tensors", [])
|
||||
)
|
||||
|
||||
flat_args, spec = pytree.tree_flatten((args, kwargs))
|
||||
|
||||
# NB: we combine (args, kwargs) into flat args for replacing.
|
||||
# This is replace_by_example uses make_fx which does not support
|
||||
# tracing a function with kwargs.
|
||||
def decomp(*flat_args):
|
||||
args, kwargs = pytree.tree_unflatten(flat_args, spec)
|
||||
return auto_functionalized_v2_dense(*args, only_clone_these_bases, **kwargs)
|
||||
|
||||
match.replace_by_example(decomp, flat_args, run_functional_passes=False)
|
||||
|
||||
graph_pass.apply(graph)
|
||||
|
||||
for node in graph.find_nodes(
|
||||
op="call_function", target=torch.ops.higher_order.auto_functionalized
|
||||
):
|
||||
raise AssertionError("auto_functionalized was not removed")
|
||||
|
||||
for node in graph.find_nodes(
|
||||
op="call_function", target=torch.ops.higher_order.auto_functionalized_v2
|
||||
):
|
||||
raise AssertionError("auto_functionalized_v2 was not removed")
|
||||
|
||||
for node in graph.find_nodes(
|
||||
op="call_function",
|
||||
target=torch.ops.higher_order.triton_kernel_wrapper_functional,
|
||||
|
@ -467,6 +467,7 @@ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
|
||||
if get_node_storage(mutated_arg) is None:
|
||||
return False
|
||||
shared_view_nodes = storage_to_nodes[get_node_storage(mutated_arg)]
|
||||
|
||||
if mutated_arg.op in ("placeholder", "get_attr"):
|
||||
# Get the first copy_ node that mutates the mutated_arg.
|
||||
copy_node = copy_nodes.get(mutated_arg, None)
|
||||
@ -482,6 +483,9 @@ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
|
||||
|
||||
return True
|
||||
elif any(view.op in ("placeholder", "get_attr") for view in shared_view_nodes):
|
||||
# This should never happen in auto_functionalize_v2 non-inference mode,
|
||||
# since all mutated_arg are bases.
|
||||
|
||||
# If mutated arg is view of any of the inputs of the graph,
|
||||
# do not allow for inplacing.
|
||||
# This would require more sophisticated algorithm to handle
|
||||
@ -491,9 +495,30 @@ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
|
||||
node, shared_view_nodes, copy_node=None, mutated_arg=mutated_arg
|
||||
)
|
||||
|
||||
def log_inplace_results(
|
||||
node_name,
|
||||
old_tensors_to_clone,
|
||||
tensors_to_clone,
|
||||
possibly_missed_reinplacing_opportunities,
|
||||
):
|
||||
log.info(
|
||||
"For node %s, attempted to reinplace %s. We were unable to reinplace %s; "
|
||||
"%s (if non-empty) are possible missed reinplacing opportunities that may be bad for "
|
||||
"memory usage and performance.",
|
||||
node_name,
|
||||
old_tensors_to_clone,
|
||||
tensors_to_clone,
|
||||
possibly_missed_reinplacing_opportunities,
|
||||
)
|
||||
torch._dynamo.utils.counters["inductor"][
|
||||
"possibly_missed_reinplacing_opportunities"
|
||||
] += len(possibly_missed_reinplacing_opportunities)
|
||||
|
||||
replace_dict: Dict[torch.fx.Node, torch.fx.Node] = {}
|
||||
|
||||
def reinplace_and_refine_tensors_to_clone(old_tensors_to_clone, kwargs, node_name):
|
||||
def reinplace_and_refine_tensors_to_clone(
|
||||
old_tensors_to_clone, kwargs, node_name, auto_functionalize_v2=False
|
||||
):
|
||||
tensors_to_clone: List[str] = []
|
||||
storage_of_reinplaced_args = set()
|
||||
possibly_missed_reinplacing_opportunities = []
|
||||
@ -507,6 +532,7 @@ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
|
||||
|
||||
for arg in old_tensors_to_clone:
|
||||
assert arg in kwargs
|
||||
|
||||
mutated_arg = kwargs[arg]
|
||||
|
||||
# Let's say we have:
|
||||
@ -523,12 +549,18 @@ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
|
||||
mutated_arg
|
||||
)
|
||||
if should_attempt_reinplace and can_inplace(node, mutated_arg):
|
||||
# In general, we probably do not need those optimizations.
|
||||
copy_node = copy_args_to_copy_nodes.get((mutated_arg, node))
|
||||
if copy_node is not None:
|
||||
replace_dict[copy_node] = copy_node.args[0]
|
||||
for user in node.users:
|
||||
if user.target == operator.getitem and user.args[1] == arg:
|
||||
replace_dict[user] = mutated_arg
|
||||
if not auto_functionalize_v2:
|
||||
for user in node.users:
|
||||
# For auto_functionalize_v2, arg is the index of the base, where base at index i corresponds to
|
||||
# output atindex size(out)+i.
|
||||
# This used to compare string with integers before for auto_functionalize_v2. Not sure
|
||||
# if it was needed for inplaceable_triton_ops?
|
||||
if user.target == operator.getitem and user.args[1] == arg:
|
||||
replace_dict[user] = mutated_arg
|
||||
|
||||
if isinstance(mutated_arg, (list, tuple)):
|
||||
for a in mutated_arg:
|
||||
@ -540,18 +572,12 @@ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
|
||||
possibly_missed_reinplacing_opportunities.append(arg)
|
||||
tensors_to_clone.append(arg)
|
||||
|
||||
log.info(
|
||||
"For node %s, attempted to reinplace %s. We were unable to reinplace %s; "
|
||||
"%s (if non-empty) are possible missed reinplacing opportunities that may be bad for "
|
||||
"memory usage and performance.",
|
||||
log_inplace_results(
|
||||
node_name,
|
||||
old_tensors_to_clone,
|
||||
tensors_to_clone,
|
||||
possibly_missed_reinplacing_opportunities,
|
||||
)
|
||||
torch._dynamo.utils.counters["inductor"][
|
||||
"possibly_missed_reinplacing_opportunities"
|
||||
] += len(possibly_missed_reinplacing_opportunities)
|
||||
return tensors_to_clone
|
||||
|
||||
for node in graph.nodes:
|
||||
@ -565,17 +591,37 @@ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
|
||||
if copy_node is not None:
|
||||
replace_dict[copy_node] = copy_node.args[0]
|
||||
node.target = inplaceable_op.inplace_op
|
||||
elif node.target == torch.ops.higher_order.auto_functionalized_v2:
|
||||
_mutable_op = node.args[0]
|
||||
kwargs = node.kwargs
|
||||
|
||||
all_bases = kwargs["_all_bases"]
|
||||
bases_to_clone = range(len(all_bases))
|
||||
base_tensors_dct = dict(enumerate(all_bases))
|
||||
new_bases_to_clone: List[int] = reinplace_and_refine_tensors_to_clone(
|
||||
bases_to_clone,
|
||||
base_tensors_dct,
|
||||
node.target,
|
||||
auto_functionalize_v2=True,
|
||||
)
|
||||
# Stash the metadata. There is a pass later on where we decompose
|
||||
# auto_functionalized into clones + a mutable op; this metadata
|
||||
# tells the decomp to only clone the following inputs
|
||||
node.meta["only_clone_these_tensors"] = new_bases_to_clone
|
||||
elif node.target == torch.ops.higher_order.auto_functionalized:
|
||||
_mutable_op = node.args[0]
|
||||
from torch._higher_order_ops.auto_functionalize import get_mutable_arg_names
|
||||
from torch._higher_order_ops.auto_functionalize import get_mutable_args
|
||||
|
||||
tensors_to_clone = get_mutable_arg_names(_mutable_op)
|
||||
tensors_to_clone, _ = get_mutable_args(_mutable_op)
|
||||
# Don't try to reinplace Optional[Tensor] args that are None.
|
||||
tensors_to_clone = [
|
||||
t for t in tensors_to_clone if node.kwargs[t] is not None
|
||||
]
|
||||
tensors_to_clone = reinplace_and_refine_tensors_to_clone(
|
||||
tensors_to_clone, node.kwargs, _mutable_op._name
|
||||
tensors_to_clone,
|
||||
node.kwargs,
|
||||
_mutable_op._name,
|
||||
auto_functionalize_v2=False,
|
||||
)
|
||||
|
||||
# Stash the metadata. There is a pass later on where we decompose
|
||||
|
@ -2,9 +2,10 @@
|
||||
import contextlib
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, ContextManager, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Callable, ContextManager, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch._inductor.config as inductor_config
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._C import _functionalization_reapply_views_tls as _reapply_views
|
||||
from torch._ops import _get_dispatch_mode_pre_dispatch
|
||||
@ -434,6 +435,7 @@ class FunctionalTensorMode(TorchDispatchMode):
|
||||
from torch._higher_order_ops.auto_functionalize import (
|
||||
can_auto_functionalize,
|
||||
do_auto_functionalize,
|
||||
do_auto_functionalize_v2,
|
||||
)
|
||||
|
||||
if can_auto_functionalize(
|
||||
@ -444,7 +446,10 @@ class FunctionalTensorMode(TorchDispatchMode):
|
||||
# it doesn't matter what mode we use here because
|
||||
# the implementation of do_auto_functionalize doesn't
|
||||
# interact with FunctionalTensorMode at all
|
||||
return do_auto_functionalize(func, args, kwargs)
|
||||
if self.export or not inductor_config.enable_auto_functionalized_v2:
|
||||
return do_auto_functionalize(func, args, kwargs)
|
||||
else:
|
||||
return do_auto_functionalize_v2(func, args, kwargs)
|
||||
|
||||
from torch._higher_order_ops.effects import handle_effects, has_effects
|
||||
|
||||
@ -611,7 +616,7 @@ class BaseFunctionalizeAPI(ABC):
|
||||
@abstractmethod
|
||||
def unwrap_tensors(
|
||||
self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...]]
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
||||
) -> Any:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -654,8 +659,8 @@ class PythonFunctionalizeAPI(BaseFunctionalizeAPI):
|
||||
)
|
||||
|
||||
def unwrap_tensors(
|
||||
self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...]]
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
||||
self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...], List[torch.Tensor]]
|
||||
) -> Any:
|
||||
return torch.utils._pytree.tree_map_only(
|
||||
FunctionalTensor, FunctionalTensor.from_functional, args
|
||||
)
|
||||
@ -748,9 +753,11 @@ class FunctorchFunctionalizeAPI(BaseFunctionalizeAPI):
|
||||
def functionalize(self, inner_f: Callable) -> Callable:
|
||||
return torch.func.functionalize(
|
||||
inner_f,
|
||||
remove="mutations_and_views"
|
||||
if self.interpreter.functionalize_add_back_views()
|
||||
else "mutations",
|
||||
remove=(
|
||||
"mutations_and_views"
|
||||
if self.interpreter.functionalize_add_back_views()
|
||||
else "mutations"
|
||||
),
|
||||
)
|
||||
|
||||
def redispatch_to_next(self) -> ContextManager:
|
||||
|
@ -9,7 +9,7 @@
|
||||
import torch
|
||||
from torch._higher_order_ops.auto_functionalize import (
|
||||
auto_functionalized,
|
||||
get_mutable_arg_names,
|
||||
auto_functionalized_v2,
|
||||
)
|
||||
from torch._inductor.fx_passes.post_grad import decompose_auto_functionalized
|
||||
from torch.export import ExportedProgram
|
||||
@ -36,10 +36,13 @@ def unsafe_remove_auto_functionalized_pass(
|
||||
if not isinstance(module, torch.fx.GraphModule):
|
||||
continue
|
||||
for node in ep.graph.nodes:
|
||||
if node.op == "call_function" and node.target is auto_functionalized:
|
||||
if (
|
||||
node.op == "call_function" and node.target is auto_functionalized
|
||||
) or (
|
||||
node.op == "call_function" and node.target is auto_functionalized_v2
|
||||
):
|
||||
func = node.args[0]
|
||||
assert isinstance(func, torch._ops.OpOverload)
|
||||
mutable_args_names = get_mutable_arg_names(func)
|
||||
# re-inplace everything
|
||||
node.meta["only_clone_these_tensors"] = []
|
||||
decompose_auto_functionalized(ep.graph)
|
||||
|
Reference in New Issue
Block a user