Add isnan exit condition to special ops (#157464)

They might have been slow on CUDA-11.3, but this version of CUDA is long gone. More fundamental underlying issue were linear complexity of the recursive polynomial definitions for higher order polynomials, for example see this loop from implementation of Chebyshev polynomial of the first kind
7081b8233a/aten/src/ATen/native/Math.h (L2969-L2973)
which were tested by `test_compare_cpu` using following values (as sample index 16)
7081b8233a/torch/testing/_internal/opinfo/core.py (L2079)

Luckily chebyshev polynomials for absolute values higher than 1 pretty quickly reach infinity, see below
```
python3 -c "import torch;print(torch.special.chebyshev_polynomial_v(torch.nextafter(torch.tensor(1.0), torch.tensor(2.0)), torch.tensor(1e6)))"
tensor(nan)
```
Which is not the case for Laguerre polynomials, but it's probably fine to just limit it to 1e7

Before
```
$ PYTORCH_TEST_WITH_SLOW=1 python test_ops.py -k chebyshev_polynomial_
ssssssss..ssssss..ssssss..ssssssssssssssssssssss..ssssss/home/ubuntu/py3.10-nightly/lib/python3.10/site-packages/torch/backends/cuda/__init__.py:131: UserWarning: This API is going to be deprecated, please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:78.)
  return torch._C._get_cublas_allow_tf32()
....ssssssssssss..ssssss..ssssss............ssssssssssssssssssssssssssssssssssss..ssssssssssssss..ssssss..ssssssssssssssssssssssssssssss..ssssss....ssssssssssss..ssssss..ssssss............ssssssssssssssssssssssssssssssssssss..ssssss..ssssssssssssss..ssssss..ssssss..ssssssssssssss..ssssss..ssssss..ssssss..ssssss..ssssss..ssssss..ssssss..ssssss..ssssss..ssssss..ssssssssssssss
----------------------------------------------------------------------
Ran 432 tests in 8.575s

OK (skipped=344)
```
After
```
$ PYTORCH_TEST_WITH_SLOW=1 python test_ops.py -k chebyshev_polynomial_
ssssssss........................ssssssssssssssss......../home/ubuntu/pytorch/torch/backends/cuda/__init__.py:131: UserWarning: This API is going to be deprecated, please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /home/ubuntu/pytorch/aten/src/ATen/Context.cpp:78.)
  return torch._C._get_cublas_allow_tf32()
........................................................................................xxxxxxxx................ssssssssssssssssssssssss........................................................................................................ssssssss........................ssssssss........................................................................................ssssssss
----------------------------------------------------------------------
Ran 432 tests in 45.580s

OK (skipped=72, expected failures=8)
```

Fixes https://github.com/pytorch/pytorch/issues/79528

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157464
Approved by: https://github.com/Skylion007, https://github.com/dcci
ghstack dependencies: #157488
This commit is contained in:
Nikita Shulga
2025-07-04 16:35:51 -07:00
committed by PyTorch MergeBot
parent 63e87d6d05
commit a952956d05
7 changed files with 66 additions and 109 deletions

View File

@ -2862,7 +2862,7 @@ inline C10_HOST_DEVICE T chebyshev_polynomial_t_forward(T x, int64_t n) {
T q = x;
T r;
for (int64_t k = 2; k <= n; k++) {
for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) {
r = (x + x) * q - p;
p = q;
q = r;
@ -2910,7 +2910,7 @@ inline C10_HOST_DEVICE T chebyshev_polynomial_u_forward(T x, int64_t n) {
T q = x + x;
T r;
for (int64_t k = 2; k <= n; k++) {
for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) {
r = (x + x) * q - p;
p = q;
q = r;
@ -2966,7 +2966,7 @@ inline C10_HOST_DEVICE T chebyshev_polynomial_v_forward(T x, int64_t n) {
T q = x + x - T(1.0);
T r;
for (int64_t k = 2; k <= n; k++) {
for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) {
r = (x + x) * q - p;
p = q;
q = r;
@ -3026,7 +3026,7 @@ inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, int64_t n) {
T q = x + x + T(1.0);
T r;
for (int64_t k = 2; k <= n; k++) {
for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) {
r = (x + x) * q - p;
p = q;
q = r;
@ -3150,7 +3150,7 @@ inline C10_HOST_DEVICE T laguerre_polynomial_l_forward(T x, int64_t n) {
T q = T(1.0) - x;
T r;
for (int64_t k = 1; k < n; k++) {
for (int64_t k = 1; (k < n) && !std::isnan(q); k++) {
r = (((k + k) + (T(1.0) - x)) * q - k * p) / (k + 1);
p = q;
q = r;
@ -3190,7 +3190,7 @@ inline C10_HOST_DEVICE T legendre_polynomial_p_forward(T x, int64_t n) {
T q = x;
T r;
for (int64_t k = 1; k < n; k++) {
for (int64_t k = 1; (k < n) && !std::isnan(q); k++) {
r = ((k + k + 1) * x * q - k * p) / (k + 1);
p = q;
q = r;
@ -3733,7 +3733,7 @@ inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_t_forward(T x, int64_t n)
T q = x + x - T(1.0);
T r;
for (int64_t k = 2; k <= n; k++) {
for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) {
r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p;
p = q;
q = r;
@ -3785,7 +3785,7 @@ inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_u_forward(T x, int64_t n)
T q = x + x - T(1.0) + (x + x - T(1.0));
T r;
for (int64_t k = 2; k <= n; k++) {
for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) {
r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p;
p = q;
q = r;
@ -3841,7 +3841,7 @@ inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, int64_t n)
T q = x + x - T(1.0) + (x + x - T(1.0)) - T(1.0);
T r;
for (int64_t k = 2; k <= n; k++) {
for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) {
r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p;
p = q;
q = r;
@ -3897,7 +3897,7 @@ inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_w_forward(T x, int64_t n)
T q = x + x - T(1.0) + (x + x - T(1.0)) + T(1.0);
T r;
for (int64_t k = 2; k <= n; k++) {
for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) {
r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p;
p = q;
q = r;

View File

@ -1946,7 +1946,7 @@ const auto chebyshev_polynomial_t_string = jiterator_stringify(
T q = x;
T r;
for (int64_t k = 2; k <= n; k++) {
for (int64_t k = 2; (k <= n) && !isnan(q); k++) {
r = (x + x) * q - p;
p = q;
q = r;
@ -1996,7 +1996,7 @@ const auto chebyshev_polynomial_u_string = jiterator_stringify(
T q = x + x;
T r;
for (int64_t k = 2; k <= n; k++) {
for (int64_t k = 2; (k <= n) && !isnan(q); k++) {
r = (x + x) * q - p;
p = q;
q = r;
@ -2054,7 +2054,7 @@ const auto chebyshev_polynomial_v_string = jiterator_stringify(
T q = x + x - T(1.0);
T r;
for (int64_t k = 2; k <= n; k++) {
for (int64_t k = 2; (k <= n) && !isnan(q); k++) {
r = (x + x) * q - p;
p = q;
q = r;
@ -2116,7 +2116,7 @@ const auto chebyshev_polynomial_w_string = jiterator_stringify(
T q = x + x + T(1.0);
T r;
for (int64_t k = 2; k <= n; k++) {
for (int64_t k = 2; (k <= n) && !isnan(q); k++) {
r = (x + x) * q - p;
p = q;
q = r;
@ -2252,7 +2252,7 @@ const auto laguerre_polynomial_l_string = jiterator_stringify(
T q = T(1.0) - x;
T r;
for (int64_t k = 1; k < n; k++) {
for (int64_t k = 1; (k < n) && !isnan(q); k++) {
r = (((k + k) + (T(1.0) - x)) * q - k * p) / (k + 1);
p = q;
q = r;
@ -2294,7 +2294,7 @@ const auto legendre_polynomial_p_string = jiterator_stringify(
T q = x;
T r;
for (int64_t k = 1; k < n; k++) {
for (int64_t k = 1; (k < n) && !isnan(q); k++) {
r = ((k + k + 1) * x * q - k * p) / (k + 1);
p = q;
q = r;
@ -2851,7 +2851,7 @@ const auto shifted_chebyshev_polynomial_t_string = jiterator_stringify(
T q = x + x - T(1.0);
T r;
for (int64_t k = 2; k <= n; k++) {
for (int64_t k = 2; (k <= n) && !isnan(q); k++) {
r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p;
p = q;
q = r;
@ -2905,7 +2905,7 @@ const auto shifted_chebyshev_polynomial_u_string = jiterator_stringify(
T q = x + x - T(1.0) + (x + x - T(1.0));
T r;
for (int64_t k = 2; k <= n; k++) {
for (int64_t k = 2; (k <= n) && !isnan(q); k++) {
r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p;
p = q;
q = r;
@ -2963,7 +2963,7 @@ const auto shifted_chebyshev_polynomial_v_string = jiterator_stringify(
T q = x + x - T(1.0) + (x + x - T(1.0)) - T(1.0);
T r;
for (int64_t k = 2; k <= n; k++) {
for (int64_t k = 2; (k <= n) && !isnan(q); k++) {
r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p;
p = q;
q = r;
@ -3021,7 +3021,7 @@ const auto shifted_chebyshev_polynomial_w_string = jiterator_stringify(
T q = x + x - T(1.0) + (x + x - T(1.0)) + T(1.0);
T r;
for (int64_t k = 2; k <= n; k++) {
for (int64_t k = 2; (k <= n) && !isnan(q); k++) {
r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p;
p = q;
q = r;

View File

@ -1559,7 +1559,7 @@ float chebyshev_polynomial_t_forward(T x, int64_t n) {
float q = x;
float r;
for (int64_t k = 2; k <= n; k++) {
for (int64_t k = 2; (k <= n) && !::metal::isnan(q); k++) {
r = (x + x) * q - p;
p = q;
q = r;
@ -1603,7 +1603,7 @@ float chebyshev_polynomial_u_forward(T x, int64_t n) {
auto p = 1.0;
float r;
for (int64_t k = 2; k <= n; k++) {
for (int64_t k = 2; (k <= n) && !::metal::isnan(q); k++) {
r = 2 * x * q - p;
p = q;
q = r;
@ -1656,7 +1656,7 @@ float chebyshev_polynomial_v_forward(T x, int64_t n) {
auto p = 1.0;
float r;
for (int64_t k = 2; k <= n; k++) {
for (int64_t k = 2; (k <= n) && !::metal::isnan(q); k++) {
r = 2 * x * q - p;
p = q;
q = r;
@ -1713,7 +1713,7 @@ float chebyshev_polynomial_w_forward(T x, int64_t n) {
auto p = 1.0;
float r;
for (int64_t k = 2; k <= n; k++) {
for (int64_t k = 2; (k <= n) && !::metal::isnan(q); k++) {
r = 2.0 * x * q - p;
p = q;
q = r;
@ -1757,7 +1757,7 @@ float shifted_chebyshev_polynomial_t_forward(T x, int64_t n) {
float q = xpxm1;
float r;
for (int64_t k = 2; k <= n; k++) {
for (int64_t k = 2; (k <= n) && !::metal::isnan(q); k++) {
r = (xpxm1 + xpxm1) * q - p;
p = q;
q = r;
@ -1806,7 +1806,7 @@ float shifted_chebyshev_polynomial_u_forward(T x, int64_t n) {
float q = xpxm1 + xpxm1;
float r;
for (int64_t k = 2; k <= n; k++) {
for (int64_t k = 2; (k <= n) && !::metal::isnan(q); k++) {
r = (xpxm1 + xpxm1) * q - p;
p = q;
q = r;
@ -1860,7 +1860,7 @@ float shifted_chebyshev_polynomial_v_forward(T x, int64_t n) {
float q = xpxm1 + xpxm1 - 1.0;
float r;
for (int64_t k = 2; k <= n; k++) {
for (int64_t k = 2; (k <= n) && !::metal::isnan(q); k++) {
r = (xpxm1 + xpxm1) * q - p;
p = q;
q = r;
@ -1914,7 +1914,7 @@ float shifted_chebyshev_polynomial_w_forward(T x, int64_t n) {
float q = xpxm1 + xpxm1 + 1.0;
float r;
for (int64_t k = 2; k <= n; k++) {
for (int64_t k = 2; (k <= n) && !::metal::isnan(q); k++) {
r = (xpxm1 + xpxm1) * q - p;
p = q;
q = r;

View File

@ -385,6 +385,8 @@ dtensor_fails = {
xfail("special.bessel_y1"),
xfail("special.chebyshev_polynomial_t"),
xfail("special.chebyshev_polynomial_u"),
xfail("special.chebyshev_polynomial_v"),
xfail("special.chebyshev_polynomial_w"),
xfail("special.entr"),
xfail("special.erfcx"),
xfail("special.hermite_polynomial_h"),
@ -393,6 +395,7 @@ dtensor_fails = {
xfail("special.i1"),
xfail("special.i1e"),
xfail("special.laguerre_polynomial_l"),
xfail("special.legendre_polynomial_p"),
xfail("special.log_ndtr"),
xfail("special.modified_bessel_i0"),
xfail("special.modified_bessel_i1"),
@ -401,6 +404,10 @@ dtensor_fails = {
xfail("special.ndtri"),
xfail("special.scaled_modified_bessel_k0"),
xfail("special.scaled_modified_bessel_k1"),
xfail("special.shifted_chebyshev_polynomial_t"),
xfail("special.shifted_chebyshev_polynomial_u"),
xfail("special.shifted_chebyshev_polynomial_v"),
xfail("special.shifted_chebyshev_polynomial_w"),
xfail("special.spherical_bessel_j0"),
xfail("special.xlog1py"),
xfail("special.zeta"),

View File

@ -4534,13 +4534,21 @@ class TestVmapOperatorsOpInfo(TestCase):
xfail("clamp_min", ""),
xfail("sparse.sampled_addmm"),
xfail("sparse.mm", "reduce"),
xfail("special.chebyshev_polynomial_t"),
xfail("special.chebyshev_polynomial_v"),
xfail("special.chebyshev_polynomial_u"),
xfail("special.chebyshev_polynomial_w"),
xfail("special.shifted_chebyshev_polynomial_t"),
xfail("special.shifted_chebyshev_polynomial_v"),
xfail("special.shifted_chebyshev_polynomial_u"),
xfail("special.shifted_chebyshev_polynomial_w"),
xfail("_segment_reduce", "offsets"),
xfail("index_reduce", "prod"),
xfail("index_reduce", "mean"),
xfail("index_reduce", "amin"),
xfail("index_reduce", "amax"),
xfail("special.laguerre_polynomial_l"),
xfail("special.legendre_polynomial_p"),
xfail("special.hermite_polynomial_h"),
xfail("jiterator_binary", device_type="cuda"),
xfail("jiterator_4inputs_with_extra_args", device_type="cuda"),
@ -4548,7 +4556,6 @@ class TestVmapOperatorsOpInfo(TestCase):
xfail("lu_solve", ""),
xfail("special.hermite_polynomial_he"),
xfail("nn.functional.dropout3d", ""),
xfail("special.chebyshev_polynomial_t"),
xfail("as_strided_scatter", ""),
xfail("equal", ""),
xfail("linalg.lu", ""),

View File

@ -465,6 +465,7 @@ if torch.backends.mps.is_available():
"special.airy_ai": None,
"special.erfcx": None,
"special.laguerre_polynomial_l": None,
"special.legendre_polynomial_p": None,
"special.log_ndtr": None,
"special.ndtri": None,
"svd_lowrank": None,

View File

@ -394,11 +394,8 @@ op_db: list[OpInfo] = [
skips=(
DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
DecorateInfo(
unittest.skip("testing takes an unreasonably long time, #79528"),
"TestCommon",
"test_compare_cpu",
),
# Greatest absolute difference: nan
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
),
supports_one_python_scalar=True,
supports_autograd=False,
@ -410,11 +407,8 @@ op_db: list[OpInfo] = [
skips=(
DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
DecorateInfo(
unittest.skip("testing takes an unreasonably long time, #79528"),
"TestCommon",
"test_compare_cpu",
),
# Greatest absolute difference: nan
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
),
supports_one_python_scalar=True,
supports_autograd=False,
@ -424,13 +418,10 @@ op_db: list[OpInfo] = [
dtypes=all_types_and(torch.bool),
promotes_int_to_float=True,
skips=(
DecorateInfo(
unittest.skip(
"Skipping - testing takes an unreasonably long time, #79528"
)
),
DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
# Greatest absolute difference: nan
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
),
supports_one_python_scalar=True,
supports_autograd=False,
@ -440,13 +431,10 @@ op_db: list[OpInfo] = [
dtypes=all_types_and(torch.bool),
promotes_int_to_float=True,
skips=(
DecorateInfo(
unittest.skip(
"Skipping - testing takes an unreasonably long time, #79528"
)
),
DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
# Greatest absolute difference: nan
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
),
supports_one_python_scalar=True,
supports_autograd=False,
@ -471,11 +459,8 @@ op_db: list[OpInfo] = [
skips=(
DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
DecorateInfo(
unittest.skip("testing takes an unreasonably long time, #79528"),
"TestCommon",
"test_compare_cpu",
),
# Greatest absolute difference: inf
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
),
supports_one_python_scalar=True,
supports_autograd=False,
@ -487,11 +472,8 @@ op_db: list[OpInfo] = [
skips=(
DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
DecorateInfo(
unittest.skip("testing takes an unreasonably long time, #79528"),
"TestCommon",
"test_compare_cpu",
),
# Greatest absolute difference: nan
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
),
supports_one_python_scalar=True,
supports_autograd=False,
@ -501,18 +483,10 @@ op_db: list[OpInfo] = [
dtypes=all_types_and(torch.bool),
promotes_int_to_float=True,
skips=(
DecorateInfo(
unittest.skip(
"Skipping - testing takes an unreasonably long time, #79528"
)
),
DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
DecorateInfo(
unittest.skip("testing takes an unreasonably long time, #79528"),
"TestCommon",
"test_compare_cpu",
),
# Greatest absolute difference: nan
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
),
supports_one_python_scalar=True,
supports_autograd=False,
@ -606,18 +580,10 @@ op_db: list[OpInfo] = [
dtypes=all_types_and(torch.bool),
promotes_int_to_float=True,
skips=(
DecorateInfo(
unittest.skip(
"Skipping - testing takes an unreasonably long time, #79528"
)
),
DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
DecorateInfo(
unittest.skip("testing takes an unreasonably long time, #79528"),
"TestCommon",
"test_compare_cpu",
),
# Greatest absolute difference: nan
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
),
supports_one_python_scalar=True,
supports_autograd=False,
@ -627,18 +593,10 @@ op_db: list[OpInfo] = [
dtypes=all_types_and(torch.bool),
promotes_int_to_float=True,
skips=(
DecorateInfo(
unittest.skip(
"Skipping - testing takes an unreasonably long time, #79528"
)
),
DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
DecorateInfo(
unittest.skip("testing takes an unreasonably long time, #79528"),
"TestCommon",
"test_compare_cpu",
),
# Greatest absolute difference: nan
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
),
supports_one_python_scalar=True,
supports_autograd=False,
@ -648,18 +606,10 @@ op_db: list[OpInfo] = [
dtypes=all_types_and(torch.bool),
promotes_int_to_float=True,
skips=(
DecorateInfo(
unittest.skip(
"Skipping - testing takes an unreasonably long time, #79528"
)
),
DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
DecorateInfo(
unittest.skip("testing takes an unreasonably long time, #79528"),
"TestCommon",
"test_compare_cpu",
),
# Greatest absolute difference: nan
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
),
supports_one_python_scalar=True,
supports_autograd=False,
@ -669,18 +619,10 @@ op_db: list[OpInfo] = [
dtypes=all_types_and(torch.bool),
promotes_int_to_float=True,
skips=(
DecorateInfo(
unittest.skip(
"Skipping - testing takes an unreasonably long time, #79528"
)
),
DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
DecorateInfo(
unittest.skip("testing takes an unreasonably long time, #79528"),
"TestCommon",
"test_compare_cpu",
),
# Greatest absolute difference: nan
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
),
supports_one_python_scalar=True,
supports_autograd=False,