Logcumsumexp for CPU (#93153)

Partial work from #90847, in the direction of solving #89205.
Most of the content is from #90847, but this is only for CPU, so hopefully it does not increase the build time by a lot.

tag: @albanD, @malfet

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93153
Approved by: https://github.com/malfet, https://github.com/Skylion007
This commit is contained in:
mfkasim1
2023-01-27 22:29:30 +00:00
committed by PyTorch MergeBot
parent 61457671a5
commit 75cfc0be21
7 changed files with 192 additions and 29 deletions

View File

@ -112,12 +112,63 @@ static void cumprod_cpu_kernel(const Tensor& result, const Tensor& self, int64_t
);
});
}
// custom min and max to be used in logcumsumexp for complex arguments
template <typename scalar_t, bool min>
c10::complex<scalar_t> _logcumsumexp_minmax(c10::complex<scalar_t> x, c10::complex<scalar_t> y) {
if (std::isnan(y)) { // either real is nan or imag is nan
return y;
} else if (std::isnan(x)) { // either real is nan or imag is nan
return x;
} else {
return ((x.real() < y.real()) == min) ? x : y; // logical xnor
}
}
template <typename scalar_t>
scalar_t _log_add_exp_helper(scalar_t x, scalar_t y) {
// Reference : https://www.tensorflow.org/api_docs/python/tf/math/cumulative_logsumexp
scalar_t min = std::isnan(y) ? y : std::min(x, y); // std::min returns first arg if one of the args is nan
scalar_t max = std::isnan(y) ? y : std::max(x, y); // std::max returns first arg if one of the args is nan
if (min != max || std::isfinite(min)) {
// nan will be propagated here
return std::log1p(std::exp(min - max)) + max;
} else {
// special case to correctly handle infinite cases
return x;
}
}
template <typename scalar_t>
c10::complex<scalar_t> _log_add_exp_helper(const c10::complex<scalar_t>& x, const c10::complex<scalar_t>& y) {
auto min = _logcumsumexp_minmax<scalar_t, true>(x, y);
auto max = _logcumsumexp_minmax<scalar_t, false>(x, y);
auto min_real = std::real(min);
auto max_real = std::real(max);
if (std::isnan(min)) { // either real is nan or imag is nan
// handling the "infectious" NaNs
return {std::numeric_limits<scalar_t>::quiet_NaN(), std::numeric_limits<scalar_t>::quiet_NaN()};
} else if ((!std::isfinite(min_real)) && (min_real == max_real)) {
if (min_real < 0) {
// handle the -inf case, the imaginary part here does not really matter as the exp(value)
// will be around 0.0 and the angle (i.e. the imaginary part) cannot be determined.
// It does not matter if we're taking the exp of this value
return min;
} else {
// handle the +inf case, we don't need the special precision for log1p for small values
// and to avoid producing nan in case of real(max) == real(min) == +inf
return std::log(std::exp(min) + std::exp(max));
}
} else {
return std::log1p(std::exp(min - max)) + max;
}
}
static void logcumsumexp_cpu_kernel(Tensor& result, const Tensor& self, int64_t dim) {
auto wrap_dim = maybe_wrap_dim(dim, self.dim());
int64_t self_dim_size = ensure_nonempty_size(self, wrap_dim);
AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, self.scalar_type(), "logcumsumexp_out_cpu", [&] {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, self.scalar_type(), "logcumsumexp_out_cpu", [&] {
cpu_cum_base_kernel<scalar_t>(result, self, wrap_dim, [&] (
scalar_t* result_data, auto result_dim_stride,
const scalar_t* self_data, auto self_dim_stride, scalar_t init_val) {
@ -126,19 +177,7 @@ static void logcumsumexp_cpu_kernel(Tensor& result, const Tensor& self, int64_t
for (const auto i : c10::irange(self_dim_size)) {
accscalar_t x = self_data[i * self_dim_stride];
// Reference : https://www.tensorflow.org/api_docs/python/tf/math/cumulative_logsumexp
auto log_add_exp = [](accscalar_t x, accscalar_t y) -> accscalar_t {
accscalar_t min = std::isnan(y) ? y : std::min(x,y); //std::min returns first arg if one of the args is nan
accscalar_t max = std::isnan(y) ? y : std::max(x,y); //std::max returns first arg if one of the args is nan
if (min != max || std::isfinite(min)) {
// nan will be propagated here
return std::log1p(std::exp(min - max)) + max;
} else {
// special case to correctly handle infinite cases
return x;
}
};
cum_number = log_add_exp(x, cum_number);
cum_number = _log_add_exp_helper(x, cum_number);
result_data[i * result_dim_stride] = static_cast<scalar_t>(cum_number);
}
}, /*init_val=*/ -std::numeric_limits<scalar_t>::infinity()

View File

@ -38,4 +38,9 @@ namespace std {
template <typename T>
class numeric_limits<c10::complex<T>> : public numeric_limits<T> {};
template <typename T>
bool isnan(const c10::complex<T>& v) {
return std::isnan(v.real()) || std::isnan(v.imag());
}
} // namespace std

View File

@ -614,7 +614,7 @@ meta_function_expected_failures = {
torch.histogram : {f64, f32},
torch.histogramdd : {f64, f32},
torch.kthvalue : {f64, i32, i64, u8, i16, bf16, i8, f32},
torch.logcumsumexp : {f64, bf16, f32},
torch.logcumsumexp : {f64, bf16, f32, c64, c128},
torch.median : {f64, i32, i64, u8, i16, bf16, i8, f32},
torch.mode : {f64, i32, i64, f16, u8, i16, bf16, b8, i8, f32},
torch.multinomial : {f64, bf16, f32},
@ -869,8 +869,8 @@ meta_dispatch_expected_failures = {
aten.histogram.bin_ct : {f32, f64},
aten.histogram.bins_tensor : {f32, f64},
aten.kthvalue.default : {i8, f64, i64, bf16, f32, i32, i16, u8},
aten.logcumsumexp.default : {bf16, f32, f64},
aten.logcumsumexp.out : {bf16, f32, f64},
aten.logcumsumexp.default : {bf16, f32, f64, c64, c128},
aten.logcumsumexp.out : {bf16, f32, f64, c64, c128},
aten.max_pool3d_with_indices.default : {f32, f64},
aten.max_unpool2d.default : {f32, f64},
aten.max_unpool3d.default : {f32, f64},

View File

@ -504,6 +504,117 @@ class TestReductions(TestCase):
self.assertEqual(expected.shape, actual.shape)
self.assertEqual(expected, actual)
@onlyCPU
@skipIfNoSciPy
@dtypes(torch.complex64, torch.complex128)
def test_logcumsumexp_complex(self, device, dtype):
# logcumsumexp is a more precise way to compute than ``log(cumsum(exp(a)))``
# and faster than ``[log(sum(exp(a[:i]))) for i in range(a.shape[0])]``
# the for-loop above should produce similar precision as logcumsumexp (it's just slower),
# so it can be used as the expected values to check our computation
# using logsumexp from scipy because by the time of writing this test code,
# torch.logsumexp has not been implemented for complex numbers
from scipy.special import logsumexp
def zero_out_neg_inf(t):
t = t.clone()
idx = torch.logical_and(~(torch.isfinite(t)), torch.real(t) < 0)
t[idx] = torch.real(t[idx]).to(t.dtype)
return t
def standardize_phase(t):
t = torch.real(t) + 1j * (torch.imag(t) % (2 * np.pi))
return t
def logcumsumexp_slow(a, dim):
res_lst = []
for i in range(a.size(dim)):
index = [slice(None, None, None) for _ in range(a.ndim)]
index[dim] = slice(None, i + 1, None)
a_inp = a[tuple(index)]
res_lst.append(logsumexp(a_inp.cpu().numpy(), axis=dim, keepdims=True))
res = np.concatenate(res_lst, axis=dim)
return torch.as_tensor(res)
def compare_logcumsumexp(a, expected=None):
for i in range(a.ndim):
actual = torch.logcumsumexp(a, dim=i)
# if the expected is not given, then revert to scipy's logsumexp
if expected is None:
expected2 = logcumsumexp_slow(a, dim=i)
else:
expected2 = expected
# move the imaginary values to (0, 2 * pi)
actual = standardize_phase(actual)
expected2 = standardize_phase(expected2)
# zeroing the imaginary part of the element if the real part is -inf
# as the imaginary part cannot be determined exactly and it does not
# really matter if we take the exp of the output
actual = zero_out_neg_inf(actual)
expected2 = zero_out_neg_inf(expected2)
self.assertEqual(expected2.shape, actual.shape)
self.assertEqual(expected2, actual)
# randomly specified values
# in this case, scipy.logsumexp should be enough
a1 = torch.randn((5, 10), dtype=dtype, device=device)
compare_logcumsumexp(a1)
# test with some non-normal values
a2 = torch.tensor([1e3 + 0j, 1e-18 + 1e4j, 1e2 + 1e-8j], dtype=dtype, device=device)
compare_logcumsumexp(a2)
# handle special case involving infinites and nans
# here we don't use scipy.logsumexp as it gives confusing answer on
# some inf cases
# see here:
inf = float('inf')
nan = float('nan')
a3_input = torch.tensor([
-inf + 4j,
-inf + 1j,
1.2 + 2.1j,
1e10 + 1e20j,
inf + 0j,
inf + 1j,
inf + 3j,
nan + 2j,
])
a3_expected = torch.tensor([
-inf + 0j,
-inf + 0j,
1.2 + 2.1j,
1e10 + 1e20j,
inf + 0j, # scipy's logsumexp gives (inf + 0.7853982j) here, unclear why
inf + (np.pi / 4) * 1j, # the imaginary part thanks to some weird behaviour of log(inf + infj)
complex(inf, nan),
complex(nan, nan),
])
# windows give strange results on the second-to-last results where it gives inf + pi/4 j
# instead of inf + nan j
if not IS_WINDOWS:
compare_logcumsumexp(a3_input, a3_expected)
a4_input = torch.tensor([
complex(-inf, inf),
complex(-inf, inf),
-inf + 1j,
1.2 + 2.1j,
complex(2.4, inf),
])
a4_expected = torch.tensor([
-inf + 0j,
-inf + 0j,
-inf + 0j,
1.2 + 2.1j,
complex(nan, nan),
])
if not IS_WINDOWS:
compare_logcumsumexp(a4_input, a4_expected)
@onlyCPU
def test_sum_parallel(self, device):
# To use parallel branches we'll need to compare on tensors

View File

@ -245,6 +245,7 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
"log10",
"log1p",
"log2",
"logcumsumexp",
"reciprocal",
"tan",
"pow",

View File

@ -814,26 +814,33 @@ Tensor logcumsumexp_backward(
// Reference: https://github.com/tensorflow/tensorflow/blob/
// 2a5910906a0e0f3dbc186ff9db6386d81a63448c/tensorflow/python/ops/math_grad.py#L1832-L1863
return AT_DISPATCH_FLOATING_TYPES_AND(
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(
at::ScalarType::BFloat16,
at::typeMetaToScalarType(grad.dtype()),
"logcumsumexp_backward",
[grad, self, result, dim]() {
auto grad_min = at::empty_like(grad);
grad_min.fill_(std::numeric_limits<scalar_t>::lowest());
auto log_grad_positive = at::where(grad > 0, grad.log(), grad_min);
auto log_grad_negative = at::where(grad < 0, (-grad).log(), grad_min);
auto reverse_logcumsumexp = [dim](auto x) {
return at::flip(at::logcumsumexp(at::flip(x, {dim}), dim), {dim});
};
auto output_pos =
(reverse_logcumsumexp(log_grad_positive - result) + self).exp();
auto output_neg =
(reverse_logcumsumexp(log_grad_negative - result) + self).exp();
if (!at::is_complex(grad)) {
grad_min.fill_(std::numeric_limits<scalar_t>::lowest());
auto log_grad_positive = at::where(grad > 0, grad.log(), grad_min);
auto log_grad_negative = at::where(grad < 0, (-grad).log(), grad_min);
return output_pos - output_neg;
auto output_pos =
(reverse_logcumsumexp(log_grad_positive - result) + self).exp();
auto output_neg =
(reverse_logcumsumexp(log_grad_negative - result) + self).exp();
return output_pos - output_neg;
} else {
// no trick separating the positive and negative required
auto log_grad = grad.conj().log();
auto output = (reverse_logcumsumexp(log_grad - result) + self).exp();
return output.conj();
}
});
}

View File

@ -16093,9 +16093,9 @@ op_db: List[OpInfo] = [
)
),
OpInfo('logcumsumexp',
dtypes=floating_types_and(torch.bfloat16),
dtypes=floating_and_complex_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
backward_dtypes=floating_types_and(torch.bfloat16),
backward_dtypes=floating_and_complex_types_and(torch.bfloat16),
backward_dtypesIfCUDA=floating_types_and(torch.bfloat16),
skips=(
# AssertionError: UserWarning not triggered : Resized a non-empty tensor but did not warn about it.