Files
pytorch/test/jit/test_alias_analysis.py
Elias Ellison ab6395fc65 Add api for recursively analyzing function calls (#73329)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73329

There is a quantization use case for having better alias analysis with function calls remaining. This does the relatively dumb approach of getting the inlined graph of each function call, and then analyzing that subgraph. Since we need a unique single analysis of every `Value*`, for every function call make a copy of the graph for every analysis past the first. This is relatively slow, but given the limited use case here should work well enough (and is no slower than calling the inlining pass).

cc vkuzo

Test Plan: Imported from OSS

Reviewed By: davidberard98

Differential Revision: D34451424

Pulled By: eellison

fbshipit-source-id: b7c7e54679d723f5ded1e11ffb32eb6d2176431d
(cherry picked from commit 81a42b31522b890311a3f512448b372c4ebbefd1)
2022-02-28 17:44:45 +00:00

94 lines
3.4 KiB
Python

# Owner(s): ["oncall: jit"]
from torch.testing._internal.jit_utils import JitTestCase
from torch._C import parse_ir
import torch
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
class TestAliasAnalysis(JitTestCase):
def test_becomes_wildcard_annotations(self):
graph_str = """
graph(%a.1 : Tensor, %b.1 : Tensor):
%11 : NoneType = prim::Constant()
%8 : int = prim::Constant[value=0]()
%7 : int = prim::Constant[value=1]()
%x.1 : Tensor = aten::add(%a.1, %b.1, %7)
%y.1 : Tensor[] = aten::split(%x.1, %7, %8)
return ()
"""
graph = parse_ir(graph_str)
alias_db = graph.alias_db()
split_node = graph.findNode("aten::split")
# split input enters wildcard set, list initalized as containing wildcard set
self.assertTrue(alias_db.may_contain_alias(next(split_node.inputs()), split_node.output()))
# because %x.1 enters wildcard set, it now aliases other members of wildcard set (graph inputs)
self.assertTrue(alias_db.may_contain_alias(next(split_node.inputs()), next(graph.inputs())))
def test_nested_list_construct_not_wildcard(self):
@torch.jit.script
def foo(x):
y = torch.rand([2, 2])
return [y]
graph = foo.graph
graph.alias_db()
alias_db = graph.alias_db()
ten_construct = graph.findNode("aten::rand").output()
output = next(graph.outputs())
self.assertTrue(alias_db.may_contain_alias(ten_construct, output))
self.assertFalse(alias_db.may_contain_alias(next(graph.inputs()), ten_construct))
def test_recursive_calls(self):
@torch.jit.script
def foo(x, y):
x.add_(1)
return x + y
@torch.jit.script
def caller():
a = torch.rand([2, 2])
b = torch.ones([2, 2])
out1 = foo(a, b)
c = torch.rand([1])
d = torch.ones([2])
out2 = foo(d, c)
return out1, out2
isFrozen = False
descend_function_calls = True
alias_db = caller.graph.alias_db(isFrozen, descend_function_calls)
func_calls = caller.graph.findAllNodes("prim::CallFunction")
self.assertEqual(len(func_calls), 2)
for node in func_calls:
inps = list(node.inputs())
self.assertTrue(alias_db.has_writers(inps[1]))
self.assertFalse(alias_db.has_writers(inps[2]))
class Mod(torch.nn.Module):
def forward(self):
a = torch.rand([2, 2])
b = torch.ones([2, 2])
out1 = self.foo2(a, b)
c = torch.rand([1])
d = torch.ones([2])
out2 = self.foo2(d, c)
return out1, out2
def foo2(self, x, y):
x.add_(1)
return x + y
mod = torch.jit.script(Mod())
alias_db = mod.graph.alias_db(isFrozen, descend_function_calls)
func_calls = mod.graph.findAllNodes("prim::CallMethod")
self.assertEqual(len(func_calls), 2)
for node in func_calls:
inps = list(node.inputs())
self.assertTrue(alias_db.has_writers(inps[1]))
self.assertFalse(alias_db.has_writers(inps[2]))