make einsum produce contiguous inputs in more cases (#161755)

Fixes #161729
Written by codex
This won't produce contiguous inputs for all einsum applications, because we flatten all right-only and left-only dimensions, so if right and left operand dimensions are interleaved in output, we cannot (with current algo) produce contiguous output, however, for common cases like in the linked issue it works. Let's see what CI says

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161755
Approved by: https://github.com/malfet, https://github.com/albanD
This commit is contained in:
Natalia Gimelshein
2025-08-29 18:50:42 +00:00
committed by PyTorch MergeBot
parent 348d781055
commit e9bbd28f22
2 changed files with 23 additions and 0 deletions

View File

@ -185,6 +185,17 @@ static Tensor sumproduct_pair(const Tensor& left_, const Tensor& right_, IntArra
// right: "lro, summed, ro" permuted with rpermutation and the three flattened
// then the permuted output is a view of bmm(left, right)
// finally, opermutation reverts the permutation to the original order of dimensions
// By default the output is "lro, lo, 1-for-summed-dims, ro" with original shape dimensions.
// However, if all dimensions from the right operand appear before those from the left
// operand in the final output, we can swap the operands so that bmm directly produces
// the result in the correct memory order.
bool swap_lo_ro = !lo.empty() && !ro.empty() && ro.back() < lo.front();
if (swap_lo_ro) {
std::swap(left, right);
std::swap(lo, ro);
std::swap(lo_size, ro_size);
}
auto out_num_dim = lro.size() + lo.size() + sum_dims_.size() + ro.size();
std::vector<SymInt> out_size;
out_size.reserve(out_num_dim);

View File

@ -4254,6 +4254,18 @@ class TestLinalg(TestCase):
test(500)
@dtypes(torch.float)
def test_einsum_output_layout(self, device, dtype):
batch, in_dim, out_dim = 2, 3, 5
x = make_tensor((batch, in_dim), dtype=dtype, device=device)
w = make_tensor((out_dim, in_dim), dtype=dtype, device=device)
result = torch.einsum("fd,bd->bf", w, x)
expected = x.matmul(w.t())
self.assertEqual(result, expected)
self.assertTrue(result.is_contiguous())
self.assertEqual(result.stride(), expected.stride())
def test_einsum_corner_cases(self, device):
def check(equation, *operands, expected_output):
tensors = [torch.tensor(operand, device=device, dtype=torch.float32) if not isinstance(operand, tuple)