mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
It turns out that we *do* need to update *_scatter ops to return the exact same strides as their inputs. I added a test to `test/test_functionalization.py`, which now trips thanks to Ed's functionalization stride debugging check. It only actually ends up tripping silent correctness if you try to .backward() on that function. Pull Request resolved: https://github.com/pytorch/pytorch/pull/91029 Approved by: https://github.com/ezyang
363 lines
13 KiB
Python
363 lines
13 KiB
Python
# Owner(s): ["module: functionalization"]
|
|
import torch
|
|
from torch.testing._internal.common_utils import TestCase, run_tests
|
|
from torch.fx.passes.reinplace import reinplace
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
|
|
try:
|
|
from functorch.experimental import functionalize
|
|
HAS_FUNCTIONALIZATION = True
|
|
except Exception as e:
|
|
HAS_FUNCTIONALIZATION = False
|
|
|
|
class TestReinplacePass(TestCase):
|
|
|
|
def test_reinplace_basic(self):
|
|
# Basic test: the out-of-place add() call should be converted
|
|
# into add_()
|
|
def f(x):
|
|
a = x.clone()
|
|
b = a.add(1)
|
|
return b
|
|
|
|
inpt = torch.ones(2)
|
|
f2 = reinplace(make_fx(f)(inpt), inpt)
|
|
expected_out = f(inpt)
|
|
actual_out = f2(inpt)
|
|
self.assertEqual(actual_out, expected_out)
|
|
self.assertExpectedInline(f2.code, """\
|
|
|
|
|
|
|
|
def forward(self, x_1):
|
|
clone = torch.ops.aten.clone.default(x_1); x_1 = None
|
|
add = torch.ops.aten.add_.Tensor(clone, 1)
|
|
return clone
|
|
""")
|
|
|
|
|
|
def test_reinplace_with_view(self):
|
|
def f(x):
|
|
a = x.clone()
|
|
a_view = a.view(-1)
|
|
# We shouldn't re-inplace the first add(), because an alias of a is re-used later in the program
|
|
b = a.add(1)
|
|
# Second add() is fine to re-inplace
|
|
c = a_view.add(1)
|
|
return c
|
|
|
|
inpt = torch.ones(2)
|
|
f2 = reinplace(make_fx(f)(inpt), inpt)
|
|
expected_out = f(inpt)
|
|
actual_out = f2(inpt)
|
|
self.assertEqual(actual_out, expected_out)
|
|
self.assertExpectedInline(f2.code, """\
|
|
|
|
|
|
|
|
def forward(self, x_1):
|
|
clone = torch.ops.aten.clone.default(x_1); x_1 = None
|
|
view = torch.ops.aten.view.default(clone, [-1])
|
|
add = torch.ops.aten.add.Tensor(clone, 1); clone = None
|
|
add_1 = torch.ops.aten.add_.Tensor(view, 1)
|
|
return view
|
|
""")
|
|
|
|
def test_reinplace_different_metadata(self):
|
|
def f(a_):
|
|
a = a_.clone()
|
|
b = a + 1
|
|
# Naively, we shouldn't try to inplace the .ge() call,
|
|
# because that would require resizing "b" (from a float to a bool tensor).
|
|
c = torch.ge(b, a)
|
|
return c
|
|
inpt = torch.ones(4)
|
|
f2 = reinplace(make_fx(f)(inpt), inpt)
|
|
expected_out = f(inpt)
|
|
actual_out = f2(inpt)
|
|
self.assertEqual(actual_out, expected_out)
|
|
# The .ge() should not be reinplaced.
|
|
self.assertExpectedInline(f2.code, """\
|
|
|
|
|
|
|
|
def forward(self, a__1):
|
|
clone = torch.ops.aten.clone.default(a__1); a__1 = None
|
|
add = torch.ops.aten.add.Tensor(clone, 1)
|
|
ge = torch.ops.aten.ge.Tensor(add, clone); add = clone = None
|
|
return ge
|
|
""")
|
|
|
|
def test_reinplace_overlapping_memory(self):
|
|
def f(a_):
|
|
a = a_.clone()
|
|
b = a.expand(4, 4)
|
|
# Can't reinplace because b has overlapping memory.
|
|
c = b.add(1)
|
|
return c
|
|
inpt = torch.ones(1)
|
|
f2 = reinplace(make_fx(f)(inpt), inpt)
|
|
expected_out = f(inpt)
|
|
actual_out = f2(inpt)
|
|
self.assertEqual(actual_out, expected_out)
|
|
self.assertExpectedInline(f2.code, """\
|
|
|
|
|
|
|
|
def forward(self, a__1):
|
|
clone = torch.ops.aten.clone.default(a__1); a__1 = None
|
|
expand = torch.ops.aten.expand.default(clone, [4, 4]); clone = None
|
|
add = torch.ops.aten.add.Tensor(expand, 1); expand = None
|
|
return add
|
|
""")
|
|
|
|
# This test won't actually run in CI, because it requires functionalize() from functorch.
|
|
# I'm planning on testing more comprehensively with torchbench models,
|
|
# but we can make this testing better once functorch moves into pytorch/pytorch.
|
|
def test_reinplace_scatter_op(self):
|
|
def f(a_):
|
|
# for now, don't test mutations to inputs
|
|
a = a_.clone()
|
|
e = a.view(-1)
|
|
b = a.view(-1)
|
|
c = b[0]
|
|
d = c.view(-1)
|
|
d.add_(1)
|
|
return a + e
|
|
|
|
if not HAS_FUNCTIONALIZATION:
|
|
return
|
|
inpt = torch.ones(4)
|
|
f2 = reinplace(make_fx(functionalize(f))(inpt), inpt)
|
|
expected_out = f(inpt)
|
|
actual_out = f2(inpt)
|
|
self.assertEqual(actual_out, expected_out)
|
|
# NOTE: one slight pessimization here is the fact that
|
|
# there are a bunch of redundant views in the graph.
|
|
# Technically, half of these views are duplicates that we could de-dup.
|
|
# This shouldn't really hurt performance though, since creating an extra view
|
|
# is effectively just moving some metadata around (and allocating a new TensorImpl).
|
|
# We can/should update the pass in the future to clean this up.
|
|
self.assertExpectedInline(f2.code, """\
|
|
|
|
|
|
|
|
def forward(self, a__1):
|
|
clone = torch.ops.aten.clone.default(a__1); a__1 = None
|
|
view = torch.ops.aten.view.default(clone, [-1])
|
|
view_1 = torch.ops.aten.view.default(clone, [-1])
|
|
select = torch.ops.aten.select.int(view_1, 0, 0); view_1 = None
|
|
view_2 = torch.ops.aten.view.default(select, [-1]); select = None
|
|
add = torch.ops.aten.add_.Tensor(view_2, 1)
|
|
view_3 = torch.ops.aten.view.default(clone, [-1]); clone = None
|
|
select_1 = torch.ops.aten.select.int(view_3, 0, 0)
|
|
view_4 = torch.ops.aten.view.default(view_2, []); view_2 = None
|
|
view_5 = torch.ops.aten.view.default(view_3, [4]); view_3 = None
|
|
view_6 = torch.ops.aten.view.default(view_5, [-1])
|
|
select_2 = torch.ops.aten.select.int(view_6, 0, 0); view_6 = None
|
|
view_7 = torch.ops.aten.view.default(select_2, [-1]); select_2 = None
|
|
view_8 = torch.ops.aten.view.default(view_5, [-1])
|
|
add_1 = torch.ops.aten.add_.Tensor(view_5, view_8); view_8 = None
|
|
return view_5
|
|
""")
|
|
|
|
def test_reinplace_scatter_twice(self):
|
|
def f(a_):
|
|
# for now, don't test mutations to inputs
|
|
a = a_.clone()
|
|
b = a[:, 1]
|
|
c = b[1]
|
|
c.add_(1)
|
|
return a
|
|
|
|
if not HAS_FUNCTIONALIZATION:
|
|
return
|
|
|
|
inpt = torch.ones(4, 4)
|
|
f2 = reinplace(make_fx(functionalize(f))(inpt), inpt)
|
|
expected_out = f(inpt)
|
|
actual_out = f2(inpt)
|
|
self.assertEqual(actual_out, expected_out)
|
|
self.assertExpectedInline(f2.code, """\
|
|
|
|
|
|
|
|
def forward(self, a__1):
|
|
clone = torch.ops.aten.clone.default(a__1); a__1 = None
|
|
slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
|
|
select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None
|
|
select_1 = torch.ops.aten.select.int(select, 0, 1); select = None
|
|
add = torch.ops.aten.add_.Tensor(select_1, 1); select_1 = None
|
|
slice_2 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
|
|
select_2 = torch.ops.aten.select.int(slice_2, 1, 1); slice_2 = None
|
|
slice_3 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
|
|
select_3 = torch.ops.aten.select.int(slice_3, 1, 1); slice_3 = None
|
|
select_4 = torch.ops.aten.select.int(select_3, 0, 1); select_3 = None
|
|
return clone
|
|
""")
|
|
|
|
def test_reinplace_scatter_twice_with_different_view_op_valid(self):
|
|
def f(a_):
|
|
a = a_.clone()
|
|
b = a[:, 1]
|
|
c = b[1]
|
|
c_updated = c.add(1)
|
|
good_mirror_of_b = a.as_strided((4,), (4,), 1)
|
|
# good_mirror_of_b points to the same region of memory as b.
|
|
# and this scatter op below tries to scatter c_updated into the same region
|
|
# that c currently takes up.
|
|
# reinplacing logic checks this by confirming that:
|
|
# c_updated
|
|
# good_mirror_of_b.select(0, 1)
|
|
# have the same size/stride/storage_offset.
|
|
b_updated = torch.select_scatter(good_mirror_of_b, c_updated, 0, 1)
|
|
return b_updated
|
|
|
|
inpt = torch.ones(4, 4)
|
|
f2 = reinplace(make_fx(f)(inpt), inpt)
|
|
expected_out = f(inpt)
|
|
actual_out = f2(inpt)
|
|
self.assertEqual(actual_out, expected_out)
|
|
self.assertExpectedInline(f2.code, """\
|
|
|
|
|
|
|
|
def forward(self, a__1):
|
|
clone = torch.ops.aten.clone.default(a__1); a__1 = None
|
|
slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
|
|
select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None
|
|
select_1 = torch.ops.aten.select.int(select, 0, 1); select = None
|
|
add = torch.ops.aten.add_.Tensor(select_1, 1); select_1 = None
|
|
as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 1); clone = None
|
|
return as_strided
|
|
""")
|
|
|
|
# Test example where we have a scatter op, where the base tensor
|
|
# has the same size/stride/storage offset (even though it is a different view),
|
|
# making it valid to re-inplace
|
|
def test_reinplace_scatter_twice_with_different_view_op_invalid(self):
|
|
def f(a_):
|
|
a = a_.clone()
|
|
b = a[:, 1]
|
|
c = b[1]
|
|
c_updated = c.add(1)
|
|
good_mirror_of_b = a.as_strided((4,), (4,), 1)
|
|
# The first arg to select_scatter is an equivalent view to b.
|
|
# However, the select_scatter call below tries to put c_updated
|
|
# into a different slice of "b" than what "c" currently occupies.
|
|
#
|
|
b_updated = torch.select_scatter(good_mirror_of_b, c_updated, 0, 0)
|
|
return b_updated
|
|
|
|
inpt = torch.ones(4, 4)
|
|
f2 = reinplace(make_fx(f)(inpt), inpt)
|
|
expected_out = f(inpt)
|
|
actual_out = f2(inpt)
|
|
self.assertEqual(actual_out, expected_out)
|
|
self.assertExpectedInline(f2.code, """\
|
|
|
|
|
|
|
|
def forward(self, a__1):
|
|
clone = torch.ops.aten.clone.default(a__1); a__1 = None
|
|
slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
|
|
select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None
|
|
select_1 = torch.ops.aten.select.int(select, 0, 1); select = None
|
|
add = torch.ops.aten.add.Tensor(select_1, 1); select_1 = None
|
|
as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 1); clone = None
|
|
select_int = torch.ops.aten.select.int(as_strided, 0, 0)
|
|
copy__default = torch.ops.aten.copy_.default(select_int, add); select_int = add = None
|
|
return as_strided
|
|
""") # noqa: B950
|
|
|
|
def test_reinplace_scatter_twice_with_different_view_op_invalid2(self):
|
|
def f(a_):
|
|
a = a_.clone()
|
|
b = a[:, 1]
|
|
c = b[1]
|
|
c_updated = c.add(1)
|
|
bad_mirror_of_b = a.as_strided((4,), (4,), 0)
|
|
# The first arg to select_scatter points to a different than c's base.
|
|
# This makes it invalid to re-inplace.
|
|
b_updated = torch.select_scatter(bad_mirror_of_b, c_updated, 0, 1)
|
|
return b_updated
|
|
|
|
inpt = torch.ones(4, 4)
|
|
f2 = reinplace(make_fx(f)(inpt), inpt)
|
|
expected_out = f(inpt)
|
|
actual_out = f2(inpt)
|
|
# self.assertEqual(actual_out, expected_out)
|
|
self.assertExpectedInline(f2.code, """\
|
|
|
|
|
|
|
|
def forward(self, a__1):
|
|
clone = torch.ops.aten.clone.default(a__1); a__1 = None
|
|
slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
|
|
select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None
|
|
select_1 = torch.ops.aten.select.int(select, 0, 1); select = None
|
|
add = torch.ops.aten.add.Tensor(select_1, 1); select_1 = None
|
|
as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 0); clone = None
|
|
select_int = torch.ops.aten.select.int(as_strided, 0, 1)
|
|
copy__default = torch.ops.aten.copy_.default(select_int, add); select_int = add = None
|
|
return as_strided
|
|
""") # noqa: B950
|
|
|
|
|
|
def test_out_node_updated(self):
|
|
def f():
|
|
x = torch.zeros(2, 2)
|
|
y = x.diagonal()
|
|
y_updated = y.add(1)
|
|
z = torch.diagonal_scatter(x, y_updated)
|
|
# reinplace needs to know to replace output [z] with [x]
|
|
return [z]
|
|
|
|
if not HAS_FUNCTIONALIZATION:
|
|
return
|
|
f2 = reinplace(make_fx(functionalize(f))())
|
|
expected_out = f()
|
|
actual_out = f2()
|
|
self.assertEqual(actual_out, expected_out)
|
|
self.assertExpectedInline(f2.code, """\
|
|
|
|
|
|
|
|
def forward(self):
|
|
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
|
|
diagonal = torch.ops.aten.diagonal.default(zeros)
|
|
add = torch.ops.aten.add_.Tensor(diagonal, 1); diagonal = None
|
|
return [zeros]
|
|
""")
|
|
|
|
def test_reinplace_index_mutation(self):
|
|
def f():
|
|
a = torch.zeros(4, 4, 4)
|
|
a[:, 2:] = torch.ones(4, 2, 4)
|
|
return a
|
|
|
|
if not HAS_FUNCTIONALIZATION:
|
|
return
|
|
f2 = reinplace(make_fx(functionalize(f))())
|
|
expected_out = f()
|
|
actual_out = f2()
|
|
self.assertEqual(actual_out, expected_out)
|
|
self.assertExpectedInline(f2.code, """\
|
|
|
|
|
|
|
|
def forward(self):
|
|
zeros = torch.ops.aten.zeros.default([4, 4, 4], device = device(type='cpu'), pin_memory = False)
|
|
ones = torch.ops.aten.ones.default([4, 2, 4], device = device(type='cpu'), pin_memory = False)
|
|
slice_1 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
|
|
slice_2 = torch.ops.aten.slice.Tensor(slice_1, 1, 2, 9223372036854775807); slice_1 = None
|
|
copy = torch.ops.aten.copy_.default(slice_2, ones); slice_2 = ones = None
|
|
slice_3 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
|
|
slice_4 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
|
|
slice_5 = torch.ops.aten.slice.Tensor(slice_4, 1, 2, 9223372036854775807); slice_4 = None
|
|
return zeros
|
|
""")
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|