mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[JIT] remove prim::SetAttr from list of ops with side effects (#68311)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68311 prim::SetAttr is listed as an op with side effects, but in AliasDb, `analyzeSetAttr` already accounts for its behavior. By removing it from the list of ops with side effects, dead code elimination will work in a few other scenarios. Test Plan: Imported from OSS Reviewed By: mrshenli Differential Revision: D32409510 fbshipit-source-id: 52ed9e19f92afb95c669ad3c2440f72f9515ba4c
This commit is contained in:
committed by
Facebook GitHub Bot
parent
add79722dd
commit
bf60c6e71b
46
test/jit/test_dce.py
Normal file
46
test/jit/test_dce.py
Normal file
@ -0,0 +1,46 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import torch
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
|
||||
|
||||
class TestDCE(JitTestCase):
|
||||
def test_setattr_no_aliasdb(self):
|
||||
class Net(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.x = torch.empty([2, 2])
|
||||
|
||||
def forward(self):
|
||||
x = torch.rand([3, 3])
|
||||
self.x = x
|
||||
|
||||
net = torch.jit.script(Net())
|
||||
|
||||
FileCheck().check("prim::SetAttr").run(net.graph)
|
||||
|
||||
def test_setattr_removed(self):
|
||||
@torch.jit.script
|
||||
class Thing1(object):
|
||||
def __init__(self):
|
||||
self.x = torch.zeros([2, 2])
|
||||
|
||||
make_global(Thing1)
|
||||
|
||||
class Thing2(torch.nn.Module):
|
||||
def forward(self):
|
||||
x = torch.rand([2, 2])
|
||||
y = torch.rand([2, 2])
|
||||
t1 = Thing1()
|
||||
t1.x = x
|
||||
return y
|
||||
|
||||
unscripted = Thing2()
|
||||
|
||||
t2 = torch.jit.script(unscripted)
|
||||
t2.eval()
|
||||
|
||||
# freezing inlines t1.__init__(), after which DCE can occur.
|
||||
t2 = torch.jit.freeze(t2)
|
||||
FileCheck().check_not("prim::SetAttr").run(t2.graph)
|
@ -69,6 +69,7 @@ from jit.test_union import TestUnion # noqa: F401
|
||||
from jit.test_models import MnistNet
|
||||
from jit.test_batch_mm import TestBatchMM # noqa: F401
|
||||
from jit.test_dtype_analysis import TestDtypeAnalysis # noqa: F401
|
||||
from jit.test_dce import TestDCE # noqa: F401
|
||||
|
||||
# Torch
|
||||
from torch import Tensor
|
||||
|
@ -1175,7 +1175,6 @@ bool Node::hasSideEffects() const {
|
||||
case prim::IgnoredPythonOp:
|
||||
case prim::Print:
|
||||
case prim::RaiseException:
|
||||
case prim::SetAttr:
|
||||
case aten::warn:
|
||||
case aten::save:
|
||||
case aten::manual_seed:
|
||||
|
@ -299,6 +299,12 @@ class DeadCodeEliminator {
|
||||
// If we don't have alias information, all mutable ops have unknown
|
||||
// effects and can't be considered for elimination.
|
||||
|
||||
if (node->kind() == prim::SetAttr) {
|
||||
// SetAttr is a special case: it doesn't have a schema, but does
|
||||
// have untracked mutations
|
||||
return true;
|
||||
}
|
||||
|
||||
// onnx export calls EliminateDeadCode but sometimes passes invalid
|
||||
// aten operators. So we call maybeSchema so we handle the cases when
|
||||
// there is no valid schema for a node
|
||||
|
Reference in New Issue
Block a user