[special] add zeta (#59623)

Summary:
Reference https://github.com/pytorch/pytorch/issues/50345

`zeta` was already present in the codebase to support computation of `polygamma`.

However, `zeta` only had `double(double, double)` signature **for CPU** before the PR (which meant that computation `polygamma` were always upcasted to `double` for zeta part).

With this PR, float computations will take place in float and double in double.

Have also refactored the code and moved the duplicate code from `Math.cuh` to `Math.h`

**Note**: For scipy, q is optional, and if it is `None`, it defaults `1` which corresponds to Reimann-Zeta. However, for `torch.specia.zeta`, I made it mandatory cause for me it feels odd without `q` this is Reimann-Zeta and with `q` it is the general Hurwitz Zeta. I think sticking to just general made more sense as passing `1` for q sounds trivial.

Verify:
* [x] Docs https://14234587-65600975-gh.circle-artifacts.com/0/docs/special.html#torch.special.zeta

Pull Request resolved: https://github.com/pytorch/pytorch/pull/59623

Reviewed By: ngimel

Differential Revision: D29348269

Pulled By: mruberry

fbshipit-source-id: a3f9ebe1f7724dbe66de2b391afb9da1cfc3e4bb
This commit is contained in:
kshitij12345
2021-06-23 23:59:03 -07:00
committed by Facebook GitHub Bot
parent 26cdec6ce4
commit dfd2edc025
18 changed files with 296 additions and 126 deletions

View File

@ -351,6 +351,7 @@ namespace c10 {
_(aten, special_i0e) \
_(aten, special_i1) \
_(aten, special_i1e) \
_(aten, special_zeta) \
_(aten, has_torch_function) \
_(aten, hardswish) \
_(aten, hardswish_) \

View File

@ -57,6 +57,10 @@ TORCH_META_FUNC(special_xlog1py) (const Tensor& self, const Tensor& other) {
build_borrowing_binary_float_op(maybe_get_output(), self, other);
}
TORCH_META_FUNC(special_zeta) (const Tensor& self, const Tensor& other) {
build_borrowing_binary_float_op(maybe_get_output(), self, other);
}
TORCH_META_FUNC2(copysign, Tensor) (
const Tensor& self, const Tensor& other
) {
@ -221,6 +225,7 @@ DEFINE_DISPATCH(copysign_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(xlogy_stub);
DEFINE_DISPATCH(xlog1py_stub);
DEFINE_DISPATCH(zeta_stub);
TORCH_IMPL_FUNC(add_out) (
const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& result
@ -262,6 +267,10 @@ TORCH_IMPL_FUNC(special_xlog1py_out) (const Tensor& self, const Tensor& other, c
xlog1py_stub(device_type(), *this);
}
TORCH_IMPL_FUNC(special_zeta_out) (const Tensor& self, const Tensor& other, const Tensor& result) {
zeta_stub(device_type(), *this);
}
#define CREATE_BINARY_TORCH_IMPL_FUNC(func_out, func_stub) \
TORCH_IMPL_FUNC(func_out) (const Tensor& self, const Tensor& other, const Tensor& result) { \
func_stub(device_type(), *this); \
@ -297,6 +306,22 @@ Tensor& special_xlog1py_out(const Tensor& self, const Scalar& other, Tensor& res
return at::special_xlog1py_out(result, self, wrapped_scalar_tensor(other));
}
Tensor special_zeta(const Scalar& x, const Tensor& y) {
return at::special_zeta(wrapped_scalar_tensor(x), y);
}
Tensor special_zeta(const Tensor& x, const Scalar& y) {
return at::special_zeta(x, wrapped_scalar_tensor(y));
}
Tensor& special_zeta_out(const Scalar& self, const Tensor& other, Tensor& result) {
return at::special_zeta_out(result, wrapped_scalar_tensor(self), other);
}
Tensor& special_zeta_out(const Tensor& self, const Scalar& other, Tensor& result) {
return at::special_zeta_out(result, self, wrapped_scalar_tensor(other));
}
TORCH_IMPL_FUNC(atan2_out) (const Tensor& self, const Tensor& other, const Tensor& result) {
atan2_stub(device_type(), *this);
}

View File

@ -93,5 +93,6 @@ DECLARE_DISPATCH(structured_binary_fn, heaviside_stub);
DECLARE_DISPATCH(structured_binary_fn, copysign_stub);
DECLARE_DISPATCH(binary_fn, xlogy_stub);
DECLARE_DISPATCH(structured_binary_fn, xlog1py_stub);
DECLARE_DISPATCH(structured_binary_fn, zeta_stub);
}} // namespace at::native

View File

@ -10,6 +10,7 @@
#include <c10/util/Half.h>
#include <c10/util/MathConstants.h>
#include <c10/util/math_compat.h>
#include <ATen/AccumulateType.h>
/* The next function is taken from https://github.com/antelopeusersgroup/antelope_contrib/blob/master/lib/location/libgenloc/erfinv.c.
@ -148,9 +149,14 @@ Date: February 1996
* This function is derived from the implementation of the zeta function in the Cephes Math Library.
* See note [3-Clause BSD License for the Cephes Math Library].
*/
static inline double zeta(double x, double q) {
static double MACHEP = 1.11022302462515654042E-16;
static double A[] = {
template <typename scalar_t, bool is_cuda=false>
C10_HOST_DEVICE static inline scalar_t zeta(scalar_t x, scalar_t q) {
using acc_t = at::acc_type<scalar_t, is_cuda>;
const acc_t MACHEP = acc_t{1.11022302462515654042E-16};
constexpr acc_t zero = acc_t{0.0};
constexpr acc_t half = acc_t{0.5};
constexpr acc_t one = acc_t{1.0};
static const acc_t A[] = {
12.0,
-720.0,
30240.0,
@ -166,58 +172,58 @@ static inline double zeta(double x, double q) {
};
int i = 0;
double a, b, k, s, t, w;
if (x == 1.0) {
return INFINITY;
acc_t a, b, k, s, t, w;
if (x == one) {
return std::numeric_limits<scalar_t>::infinity();
}
if (x < 1.0) {
return std::numeric_limits<double>::quiet_NaN();
if (x < one) {
return std::numeric_limits<scalar_t>::quiet_NaN();
}
if (q <= 0.0) {
if (q == floor(q)) {
return INFINITY;
if (q <= zero) {
if (q == ::floor(q)) {
return std::numeric_limits<scalar_t>::infinity();
}
if (x != floor(x)) {
return std::numeric_limits<double>::quiet_NaN();
if (x != ::floor(x)) {
return std::numeric_limits<scalar_t>::quiet_NaN();
}
}
s = std::pow(q, -x);
s = ::pow(q, -x);
a = q;
i = 0;
b = 0.0;
while ((i < 9) || (a <= 9.0)) {
b = zero;
while ((i < 9) || (a <= acc_t{9.0})) {
i += 1;
a += 1.0;
b = std::pow(a, -x);
a += one;
b = ::pow(a, -x);
s += b;
if ((-MACHEP * s < b) && (b < MACHEP * s)) {
return s;
return static_cast<scalar_t>(s);
}
};
w = a;
s += b * w / (x - 1.0);
s -= 0.5 * b;
a = 1.0;
k = 0.0;
s += b * w / (x - one);
s -= half * b;
a = one;
k = zero;
for (int i = 0; i < 12; i++) {
a *= x + k;
b /= w;
t = a * b / A[i];
s = s + t;
t = std::abs(t / s);
t = ::abs(t / s);
if (t < MACHEP) {
return s;
return static_cast<scalar_t>(s);
}
k += 1.0;
k += one;
a *= x + k;
b /= w;
k += 1.0;
k += one;
}
return s;
return static_cast<scalar_t>(s);
}
/*
@ -397,16 +403,12 @@ static inline float calc_digamma(float x) {
return result + logf(x) - (0.5f / x) - y;
}
static inline double calc_polygamma(int64_t n, double x) {
template <typename scalar_t, bool is_cuda=false>
static inline C10_HOST_DEVICE scalar_t calc_polygamma(int n, scalar_t x) {
// already blocked if n <= 1
return ((n % 2) ? 1.0 : -1.0) * std::exp(lgamma(double(n) + 1.0)) *
zeta(double(n + 1), x);
}
static inline float calc_polygamma(int64_t n, float x) {
// already blocked if n <= 1
return ((n % 2) ? 1.0f : -1.0f) * std::exp(lgamma(double(n) + 1.0)) *
zeta(double(n + 1), x);
return ((n % 2) ? 1.0 : -1.0) *
::exp(::lgamma(static_cast<scalar_t>(n) + 1.0)) *
zeta<scalar_t, is_cuda>(static_cast<scalar_t>(n + 1), x);
}
// regularized lower incomplete gamma

View File

@ -989,6 +989,14 @@ void xlog1py_kernel(TensorIteratorBase& iter) {
});
}
void zeta_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "zeta_cpu", [&]() {
cpu_kernel(iter, [](scalar_t x, scalar_t q) -> scalar_t {
return zeta(x, q);
});
});
}
} // namespace
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
@ -1082,6 +1090,7 @@ REGISTER_DISPATCH(copysign_stub, &copysign_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_DISPATCH(xlogy_stub, &xlogy_kernel);
REGISTER_DISPATCH(xlog1py_stub, &xlog1py_kernel);
REGISTER_DISPATCH(zeta_stub, &zeta_kernel);
} // namespace native
} // namespace at

View File

@ -3,6 +3,8 @@
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/BinaryOps.h>
#include <ATen/native/cuda/Math.cuh>
#include <ATen/native/Math.h>
#include <ATen/NumericUtils.h>
// NOTE: CUDA on Windows requires that the enclosing function
@ -67,11 +69,20 @@ void xlog1py_kernel_cuda(TensorIteratorBase& iter) {
});
}
void zeta_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "zeta_cuda", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t x, scalar_t q) -> scalar_t {
return zeta<scalar_t, /*is_cuda=*/true>(x, q);
});
});
}
REGISTER_DISPATCH(smooth_l1_stub, &smooth_l1_kernel_cuda);
REGISTER_DISPATCH(huber_stub, &huber_kernel_cuda);
REGISTER_DISPATCH(mse_stub, &mse_kernel_cuda);
REGISTER_DISPATCH(xlogy_stub, &xlogy_kernel_cuda);
REGISTER_DISPATCH(xlog1py_stub, &xlog1py_kernel_cuda);
REGISTER_DISPATCH(zeta_stub, &zeta_kernel_cuda);
// DO NOT ADD ANY NEW KERNELS HERE
// CUDA compilation times grow quickly. It's perfectly acceptable to have a file per kernel.

View File

@ -6,88 +6,6 @@
namespace at {
namespace native {
/*
* For licensing information, please refer to the the cpu implementation located in "ATen/native/Math.h".
*/
template <typename scalar_t>
static inline C10_HOST_DEVICE scalar_t zeta(scalar_t _x, scalar_t _q) {
using accscalar_t = at::acc_type<scalar_t, true>;
static const accscalar_t MACHEP = 1.11022302462515654042E-16;
const accscalar_t A[] = {
12.0,
-720.0,
30240.0,
-1209600.0,
47900160.0,
-1.8924375803183791606e9, /*1.307674368e12/691*/
7.47242496e10,
-2.950130727918164224e12, /*1.067062284288e16/3617*/
1.1646782814350067249e14, /*5.109094217170944e18/43867*/
-4.5979787224074726105e15, /*8.028576626982912e20/174611*/
1.8152105401943546773e17, /*1.5511210043330985984e23/854513*/
-7.1661652561756670113e18 /*1.6938241367317436694528e27/236364091*/
};
accscalar_t x = static_cast<accscalar_t>(_x);
accscalar_t q = static_cast<accscalar_t>(_q);
int i = 0;
accscalar_t a, b, k, s, t, w;
if( x == 1.0 ) {
return static_cast<scalar_t>(INFINITY);
}
if( x < 1.0 ){
std::numeric_limits<scalar_t>::quiet_NaN();
}
bool q_is_integer = q == ::floor(q);
if(q <= 0.0) {
if(q_is_integer) {
return static_cast<scalar_t>(INFINITY);
}
else {
std::numeric_limits<scalar_t>::quiet_NaN();
}
}
s = ::pow(q, -x);
a = q;
i = 0;
b = 0.0;
while ((i < 9) || (a <= 9.0)) {
i += 1;
a += 1.0;
b = ::pow( a, -x );
s += b;
if ((-MACHEP < (b / s)) && ((b / s) < MACHEP)) {
return static_cast<scalar_t>(s);
}
};
w = a;
s += b * w / (x - 1.0);
s -= 0.5 * b;
a = 1.0;
k = 0.0;
for (int i=0; i < 12; i++) {
a *= x + k;
b /= w;
t = a * b / A[i];
s = s + t;
t = t / s;
if (t < 0){
t = -t;
}
if ((-MACHEP <t) && (t < MACHEP)){
return static_cast<scalar_t>(s);
}
k += 1.0;
a *= x + k;
b /= w;
k += 1.0;
}
return static_cast<scalar_t>(s);
}
/*
* For licensing information, please refer to the the cpu implementation located in "ATen/native/Math.h".
*/
@ -177,12 +95,6 @@ static inline C10_HOST_DEVICE scalar_t calc_trigamma(scalar_t in) {
return static_cast<scalar_t>(sign * result);
}
template <typename scalar_t>
static inline C10_HOST_DEVICE scalar_t calc_polygamma(int n, scalar_t x) {
// already blocked if n <= 1
return ((n % 2) ? 1.0 : -1.0) * ::exp(::lgamma(static_cast<scalar_t>(n) + 1.0)) * zeta(static_cast<scalar_t>(n + 1), x);
}
template <typename scalar_t>
static inline C10_HOST_DEVICE scalar_t calc_gcd(scalar_t a_in, scalar_t b_in) {
scalar_t a = ::abs(a_in);

View File

@ -7,6 +7,7 @@
#include <ATen/native/DispatchStub.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/Math.cuh>
#include <ATen/native/Math.h>
namespace at { namespace native {
@ -34,7 +35,7 @@ void polygamma_kernel_cuda(TensorIteratorBase& iter, int64_t n) {
} else {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "polygamma_cuda", [&]() {
gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t a) -> scalar_t {
return calc_polygamma(int(n), a);
return calc_polygamma<scalar_t, /*is_cuda=*/true>(int(n), a);
});
});
}

View File

@ -9705,6 +9705,51 @@
dispatch:
CompositeExplicitAutograd: special_xlog1py_out
- func: special_zeta(Tensor self, Tensor other) -> Tensor
device_check: NoCheck # TensorIterator
python_module: special
variants: function
structured_delegate: special_zeta.out
dispatch:
CompositeExplicitAutograd: special_zeta
- func: special_zeta.self_scalar(Scalar self, Tensor other) -> Tensor
device_check: NoCheck # TensorIterator
python_module: special
variants: function
dispatch:
CompositeExplicitAutograd: special_zeta
- func: special_zeta.other_scalar(Tensor self, Scalar other) -> Tensor
device_check: NoCheck # TensorIterator
python_module: special
variants: function
dispatch:
CompositeExplicitAutograd: special_zeta
- func: special_zeta.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
structured: True
structured_inherits: TensorIteratorBase
python_module: special
variants: function
dispatch:
CPU, CUDA: special_zeta_out
- func: special_zeta.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
python_module: special
variants: function
dispatch:
CompositeExplicitAutograd: special_zeta_out
- func: special_zeta.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
python_module: special
variants: function
dispatch:
CompositeExplicitAutograd: special_zeta_out
- func: special_i0(Tensor self) -> Tensor
python_module: special
variants: function

View File

@ -40,3 +40,4 @@ Functions
.. autofunction:: round
.. autofunction:: sinc
.. autofunction:: xlog1py
.. autofunction:: zeta

View File

@ -2810,6 +2810,60 @@ class TestBinaryUfuncs(TestCase):
_compare_helper(t, zeros, *xlog1py_fns)
_compare_helper(t, 0., *xlog1py_fns)
@dtypes(*product(torch.testing.get_all_dtypes(include_complex=False,
include_half=False, include_bfloat16=False),
torch.testing.get_all_dtypes(include_complex=False,
include_half=False, include_bfloat16=False)))
@skipIf(not TEST_SCIPY, "Scipy required for the test.")
def test_zeta(self, device, dtypes):
x_dtype, q_dtype = dtypes
def test_helper(x, q):
x_np = x if isinstance(x, float) else x.cpu().numpy()
q_np = q if isinstance(q, float) else q.cpu().numpy()
expected = torch.from_numpy(scipy.special.zeta(x_np, q_np))
actual = torch.special.zeta(x, q)
rtol, atol = None, None
if self.device_type == 'cpu':
rtol, atol = 1e-6, 1e-6
self.assertEqual(expected, actual, rtol=rtol, atol=atol, exact_dtype=False)
# x tensor - q tensor same size
x = make_tensor((2, 3, 4), device, x_dtype)
q = make_tensor((2, 3, 4), device, q_dtype)
test_helper(x, q)
# x tensor - q tensor broadcast lhs
x = make_tensor((2, 1, 4), device, x_dtype)
q = make_tensor((2, 3, 4), device, q_dtype)
test_helper(x, q)
# x tensor - q tensor broadcast rhs
x = make_tensor((2, 3, 4), device, x_dtype)
q = make_tensor((2, 1, 4), device, q_dtype)
test_helper(x, q)
# x tensor - q tensor broadcast all
x = make_tensor((2, 3, 1), device, x_dtype)
q = make_tensor((2, 1, 4), device, q_dtype)
test_helper(x, q)
# x scalar - q tensor
for x in np.linspace(-5, 5, num=10).tolist():
if not q_dtype.is_floating_point:
q_dtype = torch.get_default_dtype()
q = make_tensor((2, 3, 4), device, q_dtype)
test_helper(x, q)
# x tensor - q scalar
for q in np.linspace(-5, 5, num=10).tolist():
if not x_dtype.is_floating_point:
x_dtype = torch.get_default_dtype()
x = make_tensor((2, 3, 4), device, x_dtype)
test_helper(x, q)
tensor_binary_ops = [
'__lt__', '__le__',
'__gt__', '__ge__',

View File

@ -2661,6 +2661,7 @@ class TestOperatorSignatures(JitTestCase):
'reshape_as',
'resize_',
'resize_as_',
'special.zeta',
'stack',
'to_sparse',
'view',

View File

@ -1471,6 +1471,7 @@ class TestNormalizeOperators(JitTestCase):
"reshape_as",
"resize_",
"resize_as_",
"special.zeta",
"to_sparse",
"view",
"view_as",

View File

@ -802,6 +802,16 @@
self: grad * log1p(other.toDouble())
result: auto_element_wise
- name: special_zeta(Tensor self, Tensor other) -> Tensor
self: not_implemented("zeta")
other: grad * -self * special_zeta(self + 1., other)
- name: special_zeta.self_scalar(Scalar self, Tensor other) -> Tensor
other: grad * -self * special_zeta(self.toDouble() + 1., other)
- name: special_zeta.other_scalar(Tensor self, Scalar other) -> Tensor
self: not_implemented("zeta")
- name: logdet(Tensor self) -> Tensor
self: logdet_backward(grad, self, result)

View File

@ -238,6 +238,39 @@ inline Tensor& xlog1py_out(Tensor& result, const Tensor& self, const Scalar& oth
return torch::special_xlog1py_out(result, self, other);
}
/// Computes Hurwitz Zeta function for inputs, elementwise
/// See https://pytorch.org/docs/master/special.html#torch.special.zeta.
///
/// Example:
/// ```
/// auto x = torch::randn(128, dtype=kDouble);
/// auto y = torch::randn(128, dtype=kDouble);
/// torch::special::zeta(x, y);
/// ```
inline Tensor zeta(const Tensor& self, const Tensor& other) {
return torch::special_zeta(self, other);
}
inline Tensor zeta(const Scalar& self, const Tensor& other) {
return torch::special_zeta(self, other);
}
inline Tensor zeta(const Tensor& self, const Scalar& other) {
return torch::special_zeta(self, other);
}
inline Tensor& zeta_out(Tensor& result, const Tensor& self, const Tensor& other) {
return torch::special_zeta_out(result, self, other);
}
inline Tensor& zeta_out(Tensor& result, const Scalar& self, const Tensor& other) {
return torch::special_zeta_out(result, self, other);
}
inline Tensor& zeta_out(Tensor& result, const Tensor& self, const Scalar& other) {
return torch::special_zeta_out(result, self, other);
}
/// Computes the zeroth order modified Bessel function of the first kind of input, elementwise
/// See https://pytorch.org/docs/master/special.html#torch.special.i0
///

View File

@ -918,6 +918,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.special.ndtri: lambda input: -1,
torch.special.ndtr: lambda input: -1,
torch.special.xlog1py: lambda input, other, out=None: -1,
torch.special.zeta: lambda self, other, out=None: -1,
torch.t: lambda input: -1,
torch.take: lambda input, index: -1,
torch.take_along_dim: lambda input, indices, dim=None, out=None: -1,

View File

@ -504,3 +504,33 @@ round(input, *, out=None) -> Tensor
Alias for :func:`torch.round`.
""")
zeta = _add_docstr(_special.special_zeta,
r"""
zeta(input, other, *, out=None) -> Tensor
Computes the Hurwitz zeta function, elementwise.
.. math::
\zeta(x, q) = \sum_{k=0}^{\infty} \frac{1}{(k + q)^x}
""" + r"""
Args:
input (Tensor): the input tensor corresponding to `x`.
other (Tensor): the input tensor corresponding to `q`.
.. note::
The Riemann zeta function corresponds to the case when `q = 1`
Keyword args:
{out}
Example::
>>> x = torch.tensor([2., 4.])
>>> torch.special.zeta(x, 1)
tensor([1.6449, 1.0823])
>>> torch.special.zeta(x, torch.tensor([1., 2.]))
tensor([1.6449, 0.0823])
>>> torch.special.zeta(2, torch.tensor([1., 2.]))
tensor([1.6449, 0.6449])
""".format(**common_args))

View File

@ -1507,6 +1507,7 @@ def sample_inputs_zero_(op_info, device, dtype, requires_grad, **kwargs):
return list(generator())
def sample_inputs_logsumexp(self, device, dtype, requires_grad):
inputs = (
((), (0,), True),
@ -3857,6 +3858,18 @@ def sample_inputs_entr(op_info, device, dtype, requires_grad, **kwargs):
low=low,
requires_grad=requires_grad)))
def sample_inputs_zeta(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
samples = (SampleInput(make_arg((S,), low=1, requires_grad=requires_grad),
args=(make_arg((S,), low=2, requires_grad=False),)),
SampleInput(make_arg((S,), low=1, requires_grad=requires_grad),
args=(3.,)),
)
return samples
# TODO: Consolidate `i0e` with sample_inputs_unary when `make_tensor`,
# supports `exclude` argument.
# For more context: https://github.com/pytorch/pytorch/pull/56352#discussion_r633277617
@ -7312,6 +7325,25 @@ op_db: List[OpInfo] = [
safe_casts_outputs=True,
supports_forward_ad=True,
sample_inputs_func=sample_inputs_xlog1py),
OpInfo('special.zeta',
aten_name='special_zeta',
dtypes=all_types_and(torch.bool),
supports_autograd=False,
safe_casts_outputs=True,
sample_inputs_func=sample_inputs_binary_pwise),
# OpInfo entry to verify the gradient formula of `other`/`q`
OpInfo('special.zeta',
op=lambda q, x, **kwargs: torch.special.zeta(x, q, **kwargs),
aten_name='special_zeta',
variant_test_name='grad',
dtypes=all_types_and(torch.bool),
supports_autograd=True,
safe_casts_outputs=True,
skips=(
# Lambda doesn't work in JIT test
SkipInfo("TestJit", "test_variant_consistency_jit"),
),
sample_inputs_func=sample_inputs_zeta),
OpInfo('logsumexp',
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.half),