Generalization of distributed test cases for non-CUDA devices (#138216)

# Motivation
This pr is an extension of #131758. As described in #131758, these changes are looking to make distributed UTs more accessible to users of all device types.

It is a demonstration of a few changes discussed by @kwen2501 and @jgong5 in the discussion for #131758(https://github.com/pytorch/pytorch/pull/131758#discussion_r1762422784)

This PR contains two types of changes, the first is to the common distributed folder where we have added a new class derived from MultiProcessTestCase which helps abstracts out the process group creation /deletion and other functionality for a given device.

The new generalized content can be added by deriving from this base class.
Also includes other misc changes for gaudi support

The second changed file is test_functional_api. a test file in common distributed. This file is a POC for how we can use this new class to write more device agnostic distributed test cases.

The following changes have been made to test_functional_api.py:
-Functionality has been added to test for non cuda devices using intel HPU as an example
-Multiple set up steps previously required by MultiProcessTestCase have been abstracted out
-Misc adaptations to allow for general call to accelerators while adding test skips instead explicitly skipping for multiple GPUs
-Skipifhpu flags have been added to enable skipping a few Multithreaded test cases which are as yet not supported on HPUs

NOTE: Within test functional api, there are tests which require the use of some multithreading functions which are as yet not supported on HPUs. These have been skipped for hpu using skipHPU decorator.

I will be raising a separate PR to improve usability pf said decorators in a device agnostic setting in the manner suggested by @kwen2501 in a comment on this PR.

This pr is a cleaned up version of a previous PR(#136988) which I closed due to human error. I have addressed some of the comments made by @kwen2501 in this as well

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138216
Approved by: https://github.com/kwen2501, https://github.com/guangyey
This commit is contained in:
Anant Gulati
2024-11-18 09:38:00 +00:00
committed by PyTorch MergeBot
parent 06dde8c157
commit b379a28a95
2 changed files with 161 additions and 127 deletions

View File

@ -43,6 +43,7 @@ from torch.testing._internal.common_utils import (
TEST_WITH_TSAN,
TestCase,
run_tests,
TEST_HPU,
)
from torch.testing._internal.distributed.multi_threaded_pg import (
_install_threaded_pg,
@ -82,6 +83,7 @@ TEST_SKIPS = {
86, "Test skipped at subprocess level, look at subprocess log for skip reason"
),
"importerror": TestSkip(88, "Test skipped due to missing import"),
"no_accelerator": TestSkip(89, "accelerator is not available."),
}
@ -101,6 +103,8 @@ class DistTestCases:
backend_feature["ddp"] = {"nccl", "gloo", "ucc"}
backend_feature["subgroup"] = {"nccl", "gloo", "ucc"}
backend_feature["plugin"] = set()
if TEST_HPU:
backend_feature["hpu"] = {"hccl"}
def skip_if_no_gpu(func):
@ -114,6 +118,8 @@ def skip_if_no_gpu(func):
world_size = int(os.environ["WORLD_SIZE"])
if torch.cuda.device_count() < world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
if TEST_HPU and torch.hpu.device_count < world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
return func(*args, **kwargs)
@ -191,6 +197,8 @@ def skip_if_lt_x_gpu(x):
def wrapper(*args, **kwargs):
if torch.cuda.is_available() and torch.cuda.device_count() >= x:
return func(*args, **kwargs)
if TEST_HPU and torch.hpu.device_count() >= x:
return func(*args, **kwargs)
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
return wrapper
@ -500,6 +508,9 @@ def init_multigpu_helper(world_size: int, backend: str):
divided to subsets, each process only uses a subset.
"""
nGPUs = torch.cuda.device_count()
if TEST_HPU:
nGPUs = torch.hpu.device_count()
visible_devices = range(nGPUs)
# If rank is less than or equal to number of available GPU's
@ -900,6 +911,47 @@ class MultiProcessTestCase(TestCase):
def is_master(self) -> bool:
return self.rank == 0
# Utility base class for distributed Multi Process Test cases
# This abstracts the PG creation and deletion, the backends are selected based
# on device type. The tests functions can be instantiated per device type using
# common_device_type.instantiate_device_type_tests
# other backends can add entry in backend() function
class DistributedTestBase(MultiProcessTestCase):
def setUp(self):
super().setUp()
self._spawn_processes()
def tearDown(self):
try:
os.remove(self.file_name)
except OSError:
pass
def backend(self, device) -> str:
if "cuda" in device:
return "nccl"
elif "hpu" in device : # intel gaudi
return "hccl"
else :
return "gloo"
def create_pg(self, device):
num_visible_devices = torch.get_device_module(device).device_count()
store = torch.distributed.FileStore(self.file_name, num_visible_devices)
torch.distributed.init_process_group(
backend=self.backend(device),
world_size=self.world_size,
rank=self.rank,
store=store
)
if "nccl" in self.backend(device):
torch.cuda.set_device(self.rank)
return torch.distributed.distributed_c10d._get_default_group()
def rank_to_device(self, device):
num_visible_devices = torch.get_device_module(device).device_count()
return {i: [i % num_visible_devices] for i in range(self.world_size)}
def run_subtests(
cls_inst,