Add matmul optimization for the case A.ndim <= 2 && B.ndim >= 3 (#20448)

Summary:
This addresses #18862.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20448

Differential Revision: D15393465

Pulled By: ezyang

fbshipit-source-id: 87e5b0ed8253ea00365f420d98ac96dd4e934028
This commit is contained in:
Stefan Krah
2019-05-17 09:37:06 -07:00
committed by Facebook Github Bot
parent a212a5b97a
commit 8c9f4c560a
3 changed files with 120 additions and 1 deletions

View File

@ -428,6 +428,29 @@ Tensor matmul(
Tensor output = has_out ? at::_unsafe_view(at::mm_out(out, t1, t2), output_size)
: at::_unsafe_view(t1.mm(t2), output_size);
return has_out ? out.set_(output) : output;
} else if ((dim_tensor1 == 1 || dim_tensor1 == 2) && dim_tensor2 >= 3) {
// optimization: transpose the inner dimensions of the arguments, call
// matmul on the swapped arguments, then transpose the inner dimensions
// of the result.
const int64_t n = dim_tensor1 == 2 ? tensor1.size(-2) : 1;
const int64_t m = tensor1.size(-1);
const int64_t p = tensor2.size(-1);
const Tensor t2_T = tensor2.transpose(-1, -2);
const Tensor t1_T = dim_tensor1 == 2 ? tensor1.t() : tensor1.reshape({n, m}).t();
const Tensor res_T = matmul(out_opt, t2_T, t1_T);
if (dim_tensor1 == 2) {
Tensor res = res_T.transpose(-1, -2).contiguous();
return has_out ? out.set_(res) : res;
}
else {
std::vector<int64_t> shape = tensor2.sizes().slice(0, dim_tensor2 - 2).vec();
shape.push_back(p);
Tensor res = res_T.reshape(shape).contiguous();
return has_out ? out.set_(res) : res;
}
} else if ((dim_tensor1 >= 1 && dim_tensor2 >= 1) && (dim_tensor1 >= 3 || dim_tensor2 >= 3)) {
// We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list);
// we track m1 vs m2 separately even though they must match for nicer error messages

View File

@ -466,8 +466,12 @@ def method_tests():
('ger', (S,), ((M,),)),
('matmul', (L,), ((L,),), '', (True,)),
('matmul', (S, M), ((M,),), "2d_1d", (True,)),
('matmul', (M, ), ((M, S),), "1d_2d", (True,)),
('matmul', (M,), ((M, S),), "1d_2d", (True,)),
('matmul', (S, M), ((M, S),), "2d_2d", (True,)),
('matmul', (S, S, M), ((M,),), "3d_1d", (True,)),
('matmul', (S, S, M), ((M, S),), "3d_2d", (True,)),
('matmul', (M,), ((S, M, S),), "1d_3d", (True,)),
('matmul', (S, M), ((S, M, S),), "2d_3d", (True,)),
('matmul', (S, S, M, M), ((S, S, M, S),), "4d_4d", (True,)),
('matmul', (S, S, M, M), ((M,),), "4d_1d", (True,)),
('matmul', (M,), ((S, S, M, S),), "1d_4d", (True,)),

View File

@ -8200,6 +8200,98 @@ class _TestTorchMixin(object):
A_LU, pivots = fn(torch.lu, (2, 0, 0))
self.assertEqual([(2, 0, 0), (2, 0)], [A_LU.shape, pivots.shape])
def check_single_matmul(self, x, y, shape):
a = np.array(x, copy=False)
b = np.array(y, copy=False)
expected = np.matmul(a, b)
self.assertTrue(expected.flags['C_CONTIGUOUS'])
ans = torch.matmul(x, y)
self.assertTrue(ans.is_contiguous())
self.assertTrue(np.array_equal(ans, expected))
out = torch.zeros(*shape, dtype=torch.int64)
ans = torch.matmul(x, y, out=out)
self.assertIs(ans, out)
self.assertTrue(ans.is_contiguous())
self.assertTrue(np.array_equal(ans, expected))
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_matmul_small_brute_force_1d_Nd(self):
# Issue #20452: range(0, 10) does not work.
n = 1
for m in range(1, 8):
for p in range(1, 8):
for o in range(1, 5):
# 1d, 3d, inner dimensions C
x = torch.arange(m)
y = torch.arange(o * m * p).reshape(o, m, p)
self.check_single_matmul(x, y, (o, n, p))
# 1d, 3d, inner dimensions Fortran
x = torch.arange(m)
y = torch.arange(o * p * m).reshape(o, p, m).transpose(-1, -2)
self.check_single_matmul(x, y, (o, n, p))
# 1d, 3d, inner dimensions non-contiguous
x = torch.arange(2 * m)[::2]
y = torch.arange(o * m * 2 * p).reshape(o, m, 2 * p)[:, :, ::2]
self.check_single_matmul(x, y, (o, n, p))
for r in range(1, 5):
# 1d, 4d, inner dimensions C
x = torch.arange(m)
y = torch.arange(r * o * m * p).reshape(r, o, m, p)
self.check_single_matmul(x, y, (r, o, n, p))
# 1d, 4d, inner dimensions Fortran
x = torch.arange(m)
y = torch.arange(r * o * p * m).reshape(r, o, p, m).transpose(-1, -2)
self.check_single_matmul(x, y, (r, o, n, p))
# 1d, 4d, inner dimensions non-contiguous
x = torch.arange(2 * m)[::2]
y = torch.arange(r * o * m * 2 * p).reshape(r, o, m, 2 * p)[:, :, :, ::2]
self.check_single_matmul(x, y, (r, o, n, p))
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_matmul_small_brute_force_2d_Nd(self):
# Issue #20452: range(0, 10) does not work.
for n in range(1, 5):
for m in range(1, 5):
for p in range(1, 5):
for o in range(1, 3):
# 2d, 3d, inner dimensions C
x = torch.arange(n * m).reshape(n, m)
y = torch.arange(o * m * p).reshape(o, m, p)
self.check_single_matmul(x, y, (o, n, p))
# 2d, 3d, inner dimensions Fortran
x = torch.arange(m * n).reshape(m, n).transpose(-1, -2)
y = torch.arange(o * p * m).reshape(o, p, m).transpose(-1, -2)
self.check_single_matmul(x, y, (o, n, p))
# 2d, 3d, inner dimensions non-contiguous
x = torch.arange(n * 2 * m).reshape(n, 2 * m)[:, ::2]
y = torch.arange(o * m * 2 * p).reshape(o, m, 2 * p)[:, :, ::2]
self.check_single_matmul(x, y, (o, n, p))
for r in range(1, 2):
# 2d, 4d, inner dimensions C
x = torch.arange(n * m).reshape(n, m)
y = torch.arange(r * o * m * p).reshape(r, o, m, p)
self.check_single_matmul(x, y, (r, o, n, p))
# 2d, 4d, inner dimensions Fortran
x = torch.arange(m * n).reshape(m, n).transpose(-1, -2)
y = torch.arange(r * o * p * m).reshape(r, o, p, m).transpose(-1, -2)
self.check_single_matmul(x, y, (r, o, n, p))
# 2d, 4d, inner dimensions non-contiguous
x = torch.arange(n * 2 * m).reshape(n, 2 * m)[:, ::2]
y = torch.arange(r * o * m * 2 * p).reshape(r, o, m, 2 * p)[:, :, :, ::2]
self.check_single_matmul(x, y, (r, o, n, p))
@skipIfRocm
def test_blas_alpha_beta_empty(self):
for device in torch.testing.get_all_device_types():