Added complex support for torch.logsumexp (#133187)

Added complex support for `torch.logsumexp`. Implemented complex backward pass for `torch.logsumexp`.

Fixes #133047

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133187
Approved by: https://github.com/amjames, https://github.com/lezcano
This commit is contained in:
Tobias Ringwald
2024-09-03 17:28:36 +00:00
committed by PyTorch MergeBot
parent 6c3767452d
commit 758d787901
9 changed files with 48 additions and 15 deletions

View File

@ -487,10 +487,14 @@ class TestReductions(TestCase):
self.assertEqual(y, y2)
@skipIfNoSciPy
def test_logsumexp(self, device):
@dtypes(torch.float32, torch.double, torch.complex64, torch.complex128)
def test_logsumexp(self, device, dtype):
from scipy.special import logsumexp
a = torch.randn(5, 4, device=device)
a[0, 0] = inf
a = torch.randn(5, 4, device=device, dtype=dtype)
# torch.exp(complex(inf, 0)) yields inf+nan*j instead of inf+0*j on CPU which disagrees with CUDA, C++ std::exp,
# numpy and scipy. Skip inf testing on CPU. Related to https://github.com/pytorch/pytorch/issues/95740
if torch.device(device) != torch.device('cpu'):
a[0, 0] = inf
a[1, :] = -inf
actual = a.logsumexp(1)
expected = logsumexp(a.cpu().numpy(), 1)
@ -498,11 +502,14 @@ class TestReductions(TestCase):
self.assertEqual(expected, actual)
# check that out is actually inplace
b = torch.zeros(5, 2, device=device)
b = torch.zeros(5, 2, device=device, dtype=dtype)
c = b[:, 0]
torch.logsumexp(a, 1, out=c)
self.assertEqual(expected, b[:, 0])
@skipIfNoSciPy
def test_logsumexp_integral_promotion(self, device):
from scipy.special import logsumexp
# check integral inputs is promoted to floating point
e = torch.randint(-100, 100, [5, 4], device=device)
actual = e.logsumexp(1).to(torch.float64)