[ONNX] Suppport torch.dot and torch.nn.utils.spectral_norm (#62596) (#62765)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62765

Fixes #27723

Test Plan: Imported from OSS

Reviewed By: SplitInfinity

Differential Revision: D30375181

Pulled By: msaroufim

fbshipit-source-id: 715f4745899757ec405877980cd20c826028eb2c

Co-authored-by: BowenBao <bowbao@microsoft.com>
This commit is contained in:
BowenBao
2021-08-20 12:44:29 -07:00
committed by Facebook GitHub Bot
parent db0771b05d
commit 1dd648f1c4
2 changed files with 25 additions and 0 deletions

View File

@ -5722,6 +5722,27 @@ class TestONNXRuntime(unittest.TestCase):
y = torch.randint(10, (5, ))
self.run_test(MatmulModel(), (x, y))
@skipIfUnsupportedMinOpsetVersion(9) # MatMul long inputs is added in ONNX opset 9.
def test_dot(self):
class MatmulModel(torch.nn.Module):
def forward(self, input, other):
return torch.dot(input, other)
x = torch.randn(5, requires_grad=True)
y = torch.randn(5, requires_grad=True)
self.run_test(MatmulModel(), (x, y))
x = torch.randint(10, (5, ))
y = torch.randint(10, (5, ))
self.run_test(MatmulModel(), (x, y))
@disableScriptTest() # SpectralNorm not TorchScript compatible.
def test_spectral_norm(self):
m = torch.nn.utils.spectral_norm(torch.nn.Linear(2, 4))
x = torch.randn(6, 2)
self.run_test(m, (x, ))
def test_prelu(self):
class PReluModel(torch.nn.Module):
def __init__(self):

View File

@ -3138,6 +3138,10 @@ def mv(g, self, vec):
return matmul(g, self, vec)
def dot(g, self, other):
return matmul(g, self, other)
@parse_args('v', 'v')
def fill(g, self, value):
dtype = self.type().scalarType()