Expand the coverage of test_addmm and test_addmm_sizes (#43831)

Summary:
- This test is very fast and very important, so it makes no sense in marking it as slowTest
- This test is should also run on CUDA
- This test should check alpha and beta support
- This test should check `out=` support
- manual computation should use list instead of index_put because list is much faster
- precision for TF32 needs to be fixed. Will do it in future PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/43831

Reviewed By: ailzhang

Differential Revision: D23435032

Pulled By: ngimel

fbshipit-source-id: d1b8350addf1e2fe180fdf3df243f38d95aa3f5a
This commit is contained in:
Xiang Gao
2020-09-02 20:50:20 -07:00
committed by Facebook GitHub Bot
parent f5ba489f93
commit bc45c47aa3
2 changed files with 37 additions and 44 deletions

View File

@ -74,7 +74,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
if (&result != &self) {
at::native::resize_as_(result, self_);
if (beta.to<double>() != 0.0) {
if (beta.toComplexDouble() != 0.0) {
at::native::copy_(result, self_);
}
}

View File

@ -16379,48 +16379,48 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
for use_out, row_major, incx, incy, lda_tail in product((False, True), (False, True), (1, 2), (1, 2), (0, 1)):
_test(use_out, row_major, incx, incy, lda_tail)
@slowTest
@onlyCPU
def test_addmm(self, device):
dtypes = {
torch.double: 1e-8,
torch.float: 1e-4,
torch.bfloat16: 1e-1,
torch.half: 1e-1,
torch.cfloat: 1e-4,
torch.cdouble: 1e-8
}
for dtype, prec in dtypes.items():
M = torch.randn(10, 25).to(device=device, dtype=dtype)
m1 = torch.randn(10, 50).to(device=device, dtype=dtype)
m2 = torch.randn(50, 25).to(device=device, dtype=dtype)
res1 = torch.addmm(M, m1, m2)
res2 = torch.zeros(10, 25, device=device, dtype=dtype)
res2 += M
for i in range(10):
for j in range(25):
for k in range(50):
res2[i, j] += m1[i, k] * m2[k, j]
self.assertEqual(res1, res2, atol=prec, rtol=0)
def _test_addmm(self, M, m1, m2):
dtype = M.dtype
numpy_dtype = dtype
if dtype in {torch.bfloat16}:
numpy_dtype = torch.float
if dtype.is_complex:
alpha = 0.9 + 0.3j
beta = 0.5 + 0.6j
else:
alpha = 1.2
beta = 0.8
res1 = torch.addmm(M, m1, m2, alpha=alpha, beta=beta)
res2 = torch.full_like(res1, math.nan)
torch.addmm(M, m1, m2, alpha=alpha, beta=beta, out=res2)
res3 = (beta * M).to(numpy_dtype).cpu().numpy() + alpha * (
m1.to(numpy_dtype).cpu().numpy() @ m2.to(numpy_dtype).cpu().numpy())
res3 = torch.from_numpy(res3).to(dtype)
self.assertEqual(res1, res2)
self.assertEqual(res1, res3)
@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
@dtypesIfCUDA(*torch.testing.get_all_complex_dtypes(), *torch.testing.get_all_fp_dtypes(include_bfloat16=False))
@dtypes(*torch.testing.get_all_complex_dtypes(), *torch.testing.get_all_fp_dtypes())
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_addmm(self, device, dtype):
M = torch.randn(10, 25).to(device=device, dtype=dtype)
m1 = torch.randn(10, 50).to(device=device, dtype=dtype)
m2 = torch.randn(50, 25).to(device=device, dtype=dtype)
self._test_addmm(M, m1, m2)
# Test 0-strided
for dtype, prec in dtypes.items():
M = torch.randn(10, 1).to(device=device, dtype=dtype).expand(10, 25)
m1 = torch.randn(10, 1).to(device=device, dtype=dtype).expand(10, 50)
m2 = torch.randn(50, 25).to(device=device, dtype=dtype)
res1 = torch.addmm(M, m1, m2)
res2 = torch.zeros(10, 25, device=device, dtype=dtype)
res2 += M
for i in range(10):
for j in range(25):
for k in range(50):
res2[i, j] += m1[i, k] * m2[k, j]
self.assertEqual(res1, res2, atol=prec, rtol=0)
M = torch.randn(10, 1).to(device=device, dtype=dtype).expand(10, 25)
m1 = torch.randn(10, 1).to(device=device, dtype=dtype).expand(10, 50)
m2 = torch.randn(50, 25).to(device=device, dtype=dtype)
self._test_addmm(M, m1, m2)
@dtypes(torch.float, torch.double)
@dtypesIfCUDA(*([torch.float, torch.double] +
([] if TEST_WITH_ROCM else torch.testing.get_all_complex_dtypes())))
@tf32_on_and_off(0.005)
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_addmm_sizes(self, device, dtype):
for m in [0, 1, 25]:
for n in [0, 1, 10]:
@ -16428,14 +16428,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
M = torch.randn(n, m, device=device, dtype=dtype)
m1 = torch.randn(n, k, device=device, dtype=dtype)
m2 = torch.randn(k, m, device=device, dtype=dtype)
res1 = torch.addmm(M, m1, m2)
res2 = torch.zeros(n, m, device=device, dtype=dtype)
res2 += M
for i in range(n):
for j in range(m):
for l in range(k):
res2[i, j] += m1[i, l] * m2[l, j]
self.assertEqual(res1, res2)
self._test_addmm(M, m1, m2)
def _test_dot_vdot_vs_numpy(self, device, dtype, torch_fn, np_fn):
def compare_with_numpy_bin_op(torch_fn, np_fn, x, y):