mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add tensor overlap check for cross
(#154999)
Fixes #132031 ## Test Result ```python In [1]: import torch ...: torch.manual_seed(0) ...: torch.cuda.manual_seed(0) ...: a = torch.randn(3, 4) ...: b = torch.randn(3, 4) ...: torch.cross(a, b, out=a) --------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) Cell In[1], line 6 4 a = torch.randn(3, 4) 5 b = torch.randn(3, 4) ----> 6 torch.cross(a, b, out=a) RuntimeError: unsupported operation: some elements of the input tensor and the written-to tensor refer to a single memory location. Please clone() the tensor before performing the operation. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/154999 Approved by: https://github.com/lezcano
This commit is contained in:
committed by
PyTorch MergeBot
parent
5b65628906
commit
fa3c38c7ae
@ -6,6 +6,7 @@
|
||||
#include <ATen/WrapDimUtils.h>
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/MemoryOverlap.h>
|
||||
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
@ -77,6 +78,9 @@ Tensor & cross_out(const Tensor & input, const Tensor & other, const std::option
|
||||
|
||||
TORCH_IMPL_FUNC(linalg_cross_out)
|
||||
(const Tensor & input, const Tensor & other, int64_t dim, const Tensor & out) {
|
||||
at::assert_no_internal_overlap(out);
|
||||
at::assert_no_overlap(out, input);
|
||||
at::assert_no_overlap(out, other);
|
||||
dim = maybe_wrap_dim(dim, input.dim());
|
||||
auto out_size = out.sizes();
|
||||
Tensor input_broadcasted = input.expand(out_size);
|
||||
|
@ -6022,6 +6022,18 @@ class TestLinalg(TestCase):
|
||||
self.assertEqual(res1, res2)
|
||||
self.assertEqual(res1, res3)
|
||||
|
||||
def test_cross_error(self, device):
|
||||
x = torch.randn(4, 3, device=device)
|
||||
y = torch.randn(4, 3, device=device)
|
||||
with self.assertRaisesRegex(RuntimeError, "input tensor and the written-to tensor refer to a single memory location"):
|
||||
torch.cross(x, y, out=x)
|
||||
with self.assertRaisesRegex(RuntimeError, "input tensor and the written-to tensor refer to a single memory location"):
|
||||
torch.cross(y, x, out=x)
|
||||
with self.assertRaisesRegex(RuntimeError, "input tensor and the written-to tensor refer to a single memory location"):
|
||||
torch.linalg.cross(x, y, out=x)
|
||||
with self.assertRaisesRegex(RuntimeError, "input tensor and the written-to tensor refer to a single memory location"):
|
||||
torch.linalg.cross(y, x, out=x)
|
||||
|
||||
def test_renorm(self, device):
|
||||
m1 = torch.randn(20, 20, device=device) # big enough to exercise vectorized path
|
||||
res1 = torch.tensor((), device=device)
|
||||
|
Reference in New Issue
Block a user