mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Logcumsumexp for CPU (#93153)
Partial work from #90847, in the direction of solving #89205. Most of the content is from #90847, but this is only for CPU, so hopefully it does not increase the build time by a lot. tag: @albanD, @malfet Pull Request resolved: https://github.com/pytorch/pytorch/pull/93153 Approved by: https://github.com/malfet, https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
61457671a5
commit
75cfc0be21
@ -112,12 +112,63 @@ static void cumprod_cpu_kernel(const Tensor& result, const Tensor& self, int64_t
|
||||
);
|
||||
});
|
||||
}
|
||||
// custom min and max to be used in logcumsumexp for complex arguments
|
||||
template <typename scalar_t, bool min>
|
||||
c10::complex<scalar_t> _logcumsumexp_minmax(c10::complex<scalar_t> x, c10::complex<scalar_t> y) {
|
||||
if (std::isnan(y)) { // either real is nan or imag is nan
|
||||
return y;
|
||||
} else if (std::isnan(x)) { // either real is nan or imag is nan
|
||||
return x;
|
||||
} else {
|
||||
return ((x.real() < y.real()) == min) ? x : y; // logical xnor
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
scalar_t _log_add_exp_helper(scalar_t x, scalar_t y) {
|
||||
// Reference : https://www.tensorflow.org/api_docs/python/tf/math/cumulative_logsumexp
|
||||
scalar_t min = std::isnan(y) ? y : std::min(x, y); // std::min returns first arg if one of the args is nan
|
||||
scalar_t max = std::isnan(y) ? y : std::max(x, y); // std::max returns first arg if one of the args is nan
|
||||
if (min != max || std::isfinite(min)) {
|
||||
// nan will be propagated here
|
||||
return std::log1p(std::exp(min - max)) + max;
|
||||
} else {
|
||||
// special case to correctly handle infinite cases
|
||||
return x;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
c10::complex<scalar_t> _log_add_exp_helper(const c10::complex<scalar_t>& x, const c10::complex<scalar_t>& y) {
|
||||
auto min = _logcumsumexp_minmax<scalar_t, true>(x, y);
|
||||
auto max = _logcumsumexp_minmax<scalar_t, false>(x, y);
|
||||
auto min_real = std::real(min);
|
||||
auto max_real = std::real(max);
|
||||
|
||||
if (std::isnan(min)) { // either real is nan or imag is nan
|
||||
// handling the "infectious" NaNs
|
||||
return {std::numeric_limits<scalar_t>::quiet_NaN(), std::numeric_limits<scalar_t>::quiet_NaN()};
|
||||
} else if ((!std::isfinite(min_real)) && (min_real == max_real)) {
|
||||
if (min_real < 0) {
|
||||
// handle the -inf case, the imaginary part here does not really matter as the exp(value)
|
||||
// will be around 0.0 and the angle (i.e. the imaginary part) cannot be determined.
|
||||
// It does not matter if we're taking the exp of this value
|
||||
return min;
|
||||
} else {
|
||||
// handle the +inf case, we don't need the special precision for log1p for small values
|
||||
// and to avoid producing nan in case of real(max) == real(min) == +inf
|
||||
return std::log(std::exp(min) + std::exp(max));
|
||||
}
|
||||
} else {
|
||||
return std::log1p(std::exp(min - max)) + max;
|
||||
}
|
||||
}
|
||||
|
||||
static void logcumsumexp_cpu_kernel(Tensor& result, const Tensor& self, int64_t dim) {
|
||||
auto wrap_dim = maybe_wrap_dim(dim, self.dim());
|
||||
int64_t self_dim_size = ensure_nonempty_size(self, wrap_dim);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, self.scalar_type(), "logcumsumexp_out_cpu", [&] {
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, self.scalar_type(), "logcumsumexp_out_cpu", [&] {
|
||||
cpu_cum_base_kernel<scalar_t>(result, self, wrap_dim, [&] (
|
||||
scalar_t* result_data, auto result_dim_stride,
|
||||
const scalar_t* self_data, auto self_dim_stride, scalar_t init_val) {
|
||||
@ -126,19 +177,7 @@ static void logcumsumexp_cpu_kernel(Tensor& result, const Tensor& self, int64_t
|
||||
for (const auto i : c10::irange(self_dim_size)) {
|
||||
accscalar_t x = self_data[i * self_dim_stride];
|
||||
|
||||
// Reference : https://www.tensorflow.org/api_docs/python/tf/math/cumulative_logsumexp
|
||||
auto log_add_exp = [](accscalar_t x, accscalar_t y) -> accscalar_t {
|
||||
accscalar_t min = std::isnan(y) ? y : std::min(x,y); //std::min returns first arg if one of the args is nan
|
||||
accscalar_t max = std::isnan(y) ? y : std::max(x,y); //std::max returns first arg if one of the args is nan
|
||||
if (min != max || std::isfinite(min)) {
|
||||
// nan will be propagated here
|
||||
return std::log1p(std::exp(min - max)) + max;
|
||||
} else {
|
||||
// special case to correctly handle infinite cases
|
||||
return x;
|
||||
}
|
||||
};
|
||||
cum_number = log_add_exp(x, cum_number);
|
||||
cum_number = _log_add_exp_helper(x, cum_number);
|
||||
result_data[i * result_dim_stride] = static_cast<scalar_t>(cum_number);
|
||||
}
|
||||
}, /*init_val=*/ -std::numeric_limits<scalar_t>::infinity()
|
||||
|
@ -38,4 +38,9 @@ namespace std {
|
||||
template <typename T>
|
||||
class numeric_limits<c10::complex<T>> : public numeric_limits<T> {};
|
||||
|
||||
template <typename T>
|
||||
bool isnan(const c10::complex<T>& v) {
|
||||
return std::isnan(v.real()) || std::isnan(v.imag());
|
||||
}
|
||||
|
||||
} // namespace std
|
||||
|
@ -614,7 +614,7 @@ meta_function_expected_failures = {
|
||||
torch.histogram : {f64, f32},
|
||||
torch.histogramdd : {f64, f32},
|
||||
torch.kthvalue : {f64, i32, i64, u8, i16, bf16, i8, f32},
|
||||
torch.logcumsumexp : {f64, bf16, f32},
|
||||
torch.logcumsumexp : {f64, bf16, f32, c64, c128},
|
||||
torch.median : {f64, i32, i64, u8, i16, bf16, i8, f32},
|
||||
torch.mode : {f64, i32, i64, f16, u8, i16, bf16, b8, i8, f32},
|
||||
torch.multinomial : {f64, bf16, f32},
|
||||
@ -869,8 +869,8 @@ meta_dispatch_expected_failures = {
|
||||
aten.histogram.bin_ct : {f32, f64},
|
||||
aten.histogram.bins_tensor : {f32, f64},
|
||||
aten.kthvalue.default : {i8, f64, i64, bf16, f32, i32, i16, u8},
|
||||
aten.logcumsumexp.default : {bf16, f32, f64},
|
||||
aten.logcumsumexp.out : {bf16, f32, f64},
|
||||
aten.logcumsumexp.default : {bf16, f32, f64, c64, c128},
|
||||
aten.logcumsumexp.out : {bf16, f32, f64, c64, c128},
|
||||
aten.max_pool3d_with_indices.default : {f32, f64},
|
||||
aten.max_unpool2d.default : {f32, f64},
|
||||
aten.max_unpool3d.default : {f32, f64},
|
||||
|
@ -504,6 +504,117 @@ class TestReductions(TestCase):
|
||||
self.assertEqual(expected.shape, actual.shape)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@onlyCPU
|
||||
@skipIfNoSciPy
|
||||
@dtypes(torch.complex64, torch.complex128)
|
||||
def test_logcumsumexp_complex(self, device, dtype):
|
||||
# logcumsumexp is a more precise way to compute than ``log(cumsum(exp(a)))``
|
||||
# and faster than ``[log(sum(exp(a[:i]))) for i in range(a.shape[0])]``
|
||||
# the for-loop above should produce similar precision as logcumsumexp (it's just slower),
|
||||
# so it can be used as the expected values to check our computation
|
||||
|
||||
# using logsumexp from scipy because by the time of writing this test code,
|
||||
# torch.logsumexp has not been implemented for complex numbers
|
||||
from scipy.special import logsumexp
|
||||
|
||||
def zero_out_neg_inf(t):
|
||||
t = t.clone()
|
||||
idx = torch.logical_and(~(torch.isfinite(t)), torch.real(t) < 0)
|
||||
t[idx] = torch.real(t[idx]).to(t.dtype)
|
||||
return t
|
||||
|
||||
def standardize_phase(t):
|
||||
t = torch.real(t) + 1j * (torch.imag(t) % (2 * np.pi))
|
||||
return t
|
||||
|
||||
def logcumsumexp_slow(a, dim):
|
||||
res_lst = []
|
||||
for i in range(a.size(dim)):
|
||||
index = [slice(None, None, None) for _ in range(a.ndim)]
|
||||
index[dim] = slice(None, i + 1, None)
|
||||
a_inp = a[tuple(index)]
|
||||
res_lst.append(logsumexp(a_inp.cpu().numpy(), axis=dim, keepdims=True))
|
||||
res = np.concatenate(res_lst, axis=dim)
|
||||
return torch.as_tensor(res)
|
||||
|
||||
def compare_logcumsumexp(a, expected=None):
|
||||
for i in range(a.ndim):
|
||||
actual = torch.logcumsumexp(a, dim=i)
|
||||
# if the expected is not given, then revert to scipy's logsumexp
|
||||
if expected is None:
|
||||
expected2 = logcumsumexp_slow(a, dim=i)
|
||||
else:
|
||||
expected2 = expected
|
||||
|
||||
# move the imaginary values to (0, 2 * pi)
|
||||
actual = standardize_phase(actual)
|
||||
expected2 = standardize_phase(expected2)
|
||||
|
||||
# zeroing the imaginary part of the element if the real part is -inf
|
||||
# as the imaginary part cannot be determined exactly and it does not
|
||||
# really matter if we take the exp of the output
|
||||
actual = zero_out_neg_inf(actual)
|
||||
expected2 = zero_out_neg_inf(expected2)
|
||||
self.assertEqual(expected2.shape, actual.shape)
|
||||
self.assertEqual(expected2, actual)
|
||||
|
||||
# randomly specified values
|
||||
# in this case, scipy.logsumexp should be enough
|
||||
a1 = torch.randn((5, 10), dtype=dtype, device=device)
|
||||
compare_logcumsumexp(a1)
|
||||
|
||||
# test with some non-normal values
|
||||
a2 = torch.tensor([1e3 + 0j, 1e-18 + 1e4j, 1e2 + 1e-8j], dtype=dtype, device=device)
|
||||
compare_logcumsumexp(a2)
|
||||
|
||||
# handle special case involving infinites and nans
|
||||
# here we don't use scipy.logsumexp as it gives confusing answer on
|
||||
# some inf cases
|
||||
# see here:
|
||||
inf = float('inf')
|
||||
nan = float('nan')
|
||||
a3_input = torch.tensor([
|
||||
-inf + 4j,
|
||||
-inf + 1j,
|
||||
1.2 + 2.1j,
|
||||
1e10 + 1e20j,
|
||||
inf + 0j,
|
||||
inf + 1j,
|
||||
inf + 3j,
|
||||
nan + 2j,
|
||||
])
|
||||
a3_expected = torch.tensor([
|
||||
-inf + 0j,
|
||||
-inf + 0j,
|
||||
1.2 + 2.1j,
|
||||
1e10 + 1e20j,
|
||||
inf + 0j, # scipy's logsumexp gives (inf + 0.7853982j) here, unclear why
|
||||
inf + (np.pi / 4) * 1j, # the imaginary part thanks to some weird behaviour of log(inf + infj)
|
||||
complex(inf, nan),
|
||||
complex(nan, nan),
|
||||
])
|
||||
# windows give strange results on the second-to-last results where it gives inf + pi/4 j
|
||||
# instead of inf + nan j
|
||||
if not IS_WINDOWS:
|
||||
compare_logcumsumexp(a3_input, a3_expected)
|
||||
|
||||
a4_input = torch.tensor([
|
||||
complex(-inf, inf),
|
||||
complex(-inf, inf),
|
||||
-inf + 1j,
|
||||
1.2 + 2.1j,
|
||||
complex(2.4, inf),
|
||||
])
|
||||
a4_expected = torch.tensor([
|
||||
-inf + 0j,
|
||||
-inf + 0j,
|
||||
-inf + 0j,
|
||||
1.2 + 2.1j,
|
||||
complex(nan, nan),
|
||||
])
|
||||
if not IS_WINDOWS:
|
||||
compare_logcumsumexp(a4_input, a4_expected)
|
||||
|
||||
@onlyCPU
|
||||
def test_sum_parallel(self, device):
|
||||
# To use parallel branches we'll need to compare on tensors
|
||||
|
@ -245,6 +245,7 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
|
||||
"log10",
|
||||
"log1p",
|
||||
"log2",
|
||||
"logcumsumexp",
|
||||
"reciprocal",
|
||||
"tan",
|
||||
"pow",
|
||||
|
@ -814,26 +814,33 @@ Tensor logcumsumexp_backward(
|
||||
|
||||
// Reference: https://github.com/tensorflow/tensorflow/blob/
|
||||
// 2a5910906a0e0f3dbc186ff9db6386d81a63448c/tensorflow/python/ops/math_grad.py#L1832-L1863
|
||||
return AT_DISPATCH_FLOATING_TYPES_AND(
|
||||
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(
|
||||
at::ScalarType::BFloat16,
|
||||
at::typeMetaToScalarType(grad.dtype()),
|
||||
"logcumsumexp_backward",
|
||||
[grad, self, result, dim]() {
|
||||
auto grad_min = at::empty_like(grad);
|
||||
grad_min.fill_(std::numeric_limits<scalar_t>::lowest());
|
||||
auto log_grad_positive = at::where(grad > 0, grad.log(), grad_min);
|
||||
auto log_grad_negative = at::where(grad < 0, (-grad).log(), grad_min);
|
||||
|
||||
auto reverse_logcumsumexp = [dim](auto x) {
|
||||
return at::flip(at::logcumsumexp(at::flip(x, {dim}), dim), {dim});
|
||||
};
|
||||
|
||||
auto output_pos =
|
||||
(reverse_logcumsumexp(log_grad_positive - result) + self).exp();
|
||||
auto output_neg =
|
||||
(reverse_logcumsumexp(log_grad_negative - result) + self).exp();
|
||||
if (!at::is_complex(grad)) {
|
||||
grad_min.fill_(std::numeric_limits<scalar_t>::lowest());
|
||||
auto log_grad_positive = at::where(grad > 0, grad.log(), grad_min);
|
||||
auto log_grad_negative = at::where(grad < 0, (-grad).log(), grad_min);
|
||||
|
||||
return output_pos - output_neg;
|
||||
auto output_pos =
|
||||
(reverse_logcumsumexp(log_grad_positive - result) + self).exp();
|
||||
auto output_neg =
|
||||
(reverse_logcumsumexp(log_grad_negative - result) + self).exp();
|
||||
|
||||
return output_pos - output_neg;
|
||||
} else {
|
||||
// no trick separating the positive and negative required
|
||||
auto log_grad = grad.conj().log();
|
||||
auto output = (reverse_logcumsumexp(log_grad - result) + self).exp();
|
||||
return output.conj();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -16093,9 +16093,9 @@ op_db: List[OpInfo] = [
|
||||
)
|
||||
),
|
||||
OpInfo('logcumsumexp',
|
||||
dtypes=floating_types_and(torch.bfloat16),
|
||||
dtypes=floating_and_complex_types_and(torch.bfloat16),
|
||||
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
|
||||
backward_dtypes=floating_types_and(torch.bfloat16),
|
||||
backward_dtypes=floating_and_complex_types_and(torch.bfloat16),
|
||||
backward_dtypesIfCUDA=floating_types_and(torch.bfloat16),
|
||||
skips=(
|
||||
# AssertionError: UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
|
||||
|
Reference in New Issue
Block a user