Files
pytorch/test/test_functionalization.py
Edward Z. Yang 0e3031f7e7 Functionalize and compute joint simultaneously. (#88063)
This also comes with some bug fixes that were uncovered from doing
this:

- Forward device calls to inner tensor in FunctionalTensorWrapper

- Make legacyExtractDispatchKey exclude Functionalize, so that
  it can get at the real device type key.  This is noncontroversial.

- Stop stripping dense from key set.  The reason for this is
  FunctionalWrapperTensor may be used in contexts where people
  query if it is dense or not.  If it doesn't report this correctly
  (from the dispatch key), it will cause errors.  This caused some
  torchbench models to fail when I did one-pass tracing.

- Save and restore reapply views TLS correctly

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88063
Approved by: https://github.com/bdhirsh
2022-11-05 03:52:40 +00:00

1206 lines
51 KiB
Python

# Owner(s): ["module: codegen"]
import torch
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorchDynamo, TEST_WITH_TORCHDYNAMO
from torch.testing._internal.logging_tensor import LoggingTensor, capture_logs
from torch.utils._pytree import tree_map
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.passes.reinplace import reinplace
import unittest
def are_aliased(x, y):
if x._base is None and y._base is None:
return False
if x._base is not None and y._base is None:
return x._base is y
if x._base is None and y._base is not None:
return y._base is x
return x._base is y._base
# We can unify testing and use functionalize() here instead
# if/when functorch moves into core.
# This is basically a crappy version of `functionalize()` for single-tensor-arg inputs.
def _functionalize(f, *, reapply_views: bool):
def wrapped(a):
input_functional = torch._to_functional_tensor(a)
input_functional.requires_grad = a.requires_grad
torch._enable_functionalization(reapply_views=reapply_views)
try:
out = f(input_functional)
finally:
torch._disable_functionalization()
torch._sync(input_functional)
inpt_new = torch._from_functional_tensor(input_functional)
if inpt_new is not a:
# Existing deficiency in functionalize():
# we don't correctly mutate input metadata (yet?)
if inpt_new.shape == a.shape:
a.copy_(inpt_new)
tree_map(torch._sync, out)
out_unwrapped = tree_map(torch._from_functional_tensor, out)
return out_unwrapped
return wrapped
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "https://github.com/pytorch/pytorch/issues/81457")
class TestFunctionalization(TestCase):
def get_logs(self, func, inpt, *, reapply_views=False, run_reinplace=False):
inpt_clone = inpt.clone()
traced_f = make_fx(_functionalize(func, reapply_views=reapply_views))(inpt)
if run_reinplace:
traced_f = reinplace(traced_f, inpt_clone)
return traced_f.code
def assert_functionalization(self, func, inpt, *, reapply_views=False, mutated_input_metadata=False):
input_clone = inpt.clone()
input_clone2 = inpt.clone()
input_clone3 = inpt.clone()
# Compare outputs (and mutated inputs), with and without functionalization.
out_ref = func(inpt)
out_functional = _functionalize(func, reapply_views=reapply_views)(input_clone)
# The reinplacing pass is only valid to run with reapply_views=True.
functional_func = make_fx(_functionalize(func, reapply_views=True))(input_clone2)
reinplace_func = reinplace(make_fx(_functionalize(func, reapply_views=True))(input_clone2), input_clone2)
# NOTE: for now, need to pass in fresh inputs here, because make_fx
# will directly mutate the inputs that you trace with.
# Once this is fixed we can clean this up.
out_reinplace = reinplace_func(input_clone3)
# functionalize() deficiency: input metadata mutations aren't propagated properly,
# so we just need to skip checks here for the tests that exercise that.
if not mutated_input_metadata:
self.assertEqual(inpt, input_clone) # input mutations should still occur
self.assertEqual(inpt, input_clone3)
# Handle tests with multi-tensor outputs
if isinstance(out_ref, tuple):
out_refs, out_functionals, out_reinplaces = list(out_ref), list(out_functional), list(out_reinplace)
else:
out_refs, out_functionals, out_reinplaces = [out_ref], [out_functional], [out_reinplace]
for out_ref_, out_functional_, out_reinplace_ in zip(out_refs, out_functionals, out_reinplaces):
self.assertEqual(out_ref_, out_functional_)
self.assertEqual(out_ref_, out_reinplace_)
def test_save_for_backwards_segfault(self):
inp = torch._to_functional_tensor(LoggingTensor(torch.randn(2, 2))).requires_grad_(True)
inp.exp()
def test_multiple_views_of_same_base(self):
def f(x):
y = x.view(-1)
z = x.view(-1)
x.add_(1)
# y should have been updated.
y2 = y + 1
# z should have been updated too.
z2 = z + 1
return z2
self.assert_functionalization(f, torch.ones(4))
def test_freeze(self):
def f(x):
y = x.clone()
z = y[0]
torch._freeze_functional_tensor(y)
x.add_(1)
self.assertRaises(RuntimeError, lambda: y.add_(1))
self.assertRaises(RuntimeError, lambda: z.add_(1))
return z
_functionalize(f, reapply_views=True)(torch.ones(3, 3))
def test_view_clone_view_inplace(self):
def f(input):
shape = [1, 1024, 128, 128]
input_reshaped = input.view(shape)
out = input_reshaped.clone()
r = out.view(input.shape)
r.relu_()
return r
def g(x):
loss = f(x).sum()
from functorch._src.aot_autograd import setup_stacktrace_preservation_hooks
import torch.fx.traceback as fx_traceback
setup_stacktrace_preservation_hooks([loss.grad_fn])
with fx_traceback.override_stack_trace():
loss.backward()
return x.grad
with torch.autograd.detect_anomaly(check_nan=False):
logs = self.get_logs(g, torch.ones(16, 64, 128, 128, requires_grad=True))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
view_copy = torch.ops.aten.view_copy.default(a_1, [1, 1024, 128, 128]); a_1 = None
clone = torch.ops.aten.clone.default(view_copy); view_copy = None
view_copy_1 = torch.ops.aten.view_copy.default(clone, [16, 64, 128, 128])
relu = torch.ops.aten.relu.default(view_copy_1); view_copy_1 = None
view_copy_2 = torch.ops.aten.view_copy.default(clone, [16, 64, 128, 128]); clone = None
sum_1 = torch.ops.aten.sum.default(relu)
ones_like = torch.ops.aten.ones_like.default(sum_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False, memory_format = torch.preserve_format); sum_1 = None
expand_copy = torch.ops.aten.expand_copy.default(ones_like, [16, 64, 128, 128]); ones_like = None
_reshape_alias_copy = torch.ops.aten._reshape_alias_copy.default(expand_copy, [1, 1024, 128, 128], [16777216, 16384, 128, 1]); expand_copy = None
new_empty_strided = torch.ops.aten.new_empty_strided.default(_reshape_alias_copy, [1, 1024, 128, 128], [16777216, 16384, 128, 1])
view_copy_3 = torch.ops.aten.view_copy.default(_reshape_alias_copy, [16, 64, 128, 128])
view_copy_4 = torch.ops.aten.view_copy.default(_reshape_alias_copy, [16, 64, 128, 128])
clone_1 = torch.ops.aten.clone.default(view_copy_4, memory_format = torch.contiguous_format); view_copy_4 = None
threshold_backward = torch.ops.aten.threshold_backward.default(clone_1, relu, 0); clone_1 = relu = None
_reshape_alias_copy_1 = torch.ops.aten._reshape_alias_copy.default(_reshape_alias_copy, [16, 64, 128, 128], [1048576, 16384, 128, 1]); _reshape_alias_copy = None
detach_copy = torch.ops.aten.detach_copy.default(_reshape_alias_copy_1); _reshape_alias_copy_1 = None
view_copy_5 = torch.ops.aten.view_copy.default(threshold_backward, [1, 1024, 128, 128]); threshold_backward = None
_reshape_alias_copy_2 = torch.ops.aten._reshape_alias_copy.default(view_copy_5, [16, 64, 128, 128], [1048576, 16384, 128, 1]); view_copy_5 = None
detach_copy_1 = torch.ops.aten.detach_copy.default(_reshape_alias_copy_2); _reshape_alias_copy_2 = None
return detach_copy_1
""") # noqa: B950
def test_simple(self):
def f(x):
# simple test: 1 view op, 1 inplace op
tmp = torch.ones(4, 2)
y = x.view(4, 2)
y.add_(tmp)
z = x * x
return y
self.assert_functionalization(f, torch.ones(4, 2))
logs = self.get_logs(f, torch.ones(4, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
view_copy = torch.ops.aten.view_copy.default(a_1, [4, 2])
add = torch.ops.aten.add.Tensor(view_copy, ones); view_copy = ones = None
view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2])
mul = torch.ops.aten.mul.Tensor(view_copy_1, view_copy_1)
copy_ = torch.ops.aten.copy_.default(a_1, view_copy_1); a_1 = view_copy_1 = None
return add
""")
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, a_1):
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
view = torch.ops.aten.view.default(a_1, [4, 2])
add = torch.ops.aten.add.Tensor(view, ones); view = ones = None
view_1 = torch.ops.aten.view.default(add, [4, 2])
mul = torch.ops.aten.mul.Tensor(view_1, view_1)
copy_ = torch.ops.aten.copy_.default(a_1, view_1); a_1 = view_1 = None
return add
""")
def test_simple_out(self):
def f(x):
tmp = torch.ones(4, 2)
y = x.view(4, 2)
# the out= tensor will get resized, since it has size=0 to start.
z = torch.empty(())
torch.add(y, tmp, out=z)
w = z * z
return w
self.assert_functionalization(f, torch.ones(4, 2))
logs = self.get_logs(f, torch.ones(4, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
view_copy = torch.ops.aten.view_copy.default(a_1, [4, 2]); a_1 = None
empty = torch.ops.aten.empty.memory_format([], device = device(type='cpu'), pin_memory = False)
add = torch.ops.aten.add.Tensor(view_copy, ones); view_copy = ones = None
mul = torch.ops.aten.mul.Tensor(add, add); add = None
return mul
""")
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, a_1):
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
view = torch.ops.aten.view.default(a_1, [4, 2]); a_1 = None
empty = torch.ops.aten.empty.memory_format([], device = device(type='cpu'), pin_memory = False)
add = torch.ops.aten.add.Tensor(view, ones); view = ones = None
mul = torch.ops.aten.mul.Tensor(add, add); add = None
return mul
""")
def test_multi_out(self):
def f(x):
# aminmax.out returns a tuple of tensors.
# functionalization should properly handle the tuple.
out_min = torch.empty(4)
out_max = torch.empty(4)
torch.aminmax(x, dim=0, out=(out_max, out_min))
return out_max
self.assert_functionalization(f, torch.arange(8, dtype=torch.float32))
logs = self.get_logs(f, torch.arange(8, dtype=torch.float32))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False)
empty_1 = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False)
aminmax = torch.ops.aten.aminmax.default(a_1, dim = 0); a_1 = None
getitem = aminmax[0]
getitem_1 = aminmax[1]; aminmax = None
return getitem
""")
reinplaced_logs = self.get_logs(f, torch.arange(8, dtype=torch.float32), reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False)
empty_1 = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False)
aminmax = torch.ops.aten.aminmax.default(a_1, dim = 0); a_1 = None
getitem = aminmax[0]
getitem_1 = aminmax[1]; aminmax = None
return getitem
""")
def test_tensor_ctr(self):
def f(x):
y = torch.tensor((1, 2, 3))
z = y.view(-1)
z.add_(1)
return y
inpt = torch.arange(3, dtype=torch.float32)
self.assert_functionalization(f, inpt)
logs = self.get_logs(f, inpt)
self.assertExpectedInline(logs, """\
def forward(self, a_1):
_tensor_constant0 = self._tensor_constant0
lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
view_copy = torch.ops.aten.view_copy.default(lift_fresh_copy, [-1]); lift_fresh_copy = None
add = torch.ops.aten.add.Tensor(view_copy, 1); view_copy = None
view_copy_1 = torch.ops.aten.view_copy.default(add, [3]); add = None
return view_copy_1
""")
reinplaced_logs = self.get_logs(f, inpt, reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, a_1):
_tensor_constant0 = self._tensor_constant0
lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
view = torch.ops.aten.view.default(lift_fresh_copy, [-1]); lift_fresh_copy = None
add = torch.ops.aten.add_.Tensor(view, 1)
view_1 = torch.ops.aten.view.default(view, [3]); view = None
return view_1
""")
def test_tensor_list_mixed_functional_nonfunctional(self):
nonfunctional_tensor = torch.ones(2, dtype=torch.long)
def f(x):
# simple test: 1 view op, 1 inplace op
functional_tensor = torch.ones(2, dtype=torch.long)
out = x[functional_tensor, nonfunctional_tensor]
return out
out = f(torch.ones(2, 2))
out_functional = _functionalize(f, reapply_views=True)(torch.ones(2, 2))
self.assertEqual(out, out_functional)
def test_inplace_on_non_view(self):
def f(x):
# test for the case where we functionalize an inplace op on the other tensor - not a view.
# This is worth checking because the tensor will have an empty ViewMeta stack, which needs to be special cased.
tmp = torch.ones(4, 2)
y = x.view(4, 2)
x.add_(tmp)
return y
self.assert_functionalization(f, torch.ones(4, 2))
logs = self.get_logs(f, torch.ones(4, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
view_copy = torch.ops.aten.view_copy.default(a_1, [4, 2])
add = torch.ops.aten.add.Tensor(a_1, ones); ones = None
copy_ = torch.ops.aten.copy_.default(a_1, add); a_1 = None
view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]); add = None
return view_copy_1
""")
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, a_1):
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
view = torch.ops.aten.view.default(a_1, [4, 2])
add = torch.ops.aten.add.Tensor(a_1, ones); ones = None
copy_ = torch.ops.aten.copy_.default(a_1, add); a_1 = None
view_1 = torch.ops.aten.view.default(add, [4, 2]); add = None
return view_1
""")
# Some ops that are mutable are neither inplace nor out= ops.
# They also need special handling.
def test_mutable_op_not_inplace_or_other(self):
def f(x):
return torch._fused_moving_avg_obs_fq_helper(x, x, x, x, x, x, x, 1.0, 0, 1, 0)
logs = self.get_logs(f, torch.ones(1))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
_fused_moving_avg_obs_fq_helper_functional = torch.ops.aten._fused_moving_avg_obs_fq_helper_functional.default(a_1, a_1, a_1, a_1, a_1, a_1, a_1, 1.0, 0, 1, 0)
getitem = _fused_moving_avg_obs_fq_helper_functional[0]
getitem_1 = _fused_moving_avg_obs_fq_helper_functional[1]
getitem_2 = _fused_moving_avg_obs_fq_helper_functional[2]
getitem_3 = _fused_moving_avg_obs_fq_helper_functional[3]
getitem_4 = _fused_moving_avg_obs_fq_helper_functional[4]
getitem_5 = _fused_moving_avg_obs_fq_helper_functional[5]; _fused_moving_avg_obs_fq_helper_functional = None
copy_ = torch.ops.aten.copy_.default(a_1, getitem_5); a_1 = getitem_5 = None
return (getitem, getitem_1)
""") # noqa: B950
def test_as_strided(self):
def f(x):
y = x.as_strided((2,), (2,), 1)
y.add_(1)
return x
self.assert_functionalization(f, torch.ones(9))
logs = self.get_logs(f, torch.ones(9))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
as_strided_copy = torch.ops.aten.as_strided_copy.default(a_1, [2], [2], 1)
add = torch.ops.aten.add.Tensor(as_strided_copy, 1); as_strided_copy = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(a_1, add, [2], [2], 1); add = None
copy_ = torch.ops.aten.copy_.default(a_1, as_strided_scatter); a_1 = None
return as_strided_scatter
""")
def test_tensor_list_composite(self):
def f(x):
# Test an op with TensorList input
y = torch.block_diag(x, x)
return y
self.assert_functionalization(f, torch.ones(2, 2))
logs = self.get_logs(f, torch.ones(2, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
block_diag = torch.ops.aten.block_diag.default([a_1, a_1]); a_1 = None
return block_diag
""")
def test_cat(self):
def f(x):
out = torch.empty(0)
torch.cat((x,), out=out)
return out
self.assert_functionalization(f, torch.ones(2, 2))
logs = self.get_logs(f, torch.ones(2, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([0], device = device(type='cpu'), pin_memory = False)
cat = torch.ops.aten.cat.default([a_1]); a_1 = None
return cat
""")
reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([0], device = device(type='cpu'), pin_memory = False)
cat = torch.ops.aten.cat.default([a_1]); a_1 = None
return cat
""")
def test_diagonal(self):
def f(x):
# test: view ops that take a subset of the original tensor (select/diagonal)
tmp = torch.ones(2)
y = x.clone().diagonal()
y.add_(tmp)
z = x * x
return z
self.assert_functionalization(f, torch.ones(2, 2))
logs = self.get_logs(f, torch.ones(2, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
clone = torch.ops.aten.clone.default(a_1)
diagonal_copy = torch.ops.aten.diagonal_copy.default(clone); clone = None
add = torch.ops.aten.add.Tensor(diagonal_copy, ones); diagonal_copy = ones = None
mul = torch.ops.aten.mul.Tensor(a_1, a_1); a_1 = None
return mul
""")
reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, a_1):
ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
clone = torch.ops.aten.clone.default(a_1)
diagonal = torch.ops.aten.diagonal.default(clone); clone = None
add = torch.ops.aten.add_.Tensor(diagonal, ones); diagonal = ones = None
mul = torch.ops.aten.mul.Tensor(a_1, a_1); a_1 = None
return mul
""")
def test_diagonal_mutated_input(self):
def f(x):
# simple test: there are pending updates afterwards, which the test syncs manually
tmp = torch.ones(2)
y = x.diagonal()
y.add_(tmp)
return x
x = torch.ones(2, 2)
self.assert_functionalization(f, x)
logs = self.get_logs(f, torch.ones(2, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
diagonal_copy = torch.ops.aten.diagonal_copy.default(a_1)
add = torch.ops.aten.add.Tensor(diagonal_copy, ones); diagonal_copy = ones = None
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(a_1, add); add = None
copy_ = torch.ops.aten.copy_.default(a_1, diagonal_scatter); a_1 = None
return diagonal_scatter
""")
def test_split(self):
def f(x):
# test: view ops that return multiple tensors (split)
tmp = torch.ones(2)
y1, y2 = x.split(2)
y3 = y2.diagonal()
y3.add_(tmp)
z = x * x
return y3
self.assert_functionalization(f, torch.ones(4, 2))
logs = self.get_logs(f, torch.ones(4, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
split_copy = torch.ops.aten.split_copy.Tensor(a_1, 2)
getitem = split_copy[0]
getitem_1 = split_copy[1]; split_copy = None
diagonal_copy = torch.ops.aten.diagonal_copy.default(getitem_1); getitem_1 = None
add = torch.ops.aten.add.Tensor(diagonal_copy, ones); diagonal_copy = ones = None
split_copy_1 = torch.ops.aten.split_copy.Tensor(a_1, 2)
getitem_2 = split_copy_1[0]
getitem_3 = split_copy_1[1]; split_copy_1 = None
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_3, add); getitem_3 = None
slice_scatter = torch.ops.aten.slice_scatter.default(a_1, diagonal_scatter, 0, 2, 4); diagonal_scatter = None
mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter)
copy_ = torch.ops.aten.copy_.default(a_1, slice_scatter); a_1 = slice_scatter = None
return add
""") # noqa: B950
def test_view_inplace(self):
def f(x):
# test: view + inplace op (transpose_)
tmp = torch.ones(4)
x.transpose_(1, 0)
y = x[0]
y.add_(tmp)
return x
self.assert_functionalization(f, torch.ones(4, 2), mutated_input_metadata=True)
logs = self.get_logs(f, torch.ones(4, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False)
transpose_copy = torch.ops.aten.transpose_copy.int(a_1, 1, 0)
select_copy = torch.ops.aten.select_copy.int(transpose_copy, 0, 0); transpose_copy = None
add = torch.ops.aten.add.Tensor(select_copy, ones); select_copy = ones = None
transpose_copy_1 = torch.ops.aten.transpose_copy.int(a_1, 1, 0); a_1 = None
select_scatter = torch.ops.aten.select_scatter.default(transpose_copy_1, add, 0, 0); transpose_copy_1 = add = None
transpose_copy_2 = torch.ops.aten.transpose_copy.int(select_scatter, 1, 0); select_scatter = None
transpose_copy_3 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0); transpose_copy_2 = None
return transpose_copy_3
""") # noqa: B950
def test_optional_tensor_list(self):
def f(x):
# test: an operator that takes in a List[Optional[Tensor]] argument
# (index_put)
y = x.view(8)
indices = torch.arange(4)
values = torch.arange(4, dtype=y.dtype)
y.index_put_((indices,), values, accumulate=False)
return y
self.assert_functionalization(f, torch.ones(4, 2))
logs = self.get_logs(f, torch.ones(4, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
view_copy = torch.ops.aten.view_copy.default(a_1, [8])
arange = torch.ops.aten.arange.default(4, device = device(type='cpu'), pin_memory = False)
arange_1 = torch.ops.aten.arange.default(4, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
index_put = torch.ops.aten.index_put.default(view_copy, [arange], arange_1); view_copy = arange = arange_1 = None
view_copy_1 = torch.ops.aten.view_copy.default(index_put, [4, 2])
copy_ = torch.ops.aten.copy_.default(a_1, view_copy_1); a_1 = view_copy_1 = None
return index_put
""") # noqa: B950
def test_scalars(self):
def f(x):
# test: the pass can handle scalar inputs properly
tmp = torch.ones(4, 2)
y = x.view(4, 2)
y.add_(1)
z = 2 * y
z.div_(1)
return z
self.assert_functionalization(f, torch.ones(4, 2))
logs = self.get_logs(f, torch.ones(4, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
view_copy = torch.ops.aten.view_copy.default(a_1, [4, 2])
add = torch.ops.aten.add.Tensor(view_copy, 1); view_copy = None
mul = torch.ops.aten.mul.Tensor(add, 2)
div = torch.ops.aten.div.Tensor(mul, 1); mul = None
view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]); add = None
copy_ = torch.ops.aten.copy_.default(a_1, view_copy_1); a_1 = view_copy_1 = None
return div
""")
@skipIfTorchDynamo("Test does not work with TorchDynamo")
def test_metadata_change(self):
def f(x):
# ops like ge_() are allowed to change the dtype of the input.
# functionalization should pick up on that.
y = x.clone()
out = y.ge_(0)
return out
self.assert_functionalization(f, torch.ones(4, 2))
logs = self.get_logs(f, torch.ones(4, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
clone = torch.ops.aten.clone.default(a_1); a_1 = None
ge = torch.ops.aten.ge.Scalar(clone, 0); clone = None
_to_copy = torch.ops.aten._to_copy.default(ge, dtype = torch.float32, layout = torch.strided); ge = None
return _to_copy
""")
reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, a_1):
clone = torch.ops.aten.clone.default(a_1); a_1 = None
ge = torch.ops.aten.ge.Scalar(clone, 0); clone = None
_to_copy = torch.ops.aten._to_copy.default(ge, dtype = torch.float32, layout = torch.strided); ge = None
return _to_copy
""") # noqa: B950
@skipIfTorchDynamo("Test does not work with TorchDynamo")
def test_metadata_change_out_op(self):
def f(t, y):
out_1 = torch.ones(1)
return torch.add(t, y, out=out_1)
inpt1, inpt2 = torch.tensor([1]), torch.tensor([1])
inpt1_func, inpt2_func = torch._to_functional_tensor(inpt1), torch._to_functional_tensor(inpt2)
out_ref = f(inpt1, inpt2)
torch._enable_functionalization(reapply_views=True)
try:
out_functional = f(inpt1_func, inpt2_func)
finally:
torch._disable_functionalization()
self.assertEqual(out_ref, torch._from_functional_tensor(out_functional))
def test_only_one_view(self):
def f(x):
# This tests that we don't have any unnecessary views in the trace.
# If the input wasn't mutated, we don't need to regenerate it,
# so there should be a total of 1 op in the output trace.
return x.view(4, 2)
logs = self.get_logs(f, torch.ones(4, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
view_copy = torch.ops.aten.view_copy.default(a_1, [4, 2]); a_1 = None
return view_copy
""")
def test_everything(self):
def f(x):
# test: everything
tmp = torch.ones(2, 2)
x2 = x + x
y = x2.view(8)
z0 = y.reshape(2, 4)
z1 = z0.transpose(1, 0)
z1.unsqueeze_(0)
z1.squeeze_()
z2, z3 = z1.split(2)
z2.add_(tmp)
z4 = z0[0] + z2.reshape(4)
return z2
self.assert_functionalization(f, torch.ones(4, 2))
logs = self.get_logs(f, torch.ones(4, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
ones = torch.ops.aten.ones.default([2, 2], device = device(type='cpu'), pin_memory = False)
add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
view_copy = torch.ops.aten.view_copy.default(add, [8])
_reshape_alias_copy = torch.ops.aten._reshape_alias_copy.default(view_copy, [2, 4], [4, 1]); view_copy = None
transpose_copy = torch.ops.aten.transpose_copy.int(_reshape_alias_copy, 1, 0)
unsqueeze_copy = torch.ops.aten.unsqueeze_copy.default(transpose_copy, 0); transpose_copy = None
squeeze_copy = torch.ops.aten.squeeze_copy.default(unsqueeze_copy); unsqueeze_copy = None
split_copy = torch.ops.aten.split_copy.Tensor(squeeze_copy, 2); squeeze_copy = None
getitem = split_copy[0]
getitem_1 = split_copy[1]; split_copy = None
add_1 = torch.ops.aten.add.Tensor(getitem, ones); getitem = ones = None
select_copy = torch.ops.aten.select_copy.int(_reshape_alias_copy, 0, 0); _reshape_alias_copy = None
_reshape_alias_copy_1 = torch.ops.aten._reshape_alias_copy.default(add_1, [4], [1])
view_copy_1 = torch.ops.aten.view_copy.default(add, [8]); add = None
_reshape_alias_copy_2 = torch.ops.aten._reshape_alias_copy.default(view_copy_1, [2, 4], [4, 1]); view_copy_1 = None
transpose_copy_1 = torch.ops.aten.transpose_copy.int(_reshape_alias_copy_2, 1, 0); _reshape_alias_copy_2 = None
unsqueeze_copy_1 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_1, 0); transpose_copy_1 = None
squeeze_copy_1 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_1); unsqueeze_copy_1 = None
slice_scatter = torch.ops.aten.slice_scatter.default(squeeze_copy_1, add_1, 0, 0, 2); squeeze_copy_1 = None
unsqueeze_copy_2 = torch.ops.aten.unsqueeze_copy.default(slice_scatter, 0); slice_scatter = None
squeeze_copy_2 = torch.ops.aten.squeeze_copy.dim(unsqueeze_copy_2, 0); unsqueeze_copy_2 = None
transpose_copy_2 = torch.ops.aten.transpose_copy.int(squeeze_copy_2, 1, 0); squeeze_copy_2 = None
_reshape_alias_copy_3 = torch.ops.aten._reshape_alias_copy.default(transpose_copy_2, [8], [1]); transpose_copy_2 = None
view_copy_2 = torch.ops.aten.view_copy.default(_reshape_alias_copy_3, [4, 2]); _reshape_alias_copy_3 = None
view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [8])
_reshape_alias_copy_4 = torch.ops.aten._reshape_alias_copy.default(view_copy_3, [2, 4], [4, 1]); view_copy_3 = None
select_copy_1 = torch.ops.aten.select_copy.int(_reshape_alias_copy_4, 0, 0); _reshape_alias_copy_4 = None
view_copy_4 = torch.ops.aten.view_copy.default(view_copy_2, [8]); view_copy_2 = None
_reshape_alias_copy_5 = torch.ops.aten._reshape_alias_copy.default(view_copy_4, [2, 4], [4, 1]); view_copy_4 = None
transpose_copy_3 = torch.ops.aten.transpose_copy.int(_reshape_alias_copy_5, 1, 0); _reshape_alias_copy_5 = None
unsqueeze_copy_3 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_3, 0); transpose_copy_3 = None
squeeze_copy_3 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_3); unsqueeze_copy_3 = None
split_copy_1 = torch.ops.aten.split_copy.Tensor(squeeze_copy_3, 2); squeeze_copy_3 = None
getitem_2 = split_copy_1[0]
getitem_3 = split_copy_1[1]; split_copy_1 = None
_reshape_alias_copy_6 = torch.ops.aten._reshape_alias_copy.default(getitem_2, [4], [1]); getitem_2 = None
add_2 = torch.ops.aten.add.Tensor(select_copy_1, _reshape_alias_copy_6); select_copy_1 = _reshape_alias_copy_6 = None
return add_1
""") # noqa: B950
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, a_1):
ones = torch.ops.aten.ones.default([2, 2], device = device(type='cpu'), pin_memory = False)
add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
view = torch.ops.aten.view.default(add, [8])
_reshape_alias = torch.ops.aten._reshape_alias.default(view, [2, 4], [4, 1]); view = None
transpose = torch.ops.aten.transpose.int(_reshape_alias, 1, 0)
unsqueeze = torch.ops.aten.unsqueeze.default(transpose, 0); transpose = None
squeeze = torch.ops.aten.squeeze.default(unsqueeze); unsqueeze = None
split = torch.ops.aten.split.Tensor(squeeze, 2); squeeze = None
getitem = split[0]
getitem_1 = split[1]; split = None
add_1 = torch.ops.aten.add_.Tensor(getitem, ones); ones = None
select = torch.ops.aten.select.int(_reshape_alias, 0, 0); _reshape_alias = None
clone = torch.ops.aten.clone.default(getitem, memory_format = torch.contiguous_format)
_unsafe_view = torch.ops.aten._unsafe_view.default(clone, [4]); clone = None
view_1 = torch.ops.aten.view.default(add, [8]); add = None
_reshape_alias_1 = torch.ops.aten._reshape_alias.default(view_1, [2, 4], [4, 1]); view_1 = None
transpose_1 = torch.ops.aten.transpose.int(_reshape_alias_1, 1, 0); _reshape_alias_1 = None
unsqueeze_1 = torch.ops.aten.unsqueeze.default(transpose_1, 0); transpose_1 = None
squeeze_1 = torch.ops.aten.squeeze.default(unsqueeze_1); unsqueeze_1 = None
unsqueeze_2 = torch.ops.aten.unsqueeze.default(squeeze_1, 0); squeeze_1 = None
squeeze_2 = torch.ops.aten.squeeze.dim(unsqueeze_2, 0); unsqueeze_2 = None
transpose_2 = torch.ops.aten.transpose.int(squeeze_2, 1, 0); squeeze_2 = None
_reshape_alias_2 = torch.ops.aten._reshape_alias.default(transpose_2, [8], [1]); transpose_2 = None
view_2 = torch.ops.aten.view.default(_reshape_alias_2, [4, 2]); _reshape_alias_2 = None
view_3 = torch.ops.aten.view.default(view_2, [8]); view_2 = None
_reshape_alias_3 = torch.ops.aten._reshape_alias.default(view_3, [2, 4], [4, 1]); view_3 = None
select_1 = torch.ops.aten.select.int(_reshape_alias_3, 0, 0); _reshape_alias_3 = None
add_2 = torch.ops.aten.add.Tensor(select_1, _unsafe_view); select_1 = _unsafe_view = None
return getitem
""")
def test_reapply_views_simple(self):
def f(x):
tmp = torch.ones(4, 2)
y = x.view(4, 2)
y.add_(tmp)
z = x * x
return y
self.assert_functionalization(f, torch.ones(4, 2), reapply_views=True)
logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True)
self.assertExpectedInline(logs, """\
def forward(self, a_1):
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
view = torch.ops.aten.view.default(a_1, [4, 2])
add = torch.ops.aten.add.Tensor(view, ones); view = ones = None
view_1 = torch.ops.aten.view.default(add, [4, 2])
mul = torch.ops.aten.mul.Tensor(view_1, view_1)
copy_ = torch.ops.aten.copy_.default(a_1, view_1); a_1 = view_1 = None
return add
""")
def test_aliases_maintained_after_pass_when_reapplying_views(self):
def f(x):
tmp = torch.ones(4, 2)
y = x.view(4, 2)
z = x.view(4, 2)
y.add_(tmp)
return y, z
input_functional = torch._to_functional_tensor(torch.ones(4, 2))
torch._enable_functionalization(reapply_views=True)
try:
y, z = f(input_functional)
torch._sync(y)
torch._sync(z)
finally:
torch._disable_functionalization()
# y and z are aliases inside of the function, and that aliasing relationship should be maintained.
_y = torch._from_functional_tensor(y)
_z = torch._from_functional_tensor(z)
self.assertTrue(are_aliased(_y, _z))
# copy_() gets its own test, because it is special cased in functionalization.
# self.copy_(src) decomposes into src.to(self).expand_as(self).
def test_copy_(self):
def f(x):
tmp = torch.zeros(2, 2)
tmp_slice = tmp.diagonal()
y = tmp_slice.copy_(x)
z = y.add_(x)
return z
# Test 1: copy_() with same dtype and shape
# to() is a composite op that noops when the dtype/shape match, so nothing gets logged.
# self.assert_functionalization(f, torch.ones(2))
logs = self.get_logs(f, torch.ones(2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
return add
""")
reinplaced_logs = self.get_logs(f, torch.ones(2), reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal = torch.ops.aten.diagonal.default(zeros); zeros = None
add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
return add
""")
# Test 2: copy_() with same dtype, different shape
self.assert_functionalization(f, torch.ones(1))
logs = self.get_logs(f, torch.ones(1))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
expand_copy = torch.ops.aten.expand_copy.default(a_1, [2])
add = torch.ops.aten.add.Tensor(expand_copy, a_1); expand_copy = a_1 = None
return add
""")
reinplaced_logs = self.get_logs(f, torch.ones(1), reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal = torch.ops.aten.diagonal.default(zeros); zeros = None
expand_copy = torch.ops.aten.expand_copy.default(a_1, [2])
add = torch.ops.aten.add_.Tensor(expand_copy, a_1); a_1 = None
return expand_copy
""")
# Test 3: copy_() with different dtype, same shape
self.assert_functionalization(f, torch.ones(2, dtype=torch.long))
logs = self.get_logs(f, torch.ones(2, dtype=torch.long))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
_to_copy = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
add = torch.ops.aten.add.Tensor(_to_copy, a_1); _to_copy = a_1 = None
return add
""") # noqa: B950
reinplaced_logs = self.get_logs(f, torch.ones(2, dtype=torch.long), reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal = torch.ops.aten.diagonal.default(zeros); zeros = None
_to_copy = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
add = torch.ops.aten.add_.Tensor(_to_copy, a_1); a_1 = None
return _to_copy
""") # noqa: B950
# Test 4: copy_() with different dtype, different shape
self.assert_functionalization(f, torch.ones(1, dtype=torch.long))
logs = self.get_logs(f, torch.ones(1, dtype=torch.long))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
_to_copy = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
expand_copy = torch.ops.aten.expand_copy.default(_to_copy, [2]); _to_copy = None
add = torch.ops.aten.add.Tensor(expand_copy, a_1); expand_copy = a_1 = None
return add
""") # noqa: B950
reinplaced_logs = self.get_logs(f, torch.ones(1, dtype=torch.long), reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal = torch.ops.aten.diagonal.default(zeros); zeros = None
_to_copy = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
expand_copy = torch.ops.aten.expand_copy.default(_to_copy, [2]); _to_copy = None
add = torch.ops.aten.add_.Tensor(expand_copy, a_1); a_1 = None
return expand_copy
""") # noqa: B950
def test_expand_symint(self):
# Once some existing SymInt bugs are ironed out, we should update
# this test to plumb FakeSymbolicTensors through it
def f(x):
return x.expand(x.size(0), x.size(1))
self.assert_functionalization(f, torch.ones(2, 2))
logs = self.get_logs(f, torch.ones(2, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
expand_copy = torch.ops.aten.expand_copy.default(a_1, [2, 2]); a_1 = None
return expand_copy
""")
def test_fill_(self):
def f(x):
y = x + x
z = y.diagonal()
z.fill_(0)
return y
self.assert_functionalization(f, torch.ones(2, 2))
logs = self.get_logs(f, torch.ones(2, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
diagonal_copy = torch.ops.aten.diagonal_copy.default(add)
fill = torch.ops.aten.fill.Scalar(diagonal_copy, 0); diagonal_copy = None
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(add, fill); add = fill = None
return diagonal_scatter
""")
reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, a_1):
add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
diagonal = torch.ops.aten.diagonal.default(add)
fill = torch.ops.aten.fill_.Scalar(diagonal, 0); diagonal = None
return add
""")
def test_resize_smaller(self):
def f(w):
# Resizing to a smaller size doesn't affect storage
x = w + 1
y = x.view(4, 4)
y.resize_(3, 3)
y2 = y.view(-1)
y2.add_(1)
z = y + 1
return z
self.assert_functionalization(f, torch.ones(8, 2))
logs = self.get_logs(f, torch.ones(8, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
add = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None
view_copy = torch.ops.aten.view_copy.default(add, [4, 4])
resize = torch.ops.aten.resize.default(view_copy, [3, 3])
as_strided_copy = torch.ops.aten.as_strided_copy.default(view_copy, [3, 3], [3, 1]); view_copy = None
view_copy_1 = torch.ops.aten.view_copy.default(as_strided_copy, [-1]); as_strided_copy = None
add_1 = torch.ops.aten.add.Tensor(view_copy_1, 1); view_copy_1 = None
view_copy_2 = torch.ops.aten.view_copy.default(add, [4, 4]); add = None
as_strided_copy_1 = torch.ops.aten.as_strided_copy.default(view_copy_2, [3, 3], [3, 1])
view_copy_3 = torch.ops.aten.view_copy.default(add_1, [3, 3]); add_1 = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(view_copy_2, view_copy_3, [3, 3], [3, 1]); view_copy_2 = view_copy_3 = None
view_copy_4 = torch.ops.aten.view_copy.default(as_strided_scatter, [8, 2]); as_strided_scatter = None
view_copy_5 = torch.ops.aten.view_copy.default(view_copy_4, [4, 4]); view_copy_4 = None
as_strided_copy_2 = torch.ops.aten.as_strided_copy.default(view_copy_5, [3, 3], [3, 1]); view_copy_5 = None
add_2 = torch.ops.aten.add.Tensor(as_strided_copy_2, 1); as_strided_copy_2 = None
return add_2
""") # noqa: B950
reinplaced_logs = self.get_logs(f, torch.ones(8, 2), reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, a_1):
add = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None
view = torch.ops.aten.view.default(add, [4, 4])
resize = torch.ops.aten.resize.default(view, [3, 3])
as_strided = torch.ops.aten.as_strided.default(view, [3, 3], [3, 1]); view = None
view_1 = torch.ops.aten.view.default(as_strided, [-1]); as_strided = None
add_1 = torch.ops.aten.add_.Tensor(view_1, 1)
view_2 = torch.ops.aten.view.default(add, [4, 4]); add = None
as_strided_1 = torch.ops.aten.as_strided.default(view_2, [3, 3], [3, 1])
view_3 = torch.ops.aten.view.default(view_1, [3, 3]); view_1 = None
view_4 = torch.ops.aten.view.default(view_2, [8, 2]); view_2 = None
view_5 = torch.ops.aten.view.default(view_4, [4, 4]); view_4 = None
as_strided_2 = torch.ops.aten.as_strided.default(view_5, [3, 3], [3, 1]); view_5 = None
add_2 = torch.ops.aten.add_.Tensor(as_strided_2, 1)
return as_strided_2
""")
def test_resize_larger_valid(self):
def f(x):
y = x + 1
# resizing a tensor to a larger size is only currently allowed
# if the tensor-to-resize is not a view / has no outstanding views.
# See Note [resize_() in functionalization pass]
y.resize_(5, 5)
y2 = y.view(25)
# Do a mutation to ensure that aliases of the output of resize_()
# propagate mutations correctly.
# I'm using fill_ specifically because I want to guarantee that
# none of the output has uninitialized memory at the end
# (since these tests compare the data output against a reference impl)
y2.fill_(1)
out = y + 1
return y, out
self.assert_functionalization(f, torch.ones(8, 2))
logs = self.get_logs(f, torch.ones(8, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
add = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None
resize = torch.ops.aten.resize.default(add, [5, 5]); add = None
view_copy = torch.ops.aten.view_copy.default(resize, [25]); resize = None
fill = torch.ops.aten.fill.Scalar(view_copy, 1); view_copy = None
view_copy_1 = torch.ops.aten.view_copy.default(fill, [5, 5]); fill = None
add_1 = torch.ops.aten.add.Tensor(view_copy_1, 1)
return (view_copy_1, add_1)
""")
reinplaced_logs = self.get_logs(f, torch.ones(8, 2), reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, a_1):
add = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None
resize = torch.ops.aten.resize_.default(add, [5, 5])
view = torch.ops.aten.view.default(add, [25]); add = None
fill = torch.ops.aten.fill_.Scalar(view, 1)
view_1 = torch.ops.aten.view.default(view, [5, 5]); view = None
add_1 = torch.ops.aten.add.Tensor(view_1, 1)
return (view_1, add_1)
""")
def test_resize_larger_invalid(self):
def f(x):
y = x + 1
z = y.view(4, 4)
# resizing a tensor to a larger size is only currently allowed
# if the tensor-to-resize is not a view / has no outstanding views.
# See Note [resize_() in functionalization pass]
# This should fail
z.resize_(5, 5)
z2 = z.view(25)
z2.fill_(1)
out = z + 1
return y, out
with self.assertRaisesRegex(
RuntimeError,
r'Attempted to resize a view tensor to a larger size. This is not allowed in the functionalization pass'):
self.assert_functionalization(f, torch.ones(8, 2))
def test_nested_functions_propagate_updates(self):
def g(x):
# Create a view of x
y = x[0]
y.add_(1)
# The view, y, gets deallocated at the end of this function
def f(x):
# Calling g(x) should mutate x
g(x)
# We expect x to be synced here, even though the alias created in g() has been deallocated!
y = x + x
return y
self.assert_functionalization(f, torch.ones(2, 2))
def test_mixed_wrappers_valid(self):
def f(x, y):
z = x + y
z.add_(1)
return z
x1_not_functional = LoggingTensor(torch.ones(4))
x2_functional = torch._to_functional_tensor(LoggingTensor(torch.ones(4)))
with capture_logs() as logs:
y = f(x1_not_functional, x2_functional)
# Make sure that functionalization ran the "+" kernel
# with a functional + non-functional tensor, and wrapped the output appropriately.
self.assertExpectedInline('\n'.join(logs), """\
$2 = torch._ops.aten.add.Tensor($0, $1)
$3 = torch._ops.aten.add.Tensor($2, 1)""")
def test_mixed_wrappers_invalid(self):
x1_not_functional = torch.ones(4)
x2_functional = torch._to_functional_tensor(torch.ones(4))
# When dealing with mixed functional + non functional tensors,
# normal_tensor.add_(functional_tensor) is not valid
# because normal_tensor would need to be "promoted" to a functional tensor.
with self.assertRaises(RuntimeError):
x1_not_functional.add_(x2_functional)
def test_index_mutation_on_non_input(self):
def f(x):
tmp = torch.zeros(10)
tmp[5].fill_(1)
return tmp
self.assert_functionalization(f, torch.ones(2))
logs = self.get_logs(f, torch.ones(2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([10], device = device(type='cpu'), pin_memory = False)
select_copy = torch.ops.aten.select_copy.int(zeros, 0, 5)
fill = torch.ops.aten.fill.Scalar(select_copy, 1); select_copy = None
select_scatter = torch.ops.aten.select_scatter.default(zeros, fill, 0, 5); zeros = fill = None
return select_scatter
""") # noqa: B950
reinplaced_logs = self.get_logs(f, torch.ones(2), reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([10], device = device(type='cpu'), pin_memory = False)
select = torch.ops.aten.select.int(zeros, 0, 5)
fill = torch.ops.aten.fill_.Scalar(select, 1); select = None
return zeros
""")
if __name__ == '__main__':
run_tests()