mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Expand the coverage of test_addmm and test_addmm_sizes (#43831)
Summary: - This test is very fast and very important, so it makes no sense in marking it as slowTest - This test is should also run on CUDA - This test should check alpha and beta support - This test should check `out=` support - manual computation should use list instead of index_put because list is much faster - precision for TF32 needs to be fixed. Will do it in future PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/43831 Reviewed By: ailzhang Differential Revision: D23435032 Pulled By: ngimel fbshipit-source-id: d1b8350addf1e2fe180fdf3df243f38d95aa3f5a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
f5ba489f93
commit
bc45c47aa3
@ -74,7 +74,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
|
||||
if (&result != &self) {
|
||||
at::native::resize_as_(result, self_);
|
||||
if (beta.to<double>() != 0.0) {
|
||||
if (beta.toComplexDouble() != 0.0) {
|
||||
at::native::copy_(result, self_);
|
||||
}
|
||||
}
|
||||
|
||||
@ -16379,48 +16379,48 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
|
||||
for use_out, row_major, incx, incy, lda_tail in product((False, True), (False, True), (1, 2), (1, 2), (0, 1)):
|
||||
_test(use_out, row_major, incx, incy, lda_tail)
|
||||
|
||||
@slowTest
|
||||
@onlyCPU
|
||||
def test_addmm(self, device):
|
||||
dtypes = {
|
||||
torch.double: 1e-8,
|
||||
torch.float: 1e-4,
|
||||
torch.bfloat16: 1e-1,
|
||||
torch.half: 1e-1,
|
||||
torch.cfloat: 1e-4,
|
||||
torch.cdouble: 1e-8
|
||||
}
|
||||
for dtype, prec in dtypes.items():
|
||||
M = torch.randn(10, 25).to(device=device, dtype=dtype)
|
||||
m1 = torch.randn(10, 50).to(device=device, dtype=dtype)
|
||||
m2 = torch.randn(50, 25).to(device=device, dtype=dtype)
|
||||
res1 = torch.addmm(M, m1, m2)
|
||||
res2 = torch.zeros(10, 25, device=device, dtype=dtype)
|
||||
res2 += M
|
||||
for i in range(10):
|
||||
for j in range(25):
|
||||
for k in range(50):
|
||||
res2[i, j] += m1[i, k] * m2[k, j]
|
||||
self.assertEqual(res1, res2, atol=prec, rtol=0)
|
||||
def _test_addmm(self, M, m1, m2):
|
||||
dtype = M.dtype
|
||||
numpy_dtype = dtype
|
||||
if dtype in {torch.bfloat16}:
|
||||
numpy_dtype = torch.float
|
||||
if dtype.is_complex:
|
||||
alpha = 0.9 + 0.3j
|
||||
beta = 0.5 + 0.6j
|
||||
else:
|
||||
alpha = 1.2
|
||||
beta = 0.8
|
||||
res1 = torch.addmm(M, m1, m2, alpha=alpha, beta=beta)
|
||||
res2 = torch.full_like(res1, math.nan)
|
||||
torch.addmm(M, m1, m2, alpha=alpha, beta=beta, out=res2)
|
||||
res3 = (beta * M).to(numpy_dtype).cpu().numpy() + alpha * (
|
||||
m1.to(numpy_dtype).cpu().numpy() @ m2.to(numpy_dtype).cpu().numpy())
|
||||
res3 = torch.from_numpy(res3).to(dtype)
|
||||
self.assertEqual(res1, res2)
|
||||
self.assertEqual(res1, res3)
|
||||
|
||||
@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
|
||||
torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
|
||||
@dtypesIfCUDA(*torch.testing.get_all_complex_dtypes(), *torch.testing.get_all_fp_dtypes(include_bfloat16=False))
|
||||
@dtypes(*torch.testing.get_all_complex_dtypes(), *torch.testing.get_all_fp_dtypes())
|
||||
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
||||
def test_addmm(self, device, dtype):
|
||||
M = torch.randn(10, 25).to(device=device, dtype=dtype)
|
||||
m1 = torch.randn(10, 50).to(device=device, dtype=dtype)
|
||||
m2 = torch.randn(50, 25).to(device=device, dtype=dtype)
|
||||
self._test_addmm(M, m1, m2)
|
||||
|
||||
# Test 0-strided
|
||||
for dtype, prec in dtypes.items():
|
||||
M = torch.randn(10, 1).to(device=device, dtype=dtype).expand(10, 25)
|
||||
m1 = torch.randn(10, 1).to(device=device, dtype=dtype).expand(10, 50)
|
||||
m2 = torch.randn(50, 25).to(device=device, dtype=dtype)
|
||||
res1 = torch.addmm(M, m1, m2)
|
||||
res2 = torch.zeros(10, 25, device=device, dtype=dtype)
|
||||
res2 += M
|
||||
for i in range(10):
|
||||
for j in range(25):
|
||||
for k in range(50):
|
||||
res2[i, j] += m1[i, k] * m2[k, j]
|
||||
self.assertEqual(res1, res2, atol=prec, rtol=0)
|
||||
M = torch.randn(10, 1).to(device=device, dtype=dtype).expand(10, 25)
|
||||
m1 = torch.randn(10, 1).to(device=device, dtype=dtype).expand(10, 50)
|
||||
m2 = torch.randn(50, 25).to(device=device, dtype=dtype)
|
||||
self._test_addmm(M, m1, m2)
|
||||
|
||||
@dtypes(torch.float, torch.double)
|
||||
@dtypesIfCUDA(*([torch.float, torch.double] +
|
||||
([] if TEST_WITH_ROCM else torch.testing.get_all_complex_dtypes())))
|
||||
@tf32_on_and_off(0.005)
|
||||
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
||||
def test_addmm_sizes(self, device, dtype):
|
||||
for m in [0, 1, 25]:
|
||||
for n in [0, 1, 10]:
|
||||
@ -16428,14 +16428,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
|
||||
M = torch.randn(n, m, device=device, dtype=dtype)
|
||||
m1 = torch.randn(n, k, device=device, dtype=dtype)
|
||||
m2 = torch.randn(k, m, device=device, dtype=dtype)
|
||||
res1 = torch.addmm(M, m1, m2)
|
||||
res2 = torch.zeros(n, m, device=device, dtype=dtype)
|
||||
res2 += M
|
||||
for i in range(n):
|
||||
for j in range(m):
|
||||
for l in range(k):
|
||||
res2[i, j] += m1[i, l] * m2[l, j]
|
||||
self.assertEqual(res1, res2)
|
||||
self._test_addmm(M, m1, m2)
|
||||
|
||||
def _test_dot_vdot_vs_numpy(self, device, dtype, torch_fn, np_fn):
|
||||
def compare_with_numpy_bin_op(torch_fn, np_fn, x, y):
|
||||
|
||||
Reference in New Issue
Block a user