mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Added complex support for torch.logsumexp
(#133187)
Added complex support for `torch.logsumexp`. Implemented complex backward pass for `torch.logsumexp`. Fixes #133047 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133187 Approved by: https://github.com/amjames, https://github.com/lezcano
This commit is contained in:
committed by
PyTorch MergeBot
parent
6c3767452d
commit
758d787901
@ -487,10 +487,14 @@ class TestReductions(TestCase):
|
||||
self.assertEqual(y, y2)
|
||||
|
||||
@skipIfNoSciPy
|
||||
def test_logsumexp(self, device):
|
||||
@dtypes(torch.float32, torch.double, torch.complex64, torch.complex128)
|
||||
def test_logsumexp(self, device, dtype):
|
||||
from scipy.special import logsumexp
|
||||
a = torch.randn(5, 4, device=device)
|
||||
a[0, 0] = inf
|
||||
a = torch.randn(5, 4, device=device, dtype=dtype)
|
||||
# torch.exp(complex(inf, 0)) yields inf+nan*j instead of inf+0*j on CPU which disagrees with CUDA, C++ std::exp,
|
||||
# numpy and scipy. Skip inf testing on CPU. Related to https://github.com/pytorch/pytorch/issues/95740
|
||||
if torch.device(device) != torch.device('cpu'):
|
||||
a[0, 0] = inf
|
||||
a[1, :] = -inf
|
||||
actual = a.logsumexp(1)
|
||||
expected = logsumexp(a.cpu().numpy(), 1)
|
||||
@ -498,11 +502,14 @@ class TestReductions(TestCase):
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
# check that out is actually inplace
|
||||
b = torch.zeros(5, 2, device=device)
|
||||
b = torch.zeros(5, 2, device=device, dtype=dtype)
|
||||
c = b[:, 0]
|
||||
torch.logsumexp(a, 1, out=c)
|
||||
self.assertEqual(expected, b[:, 0])
|
||||
|
||||
@skipIfNoSciPy
|
||||
def test_logsumexp_integral_promotion(self, device):
|
||||
from scipy.special import logsumexp
|
||||
# check integral inputs is promoted to floating point
|
||||
e = torch.randint(-100, 100, [5, 4], device=device)
|
||||
actual = e.logsumexp(1).to(torch.float64)
|
||||
|
Reference in New Issue
Block a user