[JIT] remove prim::SetAttr from list of ops with side effects (#68311)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68311

prim::SetAttr is listed as an op with side effects, but in AliasDb, `analyzeSetAttr` already accounts for its behavior. By removing it from the list of ops with side effects, dead code elimination will work in a few other scenarios.

Test Plan: Imported from OSS

Reviewed By: mrshenli

Differential Revision: D32409510

fbshipit-source-id: 52ed9e19f92afb95c669ad3c2440f72f9515ba4c
This commit is contained in:
David Berard
2021-11-16 08:37:50 -08:00
committed by Facebook GitHub Bot
parent add79722dd
commit bf60c6e71b
4 changed files with 53 additions and 1 deletions

46
test/jit/test_dce.py Normal file
View File

@ -0,0 +1,46 @@
# Owner(s): ["oncall: jit"]
import torch
from torch.testing import FileCheck
from torch.testing._internal.jit_utils import JitTestCase, make_global
class TestDCE(JitTestCase):
def test_setattr_no_aliasdb(self):
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.x = torch.empty([2, 2])
def forward(self):
x = torch.rand([3, 3])
self.x = x
net = torch.jit.script(Net())
FileCheck().check("prim::SetAttr").run(net.graph)
def test_setattr_removed(self):
@torch.jit.script
class Thing1(object):
def __init__(self):
self.x = torch.zeros([2, 2])
make_global(Thing1)
class Thing2(torch.nn.Module):
def forward(self):
x = torch.rand([2, 2])
y = torch.rand([2, 2])
t1 = Thing1()
t1.x = x
return y
unscripted = Thing2()
t2 = torch.jit.script(unscripted)
t2.eval()
# freezing inlines t1.__init__(), after which DCE can occur.
t2 = torch.jit.freeze(t2)
FileCheck().check_not("prim::SetAttr").run(t2.graph)

View File

@ -69,6 +69,7 @@ from jit.test_union import TestUnion # noqa: F401
from jit.test_models import MnistNet
from jit.test_batch_mm import TestBatchMM # noqa: F401
from jit.test_dtype_analysis import TestDtypeAnalysis # noqa: F401
from jit.test_dce import TestDCE # noqa: F401
# Torch
from torch import Tensor

View File

@ -1175,7 +1175,6 @@ bool Node::hasSideEffects() const {
case prim::IgnoredPythonOp:
case prim::Print:
case prim::RaiseException:
case prim::SetAttr:
case aten::warn:
case aten::save:
case aten::manual_seed:

View File

@ -299,6 +299,12 @@ class DeadCodeEliminator {
// If we don't have alias information, all mutable ops have unknown
// effects and can't be considered for elimination.
if (node->kind() == prim::SetAttr) {
// SetAttr is a special case: it doesn't have a schema, but does
// have untracked mutations
return true;
}
// onnx export calls EliminateDeadCode but sometimes passes invalid
// aten operators. So we call maybeSchema so we handle the cases when
// there is no valid schema for a node