mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This PR is part of a series attempting to re-submit https://github.com/pytorch/pytorch/pull/134592 as smaller PRs. In jit tests: - Add and use a common raise_on_run_directly method for when a user runs a test file directly which should not be run this way. Print the file which the user should have run. - Raise a RuntimeError on tests which have been disabled (not run) Pull Request resolved: https://github.com/pytorch/pytorch/pull/154725 Approved by: https://github.com/clee2000
81 lines
2.1 KiB
Python
81 lines
2.1 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
|
|
import torch
|
|
from torch.testing import FileCheck
|
|
from torch.testing._internal.common_utils import raise_on_run_directly
|
|
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
|
|
|
|
|
class TestDCE(JitTestCase):
|
|
def test_setattr_no_aliasdb(self):
|
|
class Net(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.x = torch.empty([2, 2])
|
|
|
|
def forward(self):
|
|
x = torch.rand([3, 3])
|
|
self.x = x
|
|
|
|
net = torch.jit.script(Net())
|
|
|
|
FileCheck().check("prim::SetAttr").run(net.graph)
|
|
|
|
def test_setattr_removed(self):
|
|
@torch.jit.script
|
|
class Thing1:
|
|
def __init__(self) -> None:
|
|
self.x = torch.zeros([2, 2])
|
|
|
|
make_global(Thing1)
|
|
|
|
class Thing2(torch.nn.Module):
|
|
def forward(self):
|
|
x = torch.rand([2, 2])
|
|
y = torch.rand([2, 2])
|
|
t1 = Thing1()
|
|
t1.x = x
|
|
return y
|
|
|
|
unscripted = Thing2()
|
|
|
|
t2 = torch.jit.script(unscripted)
|
|
t2.eval()
|
|
|
|
# freezing inlines t1.__init__(), after which DCE can occur.
|
|
t2 = torch.jit.freeze(t2)
|
|
FileCheck().check_not("prim::SetAttr").run(t2.graph)
|
|
|
|
def test_mutated_simple(self):
|
|
def fn(x: torch.Tensor):
|
|
y = x.sin()
|
|
y_slice = y[::2]
|
|
y_slice.add_(x[::2])
|
|
z = y.cos()
|
|
return z
|
|
|
|
fn_s = torch.jit.script(fn)
|
|
torch._C._jit_pass_dce_graph(fn_s.graph)
|
|
|
|
FileCheck().check("aten::add_").run(fn_s.graph)
|
|
|
|
def test_mutated_loop(self):
|
|
def fn(x: torch.Tensor):
|
|
y = x.sin()
|
|
y_slice = y[::2]
|
|
y_slice.add_(x[::2])
|
|
for _ in range(2):
|
|
y_slice = y[::2]
|
|
y = y.repeat(2)
|
|
z = y.cos()
|
|
return z
|
|
|
|
fn_s = torch.jit.script(fn)
|
|
torch._C._jit_pass_dce_graph(fn_s.graph)
|
|
|
|
FileCheck().check("aten::add_").run(fn_s.graph)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise_on_run_directly("test/test_jit.py")
|