mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add matmul optimization for the case A.ndim <= 2 && B.ndim >= 3 (#20448)
Summary: This addresses #18862. Pull Request resolved: https://github.com/pytorch/pytorch/pull/20448 Differential Revision: D15393465 Pulled By: ezyang fbshipit-source-id: 87e5b0ed8253ea00365f420d98ac96dd4e934028
This commit is contained in:
committed by
Facebook Github Bot
parent
a212a5b97a
commit
8c9f4c560a
@ -428,6 +428,29 @@ Tensor matmul(
|
|||||||
Tensor output = has_out ? at::_unsafe_view(at::mm_out(out, t1, t2), output_size)
|
Tensor output = has_out ? at::_unsafe_view(at::mm_out(out, t1, t2), output_size)
|
||||||
: at::_unsafe_view(t1.mm(t2), output_size);
|
: at::_unsafe_view(t1.mm(t2), output_size);
|
||||||
return has_out ? out.set_(output) : output;
|
return has_out ? out.set_(output) : output;
|
||||||
|
} else if ((dim_tensor1 == 1 || dim_tensor1 == 2) && dim_tensor2 >= 3) {
|
||||||
|
// optimization: transpose the inner dimensions of the arguments, call
|
||||||
|
// matmul on the swapped arguments, then transpose the inner dimensions
|
||||||
|
// of the result.
|
||||||
|
const int64_t n = dim_tensor1 == 2 ? tensor1.size(-2) : 1;
|
||||||
|
const int64_t m = tensor1.size(-1);
|
||||||
|
const int64_t p = tensor2.size(-1);
|
||||||
|
|
||||||
|
const Tensor t2_T = tensor2.transpose(-1, -2);
|
||||||
|
const Tensor t1_T = dim_tensor1 == 2 ? tensor1.t() : tensor1.reshape({n, m}).t();
|
||||||
|
const Tensor res_T = matmul(out_opt, t2_T, t1_T);
|
||||||
|
|
||||||
|
if (dim_tensor1 == 2) {
|
||||||
|
Tensor res = res_T.transpose(-1, -2).contiguous();
|
||||||
|
return has_out ? out.set_(res) : res;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
std::vector<int64_t> shape = tensor2.sizes().slice(0, dim_tensor2 - 2).vec();
|
||||||
|
shape.push_back(p);
|
||||||
|
|
||||||
|
Tensor res = res_T.reshape(shape).contiguous();
|
||||||
|
return has_out ? out.set_(res) : res;
|
||||||
|
}
|
||||||
} else if ((dim_tensor1 >= 1 && dim_tensor2 >= 1) && (dim_tensor1 >= 3 || dim_tensor2 >= 3)) {
|
} else if ((dim_tensor1 >= 1 && dim_tensor2 >= 1) && (dim_tensor1 >= 3 || dim_tensor2 >= 3)) {
|
||||||
// We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list);
|
// We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list);
|
||||||
// we track m1 vs m2 separately even though they must match for nicer error messages
|
// we track m1 vs m2 separately even though they must match for nicer error messages
|
||||||
|
@ -466,8 +466,12 @@ def method_tests():
|
|||||||
('ger', (S,), ((M,),)),
|
('ger', (S,), ((M,),)),
|
||||||
('matmul', (L,), ((L,),), '', (True,)),
|
('matmul', (L,), ((L,),), '', (True,)),
|
||||||
('matmul', (S, M), ((M,),), "2d_1d", (True,)),
|
('matmul', (S, M), ((M,),), "2d_1d", (True,)),
|
||||||
('matmul', (M, ), ((M, S),), "1d_2d", (True,)),
|
('matmul', (M,), ((M, S),), "1d_2d", (True,)),
|
||||||
('matmul', (S, M), ((M, S),), "2d_2d", (True,)),
|
('matmul', (S, M), ((M, S),), "2d_2d", (True,)),
|
||||||
|
('matmul', (S, S, M), ((M,),), "3d_1d", (True,)),
|
||||||
|
('matmul', (S, S, M), ((M, S),), "3d_2d", (True,)),
|
||||||
|
('matmul', (M,), ((S, M, S),), "1d_3d", (True,)),
|
||||||
|
('matmul', (S, M), ((S, M, S),), "2d_3d", (True,)),
|
||||||
('matmul', (S, S, M, M), ((S, S, M, S),), "4d_4d", (True,)),
|
('matmul', (S, S, M, M), ((S, S, M, S),), "4d_4d", (True,)),
|
||||||
('matmul', (S, S, M, M), ((M,),), "4d_1d", (True,)),
|
('matmul', (S, S, M, M), ((M,),), "4d_1d", (True,)),
|
||||||
('matmul', (M,), ((S, S, M, S),), "1d_4d", (True,)),
|
('matmul', (M,), ((S, S, M, S),), "1d_4d", (True,)),
|
||||||
|
@ -8200,6 +8200,98 @@ class _TestTorchMixin(object):
|
|||||||
A_LU, pivots = fn(torch.lu, (2, 0, 0))
|
A_LU, pivots = fn(torch.lu, (2, 0, 0))
|
||||||
self.assertEqual([(2, 0, 0), (2, 0)], [A_LU.shape, pivots.shape])
|
self.assertEqual([(2, 0, 0), (2, 0)], [A_LU.shape, pivots.shape])
|
||||||
|
|
||||||
|
def check_single_matmul(self, x, y, shape):
|
||||||
|
a = np.array(x, copy=False)
|
||||||
|
b = np.array(y, copy=False)
|
||||||
|
expected = np.matmul(a, b)
|
||||||
|
self.assertTrue(expected.flags['C_CONTIGUOUS'])
|
||||||
|
|
||||||
|
ans = torch.matmul(x, y)
|
||||||
|
self.assertTrue(ans.is_contiguous())
|
||||||
|
self.assertTrue(np.array_equal(ans, expected))
|
||||||
|
|
||||||
|
out = torch.zeros(*shape, dtype=torch.int64)
|
||||||
|
ans = torch.matmul(x, y, out=out)
|
||||||
|
self.assertIs(ans, out)
|
||||||
|
self.assertTrue(ans.is_contiguous())
|
||||||
|
self.assertTrue(np.array_equal(ans, expected))
|
||||||
|
|
||||||
|
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
||||||
|
def test_matmul_small_brute_force_1d_Nd(self):
|
||||||
|
# Issue #20452: range(0, 10) does not work.
|
||||||
|
n = 1
|
||||||
|
for m in range(1, 8):
|
||||||
|
for p in range(1, 8):
|
||||||
|
for o in range(1, 5):
|
||||||
|
# 1d, 3d, inner dimensions C
|
||||||
|
x = torch.arange(m)
|
||||||
|
y = torch.arange(o * m * p).reshape(o, m, p)
|
||||||
|
self.check_single_matmul(x, y, (o, n, p))
|
||||||
|
|
||||||
|
# 1d, 3d, inner dimensions Fortran
|
||||||
|
x = torch.arange(m)
|
||||||
|
y = torch.arange(o * p * m).reshape(o, p, m).transpose(-1, -2)
|
||||||
|
self.check_single_matmul(x, y, (o, n, p))
|
||||||
|
|
||||||
|
# 1d, 3d, inner dimensions non-contiguous
|
||||||
|
x = torch.arange(2 * m)[::2]
|
||||||
|
y = torch.arange(o * m * 2 * p).reshape(o, m, 2 * p)[:, :, ::2]
|
||||||
|
self.check_single_matmul(x, y, (o, n, p))
|
||||||
|
|
||||||
|
for r in range(1, 5):
|
||||||
|
# 1d, 4d, inner dimensions C
|
||||||
|
x = torch.arange(m)
|
||||||
|
y = torch.arange(r * o * m * p).reshape(r, o, m, p)
|
||||||
|
self.check_single_matmul(x, y, (r, o, n, p))
|
||||||
|
|
||||||
|
# 1d, 4d, inner dimensions Fortran
|
||||||
|
x = torch.arange(m)
|
||||||
|
y = torch.arange(r * o * p * m).reshape(r, o, p, m).transpose(-1, -2)
|
||||||
|
self.check_single_matmul(x, y, (r, o, n, p))
|
||||||
|
|
||||||
|
# 1d, 4d, inner dimensions non-contiguous
|
||||||
|
x = torch.arange(2 * m)[::2]
|
||||||
|
y = torch.arange(r * o * m * 2 * p).reshape(r, o, m, 2 * p)[:, :, :, ::2]
|
||||||
|
self.check_single_matmul(x, y, (r, o, n, p))
|
||||||
|
|
||||||
|
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
||||||
|
def test_matmul_small_brute_force_2d_Nd(self):
|
||||||
|
# Issue #20452: range(0, 10) does not work.
|
||||||
|
for n in range(1, 5):
|
||||||
|
for m in range(1, 5):
|
||||||
|
for p in range(1, 5):
|
||||||
|
for o in range(1, 3):
|
||||||
|
# 2d, 3d, inner dimensions C
|
||||||
|
x = torch.arange(n * m).reshape(n, m)
|
||||||
|
y = torch.arange(o * m * p).reshape(o, m, p)
|
||||||
|
self.check_single_matmul(x, y, (o, n, p))
|
||||||
|
|
||||||
|
# 2d, 3d, inner dimensions Fortran
|
||||||
|
x = torch.arange(m * n).reshape(m, n).transpose(-1, -2)
|
||||||
|
y = torch.arange(o * p * m).reshape(o, p, m).transpose(-1, -2)
|
||||||
|
self.check_single_matmul(x, y, (o, n, p))
|
||||||
|
|
||||||
|
# 2d, 3d, inner dimensions non-contiguous
|
||||||
|
x = torch.arange(n * 2 * m).reshape(n, 2 * m)[:, ::2]
|
||||||
|
y = torch.arange(o * m * 2 * p).reshape(o, m, 2 * p)[:, :, ::2]
|
||||||
|
self.check_single_matmul(x, y, (o, n, p))
|
||||||
|
|
||||||
|
for r in range(1, 2):
|
||||||
|
# 2d, 4d, inner dimensions C
|
||||||
|
x = torch.arange(n * m).reshape(n, m)
|
||||||
|
y = torch.arange(r * o * m * p).reshape(r, o, m, p)
|
||||||
|
self.check_single_matmul(x, y, (r, o, n, p))
|
||||||
|
|
||||||
|
# 2d, 4d, inner dimensions Fortran
|
||||||
|
x = torch.arange(m * n).reshape(m, n).transpose(-1, -2)
|
||||||
|
y = torch.arange(r * o * p * m).reshape(r, o, p, m).transpose(-1, -2)
|
||||||
|
self.check_single_matmul(x, y, (r, o, n, p))
|
||||||
|
|
||||||
|
# 2d, 4d, inner dimensions non-contiguous
|
||||||
|
x = torch.arange(n * 2 * m).reshape(n, 2 * m)[:, ::2]
|
||||||
|
y = torch.arange(r * o * m * 2 * p).reshape(r, o, m, 2 * p)[:, :, :, ::2]
|
||||||
|
self.check_single_matmul(x, y, (r, o, n, p))
|
||||||
|
|
||||||
@skipIfRocm
|
@skipIfRocm
|
||||||
def test_blas_alpha_beta_empty(self):
|
def test_blas_alpha_beta_empty(self):
|
||||||
for device in torch.testing.get_all_device_types():
|
for device in torch.testing.get_all_device_types():
|
||||||
|
Reference in New Issue
Block a user