mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33020 This is a pass to create functional blocks. The other PRs in the stack help avoid some of the limitations that are are often found in graphs. It's possible that this would work well with a graph that is frozen. Follow up work items that will help this pass: - We don't currently have any capacity in alias analysis to tell whether a Value that came from the wildcard set "re-escapes" back into the wildcard set. - More comments on the semantics of the graph and correctness conditions - We could consider using dynamic dag if the perf of this is a limitation. - potential make Functional Graphs Functional Blocks instead, so that we do not repeatedly copy constants, also to make IR read easier. Test Plan: Imported from OSS Differential Revision: D20603188 Pulled By: eellison fbshipit-source-id: 6822a6e65f4cc2676f8f6445fe8aa1cb858ebeeb
42 lines
1.5 KiB
Python
42 lines
1.5 KiB
Python
import os
|
|
import sys
|
|
|
|
import torch
|
|
from torch.testing import FileCheck
|
|
|
|
# Make the helper files in test/ importable
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
|
|
if __name__ == '__main__':
|
|
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
|
"\tpython test/test_jit.py TESTNAME\n\n"
|
|
"instead.")
|
|
|
|
class TestFunctionalBlocks(JitTestCase):
|
|
def test_simple_no_merge(self):
|
|
def fn(x, y, z):
|
|
x = x + 1
|
|
y = y + 1
|
|
z = z + 1
|
|
z.add_(2)
|
|
z = z * z
|
|
y = y * z
|
|
if y < 2:
|
|
y = y + 5
|
|
return x + y + z
|
|
|
|
graph = torch.jit.script(fn).graph
|
|
self.run_pass('create_functional_graphs', graph)
|
|
|
|
# all uses of x and y should be sunk
|
|
FileCheck().check(r"%x").check_not(r"%x").check("FunctionalGraph").check(r"%x").run(graph)
|
|
FileCheck().check(r"%y").check_not(r"%y").check("FunctionalGraph").check(r"%y").run(graph)
|
|
|
|
# Don't allow any outputs which escape scope, so there is one final addition in the graph
|
|
FileCheck().check("Tensor = prim::Functional").check_next("aten::add").run(graph)
|
|
|
|
# z + 1, z.add_(2) z * z considered non functional
|
|
FileCheck().check("add").check("add_").check("mul").check("FunctionalGraph").run(graph)
|