Make torch.conj() a composite function and return self for real tensors (#43270)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43270

`torch.conj` is a very commonly used operator for complex tensors, but it's mathematically a no op for real tensors. Switching to tensorflow gradients for complex tensors (as discussed in #41857) would involve adding `torch.conj()` to the backward definitions for a lot of operators. In order to preserve autograd performance for real tensors and maintain numpy compatibility for `torch.conj`, this PR updates `torch.conj()` which behaves the same for complex tensors but performs a view/returns `self` tensor for tensors of non-complex dtypes. The documentation states that the returned tensor for a real input shouldn't be mutated. We could perhaps return an immutable tensor for this case in future when that functionality is available (zdevito ezyang ).

Test Plan: Imported from OSS

Reviewed By: mruberry

Differential Revision: D23460493

Pulled By: anjali411

fbshipit-source-id: 3b3bf0af55423b77ff2d0e29f5d2c160291ae3d9
This commit is contained in:
anjali411
2020-09-02 17:04:28 -07:00
committed by Facebook GitHub Bot
parent f9efcb646b
commit 129f406062
6 changed files with 32 additions and 5 deletions

View File

@ -238,6 +238,8 @@ _(aten, clamp_min) \
_(aten, clone) \
_(aten, coalesce) \
_(aten, combinations) \
_(aten, _conj) \
_(aten, conj) \
_(aten, complex) \
_(aten, polar) \
_(aten, constant_pad_nd) \

View File

@ -200,8 +200,18 @@ Tensor imag(const Tensor& self) {
}
}
Tensor& conj_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, conj_stub); }
Tensor conj(const Tensor& self) { return unary_op_impl(self, at::conj_out); }
Tensor& conj_out(Tensor& result, const Tensor& self) {
return unary_op_impl_out(result, self, conj_stub);
}
Tensor _conj(const Tensor& self) { return unary_op_impl(self, at::conj_out); }
Tensor conj(const Tensor& self) {
if (!self.is_complex()) {
return self;
}
return at::_conj(self);
}
Tensor& bitwise_not_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, bitwise_not_stub); }
Tensor bitwise_not(const Tensor& self) { return unary_op_impl(self, at::bitwise_not_out); }

View File

@ -288,6 +288,10 @@
- func: conj.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
- func: _conj(Tensor self) -> Tensor
use_c10_dispatcher: full
variants: function
- func: acos(Tensor self) -> Tensor
use_c10_dispatcher: full
variants: function, method

View File

@ -19061,6 +19061,12 @@ class TestViewOps(TestCase):
else:
return x.transpose(dim0, dim1)
@dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes()))
def test_conj_self(self, device, dtype):
t = torch.ones(5, 5, device=device)
s = t.conj()
self.assertTrue(s is t)
@onlyOnCPUAndCUDA
def test_view_as_complex(self, device):
def fn(contiguous_input=True, dim0=0, dim1=1):

View File

@ -337,7 +337,7 @@
abs: not_implemented("polar abs")
angle: not_implemented("polar angle")
- name: conj(Tensor self) -> Tensor
- name: _conj(Tensor self) -> Tensor
self: grad.conj()
- name: cos(Tensor self) -> Tensor

View File

@ -1778,7 +1778,12 @@ add_docstr(torch.conj,
r"""
conj(input, *, out=None) -> Tensor
Computes the element-wise conjugate of the given :attr:`input` tensor.
Computes the element-wise conjugate of the given :attr:`input` tensor. If :attr`input` has a non-complex dtype,
this function just returns :attr:`input`.
.. warning:: In the future, :func:`torch.conj` may return a non-writeable view for an :attr:`input` of
non-complex dtype. It's recommended that programs not modify the tensor returned by :func:`torch.conj`
when :attr:`input` is of non-complex dtype to be compatible with this change.
.. math::
\text{out}_{i} = conj(\text{input}_{i})
@ -6860,7 +6865,7 @@ Subtracts :attr:`other`, scaled by :attr:`alpha`, from :attr:`input`.
\text{{out}}_i = \text{{input}}_i - \text{{alpha}} \times \text{{other}}_i
""" + r"""
Supports :ref:`broadcasting to a common shape <broadcasting-semantics>`,
Supports :ref:`broadcasting to a common shape <broadcasting-semantics>`,
:ref:`type promotion <type-promotion-doc>`, and integer, float, and complex inputs.
Args: