mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[static runtime] Swap to out-variant compatible nodes (#44127)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44127 Test Plan: Imported from OSS Reviewed By: hlu1 Differential Revision: D23604306 Pulled By: bwasti fbshipit-source-id: 18ccfb9b466b822e28130be3d5c4fae36c76820b
This commit is contained in:
committed by
Facebook GitHub Bot
parent
856510c96d
commit
a475613d1d
@ -16,6 +16,7 @@ class StaticRuntime:
|
||||
def __call__(self, *inps):
|
||||
return self.static_runtime.run(inps)
|
||||
|
||||
|
||||
def linear_shim(input, weight, bias=None):
|
||||
# type: (Tensor, Tensor, Optional[Tensor]) -> Tensor
|
||||
output = input.matmul(weight.t())
|
||||
@ -23,6 +24,8 @@ def linear_shim(input, weight, bias=None):
|
||||
output += bias
|
||||
ret = output
|
||||
return ret
|
||||
|
||||
|
||||
torch.nn.functional.linear = linear_shim
|
||||
|
||||
|
||||
@ -92,6 +95,7 @@ def trivial_graph(a, b, c):
|
||||
s = torch.tensor([[3, 3], [3, 3]])
|
||||
return a + b * c + s
|
||||
|
||||
|
||||
class TestStaticRuntime(TestCase):
|
||||
def test_multihead_attention_layer(self):
|
||||
HID_DIM = 256
|
||||
@ -133,7 +137,15 @@ class TestStaticRuntime(TestCase):
|
||||
ref_top = top_l(top_inp)
|
||||
acc_top = top_l_acc(top_inp)[0]
|
||||
torch.testing.assert_allclose(acc_top, ref_top)
|
||||
|
||||
for _ in range(5):
|
||||
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
|
||||
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
|
||||
ref_bot = bot_l(bot_inp)
|
||||
acc_bot = bot_l_acc(bot_inp)[0]
|
||||
torch.testing.assert_allclose(acc_bot, ref_bot)
|
||||
ref_top = top_l(top_inp)
|
||||
acc_top = top_l_acc(top_inp)[0]
|
||||
torch.testing.assert_allclose(acc_top, ref_top)
|
||||
|
||||
# def test_trivial_graph(self):
|
||||
# s = torch.full((2, 2), 2)
|
||||
@ -143,5 +155,6 @@ class TestStaticRuntime(TestCase):
|
||||
# o_test = tg_a(s, s, s)[0]
|
||||
# torch.testing.assert_allclose(o_ref, o_test)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user