mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
torch.tensordot
: performance improvements when contracting to a scalar. (#145936)
As per title. Fixes https://github.com/pytorch/pytorch/issues/145731 Touches only compute. The CPU overhead can potentially be further reduced. Before: ```python In [3]: n = 512 In [4]: A = torch.rand(n, n) In [5]: B = torch.rand(n, n) In [6]: %timeit torch.tensordot(A, B, [[0, 1], [0, 1]]) 2.04 ms ± 70 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) In [7]: %timeit torch.tensordot(A, B, [[0, 1], [1, 0]]) 2.85 ms ± 191 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) In [8]: %timeit torch.tensordot(A, B, [[1, 0], [0, 1]]) 2.9 ms ± 133 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) In [9]: %timeit torch.tensordot(A, B, [[1, 0], [1, 0]]) 4.07 ms ± 262 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` After ```python In [2]: n = 512 In [3]: A = torch.rand(n, n) In [4]: B = torch.rand(n, n) In [5]: %timeit torch.tensordot(A, B, [[0, 1], [0, 1]]) 30.7 µs ± 2.51 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each) In [6]: %timeit torch.tensordot(A, B, [[0, 1], [1, 0]]) 141 µs ± 6.52 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each) In [7]: %timeit torch.tensordot(A, B, [[1, 0], [0, 1]]) 142 µs ± 4.03 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each) In [8]: %timeit torch.tensordot(A, B, [[1, 0], [1, 0]]) 62.8 µs ± 4.31 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/145936 Approved by: https://github.com/albanD, https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
8d7dec6e92
commit
edc2d539d1
@ -9470,6 +9470,24 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
||||
an = torch.from_numpy(np.tensordot(np.zeros((), dtype=np.float32), np.zeros((), dtype=np.float32), 0))
|
||||
self.assertEqual(a, an)
|
||||
|
||||
# Testing the fast path introduced in #145936,
|
||||
# i.e. reduction to a scalar has to be of right dim.
|
||||
a = torch.rand(2, 2, device=device)
|
||||
a_dims = [-1, -2]
|
||||
b = torch.rand(2, 2, device=device)
|
||||
b_dims = [-2, -1]
|
||||
for res_ndim in range(5):
|
||||
res_torch = torch.tensordot(a, b, [a_dims, b_dims])
|
||||
self.assertEqual(res_torch.ndim, res_ndim)
|
||||
|
||||
res_numpy = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(), [a_dims, b_dims]))
|
||||
self.assertEqual(res_torch, res_numpy)
|
||||
|
||||
if res_ndim % 2:
|
||||
b.unsqueeze_(0)
|
||||
else:
|
||||
a.unsqueeze_(0)
|
||||
|
||||
@skipCUDAIfNoCusolver
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
|
Reference in New Issue
Block a user