mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
make einsum produce contiguous inputs in more cases (#161755)
Fixes #161729 Written by codex This won't produce contiguous inputs for all einsum applications, because we flatten all right-only and left-only dimensions, so if right and left operand dimensions are interleaved in output, we cannot (with current algo) produce contiguous output, however, for common cases like in the linked issue it works. Let's see what CI says Pull Request resolved: https://github.com/pytorch/pytorch/pull/161755 Approved by: https://github.com/malfet, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
348d781055
commit
e9bbd28f22
@ -4254,6 +4254,18 @@ class TestLinalg(TestCase):
|
||||
|
||||
test(500)
|
||||
|
||||
@dtypes(torch.float)
|
||||
def test_einsum_output_layout(self, device, dtype):
|
||||
batch, in_dim, out_dim = 2, 3, 5
|
||||
x = make_tensor((batch, in_dim), dtype=dtype, device=device)
|
||||
w = make_tensor((out_dim, in_dim), dtype=dtype, device=device)
|
||||
result = torch.einsum("fd,bd->bf", w, x)
|
||||
expected = x.matmul(w.t())
|
||||
self.assertEqual(result, expected)
|
||||
self.assertTrue(result.is_contiguous())
|
||||
self.assertEqual(result.stride(), expected.stride())
|
||||
|
||||
|
||||
def test_einsum_corner_cases(self, device):
|
||||
def check(equation, *operands, expected_output):
|
||||
tensors = [torch.tensor(operand, device=device, dtype=torch.float32) if not isinstance(operand, tuple)
|
||||
|
Reference in New Issue
Block a user