mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Partially addresses #123062 Ran lintrunner on: - `test/jit` with command: ```bash lintrunner -a --take UFMT --all-files ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/123623 Approved by: https://github.com/ezyang
319 lines
10 KiB
Python
319 lines
10 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
|
|
import os
|
|
import sys
|
|
from typing import List
|
|
|
|
import torch
|
|
from torch.testing import FileCheck
|
|
|
|
# Make the helper files in test/ importable
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
from torch.testing._internal.jit_utils import freeze_rng_state, JitTestCase
|
|
|
|
if __name__ == "__main__":
|
|
raise RuntimeError(
|
|
"This test file is not meant to be run directly, use:\n\n"
|
|
"\tpython test/test_jit.py TESTNAME\n\n"
|
|
"instead."
|
|
)
|
|
|
|
|
|
class TestRemoveMutation(JitTestCase):
|
|
def test_aten_inplace(self):
|
|
def test_not_new_alias(x):
|
|
y = x[0]
|
|
y.add_(2)
|
|
return y
|
|
|
|
fn = torch.jit.script(test_not_new_alias)
|
|
graph = fn.graph
|
|
self.run_pass("remove_mutation", graph)
|
|
FileCheck().check("aten::add_").run(graph)
|
|
self.assertEqual(fn(torch.ones([2, 2])), test_not_new_alias(torch.ones([2, 2])))
|
|
|
|
def test_no_lowering():
|
|
x = torch.tensor([2, 2])
|
|
x[0] = 3
|
|
return x
|
|
|
|
# there is no functional equivalent of x[0] = ...
|
|
fn = torch.jit.script(test_no_lowering)
|
|
graph = fn.graph
|
|
self.run_pass("remove_mutation", graph)
|
|
FileCheck().check("aten::copy_").run(graph)
|
|
self.assertEqual(fn(), test_no_lowering())
|
|
|
|
def test_move_before_not_valid():
|
|
y = torch.tensor([2, 2])
|
|
z = y + 2
|
|
y.add_(2)
|
|
return y, z
|
|
|
|
fn = torch.jit.script(test_move_before_not_valid)
|
|
graph = fn.graph
|
|
self.run_pass("remove_mutation", graph)
|
|
FileCheck().check("aten::add_").run(graph)
|
|
self.assertEqual(fn(), test_move_before_not_valid())
|
|
|
|
def test_successful():
|
|
x = torch.tensor([2, 2])
|
|
x.add_(1)
|
|
x.add_(3)
|
|
y = x + 4
|
|
return x, y
|
|
|
|
fn = torch.jit.script(test_successful)
|
|
graph = fn.graph
|
|
self.run_pass("remove_mutation", graph)
|
|
FileCheck().check_not("aten::add_").run(graph)
|
|
self.assertEqual(test_successful(), fn())
|
|
|
|
def test_intermediary_use():
|
|
x = torch.tensor([2, 2])
|
|
x.add_(1)
|
|
y = x + 4
|
|
x.add_(3)
|
|
return x, y
|
|
|
|
fn = torch.jit.script(test_intermediary_use)
|
|
graph = fn.graph
|
|
FileCheck().check_count("aten::add_", 2).run(graph)
|
|
self.run_pass("remove_mutation", graph)
|
|
# Unable to remove the second add_ because of the y = x + 4 use
|
|
# In the future we could duplicating the value of x as a temporary and replacing
|
|
# its intermediary use (so long as aliasing is safe)
|
|
FileCheck().check_count("aten::add_", 1).run(graph)
|
|
self.assertEqual(test_intermediary_use(), fn())
|
|
|
|
def test_if_output(self):
|
|
def foo(x, cond: bool):
|
|
if cond:
|
|
y = x + 5
|
|
else:
|
|
y = x + 2
|
|
y.add_(4)
|
|
return y
|
|
|
|
out_eager = foo(torch.tensor(5), True)
|
|
foo_script = torch.jit.script(foo)
|
|
FileCheck().check("aten::add_").run(foo_script.graph)
|
|
self.run_pass("remove_mutation", foo_script.graph)
|
|
FileCheck().check_not("aten::add_").run(foo_script.graph)
|
|
|
|
self.assertEqual(out_eager, foo_script(torch.tensor(5), True))
|
|
|
|
def test_if_output_fail(self):
|
|
@torch.jit.script
|
|
def foo(cond: bool):
|
|
li = []
|
|
if cond:
|
|
x = torch.tensor(1)
|
|
li.append(x)
|
|
else:
|
|
x = torch.tensor(2)
|
|
y = x.add_(2)
|
|
return y, li
|
|
|
|
self.run_pass("inline", foo.graph)
|
|
self.run_pass("remove_mutation", foo.graph)
|
|
FileCheck().check("aten::add_").run(foo.graph)
|
|
|
|
@torch.jit.script
|
|
def foo(cond: bool, y):
|
|
if cond:
|
|
x = y
|
|
else:
|
|
x = torch.tensor(2)
|
|
z = x.add_(2)
|
|
return z
|
|
|
|
self.run_pass("inline", foo.graph)
|
|
self.run_pass("remove_mutation", foo.graph)
|
|
FileCheck().check("aten::add_").run(foo.graph)
|
|
|
|
def test_special_mapped_op(self):
|
|
def test_successful():
|
|
x = torch.tensor([2, 2])
|
|
y = torch.tensor([2, 4])
|
|
x.zero_()
|
|
y.fill_(3)
|
|
return x, y
|
|
|
|
fn = torch.jit.script(test_successful)
|
|
graph = fn.graph
|
|
self.run_pass("remove_mutation", graph)
|
|
FileCheck().check_not("aten::zero_").check_not("aten::fill_").run(graph)
|
|
self.assertEqual(test_successful(), fn())
|
|
|
|
# full_like is not implemented for a tensor fill value
|
|
|
|
def test_successful():
|
|
x = torch.tensor([2, 2])
|
|
y = torch.tensor([2, 4])
|
|
x.fill_(y)
|
|
return x + x
|
|
|
|
fn = torch.jit.script(test_successful)
|
|
graph = fn.graph
|
|
self.run_pass("remove_mutation", graph)
|
|
FileCheck().check_not("aten::fill_").run(graph)
|
|
|
|
def normal():
|
|
# NOTE: For some unknown reason, the
|
|
# `torch._C._jit_pass_remove_mutation` call within `self.run_pass`
|
|
# replaces `torch.randn(..., dtype=None).normal_()` with an
|
|
# `aten::normal` call with dtype double, even if the default dtype
|
|
# is float. So we must explicitly set the dtype here
|
|
return torch.rand(2, 1, 3, 4, dtype=torch.float).normal_()
|
|
|
|
fn = torch.jit.script(normal)
|
|
graph = fn.graph
|
|
self.run_pass("remove_mutation", graph)
|
|
FileCheck().check_not("normal_").run(graph)
|
|
with freeze_rng_state():
|
|
out_eager = normal()
|
|
with freeze_rng_state():
|
|
out_script = fn()
|
|
self.assertEqual(out_eager, out_script)
|
|
|
|
def test_lists_append(self):
|
|
def successful_remove():
|
|
return [i for i in range(5)] # noqa: C416
|
|
|
|
fn = torch.jit.script(successful_remove)
|
|
graph = fn.graph
|
|
self.run_pass("loop_unrolling", graph)
|
|
self.run_pass("remove_mutation", graph)
|
|
self.run_pass("constant_propagation", graph)
|
|
FileCheck().check("graph").check_next("Constant").check_next("return").run(
|
|
graph
|
|
)
|
|
self.assertEqual(successful_remove(), successful_remove())
|
|
|
|
def intermediary_use():
|
|
a = [1, 2]
|
|
b = len(a)
|
|
a.append(3)
|
|
return a
|
|
|
|
fn = torch.jit.script(intermediary_use)
|
|
graph = fn.graph
|
|
FileCheck().check("append").run(graph)
|
|
self.run_pass("remove_mutation", graph)
|
|
# it is possible to remove the append here but don't currently have the logic for it
|
|
FileCheck().check_not("append").run(graph)
|
|
self.assertEqual(intermediary_use(), fn())
|
|
|
|
def test_lists_insert(self):
|
|
def successful_remove():
|
|
a: List[int] = []
|
|
a.insert(0, 1)
|
|
a.insert(0, 2)
|
|
a.insert(-10, 3)
|
|
a.insert(-9, 4)
|
|
a.insert(10, 5)
|
|
return a
|
|
|
|
fn = torch.jit.script(successful_remove)
|
|
graph = fn.graph
|
|
torch._C._jit_pass_remove_mutation(graph)
|
|
torch._C._jit_pass_constant_propagation(graph)
|
|
FileCheck().check("graph").check_next("Constant").check_next("return").run(
|
|
graph
|
|
)
|
|
self.assertEqual(successful_remove(), fn())
|
|
|
|
def test_list_indexing_removal(self):
|
|
@torch.jit.script
|
|
def out_of_bounds():
|
|
x = [1, 2]
|
|
x[4] = 3
|
|
return x
|
|
|
|
torch._C._jit_pass_remove_mutation(out_of_bounds.graph)
|
|
FileCheck().check("set_item").run(out_of_bounds.graph)
|
|
|
|
@torch.jit.script
|
|
def unknown(y: int):
|
|
x = [1, 2]
|
|
x[y] = 3
|
|
return x
|
|
|
|
torch._C._jit_pass_remove_mutation(out_of_bounds.graph)
|
|
FileCheck().check("set_item").run(out_of_bounds.graph)
|
|
|
|
def successful():
|
|
x = [1, 2, 3]
|
|
x[0] = 4
|
|
x[-1] = 0
|
|
return x
|
|
|
|
scripted_fn = torch.jit.script(successful)
|
|
torch._C._jit_pass_remove_mutation(scripted_fn.graph)
|
|
FileCheck().check_not("set_item").run(scripted_fn.graph)
|
|
self.checkScript(successful, ())
|
|
|
|
def successful():
|
|
x = [1, 2, 3]
|
|
x[0] = 4
|
|
x[-1] = 0
|
|
return x
|
|
|
|
scripted_fn = torch.jit.script(successful)
|
|
torch._C._jit_pass_remove_mutation(scripted_fn.graph)
|
|
FileCheck().check_not("set_item").run(scripted_fn.graph)
|
|
self.checkScript(successful, ())
|
|
|
|
def successful():
|
|
x = [1]
|
|
x[-1] = 3
|
|
return x
|
|
|
|
scripted_fn = torch.jit.script(successful)
|
|
torch._C._jit_pass_remove_mutation(scripted_fn.graph)
|
|
FileCheck().check_not("set_item").run(scripted_fn.graph)
|
|
self.checkScript(successful, ())
|
|
|
|
def test_common_pytorch_list_ops(self):
|
|
for op in ["cat", "stack", "vstack", "hstack", "dstack"]:
|
|
|
|
class OpMod(torch.nn.Module):
|
|
def __init__(self, op):
|
|
super().__init__()
|
|
self.op = torch_op
|
|
|
|
def forward(self):
|
|
x = torch.tensor([1, 2, 3, 4])
|
|
x.add_(3)
|
|
y = [x, x]
|
|
return self.op(y) + 3
|
|
|
|
torch_op = getattr(torch, op)
|
|
mod = OpMod(torch_op)
|
|
mod_script = torch.jit.script(mod)
|
|
self.run_pass("remove_mutation", mod_script.forward.graph)
|
|
FileCheck().check_not("aten::add_").run(mod_script.forward.graph)
|
|
self.assertEqual(mod(), mod_script())
|
|
|
|
# test that the output doesnt alias the input
|
|
for inputs in [torch.rand(2, 2)], [torch.rand(2, 2) for _ in range(2)]:
|
|
result = torch_op(inputs)
|
|
sums = [ten.sum() for ten in result]
|
|
|
|
for inp in inputs:
|
|
inp.fill_(10)
|
|
|
|
self.assertEqual(sums, [ten.sum() for ten in result])
|
|
|
|
@torch.jit.script
|
|
def test_multiple_uses():
|
|
x = torch.tensor([1, 2, 3, 4])
|
|
x.add_(3)
|
|
y = [x, x]
|
|
return torch.cat(y), y
|
|
|
|
self.run_pass("remove_mutation", mod_script.forward.graph)
|
|
FileCheck().check("aten::add_").run(test_multiple_uses.graph)
|