mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62765 Fixes #27723 Test Plan: Imported from OSS Reviewed By: SplitInfinity Differential Revision: D30375181 Pulled By: msaroufim fbshipit-source-id: 715f4745899757ec405877980cd20c826028eb2c Co-authored-by: BowenBao <bowbao@microsoft.com>
This commit is contained in:
committed by
Facebook GitHub Bot
parent
db0771b05d
commit
1dd648f1c4
@ -5722,6 +5722,27 @@ class TestONNXRuntime(unittest.TestCase):
|
||||
y = torch.randint(10, (5, ))
|
||||
self.run_test(MatmulModel(), (x, y))
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9) # MatMul long inputs is added in ONNX opset 9.
|
||||
def test_dot(self):
|
||||
class MatmulModel(torch.nn.Module):
|
||||
def forward(self, input, other):
|
||||
return torch.dot(input, other)
|
||||
|
||||
x = torch.randn(5, requires_grad=True)
|
||||
y = torch.randn(5, requires_grad=True)
|
||||
self.run_test(MatmulModel(), (x, y))
|
||||
|
||||
x = torch.randint(10, (5, ))
|
||||
y = torch.randint(10, (5, ))
|
||||
self.run_test(MatmulModel(), (x, y))
|
||||
|
||||
@disableScriptTest() # SpectralNorm not TorchScript compatible.
|
||||
def test_spectral_norm(self):
|
||||
m = torch.nn.utils.spectral_norm(torch.nn.Linear(2, 4))
|
||||
|
||||
x = torch.randn(6, 2)
|
||||
self.run_test(m, (x, ))
|
||||
|
||||
def test_prelu(self):
|
||||
class PReluModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -3138,6 +3138,10 @@ def mv(g, self, vec):
|
||||
return matmul(g, self, vec)
|
||||
|
||||
|
||||
def dot(g, self, other):
|
||||
return matmul(g, self, other)
|
||||
|
||||
|
||||
@parse_args('v', 'v')
|
||||
def fill(g, self, value):
|
||||
dtype = self.type().scalarType()
|
||||
|
Reference in New Issue
Block a user