Allow MultiProcContinuousTest to set world_size (#155920)

`MultiProcContinuousTest` will automatically set world_size to number of devices. This change allows this attribute to be modified by the derived test class

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155920
Approved by: https://github.com/fduwjj
This commit is contained in:
Howard Huang
2025-06-13 09:50:14 -07:00
committed by PyTorch MergeBot
parent 9bd42c1570
commit 8e1471bdc9

View File

@ -1649,9 +1649,11 @@ class MultiProcContinousTest(TestCase):
# Use device count as world size
device_type = cls.device_type()
cls.world_size = torch.get_device_module(device_type).device_count()
if cls.world_size == 0:
raise unittest.SkipTest(f"No {device_type} devices available")
# If world_size is not set, use device count
if cls.world_size == -2:
cls.world_size = torch.get_device_module(device_type).device_count()
if cls.world_size == 0:
raise unittest.SkipTest(f"No {device_type} devices available")
logger.info(
f"Testing class {cls.__name__} on {cls.world_size} {device_type}" # noqa: G004