Don't fastpath conj copy when conj/neg bit mismatch (#108881)

Fixes https://github.com/pytorch/pytorch/issues/106051

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108881
Approved by: https://github.com/soulitzer
This commit is contained in:
Edward Z. Yang
2023-09-08 10:48:49 -07:00
committed by PyTorch MergeBot
parent bd1229477d
commit 137afe74e0
3 changed files with 15 additions and 1 deletions

View File

@ -253,7 +253,9 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking)
self.storage_offset() == src.storage_offset() &&
self.strides().equals(src.strides()) &&
self.sizes().equals(src.sizes()) &&
self.scalar_type() == src.scalar_type()
self.scalar_type() == src.scalar_type() &&
self.is_conj() == src.is_conj() &&
self.is_neg() == src.is_neg()
);
if (is_same_data) {
return self;

View File

@ -25,6 +25,14 @@ class TestComplexTensor(TestCase):
x = torch.tensor([3., 3. + 5.j], device=device)
self.assertEqual(x.dtype, torch.cdouble if dtype == torch.float64 else torch.cfloat)
@dtypes(*complex_types())
def test_conj_copy(self, device, dtype):
# issue: https://github.com/pytorch/pytorch/issues/106051
x1 = torch.tensor([5 + 1j, 2 + 2j], device=device, dtype=dtype)
xc1 = torch.conj(x1)
x1.copy_(xc1)
self.assertEqual(x1, torch.tensor([5 - 1j, 2 - 2j], device=device, dtype=dtype))
@onlyCPU
@dtypes(*complex_types())
def test_eq(self, device, dtype):

View File

@ -3125,6 +3125,10 @@ else:
dst._neg_view().copy_(src)
self.assertEqual(dst, src.neg(), exact_dtype=False)
# issue: https://github.com/pytorch/pytorch/issues/106051
dst._neg_view().copy_(dst)
self.assertEqual(dst, src, exact_dtype=False)
for dst_dtype, src_dtype in [
(torch.complex64, torch.complex64),
(torch.complex128, torch.complex64),