Files
pytorch/test/jit/test_dce.py
Anthony Barbier bf7e290854 Add __main__ guards to jit tests (#154725)
This PR is part of a series attempting to re-submit https://github.com/pytorch/pytorch/pull/134592 as smaller PRs.

In jit tests:

- Add and use a common raise_on_run_directly method for when a user runs a test file directly which should not be run this way. Print the file which the user should have run.
- Raise a RuntimeError on tests which have been disabled (not run)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154725
Approved by: https://github.com/clee2000
2025-06-16 10:28:45 +00:00

81 lines
2.1 KiB
Python

# Owner(s): ["oncall: jit"]
import torch
from torch.testing import FileCheck
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase, make_global
class TestDCE(JitTestCase):
def test_setattr_no_aliasdb(self):
class Net(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.x = torch.empty([2, 2])
def forward(self):
x = torch.rand([3, 3])
self.x = x
net = torch.jit.script(Net())
FileCheck().check("prim::SetAttr").run(net.graph)
def test_setattr_removed(self):
@torch.jit.script
class Thing1:
def __init__(self) -> None:
self.x = torch.zeros([2, 2])
make_global(Thing1)
class Thing2(torch.nn.Module):
def forward(self):
x = torch.rand([2, 2])
y = torch.rand([2, 2])
t1 = Thing1()
t1.x = x
return y
unscripted = Thing2()
t2 = torch.jit.script(unscripted)
t2.eval()
# freezing inlines t1.__init__(), after which DCE can occur.
t2 = torch.jit.freeze(t2)
FileCheck().check_not("prim::SetAttr").run(t2.graph)
def test_mutated_simple(self):
def fn(x: torch.Tensor):
y = x.sin()
y_slice = y[::2]
y_slice.add_(x[::2])
z = y.cos()
return z
fn_s = torch.jit.script(fn)
torch._C._jit_pass_dce_graph(fn_s.graph)
FileCheck().check("aten::add_").run(fn_s.graph)
def test_mutated_loop(self):
def fn(x: torch.Tensor):
y = x.sin()
y_slice = y[::2]
y_slice.add_(x[::2])
for _ in range(2):
y_slice = y[::2]
y = y.repeat(2)
z = y.cos()
return z
fn_s = torch.jit.script(fn)
torch._C._jit_pass_dce_graph(fn_s.graph)
FileCheck().check("aten::add_").run(fn_s.graph)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")