mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[special] add zeta (#59623)
Summary: Reference https://github.com/pytorch/pytorch/issues/50345 `zeta` was already present in the codebase to support computation of `polygamma`. However, `zeta` only had `double(double, double)` signature **for CPU** before the PR (which meant that computation `polygamma` were always upcasted to `double` for zeta part). With this PR, float computations will take place in float and double in double. Have also refactored the code and moved the duplicate code from `Math.cuh` to `Math.h` **Note**: For scipy, q is optional, and if it is `None`, it defaults `1` which corresponds to Reimann-Zeta. However, for `torch.specia.zeta`, I made it mandatory cause for me it feels odd without `q` this is Reimann-Zeta and with `q` it is the general Hurwitz Zeta. I think sticking to just general made more sense as passing `1` for q sounds trivial. Verify: * [x] Docs https://14234587-65600975-gh.circle-artifacts.com/0/docs/special.html#torch.special.zeta Pull Request resolved: https://github.com/pytorch/pytorch/pull/59623 Reviewed By: ngimel Differential Revision: D29348269 Pulled By: mruberry fbshipit-source-id: a3f9ebe1f7724dbe66de2b391afb9da1cfc3e4bb
This commit is contained in:
committed by
Facebook GitHub Bot
parent
26cdec6ce4
commit
dfd2edc025
@ -351,6 +351,7 @@ namespace c10 {
|
||||
_(aten, special_i0e) \
|
||||
_(aten, special_i1) \
|
||||
_(aten, special_i1e) \
|
||||
_(aten, special_zeta) \
|
||||
_(aten, has_torch_function) \
|
||||
_(aten, hardswish) \
|
||||
_(aten, hardswish_) \
|
||||
|
@ -57,6 +57,10 @@ TORCH_META_FUNC(special_xlog1py) (const Tensor& self, const Tensor& other) {
|
||||
build_borrowing_binary_float_op(maybe_get_output(), self, other);
|
||||
}
|
||||
|
||||
TORCH_META_FUNC(special_zeta) (const Tensor& self, const Tensor& other) {
|
||||
build_borrowing_binary_float_op(maybe_get_output(), self, other);
|
||||
}
|
||||
|
||||
TORCH_META_FUNC2(copysign, Tensor) (
|
||||
const Tensor& self, const Tensor& other
|
||||
) {
|
||||
@ -221,6 +225,7 @@ DEFINE_DISPATCH(copysign_stub);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(xlogy_stub);
|
||||
DEFINE_DISPATCH(xlog1py_stub);
|
||||
DEFINE_DISPATCH(zeta_stub);
|
||||
|
||||
TORCH_IMPL_FUNC(add_out) (
|
||||
const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& result
|
||||
@ -262,6 +267,10 @@ TORCH_IMPL_FUNC(special_xlog1py_out) (const Tensor& self, const Tensor& other, c
|
||||
xlog1py_stub(device_type(), *this);
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(special_zeta_out) (const Tensor& self, const Tensor& other, const Tensor& result) {
|
||||
zeta_stub(device_type(), *this);
|
||||
}
|
||||
|
||||
#define CREATE_BINARY_TORCH_IMPL_FUNC(func_out, func_stub) \
|
||||
TORCH_IMPL_FUNC(func_out) (const Tensor& self, const Tensor& other, const Tensor& result) { \
|
||||
func_stub(device_type(), *this); \
|
||||
@ -297,6 +306,22 @@ Tensor& special_xlog1py_out(const Tensor& self, const Scalar& other, Tensor& res
|
||||
return at::special_xlog1py_out(result, self, wrapped_scalar_tensor(other));
|
||||
}
|
||||
|
||||
Tensor special_zeta(const Scalar& x, const Tensor& y) {
|
||||
return at::special_zeta(wrapped_scalar_tensor(x), y);
|
||||
}
|
||||
|
||||
Tensor special_zeta(const Tensor& x, const Scalar& y) {
|
||||
return at::special_zeta(x, wrapped_scalar_tensor(y));
|
||||
}
|
||||
|
||||
Tensor& special_zeta_out(const Scalar& self, const Tensor& other, Tensor& result) {
|
||||
return at::special_zeta_out(result, wrapped_scalar_tensor(self), other);
|
||||
}
|
||||
|
||||
Tensor& special_zeta_out(const Tensor& self, const Scalar& other, Tensor& result) {
|
||||
return at::special_zeta_out(result, self, wrapped_scalar_tensor(other));
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(atan2_out) (const Tensor& self, const Tensor& other, const Tensor& result) {
|
||||
atan2_stub(device_type(), *this);
|
||||
}
|
||||
|
@ -93,5 +93,6 @@ DECLARE_DISPATCH(structured_binary_fn, heaviside_stub);
|
||||
DECLARE_DISPATCH(structured_binary_fn, copysign_stub);
|
||||
DECLARE_DISPATCH(binary_fn, xlogy_stub);
|
||||
DECLARE_DISPATCH(structured_binary_fn, xlog1py_stub);
|
||||
DECLARE_DISPATCH(structured_binary_fn, zeta_stub);
|
||||
|
||||
}} // namespace at::native
|
||||
|
@ -10,6 +10,7 @@
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/util/MathConstants.h>
|
||||
#include <c10/util/math_compat.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
|
||||
|
||||
/* The next function is taken from https://github.com/antelopeusersgroup/antelope_contrib/blob/master/lib/location/libgenloc/erfinv.c.
|
||||
@ -148,9 +149,14 @@ Date: February 1996
|
||||
* This function is derived from the implementation of the zeta function in the Cephes Math Library.
|
||||
* See note [3-Clause BSD License for the Cephes Math Library].
|
||||
*/
|
||||
static inline double zeta(double x, double q) {
|
||||
static double MACHEP = 1.11022302462515654042E-16;
|
||||
static double A[] = {
|
||||
template <typename scalar_t, bool is_cuda=false>
|
||||
C10_HOST_DEVICE static inline scalar_t zeta(scalar_t x, scalar_t q) {
|
||||
using acc_t = at::acc_type<scalar_t, is_cuda>;
|
||||
const acc_t MACHEP = acc_t{1.11022302462515654042E-16};
|
||||
constexpr acc_t zero = acc_t{0.0};
|
||||
constexpr acc_t half = acc_t{0.5};
|
||||
constexpr acc_t one = acc_t{1.0};
|
||||
static const acc_t A[] = {
|
||||
12.0,
|
||||
-720.0,
|
||||
30240.0,
|
||||
@ -166,58 +172,58 @@ static inline double zeta(double x, double q) {
|
||||
};
|
||||
|
||||
int i = 0;
|
||||
double a, b, k, s, t, w;
|
||||
if (x == 1.0) {
|
||||
return INFINITY;
|
||||
acc_t a, b, k, s, t, w;
|
||||
if (x == one) {
|
||||
return std::numeric_limits<scalar_t>::infinity();
|
||||
}
|
||||
|
||||
if (x < 1.0) {
|
||||
return std::numeric_limits<double>::quiet_NaN();
|
||||
if (x < one) {
|
||||
return std::numeric_limits<scalar_t>::quiet_NaN();
|
||||
}
|
||||
|
||||
if (q <= 0.0) {
|
||||
if (q == floor(q)) {
|
||||
return INFINITY;
|
||||
if (q <= zero) {
|
||||
if (q == ::floor(q)) {
|
||||
return std::numeric_limits<scalar_t>::infinity();
|
||||
}
|
||||
if (x != floor(x)) {
|
||||
return std::numeric_limits<double>::quiet_NaN();
|
||||
if (x != ::floor(x)) {
|
||||
return std::numeric_limits<scalar_t>::quiet_NaN();
|
||||
}
|
||||
}
|
||||
|
||||
s = std::pow(q, -x);
|
||||
s = ::pow(q, -x);
|
||||
a = q;
|
||||
i = 0;
|
||||
b = 0.0;
|
||||
while ((i < 9) || (a <= 9.0)) {
|
||||
b = zero;
|
||||
while ((i < 9) || (a <= acc_t{9.0})) {
|
||||
i += 1;
|
||||
a += 1.0;
|
||||
b = std::pow(a, -x);
|
||||
a += one;
|
||||
b = ::pow(a, -x);
|
||||
s += b;
|
||||
if ((-MACHEP * s < b) && (b < MACHEP * s)) {
|
||||
return s;
|
||||
return static_cast<scalar_t>(s);
|
||||
}
|
||||
};
|
||||
|
||||
w = a;
|
||||
s += b * w / (x - 1.0);
|
||||
s -= 0.5 * b;
|
||||
a = 1.0;
|
||||
k = 0.0;
|
||||
s += b * w / (x - one);
|
||||
s -= half * b;
|
||||
a = one;
|
||||
k = zero;
|
||||
for (int i = 0; i < 12; i++) {
|
||||
a *= x + k;
|
||||
b /= w;
|
||||
t = a * b / A[i];
|
||||
s = s + t;
|
||||
t = std::abs(t / s);
|
||||
t = ::abs(t / s);
|
||||
if (t < MACHEP) {
|
||||
return s;
|
||||
return static_cast<scalar_t>(s);
|
||||
}
|
||||
k += 1.0;
|
||||
k += one;
|
||||
a *= x + k;
|
||||
b /= w;
|
||||
k += 1.0;
|
||||
k += one;
|
||||
}
|
||||
return s;
|
||||
return static_cast<scalar_t>(s);
|
||||
}
|
||||
|
||||
/*
|
||||
@ -397,16 +403,12 @@ static inline float calc_digamma(float x) {
|
||||
return result + logf(x) - (0.5f / x) - y;
|
||||
}
|
||||
|
||||
static inline double calc_polygamma(int64_t n, double x) {
|
||||
template <typename scalar_t, bool is_cuda=false>
|
||||
static inline C10_HOST_DEVICE scalar_t calc_polygamma(int n, scalar_t x) {
|
||||
// already blocked if n <= 1
|
||||
return ((n % 2) ? 1.0 : -1.0) * std::exp(lgamma(double(n) + 1.0)) *
|
||||
zeta(double(n + 1), x);
|
||||
}
|
||||
|
||||
static inline float calc_polygamma(int64_t n, float x) {
|
||||
// already blocked if n <= 1
|
||||
return ((n % 2) ? 1.0f : -1.0f) * std::exp(lgamma(double(n) + 1.0)) *
|
||||
zeta(double(n + 1), x);
|
||||
return ((n % 2) ? 1.0 : -1.0) *
|
||||
::exp(::lgamma(static_cast<scalar_t>(n) + 1.0)) *
|
||||
zeta<scalar_t, is_cuda>(static_cast<scalar_t>(n + 1), x);
|
||||
}
|
||||
|
||||
// regularized lower incomplete gamma
|
||||
|
@ -989,6 +989,14 @@ void xlog1py_kernel(TensorIteratorBase& iter) {
|
||||
});
|
||||
}
|
||||
|
||||
void zeta_kernel(TensorIteratorBase& iter) {
|
||||
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "zeta_cpu", [&]() {
|
||||
cpu_kernel(iter, [](scalar_t x, scalar_t q) -> scalar_t {
|
||||
return zeta(x, q);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
@ -1082,6 +1090,7 @@ REGISTER_DISPATCH(copysign_stub, ©sign_kernel);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
REGISTER_DISPATCH(xlogy_stub, &xlogy_kernel);
|
||||
REGISTER_DISPATCH(xlog1py_stub, &xlog1py_kernel);
|
||||
REGISTER_DISPATCH(zeta_stub, &zeta_kernel);
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
@ -3,6 +3,8 @@
|
||||
#include <ATen/native/cuda/Loops.cuh>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/BinaryOps.h>
|
||||
#include <ATen/native/cuda/Math.cuh>
|
||||
#include <ATen/native/Math.h>
|
||||
#include <ATen/NumericUtils.h>
|
||||
|
||||
// NOTE: CUDA on Windows requires that the enclosing function
|
||||
@ -67,11 +69,20 @@ void xlog1py_kernel_cuda(TensorIteratorBase& iter) {
|
||||
});
|
||||
}
|
||||
|
||||
void zeta_kernel_cuda(TensorIteratorBase& iter) {
|
||||
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "zeta_cuda", [&]() {
|
||||
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t x, scalar_t q) -> scalar_t {
|
||||
return zeta<scalar_t, /*is_cuda=*/true>(x, q);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(smooth_l1_stub, &smooth_l1_kernel_cuda);
|
||||
REGISTER_DISPATCH(huber_stub, &huber_kernel_cuda);
|
||||
REGISTER_DISPATCH(mse_stub, &mse_kernel_cuda);
|
||||
REGISTER_DISPATCH(xlogy_stub, &xlogy_kernel_cuda);
|
||||
REGISTER_DISPATCH(xlog1py_stub, &xlog1py_kernel_cuda);
|
||||
REGISTER_DISPATCH(zeta_stub, &zeta_kernel_cuda);
|
||||
|
||||
// DO NOT ADD ANY NEW KERNELS HERE
|
||||
// CUDA compilation times grow quickly. It's perfectly acceptable to have a file per kernel.
|
||||
|
@ -6,88 +6,6 @@
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
/*
|
||||
* For licensing information, please refer to the the cpu implementation located in "ATen/native/Math.h".
|
||||
*/
|
||||
template <typename scalar_t>
|
||||
static inline C10_HOST_DEVICE scalar_t zeta(scalar_t _x, scalar_t _q) {
|
||||
using accscalar_t = at::acc_type<scalar_t, true>;
|
||||
static const accscalar_t MACHEP = 1.11022302462515654042E-16;
|
||||
const accscalar_t A[] = {
|
||||
12.0,
|
||||
-720.0,
|
||||
30240.0,
|
||||
-1209600.0,
|
||||
47900160.0,
|
||||
-1.8924375803183791606e9, /*1.307674368e12/691*/
|
||||
7.47242496e10,
|
||||
-2.950130727918164224e12, /*1.067062284288e16/3617*/
|
||||
1.1646782814350067249e14, /*5.109094217170944e18/43867*/
|
||||
-4.5979787224074726105e15, /*8.028576626982912e20/174611*/
|
||||
1.8152105401943546773e17, /*1.5511210043330985984e23/854513*/
|
||||
-7.1661652561756670113e18 /*1.6938241367317436694528e27/236364091*/
|
||||
};
|
||||
accscalar_t x = static_cast<accscalar_t>(_x);
|
||||
accscalar_t q = static_cast<accscalar_t>(_q);
|
||||
|
||||
int i = 0;
|
||||
accscalar_t a, b, k, s, t, w;
|
||||
if( x == 1.0 ) {
|
||||
return static_cast<scalar_t>(INFINITY);
|
||||
}
|
||||
|
||||
if( x < 1.0 ){
|
||||
std::numeric_limits<scalar_t>::quiet_NaN();
|
||||
}
|
||||
bool q_is_integer = q == ::floor(q);
|
||||
|
||||
if(q <= 0.0) {
|
||||
if(q_is_integer) {
|
||||
return static_cast<scalar_t>(INFINITY);
|
||||
}
|
||||
else {
|
||||
std::numeric_limits<scalar_t>::quiet_NaN();
|
||||
}
|
||||
}
|
||||
|
||||
s = ::pow(q, -x);
|
||||
a = q;
|
||||
i = 0;
|
||||
b = 0.0;
|
||||
while ((i < 9) || (a <= 9.0)) {
|
||||
i += 1;
|
||||
a += 1.0;
|
||||
b = ::pow( a, -x );
|
||||
s += b;
|
||||
if ((-MACHEP < (b / s)) && ((b / s) < MACHEP)) {
|
||||
return static_cast<scalar_t>(s);
|
||||
}
|
||||
};
|
||||
w = a;
|
||||
s += b * w / (x - 1.0);
|
||||
s -= 0.5 * b;
|
||||
a = 1.0;
|
||||
k = 0.0;
|
||||
for (int i=0; i < 12; i++) {
|
||||
a *= x + k;
|
||||
b /= w;
|
||||
t = a * b / A[i];
|
||||
s = s + t;
|
||||
t = t / s;
|
||||
if (t < 0){
|
||||
t = -t;
|
||||
}
|
||||
if ((-MACHEP <t) && (t < MACHEP)){
|
||||
return static_cast<scalar_t>(s);
|
||||
}
|
||||
k += 1.0;
|
||||
a *= x + k;
|
||||
b /= w;
|
||||
k += 1.0;
|
||||
}
|
||||
return static_cast<scalar_t>(s);
|
||||
}
|
||||
|
||||
/*
|
||||
* For licensing information, please refer to the the cpu implementation located in "ATen/native/Math.h".
|
||||
*/
|
||||
@ -177,12 +95,6 @@ static inline C10_HOST_DEVICE scalar_t calc_trigamma(scalar_t in) {
|
||||
return static_cast<scalar_t>(sign * result);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static inline C10_HOST_DEVICE scalar_t calc_polygamma(int n, scalar_t x) {
|
||||
// already blocked if n <= 1
|
||||
return ((n % 2) ? 1.0 : -1.0) * ::exp(::lgamma(static_cast<scalar_t>(n) + 1.0)) * zeta(static_cast<scalar_t>(n + 1), x);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static inline C10_HOST_DEVICE scalar_t calc_gcd(scalar_t a_in, scalar_t b_in) {
|
||||
scalar_t a = ::abs(a_in);
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/cuda/Math.cuh>
|
||||
#include <ATen/native/Math.h>
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
@ -34,7 +35,7 @@ void polygamma_kernel_cuda(TensorIteratorBase& iter, int64_t n) {
|
||||
} else {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "polygamma_cuda", [&]() {
|
||||
gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
return calc_polygamma(int(n), a);
|
||||
return calc_polygamma<scalar_t, /*is_cuda=*/true>(int(n), a);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
@ -9705,6 +9705,51 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: special_xlog1py_out
|
||||
|
||||
- func: special_zeta(Tensor self, Tensor other) -> Tensor
|
||||
device_check: NoCheck # TensorIterator
|
||||
python_module: special
|
||||
variants: function
|
||||
structured_delegate: special_zeta.out
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: special_zeta
|
||||
|
||||
- func: special_zeta.self_scalar(Scalar self, Tensor other) -> Tensor
|
||||
device_check: NoCheck # TensorIterator
|
||||
python_module: special
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: special_zeta
|
||||
|
||||
- func: special_zeta.other_scalar(Tensor self, Scalar other) -> Tensor
|
||||
device_check: NoCheck # TensorIterator
|
||||
python_module: special
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: special_zeta
|
||||
|
||||
- func: special_zeta.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
structured: True
|
||||
structured_inherits: TensorIteratorBase
|
||||
python_module: special
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU, CUDA: special_zeta_out
|
||||
|
||||
- func: special_zeta.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
python_module: special
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: special_zeta_out
|
||||
|
||||
- func: special_zeta.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
python_module: special
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: special_zeta_out
|
||||
|
||||
- func: special_i0(Tensor self) -> Tensor
|
||||
python_module: special
|
||||
variants: function
|
||||
|
@ -40,3 +40,4 @@ Functions
|
||||
.. autofunction:: round
|
||||
.. autofunction:: sinc
|
||||
.. autofunction:: xlog1py
|
||||
.. autofunction:: zeta
|
||||
|
@ -2810,6 +2810,60 @@ class TestBinaryUfuncs(TestCase):
|
||||
_compare_helper(t, zeros, *xlog1py_fns)
|
||||
_compare_helper(t, 0., *xlog1py_fns)
|
||||
|
||||
@dtypes(*product(torch.testing.get_all_dtypes(include_complex=False,
|
||||
include_half=False, include_bfloat16=False),
|
||||
torch.testing.get_all_dtypes(include_complex=False,
|
||||
include_half=False, include_bfloat16=False)))
|
||||
@skipIf(not TEST_SCIPY, "Scipy required for the test.")
|
||||
def test_zeta(self, device, dtypes):
|
||||
x_dtype, q_dtype = dtypes
|
||||
|
||||
def test_helper(x, q):
|
||||
x_np = x if isinstance(x, float) else x.cpu().numpy()
|
||||
q_np = q if isinstance(q, float) else q.cpu().numpy()
|
||||
expected = torch.from_numpy(scipy.special.zeta(x_np, q_np))
|
||||
actual = torch.special.zeta(x, q)
|
||||
|
||||
rtol, atol = None, None
|
||||
if self.device_type == 'cpu':
|
||||
rtol, atol = 1e-6, 1e-6
|
||||
self.assertEqual(expected, actual, rtol=rtol, atol=atol, exact_dtype=False)
|
||||
|
||||
# x tensor - q tensor same size
|
||||
x = make_tensor((2, 3, 4), device, x_dtype)
|
||||
q = make_tensor((2, 3, 4), device, q_dtype)
|
||||
test_helper(x, q)
|
||||
|
||||
# x tensor - q tensor broadcast lhs
|
||||
x = make_tensor((2, 1, 4), device, x_dtype)
|
||||
q = make_tensor((2, 3, 4), device, q_dtype)
|
||||
test_helper(x, q)
|
||||
|
||||
# x tensor - q tensor broadcast rhs
|
||||
x = make_tensor((2, 3, 4), device, x_dtype)
|
||||
q = make_tensor((2, 1, 4), device, q_dtype)
|
||||
test_helper(x, q)
|
||||
|
||||
# x tensor - q tensor broadcast all
|
||||
x = make_tensor((2, 3, 1), device, x_dtype)
|
||||
q = make_tensor((2, 1, 4), device, q_dtype)
|
||||
test_helper(x, q)
|
||||
|
||||
# x scalar - q tensor
|
||||
for x in np.linspace(-5, 5, num=10).tolist():
|
||||
if not q_dtype.is_floating_point:
|
||||
q_dtype = torch.get_default_dtype()
|
||||
q = make_tensor((2, 3, 4), device, q_dtype)
|
||||
test_helper(x, q)
|
||||
|
||||
# x tensor - q scalar
|
||||
for q in np.linspace(-5, 5, num=10).tolist():
|
||||
if not x_dtype.is_floating_point:
|
||||
x_dtype = torch.get_default_dtype()
|
||||
x = make_tensor((2, 3, 4), device, x_dtype)
|
||||
test_helper(x, q)
|
||||
|
||||
|
||||
tensor_binary_ops = [
|
||||
'__lt__', '__le__',
|
||||
'__gt__', '__ge__',
|
||||
|
@ -2661,6 +2661,7 @@ class TestOperatorSignatures(JitTestCase):
|
||||
'reshape_as',
|
||||
'resize_',
|
||||
'resize_as_',
|
||||
'special.zeta',
|
||||
'stack',
|
||||
'to_sparse',
|
||||
'view',
|
||||
|
@ -1471,6 +1471,7 @@ class TestNormalizeOperators(JitTestCase):
|
||||
"reshape_as",
|
||||
"resize_",
|
||||
"resize_as_",
|
||||
"special.zeta",
|
||||
"to_sparse",
|
||||
"view",
|
||||
"view_as",
|
||||
|
@ -802,6 +802,16 @@
|
||||
self: grad * log1p(other.toDouble())
|
||||
result: auto_element_wise
|
||||
|
||||
- name: special_zeta(Tensor self, Tensor other) -> Tensor
|
||||
self: not_implemented("zeta")
|
||||
other: grad * -self * special_zeta(self + 1., other)
|
||||
|
||||
- name: special_zeta.self_scalar(Scalar self, Tensor other) -> Tensor
|
||||
other: grad * -self * special_zeta(self.toDouble() + 1., other)
|
||||
|
||||
- name: special_zeta.other_scalar(Tensor self, Scalar other) -> Tensor
|
||||
self: not_implemented("zeta")
|
||||
|
||||
- name: logdet(Tensor self) -> Tensor
|
||||
self: logdet_backward(grad, self, result)
|
||||
|
||||
|
@ -238,6 +238,39 @@ inline Tensor& xlog1py_out(Tensor& result, const Tensor& self, const Scalar& oth
|
||||
return torch::special_xlog1py_out(result, self, other);
|
||||
}
|
||||
|
||||
/// Computes Hurwitz Zeta function for inputs, elementwise
|
||||
/// See https://pytorch.org/docs/master/special.html#torch.special.zeta.
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// auto x = torch::randn(128, dtype=kDouble);
|
||||
/// auto y = torch::randn(128, dtype=kDouble);
|
||||
/// torch::special::zeta(x, y);
|
||||
/// ```
|
||||
inline Tensor zeta(const Tensor& self, const Tensor& other) {
|
||||
return torch::special_zeta(self, other);
|
||||
}
|
||||
|
||||
inline Tensor zeta(const Scalar& self, const Tensor& other) {
|
||||
return torch::special_zeta(self, other);
|
||||
}
|
||||
|
||||
inline Tensor zeta(const Tensor& self, const Scalar& other) {
|
||||
return torch::special_zeta(self, other);
|
||||
}
|
||||
|
||||
inline Tensor& zeta_out(Tensor& result, const Tensor& self, const Tensor& other) {
|
||||
return torch::special_zeta_out(result, self, other);
|
||||
}
|
||||
|
||||
inline Tensor& zeta_out(Tensor& result, const Scalar& self, const Tensor& other) {
|
||||
return torch::special_zeta_out(result, self, other);
|
||||
}
|
||||
|
||||
inline Tensor& zeta_out(Tensor& result, const Tensor& self, const Scalar& other) {
|
||||
return torch::special_zeta_out(result, self, other);
|
||||
}
|
||||
|
||||
/// Computes the zeroth order modified Bessel function of the first kind of input, elementwise
|
||||
/// See https://pytorch.org/docs/master/special.html#torch.special.i0
|
||||
///
|
||||
|
@ -918,6 +918,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||
torch.special.ndtri: lambda input: -1,
|
||||
torch.special.ndtr: lambda input: -1,
|
||||
torch.special.xlog1py: lambda input, other, out=None: -1,
|
||||
torch.special.zeta: lambda self, other, out=None: -1,
|
||||
torch.t: lambda input: -1,
|
||||
torch.take: lambda input, index: -1,
|
||||
torch.take_along_dim: lambda input, indices, dim=None, out=None: -1,
|
||||
|
@ -504,3 +504,33 @@ round(input, *, out=None) -> Tensor
|
||||
|
||||
Alias for :func:`torch.round`.
|
||||
""")
|
||||
|
||||
zeta = _add_docstr(_special.special_zeta,
|
||||
r"""
|
||||
zeta(input, other, *, out=None) -> Tensor
|
||||
|
||||
Computes the Hurwitz zeta function, elementwise.
|
||||
|
||||
.. math::
|
||||
\zeta(x, q) = \sum_{k=0}^{\infty} \frac{1}{(k + q)^x}
|
||||
|
||||
""" + r"""
|
||||
Args:
|
||||
input (Tensor): the input tensor corresponding to `x`.
|
||||
other (Tensor): the input tensor corresponding to `q`.
|
||||
|
||||
.. note::
|
||||
The Riemann zeta function corresponds to the case when `q = 1`
|
||||
|
||||
Keyword args:
|
||||
{out}
|
||||
|
||||
Example::
|
||||
>>> x = torch.tensor([2., 4.])
|
||||
>>> torch.special.zeta(x, 1)
|
||||
tensor([1.6449, 1.0823])
|
||||
>>> torch.special.zeta(x, torch.tensor([1., 2.]))
|
||||
tensor([1.6449, 0.0823])
|
||||
>>> torch.special.zeta(2, torch.tensor([1., 2.]))
|
||||
tensor([1.6449, 0.6449])
|
||||
""".format(**common_args))
|
||||
|
@ -1507,6 +1507,7 @@ def sample_inputs_zero_(op_info, device, dtype, requires_grad, **kwargs):
|
||||
|
||||
return list(generator())
|
||||
|
||||
|
||||
def sample_inputs_logsumexp(self, device, dtype, requires_grad):
|
||||
inputs = (
|
||||
((), (0,), True),
|
||||
@ -3857,6 +3858,18 @@ def sample_inputs_entr(op_info, device, dtype, requires_grad, **kwargs):
|
||||
low=low,
|
||||
requires_grad=requires_grad)))
|
||||
|
||||
|
||||
def sample_inputs_zeta(op_info, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
samples = (SampleInput(make_arg((S,), low=1, requires_grad=requires_grad),
|
||||
args=(make_arg((S,), low=2, requires_grad=False),)),
|
||||
SampleInput(make_arg((S,), low=1, requires_grad=requires_grad),
|
||||
args=(3.,)),
|
||||
)
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
# TODO: Consolidate `i0e` with sample_inputs_unary when `make_tensor`,
|
||||
# supports `exclude` argument.
|
||||
# For more context: https://github.com/pytorch/pytorch/pull/56352#discussion_r633277617
|
||||
@ -7312,6 +7325,25 @@ op_db: List[OpInfo] = [
|
||||
safe_casts_outputs=True,
|
||||
supports_forward_ad=True,
|
||||
sample_inputs_func=sample_inputs_xlog1py),
|
||||
OpInfo('special.zeta',
|
||||
aten_name='special_zeta',
|
||||
dtypes=all_types_and(torch.bool),
|
||||
supports_autograd=False,
|
||||
safe_casts_outputs=True,
|
||||
sample_inputs_func=sample_inputs_binary_pwise),
|
||||
# OpInfo entry to verify the gradient formula of `other`/`q`
|
||||
OpInfo('special.zeta',
|
||||
op=lambda q, x, **kwargs: torch.special.zeta(x, q, **kwargs),
|
||||
aten_name='special_zeta',
|
||||
variant_test_name='grad',
|
||||
dtypes=all_types_and(torch.bool),
|
||||
supports_autograd=True,
|
||||
safe_casts_outputs=True,
|
||||
skips=(
|
||||
# Lambda doesn't work in JIT test
|
||||
SkipInfo("TestJit", "test_variant_consistency_jit"),
|
||||
),
|
||||
sample_inputs_func=sample_inputs_zeta),
|
||||
OpInfo('logsumexp',
|
||||
dtypes=floating_types_and(torch.bfloat16),
|
||||
dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.half),
|
||||
|
Reference in New Issue
Block a user