mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
I added some tests for Conj, Neg and ZeroTensor for both python and C++ functionalization. This also fixes a nasty segfult when running a functorch `jacfwd` test with `torch.compile`, once AOTAutograd is using `FunctionalTensor`. Changes: (1) I use Jeffrey's `make_wrapper_subclass(extra_dispatch_keys)` kwarg to plumb extra dispatch keys ontoto the wrapper, mirroring what C++ functionalization does (C++ functionalization will mirror all dispatch keys from the inner tensor to the wrapper, except for python and functorch keys). (2) FunctionalTensorMode will decompose CompositeImplicitAutograd ops, since (for example) ZeroTensor kernels can send ops like `.to()` directly to the Python key. We'll need a way to toggle this later for pre-dispatch functionalization (3) Bound `_ForceDispatchKeyGuard` and BatchedTensorImpl's dispatch keyset to python Pull Request resolved: https://github.com/pytorch/pytorch/pull/109023 Approved by: https://github.com/zou3519 ghstack dependencies: #108654, #109662, #109632
1661 lines
74 KiB
Python
1661 lines
74 KiB
Python
# Owner(s): ["module: codegen"]
|
|
|
|
import torch
|
|
from contextlib import nullcontext
|
|
from torch.testing._internal.common_utils import (
|
|
TestCase, run_tests, skipIfTorchDynamo, TEST_WITH_TORCHDYNAMO, IS_WINDOWS,
|
|
xfail_inherited_tests
|
|
)
|
|
from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode, dispatch_functionalize
|
|
from torch.testing._internal.logging_tensor import LoggingTensor, capture_logs
|
|
from torch.utils._pytree import tree_map_only, tree_flatten
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch.fx.passes.reinplace import reinplace
|
|
from torch._dispatch.python import enable_crossref_functionalize, enable_python_dispatcher
|
|
from torch.multiprocessing.reductions import StorageWeakRef
|
|
|
|
import unittest
|
|
|
|
def are_aliased(x, y):
|
|
x_storage = StorageWeakRef(x.storage())
|
|
y_storage = StorageWeakRef(y.storage())
|
|
return x_storage == y_storage
|
|
|
|
# We can unify testing and use functionalize() here instead
|
|
# if/when functorch moves into core.
|
|
# This is basically a crappy version of `functionalize()`.
|
|
def _functionalize(f, *, reapply_views: bool, crossref: bool, skip_input_mutations: bool = False):
|
|
def to_fun(t: torch.Tensor):
|
|
func_t = torch._to_functional_tensor(t)
|
|
func_t.requires_grad = t.requires_grad
|
|
return func_t
|
|
|
|
def wrapped(*inputs):
|
|
ctx = nullcontext()
|
|
if crossref:
|
|
ctx = enable_crossref_functionalize()
|
|
with ctx:
|
|
inputs_functional = tree_map_only(torch.Tensor, to_fun, inputs)
|
|
torch._enable_functionalization(reapply_views=reapply_views)
|
|
try:
|
|
out = f(*inputs_functional)
|
|
finally:
|
|
torch._disable_functionalization()
|
|
flat_inputs, _ = tree_flatten(inputs)
|
|
flat_inputs_functional, _ = tree_flatten(inputs_functional)
|
|
|
|
for inpt, input_functional in zip(flat_inputs, flat_inputs_functional):
|
|
torch._sync(input_functional)
|
|
inpt_new = torch._from_functional_tensor(input_functional)
|
|
if inpt_new is not inpt and not skip_input_mutations:
|
|
# Existing deficiency in functionalize():
|
|
# we don't correctly mutate input metadata (yet?)
|
|
if inpt_new.shape == inpt.shape:
|
|
inpt.copy_(inpt_new)
|
|
tree_map_only(torch.Tensor, torch._sync, out)
|
|
out_unwrapped = tree_map_only(torch.Tensor, torch._from_functional_tensor, out)
|
|
return out_unwrapped
|
|
|
|
return wrapped
|
|
|
|
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "https://github.com/pytorch/pytorch/issues/81457")
|
|
class TestFunctionalization(TestCase):
|
|
|
|
crossref = False
|
|
|
|
def get_logs(self, func, *inpts, reapply_views=False, run_reinplace=False):
|
|
inpts_clone = tree_map_only(torch.Tensor, torch.clone, inpts)
|
|
traced_f = make_fx(_functionalize(func, reapply_views=reapply_views, crossref=self.crossref))(*inpts)
|
|
if run_reinplace:
|
|
traced_f = reinplace(traced_f, *inpts_clone)
|
|
return traced_f.code
|
|
|
|
def assert_functionalization(self, func, *inpts, reapply_views=False, mutated_input_metadata=False):
|
|
clones1 = tree_map_only(torch.Tensor, torch.clone, inpts)
|
|
clones2 = tree_map_only(torch.Tensor, torch.clone, inpts)
|
|
clones3 = tree_map_only(torch.Tensor, torch.clone, inpts)
|
|
|
|
# Compare outputs (and mutated inputs), with and without functionalization.
|
|
out_ref = func(*inpts)
|
|
out_functional = _functionalize(func, reapply_views=reapply_views, crossref=self.crossref)(*clones1)
|
|
|
|
# The reinplacing pass is only valid to run with reapply_views=True.
|
|
functional_func = make_fx(_functionalize(func, reapply_views=True, crossref=self.crossref))(*clones2)
|
|
reinplace_func = reinplace(functional_func, *clones2)
|
|
|
|
# NOTE: for now, need to pass in fresh inputs here, because make_fx
|
|
# will directly mutate the inputs that you trace with.
|
|
# Once this is fixed we can clean this up.
|
|
out_reinplace = reinplace_func(*clones3)
|
|
|
|
# functionalize() deficiency: input metadata mutations aren't propagated properly,
|
|
# so we just need to skip checks here for the tests that exercise that.
|
|
if not mutated_input_metadata:
|
|
flat_inpts, _ = tree_flatten(inpts)
|
|
flat_clones1, _ = tree_flatten(clones1)
|
|
flat_clones3, _ = tree_flatten(clones3)
|
|
for inpt, input_clone, input_clone3 in zip(flat_inpts, flat_clones1, flat_clones3):
|
|
self.assertEqual(inpt, input_clone) # input mutations should still occur
|
|
self.assertEqual(inpt, input_clone3)
|
|
|
|
# Handle tests with multi-tensor outputs
|
|
if isinstance(out_ref, tuple):
|
|
out_refs, out_functionals, out_reinplaces = list(out_ref), list(out_functional), list(out_reinplace)
|
|
else:
|
|
out_refs, out_functionals, out_reinplaces = [out_ref], [out_functional], [out_reinplace]
|
|
|
|
for out_ref_, out_functional_, out_reinplace_ in zip(out_refs, out_functionals, out_reinplaces):
|
|
self.assertEqual(out_ref_, out_functional_)
|
|
self.assertEqual(out_ref_, out_reinplace_)
|
|
|
|
def test_save_for_backwards_segfault(self):
|
|
inp = torch._to_functional_tensor(LoggingTensor(torch.randn(2, 2))).requires_grad_(True)
|
|
inp.exp()
|
|
|
|
def test_multiple_views_of_same_base(self):
|
|
def f(x):
|
|
y = x.view(-1)
|
|
z = x.view(-1)
|
|
x.add_(1)
|
|
# y should have been updated.
|
|
y2 = y + 1
|
|
# z should have been updated too.
|
|
z2 = z + 1
|
|
return z2
|
|
self.assert_functionalization(f, torch.ones(4))
|
|
|
|
def test_freeze(self):
|
|
def f(x):
|
|
y = x.clone()
|
|
z = y[0]
|
|
torch._freeze_functional_tensor(y)
|
|
x.add_(1)
|
|
self.assertRaises(RuntimeError, lambda: y.add_(1))
|
|
self.assertRaises(RuntimeError, lambda: z.add_(1))
|
|
return z
|
|
|
|
_functionalize(f, reapply_views=True, crossref=self.crossref)(torch.ones(3, 3))
|
|
|
|
def test_copy_stride_mismatch(self):
|
|
def f(x):
|
|
y = torch.empty_strided((2, 2), (5, 1))
|
|
y.copy_(x)
|
|
return y
|
|
|
|
r = _functionalize(f, reapply_views=True, crossref=self.crossref)(torch.ones(2, 2))
|
|
self.assertEqual(r.stride(), (5, 1))
|
|
|
|
def test_set_(self):
|
|
def f(x):
|
|
y = torch.ones(2)
|
|
y.set_(x.storage())
|
|
return y
|
|
|
|
# We should probaby get the crossref test to work,
|
|
# but fixing it for Storage() objects is annoying.
|
|
r = _functionalize(f, reapply_views=True, crossref=False)(torch.ones(2))
|
|
self.assertEqual(str(r.device), 'cpu')
|
|
|
|
def test_advanced_indexing(self):
|
|
def f():
|
|
x = torch.zeros(3, 3)
|
|
idx = torch.tensor([0])
|
|
val = torch.ones(3, 1)
|
|
x[:, idx] = val
|
|
return x
|
|
|
|
self.assert_functionalization(f)
|
|
|
|
def test_view_clone_view_inplace(self):
|
|
def f(input):
|
|
shape = [1, 1024, 128, 128]
|
|
input_reshaped = input.view(shape)
|
|
out = input_reshaped.clone()
|
|
r = out.view(input.shape)
|
|
r.relu_()
|
|
return r
|
|
|
|
def g(x):
|
|
loss = f(x).sum()
|
|
from torch._functorch.aot_autograd import setup_stacktrace_preservation_hooks
|
|
import torch.fx.traceback as fx_traceback
|
|
setup_stacktrace_preservation_hooks([loss.grad_fn])
|
|
with fx_traceback.preserve_node_meta():
|
|
loss.backward()
|
|
return x.grad
|
|
|
|
with torch.autograd.detect_anomaly(check_nan=False):
|
|
logs = self.get_logs(g, torch.ones(16, 64, 128, 128, requires_grad=True))
|
|
self.assertExpectedInline(logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
view_copy = torch.ops.aten.view_copy.default(arg0_1, [1, 1024, 128, 128]); arg0_1 = None
|
|
clone = torch.ops.aten.clone.default(view_copy); view_copy = None
|
|
view_copy_1 = torch.ops.aten.view_copy.default(clone, [16, 64, 128, 128])
|
|
relu = torch.ops.aten.relu.default(view_copy_1); view_copy_1 = None
|
|
view_copy_2 = torch.ops.aten.view_copy.default(relu, [1, 1024, 128, 128]); relu = None
|
|
view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [16, 64, 128, 128]); view_copy_2 = None
|
|
view_copy_4 = torch.ops.aten.view_copy.default(clone, [16, 64, 128, 128]); clone = None
|
|
sum_1 = torch.ops.aten.sum.default(view_copy_3)
|
|
ones_like = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format); sum_1 = None
|
|
expand_copy = torch.ops.aten.expand_copy.default(ones_like, [16, 64, 128, 128]); ones_like = None
|
|
view_copy_5 = torch.ops.aten.view_copy.default(expand_copy, [1, 1024, 128, 128]); expand_copy = None
|
|
new_empty_strided = torch.ops.aten.new_empty_strided.default(view_copy_5, [1, 1024, 128, 128], [16777216, 16384, 128, 1])
|
|
copy = torch.ops.aten.copy.default(new_empty_strided, view_copy_5); new_empty_strided = view_copy_5 = None
|
|
view_copy_6 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128])
|
|
view_copy_7 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128])
|
|
clone_1 = torch.ops.aten.clone.default(view_copy_7, memory_format = torch.contiguous_format)
|
|
threshold_backward = torch.ops.aten.threshold_backward.default(clone_1, view_copy_3, 0); clone_1 = view_copy_3 = None
|
|
copy_1 = torch.ops.aten.copy.default(view_copy_7, threshold_backward); view_copy_7 = threshold_backward = None
|
|
view_copy_8 = torch.ops.aten.view_copy.default(copy_1, [1, 1024, 128, 128]); copy_1 = None
|
|
view_copy_9 = torch.ops.aten.view_copy.default(view_copy_8, [16, 64, 128, 128])
|
|
view_copy_10 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128]); copy = None
|
|
detach_copy = torch.ops.aten.detach_copy.default(view_copy_10); view_copy_10 = None
|
|
view_copy_11 = torch.ops.aten.view_copy.default(view_copy_8, [16, 64, 128, 128]); view_copy_8 = None
|
|
detach_copy_1 = torch.ops.aten.detach_copy.default(view_copy_11); view_copy_11 = None
|
|
return detach_copy_1
|
|
""") # noqa: B950
|
|
|
|
def test_simple(self):
|
|
def f(x):
|
|
# simple test: 1 view op, 1 inplace op
|
|
tmp = torch.ones(4, 2)
|
|
y = x.view(4, 2)
|
|
y.add_(tmp)
|
|
z = x * x
|
|
return y
|
|
self.assert_functionalization(f, torch.ones(4, 2))
|
|
logs = self.get_logs(f, torch.ones(4, 2))
|
|
self.assertExpectedInline(logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
|
|
view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2])
|
|
add = torch.ops.aten.add.Tensor(view_copy, ones); view_copy = ones = None
|
|
view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]); add = None
|
|
view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [4, 2])
|
|
mul = torch.ops.aten.mul.Tensor(view_copy_1, view_copy_1)
|
|
copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1); arg0_1 = view_copy_1 = None
|
|
return view_copy_2
|
|
""")
|
|
|
|
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True)
|
|
self.assertExpectedInline(reinplaced_logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
|
|
view = torch.ops.aten.view.default(arg0_1, [4, 2])
|
|
add = torch.ops.aten.add.Tensor(view, ones); view = ones = None
|
|
view_1 = torch.ops.aten.view.default(add, [4, 2]); add = None
|
|
view_2 = torch.ops.aten.view.default(view_1, [4, 2])
|
|
mul = torch.ops.aten.mul.Tensor(view_1, view_1)
|
|
copy_ = torch.ops.aten.copy_.default(arg0_1, view_1); arg0_1 = view_1 = None
|
|
return view_2
|
|
""")
|
|
|
|
def test_simple_out(self):
|
|
def f(x):
|
|
tmp = torch.ones(4, 2)
|
|
y = x.view(4, 2)
|
|
# the out= tensor will get resized, since it has size=0 to start.
|
|
z = torch.empty(())
|
|
torch.add(y, tmp, out=z)
|
|
w = z * z
|
|
return w
|
|
self.assert_functionalization(f, torch.ones(4, 2))
|
|
logs = self.get_logs(f, torch.ones(4, 2))
|
|
self.assertExpectedInline(logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
|
|
view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2]); arg0_1 = None
|
|
empty = torch.ops.aten.empty.memory_format([], device = device(type='cpu'), pin_memory = False)
|
|
add = torch.ops.aten.add.Tensor(view_copy, ones); view_copy = ones = None
|
|
mul = torch.ops.aten.mul.Tensor(add, add); add = None
|
|
return mul
|
|
""")
|
|
|
|
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True)
|
|
self.assertExpectedInline(reinplaced_logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
|
|
view = torch.ops.aten.view.default(arg0_1, [4, 2]); arg0_1 = None
|
|
empty = torch.ops.aten.empty.memory_format([], device = device(type='cpu'), pin_memory = False)
|
|
add = torch.ops.aten.add.Tensor(view, ones); view = ones = None
|
|
mul = torch.ops.aten.mul.Tensor(add, add); add = None
|
|
return mul
|
|
""")
|
|
|
|
def test_multi_out(self):
|
|
def f(x):
|
|
# aminmax.out returns a tuple of tensors.
|
|
# functionalization should properly handle the tuple.
|
|
out_min = torch.empty(4)
|
|
out_max = torch.empty(4)
|
|
torch.aminmax(x, dim=0, out=(out_max, out_min))
|
|
return out_max
|
|
self.assert_functionalization(f, torch.arange(8, dtype=torch.float32))
|
|
logs = self.get_logs(f, torch.arange(8, dtype=torch.float32))
|
|
self.assertExpectedInline(logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
empty = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False)
|
|
empty_1 = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False)
|
|
aminmax = torch.ops.aten.aminmax.default(arg0_1, dim = 0); arg0_1 = None
|
|
getitem = aminmax[0]
|
|
getitem_1 = aminmax[1]; aminmax = None
|
|
return getitem
|
|
""")
|
|
|
|
reinplaced_logs = self.get_logs(f, torch.arange(8, dtype=torch.float32), reapply_views=True, run_reinplace=True)
|
|
self.assertExpectedInline(reinplaced_logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
empty = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False)
|
|
empty_1 = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False)
|
|
aminmax = torch.ops.aten.aminmax.default(arg0_1, dim = 0); arg0_1 = None
|
|
getitem = aminmax[0]
|
|
getitem_1 = aminmax[1]; aminmax = None
|
|
return getitem
|
|
""")
|
|
|
|
def test_tensor_ctr(self):
|
|
def f(x):
|
|
y = torch.tensor((1, 2, 3))
|
|
z = y.view(-1)
|
|
z.add_(1)
|
|
return y
|
|
|
|
inpt = torch.arange(3, dtype=torch.float32)
|
|
self.assert_functionalization(f, inpt)
|
|
|
|
logs = self.get_logs(f, inpt)
|
|
self.assertExpectedInline(logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
_tensor_constant0 = self._tensor_constant0
|
|
lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
|
|
view_copy = torch.ops.aten.view_copy.default(lift_fresh_copy, [-1]); lift_fresh_copy = None
|
|
add = torch.ops.aten.add.Tensor(view_copy, 1); view_copy = None
|
|
view_copy_1 = torch.ops.aten.view_copy.default(add, [3]); add = None
|
|
view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [-1])
|
|
return view_copy_1
|
|
""")
|
|
|
|
reinplaced_logs = self.get_logs(f, inpt, reapply_views=True, run_reinplace=True)
|
|
self.assertExpectedInline(reinplaced_logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
_tensor_constant0 = self._tensor_constant0
|
|
lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
|
|
view = torch.ops.aten.view.default(lift_fresh_copy, [-1]); lift_fresh_copy = None
|
|
add = torch.ops.aten.add_.Tensor(view, 1)
|
|
view_1 = torch.ops.aten.view.default(view, [3]); view = None
|
|
view_2 = torch.ops.aten.view.default(view_1, [-1])
|
|
return view_1
|
|
""")
|
|
|
|
def test_advanced_indexing_correct_strides(self):
|
|
def f(a):
|
|
# This test requires that *_scatter ops are able to return
|
|
# non-contiguous tensors.
|
|
b = a.clone()[:, 1]
|
|
c = torch.ones_like(b, dtype=torch.bool)
|
|
d = b.masked_fill_(c, 0)
|
|
return d
|
|
self.assert_functionalization(f, torch.ones(2, 2), reapply_views=True)
|
|
|
|
def test_tensor_list_mixed_functional_nonfunctional(self):
|
|
nonfunctional_tensor = torch.ones(2, dtype=torch.long)
|
|
|
|
def f(x):
|
|
# simple test: 1 view op, 1 inplace op
|
|
functional_tensor = torch.ones(2, dtype=torch.long)
|
|
out = x[functional_tensor, nonfunctional_tensor]
|
|
return out
|
|
out = f(torch.ones(2, 2))
|
|
out_functional = _functionalize(f, reapply_views=True, crossref=self.crossref)(torch.ones(2, 2))
|
|
self.assertEqual(out, out_functional)
|
|
|
|
def test_inplace_on_non_view(self):
|
|
def f(x):
|
|
# test for the case where we functionalize an inplace op on the other tensor - not a view.
|
|
# This is worth checking because the tensor will have an empty ViewMeta stack, which needs to be special cased.
|
|
tmp = torch.ones(4, 2)
|
|
y = x.view(4, 2)
|
|
x.add_(tmp)
|
|
return y
|
|
self.assert_functionalization(f, torch.ones(4, 2))
|
|
logs = self.get_logs(f, torch.ones(4, 2))
|
|
self.assertExpectedInline(logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
|
|
view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2])
|
|
add = torch.ops.aten.add.Tensor(arg0_1, ones); ones = None
|
|
copy_ = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = None
|
|
view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]); add = None
|
|
return view_copy_1
|
|
""")
|
|
|
|
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True)
|
|
self.assertExpectedInline(reinplaced_logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
|
|
view = torch.ops.aten.view.default(arg0_1, [4, 2])
|
|
add = torch.ops.aten.add.Tensor(arg0_1, ones); ones = None
|
|
copy_ = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = None
|
|
view_1 = torch.ops.aten.view.default(add, [4, 2]); add = None
|
|
return view_1
|
|
""")
|
|
|
|
# Some ops that are mutable are neither inplace nor out= ops.
|
|
# They also need special handling.
|
|
def test_mutable_op_not_inplace_or_other(self):
|
|
def f(x):
|
|
return torch._fused_moving_avg_obs_fq_helper(x, x, x, x, x, x, x, 1.0, 0, 1, 0)
|
|
|
|
logs = self.get_logs(f, torch.ones(1))
|
|
self.assertExpectedInline(logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
_fused_moving_avg_obs_fq_helper_functional = torch.ops.aten._fused_moving_avg_obs_fq_helper_functional.default(arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, 1.0, 0, 1, 0)
|
|
getitem = _fused_moving_avg_obs_fq_helper_functional[0]
|
|
getitem_1 = _fused_moving_avg_obs_fq_helper_functional[1]
|
|
getitem_2 = _fused_moving_avg_obs_fq_helper_functional[2]
|
|
getitem_3 = _fused_moving_avg_obs_fq_helper_functional[3]
|
|
getitem_4 = _fused_moving_avg_obs_fq_helper_functional[4]
|
|
getitem_5 = _fused_moving_avg_obs_fq_helper_functional[5]; _fused_moving_avg_obs_fq_helper_functional = None
|
|
copy_ = torch.ops.aten.copy_.default(arg0_1, getitem_5); arg0_1 = getitem_5 = None
|
|
return (getitem, getitem_1)
|
|
""") # noqa: B950
|
|
|
|
def test_as_strided(self):
|
|
def f(x):
|
|
y = x.as_strided((2,), (2,), 1)
|
|
y.add_(1)
|
|
return x
|
|
self.assert_functionalization(f, torch.ones(9))
|
|
logs = self.get_logs(f, torch.ones(9))
|
|
self.assertExpectedInline(logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
as_strided_copy = torch.ops.aten.as_strided_copy.default(arg0_1, [2], [2], 1)
|
|
add = torch.ops.aten.add.Tensor(as_strided_copy, 1); as_strided_copy = None
|
|
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(arg0_1, add, [2], [2], 1); add = None
|
|
as_strided_copy_1 = torch.ops.aten.as_strided_copy.default(as_strided_scatter, [2], [2], 1)
|
|
copy_ = torch.ops.aten.copy_.default(arg0_1, as_strided_scatter); arg0_1 = None
|
|
return as_strided_scatter
|
|
""")
|
|
|
|
def test_tensor_list_composite(self):
|
|
def f(x):
|
|
# Test an op with TensorList input
|
|
y = torch.block_diag(x, x)
|
|
return y
|
|
self.assert_functionalization(f, torch.ones(2, 2))
|
|
logs = self.get_logs(f, torch.ones(2, 2))
|
|
self.assertExpectedInline(logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
block_diag = torch.ops.aten.block_diag.default([arg0_1, arg0_1]); arg0_1 = None
|
|
return block_diag
|
|
""")
|
|
|
|
def test_cat(self):
|
|
def f(x):
|
|
out = torch.empty(0)
|
|
torch.cat((x,), out=out)
|
|
return out
|
|
self.assert_functionalization(f, torch.ones(2, 2))
|
|
logs = self.get_logs(f, torch.ones(2, 2))
|
|
self.assertExpectedInline(logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
empty = torch.ops.aten.empty.memory_format([0], device = device(type='cpu'), pin_memory = False)
|
|
cat = torch.ops.aten.cat.default([arg0_1]); arg0_1 = None
|
|
return cat
|
|
""")
|
|
|
|
reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True)
|
|
self.assertExpectedInline(reinplaced_logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
empty = torch.ops.aten.empty.memory_format([0], device = device(type='cpu'), pin_memory = False)
|
|
cat = torch.ops.aten.cat.default([arg0_1]); arg0_1 = None
|
|
return cat
|
|
""")
|
|
|
|
|
|
def test_diagonal(self):
|
|
def f(x):
|
|
# test: view ops that take a subset of the original tensor (select/diagonal)
|
|
tmp = torch.ones(2)
|
|
y = x.clone().diagonal()
|
|
y.add_(tmp)
|
|
z = x * x
|
|
return z
|
|
self.assert_functionalization(f, torch.ones(2, 2))
|
|
logs = self.get_logs(f, torch.ones(2, 2))
|
|
self.assertExpectedInline(logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
|
|
clone = torch.ops.aten.clone.default(arg0_1)
|
|
diagonal_copy = torch.ops.aten.diagonal_copy.default(clone)
|
|
add = torch.ops.aten.add.Tensor(diagonal_copy, ones); diagonal_copy = ones = None
|
|
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(clone, add); clone = add = None
|
|
diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter); diagonal_scatter = None
|
|
mul = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); arg0_1 = None
|
|
return mul
|
|
""")
|
|
|
|
reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True)
|
|
self.assertExpectedInline(reinplaced_logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
|
|
clone = torch.ops.aten.clone.default(arg0_1)
|
|
diagonal = torch.ops.aten.diagonal.default(clone)
|
|
add = torch.ops.aten.add_.Tensor(diagonal, ones); diagonal = ones = None
|
|
diagonal_1 = torch.ops.aten.diagonal.default(clone); clone = None
|
|
mul = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); arg0_1 = None
|
|
return mul
|
|
""")
|
|
|
|
def test_diagonal_mutated_input(self):
|
|
def f(x):
|
|
# simple test: there are pending updates afterwards, which the test syncs manually
|
|
tmp = torch.ones(2)
|
|
y = x.diagonal()
|
|
y.add_(tmp)
|
|
return x
|
|
x = torch.ones(2, 2)
|
|
self.assert_functionalization(f, x)
|
|
logs = self.get_logs(f, torch.ones(2, 2))
|
|
self.assertExpectedInline(logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
|
|
diagonal_copy = torch.ops.aten.diagonal_copy.default(arg0_1)
|
|
add = torch.ops.aten.add.Tensor(diagonal_copy, ones); diagonal_copy = ones = None
|
|
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(arg0_1, add); add = None
|
|
diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter)
|
|
copy_ = torch.ops.aten.copy_.default(arg0_1, diagonal_scatter); arg0_1 = None
|
|
return diagonal_scatter
|
|
""")
|
|
|
|
def test_channels_last_contiguous(self):
|
|
def f(x):
|
|
return x.contiguous(memory_format=torch.channels_last)
|
|
tmp = torch.ones(2)
|
|
y = x.diagonal()
|
|
y.add_(tmp)
|
|
return x
|
|
x = torch.randn(4, 8, 8, 3).permute(0, 3, 1, 2)
|
|
self.assert_functionalization(f, x)
|
|
logs = self.get_logs(f, x).strip()
|
|
# There should be no clone in the graph
|
|
self.assertExpectedInline(logs, """\
|
|
def forward(self, arg0_1):
|
|
return arg0_1""")
|
|
|
|
def test_split(self):
|
|
def f(x):
|
|
# test: view ops that return multiple tensors (split)
|
|
tmp = torch.ones(2)
|
|
y1, y2 = x.split(2)
|
|
y3 = y2.diagonal()
|
|
y3.add_(tmp)
|
|
z = x * x
|
|
return y3
|
|
self.assert_functionalization(f, torch.ones(4, 2))
|
|
logs = self.get_logs(f, torch.ones(4, 2))
|
|
self.assertExpectedInline(logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
|
|
split_copy = torch.ops.aten.split_copy.Tensor(arg0_1, 2)
|
|
getitem = split_copy[0]
|
|
getitem_1 = split_copy[1]; split_copy = None
|
|
diagonal_copy = torch.ops.aten.diagonal_copy.default(getitem_1); getitem_1 = None
|
|
add = torch.ops.aten.add.Tensor(diagonal_copy, ones); diagonal_copy = ones = None
|
|
split_copy_1 = torch.ops.aten.split_copy.Tensor(arg0_1, 2)
|
|
getitem_2 = split_copy_1[0]
|
|
getitem_3 = split_copy_1[1]; split_copy_1 = None
|
|
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_3, add); getitem_3 = add = None
|
|
slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 2, 4); diagonal_scatter = None
|
|
split_copy_2 = torch.ops.aten.split_copy.Tensor(slice_scatter, 2)
|
|
getitem_4 = split_copy_2[0]
|
|
getitem_5 = split_copy_2[1]; split_copy_2 = None
|
|
diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(getitem_5); getitem_5 = None
|
|
mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter)
|
|
copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = None
|
|
return diagonal_copy_1
|
|
""") # noqa: B950
|
|
|
|
def test_view_inplace(self):
|
|
def f(x):
|
|
# test: view + inplace op (transpose_)
|
|
tmp = torch.ones(4)
|
|
x.transpose_(1, 0)
|
|
y = x[0]
|
|
y.add_(tmp)
|
|
return x
|
|
self.assert_functionalization(f, torch.ones(4, 2), mutated_input_metadata=True)
|
|
logs = self.get_logs(f, torch.ones(4, 2))
|
|
self.assertExpectedInline(logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False)
|
|
transpose_copy = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0)
|
|
select_copy = torch.ops.aten.select_copy.int(transpose_copy, 0, 0); transpose_copy = None
|
|
add = torch.ops.aten.add.Tensor(select_copy, ones); select_copy = ones = None
|
|
transpose_copy_1 = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0); arg0_1 = None
|
|
select_scatter = torch.ops.aten.select_scatter.default(transpose_copy_1, add, 0, 0); transpose_copy_1 = add = None
|
|
transpose_copy_2 = torch.ops.aten.transpose_copy.int(select_scatter, 1, 0); select_scatter = None
|
|
transpose_copy_3 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0)
|
|
select_copy_1 = torch.ops.aten.select_copy.int(transpose_copy_3, 0, 0); transpose_copy_3 = None
|
|
transpose_copy_4 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0); transpose_copy_2 = None
|
|
return transpose_copy_4
|
|
""") # noqa: B950
|
|
|
|
def test_optional_tensor_list(self):
|
|
def f(x):
|
|
# test: an operator that takes in a List[Optional[Tensor]] argument
|
|
# (index_put)
|
|
y = x.view(8)
|
|
indices = torch.arange(4)
|
|
values = torch.arange(4, dtype=y.dtype)
|
|
y.index_put_((indices,), values, accumulate=False)
|
|
return y
|
|
self.assert_functionalization(f, torch.ones(4, 2))
|
|
logs = self.get_logs(f, torch.ones(4, 2))
|
|
self.assertExpectedInline(logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
view_copy = torch.ops.aten.view_copy.default(arg0_1, [8])
|
|
arange = torch.ops.aten.arange.default(4, device = device(type='cpu'), pin_memory = False)
|
|
arange_1 = torch.ops.aten.arange.default(4, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
|
index_put = torch.ops.aten.index_put.default(view_copy, [arange], arange_1); view_copy = arange = arange_1 = None
|
|
view_copy_1 = torch.ops.aten.view_copy.default(index_put, [4, 2]); index_put = None
|
|
view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [8])
|
|
copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1); arg0_1 = view_copy_1 = None
|
|
return view_copy_2
|
|
""") # noqa: B950
|
|
|
|
def test_scalars(self):
|
|
def f(x):
|
|
# test: the pass can handle scalar inputs properly
|
|
tmp = torch.ones(4, 2)
|
|
y = x.view(4, 2)
|
|
y.add_(1)
|
|
z = 2 * y
|
|
z.div_(1)
|
|
return z
|
|
self.assert_functionalization(f, torch.ones(4, 2))
|
|
logs = self.get_logs(f, torch.ones(4, 2))
|
|
self.assertExpectedInline(logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
|
|
view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2])
|
|
add = torch.ops.aten.add.Tensor(view_copy, 1); view_copy = None
|
|
view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]); add = None
|
|
view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [4, 2])
|
|
mul = torch.ops.aten.mul.Tensor(view_copy_2, 2); view_copy_2 = None
|
|
div = torch.ops.aten.div.Tensor(mul, 1); mul = None
|
|
copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1); arg0_1 = view_copy_1 = None
|
|
return div
|
|
""")
|
|
|
|
@skipIfTorchDynamo("Test does not work with TorchDynamo")
|
|
def test_metadata_change(self):
|
|
def f(x):
|
|
# ops like ge_() are allowed to change the dtype of the input.
|
|
# functionalization should pick up on that.
|
|
y = x.clone()
|
|
out = y.ge_(0)
|
|
return out
|
|
self.assert_functionalization(f, torch.ones(4, 2))
|
|
logs = self.get_logs(f, torch.ones(4, 2))
|
|
self.assertExpectedInline(logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
|
|
ge = torch.ops.aten.ge.Scalar(clone, 0); clone = None
|
|
_to_copy = torch.ops.aten._to_copy.default(ge, dtype = torch.float32, layout = torch.strided); ge = None
|
|
return _to_copy
|
|
""")
|
|
|
|
reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True)
|
|
self.assertExpectedInline(reinplaced_logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
|
|
ge = torch.ops.aten.ge.Scalar(clone, 0); clone = None
|
|
_to_copy = torch.ops.aten._to_copy.default(ge, dtype = torch.float32, layout = torch.strided); ge = None
|
|
return _to_copy
|
|
""") # noqa: B950
|
|
|
|
@skipIfTorchDynamo("Test does not work with TorchDynamo")
|
|
def test_metadata_change_out_op(self):
|
|
def f(t, y):
|
|
out_1 = torch.ones(1)
|
|
return torch.add(t, y, out=out_1)
|
|
|
|
inpt1, inpt2 = torch.tensor([1]), torch.tensor([1])
|
|
inpt1_func, inpt2_func = torch._to_functional_tensor(inpt1), torch._to_functional_tensor(inpt2)
|
|
|
|
out_ref = f(inpt1, inpt2)
|
|
torch._enable_functionalization(reapply_views=True)
|
|
try:
|
|
out_functional = f(inpt1_func, inpt2_func)
|
|
finally:
|
|
torch._disable_functionalization()
|
|
self.assertEqual(out_ref, torch._from_functional_tensor(out_functional))
|
|
|
|
|
|
def test_only_one_view(self):
|
|
def f(x):
|
|
# This tests that we don't have any unnecessary views in the trace.
|
|
# If the input wasn't mutated, we don't need to regenerate it,
|
|
# so there should be a total of 1 op in the output trace.
|
|
return x.view(4, 2)
|
|
logs = self.get_logs(f, torch.ones(4, 2))
|
|
self.assertExpectedInline(logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2]); arg0_1 = None
|
|
return view_copy
|
|
""")
|
|
|
|
def test_everything(self):
|
|
def f(x):
|
|
# test: everything
|
|
tmp = torch.ones(2, 2)
|
|
x2 = x + x
|
|
y = x2.view(8)
|
|
z0 = y.reshape(2, 4)
|
|
z1 = z0.transpose(1, 0)
|
|
z1.unsqueeze_(0)
|
|
z1.squeeze_()
|
|
z2, z3 = z1.split(2)
|
|
z2.add_(tmp)
|
|
z4 = z0[0] + z2.reshape(4)
|
|
return z2
|
|
self.assert_functionalization(f, torch.ones(4, 2))
|
|
logs = self.get_logs(f, torch.ones(4, 2))
|
|
self.assertExpectedInline(logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
ones = torch.ops.aten.ones.default([2, 2], device = device(type='cpu'), pin_memory = False)
|
|
add = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None
|
|
view_copy = torch.ops.aten.view_copy.default(add, [8])
|
|
view_copy_1 = torch.ops.aten.view_copy.default(view_copy, [2, 4]); view_copy = None
|
|
transpose_copy = torch.ops.aten.transpose_copy.int(view_copy_1, 1, 0)
|
|
unsqueeze_copy = torch.ops.aten.unsqueeze_copy.default(transpose_copy, 0); transpose_copy = None
|
|
squeeze_copy = torch.ops.aten.squeeze_copy.default(unsqueeze_copy); unsqueeze_copy = None
|
|
split_copy = torch.ops.aten.split_copy.Tensor(squeeze_copy, 2); squeeze_copy = None
|
|
getitem = split_copy[0]
|
|
getitem_1 = split_copy[1]; split_copy = None
|
|
add_1 = torch.ops.aten.add.Tensor(getitem, ones); getitem = ones = None
|
|
view_copy_2 = torch.ops.aten.view_copy.default(add, [8]); add = None
|
|
view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [2, 4]); view_copy_2 = None
|
|
transpose_copy_1 = torch.ops.aten.transpose_copy.int(view_copy_3, 1, 0); view_copy_3 = None
|
|
unsqueeze_copy_1 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_1, 0); transpose_copy_1 = None
|
|
squeeze_copy_1 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_1); unsqueeze_copy_1 = None
|
|
slice_scatter = torch.ops.aten.slice_scatter.default(squeeze_copy_1, add_1, 0, 0, 2); squeeze_copy_1 = add_1 = None
|
|
unsqueeze_copy_2 = torch.ops.aten.unsqueeze_copy.default(slice_scatter, 0); slice_scatter = None
|
|
squeeze_copy_2 = torch.ops.aten.squeeze_copy.dim(unsqueeze_copy_2, 0); unsqueeze_copy_2 = None
|
|
transpose_copy_2 = torch.ops.aten.transpose_copy.int(squeeze_copy_2, 1, 0); squeeze_copy_2 = None
|
|
view_copy_4 = torch.ops.aten.view_copy.default(transpose_copy_2, [8]); transpose_copy_2 = None
|
|
view_copy_5 = torch.ops.aten.view_copy.default(view_copy_4, [4, 2]); view_copy_4 = None
|
|
view_copy_6 = torch.ops.aten.view_copy.default(view_copy_5, [8])
|
|
view_copy_7 = torch.ops.aten.view_copy.default(view_copy_6, [2, 4]); view_copy_6 = None
|
|
transpose_copy_3 = torch.ops.aten.transpose_copy.int(view_copy_7, 1, 0); view_copy_7 = None
|
|
unsqueeze_copy_3 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_3, 0); transpose_copy_3 = None
|
|
squeeze_copy_3 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_3); unsqueeze_copy_3 = None
|
|
split_copy_1 = torch.ops.aten.split_copy.Tensor(squeeze_copy_3, 2); squeeze_copy_3 = None
|
|
getitem_2 = split_copy_1[0]
|
|
getitem_3 = split_copy_1[1]; split_copy_1 = None
|
|
select_copy = torch.ops.aten.select_copy.int(view_copy_1, 0, 0); view_copy_1 = None
|
|
view_copy_8 = torch.ops.aten.view_copy.default(getitem_2, [4])
|
|
view_copy_9 = torch.ops.aten.view_copy.default(view_copy_5, [8])
|
|
view_copy_10 = torch.ops.aten.view_copy.default(view_copy_9, [2, 4]); view_copy_9 = None
|
|
select_copy_1 = torch.ops.aten.select_copy.int(view_copy_10, 0, 0); view_copy_10 = None
|
|
view_copy_11 = torch.ops.aten.view_copy.default(view_copy_5, [8]); view_copy_5 = None
|
|
view_copy_12 = torch.ops.aten.view_copy.default(view_copy_11, [2, 4]); view_copy_11 = None
|
|
transpose_copy_4 = torch.ops.aten.transpose_copy.int(view_copy_12, 1, 0); view_copy_12 = None
|
|
unsqueeze_copy_4 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_4, 0); transpose_copy_4 = None
|
|
squeeze_copy_4 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_4); unsqueeze_copy_4 = None
|
|
split_copy_2 = torch.ops.aten.split_copy.Tensor(squeeze_copy_4, 2); squeeze_copy_4 = None
|
|
getitem_4 = split_copy_2[0]
|
|
getitem_5 = split_copy_2[1]; split_copy_2 = None
|
|
view_copy_13 = torch.ops.aten.view_copy.default(getitem_4, [4]); getitem_4 = None
|
|
add_2 = torch.ops.aten.add.Tensor(select_copy_1, view_copy_13); select_copy_1 = view_copy_13 = None
|
|
return getitem_2
|
|
""") # noqa: B950
|
|
|
|
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True)
|
|
self.assertExpectedInline(reinplaced_logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
ones = torch.ops.aten.ones.default([2, 2], device = device(type='cpu'), pin_memory = False)
|
|
add = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None
|
|
view = torch.ops.aten.view.default(add, [8])
|
|
view_1 = torch.ops.aten.view.default(view, [2, 4]); view = None
|
|
transpose = torch.ops.aten.transpose.int(view_1, 1, 0)
|
|
unsqueeze = torch.ops.aten.unsqueeze.default(transpose, 0); transpose = None
|
|
squeeze = torch.ops.aten.squeeze.default(unsqueeze); unsqueeze = None
|
|
split = torch.ops.aten.split.Tensor(squeeze, 2); squeeze = None
|
|
getitem = split[0]
|
|
getitem_1 = split[1]; split = None
|
|
add_1 = torch.ops.aten.add_.Tensor(getitem, ones); getitem = ones = None
|
|
view_2 = torch.ops.aten.view.default(add, [8]); add = None
|
|
view_3 = torch.ops.aten.view.default(view_2, [2, 4]); view_2 = None
|
|
transpose_1 = torch.ops.aten.transpose.int(view_3, 1, 0); view_3 = None
|
|
unsqueeze_1 = torch.ops.aten.unsqueeze.default(transpose_1, 0); transpose_1 = None
|
|
squeeze_1 = torch.ops.aten.squeeze.default(unsqueeze_1); unsqueeze_1 = None
|
|
unsqueeze_2 = torch.ops.aten.unsqueeze.default(squeeze_1, 0); squeeze_1 = None
|
|
squeeze_2 = torch.ops.aten.squeeze.dim(unsqueeze_2, 0); unsqueeze_2 = None
|
|
transpose_2 = torch.ops.aten.transpose.int(squeeze_2, 1, 0); squeeze_2 = None
|
|
view_4 = torch.ops.aten.view.default(transpose_2, [8]); transpose_2 = None
|
|
view_5 = torch.ops.aten.view.default(view_4, [4, 2]); view_4 = None
|
|
view_6 = torch.ops.aten.view.default(view_5, [8])
|
|
view_7 = torch.ops.aten.view.default(view_6, [2, 4]); view_6 = None
|
|
transpose_3 = torch.ops.aten.transpose.int(view_7, 1, 0); view_7 = None
|
|
unsqueeze_3 = torch.ops.aten.unsqueeze.default(transpose_3, 0); transpose_3 = None
|
|
squeeze_3 = torch.ops.aten.squeeze.default(unsqueeze_3); unsqueeze_3 = None
|
|
split_1 = torch.ops.aten.split.Tensor(squeeze_3, 2); squeeze_3 = None
|
|
getitem_2 = split_1[0]
|
|
getitem_3 = split_1[1]; split_1 = None
|
|
select = torch.ops.aten.select.int(view_1, 0, 0); view_1 = None
|
|
clone = torch.ops.aten.clone.default(getitem_2, memory_format = torch.contiguous_format)
|
|
_unsafe_view = torch.ops.aten._unsafe_view.default(clone, [4]); clone = None
|
|
view_8 = torch.ops.aten.view.default(view_5, [8]); view_5 = None
|
|
view_9 = torch.ops.aten.view.default(view_8, [2, 4]); view_8 = None
|
|
select_1 = torch.ops.aten.select.int(view_9, 0, 0); view_9 = None
|
|
add_2 = torch.ops.aten.add.Tensor(select_1, _unsafe_view); select_1 = _unsafe_view = None
|
|
return getitem_2
|
|
""")
|
|
|
|
def test_reapply_views_simple(self):
|
|
def f(x):
|
|
tmp = torch.ones(4, 2)
|
|
y = x.view(4, 2)
|
|
y.add_(tmp)
|
|
z = x * x
|
|
return y
|
|
self.assert_functionalization(f, torch.ones(4, 2), reapply_views=True)
|
|
logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True)
|
|
self.assertExpectedInline(logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
|
|
view = torch.ops.aten.view.default(arg0_1, [4, 2])
|
|
add = torch.ops.aten.add.Tensor(view, ones); view = ones = None
|
|
view_1 = torch.ops.aten.view.default(add, [4, 2]); add = None
|
|
view_2 = torch.ops.aten.view.default(view_1, [4, 2])
|
|
mul = torch.ops.aten.mul.Tensor(view_1, view_1)
|
|
copy_ = torch.ops.aten.copy_.default(arg0_1, view_1); arg0_1 = view_1 = None
|
|
return view_2
|
|
""")
|
|
|
|
def test_aliases_maintained_after_pass_when_reapplying_views(self):
|
|
def f(x):
|
|
tmp = torch.ones(4, 2)
|
|
y = x.view(4, 2)
|
|
z = x.view(4, 2)
|
|
y.add_(tmp)
|
|
return y, z
|
|
|
|
input_functional = torch._to_functional_tensor(torch.ones(4, 2))
|
|
torch._enable_functionalization(reapply_views=True)
|
|
try:
|
|
y, z = f(input_functional)
|
|
torch._sync(y)
|
|
torch._sync(z)
|
|
finally:
|
|
torch._disable_functionalization()
|
|
|
|
# y and z are aliases inside of the function, and that aliasing relationship should be maintained.
|
|
_y = torch._from_functional_tensor(y)
|
|
_z = torch._from_functional_tensor(z)
|
|
self.assertTrue(are_aliased(_y, _z))
|
|
|
|
# copy_() gets its own test, because it used to be special cased in functionalization.
|
|
# However, now it works pretty similar to other functional ops
|
|
def test_copy_(self):
|
|
def f(x):
|
|
tmp = torch.zeros(2, 2)
|
|
tmp_slice = tmp.diagonal()
|
|
y = tmp_slice.copy_(x)
|
|
z = y.add_(x)
|
|
return z
|
|
|
|
# Test 1: copy_() with same dtype and shape
|
|
# to() is a composite op that noops when the dtype/shape match, so nothing gets logged.
|
|
# self.assert_functionalization(f, torch.ones(2))
|
|
logs = self.get_logs(f, torch.ones(2))
|
|
self.assertExpectedInline(logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
|
|
diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros)
|
|
copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1); diagonal_copy = None
|
|
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(zeros, copy); zeros = copy = None
|
|
diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter)
|
|
add = torch.ops.aten.add.Tensor(diagonal_copy_1, arg0_1); diagonal_copy_1 = arg0_1 = None
|
|
diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add); diagonal_scatter = add = None
|
|
diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1); diagonal_scatter_1 = None
|
|
return diagonal_copy_2
|
|
""")
|
|
|
|
reinplaced_logs = self.get_logs(f, torch.ones(2), reapply_views=True, run_reinplace=True)
|
|
self.assertExpectedInline(reinplaced_logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
|
|
diagonal = torch.ops.aten.diagonal.default(zeros)
|
|
copy = torch.ops.aten.copy_.default(diagonal, arg0_1); diagonal = None
|
|
diagonal_1 = torch.ops.aten.diagonal.default(zeros)
|
|
add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1); diagonal_1 = arg0_1 = None
|
|
diagonal_2 = torch.ops.aten.diagonal.default(zeros); zeros = None
|
|
return diagonal_2
|
|
""")
|
|
|
|
# Test 2: copy_() with same dtype, different shape
|
|
self.assert_functionalization(f, torch.ones(1))
|
|
logs = self.get_logs(f, torch.ones(1))
|
|
self.assertExpectedInline(logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
|
|
diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros)
|
|
copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1); diagonal_copy = None
|
|
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(zeros, copy); zeros = copy = None
|
|
diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter)
|
|
add = torch.ops.aten.add.Tensor(diagonal_copy_1, arg0_1); diagonal_copy_1 = arg0_1 = None
|
|
diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add); diagonal_scatter = add = None
|
|
diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1); diagonal_scatter_1 = None
|
|
return diagonal_copy_2
|
|
""")
|
|
|
|
reinplaced_logs = self.get_logs(f, torch.ones(1), reapply_views=True, run_reinplace=True)
|
|
self.assertExpectedInline(reinplaced_logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
|
|
diagonal = torch.ops.aten.diagonal.default(zeros)
|
|
copy = torch.ops.aten.copy_.default(diagonal, arg0_1); diagonal = None
|
|
diagonal_1 = torch.ops.aten.diagonal.default(zeros)
|
|
add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1); diagonal_1 = arg0_1 = None
|
|
diagonal_2 = torch.ops.aten.diagonal.default(zeros); zeros = None
|
|
return diagonal_2
|
|
""")
|
|
|
|
# Test 3: copy_() with different dtype, same shape
|
|
self.assert_functionalization(f, torch.ones(2, dtype=torch.long))
|
|
logs = self.get_logs(f, torch.ones(2, dtype=torch.long))
|
|
self.assertExpectedInline(logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
|
|
diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros)
|
|
copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1); diagonal_copy = None
|
|
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(zeros, copy); zeros = copy = None
|
|
diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter)
|
|
add = torch.ops.aten.add.Tensor(diagonal_copy_1, arg0_1); diagonal_copy_1 = arg0_1 = None
|
|
diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add); diagonal_scatter = add = None
|
|
diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1); diagonal_scatter_1 = None
|
|
return diagonal_copy_2
|
|
""") # noqa: B950
|
|
|
|
reinplaced_logs = self.get_logs(f, torch.ones(2, dtype=torch.long), reapply_views=True, run_reinplace=True)
|
|
self.assertExpectedInline(reinplaced_logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
|
|
diagonal = torch.ops.aten.diagonal.default(zeros)
|
|
copy = torch.ops.aten.copy_.default(diagonal, arg0_1); diagonal = None
|
|
diagonal_1 = torch.ops.aten.diagonal.default(zeros)
|
|
add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1); diagonal_1 = arg0_1 = None
|
|
diagonal_2 = torch.ops.aten.diagonal.default(zeros); zeros = None
|
|
return diagonal_2
|
|
""") # noqa: B950
|
|
|
|
# Test 4: copy_() with different dtype, different shape
|
|
self.assert_functionalization(f, torch.ones(1, dtype=torch.long))
|
|
logs = self.get_logs(f, torch.ones(1, dtype=torch.long))
|
|
self.assertExpectedInline(logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
|
|
diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros)
|
|
copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1); diagonal_copy = None
|
|
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(zeros, copy); zeros = copy = None
|
|
diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter)
|
|
add = torch.ops.aten.add.Tensor(diagonal_copy_1, arg0_1); diagonal_copy_1 = arg0_1 = None
|
|
diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add); diagonal_scatter = add = None
|
|
diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1); diagonal_scatter_1 = None
|
|
return diagonal_copy_2
|
|
""") # noqa: B950
|
|
|
|
reinplaced_logs = self.get_logs(f, torch.ones(1, dtype=torch.long), reapply_views=True, run_reinplace=True)
|
|
self.assertExpectedInline(reinplaced_logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
|
|
diagonal = torch.ops.aten.diagonal.default(zeros)
|
|
copy = torch.ops.aten.copy_.default(diagonal, arg0_1); diagonal = None
|
|
diagonal_1 = torch.ops.aten.diagonal.default(zeros)
|
|
add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1); diagonal_1 = arg0_1 = None
|
|
diagonal_2 = torch.ops.aten.diagonal.default(zeros); zeros = None
|
|
return diagonal_2
|
|
""") # noqa: B950
|
|
|
|
def test_expand_symint(self):
|
|
# Once some existing SymInt bugs are ironed out, we should update
|
|
# this test to plumb FakeSymbolicTensors through it
|
|
def f(x):
|
|
return x.expand(x.size(0), x.size(1))
|
|
|
|
self.assert_functionalization(f, torch.ones(2, 2))
|
|
logs = self.get_logs(f, torch.ones(2, 2))
|
|
self.assertExpectedInline(logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
expand_copy = torch.ops.aten.expand_copy.default(arg0_1, [2, 2]); arg0_1 = None
|
|
return expand_copy
|
|
""")
|
|
|
|
def test_fill_(self):
|
|
def f(x):
|
|
y = x + x
|
|
z = y.diagonal()
|
|
z.fill_(0)
|
|
return y
|
|
|
|
self.assert_functionalization(f, torch.ones(2, 2))
|
|
logs = self.get_logs(f, torch.ones(2, 2))
|
|
self.assertExpectedInline(logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
add = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None
|
|
diagonal_copy = torch.ops.aten.diagonal_copy.default(add)
|
|
fill = torch.ops.aten.fill.Scalar(diagonal_copy, 0); diagonal_copy = None
|
|
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(add, fill); add = fill = None
|
|
diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter)
|
|
return diagonal_scatter
|
|
""")
|
|
|
|
reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True)
|
|
self.assertExpectedInline(reinplaced_logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
add = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None
|
|
diagonal = torch.ops.aten.diagonal.default(add)
|
|
fill = torch.ops.aten.fill_.Scalar(diagonal, 0); diagonal = None
|
|
diagonal_1 = torch.ops.aten.diagonal.default(add)
|
|
return add
|
|
""")
|
|
|
|
def test_resize_smaller(self):
|
|
def f(w):
|
|
# Resizing to a smaller size doesn't affect storage
|
|
x = w + 1
|
|
y = x.view(4, 4)
|
|
y.resize_(3, 3)
|
|
y2 = y.view(-1)
|
|
y2.add_(1)
|
|
z = y + 1
|
|
return z
|
|
|
|
self.assert_functionalization(f, torch.ones(8, 2))
|
|
logs = self.get_logs(f, torch.ones(8, 2))
|
|
self.assertExpectedInline(logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
add = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None
|
|
view_copy = torch.ops.aten.view_copy.default(add, [4, 4])
|
|
resize = torch.ops.aten.resize.default(view_copy, [3, 3])
|
|
as_strided_copy = torch.ops.aten.as_strided_copy.default(view_copy, [3, 3], [3, 1]); view_copy = None
|
|
view_copy_1 = torch.ops.aten.view_copy.default(as_strided_copy, [-1]); as_strided_copy = None
|
|
add_1 = torch.ops.aten.add.Tensor(view_copy_1, 1); view_copy_1 = None
|
|
view_copy_2 = torch.ops.aten.view_copy.default(add, [4, 4]); add = None
|
|
as_strided_copy_1 = torch.ops.aten.as_strided_copy.default(view_copy_2, [3, 3], [3, 1])
|
|
view_copy_3 = torch.ops.aten.view_copy.default(add_1, [3, 3]); add_1 = None
|
|
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(view_copy_2, view_copy_3, [3, 3], [3, 1]); view_copy_2 = view_copy_3 = None
|
|
view_copy_4 = torch.ops.aten.view_copy.default(as_strided_scatter, [8, 2]); as_strided_scatter = None
|
|
view_copy_5 = torch.ops.aten.view_copy.default(view_copy_4, [4, 4])
|
|
as_strided_copy_2 = torch.ops.aten.as_strided_copy.default(view_copy_5, [3, 3], [3, 1]); view_copy_5 = None
|
|
view_copy_6 = torch.ops.aten.view_copy.default(as_strided_copy_2, [-1]); as_strided_copy_2 = None
|
|
view_copy_7 = torch.ops.aten.view_copy.default(view_copy_4, [4, 4]); view_copy_4 = None
|
|
as_strided_copy_3 = torch.ops.aten.as_strided_copy.default(view_copy_7, [3, 3], [3, 1]); view_copy_7 = None
|
|
add_2 = torch.ops.aten.add.Tensor(as_strided_copy_3, 1); as_strided_copy_3 = None
|
|
return add_2
|
|
""") # noqa: B950
|
|
|
|
reinplaced_logs = self.get_logs(f, torch.ones(8, 2), reapply_views=True, run_reinplace=True)
|
|
self.assertExpectedInline(reinplaced_logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
add = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None
|
|
view = torch.ops.aten.view.default(add, [4, 4])
|
|
resize = torch.ops.aten.resize.default(view, [3, 3])
|
|
as_strided = torch.ops.aten.as_strided.default(view, [3, 3], [3, 1]); view = None
|
|
view_1 = torch.ops.aten.view.default(as_strided, [-1]); as_strided = None
|
|
add_1 = torch.ops.aten.add_.Tensor(view_1, 1)
|
|
view_2 = torch.ops.aten.view.default(add, [4, 4]); add = None
|
|
as_strided_1 = torch.ops.aten.as_strided.default(view_2, [3, 3], [3, 1])
|
|
view_3 = torch.ops.aten.view.default(view_1, [3, 3]); view_1 = None
|
|
view_4 = torch.ops.aten.view.default(view_2, [8, 2]); view_2 = None
|
|
view_5 = torch.ops.aten.view.default(view_4, [4, 4])
|
|
as_strided_2 = torch.ops.aten.as_strided.default(view_5, [3, 3], [3, 1]); view_5 = None
|
|
view_6 = torch.ops.aten.view.default(as_strided_2, [-1]); as_strided_2 = None
|
|
view_7 = torch.ops.aten.view.default(view_4, [4, 4]); view_4 = None
|
|
as_strided_3 = torch.ops.aten.as_strided.default(view_7, [3, 3], [3, 1]); view_7 = None
|
|
add_2 = torch.ops.aten.add_.Tensor(as_strided_3, 1)
|
|
return as_strided_3
|
|
""")
|
|
|
|
def test_resize_same_size_diff_rank(self):
|
|
def f(x):
|
|
y = x.clone()
|
|
y.resize_(25, 5)
|
|
return y
|
|
|
|
self.assert_functionalization(f, torch.ones(5, 5, 5))
|
|
|
|
def test_resize_larger_valid(self):
|
|
def f(x):
|
|
y = x + 1
|
|
# resizing a tensor to a larger size is only currently allowed
|
|
# if the tensor-to-resize is not a view / has no outstanding views.
|
|
# See Note [resize_() in functionalization pass]
|
|
y.resize_(5, 5)
|
|
y2 = y.view(25)
|
|
# Do a mutation to ensure that aliases of the output of resize_()
|
|
# propagate mutations correctly.
|
|
# I'm using fill_ specifically because I want to guarantee that
|
|
# none of the output has uninitialized memory at the end
|
|
# (since these tests compare the data output against a reference impl)
|
|
y2.fill_(1)
|
|
out = y + 1
|
|
return y, out
|
|
|
|
self.assert_functionalization(f, torch.ones(8, 2))
|
|
logs = self.get_logs(f, torch.ones(8, 2))
|
|
self.assertExpectedInline(logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
add = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None
|
|
resize = torch.ops.aten.resize.default(add, [5, 5]); add = None
|
|
view_copy = torch.ops.aten.view_copy.default(resize, [25]); resize = None
|
|
fill = torch.ops.aten.fill.Scalar(view_copy, 1); view_copy = None
|
|
view_copy_1 = torch.ops.aten.view_copy.default(fill, [5, 5]); fill = None
|
|
view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [25])
|
|
add_1 = torch.ops.aten.add.Tensor(view_copy_1, 1)
|
|
return (view_copy_1, add_1)
|
|
""")
|
|
|
|
reinplaced_logs = self.get_logs(f, torch.ones(8, 2), reapply_views=True, run_reinplace=True)
|
|
self.assertExpectedInline(reinplaced_logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
add = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None
|
|
resize = torch.ops.aten.resize_.default(add, [5, 5])
|
|
view = torch.ops.aten.view.default(add, [25]); add = None
|
|
fill = torch.ops.aten.fill_.Scalar(view, 1)
|
|
view_1 = torch.ops.aten.view.default(view, [5, 5]); view = None
|
|
view_2 = torch.ops.aten.view.default(view_1, [25])
|
|
add_1 = torch.ops.aten.add.Tensor(view_1, 1)
|
|
return (view_1, add_1)
|
|
""")
|
|
|
|
def test_resize_larger_invalid(self):
|
|
def f(x):
|
|
y = x + 1
|
|
z = y.view(4, 4)
|
|
# resizing a tensor to a larger size is only currently allowed
|
|
# if the tensor-to-resize is not a view / has no outstanding views.
|
|
# See Note [resize_() in functionalization pass]
|
|
# This should fail
|
|
z.resize_(5, 5)
|
|
z2 = z.view(25)
|
|
z2.fill_(1)
|
|
out = z + 1
|
|
return y, out
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r'Attempted to resize a view tensor to a larger size. This is not allowed in the functionalization pass'):
|
|
self.assert_functionalization(f, torch.ones(8, 2))
|
|
|
|
def test_nested_functions_propagate_updates(self):
|
|
def g(x):
|
|
# Create a view of x
|
|
y = x[0]
|
|
y.add_(1)
|
|
# The view, y, gets deallocated at the end of this function
|
|
|
|
def f(x):
|
|
# Calling g(x) should mutate x
|
|
g(x)
|
|
# We expect x to be synced here, even though the alias created in g() has been deallocated!
|
|
y = x + x
|
|
return y
|
|
|
|
self.assert_functionalization(f, torch.ones(2, 2))
|
|
|
|
def test_mixed_wrappers_valid(self):
|
|
def f(x, y):
|
|
z = x + y
|
|
z.add_(1)
|
|
return z
|
|
|
|
x1_not_functional = LoggingTensor(torch.ones(4))
|
|
x2_functional = torch._to_functional_tensor(LoggingTensor(torch.ones(4)))
|
|
|
|
with capture_logs() as logs:
|
|
y = f(x1_not_functional, x2_functional)
|
|
|
|
# Make sure that functionalization ran the "+" kernel
|
|
# with a functional + non-functional tensor, and wrapped the output appropriately.
|
|
self.assertExpectedInline('\n'.join(logs), """\
|
|
$2: f32[4] = torch._ops.aten.add.Tensor($0, $1)
|
|
$3: f32[4] = torch._ops.aten.add.Tensor($2, 1)""")
|
|
|
|
def test_mixed_wrappers_invalid(self):
|
|
x1_not_functional = torch.ones(4)
|
|
x2_functional = torch._to_functional_tensor(torch.ones(4))
|
|
|
|
# When dealing with mixed functional + non functional tensors,
|
|
# normal_tensor.add_(functional_tensor) is not valid
|
|
# because normal_tensor would need to be "promoted" to a functional tensor.
|
|
with self.assertRaises(RuntimeError):
|
|
x1_not_functional.add_(x2_functional)
|
|
|
|
def test_index_mutation_on_non_input(self):
|
|
def f(x):
|
|
tmp = torch.zeros(10)
|
|
tmp[5].fill_(1)
|
|
return tmp
|
|
self.assert_functionalization(f, torch.ones(2))
|
|
logs = self.get_logs(f, torch.ones(2))
|
|
self.assertExpectedInline(logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
zeros = torch.ops.aten.zeros.default([10], device = device(type='cpu'), pin_memory = False)
|
|
select_copy = torch.ops.aten.select_copy.int(zeros, 0, 5)
|
|
fill = torch.ops.aten.fill.Scalar(select_copy, 1); select_copy = None
|
|
select_scatter = torch.ops.aten.select_scatter.default(zeros, fill, 0, 5); zeros = fill = None
|
|
select_copy_1 = torch.ops.aten.select_copy.int(select_scatter, 0, 5)
|
|
return select_scatter
|
|
""") # noqa: B950
|
|
|
|
reinplaced_logs = self.get_logs(f, torch.ones(2), reapply_views=True, run_reinplace=True)
|
|
self.assertExpectedInline(reinplaced_logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1):
|
|
zeros = torch.ops.aten.zeros.default([10], device = device(type='cpu'), pin_memory = False)
|
|
select = torch.ops.aten.select.int(zeros, 0, 5)
|
|
fill = torch.ops.aten.fill_.Scalar(select, 1); select = None
|
|
select_1 = torch.ops.aten.select.int(zeros, 0, 5)
|
|
return zeros
|
|
""")
|
|
|
|
|
|
def test_instance_norm(self):
|
|
size = 100
|
|
|
|
def f(x, running_mean, running_var):
|
|
with enable_python_dispatcher():
|
|
return torch.instance_norm(x, None, None, running_mean, running_var,
|
|
use_input_stats=True, momentum=0.1, eps=1e-5, cudnn_enabled=False)
|
|
self.assert_functionalization(f, torch.randn(20, size, 35, 45), torch.zeros(size), torch.ones(size))
|
|
# On Windows, for instance_norm, the alias_copy's are reordered to come right before they need to be used
|
|
# whereas on other platforms, the alias_copy's are before the view_copy's.
|
|
# e.g., the alias_copy after the getitem_4 assignment would be moved to be right before the copy assignment.
|
|
if not IS_WINDOWS:
|
|
logs = self.get_logs(f, torch.randn(20, size, 35, 45), torch.zeros(size), torch.ones(size))
|
|
self.assertExpectedInline(logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1, arg1_1, arg2_1):
|
|
repeat = torch.ops.aten.repeat.default(arg1_1, [20])
|
|
repeat_1 = torch.ops.aten.repeat.default(arg2_1, [20])
|
|
view_copy = torch.ops.aten.view_copy.default(arg0_1, [1, 2000, 35, 45]); arg0_1 = None
|
|
empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu'))
|
|
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(view_copy, None, None, repeat, repeat_1, True, 0.1, 1e-05); view_copy = repeat = repeat_1 = None
|
|
getitem = _native_batch_norm_legit_functional[0]
|
|
getitem_1 = _native_batch_norm_legit_functional[1]
|
|
getitem_2 = _native_batch_norm_legit_functional[2]
|
|
getitem_3 = _native_batch_norm_legit_functional[3]
|
|
getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
|
|
alias_copy = torch.ops.aten.alias_copy.default(arg1_1)
|
|
view_copy_1 = torch.ops.aten.view_copy.default(getitem_3, [20, 100])
|
|
view_copy_2 = torch.ops.aten.view_copy.default(getitem_3, [20, 100]); getitem_3 = None
|
|
mean = torch.ops.aten.mean.dim(view_copy_2, [0]); view_copy_2 = None
|
|
copy = torch.ops.aten.copy.default(alias_copy, mean); alias_copy = mean = None
|
|
alias_copy_1 = torch.ops.aten.alias_copy.default(copy); copy = None
|
|
alias_copy_2 = torch.ops.aten.alias_copy.default(alias_copy_1)
|
|
alias_copy_3 = torch.ops.aten.alias_copy.default(arg2_1)
|
|
view_copy_3 = torch.ops.aten.view_copy.default(getitem_4, [20, 100])
|
|
view_copy_4 = torch.ops.aten.view_copy.default(getitem_4, [20, 100]); getitem_4 = None
|
|
mean_1 = torch.ops.aten.mean.dim(view_copy_4, [0]); view_copy_4 = None
|
|
copy_1 = torch.ops.aten.copy.default(alias_copy_3, mean_1); alias_copy_3 = mean_1 = None
|
|
alias_copy_4 = torch.ops.aten.alias_copy.default(copy_1); copy_1 = None
|
|
alias_copy_5 = torch.ops.aten.alias_copy.default(alias_copy_4)
|
|
view_copy_5 = torch.ops.aten.view_copy.default(getitem, [20, 100, 35, 45]); getitem = None
|
|
copy_ = torch.ops.aten.copy_.default(arg1_1, alias_copy_1); arg1_1 = alias_copy_1 = None
|
|
copy__1 = torch.ops.aten.copy_.default(arg2_1, alias_copy_4); arg2_1 = alias_copy_4 = None
|
|
return view_copy_5
|
|
""") # noqa: B950
|
|
|
|
reinplaced_logs = self.get_logs(
|
|
f, torch.randn(20, size, 35, 45), torch.zeros(size), torch.ones(size),
|
|
reapply_views=True, run_reinplace=True
|
|
)
|
|
self.assertExpectedInline(reinplaced_logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1, arg1_1, arg2_1):
|
|
repeat = torch.ops.aten.repeat.default(arg1_1, [20])
|
|
repeat_1 = torch.ops.aten.repeat.default(arg2_1, [20])
|
|
view = torch.ops.aten.view.default(arg0_1, [1, 2000, 35, 45]); arg0_1 = None
|
|
empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu'))
|
|
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(view, None, None, repeat, repeat_1, True, 0.1, 1e-05); view = repeat = repeat_1 = None
|
|
getitem = _native_batch_norm_legit_functional[0]
|
|
getitem_1 = _native_batch_norm_legit_functional[1]
|
|
getitem_2 = _native_batch_norm_legit_functional[2]
|
|
getitem_3 = _native_batch_norm_legit_functional[3]
|
|
getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
|
|
alias = torch.ops.aten.alias.default(arg1_1)
|
|
view_1 = torch.ops.aten.view.default(getitem_3, [20, 100])
|
|
view_2 = torch.ops.aten.view.default(getitem_3, [20, 100]); getitem_3 = None
|
|
mean = torch.ops.aten.mean.dim(view_2, [0]); view_2 = None
|
|
copy = torch.ops.aten.copy.default(alias, mean); alias = mean = None
|
|
alias_1 = torch.ops.aten.alias.default(copy); copy = None
|
|
alias_2 = torch.ops.aten.alias.default(alias_1)
|
|
alias_3 = torch.ops.aten.alias.default(arg2_1)
|
|
view_3 = torch.ops.aten.view.default(getitem_4, [20, 100])
|
|
view_4 = torch.ops.aten.view.default(getitem_4, [20, 100]); getitem_4 = None
|
|
mean_1 = torch.ops.aten.mean.dim(view_4, [0]); view_4 = None
|
|
copy_1 = torch.ops.aten.copy.default(alias_3, mean_1); alias_3 = mean_1 = None
|
|
alias_4 = torch.ops.aten.alias.default(copy_1); copy_1 = None
|
|
alias_5 = torch.ops.aten.alias.default(alias_4)
|
|
view_5 = torch.ops.aten.view.default(getitem, [20, 100, 35, 45]); getitem = None
|
|
copy_ = torch.ops.aten.copy_.default(arg1_1, alias_1); arg1_1 = alias_1 = None
|
|
copy__1 = torch.ops.aten.copy_.default(arg2_1, alias_4); arg2_1 = alias_4 = None
|
|
return view_5
|
|
""") # noqa: B950
|
|
|
|
|
|
def test_mutation_overlapping_mem(self):
|
|
def fn(x):
|
|
# x: (1, 5)
|
|
t1 = torch.add(x, x)
|
|
t2 = t1.unfold(1, 3, 2)
|
|
t3 = t2.abs_()
|
|
return t3
|
|
|
|
with self.assertRaisesRegex(RuntimeError, r'encountered a tensor being mutated that has internal overlap'):
|
|
x = torch.ones(1, 5)
|
|
out = _functionalize(fn, reapply_views=True, crossref=False)(x)
|
|
|
|
|
|
def test_batch_norm(self):
|
|
def f(x, running_mean, running_var):
|
|
with enable_python_dispatcher():
|
|
return torch.batch_norm(x, None, None, running_mean, running_var, True, 0.1, 1e-5, False)
|
|
|
|
self.assert_functionalization(f, torch.randn(20, 100, 35, 45), torch.zeros(100), torch.ones(100))
|
|
logs = self.get_logs(f, torch.randn(20, 100, 35, 45), torch.zeros(100), torch.ones(100))
|
|
self.assertExpectedInline(logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1, arg1_1, arg2_1):
|
|
empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu'))
|
|
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(arg0_1, None, None, arg1_1, arg2_1, True, 0.1, 1e-05); arg0_1 = None
|
|
getitem = _native_batch_norm_legit_functional[0]
|
|
getitem_1 = _native_batch_norm_legit_functional[1]
|
|
getitem_2 = _native_batch_norm_legit_functional[2]
|
|
getitem_3 = _native_batch_norm_legit_functional[3]
|
|
getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
|
|
copy_ = torch.ops.aten.copy_.default(arg1_1, getitem_3); arg1_1 = getitem_3 = None
|
|
copy__1 = torch.ops.aten.copy_.default(arg2_1, getitem_4); arg2_1 = getitem_4 = None
|
|
return getitem
|
|
""") # noqa: B950
|
|
|
|
reinplaced_logs = self.get_logs(
|
|
f, torch.randn(20, 100, 35, 45), torch.zeros(100), torch.ones(100), reapply_views=True, run_reinplace=True
|
|
)
|
|
self.assertExpectedInline(reinplaced_logs, """\
|
|
|
|
|
|
|
|
def forward(self, arg0_1, arg1_1, arg2_1):
|
|
empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu'))
|
|
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(arg0_1, None, None, arg1_1, arg2_1, True, 0.1, 1e-05); arg0_1 = None
|
|
getitem = _native_batch_norm_legit_functional[0]
|
|
getitem_1 = _native_batch_norm_legit_functional[1]
|
|
getitem_2 = _native_batch_norm_legit_functional[2]
|
|
getitem_3 = _native_batch_norm_legit_functional[3]
|
|
getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
|
|
copy_ = torch.ops.aten.copy_.default(arg1_1, getitem_3); arg1_1 = getitem_3 = None
|
|
copy__1 = torch.ops.aten.copy_.default(arg2_1, getitem_4); arg2_1 = getitem_4 = None
|
|
return getitem
|
|
""") # noqa: B950
|
|
|
|
# This tests our python shims around C++ Functionalization: FunctionalTensor and FunctionalTensorMode
|
|
def test_python_functionalization(self):
|
|
def f(x):
|
|
x_view = x.view(-1)
|
|
x.mul_(2)
|
|
return x_view + 1
|
|
|
|
def f_functionalized(x):
|
|
x_wrapped = FunctionalTensor.to_functional(x)
|
|
|
|
# Note [Disabling Functionalize TLS Above Python Functionalization]
|
|
# This UX is pretty annoying (although python functionalization's main customer is AOTAutograd,
|
|
# and is not really advertised as a user API).
|
|
# We need to explicitly disable functionalization when using python FunctionalTensor and FunctionalTensorMode.
|
|
# Why? FunctionalTensor is a wrapper tensor that holds an inner FunctionalTensorWrapper.
|
|
# Since the inner tensor has `DispatchKey.Functionalize` in its keyset, then by default,
|
|
# our FunctionalTensor will inherit the same keyset.
|
|
# We don't have an easy way of directly mutating a tensor's keyset from python,
|
|
# so globally disabling functionalization here is easier.
|
|
maybe_disable = torch._C._ExcludeDispatchKeyGuard(torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize))
|
|
with maybe_disable, FunctionalTensorMode():
|
|
out_wrapped = f(x_wrapped)
|
|
out_unwrapped = out_wrapped.elem
|
|
torch._sync(out_unwrapped)
|
|
return torch._from_functional_tensor(out_unwrapped)
|
|
|
|
# Make a non-leaf
|
|
x = torch.randn(2, requires_grad=True) + 1
|
|
fx_g = make_fx(f_functionalized)(x)
|
|
self.assertExpectedInline(fx_g.code.strip(), """\
|
|
def forward(self, x_1):
|
|
view = torch.ops.aten.view.default(x_1, [-1])
|
|
mul = torch.ops.aten.mul.Tensor(x_1, 2); x_1 = None
|
|
view_1 = torch.ops.aten.view.default(mul, [-1]); mul = None
|
|
add = torch.ops.aten.add.Tensor(view_1, 1); view_1 = None
|
|
return add""")
|
|
|
|
def test_python_functionalization_zero_tensor(self):
|
|
def f(x):
|
|
y = torch.ops.aten._efficientzerotensor([4])
|
|
out = x + y
|
|
out.mul_(2)
|
|
return out
|
|
x = torch.randn(4)
|
|
out_ref = f(x)
|
|
out_test = dispatch_functionalize(f)(x)
|
|
out_test_cpp = _functionalize(f, reapply_views=True, crossref=False, skip_input_mutations=True)(x)
|
|
self.assertEqual(out_ref, out_test)
|
|
self.assertEqual(out_ref, out_test_cpp)
|
|
fx_g = make_fx(dispatch_functionalize(f))(x)
|
|
fx_g_cpp = make_fx(_functionalize(f, reapply_views=True, crossref=False, skip_input_mutations=True))(x)
|
|
self.assertEqual(fx_g_cpp.code.strip(), fx_g.code.strip())
|
|
|
|
def test_python_functionalization_is_conj(self):
|
|
def f(x):
|
|
out = x.conj()
|
|
return out, out.is_conj()
|
|
|
|
x = torch.randn(4, dtype=torch.complex64)
|
|
out_ref = f(x)
|
|
out_test = dispatch_functionalize(f)(x)
|
|
out_test_cpp = _functionalize(f, reapply_views=True, crossref=False)(x)
|
|
self.assertEqual(out_ref[0], out_test[0])
|
|
self.assertEqual(out_ref[1], out_test[1])
|
|
self.assertEqual(out_ref[0], out_test_cpp[0])
|
|
self.assertEqual(out_ref[1], out_test_cpp[1])
|
|
|
|
def test_python_functionalization_is_neg(self):
|
|
def f(x):
|
|
out = x.neg()
|
|
return out, out.is_neg()
|
|
|
|
x = torch.randn(4, dtype=torch.complex64)
|
|
out_ref = f(x)
|
|
out_test = dispatch_functionalize(f)(x)
|
|
out_test_cpp = _functionalize(f, reapply_views=True, crossref=False)(x)
|
|
self.assertEqual(out_ref[0], out_test[0])
|
|
self.assertEqual(out_ref[1], out_test[1])
|
|
self.assertEqual(out_ref[0], out_test_cpp[0])
|
|
self.assertEqual(out_ref[1], out_test_cpp[1])
|
|
|
|
|
|
def test_python_functionalization_conj(self):
|
|
def f(x):
|
|
y = x.clone().conj()
|
|
y.mul_(2)
|
|
return torch.view_as_real(y.resolve_conj())
|
|
|
|
x = torch.randn(4, dtype=torch.complex64)
|
|
out_ref = f(x)
|
|
out_test = dispatch_functionalize(f)(x)
|
|
out_test_cpp = _functionalize(f, reapply_views=True, crossref=False, skip_input_mutations=True)(x)
|
|
self.assertEqual(out_ref, out_test)
|
|
self.assertEqual(out_test, out_test_cpp)
|
|
fx_g = make_fx(dispatch_functionalize(f))(x)
|
|
fx_g_cpp = make_fx(_functionalize(f, reapply_views=True, crossref=False, skip_input_mutations=True))(x)
|
|
self.assertExpectedInline(fx_g.code.strip(), """\
|
|
def forward(self, arg0_1):
|
|
clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
|
|
_conj = torch.ops.aten._conj.default(clone); clone = None
|
|
clone_1 = torch.ops.aten.clone.default(_conj)
|
|
mul = torch.ops.aten.mul.Tensor(clone_1, 2); clone_1 = None
|
|
clone_2 = torch.ops.aten.clone.default(_conj); _conj = None
|
|
copy = torch.ops.aten.copy.default(clone_2, mul); clone_2 = mul = None
|
|
_conj_1 = torch.ops.aten._conj.default(copy); copy = None
|
|
_conj_2 = torch.ops.aten._conj.default(_conj_1); _conj_1 = None
|
|
clone_3 = torch.ops.aten.clone.default(_conj_2); _conj_2 = None
|
|
view_as_real = torch.ops.aten.view_as_real.default(clone_3); clone_3 = None
|
|
return view_as_real""")
|
|
self.assertEqual(fx_g_cpp.code.strip(), fx_g.code.strip())
|
|
|
|
def test_python_functionalization_neg(self):
|
|
def f(x):
|
|
y = x._neg_view()
|
|
z = y.resolve_neg()
|
|
return z + 1
|
|
|
|
x = torch.randn(4)
|
|
out_ref = f(x)
|
|
out_test = dispatch_functionalize(f)(x)
|
|
out_test_cpp = _functionalize(f, reapply_views=True, crossref=False, skip_input_mutations=True)(x)
|
|
self.assertEqual(out_ref, out_test)
|
|
self.assertEqual(out_ref, out_test_cpp)
|
|
fx_g = make_fx(dispatch_functionalize(f))(x)
|
|
fx_g_cpp = make_fx(_functionalize(f, reapply_views=True, crossref=False, skip_input_mutations=True))(x)
|
|
self.assertExpectedInline(fx_g.code.strip(), """\
|
|
def forward(self, arg0_1):
|
|
_neg_view = torch.ops.aten._neg_view.default(arg0_1); arg0_1 = None
|
|
clone = torch.ops.aten.clone.default(_neg_view); _neg_view = None
|
|
add = torch.ops.aten.add.Tensor(clone, 1); clone = None
|
|
return add""")
|
|
self.assertEqual(fx_g_cpp.code.strip(), fx_g.code.strip())
|
|
|
|
|
|
|
|
@xfail_inherited_tests([
|
|
"test_as_strided",
|
|
"test_copy_",
|
|
"test_diagonal",
|
|
"test_diagonal_mutated_input",
|
|
"test_everything",
|
|
"test_fill_",
|
|
"test_split",
|
|
"test_view_clone_view_inplace",
|
|
"test_view_inplace",
|
|
])
|
|
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "dynamo-ing code with proxy + fake doesnt work well")
|
|
class TestCrossRefFunctionalization(TestFunctionalization):
|
|
crossref = True
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|