Revert "Turn some const variables into constexpr in C++ code (#165401)"

This reverts commit 5b2afe4c5dc87786ca65bf22ca9a78f7c21a33a4.

Reverted https://github.com/pytorch/pytorch/pull/165401 on behalf of https://github.com/seemethere due to This is breaking test/distributions/test_distributions.py::TestDistributions::test_binomial_sample on HUD, see 5b2afe4c5d ([comment](https://github.com/pytorch/pytorch/pull/165401#issuecomment-3414023134))
This commit is contained in:
PyTorch MergeBot
2025-10-17 06:14:09 +00:00
parent 364624e209
commit 9e94ec76b8
20 changed files with 79 additions and 79 deletions

View File

@ -240,8 +240,8 @@ TORCH_META_FUNC(gelu_backward) (
namespace at::native {
static constexpr double SELU_ALPHA = 1.6732632423543772848170429916717;
static constexpr double SELU_SCALE = 1.0507009873554804934193349852946;
static const double SELU_ALPHA = 1.6732632423543772848170429916717;
static const double SELU_SCALE = 1.0507009873554804934193349852946;
DEFINE_DISPATCH(elu_stub);
DEFINE_DISPATCH(elu_backward_stub);

View File

@ -286,7 +286,7 @@ template void scal_fast_path<scalar_t>(int *n, scalar_t *a, scalar_t *x, int *in
#if AT_BUILD_WITH_BLAS()
template <>
bool scal_use_fast_path<double>(int64_t n, int64_t incx) {
auto constexpr intmax = std::numeric_limits<int>::max();
auto intmax = std::numeric_limits<int>::max();
return n <= intmax && incx <= intmax;
}
@ -315,7 +315,7 @@ bool gemv_use_fast_path<float>(
int64_t incx,
[[maybe_unused]] float beta,
int64_t incy) {
auto constexpr intmax = std::numeric_limits<int>::max();
auto intmax = std::numeric_limits<int>::max();
return (m <= intmax) && (n <= intmax) && (lda <= intmax) &&
(incx > 0) && (incx <= intmax) && (incy > 0) && (incy <= intmax);
}

View File

@ -127,7 +127,7 @@ C10_DEVICE scalar_t sample_gamma(scalar_t alpha, BaseSampler<accscalar_t, unifor
template<typename scalar_t>
C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) {
constexpr static scalar_t kTailValues[] = {
const static scalar_t kTailValues[] = {
0.0810614667953272,
0.0413406959554092,
0.0276779256849983,
@ -139,7 +139,7 @@ C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) {
0.00925546218271273,
0.00833056343336287
};
if (k <= sizeof(kTailValues)/sizeof(scalar_t)) {
if (k <= 9) {
return kTailValues[static_cast<size_t>(k)];
}
scalar_t kp1sq = (k + 1) * (k + 1);

View File

@ -581,7 +581,7 @@ scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M,
template <typename scalar_t>
static scalar_t lanczos_sum_expg_scaled(scalar_t x) {
// lanczos approximation
static constexpr scalar_t lanczos_sum_expg_scaled_num[13] = {
static const scalar_t lanczos_sum_expg_scaled_num[13] = {
0.006061842346248906525783753964555936883222,
0.5098416655656676188125178644804694509993,
19.51992788247617482847860966235652136208,
@ -596,7 +596,7 @@ static scalar_t lanczos_sum_expg_scaled(scalar_t x) {
103794043.1163445451906271053616070238554,
56906521.91347156388090791033559122686859
};
static constexpr scalar_t lanczos_sum_expg_scaled_denom[13] = {
static const scalar_t lanczos_sum_expg_scaled_denom[13] = {
1.,
66.,
1925.,
@ -712,7 +712,7 @@ static scalar_t _igamc_helper_series(scalar_t a, scalar_t x) {
template <typename scalar_t>
static scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t x, bool igam) {
// Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1]
static constexpr scalar_t d[25][25] =
static const scalar_t d[25][25] =
{{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2,
1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4,
3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6,

View File

@ -62,7 +62,7 @@
#include <utility>
#include <vector>
static constexpr int MIOPEN_DIM_MAX = 5;
static const int MIOPEN_DIM_MAX = 5;
namespace at::meta {

View File

@ -1038,7 +1038,7 @@ struct HelperInterpNearest : public HelperInterpBase {
// We keep this structure for BC and consider as deprecated.
// See HelperInterpNearestExact as replacement
static constexpr int interp_size = 1;
static const int interp_size = 1;
static inline void init_indices_weights(
at::ScalarType output_type,
@ -1155,7 +1155,7 @@ struct HelperInterpNearestExact : public HelperInterpNearest {
struct HelperInterpLinear : public HelperInterpBase {
static constexpr int interp_size = 2;
static const int interp_size = 2;
// Compute indices and weights for each interpolated dimension
// indices_weights = {
@ -1275,7 +1275,7 @@ struct HelperInterpLinear : public HelperInterpBase {
struct HelperInterpCubic : public HelperInterpBase {
static constexpr int interp_size = 4;
static const int interp_size = 4;
// Compute indices and weights for each interpolated dimension
// indices_weights = {

View File

@ -249,7 +249,7 @@ __global__ void max_pool_forward_nhwc(
}
static constexpr int BLOCK_THREADS = 256;
static const int BLOCK_THREADS = 256;
template <typename scalar_t, typename accscalar_t>
#if defined (USE_ROCM)

View File

@ -36,9 +36,9 @@ namespace at::native {
namespace {
#if defined(USE_ROCM)
static constexpr int BLOCKDIMY = 16;
static const int BLOCKDIMY = 16;
#else
static constexpr int BLOCKDIMY = 32;
static const int BLOCKDIMY = 32;
#endif
template

View File

@ -82,7 +82,7 @@ __host__ __device__ scalar_t lanczos_sum_expg_scaled(scalar_t x) {
// lanczos approximation
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
constexpr accscalar_t lanczos_sum_expg_scaled_num[13] = {
static const accscalar_t lanczos_sum_expg_scaled_num[13] = {
0.006061842346248906525783753964555936883222,
0.5098416655656676188125178644804694509993,
19.51992788247617482847860966235652136208,
@ -97,7 +97,7 @@ __host__ __device__ scalar_t lanczos_sum_expg_scaled(scalar_t x) {
103794043.1163445451906271053616070238554,
56906521.91347156388090791033559122686859
};
constexpr accscalar_t lanczos_sum_expg_scaled_denom[13] = {
static const accscalar_t lanczos_sum_expg_scaled_denom[13] = {
1.,
66.,
1925.,
@ -126,10 +126,10 @@ __host__ __device__ scalar_t _igam_helper_fac(scalar_t a, scalar_t x) {
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
accscalar_t ax, fac, res, num, numfac;
constexpr accscalar_t MAXLOG = std::is_same_v<accscalar_t,double> ?
static const accscalar_t MAXLOG = std::is_same_v<accscalar_t,double> ?
7.09782712893383996843E2 : 88.72283905206835;
constexpr accscalar_t EXP1 = 2.718281828459045;
constexpr accscalar_t lanczos_g = 6.024680040776729583740234375;
static const accscalar_t EXP1 = 2.718281828459045;
static const accscalar_t lanczos_g = 6.024680040776729583740234375;
if (::fabs(a - x) > 0.4 * ::fabs(a)) {
ax = a * ::log(x) - x - ::lgamma(a);
@ -158,9 +158,9 @@ __host__ __device__ scalar_t _igam_helper_series(scalar_t a, scalar_t x) {
// Compute igam using DLMF 8.11.4. [igam1]
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
1.11022302462515654042E-16 : 5.9604644775390625E-8;
constexpr int MAXITER = 2000;
static const int MAXITER = 2000;
int i;
accscalar_t ans, ax, c, r;
@ -196,8 +196,8 @@ __host__ __device__ scalar_t _igamc_helper_series(scalar_t a, scalar_t x) {
accscalar_t fac = 1;
accscalar_t sum = 0;
accscalar_t term, logx;
constexpr int MAXITER = 2000;
constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
static const int MAXITER = 2000;
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
1.11022302462515654042E-16 : 5.9604644775390625E-8;
for (n = 1; n < MAXITER; n++) {
@ -219,7 +219,7 @@ __host__ __device__ scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t
// Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1]
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
constexpr accscalar_t d[25][25] =
static const accscalar_t d[25][25] =
{{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2, 1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4, 3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6, 8.296711340953086e-7, -1.7665952736826079e-7, 6.7078535434014986e-9, 1.0261809784240308e-8, -4.3820360184533532e-9, 9.1476995822367902e-10, -2.551419399494625e-11, -5.8307721325504251e-11, 2.4361948020667416e-11, -5.0276692801141756e-12, 1.1004392031956135e-13, 3.3717632624009854e-13, -1.3923887224181621e-13, 2.8534893807047443e-14, -5.1391118342425726e-16, -1.9752288294349443e-15},
{-1.8518518518518519e-3, -3.4722222222222222e-3, 2.6455026455026455e-3, -9.9022633744855967e-4, 2.0576131687242798e-4, -4.0187757201646091e-7, -1.8098550334489978e-5, 7.6491609160811101e-6, -1.6120900894563446e-6, 4.6471278028074343e-9, 1.378633446915721e-7, -5.752545603517705e-8, 1.1951628599778147e-8, -1.7543241719747648e-11, -1.0091543710600413e-9, 4.1627929918425826e-10, -8.5639070264929806e-11, 6.0672151016047586e-14, 7.1624989648114854e-12, -2.9331866437714371e-12, 5.9966963656836887e-13, -2.1671786527323314e-16, -4.9783399723692616e-14, 2.0291628823713425e-14, -4.13125571381061e-15},
{4.1335978835978836e-3, -2.6813271604938272e-3, 7.7160493827160494e-4, 2.0093878600823045e-6, -1.0736653226365161e-4, 5.2923448829120125e-5, -1.2760635188618728e-5, 3.4235787340961381e-8, 1.3721957309062933e-6, -6.298992138380055e-7, 1.4280614206064242e-7, -2.0477098421990866e-10, -1.4092529910867521e-8, 6.228974084922022e-9, -1.3670488396617113e-9, 9.4283561590146782e-13, 1.2872252400089318e-10, -5.5645956134363321e-11, 1.1975935546366981e-11, -4.1689782251838635e-15, -1.0940640427884594e-12, 4.6622399463901357e-13, -9.905105763906906e-14, 1.8931876768373515e-17, 8.8592218725911273e-15},
@ -248,7 +248,7 @@ __host__ __device__ scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t
int k, n, sgn;
int maxpow = 0;
constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
1.11022302462515654042E-16 : 5.9604644775390625E-8;
accscalar_t lambda = x / a;
accscalar_t sigma = (x - a) / a;
@ -314,12 +314,12 @@ __host__ __device__ scalar_t _igamc_helper_continued_fraction(scalar_t a, scalar
int i;
accscalar_t ans, ax, c, yc, r, t, y, z;
accscalar_t pk, pkm1, pkm2, qk, qkm1, qkm2;
constexpr int MAXITER = 2000;
constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
static const int MAXITER = 2000;
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
1.11022302462515654042E-16 : 5.9604644775390625E-8;
constexpr accscalar_t BIG = std::is_same_v<accscalar_t,double> ?
static const accscalar_t BIG = std::is_same_v<accscalar_t,double> ?
4.503599627370496e15 : 16777216.;
constexpr accscalar_t BIGINV = std::is_same_v<accscalar_t,double> ?
static const accscalar_t BIGINV = std::is_same_v<accscalar_t,double> ?
2.22044604925031308085e-16 : 5.9604644775390625E-8;
ax = _igam_helper_fac(a, x);
@ -385,10 +385,10 @@ __noinline__ __host__ __device__ scalar_t calc_igammac(scalar_t a, scalar_t x) {
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
accscalar_t absxma_a;
constexpr accscalar_t SMALL = 20.0;
constexpr accscalar_t LARGE = 200.0;
constexpr accscalar_t SMALLRATIO = 0.3;
constexpr accscalar_t LARGERATIO = 4.5;
static const accscalar_t SMALL = 20.0;
static const accscalar_t LARGE = 200.0;
static const accscalar_t SMALLRATIO = 0.3;
static const accscalar_t LARGERATIO = 4.5;
if ((x < 0) || (a < 0)) {
// out of defined-region of the function
@ -467,10 +467,10 @@ __noinline__ __host__ __device__ scalar_t calc_igamma(scalar_t a, scalar_t x) {
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
accscalar_t absxma_a;
constexpr accscalar_t SMALL = 20.0;
constexpr accscalar_t LARGE = 200.0;
constexpr accscalar_t SMALLRATIO = 0.3;
constexpr accscalar_t LARGERATIO = 4.5;
static const accscalar_t SMALL = 20.0;
static const accscalar_t LARGE = 200.0;
static const accscalar_t SMALLRATIO = 0.3;
static const accscalar_t LARGERATIO = 4.5;
// boundary values following SciPy
if ((x < 0) || (a < 0)) {

View File

@ -231,7 +231,7 @@ const auto lcm_string = jiterator_stringify(
const auto digamma_string = jiterator_stringify(
template <typename T>
T digamma(T x) {
static constexpr double PI_f64 = 3.14159265358979323846;
static const double PI_f64 = 3.14159265358979323846;
// Short-circuits if x is +/- 0 and returns -/+ ∞ per the C++ standard
if (x == 0) {
@ -3072,9 +3072,9 @@ template <typename scalar_t>
static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) {
// [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
static constexpr double PI_f64 = 3.14159265358979323846;
constexpr accscalar_t PSI_10 = 2.25175258906672110764;
constexpr accscalar_t A[] = {
static const double PI_f64 = 3.14159265358979323846;
const accscalar_t PSI_10 = 2.25175258906672110764;
const accscalar_t A[] = {
8.33333333333333333333E-2,
-2.10927960927960927961E-2,
7.57575757575757575758E-3,

View File

@ -277,7 +277,7 @@ struct BilinearFilterFunctor {
return 0;
}
static constexpr int size = 2;
static const int size = 2;
};
// taken from
@ -301,7 +301,7 @@ struct BicubicFilterFunctor {
return 0;
}
static constexpr int size = 4;
static const int size = 4;
};
template <typename accscalar_t>

View File

@ -416,7 +416,7 @@ static inline bool checksize(const Tensor& mat1, const Tensor& mat2){
// else if dim = 3, mat1's size = (b * m * n), mat2's size = (b * n * k)
// else called from aten::mv, mat1.size = (m * n), mat2.size = (n)
// only m * n * b * k(if exist) are large enough we can get benefit from mkldnn optimized gemm kernel
constexpr int64_t mkldnn_gemm_min_size = 16 * 16 * 16;
static const int64_t mkldnn_gemm_min_size = 16 * 16 * 16;
if (mat1.dim() == 1 && mat2.dim() == 1) {
// aten::dot
return mat1.size(0) > mkldnn_gemm_min_size;

View File

@ -3551,7 +3551,7 @@ void dequantize_tensor_per_tensor_affine_cpu(
#if defined(__ARM_NEON__) || defined(__aarch64__)
constexpr static int PARALLEL_THRESHOLD = 1 << 20;
const static int PARALLEL_THRESHOLD = 1 << 20;
// Generic template defaults to naive quantize implementation
template <typename T>

View File

@ -1388,7 +1388,7 @@ namespace at::native {
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() <= 1,
"onednn int8 linear: act scale/zp size should be 1/<=1");
static std::optional<at::Tensor> other = std::nullopt;
constexpr std::string_view binary_post_op = "none";
static const std::string_view binary_post_op = "none";
int64_t act_zp = act_zero_point.numel() == 1 ? act_zero_point.item().toLong() : 0;
return linear_int8_with_onednn_weight(
act, act_scale.item().toDouble(), act_zp,

View File

@ -16,8 +16,8 @@ namespace {
#ifdef USE_PYTORCH_QNNPACK
constexpr static float qnnpack_softmax_output_scale = 0x1.0p-8f;
constexpr static int qnnpack_softmax_output_zero_point = 0;
const static float qnnpack_softmax_output_scale = 0x1.0p-8f;
const static int qnnpack_softmax_output_zero_point = 0;
bool is_qnnpack_compatible(
const Tensor& qx,

View File

@ -110,9 +110,9 @@ class ApplyLogSumExp {
using ElementCompute = ElementCompute_;
using ElementLSE = ElementLSE_;
static int constexpr kElementsPerAccess = ElementsPerAccess;
static int constexpr kCount = kElementsPerAccess;
static constexpr ScaleType::Kind kScale =
static int const kElementsPerAccess = ElementsPerAccess;
static int const kCount = kElementsPerAccess;
static const ScaleType::Kind kScale =
cutlass::epilogue::thread::ScaleType::NoBetaScaling;
using FragmentOutput = Array<ElementOutput, kCount>;