Added complex support for torch.logsumexp (#133187)

Added complex support for `torch.logsumexp`. Implemented complex backward pass for `torch.logsumexp`.

Fixes #133047

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133187
Approved by: https://github.com/amjames, https://github.com/lezcano
This commit is contained in:
Tobias Ringwald
2024-09-03 17:28:36 +00:00
committed by PyTorch MergeBot
parent 6c3767452d
commit 758d787901
9 changed files with 48 additions and 15 deletions

View File

@ -1456,7 +1456,9 @@ Tensor nanmean(
static Tensor& logsumexp_out_impl(Tensor& result, const Tensor& self, IntArrayRef dims, bool keepdim) {
// can't take max of empty tensor
if (self.numel() != 0) {
auto maxes = at::amax(self, dims, true);
// For complex numbers, use the real part to calculate the max. Based on
// https://scicomp.stackexchange.com/questions/34273/log-sum-exp-trick-for-signed-complex-numbers
auto maxes = at::amax(at::real(self), dims, true);
auto maxes_squeezed = (keepdim ? maxes : at::squeeze(maxes, dims));
maxes_squeezed.masked_fill_(maxes_squeezed.abs() == INFINITY, 0);
at::sum_out(result, (self - maxes).exp_(), dims, keepdim);
@ -1469,7 +1471,8 @@ static Tensor& logsumexp_out_impl(Tensor& result, const Tensor& self, IntArrayRe
}
Tensor& logsumexp_out(const Tensor& self, IntArrayRef dims, bool keepdim, Tensor& result) {
TORCH_CHECK(at::isFloatingType(result.scalar_type()),
// Complex type implies floating point type
TORCH_CHECK(at::isFloatingType(result.scalar_type()) || at::isComplexType(result.scalar_type()),
"logsumexp(): Expected floating point type for result tensor, but got: ",
result.scalar_type());
{

View File

@ -438,6 +438,7 @@ def mps_ops_modifier(ops):
'logical_not',
'logical_or',
'logical_xor',
'logsumexp',
'long',
'masked_fill',
'masked.mean',
@ -445,6 +446,7 @@ def mps_ops_modifier(ops):
'masked.std',
'masked.sum',
'masked.var',
'masked.logsumexp',
'matmul',
'mean',
'mm',
@ -6540,6 +6542,18 @@ class TestMPS(TestCaseMPS):
helper((2, 8, 4, 5))
def test_logsumexp(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
log_result = torch.logsumexp(x, -1)
log_result_cpu = torch.logsumexp(cpu_x, -1)
self.assertEqual(log_result, log_result_cpu)
helper((2, 8, 4, 5))
# Test concat forward
def test_cat2(self):

View File

@ -487,10 +487,14 @@ class TestReductions(TestCase):
self.assertEqual(y, y2)
@skipIfNoSciPy
def test_logsumexp(self, device):
@dtypes(torch.float32, torch.double, torch.complex64, torch.complex128)
def test_logsumexp(self, device, dtype):
from scipy.special import logsumexp
a = torch.randn(5, 4, device=device)
a[0, 0] = inf
a = torch.randn(5, 4, device=device, dtype=dtype)
# torch.exp(complex(inf, 0)) yields inf+nan*j instead of inf+0*j on CPU which disagrees with CUDA, C++ std::exp,
# numpy and scipy. Skip inf testing on CPU. Related to https://github.com/pytorch/pytorch/issues/95740
if torch.device(device) != torch.device('cpu'):
a[0, 0] = inf
a[1, :] = -inf
actual = a.logsumexp(1)
expected = logsumexp(a.cpu().numpy(), 1)
@ -498,11 +502,14 @@ class TestReductions(TestCase):
self.assertEqual(expected, actual)
# check that out is actually inplace
b = torch.zeros(5, 2, device=device)
b = torch.zeros(5, 2, device=device, dtype=dtype)
c = b[:, 0]
torch.logsumexp(a, 1, out=c)
self.assertEqual(expected, b[:, 0])
@skipIfNoSciPy
def test_logsumexp_integral_promotion(self, device):
from scipy.special import logsumexp
# check integral inputs is promoted to floating point
e = torch.randint(-100, 100, [5, 4], device=device)
actual = e.logsumexp(1).to(torch.float64)

View File

@ -259,6 +259,7 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
"log1p",
"log2",
"logaddexp",
"logsumexp",
"logcumsumexp",
"reciprocal",
"tan",

View File

@ -804,7 +804,7 @@ def logsumexp(
dim = (dim,)
if self.numel() == 0:
return torch.sum(torch.exp(self), dim, keepdim).log()
maxes = torch.amax(self, dim, keepdim=True)
maxes = torch.amax(torch.real(self), dim, keepdim=True)
maxes = torch.masked_fill(maxes, maxes.abs() == float("inf"), 0)
maxes_squeezed = maxes if keepdim else torch.squeeze(maxes, dim)
result = torch.sum(torch.exp(self - maxes), dim, keepdim)

View File

@ -897,7 +897,7 @@ Tensor logsumexp_backward(
grad = unsqueeze_multiple(grad, dim, self.sym_sizes().size());
result = unsqueeze_multiple(result, dim, self.sym_sizes().size());
}
return grad * (self - result).exp();
return grad * (self - result).exp().conj();
}
Tensor logcumsumexp_backward(
@ -6689,7 +6689,8 @@ Tensor logsumexp_jvp(
// forward
auto self_p_exp = [&self_p, &dim]() {
if (self_p.sym_numel() > 0) {
return (self_p - at::amax(self_p, dim, true))
// Use only the real part for complex tensors
return (self_p - at::amax(at::real(self_p), dim, true))
.exp(); // Use the exp-normalize trick
} else {
// amax fails if numel() == 0, in which case it doesn't matter anyway

View File

@ -418,11 +418,18 @@ def _reduction_identity(op_name: str, input: Tensor, *args):
return torch.tensor(0, dtype=dtype, device=device)
elif op_name in {"prod", "cumprod"}:
return torch.tensor(1, dtype=dtype, device=device)
elif op_name in {"amax", "argmax", "logsumexp"}:
elif op_name in {"amax", "argmax", "logaddexp"}:
if torch.is_floating_point(input):
return torch.tensor(-torch.inf, dtype=dtype, device=device)
elif torch.is_signed(input) or dtype == torch.uint8:
return torch.tensor(torch.iinfo(dtype).min, dtype=dtype, device=device)
elif op_name in {"logsumexp"}:
if torch.is_floating_point(input):
return torch.tensor(-torch.inf, dtype=dtype, device=device)
elif torch.is_complex(input):
return torch.tensor(-torch.inf + 0j, dtype=dtype, device=device)
elif torch.is_signed(input) or dtype == torch.uint8:
return torch.tensor(torch.iinfo(dtype).min, dtype=dtype, device=device)
elif op_name in {"amin", "argmin"}:
if torch.is_floating_point(input):
return torch.tensor(torch.inf, dtype=dtype, device=device)
@ -1523,8 +1530,8 @@ def logaddexp(
if dtype is None:
dtype = input.dtype
if input.layout == torch.strided and other.layout == torch.strided:
mask_input = _combine_input_and_mask(logsumexp, input, input_mask)
mask_other = _combine_input_and_mask(logsumexp, other, other_mask)
mask_input = _combine_input_and_mask(logaddexp, input, input_mask)
mask_other = _combine_input_and_mask(logaddexp, other, other_mask)
return torch.logaddexp(mask_input, mask_other).to(dtype=dtype)
else:
raise ValueError(

View File

@ -1573,7 +1573,7 @@ def sample_inputs_logsumexp(self, device, dtype, requires_grad, **kwargs):
((S, S), (0, 1), False),
)
# Test large inputs to check numerical stability
lows = (None, 1e3, 1e6) if dtype in (torch.float32, torch.float64) else (None,)
lows = (None, 1e3, 1e6) if dtype in (torch.float32, torch.float64, torch.complex64, torch.complex128) else (None,)
for low in lows:
high = low * 2 if low is not None else None
for shape, dim, keepdim in inputs:
@ -19593,7 +19593,7 @@ op_db: List[OpInfo] = [
sample_inputs_func=sample_inputs_zero_),
OpInfo('logsumexp',
aliases=('special.logsumexp',),
dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
assert_autodiffed=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,

View File

@ -1182,7 +1182,7 @@ op_db: List[OpInfo] = [
),
ReductionOpInfo(
"masked.logsumexp",
dtypes=all_types_and(torch.half, torch.bfloat16),
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
method_variant=None,
nan_policy="propagate",
supports_out=False,