mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Make torch.conj() a composite function and return self for real tensors (#43270)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/43270 `torch.conj` is a very commonly used operator for complex tensors, but it's mathematically a no op for real tensors. Switching to tensorflow gradients for complex tensors (as discussed in #41857) would involve adding `torch.conj()` to the backward definitions for a lot of operators. In order to preserve autograd performance for real tensors and maintain numpy compatibility for `torch.conj`, this PR updates `torch.conj()` which behaves the same for complex tensors but performs a view/returns `self` tensor for tensors of non-complex dtypes. The documentation states that the returned tensor for a real input shouldn't be mutated. We could perhaps return an immutable tensor for this case in future when that functionality is available (zdevito ezyang ). Test Plan: Imported from OSS Reviewed By: mruberry Differential Revision: D23460493 Pulled By: anjali411 fbshipit-source-id: 3b3bf0af55423b77ff2d0e29f5d2c160291ae3d9
This commit is contained in:
committed by
Facebook GitHub Bot
parent
f9efcb646b
commit
129f406062
@ -238,6 +238,8 @@ _(aten, clamp_min) \
|
||||
_(aten, clone) \
|
||||
_(aten, coalesce) \
|
||||
_(aten, combinations) \
|
||||
_(aten, _conj) \
|
||||
_(aten, conj) \
|
||||
_(aten, complex) \
|
||||
_(aten, polar) \
|
||||
_(aten, constant_pad_nd) \
|
||||
|
@ -200,8 +200,18 @@ Tensor imag(const Tensor& self) {
|
||||
}
|
||||
}
|
||||
|
||||
Tensor& conj_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, conj_stub); }
|
||||
Tensor conj(const Tensor& self) { return unary_op_impl(self, at::conj_out); }
|
||||
Tensor& conj_out(Tensor& result, const Tensor& self) {
|
||||
return unary_op_impl_out(result, self, conj_stub);
|
||||
}
|
||||
|
||||
Tensor _conj(const Tensor& self) { return unary_op_impl(self, at::conj_out); }
|
||||
|
||||
Tensor conj(const Tensor& self) {
|
||||
if (!self.is_complex()) {
|
||||
return self;
|
||||
}
|
||||
return at::_conj(self);
|
||||
}
|
||||
|
||||
Tensor& bitwise_not_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, bitwise_not_stub); }
|
||||
Tensor bitwise_not(const Tensor& self) { return unary_op_impl(self, at::bitwise_not_out); }
|
||||
|
@ -288,6 +288,10 @@
|
||||
|
||||
- func: conj.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
|
||||
- func: _conj(Tensor self) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
variants: function
|
||||
|
||||
- func: acos(Tensor self) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
variants: function, method
|
||||
|
@ -19061,6 +19061,12 @@ class TestViewOps(TestCase):
|
||||
else:
|
||||
return x.transpose(dim0, dim1)
|
||||
|
||||
@dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes()))
|
||||
def test_conj_self(self, device, dtype):
|
||||
t = torch.ones(5, 5, device=device)
|
||||
s = t.conj()
|
||||
self.assertTrue(s is t)
|
||||
|
||||
@onlyOnCPUAndCUDA
|
||||
def test_view_as_complex(self, device):
|
||||
def fn(contiguous_input=True, dim0=0, dim1=1):
|
||||
|
@ -337,7 +337,7 @@
|
||||
abs: not_implemented("polar abs")
|
||||
angle: not_implemented("polar angle")
|
||||
|
||||
- name: conj(Tensor self) -> Tensor
|
||||
- name: _conj(Tensor self) -> Tensor
|
||||
self: grad.conj()
|
||||
|
||||
- name: cos(Tensor self) -> Tensor
|
||||
|
@ -1778,7 +1778,12 @@ add_docstr(torch.conj,
|
||||
r"""
|
||||
conj(input, *, out=None) -> Tensor
|
||||
|
||||
Computes the element-wise conjugate of the given :attr:`input` tensor.
|
||||
Computes the element-wise conjugate of the given :attr:`input` tensor. If :attr`input` has a non-complex dtype,
|
||||
this function just returns :attr:`input`.
|
||||
|
||||
.. warning:: In the future, :func:`torch.conj` may return a non-writeable view for an :attr:`input` of
|
||||
non-complex dtype. It's recommended that programs not modify the tensor returned by :func:`torch.conj`
|
||||
when :attr:`input` is of non-complex dtype to be compatible with this change.
|
||||
|
||||
.. math::
|
||||
\text{out}_{i} = conj(\text{input}_{i})
|
||||
@ -6860,7 +6865,7 @@ Subtracts :attr:`other`, scaled by :attr:`alpha`, from :attr:`input`.
|
||||
\text{{out}}_i = \text{{input}}_i - \text{{alpha}} \times \text{{other}}_i
|
||||
""" + r"""
|
||||
|
||||
Supports :ref:`broadcasting to a common shape <broadcasting-semantics>`,
|
||||
Supports :ref:`broadcasting to a common shape <broadcasting-semantics>`,
|
||||
:ref:`type promotion <type-promotion-doc>`, and integer, float, and complex inputs.
|
||||
|
||||
Args:
|
||||
|
Reference in New Issue
Block a user