Files
pytorch/test/jit/test_functional_blocks.py
Elias Ellison 5b2f8cef08 [JIT] Functional Graph Pass (#33020)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33020

This is a pass to create functional blocks. The other PRs in the stack help avoid some of the limitations that are are often found in graphs. It's possible that this would work well with a graph that is frozen. Follow up work items that will help this pass:

- We don't currently have any capacity in alias analysis to tell whether a Value that came from the wildcard set "re-escapes" back into the wildcard set.
- More comments on the semantics of the graph and correctness conditions
- We could consider using dynamic dag if the perf of this is a limitation.
- potential make Functional Graphs Functional Blocks instead, so that we do not repeatedly copy constants, also to make IR read easier.

Test Plan: Imported from OSS

Differential Revision: D20603188

Pulled By: eellison

fbshipit-source-id: 6822a6e65f4cc2676f8f6445fe8aa1cb858ebeeb
2020-03-24 23:44:18 -07:00

42 lines
1.5 KiB
Python

import os
import sys
import torch
from torch.testing import FileCheck
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
class TestFunctionalBlocks(JitTestCase):
def test_simple_no_merge(self):
def fn(x, y, z):
x = x + 1
y = y + 1
z = z + 1
z.add_(2)
z = z * z
y = y * z
if y < 2:
y = y + 5
return x + y + z
graph = torch.jit.script(fn).graph
self.run_pass('create_functional_graphs', graph)
# all uses of x and y should be sunk
FileCheck().check(r"%x").check_not(r"%x").check("FunctionalGraph").check(r"%x").run(graph)
FileCheck().check(r"%y").check_not(r"%y").check("FunctionalGraph").check(r"%y").run(graph)
# Don't allow any outputs which escape scope, so there is one final addition in the graph
FileCheck().check("Tensor = prim::Functional").check_next("aten::add").run(graph)
# z + 1, z.add_(2) z * z considered non functional
FileCheck().check("add").check("add_").check("mul").check("FunctionalGraph").run(graph)