[CI][CUDA] Re-enable the test-nan-assert on CUDA12 (#154448)

We need to reenable this test because there are recent changes that could be relevant to test_nan_assert.

I've already tested that there would be hang if we don't remove the "pg._allgather_base(output, nan_tensor)" in between the "backend._set_enable_nan_check" calls.
Why was it "working" previously? Because previously only cu118 distributed was running and this "backend._set_enable_nan_check" change was not tested in the merge process (skip logic is if "not CUDA 12 and above", skip).

Workaround #153479

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154448
Approved by: https://github.com/kwen2501
This commit is contained in:
Wei Wang
2025-06-05 02:09:31 +00:00
committed by PyTorch MergeBot
parent 5e03433443
commit a01bb9da14

View File

@ -285,10 +285,9 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
def setUp(self):
super().setUp()
# These tests are expected to throw SIGABRT(6); adding the negative sign
# bc the test return code is actually -6
# These tests are expected to throw SIGABRT(6);
# But if we are in Sandcastle, `skip_but_pass_in_sandcastle` would return 0.
TEST_NAN_ASSERT_RETURN = 0 if IS_SANDCASTLE else -signal.SIGABRT
TEST_NAN_ASSERT_RETURN = 0 if IS_SANDCASTLE else signal.SIGABRT
self.special_return_code_checks = {
self.test_nan_assert_float16.__wrapped__: TEST_NAN_ASSERT_RETURN,
self.test_nan_assert_float32.__wrapped__: TEST_NAN_ASSERT_RETURN,
@ -485,7 +484,7 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
@requires_nccl()
@skip_but_pass_in_sandcastle_if(
# skip for cu126 as well due to https://github.com/pytorch/pytorch/issues/153479
not (TEST_MULTIGPU and CUDA_12_AND_ABOVE and False),
not (TEST_MULTIGPU and CUDA_12_AND_ABOVE),
"NCCL test requires 2+ GPUs and Device side assert could cause unexpected errors in lower versions of CUDA",
)
@parametrize(
@ -539,10 +538,15 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
backend._set_enable_nan_check(False)
# Note: using all-gather here bc some NCCL/SM version does not support
# FP8 reduction
pg._allgather_base(output, nan_tensor)
# temporarily skip due to https://github.com/pytorch/pytorch/issues/153479
# pg._allgather_base(output, nan_tensor)
backend._set_enable_nan_check(True)
pg._allgather_base(output, nan_tensor)
try:
pg._allgather_base(output, nan_tensor)
except Exception:
sys.exit(signal.SIGABRT)
dist.destroy_process_group()
# reset env