mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add a reinplacing FX pass (#80897)
Adds a "reinplacing" FX transform, that goes through an FX graph and tries to convert out-of-place op calls into inplace calls whenever possible. Followups from this PR include: - Set up torch bench, and run the whole torchbench suite using AOTAutograd + functionalize + rein placing transforms to surface any issues (this is what I'm currently working on). Right now, I have some basic unit tests just to sanity check that the general logic makes sense. - Add any missing inplace ops. This is mostly the `*_scatter*` ops, e.g. `diagonal_scatter_`, because these ops will commonly show up an FX graph after running functionalization. The criteria for when you can swap an op `b = a.add(...)` with `a.add_(...)` is: (1) An inplace variant of the operator with the same schema needs to exist (`aten.add` -> `aten.add_`) (2) `a` (**or any of its aliases**) can't be used as an input to any other operators later on in the graph (3) `a` can't be one of the inputs to the entire graph. It also can't be an **alias** of any of the inputs *** *** One thing to note: (3) means that we can't technically guarantee that we'll get back **all** memory usage that we lost from functionalization. Functionalization converts input mutations into out-of-place calls, and then adds a `copy_()` to the end of the graph to preserve semantics. I added logic to handle `copy_()` in this PR because it it's a pretty important optimizations in the context of `functionalization()`: any program that performs input mutations will have a `copy_()` in it after running functionalization. There are some examples in the test file, but I think staring at an example of where re-inplacing is/isn't allowed to run is helpful: ``` // Before functionalization def foo(a): tmp1 = a.add_(1) tmp2 = a.add(2) // After functionalization def foo(a) tmp1 = a.add(1) tmp2 = a.add(2) .... a.copy_(tmp1) // After re-inplacing def foo(a) // first add() is safe to re-inplace even though a is a program input, // because a's data is overwritten later by a copy_() tmp1 = a.add_(1) // second add() is NOT safe to re-inplace, because: // (1) a and tmp1 are aliased. Note that they weren't aliased in the original program, but they are now that we've done some re-inplacing. // (2) tmp1 is used as an input later in the program tmp2 = a.add(2) .... a.copy_(tmp1) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/80897 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
46b83f66ec
commit
3ef7a6921d
@ -9,6 +9,12 @@
|
|||||||
|
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
|
|
||||||
|
#ifndef AT_PER_OPERATOR_HEADERS
|
||||||
|
#include <ATen/Functions.h>
|
||||||
|
#else
|
||||||
|
#include <ATen/ops/_to_copy.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
|
|
||||||
void FunctionalTensorWrapper::set_constructor_metadata() {
|
void FunctionalTensorWrapper::set_constructor_metadata() {
|
||||||
@ -205,7 +211,9 @@ void FunctionalTensorWrapper::replace_(const Tensor& other) {
|
|||||||
if (dtype() != value_.unsafeGetTensorImpl()->dtype() || layout() != value_.unsafeGetTensorImpl()->layout()) {
|
if (dtype() != value_.unsafeGetTensorImpl()->dtype() || layout() != value_.unsafeGetTensorImpl()->layout()) {
|
||||||
// .to() should not re-entrantly go through functionalization.
|
// .to() should not re-entrantly go through functionalization.
|
||||||
at::AutoDispatchSkipFunctionalize guard;
|
at::AutoDispatchSkipFunctionalize guard;
|
||||||
value_ = value_.to(c10::TensorOptions().dtype(dtype()).layout(layout()));
|
// and we want _to_copy() to show up in the graph, not the composite .to() operator
|
||||||
|
// (this can happen if autograd has already run by the time we enter this code)
|
||||||
|
value_ = at::_to_copy(value_, c10::TensorOptions().dtype(dtype()).layout(layout()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -59,6 +59,7 @@ torch.fx.node.Node.update_arg(self, idx: int, arg: torch.fx.node.Argument) -> No
|
|||||||
torch.fx.node.Node.update_kwarg(self, key: str, arg: torch.fx.node.Argument) -> None
|
torch.fx.node.Node.update_kwarg(self, key: str, arg: torch.fx.node.Argument) -> None
|
||||||
torch.fx.node.map_aggregate(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Argument], torch.fx.node.Argument]) -> torch.fx.node.Argument
|
torch.fx.node.map_aggregate(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Argument], torch.fx.node.Argument]) -> torch.fx.node.Argument
|
||||||
torch.fx.node.map_arg(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Node], torch.fx.node.Argument]) -> torch.fx.node.Argument
|
torch.fx.node.map_arg(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Node], torch.fx.node.Argument]) -> torch.fx.node.Argument
|
||||||
|
torch.fx.passes.reinplace.reinplace(gm, *sample_args)
|
||||||
torch.fx.passes.split_module.split_module(m: torch.fx.graph_module.GraphModule, root_m: torch.nn.modules.module.Module, split_callback: Callable[[torch.fx.node.Node], int], qualname_map: Optional[Dict[str, str]] = None)
|
torch.fx.passes.split_module.split_module(m: torch.fx.graph_module.GraphModule, root_m: torch.nn.modules.module.Module, split_callback: Callable[[torch.fx.node.Node], int], qualname_map: Optional[Dict[str, str]] = None)
|
||||||
torch.fx.proxy.Attribute.__init__(self, root: torch.fx.proxy.Proxy, attr: str)
|
torch.fx.proxy.Attribute.__init__(self, root: torch.fx.proxy.Proxy, attr: str)
|
||||||
torch.fx.proxy.Proxy.__init__(self, node: torch.fx.node.Node, tracer: 'Optional[TracerBase]' = None)
|
torch.fx.proxy.Proxy.__init__(self, node: torch.fx.node.Node, tracer: 'Optional[TracerBase]' = None)
|
||||||
|
@ -5,6 +5,7 @@ from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorc
|
|||||||
from torch.testing._internal.logging_tensor import LoggingTensor, capture_logs
|
from torch.testing._internal.logging_tensor import LoggingTensor, capture_logs
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
from torch.fx.experimental.proxy_tensor import make_fx
|
from torch.fx.experimental.proxy_tensor import make_fx
|
||||||
|
from torch.fx.passes.reinplace import reinplace
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
@ -17,59 +18,72 @@ def are_aliased(x, y):
|
|||||||
return y._base is x
|
return y._base is x
|
||||||
return x._base is y._base
|
return x._base is y._base
|
||||||
|
|
||||||
|
# We can unify testing and use functionalize() here instead
|
||||||
|
# if/when functorch moves into core.
|
||||||
|
# This is basically a crappy version of `functionalize()` for single-tensor-arg inputs.
|
||||||
|
def _functionalize(f, *, reapply_views: bool):
|
||||||
|
def wrapped(a):
|
||||||
|
input_functional = torch._to_functional_tensor(a)
|
||||||
|
torch._enable_functionalization(reapply_views=reapply_views)
|
||||||
|
try:
|
||||||
|
out = f(input_functional)
|
||||||
|
finally:
|
||||||
|
torch._disable_functionalization()
|
||||||
|
torch._sync(input_functional)
|
||||||
|
inpt_new = torch._from_functional_tensor(input_functional)
|
||||||
|
if inpt_new is not a:
|
||||||
|
# Existing deficiency in functionalize():
|
||||||
|
# we don't correctly mutate input metadata (yet?)
|
||||||
|
if inpt_new.shape == a.shape:
|
||||||
|
a.copy_(inpt_new)
|
||||||
|
tree_map(torch._sync, out)
|
||||||
|
out_unwrapped = tree_map(torch._from_functional_tensor, out)
|
||||||
|
return out_unwrapped
|
||||||
|
|
||||||
|
return wrapped
|
||||||
|
|
||||||
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "https://github.com/pytorch/pytorch/issues/81457")
|
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "https://github.com/pytorch/pytorch/issues/81457")
|
||||||
class TestFunctionalization(TestCase):
|
class TestFunctionalization(TestCase):
|
||||||
# We can unify testing and use functionalize() here instead
|
|
||||||
# if/when functorch moves into core.
|
|
||||||
def _functionalize(self, f, *, reapply_views: bool):
|
|
||||||
def wrapped(a):
|
|
||||||
input_functional = torch._to_functional_tensor(a)
|
|
||||||
torch._enable_functionalization(reapply_views=reapply_views)
|
|
||||||
try:
|
|
||||||
out = f(input_functional)
|
|
||||||
finally:
|
|
||||||
torch._disable_functionalization()
|
|
||||||
torch._sync(input_functional)
|
|
||||||
tree_map(torch._sync, out)
|
|
||||||
out_unwrapped = tree_map(torch._from_functional_tensor, out)
|
|
||||||
return out_unwrapped
|
|
||||||
|
|
||||||
return wrapped
|
def get_logs(self, func, inpt, *, reapply_views=False, run_reinplace=False):
|
||||||
|
inpt_clone = inpt.clone()
|
||||||
def get_logs(self, func, inpt, *, reapply_views=False):
|
traced_f = make_fx(_functionalize(func, reapply_views=reapply_views))(inpt)
|
||||||
traced_f = make_fx(self._functionalize(func, reapply_views=reapply_views))(inpt)
|
if run_reinplace:
|
||||||
|
traced_f = reinplace(traced_f, inpt_clone)
|
||||||
return traced_f.code
|
return traced_f.code
|
||||||
|
|
||||||
def assert_functionalization(self, func, inpt, *, reapply_views=False):
|
def assert_functionalization(self, func, inpt, *, reapply_views=False, mutated_input_metadata=False):
|
||||||
input_clone = inpt.clone()
|
input_clone = inpt.clone()
|
||||||
input_clone2 = inpt.clone()
|
input_clone2 = inpt.clone()
|
||||||
input_functional = torch._to_functional_tensor(input_clone2)
|
input_clone3 = inpt.clone()
|
||||||
|
|
||||||
# Compare outputs (and mutated inputs), with and without functionalization.
|
# Compare outputs (and mutated inputs), with and without functionalization.
|
||||||
out_ref = func(inpt)
|
out_ref = func(inpt)
|
||||||
|
out_functional = _functionalize(func, reapply_views=reapply_views)(input_clone)
|
||||||
|
# The reinplacing pass is only valid to run with reapply_views=True.
|
||||||
|
functional_func = make_fx(_functionalize(func, reapply_views=True))(input_clone2)
|
||||||
|
reinplace_func = reinplace(make_fx(_functionalize(func, reapply_views=True))(input_clone2), input_clone2)
|
||||||
|
|
||||||
torch._enable_functionalization(reapply_views=reapply_views)
|
# NOTE: for now, need to pass in fresh inputs here, because make_fx
|
||||||
try:
|
# will directly mutate the inputs that you trace with.
|
||||||
out_functional = func(input_functional)
|
# Once this is fixed we can clean this up.
|
||||||
finally:
|
out_reinplace = reinplace_func(input_clone3)
|
||||||
torch._disable_functionalization()
|
|
||||||
|
|
||||||
# We need to sync the input tensors first, in case there are any queued mutations left.
|
# functionalize() deficiency: input metadata mutations aren't propagated properly,
|
||||||
torch._sync(input_functional)
|
# so we just need to skip checks here for the tests that exercise that.
|
||||||
self.assertEqual(inpt, torch._from_functional_tensor(input_functional)) # input mutations should still occur
|
if not mutated_input_metadata:
|
||||||
|
self.assertEqual(inpt, input_clone) # input mutations should still occur
|
||||||
|
self.assertEqual(inpt, input_clone3)
|
||||||
|
|
||||||
# Handle tests with multi-tensor outputs
|
# Handle tests with multi-tensor outputs
|
||||||
if isinstance(out_ref, tuple) and isinstance(out_functional, tuple):
|
if isinstance(out_ref, tuple):
|
||||||
out_refs, out_functionals = list(out_ref), list(out_functional)
|
out_refs, out_functionals, out_reinplaces = list(out_ref), list(out_functional), list(out_reinplace)
|
||||||
else:
|
else:
|
||||||
out_refs, out_functionals = [out_ref], [out_functional]
|
out_refs, out_functionals, out_reinplaces = [out_ref], [out_functional], [out_reinplace]
|
||||||
|
|
||||||
for out_ref_, out_functional_ in zip(out_refs, out_functionals):
|
for out_ref_, out_functional_, out_reinplace_ in zip(out_refs, out_functionals, out_reinplaces):
|
||||||
self.assertEqual(out_ref_.size(), out_functional_.size())
|
self.assertEqual(out_ref_, out_functional_)
|
||||||
torch._sync(out_functional_)
|
self.assertEqual(out_ref_, out_reinplace_)
|
||||||
out_functional_unwrapped = torch._from_functional_tensor(out_functional_)
|
|
||||||
self.assertEqual(out_ref_, out_functional_unwrapped)
|
|
||||||
|
|
||||||
def test_save_for_backwards_segfault(self):
|
def test_save_for_backwards_segfault(self):
|
||||||
inp = torch._to_functional_tensor(LoggingTensor(torch.randn(2, 2))).requires_grad_(True)
|
inp = torch._to_functional_tensor(LoggingTensor(torch.randn(2, 2))).requires_grad_(True)
|
||||||
@ -104,10 +118,27 @@ class TestFunctionalization(TestCase):
|
|||||||
def forward(self, a_1):
|
def forward(self, a_1):
|
||||||
empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||||
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
|
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
|
||||||
view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2]); a_1 = None
|
view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2])
|
||||||
add_tensor = torch.ops.aten.add.Tensor(view_copy_default, fill_scalar); view_copy_default = fill_scalar = None
|
add_tensor = torch.ops.aten.add.Tensor(view_copy_default, fill_scalar); view_copy_default = fill_scalar = None
|
||||||
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4, 2])
|
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4, 2])
|
||||||
mul_tensor = torch.ops.aten.mul.Tensor(view_copy_default_1, view_copy_default_1); view_copy_default_1 = None
|
mul_tensor = torch.ops.aten.mul.Tensor(view_copy_default_1, view_copy_default_1)
|
||||||
|
copy__default = torch.ops.aten.copy_.default(a_1, view_copy_default_1); a_1 = view_copy_default_1 = None
|
||||||
|
return add_tensor
|
||||||
|
""")
|
||||||
|
|
||||||
|
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True)
|
||||||
|
self.assertExpectedInline(reinplaced_logs, """\
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, a_1):
|
||||||
|
empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||||
|
fill_scalar = torch.ops.aten.fill_.Scalar(empty, 1.0)
|
||||||
|
view_default = torch.ops.aten.view.default(a_1, [4, 2])
|
||||||
|
add_tensor = torch.ops.aten.add.Tensor(view_default, empty); view_default = empty = None
|
||||||
|
view_default_1 = torch.ops.aten.view.default(add_tensor, [4, 2])
|
||||||
|
mul_tensor = torch.ops.aten.mul.Tensor(view_default_1, view_default_1)
|
||||||
|
copy__default = torch.ops.aten.copy_.default(a_1, view_default_1); a_1 = view_default_1 = None
|
||||||
return add_tensor
|
return add_tensor
|
||||||
""")
|
""")
|
||||||
|
|
||||||
@ -134,6 +165,21 @@ def forward(self, a_1):
|
|||||||
add_tensor = torch.ops.aten.add.Tensor(view_copy_default, fill_scalar); view_copy_default = fill_scalar = None
|
add_tensor = torch.ops.aten.add.Tensor(view_copy_default, fill_scalar); view_copy_default = fill_scalar = None
|
||||||
mul_tensor = torch.ops.aten.mul.Tensor(add_tensor, add_tensor); add_tensor = None
|
mul_tensor = torch.ops.aten.mul.Tensor(add_tensor, add_tensor); add_tensor = None
|
||||||
return mul_tensor
|
return mul_tensor
|
||||||
|
""")
|
||||||
|
|
||||||
|
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True)
|
||||||
|
self.assertExpectedInline(reinplaced_logs, """\
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, a_1):
|
||||||
|
empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||||
|
fill_scalar = torch.ops.aten.fill_.Scalar(empty, 1.0)
|
||||||
|
view_default = torch.ops.aten.view.default(a_1, [4, 2]); a_1 = None
|
||||||
|
empty_1 = torch.ops.aten.empty.SymInt([], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||||
|
add_tensor = torch.ops.aten.add.Tensor(view_default, empty); view_default = empty = None
|
||||||
|
mul_tensor = torch.ops.aten.mul.Tensor(add_tensor, add_tensor); add_tensor = None
|
||||||
|
return mul_tensor
|
||||||
""")
|
""")
|
||||||
|
|
||||||
def test_multi_out(self):
|
def test_multi_out(self):
|
||||||
@ -157,6 +203,20 @@ def forward(self, a_1):
|
|||||||
getitem = aminmax_default[0]
|
getitem = aminmax_default[0]
|
||||||
getitem_1 = aminmax_default[1]; aminmax_default = None
|
getitem_1 = aminmax_default[1]; aminmax_default = None
|
||||||
return getitem
|
return getitem
|
||||||
|
""")
|
||||||
|
|
||||||
|
reinplaced_logs = self.get_logs(f, torch.arange(8, dtype=torch.float32), reapply_views=True, run_reinplace=True)
|
||||||
|
self.assertExpectedInline(reinplaced_logs, """\
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, a_1):
|
||||||
|
empty = torch.ops.aten.empty.SymInt([4], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||||
|
empty_1 = torch.ops.aten.empty.SymInt([4], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||||
|
aminmax_default = torch.ops.aten.aminmax.default(a_1, dim = 0); a_1 = None
|
||||||
|
getitem = aminmax_default[0]
|
||||||
|
getitem_1 = aminmax_default[1]; aminmax_default = None
|
||||||
|
return getitem
|
||||||
""")
|
""")
|
||||||
|
|
||||||
def test_tensor_ctr(self):
|
def test_tensor_ctr(self):
|
||||||
@ -165,7 +225,38 @@ def forward(self, a_1):
|
|||||||
z = y.view(-1)
|
z = y.view(-1)
|
||||||
z.add_(1)
|
z.add_(1)
|
||||||
return y
|
return y
|
||||||
self.assert_functionalization(f, torch.arange(3, dtype=torch.float32))
|
|
||||||
|
inpt = torch.arange(3, dtype=torch.float32)
|
||||||
|
self.assert_functionalization(f, inpt)
|
||||||
|
|
||||||
|
logs = self.get_logs(f, inpt)
|
||||||
|
self.assertExpectedInline(logs, """\
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, a_1):
|
||||||
|
_tensor_constant0 = self._tensor_constant0
|
||||||
|
lift_fresh = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
|
||||||
|
view_copy_default = torch.ops.aten.view_copy.default(lift_fresh, [-1]); lift_fresh = None
|
||||||
|
add_tensor = torch.ops.aten.add.Tensor(view_copy_default, 1); view_copy_default = None
|
||||||
|
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [3]); add_tensor = None
|
||||||
|
return view_copy_default_1
|
||||||
|
""")
|
||||||
|
|
||||||
|
reinplaced_logs = self.get_logs(f, inpt, reapply_views=True, run_reinplace=True)
|
||||||
|
self.assertExpectedInline(reinplaced_logs, """\
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, a_1):
|
||||||
|
_tensor_constant0 = self._tensor_constant0
|
||||||
|
lift_fresh = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
|
||||||
|
view_default = torch.ops.aten.view.default(lift_fresh, [-1]); lift_fresh = None
|
||||||
|
add_tensor = torch.ops.aten.add_.Tensor(view_default, 1)
|
||||||
|
view_default_1 = torch.ops.aten.view.default(view_default, [3]); view_default = None
|
||||||
|
return view_default_1
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
def test_inplace_on_non_view(self):
|
def test_inplace_on_non_view(self):
|
||||||
def f(x):
|
def f(x):
|
||||||
@ -185,9 +276,25 @@ def forward(self, a_1):
|
|||||||
empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||||
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
|
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
|
||||||
view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2])
|
view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2])
|
||||||
add_tensor = torch.ops.aten.add.Tensor(a_1, fill_scalar); a_1 = fill_scalar = None
|
add_tensor = torch.ops.aten.add.Tensor(a_1, fill_scalar); fill_scalar = None
|
||||||
|
copy__default = torch.ops.aten.copy_.default(a_1, add_tensor); a_1 = None
|
||||||
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4, 2]); add_tensor = None
|
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4, 2]); add_tensor = None
|
||||||
return view_copy_default_1
|
return view_copy_default_1
|
||||||
|
""")
|
||||||
|
|
||||||
|
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True)
|
||||||
|
self.assertExpectedInline(reinplaced_logs, """\
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, a_1):
|
||||||
|
empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||||
|
fill_scalar = torch.ops.aten.fill_.Scalar(empty, 1.0)
|
||||||
|
view_default = torch.ops.aten.view.default(a_1, [4, 2])
|
||||||
|
add_tensor = torch.ops.aten.add.Tensor(a_1, empty); empty = None
|
||||||
|
copy__default = torch.ops.aten.copy_.default(a_1, add_tensor); a_1 = None
|
||||||
|
view_default_1 = torch.ops.aten.view.default(add_tensor, [4, 2]); add_tensor = None
|
||||||
|
return view_default_1
|
||||||
""")
|
""")
|
||||||
|
|
||||||
# Some ops that are mutable are neither inplace nor out= ops.
|
# Some ops that are mutable are neither inplace nor out= ops.
|
||||||
@ -202,13 +309,14 @@ def forward(self, a_1):
|
|||||||
|
|
||||||
|
|
||||||
def forward(self, a_1):
|
def forward(self, a_1):
|
||||||
_fused_moving_avg_obs_fq_helper_functional_default = torch.ops.aten._fused_moving_avg_obs_fq_helper_functional.default(a_1, a_1, a_1, a_1, a_1, a_1, a_1, 1.0, 0, 1, 0); a_1 = None
|
_fused_moving_avg_obs_fq_helper_functional_default = torch.ops.aten._fused_moving_avg_obs_fq_helper_functional.default(a_1, a_1, a_1, a_1, a_1, a_1, a_1, 1.0, 0, 1, 0)
|
||||||
getitem = _fused_moving_avg_obs_fq_helper_functional_default[0]
|
getitem = _fused_moving_avg_obs_fq_helper_functional_default[0]
|
||||||
getitem_1 = _fused_moving_avg_obs_fq_helper_functional_default[1]
|
getitem_1 = _fused_moving_avg_obs_fq_helper_functional_default[1]
|
||||||
getitem_2 = _fused_moving_avg_obs_fq_helper_functional_default[2]
|
getitem_2 = _fused_moving_avg_obs_fq_helper_functional_default[2]
|
||||||
getitem_3 = _fused_moving_avg_obs_fq_helper_functional_default[3]
|
getitem_3 = _fused_moving_avg_obs_fq_helper_functional_default[3]
|
||||||
getitem_4 = _fused_moving_avg_obs_fq_helper_functional_default[4]
|
getitem_4 = _fused_moving_avg_obs_fq_helper_functional_default[4]
|
||||||
getitem_5 = _fused_moving_avg_obs_fq_helper_functional_default[5]; _fused_moving_avg_obs_fq_helper_functional_default = None
|
getitem_5 = _fused_moving_avg_obs_fq_helper_functional_default[5]; _fused_moving_avg_obs_fq_helper_functional_default = None
|
||||||
|
copy__default = torch.ops.aten.copy_.default(a_1, getitem_5); a_1 = getitem_5 = None
|
||||||
return (getitem, getitem_1)
|
return (getitem, getitem_1)
|
||||||
""") # noqa: B950
|
""") # noqa: B950
|
||||||
|
|
||||||
@ -226,7 +334,8 @@ def forward(self, a_1):
|
|||||||
def forward(self, a_1):
|
def forward(self, a_1):
|
||||||
as_strided_copy_default = torch.ops.aten.as_strided_copy.default(a_1, [2], [2], 1)
|
as_strided_copy_default = torch.ops.aten.as_strided_copy.default(a_1, [2], [2], 1)
|
||||||
add_tensor = torch.ops.aten.add.Tensor(as_strided_copy_default, 1); as_strided_copy_default = None
|
add_tensor = torch.ops.aten.add.Tensor(as_strided_copy_default, 1); as_strided_copy_default = None
|
||||||
as_strided_scatter_default = torch.ops.aten.as_strided_scatter.default(a_1, add_tensor, [2], [2], 1); a_1 = add_tensor = None
|
as_strided_scatter_default = torch.ops.aten.as_strided_scatter.default(a_1, add_tensor, [2], [2], 1); add_tensor = None
|
||||||
|
copy__default = torch.ops.aten.copy_.default(a_1, as_strided_scatter_default); a_1 = None
|
||||||
return as_strided_scatter_default
|
return as_strided_scatter_default
|
||||||
""")
|
""")
|
||||||
|
|
||||||
@ -263,11 +372,23 @@ def forward(self, a_1):
|
|||||||
return cat_default
|
return cat_default
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True)
|
||||||
|
self.assertExpectedInline(reinplaced_logs, """\
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, a_1):
|
||||||
|
empty = torch.ops.aten.empty.SymInt([0], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||||
|
cat_default = torch.ops.aten.cat.default([a_1]); a_1 = None
|
||||||
|
return cat_default
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
def test_diagonal(self):
|
def test_diagonal(self):
|
||||||
def f(x):
|
def f(x):
|
||||||
# test: view ops that take a subset of the original tensor (select/diagonal)
|
# test: view ops that take a subset of the original tensor (select/diagonal)
|
||||||
tmp = torch.ones(2)
|
tmp = torch.ones(2)
|
||||||
y = x.diagonal()
|
y = x.clone().diagonal()
|
||||||
y.add_(tmp)
|
y.add_(tmp)
|
||||||
z = x * x
|
z = x * x
|
||||||
return z
|
return z
|
||||||
@ -280,10 +401,25 @@ def forward(self, a_1):
|
|||||||
def forward(self, a_1):
|
def forward(self, a_1):
|
||||||
empty = torch.ops.aten.empty.memory_format([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
empty = torch.ops.aten.empty.memory_format([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||||
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
|
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
|
||||||
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(a_1)
|
clone_default = torch.ops.aten.clone.default(a_1)
|
||||||
|
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(clone_default); clone_default = None
|
||||||
add_tensor = torch.ops.aten.add.Tensor(diagonal_copy_default, fill_scalar); diagonal_copy_default = fill_scalar = None
|
add_tensor = torch.ops.aten.add.Tensor(diagonal_copy_default, fill_scalar); diagonal_copy_default = fill_scalar = None
|
||||||
diagonal_scatter_default = torch.ops.aten.diagonal_scatter.default(a_1, add_tensor); a_1 = add_tensor = None
|
mul_tensor = torch.ops.aten.mul.Tensor(a_1, a_1); a_1 = None
|
||||||
mul_tensor = torch.ops.aten.mul.Tensor(diagonal_scatter_default, diagonal_scatter_default); diagonal_scatter_default = None
|
return mul_tensor
|
||||||
|
""")
|
||||||
|
|
||||||
|
reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True)
|
||||||
|
self.assertExpectedInline(reinplaced_logs, """\
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, a_1):
|
||||||
|
empty = torch.ops.aten.empty.memory_format([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||||
|
fill_scalar = torch.ops.aten.fill_.Scalar(empty, 1.0)
|
||||||
|
clone_default = torch.ops.aten.clone.default(a_1)
|
||||||
|
diagonal_default = torch.ops.aten.diagonal.default(clone_default); clone_default = None
|
||||||
|
add_tensor = torch.ops.aten.add_.Tensor(diagonal_default, empty); diagonal_default = empty = None
|
||||||
|
mul_tensor = torch.ops.aten.mul.Tensor(a_1, a_1); a_1 = None
|
||||||
return mul_tensor
|
return mul_tensor
|
||||||
""")
|
""")
|
||||||
|
|
||||||
@ -296,6 +432,20 @@ def forward(self, a_1):
|
|||||||
return x
|
return x
|
||||||
x = torch.ones(2, 2)
|
x = torch.ones(2, 2)
|
||||||
self.assert_functionalization(f, x)
|
self.assert_functionalization(f, x)
|
||||||
|
logs = self.get_logs(f, torch.ones(2, 2))
|
||||||
|
self.assertExpectedInline(logs, """\
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, a_1):
|
||||||
|
empty = torch.ops.aten.empty.memory_format([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||||
|
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
|
||||||
|
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(a_1)
|
||||||
|
add_tensor = torch.ops.aten.add.Tensor(diagonal_copy_default, fill_scalar); diagonal_copy_default = fill_scalar = None
|
||||||
|
diagonal_scatter_default = torch.ops.aten.diagonal_scatter.default(a_1, add_tensor); add_tensor = None
|
||||||
|
copy__default = torch.ops.aten.copy_.default(a_1, diagonal_scatter_default); a_1 = None
|
||||||
|
return diagonal_scatter_default
|
||||||
|
""")
|
||||||
|
|
||||||
def test_split(self):
|
def test_split(self):
|
||||||
def f(x):
|
def f(x):
|
||||||
@ -324,8 +474,9 @@ def forward(self, a_1):
|
|||||||
getitem_2 = split_copy_tensor_1[0]
|
getitem_2 = split_copy_tensor_1[0]
|
||||||
getitem_3 = split_copy_tensor_1[1]; split_copy_tensor_1 = None
|
getitem_3 = split_copy_tensor_1[1]; split_copy_tensor_1 = None
|
||||||
diagonal_scatter_default = torch.ops.aten.diagonal_scatter.default(getitem_3, add_tensor); getitem_3 = None
|
diagonal_scatter_default = torch.ops.aten.diagonal_scatter.default(getitem_3, add_tensor); getitem_3 = None
|
||||||
slice_scatter_default = torch.ops.aten.slice_scatter.default(a_1, diagonal_scatter_default, 0, 2, 4); a_1 = diagonal_scatter_default = None
|
slice_scatter_default = torch.ops.aten.slice_scatter.default(a_1, diagonal_scatter_default, 0, 2, 4); diagonal_scatter_default = None
|
||||||
mul_tensor = torch.ops.aten.mul.Tensor(slice_scatter_default, slice_scatter_default); slice_scatter_default = None
|
mul_tensor = torch.ops.aten.mul.Tensor(slice_scatter_default, slice_scatter_default)
|
||||||
|
copy__default = torch.ops.aten.copy_.default(a_1, slice_scatter_default); a_1 = slice_scatter_default = None
|
||||||
return add_tensor
|
return add_tensor
|
||||||
""") # noqa: B950
|
""") # noqa: B950
|
||||||
|
|
||||||
@ -337,7 +488,7 @@ def forward(self, a_1):
|
|||||||
y = x[0]
|
y = x[0]
|
||||||
y.add_(tmp)
|
y.add_(tmp)
|
||||||
return x
|
return x
|
||||||
self.assert_functionalization(f, torch.ones(4, 2))
|
self.assert_functionalization(f, torch.ones(4, 2), mutated_input_metadata=True)
|
||||||
logs = self.get_logs(f, torch.ones(4, 2))
|
logs = self.get_logs(f, torch.ones(4, 2))
|
||||||
self.assertExpectedInline(logs, """\
|
self.assertExpectedInline(logs, """\
|
||||||
|
|
||||||
@ -372,13 +523,14 @@ def forward(self, a_1):
|
|||||||
|
|
||||||
|
|
||||||
def forward(self, a_1):
|
def forward(self, a_1):
|
||||||
view_copy_default = torch.ops.aten.view_copy.default(a_1, [8]); a_1 = None
|
view_copy_default = torch.ops.aten.view_copy.default(a_1, [8])
|
||||||
empty = torch.ops.aten.empty.memory_format([0], dtype = torch.int64, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
|
empty = torch.ops.aten.empty.memory_format([0], dtype = torch.int64, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
|
||||||
arange = torch.ops.aten.arange.start_step(0, 4, 1, dtype = torch.int64, layout = torch.strided, device = device(type='cpu'))
|
arange = torch.ops.aten.arange.start_step(0, 4, 1, dtype = torch.int64, layout = torch.strided, device = device(type='cpu'))
|
||||||
empty_1 = torch.ops.aten.empty.memory_format([0], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
|
empty_1 = torch.ops.aten.empty.memory_format([0], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
|
||||||
arange_1 = torch.ops.aten.arange.start_step(0, 4, 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
|
arange_1 = torch.ops.aten.arange.start_step(0, 4, 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
|
||||||
index_put_default = torch.ops.aten.index_put.default(view_copy_default, [arange], arange_1); view_copy_default = arange = arange_1 = None
|
index_put_default = torch.ops.aten.index_put.default(view_copy_default, [arange], arange_1); view_copy_default = arange = arange_1 = None
|
||||||
view_copy_default_1 = torch.ops.aten.view_copy.default(index_put_default, [4, 2])
|
view_copy_default_1 = torch.ops.aten.view_copy.default(index_put_default, [4, 2])
|
||||||
|
copy__default = torch.ops.aten.copy_.default(a_1, view_copy_default_1); a_1 = view_copy_default_1 = None
|
||||||
return index_put_default
|
return index_put_default
|
||||||
""") # noqa: B950
|
""") # noqa: B950
|
||||||
|
|
||||||
@ -400,11 +552,12 @@ def forward(self, a_1):
|
|||||||
def forward(self, a_1):
|
def forward(self, a_1):
|
||||||
empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||||
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
|
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
|
||||||
view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2]); a_1 = None
|
view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2])
|
||||||
add_tensor = torch.ops.aten.add.Tensor(view_copy_default, 1); view_copy_default = None
|
add_tensor = torch.ops.aten.add.Tensor(view_copy_default, 1); view_copy_default = None
|
||||||
mul_tensor = torch.ops.aten.mul.Tensor(add_tensor, 2)
|
mul_tensor = torch.ops.aten.mul.Tensor(add_tensor, 2)
|
||||||
div_tensor = torch.ops.aten.div.Tensor(mul_tensor, 1); mul_tensor = None
|
div_tensor = torch.ops.aten.div.Tensor(mul_tensor, 1); mul_tensor = None
|
||||||
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4, 2]); add_tensor = None
|
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4, 2]); add_tensor = None
|
||||||
|
copy__default = torch.ops.aten.copy_.default(a_1, view_copy_default_1); a_1 = view_copy_default_1 = None
|
||||||
return div_tensor
|
return div_tensor
|
||||||
""")
|
""")
|
||||||
|
|
||||||
@ -413,7 +566,9 @@ def forward(self, a_1):
|
|||||||
def f(x):
|
def f(x):
|
||||||
# ops like ge_() are allowed to change the dtype of the input.
|
# ops like ge_() are allowed to change the dtype of the input.
|
||||||
# functionalization should pick up on that.
|
# functionalization should pick up on that.
|
||||||
return x.ge_(0)
|
y = x.clone()
|
||||||
|
out = y.ge_(0)
|
||||||
|
return out
|
||||||
self.assert_functionalization(f, torch.ones(4, 2))
|
self.assert_functionalization(f, torch.ones(4, 2))
|
||||||
logs = self.get_logs(f, torch.ones(4, 2))
|
logs = self.get_logs(f, torch.ones(4, 2))
|
||||||
self.assertExpectedInline(logs, """\
|
self.assertExpectedInline(logs, """\
|
||||||
@ -421,11 +576,24 @@ def forward(self, a_1):
|
|||||||
|
|
||||||
|
|
||||||
def forward(self, a_1):
|
def forward(self, a_1):
|
||||||
ge_scalar = torch.ops.aten.ge.Scalar(a_1, 0); a_1 = None
|
clone_default = torch.ops.aten.clone.default(a_1); a_1 = None
|
||||||
to_dtype_layout = torch.ops.aten.to.dtype_layout(ge_scalar, dtype = torch.float32, layout = torch.strided); ge_scalar = None
|
ge_scalar = torch.ops.aten.ge.Scalar(clone_default, 0); clone_default = None
|
||||||
return to_dtype_layout
|
_to_copy_default = torch.ops.aten._to_copy.default(ge_scalar, dtype = torch.float32, layout = torch.strided); ge_scalar = None
|
||||||
|
return _to_copy_default
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True)
|
||||||
|
self.assertExpectedInline(reinplaced_logs, """\
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, a_1):
|
||||||
|
clone_default = torch.ops.aten.clone.default(a_1); a_1 = None
|
||||||
|
ge_scalar = torch.ops.aten.ge_.Scalar(clone_default, 0)
|
||||||
|
_to_copy_default = torch.ops.aten._to_copy.default(clone_default, dtype = torch.float32, layout = torch.strided); clone_default = None
|
||||||
|
return _to_copy_default
|
||||||
|
""") # noqa: B950
|
||||||
|
|
||||||
@skipIfTorchDynamo("Test does not work with TorchDynamo")
|
@skipIfTorchDynamo("Test does not work with TorchDynamo")
|
||||||
def test_metadata_change_out_op(self):
|
def test_metadata_change_out_op(self):
|
||||||
def f(t, y):
|
def f(t, y):
|
||||||
@ -514,6 +682,44 @@ def forward(self, a_1):
|
|||||||
return add_tensor_1
|
return add_tensor_1
|
||||||
""") # noqa: B950
|
""") # noqa: B950
|
||||||
|
|
||||||
|
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True)
|
||||||
|
self.assertExpectedInline(reinplaced_logs, """\
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, a_1):
|
||||||
|
empty = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||||
|
fill_scalar = torch.ops.aten.fill_.Scalar(empty, 1.0)
|
||||||
|
add_tensor = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
|
||||||
|
view_default = torch.ops.aten.view.default(add_tensor, [8])
|
||||||
|
_reshape_alias_default = torch.ops.aten._reshape_alias.default(view_default, [2, 4], [4, 1]); view_default = None
|
||||||
|
transpose_int = torch.ops.aten.transpose.int(_reshape_alias_default, 1, 0)
|
||||||
|
unsqueeze_default = torch.ops.aten.unsqueeze.default(transpose_int, 0); transpose_int = None
|
||||||
|
squeeze_default = torch.ops.aten.squeeze.default(unsqueeze_default); unsqueeze_default = None
|
||||||
|
split_tensor = torch.ops.aten.split.Tensor(squeeze_default, 2); squeeze_default = None
|
||||||
|
getitem = split_tensor[0]
|
||||||
|
getitem_1 = split_tensor[1]; split_tensor = None
|
||||||
|
add_tensor_1 = torch.ops.aten.add_.Tensor(getitem, empty); empty = None
|
||||||
|
select_int = torch.ops.aten.select.int(_reshape_alias_default, 0, 0); _reshape_alias_default = None
|
||||||
|
clone_default = torch.ops.aten.clone.default(getitem, memory_format = torch.contiguous_format)
|
||||||
|
_unsafe_view_default = torch.ops.aten._unsafe_view.default(clone_default, [4]); clone_default = None
|
||||||
|
view_default_1 = torch.ops.aten.view.default(add_tensor, [8]); add_tensor = None
|
||||||
|
_reshape_alias_default_1 = torch.ops.aten._reshape_alias.default(view_default_1, [2, 4], [4, 1]); view_default_1 = None
|
||||||
|
transpose_int_1 = torch.ops.aten.transpose.int(_reshape_alias_default_1, 1, 0); _reshape_alias_default_1 = None
|
||||||
|
unsqueeze_default_1 = torch.ops.aten.unsqueeze.default(transpose_int_1, 0); transpose_int_1 = None
|
||||||
|
squeeze_default_1 = torch.ops.aten.squeeze.default(unsqueeze_default_1); unsqueeze_default_1 = None
|
||||||
|
unsqueeze_default_2 = torch.ops.aten.unsqueeze.default(squeeze_default_1, 0); squeeze_default_1 = None
|
||||||
|
squeeze_dim = torch.ops.aten.squeeze.dim(unsqueeze_default_2, 0); unsqueeze_default_2 = None
|
||||||
|
transpose_int_2 = torch.ops.aten.transpose.int(squeeze_dim, 1, 0); squeeze_dim = None
|
||||||
|
_reshape_alias_default_2 = torch.ops.aten._reshape_alias.default(transpose_int_2, [8], [1]); transpose_int_2 = None
|
||||||
|
view_default_2 = torch.ops.aten.view.default(_reshape_alias_default_2, [4, 2]); _reshape_alias_default_2 = None
|
||||||
|
view_default_3 = torch.ops.aten.view.default(view_default_2, [8]); view_default_2 = None
|
||||||
|
_reshape_alias_default_3 = torch.ops.aten._reshape_alias.default(view_default_3, [2, 4], [4, 1]); view_default_3 = None
|
||||||
|
select_int_1 = torch.ops.aten.select.int(_reshape_alias_default_3, 0, 0); _reshape_alias_default_3 = None
|
||||||
|
add_tensor_2 = torch.ops.aten.add.Tensor(select_int_1, _unsafe_view_default); select_int_1 = _unsafe_view_default = None
|
||||||
|
return getitem
|
||||||
|
""")
|
||||||
|
|
||||||
def test_reapply_views_simple(self):
|
def test_reapply_views_simple(self):
|
||||||
def f(x):
|
def f(x):
|
||||||
tmp = torch.ones(4, 2)
|
tmp = torch.ones(4, 2)
|
||||||
@ -530,10 +736,11 @@ def forward(self, a_1):
|
|||||||
def forward(self, a_1):
|
def forward(self, a_1):
|
||||||
empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||||
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
|
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
|
||||||
view_default = torch.ops.aten.view.default(a_1, [4, 2]); a_1 = None
|
view_default = torch.ops.aten.view.default(a_1, [4, 2])
|
||||||
add_tensor = torch.ops.aten.add.Tensor(view_default, fill_scalar); view_default = fill_scalar = None
|
add_tensor = torch.ops.aten.add.Tensor(view_default, fill_scalar); view_default = fill_scalar = None
|
||||||
view_default_1 = torch.ops.aten.view.default(add_tensor, [4, 2])
|
view_default_1 = torch.ops.aten.view.default(add_tensor, [4, 2])
|
||||||
mul_tensor = torch.ops.aten.mul.Tensor(view_default_1, view_default_1); view_default_1 = None
|
mul_tensor = torch.ops.aten.mul.Tensor(view_default_1, view_default_1)
|
||||||
|
copy__default = torch.ops.aten.copy_.default(a_1, view_default_1); a_1 = view_default_1 = None
|
||||||
return add_tensor
|
return add_tensor
|
||||||
""")
|
""")
|
||||||
|
|
||||||
@ -564,8 +771,6 @@ def forward(self, a_1):
|
|||||||
def test_copy_(self):
|
def test_copy_(self):
|
||||||
def f(x):
|
def f(x):
|
||||||
tmp = torch.zeros(2, 2)
|
tmp = torch.zeros(2, 2)
|
||||||
# NOTE: LoggingTensor isn't a mode, which means that the diagonal call
|
|
||||||
# will not be logged. This is fine for testing.
|
|
||||||
tmp_slice = tmp.diagonal()
|
tmp_slice = tmp.diagonal()
|
||||||
y = tmp_slice.copy_(x)
|
y = tmp_slice.copy_(x)
|
||||||
z = y.add_(x)
|
z = y.add_(x)
|
||||||
@ -587,6 +792,21 @@ def forward(self, a_1):
|
|||||||
copy_default = torch.ops.aten.copy.default(diagonal_copy_default_1, a_1); diagonal_copy_default_1 = None
|
copy_default = torch.ops.aten.copy.default(diagonal_copy_default_1, a_1); diagonal_copy_default_1 = None
|
||||||
add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None
|
add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None
|
||||||
return add_tensor
|
return add_tensor
|
||||||
|
""")
|
||||||
|
|
||||||
|
reinplaced_logs = self.get_logs(f, torch.ones(2), reapply_views=True, run_reinplace=True)
|
||||||
|
self.assertExpectedInline(reinplaced_logs, """\
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, a_1):
|
||||||
|
empty = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||||
|
zero_default = torch.ops.aten.zero_.default(empty)
|
||||||
|
diagonal_default = torch.ops.aten.diagonal.default(empty)
|
||||||
|
diagonal_default_1 = torch.ops.aten.diagonal.default(empty); empty = None
|
||||||
|
copy_default = torch.ops.aten.copy_.default(diagonal_default_1, a_1)
|
||||||
|
add_tensor = torch.ops.aten.add_.Tensor(diagonal_default_1, a_1); a_1 = None
|
||||||
|
return diagonal_default_1
|
||||||
""")
|
""")
|
||||||
|
|
||||||
# Test 2: copy_() with same dtype, different shape
|
# Test 2: copy_() with same dtype, different shape
|
||||||
@ -604,6 +824,21 @@ def forward(self, a_1):
|
|||||||
copy_default = torch.ops.aten.copy.default(diagonal_copy_default_1, a_1); diagonal_copy_default_1 = None
|
copy_default = torch.ops.aten.copy.default(diagonal_copy_default_1, a_1); diagonal_copy_default_1 = None
|
||||||
add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None
|
add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None
|
||||||
return add_tensor
|
return add_tensor
|
||||||
|
""")
|
||||||
|
|
||||||
|
reinplaced_logs = self.get_logs(f, torch.ones(1), reapply_views=True, run_reinplace=True)
|
||||||
|
self.assertExpectedInline(reinplaced_logs, """\
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, a_1):
|
||||||
|
empty = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||||
|
zero_default = torch.ops.aten.zero_.default(empty)
|
||||||
|
diagonal_default = torch.ops.aten.diagonal.default(empty)
|
||||||
|
diagonal_default_1 = torch.ops.aten.diagonal.default(empty); empty = None
|
||||||
|
copy_default = torch.ops.aten.copy_.default(diagonal_default_1, a_1)
|
||||||
|
add_tensor = torch.ops.aten.add_.Tensor(diagonal_default_1, a_1); a_1 = None
|
||||||
|
return diagonal_default_1
|
||||||
""")
|
""")
|
||||||
|
|
||||||
# Test 3: copy_() with different dtype, same shape
|
# Test 3: copy_() with different dtype, same shape
|
||||||
@ -621,6 +856,21 @@ def forward(self, a_1):
|
|||||||
copy_default = torch.ops.aten.copy.default(diagonal_copy_default_1, a_1); diagonal_copy_default_1 = None
|
copy_default = torch.ops.aten.copy.default(diagonal_copy_default_1, a_1); diagonal_copy_default_1 = None
|
||||||
add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None
|
add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None
|
||||||
return add_tensor
|
return add_tensor
|
||||||
|
""")
|
||||||
|
|
||||||
|
reinplaced_logs = self.get_logs(f, torch.ones(2, dtype=torch.long), reapply_views=True, run_reinplace=True)
|
||||||
|
self.assertExpectedInline(reinplaced_logs, """\
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, a_1):
|
||||||
|
empty = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||||
|
zero_default = torch.ops.aten.zero_.default(empty)
|
||||||
|
diagonal_default = torch.ops.aten.diagonal.default(empty)
|
||||||
|
diagonal_default_1 = torch.ops.aten.diagonal.default(empty); empty = None
|
||||||
|
copy_default = torch.ops.aten.copy_.default(diagonal_default_1, a_1)
|
||||||
|
add_tensor = torch.ops.aten.add_.Tensor(diagonal_default_1, a_1); a_1 = None
|
||||||
|
return diagonal_default_1
|
||||||
""")
|
""")
|
||||||
|
|
||||||
# Test 4: copy_() with different dtype, different shape
|
# Test 4: copy_() with different dtype, different shape
|
||||||
@ -638,6 +888,21 @@ def forward(self, a_1):
|
|||||||
copy_default = torch.ops.aten.copy.default(diagonal_copy_default_1, a_1); diagonal_copy_default_1 = None
|
copy_default = torch.ops.aten.copy.default(diagonal_copy_default_1, a_1); diagonal_copy_default_1 = None
|
||||||
add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None
|
add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None
|
||||||
return add_tensor
|
return add_tensor
|
||||||
|
""")
|
||||||
|
|
||||||
|
reinplaced_logs = self.get_logs(f, torch.ones(1, dtype=torch.long), reapply_views=True, run_reinplace=True)
|
||||||
|
self.assertExpectedInline(reinplaced_logs, """\
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, a_1):
|
||||||
|
empty = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||||
|
zero_default = torch.ops.aten.zero_.default(empty)
|
||||||
|
diagonal_default = torch.ops.aten.diagonal.default(empty)
|
||||||
|
diagonal_default_1 = torch.ops.aten.diagonal.default(empty); empty = None
|
||||||
|
copy_default = torch.ops.aten.copy_.default(diagonal_default_1, a_1)
|
||||||
|
add_tensor = torch.ops.aten.add_.Tensor(diagonal_default_1, a_1); a_1 = None
|
||||||
|
return diagonal_default_1
|
||||||
""")
|
""")
|
||||||
|
|
||||||
def test_expand_symint(self):
|
def test_expand_symint(self):
|
||||||
@ -676,6 +941,18 @@ def forward(self, a_1):
|
|||||||
fill_scalar = torch.ops.aten.fill.Scalar(diagonal_copy_default, 0); diagonal_copy_default = None
|
fill_scalar = torch.ops.aten.fill.Scalar(diagonal_copy_default, 0); diagonal_copy_default = None
|
||||||
diagonal_scatter_default = torch.ops.aten.diagonal_scatter.default(add_tensor, fill_scalar); add_tensor = fill_scalar = None
|
diagonal_scatter_default = torch.ops.aten.diagonal_scatter.default(add_tensor, fill_scalar); add_tensor = fill_scalar = None
|
||||||
return diagonal_scatter_default
|
return diagonal_scatter_default
|
||||||
|
""")
|
||||||
|
|
||||||
|
reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True)
|
||||||
|
self.assertExpectedInline(reinplaced_logs, """\
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, a_1):
|
||||||
|
add_tensor = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
|
||||||
|
diagonal_default = torch.ops.aten.diagonal.default(add_tensor)
|
||||||
|
fill_scalar = torch.ops.aten.fill_.Scalar(diagonal_default, 0); diagonal_default = None
|
||||||
|
return add_tensor
|
||||||
""")
|
""")
|
||||||
|
|
||||||
def test_resize_smaller(self):
|
def test_resize_smaller(self):
|
||||||
@ -713,6 +990,28 @@ def forward(self, a_1):
|
|||||||
return add_tensor_2
|
return add_tensor_2
|
||||||
""") # noqa: B950
|
""") # noqa: B950
|
||||||
|
|
||||||
|
reinplaced_logs = self.get_logs(f, torch.ones(8, 2), reapply_views=True, run_reinplace=True)
|
||||||
|
self.assertExpectedInline(reinplaced_logs, """\
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, a_1):
|
||||||
|
add_tensor = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None
|
||||||
|
view_default = torch.ops.aten.view.default(add_tensor, [4, 4])
|
||||||
|
resize_default = torch.ops.aten.resize.default(view_default, [3, 3])
|
||||||
|
as_strided_default = torch.ops.aten.as_strided.default(view_default, [3, 3], [3, 1]); view_default = None
|
||||||
|
view_default_1 = torch.ops.aten.view.default(as_strided_default, [-1]); as_strided_default = None
|
||||||
|
add_tensor_1 = torch.ops.aten.add_.Tensor(view_default_1, 1)
|
||||||
|
view_default_2 = torch.ops.aten.view.default(add_tensor, [4, 4]); add_tensor = None
|
||||||
|
as_strided_default_1 = torch.ops.aten.as_strided.default(view_default_2, [3, 3], [3, 1])
|
||||||
|
view_default_3 = torch.ops.aten.view.default(view_default_1, [3, 3]); view_default_1 = None
|
||||||
|
view_default_4 = torch.ops.aten.view.default(view_default_2, [8, 2]); view_default_2 = None
|
||||||
|
view_default_5 = torch.ops.aten.view.default(view_default_4, [4, 4]); view_default_4 = None
|
||||||
|
as_strided_default_2 = torch.ops.aten.as_strided.default(view_default_5, [3, 3], [3, 1]); view_default_5 = None
|
||||||
|
add_tensor_2 = torch.ops.aten.add_.Tensor(as_strided_default_2, 1)
|
||||||
|
return as_strided_default_2
|
||||||
|
""")
|
||||||
|
|
||||||
def test_resize_larger_valid(self):
|
def test_resize_larger_valid(self):
|
||||||
def f(x):
|
def f(x):
|
||||||
y = x + 1
|
y = x + 1
|
||||||
@ -744,6 +1043,21 @@ def forward(self, a_1):
|
|||||||
view_copy_default_1 = torch.ops.aten.view_copy.default(fill_scalar, [5, 5]); fill_scalar = None
|
view_copy_default_1 = torch.ops.aten.view_copy.default(fill_scalar, [5, 5]); fill_scalar = None
|
||||||
add_tensor_1 = torch.ops.aten.add.Tensor(view_copy_default_1, 1)
|
add_tensor_1 = torch.ops.aten.add.Tensor(view_copy_default_1, 1)
|
||||||
return (view_copy_default_1, add_tensor_1)
|
return (view_copy_default_1, add_tensor_1)
|
||||||
|
""")
|
||||||
|
|
||||||
|
reinplaced_logs = self.get_logs(f, torch.ones(8, 2), reapply_views=True, run_reinplace=True)
|
||||||
|
self.assertExpectedInline(reinplaced_logs, """\
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, a_1):
|
||||||
|
add_tensor = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None
|
||||||
|
resize_default = torch.ops.aten.resize_.default(add_tensor, [5, 5])
|
||||||
|
view_default = torch.ops.aten.view.default(add_tensor, [25]); add_tensor = None
|
||||||
|
fill_scalar = torch.ops.aten.fill_.Scalar(view_default, 1)
|
||||||
|
view_default_1 = torch.ops.aten.view.default(view_default, [5, 5]); view_default = None
|
||||||
|
add_tensor_1 = torch.ops.aten.add.Tensor(view_default_1, 1)
|
||||||
|
return (view_default_1, add_tensor_1)
|
||||||
""")
|
""")
|
||||||
|
|
||||||
def test_resize_larger_invalid(self):
|
def test_resize_larger_invalid(self):
|
||||||
@ -809,5 +1123,40 @@ $3 = torch._ops.aten.add.Tensor($2, 1)""")
|
|||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
x1_not_functional.add_(x2_functional)
|
x1_not_functional.add_(x2_functional)
|
||||||
|
|
||||||
|
def test_index_mutation_on_non_input(self):
|
||||||
|
def f(x):
|
||||||
|
tmp = torch.zeros(10)
|
||||||
|
tmp[5].fill_(1)
|
||||||
|
return tmp
|
||||||
|
self.assert_functionalization(f, torch.ones(2))
|
||||||
|
logs = self.get_logs(f, torch.ones(2))
|
||||||
|
self.assertExpectedInline(logs, """\
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, a_1):
|
||||||
|
empty = torch.ops.aten.empty.memory_format([10], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||||
|
zero_default = torch.ops.aten.zero.default(empty); empty = None
|
||||||
|
select_copy_int = torch.ops.aten.select_copy.int(zero_default, 0, 5)
|
||||||
|
select_copy_int_1 = torch.ops.aten.select_copy.int(zero_default, 0, 5)
|
||||||
|
fill_scalar = torch.ops.aten.fill.Scalar(select_copy_int_1, 1); select_copy_int_1 = None
|
||||||
|
select_scatter_default = torch.ops.aten.select_scatter.default(zero_default, fill_scalar, 0, 5); zero_default = fill_scalar = None
|
||||||
|
return select_scatter_default
|
||||||
|
""") # noqa: B950
|
||||||
|
|
||||||
|
reinplaced_logs = self.get_logs(f, torch.ones(2), reapply_views=True, run_reinplace=True)
|
||||||
|
self.assertExpectedInline(reinplaced_logs, """\
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, a_1):
|
||||||
|
empty = torch.ops.aten.empty.memory_format([10], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||||
|
zero_default = torch.ops.aten.zero_.default(empty)
|
||||||
|
select_int = torch.ops.aten.select.int(empty, 0, 5)
|
||||||
|
select_int_1 = torch.ops.aten.select.int(empty, 0, 5)
|
||||||
|
fill_scalar = torch.ops.aten.fill_.Scalar(select_int_1, 1); select_int_1 = None
|
||||||
|
return empty
|
||||||
|
""")
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
run_tests()
|
run_tests()
|
||||||
|
251
test/test_fx_reinplace_pass.py
Normal file
251
test/test_fx_reinplace_pass.py
Normal file
@ -0,0 +1,251 @@
|
|||||||
|
# Owner(s): ["module: functionalization"]
|
||||||
|
import torch
|
||||||
|
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||||
|
from torch.fx.passes.reinplace import reinplace
|
||||||
|
from torch.fx.experimental.proxy_tensor import make_fx
|
||||||
|
|
||||||
|
try:
|
||||||
|
from functorch.experimental import functionalize
|
||||||
|
HAS_FUNCTIONALIZATION = True
|
||||||
|
except Exception as e:
|
||||||
|
HAS_FUNCTIONALIZATION = False
|
||||||
|
|
||||||
|
class TestReinplacePass(TestCase):
|
||||||
|
|
||||||
|
def test_reinplace_basic(self):
|
||||||
|
# Basic test: the out-of-place add() call should be converted
|
||||||
|
# into add_()
|
||||||
|
def f(x):
|
||||||
|
a = x.clone()
|
||||||
|
b = a.add(1)
|
||||||
|
return b
|
||||||
|
|
||||||
|
inpt = torch.ones(2)
|
||||||
|
f2 = reinplace(make_fx(f)(inpt), inpt)
|
||||||
|
expected_out = f(inpt)
|
||||||
|
actual_out = f2(inpt)
|
||||||
|
self.assertEqual(actual_out, expected_out)
|
||||||
|
self.assertExpectedInline(f2.code, """\
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x_1):
|
||||||
|
clone_default = torch.ops.aten.clone.default(x_1); x_1 = None
|
||||||
|
add_tensor = torch.ops.aten.add_.Tensor(clone_default, 1)
|
||||||
|
return clone_default
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
def test_reinplace_with_view(self):
|
||||||
|
def f(x):
|
||||||
|
a = x.clone()
|
||||||
|
a_view = a.view(-1)
|
||||||
|
# We shouldn't re-inplace the first add(), because an alias of a is re-used later in the program
|
||||||
|
b = a.add(1)
|
||||||
|
# Second add() is fine to re-inplace
|
||||||
|
c = a_view.add(1)
|
||||||
|
return c
|
||||||
|
|
||||||
|
inpt = torch.ones(2)
|
||||||
|
f2 = reinplace(make_fx(f)(inpt), inpt)
|
||||||
|
expected_out = f(inpt)
|
||||||
|
actual_out = f2(inpt)
|
||||||
|
self.assertEqual(actual_out, expected_out)
|
||||||
|
self.assertExpectedInline(f2.code, """\
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x_1):
|
||||||
|
clone_default = torch.ops.aten.clone.default(x_1); x_1 = None
|
||||||
|
view_default = torch.ops.aten.view.default(clone_default, [-1])
|
||||||
|
add_tensor = torch.ops.aten.add.Tensor(clone_default, 1); clone_default = None
|
||||||
|
add_tensor_1 = torch.ops.aten.add_.Tensor(view_default, 1)
|
||||||
|
return view_default
|
||||||
|
""")
|
||||||
|
|
||||||
|
# This test won't actually run in CI, because it requires functionalize() from functorch.
|
||||||
|
# I'm planning on testing more comprehensively with torchbench models,
|
||||||
|
# but we can make this testing better once functorch moves into pytorch/pytorch.
|
||||||
|
def test_reinplace_scatter_op(self):
|
||||||
|
def f(a_):
|
||||||
|
# for now, don't test mutations to inputs
|
||||||
|
a = a_.clone()
|
||||||
|
e = a.view(-1)
|
||||||
|
b = a.view(-1)
|
||||||
|
c = b[0]
|
||||||
|
d = c.view(-1)
|
||||||
|
d.add_(1)
|
||||||
|
return a + e
|
||||||
|
|
||||||
|
if not HAS_FUNCTIONALIZATION:
|
||||||
|
return
|
||||||
|
inpt = torch.ones(4)
|
||||||
|
f2 = reinplace(make_fx(functionalize(f))(inpt), inpt)
|
||||||
|
expected_out = f(inpt)
|
||||||
|
actual_out = f2(inpt)
|
||||||
|
self.assertEqual(actual_out, expected_out)
|
||||||
|
# NOTE: one slight pessimization here is the fact that
|
||||||
|
# there are a bunch of redundant views in the graph.
|
||||||
|
# Technically, half of these views are duplicates that we could de-dup.
|
||||||
|
# This shouldn't really hurt performance though, since creating an extra view
|
||||||
|
# is effectively just moving some metadata around (and allocating a new TensorImpl).
|
||||||
|
# We can/should update the pass in the future to clean this up.
|
||||||
|
self.assertExpectedInline(f2.code, """\
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, a__1):
|
||||||
|
clone_default = torch.ops.aten.clone.default(a__1); a__1 = None
|
||||||
|
view_default = torch.ops.aten.view.default(clone_default, [-1])
|
||||||
|
view_default_1 = torch.ops.aten.view.default(clone_default, [-1])
|
||||||
|
select_int = torch.ops.aten.select.int(view_default_1, 0, 0); view_default_1 = None
|
||||||
|
view_default_2 = torch.ops.aten.view.default(select_int, [-1]); select_int = None
|
||||||
|
add_tensor = torch.ops.aten.add_.Tensor(view_default_2, 1)
|
||||||
|
view_default_3 = torch.ops.aten.view.default(clone_default, [-1]); clone_default = None
|
||||||
|
select_int_1 = torch.ops.aten.select.int(view_default_3, 0, 0)
|
||||||
|
view_default_4 = torch.ops.aten.view.default(view_default_2, []); view_default_2 = None
|
||||||
|
view_default_5 = torch.ops.aten.view.default(view_default_3, [4]); view_default_3 = None
|
||||||
|
view_default_6 = torch.ops.aten.view.default(view_default_5, [-1])
|
||||||
|
add_tensor_1 = torch.ops.aten.add_.Tensor(view_default_5, view_default_6); view_default_6 = None
|
||||||
|
return view_default_5
|
||||||
|
""")
|
||||||
|
|
||||||
|
def test_reinplace_scatter_twice(self):
|
||||||
|
def f(a_):
|
||||||
|
# for now, don't test mutations to inputs
|
||||||
|
a = a_.clone()
|
||||||
|
b = a[:, 1]
|
||||||
|
c = b[1]
|
||||||
|
c.add_(1)
|
||||||
|
return a
|
||||||
|
|
||||||
|
if not HAS_FUNCTIONALIZATION:
|
||||||
|
return
|
||||||
|
|
||||||
|
inpt = torch.ones(4, 4)
|
||||||
|
f2 = reinplace(make_fx(functionalize(f))(inpt), inpt)
|
||||||
|
expected_out = f(inpt)
|
||||||
|
actual_out = f2(inpt)
|
||||||
|
self.assertEqual(actual_out, expected_out)
|
||||||
|
self.assertExpectedInline(f2.code, """\
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, a__1):
|
||||||
|
clone_default = torch.ops.aten.clone.default(a__1); a__1 = None
|
||||||
|
slice_tensor = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807)
|
||||||
|
select_int = torch.ops.aten.select.int(slice_tensor, 1, 1); slice_tensor = None
|
||||||
|
select_int_1 = torch.ops.aten.select.int(select_int, 0, 1); select_int = None
|
||||||
|
add_tensor = torch.ops.aten.add_.Tensor(select_int_1, 1); select_int_1 = None
|
||||||
|
slice_tensor_1 = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807)
|
||||||
|
select_int_2 = torch.ops.aten.select.int(slice_tensor_1, 1, 1); slice_tensor_1 = None
|
||||||
|
return clone_default
|
||||||
|
""")
|
||||||
|
|
||||||
|
def test_reinplace_scatter_twice_with_different_view_op_valid(self):
|
||||||
|
def f(a_):
|
||||||
|
a = a_.clone()
|
||||||
|
b = a[:, 1]
|
||||||
|
c = b[1]
|
||||||
|
c_updated = c.add(1)
|
||||||
|
good_mirror_of_b = a.as_strided((4,), (4,), 1)
|
||||||
|
# good_mirror_of_b points to the same region of memory as b.
|
||||||
|
# and this scatter op below tries to scatter c_updated into the same region
|
||||||
|
# that c currently takes up.
|
||||||
|
# reinplacing logic checks this by confirming that:
|
||||||
|
# c_updated
|
||||||
|
# good_mirror_of_b.select(0, 1)
|
||||||
|
# have the same size/stride/storage_offset.
|
||||||
|
b_updated = torch.select_scatter(good_mirror_of_b, c_updated, 0, 1)
|
||||||
|
return b_updated
|
||||||
|
|
||||||
|
inpt = torch.ones(4, 4)
|
||||||
|
f2 = reinplace(make_fx(f)(inpt), inpt)
|
||||||
|
expected_out = f(inpt)
|
||||||
|
actual_out = f2(inpt)
|
||||||
|
self.assertEqual(actual_out, expected_out)
|
||||||
|
self.assertExpectedInline(f2.code, """\
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, a__1):
|
||||||
|
clone_default = torch.ops.aten.clone.default(a__1); a__1 = None
|
||||||
|
slice_tensor = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807)
|
||||||
|
select_int = torch.ops.aten.select.int(slice_tensor, 1, 1); slice_tensor = None
|
||||||
|
select_int_1 = torch.ops.aten.select.int(select_int, 0, 1); select_int = None
|
||||||
|
add_tensor = torch.ops.aten.add_.Tensor(select_int_1, 1); select_int_1 = None
|
||||||
|
as_strided_default = torch.ops.aten.as_strided.default(clone_default, [4], [4], 1); clone_default = None
|
||||||
|
return as_strided_default
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Test example where we have a scatter op, where the base tensor
|
||||||
|
# has the same size/stride/storage offset (even though it is a different view),
|
||||||
|
# making it valid to re-inplace
|
||||||
|
def test_reinplace_scatter_twice_with_different_view_op_invalid(self):
|
||||||
|
def f(a_):
|
||||||
|
a = a_.clone()
|
||||||
|
b = a[:, 1]
|
||||||
|
c = b[1]
|
||||||
|
c_updated = c.add(1)
|
||||||
|
good_mirror_of_b = a.as_strided((4,), (4,), 1)
|
||||||
|
# The first arg to select_scatter is an equivalent view to b.
|
||||||
|
# However, the select_scatter call below tries to put c_updated
|
||||||
|
# into a different slice of "b" than what "c" currently occupies.
|
||||||
|
#
|
||||||
|
b_updated = torch.select_scatter(good_mirror_of_b, c_updated, 0, 0)
|
||||||
|
return b_updated
|
||||||
|
|
||||||
|
inpt = torch.ones(4, 4)
|
||||||
|
f2 = reinplace(make_fx(f)(inpt), inpt)
|
||||||
|
expected_out = f(inpt)
|
||||||
|
actual_out = f2(inpt)
|
||||||
|
self.assertEqual(actual_out, expected_out)
|
||||||
|
self.assertExpectedInline(f2.code, """\
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, a__1):
|
||||||
|
clone_default = torch.ops.aten.clone.default(a__1); a__1 = None
|
||||||
|
slice_tensor = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807)
|
||||||
|
select_int = torch.ops.aten.select.int(slice_tensor, 1, 1); slice_tensor = None
|
||||||
|
select_int_1 = torch.ops.aten.select.int(select_int, 0, 1); select_int = None
|
||||||
|
add_tensor = torch.ops.aten.add.Tensor(select_int_1, 1); select_int_1 = None
|
||||||
|
as_strided_default = torch.ops.aten.as_strided.default(clone_default, [4], [4], 1); clone_default = None
|
||||||
|
select_scatter_default = torch.ops.aten.select_scatter.default(as_strided_default, add_tensor, 0, 0); as_strided_default = add_tensor = None
|
||||||
|
return select_scatter_default
|
||||||
|
""") # noqa: B950
|
||||||
|
|
||||||
|
def test_reinplace_scatter_twice_with_different_view_op_invalid2(self):
|
||||||
|
def f(a_):
|
||||||
|
a = a_.clone()
|
||||||
|
b = a[:, 1]
|
||||||
|
c = b[1]
|
||||||
|
c_updated = c.add(1)
|
||||||
|
bad_mirror_of_b = a.as_strided((4,), (4,), 0)
|
||||||
|
# The first arg to select_scatter points to a different than c's base.
|
||||||
|
# This makes it invalid to re-inplace.
|
||||||
|
b_updated = torch.select_scatter(bad_mirror_of_b, c_updated, 0, 1)
|
||||||
|
return b_updated
|
||||||
|
|
||||||
|
inpt = torch.ones(4, 4)
|
||||||
|
f2 = reinplace(make_fx(f)(inpt), inpt)
|
||||||
|
expected_out = f(inpt)
|
||||||
|
actual_out = f2(inpt)
|
||||||
|
# self.assertEqual(actual_out, expected_out)
|
||||||
|
self.assertExpectedInline(f2.code, """\
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, a__1):
|
||||||
|
clone_default = torch.ops.aten.clone.default(a__1); a__1 = None
|
||||||
|
slice_tensor = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807)
|
||||||
|
select_int = torch.ops.aten.select.int(slice_tensor, 1, 1); slice_tensor = None
|
||||||
|
select_int_1 = torch.ops.aten.select.int(select_int, 0, 1); select_int = None
|
||||||
|
add_tensor = torch.ops.aten.add.Tensor(select_int_1, 1); select_int_1 = None
|
||||||
|
as_strided_default = torch.ops.aten.as_strided.default(clone_default, [4], [4], 0); clone_default = None
|
||||||
|
select_scatter_default = torch.ops.aten.select_scatter.default(as_strided_default, add_tensor, 0, 1); as_strided_default = add_tensor = None
|
||||||
|
return select_scatter_default
|
||||||
|
""") # noqa: B950
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
run_tests()
|
@ -3,6 +3,7 @@ from . import graph_manipulation
|
|||||||
from . import net_min_base
|
from . import net_min_base
|
||||||
from . import operator_support
|
from . import operator_support
|
||||||
from . import param_fetch
|
from . import param_fetch
|
||||||
|
from . import reinplace
|
||||||
from . import shape_prop
|
from . import shape_prop
|
||||||
from . import split_module
|
from . import split_module
|
||||||
from . import split_utils
|
from . import split_utils
|
||||||
|
512
torch/fx/passes/reinplace.py
Normal file
512
torch/fx/passes/reinplace.py
Normal file
@ -0,0 +1,512 @@
|
|||||||
|
import torch
|
||||||
|
from torch.fx import Node
|
||||||
|
from torch.fx._compatibility import compatibility
|
||||||
|
from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor
|
||||||
|
from torch.utils._pytree import tree_map
|
||||||
|
from torch.multiprocessing.reductions import StorageWeakRef
|
||||||
|
|
||||||
|
import _operator
|
||||||
|
from enum import Enum
|
||||||
|
import itertools
|
||||||
|
from typing import Set, Dict
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
__all__ = ['reinplace']
|
||||||
|
|
||||||
|
class _ViewType(Enum):
|
||||||
|
NonView = 0
|
||||||
|
SingleOutputView = 1
|
||||||
|
MultiOutputView = 2
|
||||||
|
|
||||||
|
def _is_view_op(tgt):
|
||||||
|
if tgt is not None and isinstance(tgt, torch._ops.OpOverload):
|
||||||
|
schema = tgt._schema
|
||||||
|
if len(schema.arguments) > 0:
|
||||||
|
first_arg = schema.arguments[0]
|
||||||
|
# check if op is a view
|
||||||
|
return first_arg.alias_info is not None and not first_arg.alias_info.is_write
|
||||||
|
|
||||||
|
def _get_view_type(tgt) -> _ViewType:
|
||||||
|
if tgt is not None and isinstance(tgt, torch._ops.OpOverload):
|
||||||
|
schema = tgt._schema
|
||||||
|
if len(schema.arguments) > 0:
|
||||||
|
first_arg = schema.arguments[0]
|
||||||
|
# check if op is a view
|
||||||
|
if first_arg.alias_info is not None and not first_arg.alias_info.is_write:
|
||||||
|
# check if op is a multi-output view
|
||||||
|
if '*' in first_arg.alias_info.after_set:
|
||||||
|
return _ViewType.MultiOutputView
|
||||||
|
else:
|
||||||
|
return _ViewType.SingleOutputView
|
||||||
|
return _ViewType.NonView
|
||||||
|
|
||||||
|
|
||||||
|
# Stores a bunch of metadata related to functionalization each node.
|
||||||
|
# Relevant metadata:
|
||||||
|
# n.meta['fake_result']: FakeTensor (same type as the output of the node, but with FakeTenors instead of Tensors)
|
||||||
|
# The fake tensor output from running the current node
|
||||||
|
# n.meta['view_of']: Node
|
||||||
|
# If the current node n is a view of some base tensor, the 'view_of' field tells us which
|
||||||
|
# view node was used to generate the current node (a view tensor).
|
||||||
|
# This information actually makes `fake_result` redundant, but we can use `fake_result`
|
||||||
|
# to sanity check that our aliasing information is correct.
|
||||||
|
@compatibility(is_backward_compatible=False)
|
||||||
|
class _FunctionalizationMetadataProp(torch.fx.Interpreter):
|
||||||
|
|
||||||
|
def run_node(self, node: Node):
|
||||||
|
self.node_counter += 1
|
||||||
|
result = super().run_node(node)
|
||||||
|
node.meta['fake_result'] = result
|
||||||
|
node.meta['node_idx'] = self.node_counter
|
||||||
|
|
||||||
|
# (1) Update metadata with the list of nodes that are used by this node
|
||||||
|
# copy_() doesn't read from its first argument; it writes to it, overwriting previous data.
|
||||||
|
# We don't want to treat it as "being used as an input".
|
||||||
|
node_args = node.args
|
||||||
|
if node.target is torch.ops.aten.copy_.default:
|
||||||
|
node_args = node_args[1:]
|
||||||
|
|
||||||
|
# (2) Update metadata to track aliasing information about view tensor nodes.
|
||||||
|
if node.op == 'call_function':
|
||||||
|
view_type = _get_view_type(node.target)
|
||||||
|
if view_type == _ViewType.SingleOutputView:
|
||||||
|
assert isinstance(node.args[0], Node)
|
||||||
|
node.meta['view_of'] = node.args[0]
|
||||||
|
elif view_type == _ViewType.MultiOutputView:
|
||||||
|
self.multi_output_view_nodes[node] = node.args[0]
|
||||||
|
|
||||||
|
# Check if we returned a multi-output view,
|
||||||
|
# and we're now grabbing the individual views from the output.
|
||||||
|
#
|
||||||
|
# For multi-output views, we want to map each output view to the base,
|
||||||
|
# but this mapping involves two separate nodes in FX IR.
|
||||||
|
# e.g. "a, b = x_1.split(...)" becomes:
|
||||||
|
# %split_tensor : [#users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%x_1, 2), kwargs = {})
|
||||||
|
# %getitem : [#users=1] = call_function[target=operator.getitem](args = (%split_tensor, 0), kwargs = {})
|
||||||
|
# %getitem_1 : [#users=1] = call_function[target=operator.getitem](args = (%split_tensor, 1), kwargs = {})
|
||||||
|
# And we'd like to set:
|
||||||
|
# getitem1.meta['view_of'] = x_1
|
||||||
|
elif node.target is _operator.getitem:
|
||||||
|
list_arg = node.args[0]
|
||||||
|
maybe_base_of_view = self.multi_output_view_nodes.get(list_arg, None)
|
||||||
|
if maybe_base_of_view is not None:
|
||||||
|
# Note: we could also track indexing info here for multi-output views.
|
||||||
|
# I don't think this metadata is strictly needed for de-functionalization.
|
||||||
|
assert isinstance(maybe_base_of_view, Node)
|
||||||
|
node.meta['view_of'] = maybe_base_of_view
|
||||||
|
|
||||||
|
if 'view_of' in node.meta:
|
||||||
|
# We're linking the current node with its first argument as views.
|
||||||
|
# Assert here that this is actually the case, and their storages are the same.
|
||||||
|
assert isinstance(node.meta['fake_result'], FakeTensor)
|
||||||
|
assert isinstance(node.meta['view_of'].meta['fake_result'], FakeTensor)
|
||||||
|
view_storage = StorageWeakRef(node.meta['fake_result'].storage())
|
||||||
|
base_storage = StorageWeakRef(node.meta['view_of'].meta['fake_result'].storage())
|
||||||
|
assert view_storage == base_storage
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def propagate(self, *args):
|
||||||
|
self.multi_output_view_nodes = {}
|
||||||
|
self.node_counter = -1
|
||||||
|
with FakeTensorMode.push() as mode:
|
||||||
|
fake_args = [mode.from_tensor(a) for a in args]
|
||||||
|
return super().run(*fake_args)
|
||||||
|
|
||||||
|
def _schemas_match(functional_schema, inplace_schema):
|
||||||
|
names_match = inplace_schema.name.endswith("_") and inplace_schema.name[:-1] == functional_schema.name
|
||||||
|
arg_types_match = len(functional_schema.arguments) == len(inplace_schema.arguments) and all(
|
||||||
|
a1.type == a2.type for a1, a2 in zip(functional_schema.arguments, inplace_schema.arguments))
|
||||||
|
# for the inplace op, its first argument should be mutable
|
||||||
|
assert inplace_schema.arguments[0].alias_info is not None and inplace_schema.arguments[0].alias_info.is_write
|
||||||
|
# and its remaining arguments shouldn't be.
|
||||||
|
assert all(a.alias_info is None for a in inplace_schema.arguments[1:])
|
||||||
|
return names_match and arg_types_match
|
||||||
|
|
||||||
|
# TODO: this should be beefed up to be able to properly re-inplace with:
|
||||||
|
# - mutating ops (e.g. _fused_moving_avg_obs_fq_helper)
|
||||||
|
# - out= ops (e.g. angle -> angle.out)
|
||||||
|
# TODO: we should also figure this info out using torchgen.
|
||||||
|
def _maybe_get_inplace_op(op):
|
||||||
|
# __module__ seems broken; it returns torch._ops.aten which doesn't exist
|
||||||
|
if not isinstance(op, torch._ops.OpOverload):
|
||||||
|
return None
|
||||||
|
# Some view ops have inplace variants (as_strided_, etc),
|
||||||
|
# but we do NOT want the reinplacing pass to directly add these into the program.
|
||||||
|
# (they'll require extra special handling, aren't aren't really useful for perf anyway)
|
||||||
|
if _is_view_op(op):
|
||||||
|
return None
|
||||||
|
op_namespace = op.__module__.split(".")[-1]
|
||||||
|
op_base_name = op.overloadpacket.__name__
|
||||||
|
maybe_namespace_module = getattr(torch.ops, op_namespace)
|
||||||
|
maybe_inplace_op = None if maybe_namespace_module is None else getattr(maybe_namespace_module, f'{op_base_name}_', None)
|
||||||
|
if maybe_inplace_op is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
inplace_overloads = [
|
||||||
|
getattr(maybe_inplace_op, overload_name) for overload_name in maybe_inplace_op.overloads()
|
||||||
|
]
|
||||||
|
inplace_overloads_with_matching_schemas = [
|
||||||
|
f
|
||||||
|
for f in inplace_overloads
|
||||||
|
if _schemas_match(op._schema, f._schema)
|
||||||
|
]
|
||||||
|
# This is for sanity: if foo() and foo_() are both operators,
|
||||||
|
# we expect them to have compatible schemas.
|
||||||
|
# (This is asserted by codegen for ATen, but might not be true
|
||||||
|
# for other arbitrary operators).
|
||||||
|
assert len(inplace_overloads_with_matching_schemas) == 1
|
||||||
|
inplace_op = inplace_overloads_with_matching_schemas[0]
|
||||||
|
return inplace_op
|
||||||
|
|
||||||
|
_VIEW_INVERSE_MAP = {
|
||||||
|
torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default,
|
||||||
|
torch.ops.aten.select_scatter.default: torch.ops.aten.select.int,
|
||||||
|
torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor,
|
||||||
|
torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default,
|
||||||
|
}
|
||||||
|
|
||||||
|
# This function, given a set of set of (aliased) tensor nodes,
|
||||||
|
# Returns any nodes in the graph that *use* any of the aliases, that occur *after* op_index
|
||||||
|
# in the node ordering.
|
||||||
|
def _get_all_later_node_usages(tensor_aliases: Set[Node], op_index: int):
|
||||||
|
def _add_if_tensor(x, set_):
|
||||||
|
if isinstance(x, FakeTensor):
|
||||||
|
set_.add(StorageWeakRef(x.storage()))
|
||||||
|
|
||||||
|
nodes_used_after = set()
|
||||||
|
for t in tensor_aliases:
|
||||||
|
# get all nodes that use the current alias
|
||||||
|
usage_nodes = t.users
|
||||||
|
for n in usage_nodes:
|
||||||
|
# We only care about usages after the current node
|
||||||
|
if n.meta['node_idx'] <= op_index:
|
||||||
|
continue
|
||||||
|
# We also don't care about intermediate view ops.
|
||||||
|
# They only matter if their output is then used elsewhere
|
||||||
|
# (either in an out-of-place op, or as an output to the function).
|
||||||
|
if n in tensor_aliases:
|
||||||
|
if isinstance(n.target, torch._ops.OpOverload) or n.target == _operator.getitem:
|
||||||
|
continue
|
||||||
|
nodes_used_after.add(n)
|
||||||
|
return nodes_used_after
|
||||||
|
|
||||||
|
# Given an op that we're trying to re-inplace, "b = foo(a)",
|
||||||
|
# And given a {view}_scatter op that shows up later in the graph, "y = {view}_scatter(base, x, args...)"
|
||||||
|
# Then re-inplacing `foo()` would allow us to remove the `{view}_scatter` op entirely, IF:
|
||||||
|
# If there are any aliases in the alias_set(a) that satisfy:
|
||||||
|
# (1) The base of "alias", "alias_base", has the same size/stride/offset metadata as "base"
|
||||||
|
# (2) The output of running {view}(alias, args...) gives you the same size/stride/offset metadata
|
||||||
|
# as "alias"
|
||||||
|
def _get_view_inverse_node_usages(later_node_usages: Set[Node], self_aliases: Set[Node]) -> Set[Node]:
|
||||||
|
def matching_view_metadata(a, b):
|
||||||
|
return a.size() == b.size() and \
|
||||||
|
a.stride() == b.stride() and \
|
||||||
|
a.storage_offset() == b.storage_offset()
|
||||||
|
|
||||||
|
view_inverse_nodes = set()
|
||||||
|
# Go through them in node order, so we can see chains of view_scatter ops.
|
||||||
|
for n in sorted(later_node_usages, key=lambda x: x.meta['node_idx']):
|
||||||
|
if n.target not in _VIEW_INVERSE_MAP:
|
||||||
|
continue
|
||||||
|
base = n.args[0]
|
||||||
|
mutated_view = n.args[1]
|
||||||
|
assert isinstance(base, Node)
|
||||||
|
assert isinstance(base.meta['fake_result'], FakeTensor)
|
||||||
|
assert isinstance(mutated_view, Node)
|
||||||
|
assert isinstance(mutated_view.meta['fake_result'], FakeTensor)
|
||||||
|
# Check that this view_inverse op actually corresponds to taking doing the inverse
|
||||||
|
# of one of our existing self_alias nodes.
|
||||||
|
original_view = _VIEW_INVERSE_MAP[n.target]
|
||||||
|
for self_alias in self_aliases:
|
||||||
|
# We're looking for some alias of the self arg, "alias",
|
||||||
|
# that was created from some op `alias = foo(base, args...)`
|
||||||
|
# such that the current _scatter op "inverts" that foo call.
|
||||||
|
# We can check that by running the original op again, and checking that the strides match.
|
||||||
|
if 'view_of' not in self_alias.meta:
|
||||||
|
continue
|
||||||
|
self_alias_base = self_alias.meta['view_of']
|
||||||
|
try:
|
||||||
|
# The we're trying to re-use the args from the view_scatter call inside of the corresponding
|
||||||
|
# view op, which might throw. This just indicates that view_scatter op isn't a valid inverse
|
||||||
|
# of the current alias we're looking at.
|
||||||
|
view_replay_metadata = original_view(self_alias_base.meta['fake_result'], *n.args[2:], **n.kwargs)
|
||||||
|
expected_metadata = self_alias.meta['fake_result']
|
||||||
|
# If the alias and its base both have matching metadata, then this view_scatter op is valid to re-inplace.
|
||||||
|
if matching_view_metadata(self_alias_base.meta['fake_result'], base.meta['fake_result']) and \
|
||||||
|
matching_view_metadata(view_replay_metadata, expected_metadata):
|
||||||
|
view_inverse_nodes.add(n)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return view_inverse_nodes
|
||||||
|
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=True)
|
||||||
|
def reinplace(gm, *sample_args):
|
||||||
|
"""
|
||||||
|
Given an fx.GraphModule, modifies it to perform "reinplacing",
|
||||||
|
mutating the nodes of the graph.
|
||||||
|
We look for out-of-place op call sites like `b = a.add(...)`,
|
||||||
|
and convert them to be inplace (`b = a.add_(...)`),
|
||||||
|
as long as the input to the current operator ("a") isn't re-used
|
||||||
|
anywhere later in the graph.
|
||||||
|
|
||||||
|
This pass currently expects to operate on a **functional, ATen** graph.
|
||||||
|
This can be obtained by running `make_fx(functionalize(f))`.
|
||||||
|
|
||||||
|
Sample inputs are needed to determine aliasing relationships of the inputs.
|
||||||
|
In general, we can't reinplace node `b = a.add(...)` if "a" aliases any of the
|
||||||
|
inputs to the program.
|
||||||
|
|
||||||
|
Given a node "b = foo(a, ...)", the algorithm for re-inplacing is as follows:
|
||||||
|
|
||||||
|
(1) Check if foo has a mutating variant. If not, move to the next node.
|
||||||
|
|
||||||
|
Note that we ignore view ops (we don't bother to turn `as_strided()`
|
||||||
|
into `as_strided_()`), as it complicates the algorithm and doesn't
|
||||||
|
provide meaningful speedups.
|
||||||
|
|
||||||
|
Currently, we also only check for an inplace op, `foo_`.
|
||||||
|
Later, we should beef this up to check for out= or mutable ops.
|
||||||
|
|
||||||
|
(2) Check if "a" is an alias of any of the program inputs.
|
||||||
|
|
||||||
|
If it is, skip and move to the next node.
|
||||||
|
Inplace'ing an op that would cause it to mutate a program is not sound,
|
||||||
|
because that would be a side effect visible to the user.
|
||||||
|
|
||||||
|
NOTE: there's a future optimization that we should make:
|
||||||
|
if "a" is a (alias of a) program input, but later in the program
|
||||||
|
there is a node that looks like "a.copy_(...)",
|
||||||
|
Then re-inplacing is ok to do - we are temporarily re-using a's buffer,
|
||||||
|
which will later be overwritten by the copy_() call.
|
||||||
|
|
||||||
|
This will be an important optimization to have for programs that mutate
|
||||||
|
their inputs. It currently isn't implemented though.
|
||||||
|
|
||||||
|
(3) Check that "a" and all of its outstanding aliases are not used anywhere
|
||||||
|
later in the graph. If this is the case, then it's safe to re-inplace
|
||||||
|
to "b = foo_(a)".
|
||||||
|
|
||||||
|
There are a few caveats to this, explained in more detail below:
|
||||||
|
(a) If "a" is used later as an argument to a view op, that is okay.
|
||||||
|
It's only a problem if "a" (or that view) is later passed
|
||||||
|
into a normal operator, or if it is returned as the program output.
|
||||||
|
(b) If "a" is a repeat argument in `foo()`, then don't reinplace.
|
||||||
|
Most ATen kernels don't make any guarantees that this is sound,
|
||||||
|
e.g. if you do aten.mul_(a, a).
|
||||||
|
So we'll just ban re-inplacing in this case.
|
||||||
|
It's only a problem if "a" (or that view) is later passed
|
||||||
|
(c) If "a" is used as an input into a view "inverse" / "scatter"
|
||||||
|
operator, it is potentially fine to re-inplace
|
||||||
|
(and remove that scatter operator from the graph).
|
||||||
|
See below for a more detailed example.
|
||||||
|
|
||||||
|
NOTE: there is an optimization in this step that is crucial
|
||||||
|
to fully recovering performance from functionalization.
|
||||||
|
|
||||||
|
Given this program:
|
||||||
|
def f(x):
|
||||||
|
a = torch.ops.aten.add(x, x)
|
||||||
|
b = torch.ops.aten.diagonal(a)
|
||||||
|
torch.ops.aten.fill_(b, 0)
|
||||||
|
return d
|
||||||
|
|
||||||
|
Functionalization will emit the following:
|
||||||
|
def f(x):
|
||||||
|
a = torch.ops.aten.add(x, x)
|
||||||
|
b = torch.ops.aten.diagonal(a, 0, 1)
|
||||||
|
b_updated = torch.ops.aten.fill(b, 0)
|
||||||
|
a_updated = torch.ops.aten.diagonal_scatter(a, b_updated, 0, 1)
|
||||||
|
return a_updated
|
||||||
|
|
||||||
|
Ordinarily, we would not be able to reinplace the fill,
|
||||||
|
because "b" aliases with "a" which is used by the diagonal_scatter call.
|
||||||
|
|
||||||
|
"re-inplacing" is on the hook for figuring out that it is ok to
|
||||||
|
completely, the expensive diagonal_scatter call, if we re-inplace the add().
|
||||||
|
|
||||||
|
So, for every `alias in alias_set(a)`, instead of checking
|
||||||
|
that "alias" is not used anywhere later in the graph,
|
||||||
|
we check that
|
||||||
|
EITHER:
|
||||||
|
(a) alias is not used anywhere later in the graph
|
||||||
|
OR:
|
||||||
|
(b) alias is used exactly once later on in the graph,
|
||||||
|
in the following op:
|
||||||
|
|
||||||
|
out = foo_scatter(alias, x, args...)
|
||||||
|
|
||||||
|
where the following must hold:
|
||||||
|
(i) "foo_scatter" is the "inverse" operator for foo.
|
||||||
|
This only applies to "foo" ops that are view operators,
|
||||||
|
which view into a subset of the original tensor's memory.
|
||||||
|
In practice, there are ~4 operators where this applies:
|
||||||
|
diagonal -> diagonal_scatter
|
||||||
|
slice -> slice_scatter
|
||||||
|
select -> select_scatter
|
||||||
|
as_strided -> as_strided_scatter
|
||||||
|
(ii) "args..." are the same between the foo() and foo_scatter() calls.
|
||||||
|
|
||||||
|
(4) Finally, after converting "b = foo(a)" into "foo_(a)",
|
||||||
|
we need to find all later nodes that use "b" as an argument
|
||||||
|
and update them to take in "a" instead.
|
||||||
|
|
||||||
|
Note that for the majority of inplace ops, this isn't actually necessary
|
||||||
|
(because most inplace ops return "self" as their output).
|
||||||
|
This isn't generally true for all mutable ops though, which is why
|
||||||
|
we need to actually replace all of the arguments.
|
||||||
|
|
||||||
|
We also need to update our metadata of Dict[StorageWeakRef, Set[Node]],
|
||||||
|
That maps a given tensor storage to the set of all nodes that take in that storage
|
||||||
|
as an input.
|
||||||
|
Specifically, re-inplacing `b = foo(a)` causes "a" and "b"'s sets to get fused
|
||||||
|
together.
|
||||||
|
|
||||||
|
(5) Any "view_inverse/scatter" nodes that were identified as "it's ok to ignore them"
|
||||||
|
during step (3) get manually deleted from the graph.
|
||||||
|
Their outputs are no longer used, so technically standard DCE would be able
|
||||||
|
to do this, but we can no longer run FX's DCE pass now that we have mutable
|
||||||
|
ops in the graph.
|
||||||
|
"""
|
||||||
|
_FunctionalizationMetadataProp(gm).propagate(*sample_args)
|
||||||
|
|
||||||
|
# Useful debug printing
|
||||||
|
# def _print(x):
|
||||||
|
# if isinstance(x, FakeTensor):
|
||||||
|
# print(f'fake_result: {StorageWeakRef(x.storage()).cdata}')
|
||||||
|
|
||||||
|
# for n in gm.graph.nodes:
|
||||||
|
# print(n.format_node())
|
||||||
|
# if hasattr(n, 'meta'):
|
||||||
|
# print(f'node_idx: {n.meta["node_idx"]}')
|
||||||
|
# if 'fake_result' in n.meta:
|
||||||
|
# tree_map(_print, n.meta['fake_result'])
|
||||||
|
# if 'view_of' in n.meta:
|
||||||
|
# print(f'view_of: {str(n.meta["view_of"])}')
|
||||||
|
# print()
|
||||||
|
|
||||||
|
# We need to know which nodes correspond to inputs (or their aliases)
|
||||||
|
# so we know not to re-inplace them.
|
||||||
|
# NOTE: later, we'll need to add an optimization for fully recovering performance
|
||||||
|
# on programs that mutate inputs.
|
||||||
|
input_storages = set(StorageWeakRef(node.meta['fake_result'].storage()) for node in gm.graph.nodes if node.op == 'placeholder')
|
||||||
|
|
||||||
|
|
||||||
|
# We also need to know for a given node, what are all of its aliasing nodes.
|
||||||
|
storage_to_nodes: Dict[StorageWeakRef, Set[Node]] = defaultdict(set)
|
||||||
|
for n in gm.graph.nodes:
|
||||||
|
if 'fake_result' in n.meta:
|
||||||
|
# Tree-mapping because some ops can return lists of tensors.
|
||||||
|
def _add_to_map(x):
|
||||||
|
if isinstance(x, FakeTensor):
|
||||||
|
storage_to_nodes[StorageWeakRef(x.storage())].add(n)
|
||||||
|
tree_map(_add_to_map, n.meta['fake_result'])
|
||||||
|
|
||||||
|
# inplace-ify functional ops, subject to the constraints written below.
|
||||||
|
all_later_view_inverse_node_usages = set()
|
||||||
|
for idx, node in enumerate(gm.graph.nodes):
|
||||||
|
if node.op == 'call_function':
|
||||||
|
# Step 1: Check to see if this operator has an inplace variant.
|
||||||
|
maybe_inplace_op = _maybe_get_inplace_op(node.target)
|
||||||
|
if maybe_inplace_op is None:
|
||||||
|
continue
|
||||||
|
# This is a proxy check for ensuring that the first argument is "tensor-like"
|
||||||
|
# (This should be the case for all ops with inplace variants in ATen,
|
||||||
|
# although we technically don't have guarantees for custom ops).
|
||||||
|
assert len(node.target._schema.arguments) > 0
|
||||||
|
assert 'Tensor' in str(node.target._schema.arguments[0].type)
|
||||||
|
|
||||||
|
# Step 2: ensure that the op we're trying to re-inplace isn't a program input.
|
||||||
|
self_arg = node.args[0]
|
||||||
|
self_arg_name = self_arg.name
|
||||||
|
self_arg_storage = StorageWeakRef(self_arg.meta['fake_result'].storage())
|
||||||
|
if self_arg_storage in input_storages:
|
||||||
|
# TODO: later, add the optimization for handling `copy_()` calls in the graph.
|
||||||
|
continue
|
||||||
|
if len([x for x in node.args if x is self_arg]) > 1:
|
||||||
|
# Step (3b) in the original description.
|
||||||
|
# Calling stuff like aten.mul_(a, a) isn't guaranteed to be sound,
|
||||||
|
# so we prevent re-inplacing in this case.
|
||||||
|
continue
|
||||||
|
|
||||||
|
self_arg_storage = StorageWeakRef(self_arg.meta['fake_result'].storage())
|
||||||
|
curr_node_storage = StorageWeakRef(node.meta['fake_result'].storage())
|
||||||
|
self_aliases = storage_to_nodes[self_arg_storage]
|
||||||
|
|
||||||
|
# First, we find all later usages of any of the aliases of self_arg.
|
||||||
|
later_node_usages = _get_all_later_node_usages(self_aliases, node.meta['node_idx'])
|
||||||
|
# Then, we check if any of those later usages are actually view_scatter ops
|
||||||
|
# that are safe to fully remove.
|
||||||
|
later_view_inverse_node_usages = _get_view_inverse_node_usages(later_node_usages, self_aliases)
|
||||||
|
|
||||||
|
# Step 3: Check to see if the input to the op is re-used later in the graph.
|
||||||
|
# If not (same goes for its aliases), then this op is safe to re-in place.
|
||||||
|
# This is a slightly roundabout way to check that there are no later usages of the current self argument.
|
||||||
|
# (later_view_inverse_node_usages corresponds to "view_scatter" nodes that we are allowed to delete)
|
||||||
|
can_reinplace = len(later_node_usages - later_view_inverse_node_usages) == 0
|
||||||
|
if not can_reinplace:
|
||||||
|
continue
|
||||||
|
# Step 4: replace the current out-of-place op with its inplace variant.
|
||||||
|
node.target = maybe_inplace_op
|
||||||
|
# At this point, 'storage_to_nodes' will be stale.
|
||||||
|
# Now that we're inplacing `b = foo(a)`, we need to effectively
|
||||||
|
# union together the dict values for b and a's storage.
|
||||||
|
# Hmm... morally I think we also want to keep the `fake_result` metadata
|
||||||
|
# up to date here, but I'm not sure how easy it is to do.
|
||||||
|
# Maybe it's fine to wait until the end of the pass to update it.
|
||||||
|
storage_to_nodes[self_arg_storage].update(storage_to_nodes[curr_node_storage])
|
||||||
|
storage_to_nodes[curr_node_storage].update(storage_to_nodes[self_arg_storage])
|
||||||
|
|
||||||
|
# Need to remember the view_scatter view nodes we found so we can remove them alter.
|
||||||
|
all_later_view_inverse_node_usages.update(later_view_inverse_node_usages)
|
||||||
|
|
||||||
|
# Now that we've replaced b = a.foo() with a.foo_(),
|
||||||
|
# We need to replace any later usages of "b" with "a"
|
||||||
|
for old in itertools.chain([node], later_view_inverse_node_usages):
|
||||||
|
new = old.args[0]
|
||||||
|
nodes_to_update = [n for n in old.users if n.meta['node_idx'] > node.meta['node_idx']]
|
||||||
|
for node_to_update in nodes_to_update:
|
||||||
|
new_args = []
|
||||||
|
for arg_idx, a in enumerate(node_to_update.args):
|
||||||
|
if a == old:
|
||||||
|
new_args.append(new)
|
||||||
|
else:
|
||||||
|
new_args.append(a)
|
||||||
|
new_kwargs = {}
|
||||||
|
for kwarg_idx, (k, v) in enumerate(node_to_update.kwargs.items()):
|
||||||
|
if isinstance(v, Node) and v.name == old.name:
|
||||||
|
new_kwargs[k] = new
|
||||||
|
else:
|
||||||
|
new_kwargs[k] = v
|
||||||
|
node_to_update.args = tuple(new_args)
|
||||||
|
node_to_update.kwargs = new_kwargs
|
||||||
|
|
||||||
|
old_ref = StorageWeakRef(old.meta['fake_result'].storage())
|
||||||
|
node_ref = StorageWeakRef(node_to_update.meta['fake_result'].storage())
|
||||||
|
if old_ref == node_ref:
|
||||||
|
# This will happen if we're updating a view op, e.g.
|
||||||
|
# e.g. replacing
|
||||||
|
# x = view(old)
|
||||||
|
# x = view(new)
|
||||||
|
# When that happens, we need to make sure to keep our
|
||||||
|
# storage mapping up to date.
|
||||||
|
new_ref = StorageWeakRef(new.meta['fake_result'].storage())
|
||||||
|
# Technically, "old_ref" and all its aliases will remain
|
||||||
|
# in our mapping.
|
||||||
|
# That should be fine though, since we deleted "old"
|
||||||
|
# from the graph at this point.
|
||||||
|
storage_to_nodes[node_ref].update(storage_to_nodes[new_ref])
|
||||||
|
storage_to_nodes[new_ref].update(storage_to_nodes[node_ref])
|
||||||
|
|
||||||
|
# Step 5: delete any _scatter nodes that we de-functionalized
|
||||||
|
# Need to take care not to delete any of these nodes until after *all* modifications
|
||||||
|
# to the graph are finished.
|
||||||
|
for to_delete in all_later_view_inverse_node_usages:
|
||||||
|
gm.graph.erase_node(to_delete)
|
||||||
|
|
||||||
|
|
||||||
|
gm.recompile()
|
||||||
|
return gm
|
Reference in New Issue
Block a user