mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Added complex support for torch.logsumexp
(#133187)
Added complex support for `torch.logsumexp`. Implemented complex backward pass for `torch.logsumexp`. Fixes #133047 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133187 Approved by: https://github.com/amjames, https://github.com/lezcano
This commit is contained in:
committed by
PyTorch MergeBot
parent
6c3767452d
commit
758d787901
@ -1456,7 +1456,9 @@ Tensor nanmean(
|
||||
static Tensor& logsumexp_out_impl(Tensor& result, const Tensor& self, IntArrayRef dims, bool keepdim) {
|
||||
// can't take max of empty tensor
|
||||
if (self.numel() != 0) {
|
||||
auto maxes = at::amax(self, dims, true);
|
||||
// For complex numbers, use the real part to calculate the max. Based on
|
||||
// https://scicomp.stackexchange.com/questions/34273/log-sum-exp-trick-for-signed-complex-numbers
|
||||
auto maxes = at::amax(at::real(self), dims, true);
|
||||
auto maxes_squeezed = (keepdim ? maxes : at::squeeze(maxes, dims));
|
||||
maxes_squeezed.masked_fill_(maxes_squeezed.abs() == INFINITY, 0);
|
||||
at::sum_out(result, (self - maxes).exp_(), dims, keepdim);
|
||||
@ -1469,7 +1471,8 @@ static Tensor& logsumexp_out_impl(Tensor& result, const Tensor& self, IntArrayRe
|
||||
}
|
||||
|
||||
Tensor& logsumexp_out(const Tensor& self, IntArrayRef dims, bool keepdim, Tensor& result) {
|
||||
TORCH_CHECK(at::isFloatingType(result.scalar_type()),
|
||||
// Complex type implies floating point type
|
||||
TORCH_CHECK(at::isFloatingType(result.scalar_type()) || at::isComplexType(result.scalar_type()),
|
||||
"logsumexp(): Expected floating point type for result tensor, but got: ",
|
||||
result.scalar_type());
|
||||
{
|
||||
|
@ -438,6 +438,7 @@ def mps_ops_modifier(ops):
|
||||
'logical_not',
|
||||
'logical_or',
|
||||
'logical_xor',
|
||||
'logsumexp',
|
||||
'long',
|
||||
'masked_fill',
|
||||
'masked.mean',
|
||||
@ -445,6 +446,7 @@ def mps_ops_modifier(ops):
|
||||
'masked.std',
|
||||
'masked.sum',
|
||||
'masked.var',
|
||||
'masked.logsumexp',
|
||||
'matmul',
|
||||
'mean',
|
||||
'mm',
|
||||
@ -6540,6 +6542,18 @@ class TestMPS(TestCaseMPS):
|
||||
|
||||
helper((2, 8, 4, 5))
|
||||
|
||||
def test_logsumexp(self):
|
||||
def helper(shape):
|
||||
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
|
||||
x = cpu_x.detach().clone().to('mps')
|
||||
|
||||
log_result = torch.logsumexp(x, -1)
|
||||
log_result_cpu = torch.logsumexp(cpu_x, -1)
|
||||
|
||||
self.assertEqual(log_result, log_result_cpu)
|
||||
|
||||
helper((2, 8, 4, 5))
|
||||
|
||||
# Test concat forward
|
||||
def test_cat2(self):
|
||||
|
||||
|
@ -487,10 +487,14 @@ class TestReductions(TestCase):
|
||||
self.assertEqual(y, y2)
|
||||
|
||||
@skipIfNoSciPy
|
||||
def test_logsumexp(self, device):
|
||||
@dtypes(torch.float32, torch.double, torch.complex64, torch.complex128)
|
||||
def test_logsumexp(self, device, dtype):
|
||||
from scipy.special import logsumexp
|
||||
a = torch.randn(5, 4, device=device)
|
||||
a[0, 0] = inf
|
||||
a = torch.randn(5, 4, device=device, dtype=dtype)
|
||||
# torch.exp(complex(inf, 0)) yields inf+nan*j instead of inf+0*j on CPU which disagrees with CUDA, C++ std::exp,
|
||||
# numpy and scipy. Skip inf testing on CPU. Related to https://github.com/pytorch/pytorch/issues/95740
|
||||
if torch.device(device) != torch.device('cpu'):
|
||||
a[0, 0] = inf
|
||||
a[1, :] = -inf
|
||||
actual = a.logsumexp(1)
|
||||
expected = logsumexp(a.cpu().numpy(), 1)
|
||||
@ -498,11 +502,14 @@ class TestReductions(TestCase):
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
# check that out is actually inplace
|
||||
b = torch.zeros(5, 2, device=device)
|
||||
b = torch.zeros(5, 2, device=device, dtype=dtype)
|
||||
c = b[:, 0]
|
||||
torch.logsumexp(a, 1, out=c)
|
||||
self.assertEqual(expected, b[:, 0])
|
||||
|
||||
@skipIfNoSciPy
|
||||
def test_logsumexp_integral_promotion(self, device):
|
||||
from scipy.special import logsumexp
|
||||
# check integral inputs is promoted to floating point
|
||||
e = torch.randint(-100, 100, [5, 4], device=device)
|
||||
actual = e.logsumexp(1).to(torch.float64)
|
||||
|
@ -259,6 +259,7 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
|
||||
"log1p",
|
||||
"log2",
|
||||
"logaddexp",
|
||||
"logsumexp",
|
||||
"logcumsumexp",
|
||||
"reciprocal",
|
||||
"tan",
|
||||
|
@ -804,7 +804,7 @@ def logsumexp(
|
||||
dim = (dim,)
|
||||
if self.numel() == 0:
|
||||
return torch.sum(torch.exp(self), dim, keepdim).log()
|
||||
maxes = torch.amax(self, dim, keepdim=True)
|
||||
maxes = torch.amax(torch.real(self), dim, keepdim=True)
|
||||
maxes = torch.masked_fill(maxes, maxes.abs() == float("inf"), 0)
|
||||
maxes_squeezed = maxes if keepdim else torch.squeeze(maxes, dim)
|
||||
result = torch.sum(torch.exp(self - maxes), dim, keepdim)
|
||||
|
@ -897,7 +897,7 @@ Tensor logsumexp_backward(
|
||||
grad = unsqueeze_multiple(grad, dim, self.sym_sizes().size());
|
||||
result = unsqueeze_multiple(result, dim, self.sym_sizes().size());
|
||||
}
|
||||
return grad * (self - result).exp();
|
||||
return grad * (self - result).exp().conj();
|
||||
}
|
||||
|
||||
Tensor logcumsumexp_backward(
|
||||
@ -6689,7 +6689,8 @@ Tensor logsumexp_jvp(
|
||||
// forward
|
||||
auto self_p_exp = [&self_p, &dim]() {
|
||||
if (self_p.sym_numel() > 0) {
|
||||
return (self_p - at::amax(self_p, dim, true))
|
||||
// Use only the real part for complex tensors
|
||||
return (self_p - at::amax(at::real(self_p), dim, true))
|
||||
.exp(); // Use the exp-normalize trick
|
||||
} else {
|
||||
// amax fails if numel() == 0, in which case it doesn't matter anyway
|
||||
|
@ -418,11 +418,18 @@ def _reduction_identity(op_name: str, input: Tensor, *args):
|
||||
return torch.tensor(0, dtype=dtype, device=device)
|
||||
elif op_name in {"prod", "cumprod"}:
|
||||
return torch.tensor(1, dtype=dtype, device=device)
|
||||
elif op_name in {"amax", "argmax", "logsumexp"}:
|
||||
elif op_name in {"amax", "argmax", "logaddexp"}:
|
||||
if torch.is_floating_point(input):
|
||||
return torch.tensor(-torch.inf, dtype=dtype, device=device)
|
||||
elif torch.is_signed(input) or dtype == torch.uint8:
|
||||
return torch.tensor(torch.iinfo(dtype).min, dtype=dtype, device=device)
|
||||
elif op_name in {"logsumexp"}:
|
||||
if torch.is_floating_point(input):
|
||||
return torch.tensor(-torch.inf, dtype=dtype, device=device)
|
||||
elif torch.is_complex(input):
|
||||
return torch.tensor(-torch.inf + 0j, dtype=dtype, device=device)
|
||||
elif torch.is_signed(input) or dtype == torch.uint8:
|
||||
return torch.tensor(torch.iinfo(dtype).min, dtype=dtype, device=device)
|
||||
elif op_name in {"amin", "argmin"}:
|
||||
if torch.is_floating_point(input):
|
||||
return torch.tensor(torch.inf, dtype=dtype, device=device)
|
||||
@ -1523,8 +1530,8 @@ def logaddexp(
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
if input.layout == torch.strided and other.layout == torch.strided:
|
||||
mask_input = _combine_input_and_mask(logsumexp, input, input_mask)
|
||||
mask_other = _combine_input_and_mask(logsumexp, other, other_mask)
|
||||
mask_input = _combine_input_and_mask(logaddexp, input, input_mask)
|
||||
mask_other = _combine_input_and_mask(logaddexp, other, other_mask)
|
||||
return torch.logaddexp(mask_input, mask_other).to(dtype=dtype)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
@ -1573,7 +1573,7 @@ def sample_inputs_logsumexp(self, device, dtype, requires_grad, **kwargs):
|
||||
((S, S), (0, 1), False),
|
||||
)
|
||||
# Test large inputs to check numerical stability
|
||||
lows = (None, 1e3, 1e6) if dtype in (torch.float32, torch.float64) else (None,)
|
||||
lows = (None, 1e3, 1e6) if dtype in (torch.float32, torch.float64, torch.complex64, torch.complex128) else (None,)
|
||||
for low in lows:
|
||||
high = low * 2 if low is not None else None
|
||||
for shape, dim, keepdim in inputs:
|
||||
@ -19593,7 +19593,7 @@ op_db: List[OpInfo] = [
|
||||
sample_inputs_func=sample_inputs_zero_),
|
||||
OpInfo('logsumexp',
|
||||
aliases=('special.logsumexp',),
|
||||
dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
|
||||
assert_autodiffed=True,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
|
@ -1182,7 +1182,7 @@ op_db: List[OpInfo] = [
|
||||
),
|
||||
ReductionOpInfo(
|
||||
"masked.logsumexp",
|
||||
dtypes=all_types_and(torch.half, torch.bfloat16),
|
||||
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
|
||||
method_variant=None,
|
||||
nan_policy="propagate",
|
||||
supports_out=False,
|
||||
|
Reference in New Issue
Block a user