Files
pytorch/test/jit/test_dce.py
Aaron Gokaslan 8fce9a09cd [BE]: pyupgrade Python to 3.8 - imports and object inheritance only (#94308)
Apply parts of pyupgrade to torch (starting with the safest changes).
This PR only does two things: removes the need to inherit from object and removes unused future imports.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94308
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-07 21:10:56 +00:00

47 lines
1.2 KiB
Python

# Owner(s): ["oncall: jit"]
import torch
from torch.testing import FileCheck
from torch.testing._internal.jit_utils import JitTestCase, make_global
class TestDCE(JitTestCase):
def test_setattr_no_aliasdb(self):
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.x = torch.empty([2, 2])
def forward(self):
x = torch.rand([3, 3])
self.x = x
net = torch.jit.script(Net())
FileCheck().check("prim::SetAttr").run(net.graph)
def test_setattr_removed(self):
@torch.jit.script
class Thing1:
def __init__(self):
self.x = torch.zeros([2, 2])
make_global(Thing1)
class Thing2(torch.nn.Module):
def forward(self):
x = torch.rand([2, 2])
y = torch.rand([2, 2])
t1 = Thing1()
t1.x = x
return y
unscripted = Thing2()
t2 = torch.jit.script(unscripted)
t2.eval()
# freezing inlines t1.__init__(), after which DCE can occur.
t2 = torch.jit.freeze(t2)
FileCheck().check_not("prim::SetAttr").run(t2.graph)