Add tensor overlap check for cross (#154999)

Fixes #132031

## Test Result

```python
In [1]: import torch
   ...: torch.manual_seed(0)
   ...: torch.cuda.manual_seed(0)
   ...: a = torch.randn(3, 4)
   ...: b = torch.randn(3, 4)
   ...: torch.cross(a, b, out=a)

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[1], line 6
      4 a = torch.randn(3, 4)
      5 b = torch.randn(3, 4)
----> 6 torch.cross(a, b, out=a)

RuntimeError: unsupported operation: some elements of the input tensor and the written-to tensor refer to a single memory location. Please clone() the tensor before performing the operation.
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154999
Approved by: https://github.com/lezcano
This commit is contained in:
zeshengzong
2025-06-05 09:59:57 +00:00
committed by PyTorch MergeBot
parent 5b65628906
commit fa3c38c7ae
2 changed files with 16 additions and 0 deletions

View File

@ -6,6 +6,7 @@
#include <ATen/WrapDimUtils.h>
#include <ATen/ExpandUtils.h>
#include <ATen/native/Resize.h>
#include <ATen/MemoryOverlap.h>
#ifndef AT_PER_OPERATOR_HEADERS
@ -77,6 +78,9 @@ Tensor & cross_out(const Tensor & input, const Tensor & other, const std::option
TORCH_IMPL_FUNC(linalg_cross_out)
(const Tensor & input, const Tensor & other, int64_t dim, const Tensor & out) {
at::assert_no_internal_overlap(out);
at::assert_no_overlap(out, input);
at::assert_no_overlap(out, other);
dim = maybe_wrap_dim(dim, input.dim());
auto out_size = out.sizes();
Tensor input_broadcasted = input.expand(out_size);

View File

@ -6022,6 +6022,18 @@ class TestLinalg(TestCase):
self.assertEqual(res1, res2)
self.assertEqual(res1, res3)
def test_cross_error(self, device):
x = torch.randn(4, 3, device=device)
y = torch.randn(4, 3, device=device)
with self.assertRaisesRegex(RuntimeError, "input tensor and the written-to tensor refer to a single memory location"):
torch.cross(x, y, out=x)
with self.assertRaisesRegex(RuntimeError, "input tensor and the written-to tensor refer to a single memory location"):
torch.cross(y, x, out=x)
with self.assertRaisesRegex(RuntimeError, "input tensor and the written-to tensor refer to a single memory location"):
torch.linalg.cross(x, y, out=x)
with self.assertRaisesRegex(RuntimeError, "input tensor and the written-to tensor refer to a single memory location"):
torch.linalg.cross(y, x, out=x)
def test_renorm(self, device):
m1 = torch.randn(20, 20, device=device) # big enough to exercise vectorized path
res1 = torch.tensor((), device=device)