[static runtime] Swap to out-variant compatible nodes (#44127)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44127

Test Plan: Imported from OSS

Reviewed By: hlu1

Differential Revision: D23604306

Pulled By: bwasti

fbshipit-source-id: 18ccfb9b466b822e28130be3d5c4fae36c76820b
This commit is contained in:
Bram Wasti
2020-09-14 12:33:02 -07:00
committed by Facebook GitHub Bot
parent 856510c96d
commit a475613d1d
3 changed files with 100 additions and 95 deletions

View File

@ -16,6 +16,7 @@ class StaticRuntime:
def __call__(self, *inps):
return self.static_runtime.run(inps)
def linear_shim(input, weight, bias=None):
# type: (Tensor, Tensor, Optional[Tensor]) -> Tensor
output = input.matmul(weight.t())
@ -23,6 +24,8 @@ def linear_shim(input, weight, bias=None):
output += bias
ret = output
return ret
torch.nn.functional.linear = linear_shim
@ -92,6 +95,7 @@ def trivial_graph(a, b, c):
s = torch.tensor([[3, 3], [3, 3]])
return a + b * c + s
class TestStaticRuntime(TestCase):
def test_multihead_attention_layer(self):
HID_DIM = 256
@ -133,7 +137,15 @@ class TestStaticRuntime(TestCase):
ref_top = top_l(top_inp)
acc_top = top_l_acc(top_inp)[0]
torch.testing.assert_allclose(acc_top, ref_top)
for _ in range(5):
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
ref_bot = bot_l(bot_inp)
acc_bot = bot_l_acc(bot_inp)[0]
torch.testing.assert_allclose(acc_bot, ref_bot)
ref_top = top_l(top_inp)
acc_top = top_l_acc(top_inp)[0]
torch.testing.assert_allclose(acc_top, ref_top)
# def test_trivial_graph(self):
# s = torch.full((2, 2), 2)
@ -143,5 +155,6 @@ class TestStaticRuntime(TestCase):
# o_test = tg_a(s, s, s)[0]
# torch.testing.assert_allclose(o_ref, o_test)
if __name__ == "__main__":
run_tests()