mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[special] Add i0e
(#54409)
Summary: Reference: https://github.com/pytorch/pytorch/issues/50345 Changes: * Add `i0e` * Move some kernels from `UnaryOpsKernel.cu` to `UnarySpecialOpsKernel.cu` to decrease compilation time per file. Time taken by i0e_vs_scipy tests: around 6.33.s <details> <summary>Test Run Log</summary> ``` (pytorch-cuda-dev) kshiteej@qgpu1:~/Pytorch/pytorch_module_special$ pytest test/test_unary_ufuncs.py -k _i0e_vs ======================================================================= test session starts ======================================================================== platform linux -- Python 3.8.6, pytest-6.1.2, py-1.9.0, pluggy-0.13.1 rootdir: /home/kshiteej/Pytorch/pytorch_module_special, configfile: pytest.ini plugins: hypothesis-5.38.1 collected 8843 items / 8833 deselected / 10 selected test/test_unary_ufuncs.py ...sss.... [100%] ========================================================================= warnings summary ========================================================================= ../../.conda/envs/pytorch-cuda-dev/lib/python3.8/site-packages/torch/backends/cudnn/__init__.py:73 test/test_unary_ufuncs.py::TestUnaryUfuncsCUDA::test_special_i0e_vs_scipy_cuda_bfloat16 /home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.8/site-packages/torch/backends/cudnn/__init__.py:73: UserWarning: PyTorch was compiled without cuDNN/MIOpen support. To use cuDNN/MIOpen, rebuild PyTorch making sure the library is visible to the build system. warnings.warn( -- Docs: https://docs.pytest.org/en/stable/warnings.html ===================================================================== short test summary info ====================================================================== SKIPPED [3] test/test_unary_ufuncs.py:1182: not implemented: Could not run 'aten::_copy_from' with arguments from the 'Meta' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::_copy_from' is only available for these backends: [BackendSelect, Named, InplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, UNKNOWN_TENSOR_TYPE_ID, AutogradMLC, AutogradNestedTensor, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, Tracer, Autocast, Batched, VmapMode]. BackendSelect: fallthrough registered at ../aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback] Named: registered at ../aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback] InplaceOrView: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:56 [backend fallback] AutogradOther: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:8761 [autograd kernel] AutogradCPU: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:8761 [autograd kernel] AutogradCUDA: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:8761 [autograd kernel] AutogradXLA: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:8761 [autograd kernel] UNKNOWN_TENSOR_TYPE_ID: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:8761 [autograd kernel] AutogradMLC: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:8761 [autograd kernel] AutogradNestedTensor: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:8761 [autograd kernel] AutogradPrivateUse1: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:8761 [autograd kernel] AutogradPrivateUse2: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:8761 [autograd kernel] AutogradPrivateUse3: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:8761 [autograd kernel] Tracer: registered at ../torch/csrc/autograd/generated/TraceType_4.cpp:9348 [kernel] Autocast: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:250 [backend fallback] Batched: registered at ../aten/src/ATen/BatchingRegistrations.cpp:1016 [backend fallback] VmapMode: fallthrough registered at ../aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback] ==================================================== 7 passed, 3 skipped, 8833 deselected, 2 warnings in 6.33s ===================================================== ``` </details> TODO: * [x] Check rendered docs (https://11743402-65600975-gh.circle-artifacts.com/0/docs/special.html) Pull Request resolved: https://github.com/pytorch/pytorch/pull/54409 Reviewed By: jbschlosser Differential Revision: D27760472 Pulled By: mruberry fbshipit-source-id: bdfbcaa798b00c51dc9513c34626246c8fc10548
This commit is contained in:
committed by
Facebook GitHub Bot
parent
2f895f790a
commit
50057e560b
@ -508,6 +508,7 @@ filegroup(
|
||||
"aten/src/ATen/native/cuda/TensorTransformations.cu.cc",
|
||||
"aten/src/ATen/native/cuda/TriangularOps.cu.cc",
|
||||
"aten/src/ATen/native/cuda/UnaryOpsKernel.cu.cc",
|
||||
"aten/src/ATen/native/cuda/UnarySpecialOpsKernel.cu.cc",
|
||||
"aten/src/ATen/native/cuda/Unique.cu.cc",
|
||||
"aten/src/ATen/native/cuda/UpSampleBicubic2d.cu.cc",
|
||||
"aten/src/ATen/native/cuda/UpSampleBilinear2d.cu.cc",
|
||||
|
@ -327,6 +327,7 @@ namespace c10 {
|
||||
_(aten, special_expm1) \
|
||||
_(aten, exp2) \
|
||||
_(aten, special_exp2) \
|
||||
_(aten, special_i0e) \
|
||||
_(aten, has_torch_function) \
|
||||
FORALL_ATEN_BASE_SYMBOLS(_) \
|
||||
_(onnx, Add) \
|
||||
|
@ -392,6 +392,9 @@ public:
|
||||
Vec256<T> i0() const {
|
||||
return map(calc_i0);
|
||||
}
|
||||
Vec256<T> i0e() const {
|
||||
return map(calc_i0e);
|
||||
}
|
||||
Vec256<T> igamma(const Vec256<T> &x) const {
|
||||
Vec256<T> ret;
|
||||
for (int64_t i = 0; i < size(); i++) {
|
||||
|
@ -315,6 +315,22 @@ public:
|
||||
auto o2 = _mm256_loadu_ps(tmp2);
|
||||
return cvtfp32_bf16(o1, o2);
|
||||
}
|
||||
Vec256<BFloat16> i0e() const {
|
||||
__m256 lo, hi;
|
||||
cvtbf16_fp32(values, lo, hi);
|
||||
auto sz = size();
|
||||
__at_align32__ float tmp1[sz / 2], tmp2[sz / 2];
|
||||
_mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
|
||||
_mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
|
||||
|
||||
for (decltype(sz) i = 0; i < sz / 2; i++) {
|
||||
tmp1[i] = calc_i0e(tmp1[i]);
|
||||
tmp2[i] = calc_i0e(tmp2[i]);
|
||||
}
|
||||
auto o1 = _mm256_loadu_ps(tmp1);
|
||||
auto o2 = _mm256_loadu_ps(tmp2);
|
||||
return cvtfp32_bf16(o1, o2);
|
||||
}
|
||||
Vec256<BFloat16> igamma(const Vec256<BFloat16> &x) const {
|
||||
__m256 lo, hi;
|
||||
__m256 xlo, xhi;
|
||||
|
@ -164,6 +164,9 @@ public:
|
||||
Vec256<double> i0() const {
|
||||
return map(calc_i0);
|
||||
}
|
||||
Vec256<double> i0e() const {
|
||||
return map(calc_i0e);
|
||||
}
|
||||
Vec256<double> igamma(const Vec256<double> &x) const {
|
||||
__at_align32__ double tmp[size()];
|
||||
__at_align32__ double tmp_x[size()];
|
||||
|
@ -202,6 +202,9 @@ public:
|
||||
Vec256<float> i0() const {
|
||||
return map(calc_i0);
|
||||
}
|
||||
Vec256<float> i0e() const {
|
||||
return map(calc_i0e);
|
||||
}
|
||||
Vec256<float> igamma(const Vec256<float> &x) const {
|
||||
__at_align32__ float tmp[size()];
|
||||
__at_align32__ float tmp_x[size()];
|
||||
|
@ -362,6 +362,9 @@ public:
|
||||
Vec256<float> i0() const {
|
||||
return map(calc_i0);
|
||||
}
|
||||
Vec256<float> i0e() const {
|
||||
return map(calc_i0e);
|
||||
}
|
||||
Vec256<float> igamma(const Vec256<float> &x) const {
|
||||
__at_align32__ float tmp[size()];
|
||||
__at_align32__ float tmp_x[size()];
|
||||
|
@ -351,6 +351,10 @@ class Vec256<double> {
|
||||
return map(calc_i0);
|
||||
}
|
||||
|
||||
Vec256<double> i0e() const {
|
||||
return map(calc_i0e);
|
||||
}
|
||||
|
||||
DEFINE_MEMBER_OP(operator==, double, vec_cmpeq)
|
||||
DEFINE_MEMBER_OP(operator!=, double, vec_cmpne)
|
||||
DEFINE_MEMBER_OP(operator<, double, vec_cmplt)
|
||||
|
@ -637,6 +637,10 @@ class Vec256<float> {
|
||||
return map(calc_i0);
|
||||
}
|
||||
|
||||
Vec256<float> i0e() const {
|
||||
return map(calc_i0e);
|
||||
}
|
||||
|
||||
DEFINE_MEMBER_OP(operator==, float, vec_cmpeq)
|
||||
DEFINE_MEMBER_OP(operator!=, float, vec_cmpne)
|
||||
DEFINE_MEMBER_OP(operator<, float, vec_cmplt)
|
||||
|
@ -129,6 +129,7 @@ IMPLEMENT_VML_BUG(exp)
|
||||
IMPLEMENT_VML_BUG(expm1)
|
||||
IMPLEMENT_VML_BUG(floor)
|
||||
IMPLEMENT_VML(i0)
|
||||
IMPLEMENT_VML(i0e)
|
||||
IMPLEMENT_VML(reciprocal)
|
||||
IMPLEMENT_VML_BUG(log)
|
||||
IMPLEMENT_VML_BUG(log10)
|
||||
|
@ -1161,7 +1161,7 @@ calc_gcd(T a, T b) {
|
||||
*/
|
||||
template <typename T>
|
||||
static inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
|
||||
chbevl(T x, T array[], size_t len) {
|
||||
chbevl(const T x,const T array[], size_t len) {
|
||||
T b0, b1, b2;
|
||||
|
||||
b0 = array[0];
|
||||
@ -1186,87 +1186,98 @@ static inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
|
||||
* of all inputs to convert them into the domain of the approximation.
|
||||
*/
|
||||
template <typename T>
|
||||
static inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
|
||||
calc_i0(T _x) {
|
||||
T x = std::abs(_x);
|
||||
inline const T* chebyshev_coefficients_A(){
|
||||
/* Chebyshev coefficients for exp(-x) I0(x)
|
||||
* in the interval [0,8].
|
||||
*
|
||||
* lim(x->0){ exp(-x) I0(x) } = 1.
|
||||
*/
|
||||
static T A[] = {
|
||||
-4.41534164647933937950E-18,
|
||||
3.33079451882223809783E-17,
|
||||
-2.43127984654795469359E-16,
|
||||
1.71539128555513303061E-15,
|
||||
-1.16853328779934516808E-14,
|
||||
7.67618549860493561688E-14,
|
||||
-4.85644678311192946090E-13,
|
||||
2.95505266312963983461E-12,
|
||||
-1.72682629144155570723E-11,
|
||||
9.67580903537323691224E-11,
|
||||
-5.18979560163526290666E-10,
|
||||
2.65982372468238665035E-9,
|
||||
-1.30002500998624804212E-8,
|
||||
6.04699502254191894932E-8,
|
||||
-2.67079385394061173391E-7,
|
||||
1.11738753912010371815E-6,
|
||||
-4.41673835845875056359E-6,
|
||||
1.64484480707288970893E-5,
|
||||
-5.75419501008210370398E-5,
|
||||
1.88502885095841655729E-4,
|
||||
-5.76375574538582365885E-4,
|
||||
1.63947561694133579842E-3,
|
||||
-4.32430999505057594430E-3,
|
||||
1.05464603945949983183E-2,
|
||||
-2.37374148058994688156E-2,
|
||||
4.93052842396707084878E-2,
|
||||
-9.49010970480476444210E-2,
|
||||
1.71620901522208775349E-1,
|
||||
-3.04682672343198398683E-1,
|
||||
6.76795274409476084995E-1
|
||||
};
|
||||
static const T coeff[] = {
|
||||
-4.41534164647933937950E-18, 3.33079451882223809783E-17,
|
||||
-2.43127984654795469359E-16, 1.71539128555513303061E-15,
|
||||
-1.16853328779934516808E-14, 7.67618549860493561688E-14,
|
||||
-4.85644678311192946090E-13, 2.95505266312963983461E-12,
|
||||
-1.72682629144155570723E-11, 9.67580903537323691224E-11,
|
||||
-5.18979560163526290666E-10, 2.65982372468238665035E-9,
|
||||
-1.30002500998624804212E-8, 6.04699502254191894932E-8,
|
||||
-2.67079385394061173391E-7, 1.11738753912010371815E-6,
|
||||
-4.41673835845875056359E-6, 1.64484480707288970893E-5,
|
||||
-5.75419501008210370398E-5, 1.88502885095841655729E-4,
|
||||
-5.76375574538582365885E-4, 1.63947561694133579842E-3,
|
||||
-4.32430999505057594430E-3, 1.05464603945949983183E-2,
|
||||
-2.37374148058994688156E-2, 4.93052842396707084878E-2,
|
||||
-9.49010970480476444210E-2, 1.71620901522208775349E-1,
|
||||
-3.04682672343198398683E-1, 6.76795274409476084995E-1};
|
||||
return coeff;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline const T* chebyshev_coefficients_B(){
|
||||
/* Chebyshev coefficients for exp(-x) sqrt(x) I0(x)
|
||||
* in the inverted interval [8,infinity].
|
||||
*
|
||||
* lim(x->inf){ exp(-x) sqrt(x) I0(x) } = 1/sqrt(2pi).
|
||||
*/
|
||||
static T B[] = {
|
||||
-7.23318048787475395456E-18,
|
||||
-4.83050448594418207126E-18,
|
||||
4.46562142029675999901E-17,
|
||||
3.46122286769746109310E-17,
|
||||
-2.82762398051658348494E-16,
|
||||
-3.42548561967721913462E-16,
|
||||
1.77256013305652638360E-15,
|
||||
3.81168066935262242075E-15,
|
||||
-9.55484669882830764870E-15,
|
||||
-4.15056934728722208663E-14,
|
||||
1.54008621752140982691E-14,
|
||||
3.85277838274214270114E-13,
|
||||
7.18012445138366623367E-13,
|
||||
-1.79417853150680611778E-12,
|
||||
-1.32158118404477131188E-11,
|
||||
-3.14991652796324136454E-11,
|
||||
1.18891471078464383424E-11,
|
||||
4.94060238822496958910E-10,
|
||||
3.39623202570838634515E-9,
|
||||
2.26666899049817806459E-8,
|
||||
2.04891858946906374183E-7,
|
||||
2.89137052083475648297E-6,
|
||||
6.88975834691682398426E-5,
|
||||
3.36911647825569408990E-3,
|
||||
8.04490411014108831608E-1
|
||||
};
|
||||
static const T coeff[] = {
|
||||
-7.23318048787475395456E-18, -4.83050448594418207126E-18,
|
||||
4.46562142029675999901E-17, 3.46122286769746109310E-17,
|
||||
-2.82762398051658348494E-16, -3.42548561967721913462E-16,
|
||||
1.77256013305652638360E-15, 3.81168066935262242075E-15,
|
||||
-9.55484669882830764870E-15, -4.15056934728722208663E-14,
|
||||
1.54008621752140982691E-14, 3.85277838274214270114E-13,
|
||||
7.18012445138366623367E-13, -1.79417853150680611778E-12,
|
||||
-1.32158118404477131188E-11, -3.14991652796324136454E-11,
|
||||
1.18891471078464383424E-11, 4.94060238822496958910E-10,
|
||||
3.39623202570838634515E-9, 2.26666899049817806459E-8,
|
||||
2.04891858946906374183E-7, 2.89137052083475648297E-6,
|
||||
6.88975834691682398426E-5, 3.36911647825569408990E-3,
|
||||
8.04490411014108831608E-1};
|
||||
|
||||
return coeff;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
static inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
|
||||
calc_i0(T _x) {
|
||||
T x = std::abs(_x);
|
||||
|
||||
if (x <= 8.0) {
|
||||
const auto A = chebyshev_coefficients_A<T>();
|
||||
T y = (x / 2.0) - 2.0;
|
||||
return static_cast<T>(std::exp(x) * chbevl(y, A, 30));
|
||||
}
|
||||
|
||||
const auto B = chebyshev_coefficients_B<T>();
|
||||
return static_cast<T>(std::exp(x) * chbevl(static_cast<T>(32.0 / x - 2.0), B, 25) / std::sqrt(x));
|
||||
}
|
||||
|
||||
// Upcast bfloat16 input to float for numerical accuracy purposes
|
||||
inline c10::BFloat16 calc_i0(c10::BFloat16 a) { return calc_i0(static_cast<float>(a)); }
|
||||
|
||||
/*
|
||||
* This function is derived from the implementation of the i0e function in the Cephes Math Library.
|
||||
* See note [3-Clause BSD License for the Cephes Math Library].
|
||||
*
|
||||
* Computes an approximation of the exponentially scaled zeroth order modified Bessel function of the first kind.
|
||||
* The approximation is actually two (sub)approximations, both using a Chebyshev polynomial expansion.
|
||||
* One approximates the function over [0, 8], and the other over (8, infinity). This function takes the absolute value
|
||||
* of all inputs to convert them into the domain of the approximation.
|
||||
*/
|
||||
template <typename T>
|
||||
static inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
|
||||
calc_i0e(T _x) {
|
||||
T x = std::abs(_x);
|
||||
|
||||
if (x <= 8.0) {
|
||||
auto A = chebyshev_coefficients_A<T>();
|
||||
T y = (x / 2.0) - 2.0;
|
||||
return static_cast<T>(chbevl(y, A, 30));
|
||||
}
|
||||
|
||||
auto B = chebyshev_coefficients_B<T>();
|
||||
return static_cast<T>(
|
||||
chbevl(static_cast<T>(32.0 / x - 2.0), B, 25) / std::sqrt(x));
|
||||
}
|
||||
|
||||
// Upcast bfloat16 input to float for numerical accuracy purposes
|
||||
inline c10::BFloat16 calc_i0e(c10::BFloat16 a) { return calc_i0e(static_cast<float>(a)); }
|
||||
|
@ -40,6 +40,7 @@ CREATE_UNARY_META_FUNC(sinh)
|
||||
CREATE_UNARY_META_FUNC(cosh)
|
||||
CREATE_UNARY_META_FUNC(acosh)
|
||||
CREATE_UNARY_META_FUNC(cos)
|
||||
CREATE_UNARY_META_FUNC(special_i0e)
|
||||
|
||||
} // namespace meta
|
||||
|
||||
@ -354,6 +355,10 @@ Tensor& i0_out(const Tensor& self, Tensor& result) { return unary_op_impl_out(re
|
||||
Tensor i0(const Tensor& self) { return unary_op_impl(self, at::i0_out); }
|
||||
Tensor& i0_(Tensor& self) { return unary_op_impl_(self, at::i0_out); }
|
||||
|
||||
TORCH_IMPL_FUNC(special_i0e_out) (const Tensor& self, const Tensor& result) {
|
||||
i0e_stub(device_type(), *this);
|
||||
}
|
||||
|
||||
Tensor& log_out(const Tensor& self, Tensor& result) { return unary_op_impl_float_out(result, self, log_stub); }
|
||||
Tensor log(const Tensor& self) { return unary_op_impl_float(self, log_stub); }
|
||||
Tensor& log_(Tensor& self) { return unary_op_impl_(self, at::log_out); }
|
||||
@ -804,6 +809,7 @@ DEFINE_DISPATCH(floor_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-
|
||||
DEFINE_DISPATCH(frac_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(frexp_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(i0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(i0e_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(log_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(log10_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(log1p_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
@ -829,5 +835,6 @@ DEFINE_DISPATCH(tanh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-v
|
||||
DEFINE_DISPATCH(trigamma_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(trunc_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(lgamma_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
@ -43,6 +43,7 @@ DECLARE_DISPATCH(unary_fn, floor_stub);
|
||||
DECLARE_DISPATCH(unary_fn, frac_stub);
|
||||
DECLARE_DISPATCH(unary_fn, frexp_stub);
|
||||
DECLARE_DISPATCH(unary_fn, i0_stub);
|
||||
DECLARE_DISPATCH(structured_unary_fn, i0e_stub);
|
||||
DECLARE_DISPATCH(unary_fn, log_stub);
|
||||
DECLARE_DISPATCH(unary_fn, log10_stub);
|
||||
DECLARE_DISPATCH(unary_fn, log1p_stub);
|
||||
|
@ -631,6 +631,17 @@ static void frexp_kernel(TensorIterator& iter) {
|
||||
});
|
||||
}
|
||||
|
||||
static void i0e_kernel(TensorIteratorBase& iter) {
|
||||
TORCH_INTERNAL_ASSERT(iter.ntensors() == 2);
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(
|
||||
kBFloat16, iter.common_dtype(), "i0e_cpu", [&]() {
|
||||
cpu_kernel_vec(
|
||||
iter,
|
||||
[](scalar_t x) { return calc_i0e(x); },
|
||||
[](Vec256<scalar_t> x) { return x.i0e(); });
|
||||
});
|
||||
}
|
||||
|
||||
// TODO: Disable cont. branch to test more risky code
|
||||
|
||||
#define IMPLEMENT_ITERATOR_LAMBDA(op) \
|
||||
@ -736,6 +747,7 @@ REGISTER_DISPATCH(clamp_min_stub, &clamp_min_kernel);
|
||||
REGISTER_DISPATCH(kaiser_window_stub, &kaiser_window_kernel);
|
||||
REGISTER_DISPATCH(entr_stub, &entr_kernel);
|
||||
REGISTER_DISPATCH(frexp_stub, &frexp_kernel);
|
||||
REGISTER_DISPATCH(i0e_stub, &i0e_kernel);
|
||||
|
||||
|
||||
IMPLEMENT_COMPLEX_KERNEL(acos)
|
||||
|
@ -10,7 +10,7 @@ namespace native {
|
||||
* For licensing information, please refer to the the cpu implementation located in "ATen/native/Math.h".
|
||||
*/
|
||||
template <typename scalar_t>
|
||||
static inline __host__ __device__ scalar_t zeta(scalar_t _x, scalar_t _q) {
|
||||
static inline C10_HOST_DEVICE scalar_t zeta(scalar_t _x, scalar_t _q) {
|
||||
using accscalar_t = at::acc_type<scalar_t, true>;
|
||||
static const accscalar_t MACHEP = 1.11022302462515654042E-16;
|
||||
const accscalar_t A[] = {
|
||||
@ -92,7 +92,7 @@ static inline __host__ __device__ scalar_t zeta(scalar_t _x, scalar_t _q) {
|
||||
* For licensing information, please refer to the the cpu implementation located in "ATen/native/Math.h".
|
||||
*/
|
||||
template <typename scalar_t>
|
||||
static inline __host__ __device__ scalar_t calc_digamma(scalar_t in) {
|
||||
static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) {
|
||||
// [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma
|
||||
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
|
||||
static const double PI_f64 = 3.14159265358979323846;
|
||||
@ -151,7 +151,7 @@ static inline __host__ __device__ scalar_t calc_digamma(scalar_t in) {
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static inline __host__ __device__ scalar_t calc_trigamma(scalar_t in) {
|
||||
static inline C10_HOST_DEVICE scalar_t calc_trigamma(scalar_t in) {
|
||||
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
|
||||
const accscalar_t PI = 3.14159265358979323846;
|
||||
accscalar_t x = static_cast<accscalar_t>(in);
|
||||
@ -174,7 +174,7 @@ static inline __host__ __device__ scalar_t calc_trigamma(scalar_t in) {
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static inline __host__ __device__ scalar_t calc_polygamma(int n, scalar_t x) {
|
||||
static inline C10_HOST_DEVICE scalar_t calc_polygamma(int n, scalar_t x) {
|
||||
// already blocked if n <= 1
|
||||
return ((n % 2) ? 1.0 : -1.0) * ::exp(::lgamma(static_cast<scalar_t>(n) + 1.0)) * zeta(static_cast<scalar_t>(n + 1), x);
|
||||
}
|
||||
@ -216,6 +216,58 @@ static inline C10_HOST_DEVICE scalar_t chbevl(scalar_t _x, const scalar_t array[
|
||||
/*
|
||||
* For licensing information and documentation, please refer to the the cpu implementation located in "ATen/native/Math.h".
|
||||
*/
|
||||
template <typename T>
|
||||
C10_HOST_DEVICE inline const T* chebyshev_coefficients_A() {
|
||||
/* Chebyshev coefficients for exp(-x) I0(x)
|
||||
* in the interval [0,8].
|
||||
*
|
||||
* lim(x->0){ exp(-x) I0(x) } = 1.
|
||||
*/
|
||||
static const T coefficients[] = {
|
||||
-4.41534164647933937950E-18, 3.33079451882223809783E-17,
|
||||
-2.43127984654795469359E-16, 1.71539128555513303061E-15,
|
||||
-1.16853328779934516808E-14, 7.67618549860493561688E-14,
|
||||
-4.85644678311192946090E-13, 2.95505266312963983461E-12,
|
||||
-1.72682629144155570723E-11, 9.67580903537323691224E-11,
|
||||
-5.18979560163526290666E-10, 2.65982372468238665035E-9,
|
||||
-1.30002500998624804212E-8, 6.04699502254191894932E-8,
|
||||
-2.67079385394061173391E-7, 1.11738753912010371815E-6,
|
||||
-4.41673835845875056359E-6, 1.64484480707288970893E-5,
|
||||
-5.75419501008210370398E-5, 1.88502885095841655729E-4,
|
||||
-5.76375574538582365885E-4, 1.63947561694133579842E-3,
|
||||
-4.32430999505057594430E-3, 1.05464603945949983183E-2,
|
||||
-2.37374148058994688156E-2, 4.93052842396707084878E-2,
|
||||
-9.49010970480476444210E-2, 1.71620901522208775349E-1,
|
||||
-3.04682672343198398683E-1, 6.76795274409476084995E-1};
|
||||
|
||||
return coefficients;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
C10_HOST_DEVICE inline const T* chebyshev_coefficients_B() {
|
||||
/* Chebyshev coefficients for exp(-x) sqrt(x) I0(x)
|
||||
* in the inverted interval [8,infinity].
|
||||
*
|
||||
* lim(x->inf){ exp(-x) sqrt(x) I0(x) } = 1/sqrt(2pi).
|
||||
*/
|
||||
static const T coefficients[] = {
|
||||
-7.23318048787475395456E-18, -4.83050448594418207126E-18,
|
||||
4.46562142029675999901E-17, 3.46122286769746109310E-17,
|
||||
-2.82762398051658348494E-16, -3.42548561967721913462E-16,
|
||||
1.77256013305652638360E-15, 3.81168066935262242075E-15,
|
||||
-9.55484669882830764870E-15, -4.15056934728722208663E-14,
|
||||
1.54008621752140982691E-14, 3.85277838274214270114E-13,
|
||||
7.18012445138366623367E-13, -1.79417853150680611778E-12,
|
||||
-1.32158118404477131188E-11, -3.14991652796324136454E-11,
|
||||
1.18891471078464383424E-11, 4.94060238822496958910E-10,
|
||||
3.39623202570838634515E-9, 2.26666899049817806459E-8,
|
||||
2.04891858946906374183E-7, 2.89137052083475648297E-6,
|
||||
6.88975834691682398426E-5, 3.36911647825569408990E-3,
|
||||
8.04490411014108831608E-1};
|
||||
|
||||
return coefficients;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static inline C10_HOST_DEVICE scalar_t calc_i0(scalar_t _x) {
|
||||
using accscalar_t = at::acc_type<scalar_t, true>;
|
||||
@ -224,84 +276,33 @@ static inline C10_HOST_DEVICE scalar_t calc_i0(scalar_t _x) {
|
||||
// Needed for accurate results if input is bfloat16 or float16
|
||||
accscalar_t x = ::abs(static_cast<accscalar_t>(_x));
|
||||
|
||||
/* Chebyshev coefficients for exp(-x) I0(x)
|
||||
* in the interval [0,8].
|
||||
*
|
||||
* lim(x->0){ exp(-x) I0(x) } = 1.
|
||||
*/
|
||||
const accscalar_t A[] = {
|
||||
-4.41534164647933937950E-18,
|
||||
3.33079451882223809783E-17,
|
||||
-2.43127984654795469359E-16,
|
||||
1.71539128555513303061E-15,
|
||||
-1.16853328779934516808E-14,
|
||||
7.67618549860493561688E-14,
|
||||
-4.85644678311192946090E-13,
|
||||
2.95505266312963983461E-12,
|
||||
-1.72682629144155570723E-11,
|
||||
9.67580903537323691224E-11,
|
||||
-5.18979560163526290666E-10,
|
||||
2.65982372468238665035E-9,
|
||||
-1.30002500998624804212E-8,
|
||||
6.04699502254191894932E-8,
|
||||
-2.67079385394061173391E-7,
|
||||
1.11738753912010371815E-6,
|
||||
-4.41673835845875056359E-6,
|
||||
1.64484480707288970893E-5,
|
||||
-5.75419501008210370398E-5,
|
||||
1.88502885095841655729E-4,
|
||||
-5.76375574538582365885E-4,
|
||||
1.63947561694133579842E-3,
|
||||
-4.32430999505057594430E-3,
|
||||
1.05464603945949983183E-2,
|
||||
-2.37374148058994688156E-2,
|
||||
4.93052842396707084878E-2,
|
||||
-9.49010970480476444210E-2,
|
||||
1.71620901522208775349E-1,
|
||||
-3.04682672343198398683E-1,
|
||||
6.76795274409476084995E-1
|
||||
};
|
||||
|
||||
/* Chebyshev coefficients for exp(-x) sqrt(x) I0(x)
|
||||
* in the inverted interval [8,infinity].
|
||||
*
|
||||
* lim(x->inf){ exp(-x) sqrt(x) I0(x) } = 1/sqrt(2pi).
|
||||
*/
|
||||
const accscalar_t B[] = {
|
||||
-7.23318048787475395456E-18,
|
||||
-4.83050448594418207126E-18,
|
||||
4.46562142029675999901E-17,
|
||||
3.46122286769746109310E-17,
|
||||
-2.82762398051658348494E-16,
|
||||
-3.42548561967721913462E-16,
|
||||
1.77256013305652638360E-15,
|
||||
3.81168066935262242075E-15,
|
||||
-9.55484669882830764870E-15,
|
||||
-4.15056934728722208663E-14,
|
||||
1.54008621752140982691E-14,
|
||||
3.85277838274214270114E-13,
|
||||
7.18012445138366623367E-13,
|
||||
-1.79417853150680611778E-12,
|
||||
-1.32158118404477131188E-11,
|
||||
-3.14991652796324136454E-11,
|
||||
1.18891471078464383424E-11,
|
||||
4.94060238822496958910E-10,
|
||||
3.39623202570838634515E-9,
|
||||
2.26666899049817806459E-8,
|
||||
2.04891858946906374183E-7,
|
||||
2.89137052083475648297E-6,
|
||||
6.88975834691682398426E-5,
|
||||
3.36911647825569408990E-3,
|
||||
8.04490411014108831608E-1
|
||||
};
|
||||
|
||||
if (x <= 8.0) {
|
||||
const auto A = chebyshev_coefficients_A<accscalar_t>();
|
||||
accscalar_t y = static_cast<accscalar_t>((x / 2.0) - 2.0);
|
||||
return static_cast<scalar_t>(::exp(x) * chbevl(y, A, 30));
|
||||
}
|
||||
|
||||
const auto B = chebyshev_coefficients_B<accscalar_t>();
|
||||
return static_cast<scalar_t>(::exp(x) * chbevl(static_cast<accscalar_t>(32.0 / x - 2.0), B, 25) / ::sqrt(x));
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static inline C10_HOST_DEVICE scalar_t calc_i0e(scalar_t _x) {
|
||||
using accscalar_t = at::acc_type<scalar_t, true>;
|
||||
|
||||
// Upcast input for numerical accuracy purposes
|
||||
// Needed for accurate results if input is bfloat16 or float16
|
||||
accscalar_t x = ::abs(static_cast<accscalar_t>(_x));
|
||||
|
||||
if (x <= 8.0) {
|
||||
const auto A = chebyshev_coefficients_A<accscalar_t>();
|
||||
accscalar_t y = static_cast<accscalar_t>((x / 2.0) - 2.0);
|
||||
return static_cast<scalar_t>(chbevl(y, A, 30));
|
||||
}
|
||||
|
||||
const auto B = chebyshev_coefficients_B<accscalar_t>();
|
||||
return static_cast<scalar_t>(chbevl(static_cast<accscalar_t>(32.0 / x - 2.0), B, 25) / ::sqrt(x));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
@ -41,14 +41,6 @@ void exp_kernel_cuda(TensorIterator& iter) {
|
||||
});
|
||||
}
|
||||
|
||||
void exp2_kernel_cuda(TensorIterator& iter) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "exp2_cuda", [&]() {
|
||||
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
return ::exp2(a);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void expm1_kernel_cuda(TensorIterator& iter) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "expm1_cuda", [&]() {
|
||||
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
@ -57,14 +49,6 @@ void expm1_kernel_cuda(TensorIterator& iter) {
|
||||
});
|
||||
}
|
||||
|
||||
void i0_kernel_cuda(TensorIterator& iter) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "i0_cuda", [&]() {
|
||||
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
return calc_i0(a);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// We manually overload rsqrt because std::rsqrt does not work with complex types.
|
||||
template<typename scalar_t>
|
||||
__host__ __device__ static inline scalar_t rsqrt_wrapper(scalar_t v) {
|
||||
@ -95,80 +79,6 @@ void sqrt_kernel_cuda(TensorIterator& iter) {
|
||||
});
|
||||
}
|
||||
|
||||
void sigmoid_kernel_cuda(TensorIterator& iter) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "sigmoid_cuda", [&]() {
|
||||
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
scalar_t one = scalar_t(1);
|
||||
return one / (one + std::exp(- a));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void sinc_kernel_cuda(TensorIteratorBase& iter) {
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.common_dtype(), "sinc_cuda", [&]() {
|
||||
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
if (a == scalar_t(0)) {
|
||||
return scalar_t(1);
|
||||
} else {
|
||||
// NVCC says constexpr var is not accessible from device
|
||||
scalar_t product = c10::detail::pi<scalar_t>() * a;
|
||||
return std::sin(product) / product;
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void logit_kernel_cuda(TensorIterator& iter, const Scalar& eps_scalar) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
iter.common_dtype(),
|
||||
"logit_cuda",
|
||||
[&]() {
|
||||
using T_ACC = acc_type<scalar_t, true>;
|
||||
const T_ACC eps = eps_scalar.to<T_ACC>();
|
||||
if (eps < T_ACC(0)) {
|
||||
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) -> scalar_t {
|
||||
const T_ACC x_acc = static_cast<T_ACC>(x);
|
||||
return c10::cuda::compat::log(x_acc / (T_ACC(1) - x_acc));
|
||||
});
|
||||
} else {
|
||||
const T_ACC lo = eps;
|
||||
const T_ACC hi = T_ACC(1) - eps;
|
||||
gpu_kernel(
|
||||
iter, [lo, hi] GPU_LAMBDA(scalar_t x) -> scalar_t {
|
||||
const T_ACC x_acc = static_cast<T_ACC>(x);
|
||||
T_ACC z = x_acc < lo ? lo : (x_acc > hi ? hi : x_acc);
|
||||
return c10::cuda::compat::log(z / (T_ACC(1) - z));
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void erf_kernel_cuda(TensorIterator& iter) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "erf_cuda", [&]() {
|
||||
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
return ::erf(a);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void erfc_kernel_cuda(TensorIterator& iter) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "erfc_cuda", [&]() {
|
||||
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
return ::erfc(a);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void erfinv_kernel_cuda(TensorIterator& iter) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "erfinv_cuda", [&]() {
|
||||
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
return ::erfinv(a);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void clamp_kernel_cuda(TensorIterator& iter, const Scalar& min_value, const Scalar& max_value) {
|
||||
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "clamp_cuda", [&]() {
|
||||
auto lower = min_value.to<scalar_t>();
|
||||
@ -238,40 +148,6 @@ void nan_to_num_kernel_cuda(
|
||||
});
|
||||
}
|
||||
|
||||
void kaiser_window_kernel_cuda(TensorIterator& iter, int64_t window_length, double beta_){
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "kaiser_window_cuda", [&](){
|
||||
using T_ACC = acc_type<scalar_t, true>;
|
||||
const T_ACC inv_alpha = static_cast<T_ACC>(2.0 / (window_length - 1));
|
||||
const T_ACC beta = static_cast<T_ACC>(beta_);
|
||||
const T_ACC inv_i0_beta = 1.0 / calc_i0(beta);
|
||||
gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
T_ACC x = static_cast<T_ACC>(a) * inv_alpha - 1;
|
||||
T_ACC y = std::max<T_ACC>(0, 1 - x * x);
|
||||
return calc_i0(beta * ::sqrt(y)) * inv_i0_beta;
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void entr_kernel_cuda(TensorIterator& iter) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
ScalarType::Half,
|
||||
ScalarType::BFloat16,
|
||||
iter.common_dtype(),
|
||||
"entr_cuda",
|
||||
[&]() {
|
||||
gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t x) -> scalar_t {
|
||||
if (at::_isnan(x)) {
|
||||
return x;
|
||||
} else if (x > 0) {
|
||||
return -x * std::log(x);
|
||||
} else if (x == 0) {
|
||||
return 0;
|
||||
}
|
||||
return static_cast<scalar_t>(-INFINITY);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void frexp_kernel_cuda(TensorIterator& iter) {
|
||||
#ifdef __HIP_PLATFORM_HCC__
|
||||
// Reference: https://rocmdocs.amd.com/en/latest/ROCm_API_References/HIP-MATH.html
|
||||
@ -295,23 +171,13 @@ void frexp_kernel_cuda(TensorIterator& iter) {
|
||||
|
||||
REGISTER_DISPATCH(bitwise_not_stub, &bitwise_not_kernel_cuda);
|
||||
REGISTER_DISPATCH(exp_stub, &exp_kernel_cuda);
|
||||
REGISTER_DISPATCH(exp2_stub, &exp2_kernel_cuda);
|
||||
REGISTER_DISPATCH(expm1_stub, &expm1_kernel_cuda);
|
||||
REGISTER_DISPATCH(i0_stub, &i0_kernel_cuda);
|
||||
REGISTER_DISPATCH(rsqrt_stub, &rsqrt_kernel_cuda);
|
||||
REGISTER_DISPATCH(sqrt_stub, &sqrt_kernel_cuda);
|
||||
REGISTER_DISPATCH(sigmoid_stub, &sigmoid_kernel_cuda);
|
||||
REGISTER_DISPATCH(sinc_stub, &sinc_kernel_cuda);
|
||||
REGISTER_DISPATCH(logit_stub, &logit_kernel_cuda);
|
||||
REGISTER_DISPATCH(erf_stub, &erf_kernel_cuda);
|
||||
REGISTER_DISPATCH(erfc_stub, &erfc_kernel_cuda);
|
||||
REGISTER_DISPATCH(erfinv_stub, &erfinv_kernel_cuda);
|
||||
REGISTER_DISPATCH(clamp_stub, &clamp_kernel_cuda);
|
||||
REGISTER_DISPATCH(clamp_min_stub, &clamp_min_kernel_cuda);
|
||||
REGISTER_DISPATCH(clamp_max_stub, &clamp_max_kernel_cuda);
|
||||
REGISTER_DISPATCH(nan_to_num_stub, &nan_to_num_kernel_cuda);
|
||||
REGISTER_DISPATCH(kaiser_window_stub, &kaiser_window_kernel_cuda);
|
||||
REGISTER_DISPATCH(entr_stub, &entr_kernel_cuda);
|
||||
REGISTER_DISPATCH(frexp_stub, &frexp_kernel_cuda);
|
||||
|
||||
} // namespace native
|
||||
|
167
aten/src/ATen/native/cuda/UnarySpecialOpsKernel.cu
Normal file
167
aten/src/ATen/native/cuda/UnarySpecialOpsKernel.cu
Normal file
@ -0,0 +1,167 @@
|
||||
#include <ATen/native/UnaryOps.h>
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/Context.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/Math.h>
|
||||
#include <ATen/native/TensorFactories.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/cuda/Loops.cuh>
|
||||
#include <ATen/native/cuda/Math.cuh>
|
||||
#include <ATen/NumericUtils.h>
|
||||
#include <c10/cuda/CUDAMathCompat.h>
|
||||
#include <ATen/NumericUtils.h>
|
||||
#include <c10/util/complex.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
void exp2_kernel_cuda(TensorIterator& iter) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "exp2_cuda", [&]() {
|
||||
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
return ::exp2(a);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void i0_kernel_cuda(TensorIterator& iter) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "i0_cuda", [&]() {
|
||||
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
return calc_i0(a);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void i0e_kernel_cuda(TensorIteratorBase& iter) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "i0e_cuda", [&]() {
|
||||
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
return calc_i0e(a);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void sigmoid_kernel_cuda(TensorIterator& iter) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "sigmoid_cuda", [&]() {
|
||||
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
scalar_t one = scalar_t(1);
|
||||
return one / (one + std::exp(- a));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void sinc_kernel_cuda(TensorIteratorBase& iter) {
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.common_dtype(), "sinc_cuda", [&]() {
|
||||
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
if (a == scalar_t(0)) {
|
||||
return scalar_t(1);
|
||||
} else {
|
||||
// NVCC says constexpr var is not accessible from device
|
||||
scalar_t product = c10::detail::pi<scalar_t>() * a;
|
||||
return std::sin(product) / product;
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void logit_kernel_cuda(TensorIterator& iter, const Scalar& eps_scalar) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
iter.common_dtype(),
|
||||
"logit_cuda",
|
||||
[&]() {
|
||||
using T_ACC = acc_type<scalar_t, true>;
|
||||
const T_ACC eps = eps_scalar.to<T_ACC>();
|
||||
if (eps < T_ACC(0)) {
|
||||
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) -> scalar_t {
|
||||
const T_ACC x_acc = static_cast<T_ACC>(x);
|
||||
return c10::cuda::compat::log(x_acc / (T_ACC(1) - x_acc));
|
||||
});
|
||||
} else {
|
||||
const T_ACC lo = eps;
|
||||
const T_ACC hi = T_ACC(1) - eps;
|
||||
gpu_kernel(
|
||||
iter, [lo, hi] GPU_LAMBDA(scalar_t x) -> scalar_t {
|
||||
const T_ACC x_acc = static_cast<T_ACC>(x);
|
||||
T_ACC z = x_acc < lo ? lo : (x_acc > hi ? hi : x_acc);
|
||||
return c10::cuda::compat::log(z / (T_ACC(1) - z));
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void erf_kernel_cuda(TensorIterator& iter) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "erf_cuda", [&]() {
|
||||
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
return ::erf(a);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void erfc_kernel_cuda(TensorIterator& iter) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "erfc_cuda", [&]() {
|
||||
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
return ::erfc(a);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void erfinv_kernel_cuda(TensorIterator& iter) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "erfinv_cuda", [&]() {
|
||||
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
return ::erfinv(a);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void kaiser_window_kernel_cuda(TensorIterator& iter, int64_t window_length, double beta_){
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "kaiser_window_cuda", [&](){
|
||||
using T_ACC = acc_type<scalar_t, true>;
|
||||
const T_ACC inv_alpha = static_cast<T_ACC>(2.0 / (window_length - 1));
|
||||
const T_ACC beta = static_cast<T_ACC>(beta_);
|
||||
const T_ACC inv_i0_beta = 1.0 / calc_i0(beta);
|
||||
gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
T_ACC x = static_cast<T_ACC>(a) * inv_alpha - 1;
|
||||
T_ACC y = std::max<T_ACC>(0, 1 - x * x);
|
||||
return calc_i0(beta * ::sqrt(y)) * inv_i0_beta;
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void entr_kernel_cuda(TensorIterator& iter) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
ScalarType::Half,
|
||||
ScalarType::BFloat16,
|
||||
iter.common_dtype(),
|
||||
"entr_cuda",
|
||||
[&]() {
|
||||
gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t x) -> scalar_t {
|
||||
if (at::_isnan(x)) {
|
||||
return x;
|
||||
} else if (x > 0) {
|
||||
return -x * std::log(x);
|
||||
} else if (x == 0) {
|
||||
return 0;
|
||||
}
|
||||
return static_cast<scalar_t>(-INFINITY);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(exp2_stub, &exp2_kernel_cuda);
|
||||
REGISTER_DISPATCH(i0_stub, &i0_kernel_cuda);
|
||||
REGISTER_DISPATCH(i0e_stub, &i0e_kernel_cuda);
|
||||
REGISTER_DISPATCH(sigmoid_stub, &sigmoid_kernel_cuda);
|
||||
REGISTER_DISPATCH(sinc_stub, &sinc_kernel_cuda);
|
||||
REGISTER_DISPATCH(logit_stub, &logit_kernel_cuda);
|
||||
REGISTER_DISPATCH(erf_stub, &erf_kernel_cuda);
|
||||
REGISTER_DISPATCH(erfc_stub, &erfc_kernel_cuda);
|
||||
REGISTER_DISPATCH(erfinv_stub, &erfinv_kernel_cuda);
|
||||
REGISTER_DISPATCH(kaiser_window_stub, &kaiser_window_kernel_cuda);
|
||||
REGISTER_DISPATCH(entr_stub, &entr_kernel_cuda);
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
@ -8418,6 +8418,18 @@
|
||||
- func: special_erfinv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
python_module: special
|
||||
|
||||
- func: special_i0e(Tensor self) -> Tensor
|
||||
python_module: special
|
||||
variants: function
|
||||
structured_delegate: special_i0e.out
|
||||
|
||||
- func: special_i0e.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
python_module: special
|
||||
structured: True
|
||||
structured_inherits: TensorIteratorBase
|
||||
dispatch:
|
||||
CPU, CUDA: special_i0e_out
|
||||
|
||||
- func: special_logit(Tensor self, float? eps=None) -> Tensor
|
||||
python_module: special
|
||||
variants: function
|
||||
|
@ -26,4 +26,5 @@ Functions
|
||||
.. autofunction:: expm1
|
||||
.. autofunction:: exp2
|
||||
.. autofunction:: gammaln
|
||||
.. autofunction:: i0e
|
||||
.. autofunction:: logit
|
||||
|
@ -1195,6 +1195,38 @@ class TestUnaryUfuncs(TestCase):
|
||||
t = torch.tensor([inf, -inf, nan], device=device, dtype=dtype)
|
||||
self.assertTrue(torch.i0(t).isnan().all())
|
||||
|
||||
@dtypesIfCUDA(*torch.testing.get_all_fp_dtypes())
|
||||
@dtypes(torch.bfloat16, torch.float32, torch.float64)
|
||||
@unittest.skipIf(not TEST_SCIPY, "SciPy not found")
|
||||
def test_special_i0e_vs_scipy(self, device, dtype):
|
||||
def check_equal(t):
|
||||
# Test by comparing to scipy
|
||||
actual = torch.special.i0e(t)
|
||||
if dtype is torch.bfloat16:
|
||||
t = t.to(torch.float32)
|
||||
expected = scipy.special.i0e(t.cpu().numpy())
|
||||
|
||||
# Casting down for dtype float16 is required since scipy upcasts to float32
|
||||
if dtype is torch.bfloat16 or dtype is torch.float16:
|
||||
expected = torch.from_numpy(expected).to(dtype)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
t = torch.tensor([], device=device, dtype=dtype)
|
||||
check_equal(t)
|
||||
|
||||
range = (-1e7, 1e7)
|
||||
if dtype == torch.half:
|
||||
range = (-65000, 65000)
|
||||
|
||||
t = torch.linspace(*range, int(1e4), device=device, dtype=dtype)
|
||||
check_equal(t)
|
||||
|
||||
# NaN, inf, -inf are tested in reference_numerics tests.
|
||||
info = torch.finfo(dtype)
|
||||
min, max, eps, tiny = info.min, info.max, info.eps, info.tiny
|
||||
t = torch.tensor([min, max, eps, tiny], dtype=dtype, device=device)
|
||||
check_equal(t)
|
||||
|
||||
# TODO: allow large opinfo values to be opted-into via metadata
|
||||
@dtypes(torch.long)
|
||||
def test_abs_big_number(self, device, dtype):
|
||||
|
@ -149,4 +149,20 @@ inline Tensor& expm1_out(Tensor& result, const Tensor& self) {
|
||||
return torch::special_expm1_out(result, self);
|
||||
}
|
||||
|
||||
/// Computes the exponentially scaled zeroth order modified Bessel function of the first kind
|
||||
/// See https://pytorch.org/docs/master/special.html#torch.special.i0e.
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// auto t = torch::randn(128, dtype=kDouble);
|
||||
/// torch::special::i0e(t);
|
||||
/// ```
|
||||
inline Tensor i0e(const Tensor& self) {
|
||||
return torch::special_i0e(self);
|
||||
}
|
||||
|
||||
inline Tensor i0e_out(const Tensor& self) {
|
||||
return torch::special_i0e(self);
|
||||
}
|
||||
|
||||
}} // torch::special
|
||||
|
@ -848,6 +848,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||
torch.special.expm1: lambda input: -1,
|
||||
torch.special.expit: lambda input: -1,
|
||||
torch.special.gammaln: lambda input: -1,
|
||||
torch.special.i0e: lambda input: -1,
|
||||
torch.special.logit: lambda input: -1,
|
||||
torch.t: lambda input: -1,
|
||||
torch.take: lambda input, index: -1,
|
||||
|
@ -229,3 +229,22 @@ Example::
|
||||
>>> torch.special.expm1(torch.tensor([0, math.log(2.)]))
|
||||
tensor([ 0., 1.])
|
||||
""".format(**common_args))
|
||||
|
||||
i0e = _add_docstr(_special.special_i0e,
|
||||
r"""
|
||||
i0e(input, *, out=None) -> Tensor
|
||||
Computes the exponentially scaled zeroth order modified Bessel function of the first kind (as defined below)
|
||||
for each element of :attr:`input`.
|
||||
|
||||
.. math::
|
||||
\text{out}_{i} = \exp(-|x|) * i0(x) = \exp(-|x|) * \sum_{k=0}^{\infty} \frac{(\text{input}_{i}^2/4)^k}{(k!)^2}
|
||||
|
||||
""" + r"""
|
||||
Args:
|
||||
{input}
|
||||
Keyword args:
|
||||
{out}
|
||||
Example::
|
||||
>>> torch.special.i0e(torch.arange(5, dtype=torch.float32))
|
||||
tensor([1.0000, 0.4658, 0.3085, 0.2430, 0.2070])
|
||||
""".format(**common_args))
|
||||
|
@ -3302,6 +3302,16 @@ op_db: List[OpInfo] = [
|
||||
dtypesIfCPU=floating_types_and(torch.bfloat16),
|
||||
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
|
||||
supports_autograd=False),
|
||||
UnaryUfuncInfo('special.i0e',
|
||||
aten_name='special_i0e',
|
||||
ref=scipy.special.i0e,
|
||||
decorators=(precisionOverride({torch.bfloat16: 3e-1,
|
||||
torch.float16: 3e-1}),),
|
||||
dtypes=all_types_and(torch.bool, torch.bfloat16),
|
||||
dtypesIfCPU=all_types_and(torch.bool, torch.bfloat16),
|
||||
dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
|
||||
supports_autograd=False,
|
||||
safe_casts_outputs=True),
|
||||
OpInfo('floor_divide',
|
||||
dtypes=all_types_and(torch.half, torch.bfloat16),
|
||||
sample_inputs_func=sample_inputs_floor_divide,
|
||||
|
Reference in New Issue
Block a user