torch.tensordot: performance improvements when contracting to a scalar. (#145936)

As per title.
Fixes https://github.com/pytorch/pytorch/issues/145731

Touches only compute. The CPU overhead can potentially be further reduced.

Before:
```python
In [3]: n = 512

In [4]: A = torch.rand(n, n)

In [5]: B = torch.rand(n, n)

In [6]: %timeit torch.tensordot(A, B, [[0, 1], [0, 1]])
2.04 ms ± 70 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [7]: %timeit torch.tensordot(A, B, [[0, 1], [1, 0]])
2.85 ms ± 191 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [8]: %timeit torch.tensordot(A, B, [[1, 0], [0, 1]])
2.9 ms ± 133 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [9]: %timeit torch.tensordot(A, B, [[1, 0], [1, 0]])
4.07 ms ± 262 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
```

After
```python
In [2]: n = 512

In [3]: A = torch.rand(n, n)

In [4]: B = torch.rand(n, n)

In [5]: %timeit torch.tensordot(A, B, [[0, 1], [0, 1]])
30.7 µs ± 2.51 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [6]: %timeit torch.tensordot(A, B, [[0, 1], [1, 0]])
141 µs ± 6.52 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [7]: %timeit torch.tensordot(A, B, [[1, 0], [0, 1]])
142 µs ± 4.03 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [8]: %timeit torch.tensordot(A, B, [[1, 0], [1, 0]])
62.8 µs ± 4.31 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145936
Approved by: https://github.com/albanD, https://github.com/ngimel
This commit is contained in:
nikitaved
2025-05-13 10:57:30 +00:00
committed by PyTorch MergeBot
parent 8d7dec6e92
commit edc2d539d1
5 changed files with 57 additions and 8 deletions

View File

@ -20,6 +20,7 @@
#include <ATen/ops/addmm.h>
#include <ATen/ops/bilinear_native.h>
#include <ATen/ops/bmm.h>
#include <ATen/ops/dot.h>
#include <ATen/ops/einsum_native.h>
#include <ATen/ops/linear_native.h>
#include <ATen/ops/matmul.h>
@ -811,11 +812,35 @@ Tensor tensordot(const Tensor& input1, const Tensor& input2, IntArrayRef dims1,
rsizes.emplace_back(t2.sym_size(i));
}
}
// permute and reshape for matrix multiplication
t1 = t1.permute(p1).reshape_symint({size1, csize});
t2 = t2.permute(p2).reshape_symint({csize, size2});
// multiply and reshape to target size
return at::mm(t1, t2).reshape_symint(rsizes);
// Full contraction (size1 == 1 and size2 == 1) is much faster when done with dot ...
// TODO(@nikitaved): there are other cases where dot outperforms gemms,
// like, for example, when the non-contracted dims are relatively small.
// NOTE(@nikitaved): contract with gemm when on MPS,
// otherwise issues with the tests xpassing/xfailing
// when enabling the fast-path with dot.
// TODO: resolve that
if ((t1.device().type() == at::kMPS || t2.device().type() == at::kMPS) || size1 != 1 || size2 != 1) {
// permute and reshape for matrix multiplication
t1 = t1.permute(p1).reshape_symint({size1, csize});
t2 = t2.permute(p2).reshape_symint({csize, size2});
// multiply and reshape to target size
return at::mm(t1, t2).reshape_symint(rsizes);
} else {
// permute to align for contraction
t1 = t1.permute(p1);
t2 = t2.permute(p2);
if (t1.is_contiguous() && t2.is_contiguous()) {
// If t1 and t2 are both contiguous, then flatten is a view,
// then dot is the method of choice
return at::dot(t1.flatten(), t2.flatten()).reshape_symint(rsizes);
} else {
// Otherwise mul + sum can be faster as it avoids at most 2x contiguous() calls
// NOTE: t1.dtype == t2.dtype -- check above
return (t1.squeeze() * t2.squeeze()).sum(t1.scalar_type()).reshape_symint(rsizes);
}
}
}
Tensor &tensordot_out(const Tensor& input1, const Tensor& input2, IntArrayRef dims1, IntArrayRef dims2, Tensor& result) {

View File

@ -131,7 +131,6 @@ dtensor_fails = {
xfail("cummin"),
xfail("diagonal_scatter"),
xfail("dist"),
xfail("dot"),
xfail("empty"),
xfail("empty_strided"),
xfail("empty_like"),

View File

@ -9470,6 +9470,24 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
an = torch.from_numpy(np.tensordot(np.zeros((), dtype=np.float32), np.zeros((), dtype=np.float32), 0))
self.assertEqual(a, an)
# Testing the fast path introduced in #145936,
# i.e. reduction to a scalar has to be of right dim.
a = torch.rand(2, 2, device=device)
a_dims = [-1, -2]
b = torch.rand(2, 2, device=device)
b_dims = [-2, -1]
for res_ndim in range(5):
res_torch = torch.tensordot(a, b, [a_dims, b_dims])
self.assertEqual(res_torch.ndim, res_ndim)
res_numpy = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(), [a_dims, b_dims]))
self.assertEqual(res_torch, res_numpy)
if res_ndim % 2:
b.unsqueeze_(0)
else:
a.unsqueeze_(0)
@skipCUDAIfNoCusolver
@skipCUDAIfNoMagma
@skipCPUIfNoLapack

View File

@ -208,6 +208,12 @@ def _scaled_mm_like_strategy(
return mm_strategy
@register_op_strategy(aten.dot.default)
def dot_strategy(op_schema: OpSchema) -> OpStrategy:
mesh = op_schema.get_mesh_from_args()
return _mm_like_strategy("i,i->", mesh, op_schema)
@register_op_strategy(aten.mm.default)
def mm_strategy(op_schema: OpSchema) -> OpStrategy:
mesh = op_schema.get_mesh_from_args()

View File

@ -7101,12 +7101,13 @@ def sample_inputs_tensordot(self, device, dtype, requires_grad, **kwargs):
cases = (
((2, 2, 2), (2, 2, 2), (2)),
((2, 2, 1), (2, 1, 2), ([0, 1], [2, 0])),
((1, 1, 1), (2, 1, 2), ([0, 1], [2, 0])),
)
for first_shape, second_shape, dims in cases:
yield SampleInput(make_tensor(first_shape, dtype=dtype, device=device,
requires_grad=requires_grad),
requires_grad=requires_grad, low=-1, high=+2),
make_tensor(second_shape, dtype=dtype, device=device,
requires_grad=requires_grad),
requires_grad=requires_grad, low=-1, high=+2),
dims=dims)
def sample_inputs_kron(op_info, device, dtype, requires_grad, **kwargs):