mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[static runtime] add static subgraph fusion pass (#49185)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49185 This diff adds a fusion feature that will let us use static runtime for *parts* of the graph. This will prove useful in cases where fully eliminating control flow is hard etc. TODO: [x] factor out into separate fusion file [x] add python test case [x] add graph that isn't fully lowered test case [x] add graph that has weird list/tuple outputs test case the loop example looks quite good: ``` graph(%a.1 : Tensor, %b.1 : Tensor, %iters.1 : int): %12 : bool = prim::Constant[value=1]() # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4 %c.2 : Tensor = prim::StaticSubgraph_0(%a.1, %b.1) %c : Tensor = prim::Loop(%iters.1, %12, %c.2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:110:4 block0(%i : int, %c.12 : Tensor): %c.10 : Tensor = prim::StaticSubgraph_1(%a.1, %c.12, %b.1) -> (%12, %c.10) return (%c) with prim::StaticSubgraph_0 = graph(%0 : Tensor, %4 : Tensor): %5 : int = prim::Constant[value=2]() %6 : Tensor = aten::mul(%4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:12 %2 : int = prim::Constant[value=1]() %c.2 : Tensor = aten::add(%0, %6, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:109:8 return (%c.2) with prim::StaticSubgraph_1 = graph(%1 : Tensor, %7 : Tensor, %8 : Tensor): %9 : int = prim::Constant[value=1]() %c.4 : Tensor = aten::add(%7, %8, %9) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:111:12 %5 : int = prim::Constant[value=2]() %c.7 : Tensor = aten::mul_(%c.4, %5) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:112:8 %2 : int = prim::Constant[value=1]() %c.10 : Tensor = aten::sub_(%c.7, %1, %2) # /data/users/bwasti/fbsource/fbcode/buck-out/dev/gen/caffe2/test/static_runtime#binary,link-tree/test_static_runtime.py:113:8 return (%c.10) ``` (Note: this ignores all push blocking failures!) Test Plan: buck test mode/no-gpu //caffe2/benchmarks/static_runtime:static_runtime_cpptest buck test mode/no-gpu caffe2/test:static_runtime Reviewed By: bertmaher Differential Revision: D25385702 fbshipit-source-id: 2f24af4f11d92a959167facd03fbd24f464a6098
This commit is contained in:
committed by
Facebook GitHub Bot
parent
95a1725a4a
commit
f4226b5c90
@ -105,6 +105,21 @@ def trivial_graph(a, b, c):
|
||||
s = torch.tensor([[3, 3], [3, 3]])
|
||||
return a + b * c + s
|
||||
|
||||
def loop_graph(a, b, iters : int):
|
||||
c = a + b * 2
|
||||
for i in range(iters):
|
||||
c = c + b
|
||||
c *= 2
|
||||
c -= a
|
||||
return c
|
||||
|
||||
def output_graph(a, b, c, iters : int):
|
||||
s = torch.tensor([[3, 3], [3, 3]])
|
||||
k = a + b * c + s
|
||||
d : Dict[int, Tensor] = {}
|
||||
for i in range(iters):
|
||||
d[i] = k + i
|
||||
return d
|
||||
|
||||
class TestStaticRuntime(TestCase):
|
||||
def test_multihead_attention_layer(self):
|
||||
@ -203,5 +218,63 @@ class TestStaticRuntime(TestCase):
|
||||
o_test = tg_a(s)[0]
|
||||
torch.testing.assert_allclose(o_ref, o_test)
|
||||
|
||||
def test_fusion_trivial_graph(self):
|
||||
s = torch.full((2, 2), 2)
|
||||
tg = torch.jit.script(trivial_graph)
|
||||
o_ref = tg(s, s, s)
|
||||
torch._C._fuse_to_static_runtime(tg.graph)
|
||||
assert "StaticSubgraph" in str(tg.graph)
|
||||
o_test = tg(s, s, s)
|
||||
torch.testing.assert_allclose(o_ref, o_test)
|
||||
|
||||
def test_fusion_multihead_attention_layer(self):
|
||||
HID_DIM = 256
|
||||
QUERY_LEN = 8
|
||||
BATCH_SIZE = 128
|
||||
LAYERS = 3
|
||||
HEADS = 8
|
||||
DROPOUT = 0.1
|
||||
device = torch.device("cpu")
|
||||
attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
|
||||
with torch.no_grad():
|
||||
src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
|
||||
src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
|
||||
|
||||
attention.eval()
|
||||
attention = torch.jit.script(attention)
|
||||
attention.eval()
|
||||
o_ref = attention(src, src, src, src_mask)
|
||||
|
||||
torch._C._fuse_to_static_runtime(attention._c)
|
||||
o_test = attention(src, src, src, src_mask)
|
||||
|
||||
for a, b in zip(o_ref, o_test):
|
||||
torch.testing.assert_allclose(a, b)
|
||||
|
||||
def test_fusion_loop(self):
|
||||
a = torch.randn(5, 5)
|
||||
b = torch.randn(5, 5)
|
||||
c = 4
|
||||
lg = torch.jit.script(loop_graph)
|
||||
o_ref = lg(a, b, c)
|
||||
torch._C._fuse_to_static_runtime(lg.graph)
|
||||
assert "StaticSubgraph" in str(lg.graph)
|
||||
o_test = lg(a, b, c)
|
||||
torch.testing.assert_allclose(o_ref, o_test)
|
||||
|
||||
def test_fusion_outputs(self):
|
||||
a = torch.randn(2, 2)
|
||||
b = torch.randn(2, 2)
|
||||
c = 4
|
||||
og = torch.jit.script(output_graph)
|
||||
o_ref = og(a, b, b, c)
|
||||
torch._C._fuse_to_static_runtime(og.graph)
|
||||
assert "StaticSubgraph" in str(og.graph)
|
||||
o_test = og(a, b, b, c)
|
||||
for i in o_ref.keys():
|
||||
torch.testing.assert_allclose(o_ref[i], o_test[i])
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user