fix MPCT destroy_pg call (#157952)

I was seeing hangs / exceptions not raising in some cases. Only call `c10d.destroy_process_group()` for `MultiProcessContinuousTest` in the clean exit case.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157952
Approved by: https://github.com/fduwjj
ghstack dependencies: #157589
This commit is contained in:
Howard Huang
2025-07-09 11:47:22 -07:00
committed by PyTorch MergeBot
parent 7444debaca
commit f4406689b8
2 changed files with 11 additions and 3 deletions

View File

@ -1616,6 +1616,7 @@ class MultiProcContinousTest(TestCase):
@classmethod
def _worker_loop(cls, rank, world_size, rdvz_file, task_queue, completion_queue):
raised_exception = False
# Sub tests are going to access these values, check first
assert 0 <= rank < world_size
# set class variables for the test class
@ -1641,6 +1642,7 @@ class MultiProcContinousTest(TestCase):
cls._run_test_given_id(test_id)
completion_queue.put(test_id)
except BaseException as ex:
raised_exception = True
# Send the exception and stack trace back to the dispatcher
exc_info = sys.exc_info()
tb_str = "".join(traceback.format_exception(*exc_info))
@ -1651,7 +1653,12 @@ class MultiProcContinousTest(TestCase):
# Termination
logger.info("Terminating ...")
c10d.destroy_process_group()
# Calling destroy_process_group when workers have exceptions
# while others are doing collectives will cause a deadlock since
# it waits for enqueued collectives to finish.
# Only call this on a clean exit path
if not raised_exception:
c10d.destroy_process_group()
@classmethod
def _spawn_processes(cls, world_size) -> None: