add a reinplacing FX pass (#80897)

Adds a "reinplacing" FX transform, that goes through an FX graph and tries to convert out-of-place op calls into inplace calls whenever possible.

Followups from this PR include:
- Set up torch bench, and run the whole torchbench suite using AOTAutograd + functionalize + rein placing transforms to surface any issues (this is what I'm currently working on). Right now, I have some basic unit tests just to sanity check that the general logic makes sense.
- Add any missing inplace ops. This is mostly the `*_scatter*` ops, e.g. `diagonal_scatter_`, because these ops will commonly show up an FX graph after running functionalization.

The criteria for when you can swap an op `b = a.add(...)` with `a.add_(...)` is:
(1) An inplace variant of the operator with the same schema needs to exist (`aten.add` -> `aten.add_`)
(2) `a` (**or any of its aliases**) can't be used as an input to any other operators later on in the graph
(3) `a` can't be one of the inputs to the entire graph. It also can't be an **alias** of any of the inputs ***

*** One thing to note: (3) means that we can't technically guarantee that we'll get back **all** memory usage that we lost from functionalization. Functionalization converts input mutations into out-of-place calls, and then adds a `copy_()` to the end of the graph to preserve semantics.

I added logic to handle `copy_()` in this PR because it it's a pretty important optimizations in the context of `functionalization()`: any program that performs input mutations will have a `copy_()` in it after running functionalization.

There are some examples in the test file, but I think staring at an example of where re-inplacing is/isn't allowed to run is helpful:
```
// Before functionalization
def foo(a):
    tmp1 = a.add_(1)
    tmp2 = a.add(2)

// After functionalization
def foo(a)
    tmp1 = a.add(1)
    tmp2 = a.add(2)
    ....
    a.copy_(tmp1)

// After re-inplacing
def foo(a)
    // first add() is safe to re-inplace even though a is a program input,
    // because a's data is overwritten later by a copy_()
    tmp1 = a.add_(1)
    // second add() is NOT safe to re-inplace, because:
    // (1) a and tmp1 are aliased. Note that they weren't aliased in the original program,
             but they are now that we've done some re-inplacing.
    // (2) tmp1 is used as an input later in the program
    tmp2 = a.add(2)
    ....
    a.copy_(tmp1)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80897
Approved by: https://github.com/ezyang
This commit is contained in:
Brian Hirsh
2022-07-27 07:10:15 -07:00
committed by PyTorch MergeBot
parent 46b83f66ec
commit 3ef7a6921d
6 changed files with 1182 additions and 60 deletions

View File

@ -9,6 +9,12 @@
#include <c10/util/irange.h> #include <c10/util/irange.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/_to_copy.h>
#endif
namespace at { namespace at {
void FunctionalTensorWrapper::set_constructor_metadata() { void FunctionalTensorWrapper::set_constructor_metadata() {
@ -205,7 +211,9 @@ void FunctionalTensorWrapper::replace_(const Tensor& other) {
if (dtype() != value_.unsafeGetTensorImpl()->dtype() || layout() != value_.unsafeGetTensorImpl()->layout()) { if (dtype() != value_.unsafeGetTensorImpl()->dtype() || layout() != value_.unsafeGetTensorImpl()->layout()) {
// .to() should not re-entrantly go through functionalization. // .to() should not re-entrantly go through functionalization.
at::AutoDispatchSkipFunctionalize guard; at::AutoDispatchSkipFunctionalize guard;
value_ = value_.to(c10::TensorOptions().dtype(dtype()).layout(layout())); // and we want _to_copy() to show up in the graph, not the composite .to() operator
// (this can happen if autograd has already run by the time we enter this code)
value_ = at::_to_copy(value_, c10::TensorOptions().dtype(dtype()).layout(layout()));
} }
} }

View File

@ -59,6 +59,7 @@ torch.fx.node.Node.update_arg(self, idx: int, arg: torch.fx.node.Argument) -> No
torch.fx.node.Node.update_kwarg(self, key: str, arg: torch.fx.node.Argument) -> None torch.fx.node.Node.update_kwarg(self, key: str, arg: torch.fx.node.Argument) -> None
torch.fx.node.map_aggregate(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Argument], torch.fx.node.Argument]) -> torch.fx.node.Argument torch.fx.node.map_aggregate(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Argument], torch.fx.node.Argument]) -> torch.fx.node.Argument
torch.fx.node.map_arg(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Node], torch.fx.node.Argument]) -> torch.fx.node.Argument torch.fx.node.map_arg(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Node], torch.fx.node.Argument]) -> torch.fx.node.Argument
torch.fx.passes.reinplace.reinplace(gm, *sample_args)
torch.fx.passes.split_module.split_module(m: torch.fx.graph_module.GraphModule, root_m: torch.nn.modules.module.Module, split_callback: Callable[[torch.fx.node.Node], int], qualname_map: Optional[Dict[str, str]] = None) torch.fx.passes.split_module.split_module(m: torch.fx.graph_module.GraphModule, root_m: torch.nn.modules.module.Module, split_callback: Callable[[torch.fx.node.Node], int], qualname_map: Optional[Dict[str, str]] = None)
torch.fx.proxy.Attribute.__init__(self, root: torch.fx.proxy.Proxy, attr: str) torch.fx.proxy.Attribute.__init__(self, root: torch.fx.proxy.Proxy, attr: str)
torch.fx.proxy.Proxy.__init__(self, node: torch.fx.node.Node, tracer: 'Optional[TracerBase]' = None) torch.fx.proxy.Proxy.__init__(self, node: torch.fx.node.Node, tracer: 'Optional[TracerBase]' = None)

View File

@ -5,6 +5,7 @@ from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorc
from torch.testing._internal.logging_tensor import LoggingTensor, capture_logs from torch.testing._internal.logging_tensor import LoggingTensor, capture_logs
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.passes.reinplace import reinplace
import unittest import unittest
@ -17,59 +18,72 @@ def are_aliased(x, y):
return y._base is x return y._base is x
return x._base is y._base return x._base is y._base
# We can unify testing and use functionalize() here instead
# if/when functorch moves into core.
# This is basically a crappy version of `functionalize()` for single-tensor-arg inputs.
def _functionalize(f, *, reapply_views: bool):
def wrapped(a):
input_functional = torch._to_functional_tensor(a)
torch._enable_functionalization(reapply_views=reapply_views)
try:
out = f(input_functional)
finally:
torch._disable_functionalization()
torch._sync(input_functional)
inpt_new = torch._from_functional_tensor(input_functional)
if inpt_new is not a:
# Existing deficiency in functionalize():
# we don't correctly mutate input metadata (yet?)
if inpt_new.shape == a.shape:
a.copy_(inpt_new)
tree_map(torch._sync, out)
out_unwrapped = tree_map(torch._from_functional_tensor, out)
return out_unwrapped
return wrapped
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "https://github.com/pytorch/pytorch/issues/81457") @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "https://github.com/pytorch/pytorch/issues/81457")
class TestFunctionalization(TestCase): class TestFunctionalization(TestCase):
# We can unify testing and use functionalize() here instead
# if/when functorch moves into core.
def _functionalize(self, f, *, reapply_views: bool):
def wrapped(a):
input_functional = torch._to_functional_tensor(a)
torch._enable_functionalization(reapply_views=reapply_views)
try:
out = f(input_functional)
finally:
torch._disable_functionalization()
torch._sync(input_functional)
tree_map(torch._sync, out)
out_unwrapped = tree_map(torch._from_functional_tensor, out)
return out_unwrapped
return wrapped def get_logs(self, func, inpt, *, reapply_views=False, run_reinplace=False):
inpt_clone = inpt.clone()
def get_logs(self, func, inpt, *, reapply_views=False): traced_f = make_fx(_functionalize(func, reapply_views=reapply_views))(inpt)
traced_f = make_fx(self._functionalize(func, reapply_views=reapply_views))(inpt) if run_reinplace:
traced_f = reinplace(traced_f, inpt_clone)
return traced_f.code return traced_f.code
def assert_functionalization(self, func, inpt, *, reapply_views=False): def assert_functionalization(self, func, inpt, *, reapply_views=False, mutated_input_metadata=False):
input_clone = inpt.clone() input_clone = inpt.clone()
input_clone2 = inpt.clone() input_clone2 = inpt.clone()
input_functional = torch._to_functional_tensor(input_clone2) input_clone3 = inpt.clone()
# Compare outputs (and mutated inputs), with and without functionalization. # Compare outputs (and mutated inputs), with and without functionalization.
out_ref = func(inpt) out_ref = func(inpt)
out_functional = _functionalize(func, reapply_views=reapply_views)(input_clone)
# The reinplacing pass is only valid to run with reapply_views=True.
functional_func = make_fx(_functionalize(func, reapply_views=True))(input_clone2)
reinplace_func = reinplace(make_fx(_functionalize(func, reapply_views=True))(input_clone2), input_clone2)
torch._enable_functionalization(reapply_views=reapply_views) # NOTE: for now, need to pass in fresh inputs here, because make_fx
try: # will directly mutate the inputs that you trace with.
out_functional = func(input_functional) # Once this is fixed we can clean this up.
finally: out_reinplace = reinplace_func(input_clone3)
torch._disable_functionalization()
# We need to sync the input tensors first, in case there are any queued mutations left. # functionalize() deficiency: input metadata mutations aren't propagated properly,
torch._sync(input_functional) # so we just need to skip checks here for the tests that exercise that.
self.assertEqual(inpt, torch._from_functional_tensor(input_functional)) # input mutations should still occur if not mutated_input_metadata:
self.assertEqual(inpt, input_clone) # input mutations should still occur
self.assertEqual(inpt, input_clone3)
# Handle tests with multi-tensor outputs # Handle tests with multi-tensor outputs
if isinstance(out_ref, tuple) and isinstance(out_functional, tuple): if isinstance(out_ref, tuple):
out_refs, out_functionals = list(out_ref), list(out_functional) out_refs, out_functionals, out_reinplaces = list(out_ref), list(out_functional), list(out_reinplace)
else: else:
out_refs, out_functionals = [out_ref], [out_functional] out_refs, out_functionals, out_reinplaces = [out_ref], [out_functional], [out_reinplace]
for out_ref_, out_functional_ in zip(out_refs, out_functionals): for out_ref_, out_functional_, out_reinplace_ in zip(out_refs, out_functionals, out_reinplaces):
self.assertEqual(out_ref_.size(), out_functional_.size()) self.assertEqual(out_ref_, out_functional_)
torch._sync(out_functional_) self.assertEqual(out_ref_, out_reinplace_)
out_functional_unwrapped = torch._from_functional_tensor(out_functional_)
self.assertEqual(out_ref_, out_functional_unwrapped)
def test_save_for_backwards_segfault(self): def test_save_for_backwards_segfault(self):
inp = torch._to_functional_tensor(LoggingTensor(torch.randn(2, 2))).requires_grad_(True) inp = torch._to_functional_tensor(LoggingTensor(torch.randn(2, 2))).requires_grad_(True)
@ -104,10 +118,27 @@ class TestFunctionalization(TestCase):
def forward(self, a_1): def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2]); a_1 = None view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2])
add_tensor = torch.ops.aten.add.Tensor(view_copy_default, fill_scalar); view_copy_default = fill_scalar = None add_tensor = torch.ops.aten.add.Tensor(view_copy_default, fill_scalar); view_copy_default = fill_scalar = None
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4, 2]) view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4, 2])
mul_tensor = torch.ops.aten.mul.Tensor(view_copy_default_1, view_copy_default_1); view_copy_default_1 = None mul_tensor = torch.ops.aten.mul.Tensor(view_copy_default_1, view_copy_default_1)
copy__default = torch.ops.aten.copy_.default(a_1, view_copy_default_1); a_1 = view_copy_default_1 = None
return add_tensor
""")
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
fill_scalar = torch.ops.aten.fill_.Scalar(empty, 1.0)
view_default = torch.ops.aten.view.default(a_1, [4, 2])
add_tensor = torch.ops.aten.add.Tensor(view_default, empty); view_default = empty = None
view_default_1 = torch.ops.aten.view.default(add_tensor, [4, 2])
mul_tensor = torch.ops.aten.mul.Tensor(view_default_1, view_default_1)
copy__default = torch.ops.aten.copy_.default(a_1, view_default_1); a_1 = view_default_1 = None
return add_tensor return add_tensor
""") """)
@ -134,6 +165,21 @@ def forward(self, a_1):
add_tensor = torch.ops.aten.add.Tensor(view_copy_default, fill_scalar); view_copy_default = fill_scalar = None add_tensor = torch.ops.aten.add.Tensor(view_copy_default, fill_scalar); view_copy_default = fill_scalar = None
mul_tensor = torch.ops.aten.mul.Tensor(add_tensor, add_tensor); add_tensor = None mul_tensor = torch.ops.aten.mul.Tensor(add_tensor, add_tensor); add_tensor = None
return mul_tensor return mul_tensor
""")
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
fill_scalar = torch.ops.aten.fill_.Scalar(empty, 1.0)
view_default = torch.ops.aten.view.default(a_1, [4, 2]); a_1 = None
empty_1 = torch.ops.aten.empty.SymInt([], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
add_tensor = torch.ops.aten.add.Tensor(view_default, empty); view_default = empty = None
mul_tensor = torch.ops.aten.mul.Tensor(add_tensor, add_tensor); add_tensor = None
return mul_tensor
""") """)
def test_multi_out(self): def test_multi_out(self):
@ -157,6 +203,20 @@ def forward(self, a_1):
getitem = aminmax_default[0] getitem = aminmax_default[0]
getitem_1 = aminmax_default[1]; aminmax_default = None getitem_1 = aminmax_default[1]; aminmax_default = None
return getitem return getitem
""")
reinplaced_logs = self.get_logs(f, torch.arange(8, dtype=torch.float32), reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.SymInt([4], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
empty_1 = torch.ops.aten.empty.SymInt([4], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
aminmax_default = torch.ops.aten.aminmax.default(a_1, dim = 0); a_1 = None
getitem = aminmax_default[0]
getitem_1 = aminmax_default[1]; aminmax_default = None
return getitem
""") """)
def test_tensor_ctr(self): def test_tensor_ctr(self):
@ -165,7 +225,38 @@ def forward(self, a_1):
z = y.view(-1) z = y.view(-1)
z.add_(1) z.add_(1)
return y return y
self.assert_functionalization(f, torch.arange(3, dtype=torch.float32))
inpt = torch.arange(3, dtype=torch.float32)
self.assert_functionalization(f, inpt)
logs = self.get_logs(f, inpt)
self.assertExpectedInline(logs, """\
def forward(self, a_1):
_tensor_constant0 = self._tensor_constant0
lift_fresh = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
view_copy_default = torch.ops.aten.view_copy.default(lift_fresh, [-1]); lift_fresh = None
add_tensor = torch.ops.aten.add.Tensor(view_copy_default, 1); view_copy_default = None
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [3]); add_tensor = None
return view_copy_default_1
""")
reinplaced_logs = self.get_logs(f, inpt, reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, a_1):
_tensor_constant0 = self._tensor_constant0
lift_fresh = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
view_default = torch.ops.aten.view.default(lift_fresh, [-1]); lift_fresh = None
add_tensor = torch.ops.aten.add_.Tensor(view_default, 1)
view_default_1 = torch.ops.aten.view.default(view_default, [3]); view_default = None
return view_default_1
""")
def test_inplace_on_non_view(self): def test_inplace_on_non_view(self):
def f(x): def f(x):
@ -185,9 +276,25 @@ def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2]) view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2])
add_tensor = torch.ops.aten.add.Tensor(a_1, fill_scalar); a_1 = fill_scalar = None add_tensor = torch.ops.aten.add.Tensor(a_1, fill_scalar); fill_scalar = None
copy__default = torch.ops.aten.copy_.default(a_1, add_tensor); a_1 = None
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4, 2]); add_tensor = None view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4, 2]); add_tensor = None
return view_copy_default_1 return view_copy_default_1
""")
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
fill_scalar = torch.ops.aten.fill_.Scalar(empty, 1.0)
view_default = torch.ops.aten.view.default(a_1, [4, 2])
add_tensor = torch.ops.aten.add.Tensor(a_1, empty); empty = None
copy__default = torch.ops.aten.copy_.default(a_1, add_tensor); a_1 = None
view_default_1 = torch.ops.aten.view.default(add_tensor, [4, 2]); add_tensor = None
return view_default_1
""") """)
# Some ops that are mutable are neither inplace nor out= ops. # Some ops that are mutable are neither inplace nor out= ops.
@ -202,13 +309,14 @@ def forward(self, a_1):
def forward(self, a_1): def forward(self, a_1):
_fused_moving_avg_obs_fq_helper_functional_default = torch.ops.aten._fused_moving_avg_obs_fq_helper_functional.default(a_1, a_1, a_1, a_1, a_1, a_1, a_1, 1.0, 0, 1, 0); a_1 = None _fused_moving_avg_obs_fq_helper_functional_default = torch.ops.aten._fused_moving_avg_obs_fq_helper_functional.default(a_1, a_1, a_1, a_1, a_1, a_1, a_1, 1.0, 0, 1, 0)
getitem = _fused_moving_avg_obs_fq_helper_functional_default[0] getitem = _fused_moving_avg_obs_fq_helper_functional_default[0]
getitem_1 = _fused_moving_avg_obs_fq_helper_functional_default[1] getitem_1 = _fused_moving_avg_obs_fq_helper_functional_default[1]
getitem_2 = _fused_moving_avg_obs_fq_helper_functional_default[2] getitem_2 = _fused_moving_avg_obs_fq_helper_functional_default[2]
getitem_3 = _fused_moving_avg_obs_fq_helper_functional_default[3] getitem_3 = _fused_moving_avg_obs_fq_helper_functional_default[3]
getitem_4 = _fused_moving_avg_obs_fq_helper_functional_default[4] getitem_4 = _fused_moving_avg_obs_fq_helper_functional_default[4]
getitem_5 = _fused_moving_avg_obs_fq_helper_functional_default[5]; _fused_moving_avg_obs_fq_helper_functional_default = None getitem_5 = _fused_moving_avg_obs_fq_helper_functional_default[5]; _fused_moving_avg_obs_fq_helper_functional_default = None
copy__default = torch.ops.aten.copy_.default(a_1, getitem_5); a_1 = getitem_5 = None
return (getitem, getitem_1) return (getitem, getitem_1)
""") # noqa: B950 """) # noqa: B950
@ -226,7 +334,8 @@ def forward(self, a_1):
def forward(self, a_1): def forward(self, a_1):
as_strided_copy_default = torch.ops.aten.as_strided_copy.default(a_1, [2], [2], 1) as_strided_copy_default = torch.ops.aten.as_strided_copy.default(a_1, [2], [2], 1)
add_tensor = torch.ops.aten.add.Tensor(as_strided_copy_default, 1); as_strided_copy_default = None add_tensor = torch.ops.aten.add.Tensor(as_strided_copy_default, 1); as_strided_copy_default = None
as_strided_scatter_default = torch.ops.aten.as_strided_scatter.default(a_1, add_tensor, [2], [2], 1); a_1 = add_tensor = None as_strided_scatter_default = torch.ops.aten.as_strided_scatter.default(a_1, add_tensor, [2], [2], 1); add_tensor = None
copy__default = torch.ops.aten.copy_.default(a_1, as_strided_scatter_default); a_1 = None
return as_strided_scatter_default return as_strided_scatter_default
""") """)
@ -263,11 +372,23 @@ def forward(self, a_1):
return cat_default return cat_default
""") """)
reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.SymInt([0], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
cat_default = torch.ops.aten.cat.default([a_1]); a_1 = None
return cat_default
""")
def test_diagonal(self): def test_diagonal(self):
def f(x): def f(x):
# test: view ops that take a subset of the original tensor (select/diagonal) # test: view ops that take a subset of the original tensor (select/diagonal)
tmp = torch.ones(2) tmp = torch.ones(2)
y = x.diagonal() y = x.clone().diagonal()
y.add_(tmp) y.add_(tmp)
z = x * x z = x * x
return z return z
@ -280,10 +401,25 @@ def forward(self, a_1):
def forward(self, a_1): def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) empty = torch.ops.aten.empty.memory_format([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(a_1) clone_default = torch.ops.aten.clone.default(a_1)
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(clone_default); clone_default = None
add_tensor = torch.ops.aten.add.Tensor(diagonal_copy_default, fill_scalar); diagonal_copy_default = fill_scalar = None add_tensor = torch.ops.aten.add.Tensor(diagonal_copy_default, fill_scalar); diagonal_copy_default = fill_scalar = None
diagonal_scatter_default = torch.ops.aten.diagonal_scatter.default(a_1, add_tensor); a_1 = add_tensor = None mul_tensor = torch.ops.aten.mul.Tensor(a_1, a_1); a_1 = None
mul_tensor = torch.ops.aten.mul.Tensor(diagonal_scatter_default, diagonal_scatter_default); diagonal_scatter_default = None return mul_tensor
""")
reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
fill_scalar = torch.ops.aten.fill_.Scalar(empty, 1.0)
clone_default = torch.ops.aten.clone.default(a_1)
diagonal_default = torch.ops.aten.diagonal.default(clone_default); clone_default = None
add_tensor = torch.ops.aten.add_.Tensor(diagonal_default, empty); diagonal_default = empty = None
mul_tensor = torch.ops.aten.mul.Tensor(a_1, a_1); a_1 = None
return mul_tensor return mul_tensor
""") """)
@ -296,6 +432,20 @@ def forward(self, a_1):
return x return x
x = torch.ones(2, 2) x = torch.ones(2, 2)
self.assert_functionalization(f, x) self.assert_functionalization(f, x)
logs = self.get_logs(f, torch.ones(2, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(a_1)
add_tensor = torch.ops.aten.add.Tensor(diagonal_copy_default, fill_scalar); diagonal_copy_default = fill_scalar = None
diagonal_scatter_default = torch.ops.aten.diagonal_scatter.default(a_1, add_tensor); add_tensor = None
copy__default = torch.ops.aten.copy_.default(a_1, diagonal_scatter_default); a_1 = None
return diagonal_scatter_default
""")
def test_split(self): def test_split(self):
def f(x): def f(x):
@ -324,8 +474,9 @@ def forward(self, a_1):
getitem_2 = split_copy_tensor_1[0] getitem_2 = split_copy_tensor_1[0]
getitem_3 = split_copy_tensor_1[1]; split_copy_tensor_1 = None getitem_3 = split_copy_tensor_1[1]; split_copy_tensor_1 = None
diagonal_scatter_default = torch.ops.aten.diagonal_scatter.default(getitem_3, add_tensor); getitem_3 = None diagonal_scatter_default = torch.ops.aten.diagonal_scatter.default(getitem_3, add_tensor); getitem_3 = None
slice_scatter_default = torch.ops.aten.slice_scatter.default(a_1, diagonal_scatter_default, 0, 2, 4); a_1 = diagonal_scatter_default = None slice_scatter_default = torch.ops.aten.slice_scatter.default(a_1, diagonal_scatter_default, 0, 2, 4); diagonal_scatter_default = None
mul_tensor = torch.ops.aten.mul.Tensor(slice_scatter_default, slice_scatter_default); slice_scatter_default = None mul_tensor = torch.ops.aten.mul.Tensor(slice_scatter_default, slice_scatter_default)
copy__default = torch.ops.aten.copy_.default(a_1, slice_scatter_default); a_1 = slice_scatter_default = None
return add_tensor return add_tensor
""") # noqa: B950 """) # noqa: B950
@ -337,7 +488,7 @@ def forward(self, a_1):
y = x[0] y = x[0]
y.add_(tmp) y.add_(tmp)
return x return x
self.assert_functionalization(f, torch.ones(4, 2)) self.assert_functionalization(f, torch.ones(4, 2), mutated_input_metadata=True)
logs = self.get_logs(f, torch.ones(4, 2)) logs = self.get_logs(f, torch.ones(4, 2))
self.assertExpectedInline(logs, """\ self.assertExpectedInline(logs, """\
@ -372,13 +523,14 @@ def forward(self, a_1):
def forward(self, a_1): def forward(self, a_1):
view_copy_default = torch.ops.aten.view_copy.default(a_1, [8]); a_1 = None view_copy_default = torch.ops.aten.view_copy.default(a_1, [8])
empty = torch.ops.aten.empty.memory_format([0], dtype = torch.int64, layout = torch.strided, device = device(type='cpu'), pin_memory = False) empty = torch.ops.aten.empty.memory_format([0], dtype = torch.int64, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
arange = torch.ops.aten.arange.start_step(0, 4, 1, dtype = torch.int64, layout = torch.strided, device = device(type='cpu')) arange = torch.ops.aten.arange.start_step(0, 4, 1, dtype = torch.int64, layout = torch.strided, device = device(type='cpu'))
empty_1 = torch.ops.aten.empty.memory_format([0], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False) empty_1 = torch.ops.aten.empty.memory_format([0], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
arange_1 = torch.ops.aten.arange.start_step(0, 4, 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')) arange_1 = torch.ops.aten.arange.start_step(0, 4, 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
index_put_default = torch.ops.aten.index_put.default(view_copy_default, [arange], arange_1); view_copy_default = arange = arange_1 = None index_put_default = torch.ops.aten.index_put.default(view_copy_default, [arange], arange_1); view_copy_default = arange = arange_1 = None
view_copy_default_1 = torch.ops.aten.view_copy.default(index_put_default, [4, 2]) view_copy_default_1 = torch.ops.aten.view_copy.default(index_put_default, [4, 2])
copy__default = torch.ops.aten.copy_.default(a_1, view_copy_default_1); a_1 = view_copy_default_1 = None
return index_put_default return index_put_default
""") # noqa: B950 """) # noqa: B950
@ -400,11 +552,12 @@ def forward(self, a_1):
def forward(self, a_1): def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2]); a_1 = None view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2])
add_tensor = torch.ops.aten.add.Tensor(view_copy_default, 1); view_copy_default = None add_tensor = torch.ops.aten.add.Tensor(view_copy_default, 1); view_copy_default = None
mul_tensor = torch.ops.aten.mul.Tensor(add_tensor, 2) mul_tensor = torch.ops.aten.mul.Tensor(add_tensor, 2)
div_tensor = torch.ops.aten.div.Tensor(mul_tensor, 1); mul_tensor = None div_tensor = torch.ops.aten.div.Tensor(mul_tensor, 1); mul_tensor = None
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4, 2]); add_tensor = None view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4, 2]); add_tensor = None
copy__default = torch.ops.aten.copy_.default(a_1, view_copy_default_1); a_1 = view_copy_default_1 = None
return div_tensor return div_tensor
""") """)
@ -413,7 +566,9 @@ def forward(self, a_1):
def f(x): def f(x):
# ops like ge_() are allowed to change the dtype of the input. # ops like ge_() are allowed to change the dtype of the input.
# functionalization should pick up on that. # functionalization should pick up on that.
return x.ge_(0) y = x.clone()
out = y.ge_(0)
return out
self.assert_functionalization(f, torch.ones(4, 2)) self.assert_functionalization(f, torch.ones(4, 2))
logs = self.get_logs(f, torch.ones(4, 2)) logs = self.get_logs(f, torch.ones(4, 2))
self.assertExpectedInline(logs, """\ self.assertExpectedInline(logs, """\
@ -421,11 +576,24 @@ def forward(self, a_1):
def forward(self, a_1): def forward(self, a_1):
ge_scalar = torch.ops.aten.ge.Scalar(a_1, 0); a_1 = None clone_default = torch.ops.aten.clone.default(a_1); a_1 = None
to_dtype_layout = torch.ops.aten.to.dtype_layout(ge_scalar, dtype = torch.float32, layout = torch.strided); ge_scalar = None ge_scalar = torch.ops.aten.ge.Scalar(clone_default, 0); clone_default = None
return to_dtype_layout _to_copy_default = torch.ops.aten._to_copy.default(ge_scalar, dtype = torch.float32, layout = torch.strided); ge_scalar = None
return _to_copy_default
""") """)
reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, a_1):
clone_default = torch.ops.aten.clone.default(a_1); a_1 = None
ge_scalar = torch.ops.aten.ge_.Scalar(clone_default, 0)
_to_copy_default = torch.ops.aten._to_copy.default(clone_default, dtype = torch.float32, layout = torch.strided); clone_default = None
return _to_copy_default
""") # noqa: B950
@skipIfTorchDynamo("Test does not work with TorchDynamo") @skipIfTorchDynamo("Test does not work with TorchDynamo")
def test_metadata_change_out_op(self): def test_metadata_change_out_op(self):
def f(t, y): def f(t, y):
@ -514,6 +682,44 @@ def forward(self, a_1):
return add_tensor_1 return add_tensor_1
""") # noqa: B950 """) # noqa: B950
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
fill_scalar = torch.ops.aten.fill_.Scalar(empty, 1.0)
add_tensor = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
view_default = torch.ops.aten.view.default(add_tensor, [8])
_reshape_alias_default = torch.ops.aten._reshape_alias.default(view_default, [2, 4], [4, 1]); view_default = None
transpose_int = torch.ops.aten.transpose.int(_reshape_alias_default, 1, 0)
unsqueeze_default = torch.ops.aten.unsqueeze.default(transpose_int, 0); transpose_int = None
squeeze_default = torch.ops.aten.squeeze.default(unsqueeze_default); unsqueeze_default = None
split_tensor = torch.ops.aten.split.Tensor(squeeze_default, 2); squeeze_default = None
getitem = split_tensor[0]
getitem_1 = split_tensor[1]; split_tensor = None
add_tensor_1 = torch.ops.aten.add_.Tensor(getitem, empty); empty = None
select_int = torch.ops.aten.select.int(_reshape_alias_default, 0, 0); _reshape_alias_default = None
clone_default = torch.ops.aten.clone.default(getitem, memory_format = torch.contiguous_format)
_unsafe_view_default = torch.ops.aten._unsafe_view.default(clone_default, [4]); clone_default = None
view_default_1 = torch.ops.aten.view.default(add_tensor, [8]); add_tensor = None
_reshape_alias_default_1 = torch.ops.aten._reshape_alias.default(view_default_1, [2, 4], [4, 1]); view_default_1 = None
transpose_int_1 = torch.ops.aten.transpose.int(_reshape_alias_default_1, 1, 0); _reshape_alias_default_1 = None
unsqueeze_default_1 = torch.ops.aten.unsqueeze.default(transpose_int_1, 0); transpose_int_1 = None
squeeze_default_1 = torch.ops.aten.squeeze.default(unsqueeze_default_1); unsqueeze_default_1 = None
unsqueeze_default_2 = torch.ops.aten.unsqueeze.default(squeeze_default_1, 0); squeeze_default_1 = None
squeeze_dim = torch.ops.aten.squeeze.dim(unsqueeze_default_2, 0); unsqueeze_default_2 = None
transpose_int_2 = torch.ops.aten.transpose.int(squeeze_dim, 1, 0); squeeze_dim = None
_reshape_alias_default_2 = torch.ops.aten._reshape_alias.default(transpose_int_2, [8], [1]); transpose_int_2 = None
view_default_2 = torch.ops.aten.view.default(_reshape_alias_default_2, [4, 2]); _reshape_alias_default_2 = None
view_default_3 = torch.ops.aten.view.default(view_default_2, [8]); view_default_2 = None
_reshape_alias_default_3 = torch.ops.aten._reshape_alias.default(view_default_3, [2, 4], [4, 1]); view_default_3 = None
select_int_1 = torch.ops.aten.select.int(_reshape_alias_default_3, 0, 0); _reshape_alias_default_3 = None
add_tensor_2 = torch.ops.aten.add.Tensor(select_int_1, _unsafe_view_default); select_int_1 = _unsafe_view_default = None
return getitem
""")
def test_reapply_views_simple(self): def test_reapply_views_simple(self):
def f(x): def f(x):
tmp = torch.ones(4, 2) tmp = torch.ones(4, 2)
@ -530,10 +736,11 @@ def forward(self, a_1):
def forward(self, a_1): def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
view_default = torch.ops.aten.view.default(a_1, [4, 2]); a_1 = None view_default = torch.ops.aten.view.default(a_1, [4, 2])
add_tensor = torch.ops.aten.add.Tensor(view_default, fill_scalar); view_default = fill_scalar = None add_tensor = torch.ops.aten.add.Tensor(view_default, fill_scalar); view_default = fill_scalar = None
view_default_1 = torch.ops.aten.view.default(add_tensor, [4, 2]) view_default_1 = torch.ops.aten.view.default(add_tensor, [4, 2])
mul_tensor = torch.ops.aten.mul.Tensor(view_default_1, view_default_1); view_default_1 = None mul_tensor = torch.ops.aten.mul.Tensor(view_default_1, view_default_1)
copy__default = torch.ops.aten.copy_.default(a_1, view_default_1); a_1 = view_default_1 = None
return add_tensor return add_tensor
""") """)
@ -564,8 +771,6 @@ def forward(self, a_1):
def test_copy_(self): def test_copy_(self):
def f(x): def f(x):
tmp = torch.zeros(2, 2) tmp = torch.zeros(2, 2)
# NOTE: LoggingTensor isn't a mode, which means that the diagonal call
# will not be logged. This is fine for testing.
tmp_slice = tmp.diagonal() tmp_slice = tmp.diagonal()
y = tmp_slice.copy_(x) y = tmp_slice.copy_(x)
z = y.add_(x) z = y.add_(x)
@ -587,6 +792,21 @@ def forward(self, a_1):
copy_default = torch.ops.aten.copy.default(diagonal_copy_default_1, a_1); diagonal_copy_default_1 = None copy_default = torch.ops.aten.copy.default(diagonal_copy_default_1, a_1); diagonal_copy_default_1 = None
add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None
return add_tensor return add_tensor
""")
reinplaced_logs = self.get_logs(f, torch.ones(2), reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
zero_default = torch.ops.aten.zero_.default(empty)
diagonal_default = torch.ops.aten.diagonal.default(empty)
diagonal_default_1 = torch.ops.aten.diagonal.default(empty); empty = None
copy_default = torch.ops.aten.copy_.default(diagonal_default_1, a_1)
add_tensor = torch.ops.aten.add_.Tensor(diagonal_default_1, a_1); a_1 = None
return diagonal_default_1
""") """)
# Test 2: copy_() with same dtype, different shape # Test 2: copy_() with same dtype, different shape
@ -604,6 +824,21 @@ def forward(self, a_1):
copy_default = torch.ops.aten.copy.default(diagonal_copy_default_1, a_1); diagonal_copy_default_1 = None copy_default = torch.ops.aten.copy.default(diagonal_copy_default_1, a_1); diagonal_copy_default_1 = None
add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None
return add_tensor return add_tensor
""")
reinplaced_logs = self.get_logs(f, torch.ones(1), reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
zero_default = torch.ops.aten.zero_.default(empty)
diagonal_default = torch.ops.aten.diagonal.default(empty)
diagonal_default_1 = torch.ops.aten.diagonal.default(empty); empty = None
copy_default = torch.ops.aten.copy_.default(diagonal_default_1, a_1)
add_tensor = torch.ops.aten.add_.Tensor(diagonal_default_1, a_1); a_1 = None
return diagonal_default_1
""") """)
# Test 3: copy_() with different dtype, same shape # Test 3: copy_() with different dtype, same shape
@ -621,6 +856,21 @@ def forward(self, a_1):
copy_default = torch.ops.aten.copy.default(diagonal_copy_default_1, a_1); diagonal_copy_default_1 = None copy_default = torch.ops.aten.copy.default(diagonal_copy_default_1, a_1); diagonal_copy_default_1 = None
add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None
return add_tensor return add_tensor
""")
reinplaced_logs = self.get_logs(f, torch.ones(2, dtype=torch.long), reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
zero_default = torch.ops.aten.zero_.default(empty)
diagonal_default = torch.ops.aten.diagonal.default(empty)
diagonal_default_1 = torch.ops.aten.diagonal.default(empty); empty = None
copy_default = torch.ops.aten.copy_.default(diagonal_default_1, a_1)
add_tensor = torch.ops.aten.add_.Tensor(diagonal_default_1, a_1); a_1 = None
return diagonal_default_1
""") """)
# Test 4: copy_() with different dtype, different shape # Test 4: copy_() with different dtype, different shape
@ -638,6 +888,21 @@ def forward(self, a_1):
copy_default = torch.ops.aten.copy.default(diagonal_copy_default_1, a_1); diagonal_copy_default_1 = None copy_default = torch.ops.aten.copy.default(diagonal_copy_default_1, a_1); diagonal_copy_default_1 = None
add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None
return add_tensor return add_tensor
""")
reinplaced_logs = self.get_logs(f, torch.ones(1, dtype=torch.long), reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
zero_default = torch.ops.aten.zero_.default(empty)
diagonal_default = torch.ops.aten.diagonal.default(empty)
diagonal_default_1 = torch.ops.aten.diagonal.default(empty); empty = None
copy_default = torch.ops.aten.copy_.default(diagonal_default_1, a_1)
add_tensor = torch.ops.aten.add_.Tensor(diagonal_default_1, a_1); a_1 = None
return diagonal_default_1
""") """)
def test_expand_symint(self): def test_expand_symint(self):
@ -676,6 +941,18 @@ def forward(self, a_1):
fill_scalar = torch.ops.aten.fill.Scalar(diagonal_copy_default, 0); diagonal_copy_default = None fill_scalar = torch.ops.aten.fill.Scalar(diagonal_copy_default, 0); diagonal_copy_default = None
diagonal_scatter_default = torch.ops.aten.diagonal_scatter.default(add_tensor, fill_scalar); add_tensor = fill_scalar = None diagonal_scatter_default = torch.ops.aten.diagonal_scatter.default(add_tensor, fill_scalar); add_tensor = fill_scalar = None
return diagonal_scatter_default return diagonal_scatter_default
""")
reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, a_1):
add_tensor = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
diagonal_default = torch.ops.aten.diagonal.default(add_tensor)
fill_scalar = torch.ops.aten.fill_.Scalar(diagonal_default, 0); diagonal_default = None
return add_tensor
""") """)
def test_resize_smaller(self): def test_resize_smaller(self):
@ -713,6 +990,28 @@ def forward(self, a_1):
return add_tensor_2 return add_tensor_2
""") # noqa: B950 """) # noqa: B950
reinplaced_logs = self.get_logs(f, torch.ones(8, 2), reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, a_1):
add_tensor = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None
view_default = torch.ops.aten.view.default(add_tensor, [4, 4])
resize_default = torch.ops.aten.resize.default(view_default, [3, 3])
as_strided_default = torch.ops.aten.as_strided.default(view_default, [3, 3], [3, 1]); view_default = None
view_default_1 = torch.ops.aten.view.default(as_strided_default, [-1]); as_strided_default = None
add_tensor_1 = torch.ops.aten.add_.Tensor(view_default_1, 1)
view_default_2 = torch.ops.aten.view.default(add_tensor, [4, 4]); add_tensor = None
as_strided_default_1 = torch.ops.aten.as_strided.default(view_default_2, [3, 3], [3, 1])
view_default_3 = torch.ops.aten.view.default(view_default_1, [3, 3]); view_default_1 = None
view_default_4 = torch.ops.aten.view.default(view_default_2, [8, 2]); view_default_2 = None
view_default_5 = torch.ops.aten.view.default(view_default_4, [4, 4]); view_default_4 = None
as_strided_default_2 = torch.ops.aten.as_strided.default(view_default_5, [3, 3], [3, 1]); view_default_5 = None
add_tensor_2 = torch.ops.aten.add_.Tensor(as_strided_default_2, 1)
return as_strided_default_2
""")
def test_resize_larger_valid(self): def test_resize_larger_valid(self):
def f(x): def f(x):
y = x + 1 y = x + 1
@ -744,6 +1043,21 @@ def forward(self, a_1):
view_copy_default_1 = torch.ops.aten.view_copy.default(fill_scalar, [5, 5]); fill_scalar = None view_copy_default_1 = torch.ops.aten.view_copy.default(fill_scalar, [5, 5]); fill_scalar = None
add_tensor_1 = torch.ops.aten.add.Tensor(view_copy_default_1, 1) add_tensor_1 = torch.ops.aten.add.Tensor(view_copy_default_1, 1)
return (view_copy_default_1, add_tensor_1) return (view_copy_default_1, add_tensor_1)
""")
reinplaced_logs = self.get_logs(f, torch.ones(8, 2), reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, a_1):
add_tensor = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None
resize_default = torch.ops.aten.resize_.default(add_tensor, [5, 5])
view_default = torch.ops.aten.view.default(add_tensor, [25]); add_tensor = None
fill_scalar = torch.ops.aten.fill_.Scalar(view_default, 1)
view_default_1 = torch.ops.aten.view.default(view_default, [5, 5]); view_default = None
add_tensor_1 = torch.ops.aten.add.Tensor(view_default_1, 1)
return (view_default_1, add_tensor_1)
""") """)
def test_resize_larger_invalid(self): def test_resize_larger_invalid(self):
@ -809,5 +1123,40 @@ $3 = torch._ops.aten.add.Tensor($2, 1)""")
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
x1_not_functional.add_(x2_functional) x1_not_functional.add_(x2_functional)
def test_index_mutation_on_non_input(self):
def f(x):
tmp = torch.zeros(10)
tmp[5].fill_(1)
return tmp
self.assert_functionalization(f, torch.ones(2))
logs = self.get_logs(f, torch.ones(2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([10], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
zero_default = torch.ops.aten.zero.default(empty); empty = None
select_copy_int = torch.ops.aten.select_copy.int(zero_default, 0, 5)
select_copy_int_1 = torch.ops.aten.select_copy.int(zero_default, 0, 5)
fill_scalar = torch.ops.aten.fill.Scalar(select_copy_int_1, 1); select_copy_int_1 = None
select_scatter_default = torch.ops.aten.select_scatter.default(zero_default, fill_scalar, 0, 5); zero_default = fill_scalar = None
return select_scatter_default
""") # noqa: B950
reinplaced_logs = self.get_logs(f, torch.ones(2), reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([10], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
zero_default = torch.ops.aten.zero_.default(empty)
select_int = torch.ops.aten.select.int(empty, 0, 5)
select_int_1 = torch.ops.aten.select.int(empty, 0, 5)
fill_scalar = torch.ops.aten.fill_.Scalar(select_int_1, 1); select_int_1 = None
return empty
""")
if __name__ == '__main__': if __name__ == '__main__':
run_tests() run_tests()

View File

@ -0,0 +1,251 @@
# Owner(s): ["module: functionalization"]
import torch
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.fx.passes.reinplace import reinplace
from torch.fx.experimental.proxy_tensor import make_fx
try:
from functorch.experimental import functionalize
HAS_FUNCTIONALIZATION = True
except Exception as e:
HAS_FUNCTIONALIZATION = False
class TestReinplacePass(TestCase):
def test_reinplace_basic(self):
# Basic test: the out-of-place add() call should be converted
# into add_()
def f(x):
a = x.clone()
b = a.add(1)
return b
inpt = torch.ones(2)
f2 = reinplace(make_fx(f)(inpt), inpt)
expected_out = f(inpt)
actual_out = f2(inpt)
self.assertEqual(actual_out, expected_out)
self.assertExpectedInline(f2.code, """\
def forward(self, x_1):
clone_default = torch.ops.aten.clone.default(x_1); x_1 = None
add_tensor = torch.ops.aten.add_.Tensor(clone_default, 1)
return clone_default
""")
def test_reinplace_with_view(self):
def f(x):
a = x.clone()
a_view = a.view(-1)
# We shouldn't re-inplace the first add(), because an alias of a is re-used later in the program
b = a.add(1)
# Second add() is fine to re-inplace
c = a_view.add(1)
return c
inpt = torch.ones(2)
f2 = reinplace(make_fx(f)(inpt), inpt)
expected_out = f(inpt)
actual_out = f2(inpt)
self.assertEqual(actual_out, expected_out)
self.assertExpectedInline(f2.code, """\
def forward(self, x_1):
clone_default = torch.ops.aten.clone.default(x_1); x_1 = None
view_default = torch.ops.aten.view.default(clone_default, [-1])
add_tensor = torch.ops.aten.add.Tensor(clone_default, 1); clone_default = None
add_tensor_1 = torch.ops.aten.add_.Tensor(view_default, 1)
return view_default
""")
# This test won't actually run in CI, because it requires functionalize() from functorch.
# I'm planning on testing more comprehensively with torchbench models,
# but we can make this testing better once functorch moves into pytorch/pytorch.
def test_reinplace_scatter_op(self):
def f(a_):
# for now, don't test mutations to inputs
a = a_.clone()
e = a.view(-1)
b = a.view(-1)
c = b[0]
d = c.view(-1)
d.add_(1)
return a + e
if not HAS_FUNCTIONALIZATION:
return
inpt = torch.ones(4)
f2 = reinplace(make_fx(functionalize(f))(inpt), inpt)
expected_out = f(inpt)
actual_out = f2(inpt)
self.assertEqual(actual_out, expected_out)
# NOTE: one slight pessimization here is the fact that
# there are a bunch of redundant views in the graph.
# Technically, half of these views are duplicates that we could de-dup.
# This shouldn't really hurt performance though, since creating an extra view
# is effectively just moving some metadata around (and allocating a new TensorImpl).
# We can/should update the pass in the future to clean this up.
self.assertExpectedInline(f2.code, """\
def forward(self, a__1):
clone_default = torch.ops.aten.clone.default(a__1); a__1 = None
view_default = torch.ops.aten.view.default(clone_default, [-1])
view_default_1 = torch.ops.aten.view.default(clone_default, [-1])
select_int = torch.ops.aten.select.int(view_default_1, 0, 0); view_default_1 = None
view_default_2 = torch.ops.aten.view.default(select_int, [-1]); select_int = None
add_tensor = torch.ops.aten.add_.Tensor(view_default_2, 1)
view_default_3 = torch.ops.aten.view.default(clone_default, [-1]); clone_default = None
select_int_1 = torch.ops.aten.select.int(view_default_3, 0, 0)
view_default_4 = torch.ops.aten.view.default(view_default_2, []); view_default_2 = None
view_default_5 = torch.ops.aten.view.default(view_default_3, [4]); view_default_3 = None
view_default_6 = torch.ops.aten.view.default(view_default_5, [-1])
add_tensor_1 = torch.ops.aten.add_.Tensor(view_default_5, view_default_6); view_default_6 = None
return view_default_5
""")
def test_reinplace_scatter_twice(self):
def f(a_):
# for now, don't test mutations to inputs
a = a_.clone()
b = a[:, 1]
c = b[1]
c.add_(1)
return a
if not HAS_FUNCTIONALIZATION:
return
inpt = torch.ones(4, 4)
f2 = reinplace(make_fx(functionalize(f))(inpt), inpt)
expected_out = f(inpt)
actual_out = f2(inpt)
self.assertEqual(actual_out, expected_out)
self.assertExpectedInline(f2.code, """\
def forward(self, a__1):
clone_default = torch.ops.aten.clone.default(a__1); a__1 = None
slice_tensor = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807)
select_int = torch.ops.aten.select.int(slice_tensor, 1, 1); slice_tensor = None
select_int_1 = torch.ops.aten.select.int(select_int, 0, 1); select_int = None
add_tensor = torch.ops.aten.add_.Tensor(select_int_1, 1); select_int_1 = None
slice_tensor_1 = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807)
select_int_2 = torch.ops.aten.select.int(slice_tensor_1, 1, 1); slice_tensor_1 = None
return clone_default
""")
def test_reinplace_scatter_twice_with_different_view_op_valid(self):
def f(a_):
a = a_.clone()
b = a[:, 1]
c = b[1]
c_updated = c.add(1)
good_mirror_of_b = a.as_strided((4,), (4,), 1)
# good_mirror_of_b points to the same region of memory as b.
# and this scatter op below tries to scatter c_updated into the same region
# that c currently takes up.
# reinplacing logic checks this by confirming that:
# c_updated
# good_mirror_of_b.select(0, 1)
# have the same size/stride/storage_offset.
b_updated = torch.select_scatter(good_mirror_of_b, c_updated, 0, 1)
return b_updated
inpt = torch.ones(4, 4)
f2 = reinplace(make_fx(f)(inpt), inpt)
expected_out = f(inpt)
actual_out = f2(inpt)
self.assertEqual(actual_out, expected_out)
self.assertExpectedInline(f2.code, """\
def forward(self, a__1):
clone_default = torch.ops.aten.clone.default(a__1); a__1 = None
slice_tensor = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807)
select_int = torch.ops.aten.select.int(slice_tensor, 1, 1); slice_tensor = None
select_int_1 = torch.ops.aten.select.int(select_int, 0, 1); select_int = None
add_tensor = torch.ops.aten.add_.Tensor(select_int_1, 1); select_int_1 = None
as_strided_default = torch.ops.aten.as_strided.default(clone_default, [4], [4], 1); clone_default = None
return as_strided_default
""")
# Test example where we have a scatter op, where the base tensor
# has the same size/stride/storage offset (even though it is a different view),
# making it valid to re-inplace
def test_reinplace_scatter_twice_with_different_view_op_invalid(self):
def f(a_):
a = a_.clone()
b = a[:, 1]
c = b[1]
c_updated = c.add(1)
good_mirror_of_b = a.as_strided((4,), (4,), 1)
# The first arg to select_scatter is an equivalent view to b.
# However, the select_scatter call below tries to put c_updated
# into a different slice of "b" than what "c" currently occupies.
#
b_updated = torch.select_scatter(good_mirror_of_b, c_updated, 0, 0)
return b_updated
inpt = torch.ones(4, 4)
f2 = reinplace(make_fx(f)(inpt), inpt)
expected_out = f(inpt)
actual_out = f2(inpt)
self.assertEqual(actual_out, expected_out)
self.assertExpectedInline(f2.code, """\
def forward(self, a__1):
clone_default = torch.ops.aten.clone.default(a__1); a__1 = None
slice_tensor = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807)
select_int = torch.ops.aten.select.int(slice_tensor, 1, 1); slice_tensor = None
select_int_1 = torch.ops.aten.select.int(select_int, 0, 1); select_int = None
add_tensor = torch.ops.aten.add.Tensor(select_int_1, 1); select_int_1 = None
as_strided_default = torch.ops.aten.as_strided.default(clone_default, [4], [4], 1); clone_default = None
select_scatter_default = torch.ops.aten.select_scatter.default(as_strided_default, add_tensor, 0, 0); as_strided_default = add_tensor = None
return select_scatter_default
""") # noqa: B950
def test_reinplace_scatter_twice_with_different_view_op_invalid2(self):
def f(a_):
a = a_.clone()
b = a[:, 1]
c = b[1]
c_updated = c.add(1)
bad_mirror_of_b = a.as_strided((4,), (4,), 0)
# The first arg to select_scatter points to a different than c's base.
# This makes it invalid to re-inplace.
b_updated = torch.select_scatter(bad_mirror_of_b, c_updated, 0, 1)
return b_updated
inpt = torch.ones(4, 4)
f2 = reinplace(make_fx(f)(inpt), inpt)
expected_out = f(inpt)
actual_out = f2(inpt)
# self.assertEqual(actual_out, expected_out)
self.assertExpectedInline(f2.code, """\
def forward(self, a__1):
clone_default = torch.ops.aten.clone.default(a__1); a__1 = None
slice_tensor = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807)
select_int = torch.ops.aten.select.int(slice_tensor, 1, 1); slice_tensor = None
select_int_1 = torch.ops.aten.select.int(select_int, 0, 1); select_int = None
add_tensor = torch.ops.aten.add.Tensor(select_int_1, 1); select_int_1 = None
as_strided_default = torch.ops.aten.as_strided.default(clone_default, [4], [4], 0); clone_default = None
select_scatter_default = torch.ops.aten.select_scatter.default(as_strided_default, add_tensor, 0, 1); as_strided_default = add_tensor = None
return select_scatter_default
""") # noqa: B950
if __name__ == '__main__':
run_tests()

View File

@ -3,6 +3,7 @@ from . import graph_manipulation
from . import net_min_base from . import net_min_base
from . import operator_support from . import operator_support
from . import param_fetch from . import param_fetch
from . import reinplace
from . import shape_prop from . import shape_prop
from . import split_module from . import split_module
from . import split_utils from . import split_utils

View File

@ -0,0 +1,512 @@
import torch
from torch.fx import Node
from torch.fx._compatibility import compatibility
from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor
from torch.utils._pytree import tree_map
from torch.multiprocessing.reductions import StorageWeakRef
import _operator
from enum import Enum
import itertools
from typing import Set, Dict
from collections import defaultdict
__all__ = ['reinplace']
class _ViewType(Enum):
NonView = 0
SingleOutputView = 1
MultiOutputView = 2
def _is_view_op(tgt):
if tgt is not None and isinstance(tgt, torch._ops.OpOverload):
schema = tgt._schema
if len(schema.arguments) > 0:
first_arg = schema.arguments[0]
# check if op is a view
return first_arg.alias_info is not None and not first_arg.alias_info.is_write
def _get_view_type(tgt) -> _ViewType:
if tgt is not None and isinstance(tgt, torch._ops.OpOverload):
schema = tgt._schema
if len(schema.arguments) > 0:
first_arg = schema.arguments[0]
# check if op is a view
if first_arg.alias_info is not None and not first_arg.alias_info.is_write:
# check if op is a multi-output view
if '*' in first_arg.alias_info.after_set:
return _ViewType.MultiOutputView
else:
return _ViewType.SingleOutputView
return _ViewType.NonView
# Stores a bunch of metadata related to functionalization each node.
# Relevant metadata:
# n.meta['fake_result']: FakeTensor (same type as the output of the node, but with FakeTenors instead of Tensors)
# The fake tensor output from running the current node
# n.meta['view_of']: Node
# If the current node n is a view of some base tensor, the 'view_of' field tells us which
# view node was used to generate the current node (a view tensor).
# This information actually makes `fake_result` redundant, but we can use `fake_result`
# to sanity check that our aliasing information is correct.
@compatibility(is_backward_compatible=False)
class _FunctionalizationMetadataProp(torch.fx.Interpreter):
def run_node(self, node: Node):
self.node_counter += 1
result = super().run_node(node)
node.meta['fake_result'] = result
node.meta['node_idx'] = self.node_counter
# (1) Update metadata with the list of nodes that are used by this node
# copy_() doesn't read from its first argument; it writes to it, overwriting previous data.
# We don't want to treat it as "being used as an input".
node_args = node.args
if node.target is torch.ops.aten.copy_.default:
node_args = node_args[1:]
# (2) Update metadata to track aliasing information about view tensor nodes.
if node.op == 'call_function':
view_type = _get_view_type(node.target)
if view_type == _ViewType.SingleOutputView:
assert isinstance(node.args[0], Node)
node.meta['view_of'] = node.args[0]
elif view_type == _ViewType.MultiOutputView:
self.multi_output_view_nodes[node] = node.args[0]
# Check if we returned a multi-output view,
# and we're now grabbing the individual views from the output.
#
# For multi-output views, we want to map each output view to the base,
# but this mapping involves two separate nodes in FX IR.
# e.g. "a, b = x_1.split(...)" becomes:
# %split_tensor : [#users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%x_1, 2), kwargs = {})
# %getitem : [#users=1] = call_function[target=operator.getitem](args = (%split_tensor, 0), kwargs = {})
# %getitem_1 : [#users=1] = call_function[target=operator.getitem](args = (%split_tensor, 1), kwargs = {})
# And we'd like to set:
# getitem1.meta['view_of'] = x_1
elif node.target is _operator.getitem:
list_arg = node.args[0]
maybe_base_of_view = self.multi_output_view_nodes.get(list_arg, None)
if maybe_base_of_view is not None:
# Note: we could also track indexing info here for multi-output views.
# I don't think this metadata is strictly needed for de-functionalization.
assert isinstance(maybe_base_of_view, Node)
node.meta['view_of'] = maybe_base_of_view
if 'view_of' in node.meta:
# We're linking the current node with its first argument as views.
# Assert here that this is actually the case, and their storages are the same.
assert isinstance(node.meta['fake_result'], FakeTensor)
assert isinstance(node.meta['view_of'].meta['fake_result'], FakeTensor)
view_storage = StorageWeakRef(node.meta['fake_result'].storage())
base_storage = StorageWeakRef(node.meta['view_of'].meta['fake_result'].storage())
assert view_storage == base_storage
return result
def propagate(self, *args):
self.multi_output_view_nodes = {}
self.node_counter = -1
with FakeTensorMode.push() as mode:
fake_args = [mode.from_tensor(a) for a in args]
return super().run(*fake_args)
def _schemas_match(functional_schema, inplace_schema):
names_match = inplace_schema.name.endswith("_") and inplace_schema.name[:-1] == functional_schema.name
arg_types_match = len(functional_schema.arguments) == len(inplace_schema.arguments) and all(
a1.type == a2.type for a1, a2 in zip(functional_schema.arguments, inplace_schema.arguments))
# for the inplace op, its first argument should be mutable
assert inplace_schema.arguments[0].alias_info is not None and inplace_schema.arguments[0].alias_info.is_write
# and its remaining arguments shouldn't be.
assert all(a.alias_info is None for a in inplace_schema.arguments[1:])
return names_match and arg_types_match
# TODO: this should be beefed up to be able to properly re-inplace with:
# - mutating ops (e.g. _fused_moving_avg_obs_fq_helper)
# - out= ops (e.g. angle -> angle.out)
# TODO: we should also figure this info out using torchgen.
def _maybe_get_inplace_op(op):
# __module__ seems broken; it returns torch._ops.aten which doesn't exist
if not isinstance(op, torch._ops.OpOverload):
return None
# Some view ops have inplace variants (as_strided_, etc),
# but we do NOT want the reinplacing pass to directly add these into the program.
# (they'll require extra special handling, aren't aren't really useful for perf anyway)
if _is_view_op(op):
return None
op_namespace = op.__module__.split(".")[-1]
op_base_name = op.overloadpacket.__name__
maybe_namespace_module = getattr(torch.ops, op_namespace)
maybe_inplace_op = None if maybe_namespace_module is None else getattr(maybe_namespace_module, f'{op_base_name}_', None)
if maybe_inplace_op is None:
return None
inplace_overloads = [
getattr(maybe_inplace_op, overload_name) for overload_name in maybe_inplace_op.overloads()
]
inplace_overloads_with_matching_schemas = [
f
for f in inplace_overloads
if _schemas_match(op._schema, f._schema)
]
# This is for sanity: if foo() and foo_() are both operators,
# we expect them to have compatible schemas.
# (This is asserted by codegen for ATen, but might not be true
# for other arbitrary operators).
assert len(inplace_overloads_with_matching_schemas) == 1
inplace_op = inplace_overloads_with_matching_schemas[0]
return inplace_op
_VIEW_INVERSE_MAP = {
torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default,
torch.ops.aten.select_scatter.default: torch.ops.aten.select.int,
torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor,
torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default,
}
# This function, given a set of set of (aliased) tensor nodes,
# Returns any nodes in the graph that *use* any of the aliases, that occur *after* op_index
# in the node ordering.
def _get_all_later_node_usages(tensor_aliases: Set[Node], op_index: int):
def _add_if_tensor(x, set_):
if isinstance(x, FakeTensor):
set_.add(StorageWeakRef(x.storage()))
nodes_used_after = set()
for t in tensor_aliases:
# get all nodes that use the current alias
usage_nodes = t.users
for n in usage_nodes:
# We only care about usages after the current node
if n.meta['node_idx'] <= op_index:
continue
# We also don't care about intermediate view ops.
# They only matter if their output is then used elsewhere
# (either in an out-of-place op, or as an output to the function).
if n in tensor_aliases:
if isinstance(n.target, torch._ops.OpOverload) or n.target == _operator.getitem:
continue
nodes_used_after.add(n)
return nodes_used_after
# Given an op that we're trying to re-inplace, "b = foo(a)",
# And given a {view}_scatter op that shows up later in the graph, "y = {view}_scatter(base, x, args...)"
# Then re-inplacing `foo()` would allow us to remove the `{view}_scatter` op entirely, IF:
# If there are any aliases in the alias_set(a) that satisfy:
# (1) The base of "alias", "alias_base", has the same size/stride/offset metadata as "base"
# (2) The output of running {view}(alias, args...) gives you the same size/stride/offset metadata
# as "alias"
def _get_view_inverse_node_usages(later_node_usages: Set[Node], self_aliases: Set[Node]) -> Set[Node]:
def matching_view_metadata(a, b):
return a.size() == b.size() and \
a.stride() == b.stride() and \
a.storage_offset() == b.storage_offset()
view_inverse_nodes = set()
# Go through them in node order, so we can see chains of view_scatter ops.
for n in sorted(later_node_usages, key=lambda x: x.meta['node_idx']):
if n.target not in _VIEW_INVERSE_MAP:
continue
base = n.args[0]
mutated_view = n.args[1]
assert isinstance(base, Node)
assert isinstance(base.meta['fake_result'], FakeTensor)
assert isinstance(mutated_view, Node)
assert isinstance(mutated_view.meta['fake_result'], FakeTensor)
# Check that this view_inverse op actually corresponds to taking doing the inverse
# of one of our existing self_alias nodes.
original_view = _VIEW_INVERSE_MAP[n.target]
for self_alias in self_aliases:
# We're looking for some alias of the self arg, "alias",
# that was created from some op `alias = foo(base, args...)`
# such that the current _scatter op "inverts" that foo call.
# We can check that by running the original op again, and checking that the strides match.
if 'view_of' not in self_alias.meta:
continue
self_alias_base = self_alias.meta['view_of']
try:
# The we're trying to re-use the args from the view_scatter call inside of the corresponding
# view op, which might throw. This just indicates that view_scatter op isn't a valid inverse
# of the current alias we're looking at.
view_replay_metadata = original_view(self_alias_base.meta['fake_result'], *n.args[2:], **n.kwargs)
expected_metadata = self_alias.meta['fake_result']
# If the alias and its base both have matching metadata, then this view_scatter op is valid to re-inplace.
if matching_view_metadata(self_alias_base.meta['fake_result'], base.meta['fake_result']) and \
matching_view_metadata(view_replay_metadata, expected_metadata):
view_inverse_nodes.add(n)
except Exception:
continue
return view_inverse_nodes
@compatibility(is_backward_compatible=True)
def reinplace(gm, *sample_args):
"""
Given an fx.GraphModule, modifies it to perform "reinplacing",
mutating the nodes of the graph.
We look for out-of-place op call sites like `b = a.add(...)`,
and convert them to be inplace (`b = a.add_(...)`),
as long as the input to the current operator ("a") isn't re-used
anywhere later in the graph.
This pass currently expects to operate on a **functional, ATen** graph.
This can be obtained by running `make_fx(functionalize(f))`.
Sample inputs are needed to determine aliasing relationships of the inputs.
In general, we can't reinplace node `b = a.add(...)` if "a" aliases any of the
inputs to the program.
Given a node "b = foo(a, ...)", the algorithm for re-inplacing is as follows:
(1) Check if foo has a mutating variant. If not, move to the next node.
Note that we ignore view ops (we don't bother to turn `as_strided()`
into `as_strided_()`), as it complicates the algorithm and doesn't
provide meaningful speedups.
Currently, we also only check for an inplace op, `foo_`.
Later, we should beef this up to check for out= or mutable ops.
(2) Check if "a" is an alias of any of the program inputs.
If it is, skip and move to the next node.
Inplace'ing an op that would cause it to mutate a program is not sound,
because that would be a side effect visible to the user.
NOTE: there's a future optimization that we should make:
if "a" is a (alias of a) program input, but later in the program
there is a node that looks like "a.copy_(...)",
Then re-inplacing is ok to do - we are temporarily re-using a's buffer,
which will later be overwritten by the copy_() call.
This will be an important optimization to have for programs that mutate
their inputs. It currently isn't implemented though.
(3) Check that "a" and all of its outstanding aliases are not used anywhere
later in the graph. If this is the case, then it's safe to re-inplace
to "b = foo_(a)".
There are a few caveats to this, explained in more detail below:
(a) If "a" is used later as an argument to a view op, that is okay.
It's only a problem if "a" (or that view) is later passed
into a normal operator, or if it is returned as the program output.
(b) If "a" is a repeat argument in `foo()`, then don't reinplace.
Most ATen kernels don't make any guarantees that this is sound,
e.g. if you do aten.mul_(a, a).
So we'll just ban re-inplacing in this case.
It's only a problem if "a" (or that view) is later passed
(c) If "a" is used as an input into a view "inverse" / "scatter"
operator, it is potentially fine to re-inplace
(and remove that scatter operator from the graph).
See below for a more detailed example.
NOTE: there is an optimization in this step that is crucial
to fully recovering performance from functionalization.
Given this program:
def f(x):
a = torch.ops.aten.add(x, x)
b = torch.ops.aten.diagonal(a)
torch.ops.aten.fill_(b, 0)
return d
Functionalization will emit the following:
def f(x):
a = torch.ops.aten.add(x, x)
b = torch.ops.aten.diagonal(a, 0, 1)
b_updated = torch.ops.aten.fill(b, 0)
a_updated = torch.ops.aten.diagonal_scatter(a, b_updated, 0, 1)
return a_updated
Ordinarily, we would not be able to reinplace the fill,
because "b" aliases with "a" which is used by the diagonal_scatter call.
"re-inplacing" is on the hook for figuring out that it is ok to
completely, the expensive diagonal_scatter call, if we re-inplace the add().
So, for every `alias in alias_set(a)`, instead of checking
that "alias" is not used anywhere later in the graph,
we check that
EITHER:
(a) alias is not used anywhere later in the graph
OR:
(b) alias is used exactly once later on in the graph,
in the following op:
out = foo_scatter(alias, x, args...)
where the following must hold:
(i) "foo_scatter" is the "inverse" operator for foo.
This only applies to "foo" ops that are view operators,
which view into a subset of the original tensor's memory.
In practice, there are ~4 operators where this applies:
diagonal -> diagonal_scatter
slice -> slice_scatter
select -> select_scatter
as_strided -> as_strided_scatter
(ii) "args..." are the same between the foo() and foo_scatter() calls.
(4) Finally, after converting "b = foo(a)" into "foo_(a)",
we need to find all later nodes that use "b" as an argument
and update them to take in "a" instead.
Note that for the majority of inplace ops, this isn't actually necessary
(because most inplace ops return "self" as their output).
This isn't generally true for all mutable ops though, which is why
we need to actually replace all of the arguments.
We also need to update our metadata of Dict[StorageWeakRef, Set[Node]],
That maps a given tensor storage to the set of all nodes that take in that storage
as an input.
Specifically, re-inplacing `b = foo(a)` causes "a" and "b"'s sets to get fused
together.
(5) Any "view_inverse/scatter" nodes that were identified as "it's ok to ignore them"
during step (3) get manually deleted from the graph.
Their outputs are no longer used, so technically standard DCE would be able
to do this, but we can no longer run FX's DCE pass now that we have mutable
ops in the graph.
"""
_FunctionalizationMetadataProp(gm).propagate(*sample_args)
# Useful debug printing
# def _print(x):
# if isinstance(x, FakeTensor):
# print(f'fake_result: {StorageWeakRef(x.storage()).cdata}')
# for n in gm.graph.nodes:
# print(n.format_node())
# if hasattr(n, 'meta'):
# print(f'node_idx: {n.meta["node_idx"]}')
# if 'fake_result' in n.meta:
# tree_map(_print, n.meta['fake_result'])
# if 'view_of' in n.meta:
# print(f'view_of: {str(n.meta["view_of"])}')
# print()
# We need to know which nodes correspond to inputs (or their aliases)
# so we know not to re-inplace them.
# NOTE: later, we'll need to add an optimization for fully recovering performance
# on programs that mutate inputs.
input_storages = set(StorageWeakRef(node.meta['fake_result'].storage()) for node in gm.graph.nodes if node.op == 'placeholder')
# We also need to know for a given node, what are all of its aliasing nodes.
storage_to_nodes: Dict[StorageWeakRef, Set[Node]] = defaultdict(set)
for n in gm.graph.nodes:
if 'fake_result' in n.meta:
# Tree-mapping because some ops can return lists of tensors.
def _add_to_map(x):
if isinstance(x, FakeTensor):
storage_to_nodes[StorageWeakRef(x.storage())].add(n)
tree_map(_add_to_map, n.meta['fake_result'])
# inplace-ify functional ops, subject to the constraints written below.
all_later_view_inverse_node_usages = set()
for idx, node in enumerate(gm.graph.nodes):
if node.op == 'call_function':
# Step 1: Check to see if this operator has an inplace variant.
maybe_inplace_op = _maybe_get_inplace_op(node.target)
if maybe_inplace_op is None:
continue
# This is a proxy check for ensuring that the first argument is "tensor-like"
# (This should be the case for all ops with inplace variants in ATen,
# although we technically don't have guarantees for custom ops).
assert len(node.target._schema.arguments) > 0
assert 'Tensor' in str(node.target._schema.arguments[0].type)
# Step 2: ensure that the op we're trying to re-inplace isn't a program input.
self_arg = node.args[0]
self_arg_name = self_arg.name
self_arg_storage = StorageWeakRef(self_arg.meta['fake_result'].storage())
if self_arg_storage in input_storages:
# TODO: later, add the optimization for handling `copy_()` calls in the graph.
continue
if len([x for x in node.args if x is self_arg]) > 1:
# Step (3b) in the original description.
# Calling stuff like aten.mul_(a, a) isn't guaranteed to be sound,
# so we prevent re-inplacing in this case.
continue
self_arg_storage = StorageWeakRef(self_arg.meta['fake_result'].storage())
curr_node_storage = StorageWeakRef(node.meta['fake_result'].storage())
self_aliases = storage_to_nodes[self_arg_storage]
# First, we find all later usages of any of the aliases of self_arg.
later_node_usages = _get_all_later_node_usages(self_aliases, node.meta['node_idx'])
# Then, we check if any of those later usages are actually view_scatter ops
# that are safe to fully remove.
later_view_inverse_node_usages = _get_view_inverse_node_usages(later_node_usages, self_aliases)
# Step 3: Check to see if the input to the op is re-used later in the graph.
# If not (same goes for its aliases), then this op is safe to re-in place.
# This is a slightly roundabout way to check that there are no later usages of the current self argument.
# (later_view_inverse_node_usages corresponds to "view_scatter" nodes that we are allowed to delete)
can_reinplace = len(later_node_usages - later_view_inverse_node_usages) == 0
if not can_reinplace:
continue
# Step 4: replace the current out-of-place op with its inplace variant.
node.target = maybe_inplace_op
# At this point, 'storage_to_nodes' will be stale.
# Now that we're inplacing `b = foo(a)`, we need to effectively
# union together the dict values for b and a's storage.
# Hmm... morally I think we also want to keep the `fake_result` metadata
# up to date here, but I'm not sure how easy it is to do.
# Maybe it's fine to wait until the end of the pass to update it.
storage_to_nodes[self_arg_storage].update(storage_to_nodes[curr_node_storage])
storage_to_nodes[curr_node_storage].update(storage_to_nodes[self_arg_storage])
# Need to remember the view_scatter view nodes we found so we can remove them alter.
all_later_view_inverse_node_usages.update(later_view_inverse_node_usages)
# Now that we've replaced b = a.foo() with a.foo_(),
# We need to replace any later usages of "b" with "a"
for old in itertools.chain([node], later_view_inverse_node_usages):
new = old.args[0]
nodes_to_update = [n for n in old.users if n.meta['node_idx'] > node.meta['node_idx']]
for node_to_update in nodes_to_update:
new_args = []
for arg_idx, a in enumerate(node_to_update.args):
if a == old:
new_args.append(new)
else:
new_args.append(a)
new_kwargs = {}
for kwarg_idx, (k, v) in enumerate(node_to_update.kwargs.items()):
if isinstance(v, Node) and v.name == old.name:
new_kwargs[k] = new
else:
new_kwargs[k] = v
node_to_update.args = tuple(new_args)
node_to_update.kwargs = new_kwargs
old_ref = StorageWeakRef(old.meta['fake_result'].storage())
node_ref = StorageWeakRef(node_to_update.meta['fake_result'].storage())
if old_ref == node_ref:
# This will happen if we're updating a view op, e.g.
# e.g. replacing
# x = view(old)
# x = view(new)
# When that happens, we need to make sure to keep our
# storage mapping up to date.
new_ref = StorageWeakRef(new.meta['fake_result'].storage())
# Technically, "old_ref" and all its aliases will remain
# in our mapping.
# That should be fine though, since we deleted "old"
# from the graph at this point.
storage_to_nodes[node_ref].update(storage_to_nodes[new_ref])
storage_to_nodes[new_ref].update(storage_to_nodes[node_ref])
# Step 5: delete any _scatter nodes that we de-functionalized
# Need to take care not to delete any of these nodes until after *all* modifications
# to the graph are finished.
for to_delete in all_later_view_inverse_node_usages:
gm.graph.erase_node(to_delete)
gm.recompile()
return gm