mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62765 Fixes #27723 Test Plan: Imported from OSS Reviewed By: SplitInfinity Differential Revision: D30375181 Pulled By: msaroufim fbshipit-source-id: 715f4745899757ec405877980cd20c826028eb2c Co-authored-by: BowenBao <bowbao@microsoft.com>
This commit is contained in:
committed by
Facebook GitHub Bot
parent
db0771b05d
commit
1dd648f1c4
@ -5722,6 +5722,27 @@ class TestONNXRuntime(unittest.TestCase):
|
|||||||
y = torch.randint(10, (5, ))
|
y = torch.randint(10, (5, ))
|
||||||
self.run_test(MatmulModel(), (x, y))
|
self.run_test(MatmulModel(), (x, y))
|
||||||
|
|
||||||
|
@skipIfUnsupportedMinOpsetVersion(9) # MatMul long inputs is added in ONNX opset 9.
|
||||||
|
def test_dot(self):
|
||||||
|
class MatmulModel(torch.nn.Module):
|
||||||
|
def forward(self, input, other):
|
||||||
|
return torch.dot(input, other)
|
||||||
|
|
||||||
|
x = torch.randn(5, requires_grad=True)
|
||||||
|
y = torch.randn(5, requires_grad=True)
|
||||||
|
self.run_test(MatmulModel(), (x, y))
|
||||||
|
|
||||||
|
x = torch.randint(10, (5, ))
|
||||||
|
y = torch.randint(10, (5, ))
|
||||||
|
self.run_test(MatmulModel(), (x, y))
|
||||||
|
|
||||||
|
@disableScriptTest() # SpectralNorm not TorchScript compatible.
|
||||||
|
def test_spectral_norm(self):
|
||||||
|
m = torch.nn.utils.spectral_norm(torch.nn.Linear(2, 4))
|
||||||
|
|
||||||
|
x = torch.randn(6, 2)
|
||||||
|
self.run_test(m, (x, ))
|
||||||
|
|
||||||
def test_prelu(self):
|
def test_prelu(self):
|
||||||
class PReluModel(torch.nn.Module):
|
class PReluModel(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -3138,6 +3138,10 @@ def mv(g, self, vec):
|
|||||||
return matmul(g, self, vec)
|
return matmul(g, self, vec)
|
||||||
|
|
||||||
|
|
||||||
|
def dot(g, self, other):
|
||||||
|
return matmul(g, self, other)
|
||||||
|
|
||||||
|
|
||||||
@parse_args('v', 'v')
|
@parse_args('v', 'v')
|
||||||
def fill(g, self, value):
|
def fill(g, self, value):
|
||||||
dtype = self.type().scalarType()
|
dtype = self.type().scalarType()
|
||||||
|
Reference in New Issue
Block a user