[static runtime] add static subgraph fusion pass (#49185)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49185

This diff adds a fusion feature that will let us use static runtime for *parts* of the graph.  This will prove useful in cases where fully eliminating control flow is hard etc.

TODO:
[x] factor out into separate fusion file
[x] add python test case
[x] add graph that isn't fully lowered test case
[x] add graph that has weird list/tuple outputs test case

the loop example looks quite good:
```
graph(%a.1 : Tensor,
      %b.1 : Tensor,
      %iters.1 : int):
  %12 : bool = prim::Constant[value=1]() # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
  %c.2 : Tensor = prim::StaticSubgraph_0(%a.1, %b.1)
  %c : Tensor = prim::Loop(%iters.1, %12, %c.2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4
    block0(%i : int, %c.12 : Tensor):
      %c.10 : Tensor = prim::StaticSubgraph_1(%a.1, %c.12, %b.1)
      -> (%12, %c.10)
  return (%c)
with prim::StaticSubgraph_0 = graph(%0 : Tensor,
      %4 : Tensor):
  %5 : int = prim::Constant[value=2]()
  %6 : Tensor = aten::mul(%4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:12
  %2 : int = prim::Constant[value=1]()
  %c.2 : Tensor = aten::add(%0, %6, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:8
  return (%c.2)
with prim::StaticSubgraph_1 = graph(%1 : Tensor,
      %7 : Tensor,
      %8 : Tensor):
  %9 : int = prim::Constant[value=1]()
  %c.4 : Tensor = aten::add(%7, %8, %9) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:111:12
  %5 : int = prim::Constant[value=2]()
  %c.7 : Tensor = aten::mul_(%c.4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:112:8
  %2 : int = prim::Constant[value=1]()
  %c.10 : Tensor = aten::sub_(%c.7, %1, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:113:8
  return (%c.10)
```

(Note: this ignores all push blocking failures!)

Test Plan:
buck test mode/no-gpu //caffe2/benchmarks/static_runtime:static_runtime_cpptest

buck test mode/no-gpu caffe2/test:static_runtime

Reviewed By: bertmaher

Differential Revision: D25385702

fbshipit-source-id: 2f24af4f11d92a959167facd03fbd24f464a6098
This commit is contained in:
Bram Wasti
2020-12-10 14:01:36 -08:00
committed by Facebook GitHub Bot
parent 95a1725a4a
commit f4226b5c90
12 changed files with 406 additions and 4 deletions

View File

@ -105,6 +105,21 @@ def trivial_graph(a, b, c):
s = torch.tensor([[3, 3], [3, 3]])
return a + b * c + s
def loop_graph(a, b, iters : int):
c = a + b * 2
for i in range(iters):
c = c + b
c *= 2
c -= a
return c
def output_graph(a, b, c, iters : int):
s = torch.tensor([[3, 3], [3, 3]])
k = a + b * c + s
d : Dict[int, Tensor] = {}
for i in range(iters):
d[i] = k + i
return d
class TestStaticRuntime(TestCase):
def test_multihead_attention_layer(self):
@ -203,5 +218,63 @@ class TestStaticRuntime(TestCase):
o_test = tg_a(s)[0]
torch.testing.assert_allclose(o_ref, o_test)
def test_fusion_trivial_graph(self):
s = torch.full((2, 2), 2)
tg = torch.jit.script(trivial_graph)
o_ref = tg(s, s, s)
torch._C._fuse_to_static_runtime(tg.graph)
assert "StaticSubgraph" in str(tg.graph)
o_test = tg(s, s, s)
torch.testing.assert_allclose(o_ref, o_test)
def test_fusion_multihead_attention_layer(self):
HID_DIM = 256
QUERY_LEN = 8
BATCH_SIZE = 128
LAYERS = 3
HEADS = 8
DROPOUT = 0.1
device = torch.device("cpu")
attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
with torch.no_grad():
src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
attention.eval()
attention = torch.jit.script(attention)
attention.eval()
o_ref = attention(src, src, src, src_mask)
torch._C._fuse_to_static_runtime(attention._c)
o_test = attention(src, src, src, src_mask)
for a, b in zip(o_ref, o_test):
torch.testing.assert_allclose(a, b)
def test_fusion_loop(self):
a = torch.randn(5, 5)
b = torch.randn(5, 5)
c = 4
lg = torch.jit.script(loop_graph)
o_ref = lg(a, b, c)
torch._C._fuse_to_static_runtime(lg.graph)
assert "StaticSubgraph" in str(lg.graph)
o_test = lg(a, b, c)
torch.testing.assert_allclose(o_ref, o_test)
def test_fusion_outputs(self):
a = torch.randn(2, 2)
b = torch.randn(2, 2)
c = 4
og = torch.jit.script(output_graph)
o_ref = og(a, b, b, c)
torch._C._fuse_to_static_runtime(og.graph)
assert "StaticSubgraph" in str(og.graph)
o_test = og(a, b, b, c)
for i in o_ref.keys():
torch.testing.assert_allclose(o_ref[i], o_test[i])
if __name__ == "__main__":
run_tests()